├── go.mod ├── main.go ├── .gitignore ├── main_go112.go ├── passes └── sqlrows │ ├── sqlrows_test.go │ ├── testdata │ └── src │ │ ├── b │ │ └── b.go │ │ └── a │ │ └── a.go │ └── sqlrows.go ├── .golangci.yml ├── go.sum ├── LICENSE ├── .circleci └── config.yml ├── README.md └── sqlrowsutil └── util.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gostaticanalysis/sqlrows 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/gostaticanalysis/analysisutil v0.0.0-20190329151158-56bca42c7635 7 | golang.org/x/tools v0.0.0-20190401205534-4c644d7e323d 8 | ) 9 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // +build !go1.12 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/gostaticanalysis/sqlrows/passes/sqlrows" 7 | "golang.org/x/tools/go/analysis/singlechecker" 8 | ) 9 | 10 | func main() { singlechecker.Main(sqlrows.Analyzer) } 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | vendor/ 15 | -------------------------------------------------------------------------------- /main_go112.go: -------------------------------------------------------------------------------- 1 | // +build go1.12 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/gostaticanalysis/sqlrows/passes/sqlrows" 7 | "golang.org/x/tools/go/analysis" 8 | "golang.org/x/tools/go/analysis/unitchecker" 9 | ) 10 | 11 | // Analyzers returns analyzers of bodyclose. 12 | func analyzers() []*analysis.Analyzer { 13 | return []*analysis.Analyzer{ 14 | sqlrows.Analyzer, 15 | } 16 | } 17 | 18 | func main() { 19 | unitchecker.Main(analyzers()...) 20 | } 21 | -------------------------------------------------------------------------------- /passes/sqlrows/sqlrows_test.go: -------------------------------------------------------------------------------- 1 | package sqlrows_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gostaticanalysis/sqlrows/passes/sqlrows" 7 | "golang.org/x/tools/go/analysis/analysistest" 8 | ) 9 | 10 | func Test(t *testing.T) { 11 | testdata := analysistest.TestData() 12 | analysistest.Run(t, testdata, sqlrows.Analyzer, "a") 13 | } 14 | 15 | func TestWithCheckErr(t *testing.T) { 16 | testdata := analysistest.TestData() 17 | 18 | analyzer := sqlrows.Analyzer 19 | analyzer.Flags.Set("checkerr", "true") 20 | 21 | analysistest.Run(t, testdata, analyzer, "b") 22 | } 23 | -------------------------------------------------------------------------------- /passes/sqlrows/testdata/src/b/b.go: -------------------------------------------------------------------------------- 1 | package b 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | ) 9 | 10 | func goodQueryContext() { 11 | var ctx context.Context 12 | var db *sql.DB 13 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") 14 | if err != nil { 15 | log.Fatal(err) 16 | } 17 | defer rows.Close() 18 | 19 | for rows.Next() { 20 | var id int64 21 | if err := rows.Scan(&id); err != nil { 22 | panic(err) 23 | } 24 | } 25 | 26 | if err := rows.Err(); err != nil { 27 | panic(err) 28 | } 29 | } 30 | 31 | func badQueryContext() { 32 | var ctx context.Context 33 | var db *sql.DB 34 | 35 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") // want "rows.Err must be called" 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | defer rows.Close() 40 | } 41 | 42 | func skip() { 43 | fmt.Print("skip") 44 | } 45 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters-settings: 2 | govet: 3 | check-shadowing: true 4 | golint: 5 | min-confidence: 0 6 | maligned: 7 | suggest-new: true 8 | dupl: 9 | threshold: 100 10 | goconst: 11 | min-len: 2 12 | min-occurrences: 2 13 | misspell: 14 | locale: US 15 | lll: 16 | line-length: 140 17 | gocritic: 18 | enabled-tags: 19 | - performance 20 | - style 21 | - experimental 22 | disabled-checks: 23 | - wrapperFunc 24 | 25 | linters: 26 | enable-all: true 27 | disable: 28 | - maligned 29 | - prealloc 30 | - gochecknoglobals 31 | 32 | run: 33 | skip-dirs: 34 | - passes/bodyclose/testdata 35 | 36 | issues: 37 | exclude-rules: 38 | - text: "weak cryptographic primitive" 39 | linters: 40 | - gosec 41 | 42 | service: 43 | golangci-lint-version: 1.15.x # use the fixed version to not introduce new linters unexpectedly -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/gostaticanalysis/analysisutil v0.0.0-20190329151158-56bca42c7635 h1:I/ckdXlVHde3unRCAcN/Tcpu7LFwgvyHqnFTeklC9oA= 2 | github.com/gostaticanalysis/analysisutil v0.0.0-20190329151158-56bca42c7635/go.mod h1:eEOZF4jCKGi+aprrirO9e7WKB3beBRtWgqGunKl6pKE= 3 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 4 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 5 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 6 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 7 | golang.org/x/tools v0.0.0-20190311215038-5c2858a9cfe5/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 8 | golang.org/x/tools v0.0.0-20190401205534-4c644d7e323d h1:OhjAiv3biHZzq8VoCB0zBsrzEhwLgp4Gpeby/w9ZGxU= 9 | golang.org/x/tools v0.0.0-20190401205534-4c644d7e323d/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 GoStaticAnalysis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | go-module: timakin/go-module@0.3.0 5 | golangci-lint: timakin/golangci-lint@0.1.0 6 | 7 | executors: 8 | default: 9 | working_directory: /go/src/github.com/gostaticanalysis/sqlrows 10 | docker: 11 | - image: circleci/golang:1.12 12 | environment: 13 | - GO111MODULE: "on" 14 | 15 | jobs: 16 | test: 17 | executor: 18 | name: default 19 | steps: 20 | - attach_workspace: 21 | at: /go/src/github.com/gostaticanalysis/sqlrows 22 | - run: go test ./... 23 | 24 | workflows: 25 | version: 2 26 | setup_and_deploy: 27 | jobs: 28 | - go-module/download: 29 | executor: default 30 | checkout: true 31 | vendoring: true 32 | persist-to-workspace: true 33 | workspace-root: /go/src/github.com/gostaticanalysis/sqlrows 34 | - test: 35 | requires: 36 | - go-module/download 37 | - golangci-lint/lint: 38 | checkout: false 39 | attach-workspace: true 40 | workspace-root: /go/src/github.com/gostaticanalysis/sqlrows 41 | working-directory: /go/src/github.com/gostaticanalysis/sqlrows 42 | requires: 43 | - go-module/download 44 | filters: 45 | branches: 46 | ignore: 47 | - master 48 | -------------------------------------------------------------------------------- /passes/sqlrows/testdata/src/a/a.go: -------------------------------------------------------------------------------- 1 | package a 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | ) 9 | 10 | func goodQueryContext() { 11 | var ctx context.Context 12 | var db *sql.DB 13 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") 14 | if err != nil { 15 | log.Fatal(err) 16 | } 17 | defer rows.Close() 18 | } 19 | 20 | func badQueryContext() { 21 | var ctx context.Context 22 | var db *sql.DB 23 | 24 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") 25 | defer rows.Close() // want "using rows before checking for errors" 26 | if err != nil { 27 | log.Fatal(err) 28 | } 29 | } 30 | 31 | func closeNotCalled() { 32 | var ctx context.Context 33 | var db *sql.DB 34 | 35 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") // want "rows.Close must be called" 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | fmt.Print(rows) 40 | 41 | _, err = db.QueryContext(ctx, "SELECT * FROM users") // want "rows.Close must be called" 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | } 46 | 47 | func issue1() { 48 | readDB, err := sql.Open("mysql", "root:root@tcp(localhost:3306)/mysql?parseTime=true&charset=utf8mb4") 49 | if err != nil { 50 | panic(err.Error()) 51 | } 52 | 53 | rows, err := readDB.Query("SELECT 1") 54 | if err != nil { 55 | panic(err) 56 | } 57 | defer rows.Close() // OK 58 | } 59 | 60 | func issue3() { 61 | readDB, err := sql.Open("mysql", "root:root@tcp(localhost:3306)/mysql?parseTime=true&charset=utf8mb4") 62 | if err != nil { 63 | panic(err.Error()) 64 | } 65 | 66 | rows, err := readDB.Query("SELECT 1") // want "rows.Close must be called in defer function" 67 | rows.Close() 68 | } 69 | 70 | func skip() { 71 | fmt.Print("skip") 72 | } 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlrows 2 | 3 | [![CircleCI](https://circleci.com/gh/gostaticanalysis/sqlrows.svg?style=svg)](https://circleci.com/gh/gostaticanalysis/sqlrows) 4 | 5 | `sqlrows` is a static code analyzer which helps uncover bugs by reporting a diagnostic for mistakes of `sql.Rows` usage. 6 | 7 | ## Install 8 | 9 | You can get `sqlrows` by `go get` command. 10 | 11 | ```bash 12 | $ go get -u github.com/gostaticanalysis/sqlrows 13 | ``` 14 | 15 | ## QuickStart 16 | 17 | `sqlrows` run with `go vet` as below when Go is 1.12 and higher. 18 | 19 | ```bash 20 | $ go vet -vettool=$(which sqlrows) github.com/you/sample_api/... 21 | ``` 22 | 23 | When Go is lower than 1.12, just run `sqlrows` command with the package name (import path). 24 | 25 | But it cannot accept some options such as `--tags`. 26 | 27 | ```bash 28 | $ sqlrows github.com/you/sample_api/... 29 | ``` 30 | 31 | ## Analyzer 32 | 33 | `sqlrows` checks a common mistake when using `*sql.Rows`. 34 | 35 | At first, you must call `rows.Close()` in a defer function. A connection will not be reused if you unexpectedly failed to scan records and forgot to close `*sql.Rows`. 36 | 37 | ```go 38 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | for rows.Next() { 44 | err = rows.Scan(...) 45 | if err != nil { 46 | return nil, err // NG: this return will not release a connection. 47 | } 48 | } 49 | ``` 50 | 51 | And, if you defer a function call to close the `*sql.Rows` before checking the error that determines whether the return is valid, it will mean you dually call `rows.Close()`. 52 | 53 | ```go 54 | rows, err := db.QueryContext(ctx, "SELECT * FROM users") 55 | defer rows.Close() // NG: using rows before checking for errors 56 | if err != nil { 57 | return nil, err 58 | } 59 | ``` 60 | 61 | It may cause panic and nil-pointer reference but it won't clearly teach you that is due to them. -------------------------------------------------------------------------------- /sqlrowsutil/util.go: -------------------------------------------------------------------------------- 1 | package sqlrowsutil 2 | 3 | import ( 4 | "go/types" 5 | 6 | "golang.org/x/tools/go/ssa" 7 | ) 8 | 9 | // CalledChecker checks a function is called. 10 | // See From and Func. 11 | type CalledChecker struct { 12 | Ignore func(instr ssa.Instruction) bool 13 | } 14 | 15 | // Func returns true when f is called in the instr. 16 | // If recv is not nil, Called also checks the receiver. 17 | func (c *CalledChecker) Func(instr ssa.Instruction, recv ssa.Value, f *types.Func) bool { 18 | 19 | if c.Ignore != nil && c.Ignore(instr) { 20 | return false 21 | } 22 | 23 | call, ok := instr.(ssa.CallInstruction) 24 | if !ok { 25 | return false 26 | } 27 | 28 | common := call.Common() 29 | if common == nil { 30 | return false 31 | } 32 | 33 | callee := common.StaticCallee() 34 | if callee == nil { 35 | return false 36 | } 37 | 38 | fn, ok := callee.Object().(*types.Func) 39 | if !ok { 40 | return false 41 | } 42 | 43 | if recv != nil && 44 | common.Signature().Recv() != nil && 45 | (len(common.Args) == 0 || common.Args[0] != recv) { 46 | return false 47 | } 48 | 49 | return fn == f 50 | } 51 | 52 | // From checks whether receiver's method is called in an instruction 53 | // which belongs to after i-th instructions, or in successor blocks of b. 54 | // The first result is above value. 55 | // The second result is whether type of i-th instruction does not much receiver 56 | // or matches with ignore cases. 57 | func (c *CalledChecker) From(b *ssa.BasicBlock, i int, receiver types.Type, methods ...*types.Func) (called, ok bool) { 58 | if b == nil || i < 0 || i >= len(b.Instrs) || 59 | receiver == nil || len(methods) == 0 { 60 | return false, false 61 | } 62 | 63 | v, ok := b.Instrs[i].(ssa.Value) 64 | if !ok { 65 | return false, false 66 | } 67 | 68 | if !identical(v.Type(), receiver) { 69 | return false, false 70 | } 71 | 72 | from := &calledFrom{recv: v, fs: methods, ignore: c.Ignore} 73 | if from.ignored() { 74 | return false, false 75 | } 76 | 77 | if from.instrs(b.Instrs[i:]) || 78 | from.succs(b) { 79 | return true, true 80 | } 81 | 82 | return false, true 83 | } 84 | 85 | type calledFrom struct { 86 | recv ssa.Value 87 | fs []*types.Func 88 | done map[*ssa.BasicBlock]bool 89 | ignore func(ssa.Instruction) bool 90 | } 91 | 92 | func (c *calledFrom) ignored() bool { 93 | refs := c.recv.Referrers() 94 | if refs == nil { 95 | return false 96 | } 97 | 98 | for _, ref := range *refs { 99 | if !c.isOwn(ref) && 100 | ((c.ignore != nil && c.ignore(ref)) || 101 | c.isRet(ref) || c.isArg(ref)) { 102 | return true 103 | } 104 | } 105 | 106 | return false 107 | } 108 | 109 | func (c *calledFrom) isOwn(instr ssa.Instruction) bool { 110 | v, ok := instr.(ssa.Value) 111 | if !ok { 112 | return false 113 | } 114 | return v == c.recv 115 | } 116 | 117 | func (c *calledFrom) isRet(instr ssa.Instruction) bool { 118 | 119 | ret, ok := instr.(*ssa.Return) 120 | if !ok { 121 | return false 122 | } 123 | 124 | for _, r := range ret.Results { 125 | if r == c.recv { 126 | return true 127 | } 128 | } 129 | 130 | return false 131 | } 132 | 133 | func (c *calledFrom) isArg(instr ssa.Instruction) bool { 134 | 135 | call, ok := instr.(ssa.CallInstruction) 136 | if !ok { 137 | return false 138 | } 139 | 140 | common := call.Common() 141 | if common == nil { 142 | return false 143 | } 144 | 145 | args := common.Args 146 | if common.Signature().Recv() != nil { 147 | args = args[1:] 148 | } 149 | 150 | for i := range args { 151 | if args[i] == c.recv { 152 | return true 153 | } 154 | } 155 | 156 | return false 157 | } 158 | 159 | func (c *calledFrom) instrs(instrs []ssa.Instruction) bool { 160 | for _, instr := range instrs { 161 | for _, f := range c.fs { 162 | if Called(instr, c.recv, f) { 163 | return true 164 | } 165 | } 166 | } 167 | return false 168 | } 169 | 170 | func (c *calledFrom) succs(b *ssa.BasicBlock) bool { 171 | if c.done == nil { 172 | c.done = map[*ssa.BasicBlock]bool{} 173 | } 174 | 175 | if c.done[b] { 176 | return false 177 | } 178 | c.done[b] = true 179 | 180 | if len(b.Succs) == 0 { 181 | return false 182 | } 183 | 184 | var called bool 185 | for _, s := range b.Succs { 186 | if c.instrs(s.Instrs) || c.succs(s) { 187 | called = true 188 | } 189 | } 190 | 191 | return called 192 | } 193 | 194 | // CalledFrom checks whether receiver's method is called in an instruction 195 | // which belongs to after i-th instructions, or in successor blocks of b. 196 | // The first result is above value. 197 | // The second result is whether type of i-th instruction does not much receiver 198 | // or matches with ignore cases. 199 | func CalledFrom(b *ssa.BasicBlock, i int, receiver types.Type, methods ...*types.Func) (called, ok bool) { 200 | return new(CalledChecker).From(b, i, receiver, methods...) 201 | } 202 | 203 | // Called returns true when f is called in the instr. 204 | // If recv is not nil, Called also checks the receiver. 205 | func Called(instr ssa.Instruction, recv ssa.Value, f *types.Func) bool { 206 | return new(CalledChecker).Func(instr, recv, f) 207 | } 208 | 209 | // see: https://github.com/golang/go/issues/19670 210 | func identical(x, y types.Type) (ret bool) { 211 | defer func() { 212 | r := recover() 213 | switch r := r.(type) { 214 | case string: 215 | if r == "unreachable" { 216 | ret = false 217 | return 218 | } 219 | case nil: 220 | return 221 | } 222 | panic(r) 223 | }() 224 | return types.Identical(x, y) 225 | } 226 | -------------------------------------------------------------------------------- /passes/sqlrows/sqlrows.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | 3 | // Package sqlrows defines an Analyzer that checks for mistakes using sql.Rows. 4 | package sqlrows 5 | 6 | import ( 7 | "flag" 8 | "go/ast" 9 | "go/token" 10 | "go/types" 11 | 12 | "github.com/gostaticanalysis/analysisutil" 13 | "github.com/gostaticanalysis/sqlrows/sqlrowsutil" 14 | "golang.org/x/tools/go/analysis" 15 | "golang.org/x/tools/go/analysis/passes/buildssa" 16 | "golang.org/x/tools/go/analysis/passes/inspect" 17 | "golang.org/x/tools/go/ast/inspector" 18 | "golang.org/x/tools/go/ssa" 19 | ) 20 | 21 | const Doc = `check for mistakes using Rows iterator of database/sql 22 | A common mistake when using the database/sql package is to defer a function 23 | call to close the Rows before checking the error that 24 | determines whether the returned records are valid: 25 | rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age) 26 | defer rows.Close() 27 | if err != nil { 28 | log.Fatal(err) 29 | } 30 | // (defer statement belongs here) 31 | This checker helps uncover latent nil dereference bugs by reporting a 32 | diagnostic for such mistakes.` 33 | 34 | var ( 35 | flagCheckErr *bool 36 | ) 37 | 38 | var Analyzer = &analysis.Analyzer{ 39 | Name: "sqlrows", 40 | Doc: Doc, 41 | Requires: []*analysis.Analyzer{ 42 | inspect.Analyzer, 43 | buildssa.Analyzer, 44 | }, 45 | Flags: func() flag.FlagSet { 46 | flagSet := flag.NewFlagSet("sqlrows", flag.ExitOnError) 47 | flagCheckErr = flagSet.Bool("checkerr", false, "check whether rows.Err() has been called") 48 | return *flagSet 49 | }(), 50 | Run: run, 51 | } 52 | 53 | func run(pass *analysis.Pass) (interface{}, error) { 54 | funcs := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA).SrcFuncs 55 | inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 56 | 57 | // Fast path: if the package doesn't import database/sql, 58 | // skip the traversal. 59 | if !imports(pass.Pkg, "database/sql") { 60 | return nil, nil 61 | } 62 | 63 | rowsType := analysisutil.TypeOf(pass, "database/sql", "*Rows") 64 | if rowsType == nil { 65 | // skip checking 66 | return nil, nil 67 | } 68 | 69 | var methods []*types.Func 70 | if m := analysisutil.MethodOf(rowsType, "Close"); m != nil { 71 | methods = append(methods, m) 72 | } 73 | 74 | var errMethod *types.Func 75 | if m := analysisutil.MethodOf(rowsType, "Err"); m != nil { 76 | errMethod = m 77 | } 78 | 79 | for _, f := range funcs { 80 | for _, b := range f.Blocks { 81 | for i := range b.Instrs { 82 | var pos token.Pos 83 | var refs *[]ssa.Instruction 84 | switch instr := b.Instrs[i].(type) { 85 | case *ssa.Extract: 86 | pos = instr.Tuple.Pos() 87 | refs = instr.Referrers() 88 | default: 89 | pos = instr.Pos() 90 | } 91 | called, ok := sqlrowsutil.CalledFrom(b, i, rowsType, methods...) 92 | if called { 93 | var defered bool 94 | if refs != nil { 95 | for _, ref := range *refs { 96 | if _, ok := ref.(*ssa.Defer); ok { 97 | defered = true 98 | } 99 | } 100 | } else { 101 | pass.Reportf(pos, "*refs is nil for some reason. This line at least prevents a panic") 102 | } 103 | 104 | if !defered { 105 | pass.Reportf(pos, "rows.Close must be called in defer function") 106 | } 107 | } 108 | if ok && !called { 109 | pass.Reportf(pos, "rows.Close must be called") 110 | } 111 | 112 | if *flagCheckErr { 113 | if called, ok := sqlrowsutil.CalledFrom(b, i, rowsType, errMethod); ok && !called { 114 | pass.Reportf(pos, "rows.Err must be called") 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | nodeFilter := []ast.Node{ 122 | (*ast.CallExpr)(nil), 123 | } 124 | inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) bool { 125 | if !push { 126 | return true 127 | } 128 | call := n.(*ast.CallExpr) 129 | if !hasRowsSignature(pass.TypesInfo, call) { 130 | return true // the function call is not related to this check. 131 | } 132 | 133 | // Find the innermost containing block, and get the list 134 | // of statements starting with the one containing call. 135 | stmts := restOfBlock(stack) 136 | if len(stmts) < 2 { 137 | return true // the call to the http function is the last statement of the block. 138 | } 139 | 140 | asg, ok := stmts[0].(*ast.AssignStmt) 141 | if !ok { 142 | return true // the first statement is not assignment. 143 | } 144 | resp := rootIdent(asg.Lhs[0]) 145 | if resp == nil { 146 | return true // could not find the sql.Rows in the assignment. 147 | } 148 | 149 | def, ok := stmts[1].(*ast.DeferStmt) 150 | if !ok { 151 | return true // the following statement is not a defer. 152 | } 153 | root := rootIdent(def.Call.Fun) 154 | if root == nil { 155 | return true // could not find the receiver of the defer call. 156 | } 157 | 158 | if resp.Obj == root.Obj { 159 | pass.Reportf(root.Pos(), "using %s before checking for errors", resp.Name) 160 | } 161 | return true 162 | }) 163 | 164 | return nil, nil 165 | } 166 | 167 | // hasRowsSignature checks whether the given call expression is on 168 | // either a function of the database/sql package that returns (*sql.Rows, error). 169 | func hasRowsSignature(info *types.Info, expr *ast.CallExpr) bool { 170 | fun, _ := expr.Fun.(*ast.SelectorExpr) 171 | sig, _ := info.Types[fun].Type.(*types.Signature) 172 | if sig == nil { 173 | return false // the call is not of the form x.f() 174 | } 175 | 176 | res := sig.Results() 177 | if res.Len() != 2 { 178 | return false // the function called does not return two values. 179 | } 180 | if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !isNamedType(ptr.Elem(), "database/sql", "Rows") { 181 | return false // the first return type is not *sql.Rows. 182 | } 183 | 184 | errorType := types.Universe.Lookup("error").Type() 185 | if !types.Identical(res.At(1).Type(), errorType) { 186 | return false // the second return type is not error 187 | } 188 | 189 | return true 190 | } 191 | 192 | // restOfBlock, given a traversal stack, finds the innermost containing 193 | // block and returns the suffix of its statements starting with the 194 | // current node (the last element of stack). 195 | func restOfBlock(stack []ast.Node) []ast.Stmt { 196 | for i := len(stack) - 1; i >= 0; i-- { 197 | if b, ok := stack[i].(*ast.BlockStmt); ok { 198 | for j, v := range b.List { 199 | if v == stack[i+1] { 200 | return b.List[j:] 201 | } 202 | } 203 | break 204 | } 205 | } 206 | return nil 207 | } 208 | 209 | // rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found. 210 | func rootIdent(n ast.Node) *ast.Ident { 211 | switch n := n.(type) { 212 | case *ast.SelectorExpr: 213 | return rootIdent(n.X) 214 | case *ast.Ident: 215 | return n 216 | default: 217 | return nil 218 | } 219 | } 220 | 221 | // isNamedType reports whether t is the named type path.name. 222 | func isNamedType(t types.Type, path, name string) bool { 223 | n, ok := t.(*types.Named) 224 | if !ok { 225 | return false 226 | } 227 | obj := n.Obj() 228 | return obj.Name() == name && obj.Pkg() != nil && obj.Pkg().Path() == path 229 | } 230 | 231 | func imports(pkg *types.Package, path string) bool { 232 | for _, imp := range pkg.Imports() { 233 | if imp.Path() == path { 234 | return true 235 | } 236 | } 237 | return false 238 | } 239 | --------------------------------------------------------------------------------