├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── code-review.md │ └── feature_request.md └── workflows │ ├── go-main.yml │ └── go-pr.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── _config.yml ├── go.mod ├── go.sum └── v1 ├── api.go ├── api_server.go ├── api_server_test.go ├── apollows ├── proto.go └── proto_test.go ├── ast.go ├── ast_test.go ├── callbacks.go ├── compat.go ├── compat ├── gorillaws │ └── api.go └── otelwsgraphql │ └── api.go ├── context.go ├── context_test.go ├── error.go ├── examples ├── minimal-graphql-transport-ws │ ├── README.md │ └── main.go ├── minimal-graphql-ws │ ├── README.md │ └── main.go └── simpleserver │ ├── README.md │ ├── main.go │ └── playground.html ├── interceptors.go ├── mutable ├── api.go └── mutcontext_test.go ├── server.go ├── server_plain.go ├── server_plain_test.go ├── server_test.go ├── server_websocket.go └── server_websocket_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Bug report default 4 | title: '' 5 | labels: bug 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | ``` 15 | Minimal code reproducing the issue 16 | ``` 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/code-review.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Code review 3 | about: Review existing code regarding possible improvements 4 | title: '' 5 | labels: enhancement 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: iamtakingiteasy 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | -------------------------------------------------------------------------------- /.github/workflows/go-main.yml: -------------------------------------------------------------------------------- 1 | name: go-main 2 | on: 3 | push: 4 | jobs: 5 | test: 6 | name: Test 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - uses: actions/setup-go@v3 11 | with: 12 | go-version: "1.20.4" 13 | - name: golangci-lint 14 | uses: golangci/golangci-lint-action@v3 15 | with: 16 | version: "v1.52" 17 | - name: Test & publish code coverage 18 | uses: paambaati/codeclimate-action@v3.0.0 19 | env: 20 | CC_TEST_REPORTER_ID: ${{secrets.CC_TEST_REPORTER_ID}} 21 | with: 22 | coverageCommand: go test -race -coverprofile c.out -covermode=atomic -v -bench=. ./... 23 | prefix: github.com/eientei/wsgraphql 24 | coverageLocations: ${{github.workspace}}/c.out:gocov 25 | -------------------------------------------------------------------------------- /.github/workflows/go-pr.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | pull_request: 4 | permissions: 5 | contents: read 6 | jobs: 7 | golangci: 8 | name: lint 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions/setup-go@v3 13 | with: 14 | go-version: 1.17 15 | - name: golangci-lint 16 | uses: golangci/golangci-lint-action@v3 17 | with: 18 | version: v1.48 19 | - name: test 20 | run: go test -v ./... 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | vendor 4 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | concurrency: 4 3 | timeout: 1m 4 | issues-exit-code: 1 5 | tests: true 6 | output: 7 | format: colored-line-number 8 | print-issued-lines: true 9 | print-linter-name: true 10 | uniq-by-line: true 11 | 12 | linters-settings: 13 | errcheck: 14 | check-type-assertions: true 15 | check-blank: false 16 | gocognit: 17 | min-complexity: 25 18 | goconst: 19 | min-len: 3 20 | min-occurrences: 3 21 | gocyclo: 22 | min-complexity: 15 23 | gofmt: 24 | simplify: true 25 | revive: 26 | min-confidence: 0.8 27 | govet: 28 | check-shadowing: true 29 | enable-all: true 30 | lll: 31 | line-length: 120 32 | tab-width: 1 33 | wsl: 34 | strict-append: true 35 | allow-assign-and-call: true 36 | allow-multiline-assign: true 37 | allow-cuddle-declarations: false 38 | allow-trailing-comment: false 39 | force-case-trailing-whitespace: 0 40 | 41 | linters: 42 | disable-all: true 43 | enable: 44 | - govet 45 | - errcheck 46 | - unused 47 | - gosimple 48 | - ineffassign 49 | - typecheck 50 | - bodyclose 51 | - stylecheck 52 | - revive 53 | - unconvert 54 | - goconst 55 | - gocyclo 56 | - gocognit 57 | - gofmt 58 | - goimports 59 | - godox 60 | - lll 61 | - unparam 62 | - gocritic 63 | - wsl 64 | - goprintffuncname 65 | - whitespace 66 | 67 | issues: 68 | exclude-use-default: false 69 | max-issues-per-linter: 0 70 | max-same-issues: 0 71 | new: false 72 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | v1.5.1 2 | ------ 3 | - [otelwsgraphql] Allow specifying tracer provider 4 | 5 | v1.5.0 6 | ------ 7 | - Added `WithInterceptors` option to allow context-based third-party instrumentation 8 | - `WithCallbacks` is deprecated in favor of `WithInterceptors` and will be removed by 2024. 9 | - Added openetelemetry instrumentation in compat package. 10 | - Bumped graphql-go dependency to v0.8.1 with critical security fixes 11 | 12 | v1.4.2 13 | ------ 14 | - Support setting headers in websocket upgrade response 15 | 16 | v1.4.1 17 | ------ 18 | - Work-around for ExtendedError extensions not rendering in subscription is added 19 | 20 | v1.4.0 21 | ------ 22 | - Added support for per-request protocol selection for websocket subscriptions using websocket 23 | subprotocol negotiation. (#2) 24 | - Added Stringer implementation to `apollows.Protocol`, to avoid the explicit type casts. 25 | In 1.5.0 underlying Protocol type will be replaced with an integer, migration to .String() 26 | as advised. 27 | 28 | v1.3.4 29 | ------ 30 | - Fixed serialization of empty data/payloads (#1) 31 | 32 | v1.3.0 33 | ------ 34 | - Breaking change: root object moved to functional options parametrization 35 | - Added support for graphql-ws (graphql-transport-ws subprotocol) 36 | - Ensured only pre-execution operation errors are returned as `error` type per apollows spec 37 | - Fixed incorrect OnConnect/OnOperation callback sequence 38 | 39 | v1.2.3 40 | ------ 41 | - Added OnDisconnect handler without respnsibility to handle error, callback sequence diagram 42 | 43 | v1.2.2 44 | ------ 45 | - Correct termination request handling 46 | 47 | v1.2.1 48 | ------ 49 | - Fixes, clarifications for websocket request teardown sequence 50 | 51 | - Added CHANGELOG.md 52 | 53 | - Added READMEs to examples 54 | 55 | - Updated LICENSE year 56 | 57 | v1.0.0-v1.2.0 58 | ------ 59 | Major refactor, cleaned up implementation 60 | Complete test coverage, versioned package scheme 61 | 62 | v0.0.1-v0.5.0 63 | --- 64 | Initial implemnetation 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Eientei Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](https://godoc.org/github.com/eientei/wsgraphql/v1) 2 | [](https://goreportcard.com/report/github.com/eientei/wsgraphql) 3 | [](https://codeclimate.com/github/eientei/wsgraphql) 4 | [](https://codeclimate.com/github/eientei/wsgraphql) 5 | 6 | An implementation of websocket transport for 7 | [graphql-go](https://github.com/graphql-go/graphql). 8 | 9 | Currently following flavors are supported: 10 | 11 | - `graphql-ws` subprotocol, older spec: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 12 | - `graphql-transport-ws` subprotocol, newer spec: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md 13 | 14 | Inspired by [graphqlws](https://github.com/functionalfoundry/graphqlws) 15 | 16 | Key features: 17 | 18 | - Subscription support 19 | - Interceptors at every stage of communication process for easy customization 20 | - Supports both websockets and plain http queries, with http chunked response for plain http subscriptions 21 | - [Mutable context](https://godoc.org/github.com/eientei/wsgraphql/v1/mutable) allowing to keep request-scoped 22 | connection/authentication data and operation-scoped state 23 | 24 | Usage 25 | ----- 26 | 27 | Assuming [gorilla websocket](https://github.com/gorilla/websocket) upgrader 28 | 29 | ```go 30 | import ( 31 | "net/http" 32 | 33 | "github.com/eientei/wsgraphql/v1" 34 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 35 | "github.com/gorilla/websocket" 36 | "github.com/graphql-go/graphql" 37 | ) 38 | ``` 39 | 40 | ```go 41 | schema, err := graphql.NewSchema(...) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | srv, err := wsgraphql.NewServer( 47 | schema, 48 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 49 | Subprotocols: []string{ 50 | wsgraphql.WebsocketSubprotocolGraphqlWS.String(), 51 | wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String(), 52 | }, 53 | })), 54 | ) 55 | if err != nil { 56 | panic(err) 57 | } 58 | 59 | http.Handle("/query", srv) 60 | 61 | err = http.ListenAndServe(":8080", nil) 62 | if err != nil { 63 | panic(err) 64 | } 65 | ``` 66 | 67 | Examples 68 | -------- 69 | 70 | See [/v1/examples](/v1/examples) 71 | - [minimal-graphql-ws](/v1/examples/minimal-graphql-ws) `graphql-ws` / older subscriptions-transport-ws server setup 72 | - [minimal-graphql-transport-ws](/v1/examples/minimal-graphql-transport-ws) `graphql-transport-ws` / newer graphql-ws server setup 73 | - [simpleserver](/v1/examples/simpleserver) complete example with subscriptions, mutations and queries 74 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/eientei/wsgraphql 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.5.0 7 | github.com/graphql-go/graphql v0.8.1 8 | github.com/stretchr/testify v1.8.4 9 | go.opentelemetry.io/otel v1.16.0 10 | go.opentelemetry.io/otel/trace v1.16.0 11 | ) 12 | 13 | require ( 14 | github.com/davecgh/go-spew v1.1.1 // indirect 15 | github.com/go-logr/logr v1.2.4 // indirect 16 | github.com/go-logr/stdr v1.2.2 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | go.opentelemetry.io/otel/metric v1.16.0 // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 5 | github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= 6 | github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 7 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 8 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 9 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 10 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 11 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 12 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 13 | github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuMMgc= 14 | github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 18 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 19 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 20 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 22 | github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 23 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 24 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 25 | go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= 26 | go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= 27 | go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo= 28 | go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4= 29 | go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= 30 | go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /v1/api.go: -------------------------------------------------------------------------------- 1 | // Package wsgraphql provides interfaces for server and client 2 | package wsgraphql 3 | 4 | import ( 5 | "github.com/eientei/wsgraphql/v1/apollows" 6 | ) 7 | 8 | // WebsocketSubprotocolGraphqlWS websocket subprotocol expected by subscriptions-transport-ws implementations 9 | const WebsocketSubprotocolGraphqlWS = apollows.WebsocketSubprotocolGraphqlWS 10 | 11 | // WebsocketSubprotocolGraphqlTransportWS websocket subprotocol expected by graphql-ws implementations 12 | const WebsocketSubprotocolGraphqlTransportWS = apollows.WebsocketSubprotocolGraphqlTransportWS 13 | -------------------------------------------------------------------------------- /v1/api_server.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "reflect" 8 | "strconv" 9 | "time" 10 | "unsafe" 11 | 12 | "github.com/eientei/wsgraphql/v1/apollows" 13 | "github.com/graphql-go/graphql" 14 | "github.com/graphql-go/graphql/gqlerrors" 15 | ) 16 | 17 | // Server implements graphql http handler with websocket support (if upgrader is provided with WithUpgrader) 18 | type Server interface { 19 | http.Handler 20 | } 21 | 22 | // NewServer returns new Server instance 23 | func NewServer( 24 | schema graphql.Schema, 25 | options ...ServerOption, 26 | ) (Server, error) { 27 | var c serverConfig 28 | 29 | c.subscriptionProtocols = make(map[apollows.Protocol]struct{}) 30 | 31 | for _, o := range options { 32 | err := o(&c) 33 | if err != nil { 34 | return nil, err 35 | } 36 | } 37 | 38 | if len(c.subscriptionProtocols) == 0 { 39 | c.subscriptionProtocols[apollows.WebsocketSubprotocolGraphqlWS] = struct{}{} 40 | c.subscriptionProtocols[apollows.WebsocketSubprotocolGraphqlTransportWS] = struct{}{} 41 | } 42 | 43 | initInterceptors(&c) 44 | 45 | if c.resultProcessor == nil { 46 | c.resultProcessor = identityResultProcessor 47 | } 48 | 49 | f := reflect.ValueOf(&schema).Elem().FieldByName("extensions") 50 | 51 | exts, ok := reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem().Interface().([]graphql.Extension) 52 | if !ok { 53 | return nil, errReflectExtensions 54 | } 55 | 56 | return &serverImpl{ 57 | schema: schema, 58 | extensions: exts, 59 | serverConfig: c, 60 | }, nil 61 | } 62 | 63 | // ServerOption to configure Server 64 | type ServerOption func(config *serverConfig) error 65 | 66 | // WithUpgrader option sets Upgrader (interface in image of gorilla websocket upgrader) 67 | func WithUpgrader(upgrader Upgrader) ServerOption { 68 | return func(config *serverConfig) error { 69 | config.upgrader = upgrader 70 | 71 | return nil 72 | } 73 | } 74 | 75 | // WithInterceptors option sets interceptors around various stages of requests 76 | func WithInterceptors(interceptors Interceptors) ServerOption { 77 | return func(config *serverConfig) error { 78 | config.interceptors = interceptors 79 | 80 | return nil 81 | } 82 | } 83 | 84 | // WithExtraInterceptors option appends interceptors instead of replacing them 85 | func WithExtraInterceptors(interceptors Interceptors) ServerOption { 86 | return func(config *serverConfig) error { 87 | if interceptors.HTTPRequest != nil { 88 | config.interceptors.HTTPRequest = InterceptorHTTPRequestChain( 89 | config.interceptors.HTTPRequest, 90 | interceptors.HTTPRequest, 91 | ) 92 | } 93 | 94 | if interceptors.Init != nil { 95 | config.interceptors.Init = InterceptorInitChain( 96 | config.interceptors.Init, 97 | interceptors.Init, 98 | ) 99 | } 100 | 101 | if interceptors.Operation != nil { 102 | config.interceptors.Operation = InterceptorOperationChain( 103 | config.interceptors.Operation, 104 | interceptors.Operation, 105 | ) 106 | } 107 | 108 | if interceptors.OperationParse != nil { 109 | config.interceptors.OperationParse = InterceptorOperationParseChain( 110 | config.interceptors.OperationParse, 111 | interceptors.OperationParse, 112 | ) 113 | } 114 | 115 | if interceptors.OperationExecute != nil { 116 | config.interceptors.OperationExecute = InterceptorOperationExecuteChain( 117 | config.interceptors.OperationExecute, 118 | interceptors.OperationExecute, 119 | ) 120 | } 121 | 122 | return nil 123 | } 124 | } 125 | 126 | // WithKeepalive enabled sending keepalive messages with provided intervals 127 | func WithKeepalive(interval time.Duration) ServerOption { 128 | return func(config *serverConfig) error { 129 | config.keepalive = interval 130 | 131 | return nil 132 | } 133 | } 134 | 135 | // WithoutHTTPQueries option prevents HTTP queries from being handled, allowing only websocket queries 136 | func WithoutHTTPQueries() ServerOption { 137 | return func(config *serverConfig) error { 138 | config.rejectHTTPQueries = true 139 | 140 | return nil 141 | } 142 | } 143 | 144 | // WithProtocol option sets protocol for this sever to use. May be specified multiple times. 145 | func WithProtocol(protocol apollows.Protocol) ServerOption { 146 | return func(config *serverConfig) error { 147 | config.subscriptionProtocols[protocol] = struct{}{} 148 | 149 | return nil 150 | } 151 | } 152 | 153 | // WithConnectTimeout option sets duration within which client is allowed to initialize the connection before being 154 | // disconnected 155 | func WithConnectTimeout(timeout time.Duration) ServerOption { 156 | return func(config *serverConfig) error { 157 | config.connectTimeout = timeout 158 | 159 | return nil 160 | } 161 | } 162 | 163 | // WithRootObject provides root object that will be used in root resolvers 164 | func WithRootObject(rootObject map[string]interface{}) ServerOption { 165 | return func(config *serverConfig) error { 166 | config.rootObject = rootObject 167 | 168 | return nil 169 | } 170 | } 171 | 172 | // ResultProcessor allows to post-process resolved values 173 | type ResultProcessor func( 174 | ctx context.Context, 175 | payload *apollows.PayloadOperation, 176 | result *graphql.Result, 177 | ) *graphql.Result 178 | 179 | // WithResultProcessor provides ResultProcessor to post-process resolved values 180 | func WithResultProcessor(proc ResultProcessor) ServerOption { 181 | return func(config *serverConfig) error { 182 | config.resultProcessor = proc 183 | 184 | return nil 185 | } 186 | } 187 | 188 | // WriteError helper function writing an error to http.ResponseWriter 189 | func WriteError(ctx context.Context, w http.ResponseWriter, err error) { 190 | if err == nil || ContextHTTPResponseStarted(ctx) { 191 | return 192 | } 193 | 194 | var res ResultError 195 | 196 | if !errors.As(err, &res) { 197 | err = ResultError{ 198 | Result: &graphql.Result{ 199 | Errors: []gqlerrors.FormattedError{ 200 | gqlerrors.FormatError(err), 201 | }, 202 | }, 203 | } 204 | } 205 | 206 | bs := []byte(err.Error()) 207 | 208 | w.Header().Set("content-length", strconv.Itoa(len(bs))) 209 | w.WriteHeader(http.StatusBadRequest) 210 | 211 | _, _ = w.Write(bs) 212 | } 213 | 214 | func identityResultProcessor( 215 | _ context.Context, 216 | _ *apollows.PayloadOperation, 217 | result *graphql.Result, 218 | ) *graphql.Result { 219 | return result 220 | } 221 | -------------------------------------------------------------------------------- /v1/api_server_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/eientei/wsgraphql/v1/apollows" 13 | "github.com/eientei/wsgraphql/v1/mutable" 14 | "github.com/graphql-go/graphql" 15 | "github.com/graphql-go/graphql/gqlerrors" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestWithCallbacksOnRequest(t *testing.T) { 20 | opctx := mutable.NewMutableContext(context.Background()) 21 | 22 | var c serverConfig 23 | 24 | var encountered error 25 | 26 | target := errors.New("123") 27 | 28 | assert.NoError(t, WithCallbacks(Callbacks{ 29 | OnRequest: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter) error { 30 | return target 31 | }, 32 | OnRequestDone: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) { 33 | encountered = origerr 34 | }, 35 | OnConnect: nil, 36 | OnDisconnect: nil, 37 | OnOperation: nil, 38 | OnOperationValidation: nil, 39 | OnOperationResult: nil, 40 | OnOperationDone: nil, 41 | })(&c)) 42 | 43 | w := httptest.NewRecorder() 44 | 45 | r, _ := http.NewRequest(http.MethodGet, "", nil) 46 | 47 | _ = c.interceptors.HTTPRequest( 48 | opctx, 49 | w, 50 | r, 51 | func(reqctx context.Context, w http.ResponseWriter, r *http.Request) error { 52 | return nil 53 | }, 54 | ) 55 | 56 | assert.Equal(t, target, encountered) 57 | } 58 | 59 | func TestWithCallbacksOnRequestHandler(t *testing.T) { 60 | opctx := mutable.NewMutableContext(context.Background()) 61 | 62 | var c serverConfig 63 | 64 | var encountered error 65 | 66 | target := errors.New("123") 67 | 68 | assert.NoError(t, WithCallbacks(Callbacks{ 69 | OnRequest: nil, 70 | OnRequestDone: func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) { 71 | encountered = origerr 72 | }, 73 | OnConnect: nil, 74 | OnDisconnect: nil, 75 | OnOperation: nil, 76 | OnOperationValidation: nil, 77 | OnOperationResult: nil, 78 | OnOperationDone: nil, 79 | })(&c)) 80 | 81 | w := httptest.NewRecorder() 82 | 83 | r, _ := http.NewRequest(http.MethodGet, "", nil) 84 | 85 | _ = c.interceptors.HTTPRequest( 86 | opctx, 87 | w, 88 | r, 89 | func(reqctx context.Context, w http.ResponseWriter, r *http.Request) error { 90 | return target 91 | }, 92 | ) 93 | 94 | assert.Equal(t, target, encountered) 95 | } 96 | 97 | func TestWithCallbacksOnConnect(t *testing.T) { 98 | opctx := mutable.NewMutableContext(context.Background()) 99 | 100 | var c serverConfig 101 | 102 | var encountered error 103 | 104 | target := errors.New("123") 105 | 106 | assert.NoError(t, WithCallbacks(Callbacks{ 107 | OnRequest: nil, 108 | OnRequestDone: nil, 109 | OnConnect: func(reqctx mutable.Context, init apollows.PayloadInit) error { 110 | return target 111 | }, 112 | OnDisconnect: func(reqctx mutable.Context, origerr error) error { 113 | encountered = origerr 114 | 115 | return origerr 116 | }, 117 | OnOperation: nil, 118 | OnOperationValidation: nil, 119 | OnOperationResult: nil, 120 | OnOperationDone: nil, 121 | })(&c)) 122 | 123 | assert.Equal(t, target, c.interceptors.Init( 124 | opctx, 125 | nil, 126 | func(ctx context.Context, init apollows.PayloadInit) error { 127 | return nil 128 | }, 129 | )) 130 | 131 | assert.Equal(t, target, encountered) 132 | } 133 | 134 | func TestWithCallbacksOnConnectHandler(t *testing.T) { 135 | opctx := mutable.NewMutableContext(context.Background()) 136 | 137 | var c serverConfig 138 | 139 | var encountered error 140 | 141 | target := errors.New("123") 142 | 143 | assert.NoError(t, WithCallbacks(Callbacks{ 144 | OnRequest: nil, 145 | OnRequestDone: nil, 146 | OnConnect: nil, 147 | OnDisconnect: func(reqctx mutable.Context, origerr error) error { 148 | encountered = origerr 149 | 150 | return origerr 151 | }, 152 | OnOperation: nil, 153 | OnOperationValidation: nil, 154 | OnOperationResult: nil, 155 | OnOperationDone: nil, 156 | })(&c)) 157 | 158 | assert.Equal(t, target, c.interceptors.Init( 159 | opctx, 160 | nil, 161 | func(ctx context.Context, init apollows.PayloadInit) error { 162 | return target 163 | }, 164 | )) 165 | 166 | assert.Equal(t, target, encountered) 167 | } 168 | 169 | func TestWithCallbacksOnOperation(t *testing.T) { 170 | opctx := mutable.NewMutableContext(context.Background()) 171 | 172 | var c serverConfig 173 | 174 | var encountered error 175 | 176 | target := errors.New("123") 177 | 178 | assert.NoError(t, WithCallbacks(Callbacks{ 179 | OnRequest: nil, 180 | OnRequestDone: nil, 181 | OnConnect: nil, 182 | OnDisconnect: nil, 183 | OnOperation: func(opctx mutable.Context, payload *apollows.PayloadOperation) error { 184 | return target 185 | }, 186 | OnOperationValidation: nil, 187 | OnOperationResult: nil, 188 | OnOperationDone: func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error { 189 | encountered = origerr 190 | 191 | return origerr 192 | }, 193 | })(&c)) 194 | 195 | assert.Equal(t, target, c.interceptors.Operation( 196 | opctx, 197 | nil, 198 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 199 | return nil 200 | }, 201 | )) 202 | 203 | assert.Equal(t, target, encountered) 204 | } 205 | 206 | func TestWithCallbacksOnOperationHandler(t *testing.T) { 207 | opctx := mutable.NewMutableContext(context.Background()) 208 | 209 | var c serverConfig 210 | 211 | var encountered error 212 | 213 | target := errors.New("123") 214 | 215 | assert.NoError(t, WithCallbacks(Callbacks{ 216 | OnRequest: nil, 217 | OnRequestDone: nil, 218 | OnConnect: nil, 219 | OnDisconnect: nil, 220 | OnOperation: nil, 221 | OnOperationValidation: nil, 222 | OnOperationResult: nil, 223 | OnOperationDone: func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error { 224 | encountered = origerr 225 | 226 | return origerr 227 | }, 228 | })(&c)) 229 | 230 | assert.Equal(t, target, c.interceptors.Operation( 231 | opctx, 232 | nil, 233 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 234 | return target 235 | }, 236 | )) 237 | 238 | assert.Equal(t, target, encountered) 239 | } 240 | 241 | func TestWithCallbacksOnOperationValidation(t *testing.T) { 242 | opctx := mutable.NewMutableContext(context.Background()) 243 | 244 | var c serverConfig 245 | 246 | target := errors.New("123") 247 | 248 | assert.NoError(t, WithCallbacks(Callbacks{ 249 | OnRequest: nil, 250 | OnRequestDone: nil, 251 | OnConnect: nil, 252 | OnDisconnect: nil, 253 | OnOperation: nil, 254 | OnOperationValidation: func( 255 | opctx mutable.Context, 256 | payload *apollows.PayloadOperation, 257 | result *graphql.Result, 258 | ) error { 259 | return target 260 | }, 261 | OnOperationResult: nil, 262 | OnOperationDone: nil, 263 | })(&c)) 264 | 265 | assert.Equal(t, target, c.interceptors.OperationParse( 266 | opctx, 267 | nil, 268 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 269 | return nil 270 | }, 271 | )) 272 | } 273 | 274 | func TestWithCallbacksOnOperationExecution(t *testing.T) { 275 | opctx := mutable.NewMutableContext(context.Background()) 276 | 277 | var c serverConfig 278 | 279 | target := errors.New("123") 280 | 281 | assert.NoError(t, WithCallbacks(Callbacks{ 282 | OnRequest: nil, 283 | OnRequestDone: nil, 284 | OnConnect: nil, 285 | OnDisconnect: nil, 286 | OnOperation: nil, 287 | OnOperationValidation: nil, 288 | OnOperationResult: func( 289 | opctx mutable.Context, 290 | payload *apollows.PayloadOperation, 291 | result *graphql.Result, 292 | ) error { 293 | return target 294 | }, 295 | OnOperationDone: nil, 296 | })(&c)) 297 | 298 | ch, err := c.interceptors.OperationExecute( 299 | opctx, 300 | nil, 301 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 302 | tch := make(chan *graphql.Result, 1) 303 | 304 | tch <- &graphql.Result{} 305 | close(tch) 306 | 307 | return tch, nil 308 | }, 309 | ) 310 | 311 | assert.Nil(t, err) 312 | 313 | res := <-ch 314 | 315 | assert.NotNil(t, target, res) 316 | assert.Equal(t, target, res.Data) 317 | } 318 | 319 | func TestWithCallbacksOnOperationProcess(t *testing.T) { 320 | opctx := mutable.NewMutableContext(context.Background()) 321 | 322 | var c serverConfig 323 | 324 | target := errors.New("123") 325 | 326 | assert.NoError(t, WithCallbacks(Callbacks{ 327 | OnRequest: nil, 328 | OnRequestDone: nil, 329 | OnConnect: nil, 330 | OnDisconnect: nil, 331 | OnOperation: nil, 332 | OnOperationValidation: nil, 333 | OnOperationResult: func( 334 | opctx mutable.Context, 335 | payload *apollows.PayloadOperation, 336 | result *graphql.Result, 337 | ) error { 338 | result.Data = target 339 | 340 | return nil 341 | }, 342 | OnOperationDone: nil, 343 | })(&c)) 344 | 345 | ch, err := c.interceptors.OperationExecute( 346 | opctx, 347 | nil, 348 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 349 | tch := make(chan *graphql.Result, 1) 350 | 351 | tch <- &graphql.Result{} 352 | close(tch) 353 | 354 | return tch, nil 355 | }, 356 | ) 357 | 358 | assert.Nil(t, err) 359 | 360 | res := <-ch 361 | 362 | assert.NotNil(t, target, res) 363 | assert.Equal(t, target, res.Data) 364 | } 365 | 366 | func TestWithCallbacksOnOperationExecutionHandler(t *testing.T) { 367 | opctx := mutable.NewMutableContext(context.Background()) 368 | 369 | var c serverConfig 370 | 371 | target := errors.New("123") 372 | 373 | assert.NoError(t, WithCallbacks(Callbacks{ 374 | OnRequest: nil, 375 | OnRequestDone: nil, 376 | OnConnect: nil, 377 | OnDisconnect: nil, 378 | OnOperation: nil, 379 | OnOperationValidation: nil, 380 | OnOperationResult: nil, 381 | OnOperationDone: nil, 382 | })(&c)) 383 | 384 | ch, err := c.interceptors.OperationExecute( 385 | opctx, 386 | nil, 387 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 388 | return nil, target 389 | }, 390 | ) 391 | 392 | assert.Nil(t, ch) 393 | assert.Equal(t, target, err) 394 | } 395 | 396 | func TestWithInterceptorChain(t *testing.T) { 397 | opctx := context.Background() 398 | 399 | var c serverConfig 400 | 401 | type keyT struct{} 402 | 403 | key := keyT{} 404 | 405 | assert.NoError(t, WithInterceptors(Interceptors{ 406 | HTTPRequest: InterceptorHTTPRequestChain( 407 | func(ctx context.Context, w http.ResponseWriter, r *http.Request, handler HandlerHTTPRequest) error { 408 | return handler(context.WithValue(ctx, key, 1), w, r) 409 | }, 410 | func(ctx context.Context, w http.ResponseWriter, r *http.Request, handler HandlerHTTPRequest) error { 411 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), w, r) 412 | }, 413 | ), 414 | Init: InterceptorInitChain( 415 | func(ctx context.Context, init apollows.PayloadInit, handler HandlerInit) error { 416 | return handler(context.WithValue(ctx, key, 2), init) 417 | }, 418 | func(ctx context.Context, init apollows.PayloadInit, handler HandlerInit) error { 419 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), init) 420 | }, 421 | ), 422 | Operation: InterceptorOperationChain( 423 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperation) error { 424 | return handler(context.WithValue(ctx, key, 3), payload) 425 | }, 426 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperation) error { 427 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 428 | }, 429 | ), 430 | OperationParse: InterceptorOperationParseChain( 431 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperationParse) error { 432 | return handler(context.WithValue(ctx, key, 4), payload) 433 | }, 434 | func(ctx context.Context, payload *apollows.PayloadOperation, handler HandlerOperationParse) error { 435 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 436 | }, 437 | ), 438 | OperationExecute: InterceptorOperationExecuteChain( 439 | func( 440 | ctx context.Context, 441 | payload *apollows.PayloadOperation, 442 | handler HandlerOperationExecute, 443 | ) (chan *graphql.Result, error) { 444 | return handler(context.WithValue(ctx, key, 5), payload) 445 | }, 446 | func( 447 | ctx context.Context, 448 | payload *apollows.PayloadOperation, 449 | handler HandlerOperationExecute, 450 | ) (chan *graphql.Result, error) { 451 | return handler(context.WithValue(ctx, key, ctx.Value(key).(int)+1), payload) 452 | }, 453 | ), 454 | })(&c)) 455 | 456 | r, _ := http.NewRequest(http.MethodGet, "", nil) 457 | 458 | var res []int 459 | 460 | _ = c.interceptors.HTTPRequest( 461 | opctx, 462 | httptest.NewRecorder(), 463 | r, 464 | func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 465 | res = append(res, ctx.Value(key).(int)) 466 | 467 | return nil 468 | }, 469 | ) 470 | 471 | _ = c.interceptors.Init( 472 | opctx, 473 | nil, 474 | func(ctx context.Context, init apollows.PayloadInit) error { 475 | res = append(res, ctx.Value(key).(int)) 476 | 477 | return nil 478 | }, 479 | ) 480 | 481 | _ = c.interceptors.Operation( 482 | opctx, 483 | nil, 484 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 485 | res = append(res, ctx.Value(key).(int)) 486 | 487 | return nil 488 | }, 489 | ) 490 | 491 | _ = c.interceptors.OperationParse( 492 | opctx, 493 | nil, 494 | func(ctx context.Context, payload *apollows.PayloadOperation) error { 495 | res = append(res, ctx.Value(key).(int)) 496 | 497 | return nil 498 | }, 499 | ) 500 | 501 | _, _ = c.interceptors.OperationExecute( 502 | opctx, 503 | nil, 504 | func(ctx context.Context, payload *apollows.PayloadOperation) (chan *graphql.Result, error) { 505 | res = append(res, ctx.Value(key).(int)) 506 | 507 | return nil, nil 508 | }, 509 | ) 510 | 511 | assert.Equal(t, []int{2, 3, 4, 5, 6}, res) 512 | } 513 | 514 | func TestWithKeepalive(t *testing.T) { 515 | var c serverConfig 516 | 517 | assert.NoError(t, WithKeepalive(123)(&c)) 518 | 519 | assert.Equal(t, time.Duration(123), c.keepalive) 520 | } 521 | 522 | func TestWithoutHTTPQueries(t *testing.T) { 523 | var c serverConfig 524 | 525 | assert.NoError(t, WithoutHTTPQueries()(&c)) 526 | 527 | assert.Equal(t, true, c.rejectHTTPQueries) 528 | } 529 | 530 | func TestWithRootObject(t *testing.T) { 531 | var c serverConfig 532 | 533 | obj := make(map[string]interface{}) 534 | 535 | assert.NoError(t, WithRootObject(obj)(&c)) 536 | 537 | assert.Equal(t, obj, c.rootObject) 538 | } 539 | 540 | func TestWriteError(t *testing.T) { 541 | mutctx := mutable.NewMutableContext(context.Background()) 542 | 543 | rec := httptest.NewRecorder() 544 | 545 | WriteError(mutctx, rec, errors.New("123")) 546 | 547 | resp := rec.Result() 548 | 549 | bs, err := io.ReadAll(resp.Body) 550 | 551 | assert.NoError(t, err) 552 | assert.Equal(t, ResultError{ 553 | Result: &graphql.Result{ 554 | Errors: []gqlerrors.FormattedError{ 555 | gqlerrors.FormatError(errors.New("123")), 556 | }, 557 | }, 558 | }.Error(), string(bs)) 559 | 560 | assert.NoError(t, resp.Body.Close()) 561 | } 562 | 563 | func TestWriteErrorResponseStarted(t *testing.T) { 564 | mutctx := mutable.NewMutableContext(context.Background()) 565 | 566 | mutctx.Set(ContextKeyHTTPResponseStarted, true) 567 | 568 | rec := httptest.NewRecorder() 569 | 570 | WriteError(mutctx, rec, errors.New("123")) 571 | 572 | resp := rec.Result() 573 | 574 | bs, err := io.ReadAll(resp.Body) 575 | 576 | assert.NoError(t, err) 577 | assert.Equal(t, "", string(bs)) 578 | 579 | assert.NoError(t, resp.Body.Close()) 580 | } 581 | 582 | func TestOptError(t *testing.T) { 583 | srv, err := NewServer(testNewSchema(t), func(config *serverConfig) error { 584 | return errors.New("123") 585 | }) 586 | 587 | assert.Error(t, err) 588 | assert.Nil(t, srv) 589 | } 590 | -------------------------------------------------------------------------------- /v1/apollows/proto.go: -------------------------------------------------------------------------------- 1 | // Package apollows provides implementation of GraphQL over WebSocket Protocol as defined by 2 | // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md [GWS] 3 | // https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md [GTWS] 4 | package apollows 5 | 6 | import ( 7 | "encoding/json" 8 | "errors" 9 | "io" 10 | ) 11 | 12 | // Protocol websocket subprotocol defining server behavior 13 | type Protocol string 14 | 15 | // String conversion 16 | func (p Protocol) String() string { 17 | return string(p) 18 | } 19 | 20 | const ( 21 | // WebsocketSubprotocolGraphqlWS websocket subprotocol expected by subscriptions-transport-ws implementations 22 | WebsocketSubprotocolGraphqlWS Protocol = "graphql-ws" 23 | 24 | // WebsocketSubprotocolGraphqlTransportWS websocket subprotocol exepected by graphql-ws implementations 25 | WebsocketSubprotocolGraphqlTransportWS Protocol = "graphql-transport-ws" 26 | ) 27 | 28 | // ErrUnknownProtocol indicates that unknown subscription protocol was requested 29 | var ErrUnknownProtocol = errors.New("unknown subscription protocol") 30 | 31 | // Operation type is used to enumerate possible apollo message types 32 | type Operation string 33 | 34 | const ( 35 | // OperationConnectionInit [GWS,GTWS] 36 | // is set by the connecting client to initialize the websocket state with connection params (if any) 37 | OperationConnectionInit Operation = "connection_init" 38 | 39 | // OperationStart [GWS] 40 | // client request initiates new operation, each operation may have 0-N OperationData responses before being 41 | // terminated by either OperationComplete or OperationError 42 | OperationStart Operation = "start" 43 | 44 | // OperationSubscribe [GTWS] 45 | // client request initiates new operation, each operation may have 0-N OperationNext responses before being 46 | // terminated by either OperationComplete or OperationError 47 | OperationSubscribe Operation = "subscribe" 48 | 49 | // OperationTerminate [GWS] 50 | // client request to gracefully close the connection, equivalent to closing the websocket 51 | OperationTerminate Operation = "connection_terminate" 52 | 53 | // OperationConnectionError [GWS] 54 | // server response to unsuccessful OperationConnectionInit attempt 55 | OperationConnectionError Operation = "connection_error" 56 | 57 | // OperationConnectionAck [GWS,GTWS] 58 | // server response to successful OperationConnectionInit attempt 59 | OperationConnectionAck Operation = "connection_ack" 60 | 61 | // OperationData [GWS] 62 | // server response to previously initiated operation with OperationStart may be multiple within same operation, 63 | // specifically for subscriptions 64 | OperationData Operation = "data" 65 | 66 | // OperationNext [GTWS] 67 | // server response to previously initiated operation with OperationSubscribe may be multiple within same operation, 68 | // specifically for subscriptions 69 | OperationNext Operation = "next" 70 | 71 | // OperationError [GWS,GTWS] 72 | // server response to previously initiated operation with OperationStart/OperationSubscribe 73 | OperationError Operation = "error" 74 | 75 | // OperationStop [GWS] 76 | // client request to stop previously initiated operation with OperationStart 77 | OperationStop Operation = "stop" 78 | 79 | // OperationComplete [GWS,GTWS] 80 | // GWS: server response indicating previously initiated operation is complete 81 | // GTWS: server response indicating previously initiated operation is complete 82 | // GTWS: client request to stop previously initiated operation with OperationSubscribe 83 | OperationComplete Operation = "complete" 84 | 85 | // OperationKeepAlive [GWS] 86 | // server response sent periodically to maintain websocket connection open 87 | OperationKeepAlive Operation = "ka" 88 | 89 | // OperationPing [GTWS] 90 | // sever/client request for OperationPong response 91 | OperationPing Operation = "ping" 92 | 93 | // OperationPong [GTWS] 94 | // sever/client response for OperationPing request 95 | // can be sent at any time (without prior OperationPing) to maintain wesocket connection 96 | OperationPong Operation = "pong" 97 | ) 98 | 99 | // Error providing MessageType to close websocket with 100 | type Error interface { 101 | error 102 | EventMessageType() MessageType 103 | } 104 | 105 | type errorImpl struct { 106 | error 107 | message string 108 | messageType MessageType 109 | } 110 | 111 | func (e errorImpl) EventMessageType() MessageType { 112 | return e.messageType 113 | } 114 | 115 | func (e errorImpl) Unwrap() error { 116 | return e.error 117 | } 118 | 119 | func (e errorImpl) Error() string { 120 | return e.message 121 | } 122 | 123 | // WrapError wraps provided error into Error 124 | func WrapError(err error, messageType MessageType) Error { 125 | message := err.Error() 126 | 127 | if messageType.Error() != "" { 128 | message = messageType.Error() + ": " + message 129 | } 130 | 131 | return errorImpl{ 132 | error: err, 133 | message: message, 134 | messageType: messageType, 135 | } 136 | } 137 | 138 | // NewSubscriberAlreadyExistsError constructs new Error using subscriber id as part of the message 139 | func NewSubscriberAlreadyExistsError(id string) Error { 140 | return errorImpl{ 141 | error: nil, 142 | message: "Subscriber for " + id + " already exists", 143 | messageType: EventSubscriberAlreadyExists, 144 | } 145 | } 146 | 147 | // MessageType websocket message types / status codes used to indicate protocol-level events following closing the 148 | // websocket 149 | type MessageType int 150 | 151 | const ( 152 | // EventCloseNormal standard websocket message type 153 | EventCloseNormal MessageType = 1000 154 | 155 | // EventCloseError standard websocket message type 156 | EventCloseError MessageType = 1006 157 | 158 | // EventInvalidMessage indicates invalid protocol message 159 | EventInvalidMessage MessageType = 4400 160 | 161 | // EventUnauthorized indicated attempt to subscribe to an operation before receiving OperationConnectionAck 162 | EventUnauthorized MessageType = 4401 163 | 164 | // EventInitializationTimeout indicates timeout occurring before client sending OperationConnectionInit 165 | EventInitializationTimeout MessageType = 4408 166 | 167 | // EventTooManyInitializationRequests indicates receiving more than one OperationConnectionInit 168 | EventTooManyInitializationRequests MessageType = 4429 169 | 170 | // EventSubscriberAlreadyExists indicates subscribed operation ID already being in use 171 | // (not yet terminated by either OperationComplete or OperationError) 172 | EventSubscriberAlreadyExists MessageType = 4409 173 | ) 174 | 175 | var messageTypeDescriptions = map[MessageType]string{ 176 | EventCloseNormal: "Termination requested", 177 | EventInvalidMessage: "Invalid message", 178 | EventUnauthorized: "Unauthorized", 179 | EventInitializationTimeout: "Connection initialisation timeout", 180 | EventTooManyInitializationRequests: "Too many initialisation requests", 181 | } 182 | 183 | // EventMessageType implementation 184 | func (m MessageType) EventMessageType() MessageType { 185 | return m 186 | } 187 | 188 | // Error implementation 189 | func (m MessageType) Error() string { 190 | return messageTypeDescriptions[m] 191 | } 192 | 193 | // Data encapsulates both client and server json payload, combining json.RawMessage for decoding and 194 | // arbitrary interface{} type for encoding 195 | type Data struct { 196 | Value interface{} 197 | json.RawMessage 198 | } 199 | 200 | // Ptr returns non-nil pointer to Data if it is not empty 201 | func (payload *Data) Ptr() *Data { 202 | if payload == nil || payload.Value == nil { 203 | return nil 204 | } 205 | 206 | return payload 207 | } 208 | 209 | // ReadPayloadData client-side method to parse server response 210 | func (payload *Data) ReadPayloadData() (*PayloadDataResponse, error) { 211 | if payload == nil { 212 | return nil, io.ErrUnexpectedEOF 213 | } 214 | 215 | var pd PayloadDataResponse 216 | 217 | err := json.Unmarshal(payload.RawMessage, &pd) 218 | if err != nil { 219 | return nil, err 220 | } 221 | 222 | payload.Value = pd 223 | 224 | return &pd, nil 225 | } 226 | 227 | // ReadPayloadError client-side method to parse server error response 228 | func (payload *Data) ReadPayloadError() (*PayloadError, error) { 229 | if payload == nil { 230 | return nil, io.ErrUnexpectedEOF 231 | } 232 | 233 | var pd PayloadError 234 | 235 | err := json.Unmarshal(payload.RawMessage, &pd) 236 | if err != nil { 237 | return nil, err 238 | } 239 | 240 | payload.Value = pd 241 | 242 | return &pd, nil 243 | } 244 | 245 | // ReadPayloadErrors client-side method to parse server error response 246 | func (payload *Data) ReadPayloadErrors() (pds []*PayloadError, err error) { 247 | if payload == nil { 248 | return nil, io.ErrUnexpectedEOF 249 | } 250 | 251 | err = json.Unmarshal(payload.RawMessage, &pds) 252 | if err != nil { 253 | return nil, err 254 | } 255 | 256 | payload.Value = pds 257 | 258 | return pds, nil 259 | } 260 | 261 | // UnmarshalJSON stores provided json as a RawMessage, as well as initializing Value to same RawMessage 262 | // to support both identity re-serialization and modification of Value after initialization 263 | func (payload *Data) UnmarshalJSON(bs []byte) (err error) { 264 | payload.RawMessage = bs 265 | payload.Value = payload.RawMessage 266 | 267 | return nil 268 | } 269 | 270 | // MarshalJSON marshals either provided or deserialized Value as json 271 | func (payload Data) MarshalJSON() (bs []byte, err error) { 272 | return json.Marshal(payload.Value) 273 | } 274 | 275 | // PayloadInit provides connection params 276 | type PayloadInit map[string]interface{} 277 | 278 | // PayloadOperation provides description for client-side operation initiation 279 | type PayloadOperation struct { 280 | Variables map[string]interface{} `json:"variables"` 281 | Extensions map[string]interface{} `json:"extensions"` 282 | Query string `json:"query"` 283 | OperationName string `json:"operationName"` 284 | } 285 | 286 | // PayloadDataRaw provides server-side response for previously started operation 287 | type PayloadDataRaw struct { 288 | Data Data `json:"data"` 289 | 290 | // see https://github.com/graph-gophers/graphql-go#custom-errors 291 | // for adding custom error attributes 292 | Errors []error `json:"errors,omitempty"` 293 | } 294 | 295 | // PayloadData type-alias for serialization 296 | type PayloadData PayloadDataRaw 297 | 298 | // MarshalJSON serializes PayloadData to JSON, excluding empty data 299 | func (payload PayloadData) MarshalJSON() (bs []byte, err error) { 300 | return json.Marshal(struct { 301 | Data *Data `json:"data,omitempty"` 302 | PayloadDataRaw 303 | }{ 304 | PayloadDataRaw: PayloadDataRaw(payload), 305 | Data: payload.Data.Ptr(), 306 | }) 307 | } 308 | 309 | // PayloadErrorLocation error location in originating request 310 | type PayloadErrorLocation struct { 311 | Line int `json:"line"` 312 | Column int `json:"column"` 313 | } 314 | 315 | // PayloadError client-side error representation 316 | type PayloadError struct { 317 | Extensions map[string]interface{} `json:"extensions"` 318 | Message string `json:"message"` 319 | Locations []PayloadErrorLocation `json:"locations"` 320 | Path []string `json:"path"` 321 | } 322 | 323 | // PayloadDataResponse provides client-side payload representation 324 | type PayloadDataResponse struct { 325 | Data map[string]interface{} `json:"data,omitempty"` 326 | Errors []PayloadError `json:"errors,omitempty"` 327 | } 328 | 329 | // MessageRaw encapsulates every message within apollows protocol in both directions 330 | type MessageRaw struct { 331 | ID string `json:"id,omitempty"` 332 | Type Operation `json:"type"` 333 | Payload Data `json:"payload"` 334 | } 335 | 336 | // Message type-alias for (de-)serialization 337 | type Message MessageRaw 338 | 339 | // MarshalJSON serializes Message to JSON, excluding empty id or payload from serialized fields. 340 | func (message Message) MarshalJSON() (bs []byte, err error) { 341 | return json.Marshal(struct { 342 | Payload *Data `json:"payload,omitempty"` 343 | MessageRaw 344 | }{ 345 | MessageRaw: MessageRaw(message), 346 | Payload: message.Payload.Ptr(), 347 | }) 348 | } 349 | -------------------------------------------------------------------------------- /v1/apollows/proto_test.go: -------------------------------------------------------------------------------- 1 | package apollows 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "io" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestDataMarshal(t *testing.T) { 13 | data := Data{ 14 | Value: 123, 15 | } 16 | 17 | bs, err := json.Marshal(data) 18 | 19 | assert.NoError(t, err) 20 | assert.Equal(t, "123", string(bs)) 21 | } 22 | 23 | func TestDataUnmarshal(t *testing.T) { 24 | var data Data 25 | 26 | err := json.Unmarshal([]byte(`123`), &data) 27 | 28 | assert.NoError(t, err) 29 | assert.Equal(t, "123", string(data.RawMessage)) 30 | } 31 | 32 | func TestDataUnmarshalCycle(t *testing.T) { 33 | var data Data 34 | 35 | err := json.Unmarshal([]byte(`123`), &data) 36 | 37 | assert.NoError(t, err) 38 | assert.Equal(t, "123", string(data.RawMessage)) 39 | 40 | bs, err := json.Marshal(data) 41 | 42 | assert.NoError(t, err) 43 | assert.Equal(t, "123", string(bs)) 44 | } 45 | 46 | func TestDataReadPayloadData(t *testing.T) { 47 | data := Data{ 48 | Value: PayloadData{ 49 | Data: Data{ 50 | Value: map[string]interface{}{ 51 | "foo": "123", 52 | }, 53 | }, 54 | }, 55 | } 56 | 57 | bs, err := json.Marshal(data) 58 | 59 | assert.NoError(t, err) 60 | 61 | var ndata Data 62 | 63 | err = json.Unmarshal(bs, &ndata) 64 | 65 | assert.NoError(t, err) 66 | 67 | pd, err := ndata.ReadPayloadData() 68 | 69 | assert.NoError(t, err) 70 | 71 | assert.Equal(t, "123", pd.Data["foo"]) 72 | } 73 | 74 | func TestDataReadPayloadDataError(t *testing.T) { 75 | var ndata *Data 76 | 77 | pd, err := ndata.ReadPayloadData() 78 | 79 | assert.Error(t, err) 80 | assert.Nil(t, pd) 81 | 82 | ndata = &Data{ 83 | Value: nil, 84 | RawMessage: json.RawMessage(`foo`), 85 | } 86 | 87 | pd, err = ndata.ReadPayloadData() 88 | 89 | assert.Error(t, err) 90 | assert.Nil(t, pd) 91 | } 92 | 93 | func TestDataReadPayloadError(t *testing.T) { 94 | data := Data{ 95 | Value: PayloadError{ 96 | Message: "123", 97 | }, 98 | } 99 | 100 | bs, err := json.Marshal(data) 101 | 102 | assert.NoError(t, err) 103 | 104 | var ndata Data 105 | 106 | err = json.Unmarshal(bs, &ndata) 107 | 108 | assert.NoError(t, err) 109 | 110 | pd, err := ndata.ReadPayloadError() 111 | 112 | assert.NoError(t, err) 113 | assert.Equal(t, "123", pd.Message) 114 | } 115 | 116 | func TestDataReadPayloadErrorError(t *testing.T) { 117 | var ndata *Data 118 | 119 | pd, err := ndata.ReadPayloadError() 120 | 121 | assert.Error(t, err) 122 | assert.Nil(t, pd) 123 | 124 | ndata = &Data{ 125 | Value: nil, 126 | RawMessage: json.RawMessage(`foo`), 127 | } 128 | 129 | pd, err = ndata.ReadPayloadError() 130 | 131 | assert.Error(t, err) 132 | assert.Nil(t, pd) 133 | } 134 | 135 | func TestDataReadPayloadErrors(t *testing.T) { 136 | data := Data{ 137 | Value: []PayloadError{ 138 | { 139 | Message: "123", 140 | }, 141 | }, 142 | } 143 | 144 | bs, err := json.Marshal(data) 145 | 146 | assert.NoError(t, err) 147 | 148 | var ndata Data 149 | 150 | err = json.Unmarshal(bs, &ndata) 151 | 152 | assert.NoError(t, err) 153 | 154 | pd, err := ndata.ReadPayloadErrors() 155 | 156 | assert.NoError(t, err) 157 | assert.Len(t, pd, 1) 158 | assert.Equal(t, "123", pd[0].Message) 159 | } 160 | 161 | func TestDataReadPayloadErrorsError(t *testing.T) { 162 | var ndata *Data 163 | 164 | pd, err := ndata.ReadPayloadErrors() 165 | 166 | assert.Error(t, err) 167 | assert.Nil(t, pd) 168 | 169 | ndata = &Data{ 170 | Value: nil, 171 | RawMessage: json.RawMessage(`foo`), 172 | } 173 | 174 | pd, err = ndata.ReadPayloadErrors() 175 | 176 | assert.Error(t, err) 177 | assert.Nil(t, pd) 178 | } 179 | 180 | func TestWrapError(t *testing.T) { 181 | err := WrapError(io.EOF, EventUnauthorized) 182 | 183 | assert.True(t, errors.Is(err, io.EOF)) 184 | } 185 | 186 | func TestMessageMarshal(t *testing.T) { 187 | data := Message{ 188 | ID: "123", 189 | Type: OperationConnectionAck, 190 | Payload: Data{ 191 | Value: "foo", 192 | }, 193 | } 194 | 195 | bs, err := json.Marshal(data) 196 | 197 | assert.NoError(t, err) 198 | assert.JSONEq(t, `{"id":"123","type":"connection_ack","payload":"foo"}`, string(bs)) 199 | } 200 | 201 | func TestMessageMarshalEmpty(t *testing.T) { 202 | data := Message{ 203 | ID: "123", 204 | Type: OperationConnectionAck, 205 | Payload: Data{}, 206 | } 207 | 208 | bs, err := json.Marshal(data) 209 | 210 | assert.NoError(t, err) 211 | assert.JSONEq(t, `{"id":"123","type":"connection_ack"}`, string(bs)) 212 | } 213 | 214 | func TestPayloadDataMarshal(t *testing.T) { 215 | data := PayloadData{ 216 | Data: Data{ 217 | Value: "foo", 218 | }, 219 | } 220 | 221 | bs, err := json.Marshal(data) 222 | 223 | assert.NoError(t, err) 224 | assert.JSONEq(t, `{"data":"foo"}`, string(bs)) 225 | } 226 | 227 | func TestPayloadDataMarshalEmpty(t *testing.T) { 228 | data := PayloadData{ 229 | Data: Data{}, 230 | } 231 | 232 | bs, err := json.Marshal(data) 233 | 234 | assert.NoError(t, err) 235 | assert.JSONEq(t, `{}`, string(bs)) 236 | } 237 | -------------------------------------------------------------------------------- /v1/ast.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/graphql-go/graphql" 9 | "github.com/graphql-go/graphql/gqlerrors" 10 | "github.com/graphql-go/graphql/language/ast" 11 | "github.com/graphql-go/graphql/language/parser" 12 | "github.com/graphql-go/graphql/language/source" 13 | ) 14 | 15 | func (server *serverImpl) handleExtensionsInits(p *graphql.Params) *graphql.Result { 16 | var errs gqlerrors.FormattedErrors 17 | 18 | for _, ext := range server.extensions { 19 | func() { 20 | defer func() { 21 | if r := recover(); r != nil { 22 | errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), r))) 23 | } 24 | }() 25 | 26 | p.Context = ext.Init(p.Context, p) 27 | }() 28 | } 29 | 30 | if len(errs) == 0 { 31 | return nil 32 | } 33 | 34 | return &graphql.Result{ 35 | Errors: errs, 36 | } 37 | } 38 | 39 | func (server *serverImpl) handleExtensionsParseDidStart( 40 | p *graphql.Params, 41 | ) (res *graphql.Result, endfn func(err error) *graphql.Result) { 42 | fs := make(map[string]graphql.ParseFinishFunc) 43 | 44 | var errs gqlerrors.FormattedErrors 45 | 46 | for _, ext := range server.extensions { 47 | var ( 48 | ctx context.Context 49 | finishFn graphql.ParseFinishFunc 50 | ) 51 | 52 | func() { 53 | defer func() { 54 | if r := recover(); r != nil { 55 | errs = append( 56 | errs, 57 | gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), r)), 58 | ) 59 | } 60 | }() 61 | 62 | ctx, finishFn = ext.ParseDidStart(p.Context) 63 | 64 | p.Context = ctx 65 | 66 | fs[ext.Name()] = finishFn 67 | }() 68 | } 69 | 70 | endfn = func(err error) *graphql.Result { 71 | var inerrs gqlerrors.FormattedErrors 72 | 73 | if err != nil { 74 | inerrs = append(inerrs, gqlerrors.FormatError(err)) 75 | } 76 | 77 | for name, fn := range fs { 78 | func() { 79 | defer func() { 80 | if r := recover(); r != nil { 81 | inerrs = append( 82 | inerrs, 83 | gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", name, r)), 84 | ) 85 | } 86 | }() 87 | 88 | fn(err) 89 | }() 90 | } 91 | 92 | if len(inerrs) == 0 { 93 | return nil 94 | } 95 | 96 | return &graphql.Result{ 97 | Errors: inerrs, 98 | } 99 | } 100 | 101 | if len(errs) > 0 { 102 | res = &graphql.Result{ 103 | Errors: errs, 104 | } 105 | } 106 | 107 | return 108 | } 109 | 110 | func (server *serverImpl) handleExtensionsValidationDidStart( 111 | p *graphql.Params, 112 | ) (errs []gqlerrors.FormattedError, endfn func(errs []gqlerrors.FormattedError) []gqlerrors.FormattedError) { 113 | fs := make(map[string]graphql.ValidationFinishFunc) 114 | 115 | for _, ext := range server.extensions { 116 | var ( 117 | ctx context.Context 118 | finishFn graphql.ValidationFinishFunc 119 | ) 120 | 121 | func() { 122 | defer func() { 123 | if r := recover(); r != nil { 124 | errs = append( 125 | errs, 126 | gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), r)), 127 | ) 128 | } 129 | }() 130 | 131 | ctx, finishFn = ext.ValidationDidStart(p.Context) 132 | 133 | p.Context = ctx 134 | fs[ext.Name()] = finishFn 135 | }() 136 | } 137 | 138 | endfn = func(errs []gqlerrors.FormattedError) (inerrs []gqlerrors.FormattedError) { 139 | inerrs = append(inerrs, errs...) 140 | 141 | for name, finishFn := range fs { 142 | func() { 143 | defer func() { 144 | if r := recover(); r != nil { 145 | inerrs = append( 146 | inerrs, 147 | gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", name, r)), 148 | ) 149 | } 150 | }() 151 | 152 | finishFn(errs) 153 | }() 154 | } 155 | 156 | return 157 | } 158 | 159 | return 160 | } 161 | 162 | func (server *serverImpl) parseAST( 163 | ctx context.Context, 164 | payload *apollows.PayloadOperation, 165 | ) (err error) { 166 | opctx := OperationContext(ctx) 167 | 168 | var result *graphql.Result 169 | 170 | defer func() { 171 | if result != nil { 172 | err = ResultError{Result: result} 173 | } 174 | }() 175 | 176 | src := source.NewSource(&source.Source{ 177 | Body: []byte(payload.Query), 178 | Name: "GraphQL request", 179 | }) 180 | 181 | params := graphql.Params{ 182 | Schema: server.schema, 183 | RequestString: payload.Query, 184 | RootObject: server.rootObject, 185 | VariableValues: payload.Variables, 186 | OperationName: payload.OperationName, 187 | Context: ctx, 188 | } 189 | 190 | opctx.Set(ContextKeyOperationParams, ¶ms) 191 | 192 | result = server.handleExtensionsInits(¶ms) 193 | if result != nil { 194 | return 195 | } 196 | 197 | var parseFinishFn func(err error) *graphql.Result 198 | 199 | result, parseFinishFn = server.handleExtensionsParseDidStart(¶ms) 200 | if result != nil { 201 | return 202 | } 203 | 204 | astdoc, err := parser.Parse(parser.ParseParams{Source: src}) 205 | 206 | opctx.Set(ContextKeyAST, astdoc) 207 | 208 | result = parseFinishFn(err) 209 | if result != nil { 210 | return 211 | } 212 | 213 | errs, validationFinishFn := server.handleExtensionsValidationDidStart(¶ms) 214 | 215 | validationResult := graphql.ValidateDocument(¶ms.Schema, astdoc, nil) 216 | 217 | errs = append(errs, validationFinishFn(validationResult.Errors)...) 218 | 219 | if len(errs) > 0 || !validationResult.IsValid { 220 | result = &graphql.Result{ 221 | Errors: errs, 222 | } 223 | 224 | return 225 | } 226 | 227 | var subscription bool 228 | 229 | for _, definition := range astdoc.Definitions { 230 | op, ok := definition.(*ast.OperationDefinition) 231 | if !ok { 232 | continue 233 | } 234 | 235 | if op.Operation == ast.OperationTypeSubscription { 236 | subscription = true 237 | 238 | break 239 | } 240 | } 241 | 242 | opctx.Set(ContextKeySubscription, subscription) 243 | 244 | return 245 | } 246 | -------------------------------------------------------------------------------- /v1/ast_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql" 10 | "github.com/graphql-go/graphql/gqlerrors" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestASTParse(t *testing.T) { 15 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 16 | Query: graphql.NewObject(graphql.ObjectConfig{ 17 | Name: "QueryRoot", 18 | Interfaces: nil, 19 | Fields: graphql.Fields{ 20 | "foo": &graphql.Field{ 21 | Name: "FooType", 22 | Type: graphql.Int, 23 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 24 | return 123, nil 25 | }, 26 | }, 27 | }, 28 | }), 29 | }) 30 | 31 | assert.NoError(t, err) 32 | assert.NotNil(t, schema) 33 | 34 | server, err := NewServer(schema) 35 | 36 | assert.NoError(t, err) 37 | assert.NotNil(t, server) 38 | 39 | impl, ok := server.(*serverImpl) 40 | 41 | assert.True(t, ok) 42 | 43 | opctx := mutable.NewMutableContext(context.Background()) 44 | opctx.Set(ContextKeyOperationContext, opctx) 45 | 46 | err = impl.parseAST(opctx, &apollows.PayloadOperation{ 47 | Query: `query { foo }`, 48 | Variables: nil, 49 | OperationName: "", 50 | }) 51 | 52 | assert.Nil(t, err) 53 | assert.False(t, ContextSubscription(opctx)) 54 | assert.NotNil(t, ContextAST(opctx)) 55 | assert.NotNil(t, ContextOperationParams(opctx)) 56 | } 57 | 58 | func TestASTParseSubscription(t *testing.T) { 59 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 60 | Query: graphql.NewObject(graphql.ObjectConfig{ 61 | Name: "QueryRoot", 62 | Interfaces: nil, 63 | Fields: graphql.Fields{ 64 | "foo": &graphql.Field{ 65 | Name: "FooType", 66 | Type: graphql.Int, 67 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 68 | return 123, nil 69 | }, 70 | }, 71 | }, 72 | }), 73 | Subscription: graphql.NewObject(graphql.ObjectConfig{ 74 | Name: "SubscriptionRoot", 75 | Interfaces: nil, 76 | Fields: graphql.Fields{ 77 | "foo": &graphql.Field{ 78 | Name: "FooType", 79 | Type: graphql.Int, 80 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 81 | return 123, nil 82 | }, 83 | }, 84 | }, 85 | }), 86 | }) 87 | 88 | assert.NoError(t, err) 89 | assert.NotNil(t, schema) 90 | 91 | server, err := NewServer(schema) 92 | 93 | assert.NoError(t, err) 94 | assert.NotNil(t, server) 95 | 96 | impl, ok := server.(*serverImpl) 97 | 98 | assert.True(t, ok) 99 | 100 | opctx := mutable.NewMutableContext(context.Background()) 101 | opctx.Set(ContextKeyOperationContext, opctx) 102 | 103 | err = impl.parseAST(opctx, &apollows.PayloadOperation{ 104 | Query: `subscription { foo }`, 105 | Variables: nil, 106 | OperationName: "", 107 | }) 108 | 109 | assert.Nil(t, err) 110 | assert.True(t, ContextSubscription(opctx)) 111 | assert.NotNil(t, ContextAST(opctx)) 112 | assert.NotNil(t, ContextOperationParams(opctx)) 113 | } 114 | 115 | type testExt struct { 116 | initFn func(ctx context.Context, p *graphql.Params) context.Context 117 | hasResultFn func() bool 118 | getResultFn func(context.Context) interface{} 119 | parseDidStartFn func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) 120 | validationDidStartFn func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) 121 | executionDidStartFn func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) 122 | resolveFieldDidStartFn func( 123 | ctx context.Context, 124 | i *graphql.ResolveInfo, 125 | ) (context.Context, graphql.ResolveFieldFinishFunc) 126 | name string 127 | } 128 | 129 | func (t *testExt) Init(ctx context.Context, p *graphql.Params) context.Context { 130 | return t.initFn(ctx, p) 131 | } 132 | 133 | func (t *testExt) Name() string { 134 | return t.name 135 | } 136 | 137 | func (t *testExt) HasResult() bool { 138 | return t.hasResultFn() 139 | } 140 | 141 | func (t *testExt) GetResult(ctx context.Context) interface{} { 142 | return t.getResultFn(ctx) 143 | } 144 | 145 | func (t *testExt) ParseDidStart(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 146 | return t.parseDidStartFn(ctx) 147 | } 148 | 149 | func (t *testExt) ValidationDidStart(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 150 | return t.validationDidStartFn(ctx) 151 | } 152 | 153 | func (t *testExt) ExecutionDidStart(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 154 | return t.executionDidStartFn(ctx) 155 | } 156 | 157 | func (t *testExt) ResolveFieldDidStart( 158 | ctx context.Context, 159 | i *graphql.ResolveInfo, 160 | ) (context.Context, graphql.ResolveFieldFinishFunc) { 161 | return t.resolveFieldDidStartFn(ctx, i) 162 | } 163 | 164 | func testAstParseExtensions( 165 | t *testing.T, 166 | opctx mutable.Context, 167 | f func(ext *testExt), 168 | ) (err error) { 169 | text := &testExt{ 170 | name: "foo", 171 | initFn: func(ctx context.Context, p *graphql.Params) context.Context { 172 | return ctx 173 | }, 174 | hasResultFn: func() bool { 175 | return true 176 | }, 177 | getResultFn: func(ctx context.Context) interface{} { 178 | return nil 179 | }, 180 | parseDidStartFn: func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 181 | return ctx, func(err error) { 182 | 183 | } 184 | }, 185 | validationDidStartFn: func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 186 | return ctx, func(errors []gqlerrors.FormattedError) { 187 | 188 | } 189 | }, 190 | executionDidStartFn: func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 191 | return ctx, func(result *graphql.Result) { 192 | 193 | } 194 | }, 195 | resolveFieldDidStartFn: func( 196 | ctx context.Context, 197 | i *graphql.ResolveInfo, 198 | ) (context.Context, graphql.ResolveFieldFinishFunc) { 199 | return ctx, func(i interface{}, err error) { 200 | 201 | } 202 | }, 203 | } 204 | 205 | f(text) 206 | 207 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 208 | Query: graphql.NewObject(graphql.ObjectConfig{ 209 | Name: "QueryRoot", 210 | Interfaces: nil, 211 | Fields: graphql.Fields{ 212 | "foo": &graphql.Field{ 213 | Name: "FooType", 214 | Type: graphql.Int, 215 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 216 | return 123, nil 217 | }, 218 | }, 219 | }, 220 | }), 221 | Extensions: []graphql.Extension{ 222 | text, 223 | }, 224 | }) 225 | 226 | assert.NoError(t, err) 227 | assert.NotNil(t, schema) 228 | 229 | server, err := NewServer(schema) 230 | 231 | assert.NoError(t, err) 232 | assert.NotNil(t, server) 233 | 234 | impl, ok := server.(*serverImpl) 235 | 236 | assert.True(t, ok) 237 | 238 | return impl.parseAST(opctx, &apollows.PayloadOperation{ 239 | Query: `query { foo }`, 240 | Variables: nil, 241 | OperationName: "", 242 | }) 243 | } 244 | 245 | func TestASTParseExtensions(t *testing.T) { 246 | opctx := mutable.NewMutableContext(context.Background()) 247 | opctx.Set(ContextKeyOperationContext, opctx) 248 | 249 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 250 | 251 | }) 252 | 253 | assert.Nil(t, err) 254 | assert.False(t, ContextSubscription(opctx)) 255 | assert.NotNil(t, ContextAST(opctx)) 256 | assert.NotNil(t, ContextOperationParams(opctx)) 257 | } 258 | 259 | func TestASTParseExtensionsPanicInit(t *testing.T) { 260 | opctx := mutable.NewMutableContext(context.Background()) 261 | opctx.Set(ContextKeyOperationContext, opctx) 262 | 263 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 264 | ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context { 265 | panic(1) 266 | } 267 | }) 268 | 269 | assert.NotNil(t, err) 270 | assert.False(t, ContextSubscription(opctx)) 271 | assert.Nil(t, ContextAST(opctx)) 272 | assert.NotNil(t, ContextOperationParams(opctx)) 273 | } 274 | 275 | func TestASTParseExtensionsPanicValidation(t *testing.T) { 276 | opctx := mutable.NewMutableContext(context.Background()) 277 | opctx.Set(ContextKeyOperationContext, opctx) 278 | 279 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 280 | ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 281 | panic(1) 282 | } 283 | }) 284 | 285 | assert.NotNil(t, err) 286 | assert.False(t, ContextSubscription(opctx)) 287 | assert.NotNil(t, ContextAST(opctx)) 288 | assert.NotNil(t, ContextOperationParams(opctx)) 289 | } 290 | 291 | func TestASTParseExtensionsPanicValidationCb(t *testing.T) { 292 | opctx := mutable.NewMutableContext(context.Background()) 293 | opctx.Set(ContextKeyOperationContext, opctx) 294 | 295 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 296 | ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 297 | return ctx, func(errors []gqlerrors.FormattedError) { 298 | panic(1) 299 | } 300 | } 301 | }) 302 | 303 | assert.NotNil(t, err) 304 | assert.False(t, ContextSubscription(opctx)) 305 | assert.NotNil(t, ContextAST(opctx)) 306 | assert.NotNil(t, ContextOperationParams(opctx)) 307 | } 308 | 309 | func TestASTParseExtensionsPanicParse(t *testing.T) { 310 | opctx := mutable.NewMutableContext(context.Background()) 311 | opctx.Set(ContextKeyOperationContext, opctx) 312 | 313 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 314 | ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 315 | panic(1) 316 | } 317 | }) 318 | 319 | assert.NotNil(t, err) 320 | assert.False(t, ContextSubscription(opctx)) 321 | assert.Nil(t, ContextAST(opctx)) 322 | assert.NotNil(t, ContextOperationParams(opctx)) 323 | } 324 | 325 | func TestASTParseExtensionsPanicParseCb(t *testing.T) { 326 | opctx := mutable.NewMutableContext(context.Background()) 327 | opctx.Set(ContextKeyOperationContext, opctx) 328 | 329 | err := testAstParseExtensions(t, opctx, func(ext *testExt) { 330 | ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 331 | return ctx, func(err error) { 332 | panic(1) 333 | } 334 | } 335 | }) 336 | 337 | assert.NotNil(t, err) 338 | assert.False(t, ContextSubscription(opctx)) 339 | assert.NotNil(t, ContextAST(opctx)) 340 | assert.NotNil(t, ContextOperationParams(opctx)) 341 | } 342 | -------------------------------------------------------------------------------- /v1/callbacks.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1/apollows" 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql" 10 | ) 11 | 12 | // WithCallbacks option sets callbacks handling various stages of requests 13 | // Deprecated: use WithInterceptors / WithResultProcessor 14 | func WithCallbacks(callbacks Callbacks) ServerOption { 15 | return WithExtraInterceptors(Interceptors{ 16 | HTTPRequest: legacyCallbackHTTPRequest(callbacks), 17 | Init: legacyCallbackInit(callbacks), 18 | Operation: legacyCallbackOperation(callbacks), 19 | OperationParse: legacyCallbackOperationParse(callbacks), 20 | OperationExecute: legacyCallbackExecute(callbacks), 21 | }) 22 | } 23 | 24 | // Callbacks supported by the server 25 | // use wsgraphql.ContextHTTPRequest / wsgraphql.ContextHTTPResponseWriter to access underlying 26 | // http.Request and http.ResponseWriter 27 | // Sequence: 28 | // OnRequest -> OnConnect -> 29 | // [ OnOperation -> OnOperationValidation -> OnOperationResult -> OnOperationDone ]* -> 30 | // OnDisconnect -> OnRequestDone 31 | // Deprecated: use Interceptors / ResultProcessor 32 | type Callbacks struct { 33 | // OnRequest called once HTTP request is received, before attempting to do websocket upgrade or plain request 34 | // execution, consequently before OnConnect as well. 35 | OnRequest func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter) error 36 | 37 | // OnRequestDone called once HTTP request is finished, regardless of request type, with error occurred during 38 | // request execution (if any). 39 | // By default, if error is present, will write error text and return 400 code. 40 | OnRequestDone func(reqctx mutable.Context, r *http.Request, w http.ResponseWriter, origerr error) 41 | 42 | // OnConnect is called once per HTTP request, after websocket upgrade and init message received in case of 43 | // websocket request, or before execution in case of plain request 44 | OnConnect func(reqctx mutable.Context, init apollows.PayloadInit) error 45 | 46 | // OnDisconnect is called once per HTTP request, before OnRequestDone, without responsibility to handle errors 47 | OnDisconnect func(reqctx mutable.Context, origerr error) error 48 | 49 | // OnOperation is called before each operation with original payload, allowing to modify it or terminate 50 | // the operation by returning an error. 51 | OnOperation func(opctx mutable.Context, payload *apollows.PayloadOperation) error 52 | 53 | // OnOperationValidation is called after parsing an operation payload with any immediate validation result, if 54 | // available. AST will be available in context with ContextAST if parsing succeeded. 55 | OnOperationValidation func(opctx mutable.Context, payload *apollows.PayloadOperation, result *graphql.Result) error 56 | 57 | // OnOperationResult is called after operation result is received, allowing to postprocess it or terminate the 58 | // operation before returning the result with error. AST is available in context with ContextAST. 59 | OnOperationResult func(opctx mutable.Context, payload *apollows.PayloadOperation, result *graphql.Result) error 60 | 61 | // OnOperationDone is called once operation is finished, with error occurred during the execution (if any) 62 | // error returned from this handler will close the websocket / terminate HTTP request with error response. 63 | // By default, will pass through any error occurred. AST will be available in context with ContextAST if can be 64 | // parsed. 65 | OnOperationDone func(opctx mutable.Context, payload *apollows.PayloadOperation, origerr error) error 66 | } 67 | 68 | func legacyCallbackExecute(callbacks Callbacks) InterceptorOperationExecute { 69 | return func( 70 | ctx context.Context, 71 | payload *apollows.PayloadOperation, 72 | handler HandlerOperationExecute, 73 | ) (chan *graphql.Result, error) { 74 | cres, err := handler(ctx, payload) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | if callbacks.OnOperationResult == nil { 80 | return cres, nil 81 | } 82 | 83 | ch := make(chan *graphql.Result) 84 | 85 | go func() { 86 | defer close(ch) 87 | 88 | for res := range cres { 89 | err = callbacks.OnOperationResult(OperationContext(ctx), payload, res) 90 | if err != nil { 91 | ch <- &graphql.Result{ 92 | Data: err, 93 | } 94 | 95 | return 96 | } 97 | 98 | ch <- res 99 | } 100 | }() 101 | 102 | return ch, nil 103 | } 104 | } 105 | 106 | func legacyCallbackOperationParse(callbacks Callbacks) InterceptorOperationParse { 107 | return func( 108 | ctx context.Context, 109 | payload *apollows.PayloadOperation, 110 | handler HandlerOperationParse, 111 | ) error { 112 | err := handler(ctx, payload) 113 | 114 | if callbacks.OnOperationValidation != nil { 115 | result, _ := err.(ResultError) 116 | 117 | nerr := callbacks.OnOperationValidation(OperationContext(ctx), payload, result.Result) 118 | if nerr != nil { 119 | return nerr 120 | } 121 | } 122 | 123 | return err 124 | } 125 | } 126 | 127 | func legacyCallbackOperation(callbacks Callbacks) InterceptorOperation { 128 | return func( 129 | ctx context.Context, 130 | payload *apollows.PayloadOperation, 131 | handler HandlerOperation, 132 | ) (err error) { 133 | if callbacks.OnOperation != nil { 134 | err = callbacks.OnOperation(OperationContext(ctx), payload) 135 | } 136 | 137 | if err == nil { 138 | err = handler(ctx, payload) 139 | } 140 | 141 | if callbacks.OnOperationDone != nil { 142 | err = callbacks.OnOperationDone(OperationContext(ctx), payload, err) 143 | } 144 | 145 | return err 146 | } 147 | } 148 | 149 | func legacyCallbackInit(callbacks Callbacks) InterceptorInit { 150 | return func( 151 | ctx context.Context, 152 | init apollows.PayloadInit, 153 | handler HandlerInit, 154 | ) (err error) { 155 | if callbacks.OnConnect != nil { 156 | err = callbacks.OnConnect(RequestContext(ctx), init) 157 | } 158 | 159 | if err == nil { 160 | err = handler(ctx, init) 161 | } 162 | 163 | if callbacks.OnDisconnect != nil { 164 | err = callbacks.OnDisconnect(RequestContext(ctx), err) 165 | } 166 | 167 | return err 168 | } 169 | } 170 | 171 | func legacyCallbackHTTPRequest(callbacks Callbacks) InterceptorHTTPRequest { 172 | return func( 173 | ctx context.Context, 174 | w http.ResponseWriter, 175 | r *http.Request, 176 | handler HandlerHTTPRequest, 177 | ) error { 178 | var err error 179 | 180 | defer func() { 181 | if err != nil { 182 | WriteError(ctx, w, err) 183 | } 184 | }() 185 | 186 | if callbacks.OnRequest != nil { 187 | err = callbacks.OnRequest(RequestContext(ctx), r, w) 188 | } 189 | 190 | if err == nil { 191 | err = handler(ctx, w, r) 192 | } 193 | 194 | if callbacks.OnRequestDone != nil { 195 | callbacks.OnRequestDone(RequestContext(ctx), r, w, err) 196 | } else { 197 | WriteError(ctx, w, err) 198 | } 199 | 200 | return err 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /v1/compat.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import "net/http" 4 | 5 | // Upgrader interface used to upgrade HTTP request/response pair into a Conn 6 | type Upgrader interface { 7 | Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Conn, error) 8 | } 9 | 10 | // Conn interface is used to abstract connection returned from Upgrader 11 | type Conn interface { 12 | ReadJSON(v interface{}) error 13 | WriteJSON(v interface{}) error 14 | Close(code int, message string) error 15 | Subprotocol() string 16 | } 17 | -------------------------------------------------------------------------------- /v1/compat/gorillaws/api.go: -------------------------------------------------------------------------------- 1 | // Package gorillaws provides compatibility for gorilla websocket upgrader 2 | package gorillaws 3 | 4 | import ( 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | // Wrapper for gorilla websocket upgrader 12 | type Wrapper struct { 13 | *websocket.Upgrader 14 | } 15 | 16 | type conn struct { 17 | *websocket.Conn 18 | } 19 | 20 | func (conn conn) Close(code int, message string) error { 21 | origerr := conn.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)) 22 | 23 | err := conn.Conn.Close() 24 | if err == nil { 25 | err = origerr 26 | } 27 | 28 | return err 29 | } 30 | 31 | func (conn conn) Subprotocol() string { 32 | return conn.Conn.Subprotocol() 33 | } 34 | 35 | // Upgrade implementation 36 | func (g Wrapper) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (wsgraphql.Conn, error) { 37 | c, err := g.Upgrader.Upgrade(w, r, responseHeader) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | return conn{ 43 | Conn: c, 44 | }, nil 45 | } 46 | 47 | // Wrap gorilla upgrader into wsgraphql-compatible interface 48 | func Wrap(upgrader *websocket.Upgrader) Wrapper { 49 | return Wrapper{ 50 | Upgrader: upgrader, 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /v1/compat/otelwsgraphql/api.go: -------------------------------------------------------------------------------- 1 | // Package otelwsgraphql provides opentelemetry instrumentation for wsgraphql 2 | package otelwsgraphql 3 | 4 | import ( 5 | "context" 6 | "regexp" 7 | 8 | "github.com/eientei/wsgraphql/v1" 9 | "github.com/eientei/wsgraphql/v1/apollows" 10 | "go.opentelemetry.io/otel" 11 | "go.opentelemetry.io/otel/attribute" 12 | "go.opentelemetry.io/otel/codes" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | const ( 18 | instrumentationName = "github.com/eientei/wsgraphql/v1/compat/otelwsgraphql" 19 | instrumentationVersion = "1.0.0" 20 | ) 21 | 22 | const ( 23 | operationQuery = "query" 24 | operationMutation = "mutation" 25 | operationSubscription = "subscription" 26 | ) 27 | 28 | // OperationOption provides customizations for operation interceptor 29 | type OperationOption interface { 30 | applyOperation(*operationConfig) 31 | } 32 | 33 | // NewOperationInterceptor returns new otel-span reporting wsgraphql operation interceptor 34 | func NewOperationInterceptor(options ...OperationOption) wsgraphql.InterceptorOperation { 35 | var c operationConfig 36 | 37 | defaultOptions := []OperationOption{ 38 | WithSpanNameResolver(DefaultSpanNameResolver), 39 | WithSpanAttributesResolver(DefaultSpanAttributesResolver), 40 | WithStartSpanOptions(trace.WithSpanKind(trace.SpanKindServer)), 41 | } 42 | 43 | for _, o := range append(defaultOptions, options...) { 44 | o.applyOperation(&c) 45 | } 46 | 47 | return func(ctx context.Context, payload *apollows.PayloadOperation, handler wsgraphql.HandlerOperation) error { 48 | tracer := c.tracer 49 | 50 | if tracer == nil { 51 | if c.tracerProvider != nil { 52 | tracer = newTracer(c.tracerProvider) 53 | } else if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { 54 | tracer = newTracer(span.TracerProvider()) 55 | } else { 56 | tracer = newTracer(otel.GetTracerProvider()) 57 | } 58 | } 59 | 60 | opts := append( 61 | []trace.SpanStartOption{trace.WithAttributes(c.attributesResolver(ctx, payload)...)}, 62 | c.startSpanOptions..., 63 | ) 64 | 65 | ctx, span := tracer.Start(ctx, c.nameResolver(ctx, payload), opts...) 66 | 67 | err := handler(ctx, payload) 68 | 69 | if err != nil { 70 | span.RecordError(err) 71 | span.SetStatus(codes.Error, err.Error()) 72 | } else { 73 | span.SetStatus(codes.Ok, "") 74 | } 75 | 76 | span.End() 77 | 78 | return err 79 | } 80 | } 81 | 82 | // SpanNameResolver determined span name from payload operation 83 | type SpanNameResolver func(ctx context.Context, payload *apollows.PayloadOperation) string 84 | 85 | // SpanAttributesResolver determines span attributes from payload operation 86 | type SpanAttributesResolver func(ctx context.Context, payload *apollows.PayloadOperation) []attribute.KeyValue 87 | 88 | // SpanASTAttributesResolver determines span attributes from payload operation after AST parsing 89 | type SpanASTAttributesResolver func(ctx context.Context, payload *apollows.PayloadOperation) []attribute.KeyValue 90 | 91 | // WithTracer provides predefined tracer instance 92 | func WithTracer(tracer trace.Tracer) OperationOption { 93 | return optionFunc(func(c *operationConfig) { 94 | c.tracer = tracer 95 | }) 96 | } 97 | 98 | // WithTracerProvider sets predefined tracer provider instance 99 | func WithTracerProvider(tracerProvider trace.TracerProvider) OperationOption { 100 | return optionFunc(func(c *operationConfig) { 101 | c.tracerProvider = tracerProvider 102 | }) 103 | } 104 | 105 | // WithStartSpanOptions provides extra starting span options 106 | func WithStartSpanOptions(spanOptions ...trace.SpanStartOption) OperationOption { 107 | return optionFunc(func(c *operationConfig) { 108 | c.startSpanOptions = spanOptions 109 | }) 110 | } 111 | 112 | // WithSpanNameResolver provides custom name resolver 113 | func WithSpanNameResolver(resolver SpanNameResolver) OperationOption { 114 | return optionFunc(func(c *operationConfig) { 115 | c.nameResolver = resolver 116 | }) 117 | } 118 | 119 | // WithSpanAttributesResolver provides custom attribute resolver 120 | func WithSpanAttributesResolver(resolver SpanAttributesResolver) OperationOption { 121 | return optionFunc(func(c *operationConfig) { 122 | c.attributesResolver = resolver 123 | }) 124 | } 125 | 126 | var queryRegex = regexp.MustCompile(`(query|mutation|subscription)\s*(\w*)`) 127 | 128 | // DefaultSpanNameResolver default span name resolver function 129 | func DefaultSpanNameResolver(_ context.Context, payload *apollows.PayloadOperation) string { 130 | parts := queryRegex.FindStringSubmatch(payload.Query) 131 | name := payload.OperationName 132 | kind := operationQuery 133 | 134 | if len(parts) == 3 { 135 | kind = parts[1] 136 | 137 | if name == "" { 138 | name = parts[2] 139 | } 140 | } 141 | 142 | if name == "" { 143 | return "gql." + kind 144 | } 145 | 146 | return "gql." + kind + "." + name 147 | } 148 | 149 | // DefaultSpanAttributesResolver default span attributes resolver function 150 | func DefaultSpanAttributesResolver( 151 | _ context.Context, 152 | payload *apollows.PayloadOperation, 153 | ) (attrs []attribute.KeyValue) { 154 | parts := queryRegex.FindStringSubmatch(payload.Query) 155 | operationName := payload.OperationName 156 | kind := operationQuery 157 | 158 | if len(parts) == 3 { 159 | switch parts[1] { 160 | case operationSubscription, operationMutation: 161 | kind = parts[1] 162 | } 163 | 164 | if operationName == "" { 165 | operationName = parts[2] 166 | } 167 | } 168 | 169 | switch kind { 170 | case operationSubscription: 171 | attrs = append(attrs, semconv.GraphqlOperationTypeSubscription) 172 | case operationMutation: 173 | attrs = append(attrs, semconv.GraphqlOperationTypeMutation) 174 | case operationQuery: 175 | attrs = append(attrs, semconv.GraphqlOperationTypeQuery) 176 | } 177 | 178 | if operationName != "" { 179 | attrs = append(attrs, semconv.GraphqlOperationName(operationName)) 180 | } 181 | 182 | attrs = append(attrs, semconv.GraphqlDocument(payload.Query)) 183 | 184 | return 185 | } 186 | 187 | type operationConfig struct { 188 | nameResolver SpanNameResolver 189 | attributesResolver SpanAttributesResolver 190 | tracer trace.Tracer 191 | tracerProvider trace.TracerProvider 192 | startSpanOptions []trace.SpanStartOption 193 | } 194 | 195 | type optionFunc func(c *operationConfig) 196 | 197 | func (o optionFunc) applyOperation(c *operationConfig) { 198 | o(c) 199 | } 200 | 201 | func newTracer(provider trace.TracerProvider) trace.Tracer { 202 | return provider.Tracer(instrumentationName, trace.WithInstrumentationVersion(instrumentationVersion)) 203 | } 204 | -------------------------------------------------------------------------------- /v1/context.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/graphql-go/graphql" 8 | 9 | "github.com/eientei/wsgraphql/v1/mutable" 10 | "github.com/graphql-go/graphql/language/ast" 11 | ) 12 | 13 | type ( 14 | contextKeyRequestContextT struct{} 15 | contextKeyOperationContextT struct{} 16 | contextKeyOperationStoppedT struct{} 17 | contextKeyOperationExecutedT struct{} 18 | contextKeyOperationIDT struct{} 19 | contextKeyOperationParamsT struct{} 20 | contextKeyAstT struct{} 21 | contextKeySubscriptionT struct{} 22 | contextKeyHTTPRequestT struct{} 23 | contextKeyHTTPResponseWriterT struct{} 24 | contextKeyHTTPResponseStartedT struct{} 25 | contextKeyWebsocketConnectionT struct{} 26 | ) 27 | 28 | var ( 29 | // ContextKeyRequestContext used to store HTTP request-scoped mutable.Context 30 | ContextKeyRequestContext = contextKeyRequestContextT{} 31 | 32 | // ContextKeyOperationContext used to store graphql operation-scoped mutable.Context 33 | ContextKeyOperationContext = contextKeyOperationContextT{} 34 | 35 | // ContextKeyOperationStopped indicates the operation was stopped on client request 36 | ContextKeyOperationStopped = contextKeyOperationStoppedT{} 37 | 38 | // ContextKeyOperationExecuted indicates the operation was executed 39 | ContextKeyOperationExecuted = contextKeyOperationExecutedT{} 40 | 41 | // ContextKeyOperationID indicates the operation ID 42 | ContextKeyOperationID = contextKeyOperationIDT{} 43 | 44 | // ContextKeyOperationParams used to store operation params 45 | ContextKeyOperationParams = contextKeyOperationParamsT{} 46 | 47 | // ContextKeyAST used to store operation's ast.Document (abstract syntax tree) 48 | ContextKeyAST = contextKeyAstT{} 49 | 50 | // ContextKeySubscription used to store operation subscription flag 51 | ContextKeySubscription = contextKeySubscriptionT{} 52 | 53 | // ContextKeyHTTPRequest used to store HTTP request 54 | ContextKeyHTTPRequest = contextKeyHTTPRequestT{} 55 | 56 | // ContextKeyHTTPResponseWriter used to store HTTP response 57 | ContextKeyHTTPResponseWriter = contextKeyHTTPResponseWriterT{} 58 | 59 | // ContextKeyHTTPResponseStarted used to indicate HTTP response already has headers sent 60 | ContextKeyHTTPResponseStarted = contextKeyHTTPResponseStartedT{} 61 | 62 | // ContextKeyWebsocketConnection used to store websocket connection 63 | ContextKeyWebsocketConnection = contextKeyWebsocketConnectionT{} 64 | ) 65 | 66 | func defaultMutcontext(ctx context.Context, mutctx mutable.Context) mutable.Context { 67 | if mutctx != nil { 68 | return mutctx 69 | } 70 | 71 | return mutable.NewMutableContext(ctx) 72 | } 73 | 74 | // RequestContext returns HTTP request-scoped v1.mutable context from provided context or nil if none present 75 | func RequestContext(ctx context.Context) (mutctx mutable.Context) { 76 | defer func() { 77 | mutctx = defaultMutcontext(ctx, mutctx) 78 | }() 79 | 80 | v := ctx.Value(ContextKeyRequestContext) 81 | if v == nil { 82 | return nil 83 | } 84 | 85 | mutctx, ok := v.(mutable.Context) 86 | if !ok { 87 | return nil 88 | } 89 | 90 | return mutctx 91 | } 92 | 93 | // OperationContext returns graphql operation-scoped v1.mutable context from provided context or nil if none present 94 | func OperationContext(ctx context.Context) (mutctx mutable.Context) { 95 | defer func() { 96 | mutctx = defaultMutcontext(ctx, mutctx) 97 | }() 98 | 99 | v := ctx.Value(ContextKeyOperationContext) 100 | if v == nil { 101 | return nil 102 | } 103 | 104 | mutctx, ok := v.(mutable.Context) 105 | if !ok { 106 | return nil 107 | } 108 | 109 | return mutctx 110 | } 111 | 112 | // ContextOperationStopped returns true if user requested operation stop 113 | func ContextOperationStopped(ctx context.Context) bool { 114 | v := ctx.Value(ContextKeyOperationStopped) 115 | if v == nil { 116 | return false 117 | } 118 | 119 | res, ok := v.(bool) 120 | if !ok { 121 | return false 122 | } 123 | 124 | return res 125 | } 126 | 127 | // ContextOperationExecuted returns true if user requested operation stop 128 | func ContextOperationExecuted(ctx context.Context) bool { 129 | v := ctx.Value(ContextKeyOperationExecuted) 130 | if v == nil { 131 | return false 132 | } 133 | 134 | res, ok := v.(bool) 135 | if !ok { 136 | return false 137 | } 138 | 139 | return res 140 | } 141 | 142 | // ContextOperationID returns operaion ID stored in the context 143 | func ContextOperationID(ctx context.Context) string { 144 | v := ctx.Value(ContextKeyOperationID) 145 | if v == nil { 146 | return "" 147 | } 148 | 149 | res, ok := v.(string) 150 | if !ok { 151 | return "" 152 | } 153 | 154 | return res 155 | } 156 | 157 | // ContextOperationParams returns operation params stored in the context 158 | func ContextOperationParams(ctx context.Context) (res *graphql.Params) { 159 | v := ctx.Value(ContextKeyOperationParams) 160 | if v == nil { 161 | return &graphql.Params{} 162 | } 163 | 164 | r, ok := v.(*graphql.Params) 165 | if !ok || r == nil { 166 | return &graphql.Params{} 167 | } 168 | 169 | return r 170 | } 171 | 172 | // ContextAST returns operation's abstract syntax tree document 173 | func ContextAST(ctx context.Context) *ast.Document { 174 | v := ctx.Value(ContextKeyAST) 175 | if v == nil { 176 | return nil 177 | } 178 | 179 | astdoc, ok := v.(*ast.Document) 180 | if !ok { 181 | return nil 182 | } 183 | 184 | return astdoc 185 | } 186 | 187 | // ContextSubscription returns operation's subscription flag 188 | func ContextSubscription(ctx context.Context) bool { 189 | v := ctx.Value(ContextKeySubscription) 190 | if v == nil { 191 | return false 192 | } 193 | 194 | sub, ok := v.(bool) 195 | if !ok { 196 | return false 197 | } 198 | 199 | return sub 200 | } 201 | 202 | // ContextHTTPRequest returns http request stored in a context 203 | func ContextHTTPRequest(ctx context.Context) *http.Request { 204 | v := ctx.Value(ContextKeyHTTPRequest) 205 | if v == nil { 206 | return nil 207 | } 208 | 209 | req, ok := v.(*http.Request) 210 | if !ok { 211 | return nil 212 | } 213 | 214 | return req 215 | } 216 | 217 | // ContextHTTPResponseWriter returns http response writer stored in a context 218 | func ContextHTTPResponseWriter(ctx context.Context) http.ResponseWriter { 219 | v := ctx.Value(ContextKeyHTTPResponseWriter) 220 | if v == nil { 221 | return nil 222 | } 223 | 224 | req, ok := v.(http.ResponseWriter) 225 | if !ok { 226 | return nil 227 | } 228 | 229 | return req 230 | } 231 | 232 | // ContextHTTPResponseStarted returns true if HTTP response has already headers sent 233 | func ContextHTTPResponseStarted(ctx context.Context) bool { 234 | v := ctx.Value(ContextKeyHTTPResponseStarted) 235 | if v == nil { 236 | return false 237 | } 238 | 239 | val, ok := v.(bool) 240 | if !ok { 241 | return false 242 | } 243 | 244 | return val 245 | } 246 | 247 | // ContextWebsocketConnection returns websocket connection stored in a context 248 | func ContextWebsocketConnection(ctx context.Context) Conn { 249 | v := ctx.Value(ContextKeyWebsocketConnection) 250 | if v == nil { 251 | return nil 252 | } 253 | 254 | conn, ok := v.(Conn) 255 | if !ok { 256 | return nil 257 | } 258 | 259 | return conn 260 | } 261 | -------------------------------------------------------------------------------- /v1/context_test.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/eientei/wsgraphql/v1/mutable" 9 | "github.com/graphql-go/graphql/language/ast" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRequestContext(t *testing.T) { 14 | reqctx := mutable.NewMutableContext(context.Background()) 15 | 16 | mutctx := mutable.NewMutableContext(context.Background()) 17 | 18 | r := RequestContext(mutctx) 19 | 20 | assert.NotNil(t, r) 21 | assert.NotEqual(t, reqctx, r) 22 | 23 | mutctx.Set(ContextKeyRequestContext, reqctx) 24 | 25 | assert.Equal(t, reqctx, RequestContext(mutctx)) 26 | 27 | mutctx.Set(ContextKeyRequestContext, 123) 28 | 29 | r = RequestContext(mutctx) 30 | 31 | assert.NotNil(t, r) 32 | assert.NotEqual(t, reqctx, r) 33 | } 34 | 35 | func TestOperationContext(t *testing.T) { 36 | reqctx := mutable.NewMutableContext(context.Background()) 37 | 38 | mutctx := mutable.NewMutableContext(context.Background()) 39 | 40 | r := OperationContext(mutctx) 41 | 42 | assert.NotNil(t, r) 43 | assert.NotEqual(t, reqctx, r) 44 | 45 | mutctx.Set(ContextKeyOperationContext, reqctx) 46 | 47 | assert.Equal(t, reqctx, OperationContext(mutctx)) 48 | 49 | mutctx.Set(ContextKeyOperationContext, 123) 50 | 51 | r = OperationContext(mutctx) 52 | 53 | assert.NotNil(t, r) 54 | assert.NotEqual(t, reqctx, r) 55 | } 56 | 57 | func TestContextHTTPResponseWriter(t *testing.T) { 58 | var h struct { 59 | http.ResponseWriter 60 | } 61 | 62 | mutctx := mutable.NewMutableContext(context.Background()) 63 | 64 | assert.Nil(t, ContextHTTPResponseWriter(mutctx)) 65 | 66 | mutctx.Set(ContextKeyHTTPResponseWriter, h) 67 | 68 | assert.Equal(t, h, ContextHTTPResponseWriter(mutctx)) 69 | 70 | mutctx.Set(ContextKeyHTTPResponseWriter, 123) 71 | 72 | assert.Nil(t, ContextHTTPResponseWriter(mutctx)) 73 | } 74 | 75 | func TestContextHTTPResponseStarted(t *testing.T) { 76 | mutctx := mutable.NewMutableContext(context.Background()) 77 | 78 | assert.Equal(t, false, ContextHTTPResponseStarted(mutctx)) 79 | 80 | mutctx.Set(ContextKeyHTTPResponseStarted, true) 81 | 82 | assert.Equal(t, true, ContextHTTPResponseStarted(mutctx)) 83 | 84 | mutctx.Set(ContextKeyHTTPResponseStarted, 123) 85 | 86 | assert.Equal(t, false, ContextHTTPResponseStarted(mutctx)) 87 | } 88 | 89 | func TestContextHTTPRequest(t *testing.T) { 90 | h := &http.Request{} 91 | 92 | mutctx := mutable.NewMutableContext(context.Background()) 93 | 94 | assert.Nil(t, ContextHTTPRequest(mutctx)) 95 | 96 | mutctx.Set(ContextKeyHTTPRequest, h) 97 | 98 | assert.Equal(t, h, ContextHTTPRequest(mutctx)) 99 | 100 | mutctx.Set(ContextKeyHTTPRequest, 123) 101 | 102 | assert.Nil(t, ContextHTTPRequest(mutctx)) 103 | } 104 | 105 | func TestContextSubscription(t *testing.T) { 106 | mutctx := mutable.NewMutableContext(context.Background()) 107 | 108 | assert.Equal(t, false, ContextSubscription(mutctx)) 109 | 110 | mutctx.Set(ContextKeySubscription, true) 111 | 112 | assert.Equal(t, true, ContextSubscription(mutctx)) 113 | 114 | mutctx.Set(ContextKeySubscription, 123) 115 | 116 | assert.Equal(t, false, ContextSubscription(mutctx)) 117 | } 118 | 119 | func TestContextAST(t *testing.T) { 120 | doc := &ast.Document{} 121 | 122 | mutctx := mutable.NewMutableContext(context.Background()) 123 | 124 | assert.Nil(t, ContextAST(mutctx)) 125 | 126 | mutctx.Set(ContextKeyAST, doc) 127 | 128 | assert.Equal(t, doc, ContextAST(mutctx)) 129 | 130 | mutctx.Set(ContextKeyAST, 123) 131 | 132 | assert.Nil(t, ContextAST(mutctx)) 133 | } 134 | 135 | func TestContextWebsocketConnection(t *testing.T) { 136 | var conn struct { 137 | Conn 138 | } 139 | 140 | mutctx := mutable.NewMutableContext(context.Background()) 141 | 142 | assert.Nil(t, ContextWebsocketConnection(mutctx)) 143 | 144 | mutctx.Set(ContextKeyWebsocketConnection, conn) 145 | 146 | assert.Equal(t, conn, ContextWebsocketConnection(mutctx)) 147 | 148 | mutctx.Set(ContextKeyWebsocketConnection, 123) 149 | 150 | assert.Nil(t, ContextWebsocketConnection(mutctx)) 151 | } 152 | 153 | func TestContextOperationStopped(t *testing.T) { 154 | mutctx := mutable.NewMutableContext(context.Background()) 155 | 156 | assert.Equal(t, false, ContextOperationStopped(mutctx)) 157 | 158 | mutctx.Set(ContextKeyOperationStopped, true) 159 | 160 | assert.Equal(t, true, ContextOperationStopped(mutctx)) 161 | 162 | mutctx.Set(ContextKeyOperationStopped, 123) 163 | 164 | assert.Equal(t, false, ContextOperationStopped(mutctx)) 165 | } 166 | -------------------------------------------------------------------------------- /v1/error.go: -------------------------------------------------------------------------------- 1 | package wsgraphql 2 | 3 | import ( 4 | "github.com/graphql-go/graphql/gqlerrors" 5 | "github.com/graphql-go/graphql/language/location" 6 | ) 7 | 8 | func wrapExtendedError(err error, loc []location.SourceLocation) error { 9 | _, ok := err.(gqlerrors.ExtendedError) 10 | if ok { 11 | return &gqlerrors.Error{ 12 | Message: err.Error(), 13 | OriginalError: err, 14 | Locations: loc, 15 | } 16 | } 17 | 18 | return err 19 | } 20 | 21 | // FormatError returns error formatted as graphql error 22 | func FormatError(err error) gqlerrors.FormattedError { 23 | var loc []location.SourceLocation 24 | 25 | fmterr, ok := err.(gqlerrors.FormattedError) 26 | if ok { 27 | err = fmterr.OriginalError() 28 | loc = fmterr.Locations 29 | } 30 | 31 | _, ok = err.(*gqlerrors.Error) 32 | if ok { 33 | return gqlerrors.FormatError(err) 34 | } 35 | 36 | return gqlerrors.FormatError(wrapExtendedError(err, loc)) 37 | } 38 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-transport-ws/README.md: -------------------------------------------------------------------------------- 1 | Minimal server example 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | Graphql endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | There is no playground in this example, but you can try following query with graphql client of your choice: 13 | 14 | ```graphql 15 | query { 16 | getFoo 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-transport-ws/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 9 | "github.com/gorilla/websocket" 10 | "github.com/graphql-go/graphql" 11 | ) 12 | 13 | func main() { 14 | var addr string 15 | 16 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 17 | flag.Parse() 18 | 19 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 20 | Query: graphql.NewObject(graphql.ObjectConfig{ 21 | Name: "QueryRoot", 22 | Fields: graphql.Fields{ 23 | "getFoo": &graphql.Field{ 24 | Description: "Returns most recent foo value", 25 | Type: graphql.NewNonNull(graphql.Int), 26 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 27 | return 123, nil 28 | }, 29 | }, 30 | }, 31 | }), 32 | }) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | srv, err := wsgraphql.NewServer( 38 | schema, 39 | wsgraphql.WithProtocol(wsgraphql.WebsocketSubprotocolGraphqlTransportWS), 40 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 41 | Subprotocols: []string{wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String()}, 42 | })), 43 | ) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | http.Handle("/query", srv) 49 | 50 | err = http.ListenAndServe(addr, nil) 51 | if err != nil { 52 | panic(err) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-ws/README.md: -------------------------------------------------------------------------------- 1 | Minimal server example 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | Graphql endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | There is no playground in this example, but you can try following query with graphql client of your choice: 13 | 14 | ```graphql 15 | query { 16 | getFoo 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /v1/examples/minimal-graphql-ws/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | 7 | "github.com/eientei/wsgraphql/v1" 8 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 9 | "github.com/gorilla/websocket" 10 | "github.com/graphql-go/graphql" 11 | ) 12 | 13 | func main() { 14 | var addr string 15 | 16 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 17 | flag.Parse() 18 | 19 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 20 | Query: graphql.NewObject(graphql.ObjectConfig{ 21 | Name: "QueryRoot", 22 | Fields: graphql.Fields{ 23 | "getFoo": &graphql.Field{ 24 | Description: "Returns most recent foo value", 25 | Type: graphql.NewNonNull(graphql.Int), 26 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 27 | return 123, nil 28 | }, 29 | }, 30 | }, 31 | }), 32 | }) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | srv, err := wsgraphql.NewServer( 38 | schema, 39 | wsgraphql.WithProtocol(wsgraphql.WebsocketSubprotocolGraphqlWS), 40 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 41 | Subprotocols: []string{wsgraphql.WebsocketSubprotocolGraphqlWS.String()}, 42 | })), 43 | ) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | http.Handle("/query", srv) 49 | 50 | err = http.ListenAndServe(addr, nil) 51 | if err != nil { 52 | panic(err) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/README.md: -------------------------------------------------------------------------------- 1 | Complete server example, with query, mutation and subscriptions. 2 | 3 | Running 4 | ------- 5 | 6 | ```go 7 | go run main.go -addr :8080 8 | ``` 9 | 10 | GraphQL endpoint will be available at http://127.0.0.1:8080/query 11 | 12 | Navigate to playground on http://127.0.0.1:8080 13 | 14 | And try following queries: 15 | 16 | ```graphql 17 | subscription { 18 | fooUpdates 19 | } 20 | ``` 21 | 22 | Then in new tab(s) 23 | 24 | ```graphql 25 | query { 26 | getFoo 27 | } 28 | ``` 29 | 30 | ```graphql 31 | mutation { 32 | setFoo(value: 123) 33 | } 34 | ``` 35 | 36 | ```graphql 37 | query { 38 | getFoo 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | _ "embed" 7 | "flag" 8 | "fmt" 9 | "net/http" 10 | "sync/atomic" 11 | "time" 12 | 13 | "github.com/eientei/wsgraphql/v1" 14 | "github.com/eientei/wsgraphql/v1/compat/gorillaws" 15 | "github.com/gorilla/websocket" 16 | "github.com/graphql-go/graphql" 17 | ) 18 | 19 | //go:embed playground.html 20 | var playgroundFile []byte 21 | 22 | func main() { 23 | var addr string 24 | 25 | flag.StringVar(&addr, "addr", ":8080", "Address to listen on") 26 | flag.Parse() 27 | 28 | var foo int 29 | 30 | fooupdates := make(chan int, 1) 31 | 32 | var subscriberID uint64 33 | 34 | type subscriber struct { 35 | subscription chan interface{} 36 | ctx context.Context 37 | id uint64 38 | } 39 | 40 | subscribers := make(map[uint64]*subscriber) 41 | subscriberadd := make(chan *subscriber, 1) 42 | subscriberrem := make(chan uint64, 1) 43 | 44 | go func() { 45 | for { 46 | select { 47 | case upd := <-fooupdates: 48 | foo = upd 49 | 50 | fmt.Println("broadcasting update, new value:", upd) 51 | 52 | for _, sub := range subscribers { 53 | select { 54 | case sub.subscription <- upd: 55 | case <-sub.ctx.Done(): 56 | } 57 | } 58 | case add := <-subscriberadd: 59 | subscribers[add.id] = add 60 | 61 | fmt.Println("added subscriber", add.id) 62 | case rem := <-subscriberrem: 63 | close(subscribers[rem].subscription) 64 | 65 | delete(subscribers, rem) 66 | 67 | fmt.Println("removed subscriber", rem) 68 | } 69 | } 70 | }() 71 | 72 | schema, err := graphql.NewSchema(graphql.SchemaConfig{ 73 | Query: graphql.NewObject(graphql.ObjectConfig{ 74 | Name: "QueryRoot", 75 | Fields: graphql.Fields{ 76 | "getFoo": &graphql.Field{ 77 | Description: "Returns most recent foo value", 78 | Type: graphql.NewNonNull(graphql.Int), 79 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 80 | return foo, nil 81 | }, 82 | }, 83 | }, 84 | }), 85 | Mutation: graphql.NewObject(graphql.ObjectConfig{ 86 | Name: "MutationRoot", 87 | Fields: graphql.Fields{ 88 | "setFoo": &graphql.Field{ 89 | Args: graphql.FieldConfigArgument{ 90 | "value": &graphql.ArgumentConfig{ 91 | Type: graphql.Int, 92 | }, 93 | }, 94 | Description: "Updates foo value; generating an update to subscribers of fooUpdates", 95 | Type: graphql.NewNonNull(graphql.Boolean), 96 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 97 | v, ok := p.Args["value"].(int) 98 | if ok { 99 | select { 100 | case <-p.Context.Done(): 101 | return nil, p.Context.Err() 102 | case fooupdates <- v: 103 | } 104 | } 105 | 106 | return ok, nil 107 | }, 108 | }, 109 | }, 110 | }), 111 | Subscription: graphql.NewObject(graphql.ObjectConfig{ 112 | Name: "SubscriptionRoot", 113 | Fields: graphql.Fields{ 114 | "fooUpdates": &graphql.Field{ 115 | Description: "Updates generated by setFoo mutation", 116 | Type: graphql.NewNonNull(graphql.Int), 117 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 118 | // values sent on channel, that were returned from `Subscribe`, will be available here as 119 | // `p.Source` 120 | return p.Source, nil 121 | }, 122 | Subscribe: func(p graphql.ResolveParams) (interface{}, error) { 123 | // per graphql-go contract, channel returned from `Subscribe` function must have 124 | // interface{} values 125 | ch := make(chan interface{}, 1) 126 | id := atomic.AddUint64(&subscriberID, 1) 127 | 128 | subscriberadd <- &subscriber{ 129 | id: id, 130 | subscription: ch, 131 | ctx: p.Context, 132 | } 133 | 134 | go func() { 135 | <-p.Context.Done() 136 | 137 | subscriberrem <- id 138 | }() 139 | 140 | return ch, nil 141 | }, 142 | }, 143 | }, 144 | }), 145 | }) 146 | if err != nil { 147 | panic(err) 148 | } 149 | 150 | srv, err := wsgraphql.NewServer( 151 | schema, 152 | wsgraphql.WithKeepalive(time.Second*30), 153 | wsgraphql.WithConnectTimeout(time.Second*30), 154 | wsgraphql.WithUpgrader(gorillaws.Wrap(&websocket.Upgrader{ 155 | ReadBufferSize: 1024, 156 | WriteBufferSize: 1024, 157 | Subprotocols: []string{ 158 | wsgraphql.WebsocketSubprotocolGraphqlWS.String(), 159 | wsgraphql.WebsocketSubprotocolGraphqlTransportWS.String(), 160 | }, 161 | })), 162 | ) 163 | if err != nil { 164 | panic(err) 165 | } 166 | 167 | http.Handle("/query", srv) 168 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 169 | http.ServeContent(writer, request, "playground.html", time.Time{}, bytes.NewReader(playgroundFile)) 170 | }) 171 | 172 | err = http.ListenAndServe(addr, nil) 173 | if err != nil { 174 | panic(err) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /v1/examples/simpleserver/playground.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 | 7 |