├── .gitignore ├── LICENSE ├── README.md ├── dependencies.go ├── directives.go ├── directives_test.go ├── go.mod ├── go.sum ├── graphqlws ├── connections.go ├── logger.go ├── manager.go └── wshandler.go ├── handler ├── README.md ├── graphiql.go ├── handler.go └── playground.go ├── helpers.go ├── registry.go ├── resolvers.go ├── scalars ├── bool_string.go ├── json.go ├── query_document.go └── string_set.go ├── schema.go ├── schema_test.go ├── server ├── graphiql.go ├── graphiql.tmpl ├── graphqlws.go ├── graphqlws │ └── connections.go ├── handler.go ├── logger │ └── logger.go ├── manager.go ├── playground.go ├── playground.tmpl └── server.go ├── typedefs.go ├── typedefs_test.go ├── types.go └── values.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Branden Horiuchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # graphql-go-tools 2 | Like apollo-tools for graphql-go 3 | 4 | 5 | ## Current Tools 6 | 7 | ### `MakeExecutableSchema` 8 | 9 | **Currently supports:** 10 | 11 | * Merge multiple graphql documents 12 | * Object type extending 13 | * Custom Directives 14 | * Import types and directives 15 | 16 | **Planned:** 17 | 18 | * Schema-stitching 19 | 20 | **Limitations:** 21 | 22 | * Only types and directives defined in the `TypeDefs` with schema language can be extended and have custom directives applied. 23 | 24 | ## Example 25 | 26 | ```go 27 | func main() { 28 | schema, err := tools.MakeExecutableSchema(tools.ExecutableSchema{ 29 | TypeDefs: ` 30 | directive @description(value: String!) on FIELD_DEFINITION 31 | 32 | type Foo { 33 | id: ID! 34 | name: String! 35 | description: String 36 | } 37 | 38 | type Query { 39 | foo(id: ID!): Foo @description(value: "bazqux") 40 | }`, 41 | Resolvers: tools.ResolverMap{ 42 | "Query": &tools.ObjectResolver{ 43 | Fields: tools.FieldResolveMap{ 44 | "foo": &tools.FieldResolver{ 45 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 46 | // lookup data 47 | return foo, nil 48 | } 49 | }, 50 | }, 51 | }, 52 | }, 53 | SchemaDirectives: tools.SchemaDirectiveVisitorMap{ 54 | "description": &tools.SchemaDirectiveVisitor{ 55 | VisitFieldDefinition: func(field *graphql.Field, args map[string]interface{}) { 56 | resolveFunc := field.Resolve 57 | field.Resolve = func(p graphql.ResolveParams) (interface{}, error) { 58 | result, err := resolveFunc(p) 59 | if err != nil { 60 | return result, err 61 | } 62 | data := result.(map[string]interface{}) 63 | data["description"] = args["value"] 64 | return data, nil 65 | } 66 | }, 67 | }, 68 | }, 69 | }) 70 | 71 | if err != nil { 72 | log.Fatalf("Failed to build schema, error: %v", err) 73 | } 74 | 75 | params := graphql.Params{ 76 | Schema: schema, 77 | RequestString: ` 78 | query GetFoo { 79 | foo(id: "5cffbf1ccecefcfff659cea8") { 80 | description 81 | } 82 | }`, 83 | } 84 | 85 | r := graphql.Do(params) 86 | if r.HasErrors() { 87 | log.Fatalf("failed to execute graphql operation, errors: %+v", r.Errors) 88 | } 89 | rJSON, _ := json.Marshal(r) 90 | fmt.Printf("%s \n", rJSON) 91 | } 92 | 93 | ``` 94 | 95 | ### Handler 96 | 97 | Modified `graphql-go/handler` with updated GraphiQL and Playground 98 | 99 | See [handler package](handler) -------------------------------------------------------------------------------- /dependencies.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/graphql-go/graphql/language/ast" 7 | "github.com/graphql-go/graphql/language/kinds" 8 | ) 9 | 10 | type DependencyMap map[string]map[string]interface{} 11 | 12 | func (r *registry) IdentifyDependencies() (DependencyMap, error) { 13 | m := DependencyMap{} 14 | 15 | // get list of initial types, all dependencies should be resolved 16 | for _, t := range r.types { 17 | m[t.Name()] = map[string]interface{}{} 18 | } 19 | 20 | for _, def := range r.unresolvedDefs { 21 | switch nodeKind := def.GetKind(); nodeKind { 22 | case kinds.DirectiveDefinition: 23 | if err := identifyDirectiveDependencies(m, def.(*ast.DirectiveDefinition)); err != nil { 24 | return nil, err 25 | } 26 | case kinds.ScalarDefinition: 27 | scalar := def.(*ast.ScalarDefinition) 28 | m[scalar.Name.Value] = map[string]interface{}{} 29 | case kinds.EnumDefinition: 30 | enum := def.(*ast.EnumDefinition) 31 | m[enum.Name.Value] = map[string]interface{}{} 32 | case kinds.InputObjectDefinition: 33 | if err := identifyInputDependencies(m, def.(*ast.InputObjectDefinition)); err != nil { 34 | return nil, err 35 | } 36 | case kinds.ObjectDefinition: 37 | if err := identifyObjectDependencies(m, def.(*ast.ObjectDefinition)); err != nil { 38 | return nil, err 39 | } 40 | case kinds.InterfaceDefinition: 41 | if err := identifyInterfaceDependencies(m, def.(*ast.InterfaceDefinition)); err != nil { 42 | return nil, err 43 | } 44 | case kinds.UnionDefinition: 45 | if err := identifyUnionDependencies(m, def.(*ast.UnionDefinition)); err != nil { 46 | return nil, err 47 | } 48 | case kinds.SchemaDefinition: 49 | identifySchemaDependencies(m, def.(*ast.SchemaDefinition)) 50 | } 51 | } 52 | 53 | // attempt to resolve 54 | resolved := map[string]interface{}{} 55 | maxIteration := len(m) + 1 56 | count := 0 57 | 58 | for count <= maxIteration { 59 | count++ 60 | if len(m) == 0 { 61 | break 62 | } 63 | 64 | for t, deps := range m { 65 | for dep := range deps { 66 | if _, ok := resolved[dep]; ok { 67 | delete(deps, dep) 68 | } 69 | } 70 | 71 | if len(deps) == 0 { 72 | resolved[t] = nil 73 | delete(m, t) 74 | } 75 | } 76 | } 77 | 78 | return m, nil 79 | } 80 | 81 | func isPrimitiveType(t string) bool { 82 | switch t { 83 | case "String", "Int", "Float", "Boolean", "ID": 84 | return true 85 | } 86 | return false 87 | } 88 | 89 | func identifyUnionDependencies(m DependencyMap, def *ast.UnionDefinition) error { 90 | name := def.Name.Value 91 | deps, ok := m[name] 92 | if !ok { 93 | deps = map[string]interface{}{} 94 | } 95 | 96 | for _, t := range def.Types { 97 | typeName, err := identifyRootType(t) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | if !isPrimitiveType(typeName) { 103 | deps[typeName] = nil 104 | } 105 | } 106 | 107 | m[name] = deps 108 | return nil 109 | } 110 | 111 | func identifyInterfaceDependencies(m DependencyMap, def *ast.InterfaceDefinition) error { 112 | name := def.Name.Value 113 | deps, ok := m[name] 114 | if !ok { 115 | deps = map[string]interface{}{} 116 | } 117 | 118 | for _, field := range def.Fields { 119 | for _, arg := range field.Arguments { 120 | typeName, err := identifyRootType(arg.Type) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | if !isPrimitiveType(typeName) { 126 | deps[typeName] = nil 127 | } 128 | } 129 | typeName, err := identifyRootType(field.Type) 130 | if err != nil { 131 | return err 132 | } 133 | 134 | if !isPrimitiveType(typeName) { 135 | deps[typeName] = nil 136 | } 137 | } 138 | 139 | m[name] = deps 140 | return nil 141 | } 142 | 143 | // schema dependencies 144 | func identifySchemaDependencies(m DependencyMap, def *ast.SchemaDefinition) { 145 | deps, ok := m["schema"] 146 | if !ok { 147 | deps = map[string]interface{}{} 148 | } 149 | 150 | for _, op := range def.OperationTypes { 151 | switch op.Operation { 152 | case ast.OperationTypeQuery: 153 | deps[op.Type.Name.Value] = nil 154 | case ast.OperationTypeMutation: 155 | deps[op.Type.Name.Value] = nil 156 | case ast.OperationTypeSubscription: 157 | deps[op.Type.Name.Value] = nil 158 | } 159 | } 160 | 161 | m["schema"] = deps 162 | } 163 | 164 | func identifyRootType(astType ast.Type) (string, error) { 165 | switch kind := astType.GetKind(); kind { 166 | case kinds.List: 167 | t, err := identifyRootType(astType.(*ast.List).Type) 168 | if err != nil { 169 | return "", err 170 | } 171 | return t, nil 172 | case kinds.NonNull: 173 | t, err := identifyRootType(astType.(*ast.NonNull).Type) 174 | if err != nil { 175 | return "", err 176 | } 177 | return t, nil 178 | case kinds.Named: 179 | t := astType.(*ast.Named) 180 | return t.Name.Value, nil 181 | } 182 | 183 | return "", fmt.Errorf("unknown type %v", astType) 184 | } 185 | 186 | // directive dependencies 187 | func identifyDirectiveDependencies(m DependencyMap, def *ast.DirectiveDefinition) error { 188 | name := "@" + def.Name.Value 189 | deps, ok := m[name] 190 | if !ok { 191 | deps = map[string]interface{}{} 192 | } 193 | 194 | for _, arg := range def.Arguments { 195 | typeName, err := identifyRootType(arg.Type) 196 | if err != nil { 197 | return err 198 | } 199 | if !isPrimitiveType(typeName) { 200 | deps[typeName] = nil 201 | } 202 | } 203 | 204 | m[name] = deps 205 | return nil 206 | } 207 | 208 | // gets input object depdendencies 209 | func identifyInputDependencies(m DependencyMap, def *ast.InputObjectDefinition) error { 210 | name := def.Name.Value 211 | deps, ok := m[name] 212 | if !ok { 213 | deps = map[string]interface{}{} 214 | } 215 | 216 | for _, field := range def.Fields { 217 | typeName, err := identifyRootType(field.Type) 218 | if err != nil { 219 | return err 220 | } 221 | 222 | if !isPrimitiveType(typeName) { 223 | deps[typeName] = nil 224 | } 225 | } 226 | 227 | m[name] = deps 228 | return nil 229 | } 230 | 231 | // get object dependencies 232 | func identifyObjectDependencies(m DependencyMap, def *ast.ObjectDefinition) error { 233 | name := def.Name.Value 234 | deps, ok := m[name] 235 | if !ok { 236 | deps = map[string]interface{}{} 237 | } 238 | 239 | for _, field := range def.Fields { 240 | for _, arg := range field.Arguments { 241 | typeName, err := identifyRootType(arg.Type) 242 | if err != nil { 243 | return err 244 | } 245 | 246 | if !isPrimitiveType(typeName) { 247 | deps[typeName] = nil 248 | } 249 | } 250 | typeName, err := identifyRootType(field.Type) 251 | if err != nil { 252 | return err 253 | } 254 | 255 | if !isPrimitiveType(typeName) { 256 | deps[typeName] = nil 257 | } 258 | } 259 | 260 | m[name] = deps 261 | return nil 262 | } 263 | -------------------------------------------------------------------------------- /directives.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/graphql-go/graphql" 8 | "github.com/graphql-go/graphql/language/ast" 9 | ) 10 | 11 | const ( 12 | directiveHide = "hide" 13 | ) 14 | 15 | // HideDirective hides a define field 16 | var HideDirective = graphql.NewDirective(graphql.DirectiveConfig{ 17 | Name: directiveHide, 18 | Description: "Hide a field, useful when generating types from the AST where the backend type has more fields than the graphql type", 19 | Locations: []string{graphql.DirectiveLocationFieldDefinition}, 20 | Args: graphql.FieldConfigArgument{}, 21 | }) 22 | 23 | // SchemaDirectiveVisitor defines a schema visitor. 24 | // This attempts to provide similar functionality to Apollo graphql-tools 25 | // https://www.apollographql.com/docs/graphql-tools/schema-directives/ 26 | type SchemaDirectiveVisitor struct { 27 | VisitSchema func(p VisitSchemaParams) error 28 | VisitScalar func(p VisitScalarParams) error 29 | VisitObject func(p VisitObjectParams) error 30 | VisitFieldDefinition func(p VisitFieldDefinitionParams) error 31 | VisitArgumentDefinition func(p VisitArgumentDefinitionParams) error 32 | VisitInterface func(p VisitInterfaceParams) error 33 | VisitUnion func(p VisitUnionParams) error 34 | VisitEnum func(p VisitEnumParams) error 35 | VisitEnumValue func(p VisitEnumValueParams) error 36 | VisitInputObject func(p VisitInputObjectParams) error 37 | VisitInputFieldDefinition func(p VisitInputFieldDefinitionParams) error 38 | } 39 | 40 | // VisitSchemaParams params 41 | type VisitSchemaParams struct { 42 | Context context.Context 43 | Config *graphql.SchemaConfig 44 | Node *ast.SchemaDefinition 45 | Args map[string]interface{} 46 | } 47 | 48 | // VisitScalarParams params 49 | type VisitScalarParams struct { 50 | Context context.Context 51 | Config *graphql.ScalarConfig 52 | Node *ast.ScalarDefinition 53 | Args map[string]interface{} 54 | } 55 | 56 | // VisitObjectParams params 57 | type VisitObjectParams struct { 58 | Context context.Context 59 | Config *graphql.ObjectConfig 60 | Node *ast.ObjectDefinition 61 | Extensions []*ast.ObjectDefinition 62 | Args map[string]interface{} 63 | } 64 | 65 | // VisitFieldDefinitionParams params 66 | type VisitFieldDefinitionParams struct { 67 | Context context.Context 68 | Config *graphql.Field 69 | Node *ast.FieldDefinition 70 | Args map[string]interface{} 71 | ParentName string 72 | ParentKind string 73 | } 74 | 75 | // VisitArgumentDefinitionParams params 76 | type VisitArgumentDefinitionParams struct { 77 | Context context.Context 78 | Config *graphql.ArgumentConfig 79 | Node *ast.InputValueDefinition 80 | Args map[string]interface{} 81 | } 82 | 83 | // VisitInterfaceParams params 84 | type VisitInterfaceParams struct { 85 | Context context.Context 86 | Config *graphql.InterfaceConfig 87 | Node *ast.InterfaceDefinition 88 | Args map[string]interface{} 89 | } 90 | 91 | // VisitUnionParams params 92 | type VisitUnionParams struct { 93 | Context context.Context 94 | Config *graphql.UnionConfig 95 | Node *ast.UnionDefinition 96 | Args map[string]interface{} 97 | } 98 | 99 | // VisitEnumParams params 100 | type VisitEnumParams struct { 101 | Context context.Context 102 | Config *graphql.EnumConfig 103 | Node *ast.EnumDefinition 104 | Args map[string]interface{} 105 | } 106 | 107 | // VisitEnumValueParams params 108 | type VisitEnumValueParams struct { 109 | Context context.Context 110 | Config *graphql.EnumValueConfig 111 | Node *ast.EnumValueDefinition 112 | Args map[string]interface{} 113 | } 114 | 115 | // VisitInputObjectParams params 116 | type VisitInputObjectParams struct { 117 | Context context.Context 118 | Config *graphql.InputObjectConfig 119 | Node *ast.InputObjectDefinition 120 | Args map[string]interface{} 121 | } 122 | 123 | // VisitInputFieldDefinitionParams params 124 | type VisitInputFieldDefinitionParams struct { 125 | Context context.Context 126 | Config *graphql.InputObjectFieldConfig 127 | Node *ast.InputValueDefinition 128 | Args map[string]interface{} 129 | } 130 | 131 | // SchemaDirectiveVisitorMap a map of schema directive visitors 132 | type SchemaDirectiveVisitorMap map[string]*SchemaDirectiveVisitor 133 | 134 | // DirectiveMap a map of directives 135 | type DirectiveMap map[string]*graphql.Directive 136 | 137 | // converts the directive map to an array 138 | func (c *registry) directiveArray() []*graphql.Directive { 139 | a := make([]*graphql.Directive, 0) 140 | for _, d := range c.directives { 141 | a = append(a, d) 142 | } 143 | return a 144 | } 145 | 146 | // builds directives from ast 147 | func (c *registry) buildDirectiveFromAST(definition *ast.DirectiveDefinition) error { 148 | name := definition.Name.Value 149 | directiveConfig := graphql.DirectiveConfig{ 150 | Name: name, 151 | Description: getDescription(definition), 152 | Args: graphql.FieldConfigArgument{}, 153 | Locations: []string{}, 154 | } 155 | 156 | for _, arg := range definition.Arguments { 157 | if argValue, err := c.buildArgFromAST(arg); err == nil { 158 | directiveConfig.Args[arg.Name.Value] = argValue 159 | } else { 160 | return err 161 | } 162 | } 163 | 164 | for _, loc := range definition.Locations { 165 | directiveConfig.Locations = append(directiveConfig.Locations, loc.Value) 166 | } 167 | 168 | c.directives[name] = graphql.NewDirective(directiveConfig) 169 | return nil 170 | } 171 | 172 | type applyDirectiveParams struct { 173 | config interface{} 174 | directives []*ast.Directive 175 | node interface{} 176 | extensions []*ast.ObjectDefinition 177 | parentName string 178 | parentKind string 179 | } 180 | 181 | // applies directives 182 | func (c *registry) applyDirectives(p applyDirectiveParams) error { 183 | if c.directiveMap == nil { 184 | return nil 185 | } 186 | 187 | for _, def := range p.directives { 188 | name := def.Name.Value 189 | visitor, hasVisitor := c.directiveMap[name] 190 | if !hasVisitor { 191 | continue 192 | } 193 | 194 | directive, err := c.getDirective(name) 195 | if err != nil { 196 | return err 197 | } 198 | 199 | args, err := GetArgumentValues(directive.Args, def.Arguments, map[string]interface{}{}) 200 | if err != nil { 201 | return fmt.Errorf("applyDirective %v error: %s", p.config, err) 202 | } 203 | 204 | switch p.config.(type) { 205 | case *graphql.SchemaConfig: 206 | if visitor.VisitSchema != nil { 207 | if err := visitor.VisitSchema(VisitSchemaParams{ 208 | Context: c.ctx, 209 | Config: p.config.(*graphql.SchemaConfig), 210 | Args: args, 211 | Node: p.node.(*ast.SchemaDefinition), 212 | }); err != nil { 213 | return err 214 | } 215 | } 216 | case *graphql.ScalarConfig: 217 | if visitor.VisitScalar != nil { 218 | if err := visitor.VisitScalar(VisitScalarParams{ 219 | Context: c.ctx, 220 | Config: p.config.(*graphql.ScalarConfig), 221 | Args: args, 222 | Node: p.node.(*ast.ScalarDefinition), 223 | }); err != nil { 224 | return err 225 | } 226 | } 227 | case *graphql.ObjectConfig: 228 | if visitor.VisitObject != nil { 229 | if err := visitor.VisitObject(VisitObjectParams{ 230 | Context: c.ctx, 231 | Config: p.config.(*graphql.ObjectConfig), 232 | Args: args, 233 | Node: p.node.(*ast.ObjectDefinition), 234 | Extensions: p.extensions, 235 | }); err != nil { 236 | return err 237 | } 238 | } 239 | case *graphql.Field: 240 | if visitor.VisitFieldDefinition != nil { 241 | if err := visitor.VisitFieldDefinition(VisitFieldDefinitionParams{ 242 | Context: c.ctx, 243 | Config: p.config.(*graphql.Field), 244 | Args: args, 245 | Node: p.node.(*ast.FieldDefinition), 246 | ParentName: p.parentName, 247 | ParentKind: p.parentKind, 248 | }); err != nil { 249 | return err 250 | } 251 | } 252 | case *graphql.ArgumentConfig: 253 | if visitor.VisitArgumentDefinition != nil { 254 | if err := visitor.VisitArgumentDefinition(VisitArgumentDefinitionParams{ 255 | Context: c.ctx, 256 | Config: p.config.(*graphql.ArgumentConfig), 257 | Args: args, 258 | Node: p.node.(*ast.InputValueDefinition), 259 | }); err != nil { 260 | return err 261 | } 262 | } 263 | case *graphql.InterfaceConfig: 264 | if visitor.VisitInterface != nil { 265 | if err := visitor.VisitInterface(VisitInterfaceParams{ 266 | Context: c.ctx, 267 | Config: p.config.(*graphql.InterfaceConfig), 268 | Args: args, 269 | Node: p.node.(*ast.InterfaceDefinition), 270 | }); err != nil { 271 | return err 272 | } 273 | } 274 | case *graphql.UnionConfig: 275 | if visitor.VisitUnion != nil { 276 | if err := visitor.VisitUnion(VisitUnionParams{ 277 | Context: c.ctx, 278 | Config: p.config.(*graphql.UnionConfig), 279 | Args: args, 280 | Node: p.node.(*ast.UnionDefinition), 281 | }); err != nil { 282 | return err 283 | } 284 | } 285 | case *graphql.EnumConfig: 286 | if visitor.VisitEnum != nil { 287 | if err := visitor.VisitEnum(VisitEnumParams{ 288 | Context: c.ctx, 289 | Config: p.config.(*graphql.EnumConfig), 290 | Args: args, 291 | Node: p.node.(*ast.EnumDefinition), 292 | }); err != nil { 293 | return err 294 | } 295 | } 296 | case *graphql.EnumValueConfig: 297 | if visitor.VisitEnumValue != nil { 298 | if err := visitor.VisitEnumValue(VisitEnumValueParams{ 299 | Context: c.ctx, 300 | Config: p.config.(*graphql.EnumValueConfig), 301 | Args: args, 302 | Node: p.node.(*ast.EnumValueDefinition), 303 | }); err != nil { 304 | return err 305 | } 306 | } 307 | case *graphql.InputObjectConfig: 308 | if visitor.VisitInputObject != nil { 309 | if err := visitor.VisitInputObject(VisitInputObjectParams{ 310 | Context: c.ctx, 311 | Config: p.config.(*graphql.InputObjectConfig), 312 | Args: args, 313 | Node: p.node.(*ast.InputObjectDefinition), 314 | }); err != nil { 315 | return err 316 | } 317 | } 318 | case *graphql.InputObjectFieldConfig: 319 | if visitor.VisitInputFieldDefinition != nil { 320 | if err := visitor.VisitInputFieldDefinition(VisitInputFieldDefinitionParams{ 321 | Context: c.ctx, 322 | Config: p.config.(*graphql.InputObjectFieldConfig), 323 | Args: args, 324 | Node: p.node.(*ast.InputValueDefinition), 325 | }); err != nil { 326 | return err 327 | } 328 | } 329 | } 330 | } 331 | 332 | return nil 333 | } 334 | -------------------------------------------------------------------------------- /directives_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/graphql-go/graphql" 7 | ) 8 | 9 | func TestDirectives(t *testing.T) { 10 | typeDefs := ` 11 | directive @test(message: String) on FIELD_DEFINITION 12 | 13 | type Foo { 14 | name: String! 15 | description: String 16 | } 17 | 18 | type Query { 19 | foos( 20 | name: String 21 | ): [Foo] @test(message: "foobar") 22 | } 23 | ` 24 | 25 | // create some data 26 | foos := []map[string]interface{}{ 27 | { 28 | "name": "foo", 29 | "description": "a foo", 30 | }, 31 | } 32 | 33 | // make the schema 34 | schema, err := MakeExecutableSchema(ExecutableSchema{ 35 | TypeDefs: typeDefs, 36 | Resolvers: ResolverMap{ 37 | "Query": &ObjectResolver{ 38 | Fields: FieldResolveMap{ 39 | "foos": &FieldResolve{ 40 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 41 | return foos, nil 42 | }, 43 | }, 44 | }, 45 | }, 46 | }, 47 | SchemaDirectives: SchemaDirectiveVisitorMap{ 48 | "test": &SchemaDirectiveVisitor{ 49 | VisitFieldDefinition: func(v VisitFieldDefinitionParams) error { 50 | resolveFunc := v.Config.Resolve 51 | v.Config.Resolve = func(p graphql.ResolveParams) (interface{}, error) { 52 | result, err := resolveFunc(p) 53 | if err != nil { 54 | return result, err 55 | } 56 | res := result.([]map[string]interface{}) 57 | res0 := res[0] 58 | res0["description"] = v.Args["message"] 59 | return res, nil 60 | } 61 | 62 | return nil 63 | }, 64 | }, 65 | }, 66 | }) 67 | 68 | if err != nil { 69 | t.Error(err) 70 | return 71 | } 72 | 73 | // perform a query 74 | r := graphql.Do(graphql.Params{ 75 | Schema: schema, 76 | RequestString: `query Query { 77 | foos(name:"foo") { 78 | name 79 | description 80 | } 81 | }`, 82 | }) 83 | 84 | if r.HasErrors() { 85 | t.Error(r.Errors) 86 | return 87 | } 88 | 89 | d := r.Data.(map[string]interface{}) 90 | fooResult := d["foos"] 91 | foos0 := fooResult.([]interface{})[0] 92 | foos0Desc := foos0.(map[string]interface{})["description"] 93 | if foos0Desc.(string) != "foobar" { 94 | t.Error("failed to set field with directive") 95 | return 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bhoriuchi/graphql-go-tools 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/google/uuid v1.3.0 7 | github.com/gorilla/websocket v1.4.2 8 | github.com/graphql-go/graphql v0.8.0 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 2 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 4 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 5 | github.com/graphql-go/graphql v0.8.0 h1:JHRQMeQjofwqVvGwYnr8JnPTY0AxgVy1HpHSGPLdH0I= 6 | github.com/graphql-go/graphql v0.8.0/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ= 7 | -------------------------------------------------------------------------------- /graphqlws/connections.go: -------------------------------------------------------------------------------- 1 | package graphqlws 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "sync" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | "github.com/gorilla/websocket" 13 | ) 14 | 15 | const ( 16 | // Constants for operation message types 17 | gqlConnectionAuth = "connection_auth" 18 | gqlConnectionInit = "connection_init" 19 | gqlConnectionAck = "connection_ack" 20 | gqlConnectionKeepAlive = "ka" 21 | gqlConnectionError = "connection_error" 22 | gqlConnectionTerminate = "connection_terminate" 23 | gqlStart = "start" 24 | gqlData = "data" 25 | gqlError = "error" 26 | gqlComplete = "complete" 27 | gqlStop = "stop" 28 | 29 | // Maximum size of incoming messages 30 | readLimit = 4096 31 | 32 | // Timeout for outgoing messages 33 | writeTimeout = 10 * time.Second 34 | ) 35 | 36 | // InitMessagePayload defines the parameters of a connection 37 | // init message. 38 | type InitMessagePayload struct { 39 | AuthToken string `json:"authToken"` 40 | Authorization string `json:"Authorization"` 41 | } 42 | 43 | // StartMessagePayload defines the parameters of an operation that 44 | // a client requests to be started. 45 | type StartMessagePayload struct { 46 | Query string `json:"query"` 47 | Variables map[string]interface{} `json:"variables"` 48 | OperationName string `json:"operationName"` 49 | } 50 | 51 | // DataMessagePayload defines the result data of an operation. 52 | type DataMessagePayload struct { 53 | Data interface{} `json:"data"` 54 | Errors []error `json:"errors"` 55 | } 56 | 57 | // OperationMessage represents a GraphQL WebSocket message. 58 | type OperationMessage struct { 59 | ID string `json:"id"` 60 | Type string `json:"type"` 61 | Payload interface{} `json:"payload"` 62 | } 63 | 64 | func (msg OperationMessage) String() string { 65 | s, _ := json.Marshal(msg) 66 | if s != nil { 67 | return string(s) 68 | } 69 | return "" 70 | } 71 | 72 | // AuthenticateFunc is a function that resolves an auth token 73 | // into a user (or returns an error if that isn't possible). 74 | type AuthenticateFunc func(data map[string]interface{}, conn Connection) (context.Context, error) 75 | 76 | // ConnectionEventHandlers define the event handlers for a connection. 77 | // Event handlers allow other system components to react to events such 78 | // as the connection closing or an operation being started or stopped. 79 | type ConnectionEventHandlers struct { 80 | // Close is called whenever the connection is closed, regardless of 81 | // whether this happens because of an error or a deliberate termination 82 | // by the client. 83 | Close func(Connection) 84 | 85 | // StartOperation is called whenever the client demands that a GraphQL 86 | // operation be started (typically a subscription). Event handlers 87 | // are expected to take the necessary steps to register the operation 88 | // and send data back to the client with the results eventually. 89 | StartOperation func(Connection, string, *StartMessagePayload) []error 90 | 91 | // StopOperation is called whenever the client stops a previously 92 | // started GraphQL operation (typically a subscription). Event handlers 93 | // are expected to unregister the operation and stop sending result 94 | // data to the client. 95 | StopOperation func(Connection, string) 96 | } 97 | 98 | // ConnectionConfig defines the configuration parameters of a 99 | // GraphQL WebSocket connection. 100 | type ConnectionConfig struct { 101 | Logger Logger 102 | Authenticate AuthenticateFunc 103 | EventHandlers ConnectionEventHandlers 104 | } 105 | 106 | // Connection is an interface to represent GraphQL WebSocket connections. 107 | // Each connection is associated with an ID that is unique to the server. 108 | type Connection interface { 109 | // ID returns the unique ID of the connection. 110 | ID() string 111 | 112 | // Context returns the context for the connection 113 | Context() context.Context 114 | 115 | // WS the websocket 116 | WS() *websocket.Conn 117 | 118 | // SendData sends results of executing an operation (typically a 119 | // subscription) to the client. 120 | SendData(string, *DataMessagePayload) 121 | 122 | // SendError sends an error to the client. 123 | SendError(error) 124 | } 125 | 126 | /** 127 | * The default implementation of the Connection interface. 128 | */ 129 | 130 | type connection struct { 131 | id string 132 | ws *websocket.Conn 133 | config ConnectionConfig 134 | logger Logger 135 | outgoing chan OperationMessage 136 | closeMutex *sync.Mutex 137 | closed bool 138 | context context.Context 139 | } 140 | 141 | func operationMessageForType(messageType string) OperationMessage { 142 | return OperationMessage{ 143 | Type: messageType, 144 | } 145 | } 146 | 147 | // NewConnection establishes a GraphQL WebSocket connection. It implements 148 | // the GraphQL WebSocket protocol by managing its internal state and handling 149 | // the client-server communication. 150 | func NewConnection(ws *websocket.Conn, config ConnectionConfig) Connection { 151 | conn := new(connection) 152 | conn.id = uuid.New().String() 153 | conn.ws = ws 154 | conn.context = context.Background() 155 | conn.config = config 156 | conn.logger = config.Logger 157 | conn.closed = false 158 | conn.closeMutex = &sync.Mutex{} 159 | conn.outgoing = make(chan OperationMessage) 160 | 161 | go conn.writeLoop() 162 | go conn.readLoop() 163 | conn.logger.Infof("Created connection") 164 | 165 | return conn 166 | } 167 | 168 | func (conn *connection) ID() string { 169 | return conn.id 170 | } 171 | 172 | func (conn *connection) Context() context.Context { 173 | return conn.context 174 | } 175 | 176 | func (conn *connection) WS() *websocket.Conn { 177 | return conn.ws 178 | } 179 | 180 | func (conn *connection) SendData(opID string, data *DataMessagePayload) { 181 | msg := operationMessageForType(gqlData) 182 | msg.ID = opID 183 | msg.Payload = data 184 | conn.closeMutex.Lock() 185 | if !conn.closed { 186 | conn.outgoing <- msg 187 | } 188 | conn.closeMutex.Unlock() 189 | } 190 | 191 | func (conn *connection) SendError(err error) { 192 | msg := operationMessageForType(gqlError) 193 | msg.Payload = err.Error() 194 | conn.closeMutex.Lock() 195 | if !conn.closed { 196 | conn.outgoing <- msg 197 | } 198 | conn.closeMutex.Unlock() 199 | } 200 | 201 | func (conn *connection) sendOperationErrors(opID string, errs []error) { 202 | if conn.closed { 203 | return 204 | } 205 | 206 | msg := operationMessageForType(gqlError) 207 | msg.ID = opID 208 | msg.Payload = errs 209 | conn.closeMutex.Lock() 210 | if !conn.closed { 211 | conn.outgoing <- msg 212 | } 213 | 214 | conn.closeMutex.Unlock() 215 | } 216 | 217 | func (conn *connection) close() { 218 | // Close the write loop by closing the outgoing messages channels 219 | conn.closeMutex.Lock() 220 | conn.closed = true 221 | close(conn.outgoing) 222 | conn.closeMutex.Unlock() 223 | 224 | // Notify event handlers 225 | if conn.config.EventHandlers.Close != nil { 226 | conn.config.EventHandlers.Close(conn) 227 | } 228 | 229 | conn.logger.Infof("closed connection") 230 | } 231 | 232 | func (conn *connection) writeLoop() { 233 | // Close the WebSocket connection when leaving the write loop; 234 | // this ensures the read loop is also terminated and the connection 235 | // closed cleanly 236 | defer conn.ws.Close() 237 | 238 | for { 239 | msg, ok := <-conn.outgoing 240 | // Close the write loop when the outgoing messages channel is closed; 241 | // this will close the connection 242 | if !ok { 243 | return 244 | } 245 | 246 | conn.logger.Debugf("send message: %s", msg.String()) 247 | conn.ws.SetWriteDeadline(time.Now().Add(writeTimeout)) 248 | 249 | // Send the message to the client; if this times out, the WebSocket 250 | // connection will be corrupt, hence we need to close the write loop 251 | // and the connection immediately 252 | if err := conn.ws.WriteJSON(msg); err != nil { 253 | conn.logger.Warnf("sending message failed: %s", err) 254 | return 255 | } 256 | } 257 | } 258 | 259 | func (conn *connection) readLoop() { 260 | // Close the WebSocket connection when leaving the read loop 261 | defer conn.ws.Close() 262 | conn.ws.SetReadLimit(readLimit) 263 | 264 | for { 265 | // Read the next message received from the client 266 | rawPayload := json.RawMessage{} 267 | msg := OperationMessage{ 268 | Payload: &rawPayload, 269 | } 270 | err := conn.ws.ReadJSON(&msg) 271 | 272 | // If this causes an error, close the connection and read loop immediately; 273 | // see https://github.com/gorilla/websocket/blob/master/conn.go#L924 for 274 | // more information on why this is necessary 275 | if err != nil { 276 | conn.logger.Warnf("force closing connection: %s", err) 277 | conn.close() 278 | return 279 | } 280 | 281 | conn.logger.Debugf("received message (%s): %s", msg.ID, msg.Type) 282 | 283 | switch msg.Type { 284 | case gqlConnectionAuth: 285 | data := map[string]interface{}{} 286 | if err := json.Unmarshal(rawPayload, &data); err != nil { 287 | conn.logger.Debugf("Invalid %s data: %v", msg.Type, err) 288 | conn.SendError(errors.New("invalid GQL_CONNECTION_AUTH payload")) 289 | } else { 290 | if conn.config.Authenticate != nil { 291 | ctx, err := conn.config.Authenticate(data, conn) 292 | if err != nil { 293 | msg := operationMessageForType(gqlConnectionError) 294 | msg.Payload = fmt.Sprintf("Failed to authenticate user: %v", err) 295 | conn.outgoing <- msg 296 | } else { 297 | conn.context = ctx 298 | } 299 | } 300 | } 301 | 302 | // When the GraphQL WS connection is initiated, send an ACK back 303 | case gqlConnectionInit: 304 | data := map[string]interface{}{} 305 | if err := json.Unmarshal(rawPayload, &data); err != nil { 306 | conn.logger.Debugf("Invalid %s data: %v", msg.Type, err) 307 | conn.SendError(errors.New("invalid GQL_CONNECTION_INIT payload")) 308 | } else { 309 | if conn.config.Authenticate != nil { 310 | ctx, err := conn.config.Authenticate(data, conn) 311 | if err != nil { 312 | msg := operationMessageForType(gqlConnectionError) 313 | msg.Payload = fmt.Sprintf("Failed to authenticate user: %v", err) 314 | conn.outgoing <- msg 315 | } else { 316 | conn.context = ctx 317 | conn.outgoing <- operationMessageForType(gqlConnectionAck) 318 | } 319 | } else { 320 | conn.outgoing <- operationMessageForType(gqlConnectionAck) 321 | } 322 | } 323 | 324 | // Let event handlers deal with starting operations 325 | case gqlStart: 326 | if conn.config.EventHandlers.StartOperation != nil { 327 | data := StartMessagePayload{} 328 | if err := json.Unmarshal(rawPayload, &data); err != nil { 329 | conn.SendError(errors.New("invalid GQL_START payload")) 330 | } else { 331 | errs := conn.config.EventHandlers.StartOperation(conn, msg.ID, &data) 332 | if errs != nil { 333 | conn.sendOperationErrors(msg.ID, errs) 334 | } 335 | } 336 | } 337 | 338 | // Let event handlers deal with stopping operations 339 | case gqlStop: 340 | if conn.config.EventHandlers.StopOperation != nil { 341 | conn.config.EventHandlers.StopOperation(conn, msg.ID) 342 | } 343 | 344 | // When the GraphQL WS connection is terminated by the client, 345 | // close the connection and close the read loop 346 | case gqlConnectionTerminate: 347 | conn.logger.Debugf("connection terminated by client") 348 | conn.close() 349 | return 350 | 351 | // GraphQL WS protocol messages that are not handled represent 352 | // a bug in our implementation; make this very obvious by logging 353 | // an error 354 | default: 355 | conn.logger.Errorf("unhandled message: %s", msg.String()) 356 | } 357 | } 358 | } 359 | -------------------------------------------------------------------------------- /graphqlws/logger.go: -------------------------------------------------------------------------------- 1 | package graphqlws 2 | 3 | type Logger interface { 4 | Infof(format string, data ...interface{}) 5 | Debugf(format string, data ...interface{}) 6 | Errorf(format string, data ...interface{}) 7 | Warnf(format string, data ...interface{}) 8 | } 9 | 10 | type noopLogger struct{} 11 | 12 | func (n *noopLogger) Infof(format string, data ...interface{}) {} 13 | func (n *noopLogger) Debugf(format string, data ...interface{}) {} 14 | func (n *noopLogger) Errorf(format string, data ...interface{}) {} 15 | func (n *noopLogger) Warnf(format string, data ...interface{}) {} 16 | -------------------------------------------------------------------------------- /graphqlws/manager.go: -------------------------------------------------------------------------------- 1 | package graphqlws 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/graphql-go/graphql" 7 | ) 8 | 9 | type ChanMgr struct { 10 | mx sync.Mutex 11 | conns map[string]map[string]*ResultChan 12 | } 13 | 14 | type ResultChan struct { 15 | ch chan *graphql.Result 16 | } 17 | 18 | func (c *ChanMgr) Add(cid, oid string, ch chan *graphql.Result) { 19 | c.mx.Lock() 20 | defer c.mx.Unlock() 21 | 22 | conn, ok := c.conns[cid] 23 | if !ok { 24 | conn = make(map[string]*ResultChan) 25 | c.conns[cid] = conn 26 | } 27 | 28 | conn[oid] = &ResultChan{ 29 | ch: ch, 30 | } 31 | } 32 | 33 | func (c *ChanMgr) DelConn(cid string) bool { 34 | c.mx.Lock() 35 | defer c.mx.Unlock() 36 | 37 | conn, ok := c.conns[cid] 38 | if !ok { 39 | return false 40 | } 41 | 42 | for oid := range conn { 43 | delete(conn, oid) 44 | } 45 | 46 | delete(conn, cid) 47 | return true 48 | } 49 | 50 | func (c *ChanMgr) Del(cid, oid string) bool { 51 | c.mx.Lock() 52 | defer c.mx.Unlock() 53 | 54 | conn, ok := c.conns[cid] 55 | if !ok { 56 | return false 57 | } 58 | 59 | if _, ok := conn[oid]; !ok { 60 | return false 61 | } 62 | 63 | delete(conn, oid) 64 | 65 | if len(c.conns[cid]) == 0 { 66 | delete(c.conns, cid) 67 | } 68 | 69 | return true 70 | } 71 | -------------------------------------------------------------------------------- /graphqlws/wshandler.go: -------------------------------------------------------------------------------- 1 | package graphqlws 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/gorilla/websocket" 8 | "github.com/graphql-go/graphql" 9 | ) 10 | 11 | // ConnKey the connection key 12 | var ConnKey interface{} = "conn" 13 | 14 | // HandlerConfig config 15 | type HandlerConfig struct { 16 | Logger Logger 17 | Authenticate AuthenticateFunc 18 | Schema graphql.Schema 19 | RootValue map[string]interface{} 20 | } 21 | 22 | // NewHandler creates a new handler 23 | func NewHandler(config HandlerConfig) http.Handler { 24 | var upgrader = websocket.Upgrader{ 25 | CheckOrigin: func(r *http.Request) bool { return true }, 26 | Subprotocols: []string{"graphql-ws"}, 27 | } 28 | 29 | mgr := &ChanMgr{ 30 | conns: make(map[string]map[string]*ResultChan), 31 | } 32 | 33 | if config.Logger == nil { 34 | config.Logger = &noopLogger{} 35 | } 36 | 37 | return http.HandlerFunc( 38 | func(w http.ResponseWriter, r *http.Request) { 39 | // Establish a WebSocket connection 40 | var ws, err = upgrader.Upgrade(w, r, nil) 41 | 42 | // Bail out if the WebSocket connection could not be established 43 | if err != nil { 44 | config.Logger.Warnf("Failed to establish WebSocket connection", err) 45 | return 46 | } 47 | 48 | // Close the connection early if it doesn't implement the graphql-ws protocol 49 | if ws.Subprotocol() != "graphql-ws" { 50 | config.Logger.Warnf("Connection does not implement the GraphQL WS protocol") 51 | ws.Close() 52 | return 53 | } 54 | 55 | // Establish a GraphQL WebSocket connection 56 | NewConnection(ws, ConnectionConfig{ 57 | Authenticate: config.Authenticate, 58 | Logger: config.Logger, 59 | EventHandlers: ConnectionEventHandlers{ 60 | Close: func(conn Connection) { 61 | config.Logger.Debugf("closing websocket: %s", conn.ID) 62 | mgr.DelConn(conn.ID()) 63 | }, 64 | StartOperation: func( 65 | conn Connection, 66 | opID string, 67 | data *StartMessagePayload, 68 | ) []error { 69 | config.Logger.Debugf("start operations %s on connection %s", opID, conn.ID()) 70 | 71 | ctx := context.WithValue(context.Background(), ConnKey, conn) 72 | resultChannel := graphql.Subscribe(graphql.Params{ 73 | Schema: config.Schema, 74 | RequestString: data.Query, 75 | VariableValues: data.Variables, 76 | OperationName: data.OperationName, 77 | Context: ctx, 78 | RootObject: config.RootValue, 79 | }) 80 | 81 | mgr.Add(conn.ID(), opID, resultChannel) 82 | 83 | go func() { 84 | for { 85 | select { 86 | case <-ctx.Done(): 87 | mgr.Del(conn.ID(), opID) 88 | return 89 | case res, more := <-resultChannel: 90 | if !more { 91 | return 92 | } 93 | 94 | errs := []error{} 95 | 96 | if res.HasErrors() { 97 | for _, err := range res.Errors { 98 | config.Logger.Debugf("subscription_error: %v", err) 99 | errs = append(errs, err.OriginalError()) 100 | } 101 | } 102 | 103 | conn.SendData(opID, &DataMessagePayload{ 104 | Data: res.Data, 105 | Errors: errs, 106 | }) 107 | } 108 | } 109 | }() 110 | 111 | return nil 112 | }, 113 | StopOperation: func(conn Connection, opID string) { 114 | config.Logger.Debugf("stop operation %s on connection %s", opID, conn.ID()) 115 | mgr.Del(conn.ID(), opID) 116 | }, 117 | }, 118 | }) 119 | }, 120 | ) 121 | } 122 | -------------------------------------------------------------------------------- /handler/README.md: -------------------------------------------------------------------------------- 1 | # graphql-go-tools-handler 2 | 3 | Fork of [https://github.com/graphql-go/handler](https://github.com/graphql-go/handler) with some changes 4 | 5 | ### Usage 6 | 7 | ```go 8 | func main() { 9 | schema, _ := graphql.NewSchema(...) 10 | 11 | h := handler.New(&handler.Config{ 12 | Schema: &schema, 13 | Pretty: true, 14 | GraphiQL: handler.NewDefaultGraphiQLConfig(), 15 | }) 16 | 17 | http.Handle("/graphql", h) 18 | http.ListenAndServe(":8080", nil) 19 | } 20 | ``` 21 | 22 | ### Using Playground 23 | ```go 24 | h := handler.New(&handler.Config{ 25 | Schema: &schema, 26 | Pretty: true, 27 | Playground: handler.NewDefaultPlaygroundConfig(),, 28 | }) 29 | ``` 30 | 31 | ### Details 32 | 33 | The handler will accept requests with 34 | the parameters: 35 | 36 | * **`query`**: A string GraphQL document to be executed. 37 | 38 | * **`variables`**: The runtime values to use for any GraphQL query variables 39 | as a JSON object. 40 | 41 | * **`operationName`**: If the provided `query` contains multiple named 42 | operations, this specifies which operation should be executed. If not 43 | provided, an 400 error will be returned if the `query` contains multiple 44 | named operations. 45 | 46 | GraphQL will first look for each parameter in the URL's query-string: 47 | 48 | ``` 49 | /graphql?query=query+getUser($id:ID){user(id:$id){name}}&variables={"id":"4"} 50 | ``` 51 | 52 | If not found in the query-string, it will look in the POST request body. 53 | The `handler` will interpret it 54 | depending on the provided `Content-Type` header. 55 | 56 | * **`application/json`**: the POST body will be parsed as a JSON 57 | object of parameters. 58 | 59 | * **`application/x-www-form-urlencoded`**: this POST body will be 60 | parsed as a url-encoded string of key-value pairs. 61 | 62 | * **`application/graphql`**: The POST body will be parsed as GraphQL 63 | query string, which provides the `query` parameter. 64 | -------------------------------------------------------------------------------- /handler/graphiql.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "html/template" 7 | "net/http" 8 | "net/url" 9 | "path" 10 | 11 | "github.com/graphql-go/graphql" 12 | ) 13 | 14 | // GraphiQLConfig a configuration for graphiql 15 | type GraphiQLConfig struct { 16 | Version string 17 | Endpoint string 18 | SubscriptionEndpoint string 19 | } 20 | 21 | // NewDefaultGraphiQLConfig creates a new default config 22 | func NewDefaultGraphiQLConfig() *GraphiQLConfig { 23 | return &GraphiQLConfig{ 24 | Version: GraphiqlVersion, 25 | Endpoint: "", 26 | SubscriptionEndpoint: "", 27 | } 28 | } 29 | 30 | // graphiqlData is the page data structure of the rendered GraphiQL page 31 | type graphiqlData struct { 32 | Endpoint string 33 | SubscriptionEndpoint string 34 | GraphiqlVersion string 35 | QueryString string 36 | VariablesString string 37 | OperationName string 38 | ResultString string 39 | } 40 | 41 | // renderGraphiQL renders the GraphiQL GUI 42 | func renderGraphiQL(config *GraphiQLConfig, w http.ResponseWriter, r *http.Request, params graphql.Params) { 43 | t := template.New("GraphiQL") 44 | t, err := t.Parse(graphiqlTemplate) 45 | if err != nil { 46 | http.Error(w, err.Error(), http.StatusInternalServerError) 47 | return 48 | } 49 | 50 | // Create variables string 51 | vars, err := json.MarshalIndent(params.VariableValues, "", " ") 52 | if err != nil { 53 | http.Error(w, err.Error(), http.StatusInternalServerError) 54 | return 55 | } 56 | varsString := string(vars) 57 | if varsString == "null" { 58 | varsString = "" 59 | } 60 | 61 | // Create result string 62 | var resString string 63 | if params.RequestString == "" { 64 | resString = "" 65 | } else { 66 | result, err := json.MarshalIndent(graphql.Do(params), "", " ") 67 | if err != nil { 68 | http.Error(w, err.Error(), http.StatusInternalServerError) 69 | return 70 | } 71 | resString = string(result) 72 | } 73 | 74 | endpoint := r.URL.Path 75 | if config.Endpoint != "" { 76 | endpoint = config.Endpoint 77 | } 78 | 79 | subscriptionPath := path.Join(path.Dir(r.URL.Path), "subscriptions") 80 | subscriptionEndpoint := fmt.Sprintf("ws://%v%s", r.Host, subscriptionPath) 81 | if config.SubscriptionEndpoint != "" { 82 | if _, err := url.ParseRequestURI(config.SubscriptionEndpoint); err == nil { 83 | subscriptionEndpoint = config.SubscriptionEndpoint 84 | } else { 85 | subscriptionEndpoint = path.Join( 86 | fmt.Sprintf("ws://%v", r.Host), 87 | config.SubscriptionEndpoint, 88 | ) 89 | } 90 | } 91 | 92 | d := graphiqlData{ 93 | GraphiqlVersion: GraphiqlVersion, 94 | QueryString: params.RequestString, 95 | ResultString: resString, 96 | VariablesString: varsString, 97 | OperationName: params.OperationName, 98 | Endpoint: endpoint, 99 | SubscriptionEndpoint: subscriptionEndpoint, 100 | } 101 | err = t.ExecuteTemplate(w, "index", d) 102 | if err != nil { 103 | http.Error(w, err.Error(), http.StatusInternalServerError) 104 | } 105 | 106 | return 107 | } 108 | 109 | // GraphiqlVersion is the current version of GraphiQL 110 | var GraphiqlVersion = "0.13.2" 111 | 112 | const graphiqlTemplate = ` 113 | {{ define "index" }} 114 | 121 | 122 | 123 | 124 | 135 | 136 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 154 | 155 | 156 | 157 | 158 | 159 |
Loading...
160 | 299 | 300 | 301 | {{ end }} 302 | ` 303 | -------------------------------------------------------------------------------- /handler/handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io/ioutil" 7 | "net/http" 8 | "net/url" 9 | "strings" 10 | 11 | "github.com/graphql-go/graphql" 12 | "github.com/graphql-go/graphql/gqlerrors" 13 | ) 14 | 15 | // Constants 16 | const ( 17 | ContentTypeJSON = "application/json" 18 | ContentTypeGraphQL = "application/graphql" 19 | ContentTypeFormURLEncoded = "application/x-www-form-urlencoded" 20 | ) 21 | 22 | // ResultCallbackFn result callback 23 | type ResultCallbackFn func(ctx context.Context, params *graphql.Params, result *graphql.Result, responseBody []byte) 24 | 25 | // Handler handler 26 | type Handler struct { 27 | Schema *graphql.Schema 28 | pretty bool 29 | graphiqlConfig *GraphiQLConfig 30 | playgroundConfig *PlaygroundConfig 31 | rootObjectFn RootObjectFn 32 | resultCallbackFn ResultCallbackFn 33 | formatErrorFn func(err error) gqlerrors.FormattedError 34 | } 35 | 36 | // RequestOptions options 37 | type RequestOptions struct { 38 | Query string `json:"query" url:"query" schema:"query"` 39 | Variables map[string]interface{} `json:"variables" url:"variables" schema:"variables"` 40 | OperationName string `json:"operationName" url:"operationName" schema:"operationName"` 41 | } 42 | 43 | // a workaround for getting`variables` as a JSON string 44 | type requestOptionsCompatibility struct { 45 | Query string `json:"query" url:"query" schema:"query"` 46 | Variables string `json:"variables" url:"variables" schema:"variables"` 47 | OperationName string `json:"operationName" url:"operationName" schema:"operationName"` 48 | } 49 | 50 | func getFromForm(values url.Values) *RequestOptions { 51 | query := values.Get("query") 52 | if query != "" { 53 | // get variables map 54 | variables := make(map[string]interface{}, len(values)) 55 | variablesStr := values.Get("variables") 56 | json.Unmarshal([]byte(variablesStr), &variables) 57 | 58 | return &RequestOptions{ 59 | Query: query, 60 | Variables: variables, 61 | OperationName: values.Get("operationName"), 62 | } 63 | } 64 | 65 | return nil 66 | } 67 | 68 | // NewRequestOptions Parses a http.Request into GraphQL request options struct 69 | func NewRequestOptions(r *http.Request) *RequestOptions { 70 | if reqOpt := getFromForm(r.URL.Query()); reqOpt != nil { 71 | return reqOpt 72 | } 73 | 74 | if r.Method != http.MethodPost { 75 | return &RequestOptions{} 76 | } 77 | 78 | if r.Body == nil { 79 | return &RequestOptions{} 80 | } 81 | 82 | // TODO: improve Content-Type handling 83 | contentTypeStr := r.Header.Get("Content-Type") 84 | contentTypeTokens := strings.Split(contentTypeStr, ";") 85 | contentType := contentTypeTokens[0] 86 | 87 | switch contentType { 88 | case ContentTypeGraphQL: 89 | body, err := ioutil.ReadAll(r.Body) 90 | if err != nil { 91 | return &RequestOptions{} 92 | } 93 | return &RequestOptions{ 94 | Query: string(body), 95 | } 96 | case ContentTypeFormURLEncoded: 97 | if err := r.ParseForm(); err != nil { 98 | return &RequestOptions{} 99 | } 100 | 101 | if reqOpt := getFromForm(r.PostForm); reqOpt != nil { 102 | return reqOpt 103 | } 104 | 105 | return &RequestOptions{} 106 | 107 | case ContentTypeJSON: 108 | fallthrough 109 | default: 110 | var opts RequestOptions 111 | body, err := ioutil.ReadAll(r.Body) 112 | if err != nil { 113 | return &opts 114 | } 115 | err = json.Unmarshal(body, &opts) 116 | if err != nil { 117 | // Probably `variables` was sent as a string instead of an object. 118 | // So, we try to be polite and try to parse that as a JSON string 119 | var optsCompatible requestOptionsCompatibility 120 | json.Unmarshal(body, &optsCompatible) 121 | json.Unmarshal([]byte(optsCompatible.Variables), &opts.Variables) 122 | } 123 | return &opts 124 | } 125 | } 126 | 127 | // ContextHandler provides an entrypoint into executing graphQL queries with a 128 | // user-provided context. 129 | func (h *Handler) ContextHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { 130 | // get query 131 | opts := NewRequestOptions(r) 132 | 133 | // execute graphql query 134 | params := graphql.Params{ 135 | Schema: *h.Schema, 136 | RequestString: opts.Query, 137 | VariableValues: opts.Variables, 138 | OperationName: opts.OperationName, 139 | Context: ctx, 140 | } 141 | if h.rootObjectFn != nil { 142 | params.RootObject = h.rootObjectFn(ctx, r) 143 | } 144 | result := graphql.Do(params) 145 | 146 | if formatErrorFn := h.formatErrorFn; formatErrorFn != nil && len(result.Errors) > 0 { 147 | formatted := make([]gqlerrors.FormattedError, len(result.Errors)) 148 | for i, formattedError := range result.Errors { 149 | formatted[i] = formatErrorFn(formattedError.OriginalError()) 150 | } 151 | result.Errors = formatted 152 | } 153 | 154 | if h.graphiqlConfig != nil { 155 | acceptHeader := r.Header.Get("Accept") 156 | _, raw := r.URL.Query()["raw"] 157 | if !raw && !strings.Contains(acceptHeader, "application/json") && strings.Contains(acceptHeader, "text/html") { 158 | renderGraphiQL(h.graphiqlConfig, w, r, params) 159 | return 160 | } 161 | } 162 | 163 | if h.playgroundConfig != nil && h.graphiqlConfig == nil { 164 | acceptHeader := r.Header.Get("Accept") 165 | _, raw := r.URL.Query()["raw"] 166 | if !raw && !strings.Contains(acceptHeader, "application/json") && strings.Contains(acceptHeader, "text/html") { 167 | renderPlayground(h.playgroundConfig, w, r) 168 | return 169 | } 170 | } 171 | 172 | // use proper JSON Header 173 | w.Header().Add("Content-Type", "application/json; charset=utf-8") 174 | 175 | var buff []byte 176 | if h.pretty { 177 | w.WriteHeader(http.StatusOK) 178 | buff, _ = json.MarshalIndent(result, "", "\t") 179 | 180 | w.Write(buff) 181 | } else { 182 | w.WriteHeader(http.StatusOK) 183 | buff, _ = json.Marshal(result) 184 | 185 | w.Write(buff) 186 | } 187 | 188 | if h.resultCallbackFn != nil { 189 | h.resultCallbackFn(ctx, ¶ms, result, buff) 190 | } 191 | } 192 | 193 | // ServeHTTP provides an entrypoint into executing graphQL queries. 194 | func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 195 | h.ContextHandler(r.Context(), w, r) 196 | } 197 | 198 | // RootObjectFn allows a user to generate a RootObject per request 199 | type RootObjectFn func(ctx context.Context, r *http.Request) map[string]interface{} 200 | 201 | // Config configuration 202 | type Config struct { 203 | Schema *graphql.Schema 204 | Pretty bool 205 | GraphiQLConfig *GraphiQLConfig 206 | PlaygroundConfig *PlaygroundConfig 207 | RootObjectFn RootObjectFn 208 | ResultCallbackFn ResultCallbackFn 209 | FormatErrorFn func(err error) gqlerrors.FormattedError 210 | } 211 | 212 | // NewConfig returns a new default config 213 | func NewConfig() *Config { 214 | return &Config{ 215 | Schema: nil, 216 | Pretty: true, 217 | GraphiQLConfig: nil, 218 | PlaygroundConfig: NewDefaultPlaygroundConfig(), 219 | } 220 | } 221 | 222 | // New creates a new handler 223 | func New(p *Config) *Handler { 224 | if p == nil { 225 | p = NewConfig() 226 | } 227 | 228 | if p.Schema == nil { 229 | panic("undefined GraphQL schema") 230 | } 231 | 232 | return &Handler{ 233 | Schema: p.Schema, 234 | pretty: p.Pretty, 235 | graphiqlConfig: p.GraphiQLConfig, 236 | playgroundConfig: p.PlaygroundConfig, 237 | rootObjectFn: p.RootObjectFn, 238 | resultCallbackFn: p.ResultCallbackFn, 239 | formatErrorFn: p.FormatErrorFn, 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /handler/playground.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "html/template" 6 | "net/http" 7 | "net/url" 8 | "path" 9 | ) 10 | 11 | // PlaygroundConfig playground configuration 12 | type PlaygroundConfig struct { 13 | Endpoint string 14 | SubscriptionEndpoint string 15 | Version string 16 | } 17 | 18 | // NewDefaultPlaygroundConfig creates a new default config 19 | func NewDefaultPlaygroundConfig() *PlaygroundConfig { 20 | return &PlaygroundConfig{ 21 | Endpoint: "", 22 | SubscriptionEndpoint: "", 23 | Version: PlaygroundVersion, 24 | } 25 | } 26 | 27 | type playgroundData struct { 28 | PlaygroundVersion string 29 | Endpoint string 30 | SubscriptionEndpoint string 31 | SetTitle bool 32 | } 33 | 34 | // renderPlayground renders the Playground GUI 35 | func renderPlayground(config *PlaygroundConfig, w http.ResponseWriter, r *http.Request) { 36 | t := template.New("Playground") 37 | t, err := t.Parse(graphcoolPlaygroundTemplate) 38 | if err != nil { 39 | http.Error(w, err.Error(), http.StatusInternalServerError) 40 | return 41 | } 42 | 43 | endpoint := r.URL.Path 44 | if config.Endpoint != "" { 45 | endpoint = config.Endpoint 46 | } 47 | 48 | subscriptionPath := path.Join(path.Dir(r.URL.Path), "subscriptions") 49 | subscriptionEndpoint := fmt.Sprintf("ws://%v%s", r.Host, subscriptionPath) 50 | if config.SubscriptionEndpoint != "" { 51 | if _, err := url.ParseRequestURI(config.SubscriptionEndpoint); err == nil { 52 | subscriptionEndpoint = config.SubscriptionEndpoint 53 | } else { 54 | subscriptionEndpoint = path.Join( 55 | fmt.Sprintf("ws://%v", r.Host), 56 | config.SubscriptionEndpoint, 57 | ) 58 | } 59 | } 60 | 61 | version := PlaygroundVersion 62 | if config.Version != "" { 63 | version = config.Version 64 | } 65 | 66 | d := playgroundData{ 67 | PlaygroundVersion: version, 68 | Endpoint: endpoint, 69 | SubscriptionEndpoint: subscriptionEndpoint, 70 | SetTitle: true, 71 | } 72 | err = t.ExecuteTemplate(w, "index", d) 73 | if err != nil { 74 | http.Error(w, err.Error(), http.StatusInternalServerError) 75 | } 76 | 77 | return 78 | } 79 | 80 | // PlaygroundVersion the default version to use 81 | var PlaygroundVersion = "1.7.20" 82 | 83 | const graphcoolPlaygroundTemplate = ` 84 | {{ define "index" }} 85 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | GraphQL Playground 100 | 101 | 102 | 103 | 104 | 105 | 106 |
107 | 134 | 135 |
Loading 136 | GraphQL Playground 137 |
138 |
139 | 147 | 148 | 149 | 150 | {{ end }} 151 | ` 152 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | 10 | "github.com/graphql-go/graphql" 11 | "github.com/graphql-go/graphql/language/ast" 12 | "github.com/graphql-go/graphql/language/kinds" 13 | ) 14 | 15 | // gets the field resolve function for a field 16 | func (c *registry) getFieldResolveFn(kind, typeName, fieldName string) graphql.FieldResolveFn { 17 | if r := c.getResolver(typeName); r != nil && kind == r.getKind() { 18 | switch kind { 19 | case kinds.ObjectDefinition: 20 | if fn, ok := r.(*ObjectResolver).Fields[fieldName]; ok { 21 | return fn.Resolve 22 | } 23 | case kinds.InterfaceDefinition: 24 | if fn, ok := r.(*InterfaceResolver).Fields[fieldName]; ok { 25 | return fn.Resolve 26 | } 27 | } 28 | } 29 | return graphql.DefaultResolveFn 30 | } 31 | 32 | func (c *registry) getFieldSubscribeFn(kind, typeName, fieldName string) graphql.FieldResolveFn { 33 | if r := c.getResolver(typeName); r != nil && kind == r.getKind() { 34 | switch kind { 35 | case kinds.ObjectDefinition: 36 | if fieldResolve, ok := r.(*ObjectResolver).Fields[fieldName]; ok { 37 | return fieldResolve.Subscribe 38 | } 39 | case kinds.InterfaceDefinition: 40 | if fieldResolve, ok := r.(*InterfaceResolver).Fields[fieldName]; ok { 41 | return fieldResolve.Subscribe 42 | } 43 | } 44 | } 45 | return nil 46 | } 47 | 48 | // Recursively builds a complex type 49 | func (c *registry) buildComplexType(astType ast.Type) (graphql.Type, error) { 50 | switch kind := astType.GetKind(); kind { 51 | case kinds.List: 52 | t, err := c.buildComplexType(astType.(*ast.List).Type) 53 | if err != nil { 54 | return nil, err 55 | } 56 | return graphql.NewList(t), nil 57 | 58 | case kinds.NonNull: 59 | t, err := c.buildComplexType(astType.(*ast.NonNull).Type) 60 | if err != nil { 61 | return nil, err 62 | } 63 | return graphql.NewNonNull(t), nil 64 | 65 | case kinds.Named: 66 | t := astType.(*ast.Named) 67 | return c.getType(t.Name.Value) 68 | } 69 | 70 | return nil, fmt.Errorf("invalid kind") 71 | } 72 | 73 | // gets the description or defaults to an empty string 74 | func getDescription(node ast.DescribableNode) string { 75 | if desc := node.GetDescription(); desc != nil { 76 | return desc.Value 77 | } 78 | return "" 79 | } 80 | 81 | func parseDefaultValue(inputType ast.Type, value interface{}) (interface{}, error) { 82 | if value == nil { 83 | return nil, nil 84 | } 85 | 86 | switch t := inputType.(type) { 87 | // non-null call parse on type 88 | case *ast.NonNull: 89 | return parseDefaultValue(t.Type, value) 90 | 91 | // list parse each item in the list 92 | case *ast.List: 93 | switch a := value.(type) { 94 | case []ast.Value: 95 | arr := []interface{}{} 96 | for _, v := range a { 97 | val, err := parseDefaultValue(t.Type, v.GetValue()) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | arr = append(arr, val) 103 | } 104 | return arr, nil 105 | } 106 | 107 | // parse the specific type 108 | case *ast.Named: 109 | switch t.Name.Value { 110 | case "Int": 111 | value = graphql.Int.ParseValue(value) 112 | case "Float": 113 | value = graphql.Float.ParseValue(value) 114 | case "Boolean": 115 | value = graphql.Boolean.ParseValue(value) 116 | case "ID": 117 | value = graphql.ID.ParseValue(value) 118 | case "String": 119 | value = graphql.String.ParseValue(value) 120 | } 121 | } 122 | 123 | return value, nil 124 | } 125 | 126 | // gets the default value or defaults to nil 127 | func getDefaultValue(input *ast.InputValueDefinition) (interface{}, error) { 128 | if input.DefaultValue == nil { 129 | return nil, nil 130 | } 131 | 132 | defaultValue, err := parseDefaultValue(input.Type, input.DefaultValue.GetValue()) 133 | if err != nil { 134 | return nil, err 135 | } 136 | 137 | return defaultValue, err 138 | } 139 | 140 | // ReadSourceFiles reads all source files from a specified path 141 | func ReadSourceFiles(p string, recursive ...bool) (string, error) { 142 | typeDefs := []string{} 143 | abs, err := filepath.Abs(p) 144 | if err != nil { 145 | return "", err 146 | } 147 | 148 | var readFunc = func(p string, info os.FileInfo, err error) error { 149 | if info.IsDir() { 150 | return nil 151 | } 152 | 153 | switch ext := strings.ToLower(filepath.Ext(info.Name())); ext { 154 | case ".gql", ".graphql": 155 | data, err := ioutil.ReadFile(p) 156 | if err != nil { 157 | return err 158 | } 159 | typeDefs = append(typeDefs, string(data)) 160 | return nil 161 | default: 162 | return nil 163 | } 164 | } 165 | 166 | if len(recursive) > 0 && recursive[0] { 167 | if err := filepath.Walk(abs, readFunc); err != nil { 168 | return "", err 169 | } 170 | } else { 171 | files, err := ioutil.ReadDir(abs) 172 | if err != nil { 173 | return "", err 174 | } 175 | for _, file := range files { 176 | if err := readFunc(abs, file, nil); err != nil { 177 | return "", err 178 | } 179 | } 180 | } 181 | 182 | result := strings.Join(typeDefs, "\n") 183 | return result, err 184 | } 185 | 186 | // UnaliasedPathArray gets the path array for a resolve function without aliases 187 | func UnaliasedPathArray(info graphql.ResolveInfo) []interface{} { 188 | return unaliasedPathArray(info.Operation.GetSelectionSet(), info.Path.AsArray(), []interface{}{}) 189 | } 190 | 191 | // gets the actual field path for a selection by removing aliases 192 | func unaliasedPathArray(set *ast.SelectionSet, remaining []interface{}, current []interface{}) []interface{} { 193 | if len(remaining) == 0 { 194 | return current 195 | } 196 | 197 | for _, sel := range set.Selections { 198 | switch field := sel.(type) { 199 | case *ast.Field: 200 | if field.Alias != nil && field.Alias.Value == remaining[0] { 201 | return unaliasedPathArray(sel.GetSelectionSet(), remaining[1:], append(current, field.Name.Value)) 202 | } else if field.Name.Value == remaining[0] { 203 | return unaliasedPathArray(sel.GetSelectionSet(), remaining[1:], append(current, field.Name.Value)) 204 | } 205 | } 206 | } 207 | return current 208 | } 209 | 210 | // GetPathFieldSubSelections gets the subselectiond for a path 211 | func GetPathFieldSubSelections(info graphql.ResolveInfo, field ...string) (names []string, err error) { 212 | names = []string{} 213 | if len(info.FieldASTs) == 0 { 214 | return 215 | } 216 | 217 | fieldAST := info.FieldASTs[0] 218 | if fieldAST.GetSelectionSet() == nil { 219 | return 220 | } 221 | 222 | // get any sub selections 223 | for _, f := range field { 224 | for _, sel := range fieldAST.GetSelectionSet().Selections { 225 | switch fragment := sel.(type) { 226 | case *ast.InlineFragment: 227 | for _, ss := range fragment.GetSelectionSet().Selections { 228 | switch subField := ss.(type) { 229 | case *ast.Field: 230 | if subField.Name.Value == f { 231 | fieldAST = subField 232 | break 233 | } 234 | } 235 | } 236 | case *ast.Field: 237 | subField := sel.(*ast.Field) 238 | if subField.Name.Value == f { 239 | fieldAST = subField 240 | continue 241 | } 242 | } 243 | } 244 | } 245 | 246 | for _, sel := range fieldAST.GetSelectionSet().Selections { 247 | switch fragment := sel.(type) { 248 | case *ast.InlineFragment: 249 | for _, ss := range fragment.GetSelectionSet().Selections { 250 | switch field := ss.(type) { 251 | case *ast.Field: 252 | names = append(names, field.Name.Value) 253 | } 254 | } 255 | 256 | case *ast.Field: 257 | field := sel.(*ast.Field) 258 | names = append(names, field.Name.Value) 259 | } 260 | } 261 | 262 | return 263 | } 264 | 265 | // determines if a field is hidden 266 | func isHiddenField(field *ast.FieldDefinition) bool { 267 | hide := false 268 | for _, dir := range field.Directives { 269 | if dir.Name.Value == directiveHide { 270 | return true 271 | } 272 | } 273 | 274 | return hide 275 | } 276 | 277 | // Merges object definitions 278 | func MergeExtensions(obj *ast.ObjectDefinition, extensions ...*ast.ObjectDefinition) *ast.ObjectDefinition { 279 | merged := &ast.ObjectDefinition{ 280 | Kind: obj.Kind, 281 | Loc: obj.Loc, 282 | Name: obj.Name, 283 | Description: obj.Description, 284 | Interfaces: append([]*ast.Named{}, obj.Interfaces...), 285 | Directives: append([]*ast.Directive{}, obj.Directives...), 286 | Fields: append([]*ast.FieldDefinition{}, obj.Fields...), 287 | } 288 | 289 | for _, ext := range extensions { 290 | merged.Interfaces = append(merged.Interfaces, ext.Interfaces...) 291 | merged.Directives = append(merged.Directives, ext.Directives...) 292 | merged.Fields = append(merged.Fields, ext.Fields...) 293 | } 294 | 295 | return merged 296 | } 297 | 298 | const IntrospectionQuery = `query IntrospectionQuery { 299 | __schema { 300 | queryType { 301 | name 302 | } 303 | mutationType { 304 | name 305 | } 306 | subscriptionType { 307 | name 308 | } 309 | types { 310 | ...FullType 311 | } 312 | directives { 313 | name 314 | description 315 | locations 316 | args { 317 | ...InputValue 318 | } 319 | } 320 | } 321 | } 322 | 323 | fragment FullType on __Type { 324 | kind 325 | name 326 | description 327 | fields(includeDeprecated: true) { 328 | name 329 | description 330 | args { 331 | ...InputValue 332 | } 333 | type { 334 | ...TypeRef 335 | } 336 | isDeprecated 337 | deprecationReason 338 | } 339 | inputFields { 340 | ...InputValue 341 | } 342 | interfaces { 343 | ...TypeRef 344 | } 345 | enumValues(includeDeprecated: true) { 346 | name 347 | description 348 | isDeprecated 349 | deprecationReason 350 | } 351 | possibleTypes { 352 | ...TypeRef 353 | } 354 | } 355 | 356 | fragment InputValue on __InputValue { 357 | name 358 | description 359 | type { 360 | ...TypeRef 361 | } 362 | defaultValue 363 | } 364 | 365 | fragment TypeRef on __Type { 366 | kind 367 | name 368 | ofType { 369 | kind 370 | name 371 | ofType { 372 | kind 373 | name 374 | ofType { 375 | kind 376 | name 377 | ofType { 378 | kind 379 | name 380 | ofType { 381 | kind 382 | name 383 | ofType { 384 | kind 385 | name 386 | ofType { 387 | kind 388 | name 389 | } 390 | } 391 | } 392 | } 393 | } 394 | } 395 | } 396 | }` 397 | -------------------------------------------------------------------------------- /registry.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/graphql-go/graphql" 10 | "github.com/graphql-go/graphql/language/ast" 11 | "github.com/graphql-go/graphql/language/kinds" 12 | ) 13 | 14 | var errUnresolvedDependencies = errors.New("unresolved dependencies") 15 | 16 | // registry the registry holds all of the types 17 | type registry struct { 18 | ctx context.Context 19 | types map[string]graphql.Type 20 | directives map[string]*graphql.Directive 21 | schema *graphql.Schema 22 | resolverMap resolverMap 23 | directiveMap SchemaDirectiveVisitorMap 24 | schemaDirectives []*ast.Directive 25 | document *ast.Document 26 | extensions []graphql.Extension 27 | unresolvedDefs []ast.Node 28 | maxIterations int 29 | iterations int 30 | dependencyMap DependencyMap 31 | } 32 | 33 | // newRegistry creates a new registry 34 | func newRegistry( 35 | ctx context.Context, 36 | resolvers map[string]interface{}, 37 | directiveMap SchemaDirectiveVisitorMap, 38 | extensions []graphql.Extension, 39 | document *ast.Document, 40 | ) (*registry, error) { 41 | if ctx == nil { 42 | ctx = context.Background() 43 | } 44 | 45 | r := ®istry{ 46 | ctx: ctx, 47 | types: map[string]graphql.Type{ 48 | "ID": graphql.ID, 49 | "String": graphql.String, 50 | "Int": graphql.Int, 51 | "Float": graphql.Float, 52 | "Boolean": graphql.Boolean, 53 | "DateTime": graphql.DateTime, 54 | }, 55 | directives: map[string]*graphql.Directive{ 56 | "include": graphql.IncludeDirective, 57 | "skip": graphql.SkipDirective, 58 | "deprecated": graphql.DeprecatedDirective, 59 | "hide": HideDirective, 60 | }, 61 | resolverMap: resolverMap{}, 62 | directiveMap: directiveMap, 63 | schemaDirectives: []*ast.Directive{}, 64 | document: document, 65 | extensions: extensions, 66 | unresolvedDefs: document.Definitions, 67 | iterations: 0, 68 | maxIterations: len(document.Definitions), 69 | } 70 | 71 | // import each resolver to the correct location 72 | for name, resolver := range resolvers { 73 | if err := r.importResolver(name, resolver); err != nil { 74 | return nil, err 75 | } 76 | } 77 | 78 | return r, nil 79 | } 80 | 81 | // looks up a resolver by name or returns nil 82 | func (c *registry) getResolver(name string) Resolver { 83 | if c.resolverMap != nil { 84 | if resolver, ok := c.resolverMap[name]; ok { 85 | return resolver 86 | } 87 | } 88 | return nil 89 | } 90 | 91 | // gets an object from the registry 92 | func (c *registry) getObject(name string) (*graphql.Object, error) { 93 | obj, err := c.getType(name) 94 | if err != nil { 95 | return nil, err 96 | } 97 | switch o := obj.(type) { 98 | case *graphql.Object: 99 | return o, nil 100 | } 101 | return nil, nil 102 | } 103 | 104 | // converts the type map to an array 105 | func (c *registry) typeArray() []graphql.Type { 106 | a := make([]graphql.Type, 0) 107 | for _, t := range c.types { 108 | a = append(a, t) 109 | } 110 | return a 111 | } 112 | 113 | // Get gets a type from the registry 114 | func (c *registry) getType(name string) (graphql.Type, error) { 115 | if val, ok := c.types[name]; ok { 116 | return val, nil 117 | } 118 | 119 | if !c.willResolve(name) { 120 | return nil, fmt.Errorf("no definition found for type %q", name) 121 | } 122 | 123 | return nil, errUnresolvedDependencies 124 | } 125 | 126 | // Get gets a directive from the registry 127 | func (c *registry) getDirective(name string) (*graphql.Directive, error) { 128 | if val, ok := c.directives[name]; ok { 129 | return val, nil 130 | } 131 | return nil, errUnresolvedDependencies 132 | } 133 | 134 | // gets the extensions for the current type 135 | func (c *registry) getExtensions(name, kind string) []*ast.ObjectDefinition { 136 | extensions := []*ast.ObjectDefinition{} 137 | 138 | for _, def := range c.document.Definitions { 139 | if def.GetKind() == kinds.TypeExtensionDefinition { 140 | extDef := def.(*ast.TypeExtensionDefinition).Definition 141 | if extDef.Name.Value == name && extDef.GetKind() == kind { 142 | extensions = append(extensions, extDef) 143 | } 144 | } 145 | } 146 | 147 | return extensions 148 | } 149 | 150 | // imports a resolver from an interface 151 | func (c *registry) importResolver(name string, resolver interface{}) error { 152 | switch res := resolver.(type) { 153 | case *graphql.Directive: 154 | // allow @ to be prefixed to a directive in the event there is a type with the same 155 | // name to allow both to be defined in the resolver map but strip it from the 156 | // directive before adding it to the registry 157 | name = strings.TrimLeft(name, "@") 158 | if _, ok := c.directives[name]; !ok { 159 | c.directives[name] = res 160 | } 161 | 162 | case *graphql.InputObject: 163 | if _, ok := c.types[name]; !ok { 164 | c.types[name] = res 165 | } 166 | 167 | case *graphql.Scalar: 168 | if _, ok := c.types[name]; !ok { 169 | c.types[name] = res 170 | } 171 | 172 | case *graphql.Enum: 173 | if _, ok := c.types[name]; !ok { 174 | c.types[name] = res 175 | } 176 | 177 | case *graphql.Object: 178 | if _, ok := c.types[name]; !ok { 179 | c.types[name] = res 180 | } 181 | 182 | case *graphql.Interface: 183 | if _, ok := c.types[name]; !ok { 184 | c.types[name] = res 185 | } 186 | 187 | case *graphql.Union: 188 | if _, ok := c.types[name]; !ok { 189 | c.types[name] = res 190 | } 191 | 192 | case *ScalarResolver: 193 | if _, ok := c.resolverMap[name]; !ok { 194 | c.resolverMap[name] = res 195 | } 196 | 197 | case *EnumResolver: 198 | if _, ok := c.resolverMap[name]; !ok { 199 | c.resolverMap[name] = res 200 | } 201 | 202 | case *ObjectResolver: 203 | if _, ok := c.resolverMap[name]; !ok { 204 | c.resolverMap[name] = res 205 | } 206 | 207 | case *InterfaceResolver: 208 | if _, ok := c.resolverMap[name]; !ok { 209 | c.resolverMap[name] = res 210 | } 211 | 212 | case *UnionResolver: 213 | if _, ok := c.resolverMap[name]; !ok { 214 | c.resolverMap[name] = res 215 | } 216 | default: 217 | return fmt.Errorf("invalid resolver type for %s", name) 218 | } 219 | 220 | return nil 221 | } 222 | 223 | func getNodeName(node ast.Node) string { 224 | switch node.GetKind() { 225 | case kinds.ObjectDefinition: 226 | return node.(*ast.ObjectDefinition).Name.Value 227 | case kinds.ScalarDefinition: 228 | return node.(*ast.ScalarDefinition).Name.Value 229 | case kinds.EnumDefinition: 230 | return node.(*ast.EnumDefinition).Name.Value 231 | case kinds.InputObjectDefinition: 232 | return node.(*ast.InputObjectDefinition).Name.Value 233 | case kinds.InterfaceDefinition: 234 | return node.(*ast.InterfaceDefinition).Name.Value 235 | case kinds.UnionDefinition: 236 | return node.(*ast.UnionDefinition).Name.Value 237 | case kinds.DirectiveDefinition: 238 | return node.(*ast.DirectiveDefinition).Name.Value 239 | } 240 | 241 | return "" 242 | } 243 | 244 | // determines if a node will resolve eventually or with a thunk 245 | // false if there is no possibility 246 | func (c *registry) willResolve(name string) bool { 247 | if _, ok := c.types[name]; ok { 248 | return true 249 | } 250 | for _, n := range c.unresolvedDefs { 251 | if getNodeName(n) == name { 252 | return true 253 | } 254 | } 255 | return false 256 | } 257 | 258 | // iteratively resolves dependencies until all types are resolved 259 | func (c *registry) resolveDefinitions() error { 260 | unresolved := []ast.Node{} 261 | 262 | for len(c.unresolvedDefs) > 0 && c.iterations < c.maxIterations { 263 | c.iterations = c.iterations + 1 264 | 265 | for _, definition := range c.unresolvedDefs { 266 | switch nodeKind := definition.GetKind(); nodeKind { 267 | case kinds.DirectiveDefinition: 268 | if err := c.buildDirectiveFromAST(definition.(*ast.DirectiveDefinition)); err != nil { 269 | if err == errUnresolvedDependencies { 270 | unresolved = append(unresolved, definition) 271 | } else { 272 | return err 273 | } 274 | } 275 | case kinds.ScalarDefinition: 276 | if err := c.buildScalarFromAST(definition.(*ast.ScalarDefinition)); err != nil { 277 | if err == errUnresolvedDependencies { 278 | unresolved = append(unresolved, definition) 279 | } else { 280 | return err 281 | } 282 | } 283 | case kinds.EnumDefinition: 284 | if err := c.buildEnumFromAST(definition.(*ast.EnumDefinition)); err != nil { 285 | if err == errUnresolvedDependencies { 286 | unresolved = append(unresolved, definition) 287 | } else { 288 | return err 289 | } 290 | } 291 | case kinds.InputObjectDefinition: 292 | if err := c.buildInputObjectFromAST(definition.(*ast.InputObjectDefinition)); err != nil { 293 | if err == errUnresolvedDependencies { 294 | unresolved = append(unresolved, definition) 295 | } else { 296 | return err 297 | } 298 | } 299 | case kinds.ObjectDefinition: 300 | if err := c.buildObjectFromAST(definition.(*ast.ObjectDefinition)); err != nil { 301 | if err == errUnresolvedDependencies { 302 | unresolved = append(unresolved, definition) 303 | } else { 304 | return err 305 | } 306 | } 307 | case kinds.InterfaceDefinition: 308 | if err := c.buildInterfaceFromAST(definition.(*ast.InterfaceDefinition)); err != nil { 309 | if err == errUnresolvedDependencies { 310 | unresolved = append(unresolved, definition) 311 | } else { 312 | return err 313 | } 314 | } 315 | case kinds.UnionDefinition: 316 | if err := c.buildUnionFromAST(definition.(*ast.UnionDefinition)); err != nil { 317 | if err == errUnresolvedDependencies { 318 | unresolved = append(unresolved, definition) 319 | } else { 320 | return err 321 | } 322 | } 323 | case kinds.SchemaDefinition: 324 | if err := c.buildSchemaFromAST(definition.(*ast.SchemaDefinition)); err != nil { 325 | if err == errUnresolvedDependencies { 326 | unresolved = append(unresolved, definition) 327 | } else { 328 | return err 329 | } 330 | } 331 | } 332 | } 333 | 334 | // check if everything has been resolved 335 | if len(unresolved) == 0 { 336 | return nil 337 | } 338 | 339 | // prepare the next loop 340 | c.unresolvedDefs = unresolved 341 | 342 | if c.iterations < c.maxIterations { 343 | unresolved = []ast.Node{} 344 | } 345 | } 346 | 347 | if len(unresolved) > 0 { 348 | names := []string{} 349 | for _, n := range unresolved { 350 | if name := getNodeName(n); name != "" { 351 | names = append(names, name) 352 | } else { 353 | names = append(names, n.GetKind()) 354 | } 355 | } 356 | return fmt.Errorf("failed to resolve all type definitions: %v", names) 357 | } 358 | 359 | return nil 360 | } 361 | -------------------------------------------------------------------------------- /resolvers.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "github.com/graphql-go/graphql" 5 | "github.com/graphql-go/graphql/language/kinds" 6 | ) 7 | 8 | // Resolver interface to a resolver configuration 9 | type Resolver interface { 10 | getKind() string 11 | } 12 | 13 | // ResolverMap a map of resolver configurations. 14 | // Accept generic interfaces and identify types at build 15 | type ResolverMap map[string]interface{} 16 | 17 | // internal resolver map 18 | type resolverMap map[string]Resolver 19 | 20 | // FieldResolveMap map of field resolve functions 21 | type FieldResolveMap map[string]*FieldResolve 22 | 23 | // FieldResolve field resolver 24 | type FieldResolve struct { 25 | Resolve graphql.FieldResolveFn 26 | Subscribe graphql.FieldResolveFn 27 | } 28 | 29 | // ObjectResolver config for object resolver map 30 | type ObjectResolver struct { 31 | IsTypeOf graphql.IsTypeOfFn 32 | Fields FieldResolveMap 33 | } 34 | 35 | // GetKind gets the kind 36 | func (c *ObjectResolver) getKind() string { 37 | return kinds.ObjectDefinition 38 | } 39 | 40 | // ScalarResolver config for a scalar resolve map 41 | type ScalarResolver struct { 42 | Serialize graphql.SerializeFn 43 | ParseValue graphql.ParseValueFn 44 | ParseLiteral graphql.ParseLiteralFn 45 | } 46 | 47 | // GetKind gets the kind 48 | func (c *ScalarResolver) getKind() string { 49 | return kinds.ScalarDefinition 50 | } 51 | 52 | // InterfaceResolver config for interface resolve 53 | type InterfaceResolver struct { 54 | ResolveType graphql.ResolveTypeFn 55 | Fields FieldResolveMap 56 | } 57 | 58 | // GetKind gets the kind 59 | func (c *InterfaceResolver) getKind() string { 60 | return kinds.InterfaceDefinition 61 | } 62 | 63 | // UnionResolver config for interface resolve 64 | type UnionResolver struct { 65 | ResolveType graphql.ResolveTypeFn 66 | } 67 | 68 | // GetKind gets the kind 69 | func (c *UnionResolver) getKind() string { 70 | return kinds.UnionDefinition 71 | } 72 | 73 | // EnumResolver config for enum values 74 | type EnumResolver struct { 75 | Values map[string]interface{} 76 | } 77 | 78 | // GetKind gets the kind 79 | func (c *EnumResolver) getKind() string { 80 | return kinds.EnumDefinition 81 | } 82 | -------------------------------------------------------------------------------- /scalars/bool_string.go: -------------------------------------------------------------------------------- 1 | package scalars 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/graphql-go/graphql" 7 | "github.com/graphql-go/graphql/language/ast" 8 | ) 9 | 10 | // ScalarBoolString converts boolean to a string 11 | var ScalarBoolString = graphql.NewScalar( 12 | graphql.ScalarConfig{ 13 | Name: "BoolString", 14 | Description: "BoolString converts a boolean to/from a string", 15 | Serialize: func(value interface{}) interface{} { 16 | valStr := fmt.Sprintf("%v", value) 17 | return valStr == "true" || valStr == "1" 18 | }, 19 | ParseValue: func(value interface{}) interface{} { 20 | b, ok := value.(bool) 21 | if !ok { 22 | return "false" 23 | } else if b { 24 | return "true" 25 | } 26 | return "false" 27 | }, 28 | ParseLiteral: func(astValue ast.Value) interface{} { 29 | value := astValue.GetValue() 30 | b, ok := value.(bool) 31 | if !ok { 32 | return "false" 33 | } else if b { 34 | return "true" 35 | } 36 | return "false" 37 | }, 38 | }, 39 | ) 40 | -------------------------------------------------------------------------------- /scalars/json.go: -------------------------------------------------------------------------------- 1 | package scalars 2 | 3 | import ( 4 | "github.com/graphql-go/graphql" 5 | "github.com/graphql-go/graphql/language/ast" 6 | "github.com/graphql-go/graphql/language/kinds" 7 | ) 8 | 9 | // ScalarJSON a scalar JSON type 10 | var ScalarJSON = graphql.NewScalar( 11 | graphql.ScalarConfig{ 12 | Name: "JSON", 13 | Description: "The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf)", 14 | Serialize: func(value interface{}) interface{} { 15 | return value 16 | }, 17 | ParseValue: func(value interface{}) interface{} { 18 | return value 19 | }, 20 | ParseLiteral: parseLiteralJSONFn, 21 | }, 22 | ) 23 | 24 | // recursively parse ast 25 | func parseLiteralJSONFn(astValue ast.Value) interface{} { 26 | switch kind := astValue.GetKind(); kind { 27 | // get value for primitive types 28 | case kinds.StringValue, kinds.BooleanValue, kinds.IntValue, kinds.FloatValue: 29 | return astValue.GetValue() 30 | 31 | // make a map for objects 32 | case kinds.ObjectValue: 33 | obj := make(map[string]interface{}) 34 | for _, v := range astValue.GetValue().([]*ast.ObjectField) { 35 | obj[v.Name.Value] = parseLiteralJSONFn(v.Value) 36 | } 37 | return obj 38 | 39 | // make a slice for lists 40 | case kinds.ListValue: 41 | list := make([]interface{}, 0) 42 | for _, v := range astValue.GetValue().([]ast.Value) { 43 | list = append(list, parseLiteralJSONFn(v)) 44 | } 45 | return list 46 | 47 | // default to nil 48 | default: 49 | return nil 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /scalars/query_document.go: -------------------------------------------------------------------------------- 1 | package scalars 2 | 3 | import ( 4 | "encoding/json" 5 | "regexp" 6 | 7 | "github.com/graphql-go/graphql" 8 | "github.com/graphql-go/graphql/language/ast" 9 | "github.com/graphql-go/graphql/language/kinds" 10 | ) 11 | 12 | var queryDocOperatorRx = regexp.MustCompile(`^\$`) 13 | var storedQueryDocOperatorRx = regexp.MustCompile(`^_`) 14 | 15 | func replacePrefixedKeys(obj interface{}, prefixRx *regexp.Regexp, replacement string) interface{} { 16 | switch obj.(type) { 17 | case map[string]interface{}: 18 | result := map[string]interface{}{} 19 | for k, v := range obj.(map[string]interface{}) { 20 | newKey := prefixRx.ReplaceAllString(k, replacement) 21 | result[newKey] = replacePrefixedKeys(v, prefixRx, replacement) 22 | } 23 | return result 24 | 25 | case []interface{}: 26 | result := []interface{}{} 27 | for _, v := range obj.([]interface{}) { 28 | result = append(result, replacePrefixedKeys(v, prefixRx, replacement)) 29 | } 30 | return result 31 | 32 | default: 33 | return obj 34 | } 35 | } 36 | 37 | func serializeQueryDocFn(value interface{}) interface{} { 38 | return replacePrefixedKeys(value, storedQueryDocOperatorRx, "$") 39 | } 40 | 41 | func parseValueQueryDocFn(value interface{}) interface{} { 42 | return replacePrefixedKeys(value, queryDocOperatorRx, "_") 43 | } 44 | 45 | func parseLiteralQueryDocFn(astValue ast.Value) interface{} { 46 | var val interface{} 47 | switch astValue.GetKind() { 48 | case kinds.StringValue: 49 | bvalue := []byte(astValue.GetValue().(string)) 50 | if err := json.Unmarshal(bvalue, &val); err != nil { 51 | return nil 52 | } 53 | return replacePrefixedKeys(val, queryDocOperatorRx, "_") 54 | case kinds.ObjectValue: 55 | return parseLiteralJSONFn(astValue) 56 | } 57 | return nil 58 | } 59 | 60 | // ScalarQueryDocument a mongodb style query document 61 | var ScalarQueryDocument = graphql.NewScalar( 62 | graphql.ScalarConfig{ 63 | Name: "QueryDocument", 64 | Description: "MongoDB style query document", 65 | Serialize: serializeQueryDocFn, 66 | ParseValue: parseValueQueryDocFn, 67 | ParseLiteral: parseLiteralQueryDocFn, 68 | }, 69 | ) 70 | -------------------------------------------------------------------------------- /scalars/string_set.go: -------------------------------------------------------------------------------- 1 | package scalars 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/graphql-go/graphql" 7 | "github.com/graphql-go/graphql/language/ast" 8 | ) 9 | 10 | func ensureArray(value interface{}) interface{} { 11 | switch kind := reflect.TypeOf(value).Kind(); kind { 12 | case reflect.Slice, reflect.Array: 13 | return value 14 | default: 15 | if reflect.ValueOf(value).IsNil() { 16 | return nil 17 | } 18 | return []interface{}{value} 19 | } 20 | } 21 | 22 | func serializeStringSetFn(value interface{}) interface{} { 23 | switch kind := reflect.TypeOf(value).Kind(); kind { 24 | case reflect.Slice, reflect.Array: 25 | v := reflect.ValueOf(value) 26 | if v.Len() == 1 { 27 | return v.Index(0).Interface() 28 | } 29 | return value 30 | default: 31 | return []interface{}{} 32 | } 33 | } 34 | 35 | // ScalarStringSet allows string or array of strings 36 | // stores as an array of strings 37 | var ScalarStringSet = graphql.NewScalar( 38 | graphql.ScalarConfig{ 39 | Name: "StringSet", 40 | Description: "StringSet allows either a string or list of strings", 41 | Serialize: serializeStringSetFn, 42 | ParseValue: func(value interface{}) interface{} { 43 | return ensureArray(value) 44 | }, 45 | ParseLiteral: func(astValue ast.Value) interface{} { 46 | return ensureArray(parseLiteralJSONFn(astValue)) 47 | }, 48 | }, 49 | ) 50 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/graphql-go/graphql" 9 | "github.com/graphql-go/graphql/language/ast" 10 | ) 11 | 12 | // default root type names 13 | const ( 14 | DefaultRootQueryName = "Query" 15 | DefaultRootMutationName = "Mutation" 16 | DefaultRootSubscriptionName = "Subscription" 17 | ) 18 | 19 | // MakeExecutableSchema is shorthand for ExecutableSchema{}.Make(ctx context.Context) 20 | func MakeExecutableSchema(config ExecutableSchema) (graphql.Schema, error) { 21 | return config.Make(context.Background()) 22 | } 23 | 24 | // MakeExecutableSchemaWithContext make a schema and supply a context 25 | func MakeExecutableSchemaWithContext(ctx context.Context, config ExecutableSchema) (graphql.Schema, error) { 26 | return config.Make(ctx) 27 | } 28 | 29 | // ExecutableSchema configuration for making an executable schema 30 | // this attempts to provide similar functionality to Apollo graphql-tools 31 | // https://www.apollographql.com/docs/graphql-tools/generate-schema 32 | type ExecutableSchema struct { 33 | document *ast.Document 34 | TypeDefs interface{} // a string, []string, or func() []string 35 | Resolvers map[string]interface{} // a map of Resolver, Directive, Scalar, Enum, Object, InputObject, Union, or Interface 36 | SchemaDirectives SchemaDirectiveVisitorMap // Map of SchemaDirectiveVisitor 37 | Extensions []graphql.Extension // GraphQL extensions 38 | Debug bool // Prints debug messages during compile 39 | } 40 | 41 | // Document returns the document 42 | func (c *ExecutableSchema) Document() *ast.Document { 43 | return c.document 44 | } 45 | 46 | // Make creates a graphql schema config, this struct maintains intact the types and does not require the use of a non empty Query 47 | func (c *ExecutableSchema) Make(ctx context.Context) (graphql.Schema, error) { 48 | // combine the TypeDefs 49 | document, err := c.ConcatenateTypeDefs() 50 | if err != nil { 51 | return graphql.Schema{}, err 52 | } 53 | 54 | c.document = document 55 | 56 | // create a new registry 57 | registry, err := newRegistry(ctx, c.Resolvers, c.SchemaDirectives, c.Extensions, document) 58 | if err != nil { 59 | return graphql.Schema{}, err 60 | } 61 | 62 | if registry.dependencyMap, err = registry.IdentifyDependencies(); err != nil { 63 | return graphql.Schema{}, err 64 | } 65 | 66 | // resolve the document definitions 67 | if err := registry.resolveDefinitions(); err != nil { 68 | return graphql.Schema{}, err 69 | } 70 | 71 | // check if schema was created by definition 72 | if registry.schema != nil { 73 | return *registry.schema, nil 74 | } 75 | 76 | // otherwise build a schema from default object names 77 | query, err := registry.getObject(DefaultRootQueryName) 78 | if err != nil { 79 | return graphql.Schema{}, err 80 | } 81 | 82 | mutation, _ := registry.getObject(DefaultRootMutationName) 83 | subscription, _ := registry.getObject(DefaultRootSubscriptionName) 84 | 85 | // create a new schema config 86 | schemaConfig := &graphql.SchemaConfig{ 87 | Query: query, 88 | Mutation: mutation, 89 | Subscription: subscription, 90 | Types: registry.typeArray(), 91 | Directives: registry.directiveArray(), 92 | Extensions: c.Extensions, 93 | } 94 | 95 | schema, err := graphql.NewSchema(*schemaConfig) 96 | if err != nil && c.Debug { 97 | j, _ := json.MarshalIndent(registry.dependencyMap, "", " ") 98 | fmt.Println("Unresolved types, thunks will be used") 99 | fmt.Println(string(j)) 100 | } 101 | 102 | // create a new schema 103 | return schema, nil 104 | } 105 | 106 | // build a schema from an ast 107 | func (c *registry) buildSchemaFromAST(definition *ast.SchemaDefinition) error { 108 | schemaConfig := &graphql.SchemaConfig{ 109 | Types: c.typeArray(), 110 | Directives: c.directiveArray(), 111 | Extensions: c.extensions, 112 | } 113 | 114 | // add operations 115 | for _, op := range definition.OperationTypes { 116 | switch op.Operation { 117 | case ast.OperationTypeQuery: 118 | if object, err := c.getObject(op.Type.Name.Value); err == nil { 119 | schemaConfig.Query = object 120 | } else { 121 | return err 122 | } 123 | case ast.OperationTypeMutation: 124 | if object, err := c.getObject(op.Type.Name.Value); err == nil { 125 | schemaConfig.Mutation = object 126 | } else { 127 | return err 128 | } 129 | case ast.OperationTypeSubscription: 130 | if object, err := c.getObject(op.Type.Name.Value); err == nil { 131 | schemaConfig.Subscription = object 132 | } else { 133 | return err 134 | } 135 | } 136 | } 137 | 138 | // apply schema directives 139 | if err := c.applyDirectives(applyDirectiveParams{ 140 | config: schemaConfig, 141 | directives: definition.Directives, 142 | node: definition, 143 | }); err != nil { 144 | return err 145 | } 146 | 147 | // build the schema 148 | schema, err := graphql.NewSchema(*schemaConfig) 149 | if err != nil { 150 | return err 151 | } 152 | 153 | c.schema = &schema 154 | return nil 155 | } 156 | -------------------------------------------------------------------------------- /schema_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/graphql-go/graphql" 7 | ) 8 | 9 | func TestInterface(t *testing.T) { 10 | typeDefs := ` 11 | interface User { 12 | id: ID 13 | type: String 14 | name: String 15 | } 16 | 17 | type UserAccount implements User { 18 | id: ID 19 | type: String 20 | name: String 21 | username: String 22 | } 23 | 24 | type ServiceAccount implements User { 25 | id: ID 26 | type: String 27 | name: String 28 | client_id: String 29 | } 30 | 31 | type Query { 32 | users: [User] 33 | } 34 | ` 35 | users := []map[string]interface{}{ 36 | { 37 | "id": "1", 38 | "type": "user", 39 | "name": "User1", 40 | "username": "user1", 41 | }, 42 | { 43 | "id": "1", 44 | "type": "service", 45 | "name": "Service1", 46 | "client_id": "1234567890", 47 | }, 48 | } 49 | 50 | schema, err := MakeExecutableSchema(ExecutableSchema{ 51 | TypeDefs: typeDefs, 52 | Resolvers: map[string]interface{}{ 53 | "User": &InterfaceResolver{ 54 | ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { 55 | value := p.Value.(map[string]interface{}) 56 | typ := value["type"].(string) 57 | if typ == "user" { 58 | return p.Info.Schema.Type("UserAccount").(*graphql.Object) 59 | } else if typ == "service" { 60 | return p.Info.Schema.Type("ServiceAccount").(*graphql.Object) 61 | } 62 | 63 | return nil 64 | }, 65 | }, 66 | "Query": &ObjectResolver{ 67 | Fields: FieldResolveMap{ 68 | "users": &FieldResolve{ 69 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 70 | return users, nil 71 | }, 72 | }, 73 | }, 74 | }, 75 | }, 76 | }) 77 | 78 | if err != nil { 79 | t.Errorf("failed to make schema: %v", err) 80 | return 81 | } 82 | 83 | r := graphql.Do(graphql.Params{ 84 | Schema: schema, 85 | RequestString: `query { 86 | users { 87 | id 88 | type 89 | name 90 | ... on UserAccount { 91 | username 92 | } 93 | ... on ServiceAccount { 94 | client_id 95 | } 96 | } 97 | }`, 98 | }) 99 | 100 | if r.HasErrors() { 101 | t.Error(r.Errors) 102 | return 103 | } 104 | 105 | // j, _ := json.MarshalIndent(r.Data, "", " ") 106 | // fmt.Printf("%s\n", j) 107 | } 108 | 109 | func TestMissingType(t *testing.T) { 110 | typeDefs := ` 111 | type Foo { 112 | name: String! 113 | meta: JSON 114 | } 115 | 116 | input Cyclic { 117 | name: String 118 | cyclic: Cyclic 119 | } 120 | 121 | type Query { 122 | foos: [Foo] 123 | }` 124 | 125 | // create some data 126 | foos := []map[string]interface{}{ 127 | { 128 | "name": "foo", 129 | "meta": map[string]interface{}{ 130 | "bar": "baz", 131 | }, 132 | }, 133 | } 134 | 135 | // make the schema 136 | _, err := MakeExecutableSchema(ExecutableSchema{ 137 | TypeDefs: typeDefs, 138 | Resolvers: map[string]interface{}{ 139 | "Query": &ObjectResolver{ 140 | Fields: FieldResolveMap{ 141 | "foos": &FieldResolve{ 142 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 143 | return foos, nil 144 | }, 145 | }, 146 | }, 147 | }, 148 | }, 149 | }) 150 | 151 | if err != nil { 152 | t.Error("failed to use tunks for cyclic type") 153 | return 154 | } 155 | } 156 | 157 | func TestMakeExecutableSchema(t *testing.T) { 158 | typeDefs := ` 159 | type Foo { 160 | name: String! 161 | description: String 162 | } 163 | 164 | type Query1 { 165 | foos( 166 | name: String 167 | ): [Foo] 168 | } 169 | 170 | schema { 171 | query: Query1 172 | } 173 | ` 174 | 175 | // create some data 176 | foos := []map[string]interface{}{ 177 | { 178 | "name": "foo", 179 | "description": "a foo", 180 | }, 181 | } 182 | 183 | // make the schema 184 | schema, err := MakeExecutableSchema(ExecutableSchema{ 185 | TypeDefs: typeDefs, 186 | Resolvers: map[string]interface{}{ 187 | "Query": &ObjectResolver{ 188 | Fields: FieldResolveMap{ 189 | "foos": &FieldResolve{ 190 | Resolve: func(p graphql.ResolveParams) (interface{}, error) { 191 | return foos, nil 192 | }, 193 | }, 194 | }, 195 | }, 196 | }, 197 | }) 198 | 199 | if err != nil { 200 | t.Error(err) 201 | return 202 | } 203 | 204 | // perform a query 205 | r := graphql.Do(graphql.Params{ 206 | Schema: schema, 207 | RequestString: `query Query { 208 | foos(name:"foo") { 209 | name 210 | description 211 | } 212 | }`, 213 | }) 214 | 215 | if r.HasErrors() { 216 | t.Error(r.Errors) 217 | return 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /server/graphiql.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "html/template" 7 | "net/http" 8 | 9 | "github.com/graphql-go/graphql" 10 | ) 11 | 12 | // GraphiqlVersion is the current version of GraphiQL 13 | var GraphiqlVersion = "1.4.1" 14 | 15 | type GraphiQLOptions struct { 16 | Version string 17 | SSL bool 18 | Endpoint string 19 | SubscriptionEndpoint string 20 | } 21 | 22 | func NewDefaultGraphiQLOptions() *GraphiQLOptions { 23 | return &GraphiQLOptions{ 24 | Version: GraphiqlVersion, 25 | } 26 | } 27 | 28 | func NewDefaultSSLGraphiQLOption() *GraphiQLOptions { 29 | return &GraphiQLOptions{ 30 | Version: GraphiqlVersion, 31 | SSL: true, 32 | } 33 | } 34 | 35 | // graphiqlData is the page data structure of the rendered GraphiQL page 36 | type graphiqlData struct { 37 | Endpoint string 38 | SubscriptionEndpoint string 39 | GraphiqlVersion string 40 | QueryString string 41 | VariablesString string 42 | OperationName string 43 | ResultString string 44 | } 45 | 46 | // renderGraphiQL renders the GraphiQL GUI 47 | func renderGraphiQL(config *GraphiQLOptions, w http.ResponseWriter, r *http.Request, params graphql.Params) { 48 | t := template.New("GraphiQL") 49 | t, err := t.Parse(graphiqlTemplate) 50 | if err != nil { 51 | http.Error(w, err.Error(), http.StatusInternalServerError) 52 | return 53 | } 54 | 55 | // Create variables string 56 | vars, err := json.MarshalIndent(params.VariableValues, "", " ") 57 | if err != nil { 58 | http.Error(w, err.Error(), http.StatusInternalServerError) 59 | return 60 | } 61 | varsString := string(vars) 62 | if varsString == "null" { 63 | varsString = "" 64 | } 65 | 66 | // Create result string 67 | var resString string 68 | if params.RequestString == "" { 69 | resString = "" 70 | } else { 71 | result, err := json.MarshalIndent(graphql.Do(params), "", " ") 72 | if err != nil { 73 | http.Error(w, err.Error(), http.StatusInternalServerError) 74 | return 75 | } 76 | resString = string(result) 77 | } 78 | 79 | endpoint := r.URL.Path 80 | if config.Endpoint != "" { 81 | endpoint = config.Endpoint 82 | } 83 | 84 | wsScheme := "ws:" 85 | if config.SSL { 86 | wsScheme = "wss:" 87 | } 88 | 89 | subscriptionEndpoint := fmt.Sprintf("%s//%v%s", wsScheme, r.Host, r.URL.Path) 90 | if config.SubscriptionEndpoint != "" { 91 | subscriptionEndpoint = config.SubscriptionEndpoint 92 | } 93 | 94 | d := graphiqlData{ 95 | GraphiqlVersion: GraphiqlVersion, 96 | QueryString: params.RequestString, 97 | ResultString: resString, 98 | VariablesString: varsString, 99 | OperationName: params.OperationName, 100 | Endpoint: endpoint, 101 | SubscriptionEndpoint: subscriptionEndpoint, 102 | } 103 | err = t.ExecuteTemplate(w, "index", d) 104 | if err != nil { 105 | http.Error(w, err.Error(), http.StatusInternalServerError) 106 | } 107 | } 108 | 109 | const graphiqlTemplate = ` 110 | {{ define "index" }} 111 | 112 | 113 | Simple GraphiQL Example 114 | 115 | 119 | 123 | 127 | 128 | 129 | 130 |
131 | 143 | 144 | 145 | {{end}} 146 | ` 147 | -------------------------------------------------------------------------------- /server/graphiql.tmpl: -------------------------------------------------------------------------------- 1 | {{ define "index" }} 2 | 3 | 4 | Simple GraphiQL Example 5 | 6 | 7 | 11 | 15 | 19 | 20 | 21 | 22 |
23 | 41 | 42 | 43 | {{end}} -------------------------------------------------------------------------------- /server/graphqlws.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/bhoriuchi/graphql-go-tools/server/graphqlws" 8 | "github.com/gorilla/websocket" 9 | "github.com/graphql-go/graphql" 10 | ) 11 | 12 | func (s *Server) newGraphQLWSConnection(ctx context.Context, r *http.Request, ws *websocket.Conn) { 13 | // Establish a GraphQL WebSocket connection 14 | graphqlws.NewConnection(ws, graphqlws.ConnectionConfig{ 15 | Authenticate: s.options.WS.AuthenticateFunc, 16 | Logger: s.log, 17 | EventHandlers: graphqlws.ConnectionEventHandlers{ 18 | Close: func(conn graphqlws.Connection) { 19 | s.log.Debugf("closing websocket: %s", conn.ID()) 20 | s.mgr.DelConn(conn.ID()) 21 | }, 22 | StartOperation: func( 23 | conn graphqlws.Connection, 24 | opID string, 25 | data *graphqlws.StartMessagePayload, 26 | ) []error { 27 | s.log.Debugf("start operations %s on connection %s", opID, conn.ID()) 28 | 29 | rootObject := map[string]interface{}{} 30 | if s.options.RootValueFunc != nil { 31 | rootObject = s.options.RootValueFunc(ctx, r) 32 | } 33 | ctx, cancelFunc := context.WithCancel(context.WithValue(context.Background(), ConnKey, conn)) 34 | resultChannel := graphql.Subscribe(graphql.Params{ 35 | Schema: s.schema, 36 | RequestString: data.Query, 37 | VariableValues: data.Variables, 38 | OperationName: data.OperationName, 39 | Context: ctx, 40 | RootObject: rootObject, 41 | }) 42 | 43 | s.mgr.Add(&ResultChan{ 44 | ch: resultChannel, 45 | cancelFunc: cancelFunc, 46 | ctx: ctx, 47 | cid: conn.ID(), 48 | oid: opID, 49 | }) 50 | 51 | go func() { 52 | for { 53 | select { 54 | case <-ctx.Done(): 55 | s.mgr.Del(conn.ID(), opID) 56 | return 57 | case res, more := <-resultChannel: 58 | if !more { 59 | return 60 | } 61 | 62 | errs := []error{} 63 | 64 | if res.HasErrors() { 65 | for _, err := range res.Errors { 66 | s.log.Debugf("subscription_error: %v", err) 67 | errs = append(errs, err.OriginalError()) 68 | } 69 | } 70 | 71 | conn.SendData(opID, &graphqlws.DataMessagePayload{ 72 | Data: res.Data, 73 | Errors: errs, 74 | }) 75 | } 76 | } 77 | }() 78 | 79 | return nil 80 | }, 81 | StopOperation: func(conn graphqlws.Connection, opID string) { 82 | s.log.Debugf("stop operation %s on connection %s", opID, conn.ID()) 83 | s.mgr.Del(conn.ID(), opID) 84 | }, 85 | }, 86 | }) 87 | } 88 | -------------------------------------------------------------------------------- /server/graphqlws/connections.go: -------------------------------------------------------------------------------- 1 | package graphqlws 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "sync" 9 | "time" 10 | 11 | "github.com/bhoriuchi/graphql-go-tools/server/logger" 12 | "github.com/google/uuid" 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | const ( 17 | // Constants for operation message types 18 | gqlConnectionAuth = "connection_auth" 19 | gqlConnectionInit = "connection_init" 20 | gqlConnectionAck = "connection_ack" 21 | gqlConnectionKeepAlive = "ka" 22 | gqlConnectionError = "connection_error" 23 | gqlConnectionTerminate = "connection_terminate" 24 | gqlStart = "start" 25 | gqlData = "data" 26 | gqlError = "error" 27 | gqlComplete = "complete" 28 | gqlStop = "stop" 29 | 30 | // Maximum size of incoming messages 31 | readLimit = 4096 32 | 33 | // Timeout for outgoing messages 34 | writeTimeout = 10 * time.Second 35 | ) 36 | 37 | // InitMessagePayload defines the parameters of a connection 38 | // init message. 39 | type InitMessagePayload struct { 40 | AuthToken string `json:"authToken"` 41 | Authorization string `json:"Authorization"` 42 | } 43 | 44 | // StartMessagePayload defines the parameters of an operation that 45 | // a client requests to be started. 46 | type StartMessagePayload struct { 47 | Query string `json:"query"` 48 | Variables map[string]interface{} `json:"variables"` 49 | OperationName string `json:"operationName"` 50 | } 51 | 52 | // DataMessagePayload defines the result data of an operation. 53 | type DataMessagePayload struct { 54 | Data interface{} `json:"data"` 55 | Errors []error `json:"errors"` 56 | } 57 | 58 | // OperationMessage represents a GraphQL WebSocket message. 59 | type OperationMessage struct { 60 | ID string `json:"id"` 61 | Type string `json:"type"` 62 | Payload interface{} `json:"payload"` 63 | } 64 | 65 | func (msg OperationMessage) String() string { 66 | s, _ := json.Marshal(msg) 67 | if s != nil { 68 | return string(s) 69 | } 70 | return "" 71 | } 72 | 73 | // AuthenticateFunc is a function that resolves an auth token 74 | // into a user (or returns an error if that isn't possible). 75 | type AuthenticateFunc func(data map[string]interface{}, conn Connection) (context.Context, error) 76 | 77 | // ConnectionEventHandlers define the event handlers for a connection. 78 | // Event handlers allow other system components to react to events such 79 | // as the connection closing or an operation being started or stopped. 80 | type ConnectionEventHandlers struct { 81 | // Close is called whenever the connection is closed, regardless of 82 | // whether this happens because of an error or a deliberate termination 83 | // by the client. 84 | Close func(Connection) 85 | 86 | // StartOperation is called whenever the client demands that a GraphQL 87 | // operation be started (typically a subscription). Event handlers 88 | // are expected to take the necessary steps to register the operation 89 | // and send data back to the client with the results eventually. 90 | StartOperation func(Connection, string, *StartMessagePayload) []error 91 | 92 | // StopOperation is called whenever the client stops a previously 93 | // started GraphQL operation (typically a subscription). Event handlers 94 | // are expected to unregister the operation and stop sending result 95 | // data to the client. 96 | StopOperation func(Connection, string) 97 | } 98 | 99 | // ConnectionConfig defines the configuration parameters of a 100 | // GraphQL WebSocket connection. 101 | type ConnectionConfig struct { 102 | Logger logger.Logger 103 | Authenticate AuthenticateFunc 104 | EventHandlers ConnectionEventHandlers 105 | } 106 | 107 | // Connection is an interface to represent GraphQL WebSocket connections. 108 | // Each connection is associated with an ID that is unique to the server. 109 | type Connection interface { 110 | // ID returns the unique ID of the connection. 111 | ID() string 112 | 113 | // Context returns the context for the connection 114 | Context() context.Context 115 | 116 | // WS the websocket 117 | WS() *websocket.Conn 118 | 119 | // SendData sends results of executing an operation (typically a 120 | // subscription) to the client. 121 | SendData(string, *DataMessagePayload) 122 | 123 | // SendError sends an error to the client. 124 | SendError(error) 125 | } 126 | 127 | /** 128 | * The default implementation of the Connection interface. 129 | */ 130 | 131 | type connection struct { 132 | id string 133 | ws *websocket.Conn 134 | config ConnectionConfig 135 | logger logger.Logger 136 | outgoing chan OperationMessage 137 | closeMutex *sync.Mutex 138 | closed bool 139 | context context.Context 140 | } 141 | 142 | func operationMessageForType(messageType string) OperationMessage { 143 | return OperationMessage{ 144 | Type: messageType, 145 | } 146 | } 147 | 148 | // NewConnection establishes a GraphQL WebSocket connection. It implements 149 | // the GraphQL WebSocket protocol by managing its internal state and handling 150 | // the client-server communication. 151 | func NewConnection(ws *websocket.Conn, config ConnectionConfig) Connection { 152 | conn := new(connection) 153 | conn.id = uuid.New().String() 154 | conn.ws = ws 155 | conn.context = context.Background() 156 | conn.config = config 157 | conn.logger = config.Logger 158 | conn.closed = false 159 | conn.closeMutex = &sync.Mutex{} 160 | conn.outgoing = make(chan OperationMessage) 161 | 162 | go conn.writeLoop() 163 | go conn.readLoop() 164 | conn.logger.Infof("Created connection") 165 | 166 | return conn 167 | } 168 | 169 | func (conn *connection) ID() string { 170 | return conn.id 171 | } 172 | 173 | func (conn *connection) Context() context.Context { 174 | return conn.context 175 | } 176 | 177 | func (conn *connection) WS() *websocket.Conn { 178 | return conn.ws 179 | } 180 | 181 | func (conn *connection) SendData(opID string, data *DataMessagePayload) { 182 | msg := operationMessageForType(gqlData) 183 | msg.ID = opID 184 | msg.Payload = data 185 | conn.closeMutex.Lock() 186 | if !conn.closed { 187 | conn.outgoing <- msg 188 | } 189 | conn.closeMutex.Unlock() 190 | } 191 | 192 | func (conn *connection) SendError(err error) { 193 | msg := operationMessageForType(gqlError) 194 | msg.Payload = err.Error() 195 | conn.closeMutex.Lock() 196 | if !conn.closed { 197 | conn.outgoing <- msg 198 | } 199 | conn.closeMutex.Unlock() 200 | } 201 | 202 | func (conn *connection) sendOperationErrors(opID string, errs []error) { 203 | if conn.closed { 204 | return 205 | } 206 | 207 | msg := operationMessageForType(gqlError) 208 | msg.ID = opID 209 | msg.Payload = errs 210 | conn.closeMutex.Lock() 211 | if !conn.closed { 212 | conn.outgoing <- msg 213 | } 214 | 215 | conn.closeMutex.Unlock() 216 | } 217 | 218 | func (conn *connection) close() { 219 | // Close the write loop by closing the outgoing messages channels 220 | conn.closeMutex.Lock() 221 | conn.closed = true 222 | close(conn.outgoing) 223 | conn.closeMutex.Unlock() 224 | 225 | // Notify event handlers 226 | if conn.config.EventHandlers.Close != nil { 227 | conn.config.EventHandlers.Close(conn) 228 | } 229 | 230 | conn.logger.Infof("closed connection") 231 | } 232 | 233 | func (conn *connection) writeLoop() { 234 | // Close the WebSocket connection when leaving the write loop; 235 | // this ensures the read loop is also terminated and the connection 236 | // closed cleanly 237 | defer conn.ws.Close() 238 | 239 | for { 240 | msg, ok := <-conn.outgoing 241 | // Close the write loop when the outgoing messages channel is closed; 242 | // this will close the connection 243 | if !ok { 244 | return 245 | } 246 | 247 | // conn.logger.Debugf("send message: %s", msg.String()) 248 | conn.ws.SetWriteDeadline(time.Now().Add(writeTimeout)) 249 | 250 | // Send the message to the client; if this times out, the WebSocket 251 | // connection will be corrupt, hence we need to close the write loop 252 | // and the connection immediately 253 | if err := conn.ws.WriteJSON(msg); err != nil { 254 | conn.logger.Warnf("sending message failed: %s", err) 255 | return 256 | } 257 | } 258 | } 259 | 260 | func (conn *connection) readLoop() { 261 | // Close the WebSocket connection when leaving the read loop 262 | defer conn.ws.Close() 263 | conn.ws.SetReadLimit(readLimit) 264 | 265 | for { 266 | // Read the next message received from the client 267 | rawPayload := json.RawMessage{} 268 | msg := OperationMessage{ 269 | Payload: &rawPayload, 270 | } 271 | err := conn.ws.ReadJSON(&msg) 272 | 273 | // If this causes an error, close the connection and read loop immediately; 274 | // see https://github.com/gorilla/websocket/blob/master/conn.go#L924 for 275 | // more information on why this is necessary 276 | if err != nil { 277 | conn.logger.Warnf("force closing connection: %s", err) 278 | conn.close() 279 | return 280 | } 281 | 282 | // conn.logger.Debugf("received message (%s): %s", msg.ID, msg.Type) 283 | 284 | switch msg.Type { 285 | case gqlConnectionAuth: 286 | data := map[string]interface{}{} 287 | if err := json.Unmarshal(rawPayload, &data); err != nil { 288 | conn.logger.Errorf("Invalid %s data: %v", msg.Type, err) 289 | conn.SendError(errors.New("invalid GQL_CONNECTION_AUTH payload")) 290 | } else { 291 | if conn.config.Authenticate != nil { 292 | ctx, err := conn.config.Authenticate(data, conn) 293 | if err != nil { 294 | msg := operationMessageForType(gqlConnectionError) 295 | msg.Payload = fmt.Sprintf("Failed to authenticate user: %v", err) 296 | conn.outgoing <- msg 297 | } else { 298 | conn.context = ctx 299 | } 300 | } 301 | } 302 | 303 | // When the GraphQL WS connection is initiated, send an ACK back 304 | case gqlConnectionInit: 305 | data := map[string]interface{}{} 306 | if err := json.Unmarshal(rawPayload, &data); err != nil { 307 | conn.logger.Errorf("Invalid %s data: %v", msg.Type, err) 308 | conn.SendError(errors.New("invalid GQL_CONNECTION_INIT payload")) 309 | } else { 310 | if conn.config.Authenticate != nil { 311 | ctx, err := conn.config.Authenticate(data, conn) 312 | if err != nil { 313 | msg := operationMessageForType(gqlConnectionError) 314 | msg.Payload = fmt.Sprintf("Failed to authenticate user: %v", err) 315 | conn.outgoing <- msg 316 | } else { 317 | conn.context = ctx 318 | conn.outgoing <- operationMessageForType(gqlConnectionAck) 319 | } 320 | } else { 321 | conn.outgoing <- operationMessageForType(gqlConnectionAck) 322 | } 323 | } 324 | 325 | // Let event handlers deal with starting operations 326 | case gqlStart: 327 | if conn.config.EventHandlers.StartOperation != nil { 328 | data := StartMessagePayload{} 329 | if err := json.Unmarshal(rawPayload, &data); err != nil { 330 | conn.SendError(errors.New("invalid GQL_START payload")) 331 | } else { 332 | errs := conn.config.EventHandlers.StartOperation(conn, msg.ID, &data) 333 | if errs != nil { 334 | conn.sendOperationErrors(msg.ID, errs) 335 | } 336 | } 337 | } 338 | 339 | // Let event handlers deal with stopping operations 340 | case gqlStop: 341 | if conn.config.EventHandlers.StopOperation != nil { 342 | conn.config.EventHandlers.StopOperation(conn, msg.ID) 343 | } 344 | 345 | // When the GraphQL WS connection is terminated by the client, 346 | // close the connection and close the read loop 347 | case gqlConnectionTerminate: 348 | // conn.logger.Debugf("connection terminated by client") 349 | conn.close() 350 | return 351 | 352 | // GraphQL WS protocol messages that are not handled represent 353 | // a bug in our implementation; make this very obvious by logging 354 | // an error 355 | default: 356 | conn.logger.Errorf("unhandled message: %s", msg.String()) 357 | } 358 | } 359 | } 360 | -------------------------------------------------------------------------------- /server/handler.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | 12 | "github.com/graphql-go/graphql" 13 | "github.com/graphql-go/graphql/gqlerrors" 14 | ) 15 | 16 | // RequestOptions options 17 | type RequestOptions struct { 18 | Query string `json:"query" url:"query" schema:"query"` 19 | Variables map[string]interface{} `json:"variables" url:"variables" schema:"variables"` 20 | OperationName string `json:"operationName" url:"operationName" schema:"operationName"` 21 | } 22 | 23 | // a workaround for getting`variables` as a JSON string 24 | type requestOptionsCompatibility struct { 25 | Query string `json:"query" url:"query" schema:"query"` 26 | Variables string `json:"variables" url:"variables" schema:"variables"` 27 | OperationName string `json:"operationName" url:"operationName" schema:"operationName"` 28 | } 29 | 30 | func getFromForm(values url.Values) *RequestOptions { 31 | query := values.Get("query") 32 | if query != "" { 33 | // get variables map 34 | variables := make(map[string]interface{}, len(values)) 35 | variablesStr := values.Get("variables") 36 | json.Unmarshal([]byte(variablesStr), &variables) 37 | 38 | return &RequestOptions{ 39 | Query: query, 40 | Variables: variables, 41 | OperationName: values.Get("operationName"), 42 | } 43 | } 44 | 45 | return nil 46 | } 47 | 48 | // NewRequestOptions Parses a http.Request into GraphQL request options struct 49 | func NewRequestOptions(r *http.Request) *RequestOptions { 50 | if reqOpt := getFromForm(r.URL.Query()); reqOpt != nil { 51 | return reqOpt 52 | } 53 | 54 | if r.Method != http.MethodPost { 55 | return &RequestOptions{} 56 | } 57 | 58 | if r.Body == nil { 59 | return &RequestOptions{} 60 | } 61 | 62 | // TODO: improve Content-Type handling 63 | contentTypeStr := r.Header.Get("Content-Type") 64 | contentTypeTokens := strings.Split(contentTypeStr, ";") 65 | contentType := contentTypeTokens[0] 66 | 67 | switch contentType { 68 | case ContentTypeGraphQL: 69 | body, err := ioutil.ReadAll(r.Body) 70 | if err != nil { 71 | return &RequestOptions{} 72 | } 73 | return &RequestOptions{ 74 | Query: string(body), 75 | } 76 | case ContentTypeFormURLEncoded: 77 | if err := r.ParseForm(); err != nil { 78 | return &RequestOptions{} 79 | } 80 | 81 | if reqOpt := getFromForm(r.PostForm); reqOpt != nil { 82 | return reqOpt 83 | } 84 | 85 | return &RequestOptions{} 86 | 87 | case ContentTypeJSON: 88 | fallthrough 89 | default: 90 | var opts RequestOptions 91 | body, err := ioutil.ReadAll(r.Body) 92 | if err != nil { 93 | return &opts 94 | } 95 | err = json.Unmarshal(body, &opts) 96 | if err != nil { 97 | // Probably `variables` was sent as a string instead of an object. 98 | // So, we try to be polite and try to parse that as a JSON string 99 | var optsCompatible requestOptionsCompatibility 100 | json.Unmarshal(body, &optsCompatible) 101 | json.Unmarshal([]byte(optsCompatible.Variables), &opts.Variables) 102 | } 103 | return &opts 104 | } 105 | } 106 | 107 | // GetRequestOptions Parses a http.Request into GraphQL request options struct without clearning the body 108 | func GetRequestOptions(r *http.Request) *RequestOptions { 109 | if reqOpt := getFromForm(r.URL.Query()); reqOpt != nil { 110 | return reqOpt 111 | } 112 | 113 | if r.Method != http.MethodPost { 114 | return &RequestOptions{} 115 | } 116 | 117 | if r.Body == nil { 118 | return &RequestOptions{} 119 | } 120 | 121 | // TODO: improve Content-Type handling 122 | contentTypeStr := r.Header.Get("Content-Type") 123 | contentTypeTokens := strings.Split(contentTypeStr, ";") 124 | contentType := contentTypeTokens[0] 125 | 126 | switch contentType { 127 | case ContentTypeGraphQL: 128 | body, err := ioutil.ReadAll(r.Body) 129 | if err != nil { 130 | return &RequestOptions{} 131 | } 132 | r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) 133 | return &RequestOptions{ 134 | Query: string(body), 135 | } 136 | case ContentTypeFormURLEncoded: 137 | if err := r.ParseForm(); err != nil { 138 | return &RequestOptions{} 139 | } 140 | 141 | if reqOpt := getFromForm(r.PostForm); reqOpt != nil { 142 | return reqOpt 143 | } 144 | 145 | return &RequestOptions{} 146 | 147 | case ContentTypeJSON: 148 | fallthrough 149 | default: 150 | var opts RequestOptions 151 | body, err := ioutil.ReadAll(r.Body) 152 | if err != nil { 153 | return &opts 154 | } 155 | r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) 156 | err = json.Unmarshal(body, &opts) 157 | if err != nil { 158 | // Probably `variables` was sent as a string instead of an object. 159 | // So, we try to be polite and try to parse that as a JSON string 160 | var optsCompatible requestOptionsCompatibility 161 | json.Unmarshal(body, &optsCompatible) 162 | json.Unmarshal([]byte(optsCompatible.Variables), &opts.Variables) 163 | } 164 | return &opts 165 | } 166 | } 167 | 168 | // ContextHandler provides an entrypoint into executing graphQL queries with a 169 | // user-provided context. 170 | func (s *Server) ContextHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { 171 | // get query 172 | opts := NewRequestOptions(r) 173 | 174 | // execute graphql query 175 | params := graphql.Params{ 176 | Schema: s.schema, 177 | RequestString: opts.Query, 178 | VariableValues: opts.Variables, 179 | OperationName: opts.OperationName, 180 | Context: ctx, 181 | } 182 | if s.options.RootValueFunc != nil { 183 | params.RootObject = s.options.RootValueFunc(ctx, r) 184 | } 185 | result := graphql.Do(params) 186 | 187 | if formatErrorFunc := s.options.FormatErrorFunc; formatErrorFunc != nil && len(result.Errors) > 0 { 188 | formatted := make([]gqlerrors.FormattedError, len(result.Errors)) 189 | for i, formattedError := range result.Errors { 190 | formatted[i] = formatErrorFunc(formattedError.OriginalError()) 191 | } 192 | result.Errors = formatted 193 | } 194 | 195 | if s.options.GraphiQL != nil { 196 | acceptHeader := r.Header.Get("Accept") 197 | _, raw := r.URL.Query()["raw"] 198 | if !raw && !strings.Contains(acceptHeader, "application/json") && strings.Contains(acceptHeader, "text/html") { 199 | renderGraphiQL(s.options.GraphiQL, w, r, params) 200 | return 201 | } 202 | } else if s.options.Playground != nil { 203 | acceptHeader := r.Header.Get("Accept") 204 | _, raw := r.URL.Query()["raw"] 205 | if !raw && !strings.Contains(acceptHeader, "application/json") && strings.Contains(acceptHeader, "text/html") { 206 | renderPlayground(s.options.Playground, w, r) 207 | return 208 | } 209 | } 210 | 211 | // use proper JSON Header 212 | w.Header().Add("Content-Type", "application/json; charset=utf-8") 213 | 214 | var buff []byte 215 | if s.options.Pretty { 216 | w.WriteHeader(http.StatusOK) 217 | buff, _ = json.MarshalIndent(result, "", "\t") 218 | 219 | w.Write(buff) 220 | } else { 221 | w.WriteHeader(http.StatusOK) 222 | buff, _ = json.Marshal(result) 223 | 224 | w.Write(buff) 225 | } 226 | 227 | if s.options.ResultCallbackFunc != nil { 228 | s.options.ResultCallbackFunc(ctx, ¶ms, result, buff) 229 | } 230 | } 231 | 232 | func (s *Server) WSHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { 233 | // Establish a WebSocket connection 234 | var ws, err = s.upgrader.Upgrade(w, r, nil) 235 | 236 | // Bail out if the WebSocket connection could not be established 237 | if err != nil { 238 | s.log.Warnf("Failed to establish WebSocket connection", err) 239 | return 240 | } 241 | 242 | // Close the connection early if it doesn't implement the graphql-ws protocol 243 | if ws.Subprotocol() == "graphql-ws" { 244 | s.newGraphQLWSConnection(ctx, r, ws) 245 | return 246 | } 247 | 248 | // TODO: support other popular protocols 249 | s.log.Warnf("Connection does not implement the GraphQL WS protocol. Subprotocol: %s", ws.Subprotocol()) 250 | ws.Close() 251 | } 252 | -------------------------------------------------------------------------------- /server/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | type Logger interface { 4 | Infof(format string, data ...interface{}) 5 | Debugf(format string, data ...interface{}) 6 | Errorf(format string, data ...interface{}) 7 | Warnf(format string, data ...interface{}) 8 | } 9 | 10 | type NoopLogger struct{} 11 | 12 | func (n *NoopLogger) Infof(format string, data ...interface{}) {} 13 | func (n *NoopLogger) Debugf(format string, data ...interface{}) {} 14 | func (n *NoopLogger) Errorf(format string, data ...interface{}) {} 15 | func (n *NoopLogger) Warnf(format string, data ...interface{}) {} 16 | -------------------------------------------------------------------------------- /server/manager.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/graphql-go/graphql" 8 | ) 9 | 10 | type ChanMgr struct { 11 | mx sync.Mutex 12 | conns map[string]map[string]*ResultChan 13 | } 14 | 15 | type ResultChan struct { 16 | ch chan *graphql.Result 17 | cancelFunc context.CancelFunc 18 | ctx context.Context 19 | cid string 20 | oid string 21 | } 22 | 23 | func (c *ChanMgr) Add(rc *ResultChan) { // Add(cid, oid string, ch chan *graphql.Result) { 24 | c.mx.Lock() 25 | defer c.mx.Unlock() 26 | 27 | conn, ok := c.conns[rc.cid] 28 | if !ok { 29 | conn = make(map[string]*ResultChan) 30 | c.conns[rc.cid] = conn 31 | } 32 | 33 | conn[rc.oid] = rc 34 | } 35 | 36 | func (c *ChanMgr) DelConn(cid string) bool { 37 | c.mx.Lock() 38 | defer c.mx.Unlock() 39 | 40 | conn, ok := c.conns[cid] 41 | if !ok { 42 | return false 43 | } 44 | 45 | for oid, rc := range conn { 46 | if rc.cancelFunc != nil { 47 | rc.cancelFunc() 48 | } 49 | delete(conn, oid) 50 | } 51 | 52 | delete(conn, cid) 53 | return true 54 | } 55 | 56 | func (c *ChanMgr) Del(cid, oid string) bool { 57 | c.mx.Lock() 58 | defer c.mx.Unlock() 59 | 60 | conn, ok := c.conns[cid] 61 | if !ok { 62 | return false 63 | } 64 | 65 | rc, ok := conn[oid] 66 | if !ok { 67 | return false 68 | } 69 | 70 | if rc.cancelFunc != nil { 71 | rc.cancelFunc() 72 | } 73 | delete(conn, oid) 74 | 75 | if len(c.conns[cid]) == 0 { 76 | delete(c.conns, cid) 77 | } 78 | 79 | return true 80 | } 81 | -------------------------------------------------------------------------------- /server/playground.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "html/template" 6 | "net/http" 7 | ) 8 | 9 | // PlaygroundVersion the default version to use 10 | var PlaygroundVersion = "1.7.27" 11 | 12 | type PlaygroundOptions struct { 13 | Version string 14 | SSL bool 15 | Endpoint string 16 | SubscriptionEndpoint string 17 | } 18 | 19 | func NewDefaultPlaygroundOptions() *PlaygroundOptions { 20 | return &PlaygroundOptions{ 21 | Version: PlaygroundVersion, 22 | } 23 | } 24 | 25 | func NewDefaultSSLPlaygroundOptions() *PlaygroundOptions { 26 | return &PlaygroundOptions{ 27 | Version: PlaygroundVersion, 28 | SSL: true, 29 | } 30 | } 31 | 32 | type playgroundData struct { 33 | PlaygroundVersion string 34 | Endpoint string 35 | SubscriptionEndpoint string 36 | SetTitle bool 37 | } 38 | 39 | // renderPlayground renders the Playground GUI 40 | func renderPlayground(config *PlaygroundOptions, w http.ResponseWriter, r *http.Request) { 41 | fmt.Printf("CONFIG %+v\n", config) 42 | 43 | t := template.New("Playground") 44 | t, err := t.Parse(graphcoolPlaygroundTemplate) 45 | if err != nil { 46 | http.Error(w, err.Error(), http.StatusInternalServerError) 47 | return 48 | } 49 | 50 | endpoint := r.URL.Path 51 | if config.Endpoint != "" { 52 | endpoint = config.Endpoint 53 | } 54 | 55 | wsScheme := "ws:" 56 | if config.SSL { 57 | wsScheme = "wss:" 58 | } 59 | 60 | subscriptionEndpoint := fmt.Sprintf("%s//%v%s", wsScheme, r.Host, r.URL.Path) 61 | if config.SubscriptionEndpoint != "" { 62 | subscriptionEndpoint = config.SubscriptionEndpoint 63 | } 64 | 65 | version := PlaygroundVersion 66 | if config.Version != "" { 67 | version = config.Version 68 | } 69 | 70 | d := playgroundData{ 71 | PlaygroundVersion: version, 72 | Endpoint: endpoint, 73 | SubscriptionEndpoint: subscriptionEndpoint, 74 | SetTitle: true, 75 | } 76 | err = t.ExecuteTemplate(w, "index", d) 77 | if err != nil { 78 | http.Error(w, err.Error(), http.StatusInternalServerError) 79 | } 80 | } 81 | 82 | const graphcoolPlaygroundTemplate = ` 83 | {{ define "index" }} 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | GraphQL Playground 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 145 | 146 | 567 |
568 | 601 |
Loading 602 | GraphQL Playground 603 |
604 |
605 | 606 |
607 | 624 | 625 | 626 | {{ end }} 627 | ` 628 | -------------------------------------------------------------------------------- /server/playground.tmpl: -------------------------------------------------------------------------------- 1 | {{ define "index" }} 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | GraphQL Playground 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 63 | 64 | 485 |
486 | 519 |
Loading 520 | GraphQL Playground 521 |
522 |
523 | 524 |
525 | 542 | 543 | 544 | {{ end }} -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/bhoriuchi/graphql-go-tools/server/graphqlws" 9 | "github.com/bhoriuchi/graphql-go-tools/server/logger" 10 | "github.com/gorilla/websocket" 11 | "github.com/graphql-go/graphql" 12 | "github.com/graphql-go/graphql/gqlerrors" 13 | ) 14 | 15 | // Constants 16 | const ( 17 | ContentTypeJSON = "application/json" 18 | ContentTypeGraphQL = "application/graphql" 19 | ContentTypeFormURLEncoded = "application/x-www-form-urlencoded" 20 | ) 21 | 22 | // ConnKey the connection key 23 | var ConnKey interface{} = "conn" 24 | 25 | type Server struct { 26 | schema graphql.Schema 27 | log logger.Logger 28 | options *Options 29 | upgrader websocket.Upgrader 30 | mgr *ChanMgr 31 | } 32 | 33 | func New(schema graphql.Schema, options *Options) *Server { 34 | if options.Logger == nil { 35 | options.Logger = &logger.NoopLogger{} 36 | } 37 | 38 | return &Server{ 39 | schema: schema, 40 | log: options.Logger, 41 | options: options, 42 | upgrader: websocket.Upgrader{ 43 | CheckOrigin: func(r *http.Request) bool { return true }, 44 | Subprotocols: []string{"graphql-ws"}, 45 | }, 46 | mgr: &ChanMgr{ 47 | conns: make(map[string]map[string]*ResultChan), 48 | }, 49 | } 50 | } 51 | 52 | type RootValueFunc func(ctx context.Context, r *http.Request) map[string]interface{} 53 | 54 | type FormatErrorFunc func(err error) gqlerrors.FormattedError 55 | 56 | type ContextFunc func(r *http.Request) context.Context 57 | 58 | type ResultCallbackFunc func(ctx context.Context, params *graphql.Params, result *graphql.Result, responseBody []byte) 59 | 60 | type Options struct { 61 | Pretty bool 62 | RootValueFunc RootValueFunc 63 | FormatErrorFunc FormatErrorFunc 64 | ContextFunc ContextFunc 65 | WSContextFunc ContextFunc 66 | ResultCallbackFunc ResultCallbackFunc 67 | Logger logger.Logger 68 | WS *WSOptions 69 | Playground *PlaygroundOptions 70 | GraphiQL *GraphiQLOptions 71 | } 72 | 73 | type WSOptions struct { 74 | AuthenticateFunc graphqlws.AuthenticateFunc 75 | } 76 | 77 | func IsWSUpgrade(r *http.Request) bool { 78 | connection := strings.ToLower(r.Header.Get("Connection")) 79 | upgrade := strings.ToLower(r.Header.Get("Upgrade")) 80 | return connection == "upgrade" && upgrade == "websocket" 81 | } 82 | 83 | // ServeHTTP provides an entrypoint into executing graphQL queries. 84 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 85 | if IsWSUpgrade(r) { 86 | s.log.Debugf("Upgrading connection to websocket") 87 | ctx := r.Context() 88 | if s.options.WSContextFunc != nil { 89 | ctx = s.options.WSContextFunc(r) 90 | } 91 | s.WSHandler(ctx, w, r) 92 | } else { 93 | ctx := r.Context() 94 | if s.options.ContextFunc != nil { 95 | ctx = s.options.ContextFunc(r) 96 | } 97 | s.ContextHandler(ctx, w, r) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /typedefs.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/graphql-go/graphql/language/ast" 8 | "github.com/graphql-go/graphql/language/parser" 9 | "github.com/graphql-go/graphql/language/printer" 10 | "github.com/graphql-go/graphql/language/source" 11 | ) 12 | 13 | // ConcatenateTypeDefs combines one ore more typeDefs into an ast Document 14 | func (c *ExecutableSchema) ConcatenateTypeDefs() (*ast.Document, error) { 15 | switch c.TypeDefs.(type) { 16 | case string: 17 | return c.concatenateTypeDefs([]string{c.TypeDefs.(string)}) 18 | case []string: 19 | return c.concatenateTypeDefs(c.TypeDefs.([]string)) 20 | case func() []string: 21 | return c.concatenateTypeDefs(c.TypeDefs.(func() []string)()) 22 | } 23 | return nil, fmt.Errorf("unsupported TypeDefs value. Must be one of string, []string, or func() []string") 24 | } 25 | 26 | // performs the actual concatenation of the types by parsing each 27 | // typeDefs string and converting each definition into a string 28 | // then creating a unique list of all definitions and finally 29 | // printing them as a single definition and returning the parsed document 30 | func (c *ExecutableSchema) concatenateTypeDefs(typeDefs []string) (*ast.Document, error) { 31 | resolvedTypes := map[string]interface{}{} 32 | for _, defs := range typeDefs { 33 | doc, err := parser.Parse(parser.ParseParams{ 34 | Source: &source.Source{ 35 | Body: []byte(defs), 36 | Name: "GraphQL", 37 | }, 38 | }) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | // if there is only 1 typedef, no de-duplication needs to happen 44 | if len(typeDefs) == 1 { 45 | return doc, nil 46 | } 47 | 48 | for _, typeDef := range doc.Definitions { 49 | if def := printer.Print(typeDef); def != nil { 50 | stringDef := strings.TrimSpace(def.(string)) 51 | resolvedTypes[stringDef] = nil 52 | } 53 | } 54 | } 55 | 56 | typeArray := []string{} 57 | for def := range resolvedTypes { 58 | typeArray = append(typeArray, def) 59 | } 60 | 61 | doc, err := parser.Parse(parser.ParseParams{ 62 | Source: &source.Source{ 63 | Body: []byte(strings.Join(typeArray, "\n")), 64 | Name: "GraphQL", 65 | }, 66 | }) 67 | 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return doc, nil 73 | } 74 | -------------------------------------------------------------------------------- /typedefs_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/graphql-go/graphql" 7 | ) 8 | 9 | func TestConcatenateTypeDefs(t *testing.T) { 10 | config := ExecutableSchema{ 11 | TypeDefs: []string{ 12 | "type Query{}", 13 | ` 14 | # a foo 15 | type Foo { 16 | name: String! 17 | description: String 18 | } 19 | 20 | extend type Query { 21 | foo: Foo 22 | }`, 23 | ` 24 | interface Named { 25 | name: String! 26 | } 27 | 28 | type Bar implements Named { 29 | name: String! 30 | description: String 31 | } 32 | 33 | extend type Query { 34 | bar: Bar 35 | }`, 36 | }, 37 | } 38 | 39 | schema, err := MakeExecutableSchema(config) 40 | if err != nil { 41 | t.Errorf("failed to make schema from concatenated TypeDefs: %v", err) 42 | return 43 | } 44 | 45 | // perform a query 46 | r := graphql.Do(graphql.Params{ 47 | Schema: schema, 48 | RequestString: `query Query { 49 | foo { 50 | name 51 | } 52 | bar { 53 | name 54 | } 55 | }`, 56 | }) 57 | 58 | if r.HasErrors() { 59 | t.Error(r.Errors) 60 | return 61 | } 62 | } 63 | 64 | func TestObjectIsTypeOf(t *testing.T) { 65 | config := ExecutableSchema{ 66 | TypeDefs: []string{ 67 | "type Query{}", 68 | ` 69 | # a foo 70 | type A { 71 | name: String! 72 | } 73 | type B { 74 | description: String 75 | } 76 | union Foo = A | B 77 | 78 | extend type Query { 79 | foo: Foo 80 | }`, 81 | }, 82 | Resolvers: map[string]interface{}{ 83 | "A": &ObjectResolver{ 84 | IsTypeOf: func(p graphql.IsTypeOfParams) bool { 85 | return true 86 | }, 87 | }, 88 | "B": &ObjectResolver{ 89 | IsTypeOf: func(p graphql.IsTypeOfParams) bool { 90 | return false 91 | }, 92 | }, 93 | }, 94 | } 95 | 96 | schema, err := MakeExecutableSchema(config) 97 | if err != nil { 98 | t.Errorf("failed to make schema from concatenated TypeDefs: %v", err) 99 | return 100 | } 101 | 102 | // perform a query 103 | r := graphql.Do(graphql.Params{ 104 | Schema: schema, 105 | RequestString: `query Query { 106 | foo { 107 | ...on A { 108 | name 109 | } 110 | ...on B { 111 | description 112 | } 113 | } 114 | }`, 115 | }) 116 | 117 | if r.HasErrors() { 118 | t.Error(r.Errors) 119 | return 120 | } 121 | } 122 | 123 | func TestUnionResolveType(t *testing.T) { 124 | config := ExecutableSchema{ 125 | TypeDefs: []string{ 126 | "type Query{}", 127 | ` 128 | # a foo 129 | type A { 130 | name: String! 131 | } 132 | type B { 133 | description: String 134 | } 135 | union Foo = A | B 136 | 137 | extend type Query { 138 | foo: Foo 139 | }`, 140 | }, 141 | Resolvers: map[string]interface{}{ 142 | "Foo": &UnionResolver{ 143 | ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { 144 | return p.Info.Schema.TypeMap()["A"].(*graphql.Object) 145 | }, 146 | }, 147 | }, 148 | } 149 | 150 | schema, err := MakeExecutableSchema(config) 151 | if err != nil { 152 | t.Errorf("failed to make schema from concatenated TypeDefs: %v", err) 153 | return 154 | } 155 | 156 | // perform a query 157 | r := graphql.Do(graphql.Params{ 158 | Schema: schema, 159 | RequestString: `query Query { 160 | foo { 161 | ...on A { 162 | name 163 | } 164 | ...on B { 165 | description 166 | } 167 | } 168 | }`, 169 | }) 170 | 171 | if r.HasErrors() { 172 | t.Error(r.Errors) 173 | return 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/graphql-go/graphql" 7 | "github.com/graphql-go/graphql/language/ast" 8 | "github.com/graphql-go/graphql/language/kinds" 9 | ) 10 | 11 | // builds a scalar from ast 12 | func (c *registry) buildScalarFromAST(definition *ast.ScalarDefinition) error { 13 | name := definition.Name.Value 14 | scalarConfig := graphql.ScalarConfig{ 15 | Name: name, 16 | Description: getDescription(definition), 17 | } 18 | 19 | if r := c.getResolver(name); r != nil && r.getKind() == kinds.ScalarDefinition { 20 | scalarConfig.ParseLiteral = r.(*ScalarResolver).ParseLiteral 21 | scalarConfig.ParseValue = r.(*ScalarResolver).ParseValue 22 | scalarConfig.Serialize = r.(*ScalarResolver).Serialize 23 | } 24 | 25 | if err := c.applyDirectives(applyDirectiveParams{ 26 | config: &scalarConfig, 27 | directives: definition.Directives, 28 | node: definition, 29 | }); err != nil { 30 | return err 31 | } 32 | 33 | c.types[name] = graphql.NewScalar(scalarConfig) 34 | return nil 35 | } 36 | 37 | // builds an enum from ast 38 | func (c *registry) buildEnumFromAST(definition *ast.EnumDefinition) error { 39 | name := definition.Name.Value 40 | enumConfig := graphql.EnumConfig{ 41 | Name: name, 42 | Description: getDescription(definition), 43 | Values: graphql.EnumValueConfigMap{}, 44 | } 45 | 46 | for _, value := range definition.Values { 47 | if value != nil { 48 | val, err := c.buildEnumValueFromAST(value, name) 49 | if err != nil { 50 | return err 51 | } 52 | enumConfig.Values[value.Name.Value] = val 53 | } 54 | } 55 | 56 | if err := c.applyDirectives(applyDirectiveParams{ 57 | config: &enumConfig, 58 | directives: definition.Directives, 59 | node: definition, 60 | }); err != nil { 61 | return err 62 | } 63 | 64 | c.types[name] = graphql.NewEnum(enumConfig) 65 | return nil 66 | } 67 | 68 | // builds an enum value from an ast 69 | func (c *registry) buildEnumValueFromAST(definition *ast.EnumValueDefinition, enumName string) (*graphql.EnumValueConfig, error) { 70 | var value interface{} 71 | value = definition.Name.Value 72 | 73 | if r := c.getResolver(enumName); r != nil && r.getKind() == kinds.EnumDefinition { 74 | if val, ok := r.(*EnumResolver).Values[definition.Name.Value]; ok { 75 | value = val 76 | } 77 | } 78 | 79 | valueConfig := graphql.EnumValueConfig{ 80 | Value: value, 81 | Description: getDescription(definition), 82 | } 83 | 84 | if err := c.applyDirectives(applyDirectiveParams{ 85 | config: &valueConfig, 86 | directives: definition.Directives, 87 | node: definition, 88 | }); err != nil { 89 | return nil, err 90 | } 91 | 92 | return &valueConfig, nil 93 | } 94 | 95 | // builds an input from ast 96 | func (c *registry) buildInputObjectFromAST(definition *ast.InputObjectDefinition) error { 97 | var fields interface{} 98 | name := definition.Name.Value 99 | inputConfig := graphql.InputObjectConfig{ 100 | Name: name, 101 | Description: getDescription(definition), 102 | Fields: fields, 103 | } 104 | 105 | // use thunks only when allowed 106 | if _, ok := c.dependencyMap[name]; ok { 107 | var fields graphql.InputObjectConfigFieldMapThunk = func() graphql.InputObjectConfigFieldMap { 108 | fieldMap, err := c.buildInputObjectFieldMapFromAST(definition.Fields) 109 | if err != nil { 110 | return nil 111 | } 112 | return fieldMap 113 | } 114 | inputConfig.Fields = fields 115 | } else { 116 | fieldMap, err := c.buildInputObjectFieldMapFromAST(definition.Fields) 117 | if err != nil { 118 | return err 119 | } 120 | inputConfig.Fields = fieldMap 121 | } 122 | 123 | if err := c.applyDirectives(applyDirectiveParams{ 124 | config: &inputConfig, 125 | directives: definition.Directives, 126 | node: definition, 127 | }); err != nil { 128 | return err 129 | } 130 | 131 | c.types[name] = graphql.NewInputObject(inputConfig) 132 | return nil 133 | } 134 | 135 | // builds an input object field map from ast 136 | func (c *registry) buildInputObjectFieldMapFromAST(fields []*ast.InputValueDefinition) (graphql.InputObjectConfigFieldMap, error) { 137 | fieldMap := graphql.InputObjectConfigFieldMap{} 138 | for _, fieldDef := range fields { 139 | field, err := c.buildInputObjectFieldFromAST(fieldDef) 140 | if err != nil { 141 | return nil, err 142 | } 143 | fieldMap[fieldDef.Name.Value] = field 144 | } 145 | return fieldMap, nil 146 | } 147 | 148 | // builds an input object field from an AST 149 | func (c *registry) buildInputObjectFieldFromAST(definition *ast.InputValueDefinition) (*graphql.InputObjectFieldConfig, error) { 150 | inputType, err := c.buildComplexType(definition.Type) 151 | if err != nil { 152 | return nil, err 153 | } 154 | 155 | defaultValue, err := getDefaultValue(definition) 156 | if err != nil { 157 | return nil, err 158 | } 159 | 160 | field := graphql.InputObjectFieldConfig{ 161 | Type: inputType, 162 | Description: getDescription(definition), 163 | DefaultValue: defaultValue, 164 | } 165 | 166 | if err := c.applyDirectives(applyDirectiveParams{ 167 | config: &field, 168 | directives: definition.Directives, 169 | node: definition, 170 | }); err != nil { 171 | return nil, err 172 | } 173 | 174 | return &field, nil 175 | } 176 | 177 | // builds an object from an AST 178 | func (c *registry) buildObjectFromAST(definition *ast.ObjectDefinition) error { 179 | name := definition.Name.Value 180 | extensions := c.getExtensions(name, definition.GetKind()) 181 | objectConfig := graphql.ObjectConfig{ 182 | Name: name, 183 | Description: getDescription(definition), 184 | } 185 | 186 | if _, ok := c.dependencyMap[name]; ok { 187 | // get interfaces thunk 188 | var ifaces graphql.InterfacesThunk = func() []*graphql.Interface { 189 | ifaceArr, err := c.buildInterfacesArrayFromAST(definition, extensions) 190 | if err != nil { 191 | return nil 192 | } 193 | return ifaceArr 194 | } 195 | 196 | // get fields thunk 197 | var fields graphql.FieldsThunk = func() graphql.Fields { 198 | fieldMap, err := c.buildFieldMapFromAST(definition.Fields, definition.GetKind(), name, extensions) 199 | if err != nil { 200 | return nil 201 | } 202 | return fieldMap 203 | } 204 | 205 | objectConfig.Interfaces = ifaces 206 | objectConfig.Fields = fields 207 | 208 | } else { 209 | // get interfaces 210 | ifaceArr, err := c.buildInterfacesArrayFromAST(definition, extensions) 211 | if err != nil { 212 | return err 213 | } 214 | 215 | // get fields 216 | fieldMap, err := c.buildFieldMapFromAST(definition.Fields, definition.GetKind(), name, extensions) 217 | if err != nil { 218 | return err 219 | } 220 | 221 | objectConfig.Interfaces = ifaceArr 222 | objectConfig.Fields = fieldMap 223 | } 224 | 225 | // set IsTypeOf from resolvers 226 | if r := c.getResolver(name); r != nil { 227 | if resolver, ok := r.(*ObjectResolver); ok { 228 | objectConfig.IsTypeOf = resolver.IsTypeOf 229 | } 230 | } 231 | 232 | // update description from extensions if none 233 | for _, extDef := range extensions { 234 | if objectConfig.Description != "" { 235 | break 236 | } 237 | objectConfig.Description = getDescription(extDef) 238 | } 239 | 240 | // create a combined directives array 241 | directiveDefs := append([]*ast.Directive{}, definition.Directives...) 242 | for _, extDef := range extensions { 243 | directiveDefs = append(directiveDefs, extDef.Directives...) 244 | } 245 | 246 | if err := c.applyDirectives(applyDirectiveParams{ 247 | config: &objectConfig, 248 | directives: directiveDefs, 249 | extensions: extensions, 250 | node: definition, 251 | }); err != nil { 252 | return err 253 | } 254 | 255 | c.types[name] = graphql.NewObject(objectConfig) 256 | return nil 257 | } 258 | 259 | func (c *registry) buildInterfacesArrayFromAST(definition *ast.ObjectDefinition, extensions []*ast.ObjectDefinition) ([]*graphql.Interface, error) { 260 | imap := map[string]bool{} 261 | ifaces := []*graphql.Interface{} 262 | 263 | // build list of interfaces and append extensions 264 | ifaceDefs := append([]*ast.Named{}, definition.Interfaces...) 265 | for _, extDef := range extensions { 266 | ifaceDefs = append(ifaceDefs, extDef.Interfaces...) 267 | } 268 | 269 | // add defined interfaces 270 | for _, ifaceDef := range ifaceDefs { 271 | if _, ok := imap[ifaceDef.Name.Value]; !ok { 272 | iface, err := c.getType(ifaceDef.Name.Value) 273 | if err != nil { 274 | return nil, err 275 | } 276 | ifaces = append(ifaces, iface.(*graphql.Interface)) 277 | imap[ifaceDef.Name.Value] = true 278 | } 279 | } 280 | 281 | return ifaces, nil 282 | } 283 | 284 | func (c *registry) buildFieldMapFromAST(fields []*ast.FieldDefinition, kind, typeName string, extensions []*ast.ObjectDefinition) (graphql.Fields, error) { 285 | fieldMap := graphql.Fields{} 286 | 287 | // build list of fields and append extensions 288 | fieldDefs := append([]*ast.FieldDefinition{}, fields...) 289 | for _, extDef := range extensions { 290 | fieldDefs = append(fieldDefs, extDef.Fields...) 291 | } 292 | 293 | // add defined fields 294 | for _, fieldDef := range fieldDefs { 295 | if _, ok := fieldMap[fieldDef.Name.Value]; !ok { 296 | if field, err := c.buildFieldFromAST(fieldDef, kind, typeName); err == nil { 297 | if !isHiddenField(fieldDef) { 298 | fieldMap[fieldDef.Name.Value] = field 299 | } 300 | } else { 301 | return nil, err 302 | } 303 | } 304 | } 305 | 306 | return fieldMap, nil 307 | } 308 | 309 | // builds an interfacefrom ast 310 | func (c *registry) buildInterfaceFromAST(definition *ast.InterfaceDefinition) error { 311 | extensions := []*ast.ObjectDefinition{} 312 | name := definition.Name.Value 313 | ifaceConfig := graphql.InterfaceConfig{ 314 | Name: name, 315 | Description: getDescription(definition), 316 | } 317 | 318 | if _, ok := c.dependencyMap[name]; ok { 319 | var fields graphql.FieldsThunk = func() graphql.Fields { 320 | fieldMap, err := c.buildFieldMapFromAST(definition.Fields, definition.GetKind(), name, extensions) 321 | if err != nil { 322 | return nil 323 | } 324 | return fieldMap 325 | } 326 | ifaceConfig.Fields = fields 327 | } else { 328 | fieldMap, err := c.buildFieldMapFromAST(definition.Fields, definition.GetKind(), name, extensions) 329 | if err != nil { 330 | return err 331 | } 332 | ifaceConfig.Fields = fieldMap 333 | } 334 | 335 | if r := c.getResolver(name); r != nil && r.getKind() == kinds.InterfaceDefinition { 336 | ifaceConfig.ResolveType = r.(*InterfaceResolver).ResolveType 337 | } 338 | 339 | if err := c.applyDirectives(applyDirectiveParams{ 340 | config: &ifaceConfig, 341 | directives: definition.Directives, 342 | node: definition, 343 | }); err != nil { 344 | return err 345 | } 346 | 347 | c.types[name] = graphql.NewInterface(ifaceConfig) 348 | return nil 349 | } 350 | 351 | // builds an arg from an ast 352 | func (c *registry) buildArgFromAST(definition *ast.InputValueDefinition) (*graphql.ArgumentConfig, error) { 353 | inputType, err := c.buildComplexType(definition.Type) 354 | if err != nil { 355 | return nil, err 356 | } 357 | 358 | defaultValue, err := getDefaultValue(definition) 359 | if err != nil { 360 | return nil, err 361 | } 362 | 363 | arg := graphql.ArgumentConfig{ 364 | Type: inputType, 365 | Description: getDescription(definition), 366 | DefaultValue: defaultValue, 367 | } 368 | 369 | if err := c.applyDirectives(applyDirectiveParams{ 370 | config: &arg, 371 | directives: definition.Directives, 372 | node: definition, 373 | }); err != nil { 374 | return nil, err 375 | } 376 | 377 | return &arg, nil 378 | } 379 | 380 | // builds a field from an ast 381 | func (c *registry) buildFieldFromAST(definition *ast.FieldDefinition, kind, typeName string) (*graphql.Field, error) { 382 | fieldType, err := c.buildComplexType(definition.Type) 383 | if err != nil { 384 | return nil, err 385 | } 386 | 387 | field := graphql.Field{ 388 | Name: definition.Name.Value, 389 | Description: getDescription(definition), 390 | Type: fieldType, 391 | Args: graphql.FieldConfigArgument{}, 392 | Resolve: c.getFieldResolveFn(kind, typeName, definition.Name.Value), 393 | Subscribe: c.getFieldSubscribeFn(kind, typeName, definition.Name.Value), 394 | } 395 | 396 | for _, arg := range definition.Arguments { 397 | if arg != nil { 398 | argValue, err := c.buildArgFromAST(arg) 399 | if err != nil { 400 | return nil, err 401 | } 402 | field.Args[arg.Name.Value] = argValue 403 | } 404 | } 405 | 406 | if err := c.applyDirectives(applyDirectiveParams{ 407 | config: &field, 408 | directives: definition.Directives, 409 | node: definition, 410 | parentName: typeName, 411 | parentKind: kind, 412 | }); err != nil { 413 | return nil, err 414 | } 415 | 416 | return &field, nil 417 | } 418 | 419 | // builds a union from ast 420 | func (c *registry) buildUnionFromAST(definition *ast.UnionDefinition) error { 421 | name := definition.Name.Value 422 | unionConfig := graphql.UnionConfig{ 423 | Name: name, 424 | Types: []*graphql.Object{}, 425 | Description: getDescription(definition), 426 | } 427 | 428 | // add types 429 | for _, unionType := range definition.Types { 430 | object, err := c.getType(unionType.Name.Value) 431 | if err != nil { 432 | return err 433 | } 434 | if object != nil { 435 | switch o := object.(type) { 436 | case *graphql.Object: 437 | unionConfig.Types = append(unionConfig.Types, o) 438 | continue 439 | } 440 | } 441 | return fmt.Errorf("build Union failed: no Object type %q found", unionType.Name.Value) 442 | } 443 | 444 | // set ResolveType from resolvers 445 | if r := c.getResolver(name); r != nil { 446 | if resolver, ok := r.(*UnionResolver); ok { 447 | unionConfig.ResolveType = resolver.ResolveType 448 | } 449 | } 450 | 451 | if err := c.applyDirectives(applyDirectiveParams{ 452 | config: &unionConfig, 453 | directives: definition.Directives, 454 | node: definition, 455 | }); err != nil { 456 | return err 457 | } 458 | 459 | c.types[name] = graphql.NewUnion(unionConfig) 460 | return nil 461 | } 462 | -------------------------------------------------------------------------------- /values.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | // taken from https://github.com/graphql-go/graphql/values.go 4 | // since none of these functions are exported 5 | 6 | import ( 7 | "fmt" 8 | "math" 9 | "reflect" 10 | 11 | "github.com/graphql-go/graphql" 12 | "github.com/graphql-go/graphql/language/ast" 13 | "github.com/graphql-go/graphql/language/kinds" 14 | ) 15 | 16 | // Prepares an object map of argument values given a list of argument 17 | // definitions and list of argument AST nodes. 18 | func GetArgumentValues(argDefs []*graphql.Argument, argASTs []*ast.Argument, variableVariables map[string]interface{}) (map[string]interface{}, error) { 19 | 20 | argASTMap := map[string]*ast.Argument{} 21 | for _, argAST := range argASTs { 22 | if argAST.Name != nil { 23 | argASTMap[argAST.Name.Value] = argAST 24 | } 25 | } 26 | results := map[string]interface{}{} 27 | for _, argDef := range argDefs { 28 | 29 | name := argDef.PrivateName 30 | var valueAST ast.Value 31 | if argAST, ok := argASTMap[name]; ok { 32 | valueAST = argAST.Value 33 | } 34 | 35 | value := valueFromAST(valueAST, argDef.Type, variableVariables) 36 | 37 | if isNullish(value) { 38 | value = argDef.DefaultValue 39 | } 40 | 41 | // fix for checking that non nulls are not null 42 | typeString := argDef.Type.String() 43 | isNonNull := typeString[len(typeString)-1:] == "!" 44 | if isNonNull && isNullish(value) { 45 | locs := []string{} 46 | 47 | for _, a := range argASTs { 48 | locs = append(locs, fmt.Sprintf("%d:%d", a.Loc.Start, a.Loc.End)) 49 | } 50 | return nil, fmt.Errorf(`graphql input %q @ %q cannot be null`, name, locs) 51 | } 52 | 53 | if !isNullish(value) { 54 | results[name] = value 55 | } 56 | } 57 | return results, nil 58 | } 59 | 60 | // Returns true if a value is null, undefined, or NaN. 61 | func isNullish(src interface{}) bool { 62 | if src == nil { 63 | return true 64 | } 65 | value := reflect.ValueOf(src) 66 | if value.Kind() == reflect.Ptr { 67 | value = value.Elem() 68 | } 69 | switch value.Kind() { 70 | case reflect.String: 71 | // if src is ptr type and len(string)=0, it returns false 72 | if !value.IsValid() { 73 | return true 74 | } 75 | case reflect.Int: 76 | return math.IsNaN(float64(value.Int())) 77 | case reflect.Float32, reflect.Float64: 78 | return math.IsNaN(float64(value.Float())) 79 | } 80 | return false 81 | } 82 | 83 | /** 84 | * Produces a value given a GraphQL Value AST. 85 | * 86 | * A GraphQL type must be provided, which will be used to interpret different 87 | * GraphQL Value literals. 88 | * 89 | * | GraphQL Value | JSON Value | 90 | * | -------------------- | ------------- | 91 | * | Input Object | Object | 92 | * | List | Array | 93 | * | Boolean | Boolean | 94 | * | String / Enum Value | String | 95 | * | Int / Float | Number | 96 | * 97 | */ 98 | func valueFromAST(valueAST ast.Value, ttype graphql.Input, variables map[string]interface{}) interface{} { 99 | 100 | if ttype, ok := ttype.(*graphql.NonNull); ok { 101 | val := valueFromAST(valueAST, ttype.OfType, variables) 102 | return val 103 | } 104 | 105 | if valueAST == nil { 106 | return nil 107 | } 108 | 109 | if valueAST, ok := valueAST.(*ast.Variable); ok && valueAST.Kind == kinds.Variable { 110 | if valueAST.Name == nil { 111 | return nil 112 | } 113 | if variables == nil { 114 | return nil 115 | } 116 | variableName := valueAST.Name.Value 117 | variableVal, ok := variables[variableName] 118 | if !ok { 119 | return nil 120 | } 121 | // Note: we're not doing any checking that this variable is correct. We're 122 | // assuming that this query has been validated and the variable usage here 123 | // is of the correct type. 124 | return variableVal 125 | } 126 | 127 | if ttype, ok := ttype.(*graphql.List); ok { 128 | itemType := ttype.OfType 129 | if valueAST, ok := valueAST.(*ast.ListValue); ok && valueAST.Kind == kinds.ListValue { 130 | values := []interface{}{} 131 | for _, itemAST := range valueAST.Values { 132 | v := valueFromAST(itemAST, itemType, variables) 133 | values = append(values, v) 134 | } 135 | return values 136 | } 137 | v := valueFromAST(valueAST, itemType, variables) 138 | return []interface{}{v} 139 | } 140 | 141 | if ttype, ok := ttype.(*graphql.InputObject); ok { 142 | valueAST, ok := valueAST.(*ast.ObjectValue) 143 | if !ok { 144 | return nil 145 | } 146 | fieldASTs := map[string]*ast.ObjectField{} 147 | for _, fieldAST := range valueAST.Fields { 148 | if fieldAST.Name == nil { 149 | continue 150 | } 151 | fieldName := fieldAST.Name.Value 152 | fieldASTs[fieldName] = fieldAST 153 | 154 | } 155 | obj := map[string]interface{}{} 156 | for fieldName, field := range ttype.Fields() { 157 | fieldAST, ok := fieldASTs[fieldName] 158 | fieldValue := field.DefaultValue 159 | if !ok || fieldAST == nil { 160 | if fieldValue == nil { 161 | continue 162 | } 163 | } else { 164 | fieldValue = valueFromAST(fieldAST.Value, field.Type, variables) 165 | } 166 | if isNullish(fieldValue) { 167 | fieldValue = field.DefaultValue 168 | } 169 | if !isNullish(fieldValue) { 170 | obj[fieldName] = fieldValue 171 | } 172 | } 173 | return obj 174 | } 175 | 176 | switch ttype := ttype.(type) { 177 | case *graphql.Scalar: 178 | parsed := ttype.ParseLiteral(valueAST) 179 | if !isNullish(parsed) { 180 | return parsed 181 | } 182 | case *graphql.Enum: 183 | parsed := ttype.ParseLiteral(valueAST) 184 | if !isNullish(parsed) { 185 | return parsed 186 | } 187 | } 188 | return nil 189 | } 190 | --------------------------------------------------------------------------------