├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── code-review.md │ └── feature_request.md └── workflows │ ├── go-main.yml │ └── go-pr.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── _config.yml ├── go.mod ├── go.sum └── v1 ├── api.go ├── api_server.go ├── api_server_test.go ├── apollows ├── proto.go └── proto_test.go ├── ast.go ├── ast_test.go ├── callbacks.go ├── compat.go ├── compat ├── gorillaws │ └── api.go └── otelwsgraphql │ └── api.go ├── context.go ├── context_test.go ├── error.go ├── examples ├── minimal-graphql-transport-ws │ ├── README.md │ └── main.go ├── minimal-graphql-ws │ ├── README.md │ └── main.go └── simpleserver │ ├── README.md │ ├── main.go │ └── playground.html ├── interceptors.go ├── mutable ├── api.go └── mutcontext_test.go ├── server.go ├── server_plain.go ├── server_plain_test.go ├── server_test.go ├── server_websocket.go └── server_websocket_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Bug report default 4 | title: '' 5 | labels: bug 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | ``` 15 | Minimal code reproducing the issue 16 | ``` 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/code-review.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Code review 3 | about: Review existing code regarding possible improvements 4 | title: '' 5 | labels: enhancement 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | -------------------------------------------------------------------------------- /.github/workflows/go-main.yml: -------------------------------------------------------------------------------- 1 | name: go-main 2 | on: 3 | push: 4 | jobs: 5 | test: 6 | name: Test 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - uses: actions/setup-go@v3 11 | with: 12 | go-version: "1.20.4" 13 | - name: golangci-lint 14 | uses: golangci/golangci-lint-action@v3 15 | with: 16 | version: "v1.52" 17 | - name: Test & publish code coverage 18 | uses: paambaati/codeclimate-action@v3.0.0 19 | env: 20 | CC_TEST_REPORTER_ID: ${{secrets.CC_TEST_REPORTER_ID}} 21 | with: 22 | coverageCommand: go test -race -coverprofile c.out -covermode=atomic -v -bench=. ./... 23 | prefix: github.com/eientei/wsgraphql 24 | coverageLocations: ${{github.workspace}}/c.out:gocov 25 | -------------------------------------------------------------------------------- /.github/workflows/go-pr.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | pull_request: 4 | permissions: 5 | contents: read 6 | jobs: 7 | golangci: 8 | name: lint 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions/setup-go@v3 13 | with: 14 | go-version: 1.17 15 | - name: golangci-lint 16 | uses: golangci/golangci-lint-action@v3 17 | with: 18 | version: v1.48 19 | - name: test 20 | run: go test -v ./... 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | vendor 4 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | concurrency: 4 3 | timeout: 1m 4 | issues-exit-code: 1 5 | tests: true 6 | output: 7 | format: colored-line-number 8 | print-issued-lines: true 9 | print-linter-name: true 10 | uniq-by-line: true 11 | 12 | linters-settings: 13 | errcheck: 14 | check-type-assertions: true 15 | check-blank: false 16 | gocognit: 17 | min-complexity: 25 18 | goconst: 19 | min-len: 3 20 | min-occurrences: 3 21 | gocyclo: 22 | min-complexity: 15 23 | gofmt: 24 | simplify: true 25 | revive: 26 | min-confidence: 0.8 27 | govet: 28 | check-shadowing: true 29 | enable-all: true 30 | lll: 31 | line-length: 120 32 | tab-width: 1 33 | wsl: 34 | strict-append: true 35 | allow-assign-and-call: true 36 | allow-multiline-assign: true 37 | allow-cuddle-declarations: false 38 | allow-trailing-comment: false 39 | force-case-trailing-whitespace: 0 40 | 41 | linters: 42 | disable-all: true 43 | enable: 44 | - govet 45 | - errcheck 46 | - unused 47 | - gosimple 48 | - ineffassign 49 | - typecheck 50 | - bodyclose 51 | - stylecheck 52 | - revive 53 | - unconvert 54 | - goconst 55 | - gocyclo 56 | - gocognit 57 | - gofmt 58 | - goimports 59 | - godox 60 | - lll 61 | - unparam 62 | - gocritic 63 | - wsl 64 | - goprintffuncname 65 | - whitespace 66 | 67 | issues: 68 | exclude-use-default: false 69 | max-issues-per-linter: 0 70 | max-same-issues: 0 71 | new: false 72 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | v1.5.1 2 | ------ 3 | - [otelwsgraphql] Allow specifying tracer provider 4 | 5 | v1.5.0 6 | ------ 7 | - Added `WithInterceptors` option to allow context-based third-party instrumentation 8 | - `WithCallbacks` is deprecated in favor of `WithInterceptors` and will be removed by 2024. 9 | - Added openetelemetry instrumentation in compat package. 10 | - Bumped graphql-go dependency to v0.8.1 with critical security fixes 11 | 12 | v1.4.2 13 | ------ 14 | - Support setting headers in websocket upgrade response 15 | 16 | v1.4.1 17 | ------ 18 | - Work-around for ExtendedError extensions not rendering in subscription is added 19 | 20 | v1.4.0 21 | ------ 22 | - Added support for per-request protocol selection for websocket subscriptions using websocket 23 | subprotocol negotiation. (#2) 24 | - Added Stringer implementation to `apollows.Protocol`, to avoid the explicit type casts. 25 | In 1.5.0 underlying Protocol type will be replaced with an integer, migration to .String() 26 | as advised. 27 | 28 | v1.3.4 29 | ------ 30 | - Fixed serialization of empty data/payloads (#1) 31 | 32 | v1.3.0 33 | ------ 34 | - Breaking change: root object moved to functional options parametrization 35 | - Added support for graphql-ws (graphql-transport-ws subprotocol) 36 | - Ensured only pre-execution operation errors are returned as `error` type per apollows spec 37 | - Fixed incorrect OnConnect/OnOperation callback sequence 38 | 39 | v1.2.3 40 | ------ 41 | - Added OnDisconnect handler without respnsibility to handle error, callback sequence diagram 42 | 43 | v1.2.2 44 | ------ 45 | - Correct termination request handling 46 | 47 | v1.2.1 48 | ------ 49 | - Fixes, clarifications for websocket request teardown sequence 50 | 51 | - Added CHANGELOG.md 52 | 53 | - Added READMEs to examples 54 | 55 | - Updated LICENSE year 56 | 57 | v1.0.0-v1.2.0 58 | ------ 59 | Major refactor, cleaned up implementation 60 | Complete test coverage, versioned package scheme 61 | 62 | v0.0.1-v0.5.0 63 | --- 64 | Initial implemnetation 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Eientei Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Doc Reference](https://godoc.org/github.com/eientei/wsgraphql/v1?status.svg)](https://godoc.org/github.com/eientei/wsgraphql/v1) 2 | [![Go Report Card](https://goreportcard.com/badge/github.com/eientei/wsgraphql)](https://goreportcard.com/report/github.com/eientei/wsgraphql) 3 | [![Maintainability](https://api.codeclimate.com/v1/badges/c626b5f2399b044bdebf/maintainability)](https://codeclimate.com/github/eientei/wsgraphql) 4 | [![Test Coverage](https://api.codeclimate.com/v1/badges/c626b5f2399b044bdebf/test_coverage)](https://codeclimate.com/github/eientei/wsgraphql) 5 | 6 | An implementation of websocket transport for 7 | [graphql-go](https://github.com/graphql-go/graphql). 8 | 9 | Currently following flavors are supported: 10 | 11 | - `graphql-ws` subprotocol, older spec: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 12 | - `graphql-transport-ws` subprotocol, newer spec: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md 13 | 14 | Inspired by [graphqlws](https://github.com/functionalfoundry/graphqlws) 15 | 16 | Key features: 17 | 18 | - Subscription support 19 | - Interceptors at every stage of communication process for easy customization 20 | - Supports both websockets and plain http queries, with http chunked response for plain http subscriptions 21 | - [Mutable context](https://godoc.org/github.com/eientei/wsgraphql/v1/mutable) allowing to keep request-scoped 22 | connection/authentication data and operation-scoped state 23 | 24 | Usage 25 | ----- 26 | 27 | Assuming [gorilla websocket](https://github.com/gorilla/websocket) upgrader 28 | 29 | ```go 30 | import ( 31 | "net/http" 32 | 33 | "github.com/eientei/wsgraphql/v1" 34 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 35 | "github.com/gorilla/websocket" 36 | "github.com/graphql-go/graphql" 37 | ) 38 | ``` 39 | 40 | ```go 41 | schema, err := graphql.NewSchema(...) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | srv, err := wsgraphql.NewServer( 47 | schema, 48 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 49 | Subprotocols: []string{ 50 | wsgraphql.WebsocketSubprotocolGraphqlWS.String(), 51 | wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String(), 52 | }, 53 | })), 54 | ) 55 | if err != nil { 56 | panic(err) 57 | } 58 | 59 | http.Handle("/query", srv) 60 | 61 | err = http.ListenAndServe(":8080", nil) 62 | if err != nil { 63 | panic(err) 64 | } 65 | ``` 66 | 67 | Examples 68 | -------- 69 | 70 | See [/v1/examples](/v1/examples) 71 | - [minimal-graphql-ws](/v1/examples/minimal-graphql-ws) `graphql-ws` / older subscriptions-transport-ws server setup 72 | - [minimal-graphql-transport-ws](/v1/examples/minimal-graphql-transport-ws) `graphql-transport-ws` / newer graphql-ws server setup 73 | - [simpleserver](/v1/examples/simpleserver) complete example with subscriptions, mutations and queries 74 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/eientei/wsgraphql 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.5.0 7 | github.com/graphql-go/graphql v0.8.1 8 | github.com/stretchr/testify v1.8.4 9 | go.opentelemetry.io/otel v1.16.0 10 | go.opentelemetry.io/otel/trace v1.16.0 11 | ) 12 | 13 | require ( 14 | github.com/davecgh/go-spew v1.1.1 // indirect 15 | github.com/go-logr/logr v1.2.4 // indirect 16 | github.com/go-logr/stdr v1.2.2 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | go.opentelemetry.io/otel/metric v1.16.0 // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 5 | github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= 6 | github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 7 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 8 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 9 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 10 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 11 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 12 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 13 | github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuMMgc= 14 | github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 18 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 19 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 20 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 22 | github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 23 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 24 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 25 | go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= 26 | go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= 27 | go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo= 28 | go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4= 29 | go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= 30 | go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /v1/api.go: -------------------------------------------------------------------------------- 1 | // Package wsgraphql provides interfaces for server and client 2 | package wsgraphql 3 | 4 | import ( 5 | "github.com/eientei/wsgraphql/v1/apollows" 6 | ) 7 | 8 | // WebsocketSubprotocolGraphqlWS websocket subprotocol expected by subscriptions-transport-ws implementations 9 | const WebsocketSubprotocolGraphqlWS = apollows.WebsocketSubprotocolGraphqlWS 10 | 11 | // WebsocketSubprotocolGraphqlTransportWS websocket subprotocol expected by graphql-ws implementations 12 | const WebsocketSubprotocolGraphqlTransportWS = apollows.WebsocketSubprotocolGraphqlTransportWS 13 | -------------------------------------------------------------------------------- /v1/api_server.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "reflect" 8 | "strconv" 9 | "time" 10 | "unsafe" 11 | 12 | "github.com/eientei/wsgraphql/v1/apollows" 13 | "github.com/graphql-go/graphql" 14 | "github.com/graphql-go/graphql/gqlerrors" 15 | ) 16 | 17 | // Server implements graphql http handler with websocket support (if upgrader is provided with WithUpgrader) 18 | type Server interface { 19 | http.Handler 20 | } 21 | 22 | // NewServer returns new Server instance 23 | func NewServer( 24 | schema graphql.Schema, 25 | options ...ServerOption, 26 | ) (Server, error) { 27 | var c serverConfig 28 | 29 | c.subscriptionProtocols = make(map[apollows.Protocol]struct{}) 30 | 31 | for _, o := range options { 32 | err := o(&c) 33 | if err != nil { 34 | return nil, err 35 | } 36 | } 37 | 38 | if len(c.subscriptionProtocols) == 0 { 39 | c.subscriptionProtocols[apollows.WebsocketSubprotocolGraphqlWS] = struct{}{} 40 | c.subscriptionProtocols[apollows.WebsocketSubprotocolGraphqlTransportWS] = struct{}{} 41 | } 42 | 43 | initInterceptors(&c) 44 | 45 | if c.resultProcessor == nil { 46 | c.resultProcessor = identityResultProcessor 47 | } 48 | 49 | f := reflect.ValueOf(&schema).Elem().FieldByName("extensions") 50 | 51 | exts, ok := reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem().Interface().([]graphql.Extension) 52 | if !ok { 53 | return nil, errReflectExtensions 54 | } 55 | 56 | return &serverImpl{ 57 | schema: schema, 58 | extensions: exts, 59 | serverConfig: c, 60 | }, nil 61 | } 62 | 63 | // ServerOption to configure Server 64 | type ServerOption func(config *serverConfig) error 65 | 66 | // WithUpgrader option sets Upgrader (interface in image of gorilla websocket upgrader) 67 | func WithUpgrader(upgrader Upgrader) ServerOption { 68 | return func(config *serverConfig) error { 69 | config.upgrader = upgrader 70 | 71 | return nil 72 | } 73 | } 74 | 75 | // WithInterceptors option sets interceptors around various stages of requests 76 | func WithInterceptors(interceptors Interceptors) ServerOption { 77 | return func(config *serverConfig) error { 78 | config.interceptors = interceptors 79 | 80 | return nil 81 | } 82 | } 83 | 84 | // WithExtraInterceptors option appends interceptors instead of replacing them 85 | func WithExtraInterceptors(interceptors Interceptors) ServerOption { 86 | return func(config *serverConfig) error { 87 | if interceptors.HTTPRequest != nil { 88 | config.interceptors.HTTPRequest = InterceptorHTTPRequestChain( 89 | config.interceptors.HTTPRequest, 90 | interceptors.HTTPRequest, 91 | ) 92 | } 93 | 94 | if interceptors.Init != nil { 95 | config.interceptors.Init = InterceptorInitChain( 96 | config.interceptors.Init, 97 | interceptors.Init, 98 | ) 99 | } 100 | 101 | if interceptors.Operation != nil { 102 | config.interceptors.Operation = InterceptorOperationChain( 103 | config.interceptors.Operation, 104 | interceptors.Operation, 105 | ) 106 | } 107 | 108 | if interceptors.OperationParse != nil { 109 | config.interceptors.OperationParse = InterceptorOperationParseChain( 110 | config.interceptors.OperationParse, 111 | interceptors.OperationParse, 112 | ) 113 | } 114 | 115 | if interceptors.OperationExecute != nil { 116 | config.interceptors.OperationExecute = InterceptorOperationExecuteChain( 117 | config.interceptors.OperationExecute, 118 | interceptors.OperationExecute, 119 | ) 120 | } 121 | 122 | return nil 123 | } 124 | } 125 | 126 | // WithKeepalive enabled sending keepalive messages with provided intervals 127 | func WithKeepalive(interval time.Duration) ServerOption { 128 | return func(config *serverConfig) error { 129 | config.keepalive = interval 130 | 131 | return nil 132 | } 133 | } 134 | 135 | // WithoutHTTPQueries option prevents HTTP queries from being handled, allowing only websocket queries 136 | func WithoutHTTPQueries() ServerOption { 137 | return func(config *serverConfig) error { 138 | config.rejectHTTPQueries = true 139 | 140 | return nil 141 | } 142 | } 143 | 144 | // WithProtocol option sets protocol for this sever to use. May be specified multiple times. 145 | func WithProtocol(protocol apollows.Protocol) ServerOption { 146 | return func(config *serverConfig) error { 147 | config.subscriptionProtocols[protocol] = struct{}{} 148 | 149 | return nil 150 | } 151 | } 152 | 153 | // WithConnectTimeout option sets duration within which client is allowed to initialize the connection before being 154 | // disconnected 155 | func WithConnectTimeout(timeout time.Duration) ServerOption { 156 | return func(config *serverConfig) error { 157 | config.connectTimeout = timeout 158 | 159 | return nil 160 | } 161 | } 162 | 163 | // WithRootObject provides root object that will be used in root resolvers 164 | func WithRootObject(rootObject map[string]interface{}) ServerOption { 165 | return func(config *serverConfig) error { 166 | config.rootObject = rootObject 167 | 168 | return nil 169 | } 170 | } 171 | 172 | // ResultProcessor allows to post-process resolved values 173 | type ResultProcessor func( 174 | ctx context.Context, 175 | payload *apollows.PayloadOperation, 176 | result *graphql.Result, 177 | ) *graphql.Result 178 | 179 | // WithResultProcessor provides ResultProcessor to post-process resolved values 180 | func WithResultProcessor(proc ResultProcessor) ServerOption { 181 | return func(config *serverConfig) error { 182 | config.resultProcessor = proc 183 | 184 | return nil 185 | } 186 | } 187 | 188 | // WriteError helper function writing an error to http.ResponseWriter 189 | func WriteError(ctx context.Context, w http.ResponseWriter, err error) { 190 | if err == nil || ContextHTTPResponseStarted(ctx) { 191 | return 192 | } 193 | 194 | var res ResultError 195 | 196 | if !errors.As(err, &res) { 197 | err = ResultError{ 198 | Result: &graphql.Result{ 199 | Errors: []gqlerrors.FormattedError{ 200 | gqlerrors.FormatError(err), 201 | }, 202 | }, 203 | } 204 | } 205 | 206 | bs := []byte(err.Error()) 207 | 208 | w.Header().Set("content-length", strconv.Itoa(len(bs))) 209 | w.WriteHeader(http.StatusBadRequest) 210 | 211 | _, _ = w.Write(bs) 212 | } 213 | 214 | func identityResultProcessor( 215 | _ context.Context, 216 | _ *apollows.PayloadOperation, 217 | result *graphql.Result, 218 | ) *graphql.Result { 219 | return result 220 | } 221 | -------------------------------------------------------------------------------- /v1/api_server_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/eientei/wsgraphql/v1/apollows" 13 | "github.com/eientei/wsgraphql/v1/mutable" 14 | "github.com/graphql-go/graphql" 15 | "github.com/graphql-go/graphql/gqlerrors" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestWithCallbacksOnRequest(t *testing.T) { 20 | opctx := mutable.NewMutableContext(context.Background()) 21 | 22 | var c serverConfig 23 | 24 | var encountered error 25 | 26 | target := errors.New("123") 27 | 28 | assert.NoError(t, WithCallbacks(Callbacks{ 29 | OnRequest: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter) error { 30 | return target 31 | }, 32 | OnRequestDone: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) { 33 | encountered = origerr 34 | }, 35 | OnConnect: nil, 36 | OnDisconnect: nil, 37 | OnOperation: nil, 38 | OnOperationValidation: nil, 39 | OnOperationResult: nil, 40 | OnOperationDone: nil, 41 | })(&c)) 42 | 43 | w := httptest.NewRecorder() 44 | 45 | r, _ := http.NewRequest(http.MethodGet, "", nil) 46 | 47 | _ = c.interceptors.HTTPRequest( 48 | opctx, 49 | w, 50 | r, 51 | func(reqctx context.Context, w http.ResponseWriter, r *http.Request) error { 52 | return nil 53 | }, 54 | ) 55 | 56 | assert.Equal(t, target, encountered) 57 | } 58 | 59 | func TestWithCallbacksOnRequestHandler(t *testing.T) { 60 | opctx := mutable.NewMutableContext(context.Background()) 61 | 62 | var c serverConfig 63 | 64 | var encountered error 65 | 66 | target := errors.New("123") 67 | 68 | assert.NoError(t, WithCallbacks(Callbacks{ 69 | OnRequest: nil, 70 | OnRequestDone: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) { 71 | encountered = origerr 72 | }, 73 | OnConnect: nil, 74 | OnDisconnect: nil, 75 | OnOperation: nil, 76 | OnOperationValidation: nil, 77 | OnOperationResult: nil, 78 | OnOperationDone: nil, 79 | })(&c)) 80 | 81 | w := httptest.NewRecorder() 82 | 83 | r, _ := http.NewRequest(http.MethodGet, "", nil) 84 | 85 | _ = c.interceptors.HTTPRequest( 86 | opctx, 87 | w, 88 | r, 89 | func(reqctx context.Context, w http.ResponseWriter, r *http.Request) error { 90 | return target 91 | }, 92 | ) 93 | 94 | assert.Equal(t, target, encountered) 95 | } 96 | 97 | func TestWithCallbacksOnConnect(t *testing.T) { 98 | opctx := mutable.NewMutableContext(context.Background()) 99 | 100 | var c serverConfig 101 | 102 | var encountered error 103 | 104 | target := errors.New("123") 105 | 106 | assert.NoError(t, WithCallbacks(Callbacks{ 107 | OnRequest: nil, 108 | OnRequestDone: nil, 109 | OnConnect: func(reqctx mutable.Context, init apollows.PayloadInit) error { 110 | return target 111 | }, 112 | OnDisconnect: func(reqctx mutable.Context, origerr error) error { 113 | encountered = origerr 114 | 115 | return origerr 116 | }, 117 | OnOperation: nil, 118 | OnOperationValidation: nil, 119 | OnOperationResult: nil, 120 | OnOperationDone: nil, 121 | })(&c)) 122 | 123 | assert.Equal(t, target, c.interceptors.Init( 124 | opctx, 125 | nil, 126 | func(ctx context.Context, init apollows.PayloadInit) error { 127 | return nil 128 | }, 129 | )) 130 | 131 | assert.Equal(t, target, encountered) 132 | } 133 | 134 | func TestWithCallbacksOnConnectHandler(t *testing.T) { 135 | opctx := mutable.NewMutableContext(context.Background()) 136 | 137 | var c serverConfig 138 | 139 | var encountered error 140 | 141 | target := errors.New("123") 142 | 143 | assert.NoError(t, WithCallbacks(Callbacks{ 144 | OnRequest: nil, 145 | OnRequestDone: nil, 146 | OnConnect: nil, 147 | OnDisconnect: func(reqctx mutable.Context, origerr error) error { 148 | encountered = origerr 149 | 150 | return origerr 151 | }, 152 | OnOperation: nil, 153 | OnOperationValidation: nil, 154 | OnOperationResult: nil, 155 | OnOperationDone: nil, 156 | })(&c)) 157 | 158 | assert.Equal(t, target, c.interceptors.Init( 159 | opctx, 160 | nil, 161 | func(ctx context.Context, init apollows.PayloadInit) error { 162 | return target 163 | }, 164 | )) 165 | 166 | assert.Equal(t, target, encountered) 167 | } 168 | 169 | func TestWithCallbacksOnOperation(t *testing.T) { 170 | opctx := mutable.NewMutableContext(context.Background()) 171 | 172 | var c serverConfig 173 | 174 | var encountered error 175 | 176 | target := errors.New("123") 177 | 178 | assert.NoError(t, WithCallbacks(Callbacks{ 179 | OnRequest: nil, 180 | OnRequestDone: nil, 181 | OnConnect: nil, 182 | OnDisconnect: nil, 183 | OnOperation: func(opctx mutable.Context, payload *apollows.PayloadOperation) error { 184 | return target 185 | }, 186 | OnOperationValidation: nil, 187 | OnOperationResult: nil, 188 | OnOperationDone: func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error { 189 | encountered = origerr 190 | 191 | return origerr 192 | }, 193 | })(&c)) 194 | 195 | assert.Equal(t, target, c.interceptors.Operation( 196 | opctx, 197 | nil, 198 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 199 | return nil 200 | }, 201 | )) 202 | 203 | assert.Equal(t, target, encountered) 204 | } 205 | 206 | func TestWithCallbacksOnOperationHandler(t *testing.T) { 207 | opctx := mutable.NewMutableContext(context.Background()) 208 | 209 | var c serverConfig 210 | 211 | var encountered error 212 | 213 | target := errors.New("123") 214 | 215 | assert.NoError(t, WithCallbacks(Callbacks{ 216 | OnRequest: nil, 217 | OnRequestDone: nil, 218 | OnConnect: nil, 219 | OnDisconnect: nil, 220 | OnOperation: nil, 221 | OnOperationValidation: nil, 222 | OnOperationResult: nil, 223 | OnOperationDone: func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error { 224 | encountered = origerr 225 | 226 | return origerr 227 | }, 228 | })(&c)) 229 | 230 | assert.Equal(t, target, c.interceptors.Operation( 231 | opctx, 232 | nil, 233 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 234 | return target 235 | }, 236 | )) 237 | 238 | assert.Equal(t, target, encountered) 239 | } 240 | 241 | func TestWithCallbacksOnOperationValidation(t *testing.T) { 242 | opctx := mutable.NewMutableContext(context.Background()) 243 | 244 | var c serverConfig 245 | 246 | target := errors.New("123") 247 | 248 | assert.NoError(t, WithCallbacks(Callbacks{ 249 | OnRequest: nil, 250 | OnRequestDone: nil, 251 | OnConnect: nil, 252 | OnDisconnect: nil, 253 | OnOperation: nil, 254 | OnOperationValidation: func( 255 | opctx mutable.Context, 256 | payload *apollows.PayloadOperation, 257 | result *graphql.Result, 258 | ) error { 259 | return target 260 | }, 261 | OnOperationResult: nil, 262 | OnOperationDone: nil, 263 | })(&c)) 264 | 265 | assert.Equal(t, target, c.interceptors.OperationParse( 266 | opctx, 267 | nil, 268 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 269 | return nil 270 | }, 271 | )) 272 | } 273 | 274 | func TestWithCallbacksOnOperationExecution(t *testing.T) { 275 | opctx := mutable.NewMutableContext(context.Background()) 276 | 277 | var c serverConfig 278 | 279 | target := errors.New("123") 280 | 281 | assert.NoError(t, WithCallbacks(Callbacks{ 282 | OnRequest: nil, 283 | OnRequestDone: nil, 284 | OnConnect: nil, 285 | OnDisconnect: nil, 286 | OnOperation: nil, 287 | OnOperationValidation: nil, 288 | OnOperationResult: func( 289 | opctx mutable.Context, 290 | payload *apollows.PayloadOperation, 291 | result *graphql.Result, 292 | ) error { 293 | return target 294 | }, 295 | OnOperationDone: nil, 296 | })(&c)) 297 | 298 | ch, err := c.interceptors.OperationExecute( 299 | opctx, 300 | nil, 301 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 302 | tch := make(chan *graphql.Result, 1) 303 | 304 | tch <- &graphql.Result{} 305 | close(tch) 306 | 307 | return tch, nil 308 | }, 309 | ) 310 | 311 | assert.Nil(t, err) 312 | 313 | res := <-ch 314 | 315 | assert.NotNil(t, target, res) 316 | assert.Equal(t, target, res.Data) 317 | } 318 | 319 | func TestWithCallbacksOnOperationProcess(t *testing.T) { 320 | opctx := mutable.NewMutableContext(context.Background()) 321 | 322 | var c serverConfig 323 | 324 | target := errors.New("123") 325 | 326 | assert.NoError(t, WithCallbacks(Callbacks{ 327 | OnRequest: nil, 328 | OnRequestDone: nil, 329 | OnConnect: nil, 330 | OnDisconnect: nil, 331 | OnOperation: nil, 332 | OnOperationValidation: nil, 333 | OnOperationResult: func( 334 | opctx mutable.Context, 335 | payload *apollows.PayloadOperation, 336 | result *graphql.Result, 337 | ) error { 338 | result.Data = target 339 | 340 | return nil 341 | }, 342 | OnOperationDone: nil, 343 | })(&c)) 344 | 345 | ch, err := c.interceptors.OperationExecute( 346 | opctx, 347 | nil, 348 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 349 | tch := make(chan *graphql.Result, 1) 350 | 351 | tch <- &graphql.Result{} 352 | close(tch) 353 | 354 | return tch, nil 355 | }, 356 | ) 357 | 358 | assert.Nil(t, err) 359 | 360 | res := <-ch 361 | 362 | assert.NotNil(t, target, res) 363 | assert.Equal(t, target, res.Data) 364 | } 365 | 366 | func TestWithCallbacksOnOperationExecutionHandler(t *testing.T) { 367 | opctx := mutable.NewMutableContext(context.Background()) 368 | 369 | var c serverConfig 370 | 371 | target := errors.New("123") 372 | 373 | assert.NoError(t, WithCallbacks(Callbacks{ 374 | OnRequest: nil, 375 | OnRequestDone: nil, 376 | OnConnect: nil, 377 | OnDisconnect: nil, 378 | OnOperation: nil, 379 | OnOperationValidation: nil, 380 | OnOperationResult: nil, 381 | OnOperationDone: nil, 382 | })(&c)) 383 | 384 | ch, err := c.interceptors.OperationExecute( 385 | opctx, 386 | nil, 387 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 388 | return nil, target 389 | }, 390 | ) 391 | 392 | assert.Nil(t, ch) 393 | assert.Equal(t, target, err) 394 | } 395 | 396 | func TestWithInterceptorChain(t *testing.T) { 397 | opctx := context.Background() 398 | 399 | var c serverConfig 400 | 401 | type keyT struct{} 402 | 403 | key := keyT{} 404 | 405 | assert.NoError(t, WithInterceptors(Interceptors{ 406 | HTTPRequest: InterceptorHTTPRequestChain( 407 | func(ctx context.Context, w http.ResponseWriter, r *http.Request, handler HandlerHTTPRequest) error { 408 | return handler(context.WithValue(ctx, key, 1), w, r) 409 | }, 410 | func(ctx context.Context, w http.ResponseWriter, r *http.Request, handler HandlerHTTPRequest) error { 411 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), w, r) 412 | }, 413 | ), 414 | Init: InterceptorInitChain( 415 | func(ctx context.Context, init apollows.PayloadInit, handler HandlerInit) error { 416 | return handler(context.WithValue(ctx, key, 2), init) 417 | }, 418 | func(ctx context.Context, init apollows.PayloadInit, handler HandlerInit) error { 419 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), init) 420 | }, 421 | ), 422 | Operation: InterceptorOperationChain( 423 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperation) error { 424 | return handler(context.WithValue(ctx, key, 3), payload) 425 | }, 426 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperation) error { 427 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 428 | }, 429 | ), 430 | OperationParse: InterceptorOperationParseChain( 431 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperationParse) error { 432 | return handler(context.WithValue(ctx, key, 4), payload) 433 | }, 434 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperationParse) error { 435 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 436 | }, 437 | ), 438 | OperationExecute: InterceptorOperationExecuteChain( 439 | func( 440 | ctx context.Context, 441 | payload *apollows.PayloadOperation, 442 | handler HandlerOperationExecute, 443 | ) (chan *graphql.Result, error) { 444 | return handler(context.WithValue(ctx, key, 5), payload) 445 | }, 446 | func( 447 | ctx context.Context, 448 | payload *apollows.PayloadOperation, 449 | handler HandlerOperationExecute, 450 | ) (chan *graphql.Result, error) { 451 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 452 | }, 453 | ), 454 | })(&c)) 455 | 456 | r, _ := http.NewRequest(http.MethodGet, "", nil) 457 | 458 | var res []int 459 | 460 | _ = c.interceptors.HTTPRequest( 461 | opctx, 462 | httptest.NewRecorder(), 463 | r, 464 | func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 465 | res = append(res, ctx.Value(key).(int)) 466 | 467 | return nil 468 | }, 469 | ) 470 | 471 | _ = c.interceptors.Init( 472 | opctx, 473 | nil, 474 | func(ctx context.Context, init apollows.PayloadInit) error { 475 | res = append(res, ctx.Value(key).(int)) 476 | 477 | return nil 478 | }, 479 | ) 480 | 481 | _ = c.interceptors.Operation( 482 | opctx, 483 | nil, 484 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 485 | res = append(res, ctx.Value(key).(int)) 486 | 487 | return nil 488 | }, 489 | ) 490 | 491 | _ = c.interceptors.OperationParse( 492 | opctx, 493 | nil, 494 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 495 | res = append(res, ctx.Value(key).(int)) 496 | 497 | return nil 498 | }, 499 | ) 500 | 501 | _, _ = c.interceptors.OperationExecute( 502 | opctx, 503 | nil, 504 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 505 | res = append(res, ctx.Value(key).(int)) 506 | 507 | return nil, nil 508 | }, 509 | ) 510 | 511 | assert.Equal(t, []int{2, 3, 4, 5, 6}, res) 512 | } 513 | 514 | func TestWithKeepalive(t *testing.T) { 515 | var c serverConfig 516 | 517 | assert.NoError(t, WithKeepalive(123)(&c)) 518 | 519 | assert.Equal(t, time.Duration(123), c.keepalive) 520 | } 521 | 522 | func TestWithoutHTTPQueries(t *testing.T) { 523 | var c serverConfig 524 | 525 | assert.NoError(t, WithoutHTTPQueries()(&c)) 526 | 527 | assert.Equal(t, true, c.rejectHTTPQueries) 528 | } 529 | 530 | func TestWithRootObject(t *testing.T) { 531 | var c serverConfig 532 | 533 | obj := make(map[string]interface{}) 534 | 535 | assert.NoError(t, WithRootObject(obj)(&c)) 536 | 537 | assert.Equal(t, obj, c.rootObject) 538 | } 539 | 540 | func TestWriteError(t *testing.T) { 541 | mutctx := mutable.NewMutableContext(context.Background()) 542 | 543 | rec := httptest.NewRecorder() 544 | 545 | WriteError(mutctx, rec, errors.New("123")) 546 | 547 | resp := rec.Result() 548 | 549 | bs, err := io.ReadAll(resp.Body) 550 | 551 | assert.NoError(t, err) 552 | assert.Equal(t, ResultError{ 553 | Result: &graphql.Result{ 554 | Errors: []gqlerrors.FormattedError{ 555 | gqlerrors.FormatError(errors.New("123")), 556 | }, 557 | }, 558 | }.Error(), string(bs)) 559 | 560 | assert.NoError(t, resp.Body.Close()) 561 | } 562 | 563 | func TestWriteErrorResponseStarted(t *testing.T) { 564 | mutctx := mutable.NewMutableContext(context.Background()) 565 | 566 | mutctx.Set(ContextKeyHTTPResponseStarted, true) 567 | 568 | rec := httptest.NewRecorder() 569 | 570 | WriteError(mutctx, rec, errors.New("123")) 571 | 572 | resp := rec.Result() 573 | 574 | bs, err := io.ReadAll(resp.Body) 575 | 576 | assert.NoError(t, err) 577 | assert.Equal(t, "", string(bs)) 578 | 579 | assert.NoError(t, resp.Body.Close()) 580 | } 581 | 582 | func TestOptError(t *testing.T) { 583 | srv, err := NewServer(testNewSchema(t), func(config *serverConfig) error { 584 | return errors.New("123") 585 | }) 586 | 587 | assert.Error(t, err) 588 | assert.Nil(t, srv) 589 | } 590 | -------------------------------------------------------------------------------- /v1/apollows/proto.go: -------------------------------------------------------------------------------- 1 | // Package apollows provides implementation of GraphQL over WebSocket Protocol as defined by 2 | // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md [GWS] 3 | // https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md [GTWS] 4 | package apollows 5 | 6 | import ( 7 | "encoding/json" 8 | "errors" 9 | "io" 10 | ) 11 | 12 | // Protocol websocket subprotocol defining server behavior 13 | type Protocol string 14 | 15 | // String conversion 16 | func (p Protocol) String() string { 17 | return string(p) 18 | } 19 | 20 | const ( 21 | // WebsocketSubprotocolGraphqlWS websocket subprotocol expected by subscriptions-transport-ws implementations 22 | WebsocketSubprotocolGraphqlWS Protocol = "graphql-ws" 23 | 24 | // WebsocketSubprotocolGraphqlTransportWS websocket subprotocol exepected by graphql-ws implementations 25 | WebsocketSubprotocolGraphqlTransportWS Protocol = "graphql-transport-ws" 26 | ) 27 | 28 | // ErrUnknownProtocol indicates that unknown subscription protocol was requested 29 | var ErrUnknownProtocol = errors.New("unknown subscription protocol") 30 | 31 | // Operation type is used to enumerate possible apollo message types 32 | type Operation string 33 | 34 | const ( 35 | // OperationConnectionInit [GWS,GTWS] 36 | // is set by the connecting client to initialize the websocket state with connection params (if any) 37 | OperationConnectionInit Operation = "connection_init" 38 | 39 | // OperationStart [GWS] 40 | // client request initiates new operation, each operation may have 0-N OperationData responses before being 41 | // terminated by either OperationComplete or OperationError 42 | OperationStart Operation = "start" 43 | 44 | // OperationSubscribe [GTWS] 45 | // client request initiates new operation, each operation may have 0-N OperationNext responses before being 46 | // terminated by either OperationComplete or OperationError 47 | OperationSubscribe Operation = "subscribe" 48 | 49 | // OperationTerminate [GWS] 50 | // client request to gracefully close the connection, equivalent to closing the websocket 51 | OperationTerminate Operation = "connection_terminate" 52 | 53 | // OperationConnectionError [GWS] 54 | // server response to unsuccessful OperationConnectionInit attempt 55 | OperationConnectionError Operation = "connection_error" 56 | 57 | // OperationConnectionAck [GWS,GTWS] 58 | // server response to successful OperationConnectionInit attempt 59 | OperationConnectionAck Operation = "connection_ack" 60 | 61 | // OperationData [GWS] 62 | // server response to previously initiated operation with OperationStart may be multiple within same operation, 63 | // specifically for subscriptions 64 | OperationData Operation = "data" 65 | 66 | // OperationNext [GTWS] 67 | // server response to previously initiated operation with OperationSubscribe may be multiple within same operation, 68 | // specifically for subscriptions 69 | OperationNext Operation = "next" 70 | 71 | // OperationError [GWS,GTWS] 72 | // server response to previously initiated operation with OperationStart/OperationSubscribe 73 | OperationError Operation = "error" 74 | 75 | // OperationStop [GWS] 76 | // client request to stop previously initiated operation with OperationStart 77 | OperationStop Operation = "stop" 78 | 79 | // OperationComplete [GWS,GTWS] 80 | // GWS: server response indicating previously initiated operation is complete 81 | // GTWS: server response indicating previously initiated operation is complete 82 | // GTWS: client request to stop previously initiated operation with OperationSubscribe 83 | OperationComplete Operation = "complete" 84 | 85 | // OperationKeepAlive [GWS] 86 | // server response sent periodically to maintain websocket connection open 87 | OperationKeepAlive Operation = "ka" 88 | 89 | // OperationPing [GTWS] 90 | // sever/client request for OperationPong response 91 | OperationPing Operation = "ping" 92 | 93 | // OperationPong [GTWS] 94 | // sever/client response for OperationPing request 95 | // can be sent at any time (without prior OperationPing) to maintain wesocket connection 96 | OperationPong Operation = "pong" 97 | ) 98 | 99 | // Error providing MessageType to close websocket with 100 | type Error interface { 101 | error 102 | EventMessageType() MessageType 103 | } 104 | 105 | type errorImpl struct { 106 | error 107 | message string 108 | messageType MessageType 109 | } 110 | 111 | func (e errorImpl) EventMessageType() MessageType { 112 | return e.messageType 113 | } 114 | 115 | func (e errorImpl) Unwrap() error { 116 | return e.error 117 | } 118 | 119 | func (e errorImpl) Error() string { 120 | return e.message 121 | } 122 | 123 | // WrapError wraps provided error into Error 124 | func WrapError(err error, messageType MessageType) Error { 125 | message := err.Error() 126 | 127 | if messageType.Error() != "" { 128 | message = messageType.Error() + ": " + message 129 | } 130 | 131 | return errorImpl{ 132 | error: err, 133 | message: message, 134 | messageType: messageType, 135 | } 136 | } 137 | 138 | // NewSubscriberAlreadyExistsError constructs new Error using subscriber id as part of the message 139 | func NewSubscriberAlreadyExistsError(id string) Error { 140 | return errorImpl{ 141 | error: nil, 142 | message: "Subscriber for " + id + " already exists", 143 | messageType: EventSubscriberAlreadyExists, 144 | } 145 | } 146 | 147 | // MessageType websocket message types / status codes used to indicate protocol-level events following closing the 148 | // websocket 149 | type MessageType int 150 | 151 | const ( 152 | // EventCloseNormal standard websocket message type 153 | EventCloseNormal MessageType = 1000 154 | 155 | // EventCloseError standard websocket message type 156 | EventCloseError MessageType = 1006 157 | 158 | // EventInvalidMessage indicates invalid protocol message 159 | EventInvalidMessage MessageType = 4400 160 | 161 | // EventUnauthorized indicated attempt to subscribe to an operation before receiving OperationConnectionAck 162 | EventUnauthorized MessageType = 4401 163 | 164 | // EventInitializationTimeout indicates timeout occurring before client sending OperationConnectionInit 165 | EventInitializationTimeout MessageType = 4408 166 | 167 | // EventTooManyInitializationRequests indicates receiving more than one OperationConnectionInit 168 | EventTooManyInitializationRequests MessageType = 4429 169 | 170 | // EventSubscriberAlreadyExists indicates subscribed operation ID already being in use 171 | // (not yet terminated by either OperationComplete or OperationError) 172 | EventSubscriberAlreadyExists MessageType = 4409 173 | ) 174 | 175 | var messageTypeDescriptions = map[MessageType]string{ 176 | EventCloseNormal: "Termination requested", 177 | EventInvalidMessage: "Invalid message", 178 | EventUnauthorized: "Unauthorized", 179 | EventInitializationTimeout: "Connection initialisation timeout", 180 | EventTooManyInitializationRequests: "Too many initialisation requests", 181 | } 182 | 183 | // EventMessageType implementation 184 | func (m MessageType) EventMessageType() MessageType { 185 | return m 186 | } 187 | 188 | // Error implementation 189 | func (m MessageType) Error() string { 190 | return messageTypeDescriptions[m] 191 | } 192 | 193 | // Data encapsulates both client and server json payload, combining json.RawMessage for decoding and 194 | // arbitrary interface{} type for encoding 195 | type Data struct { 196 | Value interface{} 197 | json.RawMessage 198 | } 199 | 200 | // Ptr returns non-nil pointer to Data if it is not empty 201 | func (payload *Data) Ptr() *Data { 202 | if payload == nil || payload.Value == nil { 203 | return nil 204 | } 205 | 206 | return payload 207 | } 208 | 209 | // ReadPayloadData client-side method to parse server response 210 | func (payload *Data) ReadPayloadData() (*PayloadDataResponse, error) { 211 | if payload == nil { 212 | return nil, io.ErrUnexpectedEOF 213 | } 214 | 215 | var pd PayloadDataResponse 216 | 217 | err := json.Unmarshal(payload.RawMessage, &pd) 218 | if err != nil { 219 | return nil, err 220 | } 221 | 222 | payload.Value = pd 223 | 224 | return &pd, nil 225 | } 226 | 227 | // ReadPayloadError client-side method to parse server error response 228 | func (payload *Data) ReadPayloadError() (*PayloadError, error) { 229 | if payload == nil { 230 | return nil, io.ErrUnexpectedEOF 231 | } 232 | 233 | var pd PayloadError 234 | 235 | err := json.Unmarshal(payload.RawMessage, &pd) 236 | if err != nil { 237 | return nil, err 238 | } 239 | 240 | payload.Value = pd 241 | 242 | return &pd, nil 243 | } 244 | 245 | // ReadPayloadErrors client-side method to parse server error response 246 | func (payload *Data) ReadPayloadErrors() (pds []*PayloadError, err error) { 247 | if payload == nil { 248 | return nil, io.ErrUnexpectedEOF 249 | } 250 | 251 | err = json.Unmarshal(payload.RawMessage, &pds) 252 | if err != nil { 253 | return nil, err 254 | } 255 | 256 | payload.Value = pds 257 | 258 | return pds, nil 259 | } 260 | 261 | // UnmarshalJSON stores provided json as a RawMessage, as well as initializing Value to same RawMessage 262 | // to support both identity re-serialization and modification of Value after initialization 263 | func (payload *Data) UnmarshalJSON(bs []byte) (err error) { 264 | payload.RawMessage = bs 265 | payload.Value = payload.RawMessage 266 | 267 | return nil 268 | } 269 | 270 | // MarshalJSON marshals either provided or deserialized Value as json 271 | func (payload Data) MarshalJSON() (bs []byte, err error) { 272 | return json.Marshal(payload.Value) 273 | } 274 | 275 | // PayloadInit provides connection params 276 | type PayloadInit map[string]interface{} 277 | 278 | // PayloadOperation provides description for client-side operation initiation 279 | type PayloadOperation struct { 280 | Variables map[string]interface{} `json:"variables"` 281 | Extensions map[string]interface{} `json:"extensions"` 282 | Query string `json:"query"` 283 | OperationName string `json:"operationName"` 284 | } 285 | 286 | // PayloadDataRaw provides server-side response for previously started operation 287 | type PayloadDataRaw struct { 288 | Data Data `json:"data"` 289 | 290 | // see https://github.com/graph-gophers/graphql-go#custom-errors 291 | // for adding custom error attributes 292 | Errors []error `json:"errors,omitempty"` 293 | } 294 | 295 | // PayloadData type-alias for serialization 296 | type PayloadData PayloadDataRaw 297 | 298 | // MarshalJSON serializes PayloadData to JSON, excluding empty data 299 | func (payload PayloadData) MarshalJSON() (bs []byte, err error) { 300 | return json.Marshal(struct { 301 | Data *Data `json:"data,omitempty"` 302 | PayloadDataRaw 303 | }{ 304 | PayloadDataRaw: PayloadDataRaw(payload), 305 | Data: payload.Data.Ptr(), 306 | }) 307 | } 308 | 309 | // PayloadErrorLocation error location in originating request 310 | type PayloadErrorLocation struct { 311 | Line int `json:"line"` 312 | Column int `json:"column"` 313 | } 314 | 315 | // PayloadError client-side error representation 316 | type PayloadError struct { 317 | Extensions map[string]interface{} `json:"extensions"` 318 | Message string `json:"message"` 319 | Locations []PayloadErrorLocation `json:"locations"` 320 | Path []string `json:"path"` 321 | } 322 | 323 | // PayloadDataResponse provides client-side payload representation 324 | type PayloadDataResponse struct { 325 | Data map[string]interface{} `json:"data,omitempty"` 326 | Errors []PayloadError `json:"errors,omitempty"` 327 | } 328 | 329 | // MessageRaw encapsulates every message within apollows protocol in both directions 330 | type MessageRaw struct { 331 | ID string `json:"id,omitempty"` 332 | Type Operation `json:"type"` 333 | Payload Data `json:"payload"` 334 | } 335 | 336 | // Message type-alias for (de-)serialization 337 | type Message MessageRaw 338 | 339 | // MarshalJSON serializes Message to JSON, excluding empty id or payload from serialized fields. 340 | func (message Message) MarshalJSON() (bs []byte, err error) { 341 | return json.Marshal(struct { 342 | Payload *Data `json:"payload,omitempty"` 343 | MessageRaw 344 | }{ 345 | MessageRaw: MessageRaw(message), 346 | Payload: message.Payload.Ptr(), 347 | }) 348 | } 349 | -------------------------------------------------------------------------------- /v1/apollows/proto_test.go: -------------------------------------------------------------------------------- 1 | package apollows 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "io" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestDataMarshal(t *testing.T) { 13 | data := Data{ 14 | Value: 123, 15 | } 16 | 17 | bs, err := json.Marshal(data) 18 | 19 | assert.NoError(t, err) 20 | assert.Equal(t, "123", string(bs)) 21 | } 22 | 23 | func TestDataUnmarshal(t *testing.T) { 24 | var data Data 25 | 26 | err := json.Unmarshal([]byte(`123`), &data) 27 | 28 | assert.NoError(t, err) 29 | assert.Equal(t, "123", string(data.RawMessage)) 30 | } 31 | 32 | func TestDataUnmarshalCycle(t *testing.T) { 33 | var data Data 34 | 35 | err := json.Unmarshal([]byte(`123`), &data) 36 | 37 | assert.NoError(t, err) 38 | assert.Equal(t, "123", string(data.RawMessage)) 39 | 40 | bs, err := json.Marshal(data) 41 | 42 | assert.NoError(t, err) 43 | assert.Equal(t, "123", string(bs)) 44 | } 45 | 46 | func TestDataReadPayloadData(t *testing.T) { 47 | data := Data{ 48 | Value: PayloadData{ 49 | Data: Data{ 50 | Value: map[string]interface{}{ 51 | "foo": "123", 52 | }, 53 | }, 54 | }, 55 | } 56 | 57 | bs, err := json.Marshal(data) 58 | 59 | assert.NoError(t, err) 60 | 61 | var ndata Data 62 | 63 | err = json.Unmarshal(bs, &ndata) 64 | 65 | assert.NoError(t, err) 66 | 67 | pd, err := ndata.ReadPayloadData() 68 | 69 | assert.NoError(t, err) 70 | 71 | assert.Equal(t, "123", pd.Data["foo"]) 72 | } 73 | 74 | func TestDataReadPayloadDataError(t *testing.T) { 75 | var ndata *Data 76 | 77 | pd, err := ndata.ReadPayloadData() 78 | 79 | assert.Error(t, err) 80 | assert.Nil(t, pd) 81 | 82 | ndata = &Data{ 83 | Value: nil, 84 | RawMessage: json.RawMessage(`foo`), 85 | } 86 | 87 | pd, err = ndata.ReadPayloadData() 88 | 89 | assert.Error(t, err) 90 | assert.Nil(t, pd) 91 | } 92 | 93 | func TestDataReadPayloadError(t *testing.T) { 94 | data := Data{ 95 | Value: PayloadError{ 96 | Message: "123", 97 | }, 98 | } 99 | 100 | bs, err := json.Marshal(data) 101 | 102 | assert.NoError(t, err) 103 | 104 | var ndata Data 105 | 106 | err = json.Unmarshal(bs, &ndata) 107 | 108 | assert.NoError(t, err) 109 | 110 | pd, err := ndata.ReadPayloadError() 111 | 112 | assert.NoError(t, err) 113 | assert.Equal(t, "123", pd.Message) 114 | } 115 | 116 | func TestDataReadPayloadErrorError(t *testing.T) { 117 | var ndata *Data 118 | 119 | pd, err := ndata.ReadPayloadError() 120 | 121 | assert.Error(t, err) 122 | assert.Nil(t, pd) 123 | 124 | ndata = &Data{ 125 | Value: nil, 126 | RawMessage: json.RawMessage(`foo`), 127 | } 128 | 129 | pd, err = ndata.ReadPayloadError() 130 | 131 | assert.Error(t, err) 132 | assert.Nil(t, pd) 133 | } 134 | 135 | func TestDataReadPayloadErrors(t *testing.T) { 136 | data := Data{ 137 | Value: []PayloadError{ 138 | { 139 | Message: "123", 140 | }, 141 | }, 142 | } 143 | 144 | bs, err := json.Marshal(data) 145 | 146 | assert.NoError(t, err) 147 | 148 | var ndata Data 149 | 150 | err = json.Unmarshal(bs, &ndata) 151 | 152 | assert.NoError(t, err) 153 | 154 | pd, err := ndata.ReadPayloadErrors() 155 | 156 | assert.NoError(t, err) 157 | assert.Len(t, pd, 1) 158 | assert.Equal(t, "123", pd[0].Message) 159 | } 160 | 161 | func TestDataReadPayloadErrorsError(t *testing.T) { 162 | var ndata *Data 163 | 164 | pd, err := ndata.ReadPayloadErrors() 165 | 166 | assert.Error(t, err) 167 | assert.Nil(t, pd) 168 | 169 | ndata = &Data{ 170 | Value: nil, 171 | RawMessage: json.RawMessage(`foo`), 172 | } 173 | 174 | pd, err = ndata.ReadPayloadErrors() 175 | 176 | assert.Error(t, err) 177 | assert.Nil(t, pd) 178 | } 179 | 180 | func TestWrapError(t *testing.T) { 181 | err := WrapError(io.EOF, EventUnauthorized) 182 | 183 | assert.True(t, errors.Is(err, io.EOF)) 184 | } 185 | 186 | func TestMessageMarshal(t *testing.T) { 187 | data := Message{ 188 | ID: "123", 189 | Type: OperationConnectionAck, 190 | Payload: Data{ 191 | Value: "foo", 192 | }, 193 | } 194 | 195 | bs, err := json.Marshal(data) 196 | 197 | assert.NoError(t, err) 198 | assert.JSONEq(t, `{"id":"123","type":"connection_ack","payload":"foo"}`, string(bs)) 199 | } 200 | 201 | func TestMessageMarshalEmpty(t *testing.T) { 202 | data := Message{ 203 | ID: "123", 204 | Type: OperationConnectionAck, 205 | Payload: Data{}, 206 | } 207 | 208 | bs, err := json.Marshal(data) 209 | 210 | assert.NoError(t, err) 211 | assert.JSONEq(t, `{"id":"123","type":"connection_ack"}`, string(bs)) 212 | } 213 | 214 | func TestPayloadDataMarshal(t *testing.T) { 215 | data := PayloadData{ 216 | Data: Data{ 217 | Value: "foo", 218 | }, 219 | } 220 | 221 | bs, err := json.Marshal(data) 222 | 223 | assert.NoError(t, err) 224 | assert.JSONEq(t, `{"data":"foo"}`, string(bs)) 225 | } 226 | 227 | func TestPayloadDataMarshalEmpty(t *testing.T) { 228 | data := PayloadData{ 229 | Data: Data{}, 230 | } 231 | 232 | bs, err := json.Marshal(data) 233 | 234 | assert.NoError(t, err) 235 | assert.JSONEq(t, `{}`, string(bs)) 236 | } 237 | -------------------------------------------------------------------------------- /v1/ast.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/graphql-go/graphql" 9 | "github.com/graphql-go/graphql/gqlerrors" 10 | "github.com/graphql-go/graphql/language/ast" 11 | "github.com/graphql-go/graphql/language/parser" 12 | "github.com/graphql-go/graphql/language/source" 13 | ) 14 | 15 | func (server *serverImpl) handleExtensionsInits(p *graphql.Params) *graphql.Result { 16 | var errs gqlerrors.FormattedErrors 17 | 18 | for _, ext := range server.extensions { 19 | func() { 20 | defer func() { 21 | if r := recover(); r != nil { 22 | errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), r))) 23 | } 24 | }() 25 | 26 | p.Context = ext.Init(p.Context, p) 27 | }() 28 | } 29 | 30 | if len(errs) == 0 { 31 | return nil 32 | } 33 | 34 | return &graphql.Result{ 35 | Errors: errs, 36 | } 37 | } 38 | 39 | func (server *serverImpl) handleExtensionsParseDidStart( 40 | p *graphql.Params, 41 | ) (res *graphql.Result, endfn func(err error) *graphql.Result) { 42 | fs := make(map[string]graphql.ParseFinishFunc) 43 | 44 | var errs gqlerrors.FormattedErrors 45 | 46 | for _, ext := range server.extensions { 47 | var ( 48 | ctx context.Context 49 | finishFn graphql.ParseFinishFunc 50 | ) 51 | 52 | func() { 53 | defer func() { 54 | if r := recover(); r != nil { 55 | errs = append( 56 | errs, 57 | gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), r)), 58 | ) 59 | } 60 | }() 61 | 62 | ctx, finishFn = ext.ParseDidStart(p.Context) 63 | 64 | p.Context = ctx 65 | 66 | fs[ext.Name()] = finishFn 67 | }() 68 | } 69 | 70 | endfn = func(err error) *graphql.Result { 71 | var inerrs gqlerrors.FormattedErrors 72 | 73 | if err != nil { 74 | inerrs = append(inerrs, gqlerrors.FormatError(err)) 75 | } 76 | 77 | for name, fn := range fs { 78 | func() { 79 | defer func() { 80 | if r := recover(); r != nil { 81 | inerrs = append( 82 | inerrs, 83 | gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", name, r)), 84 | ) 85 | } 86 | }() 87 | 88 | fn(err) 89 | }() 90 | } 91 | 92 | if len(inerrs) == 0 { 93 | return nil 94 | } 95 | 96 | return &graphql.Result{ 97 | Errors: inerrs, 98 | } 99 | } 100 | 101 | if len(errs) > 0 { 102 | res = &graphql.Result{ 103 | Errors: errs, 104 | } 105 | } 106 | 107 | return 108 | } 109 | 110 | func (server *serverImpl) handleExtensionsValidationDidStart( 111 | p *graphql.Params, 112 | ) (errs []gqlerrors.FormattedError, endfn func(errs []gqlerrors.FormattedError) []gqlerrors.FormattedError) { 113 | fs := make(map[string]graphql.ValidationFinishFunc) 114 | 115 | for _, ext := range server.extensions { 116 | var ( 117 | ctx context.Context 118 | finishFn graphql.ValidationFinishFunc 119 | ) 120 | 121 | func() { 122 | defer func() { 123 | if r := recover(); r != nil { 124 | errs = append( 125 | errs, 126 | gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), r)), 127 | ) 128 | } 129 | }() 130 | 131 | ctx, finishFn = ext.ValidationDidStart(p.Context) 132 | 133 | p.Context = ctx 134 | fs[ext.Name()] = finishFn 135 | }() 136 | } 137 | 138 | endfn = func(errs []gqlerrors.FormattedError) (inerrs []gqlerrors.FormattedError) { 139 | inerrs = append(inerrs, errs...) 140 | 141 | for name, finishFn := range fs { 142 | func() { 143 | defer func() { 144 | if r := recover(); r != nil { 145 | inerrs = append( 146 | inerrs, 147 | gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", name, r)), 148 | ) 149 | } 150 | }() 151 | 152 | finishFn(errs) 153 | }() 154 | } 155 | 156 | return 157 | } 158 | 159 | return 160 | } 161 | 162 | func (server *serverImpl) parseAST( 163 | ctx context.Context, 164 | payload *apollows.PayloadOperation, 165 | ) (err error) { 166 | opctx := OperationContext(ctx) 167 | 168 | var result *graphql.Result 169 | 170 | defer func() { 171 | if result != nil { 172 | err = ResultError{Result: result} 173 | } 174 | }() 175 | 176 | src := source.NewSource(&source.Source{ 177 | Body: []byte(payload.Query), 178 | Name: "GraphQL request", 179 | }) 180 | 181 | params := graphql.Params{ 182 | Schema: server.schema, 183 | RequestString: payload.Query, 184 | RootObject: server.rootObject, 185 | VariableValues: payload.Variables, 186 | OperationName: payload.OperationName, 187 | Context: ctx, 188 | } 189 | 190 | opctx.Set(ContextKeyOperationParams, ¶ms) 191 | 192 | result = server.handleExtensionsInits(¶ms) 193 | if result != nil { 194 | return 195 | } 196 | 197 | var parseFinishFn func(err error) *graphql.Result 198 | 199 | result, parseFinishFn = server.handleExtensionsParseDidStart(¶ms) 200 | if result != nil { 201 | return 202 | } 203 | 204 | astdoc, err := parser.Parse(parser.ParseParams{Source: src}) 205 | 206 | opctx.Set(ContextKeyAST, astdoc) 207 | 208 | result = parseFinishFn(err) 209 | if result != nil { 210 | return 211 | } 212 | 213 | errs, validationFinishFn := server.handleExtensionsValidationDidStart(¶ms) 214 | 215 | validationResult := graphql.ValidateDocument(¶ms.Schema, astdoc, nil) 216 | 217 | errs = append(errs, validationFinishFn(validationResult.Errors)...) 218 | 219 | if len(errs) > 0 || !validationResult.IsValid { 220 | result = &graphql.Result{ 221 | Errors: errs, 222 | } 223 | 224 | return 225 | } 226 | 227 | var subscription bool 228 | 229 | for _, definition := range astdoc.Definitions { 230 | op, ok := definition.(*ast.OperationDefinition) 231 | if !ok { 232 | continue 233 | } 234 | 235 | if op.Operation == ast.OperationTypeSubscription { 236 | subscription = true 237 | 238 | break 239 | } 240 | } 241 | 242 | opctx.Set(ContextKeySubscription, subscription) 243 | 244 | return 245 | } 246 | -------------------------------------------------------------------------------- /v1/ast_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql" 10 | "github.com/graphql-go/graphql/gqlerrors" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestASTParse(t *testing.T) { 15 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 16 | Query: graphql.NewObject(graphql.ObjectConfig{ 17 | Name: "QueryRoot", 18 | Interfaces: nil, 19 | Fields: graphql.Fields{ 20 | "foo": &graphql.Field{ 21 | Name: "FooType", 22 | Type: graphql.Int, 23 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 24 | return 123, nil 25 | }, 26 | }, 27 | }, 28 | }), 29 | }) 30 | 31 | assert.NoError(t, err) 32 | assert.NotNil(t, schema) 33 | 34 | server, err := NewServer(schema) 35 | 36 | assert.NoError(t, err) 37 | assert.NotNil(t, server) 38 | 39 | impl, ok := server.(*serverImpl) 40 | 41 | assert.True(t, ok) 42 | 43 | opctx := mutable.NewMutableContext(context.Background()) 44 | opctx.Set(ContextKeyOperationContext, opctx) 45 | 46 | err = impl.parseAST(opctx, &apollows.PayloadOperation{ 47 | Query: `query { foo }`, 48 | Variables: nil, 49 | OperationName: "", 50 | }) 51 | 52 | assert.Nil(t, err) 53 | assert.False(t, ContextSubscription(opctx)) 54 | assert.NotNil(t, ContextAST(opctx)) 55 | assert.NotNil(t, ContextOperationParams(opctx)) 56 | } 57 | 58 | func TestASTParseSubscription(t *testing.T) { 59 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 60 | Query: graphql.NewObject(graphql.ObjectConfig{ 61 | Name: "QueryRoot", 62 | Interfaces: nil, 63 | Fields: graphql.Fields{ 64 | "foo": &graphql.Field{ 65 | Name: "FooType", 66 | Type: graphql.Int, 67 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 68 | return 123, nil 69 | }, 70 | }, 71 | }, 72 | }), 73 | Subscription: graphql.NewObject(graphql.ObjectConfig{ 74 | Name: "SubscriptionRoot", 75 | Interfaces: nil, 76 | Fields: graphql.Fields{ 77 | "foo": &graphql.Field{ 78 | Name: "FooType", 79 | Type: graphql.Int, 80 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 81 | return 123, nil 82 | }, 83 | }, 84 | }, 85 | }), 86 | }) 87 | 88 | assert.NoError(t, err) 89 | assert.NotNil(t, schema) 90 | 91 | server, err := NewServer(schema) 92 | 93 | assert.NoError(t, err) 94 | assert.NotNil(t, server) 95 | 96 | impl, ok := server.(*serverImpl) 97 | 98 | assert.True(t, ok) 99 | 100 | opctx := mutable.NewMutableContext(context.Background()) 101 | opctx.Set(ContextKeyOperationContext, opctx) 102 | 103 | err = impl.parseAST(opctx, &apollows.PayloadOperation{ 104 | Query: `subscription { foo }`, 105 | Variables: nil, 106 | OperationName: "", 107 | }) 108 | 109 | assert.Nil(t, err) 110 | assert.True(t, ContextSubscription(opctx)) 111 | assert.NotNil(t, ContextAST(opctx)) 112 | assert.NotNil(t, ContextOperationParams(opctx)) 113 | } 114 | 115 | type testExt struct { 116 | initFn func(ctx context.Context, p *graphql.Params) context.Context 117 | hasResultFn func() bool 118 | getResultFn func(context.Context) interface{} 119 | parseDidStartFn func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) 120 | validationDidStartFn func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) 121 | executionDidStartFn func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) 122 | resolveFieldDidStartFn func( 123 | ctx context.Context, 124 | i *graphql.ResolveInfo, 125 | ) (context.Context, graphql.ResolveFieldFinishFunc) 126 | name string 127 | } 128 | 129 | func (t *testExt) Init(ctx context.Context, p *graphql.Params) context.Context { 130 | return t.initFn(ctx, p) 131 | } 132 | 133 | func (t *testExt) Name() string { 134 | return t.name 135 | } 136 | 137 | func (t *testExt) HasResult() bool { 138 | return t.hasResultFn() 139 | } 140 | 141 | func (t *testExt) GetResult(ctx context.Context) interface{} { 142 | return t.getResultFn(ctx) 143 | } 144 | 145 | func (t *testExt) ParseDidStart(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 146 | return t.parseDidStartFn(ctx) 147 | } 148 | 149 | func (t *testExt) ValidationDidStart(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 150 | return t.validationDidStartFn(ctx) 151 | } 152 | 153 | func (t *testExt) ExecutionDidStart(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 154 | return t.executionDidStartFn(ctx) 155 | } 156 | 157 | func (t *testExt) ResolveFieldDidStart( 158 | ctx context.Context, 159 | i *graphql.ResolveInfo, 160 | ) (context.Context, graphql.ResolveFieldFinishFunc) { 161 | return t.resolveFieldDidStartFn(ctx, i) 162 | } 163 | 164 | func testAstParseExtensions( 165 | t *testing.T, 166 | opctx mutable.Context, 167 | f func(ext *testExt), 168 | ) (err error) { 169 | text := &testExt{ 170 | name: "foo", 171 | initFn: func(ctx context.Context, p *graphql.Params) context.Context { 172 | return ctx 173 | }, 174 | hasResultFn: func() bool { 175 | return true 176 | }, 177 | getResultFn: func(ctx context.Context) interface{} { 178 | return nil 179 | }, 180 | parseDidStartFn: func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 181 | return ctx, func(err error) { 182 | 183 | } 184 | }, 185 | validationDidStartFn: func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 186 | return ctx, func(errors []gqlerrors.FormattedError) { 187 | 188 | } 189 | }, 190 | executionDidStartFn: func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 191 | return ctx, func(result *graphql.Result) { 192 | 193 | } 194 | }, 195 | resolveFieldDidStartFn: func( 196 | ctx context.Context, 197 | i *graphql.ResolveInfo, 198 | ) (context.Context, graphql.ResolveFieldFinishFunc) { 199 | return ctx, func(i interface{}, err error) { 200 | 201 | } 202 | }, 203 | } 204 | 205 | f(text) 206 | 207 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 208 | Query: graphql.NewObject(graphql.ObjectConfig{ 209 | Name: "QueryRoot", 210 | Interfaces: nil, 211 | Fields: graphql.Fields{ 212 | "foo": &graphql.Field{ 213 | Name: "FooType", 214 | Type: graphql.Int, 215 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 216 | return 123, nil 217 | }, 218 | }, 219 | }, 220 | }), 221 | Extensions: []graphql.Extension{ 222 | text, 223 | }, 224 | }) 225 | 226 | assert.NoError(t, err) 227 | assert.NotNil(t, schema) 228 | 229 | server, err := NewServer(schema) 230 | 231 | assert.NoError(t, err) 232 | assert.NotNil(t, server) 233 | 234 | impl, ok := server.(*serverImpl) 235 | 236 | assert.True(t, ok) 237 | 238 | return impl.parseAST(opctx, &apollows.PayloadOperation{ 239 | Query: `query { foo }`, 240 | Variables: nil, 241 | OperationName: "", 242 | }) 243 | } 244 | 245 | func TestASTParseExtensions(t *testing.T) { 246 | opctx := mutable.NewMutableContext(context.Background()) 247 | opctx.Set(ContextKeyOperationContext, opctx) 248 | 249 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 250 | 251 | }) 252 | 253 | assert.Nil(t, err) 254 | assert.False(t, ContextSubscription(opctx)) 255 | assert.NotNil(t, ContextAST(opctx)) 256 | assert.NotNil(t, ContextOperationParams(opctx)) 257 | } 258 | 259 | func TestASTParseExtensionsPanicInit(t *testing.T) { 260 | opctx := mutable.NewMutableContext(context.Background()) 261 | opctx.Set(ContextKeyOperationContext, opctx) 262 | 263 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 264 | ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context { 265 | panic(1) 266 | } 267 | }) 268 | 269 | assert.NotNil(t, err) 270 | assert.False(t, ContextSubscription(opctx)) 271 | assert.Nil(t, ContextAST(opctx)) 272 | assert.NotNil(t, ContextOperationParams(opctx)) 273 | } 274 | 275 | func TestASTParseExtensionsPanicValidation(t *testing.T) { 276 | opctx := mutable.NewMutableContext(context.Background()) 277 | opctx.Set(ContextKeyOperationContext, opctx) 278 | 279 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 280 | ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 281 | panic(1) 282 | } 283 | }) 284 | 285 | assert.NotNil(t, err) 286 | assert.False(t, ContextSubscription(opctx)) 287 | assert.NotNil(t, ContextAST(opctx)) 288 | assert.NotNil(t, ContextOperationParams(opctx)) 289 | } 290 | 291 | func TestASTParseExtensionsPanicValidationCb(t *testing.T) { 292 | opctx := mutable.NewMutableContext(context.Background()) 293 | opctx.Set(ContextKeyOperationContext, opctx) 294 | 295 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 296 | ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 297 | return ctx, func(errors []gqlerrors.FormattedError) { 298 | panic(1) 299 | } 300 | } 301 | }) 302 | 303 | assert.NotNil(t, err) 304 | assert.False(t, ContextSubscription(opctx)) 305 | assert.NotNil(t, ContextAST(opctx)) 306 | assert.NotNil(t, ContextOperationParams(opctx)) 307 | } 308 | 309 | func TestASTParseExtensionsPanicParse(t *testing.T) { 310 | opctx := mutable.NewMutableContext(context.Background()) 311 | opctx.Set(ContextKeyOperationContext, opctx) 312 | 313 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 314 | ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 315 | panic(1) 316 | } 317 | }) 318 | 319 | assert.NotNil(t, err) 320 | assert.False(t, ContextSubscription(opctx)) 321 | assert.Nil(t, ContextAST(opctx)) 322 | assert.NotNil(t, ContextOperationParams(opctx)) 323 | } 324 | 325 | func TestASTParseExtensionsPanicParseCb(t *testing.T) { 326 | opctx := mutable.NewMutableContext(context.Background()) 327 | opctx.Set(ContextKeyOperationContext, opctx) 328 | 329 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 330 | ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 331 | return ctx, func(err error) { 332 | panic(1) 333 | } 334 | } 335 | }) 336 | 337 | assert.NotNil(t, err) 338 | assert.False(t, ContextSubscription(opctx)) 339 | assert.NotNil(t, ContextAST(opctx)) 340 | assert.NotNil(t, ContextOperationParams(opctx)) 341 | } 342 | -------------------------------------------------------------------------------- /v1/callbacks.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql" 10 | ) 11 | 12 | // WithCallbacks option sets callbacks handling various stages of requests 13 | // Deprecated: use WithInterceptors / WithResultProcessor 14 | func WithCallbacks(callbacks Callbacks) ServerOption { 15 | return WithExtraInterceptors(Interceptors{ 16 | HTTPRequest: legacyCallbackHTTPRequest(callbacks), 17 | Init: legacyCallbackInit(callbacks), 18 | Operation: legacyCallbackOperation(callbacks), 19 | OperationParse: legacyCallbackOperationParse(callbacks), 20 | OperationExecute: legacyCallbackExecute(callbacks), 21 | }) 22 | } 23 | 24 | // Callbacks supported by the server 25 | // use wsgraphql.ContextHTTPRequest / wsgraphql.ContextHTTPResponseWriter to access underlying 26 | // http.Request and http.ResponseWriter 27 | // Sequence: 28 | // OnRequest -> OnConnect -> 29 | // [ OnOperation -> OnOperationValidation -> OnOperationResult -> OnOperationDone ]* -> 30 | // OnDisconnect -> OnRequestDone 31 | // Deprecated: use Interceptors / ResultProcessor 32 | type Callbacks struct { 33 | // OnRequest called once HTTP request is received, before attempting to do websocket upgrade or plain request 34 | // execution, consequently before OnConnect as well. 35 | OnRequest func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter) error 36 | 37 | // OnRequestDone called once HTTP request is finished, regardless of request type, with error occurred during 38 | // request execution (if any). 39 | // By default, if error is present, will write error text and return 400 code. 40 | OnRequestDone func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) 41 | 42 | // OnConnect is called once per HTTP request, after websocket upgrade and init message received in case of 43 | // websocket request, or before execution in case of plain request 44 | OnConnect func(reqctx mutable.Context, init apollows.PayloadInit) error 45 | 46 | // OnDisconnect is called once per HTTP request, before OnRequestDone, without responsibility to handle errors 47 | OnDisconnect func(reqctx mutable.Context, origerr error) error 48 | 49 | // OnOperation is called before each operation with original payload, allowing to modify it or terminate 50 | // the operation by returning an error. 51 | OnOperation func(opctx mutable.Context, payload *apollows.PayloadOperation) error 52 | 53 | // OnOperationValidation is called after parsing an operation payload with any immediate validation result, if 54 | // available. AST will be available in context with ContextAST if parsing succeeded. 55 | OnOperationValidation func(opctx mutable.Context, payload *apollows.PayloadOperation, result *graphql.Result) error 56 | 57 | // OnOperationResult is called after operation result is received, allowing to postprocess it or terminate the 58 | // operation before returning the result with error. AST is available in context with ContextAST. 59 | OnOperationResult func(opctx mutable.Context, payload *apollows.PayloadOperation, result *graphql.Result) error 60 | 61 | // OnOperationDone is called once operation is finished, with error occurred during the execution (if any) 62 | // error returned from this handler will close the websocket / terminate HTTP request with error response. 63 | // By default, will pass through any error occurred. AST will be available in context with ContextAST if can be 64 | // parsed. 65 | OnOperationDone func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error 66 | } 67 | 68 | func legacyCallbackExecute(callbacks Callbacks) InterceptorOperationExecute { 69 | return func( 70 | ctx context.Context, 71 | payload *apollows.PayloadOperation, 72 | handler HandlerOperationExecute, 73 | ) (chan *graphql.Result, error) { 74 | cres, err := handler(ctx, payload) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | if callbacks.OnOperationResult == nil { 80 | return cres, nil 81 | } 82 | 83 | ch := make(chan *graphql.Result) 84 | 85 | go func() { 86 | defer close(ch) 87 | 88 | for res := range cres { 89 | err = callbacks.OnOperationResult(OperationContext(ctx), payload, res) 90 | if err != nil { 91 | ch <- &graphql.Result{ 92 | Data: err, 93 | } 94 | 95 | return 96 | } 97 | 98 | ch <- res 99 | } 100 | }() 101 | 102 | return ch, nil 103 | } 104 | } 105 | 106 | func legacyCallbackOperationParse(callbacks Callbacks) InterceptorOperationParse { 107 | return func( 108 | ctx context.Context, 109 | payload *apollows.PayloadOperation, 110 | handler HandlerOperationParse, 111 | ) error { 112 | err := handler(ctx, payload) 113 | 114 | if callbacks.OnOperationValidation != nil { 115 | result, _ := err.(ResultError) 116 | 117 | nerr := callbacks.OnOperationValidation(OperationContext(ctx), payload, result.Result) 118 | if nerr != nil { 119 | return nerr 120 | } 121 | } 122 | 123 | return err 124 | } 125 | } 126 | 127 | func legacyCallbackOperation(callbacks Callbacks) InterceptorOperation { 128 | return func( 129 | ctx context.Context, 130 | payload *apollows.PayloadOperation, 131 | handler HandlerOperation, 132 | ) (err error) { 133 | if callbacks.OnOperation != nil { 134 | err = callbacks.OnOperation(OperationContext(ctx), payload) 135 | } 136 | 137 | if err == nil { 138 | err = handler(ctx, payload) 139 | } 140 | 141 | if callbacks.OnOperationDone != nil { 142 | err = callbacks.OnOperationDone(OperationContext(ctx), payload, err) 143 | } 144 | 145 | return err 146 | } 147 | } 148 | 149 | func legacyCallbackInit(callbacks Callbacks) InterceptorInit { 150 | return func( 151 | ctx context.Context, 152 | init apollows.PayloadInit, 153 | handler HandlerInit, 154 | ) (err error) { 155 | if callbacks.OnConnect != nil { 156 | err = callbacks.OnConnect(RequestContext(ctx), init) 157 | } 158 | 159 | if err == nil { 160 | err = handler(ctx, init) 161 | } 162 | 163 | if callbacks.OnDisconnect != nil { 164 | err = callbacks.OnDisconnect(RequestContext(ctx), err) 165 | } 166 | 167 | return err 168 | } 169 | } 170 | 171 | func legacyCallbackHTTPRequest(callbacks Callbacks) InterceptorHTTPRequest { 172 | return func( 173 | ctx context.Context, 174 | w http.ResponseWriter, 175 | r *http.Request, 176 | handler HandlerHTTPRequest, 177 | ) error { 178 | var err error 179 | 180 | defer func() { 181 | if err != nil { 182 | WriteError(ctx, w, err) 183 | } 184 | }() 185 | 186 | if callbacks.OnRequest != nil { 187 | err = callbacks.OnRequest(RequestContext(ctx), r, w) 188 | } 189 | 190 | if err == nil { 191 | err = handler(ctx, w, r) 192 | } 193 | 194 | if callbacks.OnRequestDone != nil { 195 | callbacks.OnRequestDone(RequestContext(ctx), r, w, err) 196 | } else { 197 | WriteError(ctx, w, err) 198 | } 199 | 200 | return err 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /v1/compat.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import "net/http" 4 | 5 | // Upgrader interface used to upgrade HTTP request/response pair into a Conn 6 | type Upgrader interface { 7 | Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) 8 | } 9 | 10 | // Conn interface is used to abstract connection returned from Upgrader 11 | type Conn interface { 12 | ReadJSON(v interface{}) error 13 | WriteJSON(v interface{}) error 14 | Close(code int, message string) error 15 | Subprotocol() string 16 | } 17 | -------------------------------------------------------------------------------- /v1/compat/gorillaws/api.go: -------------------------------------------------------------------------------- 1 | // Package gorillaws provides compatibility for gorilla websocket upgrader 2 | package gorillaws 3 | 4 | import ( 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | // Wrapper for gorilla websocket upgrader 12 | type Wrapper struct { 13 | *websocket.Upgrader 14 | } 15 | 16 | type conn struct { 17 | *websocket.Conn 18 | } 19 | 20 | func (conn conn) Close(code int, message string) error { 21 | origerr := conn.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)) 22 | 23 | err := conn.Conn.Close() 24 | if err == nil { 25 | err = origerr 26 | } 27 | 28 | return err 29 | } 30 | 31 | func (conn conn) Subprotocol() string { 32 | return conn.Conn.Subprotocol() 33 | } 34 | 35 | // Upgrade implementation 36 | func (g Wrapper) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (wsgraphql.Conn, error) { 37 | c, err := g.Upgrader.Upgrade(w, r, responseHeader) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | return conn{ 43 | Conn: c, 44 | }, nil 45 | } 46 | 47 | // Wrap gorilla upgrader into wsgraphql-compatible interface 48 | func Wrap(upgrader *websocket.Upgrader) Wrapper { 49 | return Wrapper{ 50 | Upgrader: upgrader, 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /v1/compat/otelwsgraphql/api.go: -------------------------------------------------------------------------------- 1 | // Package otelwsgraphql provides opentelemetry instrumentation for wsgraphql 2 | package otelwsgraphql 3 | 4 | import ( 5 | "context" 6 | "regexp" 7 | 8 | "github.com/eientei/wsgraphql/v1" 9 | "github.com/eientei/wsgraphql/v1/apollows" 10 | "go.opentelemetry.io/otel" 11 | "go.opentelemetry.io/otel/attribute" 12 | "go.opentelemetry.io/otel/codes" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | const ( 18 | instrumentationName = "github.com/eientei/wsgraphql/v1/compat/otelwsgraphql" 19 | instrumentationVersion = "1.0.0" 20 | ) 21 | 22 | const ( 23 | operationQuery = "query" 24 | operationMutation = "mutation" 25 | operationSubscription = "subscription" 26 | ) 27 | 28 | // OperationOption provides customizations for operation interceptor 29 | type OperationOption interface { 30 | applyOperation(*operationConfig) 31 | } 32 | 33 | // NewOperationInterceptor returns new otel-span reporting wsgraphql operation interceptor 34 | func NewOperationInterceptor(options ...OperationOption) wsgraphql.InterceptorOperation { 35 | var c operationConfig 36 | 37 | defaultOptions := []OperationOption{ 38 | WithSpanNameResolver(DefaultSpanNameResolver), 39 | WithSpanAttributesResolver(DefaultSpanAttributesResolver), 40 | WithStartSpanOptions(trace.WithSpanKind(trace.SpanKindServer)), 41 | } 42 | 43 | for _, o := range append(defaultOptions, options...) { 44 | o.applyOperation(&c) 45 | } 46 | 47 | return func(ctx context.Context, payload *apollows.PayloadOperation, handler wsgraphql.HandlerOperation) error { 48 | tracer := c.tracer 49 | 50 | if tracer == nil { 51 | if c.tracerProvider != nil { 52 | tracer = newTracer(c.tracerProvider) 53 | } else if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { 54 | tracer = newTracer(span.TracerProvider()) 55 | } else { 56 | tracer = newTracer(otel.GetTracerProvider()) 57 | } 58 | } 59 | 60 | opts := append( 61 | []trace.SpanStartOption{trace.WithAttributes(c.attributesResolver(ctx, payload)...)}, 62 | c.startSpanOptions..., 63 | ) 64 | 65 | ctx, span := tracer.Start(ctx, c.nameResolver(ctx, payload), opts...) 66 | 67 | err := handler(ctx, payload) 68 | 69 | if err != nil { 70 | span.RecordError(err) 71 | span.SetStatus(codes.Error, err.Error()) 72 | } else { 73 | span.SetStatus(codes.Ok, "") 74 | } 75 | 76 | span.End() 77 | 78 | return err 79 | } 80 | } 81 | 82 | // SpanNameResolver determined span name from payload operation 83 | type SpanNameResolver func(ctx context.Context, payload *apollows.PayloadOperation) string 84 | 85 | // SpanAttributesResolver determines span attributes from payload operation 86 | type SpanAttributesResolver func(ctx context.Context, payload *apollows.PayloadOperation) []attribute.KeyValue 87 | 88 | // SpanASTAttributesResolver determines span attributes from payload operation after AST parsing 89 | type SpanASTAttributesResolver func(ctx context.Context, payload *apollows.PayloadOperation) []attribute.KeyValue 90 | 91 | // WithTracer provides predefined tracer instance 92 | func WithTracer(tracer trace.Tracer) OperationOption { 93 | return optionFunc(func(c *operationConfig) { 94 | c.tracer = tracer 95 | }) 96 | } 97 | 98 | // WithTracerProvider sets predefined tracer provider instance 99 | func WithTracerProvider(tracerProvider trace.TracerProvider) OperationOption { 100 | return optionFunc(func(c *operationConfig) { 101 | c.tracerProvider = tracerProvider 102 | }) 103 | } 104 | 105 | // WithStartSpanOptions provides extra starting span options 106 | func WithStartSpanOptions(spanOptions ...trace.SpanStartOption) OperationOption { 107 | return optionFunc(func(c *operationConfig) { 108 | c.startSpanOptions = spanOptions 109 | }) 110 | } 111 | 112 | // WithSpanNameResolver provides custom name resolver 113 | func WithSpanNameResolver(resolver SpanNameResolver) OperationOption { 114 | return optionFunc(func(c *operationConfig) { 115 | c.nameResolver = resolver 116 | }) 117 | } 118 | 119 | // WithSpanAttributesResolver provides custom attribute resolver 120 | func WithSpanAttributesResolver(resolver SpanAttributesResolver) OperationOption { 121 | return optionFunc(func(c *operationConfig) { 122 | c.attributesResolver = resolver 123 | }) 124 | } 125 | 126 | var queryRegex = regexp.MustCompile(`(query|mutation|subscription)\s*(\w*)`) 127 | 128 | // DefaultSpanNameResolver default span name resolver function 129 | func DefaultSpanNameResolver(_ context.Context, payload *apollows.PayloadOperation) string { 130 | parts := queryRegex.FindStringSubmatch(payload.Query) 131 | name := payload.OperationName 132 | kind := operationQuery 133 | 134 | if len(parts) == 3 { 135 | kind = parts[1] 136 | 137 | if name == "" { 138 | name = parts[2] 139 | } 140 | } 141 | 142 | if name == "" { 143 | return "gql." + kind 144 | } 145 | 146 | return "gql." + kind + "." + name 147 | } 148 | 149 | // DefaultSpanAttributesResolver default span attributes resolver function 150 | func DefaultSpanAttributesResolver( 151 | _ context.Context, 152 | payload *apollows.PayloadOperation, 153 | ) (attrs []attribute.KeyValue) { 154 | parts := queryRegex.FindStringSubmatch(payload.Query) 155 | operationName := payload.OperationName 156 | kind := operationQuery 157 | 158 | if len(parts) == 3 { 159 | switch parts[1] { 160 | case operationSubscription, operationMutation: 161 | kind = parts[1] 162 | } 163 | 164 | if operationName == "" { 165 | operationName = parts[2] 166 | } 167 | } 168 | 169 | switch kind { 170 | case operationSubscription: 171 | attrs = append(attrs, semconv.GraphqlOperationTypeSubscription) 172 | case operationMutation: 173 | attrs = append(attrs, semconv.GraphqlOperationTypeMutation) 174 | case operationQuery: 175 | attrs = append(attrs, semconv.GraphqlOperationTypeQuery) 176 | } 177 | 178 | if operationName != "" { 179 | attrs = append(attrs, semconv.GraphqlOperationName(operationName)) 180 | } 181 | 182 | attrs = append(attrs, semconv.GraphqlDocument(payload.Query)) 183 | 184 | return 185 | } 186 | 187 | type operationConfig struct { 188 | nameResolver SpanNameResolver 189 | attributesResolver SpanAttributesResolver 190 | tracer trace.Tracer 191 | tracerProvider trace.TracerProvider 192 | startSpanOptions []trace.SpanStartOption 193 | } 194 | 195 | type optionFunc func(c *operationConfig) 196 | 197 | func (o optionFunc) applyOperation(c *operationConfig) { 198 | o(c) 199 | } 200 | 201 | func newTracer(provider trace.TracerProvider) trace.Tracer { 202 | return provider.Tracer(instrumentationName, trace.WithInstrumentationVersion(instrumentationVersion)) 203 | } 204 | -------------------------------------------------------------------------------- /v1/context.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/graphql-go/graphql" 8 | 9 | "github.com/eientei/wsgraphql/v1/mutable" 10 | "github.com/graphql-go/graphql/language/ast" 11 | ) 12 | 13 | type ( 14 | contextKeyRequestContextT struct{} 15 | contextKeyOperationContextT struct{} 16 | contextKeyOperationStoppedT struct{} 17 | contextKeyOperationExecutedT struct{} 18 | contextKeyOperationIDT struct{} 19 | contextKeyOperationParamsT struct{} 20 | contextKeyAstT struct{} 21 | contextKeySubscriptionT struct{} 22 | contextKeyHTTPRequestT struct{} 23 | contextKeyHTTPResponseWriterT struct{} 24 | contextKeyHTTPResponseStartedT struct{} 25 | contextKeyWebsocketConnectionT struct{} 26 | ) 27 | 28 | var ( 29 | // ContextKeyRequestContext used to store HTTP request-scoped mutable.Context 30 | ContextKeyRequestContext = contextKeyRequestContextT{} 31 | 32 | // ContextKeyOperationContext used to store graphql operation-scoped mutable.Context 33 | ContextKeyOperationContext = contextKeyOperationContextT{} 34 | 35 | // ContextKeyOperationStopped indicates the operation was stopped on client request 36 | ContextKeyOperationStopped = contextKeyOperationStoppedT{} 37 | 38 | // ContextKeyOperationExecuted indicates the operation was executed 39 | ContextKeyOperationExecuted = contextKeyOperationExecutedT{} 40 | 41 | // ContextKeyOperationID indicates the operation ID 42 | ContextKeyOperationID = contextKeyOperationIDT{} 43 | 44 | // ContextKeyOperationParams used to store operation params 45 | ContextKeyOperationParams = contextKeyOperationParamsT{} 46 | 47 | // ContextKeyAST used to store operation's ast.Document (abstract syntax tree) 48 | ContextKeyAST = contextKeyAstT{} 49 | 50 | // ContextKeySubscription used to store operation subscription flag 51 | ContextKeySubscription = contextKeySubscriptionT{} 52 | 53 | // ContextKeyHTTPRequest used to store HTTP request 54 | ContextKeyHTTPRequest = contextKeyHTTPRequestT{} 55 | 56 | // ContextKeyHTTPResponseWriter used to store HTTP response 57 | ContextKeyHTTPResponseWriter = contextKeyHTTPResponseWriterT{} 58 | 59 | // ContextKeyHTTPResponseStarted used to indicate HTTP response already has headers sent 60 | ContextKeyHTTPResponseStarted = contextKeyHTTPResponseStartedT{} 61 | 62 | // ContextKeyWebsocketConnection used to store websocket connection 63 | ContextKeyWebsocketConnection = contextKeyWebsocketConnectionT{} 64 | ) 65 | 66 | func defaultMutcontext(ctx context.Context, mutctx mutable.Context) mutable.Context { 67 | if mutctx != nil { 68 | return mutctx 69 | } 70 | 71 | return mutable.NewMutableContext(ctx) 72 | } 73 | 74 | // RequestContext returns HTTP request-scoped v1.mutable context from provided context or nil if none present 75 | func RequestContext(ctx context.Context) (mutctx mutable.Context) { 76 | defer func() { 77 | mutctx = defaultMutcontext(ctx, mutctx) 78 | }() 79 | 80 | v := ctx.Value(ContextKeyRequestContext) 81 | if v == nil { 82 | return nil 83 | } 84 | 85 | mutctx, ok := v.(mutable.Context) 86 | if !ok { 87 | return nil 88 | } 89 | 90 | return mutctx 91 | } 92 | 93 | // OperationContext returns graphql operation-scoped v1.mutable context from provided context or nil if none present 94 | func OperationContext(ctx context.Context) (mutctx mutable.Context) { 95 | defer func() { 96 | mutctx = defaultMutcontext(ctx, mutctx) 97 | }() 98 | 99 | v := ctx.Value(ContextKeyOperationContext) 100 | if v == nil { 101 | return nil 102 | } 103 | 104 | mutctx, ok := v.(mutable.Context) 105 | if !ok { 106 | return nil 107 | } 108 | 109 | return mutctx 110 | } 111 | 112 | // ContextOperationStopped returns true if user requested operation stop 113 | func ContextOperationStopped(ctx context.Context) bool { 114 | v := ctx.Value(ContextKeyOperationStopped) 115 | if v == nil { 116 | return false 117 | } 118 | 119 | res, ok := v.(bool) 120 | if !ok { 121 | return false 122 | } 123 | 124 | return res 125 | } 126 | 127 | // ContextOperationExecuted returns true if user requested operation stop 128 | func ContextOperationExecuted(ctx context.Context) bool { 129 | v := ctx.Value(ContextKeyOperationExecuted) 130 | if v == nil { 131 | return false 132 | } 133 | 134 | res, ok := v.(bool) 135 | if !ok { 136 | return false 137 | } 138 | 139 | return res 140 | } 141 | 142 | // ContextOperationID returns operaion ID stored in the context 143 | func ContextOperationID(ctx context.Context) string { 144 | v := ctx.Value(ContextKeyOperationID) 145 | if v == nil { 146 | return "" 147 | } 148 | 149 | res, ok := v.(string) 150 | if !ok { 151 | return "" 152 | } 153 | 154 | return res 155 | } 156 | 157 | // ContextOperationParams returns operation params stored in the context 158 | func ContextOperationParams(ctx context.Context) (res *graphql.Params) { 159 | v := ctx.Value(ContextKeyOperationParams) 160 | if v == nil { 161 | return &graphql.Params{} 162 | } 163 | 164 | r, ok := v.(*graphql.Params) 165 | if !ok || r == nil { 166 | return &graphql.Params{} 167 | } 168 | 169 | return r 170 | } 171 | 172 | // ContextAST returns operation's abstract syntax tree document 173 | func ContextAST(ctx context.Context) *ast.Document { 174 | v := ctx.Value(ContextKeyAST) 175 | if v == nil { 176 | return nil 177 | } 178 | 179 | astdoc, ok := v.(*ast.Document) 180 | if !ok { 181 | return nil 182 | } 183 | 184 | return astdoc 185 | } 186 | 187 | // ContextSubscription returns operation's subscription flag 188 | func ContextSubscription(ctx context.Context) bool { 189 | v := ctx.Value(ContextKeySubscription) 190 | if v == nil { 191 | return false 192 | } 193 | 194 | sub, ok := v.(bool) 195 | if !ok { 196 | return false 197 | } 198 | 199 | return sub 200 | } 201 | 202 | // ContextHTTPRequest returns http request stored in a context 203 | func ContextHTTPRequest(ctx context.Context) *http.Request { 204 | v := ctx.Value(ContextKeyHTTPRequest) 205 | if v == nil { 206 | return nil 207 | } 208 | 209 | req, ok := v.(*http.Request) 210 | if !ok { 211 | return nil 212 | } 213 | 214 | return req 215 | } 216 | 217 | // ContextHTTPResponseWriter returns http response writer stored in a context 218 | func ContextHTTPResponseWriter(ctx context.Context) http.ResponseWriter { 219 | v := ctx.Value(ContextKeyHTTPResponseWriter) 220 | if v == nil { 221 | return nil 222 | } 223 | 224 | req, ok := v.(http.ResponseWriter) 225 | if !ok { 226 | return nil 227 | } 228 | 229 | return req 230 | } 231 | 232 | // ContextHTTPResponseStarted returns true if HTTP response has already headers sent 233 | func ContextHTTPResponseStarted(ctx context.Context) bool { 234 | v := ctx.Value(ContextKeyHTTPResponseStarted) 235 | if v == nil { 236 | return false 237 | } 238 | 239 | val, ok := v.(bool) 240 | if !ok { 241 | return false 242 | } 243 | 244 | return val 245 | } 246 | 247 | // ContextWebsocketConnection returns websocket connection stored in a context 248 | func ContextWebsocketConnection(ctx context.Context) Conn { 249 | v := ctx.Value(ContextKeyWebsocketConnection) 250 | if v == nil { 251 | return nil 252 | } 253 | 254 | conn, ok := v.(Conn) 255 | if !ok { 256 | return nil 257 | } 258 | 259 | return conn 260 | } 261 | -------------------------------------------------------------------------------- /v1/context_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql/language/ast" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRequestContext(t *testing.T) { 14 | reqctx := mutable.NewMutableContext(context.Background()) 15 | 16 | mutctx := mutable.NewMutableContext(context.Background()) 17 | 18 | r := RequestContext(mutctx) 19 | 20 | assert.NotNil(t, r) 21 | assert.NotEqual(t, reqctx, r) 22 | 23 | mutctx.Set(ContextKeyRequestContext, reqctx) 24 | 25 | assert.Equal(t, reqctx, RequestContext(mutctx)) 26 | 27 | mutctx.Set(ContextKeyRequestContext, 123) 28 | 29 | r = RequestContext(mutctx) 30 | 31 | assert.NotNil(t, r) 32 | assert.NotEqual(t, reqctx, r) 33 | } 34 | 35 | func TestOperationContext(t *testing.T) { 36 | reqctx := mutable.NewMutableContext(context.Background()) 37 | 38 | mutctx := mutable.NewMutableContext(context.Background()) 39 | 40 | r := OperationContext(mutctx) 41 | 42 | assert.NotNil(t, r) 43 | assert.NotEqual(t, reqctx, r) 44 | 45 | mutctx.Set(ContextKeyOperationContext, reqctx) 46 | 47 | assert.Equal(t, reqctx, OperationContext(mutctx)) 48 | 49 | mutctx.Set(ContextKeyOperationContext, 123) 50 | 51 | r = OperationContext(mutctx) 52 | 53 | assert.NotNil(t, r) 54 | assert.NotEqual(t, reqctx, r) 55 | } 56 | 57 | func TestContextHTTPResponseWriter(t *testing.T) { 58 | var h struct { 59 | http.ResponseWriter 60 | } 61 | 62 | mutctx := mutable.NewMutableContext(context.Background()) 63 | 64 | assert.Nil(t, ContextHTTPResponseWriter(mutctx)) 65 | 66 | mutctx.Set(ContextKeyHTTPResponseWriter, h) 67 | 68 | assert.Equal(t, h, ContextHTTPResponseWriter(mutctx)) 69 | 70 | mutctx.Set(ContextKeyHTTPResponseWriter, 123) 71 | 72 | assert.Nil(t, ContextHTTPResponseWriter(mutctx)) 73 | } 74 | 75 | func TestContextHTTPResponseStarted(t *testing.T) { 76 | mutctx := mutable.NewMutableContext(context.Background()) 77 | 78 | assert.Equal(t, false, ContextHTTPResponseStarted(mutctx)) 79 | 80 | mutctx.Set(ContextKeyHTTPResponseStarted, true) 81 | 82 | assert.Equal(t, true, ContextHTTPResponseStarted(mutctx)) 83 | 84 | mutctx.Set(ContextKeyHTTPResponseStarted, 123) 85 | 86 | assert.Equal(t, false, ContextHTTPResponseStarted(mutctx)) 87 | } 88 | 89 | func TestContextHTTPRequest(t *testing.T) { 90 | h := &http.Request{} 91 | 92 | mutctx := mutable.NewMutableContext(context.Background()) 93 | 94 | assert.Nil(t, ContextHTTPRequest(mutctx)) 95 | 96 | mutctx.Set(ContextKeyHTTPRequest, h) 97 | 98 | assert.Equal(t, h, ContextHTTPRequest(mutctx)) 99 | 100 | mutctx.Set(ContextKeyHTTPRequest, 123) 101 | 102 | assert.Nil(t, ContextHTTPRequest(mutctx)) 103 | } 104 | 105 | func TestContextSubscription(t *testing.T) { 106 | mutctx := mutable.NewMutableContext(context.Background()) 107 | 108 | assert.Equal(t, false, ContextSubscription(mutctx)) 109 | 110 | mutctx.Set(ContextKeySubscription, true) 111 | 112 | assert.Equal(t, true, ContextSubscription(mutctx)) 113 | 114 | mutctx.Set(ContextKeySubscription, 123) 115 | 116 | assert.Equal(t, false, ContextSubscription(mutctx)) 117 | } 118 | 119 | func TestContextAST(t *testing.T) { 120 | doc := &ast.Document{} 121 | 122 | mutctx := mutable.NewMutableContext(context.Background()) 123 | 124 | assert.Nil(t, ContextAST(mutctx)) 125 | 126 | mutctx.Set(ContextKeyAST, doc) 127 | 128 | assert.Equal(t, doc, ContextAST(mutctx)) 129 | 130 | mutctx.Set(ContextKeyAST, 123) 131 | 132 | assert.Nil(t, ContextAST(mutctx)) 133 | } 134 | 135 | func TestContextWebsocketConnection(t *testing.T) { 136 | var conn struct { 137 | Conn 138 | } 139 | 140 | mutctx := mutable.NewMutableContext(context.Background()) 141 | 142 | assert.Nil(t, ContextWebsocketConnection(mutctx)) 143 | 144 | mutctx.Set(ContextKeyWebsocketConnection, conn) 145 | 146 | assert.Equal(t, conn, ContextWebsocketConnection(mutctx)) 147 | 148 | mutctx.Set(ContextKeyWebsocketConnection, 123) 149 | 150 | assert.Nil(t, ContextWebsocketConnection(mutctx)) 151 | } 152 | 153 | func TestContextOperationStopped(t *testing.T) { 154 | mutctx := mutable.NewMutableContext(context.Background()) 155 | 156 | assert.Equal(t, false, ContextOperationStopped(mutctx)) 157 | 158 | mutctx.Set(ContextKeyOperationStopped, true) 159 | 160 | assert.Equal(t, true, ContextOperationStopped(mutctx)) 161 | 162 | mutctx.Set(ContextKeyOperationStopped, 123) 163 | 164 | assert.Equal(t, false, ContextOperationStopped(mutctx)) 165 | } 166 | -------------------------------------------------------------------------------- /v1/error.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "github.com/graphql-go/graphql/gqlerrors" 5 | "github.com/graphql-go/graphql/language/location" 6 | ) 7 | 8 | func wrapExtendedError(err error, loc []location.SourceLocation) error { 9 | _, ok := err.(gqlerrors.ExtendedError) 10 | if ok { 11 | return &gqlerrors.Error{ 12 | Message: err.Error(), 13 | OriginalError: err, 14 | Locations: loc, 15 | } 16 | } 17 | 18 | return err 19 | } 20 | 21 | // FormatError returns error formatted as graphql error 22 | func FormatError(err error) gqlerrors.FormattedError { 23 | var loc []location.SourceLocation 24 | 25 | fmterr, ok := err.(gqlerrors.FormattedError) 26 | if ok { 27 | err = fmterr.OriginalError() 28 | loc = fmterr.Locations 29 | } 30 | 31 | _, ok = err.(*gqlerrors.Error) 32 | if ok { 33 | return gqlerrors.FormatError(err) 34 | } 35 | 36 | return gqlerrors.FormatError(wrapExtendedError(err, loc)) 37 | } 38 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-transport-ws/README.md: -------------------------------------------------------------------------------- 1 | Minimal server example 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | Graphql endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | There is no playground in this example, but you can try following query with graphql client of your choice: 13 | 14 | ```graphql 15 | query { 16 | getFoo 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-transport-ws/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 9 | "github.com/gorilla/websocket" 10 | "github.com/graphql-go/graphql" 11 | ) 12 | 13 | func main() { 14 | var addr string 15 | 16 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 17 | flag.Parse() 18 | 19 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 20 | Query: graphql.NewObject(graphql.ObjectConfig{ 21 | Name: "QueryRoot", 22 | Fields: graphql.Fields{ 23 | "getFoo": &graphql.Field{ 24 | Description: "Returns most recent foo value", 25 | Type: graphql.NewNonNull(graphql.Int), 26 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 27 | return 123, nil 28 | }, 29 | }, 30 | }, 31 | }), 32 | }) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | srv, err := wsgraphql.NewServer( 38 | schema, 39 | wsgraphql.WithProtocol(wsgraphql.WebsocketSubprotocolGraphqlTransportWS), 40 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 41 | Subprotocols: []string{wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String()}, 42 | })), 43 | ) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | http.Handle("/query", srv) 49 | 50 | err = http.ListenAndServe(addr, nil) 51 | if err != nil { 52 | panic(err) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-ws/README.md: -------------------------------------------------------------------------------- 1 | Minimal server example 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | Graphql endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | There is no playground in this example, but you can try following query with graphql client of your choice: 13 | 14 | ```graphql 15 | query { 16 | getFoo 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-ws/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 9 | "github.com/gorilla/websocket" 10 | "github.com/graphql-go/graphql" 11 | ) 12 | 13 | func main() { 14 | var addr string 15 | 16 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 17 | flag.Parse() 18 | 19 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 20 | Query: graphql.NewObject(graphql.ObjectConfig{ 21 | Name: "QueryRoot", 22 | Fields: graphql.Fields{ 23 | "getFoo": &graphql.Field{ 24 | Description: "Returns most recent foo value", 25 | Type: graphql.NewNonNull(graphql.Int), 26 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 27 | return 123, nil 28 | }, 29 | }, 30 | }, 31 | }), 32 | }) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | srv, err := wsgraphql.NewServer( 38 | schema, 39 | wsgraphql.WithProtocol(wsgraphql.WebsocketSubprotocolGraphqlWS), 40 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 41 | Subprotocols: []string{wsgraphql.WebsocketSubprotocolGraphqlWS.String()}, 42 | })), 43 | ) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | http.Handle("/query", srv) 49 | 50 | err = http.ListenAndServe(addr, nil) 51 | if err != nil { 52 | panic(err) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/README.md: -------------------------------------------------------------------------------- 1 | Complete server example, with query, mutation and subscriptions. 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | GraphQL endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | Navigate to playground on http://127.0.0.1:8080 13 | 14 | And try following queries: 15 | 16 | ```graphql 17 | subscription { 18 | fooUpdates 19 | } 20 | ``` 21 | 22 | Then in new tab(s) 23 | 24 | ```graphql 25 | query { 26 | getFoo 27 | } 28 | ``` 29 | 30 | ```graphql 31 | mutation { 32 | setFoo(value: 123) 33 | } 34 | ``` 35 | 36 | ```graphql 37 | query { 38 | getFoo 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | _ "embed" 7 | "flag" 8 | "fmt" 9 | "net/http" 10 | "sync/atomic" 11 | "time" 12 | 13 | "github.com/eientei/wsgraphql/v1" 14 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 15 | "github.com/gorilla/websocket" 16 | "github.com/graphql-go/graphql" 17 | ) 18 | 19 | //go:embed playground.html 20 | var playgroundFile []byte 21 | 22 | func main() { 23 | var addr string 24 | 25 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 26 | flag.Parse() 27 | 28 | var foo int 29 | 30 | fooupdates := make(chan int, 1) 31 | 32 | var subscriberID uint64 33 | 34 | type subscriber struct { 35 | subscription chan interface{} 36 | ctx context.Context 37 | id uint64 38 | } 39 | 40 | subscribers := make(map[uint64]*subscriber) 41 | subscriberadd := make(chan *subscriber, 1) 42 | subscriberrem := make(chan uint64, 1) 43 | 44 | go func() { 45 | for { 46 | select { 47 | case upd := <-fooupdates: 48 | foo = upd 49 | 50 | fmt.Println("broadcasting update, new value:", upd) 51 | 52 | for _, sub := range subscribers { 53 | select { 54 | case sub.subscription <- upd: 55 | case <-sub.ctx.Done(): 56 | } 57 | } 58 | case add := <-subscriberadd: 59 | subscribers[add.id] = add 60 | 61 | fmt.Println("added subscriber", add.id) 62 | case rem := <-subscriberrem: 63 | close(subscribers[rem].subscription) 64 | 65 | delete(subscribers, rem) 66 | 67 | fmt.Println("removed subscriber", rem) 68 | } 69 | } 70 | }() 71 | 72 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 73 | Query: graphql.NewObject(graphql.ObjectConfig{ 74 | Name: "QueryRoot", 75 | Fields: graphql.Fields{ 76 | "getFoo": &graphql.Field{ 77 | Description: "Returns most recent foo value", 78 | Type: graphql.NewNonNull(graphql.Int), 79 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 80 | return foo, nil 81 | }, 82 | }, 83 | }, 84 | }), 85 | Mutation: graphql.NewObject(graphql.ObjectConfig{ 86 | Name: "MutationRoot", 87 | Fields: graphql.Fields{ 88 | "setFoo": &graphql.Field{ 89 | Args: graphql.FieldConfigArgument{ 90 | "value": &graphql.ArgumentConfig{ 91 | Type: graphql.Int, 92 | }, 93 | }, 94 | Description: "Updates foo value; generating an update to subscribers of fooUpdates", 95 | Type: graphql.NewNonNull(graphql.Boolean), 96 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 97 | v, ok := p.Args["value"].(int) 98 | if ok { 99 | select { 100 | case <-p.Context.Done(): 101 | return nil, p.Context.Err() 102 | case fooupdates <- v: 103 | } 104 | } 105 | 106 | return ok, nil 107 | }, 108 | }, 109 | }, 110 | }), 111 | Subscription: graphql.NewObject(graphql.ObjectConfig{ 112 | Name: "SubscriptionRoot", 113 | Fields: graphql.Fields{ 114 | "fooUpdates": &graphql.Field{ 115 | Description: "Updates generated by setFoo mutation", 116 | Type: graphql.NewNonNull(graphql.Int), 117 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 118 | // values sent on channel, that were returned from `Subscribe`, will be available here as 119 | // `p.Source` 120 | return p.Source, nil 121 | }, 122 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 123 | // per graphql-go contract, channel returned from `Subscribe` function must have 124 | // interface{} values 125 | ch := make(chan interface{}, 1) 126 | id := atomic.AddUint64(&subscriberID, 1) 127 | 128 | subscriberadd <- &subscriber{ 129 | id: id, 130 | subscription: ch, 131 | ctx: p.Context, 132 | } 133 | 134 | go func() { 135 | <-p.Context.Done() 136 | 137 | subscriberrem <- id 138 | }() 139 | 140 | return ch, nil 141 | }, 142 | }, 143 | }, 144 | }), 145 | }) 146 | if err != nil { 147 | panic(err) 148 | } 149 | 150 | srv, err := wsgraphql.NewServer( 151 | schema, 152 | wsgraphql.WithKeepalive(time.Second*30), 153 | wsgraphql.WithConnectTimeout(time.Second*30), 154 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 155 | ReadBufferSize: 1024, 156 | WriteBufferSize: 1024, 157 | Subprotocols: []string{ 158 | wsgraphql.WebsocketSubprotocolGraphqlWS.String(), 159 | wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String(), 160 | }, 161 | })), 162 | ) 163 | if err != nil { 164 | panic(err) 165 | } 166 | 167 | http.Handle("/query", srv) 168 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 169 | http.ServeContent(writer, request, "playground.html", time.Time{}, bytes.NewReader(playgroundFile)) 170 | }) 171 | 172 | err = http.ListenAndServe(addr, nil) 173 | if err != nil { 174 | panic(err) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/playground.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | GraphQL Playground 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 42 | 43 |
Loading 44 | GraphQL Playground 45 |
46 |
47 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /v1/interceptors.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/graphql-go/graphql" 9 | ) 10 | 11 | // Interceptors allow to customize request processing 12 | // Sequence: 13 | // HTTPRequest -> Init -> [ Operation -> OperationParse -> OperationExecute ]* 14 | type Interceptors struct { 15 | HTTPRequest InterceptorHTTPRequest 16 | Init InterceptorInit 17 | Operation InterceptorOperation 18 | OperationParse InterceptorOperationParse 19 | OperationExecute InterceptorOperationExecute 20 | } 21 | 22 | type ( 23 | // HandlerHTTPRequest handler 24 | HandlerHTTPRequest func(ctx context.Context, w http.ResponseWriter, r *http.Request) error 25 | // HandlerInit handler 26 | HandlerInit func(ctx context.Context, init apollows.PayloadInit) error 27 | // HandlerOperation handler 28 | HandlerOperation func(ctx context.Context, payload *apollows.PayloadOperation) error 29 | // HandlerOperationParse handler 30 | HandlerOperationParse func(ctx context.Context, payload *apollows.PayloadOperation) error 31 | // HandlerOperationExecute handler 32 | HandlerOperationExecute func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) 33 | ) 34 | 35 | type ( 36 | // InterceptorHTTPRequest interceptor 37 | InterceptorHTTPRequest func( 38 | ctx context.Context, 39 | w http.ResponseWriter, 40 | r *http.Request, 41 | handler HandlerHTTPRequest, 42 | ) error 43 | // InterceptorInit interceptor 44 | InterceptorInit func( 45 | ctx context.Context, 46 | init apollows.PayloadInit, 47 | handler HandlerInit, 48 | ) error 49 | // InterceptorOperation interceptor 50 | InterceptorOperation func( 51 | ctx context.Context, 52 | payload *apollows.PayloadOperation, 53 | handler HandlerOperation, 54 | ) error 55 | // InterceptorOperationParse interceptor 56 | InterceptorOperationParse func( 57 | ctx context.Context, 58 | payload *apollows.PayloadOperation, 59 | handler HandlerOperationParse, 60 | ) error 61 | // InterceptorOperationExecute interceptor 62 | InterceptorOperationExecute func( 63 | ctx context.Context, 64 | payload *apollows.PayloadOperation, 65 | handler HandlerOperationExecute, 66 | ) (chan *graphql.Result, error) 67 | ) 68 | 69 | func interceptorHTTPRequestChain( 70 | interceptors []InterceptorHTTPRequest, 71 | idx int, 72 | handler HandlerHTTPRequest, 73 | ) HandlerHTTPRequest { 74 | for idx < len(interceptors) && interceptors[idx] == nil { 75 | idx++ 76 | } 77 | 78 | if idx == len(interceptors) { 79 | return handler 80 | } 81 | 82 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 83 | return interceptors[idx](ctx, w, r, interceptorHTTPRequestChain(interceptors, idx+1, handler)) 84 | } 85 | } 86 | 87 | // InterceptorHTTPRequestChain returns interceptor composed of the provided list 88 | func InterceptorHTTPRequestChain(interceptors ...InterceptorHTTPRequest) InterceptorHTTPRequest { 89 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request, handler HandlerHTTPRequest) error { 90 | return interceptorHTTPRequestChain(interceptors, 0, handler)(ctx, w, r) 91 | } 92 | } 93 | 94 | func interceptorInitChain( 95 | interceptors []InterceptorInit, 96 | idx int, 97 | handler HandlerInit, 98 | ) HandlerInit { 99 | for idx < len(interceptors) && interceptors[idx] == nil { 100 | idx++ 101 | } 102 | 103 | if idx == len(interceptors) { 104 | return handler 105 | } 106 | 107 | return func(ctx context.Context, init apollows.PayloadInit) error { 108 | return interceptors[idx](ctx, init, interceptorInitChain(interceptors, idx+1, handler)) 109 | } 110 | } 111 | 112 | // InterceptorInitChain returns interceptor composed of the provided list 113 | func InterceptorInitChain(interceptors ...InterceptorInit) InterceptorInit { 114 | return func(ctx context.Context, init apollows.PayloadInit, handler HandlerInit) error { 115 | return interceptorInitChain(interceptors, 0, handler)(ctx, init) 116 | } 117 | } 118 | 119 | func interceptorOperationChain( 120 | interceptors []InterceptorOperation, 121 | idx int, 122 | handler HandlerOperation, 123 | ) HandlerOperation { 124 | for idx < len(interceptors) && interceptors[idx] == nil { 125 | idx++ 126 | } 127 | 128 | if idx == len(interceptors) { 129 | return handler 130 | } 131 | 132 | return func(ctx context.Context, payload *apollows.PayloadOperation) error { 133 | return interceptors[idx](ctx, payload, interceptorOperationChain(interceptors, idx+1, handler)) 134 | } 135 | } 136 | 137 | // InterceptorOperationChain returns interceptor composed of the provided list 138 | func InterceptorOperationChain(interceptors ...InterceptorOperation) InterceptorOperation { 139 | return func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperation) error { 140 | return interceptorOperationChain(interceptors, 0, handler)(ctx, payload) 141 | } 142 | } 143 | 144 | func interceptorOperationParseChain( 145 | interceptors []InterceptorOperationParse, 146 | idx int, 147 | handler HandlerOperationParse, 148 | ) HandlerOperationParse { 149 | for idx < len(interceptors) && interceptors[idx] == nil { 150 | idx++ 151 | } 152 | 153 | if idx == len(interceptors) { 154 | return handler 155 | } 156 | 157 | return func(ctx context.Context, payload *apollows.PayloadOperation) error { 158 | return interceptors[idx](ctx, payload, interceptorOperationParseChain(interceptors, idx+1, handler)) 159 | } 160 | } 161 | 162 | // InterceptorOperationParseChain returns interceptor composed of the provided list 163 | func InterceptorOperationParseChain(interceptors ...InterceptorOperationParse) InterceptorOperationParse { 164 | return func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperationParse) error { 165 | return interceptorOperationParseChain(interceptors, 0, handler)(ctx, payload) 166 | } 167 | } 168 | 169 | func interceptorOperationExecuteChain( 170 | interceptors []InterceptorOperationExecute, 171 | idx int, 172 | handler HandlerOperationExecute, 173 | ) HandlerOperationExecute { 174 | for idx < len(interceptors) && interceptors[idx] == nil { 175 | idx++ 176 | } 177 | 178 | if idx == len(interceptors) { 179 | return handler 180 | } 181 | 182 | return func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 183 | return interceptors[idx](ctx, payload, interceptorOperationExecuteChain(interceptors, idx+1, handler)) 184 | } 185 | } 186 | 187 | // InterceptorOperationExecuteChain returns interceptor composed of the provided list 188 | func InterceptorOperationExecuteChain(interceptors ...InterceptorOperationExecute) InterceptorOperationExecute { 189 | return func( 190 | ctx context.Context, 191 | payload *apollows.PayloadOperation, 192 | handler HandlerOperationExecute, 193 | ) (chan *graphql.Result, error) { 194 | return interceptorOperationExecuteChain(interceptors, 0, handler)(ctx, payload) 195 | } 196 | } 197 | 198 | func initInterceptors(c *serverConfig) { 199 | if c.interceptors.HTTPRequest == nil { 200 | c.interceptors.HTTPRequest = func( 201 | ctx context.Context, 202 | w http.ResponseWriter, 203 | r *http.Request, 204 | handler HandlerHTTPRequest, 205 | ) error { 206 | err := handler(ctx, w, r) 207 | if err != nil { 208 | WriteError(ctx, w, err) 209 | } 210 | 211 | return err 212 | } 213 | } 214 | 215 | if c.interceptors.Init == nil { 216 | c.interceptors.Init = func( 217 | ctx context.Context, 218 | init apollows.PayloadInit, 219 | handler HandlerInit, 220 | ) error { 221 | return handler(ctx, init) 222 | } 223 | } 224 | 225 | if c.interceptors.Operation == nil { 226 | c.interceptors.Operation = func( 227 | ctx context.Context, 228 | payload *apollows.PayloadOperation, 229 | handler HandlerOperation, 230 | ) error { 231 | return handler(ctx, payload) 232 | } 233 | } 234 | 235 | if c.interceptors.OperationParse == nil { 236 | c.interceptors.OperationParse = func( 237 | ctx context.Context, 238 | payload *apollows.PayloadOperation, 239 | handler HandlerOperationParse, 240 | ) error { 241 | return handler(ctx, payload) 242 | } 243 | } 244 | 245 | if c.interceptors.OperationExecute == nil { 246 | c.interceptors.OperationExecute = func( 247 | ctx context.Context, 248 | payload *apollows.PayloadOperation, 249 | handler HandlerOperationExecute, 250 | ) (chan *graphql.Result, error) { 251 | return handler(ctx, payload) 252 | } 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /v1/mutable/api.go: -------------------------------------------------------------------------------- 1 | // Package mutable provides v1.mutable context, that can store multiple values and be updated after creation 2 | package mutable 3 | 4 | import ( 5 | "context" 6 | "sync" 7 | ) 8 | 9 | // Context interface, provides additional Set method to change values of the context after creation 10 | type Context interface { 11 | context.Context 12 | Set(key, value interface{}) 13 | Cancel() 14 | } 15 | 16 | type mutableContext struct { 17 | context.Context 18 | values map[interface{}]interface{} 19 | cancel context.CancelFunc 20 | mutex sync.RWMutex 21 | } 22 | 23 | func (mctx *mutableContext) Set(key, value interface{}) { 24 | mctx.mutex.Lock() 25 | 26 | mctx.values[key] = value 27 | 28 | mctx.mutex.Unlock() 29 | } 30 | 31 | func (mctx *mutableContext) Value(key interface{}) (res interface{}) { 32 | var ok bool 33 | 34 | mctx.mutex.RLock() 35 | 36 | res, ok = mctx.values[key] 37 | 38 | mctx.mutex.RUnlock() 39 | 40 | if !ok { 41 | res = mctx.Context.Value(key) 42 | } 43 | 44 | return 45 | } 46 | 47 | func (mctx *mutableContext) Cancel() { 48 | mctx.cancel() 49 | } 50 | 51 | // NewMutableContext returns new Context instance 52 | func NewMutableContext(parent context.Context) Context { 53 | ctx, cancel := context.WithCancel(parent) 54 | 55 | return &mutableContext{ 56 | Context: ctx, 57 | values: make(map[interface{}]interface{}), 58 | cancel: cancel, 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /v1/mutable/mutcontext_test.go: -------------------------------------------------------------------------------- 1 | package mutable 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type ( 11 | testFooKeyT struct{} 12 | testQuxKeyT struct{} 13 | testDagKeyT struct{} 14 | ) 15 | 16 | var ( 17 | testFooKey = testFooKeyT{} 18 | testQuxKey = testQuxKeyT{} 19 | testDagKey = testDagKeyT{} 20 | ) 21 | 22 | func TestNew(t *testing.T) { 23 | mctx := NewMutableContext(context.Background()) 24 | 25 | assert.NotNil(t, mctx) 26 | assert.Nil(t, mctx.Err()) 27 | 28 | assert.Equal(t, nil, mctx.Value(testFooKey)) 29 | assert.Equal(t, nil, mctx.Value(testQuxKey)) 30 | assert.Equal(t, nil, mctx.Value(testDagKey)) 31 | 32 | mctx.Set(testFooKey, "bar") 33 | mctx.Set(testQuxKey, "baz") 34 | 35 | assert.Equal(t, "bar", mctx.Value(testFooKey)) 36 | assert.Equal(t, "baz", mctx.Value(testQuxKey)) 37 | assert.Equal(t, nil, mctx.Value(testDagKey)) 38 | } 39 | 40 | func TestParent(t *testing.T) { 41 | mctx := NewMutableContext(context.WithValue(context.Background(), testFooKey, "123")) 42 | 43 | assert.NotNil(t, mctx) 44 | assert.Nil(t, mctx.Err()) 45 | 46 | assert.Equal(t, "123", mctx.Value(testFooKey)) 47 | assert.Equal(t, nil, mctx.Value(testQuxKey)) 48 | assert.Equal(t, nil, mctx.Value(testDagKey)) 49 | 50 | mctx.Set(testFooKey, "bar") 51 | mctx.Set(testQuxKey, "baz") 52 | 53 | assert.Equal(t, "bar", mctx.Value(testFooKey)) 54 | assert.Equal(t, "baz", mctx.Value(testQuxKey)) 55 | assert.Equal(t, nil, mctx.Value(testDagKey)) 56 | } 57 | 58 | func TestCancel(t *testing.T) { 59 | ctx, cancel := context.WithCancel(context.Background()) 60 | 61 | mctx := NewMutableContext(ctx) 62 | 63 | assert.Nil(t, mctx.Err()) 64 | 65 | cancel() 66 | 67 | assert.Error(t, mctx.Err()) 68 | } 69 | -------------------------------------------------------------------------------- /v1/server.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/graphql-go/graphql/gqlerrors" 10 | 11 | "github.com/eientei/wsgraphql/v1/apollows" 12 | "github.com/eientei/wsgraphql/v1/mutable" 13 | "github.com/graphql-go/graphql" 14 | ) 15 | 16 | var ( 17 | errHTTPQueryRejected = errors.New("HTTP query rejected") 18 | errReflectExtensions = errors.New("could not reflect schema extensions") 19 | ) 20 | 21 | type serverConfig struct { 22 | upgrader Upgrader 23 | interceptors Interceptors 24 | resultProcessor ResultProcessor 25 | rootObject map[string]interface{} 26 | subscriptionProtocols map[apollows.Protocol]struct{} 27 | keepalive time.Duration 28 | connectTimeout time.Duration 29 | rejectHTTPQueries bool 30 | } 31 | 32 | type serverImpl struct { 33 | extensions []graphql.Extension 34 | schema graphql.Schema 35 | serverConfig 36 | } 37 | 38 | func (server *serverImpl) handleHTTPRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) { 39 | if r.Header.Get("connection") != "" && r.Header.Get("upgrade") != "" && server.upgrader != nil { 40 | err = server.serveWebsocketRequest(ctx, w, r) 41 | } else { 42 | err = server.servePlainRequest(ctx) 43 | } 44 | 45 | return 46 | } 47 | 48 | func (server *serverImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) { 49 | reqctx := mutable.NewMutableContext(r.Context()) 50 | 51 | reqctx.Set(ContextKeyRequestContext, reqctx) 52 | reqctx.Set(ContextKeyHTTPRequest, r) 53 | reqctx.Set(ContextKeyHTTPResponseWriter, w) 54 | 55 | _ = server.interceptors.HTTPRequest(reqctx, w, r, server.handleHTTPRequest) 56 | 57 | reqctx.Cancel() 58 | } 59 | 60 | func (server *serverImpl) processResults( 61 | ctx context.Context, 62 | payload *apollows.PayloadOperation, 63 | cres chan *graphql.Result, 64 | write func(ctx context.Context, result *graphql.Result) error, 65 | ) (err error) { 66 | OperationContext(ctx).Set(ContextKeyOperationExecuted, true) 67 | 68 | for { 69 | select { 70 | case <-ctx.Done(): 71 | if !ContextOperationStopped(ctx) { 72 | err = ctx.Err() 73 | } 74 | 75 | return 76 | case result, ok := <-cres: 77 | if !ok { 78 | return 79 | } 80 | 81 | err = server.processResult(ctx, payload, result, write) 82 | if err != nil { 83 | return err 84 | } 85 | } 86 | } 87 | } 88 | 89 | func (server *serverImpl) processResult( 90 | ctx context.Context, 91 | payload *apollows.PayloadOperation, 92 | result *graphql.Result, 93 | write func(ctx context.Context, result *graphql.Result) error, 94 | ) error { 95 | result = server.resultProcessor(ctx, payload, result) 96 | 97 | var tgterrs []gqlerrors.FormattedError 98 | 99 | err, ok := result.Data.(error) 100 | if ok { 101 | tgterrs = append(tgterrs, FormatError(err)) 102 | } 103 | 104 | for _, src := range result.Errors { 105 | tgterrs = append(tgterrs, FormatError(src)) 106 | } 107 | 108 | result.Errors = tgterrs 109 | 110 | err = write(ctx, result) 111 | if err != nil { 112 | return err 113 | } 114 | 115 | if result.HasErrors() { 116 | return ResultError{ 117 | Result: result, 118 | } 119 | } 120 | 121 | return nil 122 | } 123 | -------------------------------------------------------------------------------- /v1/server_plain.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | "strconv" 8 | 9 | "github.com/eientei/wsgraphql/v1/apollows" 10 | "github.com/eientei/wsgraphql/v1/mutable" 11 | "github.com/graphql-go/graphql" 12 | ) 13 | 14 | // ResultError passes error result as error 15 | type ResultError struct { 16 | *graphql.Result 17 | } 18 | 19 | // Error implementation 20 | func (r ResultError) Error() string { 21 | bs, _ := json.Marshal(r.Result) 22 | 23 | return string(bs) 24 | } 25 | 26 | func (server *serverImpl) operationExecute( 27 | ctx context.Context, 28 | payload *apollows.PayloadOperation, 29 | ) (cres chan *graphql.Result, err error) { 30 | subscription := ContextSubscription(ctx) 31 | astdoc := ContextAST(ctx) 32 | 33 | if subscription { 34 | cres = graphql.ExecuteSubscription(graphql.ExecuteParams{ 35 | Schema: server.schema, 36 | Root: server.rootObject, 37 | AST: astdoc, 38 | OperationName: payload.OperationName, 39 | Args: payload.Variables, 40 | Context: ctx, 41 | }) 42 | } else { 43 | cres = make(chan *graphql.Result, 1) 44 | cres <- graphql.Execute(graphql.ExecuteParams{ 45 | Schema: server.schema, 46 | Root: server.rootObject, 47 | AST: astdoc, 48 | OperationName: payload.OperationName, 49 | Args: payload.Variables, 50 | Context: ctx, 51 | }) 52 | close(cres) 53 | } 54 | 55 | return 56 | } 57 | 58 | func (server *serverImpl) operationParse( 59 | ctx context.Context, 60 | payload *apollows.PayloadOperation, 61 | ) (err error) { 62 | err = server.parseAST(ctx, payload) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | return 68 | } 69 | 70 | func (server *serverImpl) plainRequestOperation( 71 | ctx context.Context, 72 | payload *apollows.PayloadOperation, 73 | ) (err error) { 74 | err = server.interceptors.OperationParse(ctx, payload, server.operationParse) 75 | if err != nil { 76 | return err 77 | } 78 | 79 | w := ContextHTTPResponseWriter(ctx) 80 | 81 | w.Header().Set("content-type", "application/json") 82 | 83 | cres, err := server.interceptors.OperationExecute( 84 | ctx, 85 | payload, 86 | server.operationExecute, 87 | ) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | var flusher http.Flusher 93 | 94 | if ContextSubscription(ctx) { 95 | flusher, _ = w.(http.Flusher) 96 | w.Header().Set("x-content-type-options", "nosniff") 97 | w.Header().Set("connection", "keep-alive") 98 | } 99 | 100 | return server.processResults(ctx, payload, cres, func(ctx context.Context, result *graphql.Result) error { 101 | return server.writePlainResult(ctx, result, w, flusher) 102 | }) 103 | } 104 | 105 | func (server *serverImpl) servePlainRequest(reqctx context.Context) (err error) { 106 | if server.rejectHTTPQueries { 107 | return errHTTPQueryRejected 108 | } 109 | 110 | err = server.interceptors.Init(reqctx, nil, func(nctx context.Context, init apollows.PayloadInit) error { 111 | reqctx = nctx 112 | 113 | return nil 114 | }) 115 | if err != nil { 116 | return err 117 | } 118 | 119 | var payload apollows.PayloadOperation 120 | 121 | err = json.NewDecoder(ContextHTTPRequest(reqctx).Body).Decode(&payload) 122 | if err != nil { 123 | return 124 | } 125 | 126 | opctx := mutable.NewMutableContext(reqctx) 127 | opctx.Set(ContextKeyOperationContext, opctx) 128 | 129 | defer opctx.Cancel() 130 | 131 | return server.interceptors.Operation(opctx, &payload, server.plainRequestOperation) 132 | } 133 | 134 | func (server *serverImpl) writePlainResult( 135 | reqctx context.Context, 136 | result *graphql.Result, 137 | w http.ResponseWriter, 138 | flusher http.Flusher, 139 | ) (err error) { 140 | if result == nil { 141 | return nil 142 | } 143 | 144 | bs, err := json.Marshal(result) 145 | if err != nil { 146 | return 147 | } 148 | 149 | bs = append(bs, '\n') 150 | 151 | if !ContextHTTPResponseStarted(reqctx) && flusher == nil { 152 | w.Header().Set("content-length", strconv.Itoa(len(bs))) 153 | } 154 | 155 | _, err = w.Write(bs) 156 | if err != nil { 157 | return 158 | } 159 | 160 | if flusher != nil { 161 | flusher.Flush() 162 | } 163 | 164 | RequestContext(reqctx).Set(ContextKeyHTTPResponseStarted, true) 165 | 166 | return nil 167 | } 168 | -------------------------------------------------------------------------------- /v1/server_plain_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/json" 7 | "testing" 8 | 9 | "github.com/eientei/wsgraphql/v1/apollows" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewServerPlain(t *testing.T) { 14 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS) 15 | 16 | defer srv.Close() 17 | 18 | client := srv.Client() 19 | 20 | query := srv.URL + "/query" 21 | 22 | var pd apollows.PayloadDataResponse 23 | 24 | bs, err := json.Marshal(apollows.PayloadOperation{ 25 | Query: `query { getFoo }`, 26 | }) 27 | 28 | assert.NoError(t, err) 29 | 30 | resp, err := client.Post(query, "application/json", bytes.NewReader(bs)) 31 | 32 | assert.NoError(t, err) 33 | 34 | assert.NoError(t, json.NewDecoder(resp.Body).Decode(&pd)) 35 | 36 | assert.NoError(t, resp.Body.Close()) 37 | 38 | assert.NoError(t, err) 39 | assert.Len(t, pd.Errors, 0) 40 | assert.EqualValues(t, 123, pd.Data["getFoo"]) 41 | 42 | bs, err = json.Marshal(apollows.PayloadOperation{ 43 | Query: `mutation { setFoo(value: 3) }`, 44 | }) 45 | 46 | assert.NoError(t, err) 47 | 48 | resp, err = client.Post(query, "application/json", bytes.NewReader(bs)) 49 | 50 | assert.NoError(t, err) 51 | 52 | assert.NoError(t, json.NewDecoder(resp.Body).Decode(&pd)) 53 | 54 | assert.NoError(t, resp.Body.Close()) 55 | 56 | assert.NoError(t, err) 57 | assert.Len(t, pd.Errors, 0) 58 | assert.EqualValues(t, true, pd.Data["setFoo"]) 59 | 60 | bs, err = json.Marshal(apollows.PayloadOperation{ 61 | Query: `mutation { setFoo }`, 62 | }) 63 | 64 | assert.NoError(t, err) 65 | 66 | resp, err = client.Post(query, "application/json", bytes.NewReader(bs)) 67 | 68 | assert.NoError(t, err) 69 | 70 | assert.NoError(t, json.NewDecoder(resp.Body).Decode(&pd)) 71 | 72 | assert.NoError(t, resp.Body.Close()) 73 | 74 | assert.NoError(t, err) 75 | assert.Len(t, pd.Errors, 0) 76 | assert.EqualValues(t, false, pd.Data["setFoo"]) 77 | 78 | bs, err = json.Marshal(apollows.PayloadOperation{ 79 | Query: `mutation { bar }`, 80 | }) 81 | 82 | assert.NoError(t, err) 83 | 84 | resp, err = client.Post(query, "application/json", bytes.NewReader(bs)) 85 | 86 | assert.NoError(t, err) 87 | 88 | assert.NoError(t, json.NewDecoder(resp.Body).Decode(&pd)) 89 | 90 | assert.NoError(t, resp.Body.Close()) 91 | 92 | assert.NoError(t, err) 93 | assert.Greater(t, len(pd.Errors), 0) 94 | 95 | bs, err = json.Marshal(apollows.PayloadOperation{ 96 | Query: `subscription { fooUpdates }`, 97 | }) 98 | 99 | assert.NoError(t, err) 100 | 101 | resp, err = client.Post(query, "application/json", bytes.NewReader(bs)) 102 | 103 | assert.NoError(t, err) 104 | 105 | scanner := bufio.NewScanner(resp.Body) 106 | 107 | idx := 1 108 | 109 | for scanner.Scan() { 110 | if len(scanner.Bytes()) > 0 { 111 | pd = apollows.PayloadDataResponse{} 112 | 113 | assert.NoError(t, json.Unmarshal(scanner.Bytes(), &pd)) 114 | 115 | assert.Len(t, pd.Errors, 0) 116 | assert.EqualValues(t, idx, pd.Data["fooUpdates"]) 117 | idx++ 118 | } 119 | } 120 | 121 | assert.NoError(t, resp.Body.Close()) 122 | } 123 | -------------------------------------------------------------------------------- /v1/server_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/eientei/wsgraphql/v1/apollows" 10 | "github.com/gorilla/websocket" 11 | "github.com/graphql-go/graphql" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type testWrapper struct { 16 | *websocket.Upgrader 17 | } 18 | 19 | type testConn struct { 20 | *websocket.Conn 21 | } 22 | 23 | func (conn testConn) Close(code int, message string) error { 24 | origerr := conn.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)) 25 | 26 | err := conn.Conn.Close() 27 | if err == nil { 28 | err = origerr 29 | } 30 | 31 | return err 32 | } 33 | 34 | // Upgrade implementation 35 | func (g testWrapper) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) { 36 | c, err := g.Upgrader.Upgrade(w, r, responseHeader) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return testConn{ 42 | Conn: c, 43 | }, nil 44 | } 45 | 46 | type extendedError struct { 47 | error 48 | extensions map[string]interface{} 49 | } 50 | 51 | func (ext *extendedError) Extensions() map[string]interface{} { 52 | return ext.extensions 53 | } 54 | 55 | func testNewSchema(t *testing.T) graphql.Schema { 56 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 57 | Query: graphql.NewObject(graphql.ObjectConfig{ 58 | Name: "QueryRoot", 59 | Interfaces: nil, 60 | Fields: graphql.Fields{ 61 | "getFoo": &graphql.Field{ 62 | Type: graphql.Int, 63 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 64 | return 123, nil 65 | }, 66 | }, 67 | "getError": &graphql.Field{ 68 | Type: graphql.Int, 69 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 70 | return nil, &extendedError{ 71 | error: errors.New("someerr"), 72 | extensions: map[string]interface{}{"foo": "bar"}, 73 | } 74 | }, 75 | }, 76 | }, 77 | }), 78 | Mutation: graphql.NewObject(graphql.ObjectConfig{ 79 | Name: "MutationRoot", 80 | Interfaces: nil, 81 | Fields: graphql.Fields{ 82 | "setFoo": &graphql.Field{ 83 | Args: graphql.FieldConfigArgument{ 84 | "value": &graphql.ArgumentConfig{ 85 | Type: graphql.Int, 86 | }, 87 | }, 88 | Type: graphql.Boolean, 89 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 90 | _, ok := p.Args["value"].(int) 91 | 92 | return ok, nil 93 | }, 94 | }, 95 | }, 96 | }), 97 | Subscription: graphql.NewObject(graphql.ObjectConfig{ 98 | Name: "SubscriptionRoot", 99 | Interfaces: nil, 100 | Fields: graphql.Fields{ 101 | "fooUpdates": &graphql.Field{ 102 | Type: graphql.Int, 103 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 104 | return p.Source, nil 105 | }, 106 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 107 | ch := make(chan interface{}, 3) 108 | 109 | ch <- 1 110 | ch <- 2 111 | ch <- 3 112 | 113 | close(ch) 114 | 115 | return ch, nil 116 | }, 117 | }, 118 | "forever": &graphql.Field{ 119 | Type: graphql.Int, 120 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 121 | return p.Source, nil 122 | }, 123 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 124 | ch := make(chan interface{}) 125 | 126 | return ch, nil 127 | }, 128 | }, 129 | "errors": &graphql.Field{ 130 | Type: graphql.Int, 131 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 132 | return p.Source, nil 133 | }, 134 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 135 | return nil, &extendedError{ 136 | error: errors.New("someerr"), 137 | extensions: map[string]interface{}{"foo": "bar"}, 138 | } 139 | }, 140 | }, 141 | }, 142 | }), 143 | }) 144 | 145 | assert.NoError(t, err) 146 | assert.NotNil(t, schema) 147 | 148 | return schema 149 | } 150 | 151 | func testNewServerProtocols(t *testing.T, protocols []apollows.Protocol, opts ...ServerOption) *httptest.Server { 152 | var strprotocols []string 153 | 154 | for _, p := range protocols { 155 | strprotocols = append(strprotocols, p.String()) 156 | opts = append(opts, WithProtocol(p)) 157 | } 158 | 159 | opts = append(opts, WithUpgrader(testWrapper{ 160 | Upgrader: &websocket.Upgrader{ 161 | ReadBufferSize: 1024, 162 | WriteBufferSize: 1024, 163 | Subprotocols: strprotocols, 164 | CheckOrigin: func(r *http.Request) bool { 165 | return true 166 | }, 167 | }, 168 | })) 169 | 170 | server, err := NewServer(testNewSchema(t), opts...) 171 | 172 | assert.NoError(t, err) 173 | assert.NotNil(t, server) 174 | 175 | return httptest.NewServer(server) 176 | } 177 | 178 | func testNewServer(t *testing.T, protocol apollows.Protocol, opts ...ServerOption) *httptest.Server { 179 | return testNewServerProtocols(t, []apollows.Protocol{protocol}, opts...) 180 | } 181 | -------------------------------------------------------------------------------- /v1/server_websocket.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "github.com/eientei/wsgraphql/v1/apollows" 12 | "github.com/eientei/wsgraphql/v1/mutable" 13 | "github.com/graphql-go/graphql" 14 | "github.com/graphql-go/graphql/gqlerrors" 15 | ) 16 | 17 | type websocketRequest struct { 18 | ctx context.Context 19 | outgoing chan outgoingMessage 20 | operations map[string]mutable.Context 21 | ws Conn 22 | server *serverImpl 23 | protocol apollows.Protocol 24 | wg sync.WaitGroup 25 | m sync.RWMutex 26 | init bool 27 | } 28 | 29 | type outgoingMessage struct { 30 | *apollows.Message 31 | apollows.Error 32 | } 33 | 34 | func (server *serverImpl) serveWebsocketRequest( 35 | ctx context.Context, 36 | w http.ResponseWriter, 37 | r *http.Request, 38 | ) (err error) { 39 | ws, err := server.upgrader.Upgrade(w, r, w.Header()) 40 | if err != nil { 41 | return 42 | } 43 | 44 | reqctx := RequestContext(ctx) 45 | 46 | reqctx.Set(ContextKeyWebsocketConnection, ws) 47 | reqctx.Set(ContextKeyHTTPResponseStarted, true) 48 | 49 | protocol := apollows.Protocol(ws.Subprotocol()) 50 | 51 | _, known := server.subscriptionProtocols[protocol] 52 | if !known { 53 | if ws != nil { 54 | _ = ws.Close(int(apollows.EventCloseNormal), apollows.ErrUnknownProtocol.Error()) 55 | } 56 | 57 | return apollows.ErrUnknownProtocol 58 | } 59 | 60 | req := &websocketRequest{ 61 | protocol: protocol, 62 | ctx: ctx, 63 | outgoing: make(chan outgoingMessage, 1), 64 | operations: make(map[string]mutable.Context), 65 | ws: ws, 66 | server: server, 67 | } 68 | 69 | var tickerType apollows.Operation 70 | 71 | switch req.protocol { 72 | case apollows.WebsocketSubprotocolGraphqlWS: 73 | tickerType = apollows.OperationKeepAlive 74 | case apollows.WebsocketSubprotocolGraphqlTransportWS: 75 | tickerType = apollows.OperationPong 76 | } 77 | 78 | var tickerch <-chan time.Time 79 | 80 | if server.keepalive > 0 { 81 | ticker := time.NewTicker(server.keepalive) 82 | 83 | defer func() { 84 | ticker.Stop() 85 | }() 86 | 87 | tickerch = ticker.C 88 | } 89 | 90 | go req.readWebsocket() 91 | 92 | // req.outgoing is read to completion to avoid any potential blocking 93 | // readWebsocket exit is ensured by closing a websocket on any error, this causes req.ws.ReadJSON() to return 94 | for { 95 | select { 96 | case msg, ok := <-req.outgoing: 97 | if !ok { 98 | return 99 | } 100 | 101 | switch { 102 | case msg.Message != nil: 103 | err = ws.WriteJSON(msg.Message) 104 | case msg.Error != nil: 105 | err = ws.Close(int(msg.Error.EventMessageType()), msg.Error.Error()) 106 | } 107 | case <-tickerch: 108 | err = ws.WriteJSON(&apollows.Message{ 109 | Type: tickerType, 110 | }) 111 | } 112 | 113 | if err != nil { 114 | _ = ws.Close(int(apollows.EventCloseNormal), err.Error()) 115 | } 116 | } 117 | } 118 | 119 | func combineErrors(errs []gqlerrors.FormattedError) gqlerrors.FormattedError { 120 | if len(errs) == 1 { 121 | return errs[0] 122 | } 123 | 124 | errmsg := "preparing operation" 125 | 126 | var errmsgs []string 127 | 128 | combinedext := make(map[string]interface{}) 129 | 130 | for _, err := range errs { 131 | fmterr := FormatError(err) 132 | 133 | for k, v := range fmterr.Extensions { 134 | combinedext[k] = v 135 | } 136 | 137 | errmsgs = append(errmsgs, fmterr.Error()) 138 | } 139 | 140 | if len(errmsgs) > 0 { 141 | errmsg += ": " + strings.Join(errmsgs, "; ") 142 | } 143 | 144 | rooterr := gqlerrors.NewFormattedError(errmsg) 145 | 146 | if len(errs) > 0 { 147 | combinedext["errors"] = errs 148 | rooterr.Extensions = combinedext 149 | } 150 | 151 | return rooterr 152 | } 153 | 154 | func (req *websocketRequest) handleError(ctx context.Context, err error) { 155 | awerr, ok := err.(apollows.Error) 156 | if ok { 157 | if req.protocol == apollows.WebsocketSubprotocolGraphqlWS { 158 | req.writeWebsocketMessage( 159 | ctx, 160 | apollows.OperationConnectionError, 161 | gqlerrors.FormatError(awerr), 162 | ) 163 | } 164 | 165 | req.outgoing <- outgoingMessage{ 166 | Error: awerr, 167 | } 168 | 169 | return 170 | } 171 | 172 | res, ok := err.(ResultError) 173 | 174 | if ok { 175 | switch { 176 | case req.protocol == apollows.WebsocketSubprotocolGraphqlWS: 177 | req.writeWebsocketMessage(ctx, apollows.OperationError, combineErrors(res.Result.Errors)) 178 | default: 179 | req.writeWebsocketMessage(ctx, apollows.OperationError, res.Result.Errors) 180 | } 181 | 182 | return 183 | } 184 | 185 | req.writeWebsocketMessage(ctx, apollows.OperationError, gqlerrors.FormatError(err)) 186 | } 187 | 188 | func (req *websocketRequest) writeWebsocketData(ctx context.Context, data *graphql.Result) { 189 | var t apollows.Operation 190 | 191 | switch req.protocol { 192 | case apollows.WebsocketSubprotocolGraphqlWS: 193 | t = apollows.OperationData 194 | case apollows.WebsocketSubprotocolGraphqlTransportWS: 195 | t = apollows.OperationNext 196 | 197 | if ContextOperationStopped(ctx) { 198 | return 199 | } 200 | } 201 | 202 | req.writeWebsocketMessage(ctx, t, data) 203 | } 204 | 205 | func (req *websocketRequest) writeWebsocketMessage(ctx context.Context, t apollows.Operation, data interface{}) { 206 | if t == apollows.OperationError { 207 | OperationContext(ctx).Set(ContextKeyOperationStopped, true) 208 | } 209 | 210 | select { 211 | case req.outgoing <- outgoingMessage{ 212 | Message: &apollows.Message{ 213 | ID: ContextOperationID(ctx), 214 | Type: t, 215 | Payload: apollows.Data{ 216 | Value: data, 217 | }, 218 | }, 219 | }: 220 | case <-RequestContext(ctx).Done(): 221 | } 222 | } 223 | 224 | func (req *websocketRequest) readWebsocketInit(msg *apollows.Message) (err error) { 225 | init := make(apollows.PayloadInit) 226 | 227 | if len(msg.Payload.RawMessage) > 0 { 228 | err = json.Unmarshal(msg.Payload.RawMessage, &init) 229 | if err != nil { 230 | return 231 | } 232 | } 233 | 234 | err = req.server.interceptors.Init(req.ctx, init, func(nctx context.Context, ninit apollows.PayloadInit) error { 235 | req.ctx, init = nctx, ninit 236 | 237 | return nil 238 | }) 239 | if err != nil { 240 | return 241 | } 242 | 243 | req.writeWebsocketMessage(req.ctx, apollows.OperationConnectionAck, nil) 244 | 245 | return 246 | } 247 | 248 | func (req *websocketRequest) readWebsocketStart(msg *apollows.Message) (err error) { 249 | if !req.init && req.protocol == apollows.WebsocketSubprotocolGraphqlTransportWS { 250 | return apollows.EventUnauthorized 251 | } 252 | 253 | req.m.RLock() 254 | _, ok := req.operations[msg.ID] 255 | req.m.RUnlock() 256 | 257 | if ok { 258 | return apollows.NewSubscriberAlreadyExistsError(msg.ID) 259 | } 260 | 261 | opctx := mutable.NewMutableContext(req.ctx) 262 | 263 | opctx.Set(ContextKeyOperationContext, opctx) 264 | opctx.Set(ContextKeyOperationID, msg.ID) 265 | 266 | req.m.Lock() 267 | req.operations[msg.ID] = opctx 268 | req.m.Unlock() 269 | 270 | req.wg.Add(1) 271 | 272 | go func() { 273 | var payload apollows.PayloadOperation 274 | 275 | operr := json.Unmarshal(msg.Payload.RawMessage, &payload) 276 | if operr != nil { 277 | if req.protocol == apollows.WebsocketSubprotocolGraphqlTransportWS { 278 | operr = apollows.WrapError(operr, apollows.EventInvalidMessage) 279 | } 280 | } else { 281 | operr = req.server.interceptors.Operation(opctx, &payload, req.serveWebsocketOperation) 282 | } 283 | 284 | if operr != nil && !ContextOperationExecuted(opctx) { 285 | req.handleError(opctx, operr) 286 | } 287 | 288 | if !ContextOperationStopped(opctx) || 289 | req.protocol == apollows.WebsocketSubprotocolGraphqlWS { 290 | req.writeWebsocketMessage(opctx, apollows.OperationComplete, nil) 291 | } 292 | 293 | opctx.Cancel() 294 | 295 | req.m.Lock() 296 | delete(req.operations, msg.ID) 297 | req.m.Unlock() 298 | 299 | req.wg.Done() 300 | }() 301 | 302 | return 303 | } 304 | 305 | func (req *websocketRequest) readWebsocketStop(msg *apollows.Message) (err error) { 306 | if !req.init && req.protocol == apollows.WebsocketSubprotocolGraphqlTransportWS { 307 | return apollows.EventUnauthorized 308 | } 309 | 310 | req.m.RLock() 311 | prev, ok := req.operations[msg.ID] 312 | req.m.RUnlock() 313 | 314 | if ok { 315 | prev.Set(ContextKeyOperationStopped, true) 316 | prev.Cancel() 317 | } 318 | 319 | return 320 | } 321 | 322 | func (req *websocketRequest) readWebsocketTerminate() (err error) { 323 | if req.protocol == apollows.WebsocketSubprotocolGraphqlTransportWS { 324 | return apollows.EventUnauthorized 325 | } 326 | 327 | RequestContext(req.ctx).Set(ContextKeyOperationStopped, true) 328 | 329 | req.outgoing <- outgoingMessage{ 330 | Error: apollows.EventCloseNormal, 331 | } 332 | 333 | return 334 | } 335 | 336 | func (req *websocketRequest) readWebsocketPing(msg *apollows.Message) { 337 | req.writeWebsocketMessage(req.ctx, apollows.OperationPong, msg.Payload.Value) 338 | } 339 | 340 | func (req *websocketRequest) backgroundTimeout(timeout time.Duration, connectSuccessful chan struct{}) { 341 | timer := time.NewTimer(timeout) 342 | 343 | select { 344 | case <-timer.C: 345 | req.handleError(req.ctx, apollows.EventInitializationTimeout) 346 | case <-connectSuccessful: 347 | case <-req.ctx.Done(): 348 | } 349 | 350 | timer.Stop() 351 | } 352 | 353 | func (req *websocketRequest) readWebsocket() { 354 | var err error 355 | 356 | defer func() { 357 | if err != nil { 358 | req.handleError(req.ctx, err) 359 | } 360 | 361 | // cancel request context and consequently all pending operation contexts 362 | RequestContext(req.ctx).Cancel() 363 | 364 | // await for all operations to complete, so nothing will write to req.outgoing from this point 365 | req.wg.Wait() 366 | 367 | close(req.outgoing) 368 | }() 369 | 370 | var connectSuccessful chan struct{} 371 | 372 | if req.server.connectTimeout > 0 { 373 | connectSuccessful = make(chan struct{}) 374 | 375 | go req.backgroundTimeout(req.server.connectTimeout, connectSuccessful) 376 | } 377 | 378 | for { 379 | var msg apollows.Message 380 | 381 | err = req.ws.ReadJSON(&msg) 382 | if err != nil { 383 | return 384 | } 385 | 386 | switch msg.Type { 387 | case apollows.OperationConnectionInit: 388 | if req.init { 389 | err = apollows.EventTooManyInitializationRequests 390 | 391 | return 392 | } 393 | 394 | req.init = true 395 | 396 | if connectSuccessful != nil { 397 | connectSuccessful <- struct{}{} 398 | close(connectSuccessful) 399 | } 400 | 401 | err = req.readWebsocketInit(&msg) 402 | case apollows.OperationStart, apollows.OperationSubscribe: 403 | err = req.readWebsocketStart(&msg) 404 | case apollows.OperationStop, apollows.OperationComplete: 405 | err = req.readWebsocketStop(&msg) 406 | case apollows.OperationTerminate: 407 | err = req.readWebsocketTerminate() 408 | case apollows.OperationPing: 409 | req.readWebsocketPing(&msg) 410 | } 411 | 412 | if err != nil { 413 | return 414 | } 415 | } 416 | } 417 | 418 | func (req *websocketRequest) serveWebsocketOperation( 419 | ctx context.Context, 420 | payload *apollows.PayloadOperation, 421 | ) (err error) { 422 | err = req.server.interceptors.OperationParse(ctx, payload, req.server.operationParse) 423 | if err != nil { 424 | return 425 | } 426 | 427 | cres, err := req.server.interceptors.OperationExecute(ctx, payload, req.server.operationExecute) 428 | if err != nil { 429 | return 430 | } 431 | 432 | return req.server.processResults(ctx, payload, cres, func(ctx context.Context, result *graphql.Result) error { 433 | req.writeWebsocketData(ctx, result) 434 | 435 | return nil 436 | }) 437 | } 438 | -------------------------------------------------------------------------------- /v1/server_websocket_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/eientei/wsgraphql/v1/apollows" 13 | "github.com/eientei/wsgraphql/v1/mutable" 14 | "github.com/gorilla/websocket" 15 | "github.com/graphql-go/graphql" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func testNewServerWebsocketGWS(t *testing.T, srv *httptest.Server) { 20 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 21 | 22 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 23 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 24 | }) 25 | 26 | assert.NoError(t, err) 27 | 28 | defer func() { 29 | _ = conn.Close() 30 | _ = resp.Body.Close() 31 | }() 32 | 33 | err = conn.WriteJSON(apollows.Message{ 34 | ID: "", 35 | Type: apollows.OperationConnectionInit, 36 | Payload: apollows.Data{}, 37 | }) 38 | 39 | assert.NoError(t, err) 40 | 41 | var msg apollows.Message 42 | 43 | err = conn.ReadJSON(&msg) 44 | 45 | assert.NoError(t, err) 46 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 47 | 48 | err = conn.WriteJSON(apollows.Message{ 49 | ID: "1", 50 | Type: apollows.OperationStart, 51 | Payload: apollows.Data{ 52 | Value: apollows.PayloadOperation{ 53 | Query: `query { getFoo }`, 54 | }, 55 | }, 56 | }) 57 | 58 | assert.NoError(t, err) 59 | 60 | err = conn.ReadJSON(&msg) 61 | 62 | assert.NoError(t, err) 63 | assert.Equal(t, "1", msg.ID) 64 | assert.Equal(t, apollows.OperationData, msg.Type) 65 | 66 | pd, err := msg.Payload.ReadPayloadData() 67 | 68 | assert.NoError(t, err) 69 | assert.Len(t, pd.Errors, 0) 70 | assert.EqualValues(t, 123, pd.Data["getFoo"]) 71 | 72 | err = conn.ReadJSON(&msg) 73 | 74 | assert.NoError(t, err) 75 | assert.Equal(t, "1", msg.ID) 76 | assert.Equal(t, apollows.OperationComplete, msg.Type) 77 | 78 | err = conn.WriteJSON(apollows.Message{ 79 | ID: "2", 80 | Type: apollows.OperationStart, 81 | Payload: apollows.Data{ 82 | Value: apollows.PayloadOperation{ 83 | Query: `mutation { setFoo(value: 3) }`, 84 | }, 85 | }, 86 | }) 87 | 88 | assert.NoError(t, err) 89 | 90 | err = conn.ReadJSON(&msg) 91 | 92 | assert.NoError(t, err) 93 | assert.Equal(t, "2", msg.ID) 94 | assert.Equal(t, apollows.OperationData, msg.Type) 95 | 96 | pd, err = msg.Payload.ReadPayloadData() 97 | 98 | assert.NoError(t, err) 99 | assert.Len(t, pd.Errors, 0) 100 | assert.EqualValues(t, map[string]interface{}{ 101 | "setFoo": true, 102 | }, pd.Data) 103 | 104 | err = conn.ReadJSON(&msg) 105 | 106 | assert.NoError(t, err) 107 | assert.Equal(t, "2", msg.ID) 108 | assert.Equal(t, apollows.OperationComplete, msg.Type) 109 | 110 | err = conn.WriteJSON(apollows.Message{ 111 | ID: "3", 112 | Type: apollows.OperationStart, 113 | Payload: apollows.Data{ 114 | Value: apollows.PayloadOperation{ 115 | Query: `mutation { setFoo }`, 116 | }, 117 | }, 118 | }) 119 | 120 | assert.NoError(t, err) 121 | 122 | err = conn.ReadJSON(&msg) 123 | 124 | assert.NoError(t, err) 125 | assert.Equal(t, "3", msg.ID) 126 | assert.Equal(t, apollows.OperationData, msg.Type) 127 | 128 | pd, err = msg.Payload.ReadPayloadData() 129 | 130 | assert.NoError(t, err) 131 | assert.Len(t, pd.Errors, 0) 132 | assert.EqualValues(t, map[string]interface{}{ 133 | "setFoo": false, 134 | }, pd.Data) 135 | 136 | err = conn.ReadJSON(&msg) 137 | 138 | assert.NoError(t, err) 139 | assert.Equal(t, "3", msg.ID) 140 | assert.Equal(t, apollows.OperationComplete, msg.Type) 141 | 142 | err = conn.WriteJSON(apollows.Message{ 143 | ID: "4", 144 | Type: apollows.OperationStart, 145 | Payload: apollows.Data{ 146 | Value: apollows.PayloadOperation{ 147 | Query: `mutation { bar }`, 148 | }, 149 | }, 150 | }) 151 | 152 | assert.NoError(t, err) 153 | 154 | err = conn.ReadJSON(&msg) 155 | 156 | assert.NoError(t, err) 157 | assert.Equal(t, "4", msg.ID) 158 | assert.Equal(t, apollows.OperationError, msg.Type) 159 | 160 | pde, err := msg.Payload.ReadPayloadError() 161 | 162 | assert.NoError(t, err) 163 | assert.Contains(t, pde.Message, `Cannot query field "bar"`) 164 | 165 | err = conn.ReadJSON(&msg) 166 | 167 | assert.NoError(t, err) 168 | assert.Equal(t, "4", msg.ID) 169 | assert.Equal(t, apollows.OperationComplete, msg.Type) 170 | 171 | err = conn.WriteJSON(apollows.Message{ 172 | ID: "5", 173 | Type: apollows.OperationStart, 174 | Payload: apollows.Data{ 175 | Value: apollows.PayloadOperation{ 176 | Query: `subscription { forever }`, 177 | }, 178 | }, 179 | }) 180 | 181 | assert.NoError(t, err) 182 | 183 | err = conn.WriteJSON(apollows.Message{ 184 | ID: "5", 185 | Type: apollows.OperationStop, 186 | }) 187 | 188 | assert.NoError(t, err) 189 | 190 | err = conn.ReadJSON(&msg) 191 | 192 | assert.NoError(t, err) 193 | assert.Equal(t, "5", msg.ID) 194 | assert.Equal(t, apollows.OperationComplete, msg.Type) 195 | 196 | err = conn.WriteJSON(apollows.Message{ 197 | ID: "6", 198 | Type: apollows.OperationStart, 199 | Payload: apollows.Data{ 200 | Value: apollows.PayloadOperation{ 201 | Query: `subscription { fooUpdates }`, 202 | }, 203 | }, 204 | }) 205 | 206 | assert.NoError(t, err) 207 | 208 | err = conn.ReadJSON(&msg) 209 | 210 | assert.NoError(t, err) 211 | assert.Equal(t, "6", msg.ID) 212 | assert.Equal(t, apollows.OperationData, msg.Type) 213 | 214 | pd, err = msg.Payload.ReadPayloadData() 215 | 216 | assert.NoError(t, err) 217 | assert.Len(t, pd.Errors, 0) 218 | assert.EqualValues(t, 1, pd.Data["fooUpdates"]) 219 | 220 | err = conn.ReadJSON(&msg) 221 | 222 | assert.NoError(t, err) 223 | assert.Equal(t, "6", msg.ID) 224 | assert.Equal(t, apollows.OperationData, msg.Type) 225 | 226 | pd, err = msg.Payload.ReadPayloadData() 227 | 228 | assert.NoError(t, err) 229 | assert.Len(t, pd.Errors, 0) 230 | assert.EqualValues(t, 2, pd.Data["fooUpdates"]) 231 | 232 | err = conn.ReadJSON(&msg) 233 | 234 | assert.NoError(t, err) 235 | assert.Equal(t, "6", msg.ID) 236 | assert.Equal(t, apollows.OperationData, msg.Type) 237 | 238 | pd, err = msg.Payload.ReadPayloadData() 239 | 240 | assert.NoError(t, err) 241 | assert.Len(t, pd.Errors, 0) 242 | assert.EqualValues(t, 3, pd.Data["fooUpdates"]) 243 | 244 | err = conn.ReadJSON(&msg) 245 | 246 | assert.NoError(t, err) 247 | assert.Equal(t, "6", msg.ID) 248 | assert.Equal(t, apollows.OperationComplete, msg.Type) 249 | 250 | err = conn.WriteJSON(apollows.Message{ 251 | ID: "7", 252 | Type: apollows.OperationStart, 253 | Payload: apollows.Data{ 254 | Value: apollows.PayloadOperation{ 255 | Query: `subscription { errors }`, 256 | }, 257 | }, 258 | }) 259 | 260 | assert.NoError(t, err) 261 | 262 | err = conn.ReadJSON(&msg) 263 | 264 | assert.NoError(t, err) 265 | assert.Equal(t, "7", msg.ID) 266 | assert.Equal(t, apollows.OperationData, msg.Type) 267 | 268 | pd, err = msg.Payload.ReadPayloadData() 269 | 270 | assert.NoError(t, err) 271 | assert.Len(t, pd.Errors, 1) 272 | assert.EqualValues(t, nil, pd.Data["errors"]) 273 | assert.EqualValues(t, map[string]interface{}{"foo": "bar"}, pd.Errors[0].Extensions) 274 | 275 | err = conn.ReadJSON(&msg) 276 | 277 | assert.NoError(t, err) 278 | assert.Equal(t, "7", msg.ID) 279 | assert.Equal(t, apollows.OperationComplete, msg.Type) 280 | 281 | assert.NoError(t, conn.Close()) 282 | } 283 | 284 | func testNewServerWebsocketGWTS(t *testing.T, srv *httptest.Server) { 285 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 286 | 287 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 288 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 289 | }) 290 | 291 | assert.NoError(t, err) 292 | 293 | defer func() { 294 | _ = conn.Close() 295 | _ = resp.Body.Close() 296 | }() 297 | 298 | err = conn.WriteJSON(apollows.Message{ 299 | ID: "", 300 | Type: apollows.OperationConnectionInit, 301 | Payload: apollows.Data{}, 302 | }) 303 | 304 | assert.NoError(t, err) 305 | 306 | var msg apollows.Message 307 | 308 | err = conn.ReadJSON(&msg) 309 | 310 | assert.NoError(t, err) 311 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 312 | 313 | err = conn.WriteJSON(apollows.Message{ 314 | ID: "1", 315 | Type: apollows.OperationStart, 316 | Payload: apollows.Data{ 317 | Value: apollows.PayloadOperation{ 318 | Query: `query { getFoo }`, 319 | }, 320 | }, 321 | }) 322 | 323 | assert.NoError(t, err) 324 | 325 | err = conn.ReadJSON(&msg) 326 | 327 | assert.NoError(t, err) 328 | assert.Equal(t, "1", msg.ID) 329 | assert.Equal(t, apollows.OperationNext, msg.Type) 330 | 331 | pd, err := msg.Payload.ReadPayloadData() 332 | 333 | assert.NoError(t, err) 334 | assert.Len(t, pd.Errors, 0) 335 | assert.EqualValues(t, 123, pd.Data["getFoo"]) 336 | 337 | err = conn.ReadJSON(&msg) 338 | 339 | assert.NoError(t, err) 340 | assert.Equal(t, "1", msg.ID) 341 | assert.Equal(t, apollows.OperationComplete, msg.Type) 342 | 343 | err = conn.WriteJSON(apollows.Message{ 344 | ID: "2", 345 | Type: apollows.OperationStart, 346 | Payload: apollows.Data{ 347 | Value: apollows.PayloadOperation{ 348 | Query: `mutation { setFoo(value: 3) }`, 349 | }, 350 | }, 351 | }) 352 | 353 | assert.NoError(t, err) 354 | 355 | err = conn.ReadJSON(&msg) 356 | 357 | assert.NoError(t, err) 358 | assert.Equal(t, "2", msg.ID) 359 | assert.Equal(t, apollows.OperationNext, msg.Type) 360 | 361 | pd, err = msg.Payload.ReadPayloadData() 362 | 363 | assert.NoError(t, err) 364 | assert.Len(t, pd.Errors, 0) 365 | assert.EqualValues(t, map[string]interface{}{ 366 | "setFoo": true, 367 | }, pd.Data) 368 | 369 | err = conn.ReadJSON(&msg) 370 | 371 | assert.NoError(t, err) 372 | assert.Equal(t, "2", msg.ID) 373 | assert.Equal(t, apollows.OperationComplete, msg.Type) 374 | 375 | err = conn.WriteJSON(apollows.Message{ 376 | ID: "3", 377 | Type: apollows.OperationStart, 378 | Payload: apollows.Data{ 379 | Value: apollows.PayloadOperation{ 380 | Query: `mutation { setFoo }`, 381 | }, 382 | }, 383 | }) 384 | 385 | assert.NoError(t, err) 386 | 387 | err = conn.ReadJSON(&msg) 388 | 389 | assert.NoError(t, err) 390 | assert.Equal(t, "3", msg.ID) 391 | assert.Equal(t, apollows.OperationNext, msg.Type) 392 | 393 | pd, err = msg.Payload.ReadPayloadData() 394 | 395 | assert.NoError(t, err) 396 | assert.Len(t, pd.Errors, 0) 397 | assert.EqualValues(t, map[string]interface{}{ 398 | "setFoo": false, 399 | }, pd.Data) 400 | 401 | err = conn.ReadJSON(&msg) 402 | 403 | assert.NoError(t, err) 404 | assert.Equal(t, "3", msg.ID) 405 | assert.Equal(t, apollows.OperationComplete, msg.Type) 406 | 407 | err = conn.WriteJSON(apollows.Message{ 408 | ID: "4", 409 | Type: apollows.OperationStart, 410 | Payload: apollows.Data{ 411 | Value: apollows.PayloadOperation{ 412 | Query: `mutation { bar }`, 413 | }, 414 | }, 415 | }) 416 | 417 | assert.NoError(t, err) 418 | 419 | err = conn.ReadJSON(&msg) 420 | 421 | assert.NoError(t, err) 422 | assert.Equal(t, "4", msg.ID) 423 | assert.Equal(t, apollows.OperationError, msg.Type) 424 | 425 | pde, err := msg.Payload.ReadPayloadErrors() 426 | 427 | assert.NoError(t, err) 428 | assert.Greater(t, len(pde), 0) 429 | assert.Contains(t, pde[0].Message, `Cannot query field "bar"`) 430 | 431 | err = conn.WriteJSON(apollows.Message{ 432 | ID: "5", 433 | Type: apollows.OperationStart, 434 | Payload: apollows.Data{ 435 | Value: apollows.PayloadOperation{ 436 | Query: `subscription { forever }`, 437 | }, 438 | }, 439 | }) 440 | 441 | assert.NoError(t, err) 442 | 443 | err = conn.WriteJSON(apollows.Message{ 444 | ID: "5", 445 | Type: apollows.OperationStop, 446 | }) 447 | 448 | assert.NoError(t, err) 449 | 450 | err = conn.WriteJSON(apollows.Message{ 451 | ID: "6", 452 | Type: apollows.OperationStart, 453 | Payload: apollows.Data{ 454 | Value: apollows.PayloadOperation{ 455 | Query: `subscription { fooUpdates }`, 456 | }, 457 | }, 458 | }) 459 | 460 | assert.NoError(t, err) 461 | 462 | err = conn.ReadJSON(&msg) 463 | 464 | assert.NoError(t, err) 465 | assert.Equal(t, "6", msg.ID) 466 | assert.Equal(t, apollows.OperationNext, msg.Type) 467 | 468 | pd, err = msg.Payload.ReadPayloadData() 469 | 470 | assert.NoError(t, err) 471 | assert.Len(t, pd.Errors, 0) 472 | assert.EqualValues(t, 1, pd.Data["fooUpdates"]) 473 | 474 | err = conn.ReadJSON(&msg) 475 | 476 | assert.NoError(t, err) 477 | assert.Equal(t, "6", msg.ID) 478 | assert.Equal(t, apollows.OperationNext, msg.Type) 479 | 480 | pd, err = msg.Payload.ReadPayloadData() 481 | 482 | assert.NoError(t, err) 483 | assert.Len(t, pd.Errors, 0) 484 | assert.EqualValues(t, 2, pd.Data["fooUpdates"]) 485 | 486 | err = conn.ReadJSON(&msg) 487 | 488 | assert.NoError(t, err) 489 | assert.Equal(t, "6", msg.ID) 490 | assert.Equal(t, apollows.OperationNext, msg.Type) 491 | 492 | pd, err = msg.Payload.ReadPayloadData() 493 | 494 | assert.NoError(t, err) 495 | assert.Len(t, pd.Errors, 0) 496 | assert.EqualValues(t, 3, pd.Data["fooUpdates"]) 497 | 498 | err = conn.ReadJSON(&msg) 499 | 500 | assert.NoError(t, err) 501 | assert.Equal(t, "6", msg.ID) 502 | assert.Equal(t, apollows.OperationComplete, msg.Type) 503 | 504 | err = conn.WriteJSON(apollows.Message{ 505 | ID: "7", 506 | Type: apollows.OperationStart, 507 | Payload: apollows.Data{ 508 | Value: apollows.PayloadOperation{ 509 | Query: `subscription { errors }`, 510 | }, 511 | }, 512 | }) 513 | 514 | assert.NoError(t, err) 515 | 516 | err = conn.ReadJSON(&msg) 517 | 518 | assert.NoError(t, err) 519 | assert.Equal(t, "7", msg.ID) 520 | assert.Equal(t, apollows.OperationNext, msg.Type) 521 | 522 | pd, err = msg.Payload.ReadPayloadData() 523 | 524 | assert.NoError(t, err) 525 | assert.Len(t, pd.Errors, 1) 526 | assert.EqualValues(t, nil, pd.Data["errors"]) 527 | assert.EqualValues(t, map[string]interface{}{"foo": "bar"}, pd.Errors[0].Extensions) 528 | 529 | err = conn.ReadJSON(&msg) 530 | 531 | assert.NoError(t, err) 532 | assert.Equal(t, "7", msg.ID) 533 | assert.Equal(t, apollows.OperationComplete, msg.Type) 534 | 535 | assert.NoError(t, conn.Close()) 536 | } 537 | 538 | func TestNewServerWebsocketGWS(t *testing.T) { 539 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithConnectTimeout(time.Second)) 540 | 541 | defer srv.Close() 542 | 543 | testNewServerWebsocketGWS(t, srv) 544 | } 545 | 546 | func TestNewServerWebsocketGWTS(t *testing.T) { 547 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(time.Second)) 548 | 549 | defer srv.Close() 550 | 551 | testNewServerWebsocketGWTS(t, srv) 552 | } 553 | 554 | func TestNewServerWebsocketGWSGWTS(t *testing.T) { 555 | srv := testNewServerProtocols( 556 | t, 557 | []apollows.Protocol{apollows.WebsocketSubprotocolGraphqlWS, apollows.WebsocketSubprotocolGraphqlTransportWS}, 558 | WithProtocol(apollows.WebsocketSubprotocolGraphqlTransportWS), 559 | WithConnectTimeout(time.Second), 560 | ) 561 | 562 | defer srv.Close() 563 | 564 | testNewServerWebsocketGWS(t, srv) 565 | testNewServerWebsocketGWTS(t, srv) 566 | } 567 | 568 | func TestNewServerWebsocketProtocolMismatch(t *testing.T) { 569 | srv := testNewServerProtocols( 570 | t, 571 | []apollows.Protocol{apollows.WebsocketSubprotocolGraphqlWS, apollows.WebsocketSubprotocolGraphqlTransportWS}, 572 | WithProtocol(apollows.WebsocketSubprotocolGraphqlTransportWS), 573 | WithConnectTimeout(time.Second), 574 | ) 575 | 576 | defer srv.Close() 577 | 578 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 579 | 580 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 581 | "sec-websocket-protocol": []string{"foo"}, 582 | }) 583 | 584 | assert.NoError(t, err) 585 | 586 | defer func() { 587 | _ = conn.Close() 588 | _ = resp.Body.Close() 589 | }() 590 | 591 | var msg apollows.Message 592 | 593 | err = conn.ReadJSON(&msg) 594 | 595 | assert.ErrorContains(t, err, apollows.ErrUnknownProtocol.Error()) 596 | } 597 | 598 | func TestNewServerWebsocketKeepalive(t *testing.T) { 599 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithKeepalive(time.Millisecond*10)) 600 | 601 | defer srv.Close() 602 | 603 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 604 | 605 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 606 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 607 | }) 608 | 609 | assert.NoError(t, err) 610 | 611 | defer func() { 612 | _ = conn.Close() 613 | _ = resp.Body.Close() 614 | }() 615 | 616 | err = conn.WriteJSON(apollows.Message{ 617 | ID: "", 618 | Type: apollows.OperationConnectionInit, 619 | Payload: apollows.Data{}, 620 | }) 621 | 622 | assert.NoError(t, err) 623 | 624 | var msg apollows.Message 625 | 626 | err = conn.ReadJSON(&msg) 627 | 628 | assert.NoError(t, err) 629 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 630 | 631 | err = conn.ReadJSON(&msg) 632 | 633 | assert.NoError(t, err) 634 | assert.Equal(t, apollows.OperationKeepAlive, msg.Type) 635 | } 636 | 637 | func TestNewServerWebsocketTerminateGWS(t *testing.T) { 638 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS) 639 | 640 | defer srv.Close() 641 | 642 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 643 | 644 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 645 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 646 | }) 647 | 648 | assert.NoError(t, err) 649 | 650 | defer func() { 651 | _ = conn.Close() 652 | _ = resp.Body.Close() 653 | }() 654 | 655 | err = conn.WriteJSON(apollows.Message{ 656 | ID: "", 657 | Type: apollows.OperationConnectionInit, 658 | Payload: apollows.Data{}, 659 | }) 660 | 661 | assert.NoError(t, err) 662 | 663 | var msg apollows.Message 664 | 665 | err = conn.ReadJSON(&msg) 666 | 667 | assert.NoError(t, err) 668 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 669 | 670 | err = conn.WriteJSON(apollows.Message{ 671 | ID: "", 672 | Type: apollows.OperationTerminate, 673 | }) 674 | 675 | assert.NoError(t, err) 676 | 677 | err = conn.ReadJSON(&msg) 678 | 679 | assert.ErrorContains(t, err, "requested") 680 | } 681 | 682 | func TestNewServerWebsocketTerminateGTWS(t *testing.T) { 683 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS) 684 | 685 | defer srv.Close() 686 | 687 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 688 | 689 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 690 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 691 | }) 692 | 693 | assert.NoError(t, err) 694 | 695 | defer func() { 696 | _ = conn.Close() 697 | _ = resp.Body.Close() 698 | }() 699 | 700 | err = conn.WriteJSON(apollows.Message{ 701 | ID: "", 702 | Type: apollows.OperationConnectionInit, 703 | Payload: apollows.Data{}, 704 | }) 705 | 706 | assert.NoError(t, err) 707 | 708 | var msg apollows.Message 709 | 710 | err = conn.ReadJSON(&msg) 711 | 712 | assert.NoError(t, err) 713 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 714 | 715 | err = conn.WriteJSON(apollows.Message{ 716 | ID: "", 717 | Type: apollows.OperationTerminate, 718 | }) 719 | 720 | assert.NoError(t, err) 721 | 722 | err = conn.ReadJSON(&msg) 723 | 724 | assert.ErrorContains(t, err, "Unauthorized") 725 | } 726 | 727 | func TestNewServerWebsocketTimeoutGWS(t *testing.T) { 728 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithConnectTimeout(1)) 729 | 730 | defer srv.Close() 731 | 732 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 733 | 734 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 735 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 736 | }) 737 | 738 | assert.NoError(t, err) 739 | 740 | defer func() { 741 | _ = conn.Close() 742 | _ = resp.Body.Close() 743 | }() 744 | 745 | time.Sleep(time.Millisecond * 10) 746 | 747 | err = conn.WriteJSON(apollows.Message{ 748 | ID: "", 749 | Type: apollows.OperationConnectionInit, 750 | Payload: apollows.Data{}, 751 | }) 752 | 753 | assert.NoError(t, err) 754 | 755 | var msg apollows.Message 756 | 757 | err = conn.ReadJSON(&msg) 758 | 759 | assert.NoError(t, err) 760 | assert.Equal(t, apollows.OperationConnectionError, msg.Type) 761 | } 762 | 763 | func TestNewServerWebsocketTimeoutGTWS(t *testing.T) { 764 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(1)) 765 | 766 | defer srv.Close() 767 | 768 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 769 | 770 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 771 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 772 | }) 773 | 774 | assert.NoError(t, err) 775 | 776 | defer func() { 777 | _ = conn.Close() 778 | _ = resp.Body.Close() 779 | }() 780 | 781 | time.Sleep(time.Millisecond * 10) 782 | 783 | err = conn.WriteJSON(apollows.Message{ 784 | ID: "", 785 | Type: apollows.OperationConnectionInit, 786 | Payload: apollows.Data{}, 787 | }) 788 | 789 | assert.NoError(t, err) 790 | 791 | var msg apollows.Message 792 | 793 | err = conn.ReadJSON(&msg) 794 | 795 | assert.ErrorContains(t, err, "4408: Connection initialisation timeout") 796 | } 797 | 798 | func TestNewServerWebsocketReinitGWS(t *testing.T) { 799 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS) 800 | 801 | defer srv.Close() 802 | 803 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 804 | 805 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 806 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 807 | }) 808 | 809 | assert.NoError(t, err) 810 | 811 | defer func() { 812 | _ = conn.Close() 813 | _ = resp.Body.Close() 814 | }() 815 | 816 | err = conn.WriteJSON(apollows.Message{ 817 | ID: "", 818 | Type: apollows.OperationConnectionInit, 819 | Payload: apollows.Data{}, 820 | }) 821 | 822 | assert.NoError(t, err) 823 | 824 | var msg apollows.Message 825 | 826 | err = conn.ReadJSON(&msg) 827 | 828 | assert.NoError(t, err) 829 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 830 | 831 | err = conn.WriteJSON(apollows.Message{ 832 | ID: "", 833 | Type: apollows.OperationConnectionInit, 834 | Payload: apollows.Data{}, 835 | }) 836 | 837 | assert.NoError(t, err) 838 | 839 | err = conn.ReadJSON(&msg) 840 | 841 | assert.NoError(t, err) 842 | assert.Equal(t, apollows.OperationConnectionError, msg.Type) 843 | 844 | pde, err := msg.Payload.ReadPayloadError() 845 | assert.NoError(t, err) 846 | 847 | assert.Contains(t, pde.Message, "Too many initialisation requests") 848 | } 849 | 850 | func TestNewServerWebsocketReinitGTWS(t *testing.T) { 851 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS) 852 | 853 | defer srv.Close() 854 | 855 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 856 | 857 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 858 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 859 | }) 860 | 861 | assert.NoError(t, err) 862 | 863 | defer func() { 864 | _ = conn.Close() 865 | _ = resp.Body.Close() 866 | }() 867 | 868 | err = conn.WriteJSON(apollows.Message{ 869 | ID: "", 870 | Type: apollows.OperationConnectionInit, 871 | Payload: apollows.Data{}, 872 | }) 873 | 874 | assert.NoError(t, err) 875 | 876 | var msg apollows.Message 877 | 878 | err = conn.ReadJSON(&msg) 879 | 880 | assert.NoError(t, err) 881 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 882 | 883 | err = conn.WriteJSON(apollows.Message{ 884 | ID: "", 885 | Type: apollows.OperationConnectionInit, 886 | Payload: apollows.Data{}, 887 | }) 888 | 889 | assert.NoError(t, err) 890 | 891 | err = conn.ReadJSON(&msg) 892 | 893 | assert.ErrorContains(t, err, "4429: Too many initialisation requests") 894 | } 895 | 896 | func TestNewServerWebsocketOperationRestartGWS(t *testing.T) { 897 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithConnectTimeout(time.Second)) 898 | 899 | defer srv.Close() 900 | 901 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 902 | 903 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 904 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 905 | }) 906 | 907 | assert.NoError(t, err) 908 | 909 | defer func() { 910 | _ = conn.Close() 911 | _ = resp.Body.Close() 912 | }() 913 | 914 | err = conn.WriteJSON(apollows.Message{ 915 | ID: "", 916 | Type: apollows.OperationConnectionInit, 917 | Payload: apollows.Data{}, 918 | }) 919 | 920 | assert.NoError(t, err) 921 | 922 | var msg apollows.Message 923 | 924 | err = conn.ReadJSON(&msg) 925 | 926 | assert.NoError(t, err) 927 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 928 | 929 | err = conn.WriteJSON(apollows.Message{ 930 | ID: "1", 931 | Type: apollows.OperationStart, 932 | Payload: apollows.Data{ 933 | Value: apollows.PayloadOperation{ 934 | Query: `subscription { forever }`, 935 | }, 936 | }, 937 | }) 938 | 939 | assert.NoError(t, err) 940 | 941 | err = conn.WriteJSON(apollows.Message{ 942 | ID: "1", 943 | Type: apollows.OperationStart, 944 | Payload: apollows.Data{ 945 | Value: apollows.PayloadOperation{ 946 | Query: `subscription { forever }`, 947 | }, 948 | }, 949 | }) 950 | 951 | assert.NoError(t, err) 952 | 953 | err = conn.ReadJSON(&msg) 954 | 955 | assert.NoError(t, err) 956 | assert.Equal(t, apollows.OperationConnectionError, msg.Type) 957 | 958 | pde, err := msg.Payload.ReadPayloadError() 959 | 960 | assert.NoError(t, err) 961 | assert.Contains(t, pde.Message, "Subscriber for 1 already exists") 962 | } 963 | 964 | func TestNewServerWebsocketOperationRestartGTWS(t *testing.T) { 965 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(time.Second)) 966 | 967 | defer srv.Close() 968 | 969 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 970 | 971 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 972 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 973 | }) 974 | 975 | assert.NoError(t, err) 976 | 977 | defer func() { 978 | _ = conn.Close() 979 | _ = resp.Body.Close() 980 | }() 981 | 982 | err = conn.WriteJSON(apollows.Message{ 983 | ID: "", 984 | Type: apollows.OperationConnectionInit, 985 | Payload: apollows.Data{}, 986 | }) 987 | 988 | assert.NoError(t, err) 989 | 990 | var msg apollows.Message 991 | 992 | err = conn.ReadJSON(&msg) 993 | 994 | assert.NoError(t, err) 995 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 996 | 997 | err = conn.WriteJSON(apollows.Message{ 998 | ID: "1", 999 | Type: apollows.OperationStart, 1000 | Payload: apollows.Data{ 1001 | Value: apollows.PayloadOperation{ 1002 | Query: `subscription { forever }`, 1003 | }, 1004 | }, 1005 | }) 1006 | 1007 | assert.NoError(t, err) 1008 | 1009 | err = conn.WriteJSON(apollows.Message{ 1010 | ID: "1", 1011 | Type: apollows.OperationStart, 1012 | Payload: apollows.Data{ 1013 | Value: apollows.PayloadOperation{ 1014 | Query: `subscription { forever }`, 1015 | }, 1016 | }, 1017 | }) 1018 | 1019 | assert.NoError(t, err) 1020 | 1021 | err = conn.ReadJSON(&msg) 1022 | 1023 | assert.ErrorContains(t, err, "4409: Subscriber for 1 already exists") 1024 | } 1025 | 1026 | func TestNewServerWebsocketOperationInvalidGWS(t *testing.T) { 1027 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithConnectTimeout(time.Second)) 1028 | 1029 | defer srv.Close() 1030 | 1031 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1032 | 1033 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1034 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1035 | }) 1036 | 1037 | assert.NoError(t, err) 1038 | 1039 | defer func() { 1040 | _ = conn.Close() 1041 | _ = resp.Body.Close() 1042 | }() 1043 | 1044 | err = conn.WriteJSON(apollows.Message{ 1045 | ID: "", 1046 | Type: apollows.OperationConnectionInit, 1047 | Payload: apollows.Data{}, 1048 | }) 1049 | 1050 | assert.NoError(t, err) 1051 | 1052 | var msg apollows.Message 1053 | 1054 | err = conn.ReadJSON(&msg) 1055 | 1056 | assert.NoError(t, err) 1057 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1058 | 1059 | err = conn.WriteJSON(apollows.Message{ 1060 | ID: "1", 1061 | Type: apollows.OperationStart, 1062 | Payload: apollows.Data{ 1063 | Value: "foo", 1064 | }, 1065 | }) 1066 | 1067 | assert.NoError(t, err) 1068 | 1069 | err = conn.ReadJSON(&msg) 1070 | 1071 | assert.NoError(t, err) 1072 | assert.Equal(t, "1", msg.ID) 1073 | assert.Equal(t, apollows.OperationError, msg.Type) 1074 | 1075 | err = conn.ReadJSON(&msg) 1076 | 1077 | assert.NoError(t, err) 1078 | assert.Equal(t, "1", msg.ID) 1079 | assert.Equal(t, apollows.OperationComplete, msg.Type) 1080 | } 1081 | 1082 | func TestNewServerWebsocketOperationInvalidGTWS(t *testing.T) { 1083 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(time.Second)) 1084 | 1085 | defer srv.Close() 1086 | 1087 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1088 | 1089 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1090 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 1091 | }) 1092 | 1093 | assert.NoError(t, err) 1094 | 1095 | defer func() { 1096 | _ = conn.Close() 1097 | _ = resp.Body.Close() 1098 | }() 1099 | 1100 | err = conn.WriteJSON(apollows.Message{ 1101 | ID: "", 1102 | Type: apollows.OperationConnectionInit, 1103 | Payload: apollows.Data{}, 1104 | }) 1105 | 1106 | assert.NoError(t, err) 1107 | 1108 | var msg apollows.Message 1109 | 1110 | err = conn.ReadJSON(&msg) 1111 | 1112 | assert.NoError(t, err) 1113 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1114 | 1115 | err = conn.WriteJSON(apollows.Message{ 1116 | ID: "1", 1117 | Type: apollows.OperationStart, 1118 | Payload: apollows.Data{ 1119 | Value: "foo", 1120 | }, 1121 | }) 1122 | 1123 | assert.NoError(t, err) 1124 | 1125 | err = conn.ReadJSON(&msg) 1126 | 1127 | assert.ErrorContains(t, err, "4400: Invalid message") 1128 | } 1129 | 1130 | func TestNewServerWebsocketOperationErrorGWS(t *testing.T) { 1131 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlWS, WithConnectTimeout(time.Second)) 1132 | 1133 | defer srv.Close() 1134 | 1135 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1136 | 1137 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1138 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1139 | }) 1140 | 1141 | assert.NoError(t, err) 1142 | 1143 | defer func() { 1144 | _ = conn.Close() 1145 | _ = resp.Body.Close() 1146 | }() 1147 | 1148 | err = conn.WriteJSON(apollows.Message{ 1149 | ID: "", 1150 | Type: apollows.OperationConnectionInit, 1151 | Payload: apollows.Data{}, 1152 | }) 1153 | 1154 | assert.NoError(t, err) 1155 | 1156 | var msg apollows.Message 1157 | 1158 | err = conn.ReadJSON(&msg) 1159 | 1160 | assert.NoError(t, err) 1161 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1162 | 1163 | err = conn.WriteJSON(apollows.Message{ 1164 | ID: "1", 1165 | Type: apollows.OperationStart, 1166 | Payload: apollows.Data{ 1167 | Value: apollows.PayloadOperation{ 1168 | Query: `query { getError }`, 1169 | }, 1170 | }, 1171 | }) 1172 | 1173 | assert.NoError(t, err) 1174 | 1175 | err = conn.ReadJSON(&msg) 1176 | 1177 | assert.NoError(t, err) 1178 | assert.Equal(t, "1", msg.ID) 1179 | assert.Equal(t, apollows.OperationData, msg.Type) 1180 | 1181 | pd, err := msg.Payload.ReadPayloadData() 1182 | 1183 | assert.NoError(t, err) 1184 | assert.Len(t, pd.Errors, 1) 1185 | assert.Contains(t, pd.Errors[0].Message, "someerr") 1186 | assert.EqualValues(t, map[string]interface{}{"foo": "bar"}, pd.Errors[0].Extensions) 1187 | 1188 | err = conn.ReadJSON(&msg) 1189 | 1190 | assert.NoError(t, err) 1191 | assert.Equal(t, "1", msg.ID) 1192 | assert.Equal(t, apollows.OperationComplete, msg.Type) 1193 | } 1194 | 1195 | func TestNewServerWebsocketOperationErrorGTWS(t *testing.T) { 1196 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(time.Second)) 1197 | 1198 | defer srv.Close() 1199 | 1200 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1201 | 1202 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1203 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 1204 | }) 1205 | 1206 | assert.NoError(t, err) 1207 | 1208 | defer func() { 1209 | _ = conn.Close() 1210 | _ = resp.Body.Close() 1211 | }() 1212 | 1213 | err = conn.WriteJSON(apollows.Message{ 1214 | ID: "", 1215 | Type: apollows.OperationConnectionInit, 1216 | Payload: apollows.Data{}, 1217 | }) 1218 | 1219 | assert.NoError(t, err) 1220 | 1221 | var msg apollows.Message 1222 | 1223 | err = conn.ReadJSON(&msg) 1224 | 1225 | assert.NoError(t, err) 1226 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1227 | 1228 | err = conn.WriteJSON(apollows.Message{ 1229 | ID: "1", 1230 | Type: apollows.OperationStart, 1231 | Payload: apollows.Data{ 1232 | Value: apollows.PayloadOperation{ 1233 | Query: `query { getError }`, 1234 | }, 1235 | }, 1236 | }) 1237 | 1238 | assert.NoError(t, err) 1239 | 1240 | err = conn.ReadJSON(&msg) 1241 | 1242 | assert.NoError(t, err) 1243 | assert.Equal(t, "1", msg.ID) 1244 | assert.Equal(t, apollows.OperationNext, msg.Type) 1245 | 1246 | pd, err := msg.Payload.ReadPayloadData() 1247 | 1248 | assert.NoError(t, err) 1249 | assert.Len(t, pd.Errors, 1) 1250 | assert.Contains(t, pd.Errors[0].Message, "someerr") 1251 | assert.EqualValues(t, map[string]interface{}{"foo": "bar"}, pd.Errors[0].Extensions) 1252 | 1253 | err = conn.ReadJSON(&msg) 1254 | 1255 | assert.NoError(t, err) 1256 | assert.Equal(t, "1", msg.ID) 1257 | assert.Equal(t, apollows.OperationComplete, msg.Type) 1258 | } 1259 | 1260 | func TestNewServerWebsocketPingGTWS(t *testing.T) { 1261 | srv := testNewServer(t, apollows.WebsocketSubprotocolGraphqlTransportWS, WithConnectTimeout(time.Second)) 1262 | 1263 | defer srv.Close() 1264 | 1265 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1266 | 1267 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1268 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlTransportWS.String()}, 1269 | }) 1270 | 1271 | assert.NoError(t, err) 1272 | 1273 | defer func() { 1274 | _ = conn.Close() 1275 | _ = resp.Body.Close() 1276 | }() 1277 | 1278 | err = conn.WriteJSON(apollows.Message{ 1279 | ID: "", 1280 | Type: apollows.OperationConnectionInit, 1281 | Payload: apollows.Data{}, 1282 | }) 1283 | 1284 | assert.NoError(t, err) 1285 | 1286 | var msg apollows.Message 1287 | 1288 | err = conn.ReadJSON(&msg) 1289 | 1290 | assert.NoError(t, err) 1291 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1292 | 1293 | err = conn.WriteJSON(apollows.Message{ 1294 | Type: apollows.OperationPing, 1295 | Payload: apollows.Data{ 1296 | Value: map[string]interface{}{ 1297 | "foo": 123, 1298 | }, 1299 | }, 1300 | }) 1301 | 1302 | assert.NoError(t, err) 1303 | 1304 | err = conn.ReadJSON(&msg) 1305 | 1306 | assert.NoError(t, err) 1307 | assert.Equal(t, apollows.OperationPong, msg.Type) 1308 | 1309 | var m map[string]interface{} 1310 | 1311 | err = json.Unmarshal(msg.Payload.RawMessage, &m) 1312 | 1313 | assert.NoError(t, err) 1314 | assert.EqualValues(t, 123, m["foo"]) 1315 | } 1316 | 1317 | func TestNewServerWebsocketCombineErrorsGWS(t *testing.T) { 1318 | ex1 := &testExt{} 1319 | ex2 := &testExt{} 1320 | 1321 | var opts []ServerOption 1322 | 1323 | opts = append(opts, WithUpgrader(testWrapper{ 1324 | Upgrader: &websocket.Upgrader{ 1325 | ReadBufferSize: 1024, 1326 | WriteBufferSize: 1024, 1327 | Subprotocols: []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1328 | CheckOrigin: func(r *http.Request) bool { 1329 | return true 1330 | }, 1331 | }, 1332 | }), WithConnectTimeout(time.Second)) 1333 | 1334 | opts = append(opts, WithProtocol(apollows.WebsocketSubprotocolGraphqlWS)) 1335 | 1336 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 1337 | Query: graphql.NewObject(graphql.ObjectConfig{ 1338 | Name: "QueryRoot", 1339 | Interfaces: nil, 1340 | Fields: graphql.Fields{ 1341 | "getFoo": &graphql.Field{ 1342 | Type: graphql.Int, 1343 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 1344 | return 123, nil 1345 | }, 1346 | }, 1347 | "getError": &graphql.Field{ 1348 | Type: graphql.Int, 1349 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 1350 | return nil, errors.New("someerr") 1351 | }, 1352 | }, 1353 | }, 1354 | }), 1355 | Extensions: []graphql.Extension{ 1356 | ex1, 1357 | ex2, 1358 | }, 1359 | }) 1360 | 1361 | assert.NoError(t, err) 1362 | 1363 | server, err := NewServer(schema, opts...) 1364 | 1365 | assert.NoError(t, err) 1366 | assert.NotNil(t, server) 1367 | 1368 | srv := httptest.NewServer(server) 1369 | 1370 | defer srv.Close() 1371 | 1372 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1373 | 1374 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1375 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1376 | }) 1377 | 1378 | assert.NoError(t, err) 1379 | 1380 | defer func() { 1381 | _ = conn.Close() 1382 | _ = resp.Body.Close() 1383 | }() 1384 | 1385 | err = conn.WriteJSON(apollows.Message{ 1386 | ID: "", 1387 | Type: apollows.OperationConnectionInit, 1388 | Payload: apollows.Data{}, 1389 | }) 1390 | 1391 | assert.NoError(t, err) 1392 | 1393 | var msg apollows.Message 1394 | 1395 | err = conn.ReadJSON(&msg) 1396 | 1397 | assert.NoError(t, err) 1398 | assert.Equal(t, apollows.OperationConnectionAck, msg.Type) 1399 | 1400 | err = conn.WriteJSON(apollows.Message{ 1401 | ID: "1", 1402 | Type: apollows.OperationStart, 1403 | Payload: apollows.Data{ 1404 | Value: apollows.PayloadOperation{ 1405 | Query: `query { getError }`, 1406 | }, 1407 | }, 1408 | }) 1409 | 1410 | assert.NoError(t, err) 1411 | 1412 | err = conn.ReadJSON(&msg) 1413 | 1414 | assert.NoError(t, err) 1415 | assert.Equal(t, "1", msg.ID) 1416 | assert.Equal(t, apollows.OperationError, msg.Type) 1417 | 1418 | pd, err := msg.Payload.ReadPayloadError() 1419 | 1420 | assert.NoError(t, err) 1421 | assert.NotNil(t, pd.Extensions["errors"]) 1422 | assert.Len(t, pd.Extensions["errors"], 2) 1423 | } 1424 | 1425 | func TestNewServerWebsocketHeaders(t *testing.T) { 1426 | var opts []ServerOption 1427 | 1428 | opts = append(opts, WithUpgrader(testWrapper{ 1429 | Upgrader: &websocket.Upgrader{ 1430 | ReadBufferSize: 1024, 1431 | WriteBufferSize: 1024, 1432 | Subprotocols: []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1433 | CheckOrigin: func(r *http.Request) bool { 1434 | return true 1435 | }, 1436 | }, 1437 | }), WithConnectTimeout(time.Second)) 1438 | 1439 | opts = append( 1440 | opts, 1441 | WithProtocol(apollows.WebsocketSubprotocolGraphqlWS), 1442 | WithCallbacks(Callbacks{ 1443 | OnRequest: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter) error { 1444 | w.Header().Set("foo", "bar") 1445 | 1446 | return nil 1447 | }, 1448 | }), 1449 | ) 1450 | 1451 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 1452 | Query: graphql.NewObject(graphql.ObjectConfig{ 1453 | Name: "QueryRoot", 1454 | Interfaces: nil, 1455 | Fields: graphql.Fields{ 1456 | "getFoo": &graphql.Field{ 1457 | Type: graphql.Int, 1458 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 1459 | return 123, nil 1460 | }, 1461 | }, 1462 | }, 1463 | }), 1464 | }) 1465 | 1466 | assert.NoError(t, err) 1467 | 1468 | server, err := NewServer(schema, opts...) 1469 | 1470 | assert.NoError(t, err) 1471 | assert.NotNil(t, server) 1472 | 1473 | srv := httptest.NewServer(server) 1474 | 1475 | defer srv.Close() 1476 | 1477 | u := "ws" + strings.TrimPrefix(srv.URL, "http") 1478 | 1479 | conn, resp, err := websocket.DefaultDialer.Dial(u, http.Header{ 1480 | "sec-websocket-protocol": []string{apollows.WebsocketSubprotocolGraphqlWS.String()}, 1481 | }) 1482 | 1483 | assert.NoError(t, err) 1484 | 1485 | defer func() { 1486 | _ = conn.Close() 1487 | _ = resp.Body.Close() 1488 | }() 1489 | 1490 | assert.Equal(t, "bar", resp.Header.Get("foo")) 1491 | } 1492 | --------------------------------------------------------------------------------