├── .github └── workflows │ ├── codeql.yml │ └── go.yml ├── README.md ├── README_zh.md ├── future.go ├── gen ├── api.go ├── api_test.go ├── builder.go ├── builder_test.go ├── generate.go ├── generate_test.go ├── integration │ ├── sqlx │ │ ├── executor.go │ │ ├── go.mod │ │ └── main.go │ └── sqlx_test.go ├── method.go ├── method_test.go ├── sqlx.go ├── sqlx_test.go ├── template │ ├── api.tmpl │ └── sqlx.tmpl ├── testdata │ ├── api │ │ └── test.go │ ├── cycle │ │ ├── a │ │ │ └── test.go │ │ └── b │ │ │ └── test.go │ └── sqlx │ │ ├── test.go │ │ └── test.sql ├── tools.go └── tools_test.go ├── go.mod ├── legacy.go ├── main.go ├── runtime ├── init.go ├── init_test.go ├── merge.go ├── merge_test.go ├── pool.go ├── pool_test.go ├── request.go ├── request_test.go ├── response.go ├── response_test.go ├── split.go ├── split_test.go ├── sqlx.go ├── sqlx_test.go ├── token │ └── token.go └── version.go ├── sqlx ├── reflectx │ └── reflectx.gen.go ├── sqlx.refined.go └── sqlx_test.go └── test.sh /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | 20 | jobs: 21 | analyze: 22 | name: Analyze (${{ matrix.language }}) 23 | # Runner size impacts CodeQL analysis time. To learn more, please see: 24 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 25 | # - https://gh.io/supported-runners-and-hardware-resources 26 | # - https://gh.io/using-larger-runners (GitHub.com only) 27 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 28 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 29 | timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} 30 | permissions: 31 | # required for all workflows 32 | security-events: write 33 | 34 | # only required for workflows in private repositories 35 | actions: read 36 | contents: read 37 | 38 | strategy: 39 | fail-fast: false 40 | matrix: 41 | include: 42 | - language: go 43 | build-mode: autobuild 44 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 45 | # Use `c-cpp` to analyze code written in C, C++ or both 46 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 47 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 48 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 49 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 50 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 51 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 52 | steps: 53 | - name: Checkout repository 54 | uses: actions/checkout@v4 55 | 56 | # Initializes the CodeQL tools for scanning. 57 | - name: Initialize CodeQL 58 | uses: github/codeql-action/init@v3 59 | with: 60 | languages: ${{ matrix.language }} 61 | build-mode: ${{ matrix.build-mode }} 62 | # If you wish to specify custom queries, you can do so here or in a config file. 63 | # By default, queries listed here will override any specified in a config file. 64 | # Prefix the list here with "+" to use these queries and those in the config file. 65 | 66 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 67 | # queries: security-extended,security-and-quality 68 | 69 | # If the analyze step fails for one of the languages you are analyzing with 70 | # "We were unable to automatically build your code", modify the matrix above 71 | # to set the build mode to "manual" for that language. Then modify this step 72 | # to build your code. 73 | # ℹ️ Command-line programs to run using the OS shell. 74 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 75 | - if: matrix.build-mode == 'manual' 76 | run: | 77 | echo 'If you are using a "manual" build mode for one or more of the' \ 78 | 'languages you are analyzing, replace this with the commands to build' \ 79 | 'your code, for example:' 80 | echo ' make bootstrap' 81 | echo ' make release' 82 | exit 1 83 | 84 | - name: Perform CodeQL Analysis 85 | uses: github/codeql-action/analyze@v3 86 | with: 87 | category: "/language:${{matrix.language}}" 88 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.19' 23 | 24 | - name: Install dependencies 25 | run: go mod tidy 26 | 27 | - name: Build 28 | run: go build -v . 29 | 30 | - name: Test 31 | run: ./test.sh 32 | -------------------------------------------------------------------------------- /future.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/x5iu/defc/gen" 5 | ) 6 | 7 | func onInitialize() { 8 | features = append(features, gen.FeatureApiFuture, gen.FeatureSqlxFuture) 9 | } 10 | -------------------------------------------------------------------------------- /gen/api.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/token" 9 | "io" 10 | "net/http" 11 | "sort" 12 | "strings" 13 | "text/template" 14 | 15 | _ "embed" 16 | ) 17 | 18 | const ( 19 | apiMethodOptions = "OPTIONS" 20 | apiMethodResponseHandler = "RESPONSEHANDLER" 21 | 22 | // Deprecated: use apiMethodOptions instead 23 | apiMethodInner = "INNER" 24 | // Deprecated: use apiMethodResponseHandler instead 25 | apiMethodResponse = "RESPONSE" 26 | 27 | FeatureApiCache = "api/cache" 28 | FeatureApiLog = "api/log" 29 | FeatureApiLogx = "api/logx" 30 | FeatureApiClient = "api/client" 31 | FeatureApiPage = "api/page" 32 | FeatureApiError = "api/error" 33 | FeatureApiNoRt = "api/nort" 34 | FeatureApiFuture = "api/future" 35 | FeatureApiIgnoreStatus = "api/ignore-status" 36 | FeatureApiGzip = "api/gzip" 37 | FeatureApiRetry = "api/retry" 38 | ) 39 | 40 | func (builder *CliBuilder) buildApi(w io.Writer) error { 41 | inspectCtx, err := builder.inspectApi() 42 | if err != nil { 43 | return fmt.Errorf("inspectApi(%s, %d): %w", quote(join(builder.pwd, builder.file)), builder.pos, err) 44 | } 45 | return inspectCtx.Build(w) 46 | } 47 | 48 | type apiContext struct { 49 | Package string 50 | BuildTags []string 51 | Ident string 52 | Generics map[string]ast.Expr 53 | Methods []*Method 54 | Features []string 55 | Imports []string 56 | Funcs []string 57 | Doc Doc 58 | Schema string 59 | } 60 | 61 | func (ctx *apiContext) Build(w io.Writer) error { 62 | if !checkResponse(ctx.Methods) { 63 | return fmt.Errorf("checkResponse: no '%s() T' method found in Interface", apiMethodResponse) 64 | } 65 | 66 | for _, method := range ctx.Methods { 67 | if !isResponse(method.Ident) && !isInner(method.Ident) { 68 | if l := len(method.Out); l == 0 || !checkErrorType(method.Out[l-1]) { 69 | return fmt.Errorf("checkErrorType: no 'error' found in method %s returned values", 70 | quote(method.Ident)) 71 | } 72 | } 73 | 74 | if (isResponse(method.Ident) || isInner(method.Ident)) && 75 | (len(method.In) != 0 || len(method.Out) != 1) { 76 | return fmt.Errorf( 77 | "%s method can only have no income params "+ 78 | "and 1 returned value", quote(method.Ident)) 79 | } 80 | 81 | if isResponse(method.Ident) { 82 | methodOut0 := method.Out[0] 83 | if !checkResponseType(methodOut0) { 84 | return fmt.Errorf( 85 | "checkResponseType: returned type of %s "+ 86 | "should be kind of "+ 87 | "*ast.Ident/"+ 88 | "*ast.StarExpr/"+ 89 | "*ast.SelectorExpr/"+ 90 | "*ast.IndexExpr/"+ 91 | "*ast.IndexListExpr"+ 92 | ", got %T", 93 | quote(apiMethodResponse), 94 | methodOut0) 95 | } 96 | } 97 | 98 | // [2023-06-11] we limit 2 returned values on v1.0.0, now it is time to cancel this limitation 99 | /* 100 | if len(method.Out) > 2 { 101 | return fmt.Errorf("%s method expects 2 returned value at most, got %d", 102 | quote(method.Ident), 103 | len(method.Out)) 104 | } 105 | */ 106 | } 107 | 108 | // When using the api/future feature without enabling the api/error feature, it may cause connections 109 | // to not be closed properly, potentially leading to memory leak risks. To prevent this from happening, 110 | // when the api/future feature is enabled, the api/error feature must also be enforced. 111 | if in(ctx.Features, FeatureApiFuture) && !in(ctx.Features, FeatureApiError) { 112 | ctx.Features = append(ctx.Features, FeatureApiError) 113 | } 114 | 115 | // We do not allow the use of api/ignore-status without enabling the api/future feature, as this would 116 | // cause callers to miss out on determining exceptional response codes. 117 | if in(ctx.Features, FeatureApiIgnoreStatus) && !in(ctx.Features, FeatureApiFuture) { 118 | return fmt.Errorf("api/ignore-status feature requires api/future feature to be enabled") 119 | } 120 | 121 | if in(ctx.Features, FeatureApiGzip) && in(ctx.Features, FeatureApiNoRt) { 122 | return fmt.Errorf("api/gzip feature requires api/nort feature to be disabled") 123 | } 124 | 125 | if err := ctx.genApiCode(w); err != nil { 126 | return fmt.Errorf("genApiCode: %w", err) 127 | } 128 | 129 | return nil 130 | } 131 | 132 | func (ctx *apiContext) SortGenerics() []string { 133 | idents := make([]string, 0, len(ctx.Generics)) 134 | for k := range ctx.Generics { 135 | idents = append(idents, k) 136 | } 137 | sort.Slice(idents, func(i, j int) bool { 138 | return ctx.Generics[idents[i]].Pos() < ctx.Generics[idents[j]].Pos() 139 | }) 140 | return idents 141 | } 142 | 143 | func (ctx *apiContext) GenericsRepr(withType bool) string { 144 | if len(ctx.Generics) == 0 { 145 | return "" 146 | } 147 | 148 | var dst bytes.Buffer 149 | dst.WriteByte('[') 150 | for index, name := range ctx.SortGenerics() { 151 | expr := ctx.Generics[name] 152 | dst.WriteString(name) 153 | if withType { 154 | dst.WriteByte(' ') 155 | dst.WriteString(ctx.Doc.Repr(expr)) 156 | } 157 | if index < len(ctx.Generics)-1 { 158 | dst.WriteString(", ") 159 | } 160 | } 161 | dst.WriteByte(']') 162 | 163 | return dst.String() 164 | } 165 | 166 | func (ctx *apiContext) HasFeature(feature string) bool { 167 | for _, current := range ctx.Features { 168 | if current == feature { 169 | return true 170 | } 171 | } 172 | return false 173 | } 174 | 175 | func (ctx *apiContext) HasHeader() bool { 176 | for _, method := range ctx.Methods { 177 | if method.Header != "" { 178 | return true 179 | } 180 | } 181 | return false 182 | } 183 | 184 | func (ctx *apiContext) HasBody() bool { 185 | for _, method := range ctx.Methods { 186 | if httpMethodHasBody(method.MethodHTTP()) && headerHasBody(method.TmplHeader()) { 187 | return true 188 | } 189 | } 190 | return false 191 | } 192 | 193 | func (ctx *apiContext) HasInner() bool { 194 | return hasInner(ctx.Methods) 195 | } 196 | 197 | func (ctx *apiContext) InnerType() ast.Node { 198 | for _, method := range ctx.Methods { 199 | if isInner(method.Ident) { 200 | return method.Out[0] 201 | } 202 | } 203 | return nil 204 | } 205 | 206 | func (ctx *apiContext) MethodResponse() string { 207 | for _, method := range ctx.Methods { 208 | if isResponse(method.Ident) { 209 | return method.Ident 210 | } 211 | } 212 | return apiMethodResponse 213 | } 214 | 215 | func (ctx *apiContext) MethodInner() string { 216 | for _, method := range ctx.Methods { 217 | if isInner(method.Ident) { 218 | return method.Ident 219 | } 220 | } 221 | return apiMethodInner 222 | } 223 | 224 | func (ctx *apiContext) MergedImports() (imports []string) { 225 | imports = []string{ 226 | quote("fmt"), 227 | quote("io"), 228 | quote("net/http"), 229 | quote("text/template"), 230 | } 231 | 232 | if ctx.HasFeature(FeatureApiLog) || ctx.HasFeature(FeatureApiLogx) { 233 | imports = append(imports, quote("time")) 234 | imports = append(imports, quote("context")) 235 | } 236 | 237 | if ctx.HasFeature(FeatureApiNoRt) { 238 | imports = append(imports, 239 | quote("bytes"), 240 | quote("sync"), 241 | quote("reflect")) 242 | } else { 243 | imports = append(imports, parseImport("__rt github.com/x5iu/defc/runtime")) 244 | } 245 | 246 | if ctx.HasHeader() { 247 | imports = append(imports, quote("bufio")) 248 | imports = append(imports, quote("net/textproto")) 249 | if ctx.HasBody() && ctx.HasFeature(FeatureApiLogx) { 250 | imports = append(imports, quote("bytes")) 251 | } 252 | } 253 | 254 | if importContext(ctx.Methods) { 255 | imports = append(imports, quote("context")) 256 | } 257 | 258 | for _, imp := range ctx.Imports { 259 | if !in(imports, imp) { 260 | imports = append(imports, parseImport(imp)) 261 | } 262 | } 263 | 264 | return imports 265 | } 266 | 267 | func (ctx *apiContext) AdditionalFuncs() (funcMap map[string]string) { 268 | funcMap = make(map[string]string, len(ctx.Funcs)) 269 | for _, fn := range ctx.Funcs { 270 | if key, value, ok := cutkv(fn); ok { 271 | funcMap[key] = value 272 | } 273 | } 274 | return funcMap 275 | } 276 | 277 | func (builder *CliBuilder) inspectApi() (*apiContext, error) { 278 | fset := token.NewFileSet() 279 | 280 | f, err := parser.ParseFile(fset, builder.file, builder.doc.Bytes(), parser.ParseComments) 281 | if err != nil { 282 | return nil, err 283 | } 284 | 285 | var ( 286 | genDecl *ast.GenDecl 287 | typeSpec *ast.TypeSpec 288 | ifaceType *ast.InterfaceType 289 | ) 290 | 291 | line := builder.pos + 1 292 | inspectDecl: 293 | for _, declIface := range f.Decls { 294 | if surroundLine(fset, declIface, line) { 295 | if decl, ok := declIface.(*ast.GenDecl); ok && decl.Tok == token.TYPE { 296 | genDecl = decl 297 | break inspectDecl 298 | } 299 | } 300 | } 301 | 302 | if genDecl == nil { 303 | return nil, fmt.Errorf( 304 | "no available 'Interface' type declaration (*ast.GenDecl) found, "+ 305 | "available *ast.GenDecl are: \n\n"+ 306 | "%s\n\n", concat(nodeMap(f.Decls, fmtNode), "\n")) 307 | } 308 | 309 | inspectType: 310 | for _, specIface := range genDecl.Specs { 311 | if afterLine(fset, specIface, line) { 312 | if spec, ok := specIface.(*ast.TypeSpec); ok { 313 | if iface, ok := spec.Type.(*ast.InterfaceType); ok && afterLine(fset, iface, line) { 314 | typeSpec = spec 315 | ifaceType = iface 316 | break inspectType 317 | } 318 | } 319 | } 320 | } 321 | 322 | if ifaceType == nil { 323 | return nil, fmt.Errorf( 324 | "no available 'Interface' type declaration (*ast.InterfaceType) found, "+ 325 | "available *ast.GenDecl are: \n\n"+ 326 | "%s\n\n", concat(nodeMap(f.Decls, fmtNode), "\n")) 327 | } 328 | 329 | if !builder.disableAutoImport { 330 | imports, err := getImports(builder.pkg, builder.pwd, builder.file, func(node ast.Node) bool { 331 | switch x := node.(type) { 332 | case *ast.TypeSpec: 333 | return x.Name.Name == typeSpec.Name.Name 334 | } 335 | return false 336 | }) 337 | 338 | if err != nil { 339 | return nil, err 340 | } 341 | 342 | for _, spec := range f.Imports { 343 | path := spec.Path.Value[1 : len(spec.Path.Value)-1] 344 | for _, imported := range imports { 345 | if path == imported.Path { 346 | var name string 347 | if spec.Name != nil { 348 | name = spec.Name.Name 349 | } 350 | if importRepr := strings.TrimSpace(name + " " + path); !in(builder.imports, importRepr) { 351 | builder.imports = append(builder.imports, importRepr) 352 | } 353 | } 354 | } 355 | } 356 | } 357 | 358 | for _, method := range ifaceType.Methods.List { 359 | if funcType, ok := method.Type.(*ast.FuncType); ok { 360 | if !checkInput(funcType) { 361 | return nil, fmt.Errorf(""+ 362 | "input params for method %s should "+ 363 | "contain 'Name' and 'Type' both", 364 | quote(method.Names[0].Name)) 365 | } 366 | } 367 | } 368 | 369 | apiFeatures := make([]string, 0, len(builder.feats)) 370 | for _, feature := range builder.feats { 371 | if hasPrefix(feature, "api") { 372 | apiFeatures = append(apiFeatures, feature) 373 | } 374 | } 375 | 376 | generics := make(map[string]ast.Expr, 16) 377 | if typeSpec.TypeParams != nil { 378 | for _, param := range typeSpec.TypeParams.List { 379 | for _, name := range param.Names { 380 | generics[name.Name] = param.Type 381 | } 382 | } 383 | } 384 | 385 | return &apiContext{ 386 | Package: builder.pkg, 387 | BuildTags: parseBuildTags(builder.doc), 388 | Ident: typeSpec.Name.Name, 389 | Generics: generics, 390 | Methods: typeMap(ifaceType.Methods.List, builder.doc.InspectMethod), 391 | Features: apiFeatures, 392 | Imports: builder.imports, 393 | Funcs: builder.funcs, 394 | Doc: builder.doc, 395 | }, nil 396 | } 397 | 398 | func checkResponse(methods []*Method) bool { 399 | for _, method := range methods { 400 | if isResponse(method.Ident) { 401 | return true 402 | } 403 | } 404 | return false 405 | } 406 | 407 | func checkResponseType(node ast.Node) bool { 408 | node = getNode(node) 409 | switch node.(type) { 410 | case *ast.Ident, *ast.StarExpr, *ast.SelectorExpr, *ast.IndexExpr, *ast.IndexListExpr: 411 | return true 412 | default: 413 | return false 414 | } 415 | } 416 | 417 | func hasInner(methods []*Method) bool { 418 | for _, method := range methods { 419 | if isInner(method.Ident) { 420 | return true 421 | } 422 | } 423 | return false 424 | } 425 | 426 | func importContext(methods []*Method) bool { 427 | for _, method := range methods { 428 | if method.HasContext() { 429 | return true 430 | } 431 | } 432 | return false 433 | } 434 | 435 | func isResponse(ident string) bool { 436 | ident = toUpper(ident) 437 | return ident == apiMethodResponse || ident == apiMethodResponseHandler 438 | } 439 | 440 | func isInner(ident string) bool { 441 | ident = toUpper(ident) 442 | return ident == apiMethodInner || ident == apiMethodOptions 443 | } 444 | 445 | func httpMethodHasBody(method string) bool { 446 | switch method { 447 | case http.MethodGet: 448 | return false 449 | case http.MethodPost, http.MethodPut, http.MethodPatch: 450 | return true 451 | default: 452 | return false 453 | } 454 | } 455 | 456 | func headerHasBody(header string) bool { 457 | if idx := index(header, "\r\n\r\n"); idx != -1 { 458 | return len(trimSpace(header[idx+4:])) > 0 459 | } 460 | if idx := index(header, "\n\n"); idx != -1 { 461 | return len(trimSpace(header[idx+2:])) > 0 462 | } 463 | return false 464 | } 465 | 466 | //go:embed template/api.tmpl 467 | var apiTemplate string 468 | 469 | func (ctx *apiContext) genApiCode(w io.Writer) error { 470 | tmpl, err := template. 471 | New("defc(api)"). 472 | Funcs(template.FuncMap{ 473 | "quote": quote, 474 | "isPointer": isPointer, 475 | "indirect": indirect, 476 | "importContext": importContext, 477 | "sub": func(x, y int) int { return x - y }, 478 | "getRepr": func(node ast.Node) string { return ctx.Doc.Repr(node) }, 479 | "isEllipsis": func(node ast.Node) bool { return hasPrefix(ctx.Doc.Repr(node), "...") }, 480 | "methodResp": ctx.MethodResponse, 481 | "methodInner": ctx.MethodInner, 482 | "isResponse": isResponse, 483 | "isInner": isInner, 484 | "httpMethodHasBody": httpMethodHasBody, 485 | "headerHasBody": headerHasBody, 486 | }). 487 | Parse(apiTemplate) 488 | 489 | if err != nil { 490 | return err 491 | } 492 | 493 | if ctx.Schema != "" { 494 | if tmpl, err = tmpl.Parse(sprintf(`{{ define "schema" }} %s {{ end }}`, ctx.Schema)); err != nil { 495 | return err 496 | } 497 | } 498 | 499 | return tmpl.Execute(w, ctx) 500 | } 501 | -------------------------------------------------------------------------------- /gen/api_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestBuildApi(t *testing.T) { 13 | const ( 14 | testPk = "test" 15 | testGo = testPk + ".go" 16 | ) 17 | var ( 18 | testDir = filepath.Join("testdata", "api") 19 | testFile = testGo 20 | genFile = testPk + "." + strings.ReplaceAll(t.Name(), "/", "_") + ".go" 21 | ) 22 | pwd, err := os.Getwd() 23 | if err != nil { 24 | t.Errorf("getwd: %s", err) 25 | return 26 | } 27 | defer func() { 28 | if err = os.Chdir(pwd); err != nil { 29 | t.Errorf("chdir: %s", err) 30 | return 31 | } 32 | }() 33 | if err = os.Chdir(testDir); err != nil { 34 | t.Errorf("chdir: %s", err) 35 | return 36 | } 37 | newBuilder := func(t *testing.T) (*CliBuilder, bool) { 38 | doc, err := os.ReadFile(testFile) 39 | if err != nil { 40 | t.Errorf("build: error reading %s file => %s", testGo, err) 41 | return nil, false 42 | } 43 | var pos int 44 | lineScanner := bufio.NewScanner(bytes.NewReader(doc)) 45 | for i := 1; lineScanner.Scan(); i++ { 46 | text := lineScanner.Text() 47 | if strings.HasPrefix(text, "//go:generate") && 48 | strings.HasSuffix(text, t.Name()) { 49 | pos = i 50 | break 51 | } 52 | } 53 | if err = lineScanner.Err(); err != nil { 54 | t.Errorf("build: error scanning %s lines => %s", testGo, err) 55 | return nil, false 56 | } 57 | if pos == 0 { 58 | t.Errorf("build: unable to get pos in %s", testGo) 59 | return nil, false 60 | } 61 | testDirAbs, err := os.Getwd() 62 | if err != nil { 63 | t.Errorf("getwd: %s", err) 64 | return nil, false 65 | } 66 | return NewCliBuilder(ModeApi). 67 | WithFeats([]string{FeatureApiNoRt, FeatureApiFuture}). 68 | WithPkg(testPk). 69 | WithPwd(testDirAbs). 70 | WithFile(testGo, doc). 71 | WithPos(pos), true 72 | } 73 | t.Run("success", func(t *testing.T) { 74 | builder, ok := newBuilder(t) 75 | if !ok { 76 | return 77 | } 78 | if err := runTest(genFile, builder); err != nil { 79 | t.Errorf("build: %s", err) 80 | return 81 | } 82 | builder = builder.WithFeats([]string{FeatureApiLogx}). 83 | WithImports([]string{"url net/url"}, false). 84 | WithFuncs([]string{"escape=url.QueryEscape"}) 85 | if err := runTest(genFile, builder); err != nil { 86 | t.Errorf("build: %s", err) 87 | return 88 | } 89 | t.Run("no_generics", func(t *testing.T) { 90 | builder, ok := newBuilder(t) 91 | if !ok { 92 | return 93 | } 94 | if err := runTest(genFile, builder); err != nil { 95 | t.Errorf("build: %s", err) 96 | return 97 | } 98 | }) 99 | }) 100 | t.Run("fail_no_response", func(t *testing.T) { 101 | builder, ok := newBuilder(t) 102 | if !ok { 103 | return 104 | } 105 | if err := runTest(genFile, builder); err == nil { 106 | t.Errorf("build: expects errors, got nil") 107 | return 108 | } else if !strings.Contains(err.Error(), "checkResponse: ") { 109 | t.Errorf("build: expects checkResponse error, got => %s", err) 110 | return 111 | } 112 | }) 113 | t.Run("fail_no_error", func(t *testing.T) { 114 | builder, ok := newBuilder(t) 115 | if !ok { 116 | return 117 | } 118 | if err := runTest(genFile, builder); err == nil { 119 | t.Errorf("build: expects errors, got nil") 120 | return 121 | } else if !strings.Contains(err.Error(), "checkErrorType: ") { 122 | t.Errorf("build: expects checkErrorType error, got => %s", err) 123 | return 124 | } 125 | }) 126 | t.Run("fail_no_name_type", func(t *testing.T) { 127 | builder, ok := newBuilder(t) 128 | if !ok { 129 | return 130 | } 131 | if err := runTest(genFile, builder); err == nil { 132 | t.Errorf("build: expects errors, got nil") 133 | return 134 | } else if !strings.Contains(err.Error(), 135 | "should contain 'Name' and 'Type' both") { 136 | t.Errorf("build: expects NoNameType error, got => %s", err) 137 | return 138 | } 139 | }) 140 | t.Run("fail_invalid_IR", func(t *testing.T) { 141 | t.Run("I", func(t *testing.T) { 142 | builder, ok := newBuilder(t) 143 | if !ok { 144 | return 145 | } 146 | if err := runTest(genFile, builder); err == nil { 147 | t.Errorf("build: expects errors, got nil") 148 | return 149 | } else if !strings.Contains(err.Error(), 150 | "method can only have no income params and 1 returned value") { 151 | t.Errorf("build: expects InvalidI error, got => %s", err) 152 | return 153 | } 154 | }) 155 | t.Run("R", func(t *testing.T) { 156 | builder, ok := newBuilder(t) 157 | if !ok { 158 | return 159 | } 160 | if err := runTest(genFile, builder); err == nil { 161 | t.Errorf("build: expects errors, got nil") 162 | return 163 | } else if !strings.Contains(err.Error(), 164 | "method can only have no income params and 1 returned value") { 165 | t.Errorf("build: expects InvalidR error, got => %s", err) 166 | return 167 | } 168 | t.Run("type", func(t *testing.T) { 169 | builder, ok := newBuilder(t) 170 | if !ok { 171 | return 172 | } 173 | if err := runTest(genFile, builder); err == nil { 174 | t.Errorf("build: expects errors, got nil") 175 | return 176 | } else if !strings.Contains(err.Error(), "checkResponseType: ") { 177 | t.Errorf("build: expects checkResponseType error, got => %s", err) 178 | return 179 | } 180 | }) 181 | }) 182 | }) 183 | t.Run("fail_no_type_decl", func(t *testing.T) { 184 | builder, ok := newBuilder(t) 185 | if !ok { 186 | return 187 | } 188 | if err := runTest(genFile, builder); err == nil { 189 | t.Errorf("build: expects errors, got nil") 190 | return 191 | } else if !strings.Contains(err.Error(), 192 | "no available 'Interface' type declaration (*ast.GenDecl) found, ") { 193 | t.Errorf("build: expects NoTypeDecl error, got => %s", err) 194 | return 195 | } 196 | }) 197 | t.Run("fail_no_iface_type", func(t *testing.T) { 198 | builder, ok := newBuilder(t) 199 | if !ok { 200 | return 201 | } 202 | if err := runTest(genFile, builder); err == nil { 203 | t.Errorf("build: expects errors, got nil") 204 | return 205 | } else if !strings.Contains(err.Error(), 206 | "no available 'Interface' type declaration (*ast.InterfaceType) found, ") { 207 | t.Errorf("build: expects NoIfaceType error, got => %s", err) 208 | return 209 | } 210 | }) 211 | } 212 | -------------------------------------------------------------------------------- /gen/builder.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/ast" 5 | "io" 6 | ) 7 | 8 | type Mode int 9 | 10 | const ( 11 | ModeStart Mode = iota 12 | ModeApi 13 | ModeSqlx 14 | ModeEnd 15 | ) 16 | 17 | func (mode Mode) String() string { 18 | switch mode { 19 | case ModeApi: 20 | return "api" 21 | case ModeSqlx: 22 | return "sqlx" 23 | default: 24 | return sprintf("Mode(%d)", mode) 25 | } 26 | } 27 | 28 | func (mode Mode) IsValid() bool { 29 | return ModeStart < mode && mode < ModeEnd 30 | } 31 | 32 | type Doc []byte 33 | 34 | func (doc Doc) Bytes() []byte { 35 | return doc 36 | } 37 | 38 | func (doc Doc) Repr(node ast.Node) string { 39 | return getRepr(node, doc) 40 | } 41 | 42 | func (doc Doc) InspectMethod(node *ast.Field) *Method { 43 | return inspectMethod(node, doc) 44 | } 45 | 46 | func (doc Doc) IsContextType(ident string, expr ast.Expr) bool { 47 | return isContextType(ident, expr, doc) 48 | } 49 | 50 | func NewCliBuilder(mode Mode) *CliBuilder { 51 | assert(mode.IsValid(), "invalid mode") 52 | return &CliBuilder{ 53 | mode: mode, 54 | } 55 | } 56 | 57 | type CliBuilder struct { 58 | // mode 59 | mode Mode 60 | 61 | // feats 62 | feats []string 63 | 64 | // imports 65 | imports []string 66 | 67 | // disableAutoImport 68 | disableAutoImport bool 69 | 70 | // funcs 71 | funcs []string 72 | 73 | // pkg package name 74 | pkg string 75 | 76 | // pwd current working directory 77 | pwd string 78 | 79 | // file current file 80 | file string 81 | 82 | // doc total content of current file 83 | doc Doc 84 | 85 | // pos position of `go generate` command 86 | pos int 87 | 88 | // template 89 | template string 90 | } 91 | 92 | func (builder *CliBuilder) WithFeats(feats []string) *CliBuilder { 93 | builder.feats = feats 94 | return builder 95 | } 96 | 97 | func (builder *CliBuilder) WithImports(imports []string, disableAutoImport bool) *CliBuilder { 98 | builder.imports = imports 99 | builder.disableAutoImport = disableAutoImport 100 | return builder 101 | } 102 | 103 | func (builder *CliBuilder) WithFuncs(funcs []string) *CliBuilder { 104 | builder.funcs = funcs 105 | return builder 106 | } 107 | 108 | func (builder *CliBuilder) WithPkg(pkg string) *CliBuilder { 109 | builder.pkg = pkg 110 | return builder 111 | } 112 | 113 | func (builder *CliBuilder) WithPwd(pwd string) *CliBuilder { 114 | builder.pwd = pwd 115 | return builder 116 | } 117 | 118 | func (builder *CliBuilder) WithFile(file string, doc []byte) *CliBuilder { 119 | builder.file = file 120 | builder.doc = doc 121 | return builder 122 | } 123 | 124 | func (builder *CliBuilder) WithPos(pos int) *CliBuilder { 125 | builder.pos = pos 126 | return builder 127 | } 128 | 129 | func (builder *CliBuilder) WithTemplate(template string) *CliBuilder { 130 | builder.template = template 131 | return builder 132 | } 133 | 134 | func (builder *CliBuilder) Build(w io.Writer) error { 135 | switch builder.mode { 136 | case ModeApi: 137 | return builder.buildApi(w) 138 | case ModeSqlx: 139 | return builder.buildSqlx(w) 140 | default: 141 | } 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /gen/builder_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "testing" 7 | 8 | goformat "go/format" 9 | goimport "golang.org/x/tools/imports" 10 | ) 11 | 12 | func runTest(path string, builder Builder) (err error) { 13 | var bf bytes.Buffer 14 | if err = builder.Build(&bf); err != nil { 15 | return err 16 | } 17 | code := bf.Bytes() 18 | code, err = goformat.Source(code) 19 | if err != nil { 20 | return err 21 | } 22 | if err = os.WriteFile(path, code, 0644); err != nil { 23 | return err 24 | } 25 | code, err = goimport.Process(path, code, nil) 26 | if err != nil { 27 | return err 28 | } 29 | if err = os.WriteFile(path, code, 0644); err != nil { 30 | return err 31 | } 32 | if err = os.Remove(path); err != nil { 33 | return err 34 | } 35 | return nil 36 | } 37 | 38 | func TestMode(t *testing.T) { 39 | t.Run("string", func(t *testing.T) { 40 | type TestCase struct { 41 | Mode Mode 42 | String string 43 | IsValid bool 44 | } 45 | var testcases = []*TestCase{ 46 | {Mode: 0, String: "Mode(0)", IsValid: false}, 47 | {Mode: 1, String: "api", IsValid: true}, 48 | {Mode: 2, String: "sqlx", IsValid: true}, 49 | {Mode: 3, String: "Mode(3)", IsValid: false}, 50 | {Mode: 999, String: "Mode(999)", IsValid: false}, 51 | } 52 | for _, testcase := range testcases { 53 | if testcase.Mode.String() != testcase.String { 54 | t.Errorf("mode: %q != %q", testcase.Mode.String(), testcase.String) 55 | return 56 | } 57 | if testcase.Mode.IsValid() != testcase.IsValid { 58 | t.Errorf("mode: %v != %v", testcase.Mode.IsValid(), testcase.IsValid) 59 | return 60 | } 61 | } 62 | }) 63 | } 64 | -------------------------------------------------------------------------------- /gen/generate.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/token" 8 | "io" 9 | ) 10 | 11 | type ( 12 | Builder interface { 13 | Build(w io.Writer) error 14 | } 15 | 16 | Config struct { 17 | Package string `json:"package" toml:"package" yaml:"package"` 18 | Tags []string `json:"tags" toml:"tags" yaml:"tags"` 19 | Ident string `json:"ident" toml:"ident" yaml:"ident"` 20 | Features []string `json:"features" toml:"features" yaml:"features"` 21 | Imports []string `json:"imports" toml:"imports" yaml:"imports"` 22 | Funcs []string `json:"funcs" toml:"funcs" yaml:"funcs"` 23 | Schemas []*Schema `json:"schemas" toml:"schemas" yaml:"schemas"` 24 | Include string `json:"include" toml:"include" yaml:"include"` 25 | Declare []*Declare `json:"declare" toml:"declare" yaml:"declare"` 26 | } 27 | 28 | Schema struct { 29 | Meta string `json:"meta" toml:"meta" yaml:"meta"` 30 | Header string `json:"header" toml:"header" yaml:"header"` 31 | In []*Param `json:"in" toml:"in" yaml:"in"` 32 | Out []*Param `json:"out" toml:"out" yaml:"out"` 33 | } 34 | 35 | Param struct { 36 | Ident string `json:"ident" toml:"ident" yaml:"ident"` 37 | Type string `json:"type" toml:"type" yaml:"type"` 38 | } 39 | 40 | Declare struct { 41 | Ident string `json:"ident" toml:"ident" yaml:"ident"` 42 | Fields []*Field `json:"fields" toml:"fields" yaml:"fields"` 43 | } 44 | 45 | Field struct { 46 | Ident string `json:"ident" toml:"ident" yaml:"ident"` 47 | Type string `json:"type" toml:"type" yaml:"type"` 48 | Tag string `json:"tag" toml:"tag" yaml:"tag"` 49 | } 50 | ) 51 | 52 | func Generate(w io.Writer, mode Mode, cfg *Config) error { 53 | builder, err := toBuilder(mode, cfg) 54 | if err != nil { 55 | return err 56 | } 57 | return builder.Build(w) 58 | } 59 | 60 | func toBuilder(mode Mode, cfg *Config) (Builder, error) { 61 | methods := make([]*Method, len(cfg.Schemas)) 62 | doc := make(Doc, 0, len(cfg.Schemas)*(3+2)*7) 63 | 64 | for i := 0; i < len(cfg.Schemas); i++ { 65 | schema := cfg.Schemas[i] 66 | method := &Method{ 67 | Meta: schema.Meta, 68 | Header: schema.Header, 69 | Ident: getIdent(schema.Meta), 70 | OrderedIn: make([]string, len(schema.In)), 71 | In: make(map[string]ast.Expr, len(schema.In)), 72 | UnnamedIn: make([]ast.Expr, 0, 3), 73 | Out: make([]ast.Expr, len(schema.Out)), 74 | } 75 | for j := 0; j < len(schema.In); j++ { 76 | in := schema.In[j] 77 | expr, err := parseExpr(in.Type) 78 | if err != nil { 79 | return nil, fmt.Errorf("invalid expr: %w", err) 80 | } 81 | wrapped := &Expr{ 82 | Expr: expr, 83 | Offset: len(doc), 84 | Repr: in.Type, 85 | } 86 | if in.Ident == "" { 87 | method.UnnamedIn = append(method.UnnamedIn, wrapped) 88 | } else { 89 | method.OrderedIn[j] = in.Ident 90 | method.In[in.Ident] = wrapped 91 | } 92 | doc = append(doc, in.Type...) 93 | } 94 | for k := 0; k < len(schema.Out); k++ { 95 | out := schema.Out[k] 96 | expr, err := parseExpr(out.Type) 97 | if err != nil { 98 | return nil, fmt.Errorf("invalid expr: %w", err) 99 | } 100 | method.Out[k] = &Expr{ 101 | Expr: expr, 102 | Offset: len(doc), 103 | Repr: out.Type, 104 | } 105 | doc = append(doc, out.Type...) 106 | } 107 | methods[i] = method 108 | } 109 | 110 | // lazy update 111 | defer func() { 112 | for _, method := range methods { 113 | method.Source = doc 114 | } 115 | }() 116 | 117 | switch mode { 118 | case ModeApi: 119 | const ( 120 | ResponseIdent = "response" 121 | ResponseType = "T" 122 | ) 123 | 124 | var ( 125 | ResponseExpr = "__rt.Response" 126 | ) 127 | 128 | if in(cfg.Features, FeatureApiNoRt) { 129 | ResponseExpr = sprintf("%sResponseInterface", cfg.Ident) 130 | } 131 | 132 | hasResponse := func(schemas []*Schema) bool { 133 | for _, schema := range schemas { 134 | if isResponse(getIdent(schema.Meta)) { 135 | return true 136 | } 137 | } 138 | return false 139 | } 140 | 141 | // hack generic decl and schema def 142 | hackCfg := *cfg 143 | generics := make(map[string]ast.Expr) 144 | if !hasResponse(hackCfg.Schemas) { 145 | hackCfg.Ident = sprintf("%s[%s %s]", cfg.Ident, ResponseType, ResponseExpr) 146 | hackCfg.Schemas = append([]*Schema{ 147 | { 148 | Meta: ResponseIdent, 149 | Out: []*Param{ 150 | { 151 | Type: ResponseType, 152 | }, 153 | }, 154 | }, 155 | }, hackCfg.Schemas...) 156 | 157 | // generic interface for `Response() T` 158 | expr, _ := parseExpr(ResponseExpr) 159 | generics = map[string]ast.Expr{ 160 | ResponseType: &Expr{ 161 | Expr: expr, 162 | Offset: len(doc), 163 | Repr: ResponseExpr, 164 | }, 165 | } 166 | doc = append(doc, ResponseExpr...) 167 | 168 | // response type 169 | expr, _ = parseExpr(ResponseType) 170 | methods = append(methods, &Method{ 171 | Ident: ResponseIdent, 172 | Out: []ast.Expr{ 173 | &Expr{ 174 | Expr: expr, 175 | Offset: len(doc), 176 | Repr: ResponseType, 177 | }, 178 | }, 179 | }) 180 | doc = append(doc, ResponseType...) 181 | } 182 | 183 | return &apiContext{ 184 | Package: cfg.Package, 185 | BuildTags: cfg.Tags, 186 | Ident: cfg.Ident, 187 | Generics: generics, 188 | Methods: methods, 189 | Features: cfg.Features, 190 | Imports: cfg.Imports, 191 | Funcs: cfg.Funcs, 192 | Doc: doc, 193 | Schema: format(&hackCfg), 194 | }, nil 195 | case ModeSqlx: 196 | return &sqlxContext{ 197 | Package: cfg.Package, 198 | BuildTags: cfg.Tags, 199 | Ident: cfg.Ident, 200 | Methods: methods, 201 | Features: cfg.Features, 202 | Imports: cfg.Imports, 203 | Funcs: cfg.Funcs, 204 | Doc: doc, 205 | Schema: format(cfg), 206 | }, nil 207 | } 208 | 209 | return nil, fmt.Errorf("unimplemented mode %q", mode.String()) 210 | } 211 | 212 | func format(cfg *Config) string { 213 | var buf bytes.Buffer 214 | buf.WriteString("type " + cfg.Ident + " interface {") 215 | buf.WriteByte('\n') 216 | for _, schema := range cfg.Schemas { 217 | buf.WriteString(getIdent(schema.Meta)) 218 | buf.WriteByte('(') 219 | for _, in := range schema.In { 220 | buf.WriteString(in.Ident + " " + in.Type + ", ") 221 | } 222 | buf.WriteByte(')') 223 | buf.WriteByte('(') 224 | for _, out := range schema.Out { 225 | buf.WriteString(out.Ident + " " + out.Type + ", ") 226 | } 227 | buf.WriteByte(')') 228 | buf.WriteByte('\n') 229 | } 230 | buf.WriteByte('}') 231 | buf.WriteByte('\n') 232 | buf.WriteByte('\n') 233 | buf.WriteString(cfg.Include) 234 | buf.WriteByte('\n') 235 | buf.WriteByte('\n') 236 | for _, declare := range cfg.Declare { 237 | buf.WriteString("type " + declare.Ident + " struct {") 238 | buf.WriteByte('\n') 239 | for _, field := range declare.Fields { 240 | buf.WriteString(field.Ident) 241 | buf.WriteByte(' ') 242 | buf.WriteString(field.Type) 243 | buf.WriteByte(' ') 244 | buf.WriteString("`" + field.Tag + "`") 245 | buf.WriteByte('\n') 246 | } 247 | buf.WriteByte('}') 248 | buf.WriteByte('\n') 249 | buf.WriteByte('\n') 250 | } 251 | return buf.String() 252 | } 253 | 254 | type Expr struct { 255 | ast.Expr 256 | Offset int 257 | Repr string 258 | } 259 | 260 | func (expr *Expr) Pos() token.Pos { 261 | return token.Pos(expr.Offset + 1) 262 | } 263 | 264 | func (expr *Expr) End() token.Pos { 265 | return expr.Pos() + (expr.Expr.End() - expr.Expr.Pos()) 266 | } 267 | 268 | func (expr *Expr) String() string { 269 | return expr.Repr 270 | } 271 | 272 | func (expr *Expr) Unwrap() ast.Node { 273 | return expr.Expr 274 | } 275 | -------------------------------------------------------------------------------- /gen/generate_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | ) 7 | 8 | func TestGenerate(t *testing.T) { 9 | declareType := &Declare{ 10 | Ident: "Type", 11 | Fields: []*Field{ 12 | { 13 | Ident: "ID", 14 | Type: "int64", 15 | Tag: `json:"id" db:"id"`, 16 | }, 17 | }, 18 | } 19 | t.Run("api", func(t *testing.T) { 20 | cfg := &Config{ 21 | Features: []string{ 22 | FeatureApiNoRt, 23 | }, 24 | Imports: []string{ 25 | "gofmt \"fmt\"", 26 | }, 27 | Schemas: []*Schema{ 28 | { 29 | Meta: "Run POST https://localhost:port/path?{{ $.query }}", 30 | Header: "- Content-Type: application/json; charset=utf-8\n" + 31 | "- Authorization: Bearer XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX\n\n" + 32 | "{{ $.body }}", 33 | In: []*Param{ 34 | { 35 | Ident: "ctx", 36 | Type: "context.Context", 37 | }, 38 | { 39 | Ident: "query", 40 | Type: "gofmt.Stringer", 41 | }, 42 | { 43 | Ident: "body", 44 | Type: "gofmt.Stringer", 45 | }, 46 | }, 47 | Out: []*Param{ 48 | {Type: "*Type"}, 49 | {Type: "error"}, 50 | }, 51 | }, 52 | }, 53 | Declare: []*Declare{ 54 | declareType, 55 | }, 56 | } 57 | if err := Generate(io.Discard, ModeApi, cfg); err != nil { 58 | t.Errorf("generate: %s", err) 59 | return 60 | } 61 | }) 62 | t.Run("sqlx", func(t *testing.T) { 63 | cfg := &Config{ 64 | Features: []string{ 65 | FeatureSqlxNoRt, 66 | }, 67 | Imports: []string{ 68 | "gofmt \"fmt\"", 69 | }, 70 | Schemas: []*Schema{ 71 | { 72 | Meta: "Run query many bind", 73 | Header: "{{ $.query }};", 74 | In: []*Param{ 75 | { 76 | Ident: "ctx", 77 | Type: "context.Context", 78 | }, 79 | { 80 | Ident: "query", 81 | Type: "gofmt.Stringer", 82 | }, 83 | }, 84 | Out: []*Param{ 85 | {Type: "[]Type"}, 86 | {Type: "error"}, 87 | }, 88 | }, 89 | }, 90 | Declare: []*Declare{ 91 | declareType, 92 | }, 93 | } 94 | if err := Generate(io.Discard, ModeSqlx, cfg); err != nil { 95 | t.Errorf("generate: %s", err) 96 | return 97 | } 98 | }) 99 | t.Run("unknown", func(t *testing.T) { 100 | if err := Generate(io.Discard, 999, &Config{}); err == nil { 101 | t.Errorf("generate: expects errors, got nil") 102 | return 103 | } else if err.Error() != "unimplemented mode \"Mode(999)\"" { 104 | t.Errorf("generate: expects UnimplementedError, got => %s", err) 105 | return 106 | } 107 | }) 108 | } 109 | -------------------------------------------------------------------------------- /gen/integration/sqlx/executor.go: -------------------------------------------------------------------------------- 1 | //go:build !test || no_test 2 | // +build !test no_test 3 | 4 | package main 5 | 6 | func NewExecutorFromCore(ExecutorCoreInterface) Executor { 7 | panic("Please use `go run -tags test ...` to enable testing; " + 8 | "this is just a placeholder function for static analysis to proceed.") 9 | } 10 | 11 | type ExecutorCoreInterface interface{} 12 | -------------------------------------------------------------------------------- /gen/integration/sqlx/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/x5iu/defc/gen/integration/sqlx 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/mattn/go-sqlite3 v1.14.23 7 | github.com/x5iu/defc v0.0.0 8 | ) 9 | 10 | require github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect 11 | 12 | replace github.com/x5iu/defc => ../../.. 13 | -------------------------------------------------------------------------------- /gen/integration/sqlx/main.go: -------------------------------------------------------------------------------- 1 | //go:build test 2 | // +build test 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "log" 14 | "reflect" 15 | "strings" 16 | "time" 17 | 18 | _ "github.com/mattn/go-sqlite3" 19 | defc "github.com/x5iu/defc/runtime" 20 | ) 21 | 22 | var executor Executor 23 | 24 | func init() { 25 | log.SetFlags(log.Lshortfile | log.Lmsgprefix) 26 | log.SetPrefix("[defc] ") 27 | } 28 | 29 | func main() { 30 | ctx := context.Background() 31 | ctx, cancel := context.WithTimeout(ctx, time.Second*5) 32 | defer cancel() 33 | db := defc.MustOpen("sqlite3", ":memory:") 34 | defer db.Close() 35 | executor = NewExecutorFromCore(&sqlc{db}) 36 | defer executor.(io.Closer).Close() 37 | if err := executor.InitTable(ctx); err != nil { 38 | log.Fatalln(err) 39 | } 40 | var id int64 41 | err := executor.WithTx(func(tx Executor) error { 42 | r, errTx := tx.CreateUser(ctx, 43 | &User{name: "defc_test_0001"}, 44 | &User{name: "defc_test_0002"}, 45 | ) 46 | if errTx != nil { 47 | return errTx 48 | } 49 | r, errTx = tx.CreateUser(ctx, 50 | &User{name: "defc_test_0003"}, 51 | &User{name: "defc_test_0004"}, 52 | ) 53 | if errTx != nil { 54 | return errTx 55 | } 56 | r, errTx = tx.CreateUser(ctx, &User{ 57 | name: "defc_test_0005", 58 | projects: []*Project{ 59 | {name: "defc_test_0005_project_01"}, 60 | {name: "defc_test_0005_project_02"}, 61 | }, 62 | }) 63 | if errTx != nil { 64 | return errTx 65 | } 66 | if id, errTx = r.LastInsertId(); errTx != nil { 67 | return errTx 68 | } 69 | return nil 70 | }) 71 | if err != nil { 72 | log.Fatalln(err) 73 | } 74 | user, err := executor.GetUserByID(ctx, id) 75 | if err != nil { 76 | log.Fatalln(err) 77 | } 78 | if !(user.id == id && user.name == fmt.Sprintf("defc_test_%04d", id)) { 79 | log.Fatalf("unexpected user: User(id=%d, name=%q)\n", 80 | user.id, 81 | user.name) 82 | } 83 | if len(user.projects) != 2 { 84 | log.Fatalf("unexpected projects: %v\n", user.projects) 85 | } 86 | if !reflect.DeepEqual(user.projects, []*Project{ 87 | {id: 1, name: "defc_test_0005_project_01", userID: 5}, 88 | {id: 2, name: "defc_test_0005_project_02", userID: 5}, 89 | }) { 90 | log.Fatalf("unexpected projects: %v\n", user.projects) 91 | } 92 | users, err := executor.QueryUsers("defc_test_0001", "defc_test_0004") 93 | if err != nil { 94 | log.Fatalln(err) 95 | } 96 | if len(users) != 2 || users[0].id != 1 || users[1].id != 4 { 97 | var msg strings.Builder 98 | msg.WriteString("unexpected users: [") 99 | for i, unexpected := range users { 100 | if i > 0 { 101 | msg.WriteString(", ") 102 | } 103 | fmt.Fprintf(&msg, "User(id=%d, name=%q)", 104 | unexpected.id, 105 | unexpected.name) 106 | } 107 | msg.WriteString("]") 108 | log.Fatalln(msg.String()) 109 | } 110 | userIDs, err := executor.QueryUserIDs("defc_test_0001", "defc_test_0004") 111 | if err != nil { 112 | log.Fatalln(err) 113 | } 114 | if !reflect.DeepEqual(userIDs, UserIDs{{1}, {4}}) { 115 | log.Fatalf("unexpected userIDs: %v\n", userIDs) 116 | } 117 | } 118 | 119 | type sqlc struct { 120 | *defc.DB 121 | } 122 | 123 | func (c *sqlc) Log( 124 | _ context.Context, 125 | name string, 126 | query string, 127 | args any, 128 | elapse time.Duration, 129 | ) { 130 | argsjson, _ := json.Marshal(args) 131 | fmt.Printf("=== %s\n query: %s \n args: %v \nelapse: %s\n", 132 | name, 133 | query, 134 | string(argsjson), 135 | elapse, 136 | ) 137 | if !strings.HasPrefix(strings.TrimSpace(query), `/* {"name": "defc", "action": "test"} */`) { 138 | log.Fatalf("%q query not starts with sqlcomment header\n", name) 139 | } 140 | } 141 | 142 | var cmTemplate = `{{ define "sqlcomment" }}{{ sqlcomment . }}{{ end }}` 143 | 144 | //go:generate defc generate -T Executor -o executor.gen.go --features sqlx/future,sqlx/log,sqlx/callback --template :cmTemplate --function sqlcomment=sqlComment 145 | type Executor interface { 146 | // WithTx isolation=7 147 | WithTx(func(Executor) error) error 148 | 149 | // InitTable exec 150 | /* 151 | {{ template "sqlcomment" .ctx }} 152 | create table if not exists user 153 | ( 154 | id integer not null 155 | constraint user_pk 156 | primary key autoincrement, 157 | name text not null 158 | ); 159 | {{ template "sqlcomment" .ctx }} 160 | create table if not exists project 161 | ( 162 | id integer not null 163 | constraint project_pk 164 | primary key autoincrement, 165 | name text not null, 166 | user_id integer not null 167 | ); 168 | */ 169 | InitTable(ctx context.Context) error 170 | 171 | // CreateUser exec bind isolation=sql.LevelLinearizable 172 | /* 173 | {{ $context := .ctx }} 174 | {{ range $index, $user := .users }} 175 | {{ if $user.Projects }} 176 | {{ template "sqlcomment" $context }} 177 | insert into project ( name, user_id ) values 178 | {{ range $index, $project := $user.Projects }} 179 | {{ if gt $index 0 }},{{ end }} 180 | ( 181 | {{ bind $project.Name }}, 182 | 0 183 | ) 184 | {{ end }} 185 | ; 186 | {{ end }} 187 | {{ template "sqlcomment" $context }} 188 | insert into user ( name ) values ( {{ bind $user.Name }} ); 189 | {{ if $user.Projects }} 190 | {{ template "sqlcomment" $context }} 191 | update project set user_id = last_insert_rowid() where user_id = 0; 192 | {{ end }} 193 | {{ end }} 194 | */ 195 | CreateUser(ctx context.Context, users ...*User) (sql.Result, error) 196 | 197 | // GetUserByID query named 198 | // {{ template "sqlcomment" .ctx }} 199 | // select id, name from user where id = :id; 200 | GetUserByID(ctx context.Context, id int64) (*User, error) 201 | 202 | // QueryUsers query named const 203 | // /* {"name":: "defc", "action":: "test"} */ 204 | // select id, name from user where name in (:names); 205 | QueryUsers(names ...string) ([]*User, error) 206 | 207 | // QueryUserIDs query many named const 208 | // /* {"name":: "defc", "action":: "test"} */ 209 | // select id, name from user where name in (:names) order by id asc; 210 | QueryUserIDs(names ...string) (UserIDs, error) 211 | 212 | // GetProjectsByUserID query const 213 | // /* {"name": "defc", "action": "test"} */ 214 | // select id, name, user_id from project where user_id = ? and id != 0 order by id asc; 215 | GetProjectsByUserID(userID int64) ([]*Project, error) 216 | } 217 | 218 | type UserID struct { 219 | UserID int64 220 | } 221 | 222 | type UserIDs []UserID 223 | 224 | func (ids *UserIDs) FromRows(rows defc.Rows) error { 225 | if ids == nil { 226 | return errors.New("UserIDs.FromRows: nil pointer") 227 | } 228 | for rows.Next() { 229 | var id UserID 230 | if err := defc.ScanRow(rows, "id", &id.UserID); err != nil { 231 | return err 232 | } 233 | *ids = append(*ids, id) 234 | } 235 | return nil 236 | } 237 | 238 | type User struct { 239 | id int64 240 | name string 241 | projects []*Project 242 | } 243 | 244 | func (user *User) Name() string { return user.name } 245 | func (user *User) Projects() []*Project { return user.projects } 246 | 247 | func (user *User) Callback(ctx context.Context, e Executor) (err error) { 248 | user.projects, err = e.GetProjectsByUserID(user.id) 249 | return err 250 | } 251 | 252 | func (user *User) FromRow(row defc.Row) error { 253 | const ( 254 | FieldID = "id" 255 | FieldName = "name" 256 | ) 257 | columns, err := row.Columns() 258 | if err != nil { 259 | return err 260 | } 261 | scanner := make([]any, 0, 2) 262 | for _, column := range columns { 263 | switch column { 264 | case FieldID: 265 | scanner = append(scanner, &user.id) 266 | case FieldName: 267 | scanner = append(scanner, &user.name) 268 | default: 269 | scanner = append(scanner, new(sql.RawBytes)) 270 | } 271 | } 272 | if err = row.Scan(scanner...); err != nil { 273 | return err 274 | } 275 | return nil 276 | } 277 | 278 | type Project struct { 279 | id int64 280 | name string 281 | userID int64 282 | } 283 | 284 | func (project *Project) Name() string { return project.name } 285 | 286 | func (project *Project) FromRow(row defc.Row) error { 287 | return defc.ScanRow(row, 288 | "id", &project.id, 289 | "name", &project.name, 290 | "user_id", &project.userID, 291 | ) 292 | } 293 | 294 | func sqlComment(context.Context) string { 295 | return `/* {"name": "defc", "action": "test"} */` 296 | } 297 | -------------------------------------------------------------------------------- /gen/integration/sqlx_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | "regexp" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/x5iu/defc/gen" 14 | goimport "golang.org/x/tools/imports" 15 | ) 16 | 17 | func TestSqlx(t *testing.T) { 18 | var ( 19 | testPk = "main" 20 | testDir = "sqlx" 21 | testFile = "main.go" 22 | testGenFile = "executor.gen.go" 23 | ) 24 | pwd, err := os.Getwd() 25 | if err != nil { 26 | t.Errorf("getwd: %s", err) 27 | return 28 | } 29 | defer func() { 30 | if err = os.Chdir(pwd); err != nil { 31 | t.Errorf("chdir: %s", err) 32 | return 33 | } 34 | }() 35 | if err = os.Chdir(testDir); err != nil { 36 | t.Errorf("chdir: %s", err) 37 | return 38 | } 39 | defer os.Remove(testGenFile) 40 | doc, err := os.ReadFile(testFile) 41 | if err != nil { 42 | t.Errorf("read %s: %s", testFile, err) 43 | return 44 | } 45 | var ( 46 | featReg = regexp.MustCompile(`--features(?:\s|=)([\w,/]+)`) 47 | tmplReg = regexp.MustCompile(`--template(?:\s|=)([\w:]+)`) 48 | funcReg = regexp.MustCompile(`--function(?:\s|=)([\w=]+)`) 49 | 50 | pos int 51 | features []string 52 | template string 53 | functions []string 54 | ) 55 | lineScanner := bufio.NewScanner(bytes.NewReader(doc)) 56 | for i := 1; lineScanner.Scan(); i++ { 57 | text := lineScanner.Text() 58 | if strings.HasPrefix(text, "//go:generate") { 59 | pos = i 60 | featureList := featReg.FindAllStringSubmatch(text, -1) 61 | for _, sublist := range featureList { 62 | features = append(features, strings.Split(sublist[1], ",")...) 63 | } 64 | templateList := tmplReg.FindAllStringSubmatch(text, -1) 65 | for _, sublist := range templateList { 66 | template = strings.TrimPrefix(sublist[1], ":") 67 | } 68 | functionList := funcReg.FindAllStringSubmatch(text, -1) 69 | for _, sublist := range functionList { 70 | functions = append(functions, sublist[1]) 71 | } 72 | break 73 | } 74 | } 75 | if err = lineScanner.Err(); err != nil { 76 | t.Errorf("scan %s: %s", testFile, err) 77 | return 78 | } 79 | runTest := func(t *testing.T, feats ...string) { 80 | generator := gen.NewCliBuilder(gen.ModeSqlx). 81 | WithPkg(testPk). 82 | WithPwd(pwd). 83 | WithFile(testFile, doc). 84 | WithPos(pos). 85 | WithImports(nil, true). 86 | WithFeats(append(features, feats...)). 87 | WithTemplate(template). 88 | WithFuncs(functions) 89 | var buf bytes.Buffer 90 | if err = generator.Build(&buf); err != nil { 91 | t.Errorf("build: %s", err) 92 | return 93 | } 94 | if err = os.WriteFile(testGenFile, buf.Bytes(), 0644); err != nil { 95 | t.Errorf("write %s: %s", testGenFile, err) 96 | return 97 | } 98 | code, err := goimport.Process(testGenFile, buf.Bytes(), nil) 99 | if err != nil { 100 | t.Errorf("fix import %s: %s", testGenFile, err) 101 | return 102 | } 103 | if err = os.WriteFile(testGenFile, code, 0644); err != nil { 104 | t.Errorf("write %s: %s", testGenFile, err) 105 | return 106 | } 107 | if !runCommand(t, "go", "mod", "tidy") { 108 | return 109 | } 110 | if !runCommand(t, "go", "run", "-tags", "test", filepath.Join(pwd, testDir)) { 111 | return 112 | } 113 | } 114 | t.Run("rt", func(t *testing.T) { runTest(t) }) 115 | t.Run("nort", func(t *testing.T) { runTest(t, gen.FeatureSqlxNoRt) }) 116 | } 117 | 118 | func runCommand(t *testing.T, name string, args ...string) (success bool) { 119 | var ( 120 | stdout bytes.Buffer 121 | stderr bytes.Buffer 122 | ) 123 | cmd := exec.Command(name, args...) 124 | cmd.Stdout = &stdout 125 | cmd.Stderr = &stderr 126 | if err := cmd.Run(); err != nil { 127 | t.Logf("the integration test program encountered an error, "+ 128 | "some information is shown below: \n%s\n", stdout.String()) 129 | t.Errorf("run `%s %s`: \n%s", name, strings.Join(args, " "), stderr.String()) 130 | return false 131 | } 132 | if stdout.Len() > 0 { 133 | t.Logf("the integration test program has been successfully completed, "+ 134 | "with detailed information as follows: \n%s\n", stdout.String()) 135 | } 136 | return true 137 | } 138 | -------------------------------------------------------------------------------- /gen/method.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "net/http" 8 | "regexp" 9 | ) 10 | 11 | // Method represents a method declaration in an interface 12 | type Method struct { 13 | // Meta represents first-line comment of this method, who 14 | // looks like a command in cli, the first argument should 15 | // always be the name of this method, which is 'Ident' 16 | // field below 17 | Meta string 18 | 19 | // Header represents contents after first-line comment, 20 | // who is HTTP header message with '--mode=api' arg or 21 | // literal sql string with '--mode=sqlx' arg 22 | Header string 23 | 24 | Ident string 25 | OrderedIn []string // to make In sorted 26 | In map[string]ast.Expr 27 | UnnamedIn []ast.Expr 28 | Out []ast.Expr 29 | 30 | // Source represents the raw file content 31 | Source []byte 32 | } 33 | 34 | func (method *Method) TxType() (ast.Expr, error) { 35 | var lastType ast.Expr 36 | if len(method.OrderedIn) > 0 { 37 | lastIn := method.OrderedIn[len(method.OrderedIn)-1] 38 | lastType = method.In[lastIn] 39 | } else if len(method.UnnamedIn) > 0 { 40 | lastType = method.UnnamedIn[len(method.UnnamedIn)-1] 41 | } else { 42 | return nil, fmt.Errorf("method %s expects at least one argument", method.Ident) 43 | } 44 | if funcType, ok := lastType.(*ast.FuncType); ok { 45 | if len(funcType.Params.List) != 1 { 46 | return nil, fmt.Errorf( 47 | "method %s expects an *ast.FuncType as arguments, who has and only has one argument", 48 | method.Ident, 49 | ) 50 | } 51 | fnIn := funcType.Params.List[0] 52 | return fnIn.Type, nil 53 | } else { 54 | return nil, fmt.Errorf("method %s expects a function as the last argument", method.Ident) 55 | } 56 | } 57 | 58 | func (method *Method) SortIn() []string { 59 | return method.OrderedIn 60 | } 61 | 62 | var backslashRe = regexp.MustCompile(`\\[ \t\r]*?\n[ \t\r]*`) 63 | 64 | func (method *Method) MetaArgs() []string { 65 | rawArgs := splitArgs(backslashRe.ReplaceAllString(method.Meta, "")) 66 | args := make([]string, 0, len(rawArgs)) 67 | for i := 0; i < len(rawArgs); i++ { 68 | if rawArgs[i] != "" && rawArgs[i] != " " { 69 | args = append(args, rawArgs[i]) 70 | } 71 | } 72 | return args 73 | } 74 | 75 | // TmplURL should only be used with '--mode=api' arg 76 | func (method *Method) TmplURL() string { 77 | args := method.MetaArgs() 78 | if len(args) >= 1 { 79 | return args[len(args)-1] 80 | } 81 | return "" 82 | } 83 | 84 | var minusRe = regexp.MustCompile(`(?m)^[ \t]*?-[ \t]*`) 85 | 86 | // TmplHeader should only be used with '--mode=api' arg 87 | func (method *Method) TmplHeader() string { 88 | var ( 89 | header = method.Header 90 | body string 91 | ) 92 | if idx := index(header, "\r\n\r\n"); idx != -1 { 93 | body = trimSpace(header[idx+4:]) 94 | header = trimSpace(header[:idx]) 95 | } 96 | if idx := index(header, "\n\n"); idx != -1 { 97 | body = trimSpace(header[idx+2:]) 98 | header = trimSpace(header[:idx]) 99 | } 100 | header = minusRe.ReplaceAllString(header, "") + "\r\n\r\n" 101 | if len(body) > 0 { 102 | header += body 103 | } 104 | return header 105 | } 106 | 107 | var availableMethods = []string{ 108 | http.MethodGet, 109 | http.MethodHead, 110 | http.MethodPost, 111 | http.MethodPut, 112 | http.MethodPatch, 113 | http.MethodDelete, 114 | http.MethodConnect, 115 | http.MethodOptions, 116 | http.MethodTrace, 117 | } 118 | 119 | // MethodHTTP should only be used with '--mode=api' arg 120 | func (method *Method) MethodHTTP() string { 121 | args := method.MetaArgs() 122 | if len(args) >= 2 { 123 | for _, httpMethod := range availableMethods { 124 | if toUpper(args[1]) == httpMethod { 125 | return httpMethod 126 | } 127 | } 128 | } 129 | return "" 130 | } 131 | 132 | var availableOperations = []string{ 133 | sqlxOpExec, 134 | sqlxOpQuery, 135 | } 136 | 137 | // SqlxOperation should only be used with '--mode=sqlx' arg 138 | func (method *Method) SqlxOperation() string { 139 | args := method.MetaArgs() 140 | if len(args) >= 2 { 141 | for _, operation := range availableOperations { 142 | if toUpper(args[1]) == operation { 143 | return operation 144 | } 145 | } 146 | } 147 | return "" 148 | } 149 | 150 | // SqlxOptions should only be used with '--mode=sqlx' arg 151 | func (method *Method) SqlxOptions() []string { 152 | args := method.MetaArgs() 153 | if len(args) >= 3 { 154 | opts := make([]string, 0, len(args[2:])) 155 | for _, opt := range args[2:] { 156 | opts = append(opts, toUpper(opt)) 157 | } 158 | return opts 159 | } 160 | return nil 161 | } 162 | 163 | func (method *Method) HasContext() bool { 164 | for ident, ty := range method.In { 165 | if isContextType(ident, ty, method.Source) { 166 | return true 167 | } 168 | } 169 | 170 | // for sqlx WithTxContext, we should consider unnamed arguments 171 | for _, ty := range method.UnnamedIn { 172 | if isContextType("", ty, method.Source) { 173 | return true 174 | } 175 | } 176 | 177 | return false 178 | } 179 | 180 | // ExtraScan should only be used with '--mode=api' arg 181 | func (method *Method) ExtraScan() []string { 182 | if args := method.MetaArgs(); len(args) >= 3 { 183 | extra := make([]string, 0, 2) 184 | for _, arg := range args[2:] { 185 | if len(arg) > 6 && toUpper(arg[0:5]) == "SCAN(" && arg[len(arg)-1] == ')' { 186 | extra = append(extra, split(arg[5:len(arg)-1], ",")...) 187 | } 188 | } 189 | return extra 190 | } 191 | return nil 192 | } 193 | 194 | // SingleScan should only be used with '--mode=sqlx' arg 195 | func (method *Method) SingleScan() string { 196 | if args := method.MetaArgs(); len(args) >= 3 { 197 | for _, opt := range args[2:] { 198 | if len(opt) > 6 && toUpper(opt[0:5]) == "SCAN(" && opt[len(opt)-1] == ')' { 199 | expressions := split(opt[5:len(opt)-1], ",") 200 | for _, expr := range expressions { 201 | return expr 202 | } 203 | } 204 | } 205 | } 206 | return "" 207 | } 208 | 209 | // WrapFunc should only be used with '--mode=sqlx' arg 210 | func (method *Method) WrapFunc() string { 211 | const prefix = "WRAP=" 212 | if args := method.MetaArgs(); len(args) >= 3 { 213 | for _, opt := range args[2:] { 214 | if len(opt) > len(prefix) && toUpper(opt[:len(prefix)]) == prefix { 215 | return opt[len(prefix):] 216 | } 217 | } 218 | } 219 | return "" 220 | } 221 | 222 | // IsolationLv should only be used with '--mode=sqlx' arg 223 | func (method *Method) IsolationLv() string { 224 | const prefix = "ISOLATION=" 225 | if args := method.MetaArgs(); len(args) >= 3 { 226 | for _, opt := range args[2:] { 227 | if len(opt) > len(prefix) && toUpper(opt[:len(prefix)]) == prefix { 228 | return opt[len(prefix):] 229 | } 230 | } 231 | } 232 | return "" 233 | } 234 | 235 | // TxIsolationLv should only be used with '--mode=sqlx' arg 236 | func (method *Method) TxIsolationLv() string { 237 | const prefix = "ISOLATION=" 238 | if args := method.MetaArgs(); len(args) >= 2 { 239 | for _, opt := range args[1:] { 240 | if len(opt) > len(prefix) && toUpper(opt[:len(prefix)]) == prefix { 241 | return opt[len(prefix):] 242 | } 243 | } 244 | } 245 | return "" 246 | } 247 | 248 | // ArgumentsVar should only be used with '--mode=sqlx' arg 249 | func (method *Method) ArgumentsVar() string { 250 | const prefix = "ARGUMENTS=" 251 | if args := method.MetaArgs(); len(args) >= 3 { 252 | for _, opt := range args[2:] { 253 | if len(opt) > len(prefix) && toUpper(opt[:len(prefix)]) == prefix { 254 | return opt[len(prefix):] 255 | } 256 | } 257 | } 258 | return "" 259 | } 260 | 261 | // ReturnSlice should only be used with '--mode=api' arg 262 | func (method *Method) ReturnSlice() bool { 263 | if args := method.MetaArgs(); len(args) >= 3 { 264 | for _, arg := range args[2:] { 265 | switch toUpper(arg) { 266 | case "ONE": 267 | return false 268 | case "MANY": 269 | return true 270 | } 271 | } 272 | } 273 | return len(method.Out) > 1 && isSlice(method.Out[0]) 274 | } 275 | 276 | // MaxRetry should only be used with '--mode=api' arg 277 | func (method *Method) MaxRetry() string { 278 | const prefix = "RETRY=" 279 | if args := method.MetaArgs(); len(args) >= 3 { 280 | for _, arg := range args[2:] { 281 | if len(arg) > len(prefix) && toUpper(arg[:len(prefix)]) == prefix { 282 | return arg[len(prefix):] 283 | } 284 | } 285 | } 286 | // defaults to 2 287 | return "2" 288 | } 289 | 290 | // RequestOptions should only be used with '--mode=api' arg 291 | func (method *Method) RequestOptions() string { 292 | if args := method.MetaArgs(); len(args) >= 3 { 293 | for _, arg := range args[2:] { 294 | if len(arg) > 9 && toUpper(arg[0:8]) == "OPTIONS(" && arg[len(arg)-1] == ')' { 295 | return arg[8 : len(arg)-1] 296 | } 297 | } 298 | } 299 | return "" 300 | } 301 | 302 | func inspectMethod(node *ast.Field, source []byte) (method *Method) { 303 | field := node 304 | method = new(Method) 305 | method.Source = source 306 | if field.Doc != nil { 307 | method.Meta = trimSlash(field.Doc.List[0].Text) 308 | var buffer bytes.Buffer 309 | for _, header := range field.Doc.List[1:] { 310 | buffer.WriteString(trimSlash(header.Text)) 311 | buffer.WriteString("\r\n") 312 | } 313 | method.Header = buffer.String() 314 | switch len(method.Header) { 315 | default: 316 | if method.Header[len(method.Header)-4:] == "\r\n\r\n" { 317 | break 318 | } 319 | fallthrough 320 | case 2, 3: 321 | if method.Header[len(method.Header)-2:] == "\r\n" { 322 | method.Header += "\r\n" 323 | } else { 324 | method.Header += "\r\n\r\n" 325 | } 326 | case 1: 327 | method.Header += "\r\n\r\n" 328 | case 0: 329 | } 330 | } 331 | method.Ident = field.Names[0].Name 332 | if funcType, ok := field.Type.(*ast.FuncType); ok { 333 | inParams := funcType.Params.List 334 | method.In = make(map[string]ast.Expr, len(inParams)) 335 | for _, param := range inParams { 336 | if param.Names != nil { 337 | for _, name := range param.Names { 338 | method.OrderedIn = append(method.OrderedIn, name.Name) 339 | method.In[name.Name] = param.Type 340 | } 341 | } else { 342 | method.UnnamedIn = append(method.UnnamedIn, param.Type) 343 | } 344 | } 345 | if funcType.Results != nil { 346 | outParams := funcType.Results.List 347 | method.Out = make([]ast.Expr, 0, len(outParams)) 348 | for _, param := range outParams { 349 | method.Out = append(method.Out, param.Type) 350 | } 351 | } 352 | } 353 | return method 354 | } 355 | -------------------------------------------------------------------------------- /gen/method_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestMethod(t *testing.T) { 9 | m := &Method{} 10 | if tmplUrl := m.TmplURL(); tmplUrl != "" { 11 | t.Errorf("method: %q != \"\"", tmplUrl) 12 | return 13 | } 14 | if sqlxOp := m.SqlxOperation(); sqlxOp != "" { 15 | t.Errorf("method: %q != \"\"", sqlxOp) 16 | return 17 | } 18 | if sqlxOpt := m.SqlxOptions(); sqlxOpt != nil { 19 | t.Errorf("method: %v != nil", sqlxOpt) 20 | return 21 | } 22 | if exScan := m.ExtraScan(); exScan != nil { 23 | t.Errorf("method: %v != nil", exScan) 24 | return 25 | } 26 | if wrapFn := m.WrapFunc(); wrapFn != "" { 27 | t.Errorf("method: %q != \"\"", wrapFn) 28 | return 29 | } 30 | if isoLv := m.IsolationLv(); isoLv != "" { 31 | t.Errorf("method: %q != \"\"", isoLv) 32 | return 33 | } 34 | if argVar := m.ArgumentsVar(); argVar != "" { 35 | t.Errorf("method: %q != \"\"", argVar) 36 | return 37 | } 38 | m = &Method{Meta: "Test Query One Scan(obj) wrap=fn isolation=sql.LevelDefault arguments=sqlArguments retry=3 options(reqOpts)"} 39 | if exScan := m.ExtraScan(); !reflect.DeepEqual(exScan, []string{"obj"}) { 40 | t.Errorf("method: %v != [obj]", exScan) 41 | return 42 | } 43 | if wrapFn := m.WrapFunc(); wrapFn != "fn" { 44 | t.Errorf("method: %q != \"fn\"", wrapFn) 45 | return 46 | } 47 | if isoLv := m.IsolationLv(); isoLv != "sql.LevelDefault" { 48 | t.Errorf("method: %q != \"sql.LevelDefault\"", isoLv) 49 | return 50 | } 51 | if rtnSlice := m.ReturnSlice(); rtnSlice != false { 52 | t.Errorf("method: %v != false", rtnSlice) 53 | return 54 | } 55 | if argVar := m.ArgumentsVar(); argVar != "sqlArguments" { 56 | t.Errorf("method: %q != \"sqlArguments\"", argVar) 57 | return 58 | } 59 | if maxRetry := m.MaxRetry(); maxRetry != "3" { 60 | t.Errorf("method: %q != \"3\"", maxRetry) 61 | return 62 | } 63 | if options := m.RequestOptions(); options != "reqOpts" { 64 | t.Errorf("method: %q != \"reqOpts\"", options) 65 | return 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /gen/sqlx.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "go/ast" 8 | "go/parser" 9 | "go/token" 10 | "io" 11 | "strings" 12 | "text/template" 13 | 14 | _ "embed" 15 | ) 16 | 17 | const ( 18 | sqlxOpExec = "EXEC" 19 | sqlxOpQuery = "QUERY" 20 | 21 | sqlxMethodWithTx = "WithTx" 22 | 23 | sqlxCmdInclude = "#INCLUDE" 24 | sqlxCmdScript = "#SCRIPT" 25 | 26 | FeatureSqlxIn = "sqlx/in" 27 | FeatureSqlxLog = "sqlx/log" 28 | FeatureSqlxRebind = "sqlx/rebind" 29 | FeatureSqlxNoRt = "sqlx/nort" 30 | FeatureSqlxFuture = "sqlx/future" 31 | FeatureSqlxCallback = "sqlx/callback" 32 | FeatureSqlxAnyCallback = "sqlx/any-callback" 33 | ) 34 | 35 | func (builder *CliBuilder) buildSqlx(w io.Writer) error { 36 | inspectCtx, err := builder.inspectSqlx() 37 | if err != nil { 38 | return fmt.Errorf("inspectSqlx(%s, %d): %w", quote(join(builder.pwd, builder.file)), builder.pos, err) 39 | } 40 | return inspectCtx.Build(w) 41 | } 42 | 43 | type sqlxContext struct { 44 | Package string 45 | BuildTags []string 46 | Ident string 47 | Methods []*Method 48 | Embeds []ast.Expr 49 | WithTx bool 50 | WithTxType ast.Expr 51 | WithTxContext bool 52 | WithTxIsolation string 53 | Features []string 54 | Imports []string 55 | Funcs []string 56 | Pwd string 57 | Doc Doc 58 | Schema string 59 | Template string 60 | } 61 | 62 | func (ctx *sqlxContext) Build(w io.Writer) error { 63 | var fixedMethods []*Method = nil 64 | for i, method := range ctx.Methods { 65 | if l := len(method.Out); l == 0 || !checkErrorType(method.Out[l-1]) { 66 | return fmt.Errorf("checkErrorType: no 'error' found in method %s returned values", 67 | quote(method.Ident)) 68 | } 69 | 70 | if method.SingleScan() != "" { 71 | if len(method.Out) != 1 { 72 | return fmt.Errorf("%s method expects only error returned value when `scan(expr)` option has been specified", 73 | quote(method.Ident)) 74 | } 75 | } else { 76 | if len(method.Out) > 2 { 77 | return fmt.Errorf("%s method expects 2 returned value at most, got %d", 78 | quote(method.Ident), 79 | len(method.Out)) 80 | } 81 | } 82 | 83 | if method.Ident == sqlxMethodWithTx { 84 | txType, err := method.TxType() 85 | if err != nil { 86 | return err 87 | } 88 | ctx.WithTx = true 89 | ctx.WithTxType = txType 90 | ctx.WithTxContext = method.HasContext() 91 | ctx.WithTxIsolation = method.TxIsolationLv() 92 | fixedMethods = make([]*Method, 0, len(ctx.Methods)-1) 93 | fixedMethods = append(fixedMethods, ctx.Methods[:i]...) 94 | fixedMethods = append(fixedMethods, ctx.Methods[i+1:]...) 95 | } 96 | } 97 | 98 | // Modifying the value of Methods within the loop can cause the loop to skip the check for one of the methods. 99 | // To avoid this issue, we assign the modified Methods value to fixedMethods and then assign it back to the 100 | // original Methods after the loop ends. 101 | if fixedMethods != nil { 102 | ctx.Methods = fixedMethods 103 | } 104 | 105 | var bindInvoked bool 106 | // Since the text/template standard library does not provide a specific error type, we can only determine whether 107 | // the bind function has been invoked in the template through this rudimentary way. 108 | if _, err := template.New("detect_bind_function").Parse(ctx.Template); err != nil { 109 | bindInvoked = strings.Contains(err.Error(), `function "bind" not defined`) 110 | } 111 | 112 | // Small hack: When the --template/-t option is enabled, and "bind" function has been invoked, the Bind option 113 | // is enabled by default for all methods. 114 | if ctx.Template != "" && bindInvoked { 115 | const ( 116 | bindOption = "BIND" 117 | namedOption = "NAMED" 118 | ) 119 | // [2024-05-07] 120 | // Eventually, it was realized that arbitrarily adding a Bind option to each method was a foolish act. 121 | // Bind would require parsing the template content every time the method is called, which is very slow. 122 | // In some scenarios, there is simply a need for some common templates without wanting this heavy burden. 123 | // Therefore, today we will disable this unwise behavior. 124 | // 125 | // [2024-05-11] 126 | // When the situation becomes that one method includes a Bind option, but other methods do not include a 127 | // Bind option, the best strategy should be to add a Bind option to all methods. This is because the 128 | // template may contain calls to bind, and if you do not add a Bind option for the method, it will cause 129 | // an error in rendering the template. 130 | var useBind bool 131 | for _, method := range ctx.Methods { 132 | if hasOption(method.SqlxOptions(), bindOption) { 133 | useBind = true 134 | break 135 | } 136 | } 137 | if useBind { 138 | for _, method := range ctx.Methods { 139 | if !hasOption(method.SqlxOptions(), bindOption) && !hasOption(method.SqlxOptions(), namedOption) { 140 | method.Meta += " " + bindOption 141 | } 142 | } 143 | } 144 | } 145 | 146 | if err := ctx.genSqlxCode(w); err != nil { 147 | return fmt.Errorf("genSqlxCode: %w", err) 148 | } 149 | 150 | return nil 151 | } 152 | 153 | func (ctx *sqlxContext) HasFeature(feature string) bool { 154 | for _, current := range ctx.Features { 155 | if current == feature { 156 | return true 157 | } 158 | } 159 | return false 160 | } 161 | 162 | func (ctx *sqlxContext) MergedImports() (imports []string) { 163 | imports = []string{ 164 | quote("fmt"), 165 | quote("strconv"), 166 | quote("database/sql"), 167 | quote("context"), 168 | quote("text/template"), 169 | } 170 | 171 | if ctx.HasFeature(FeatureSqlxFuture) { 172 | imports = append(imports, quote("github.com/x5iu/defc/sqlx")) 173 | } else { 174 | imports = append(imports, quote("github.com/jmoiron/sqlx")) 175 | } 176 | 177 | if ctx.HasFeature(FeatureSqlxLog) { 178 | imports = append(imports, quote("time")) 179 | } 180 | 181 | if ctx.HasFeature(FeatureSqlxNoRt) { 182 | imports = append(imports, 183 | quote("errors"), 184 | quote("strings"), 185 | quote("reflect"), 186 | quote("sync"), 187 | quote("bytes"), 188 | quote("database/sql/driver")) 189 | } else { 190 | if len(ctx.Methods) > 0 { 191 | imports = append(imports, parseImport("__rt github.com/x5iu/defc/runtime")) 192 | } 193 | } 194 | 195 | for _, imp := range ctx.Imports { 196 | if !in(imports, imp) { 197 | imports = append(imports, parseImport(imp)) 198 | } 199 | } 200 | 201 | return imports 202 | } 203 | 204 | func (ctx *sqlxContext) AdditionalFuncs() (funcMap map[string]string) { 205 | funcMap = make(map[string]string, len(ctx.Funcs)) 206 | for _, fn := range ctx.Funcs { 207 | if key, value, ok := cutkv(fn); ok { 208 | funcMap[key] = value 209 | } 210 | } 211 | return funcMap 212 | } 213 | 214 | func (builder *CliBuilder) inspectSqlx() (*sqlxContext, error) { 215 | fset := token.NewFileSet() 216 | 217 | f, err := parser.ParseFile(fset, builder.file, builder.doc.Bytes(), parser.ParseComments) 218 | if err != nil { 219 | return nil, err 220 | } 221 | 222 | var ( 223 | genDecl *ast.GenDecl 224 | typeSpec *ast.TypeSpec 225 | ifaceType *ast.InterfaceType 226 | ) 227 | 228 | line := builder.pos + 1 229 | inspectDecl: 230 | for _, declIface := range f.Decls { 231 | if surroundLine(fset, declIface, line) { 232 | if decl, ok := declIface.(*ast.GenDecl); ok && decl.Tok == token.TYPE { 233 | genDecl = decl 234 | break inspectDecl 235 | } 236 | } 237 | } 238 | 239 | if genDecl == nil { 240 | return nil, fmt.Errorf( 241 | "no available 'Interface' type declaration (*ast.GenDecl) found, "+ 242 | "available *ast.GenDecl are: \n\n"+ 243 | "%s\n\n", concat(nodeMap(f.Decls, fmtNode), "\n")) 244 | } 245 | 246 | inspectType: 247 | for _, specIface := range genDecl.Specs { 248 | if afterLine(fset, specIface, line) { 249 | if spec, ok := specIface.(*ast.TypeSpec); ok { 250 | if iface, ok := spec.Type.(*ast.InterfaceType); ok && afterLine(fset, iface, line) { 251 | typeSpec = spec 252 | ifaceType = iface 253 | break inspectType 254 | } 255 | } 256 | } 257 | } 258 | 259 | if ifaceType == nil { 260 | return nil, fmt.Errorf( 261 | "no available 'Interface' type declaration (*ast.InterfaceType) found, "+ 262 | "available *ast.GenDecl are: \n\n"+ 263 | "%s\n\n", concat(nodeMap(f.Decls, fmtNode), "\n")) 264 | } 265 | 266 | if !builder.disableAutoImport { 267 | imports, err := getImports(builder.pkg, builder.pwd, builder.file, func(node ast.Node) bool { 268 | switch x := node.(type) { 269 | case *ast.TypeSpec: 270 | return x.Name.Name == typeSpec.Name.Name 271 | } 272 | return false 273 | }) 274 | 275 | if err != nil { 276 | return nil, err 277 | } 278 | 279 | for _, spec := range f.Imports { 280 | path := spec.Path.Value[1 : len(spec.Path.Value)-1] 281 | for _, imported := range imports { 282 | if path == imported.Path { 283 | var name string 284 | if spec.Name != nil { 285 | name = spec.Name.Name 286 | } 287 | if importRepr := strings.TrimSpace(name + " " + path); !in(builder.imports, importRepr) { 288 | builder.imports = append(builder.imports, importRepr) 289 | } 290 | } 291 | } 292 | } 293 | } 294 | 295 | var ( 296 | methods = make([]*ast.Field, 0, len(ifaceType.Methods.List)) 297 | embeds = make([]ast.Expr, 0, len(ifaceType.Methods.List)) 298 | ) 299 | 300 | for _, method := range ifaceType.Methods.List { 301 | if _, ok := method.Type.(*ast.FuncType); ok { 302 | methods = append(methods, method) 303 | } else if method.Names == nil { 304 | embeds = append(embeds, method.Type) 305 | } 306 | } 307 | 308 | for _, method := range methods { 309 | if name := method.Names[0].Name; name != sqlxMethodWithTx { 310 | if funcType, ok := method.Type.(*ast.FuncType); ok && !checkInput(funcType) { 311 | return nil, fmt.Errorf(""+ 312 | "input params for method %s should "+ 313 | "contain 'Name' and 'Type' both", 314 | quote(name)) 315 | } 316 | } 317 | } 318 | 319 | sqlxFeatures := make([]string, 0, len(builder.feats)) 320 | for _, feature := range builder.feats { 321 | if hasPrefix(feature, "sqlx") { 322 | sqlxFeatures = append(sqlxFeatures, feature) 323 | } 324 | } 325 | 326 | return &sqlxContext{ 327 | Package: builder.pkg, 328 | BuildTags: parseBuildTags(builder.doc), 329 | Ident: typeSpec.Name.Name, 330 | Methods: typeMap(methods, builder.doc.InspectMethod), 331 | Embeds: embeds, 332 | Features: sqlxFeatures, 333 | Imports: builder.imports, 334 | Funcs: builder.funcs, 335 | Doc: builder.doc, 336 | Template: builder.template, 337 | }, nil 338 | } 339 | 340 | func readHeader(header string, pwd string) (string, error) { 341 | var buf bytes.Buffer 342 | scanner := bufio.NewScanner(strings.NewReader(header)) 343 | var text string 344 | for { 345 | if text == "" { 346 | if !scanner.Scan() { 347 | break 348 | } 349 | text = scanner.Text() 350 | } 351 | 352 | var next string 353 | for { 354 | if !scanner.Scan() { 355 | break 356 | } 357 | next = scanner.Text() 358 | if len(next) > 0 && (next[0] == ' ' || next[0] == '\t') { 359 | text += " " + trimSpace(next) 360 | next = "" // next is consumed here 361 | } else { 362 | break 363 | } 364 | } 365 | 366 | text = trimSpace(text) 367 | args := splitArgs(text) 368 | 369 | // parse #include/#script command which should be placed in a new line 370 | if len(args) == 2 && toUpper(args[0]) == sqlxCmdInclude { 371 | // unquote path pattern if it is quoted 372 | path := unquote(args[1]) 373 | if !isAbs(path) { 374 | path = join(pwd, path) 375 | } 376 | // get filenames that match the pattern 377 | matches, err := glob(path) 378 | if err != nil { 379 | return "", err 380 | } 381 | // read each file into buffer 382 | for _, path = range matches { 383 | if !isAbs(path) { 384 | path = join(pwd, path) 385 | } 386 | content, err := read(path) 387 | if err != nil { 388 | return "", fmt.Errorf("os.ReadFile(%s): %w", quote(path), err) 389 | } 390 | buf.WriteString(string(content)) 391 | } 392 | } else if len(args) > 1 && toUpper(args[0]) == sqlxCmdScript { 393 | output, err := runCommand(args[1:]) 394 | if err != nil { 395 | return "", err 396 | } 397 | buf.WriteString(output) 398 | } else { 399 | buf.WriteString(text) 400 | } 401 | buf.WriteString("\r\n") 402 | 403 | // now next becomes the current line 404 | text = next 405 | } 406 | return buf.String(), nil 407 | } 408 | 409 | func hasOption(opts []string, opt string) bool { 410 | for _, o := range opts { 411 | if o == toUpper(opt) { 412 | return true 413 | } 414 | } 415 | return false 416 | } 417 | 418 | //go:embed template/sqlx.tmpl 419 | var sqlxTemplate string 420 | 421 | func (ctx *sqlxContext) genSqlxCode(w io.Writer) error { 422 | tmpl, err := template. 423 | New("defc(sqlx)"). 424 | Funcs(template.FuncMap{ 425 | "quote": quote, 426 | "hasOption": hasOption, 427 | "isSlice": isSlice, 428 | "isPointer": isPointer, 429 | "indirect": indirect, 430 | "deselect": deselect, 431 | "readHeader": func(header string) (string, error) { return readHeader(header, ctx.Pwd) }, 432 | "isContextType": func(ident string, expr ast.Expr) bool { return ctx.Doc.IsContextType(ident, expr) }, 433 | "sub": func(x, y int) int { return x - y }, 434 | "getRepr": func(node ast.Node) string { return ctx.Doc.Repr(node) }, 435 | "isQuery": func(op string) bool { return op == sqlxOpQuery }, 436 | "isExec": func(op string) bool { return op == sqlxOpExec }, 437 | }). 438 | Parse(sqlxTemplate) 439 | 440 | if err != nil { 441 | return err 442 | } 443 | 444 | if ctx.Schema != "" { 445 | if tmpl, err = tmpl.Parse(sprintf(`{{ define "schema" }} %s {{ end }}`, ctx.Schema)); err != nil { 446 | return err 447 | } 448 | } 449 | 450 | return tmpl.Execute(w, ctx) 451 | } 452 | -------------------------------------------------------------------------------- /gen/sqlx_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestBuildSqlx(t *testing.T) { 13 | const ( 14 | testPk = "test" 15 | testGo = testPk + ".go" 16 | ) 17 | var ( 18 | testDir = filepath.Join("testdata", "sqlx") 19 | testFile = testGo 20 | genFile = testPk + "." + strings.ReplaceAll(t.Name(), "/", "_") + ".go" 21 | ) 22 | pwd, err := os.Getwd() 23 | if err != nil { 24 | t.Errorf("getwd: %s", err) 25 | return 26 | } 27 | defer func() { 28 | if err = os.Chdir(pwd); err != nil { 29 | t.Errorf("chdir: %s", err) 30 | return 31 | } 32 | }() 33 | if err = os.Chdir(testDir); err != nil { 34 | t.Errorf("chdir: %s", err) 35 | return 36 | } 37 | newBuilder := func(t *testing.T) (*CliBuilder, bool) { 38 | doc, err := os.ReadFile(testFile) 39 | if err != nil { 40 | t.Errorf("build: error reading %s file => %s", testGo, err) 41 | return nil, false 42 | } 43 | var pos int 44 | lineScanner := bufio.NewScanner(bytes.NewReader(doc)) 45 | for i := 1; lineScanner.Scan(); i++ { 46 | text := lineScanner.Text() 47 | if strings.HasPrefix(text, "//go:generate") && 48 | strings.HasSuffix(text, t.Name()) { 49 | pos = i 50 | break 51 | } 52 | } 53 | if err = lineScanner.Err(); err != nil { 54 | t.Errorf("build: error scanning %s lines => %s", testGo, err) 55 | return nil, false 56 | } 57 | if pos == 0 { 58 | t.Errorf("build: unable to get pos in %s", testGo) 59 | return nil, false 60 | } 61 | testDirAbs, err := os.Getwd() 62 | if err != nil { 63 | t.Errorf("getwd: %s", err) 64 | return nil, false 65 | } 66 | return NewCliBuilder(ModeSqlx). 67 | WithFeats([]string{FeatureSqlxNoRt}). 68 | WithPkg(testPk). 69 | WithPwd(testDirAbs). 70 | WithFile(testGo, doc). 71 | WithPos(pos). 72 | WithTemplate(quote(`{{ define "test_template" }} test_template {{ end }}`)), true 73 | } 74 | t.Run("success", func(t *testing.T) { 75 | builder, ok := newBuilder(t) 76 | if !ok { 77 | return 78 | } 79 | if err := runTest(genFile, builder); err != nil { 80 | t.Errorf("build: %s", err) 81 | return 82 | } 83 | builder = builder.WithFeats([]string{FeatureSqlxFuture, FeatureSqlxLog}). 84 | WithImports([]string{"C", "json encoding/json"}, false). 85 | WithFuncs([]string{"marshal: json.Marshal"}) 86 | if err := runTest(genFile, builder); err != nil { 87 | t.Errorf("build: %s", err) 88 | return 89 | } 90 | }) 91 | t.Run("success_named_tx", func(t *testing.T) { 92 | builder, ok := newBuilder(t) 93 | if !ok { 94 | return 95 | } 96 | if err := runTest(genFile, builder); err != nil { 97 | t.Errorf("build: %s", err) 98 | return 99 | } 100 | builder = builder.WithFeats([]string{FeatureSqlxFuture, FeatureSqlxLog}). 101 | WithImports([]string{"C", "json encoding/json"}, false). 102 | WithFuncs([]string{"marshal: json.Marshal"}) 103 | if err := runTest(genFile, builder); err != nil { 104 | t.Errorf("build: %s", err) 105 | return 106 | } 107 | }) 108 | t.Run("fail_no_error", func(t *testing.T) { 109 | builder, ok := newBuilder(t) 110 | if !ok { 111 | return 112 | } 113 | if err := runTest(genFile, builder); err == nil { 114 | t.Errorf("build: expects errors, got nil") 115 | return 116 | } else if !strings.Contains(err.Error(), "checkErrorType: ") { 117 | t.Errorf("build: expects checkErrorType error, got => %s", err) 118 | return 119 | } 120 | }) 121 | t.Run("fail_single_scan", func(t *testing.T) { 122 | builder, ok := newBuilder(t) 123 | if !ok { 124 | return 125 | } 126 | if err := runTest(genFile, builder); err == nil { 127 | t.Errorf("build: expects errors, got nil") 128 | return 129 | } else if !strings.Contains(err.Error(), 130 | " expects only error returned value when `scan(expr)` option has been specified") { 131 | t.Errorf("build: expects SingleScan error, got => %s", err) 132 | return 133 | } 134 | }) 135 | t.Run("fail_2_values", func(t *testing.T) { 136 | builder, ok := newBuilder(t) 137 | if !ok { 138 | return 139 | } 140 | if err := runTest(genFile, builder); err == nil { 141 | t.Errorf("build: expects errors, got nil") 142 | return 143 | } else if !strings.Contains(err.Error(), 144 | " method expects 2 returned value at most") { 145 | t.Errorf("build: expects 2ValuesAtMost error, got => %s", err) 146 | return 147 | } 148 | }) 149 | t.Run("fail_no_name_type", func(t *testing.T) { 150 | builder, ok := newBuilder(t) 151 | if !ok { 152 | return 153 | } 154 | if err := runTest(genFile, builder); err == nil { 155 | t.Errorf("build: expects errors, got nil") 156 | return 157 | } else if !strings.Contains(err.Error(), 158 | "should contain 'Name' and 'Type' both") { 159 | t.Errorf("build: expects NoNameType error, got => %s", err) 160 | return 161 | } 162 | }) 163 | t.Run("fail_no_type_decl", func(t *testing.T) { 164 | builder, ok := newBuilder(t) 165 | if !ok { 166 | return 167 | } 168 | if err := runTest(genFile, builder); err == nil { 169 | t.Errorf("build: expects errors, got nil") 170 | return 171 | } else if !strings.Contains(err.Error(), 172 | "no available 'Interface' type declaration (*ast.GenDecl) found, ") { 173 | t.Errorf("build: expects NoTypeDecl error, got => %s", err) 174 | return 175 | } 176 | }) 177 | t.Run("fail_no_iface_type", func(t *testing.T) { 178 | builder, ok := newBuilder(t) 179 | if !ok { 180 | return 181 | } 182 | if err := runTest(genFile, builder); err == nil { 183 | t.Errorf("build: expects errors, got nil") 184 | return 185 | } else if !strings.Contains(err.Error(), 186 | "no available 'Interface' type declaration (*ast.InterfaceType) found, ") { 187 | t.Errorf("build: expects NoIfaceType error, got => %s", err) 188 | return 189 | } 190 | }) 191 | } 192 | -------------------------------------------------------------------------------- /gen/testdata/api/test.go: -------------------------------------------------------------------------------- 1 | //go:build !no_test 2 | // +build !no_test 3 | 4 | package test 5 | 6 | import "C" 7 | import ( 8 | "context" 9 | "net/http" 10 | 11 | gofmt "fmt" 12 | defc "github.com/x5iu/defc/runtime" 13 | 14 | _ "unsafe" 15 | ) 16 | 17 | //go:generate defc [mode] [output] [features...] TestBuildApi/success 18 | type Success[I any, R interface { 19 | Err() error 20 | ScanValues(...any) error 21 | FromBytes(string, []byte) error 22 | FromResponse(string, *http.Response) error 23 | Break() bool 24 | }] interface { 25 | Inner() I 26 | Response() Generic[I, R] 27 | 28 | // Run POST https://localhost:port/path?{{ $.query }} 29 | /* 30 | - Content-Type: application/json; charset=utf-8 31 | - Authorization: Bearer XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 32 | 33 | {{ $.body }} 34 | */ 35 | Run(ctx context.Context, query, body gofmt.Stringer) (*struct{}, error) 36 | } 37 | 38 | //go:generate defc [mode] [output] [features...] TestBuildApi/success/no_generics 39 | type SuccessNoGenerics interface { 40 | Response() Generic[defc.Response, defc.FutureResponse] 41 | 42 | // Run GET MANY https://localhost:port/path?{{ $.query }} 43 | Run(ctx context.Context, query gofmt.Stringer) ([]struct{}, error) 44 | 45 | // Crawl GET https://localhost:port/path?{{ $.query }} 46 | Crawl(ctx context.Context, query gofmt.Stringer) ([]struct{}, error) 47 | } 48 | 49 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_no_response 50 | type FailNoResponse interface { 51 | // Run POST https://localhost:port/path?{{ $.query }} 52 | /* 53 | - Content-Type: application/json; charset=utf-8 54 | - Authorization: Bearer XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 55 | 56 | {{ $.body }} 57 | */ 58 | Run(ctx context.Context, query, body gofmt.Stringer) error 59 | } 60 | 61 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_no_error 62 | type FailNoError[I any, R defc.Response] interface { 63 | Inner() I 64 | Response() R 65 | 66 | // Run POST https://localhost:port/path?{{ $.query }} 67 | /* 68 | - Content-Type: application/json; charset=utf-8 69 | - Authorization: Bearer XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 70 | 71 | {{ $.body }} 72 | */ 73 | Run(ctx context.Context, query, body gofmt.Stringer) 74 | } 75 | 76 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_no_name_type 77 | type FailNoNameType[I any, R defc.Response] interface { 78 | Inner() I 79 | Response() R 80 | 81 | // Run POST https://localhost:port/path?{{ $.query }} 82 | /* 83 | - Content-Type: application/json; charset=utf-8 84 | - Authorization: Bearer XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 85 | 86 | {{ $.body }} 87 | */ 88 | Run(context.Context, gofmt.Stringer) error 89 | } 90 | 91 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_invalid_IR/I 92 | type FailInvalidI[I any, R defc.Response] interface { 93 | Inner(_ struct{}) I 94 | Response() R 95 | } 96 | 97 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_invalid_IR/R 98 | type FailInvalidR[I any, R defc.Response] interface { 99 | Inner() I 100 | Response(_ struct{}) R 101 | } 102 | 103 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_invalid_IR/R/type 104 | type FailInvalidRType[I any, R defc.Response] interface { 105 | Inner() I 106 | Response() struct{} 107 | } 108 | 109 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_no_type_decl 110 | var FailNoTypeDecl struct{} 111 | 112 | //go:generate defc [mode] [output] [features...] TestBuildApi/fail_no_iface_type 113 | type FailNoIfaceType struct{} 114 | 115 | type Generic[T any, U any] struct{} 116 | -------------------------------------------------------------------------------- /gen/testdata/cycle/a/test.go: -------------------------------------------------------------------------------- 1 | package a 2 | 3 | import "github.com/x5iu/defc/gen/testdata/cycle/b" 4 | 5 | var _ b.Type 6 | 7 | type Type struct{} 8 | -------------------------------------------------------------------------------- /gen/testdata/cycle/b/test.go: -------------------------------------------------------------------------------- 1 | package b 2 | 3 | import "github.com/x5iu/defc/gen/testdata/cycle/a" 4 | 5 | var _ a.Type 6 | 7 | type Type struct{} 8 | -------------------------------------------------------------------------------- /gen/testdata/sqlx/test.go: -------------------------------------------------------------------------------- 1 | //go:build !no_test 2 | // +build !no_test 3 | 4 | package test 5 | 6 | import "C" 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | gofmt "fmt" 12 | 13 | _ "unsafe" 14 | ) 15 | 16 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/success 17 | type Success interface { 18 | gofmt.GoStringer 19 | WithTx(context.Context, func(tx Success) error) error 20 | 21 | // Run exec bind 22 | // #include "test.sql" 23 | /* 24 | #script cat 25 | "test.sql" 26 | */ 27 | // {{ $.query }} 28 | Run(ctx context.Context, query gofmt.Stringer) error 29 | 30 | // C query one bind 31 | // SELECT * FROM C WHERE type = {{ bind $.c }}; 32 | C(c *C.char) (struct{}, error) 33 | } 34 | 35 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/success_named_tx 36 | type SuccessNamedTx interface { 37 | gofmt.GoStringer 38 | WithTx(ctx context.Context, f func(tx Success) error) error 39 | 40 | // Run exec bind 41 | // #include "test.sql" 42 | /* 43 | #script cat 44 | "test.sql" 45 | */ 46 | // {{ $.query }} 47 | Run(ctx context.Context, query gofmt.Stringer) error 48 | 49 | // C query one bind 50 | // SELECT * FROM C WHERE type = {{ bind $.c }}; 51 | C(c *C.char) (struct{}, error) 52 | } 53 | 54 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_no_error 55 | type FailNoError interface { 56 | // Run exec bind 57 | // {{ $.query }} 58 | Run(ctx context.Context, query gofmt.Stringer) sql.Result 59 | } 60 | 61 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_single_scan 62 | type FailSingleScan interface { 63 | // Run exec bind scan(obj) 64 | // {{ $.query }} 65 | Run(ctx context.Context, obj any, query gofmt.Stringer) (sql.Result, error) 66 | } 67 | 68 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_2_values 69 | type Fail2Values interface { 70 | // Run exec bind 71 | // {{ $.query }} 72 | Run(ctx context.Context, query gofmt.Stringer) (sql.Result, struct{}, error) 73 | } 74 | 75 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_no_name_type 76 | type FailNoNameType interface { 77 | // Run exec bind 78 | // {{ $.query }} 79 | Run(context.Context, gofmt.Stringer) (sql.Result, error) 80 | } 81 | 82 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_no_type_decl 83 | var FailNoTypeDecl struct{} 84 | 85 | //go:generate defc [mode] [output] [features...] TestBuildSqlx/fail_no_iface_type 86 | type FailNoIfaceType struct{} 87 | -------------------------------------------------------------------------------- /gen/testdata/sqlx/test.sql: -------------------------------------------------------------------------------- 1 | SELECT CURRENT_TIMESTAMP; -------------------------------------------------------------------------------- /gen/tools.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "go/ast" 9 | "go/build" 10 | "go/importer" 11 | "go/parser" 12 | "go/token" 13 | "go/types" 14 | "net/http" 15 | "os" 16 | "os/exec" 17 | "path/filepath" 18 | "strconv" 19 | "strings" 20 | ) 21 | 22 | func assert(expr bool, msg string) { 23 | if !expr { 24 | panic(msg) 25 | } 26 | } 27 | 28 | const ( 29 | ExprErrorIdent = "error" 30 | ExprContextIdent = "Context" 31 | ) 32 | 33 | var ( 34 | sprintf = fmt.Sprintf 35 | errorf = fmt.Errorf 36 | quote = strconv.Quote 37 | trimPrefix = strings.TrimPrefix 38 | trimSuffix = strings.TrimSuffix 39 | trimSpace = strings.TrimSpace 40 | hasPrefix = strings.HasPrefix 41 | hasSuffix = strings.HasSuffix 42 | split = strings.Split 43 | concat = strings.Join 44 | toUpper = strings.ToUpper 45 | index = strings.Index 46 | cut = strings.Cut 47 | contains = strings.Contains 48 | join = filepath.Join 49 | isAbs = filepath.IsAbs 50 | glob = filepath.Glob 51 | base = filepath.Base 52 | stat = os.Stat 53 | read = os.ReadFile 54 | list = os.ReadDir 55 | ) 56 | 57 | func getPosRepr(src []byte, pos, end token.Pos) string { 58 | return string(src[pos-1 : end-1]) 59 | } 60 | 61 | func getRepr(node ast.Node, src []byte) string { 62 | return getPosRepr(src, node.Pos(), node.End()) 63 | } 64 | 65 | func surroundLine(fset *token.FileSet, node ast.Node, line int) bool { 66 | pos, end := fset.Position(node.Pos()), fset.Position(node.End()) 67 | return pos.Line <= line && end.Line >= line 68 | } 69 | 70 | func afterLine(fset *token.FileSet, node ast.Node, line int) bool { 71 | _, end := fset.Position(node.Pos()), fset.Position(node.End()) 72 | return end.Line >= line 73 | } 74 | 75 | func indirect(node ast.Node) ast.Node { 76 | if ptr, ok := getNode(node).(*ast.StarExpr); ok { 77 | // hack for compatibility 78 | return &Expr{ 79 | Expr: ptr.X, 80 | // `ptr` may come from parser.ParseExpr, so we calculate offset by field offsets 81 | Offset: int(node.Pos() + (ptr.X.Pos() - ptr.Star) - 1), 82 | } 83 | } 84 | return node 85 | } 86 | 87 | func deselect(node ast.Node) ast.Node { 88 | if sel, ok := getNode(node).(*ast.SelectorExpr); ok { 89 | // hack for compatibility 90 | return &Expr{ 91 | Expr: sel.Sel, 92 | // `sel` may come from parser.ParseExpr, so we calculate offset by field offsets 93 | Offset: int(node.Pos() + (sel.Sel.Pos() - sel.X.Pos()) - 1), 94 | } 95 | } 96 | return node 97 | } 98 | 99 | func isPointer(node ast.Node) bool { 100 | node = getNode(node) 101 | _, ok := node.(*ast.StarExpr) 102 | return ok 103 | } 104 | 105 | func isSlice(node ast.Node) bool { 106 | node = getNode(node) 107 | typ, ok := node.(*ast.ArrayType) 108 | if !ok { 109 | return false 110 | } 111 | 112 | // []byte is a special slice, equivalent to type string 113 | eltIsByte := false 114 | if elt, ok := typ.Elt.(*ast.Ident); ok { 115 | eltIsByte = elt.Name == "byte" 116 | } 117 | 118 | return typ.Len == nil && !eltIsByte 119 | } 120 | 121 | func checkInput(method *ast.FuncType) bool { 122 | for _, param := range method.Params.List { 123 | if len(param.Names) == 0 { 124 | return false 125 | } 126 | } 127 | return true 128 | } 129 | 130 | func checkErrorType(node ast.Node) bool { 131 | node = getNode(node) 132 | ident, ok := node.(*ast.Ident) 133 | return ok && ident.Name == ExprErrorIdent 134 | } 135 | 136 | func isContextType(ident string, expr ast.Expr, src []byte) bool { 137 | return ident == "ctx" || contains(getRepr(expr, src), ExprContextIdent) 138 | } 139 | 140 | func typeMap[T any, U any](src []T, f func(T) U) []U { 141 | dst := make([]U, len(src)) 142 | for i := 0; i < len(dst); i++ { 143 | dst[i] = f(src[i]) 144 | } 145 | return dst 146 | } 147 | 148 | func nodeMap[T ast.Node, U any](src []T, f func(ast.Node) U) []U { 149 | dst := make([]U, len(src)) 150 | for i := 0; i < len(dst); i++ { 151 | dst[i] = f(src[i]) 152 | } 153 | return dst 154 | } 155 | 156 | func fmtNode(node ast.Node) string { 157 | if stringer, ok := node.(fmt.Stringer); ok { 158 | return stringer.String() 159 | } 160 | return fmt.Sprintf("%#v", node) 161 | } 162 | 163 | func splitArgs(line string) (args []string) { 164 | line = trimSpace(line) 165 | if len(line) == 0 { 166 | return nil 167 | } 168 | 169 | var ( 170 | parenthesisStack int 171 | curlyBraceStack int 172 | doubleQuoted bool 173 | singleQuoted bool 174 | backQuoted bool 175 | arg []byte 176 | ) 177 | 178 | for i := 0; i < len(line); i++ { 179 | switch ch := line[i]; ch { 180 | case ' ', '\t', '\n', '\r': 181 | if doubleQuoted || singleQuoted || backQuoted || 182 | parenthesisStack > 0 || curlyBraceStack > 0 { 183 | arg = append(arg, ch) 184 | } else if len(arg) > 0 { 185 | args = append(args, string(arg)) 186 | arg = arg[:0] 187 | } 188 | case '"': 189 | if (i > 0 && line[i-1] == '\\') || singleQuoted || backQuoted { 190 | arg = append(arg, ch) 191 | } else { 192 | doubleQuoted = !doubleQuoted 193 | arg = append(arg, ch) 194 | } 195 | case '\'': 196 | if (i > 0 && line[i-1] == '\\') || doubleQuoted || backQuoted { 197 | arg = append(arg, ch) 198 | } else { 199 | singleQuoted = !singleQuoted 200 | arg = append(arg, ch) 201 | } 202 | case '`': 203 | if (i > 0 && line[i-1] == '\\') || doubleQuoted || singleQuoted { 204 | arg = append(arg, ch) 205 | } else { 206 | backQuoted = !backQuoted 207 | arg = append(arg, ch) 208 | } 209 | case '(': 210 | if !(doubleQuoted || singleQuoted || backQuoted) { 211 | parenthesisStack++ 212 | } 213 | arg = append(arg, ch) 214 | case ')': 215 | if !(doubleQuoted || singleQuoted || backQuoted) { 216 | parenthesisStack-- 217 | } 218 | arg = append(arg, ch) 219 | case '{': 220 | if !(doubleQuoted || singleQuoted || backQuoted) { 221 | curlyBraceStack++ 222 | } 223 | arg = append(arg, ch) 224 | case '}': 225 | if !(doubleQuoted || singleQuoted || backQuoted) { 226 | curlyBraceStack-- 227 | } 228 | arg = append(arg, ch) 229 | default: 230 | arg = append(arg, ch) 231 | } 232 | } 233 | 234 | if len(arg) > 0 { 235 | args = append(args, string(arg)) 236 | } 237 | 238 | return args 239 | } 240 | 241 | func trimSlash(comment string) string { 242 | if hasPrefix(comment, "//") { 243 | comment = trimPrefix(comment, "//") 244 | } else if hasPrefix(comment, "/*") { 245 | comment = trimPrefix(comment, "/*") 246 | if hasSuffix(comment, "*/") { 247 | comment = trimSuffix(comment, "*/") 248 | } 249 | } 250 | return trimSpace(comment) 251 | } 252 | 253 | func in[T comparable](list []T, item T) bool { 254 | for _, ele := range list { 255 | if ele == item { 256 | return true 257 | } 258 | } 259 | return false 260 | } 261 | 262 | func parseImport(imp string) string { 263 | elements := splitArgs(imp) 264 | if len(elements) == 1 { 265 | pkg := elements[0] 266 | if hasPrefix(pkg, "\"") && hasSuffix(pkg, "\"") { 267 | return pkg 268 | } 269 | return quote(pkg) 270 | } else { 271 | alias, pkg := elements[0], elements[1] 272 | if hasPrefix(pkg, "\"") && hasSuffix(pkg, "\"") { 273 | return alias + " " + pkg 274 | } 275 | return alias + " " + quote(pkg) 276 | } 277 | } 278 | 279 | var seps = []rune{ 280 | '=', 281 | ':', 282 | } 283 | 284 | func cutkv(kv string) (string, string, bool) { 285 | for _, ch := range kv { 286 | if in(seps, ch) { 287 | k, v, ok := cut(kv, string(ch)) 288 | if !ok { 289 | return kv, "", false 290 | } 291 | return trimSpace(k), trimSpace(v), true 292 | } 293 | } 294 | return kv, "", false 295 | } 296 | 297 | func getIdent(s string) string { 298 | if i := index(s, " "); i >= 0 { 299 | return s[:i] 300 | } 301 | return s 302 | } 303 | 304 | func parseExpr(input string) (expr ast.Expr, err error) { 305 | return parser.ParseExpr(input) 306 | } 307 | 308 | func getNode(node ast.Node) ast.Node { 309 | // NOTE: compatible with `defc generate` command 310 | for { 311 | if wrapper, ok := node.(interface{ Unwrap() ast.Node }); ok { 312 | node = wrapper.Unwrap() 313 | continue 314 | } 315 | break 316 | } 317 | return node 318 | } 319 | 320 | const ( 321 | addBuild = "+build" 322 | goBuild = "go:build" 323 | ) 324 | 325 | // parseBuildTags uses source []byte instead of ast.CommentGroup to parse build tags, 326 | // since parser.ParseFile removes commands like "//go:build" or "//go:generate", we 327 | // can't get build tags from ast.CommentGroup. 328 | func parseBuildTags(src []byte) (tags []string) { 329 | scanner := bufio.NewScanner(bytes.NewReader(src)) 330 | for scanner.Scan() { 331 | text := trimSlash(scanner.Text()) 332 | if hasPrefix(text, addBuild) || hasPrefix(text, goBuild) { 333 | tags = append(tags, text) 334 | } 335 | } 336 | return tags 337 | } 338 | 339 | func unquote(quoted string) (unquoted string) { 340 | if len(quoted) == 0 { 341 | return "" 342 | } 343 | if (hasPrefix(quoted, "\"") && hasSuffix(quoted, "\"")) || 344 | (hasPrefix(quoted, "'") && hasSuffix(quoted, "'")) || 345 | isBackQuoted(quoted) { 346 | return quoted[1 : len(quoted)-1] 347 | } 348 | return quoted 349 | } 350 | 351 | func isBackQuoted(s string) bool { 352 | return hasPrefix(s, "`") && hasSuffix(s, "`") 353 | } 354 | 355 | func runCommand(args []string) (string, error) { 356 | assert(len(args) > 0, "empty command") 357 | repl := make([]string, 0, len(args)) 358 | for i := 0; i < len(args); i++ { 359 | arg := args[i] 360 | if isBackQuoted(arg) { 361 | if innerArgs := splitArgs(unquote(arg)); len(innerArgs) > 0 { 362 | innerOutput, err := runCommand(innerArgs) 363 | if err != nil { 364 | return "", err 365 | } 366 | repl = append(repl, innerOutput) 367 | } 368 | } else if (hasPrefix(arg, "$(") && hasSuffix(arg, ")")) || 369 | (hasPrefix(arg, "${") && hasSuffix(arg, "}")) { 370 | if innerArgs := splitArgs(arg[2 : len(arg)-1]); len(innerArgs) > 0 { 371 | innerOutput, err := runCommand(innerArgs) 372 | if err != nil { 373 | return "", err 374 | } 375 | repl = append(repl, innerOutput) 376 | } 377 | } else { 378 | repl = append(repl, unquote(arg)) 379 | } 380 | } 381 | if len(repl) == 0 { 382 | return "", nil 383 | } 384 | var output bytes.Buffer 385 | command := exec.Command(repl[0], repl[1:]...) 386 | command.Stdout = &output 387 | command.Stderr = os.Stderr 388 | if err := command.Run(); err != nil { 389 | return "", err 390 | } 391 | return trimSpace(output.String()), nil 392 | } 393 | 394 | type Import struct { 395 | Name string 396 | Path string 397 | } 398 | 399 | func getImports(pkg string, dir string, name string, isIt func(ast.Node) bool) (imports []*Import, err error) { 400 | imports = make([]*Import, 0, 8) 401 | 402 | var ( 403 | fset = token.NewFileSet() 404 | files = make([]*ast.File, 0, 8) 405 | target *ast.File 406 | ) 407 | 408 | filenames, err := glob(join(dir, "*.go")) 409 | if err != nil { 410 | return nil, err 411 | } 412 | 413 | for _, filename := range filenames { 414 | file, err := parser.ParseFile(fset, filename, nil, 0) 415 | if err != nil { 416 | return nil, err 417 | } 418 | files = append(files, file) 419 | if base(name) == base(filename) { 420 | target = file 421 | } 422 | } 423 | 424 | conf := types.Config{ 425 | IgnoreFuncBodies: true, 426 | FakeImportC: true, 427 | Importer: &Importer{ 428 | imported: map[string]*types.Package{}, 429 | tokenFileSet: fset, 430 | defaultImport: importer.Default(), 431 | }, 432 | } 433 | 434 | info := &types.Info{ 435 | Types: map[ast.Expr]types.TypeAndValue{}, 436 | Instances: map[*ast.Ident]types.Instance{}, 437 | Defs: map[*ast.Ident]types.Object{}, 438 | Uses: map[*ast.Ident]types.Object{}, 439 | Implicits: map[ast.Node]types.Object{}, 440 | Selections: map[*ast.SelectorExpr]*types.Selection{}, 441 | Scopes: map[ast.Node]*types.Scope{}, 442 | InitOrder: []*types.Initializer{}, 443 | } 444 | 445 | _, err = conf.Check(pkg, fset, files, info) 446 | if err != nil { 447 | return nil, err 448 | } 449 | 450 | if target != nil { 451 | ast.Inspect(target, func(x ast.Node) bool { 452 | if x != nil && isIt(x) { 453 | ast.Inspect(x, func(n ast.Node) bool { 454 | switch node := n.(type) { 455 | case ast.Expr: 456 | if named, ok := info.TypeOf(node).(*types.Named); ok { 457 | if objPkg := named.Obj().Pkg(); objPkg != nil { 458 | imports = append(imports, &Import{ 459 | Name: objPkg.Name(), 460 | Path: objPkg.Path(), 461 | }) 462 | } 463 | } 464 | } 465 | return true 466 | }) 467 | } 468 | return true 469 | }) 470 | } 471 | 472 | return imports, nil 473 | } 474 | 475 | type Importer struct { 476 | imported map[string]*types.Package 477 | tokenFileSet *token.FileSet 478 | defaultImport types.Importer 479 | } 480 | 481 | var importing types.Package 482 | 483 | func (importer *Importer) ImportFrom(path, dir string, _ types.ImportMode) (*types.Package, error) { 484 | if path == "unsafe" { 485 | return types.Unsafe, nil 486 | } 487 | if path == "C" { 488 | return nil, errorf("unreachable: %s", "import \"C\"") 489 | } 490 | goroot := join(build.Default.GOROOT, "src") 491 | if _, err := stat(join(goroot, path)); err != nil { 492 | if os.IsNotExist(err) { 493 | target := importer.imported[path] 494 | if target != nil { 495 | if target == &importing { 496 | return nil, errors.New("cycle importing " + path) 497 | } 498 | return target, nil 499 | } 500 | importer.imported[path] = &importing 501 | pkg, err := build.Import(path, dir, 0) 502 | if err != nil { 503 | return nil, err 504 | } 505 | var files []*ast.File 506 | for _, name := range append(pkg.GoFiles, pkg.CgoFiles...) { 507 | name = join(pkg.Dir, name) 508 | file, err := parser.ParseFile(importer.tokenFileSet, name, nil, 0) 509 | if err != nil { 510 | return nil, err 511 | } 512 | files = append(files, file) 513 | } 514 | conf := types.Config{ 515 | Importer: importer, 516 | FakeImportC: true, 517 | IgnoreFuncBodies: true, 518 | } 519 | target, err = conf.Check(path, importer.tokenFileSet, files, nil) 520 | if err != nil { 521 | return nil, err 522 | } 523 | importer.imported[path] = target 524 | return target, nil 525 | } 526 | } 527 | if importerFrom, ok := importer.defaultImport.(types.ImporterFrom); ok { 528 | return importerFrom.ImportFrom(path, dir, 0) 529 | } 530 | return importer.defaultImport.Import(path) 531 | } 532 | 533 | func (importer *Importer) Import(path string) (*types.Package, error) { 534 | return importer.ImportFrom(path, "", 0) 535 | } 536 | 537 | var ErrNoTargetDeclFound = errors.New("no target decl found") 538 | 539 | func DetectTargetDecl(file string, src []byte, target string) (string, Mode, int, error) { 540 | fset := token.NewFileSet() 541 | f, err := parser.ParseFile(fset, file, src, parser.ParseComments) 542 | if err != nil { 543 | return "", 0, 0, err 544 | } 545 | for _, decl := range f.Decls { 546 | if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE { 547 | for _, spec := range genDecl.Specs { 548 | if typeSpec, ok := spec.(*ast.TypeSpec); ok { 549 | if target != "" && typeSpec.Name.String() != target { 550 | continue 551 | } 552 | if ifaceType, ok := typeSpec.Type.(*ast.InterfaceType); ok && ifaceType.Methods != nil { 553 | for _, field := range ifaceType.Methods.List { 554 | if _, ok := field.Type.(*ast.FuncType); ok { 555 | if field.Doc != nil && len(field.Doc.List) > 0 { 556 | firstLine := field.Doc.List[0] 557 | firstLineArgs := splitArgs(trimSlash(firstLine.Text)) 558 | if len(firstLineArgs) > 1 { 559 | switch opArg := firstLineArgs[1]; toUpper(opArg) { 560 | case sqlxOpExec, sqlxOpQuery: 561 | return f.Name.String(), ModeSqlx, fset.Position(typeSpec.Pos()).Line - 1, nil 562 | case http.MethodGet, 563 | http.MethodHead, 564 | http.MethodPost, 565 | http.MethodPut, 566 | http.MethodPatch, 567 | http.MethodDelete, 568 | http.MethodConnect, 569 | http.MethodOptions, 570 | http.MethodTrace: 571 | return f.Name.String(), ModeApi, fset.Position(typeSpec.Pos()).Line - 1, nil 572 | } 573 | } 574 | } 575 | if len(field.Names) > 0 { 576 | if funcName := field.Names[0]; funcName.String() == sqlxMethodWithTx { 577 | return f.Name.String(), ModeSqlx, fset.Position(typeSpec.Pos()).Line - 1, nil 578 | } else if funcNameString := funcName.String(); isInner(funcNameString) || isResponse(funcNameString) { 579 | return f.Name.String(), ModeApi, fset.Position(typeSpec.Pos()).Line - 1, nil 580 | } 581 | } 582 | } 583 | } 584 | } 585 | } 586 | } 587 | } 588 | } 589 | return "", 0, 0, ErrNoTargetDeclFound 590 | } 591 | -------------------------------------------------------------------------------- /gen/tools_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "go/ast" 9 | "go/importer" 10 | "go/token" 11 | "go/types" 12 | "path/filepath" 13 | "reflect" 14 | "strings" 15 | "testing" 16 | "unsafe" 17 | ) 18 | 19 | func randStr() string { 20 | b := make([]byte, 8) 21 | _, err := rand.Read(b) 22 | if err != nil { 23 | panic(fmt.Errorf("randStr: %w", err)) 24 | } 25 | return hex.EncodeToString(b) 26 | } 27 | 28 | func TestRunCommand(t *testing.T) { 29 | t.Run("backquoted", func(t *testing.T) { 30 | commandOutput, err := runCommand([]string{ 31 | "echo", 32 | "`echo test`", 33 | }) 34 | if err != nil { 35 | t.Errorf("runCommand: %s", err) 36 | return 37 | } 38 | if commandOutput != "test" { 39 | t.Errorf("runCommand: %q != %q", commandOutput, "test") 40 | return 41 | } 42 | t.Run("error", func(t *testing.T) { 43 | commandOutput, err := runCommand([]string{ 44 | "echo", 45 | "`a_binary_name_that_will_never_appear_in_syspath_" + randStr() + "`", 46 | }) 47 | if err == nil || commandOutput != "" { 48 | t.Errorf("runCommand: expects errors, got nil") 49 | return 50 | } else if !strings.HasPrefix(err.Error(), "exec: ") || 51 | !strings.Contains(err.Error(), "executable file not found in $PATH") { 52 | t.Errorf("runCommand: expects NotFoundError, got => %s", err) 53 | return 54 | } 55 | }) 56 | }) 57 | t.Run("paren", func(t *testing.T) { 58 | commandOutput, err := runCommand([]string{ 59 | "echo", 60 | "$(echo test)", 61 | }) 62 | if err != nil { 63 | t.Errorf("runCommand: %s", err) 64 | return 65 | } 66 | if commandOutput != "test" { 67 | t.Errorf("runCommand: %q != %q", commandOutput, "test") 68 | return 69 | } 70 | t.Run("error", func(t *testing.T) { 71 | commandOutput, err := runCommand([]string{ 72 | "echo", 73 | "$(a_binary_name_that_will_never_appear_in_syspath_" + randStr() + ")", 74 | }) 75 | if err == nil || commandOutput != "" { 76 | t.Errorf("runCommand: expects errors, got nil") 77 | return 78 | } else if !strings.HasPrefix(err.Error(), "exec: ") || 79 | !strings.Contains(err.Error(), "executable file not found in $PATH") { 80 | t.Errorf("runCommand: expects NotFoundError, got => %s", err) 81 | return 82 | } 83 | }) 84 | }) 85 | t.Run("braces", func(t *testing.T) { 86 | commandOutput, err := runCommand([]string{ 87 | "echo", 88 | "${echo test}", 89 | }) 90 | if err != nil { 91 | t.Errorf("runCommand: %s", err) 92 | return 93 | } 94 | if commandOutput != "test" { 95 | t.Errorf("runCommand: %q != %q", commandOutput, "test") 96 | return 97 | } 98 | t.Run("error", func(t *testing.T) { 99 | commandOutput, err := runCommand([]string{ 100 | "echo", 101 | "${a_binary_name_that_will_never_appear_in_syspath_" + randStr() + "}", 102 | }) 103 | if err == nil || commandOutput != "" { 104 | t.Errorf("runCommand: expects errors, got nil") 105 | return 106 | } else if !strings.HasPrefix(err.Error(), "exec: ") || 107 | !strings.Contains(err.Error(), "executable file not found in $PATH") { 108 | t.Errorf("runCommand: expects NotFoundError, got => %s", err) 109 | return 110 | } 111 | }) 112 | }) 113 | t.Run("nested", func(t *testing.T) { 114 | commandOutput, err := runCommand([]string{ 115 | "echo", 116 | "${echo $(echo `echo \"test\"`)}", 117 | }) 118 | if err != nil { 119 | t.Errorf("runCommand: %s", err) 120 | return 121 | } 122 | if commandOutput != "test" { 123 | t.Errorf("runCommand: %q != %q", commandOutput, "test") 124 | return 125 | } 126 | t.Run("error", func(t *testing.T) { 127 | commandOutput, err := runCommand([]string{ 128 | "echo", 129 | "${$(`a_binary_name_that_will_never_appear_in_syspath_" + randStr() + "`)}", 130 | }) 131 | if err == nil || commandOutput != "" { 132 | t.Errorf("runCommand: expects errors, got nil") 133 | return 134 | } else if !strings.HasPrefix(err.Error(), "exec: ") || 135 | !strings.Contains(err.Error(), "executable file not found in $PATH") { 136 | t.Errorf("runCommand: expects NotFoundError, got => %s", err) 137 | return 138 | } 139 | }) 140 | }) 141 | t.Run("empty", func(t *testing.T) { 142 | commandOutput, err := runCommand([]string{ 143 | "${}", 144 | }) 145 | if err != nil { 146 | t.Errorf("runCommand: %s", err) 147 | return 148 | } 149 | if commandOutput != "" { 150 | t.Errorf("runCommand: expects empty output, got %q", commandOutput) 151 | return 152 | } 153 | }) 154 | } 155 | 156 | func TestSplitArgs(t *testing.T) { 157 | type TestCase struct { 158 | Name string 159 | Data string 160 | Expect []string 161 | } 162 | var testcases = []*TestCase{ 163 | { 164 | Name: "single_quote", 165 | Data: "'test' \\'", 166 | Expect: []string{"'test'", "\\'"}, 167 | }, 168 | { 169 | Name: "back_quote", 170 | Data: "`test` \\`", 171 | Expect: []string{"`test`", "\\`"}, 172 | }, 173 | } 174 | for _, testcase := range testcases { 175 | t.Run(testcase.Name, func(t *testing.T) { 176 | if args := splitArgs(testcase.Data); !reflect.DeepEqual(args, testcase.Expect) { 177 | t.Errorf("split: %v != %v", args, testcase.Expect) 178 | return 179 | } 180 | }) 181 | } 182 | } 183 | 184 | func TestGetIdent(t *testing.T) { 185 | type TestCase struct { 186 | Name string 187 | Data string 188 | Expect string 189 | } 190 | var testcases = []*TestCase{ 191 | { 192 | Name: "with_space", 193 | Data: "test", 194 | Expect: "test", 195 | }, 196 | { 197 | Name: "without_space", 198 | Data: "test 0327", 199 | Expect: "test", 200 | }, 201 | } 202 | for _, testcase := range testcases { 203 | t.Run(testcase.Name, func(t *testing.T) { 204 | if ident := getIdent(testcase.Data); ident != testcase.Expect { 205 | t.Errorf("ident: %q != %q", ident, testcase.Expect) 206 | return 207 | } 208 | }) 209 | } 210 | } 211 | 212 | func TestParseExpr(t *testing.T) { 213 | if expr, err := parseExpr("json.RawMessage"); err != nil { 214 | t.Errorf("expr: %s", err) 215 | return 216 | } else if _, ok := expr.(*ast.SelectorExpr); !ok { 217 | t.Errorf("expr: expects *ast.SelectorExpr, got %T", expr) 218 | return 219 | } 220 | } 221 | 222 | func TestImporter(t *testing.T) { 223 | testImporter := &Importer{ 224 | imported: map[string]*types.Package{}, 225 | tokenFileSet: token.NewFileSet(), 226 | defaultImport: importer.Default(), 227 | } 228 | t.Run("unsafe", func(t *testing.T) { 229 | if pkg, err := testImporter.Import("unsafe"); err != nil { 230 | t.Errorf("import: %s", err) 231 | return 232 | } else if pkg == nil { 233 | t.Errorf("import: expects non-nil *types.Package, got nil") 234 | return 235 | } 236 | }) 237 | t.Run("C", func(t *testing.T) { 238 | if pkg, err := testImporter.Import("C"); err == nil || pkg != nil { 239 | t.Errorf("import: expects errors, got nil") 240 | return 241 | } else if err.Error() != "unreachable: import \"C\"" { 242 | t.Errorf("import: expects unreachable error, got => %s", err) 243 | return 244 | } 245 | }) 246 | t.Run("twice", func(t *testing.T) { 247 | pkg1, err := testImporter.Import("github.com/x5iu/defc/runtime") 248 | if err != nil { 249 | t.Errorf("import: %s", err) 250 | return 251 | } else if pkg1 == nil { 252 | t.Errorf("import: expects non-nil *types.Package, got nil") 253 | return 254 | } 255 | pkg2, err := testImporter.Import("github.com/x5iu/defc/runtime") 256 | if err != nil { 257 | t.Errorf("import: %s", err) 258 | return 259 | } else if pkg2 == nil { 260 | t.Errorf("import: expects non-nil *types.Package, got nil") 261 | return 262 | } 263 | if uintptr(unsafe.Pointer(pkg1)) != uintptr(unsafe.Pointer(pkg2)) { 264 | t.Errorf("import: cache not effective") 265 | return 266 | } 267 | }) 268 | t.Run("cycle_import", func(t *testing.T) { 269 | for _, dir := range [][]string{ 270 | {"testdata", "cycle", "a"}, 271 | {"testdata", "cycle", "b"}, 272 | } { 273 | cycleImporter := &Importer{ 274 | imported: map[string]*types.Package{}, 275 | tokenFileSet: token.NewFileSet(), 276 | defaultImport: importer.Default(), 277 | } 278 | _, err := cycleImporter.ImportFrom( 279 | "github.com/x5iu/defc/gen/"+strings.Join(dir, "/"), 280 | filepath.Join(dir...), 281 | 0, 282 | ) 283 | if err == nil { 284 | t.Errorf("import: expects errors, got nil") 285 | return 286 | } else if !strings.Contains(err.Error(), "cycle importing ") { 287 | t.Errorf("import: expects CycleImporting error, got => %s", err) 288 | return 289 | } 290 | } 291 | }) 292 | } 293 | 294 | func TestDetectTargetDecl(t *testing.T) { 295 | var src = []byte(` 296 | package test 297 | 298 | type TestApi1[I any] interface { 299 | Inner() I 300 | } 301 | 302 | type TestApi2[I any] interface { 303 | // Test POST https://localhost:port/test 304 | Test(r any) error 305 | 306 | Inner() I 307 | } 308 | 309 | type TestSqlx1 interface { 310 | WithTx(func(TestSqlx) error) error 311 | } 312 | 313 | type TestSqlx2 interface { 314 | // Select Query Scan(obj) 315 | Select(obj any) error 316 | } 317 | `) 318 | pkg, mod, pos, err := DetectTargetDecl("test.go", src, "") 319 | if err != nil { 320 | t.Errorf("detect: %s", err) 321 | return 322 | } else if pkg != "test" || mod != ModeApi || pos != 3 { 323 | t.Errorf("detect: pkg = %q; mod = %q; pos = %d", pkg, mod, pos) 324 | return 325 | } 326 | pkg, mod, pos, err = DetectTargetDecl("test.go", src, "TestApi2") 327 | if err != nil { 328 | t.Errorf("detect: %s", err) 329 | return 330 | } else if pkg != "test" || mod != ModeApi || pos != 7 { 331 | t.Errorf("detect: pkg = %q; mod = %q; pos = %d", pkg, mod, pos) 332 | return 333 | } 334 | pkg, mod, pos, err = DetectTargetDecl("test.go", src, "TestSqlx1") 335 | if err != nil { 336 | t.Errorf("detect: %s", err) 337 | return 338 | } else if pkg != "test" || mod != ModeSqlx || pos != 14 { 339 | t.Errorf("detect: pkg = %q; mod = %q; pos = %d", pkg, mod, pos) 340 | return 341 | } 342 | pkg, mod, pos, err = DetectTargetDecl("test.go", src, "TestSqlx2") 343 | if err != nil { 344 | t.Errorf("detect: %s", err) 345 | return 346 | } else if pkg != "test" || mod != ModeSqlx || pos != 18 { 347 | t.Errorf("detect: pkg = %q; mod = %q; pos = %d", pkg, mod, pos) 348 | return 349 | } 350 | _, _, _, err = DetectTargetDecl("test.go", src, "Test") 351 | if err == nil { 352 | t.Errorf("detect: expects errors, got nil") 353 | return 354 | } else if !errors.Is(err, ErrNoTargetDeclFound) { 355 | t.Errorf("detect: expects ErrNoTargetDeclFound, got => %s", err) 356 | return 357 | } 358 | } 359 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/x5iu/defc 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/BurntSushi/toml v1.4.0 7 | github.com/hashicorp/golang-lru/v2 v2.0.7 8 | github.com/spf13/cobra v1.8.1 9 | golang.org/x/tools v0.17.0 10 | gopkg.in/yaml.v3 v3.0.1 11 | ) 12 | 13 | require ( 14 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 15 | github.com/spf13/pflag v1.0.5 // indirect 16 | golang.org/x/mod v0.14.0 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /legacy.go: -------------------------------------------------------------------------------- 1 | //go:build legacy 2 | // +build legacy 3 | 4 | package main 5 | 6 | func onInitialize() {} 7 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "go/format" 9 | "os" 10 | "path/filepath" 11 | "strconv" 12 | "strings" 13 | 14 | "github.com/BurntSushi/toml" 15 | "github.com/spf13/cobra" 16 | goimport "golang.org/x/tools/imports" 17 | "gopkg.in/yaml.v3" 18 | 19 | "github.com/x5iu/defc/gen" 20 | runtime "github.com/x5iu/defc/runtime" 21 | ) 22 | 23 | const ( 24 | EnvPWD = "PWD" 25 | EnvGoPackage = "GOPACKAGE" 26 | EnvGoFile = "GOFILE" 27 | EnvGoLine = "GOLINE" 28 | ) 29 | 30 | var ( 31 | modeMap map[string]gen.Mode 32 | validModes []string 33 | validFeatures = []string{ 34 | gen.FeatureApiCache, 35 | gen.FeatureApiLog, 36 | gen.FeatureApiLogx, 37 | gen.FeatureApiClient, 38 | gen.FeatureApiPage, 39 | gen.FeatureApiError, 40 | gen.FeatureApiNoRt, 41 | gen.FeatureApiFuture, 42 | gen.FeatureApiIgnoreStatus, 43 | gen.FeatureApiGzip, 44 | gen.FeatureApiRetry, 45 | gen.FeatureSqlxIn, 46 | gen.FeatureSqlxLog, 47 | gen.FeatureSqlxRebind, 48 | gen.FeatureSqlxNoRt, 49 | gen.FeatureSqlxFuture, 50 | gen.FeatureSqlxCallback, 51 | gen.FeatureSqlxAnyCallback, 52 | } 53 | ) 54 | 55 | func init() { 56 | modeMap = make(map[string]gen.Mode, gen.ModeEnd-gen.ModeStart-1) 57 | validModes = make([]string, 0, gen.ModeEnd-gen.ModeStart-1) 58 | for m := gen.ModeStart + 1; m < gen.ModeEnd; m++ { 59 | modeMap[m.String()] = m 60 | validModes = append(validModes, m.String()) 61 | } 62 | cobra.OnInitialize(onInitialize) 63 | } 64 | 65 | var ( 66 | mode string 67 | output string 68 | features []string 69 | imports []string 70 | disableAutoImport bool 71 | funcs []string 72 | targetType string 73 | template string 74 | ) 75 | 76 | var ( 77 | defc = &cobra.Command{ 78 | Use: "defc --mode MODE --output FILE [--features LIST] [--import PACKAGE]... [--func FUNCTION]...", 79 | Short: "By defining the Schema, use go generate to generate database CRUD or HTTP request code.", 80 | Long: `defc originates from the tedium of repetitively writing code for "create, read, update, delete" (CRUD) 81 | operations and "network interface integration" in our daily work and life. 82 | 83 | For example, for database queries, we often need to: 84 | 85 | 1. Define a new function or method; 86 | 2. Write a new SQL query; 87 | 3. Execute the query, handle errors, and map the results to a structure; 88 | 4. If there are multiple SQL statements, initiate a transaction, and perform commit or rollback; 89 | 5. Log the query; 90 | 6. ... 91 | 92 | Similarly, for network interface integration, for a new interface, we often: 93 | 94 | 1. Define a new function or method; 95 | 2. Set the interface URL, configure parameters (such as Headers, Query, Body in HTTP requests); 96 | 3. Make the request, handle errors, and map the response to a structure; 97 | 4. If it involves pagination, concatenate the results of multiple paginated queries into the final result; 98 | 5. Log the request; 99 | 6. ... 100 | 101 | All of the above are repeated several times when writing new requirements or scenarios. Especially the parts related 102 | to queries, requests, error handling, transaction commit/rollback, data mapping, list concatenation, and log recording, 103 | which are all logically identical repetitive codes. Writing them is very annoying; some codes are very long, and 104 | copying and pasting require various changes to variable names, method names, and configuration information, which 105 | greatly affects development efficiency; 106 | 107 | Unfortunately, the Go language does not provide official macro features, and we cannot use macros to complete these 108 | complex repetitive codes like Rust does (of course, macros also have their limitations; they are devastating to code 109 | readability when not expanded and also affect IDE completion). However, fortunately, Go provides a workaround with go 110 | generate. Through go generate, we can approximately provide macro functionality, that is, code generation capabilities. 111 | 112 | Based on the above background, I wanted to implement a code generation tool. By defining the Schema of a query or 113 | request, it is possible to automatically generate code for the related CRUD operations or HTTP requests, which includes 114 | parameter construction, error handling, result mapping, and log recording logic. defc is my experimental attempt at 115 | such a schema-to-code generation; "def" stands for "define", indicating the behavior of setting up a Schema. Currently, 116 | defc provides the following two scenarios of code generation features: 117 | 118 | * CRUD code generation based on sqlx for databases 119 | * HTTP interface request code generation based on the net/http package in the Golang standard library`, 120 | Version: runtime.Version, 121 | SilenceUsage: true, 122 | SilenceErrors: true, 123 | CompletionOptions: cobra.CompletionOptions{ 124 | DisableDefaultCmd: true, 125 | DisableNoDescFlag: true, 126 | DisableDescriptions: true, 127 | HiddenDefaultCmd: true, 128 | }, 129 | PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) { 130 | // parent == nil means root command 131 | if cmd.Parent() == nil { 132 | if cmd.Flags().NFlag() == 0 && len(args) == 0 { 133 | defer os.Exit(0) 134 | return cmd.Usage() 135 | } 136 | } 137 | return nil 138 | }, 139 | RunE: func(cmd *cobra.Command, args []string) (err error) { 140 | if err = checkFlags(); err != nil { 141 | return err 142 | } 143 | 144 | var ( 145 | pwd = os.Getenv(EnvPWD) 146 | file = os.Getenv(EnvGoFile) 147 | doc []byte 148 | pos int 149 | ) 150 | 151 | if pwd == "" { 152 | pwd, err = os.Getwd() 153 | if err != nil { 154 | return fmt.Errorf("get current working directory: %w", err) 155 | } 156 | } 157 | 158 | if !filepath.IsAbs(file) { 159 | file = filepath.Join(pwd, file) 160 | } 161 | 162 | if doc, err = os.ReadFile(file); err != nil { 163 | return fmt.Errorf("$GOFILE: os.ReadFile(%q): %w", file, err) 164 | } 165 | 166 | if pos, err = strconv.Atoi(os.Getenv(EnvGoLine)); err != nil { 167 | return fmt.Errorf("$GOLINE: strconv.Atoi(%s): %w", os.Getenv(EnvGoLine), err) 168 | } 169 | 170 | builder := gen.NewCliBuilder(modeMap[mode]). 171 | WithFeats(features). 172 | // Since we are using the golang.org/x/tools/imports 173 | // package to handle imports, there is no need to 174 | // use the auto-import feature. 175 | // 176 | // disableAutoImport = true 177 | WithImports(imports, true). 178 | WithFuncs(funcs). 179 | WithPkg(os.Getenv(EnvGoPackage)). 180 | WithPwd(pwd). 181 | WithFile(file, doc). 182 | WithPos(pos) 183 | 184 | var buffer bytes.Buffer 185 | if err = builder.Build(&buffer); err != nil { 186 | return err 187 | } 188 | 189 | if !filepath.IsAbs(output) { 190 | output = filepath.Join(pwd, output) 191 | } 192 | return save(output, buffer.Bytes()) 193 | }, 194 | } 195 | 196 | generate = &cobra.Command{ 197 | Use: "generate FILE", 198 | Short: "Generate code from schema file", 199 | Long: `When the target file is a .go file, defc will analyze the file content, automatically determine the type 200 | representing the schema, and match the corresponding mode. This means you don't have to specify the corresponding mode 201 | using the '--mode/-m' parameter. You can also ignore the '--output' parameter, and defc will use the current file's name 202 | with a .gen suffix as the generated code file's name. This allows you to generate the corresponding code by only 203 | providing a filename without any flags. If your .go file contains multiple types that meet the criteria, you can also 204 | manually specify the type that defc should handle using the '--type/-T' parameter to avoid generating incorrect code.`, 205 | Args: cobra.MaximumNArgs(1), 206 | SilenceUsage: true, 207 | SilenceErrors: true, 208 | RunE: func(cmd *cobra.Command, args []string) (err error) { 209 | var file string 210 | if len(args) > 0 { 211 | file = args[0] 212 | } else if goFile := os.Getenv(EnvGoFile); goFile != "" { 213 | file = goFile 214 | } else { 215 | return fmt.Errorf("unable to retrieve schema file from the $GOFILE environment variable or positional arguments") 216 | } 217 | ext := filepath.Ext(file) 218 | if ext == ".go" { 219 | var ( 220 | pwd = os.Getenv(EnvPWD) 221 | doc []byte 222 | pos int 223 | pkg string 224 | mod gen.Mode 225 | out = output 226 | ) 227 | if pwd == "" { 228 | pwd, err = os.Getwd() 229 | if err != nil { 230 | return fmt.Errorf("get current working directory: %w", err) 231 | } 232 | } 233 | if !filepath.IsAbs(file) { 234 | file = filepath.Join(pwd, file) 235 | } 236 | if doc, err = os.ReadFile(file); err != nil { 237 | return fmt.Errorf("os.ReadFile(%q): %w", file, err) 238 | } 239 | var declNotFoundErr error 240 | pkg, mod, pos, declNotFoundErr = gen.DetectTargetDecl(file, doc, targetType) 241 | specifyManually := len(args) > 0 || targetType != "" 242 | if goLine := os.Getenv(EnvGoLine); goLine != "" && !specifyManually { 243 | if pos, err = strconv.Atoi(goLine); err != nil { 244 | return fmt.Errorf("strconv.Atoi(%s): %w", goLine, err) 245 | } 246 | } else { 247 | if declNotFoundErr != nil { 248 | return fmt.Errorf("gen.DetectTargetDecl: %w", declNotFoundErr) 249 | } 250 | } 251 | if goPackage := os.Getenv(EnvGoPackage); goPackage != "" { 252 | pkg = goPackage 253 | } 254 | if mode != "" { 255 | if mod = modeMap[mode]; !mod.IsValid() { 256 | return fmt.Errorf("invalid mode %q, available modes are: [%s]", mode, printStrings(validModes)) 257 | } 258 | } 259 | if out == "" { 260 | out = strings.TrimSuffix(file, ext) + ".gen" + ext 261 | } 262 | mode, output = mod.String(), out 263 | if err = checkFlags(); err != nil { 264 | return err 265 | } 266 | if template != "" { 267 | if mod == gen.ModeApi { 268 | return errors.New("the --template/-t option is not supported in the current mode=api scenario") 269 | } 270 | // The --template option supports two types of parameters. The first type is the path of a template 271 | // file, the program will read the content string of the file and generate a template. The second 272 | // type starts with a colon followed by an expression string. The program will remove the colon and 273 | // use the expression after the colon as the template string, generating a template based on the 274 | // value of that expression. 275 | if strings.HasPrefix(template, ":") { 276 | template = template[1:] 277 | if template == "" { 278 | return errors.New("invalid empty template") 279 | } 280 | } else { 281 | if !filepath.IsAbs(template) { 282 | template = filepath.Join(pwd, template) 283 | } 284 | templateBytes, err := os.ReadFile(template) 285 | if err != nil { 286 | return fmt.Errorf("os.ReadFile(%q): %w", template, err) 287 | } 288 | template = strconv.Quote(string(templateBytes)) 289 | } 290 | } 291 | builder := gen.NewCliBuilder(mod). 292 | WithFeats(features). 293 | WithImports(imports, true). 294 | WithFuncs(funcs). 295 | WithPkg(pkg). 296 | WithPwd(pwd). 297 | WithFile(file, doc). 298 | WithPos(pos). 299 | WithTemplate(template) 300 | var buffer bytes.Buffer 301 | if err = builder.Build(&buffer); err != nil { 302 | return err 303 | } 304 | if !filepath.IsAbs(output) { 305 | output = filepath.Join(pwd, output) 306 | } 307 | return save(output, buffer.Bytes()) 308 | } else { 309 | if err = checkFlags(); err != nil { 310 | return err 311 | } 312 | 313 | schema, err := os.ReadFile(file) 314 | if err != nil { 315 | return fmt.Errorf("os.ReadFile(%s): %w", args[0], err) 316 | } 317 | 318 | var cfg gen.Config 319 | switch ext := filepath.Ext(file); ext { 320 | case ".json": 321 | if err = json.Unmarshal(schema, &cfg); err != nil { 322 | return fmt.Errorf("json.Unmarshal: %w", err) 323 | } 324 | case ".toml": 325 | if err = toml.Unmarshal(schema, &cfg); err != nil { 326 | return fmt.Errorf("toml.Unmarshal: %w", err) 327 | } 328 | case ".yaml", ".yml": 329 | if err = yaml.Unmarshal(schema, &cfg); err != nil { 330 | return fmt.Errorf("yaml.Unmarshal: %w", err) 331 | } 332 | default: 333 | return fmt.Errorf("%s currently does not support schema extension %q", cmd.Root().Name(), ext) 334 | } 335 | 336 | cfg.Features = append(cfg.Features, features...) 337 | cfg.Imports = append(cfg.Imports, imports...) 338 | cfg.Funcs = append(cfg.Funcs, funcs...) 339 | 340 | var buffer bytes.Buffer 341 | if err = gen.Generate(&buffer, modeMap[mode], &cfg); err != nil { 342 | return err 343 | } 344 | 345 | return save(output, buffer.Bytes()) 346 | } 347 | }, 348 | } 349 | ) 350 | 351 | func save(name string, code []byte) (err error) { 352 | oriCode := code 353 | code, err = format.Source(code) 354 | if err != nil { 355 | return fmt.Errorf("format.Source: \n\n%s\n\n%w", oriCode, err) 356 | } 357 | if err = os.WriteFile(name, code, 0644); err != nil { 358 | return fmt.Errorf("os.WriteFile(%q, 0644): %w", name, err) 359 | } 360 | if !disableAutoImport { 361 | code, err = goimport.Process(name, code, nil) 362 | if err != nil { 363 | return fmt.Errorf("imports.Process: \n\n%s\n\n%w", oriCode, err) 364 | } 365 | 366 | if err = os.WriteFile(name, code, 0644); err != nil { 367 | return fmt.Errorf("os.WriteFile(%q, 0644): %w", name, err) 368 | } 369 | } 370 | return nil 371 | } 372 | 373 | func checkFlags() (err error) { 374 | if len(mode) == 0 { 375 | return fmt.Errorf("`-m/--mode` required, available modes are: [%s]", printStrings(validModes)) 376 | } 377 | if len(output) == 0 { 378 | return fmt.Errorf("`-o/--output` required") 379 | } 380 | if genMode := modeMap[mode]; !genMode.IsValid() { 381 | return fmt.Errorf("invalid mode %q, available modes are: [%s]", mode, printStrings(validModes)) 382 | } 383 | if err = checkFeatures(features); err != nil { 384 | return err 385 | } 386 | return nil 387 | } 388 | 389 | func checkFeatures(features []string) error { 390 | if len(features) == 0 { 391 | return nil 392 | } 393 | 394 | Check: 395 | for _, feature := range features { 396 | for _, valid := range validFeatures { 397 | if feature == valid { 398 | continue Check 399 | } 400 | } 401 | 402 | return fmt.Errorf("checkFeatures: invalid feature %s, available features are: \n\n%s\n\n", 403 | strconv.Quote(feature), 404 | printStrings(validFeatures)) 405 | } 406 | 407 | return nil 408 | } 409 | 410 | func printStrings(strings []string) string { 411 | var buf bytes.Buffer 412 | for i, s := range strings { 413 | buf.WriteString(strconv.Quote(s)) 414 | if i < len(strings)-1 { 415 | buf.WriteString(", ") 416 | } 417 | } 418 | return buf.String() 419 | } 420 | 421 | func init() { 422 | defc.AddCommand(generate) 423 | defc.SetHelpCommand(&cobra.Command{Hidden: true}) 424 | 425 | flags := defc.PersistentFlags() 426 | flags.StringVarP(&mode, "mode", "m", "", fmt.Sprintf("mode=[%s]", printStrings(validModes))) 427 | flags.StringVarP(&output, "output", "o", "", "output file name") 428 | flags.StringSliceVarP(&features, "features", "f", nil, fmt.Sprintf("features=[%s]", printStrings(validFeatures))) 429 | flags.StringArrayVar(&imports, "import", nil, "additional imports") 430 | flags.BoolVar(&disableAutoImport, "disable-auto-import", false, "disable auto import and import packages manually by '--import' option") 431 | flags.StringArrayVar(&funcs, "func", nil, "additional funcs") 432 | flags.StringArrayVar(&funcs, "function", nil, "additional funcs") 433 | 434 | // [2024-04-07] 435 | // Since we use the `checkFlags` function to validate required parameters, 436 | // we can disable Cobra's check for required flags. 437 | /* 438 | defc.MarkPersistentFlagRequired("mode") 439 | defc.MarkPersistentFlagRequired("output") 440 | */ 441 | 442 | genFlags := generate.PersistentFlags() 443 | genFlags.StringVarP(&targetType, "type", "T", "", "the type representing the schema definition") 444 | // --template/-t is an experimental parameter, during the experimental phase 445 | // it will only be applied to the generate command. 446 | genFlags.StringVarP(&template, "template", "t", "", "only applicable to additional template content under the sqlx mode") 447 | 448 | defc.MarkPersistentFlagFilename("output") 449 | defc.MarkFlagsMutuallyExclusive("func", "function") 450 | } 451 | 452 | func main() { 453 | cobra.CheckErr(defc.Execute()) 454 | } 455 | -------------------------------------------------------------------------------- /runtime/init.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import "reflect" 4 | 5 | func New[T any]() (v T) { 6 | val := reflect.ValueOf(&v).Elem() 7 | switch val.Kind() { 8 | case reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Pointer: 9 | val.Set(newType(val.Type())) 10 | } 11 | return v 12 | } 13 | 14 | func newType(typ reflect.Type) reflect.Value { 15 | switch typ.Kind() { 16 | case reflect.Slice: 17 | return reflect.MakeSlice(typ, 0, 0) 18 | case reflect.Map: 19 | return reflect.MakeMap(typ) 20 | case reflect.Chan: 21 | return reflect.MakeChan(typ, 0) 22 | case reflect.Func: 23 | return reflect.MakeFunc(typ, func(_ []reflect.Value) (results []reflect.Value) { 24 | results = make([]reflect.Value, typ.NumOut()) 25 | for i := 0; i < typ.NumOut(); i++ { 26 | results[i] = newType(typ.Out(i)) 27 | } 28 | return results 29 | }) 30 | case reflect.Pointer: 31 | return reflect.New(typ.Elem()) 32 | default: 33 | return reflect.Zero(typ) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /runtime/init_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestNew(t *testing.T) { 10 | t.Run("default", func(t *testing.T) { 11 | var boolVal bool 12 | newDefault := newType(reflect.TypeOf(boolVal)) 13 | if yes, ok := newDefault.Interface().(bool); !ok || yes { 14 | t.Errorf("default: bool = %v", newDefault) 15 | return 16 | } 17 | reflect.ValueOf(&boolVal).Elem().Set(newDefault) 18 | }) 19 | t.Run("pointer", func(t *testing.T) { 20 | newPointer := New[*string]() 21 | if newPointer == nil { 22 | t.Errorf("pointer: expects non-nil value, got nil") 23 | return 24 | } 25 | if *newPointer != "" { 26 | t.Errorf("pointer: *string != \"\"") 27 | return 28 | } 29 | *newPointer = "test" 30 | }) 31 | t.Run("slice", func(t *testing.T) { 32 | newSlice := New[[]int]() 33 | if newSlice == nil || len(newSlice) != 0 || cap(newSlice) != 0 { 34 | t.Errorf("slice: []int = %v", newSlice) 35 | return 36 | } 37 | newSlice = append(newSlice, 1, 2, 3, 4) 38 | }) 39 | t.Run("map", func(t *testing.T) { 40 | newMap := New[map[string]int]() 41 | if newMap == nil || len(newMap) != 0 { 42 | t.Errorf("map: map[string]int = %v", newMap) 43 | return 44 | } 45 | newMap["test"] = 0313 46 | }) 47 | t.Run("chan", func(t *testing.T) { 48 | newChan := New[chan struct{}]() 49 | if newChan == nil { 50 | t.Errorf("chan: chan struct{} = %v", newChan) 51 | return 52 | } 53 | go func() { <-newChan }() 54 | newChan <- struct{}{} 55 | close(newChan) 56 | }) 57 | t.Run("func", func(t *testing.T) { 58 | newFunc := New[func(int) float64]() 59 | if newFunc == nil { 60 | t.Errorf("func: func(int) float64 = nil") 61 | return 62 | } 63 | ret := newFunc(0313) 64 | if ret != 0.0 { 65 | t.Errorf("func: return %v != 0.0", ret) 66 | return 67 | } 68 | }) 69 | t.Run("interface", func(t *testing.T) { 70 | newInterface := New[error]() 71 | if newInterface != nil { 72 | t.Errorf("interface: error = %v", newInterface) 73 | return 74 | } 75 | newInterface = errors.New("error") 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /runtime/merge.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "reflect" 7 | 8 | tok "github.com/x5iu/defc/runtime/token" 9 | ) 10 | 11 | type NotAnArg interface { 12 | NotAnArg() 13 | } 14 | 15 | type ToArgs interface { 16 | ToArgs() []any 17 | } 18 | 19 | type ToNamedArgs interface { 20 | ToNamedArgs() map[string]any 21 | } 22 | 23 | var bytesType = reflect.TypeOf([]byte{}) 24 | 25 | func MergeArgs(args ...any) []any { 26 | dst := make([]any, 0, len(args)) 27 | for _, arg := range args { 28 | rv := reflect.ValueOf(arg) 29 | if _, notAnArg := arg.(NotAnArg); notAnArg { 30 | continue 31 | } else if toArgs, ok := arg.(ToArgs); ok { 32 | dst = append(dst, MergeArgs(toArgs.ToArgs()...)...) 33 | } else if _, ok = arg.(driver.Valuer); ok { 34 | dst = append(dst, arg) 35 | } else if (rv.Kind() == reflect.Slice && !rv.Type().AssignableTo(bytesType)) || 36 | rv.Kind() == reflect.Array { 37 | for i := 0; i < rv.Len(); i++ { 38 | dst = append(dst, MergeArgs(rv.Index(i).Interface())...) 39 | } 40 | } else { 41 | dst = append(dst, arg) 42 | } 43 | } 44 | return dst 45 | } 46 | 47 | func MergeNamedArgs(argsMap map[string]any) map[string]any { 48 | namedMap := make(map[string]any, len(argsMap)) 49 | for name, arg := range argsMap { 50 | rv := reflect.ValueOf(arg) 51 | if _, notAnArg := arg.(NotAnArg); notAnArg { 52 | continue 53 | } else if toNamedArgs, ok := arg.(ToNamedArgs); ok { 54 | for k, v := range toNamedArgs.ToNamedArgs() { 55 | namedMap[k] = v 56 | } 57 | } else if _, ok = arg.(driver.Valuer); ok { 58 | namedMap[name] = arg 59 | } else if _, ok = arg.(ToArgs); ok { 60 | namedMap[name] = arg 61 | } else if rv.Kind() == reflect.Map { 62 | iter := rv.MapRange() 63 | for iter.Next() { 64 | k, v := iter.Key(), iter.Value() 65 | if k.Kind() == reflect.String { 66 | namedMap[k.String()] = v.Interface() 67 | } 68 | } 69 | } else if rv.Kind() == reflect.Struct || 70 | (rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Struct) { 71 | rv = reflect.Indirect(rv) 72 | rt := rv.Type() 73 | for i := 0; i < rt.NumField(); i++ { 74 | if sf := rt.Field(i); sf.Anonymous { 75 | sft := sf.Type 76 | if sft.Kind() == reflect.Pointer { 77 | sft = sft.Elem() 78 | } 79 | for j := 0; j < sft.NumField(); j++ { 80 | if tag, exists := sft.Field(j).Tag.Lookup("db"); exists { 81 | for pos, char := range tag { 82 | if !(('0' <= char && char <= '9') || ('a' <= char && char <= 'z') || ('A' <= char && char <= 'Z') || char == '_') { 83 | tag = tag[:pos] 84 | break 85 | } 86 | } 87 | namedMap[tag] = rv.FieldByIndex([]int{i, j}).Interface() 88 | } 89 | } 90 | } else if tag, exists := sf.Tag.Lookup("db"); exists { 91 | for pos, char := range tag { 92 | if !(('0' <= char && char <= '9') || ('a' <= char && char <= 'z') || ('A' <= char && char <= 'Z') || char == '_') { 93 | tag = tag[:pos] 94 | break 95 | } 96 | } 97 | namedMap[tag] = rv.Field(i).Interface() 98 | } 99 | } 100 | } else { 101 | namedMap[name] = arg 102 | } 103 | } 104 | return namedMap 105 | } 106 | 107 | func BindVars(data any) string { 108 | var n int 109 | switch rv := reflect.ValueOf(data); rv.Kind() { 110 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 111 | n = int(rv.Int()) 112 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 113 | n = int(rv.Uint()) 114 | case reflect.Slice: 115 | if rv.Type().AssignableTo(bytesType) { 116 | n = 1 117 | } else { 118 | n = rv.Len() 119 | } 120 | default: 121 | n = 1 122 | } 123 | maxInt := func(a, b int) int { 124 | if a > b { 125 | return a 126 | } 127 | return b 128 | } 129 | bindvars := make([]string, 0, maxInt(2*n-1, 2)) 130 | for i := 0; i < n; i++ { 131 | if i > 0 { 132 | bindvars = append(bindvars, tok.Comma) 133 | } 134 | bindvars = append(bindvars, tok.Question) 135 | } 136 | return tok.MergeSqlTokens(bindvars) 137 | } 138 | 139 | func In[S ~[]any](query string, args S) (string, S, error) { 140 | tokens := tok.SplitTokens(query) 141 | targetArgs := make(S, 0, len(args)) 142 | targetQuery := make([]string, 0, len(tokens)) 143 | n := 0 144 | for _, token := range tokens { 145 | switch token { 146 | case tok.Question: 147 | if n >= len(args) { 148 | return "", nil, errors.New("number of bind-vars exceeds arguments") 149 | } 150 | nested := MergeArgs(args[n]) 151 | if len(nested) == 0 { 152 | return "", nil, errors.New("empty slice passed to 'in' query") 153 | } 154 | targetArgs = append(targetArgs, nested...) 155 | targetQuery = append(targetQuery, BindVars(len(nested))) 156 | n++ 157 | default: 158 | targetQuery = append(targetQuery, token) 159 | } 160 | } 161 | if n < len(args) { 162 | return "", nil, errors.New("number of bind-vars less than number arguments") 163 | } 164 | return tok.MergeSqlTokens(targetQuery), targetArgs, nil 165 | } 166 | 167 | // in is a special function designed to allow the sqlx package to reference it without using import, 168 | // but instead through go:linkname, in order to avoid circular references. 169 | func in(query string, args ...any) (string, []any, error) { 170 | return In[[]any](query, args) 171 | } 172 | 173 | type Arguments []any 174 | 175 | func (arguments *Arguments) add(argument any) string { 176 | merged := MergeArgs(argument) 177 | *arguments = append(*arguments, merged...) 178 | return BindVars(len(merged)) 179 | } 180 | 181 | func (arguments *Arguments) Add(argument any) string { return arguments.add(argument) } 182 | func (arguments *Arguments) Bind(argument any) string { return arguments.add(argument) } 183 | func (arguments *Arguments) Push(argument any) string { return arguments.add(argument) } 184 | func (arguments *Arguments) Append(argument any) string { return arguments.add(argument) } 185 | -------------------------------------------------------------------------------- /runtime/merge_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | type implNotAnArg struct{} 12 | 13 | func (implNotAnArg) NotAnArg() {} 14 | 15 | type implToArgs struct{} 16 | 17 | func (implToArgs) ToArgs() []any { 18 | return []any{ 19 | "test", 20 | 0314, 21 | true, 22 | } 23 | } 24 | 25 | type implToNamedArgs struct{} 26 | 27 | func (implToNamedArgs) ToNamedArgs() map[string]any { 28 | return map[string]any{ 29 | "s": "test", 30 | "i": 0315, 31 | "b": true, 32 | } 33 | } 34 | 35 | type nested struct { 36 | Nested string `db:"nested; charset=utf-8"` 37 | } 38 | 39 | func TestMergeArgs(t *testing.T) { 40 | type TestCase struct { 41 | Name string 42 | Data []any 43 | Expect []any 44 | } 45 | var testcases = []*TestCase{ 46 | { 47 | Name: "ints", 48 | Data: []any{1, 2, 3}, 49 | Expect: []any{1, 2, 3}, 50 | }, 51 | { 52 | Name: "valuer", 53 | Data: []any{sql.NullString{String: "test", Valid: true}}, 54 | Expect: []any{sql.NullString{String: "test", Valid: true}}, 55 | }, 56 | { 57 | Name: "list", 58 | Data: []any{1, 2, []int{3, 4, 5}, [2]int{6, 7}, &implNotAnArg{}}, 59 | Expect: []any{1, 2, 3, 4, 5, 6, 7}, 60 | }, 61 | { 62 | Name: "naa", 63 | Data: []any{&implNotAnArg{}}, 64 | Expect: []any{}, 65 | }, 66 | { 67 | Name: "nested", 68 | Data: []any{1, 2, []any{3, []any{4, 5, &implToArgs{}}}}, 69 | Expect: []any{1, 2, 3, 4, 5, "test", 0314, true}, 70 | }, 71 | { 72 | Name: "bytes", 73 | Data: []any{"test", []byte("test")}, 74 | Expect: []any{"test", []byte("test")}, 75 | }, 76 | } 77 | for _, testcase := range testcases { 78 | t.Run(testcase.Name, func(t *testing.T) { 79 | merged := MergeArgs(testcase.Data...) 80 | if len(merged) != len(testcase.Expect) { 81 | t.Errorf("merge: %d != %d", len(merged), len(testcase.Expect)) 82 | return 83 | } 84 | for i, arg := range merged { 85 | if !reflect.DeepEqual(arg, testcase.Expect[i]) { 86 | t.Errorf("merge: %v != %v", arg, testcase.Expect[i]) 87 | return 88 | } 89 | } 90 | }) 91 | } 92 | } 93 | 94 | func TestMergeNamedArgs(t *testing.T) { 95 | type TestCase struct { 96 | Name string 97 | Data map[string]any 98 | Expect map[string]any 99 | } 100 | var testcases = []*TestCase{ 101 | { 102 | Name: "ints", 103 | Data: map[string]any{ 104 | "one": 1, 105 | "two": 2, 106 | "three": 3, 107 | }, 108 | Expect: map[string]any{ 109 | "one": 1, 110 | "two": 2, 111 | "three": 3, 112 | }, 113 | }, 114 | { 115 | Name: "naa", 116 | Data: map[string]any{ 117 | "naa": &implNotAnArg{}, 118 | }, 119 | Expect: map[string]any{}, 120 | }, 121 | { 122 | Name: "tna", 123 | Data: map[string]any{ 124 | "tna": &implToNamedArgs{}, 125 | }, 126 | Expect: map[string]any{ 127 | "s": "test", 128 | "i": 0315, 129 | "b": true, 130 | }, 131 | }, 132 | { 133 | Name: "args", 134 | Data: map[string]any{ 135 | "args": &implToArgs{}, 136 | }, 137 | Expect: map[string]any{ 138 | "args": &implToArgs{}, 139 | }, 140 | }, 141 | { 142 | Name: "map", 143 | Data: map[string]any{ 144 | "map": map[string]any{ 145 | "one": 1, 146 | "two": 2, 147 | "three": 3, 148 | }, 149 | }, 150 | Expect: map[string]any{ 151 | "one": 1, 152 | "two": 2, 153 | "three": 3, 154 | }, 155 | }, 156 | { 157 | Name: "valuer", 158 | Data: map[string]any{ 159 | "valuer": sql.NullInt64{ 160 | Int64: 0315, 161 | Valid: true, 162 | }, 163 | }, 164 | Expect: map[string]any{ 165 | "valuer": sql.NullInt64{ 166 | Int64: 0315, 167 | Valid: true, 168 | }, 169 | }, 170 | }, 171 | { 172 | Name: "struct", 173 | Data: map[string]any{ 174 | "struct": struct { 175 | One, Two, Three int 176 | Name string `db:"name; charset=utf-8"` 177 | *nested 178 | }{ 179 | One: 1, 180 | Two: 2, 181 | Three: 3, 182 | Name: "test", 183 | nested: &nested{ 184 | Nested: "nested", 185 | }, 186 | }, 187 | }, 188 | Expect: map[string]any{ 189 | "name": "test", 190 | "nested": "nested", 191 | }, 192 | }, 193 | } 194 | for _, testcase := range testcases { 195 | t.Run(testcase.Name, func(t *testing.T) { 196 | merged := MergeNamedArgs(testcase.Data) 197 | if len(merged) != len(testcase.Expect) { 198 | t.Errorf("merge: %d != %d", len(merged), len(testcase.Expect)) 199 | return 200 | } 201 | for k, v := range merged { 202 | if !reflect.DeepEqual(v, testcase.Expect[k]) { 203 | t.Errorf("merge: %v != %v", v, testcase.Expect[k]) 204 | return 205 | } 206 | } 207 | }) 208 | } 209 | } 210 | 211 | func TestBindVars(t *testing.T) { 212 | type TestCase struct { 213 | Name string 214 | Input any 215 | N int 216 | } 217 | output := func(testcase *TestCase) string { 218 | var bf strings.Builder 219 | for i := 0; i < testcase.N; i++ { 220 | if i > 0 { 221 | bf.WriteString(",") 222 | } 223 | bf.WriteString("?") 224 | } 225 | return bf.String() 226 | } 227 | test := func(testcases []*TestCase) func(*testing.T) { 228 | return func(t *testing.T) { 229 | for _, testcase := range testcases { 230 | t.Run(testcase.Name, func(t *testing.T) { 231 | if bindvars, expect := BindVars(testcase.Input), output(testcase); bindvars != expect { 232 | t.Errorf("%T: %q != %q", testcase.Input, bindvars, expect) 233 | return 234 | } 235 | }) 236 | } 237 | } 238 | } 239 | t.Run("int", test([]*TestCase{ 240 | {Name: "zero", Input: 0, N: 0}, 241 | {Name: "int", Input: int(1), N: 1}, 242 | {Name: "int8", Input: int8(2), N: 2}, 243 | {Name: "int16", Input: int16(3), N: 3}, 244 | {Name: "int32", Input: int32(4), N: 4}, 245 | {Name: "int64", Input: int64(5), N: 5}, 246 | {Name: "uint", Input: uint(1), N: 1}, 247 | {Name: "uint8", Input: uint8(2), N: 2}, 248 | {Name: "uint16", Input: uint16(3), N: 3}, 249 | {Name: "uint32", Input: uint32(4), N: 4}, 250 | {Name: "uint64", Input: uint64(5), N: 5}, 251 | })) 252 | t.Run("slice", test([]*TestCase{ 253 | {Name: "three", Input: []int{1, 2, 3}, N: 3}, 254 | {Name: "zero", Input: []int{}, N: 0}, 255 | })) 256 | t.Run("bytes", test([]*TestCase{ 257 | {Name: "one", Input: []byte("test"), N: 1}, 258 | {Name: "empty", Input: []byte{}, N: 1}, 259 | {Name: "other", Input: json.RawMessage{}, N: 1}, 260 | })) 261 | t.Run("nil", test([]*TestCase{ 262 | {Name: "one", Input: nil, N: 1}, 263 | })) 264 | t.Run("other", test([]*TestCase{ 265 | {Name: "valuer", Input: sql.NullInt64{Int64: 0314, Valid: true}, N: 1}, 266 | })) 267 | } 268 | 269 | func TestIn(t *testing.T) { 270 | type TestCase struct { 271 | Name string 272 | Query string 273 | Args []any 274 | Expect string 275 | N int 276 | } 277 | var testcases = []*TestCase{ 278 | { 279 | Name: "mixin", 280 | Query: "(?) (?) (?)", 281 | Args: []any{ 282 | "test", 283 | [2]bool{true, false}, 284 | &implToArgs{}, 285 | }, 286 | Expect: "(?) (?,?) (?,?,?)", 287 | N: 6, 288 | }, 289 | { 290 | Name: "issue:2025-03-23-part-1", 291 | Query: "INSERT INTO migrate (version) VALUES (?);", 292 | Args: []any{"2025-03-23.sql"}, 293 | Expect: "INSERT INTO migrate (version) VALUES (?);", 294 | N: 1, 295 | }, 296 | { 297 | Name: "issue:2025-03-23-part-2", 298 | Query: "INSERT INTO migrate (version) VALUES (:version);", 299 | Args: []any{}, 300 | Expect: "INSERT INTO migrate (version) VALUES (:version);", 301 | N: 0, 302 | }, 303 | { 304 | Name: "issue:2025-03-23-part-3", 305 | Query: "INSERT INTO migrate (version) VALUES (?);", 306 | Args: []any{"2025-03-23.sql"}, 307 | Expect: "INSERT INTO migrate (version) VALUES (?);", 308 | N: 1, 309 | }, 310 | } 311 | for _, testcase := range testcases { 312 | t.Run(testcase.Name, func(t *testing.T) { 313 | query, args, err := In(testcase.Query, testcase.Args) 314 | if err != nil { 315 | t.Errorf("in: %s", err) 316 | return 317 | } 318 | if query != testcase.Expect { 319 | t.Errorf("in: %q != %q", query, testcase.Expect) 320 | return 321 | } 322 | if len(args) != testcase.N { 323 | t.Errorf("in: %d != %d", len(args), testcase.N) 324 | return 325 | } 326 | }) 327 | } 328 | t.Run("error", func(t *testing.T) { 329 | t.Run("more", func(t *testing.T) { 330 | _, _, err := In("?, ?", []any{1}) 331 | if err == nil { 332 | t.Errorf("errors: expects errors, got nil") 333 | return 334 | } 335 | if err.Error() != "number of bind-vars exceeds arguments" { 336 | t.Errorf("errors: unexpected error message => %q", err.Error()) 337 | return 338 | } 339 | }) 340 | t.Run("less", func(t *testing.T) { 341 | _, _, err := In("?", []any{1, 2}) 342 | if err == nil { 343 | t.Errorf("errors: expects errors, got nil") 344 | return 345 | } 346 | if err.Error() != "number of bind-vars less than number arguments" { 347 | t.Errorf("errors: unexpected error message => %q", err.Error()) 348 | return 349 | } 350 | }) 351 | t.Run("empty", func(t *testing.T) { 352 | _, _, err := In("?", []any{[]any{}}) 353 | if err == nil { 354 | t.Errorf("errors: expects errors, got nil") 355 | return 356 | } 357 | if err.Error() != "empty slice passed to 'in' query" { 358 | t.Errorf("errors: unexpected error message => %q", err.Error()) 359 | return 360 | } 361 | }) 362 | }) 363 | } 364 | 365 | func TestArguments(t *testing.T) { 366 | var arguments = make(Arguments, 0, 2) 367 | bindvars := arguments.Add([]int{1, 2, 3}) 368 | if bindvars != "?,?,?" { 369 | t.Errorf("arguments: %q != \"?,?,?\"", bindvars) 370 | return 371 | } 372 | if l := len(arguments); l != 3 { 373 | t.Errorf("arguments: len(arguments) != 3, got %d", l) 374 | return 375 | } 376 | if !reflect.DeepEqual(arguments, Arguments{1, 2, 3}) { 377 | t.Errorf("arguments: %v != [1, 2, 3]", arguments) 378 | return 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /runtime/pool.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "bytes" 5 | "sync" 6 | ) 7 | 8 | var bufferPool = sync.Pool{ 9 | New: func() any { 10 | return new(bytes.Buffer) 11 | }, 12 | } 13 | 14 | func GetBuffer() *bytes.Buffer { 15 | return bufferPool.Get().(*bytes.Buffer) 16 | } 17 | 18 | func PutBuffer(buffer *bytes.Buffer) { 19 | bufferPool.Put(buffer) 20 | } 21 | -------------------------------------------------------------------------------- /runtime/pool_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "runtime" 5 | "testing" 6 | "unsafe" 7 | ) 8 | 9 | func TestPool(t *testing.T) { 10 | buffer := GetBuffer() 11 | if buffer == nil { 12 | t.Errorf("pool: buffer = %v", buffer) 13 | return 14 | } 15 | bufferAddress := (uintptr)(unsafe.Pointer(buffer)) 16 | if bufferAddress == 0 { 17 | t.Errorf("pool: uintptr = %d", bufferAddress) 18 | return 19 | } 20 | PutBuffer(buffer) 21 | nextBuffer := GetBuffer() 22 | if nextBuffer == nil { 23 | t.Errorf("pool: buffer = %v", nextBuffer) 24 | return 25 | } 26 | nextBufferAddress := (uintptr)(unsafe.Pointer(nextBuffer)) 27 | if nextBufferAddress == 0 { 28 | t.Errorf("pool: uintptr = %d", nextBufferAddress) 29 | return 30 | } 31 | if nextBufferAddress != bufferAddress { 32 | // sync/pool.go:121 33 | // Get may choose to ignore the pool and treat it as empty. 34 | // Callers should not assume any relation between values passed to Put and 35 | // the values returned by Get. 36 | t.Logf("pool: %d != %d", nextBufferAddress, bufferAddress) 37 | return 38 | } 39 | runtime.KeepAlive(buffer) 40 | } 41 | -------------------------------------------------------------------------------- /runtime/request.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/hex" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "reflect" 11 | "strings" 12 | "sync" 13 | "unsafe" 14 | 15 | "github.com/x5iu/defc/runtime/token" 16 | ) 17 | 18 | // JSONBody is a shortcut type used for quickly constructing an io.Reader. 19 | // 20 | // Embed JSONBody as the first field in the struct, and set the generic parameter of JSONBody 21 | // to the type of the embedded struct (ensure to use a value type rather than a pointer type). 22 | // This will allow the struct to be converted to an io.Reader and used as the body of an HTTP 23 | // request. 24 | // 25 | // Note that incorrectly setting the value of the generic type T (for example, not setting it 26 | // to a type consistent with the embedding struct) can lead to severe errors. Please adhere 27 | // to the aforementioned rule. 28 | type JSONBody[T any] struct { 29 | data bytes.Buffer 30 | once sync.Once 31 | } 32 | 33 | func (b *JSONBody[T]) Read(p []byte) (n int, err error) { 34 | b.once.Do(func() { 35 | var x T 36 | vf := reflect.ValueOf(x) 37 | if vf.Kind() != reflect.Struct { 38 | panic("use the value type of a struct rather than a pointer type as the value for generics") 39 | } 40 | vt := vf.Type() 41 | // estimate the size of the body in advance 42 | var toGrow = 0 43 | for i := 0; i < vt.NumField(); i++ { 44 | sf, sv := vt.Field(i), vf.Field(i) 45 | if i == 0 { 46 | if !sf.Anonymous || sf.Type != reflect.TypeOf(b).Elem() { 47 | panic("JSONBody is not the first embedded field of struct type T") 48 | } 49 | } else { 50 | toGrow += 8 // object keys 51 | switch sf.Type.Kind() { 52 | case reflect.String: 53 | toGrow += 2 + sv.Len() 54 | case reflect.Slice: 55 | toGrow += sv.Len() * 2 56 | default: 57 | toGrow += 4 58 | } 59 | } 60 | } 61 | b.data.Grow(toGrow) 62 | encoder := json.NewEncoder(&b.data) 63 | x = *(*T)(unsafe.Pointer(b)) 64 | err = encoder.Encode(&x) 65 | }) 66 | if err != nil { 67 | return 0, err 68 | } 69 | return b.data.Read(p) 70 | } 71 | 72 | // MultipartBody is a shortcut type used for quickly constructing an io.Reader. 73 | // 74 | // Embed MultipartBody as the first field in the struct, and set the generic parameter of 75 | // MultipartBody to the type of the embedded struct (ensure to use a value type rather than a 76 | // pointer type). This will allow the struct to be converted to an io.Reader and used as the 77 | // body of an HTTP request. 78 | // 79 | // Use the ContentType method to obtain the Content-Type Header with a boundary. Use the "form" 80 | // tag to specify the name of fields in multipart/form-data, its usage is similar to the 81 | // encoding/json package, and it also supports the omitempty syntax. 82 | // 83 | // For file types, you can directly use os.File as a value, or you can use types that implement 84 | // the namedReader interface (os.File has implemented the namedReader interface). 85 | // 86 | // Note that incorrectly setting the value of the generic type T (for example, not setting it 87 | // to a type consistent with the embedding struct) can lead to severe errors. Please adhere 88 | // to the aforementioned rule. 89 | type MultipartBody[T any] struct { 90 | reader io.Reader 91 | boundary string 92 | once sync.Once 93 | } 94 | 95 | func (b *MultipartBody[T]) getBoundary() string { 96 | if b.boundary == "" { 97 | var buf [32]byte 98 | io.ReadFull(rand.Reader, buf[:]) 99 | b.boundary = hex.EncodeToString(buf[:]) 100 | } 101 | return b.boundary 102 | } 103 | 104 | func (b *MultipartBody[T]) ContentType() string { 105 | boundary := b.getBoundary() 106 | if strings.ContainsAny(boundary, `()<>@,;:\"/[]?= `) { 107 | boundary = `"` + boundary + `"` 108 | } 109 | return "multipart/form-data; boundary=" + boundary 110 | } 111 | 112 | type namedReader interface { 113 | io.Reader 114 | Name() string 115 | } 116 | 117 | var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") 118 | 119 | func escapeQuotes(s string) string { 120 | return quoteEscaper.Replace(s) 121 | } 122 | 123 | func (b *MultipartBody[T]) Read(p []byte) (n int, err error) { 124 | b.once.Do(func() { 125 | var x T 126 | x = *(*T)(unsafe.Pointer(b)) 127 | s := &fieldScanner{tag: "form", val: reflect.ValueOf(x)} 128 | if s.val.Kind() != reflect.Struct { 129 | panic("use the value type of a struct rather than a pointer type as the value for generics") 130 | } 131 | if !s.CheckFirstEmbedType(reflect.TypeOf(b).Elem()) { 132 | panic("MultipartBody is not the first embedded field of struct type T") 133 | } 134 | readers := make([]io.Reader, 0, s.val.NumField()) 135 | for i := 0; s.Scan(); i++ { 136 | tag := s.Tag() 137 | if tag != "-" && s.Exported() { 138 | if tagContains(tag, "omitempty") && s.Empty() { 139 | continue 140 | } 141 | var buf bytes.Buffer 142 | if i == 0 { 143 | buf.WriteString("--" + b.getBoundary() + "\r\n") 144 | } else { 145 | buf.WriteString("\r\n--" + b.getBoundary() + "\r\n") 146 | } 147 | val := s.Val() 148 | if file, ok := val.(namedReader); ok { 149 | buf.WriteString(fmt.Sprintf(`Content-Disposition: form-data; name="%s"; filename="%s"`+"\r\n", 150 | escapeQuotes(getTag(tag)), 151 | escapeQuotes(file.Name()))) 152 | buf.WriteString("Content-Type: application/octet-stream\r\n\r\n") 153 | readers = append(readers, &buf) 154 | readers = append(readers, file) 155 | } else { 156 | var fieldvalue string 157 | if fieldvalue, ok = val.(string); !ok { 158 | fieldvalue = fmt.Sprintf("%v", val) 159 | } 160 | buf.WriteString(fmt.Sprintf(`Content-Disposition: form-data; name="%s"`+"\r\n\r\n", 161 | escapeQuotes(getTag(tag)))) 162 | buf.WriteString(fieldvalue) 163 | readers = append(readers, &buf) 164 | } 165 | } 166 | } 167 | readers = append(readers, strings.NewReader("\r\n--"+b.getBoundary()+"--\r\n")) 168 | b.reader = io.MultiReader(readers...) 169 | }) 170 | return b.reader.Read(p) 171 | } 172 | 173 | type fieldScanner struct { 174 | tag string 175 | val reflect.Value 176 | typ reflect.Type 177 | idx int 178 | } 179 | 180 | func (s *fieldScanner) CheckFirstEmbedType(target reflect.Type) bool { 181 | if s.typ == nil { 182 | s.typ = s.val.Type() 183 | } 184 | if s.typ.NumField() == 0 { 185 | return false 186 | } 187 | sf := s.typ.Field(0) 188 | return sf.Anonymous && sf.Type == target 189 | } 190 | 191 | func (s *fieldScanner) pos() int { 192 | return s.idx - 1 193 | } 194 | 195 | func (s *fieldScanner) Scan() bool { 196 | if s.val.Kind() != reflect.Struct { 197 | return false 198 | } 199 | if s.typ == nil { 200 | s.typ = s.val.Type() 201 | } 202 | s.idx++ 203 | if s.pos() >= s.typ.NumField() { 204 | return false 205 | } 206 | if sf := s.typ.Field(s.pos()); sf.Anonymous { 207 | s.idx++ 208 | } 209 | return s.pos() < s.typ.NumField() 210 | } 211 | 212 | func (s *fieldScanner) Tag() string { 213 | return s.typ.Field(s.pos()).Tag.Get(s.tag) 214 | } 215 | 216 | func (s *fieldScanner) Val() any { 217 | return s.val.Field(s.pos()).Interface() 218 | } 219 | 220 | func (s *fieldScanner) Empty() bool { 221 | return s.val.Field(s.pos()).IsZero() 222 | } 223 | 224 | func (s *fieldScanner) Exported() bool { 225 | return s.typ.Field(s.pos()).IsExported() 226 | } 227 | 228 | func getTag(tag string) string { 229 | tag, _, _ = strings.Cut(tag, token.Comma) 230 | return tag 231 | } 232 | 233 | func tagContains(tag string, option string) bool { 234 | parts := strings.Split(tag, token.Comma) 235 | if len(parts) <= 1 { 236 | return false 237 | } 238 | options := parts[1:] 239 | for _, o := range options { 240 | if strings.TrimSpace(o) == option { 241 | return true 242 | } 243 | } 244 | return false 245 | } 246 | -------------------------------------------------------------------------------- /runtime/request_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "mime/multipart" 10 | "os" 11 | "path/filepath" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | ) 16 | 17 | type testRequestJSONBody struct { 18 | JSONBody[testRequestJSONBody] 19 | Code string `json:"code"` 20 | Message string `json:"message"` 21 | Success bool `json:"success"` 22 | List []int `json:"list"` 23 | } 24 | 25 | type testErrRequestJSONBody struct { 26 | JSONBody[testErrRequestJSONBody] 27 | } 28 | 29 | func (*testErrRequestJSONBody) MarshalJSON() ([]byte, error) { 30 | return nil, errors.New("test_err_request_body_unmarshal_error") 31 | } 32 | 33 | type testPanicRequestJSONBody struct { 34 | JSONBody[*testPanicRequestJSONBody] 35 | } 36 | 37 | type testDislocationRequestJSONBody struct { 38 | Code string `json:"code"` 39 | JSONBody[testDislocationRequestJSONBody] 40 | } 41 | 42 | func TestJSONBody(t *testing.T) { 43 | want := map[string]any{ 44 | "code": "test", 45 | "message": "0408", 46 | "success": true, 47 | "list": []any{1.0, 2.0, 3.0}, 48 | } 49 | var r io.Reader = &testRequestJSONBody{ 50 | Code: "test", 51 | Message: "0408", 52 | Success: true, 53 | List: []int{1, 2, 3}, 54 | } 55 | raw, _ := io.ReadAll(r) 56 | var got map[string]any 57 | if err := json.Unmarshal(raw, &got); err != nil { 58 | t.Errorf("json_body: %s", err) 59 | return 60 | } 61 | if !reflect.DeepEqual(got, want) { 62 | t.Errorf("json_body: got => %v", got) 63 | return 64 | } 65 | r = &testErrRequestJSONBody{} 66 | if unexpected, err := io.ReadAll(r); err == nil { 67 | t.Errorf("json_body: expects errors, got nil, unexpected => %s", string(unexpected)) 68 | return 69 | } else if !strings.Contains(err.Error(), "test_err_request_body_unmarshal_error") { 70 | t.Errorf("json_body: expects Error, got => %s", err) 71 | return 72 | } 73 | t.Run("kind_panic", func(t *testing.T) { 74 | var r io.Reader 75 | r = &testPanicRequestJSONBody{} 76 | defer func() { 77 | if rec := recover(); rec == nil { 78 | t.Errorf("json_body: expects Panic, got nil") 79 | return 80 | } else if lit, ok := rec.(string); !ok || lit != "use the value type of a struct rather than a pointer type as the value for generics" { 81 | t.Errorf("json_body: unexpected Panic literal => %s", rec) 82 | return 83 | } 84 | }() 85 | _, _ = io.ReadAll(r) 86 | }) 87 | t.Run("loc_panic", func(t *testing.T) { 88 | var r io.Reader 89 | r = &testDislocationRequestJSONBody{Code: "test"} 90 | defer func() { 91 | if rec := recover(); rec == nil { 92 | t.Errorf("json_body: expects Panic, got nil") 93 | return 94 | } else if lit, ok := rec.(string); !ok || lit != "JSONBody is not the first embedded field of struct type T" { 95 | t.Errorf("json_body: unexpected Panic literal => %s", rec) 96 | return 97 | } 98 | }() 99 | _, _ = io.ReadAll(r) 100 | }) 101 | } 102 | 103 | type testRequestMultipartBody struct { 104 | MultipartBody[testRequestMultipartBody] 105 | File *os.File `form:"file"` 106 | Name string `form:"name"` 107 | N int `form:"n,test"` 108 | Skip string `form:"-"` 109 | Exclude string `form:"exclude,omitempty"` 110 | unexported string 111 | } 112 | 113 | type testPanicRequestMultipartBody struct { 114 | MultipartBody[*testPanicRequestMultipartBody] 115 | } 116 | 117 | type testDislocationRequestMultipartBody struct { 118 | Name string `form:"name"` 119 | MultipartBody[testDislocationRequestMultipartBody] 120 | } 121 | 122 | func TestMultipartBody(t *testing.T) { 123 | file, err := os.Open("request_test.go") 124 | if err != nil { 125 | t.Errorf("multipart_body: %s", err) 126 | return 127 | } 128 | defer file.Close() 129 | r := &testRequestMultipartBody{ 130 | File: file, 131 | Name: "test", 132 | N: 1, 133 | } 134 | var buf bytes.Buffer 135 | multipartWriter := multipart.NewWriter(&buf) 136 | if err = multipartWriter.SetBoundary(r.getBoundary()); err != nil { 137 | t.Errorf("multipart_body: %s", err) 138 | return 139 | } 140 | w, err := multipartWriter.CreateFormFile("file", filepath.Base(file.Name())) 141 | if err != nil { 142 | t.Errorf("multipart_body: %s", err) 143 | return 144 | } 145 | io.Copy(w, file) 146 | if err = multipartWriter.WriteField("name", "test"); err != nil { 147 | t.Errorf("multipart_body: %s", err) 148 | return 149 | } 150 | if err = multipartWriter.WriteField("n", "1"); err != nil { 151 | t.Errorf("multipart_body: %s", err) 152 | return 153 | } 154 | multipartWriter.Close() 155 | if _, err = file.Seek(0, io.SeekStart); err != nil { 156 | t.Errorf("multipart_body: %s", err) 157 | return 158 | } 159 | requestBody, _ := io.ReadAll(r) 160 | if !bytes.Equal(requestBody, buf.Bytes()) { 161 | t.Errorf("multipart_body: unexptected body => \n\n%s", string(requestBody)) 162 | return 163 | } 164 | t.Run("content_type", func(t *testing.T) { 165 | r := &testRequestMultipartBody{} 166 | r.boundary = r.getBoundary() + "@" 167 | if want := fmt.Sprintf(`multipart/form-data; boundary="%s"`, r.boundary); r.ContentType() != want { 168 | t.Errorf("multipart_body: unexptected Content-Type => \n\nwant: %s\ngot: %s", want, r.ContentType()) 169 | return 170 | } 171 | }) 172 | t.Run("kind_panic", func(t *testing.T) { 173 | r := &testPanicRequestMultipartBody{} 174 | defer func() { 175 | if rec := recover(); rec == nil { 176 | t.Errorf("multipart_body: expects Panic, got nil") 177 | return 178 | } else if lit, ok := rec.(string); !ok || lit != "use the value type of a struct rather than a pointer type as the value for generics" { 179 | t.Errorf("multipart_body: unexpected Panic literal => %s", rec) 180 | return 181 | } 182 | }() 183 | _, _ = io.ReadAll(r) 184 | }) 185 | t.Run("loc_panic", func(t *testing.T) { 186 | r := &testDislocationRequestMultipartBody{Name: "test"} 187 | defer func() { 188 | if rec := recover(); rec == nil { 189 | t.Errorf("multipart_body: expects Panic, got nil") 190 | return 191 | } else if lit, ok := rec.(string); !ok || lit != "MultipartBody is not the first embedded field of struct type T" { 192 | t.Errorf("multipart_body: unexpected Panic literal => %s", rec) 193 | return 194 | } 195 | }() 196 | _, _ = io.ReadAll(r) 197 | }) 198 | } 199 | 200 | func TestFieldScanner(t *testing.T) { 201 | var s *fieldScanner 202 | s = &fieldScanner{tag: "test", val: reflect.ValueOf("test")} 203 | for i := 0; i < 8; i++ { 204 | if s.Scan() != false { 205 | t.Errorf("field_scanner: expects s.Scan() == false, got true") 206 | return 207 | } 208 | } 209 | s = &fieldScanner{tag: "test", val: reflect.ValueOf(struct{}{})} 210 | if s.Scan() != false { 211 | t.Errorf("field_scanner: expects s.Scan() == false, got true") 212 | return 213 | } 214 | if s.CheckFirstEmbedType(nil) != false { 215 | t.Errorf("field_scanner: expects s.CheckFirstEmbedType() == false, got true") 216 | return 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /runtime/response.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | type Response interface { 13 | Err() error 14 | ScanValues(...any) error 15 | FromBytes(string, []byte) error 16 | Break() bool 17 | } 18 | 19 | // FutureResponse represents Response interface which would be used in next 20 | // major version of defc, who may cause breaking changes. 21 | type FutureResponse interface { 22 | Err() error 23 | ScanValues(...any) error 24 | FromResponse(string, *http.Response) error 25 | Break() bool 26 | } 27 | 28 | type ResponseError interface { 29 | error 30 | Status() int 31 | Body() []byte 32 | } 33 | 34 | func NewResponseError(caller string, status int, body []byte) ResponseError { 35 | return &implResponseError{ 36 | caller: caller, 37 | status: status, 38 | body: body, 39 | } 40 | } 41 | 42 | type implResponseError struct { 43 | caller string 44 | status int 45 | body []byte 46 | } 47 | 48 | func (e *implResponseError) Error() string { 49 | return fmt.Sprintf("response status code %d for '%s' with body: \n\n%s\n\n", e.status, e.caller, string(e.body)) 50 | } 51 | 52 | func (e *implResponseError) Status() int { 53 | return e.status 54 | } 55 | 56 | func (e *implResponseError) Body() []byte { 57 | return e.body 58 | } 59 | 60 | // FutureResponseError represents Response error interface which would be used 61 | // in next major version of defc, who may cause breaking changes. 62 | type FutureResponseError interface { 63 | error 64 | Response() *http.Response 65 | } 66 | 67 | func NewFutureResponseError(caller string, response *http.Response) FutureResponseError { 68 | return &implFutureResponseError{ 69 | caller: caller, 70 | response: response, 71 | } 72 | } 73 | 74 | type implFutureResponseError struct { 75 | caller string 76 | response *http.Response 77 | } 78 | 79 | func (e *implFutureResponseError) Error() string { 80 | return fmt.Sprintf("response status code %d for '%s'", e.response.StatusCode, e.caller) 81 | } 82 | 83 | func (e *implFutureResponseError) Response() *http.Response { 84 | return e.response 85 | } 86 | 87 | // JSON is a Response handler that quickly adapts to interfaces with Content-Type: application/json. 88 | // You can directly use *JSON as the return type for response methods in the API Schema to handle 89 | // the JSON data returned by the interface. 90 | // 91 | // NOTE: Not suitable for pagination query interfaces. If your interface involves pagination queries, 92 | // please implement a custom Response handler. 93 | type JSON struct { 94 | Raw json.RawMessage 95 | Res *http.Response 96 | } 97 | 98 | func (j *JSON) Err() (err error) { 99 | if j.Res != nil { 100 | if j.Res.StatusCode != http.StatusOK { 101 | defer j.Res.Body.Close() 102 | var body []byte 103 | if len(j.Raw) > 0 { 104 | body = j.Raw 105 | } else { 106 | body, err = io.ReadAll(j.Res.Body) 107 | if err != nil { 108 | return fmt.Errorf("error reading response body: %w", err) 109 | } 110 | } 111 | return fmt.Errorf("response status code %d with body: \n\n%s\n\n", j.Res.StatusCode, string(body)) 112 | } 113 | } 114 | return nil 115 | } 116 | 117 | func (j *JSON) FromBytes(_ string, bytes []byte) error { 118 | j.Raw = bytes 119 | return nil 120 | } 121 | 122 | func (j *JSON) FromResponse(_ string, r *http.Response) error { 123 | j.Res = r 124 | var ( 125 | ctt = r.Header.Get("Content-Type") 126 | idx = -1 127 | ) 128 | if idx = strings.IndexByte(ctt, ';'); idx < 0 { 129 | idx = len(ctt) 130 | } 131 | defer r.Body.Close() 132 | if ctt = strings.TrimSpace(ctt[:idx]); ctt == "application/json" { 133 | return json.NewDecoder(r.Body).Decode(&j.Raw) 134 | } else { 135 | return fmt.Errorf("response content type %q is not %q", ctt, "application/json") 136 | } 137 | } 138 | 139 | func (j *JSON) ScanValues(vs ...any) error { 140 | for _, v := range vs { 141 | if err := json.Unmarshal(j.Raw, v); err != nil { 142 | return err 143 | } 144 | } 145 | return nil 146 | } 147 | 148 | func (j *JSON) Break() bool { 149 | panic("JSON is not well-suited for pagination query requests") 150 | } 151 | 152 | type GzipReadCloser struct { 153 | R io.ReadCloser 154 | gzipReader *gzip.Reader 155 | } 156 | 157 | func (r *GzipReadCloser) Read(p []byte) (n int, err error) { 158 | if r.gzipReader == nil { 159 | r.gzipReader, err = gzip.NewReader(r.R) 160 | if err != nil { 161 | return 0, err 162 | } 163 | } 164 | return r.gzipReader.Read(p) 165 | } 166 | 167 | func (r *GzipReadCloser) Close() error { 168 | if r.gzipReader != nil { 169 | if err := r.gzipReader.Close(); err != nil { 170 | r.R.Close() 171 | return err 172 | } 173 | } 174 | return r.R.Close() 175 | } 176 | -------------------------------------------------------------------------------- /runtime/response_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "strconv" 8 | "testing" 9 | ) 10 | 11 | func TestNewResponseError(t *testing.T) { 12 | err := NewResponseError("test", http.StatusInternalServerError, []byte("test")) 13 | if status := err.Status(); status != http.StatusInternalServerError { 14 | t.Errorf("status: %d != %d", status, http.StatusInternalServerError) 15 | return 16 | } 17 | if body := err.Body(); !bytes.Equal(body, []byte("test")) { 18 | t.Errorf("body: %s != %s", string(body), "test") 19 | return 20 | } 21 | expectedString := `response status code ` + 22 | strconv.Itoa(http.StatusInternalServerError) + 23 | ` for 'test' with body: 24 | 25 | test 26 | 27 | ` 28 | if errorString := err.Error(); errorString != expectedString { 29 | t.Errorf("error: unexpected error string => %s", errorString) 30 | return 31 | } 32 | } 33 | 34 | func TestNewFutureResponseError(t *testing.T) { 35 | err := NewFutureResponseError("test", &http.Response{ 36 | StatusCode: http.StatusInternalServerError, 37 | Body: io.NopCloser(bytes.NewReader([]byte("test"))), 38 | }) 39 | response := err.Response() 40 | defer response.Body.Close() 41 | body, _ := io.ReadAll(response.Body) 42 | if !bytes.Equal(body, []byte("test")) { 43 | t.Errorf("body: %s != %s", string(body), "test") 44 | return 45 | } 46 | expectedString := "response status code " + strconv.Itoa(http.StatusInternalServerError) + " for 'test'" 47 | if errorString := err.Error(); errorString != expectedString { 48 | t.Errorf("error: unexpected error string => %s", errorString) 49 | return 50 | } 51 | } 52 | 53 | func TestJSON(t *testing.T) { 54 | body := []byte(`{"code": 200}`) 55 | r := &http.Response{ 56 | Body: io.NopCloser(bytes.NewReader(body)), 57 | Header: http.Header{ 58 | "Content-Type": []string{"application/json; charset=utf-8"}, 59 | }, 60 | } 61 | j := new(JSON) 62 | if err := j.Err(); err != nil { 63 | t.Errorf("json: %s", err) 64 | return 65 | } 66 | if err := j.FromBytes("test", body); err != nil { 67 | t.Errorf("json: %s", err) 68 | return 69 | } 70 | if err := j.FromResponse("test", r); err != nil { 71 | t.Errorf("json: %s", err) 72 | return 73 | } 74 | if err := j.ScanValues([]any{}...); err != nil { 75 | t.Errorf("json: %s", err) 76 | return 77 | } 78 | var val struct { 79 | Code int `json:"code"` 80 | } 81 | if err := j.ScanValues(val); err == nil { 82 | t.Errorf("json: expects UnmarshalError, got nil") 83 | return 84 | } else if val.Code != 0 { 85 | t.Errorf("json: %v != 0", val.Code) 86 | return 87 | } 88 | if err := j.ScanValues(&val); err != nil { 89 | t.Errorf("json: %s", err) 90 | return 91 | } else if val.Code != 200 { 92 | t.Errorf("json: %v != 200", val.Code) 93 | return 94 | } 95 | defer func() { 96 | if err := recover(); err == nil { 97 | t.Errorf("json: expects Panic, got nil") 98 | return 99 | } 100 | }() 101 | if j.Break() { 102 | t.Errorf("json: unreachable") 103 | return 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /runtime/split.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "strings" 5 | 6 | tok "github.com/x5iu/defc/runtime/token" 7 | ) 8 | 9 | func Count(sql string, ch string) (n int) { 10 | tokens := tok.SplitTokens(sql) 11 | for _, token := range tokens { 12 | if token == ch { 13 | n++ 14 | } 15 | } 16 | return n 17 | } 18 | 19 | func Split(sql string, sep string) (group []string) { 20 | tokens := tok.SplitTokens(sql) 21 | group = make([]string, 0, len(tokens)) 22 | last := 0 23 | for i, token := range tokens { 24 | if token == sep || i+1 == len(tokens) { 25 | if joint := tok.MergeSqlTokens(tokens[last : i+1]); len(strings.Trim(joint, sep)) > 0 { 26 | group = append(group, joint) 27 | } 28 | last = i + 1 29 | } 30 | } 31 | return group 32 | } 33 | -------------------------------------------------------------------------------- /runtime/split_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/x5iu/defc/runtime/token" 8 | ) 9 | 10 | func TestCount(t *testing.T) { 11 | type TestCase struct { 12 | Name string 13 | Input string 14 | Token string 15 | N int 16 | } 17 | var testcases = []*TestCase{ 18 | { 19 | Name: "?", 20 | Input: "? '\t?' \"\n?\" \r\n? \t?", 21 | Token: "?", 22 | N: 3, 23 | }, 24 | } 25 | for _, testcase := range testcases { 26 | t.Run(testcase.Name, func(t *testing.T) { 27 | if n := Count(testcase.Input, testcase.Token); n != testcase.N { 28 | t.Errorf("count: %d != %d", n, testcase.N) 29 | return 30 | } 31 | }) 32 | } 33 | } 34 | 35 | func TestSplit(t *testing.T) { 36 | type TestCase struct { 37 | Name string 38 | Input string 39 | Sep string 40 | Expect []string 41 | } 42 | var testcases = []*TestCase{ 43 | { 44 | Name: "escape_quotes", 45 | Input: "part1;\r\n \\'part2;\r\n \"\\\"part3\";", 46 | Sep: ";", 47 | Expect: []string{ 48 | "part1;", 49 | " \\'part2;", 50 | " \"\\\"part3\";", 51 | }, 52 | }, 53 | { 54 | Name: "separate_comma", 55 | Input: "part1, part2, part3", 56 | Sep: ",", 57 | Expect: []string{ 58 | "part1,", 59 | " part2,", 60 | " part3", 61 | }, 62 | }, 63 | { 64 | Name: "comma_in_paren", 65 | Input: "(autoincrement,\n\t\t\tname);insert", 66 | Sep: ";", 67 | Expect: []string{ 68 | "(autoincrement, name);", 69 | "insert", 70 | }, 71 | }, 72 | { 73 | Name: "comma_query", 74 | Input: "select id, name from user where name in (?, ?);", 75 | Sep: ";", 76 | Expect: []string{ 77 | "select id, name from user where name in (?, ?);", 78 | }, 79 | }, 80 | { 81 | Name: "named_query", 82 | Input: "select id, name from user where id = :id and name = :name;", 83 | Sep: ";", 84 | Expect: []string{ 85 | "select id, name from user where id = :id and name = :name;", 86 | }, 87 | }, 88 | { 89 | Name: "comment_query", 90 | Input: "/* sqlcomment */ select id, name from user where id = :id and name = :name; -- comment; // comment;", 91 | Sep: ";", 92 | Expect: []string{ 93 | "/* sqlcomment */ select id, name from user where id = :id and name = :name;", 94 | " -- comment;", 95 | " // comment;", 96 | }, 97 | }, 98 | { 99 | Name: "at_query", 100 | Input: "select id, name from user where id = @id and name = @name;", 101 | Sep: ";", 102 | Expect: []string{ 103 | "select id, name from user where id = @id and name = @name;", 104 | }, 105 | }, 106 | { 107 | Name: "dollar_query", 108 | Input: "select id, name from user where id = $1 and name = $2;", 109 | Sep: ";", 110 | Expect: []string{ 111 | "select id, name from user where id = $1 and name = $2;", 112 | }, 113 | }, 114 | } 115 | for _, testcase := range testcases { 116 | t.Run(testcase.Name, func(t *testing.T) { 117 | if splitStrings := Split(testcase.Input, testcase.Sep); !reflect.DeepEqual(splitStrings, testcase.Expect) { 118 | t.Errorf("split: %v != %v", splitStrings, testcase.Expect) 119 | return 120 | } 121 | }) 122 | } 123 | } 124 | 125 | func TestSplitTokens(t *testing.T) { 126 | type TestCase struct { 127 | Name string 128 | Input string 129 | Expect []string 130 | } 131 | var testcases = []*TestCase{ 132 | { 133 | Name: "comma_token", 134 | Input: "autoincrement,\n\t\t\tname", 135 | Expect: []string{ 136 | "autoincrement", 137 | ",", 138 | " ", 139 | "name", 140 | }, 141 | }, 142 | { 143 | Name: "question_token", 144 | Input: "in(?,?);", 145 | Expect: []string{ 146 | "in", "(", "?", ",", "?", ")", ";", 147 | }, 148 | }, 149 | { 150 | Name: "comment_token", 151 | Input: "# // -- /* */", 152 | Expect: []string{ 153 | "#", " ", "/", "/", " ", "-", "-", " ", "/", "*", " ", "*", "/", 154 | }, 155 | }, 156 | { 157 | Name: "colon_token", 158 | Input: ":id, :name", 159 | Expect: []string{ 160 | ":", "id", 161 | ",", " ", 162 | ":", "name", 163 | }, 164 | }, 165 | } 166 | for _, testcase := range testcases { 167 | t.Run(testcase.Name, func(t *testing.T) { 168 | if tokens := token.SplitTokens(testcase.Input); !reflect.DeepEqual(tokens, testcase.Expect) { 169 | t.Errorf("tokens: %v != %v", tokens, testcase.Expect) 170 | return 171 | } 172 | }) 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /runtime/sqlx.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/x5iu/defc/sqlx" 8 | ) 9 | 10 | var ( 11 | Open = sqlx.Open 12 | MustOpen = sqlx.MustOpen 13 | Connect = sqlx.Connect 14 | ConnectContext = sqlx.ConnectContext 15 | MustConnect = sqlx.MustConnect 16 | NewDB = sqlx.NewDB 17 | 18 | // There is no need to import In from sqlx package, since sqlx.In references defc.In 19 | /* 20 | In = sqlx.In 21 | */ 22 | 23 | Named = sqlx.Named 24 | StructScan = sqlx.StructScan 25 | ScanStruct = sqlx.StructScan 26 | MapScan = sqlx.MapScan 27 | ScanMap = sqlx.MapScan 28 | SliceScan = sqlx.SliceScan 29 | ScanSlice = sqlx.SliceScan 30 | ScanRow = sqlx.ScanRow 31 | ) 32 | 33 | type ( 34 | DB = sqlx.DB 35 | Tx = sqlx.Tx 36 | Row = sqlx.IRow 37 | Rows = sqlx.IRows 38 | FromRow = sqlx.FromRow 39 | FromRows = sqlx.FromRows 40 | ) 41 | 42 | type TxInterface interface { 43 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 44 | GetContext(ctx context.Context, dest any, query string, args ...any) error 45 | SelectContext(ctx context.Context, dest any, query string, args ...any) error 46 | Rollback() error 47 | Commit() error 48 | } 49 | 50 | type TxRebindInterface interface { 51 | TxInterface 52 | Rebind(query string) string 53 | } 54 | -------------------------------------------------------------------------------- /runtime/sqlx_test.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | import ( 4 | "html/template" 5 | "io" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | // The bench results show that for ordinary template rendering tasks (such as a single CRUD SQL statement), 11 | // using `bind` takes 10 times longer than using `arguments`, but the time spent is still in the microsecond 12 | // range and can be ignored. However, as the content of the template rendering gradually increases, using 13 | // `bind` will take more and more time and memory; at its peak (i.e., Benchmark_1000), using `bind` requires 14 | // 700 milliseconds, while using `arguments` still maintains a rendering time of 1 microsecond. 15 | // 16 | // Use `arguments` to add parameters whenever possible instead of `bind`. 17 | // 18 | // ``` 19 | // goos: darwin 20 | // goarch: arm64 21 | // pkg: github.com/x5iu/defc/runtime 22 | // cpu: Apple M1 23 | // BenchmarkBind1-8 64046 15810 ns/op 15786 B/op 215 allocs/op 24 | // BenchmarkArgumentsBind1-8 918460 1224 ns/op 920 B/op 19 allocs/op 25 | // BenchmarkBind10-8 6634 163205 ns/op 154695 B/op 3175 allocs/op 26 | // BenchmarkArgumentsBind10-8 904687 1258 ns/op 920 B/op 19 allocs/op 27 | // BenchmarkBind100-8 100 11570121 ns/op 9696021 B/op 176360 allocs/op 28 | // BenchmarkArgumentsBind100-8 814221 1264 ns/op 921 B/op 19 allocs/op 29 | // BenchmarkBind1000-8 2 722617750 ns/op 969184652 B/op 16180071 allocs/op 30 | // BenchmarkArgumentsBind1000-8 812119 1487 ns/op 927 B/op 19 allocs/op 31 | // ``` 32 | 33 | type Person struct { 34 | Name string 35 | Age int 36 | Gender string 37 | Address string 38 | } 39 | 40 | func BenchmarkBind1(b *testing.B) { benchmarkBind(b, 1) } 41 | func BenchmarkArgumentsBind1(b *testing.B) { benchmarkArgumentsBind(b, 1) } 42 | func BenchmarkBind10(b *testing.B) { benchmarkBind(b, 10) } 43 | func BenchmarkArgumentsBind10(b *testing.B) { benchmarkArgumentsBind(b, 10) } 44 | func BenchmarkBind100(b *testing.B) { benchmarkBind(b, 100) } 45 | func BenchmarkArgumentsBind100(b *testing.B) { benchmarkArgumentsBind(b, 100) } 46 | func BenchmarkBind1000(b *testing.B) { benchmarkBind(b, 1000) } 47 | func BenchmarkArgumentsBind1000(b *testing.B) { benchmarkArgumentsBind(b, 1000) } 48 | 49 | func benchmarkBind(b *testing.B, n int) { 50 | b.ReportAllocs() 51 | const tmplStr = ` 52 | {{ bind .Name }} 53 | {{ bind .Age }} 54 | {{ bind .Gender }} 55 | {{ bind .Address }} 56 | ` 57 | var largeTmplStr strings.Builder 58 | for i := 0; i < n; i++ { 59 | largeTmplStr.WriteString(tmplStr) 60 | } 61 | for i := 0; i < b.N; i++ { 62 | var argListBenchmarkBind []any 63 | bind := func(arg any) string { 64 | argListBenchmarkBind = append(argListBenchmarkBind, arg) 65 | return BindVars(len(MergeArgs(argListBenchmarkBind))) 66 | } 67 | funcMap := template.FuncMap{ 68 | "bind": bind, 69 | } 70 | t := template.Must(template.New("BenchmarkBind").Funcs(funcMap).Parse(largeTmplStr.String())) 71 | t.Execute(io.Discard, Person{ 72 | Name: "John", 73 | Age: 20, 74 | Gender: "Male", 75 | Address: "123 Main St, Anytown, USA", 76 | }) 77 | } 78 | } 79 | 80 | func benchmarkArgumentsBind(b *testing.B, n int) { 81 | b.ReportAllocs() 82 | const tmplStr = ` 83 | {{ .args.Bind .Name }} 84 | {{ .args.Bind .Age }} 85 | {{ .args.Bind .Gender }} 86 | {{ .args.Bind .Address }} 87 | ` 88 | var largeTmplStr strings.Builder 89 | for i := 0; i < n; i++ { 90 | largeTmplStr.WriteString(tmplStr) 91 | } 92 | t := template.Must(template.New("BenchmarkArgumentsBind").Parse(largeTmplStr.String())) 93 | for i := 0; i < b.N; i++ { 94 | var argListBenchmarkArgumentsBind Arguments 95 | t.Execute(io.Discard, map[string]any{ 96 | "args": argListBenchmarkArgumentsBind, 97 | "Name": "John", 98 | "Age": 20, 99 | "Gender": "Male", 100 | "Address": "123 Main St, Anytown, USA", 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /runtime/token/token.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | lru "github.com/hashicorp/golang-lru/v2" 8 | ) 9 | 10 | const ( 11 | Space = " " 12 | ) 13 | 14 | const ( 15 | Question = "?" 16 | Comma = "," 17 | Colon = ":" 18 | Dollar = "$" 19 | At = "@" 20 | Dash = "-" 21 | Div = "/" 22 | Mul = "*" 23 | Underline = "_" 24 | ) 25 | 26 | type Lexer struct { 27 | Raw string 28 | 29 | index int 30 | atsep bool 31 | token string 32 | } 33 | 34 | func (l *Lexer) Next() (next bool) { 35 | l.token, next = l.parse() 36 | return next 37 | } 38 | 39 | func (l *Lexer) Token() string { 40 | return l.token 41 | } 42 | 43 | func (l *Lexer) parse() (string, bool) { 44 | line := l.Raw 45 | 46 | var ( 47 | singleQuoted bool 48 | doubleQuoted bool 49 | backQuoted bool 50 | arg []byte 51 | ) 52 | 53 | for ; l.index < len(line); l.index++ { 54 | switch ch := line[l.index]; ch { 55 | case ':', ';', ',', '(', ')', '[', ']', '{', '}', '.', '=', '?', '+', '-', '*', '/', '>', '<', '!', '~', '%', '@', '&', '|': 56 | if doubleQuoted || singleQuoted || backQuoted { 57 | if l.atsep { 58 | panic("in various quotation marks, `atsep` should not be set") 59 | } 60 | arg = append(arg, ch) 61 | } else { 62 | if len(arg) > 0 { 63 | if l.atsep { 64 | panic("when the symbol is immediately adjacent to other tokens, `atsep` should not be set") 65 | } 66 | return string(arg), true 67 | } 68 | if l.atsep { 69 | l.atsep = false 70 | return Space, true 71 | } 72 | l.index++ 73 | return string(ch), true 74 | } 75 | case ' ', '\t', '\n', '\r': 76 | if doubleQuoted || singleQuoted || backQuoted { 77 | if l.atsep { 78 | panic("in various quotation marks, `atsep` should not be set") 79 | } 80 | arg = append(arg, ch) 81 | } else if len(arg) > 0 { 82 | if l.atsep { 83 | panic("this is the first encounter with a space, `atsep` should not be set") 84 | } 85 | l.atsep = true 86 | return string(arg), true 87 | } else { 88 | l.atsep = true 89 | } 90 | case '"': 91 | if !(l.index > 0 && line[l.index-1] == '\\' || singleQuoted || backQuoted) { 92 | if !doubleQuoted { 93 | if l.atsep { 94 | l.atsep = false 95 | return Space, true 96 | } 97 | } 98 | doubleQuoted = !doubleQuoted 99 | } 100 | arg = append(arg, ch) 101 | if !doubleQuoted { 102 | l.index++ 103 | return string(arg), true 104 | } 105 | case '\'': 106 | if !(l.index > 0 && line[l.index-1] == '\\' || doubleQuoted || backQuoted) { 107 | if !singleQuoted { 108 | if l.atsep { 109 | l.atsep = false 110 | return Space, true 111 | } 112 | } 113 | singleQuoted = !singleQuoted 114 | } 115 | arg = append(arg, ch) 116 | if !singleQuoted { 117 | l.index++ 118 | return string(arg), true 119 | } 120 | case '`': 121 | if !(l.index > 0 && line[l.index-1] == '\\' || singleQuoted || doubleQuoted) { 122 | if !backQuoted { 123 | if l.atsep { 124 | l.atsep = false 125 | return Space, true 126 | } 127 | } 128 | backQuoted = !backQuoted 129 | } 130 | arg = append(arg, ch) 131 | if !backQuoted { 132 | l.index++ 133 | return string(arg), true 134 | } 135 | default: 136 | if l.atsep { 137 | l.atsep = false 138 | return Space, true 139 | } 140 | arg = append(arg, ch) 141 | } 142 | } 143 | 144 | if len(arg) > 0 { 145 | return string(arg), true 146 | } 147 | 148 | return "", false 149 | } 150 | 151 | func MergeSqlTokens(tokens []string) string { 152 | n := 0 153 | for _, token := range tokens { 154 | n += len(token) 155 | } 156 | var merged strings.Builder 157 | merged.Grow(n) 158 | for _, token := range tokens { 159 | merged.WriteString(token) 160 | } 161 | return merged.String() 162 | } 163 | 164 | var splitTokensCache *lru.TwoQueueCache[string, []string] 165 | 166 | func init() { 167 | var err error 168 | if splitTokensCache, err = lru.New2Q[string, []string](1024); err != nil { 169 | panic(fmt.Errorf("failed to init lru cache: %w", err)) 170 | } 171 | } 172 | 173 | func SplitTokens(line string) (tokens []string) { 174 | tokens, exists := splitTokensCache.Get(line) 175 | if exists { 176 | return tokens 177 | } 178 | l := Lexer{Raw: line} 179 | for l.Next() { 180 | tokens = append(tokens, l.Token()) 181 | } 182 | splitTokensCache.Add(line, tokens) 183 | return tokens 184 | } 185 | -------------------------------------------------------------------------------- /runtime/version.go: -------------------------------------------------------------------------------- 1 | package defc 2 | 3 | const Version = "v1.39.0" 4 | -------------------------------------------------------------------------------- /sqlx/reflectx/reflectx.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. 2 | //go:generate bundle -o reflectx.gen.go -pkg reflectx -prefix " " github.com/x5iu/sqlx/reflectx 3 | 4 | // Package reflectx implements extensions to the standard reflect lib suitable 5 | // for implementing marshalling and unmarshalling packages. The main Mapper type 6 | // allows for Go-compatible named attribute access, including accessing embedded 7 | // struct attributes and the ability to use functions and struct tags to 8 | // customize field names. 9 | // 10 | 11 | package reflectx 12 | 13 | import ( 14 | "reflect" 15 | "runtime" 16 | "strings" 17 | "sync" 18 | ) 19 | 20 | // A FieldInfo is metadata for a struct field. 21 | type FieldInfo struct { 22 | Index []int 23 | Path string 24 | Field reflect.StructField 25 | Zero reflect.Value 26 | Name string 27 | Options map[string]string 28 | Embedded bool 29 | Children []*FieldInfo 30 | Parent *FieldInfo 31 | } 32 | 33 | // A StructMap is an index of field metadata for a struct. 34 | type StructMap struct { 35 | Tree *FieldInfo 36 | Index []*FieldInfo 37 | Paths map[string]*FieldInfo 38 | Names map[string]*FieldInfo 39 | } 40 | 41 | // GetByPath returns a *FieldInfo for a given string path. 42 | func (f StructMap) GetByPath(path string) *FieldInfo { 43 | return f.Paths[path] 44 | } 45 | 46 | // GetByTraversal returns a *FieldInfo for a given integer path. It is 47 | // analogous to reflect.FieldByIndex, but using the cached traversal 48 | // rather than re-executing the reflect machinery each time. 49 | func (f StructMap) GetByTraversal(index []int) *FieldInfo { 50 | if len(index) == 0 { 51 | return nil 52 | } 53 | 54 | tree := f.Tree 55 | for _, i := range index { 56 | if i >= len(tree.Children) || tree.Children[i] == nil { 57 | return nil 58 | } 59 | tree = tree.Children[i] 60 | } 61 | return tree 62 | } 63 | 64 | // Mapper is a general purpose mapper of names to struct fields. A Mapper 65 | // behaves like most marshallers in the standard library, obeying a field tag 66 | // for name mapping but also providing a basic transform function. 67 | type Mapper struct { 68 | cache map[reflect.Type]*StructMap 69 | tagName string 70 | tagMapFunc func(string) string 71 | mapFunc func(string) string 72 | mutex sync.Mutex 73 | } 74 | 75 | // NewMapper returns a new mapper using the tagName as its struct field tag. 76 | // If tagName is the empty string, it is ignored. 77 | func NewMapper(tagName string) *Mapper { 78 | return &Mapper{ 79 | cache: make(map[reflect.Type]*StructMap), 80 | tagName: tagName, 81 | } 82 | } 83 | 84 | // NewMapperTagFunc returns a new mapper which contains a mapper for field names 85 | // AND a mapper for tag values. This is useful for tags like json which can 86 | // have values like "name,omitempty". 87 | func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { 88 | return &Mapper{ 89 | cache: make(map[reflect.Type]*StructMap), 90 | tagName: tagName, 91 | mapFunc: mapFunc, 92 | tagMapFunc: tagMapFunc, 93 | } 94 | } 95 | 96 | // NewMapperFunc returns a new mapper which optionally obeys a field tag and 97 | // a struct field name mapper func given by f. Tags will take precedence, but 98 | // for any other field, the mapped name will be f(field.Name) 99 | func NewMapperFunc(tagName string, f func(string) string) *Mapper { 100 | return &Mapper{ 101 | cache: make(map[reflect.Type]*StructMap), 102 | tagName: tagName, 103 | mapFunc: f, 104 | } 105 | } 106 | 107 | // TypeMap returns a mapping of field strings to int slices representing 108 | // the traversal down the struct to reach the field. 109 | func (m *Mapper) TypeMap(t reflect.Type) *StructMap { 110 | m.mutex.Lock() 111 | mapping, ok := m.cache[t] 112 | if !ok { 113 | mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) 114 | m.cache[t] = mapping 115 | } 116 | m.mutex.Unlock() 117 | return mapping 118 | } 119 | 120 | // FieldMap returns the mapper's mapping of field names to reflect values. Panics 121 | // if v's Kind is not Struct, or v is not Indirectable to a struct kind. 122 | func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { 123 | v = reflect.Indirect(v) 124 | mustBe(v, reflect.Struct) 125 | 126 | r := map[string]reflect.Value{} 127 | tm := m.TypeMap(v.Type()) 128 | for tagName, fi := range tm.Names { 129 | r[tagName] = FieldByIndexes(v, fi.Index) 130 | } 131 | return r 132 | } 133 | 134 | // FieldByName returns a field by its mapped name as a reflect.Value. 135 | // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. 136 | // Returns zero Value if the name is not found. 137 | func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { 138 | v = reflect.Indirect(v) 139 | mustBe(v, reflect.Struct) 140 | 141 | tm := m.TypeMap(v.Type()) 142 | fi, ok := tm.Names[name] 143 | if !ok { 144 | return v 145 | } 146 | return FieldByIndexes(v, fi.Index) 147 | } 148 | 149 | // FieldsByName returns a slice of values corresponding to the slice of names 150 | // for the value. Panics if v's Kind is not Struct or v is not Indirectable 151 | // to a struct Kind. Returns zero Value for each name not found. 152 | func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { 153 | v = reflect.Indirect(v) 154 | mustBe(v, reflect.Struct) 155 | 156 | tm := m.TypeMap(v.Type()) 157 | vals := make([]reflect.Value, 0, len(names)) 158 | for _, name := range names { 159 | fi, ok := tm.Names[name] 160 | if !ok { 161 | vals = append(vals, *new(reflect.Value)) 162 | } else { 163 | vals = append(vals, FieldByIndexes(v, fi.Index)) 164 | } 165 | } 166 | return vals 167 | } 168 | 169 | // TraversalsByName returns a slice of int slices which represent the struct 170 | // traversals for each mapped name. Panics if t is not a struct or Indirectable 171 | // to a struct. Returns empty int slice for each name not found. 172 | func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { 173 | r := make([][]int, 0, len(names)) 174 | m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { 175 | if i == nil { 176 | r = append(r, []int{}) 177 | } else { 178 | r = append(r, i) 179 | } 180 | 181 | return nil 182 | }) 183 | return r 184 | } 185 | 186 | // TraversalsByNameFunc traverses the mapped names and calls fn with the index of 187 | // each name and the struct traversal represented by that name. Panics if t is not 188 | // a struct or Indirectable to a struct. Returns the first error returned by fn or nil. 189 | func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { 190 | t = Deref(t) 191 | mustBe(t, reflect.Struct) 192 | tm := m.TypeMap(t) 193 | for i, name := range names { 194 | fi, ok := tm.Names[name] 195 | if !ok { 196 | if err := fn(i, nil); err != nil { 197 | return err 198 | } 199 | } else { 200 | if err := fn(i, fi.Index); err != nil { 201 | return err 202 | } 203 | } 204 | } 205 | return nil 206 | } 207 | 208 | // FieldByIndexes returns a value for the field given by the struct traversal 209 | // for the given value. 210 | func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { 211 | for _, i := range indexes { 212 | v = reflect.Indirect(v).Field(i) 213 | // if this is a pointer and it's nil, allocate a new value and set it 214 | if v.Kind() == reflect.Ptr && v.IsNil() { 215 | alloc := reflect.New(Deref(v.Type())) 216 | v.Set(alloc) 217 | } 218 | if v.Kind() == reflect.Map && v.IsNil() { 219 | v.Set(reflect.MakeMap(v.Type())) 220 | } 221 | } 222 | return v 223 | } 224 | 225 | // FieldByIndexesReadOnly returns a value for a particular struct traversal, 226 | // but is not concerned with allocating nil pointers because the value is 227 | // going to be used for reading and not setting. 228 | func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { 229 | for _, i := range indexes { 230 | v = reflect.Indirect(v).Field(i) 231 | } 232 | return v 233 | } 234 | 235 | // Deref is Indirect for reflect.Types 236 | func Deref(t reflect.Type) reflect.Type { 237 | if t.Kind() == reflect.Ptr { 238 | t = t.Elem() 239 | } 240 | return t 241 | } 242 | 243 | // -- helpers & utilities -- 244 | 245 | type kinder interface { 246 | Kind() reflect.Kind 247 | } 248 | 249 | // mustBe checks a value against a kind, panicing with a reflect.ValueError 250 | // if the kind isn't that which is required. 251 | func mustBe(v kinder, expected reflect.Kind) { 252 | if k := v.Kind(); k != expected { 253 | panic(&reflect.ValueError{Method: methodName(), Kind: k}) 254 | } 255 | } 256 | 257 | // methodName returns the caller of the function calling methodName 258 | func methodName() string { 259 | pc, _, _, _ := runtime.Caller(2) 260 | f := runtime.FuncForPC(pc) 261 | if f == nil { 262 | return "unknown method" 263 | } 264 | return f.Name() 265 | } 266 | 267 | type typeQueue struct { 268 | t reflect.Type 269 | fi *FieldInfo 270 | pp string // Parent path 271 | } 272 | 273 | // A copying append that creates a new slice each time. 274 | func apnd(is []int, i int) []int { 275 | x := make([]int, len(is)+1) 276 | copy(x, is) 277 | x[len(x)-1] = i 278 | return x 279 | } 280 | 281 | type mapf func(string) string 282 | 283 | // parseName parses the tag and the target name for the given field using 284 | // the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the 285 | // field's name to a target name, and tagMapFunc for mapping the tag to 286 | // a target name. 287 | func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { 288 | // first, set the fieldName to the field's name 289 | fieldName = field.Name 290 | // if a mapFunc is set, use that to override the fieldName 291 | if mapFunc != nil { 292 | fieldName = mapFunc(fieldName) 293 | } 294 | 295 | // if there's no tag to look for, return the field name 296 | if tagName == "" { 297 | return "", fieldName 298 | } 299 | 300 | // if this tag is not set using the normal convention in the tag, 301 | // then return the fieldname.. this check is done because according 302 | // to the reflect documentation: 303 | // If the tag does not have the conventional format, 304 | // the value returned by Get is unspecified. 305 | // which doesn't sound great. 306 | if !strings.Contains(string(field.Tag), tagName+":") { 307 | return "", fieldName 308 | } 309 | 310 | // at this point we're fairly sure that we have a tag, so lets pull it out 311 | tag = field.Tag.Get(tagName) 312 | 313 | // if we have a mapper function, call it on the whole tag 314 | // XXX: this is a change from the old version, which pulled out the name 315 | // before the tagMapFunc could be run, but I think this is the right way 316 | if tagMapFunc != nil { 317 | tag = tagMapFunc(tag) 318 | } 319 | 320 | // finally, split the options from the name 321 | parts := strings.Split(tag, ",") 322 | fieldName = parts[0] 323 | 324 | return tag, fieldName 325 | } 326 | 327 | // parseOptions parses options out of a tag string, skipping the name 328 | func parseOptions(tag string) map[string]string { 329 | parts := strings.Split(tag, ",") 330 | options := make(map[string]string, len(parts)) 331 | if len(parts) > 1 { 332 | for _, opt := range parts[1:] { 333 | // short circuit potentially expensive split op 334 | if strings.Contains(opt, "=") { 335 | kv := strings.Split(opt, "=") 336 | options[kv[0]] = kv[1] 337 | continue 338 | } 339 | options[opt] = "" 340 | } 341 | } 342 | return options 343 | } 344 | 345 | // getMapping returns a mapping for the t type, using the tagName, mapFunc and 346 | // tagMapFunc to determine the canonical names of fields. 347 | func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { 348 | m := []*FieldInfo{} 349 | 350 | root := &FieldInfo{} 351 | queue := []typeQueue{} 352 | queue = append(queue, typeQueue{Deref(t), root, ""}) 353 | 354 | QueueLoop: 355 | for len(queue) != 0 { 356 | // pop the first item off of the queue 357 | tq := queue[0] 358 | queue = queue[1:] 359 | 360 | // ignore recursive field 361 | for p := tq.fi.Parent; p != nil; p = p.Parent { 362 | if tq.fi.Field.Type == p.Field.Type { 363 | continue QueueLoop 364 | } 365 | } 366 | 367 | nChildren := 0 368 | if tq.t.Kind() == reflect.Struct { 369 | nChildren = tq.t.NumField() 370 | } 371 | tq.fi.Children = make([]*FieldInfo, nChildren) 372 | 373 | // iterate through all of its fields 374 | for fieldPos := 0; fieldPos < nChildren; fieldPos++ { 375 | 376 | f := tq.t.Field(fieldPos) 377 | 378 | // parse the tag and the target name using the mapping options for this field 379 | tag, name := parseName(f, tagName, mapFunc, tagMapFunc) 380 | 381 | // if the name is "-", disabled via a tag, skip it 382 | if name == "-" { 383 | continue 384 | } 385 | 386 | fi := FieldInfo{ 387 | Field: f, 388 | Name: name, 389 | Zero: reflect.New(f.Type).Elem(), 390 | Options: parseOptions(tag), 391 | } 392 | 393 | // if the path is empty this path is just the name 394 | if tq.pp == "" { 395 | fi.Path = fi.Name 396 | } else { 397 | fi.Path = tq.pp + "." + fi.Name 398 | } 399 | 400 | // skip unexported fields 401 | if len(f.PkgPath) != 0 && !f.Anonymous { 402 | continue 403 | } 404 | 405 | // bfs search of anonymous embedded structs 406 | if f.Anonymous { 407 | pp := tq.pp 408 | if tag != "" { 409 | pp = fi.Path 410 | } 411 | 412 | fi.Embedded = true 413 | fi.Index = apnd(tq.fi.Index, fieldPos) 414 | nChildren := 0 415 | ft := Deref(f.Type) 416 | if ft.Kind() == reflect.Struct { 417 | nChildren = ft.NumField() 418 | } 419 | fi.Children = make([]*FieldInfo, nChildren) 420 | queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) 421 | } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { 422 | fi.Index = apnd(tq.fi.Index, fieldPos) 423 | fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) 424 | queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) 425 | } 426 | 427 | fi.Index = apnd(tq.fi.Index, fieldPos) 428 | fi.Parent = tq.fi 429 | tq.fi.Children[fieldPos] = &fi 430 | m = append(m, &fi) 431 | } 432 | } 433 | 434 | flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} 435 | for _, fi := range flds.Index { 436 | // check if nothing has already been pushed with the same path 437 | // sometimes you can choose to override a type using embedded struct 438 | fld, ok := flds.Paths[fi.Path] 439 | if !ok || fld.Embedded { 440 | flds.Paths[fi.Path] = fi 441 | if fi.Name != "" && !fi.Embedded { 442 | flds.Names[fi.Path] = fi 443 | } 444 | } 445 | } 446 | 447 | return flds 448 | } 449 | -------------------------------------------------------------------------------- /sqlx/sqlx_test.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestRebind(t *testing.T) { 9 | testCases := []struct { 10 | Name string 11 | BindType int 12 | Query string 13 | Want string 14 | }{ 15 | { 16 | Name: "QuestionBindType", 17 | BindType: QUESTION, 18 | Query: "SELECT * FROM users WHERE id = ?", 19 | Want: "SELECT * FROM users WHERE id = ?", 20 | }, 21 | { 22 | Name: "UnknownBindType", 23 | BindType: UNKNOWN, 24 | Query: "SELECT * FROM users WHERE id = ?", 25 | Want: "SELECT * FROM users WHERE id = ?", 26 | }, 27 | { 28 | Name: "DollarBindType", 29 | BindType: DOLLAR, 30 | Query: "SELECT * FROM users WHERE id = ? AND name = ?", 31 | Want: "SELECT * FROM users WHERE id = $1 AND name = $2", 32 | }, 33 | { 34 | Name: "NamedBindType", 35 | BindType: NAMED, 36 | Query: "SELECT * FROM users WHERE id = ? AND name = ?", 37 | Want: "SELECT * FROM users WHERE id = :arg1 AND name = :arg2", 38 | }, 39 | { 40 | Name: "AtBindType", 41 | BindType: AT, 42 | Query: "SELECT * FROM users WHERE id = ? AND name = ?", 43 | Want: "SELECT * FROM users WHERE id = @p1 AND name = @p2", 44 | }, 45 | { 46 | Name: "ComplexQuery", 47 | BindType: DOLLAR, 48 | Query: "SELECT * FROM users WHERE id IN (?, ?) AND name LIKE ? AND age > ?", 49 | Want: "SELECT * FROM users WHERE id IN ($1, $2) AND name LIKE $3 AND age > $4", 50 | }, 51 | } 52 | 53 | for _, tc := range testCases { 54 | t.Run(tc.Name, func(t *testing.T) { 55 | got := Rebind(tc.BindType, tc.Query) 56 | if got != tc.Want { 57 | t.Errorf("Rebind(%q, %q) = %q, want %q", bindTypeToString(tc.BindType), tc.Query, got, tc.Want) 58 | } 59 | }) 60 | } 61 | } 62 | 63 | func bindTypeToString(bindType int) string { 64 | switch bindType { 65 | case QUESTION: 66 | return "?" 67 | case UNKNOWN: 68 | return "?" 69 | case DOLLAR: 70 | return "$" 71 | case NAMED: 72 | return ":" 73 | case AT: 74 | return "@" 75 | } 76 | panic(fmt.Sprintf("unknown bind type: %d", bindType)) 77 | } 78 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go test -cover ./gen && \ 4 | go test -cover ./runtime && \ 5 | go test -cover ./sqlx && \ 6 | go test -tags=test ./gen/integration --------------------------------------------------------------------------------