├── .gitignore ├── .travis.yml ├── cmd └── sqlargs │ └── main.go ├── LICENSE ├── testdata └── src │ ├── embed │ ├── embed2.go │ └── embed.go │ ├── sqlx │ ├── embed.go │ ├── inherited.go │ ├── sqlx.go │ └── sqlx2.go │ └── basic │ └── basic.go ├── sqlargs_test.go ├── README.md ├── query_analyzer.go └── sqlargs.go /.gitignore: -------------------------------------------------------------------------------- 1 | /sqlargs 2 | testdata/pkg/ 3 | testdata/src/github.com/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.11.x 5 | - 1.12.x 6 | - tip 7 | 8 | script: 9 | - GOPATH=`pwd`/testdata go get github.com/jmoiron/sqlx 10 | - go test -v ./... 11 | -------------------------------------------------------------------------------- /cmd/sqlargs/main.go: -------------------------------------------------------------------------------- 1 | // package main runs the sqlargs analyzer. 2 | package main 3 | 4 | import ( 5 | "github.com/agnivade/sqlargs" 6 | "golang.org/x/tools/go/analysis/singlechecker" 7 | ) 8 | 9 | func main() { singlechecker.Main(sqlargs.Analyzer) } 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Agniva De Sarker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /testdata/src/embed/embed2.go: -------------------------------------------------------------------------------- 1 | package embed 2 | 3 | // Test taken from https://github.com/danjac/kanban/blob/master/database/tasks.go 4 | import ( 5 | "database/sql" 6 | ) 7 | 8 | type Task struct { 9 | ID int64 `db:"id" json:"id,string"` 10 | CardID int64 `db:"card_id" json:"-"` 11 | Text string `db:"label" json:"text" binding:"required,max=60"` 12 | } 13 | 14 | type TaskDB interface { 15 | Delete(int64) error 16 | Create(*Task) error 17 | Move(int64, int64) error 18 | } 19 | 20 | type defaultTaskDB struct { 21 | *sql.DB 22 | } 23 | 24 | func (db *defaultTaskDB) Create(task *Task) { 25 | db.Exec("insertinto tasks(card_id, label) values (?, ?)", task.CardID, task.Text) // want `Invalid query: syntax error at or near "insertinto"` 26 | 27 | db.Exec("insert into tasks(card_id, label) values ($1, $2)", task.CardID) // // want `No. of args \(1\) is less than no. of params \(2\)` 28 | } 29 | 30 | func (db *defaultTaskDB) Move(taskID int64, newCardID int64) { 31 | db.Exec("update tasks set card_id=$1 where id=$2", newCardID) 32 | } 33 | 34 | // Testing non-pointer receiver. 35 | func (db defaultTaskDB) Delete(taskID int64) { 36 | db.Exec("delete from taskswhere id=$1", taskID) // want `Invalid query: syntax error at or near "="` 37 | db.Exec("delete from tasks where id=$1", taskID) 38 | } 39 | -------------------------------------------------------------------------------- /testdata/src/sqlx/embed.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | // Test taken from https://github.com/danjac/kanban/blob/master/database/tasks.go 4 | import ( 5 | "github.com/jmoiron/sqlx" 6 | ) 7 | 8 | type Task struct { 9 | ID int64 `db:"id" json:"id,string"` 10 | CardID int64 `db:"card_id" json:"-"` 11 | Text string `db:"label" json:"text" binding:"required,max=60"` 12 | } 13 | 14 | type TaskDB interface { 15 | Delete(int64) error 16 | Create(*Task) error 17 | Move(int64, int64) error 18 | } 19 | 20 | type defaultTaskDB struct { 21 | *sqlx.DB 22 | } 23 | 24 | func (db *defaultTaskDB) Create(task *Task) { 25 | db.Exec("insertinto tasks(card_id, label) values (?, ?)", task.CardID, task.Text) // want `Invalid query: syntax error at or near "insertinto"` 26 | 27 | db.Exec("insert into tasks(card_id, label) values ($1, $2)", task.CardID) // // want `No. of args \(1\) is less than no. of params \(2\)` 28 | } 29 | 30 | func (db *defaultTaskDB) Move(taskID int64, newCardID int64) { 31 | db.Exec("update tasks set card_id=$1 where id=$2", newCardID) 32 | } 33 | 34 | // Testing non-pointer receiver. 35 | func (db defaultTaskDB) Delete(taskID int64) { 36 | db.Exec("delete from taskswhere id=$1", taskID) // want `Invalid query: syntax error at or near "="` 37 | db.Exec("delete from tasks where id=$1", taskID) 38 | } 39 | -------------------------------------------------------------------------------- /sqlargs_test.go: -------------------------------------------------------------------------------- 1 | package sqlargs 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/tools/go/analysis/analysistest" 7 | ) 8 | 9 | func TestBasic(t *testing.T) { 10 | testdata := analysistest.TestData() 11 | analysistest.Run(t, testdata, Analyzer, "basic") // loads testdata/src/basic 12 | } 13 | 14 | func TestEmbed(t *testing.T) { 15 | testdata := analysistest.TestData() 16 | analysistest.Run(t, testdata, Analyzer, "embed") 17 | } 18 | 19 | func TestSqlx(t *testing.T) { 20 | testdata := analysistest.TestData() 21 | analysistest.Run(t, testdata, Analyzer, "sqlx") 22 | } 23 | 24 | func TestStripVendor(t *testing.T) { 25 | tests := []struct { 26 | name string 27 | input string 28 | want string 29 | }{ 30 | { 31 | name: "Strip", 32 | input: "github.com/godwhoa/upboat/vendor/github.com/jmoiron/sqlx", 33 | want: "github.com/jmoiron/sqlx", 34 | }, 35 | { 36 | name: "Ignore", 37 | input: "github.com/jmoiron/sqlx", 38 | want: "github.com/jmoiron/sqlx", 39 | }, 40 | { 41 | name: "Vendor in pkg URL", 42 | input: "github.com/vendor/upboat/vendor/github.com/jmoiron/sqlx", 43 | want: "github.com/jmoiron/sqlx", 44 | }, 45 | } 46 | for _, tt := range tests { 47 | t.Run(tt.name, func(t *testing.T) { 48 | if got := stripVendor(tt.input); got != tt.want { 49 | t.Errorf("stripVendor() = %v, want %v", got, tt.want) 50 | } 51 | }) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlargs [![Build Status](https://travis-ci.org/agnivade/sqlargs.svg?branch=master)](https://travis-ci.org/agnivade/sqlargs) 2 | A vet analyzer which checks sql(only Postgres!) queries for correctness. 3 | 4 | ### Background 5 | 6 | Let's assume you have a query like: 7 | 8 | `db.Exec("insert into table (c1, c2, c3, c4) values ($1, $2, $3, $4)", p1, p2, p3, p4)`. 9 | 10 | It's the middle of the night and you need to add a new column. You quickly change the query to - 11 | 12 | `db.Exec("insert into table (c1, c2, c3, c4, c5) values ($1, $2, $3, $4)", p1, p2, p3, p4, p5)`. 13 | 14 | Everything compiles fine. Except it's not ! A `$5` is missing. It can even go the other way round; you add the `$5` but forget to add `c5`. 15 | 16 | This is a semantic error which will eventually get caught while running the app. Not to mention, if there are tests. But sometimes I get lazy and don't write tests for _all_ my sql queries. :sweat_smile: 17 | 18 | `sqlargs` will statically check for semantic errors like these and flag them beforehand. 19 | 20 | ### Quick start 21 | 22 | This is written using the `go/analysis` API. So you can plug this directly into `go vet`, or you can run it as a standalone tool too. 23 | 24 | Install: 25 | ``` 26 | go get github.com/agnivade/sqlargs/cmd/sqlargs 27 | ``` 28 | 29 | And then run it on your repo: 30 | ``` 31 | go vet -vettool $(which sqlargs) ./... # Has to be >= 1.12 32 | OR 33 | sqlargs ./... 34 | ``` 35 | 36 | __P.S.: This only works for Postgres queries. So if your codebase has queries which do not match with the postgres query parser, it might flag incorrect errors.__ 37 | -------------------------------------------------------------------------------- /query_analyzer.go: -------------------------------------------------------------------------------- 1 | package sqlargs 2 | 3 | import ( 4 | "go/ast" 5 | 6 | pg_query "github.com/lfittl/pg_query_go" 7 | nodes "github.com/lfittl/pg_query_go/nodes" 8 | "golang.org/x/tools/go/analysis" 9 | ) 10 | 11 | func analyzeQuery(query string, call *ast.CallExpr, pass *analysis.Pass, checkArgs bool) { 12 | tree, err := pg_query.Parse(query) 13 | if err != nil { 14 | pass.Reportf(call.Lparen, "Invalid query: %v", err) 15 | return 16 | } 17 | // Analyze the parse tree for semantic errors. 18 | if len(tree.Statements) == 0 { 19 | return 20 | } 21 | rawStmt, ok := tree.Statements[0].(nodes.RawStmt) 22 | if !ok { 23 | return 24 | } 25 | switch stmt := rawStmt.Stmt.(type) { 26 | // 1. For insert statements, the no. of columns(if present) should be equal to no. of values. 27 | case nodes.InsertStmt: 28 | numCols := len(stmt.Cols.Items) 29 | if numCols == 0 { 30 | return 31 | } 32 | selStmt, ok := stmt.SelectStmt.(nodes.SelectStmt) 33 | if !ok { 34 | return 35 | } 36 | if len(selStmt.ValuesLists) == 0 { 37 | return 38 | } 39 | numValues := len(selStmt.ValuesLists[0]) 40 | if numCols != numValues { 41 | pass.Reportf(call.Lparen, "No. of columns (%d) not equal to no. of values (%d)", numCols, numValues) 42 | } 43 | if !checkArgs { 44 | return 45 | } 46 | numParams := numParams(selStmt.ValuesLists[0]) 47 | args := len(call.Args[1:]) 48 | // A safe check is to just check if args are less than no. of params. If this is true, 49 | // then there has to be an error somewhere. On the contrary, if there are less params 50 | // found than args, then it just means we haven't parsed the query well enough and there are 51 | // other parts of the query which use the other arguments. 52 | if args < numParams { 53 | pass.Reportf(call.Lparen, "No. of args (%d) is less than no. of params (%d)", args, numParams) 54 | } 55 | } 56 | } 57 | 58 | //numParams returns the count of unique paramters. 59 | func numParams(params []nodes.Node) int { 60 | num := 0 61 | // posMap is used to keep track of unique positional parameters. 62 | posMap := make(map[int]bool) 63 | for _, p := range params { 64 | switch t := p.(type) { 65 | case nodes.ParamRef: 66 | if !posMap[t.Number] { 67 | num++ 68 | posMap[t.Number] = true 69 | } 70 | case nodes.TypeCast: 71 | if pRef, ok := t.Arg.(nodes.ParamRef); ok { 72 | if !posMap[pRef.Number] { 73 | num++ 74 | posMap[pRef.Number] = true 75 | } 76 | } 77 | } 78 | } 79 | return num 80 | } 81 | -------------------------------------------------------------------------------- /testdata/src/embed/embed.go: -------------------------------------------------------------------------------- 1 | package embed 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "time" 8 | ) 9 | 10 | type myDB struct { 11 | *sql.DB 12 | } 13 | 14 | func runDB() { 15 | var db myDB 16 | defer db.Close() 17 | var p1, p2 string 18 | 19 | // Execs 20 | db.Exec(`DELETE FROM t`) 21 | 22 | queryStr := fmt.Sprintf(`INSERT INTO t VALUES ($1, $%d) `, 1) // Should not crash 23 | db.Exec(queryStr, p1, p2) 24 | 25 | db.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 26 | 27 | const q = `INSERT INTO t(c1 c2) VALUES ($1, $2)` 28 | db.Exec(q, p1, p2) // want `Invalid query: syntax error at or near "c2"` 29 | 30 | db.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 31 | 32 | db.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 33 | 34 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 35 | 36 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`) // // want `No. of args \(0\) is less than no. of params \(1\)` 37 | 38 | // QueryRow 39 | db.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 40 | 41 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 42 | 43 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 44 | 45 | ctx := context.Background() 46 | db.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 47 | 48 | db.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 49 | } 50 | 51 | func runTx() { 52 | var db myDB 53 | tx, _ := db.Begin() 54 | defer tx.Commit() 55 | var p1, p2 string 56 | 57 | // Execs 58 | tx.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 59 | 60 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, p2) 61 | 62 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 63 | 64 | tx.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 65 | 66 | tx.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 67 | 68 | // QueryRow 69 | tx.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 70 | 71 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 72 | 73 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 74 | 75 | ctx := context.Background() 76 | tx.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 77 | 78 | tx.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 79 | } 80 | -------------------------------------------------------------------------------- /testdata/src/sqlx/inherited.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/jmoiron/sqlx" 9 | ) 10 | 11 | func runDB() { 12 | var db *sqlx.DB 13 | defer db.Close() 14 | var p1, p2 string 15 | 16 | // Execs 17 | db.Exec(`DELETE FROM t`) 18 | 19 | queryStr := fmt.Sprintf(`INSERT INTO t VALUES ($1, $%d) `, 1) // Should not crash 20 | db.Exec(queryStr, p1, p2) 21 | 22 | db.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 23 | 24 | const q = `INSERT INTO t(c1 c2) VALUES ($1, $2)` 25 | db.Exec(q, p1, p2) // want `Invalid query: syntax error at or near "c2"` 26 | 27 | db.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 28 | 29 | db.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 30 | 31 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 32 | 33 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`) // // want `No. of args \(0\) is less than no. of params \(1\)` 34 | 35 | // QueryRow 36 | db.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 37 | 38 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 39 | 40 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 41 | 42 | ctx := context.Background() 43 | db.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 44 | 45 | db.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 46 | } 47 | 48 | func runTx() { 49 | // Doing a non-pointer check with transactions. 50 | var tx sqlx.Tx 51 | defer tx.Commit() 52 | var p1, p2 string 53 | 54 | // Execs 55 | tx.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 56 | 57 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, p2) 58 | 59 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 60 | 61 | tx.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 62 | 63 | tx.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 64 | 65 | // QueryRow 66 | tx.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 67 | 68 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 69 | 70 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 71 | 72 | ctx := context.Background() 73 | tx.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 74 | 75 | tx.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 76 | } 77 | -------------------------------------------------------------------------------- /testdata/src/basic/basic.go: -------------------------------------------------------------------------------- 1 | package basic 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "time" 8 | ) 9 | 10 | func runDB() { 11 | var db *sql.DB 12 | defer db.Close() 13 | var p1, p2 string 14 | 15 | // Execs 16 | db.Exec(`DELETE FROM t`) 17 | 18 | queryStr := fmt.Sprintf(`INSERT INTO t VALUES ($1, $%d) `, 1) // Should not crash 19 | db.Exec(queryStr, p1, p2) 20 | 21 | db.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 22 | 23 | const q = `INSERT INTO t(c1 c2) VALUES ($1, $2)` 24 | db.Exec(q, p1, p2) // want `Invalid query: syntax error at or near "c2"` 25 | 26 | r := `INSERT INTO t(c1 c2) VALUES ($1, $2)` 27 | db.Exec(r, p1, p2) // want `Invalid query: syntax error at or near "c2"` 28 | 29 | db.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 30 | 31 | db.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 32 | 33 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 34 | 35 | db.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`) // // want `No. of args \(0\) is less than no. of params \(1\)` 36 | 37 | // QueryRow 38 | db.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 39 | 40 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 41 | 42 | db.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 43 | 44 | ctx := context.Background() 45 | db.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 46 | 47 | r = `INSERT INTO t(c1 c2) VALUES ($1, $2)` 48 | db.ExecContext(ctx, r, p1, p2) // want `Invalid query: syntax error at or near "c2"` 49 | 50 | db.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 51 | } 52 | 53 | func runTx() { 54 | // Doing a non-pointer check with transactions. 55 | var tx sql.Tx 56 | defer tx.Commit() 57 | var p1, p2 string 58 | 59 | // Execs 60 | tx.Exec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 61 | 62 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, p2) 63 | 64 | tx.Exec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 65 | 66 | tx.Exec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 67 | 68 | tx.Exec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 69 | 70 | // QueryRow 71 | tx.QueryRow(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 72 | 73 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 74 | 75 | tx.QueryRow(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 76 | 77 | ctx := context.Background() 78 | tx.ExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 79 | 80 | tx.QueryRowContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 81 | } 82 | -------------------------------------------------------------------------------- /testdata/src/sqlx/sqlx.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/jmoiron/sqlx" 9 | ) 10 | 11 | func runSqlxDB() { 12 | var db *sqlx.DB 13 | defer db.Close() 14 | var p1, p2 string 15 | 16 | // Execs 17 | db.MustExec(`DELETE FROM t`) 18 | 19 | queryStr := fmt.Sprintf(`INSERT INTO t VALUES ($1, $%d) `, 1) // Should not crash 20 | db.MustExec(queryStr, p1, p2) 21 | 22 | db.MustExec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 23 | 24 | const q = `INSERT INTO t(c1 c2) VALUES ($1, $2)` 25 | db.MustExec(q, p1, p2) // want `Invalid query: syntax error at or near "c2"` 26 | 27 | db.MustExec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 28 | 29 | db.MustExec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 30 | 31 | db.MustExec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 32 | 33 | db.MustExec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`) // // want `No. of args \(0\) is less than no. of params \(1\)` 34 | 35 | // Queryx 36 | db.Queryx(`SELECT * FROM students`) 37 | db.Queryx(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 38 | 39 | // QueryRowx 40 | db.QueryRowx(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 41 | 42 | db.QueryRowx(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 43 | 44 | db.QueryRowx(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 45 | 46 | // Context 47 | ctx := context.Background() 48 | db.MustExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 49 | db.QueryxContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 50 | db.QueryRowxContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 51 | } 52 | 53 | func runSqlxTx() { 54 | // Doing a non-pointer check with transactions. 55 | var tx sqlx.Tx 56 | defer tx.Commit() 57 | var p1, p2 string 58 | 59 | // Execs 60 | tx.MustExec(`INSERT INTO t VALUES ($1, $2)`, p1, p2) 61 | 62 | tx.MustExec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, p2) 63 | 64 | tx.MustExec(`INSERT INTO t (c1, c2) VALUES ($1, $2)`, p1, "const") 65 | 66 | tx.MustExec(`INSERT INTO t (c1) VALUES ($1::uuid, $2)`, p1, p2) // want `No. of columns \(1\) not equal to no. of values \(2\)` 67 | 68 | tx.MustExec(`INSERT INTO t (c1, c2, c3, c4, c5) values ('o', $1, $1, 1, '{"duration": "1440h00m00s"}')`, time.Now()) 69 | 70 | // QueryRow 71 | tx.QueryRowx(`INSERT INTO t (c1, c2) VALUES ($1) RETURNING c1`, p1, p2) // want `No. of columns \(2\) not equal to no. of values \(1\)` 72 | 73 | tx.QueryRowx(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1, p2) 74 | 75 | tx.QueryRowx(`INSERT INTO t (c1, c2, c3, c4) VALUES ('o', $1, 'epoch'::timestamp, $2) RETURNING c1`, p1) // want `No. of args \(1\) is less than no. of params \(2\)` 76 | 77 | ctx := context.Background() 78 | tx.MustExecContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2)`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 79 | tx.QueryxContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 80 | tx.QueryRowxContext(ctx, `INSERT INTO t(c1 c2) VALUES ($1, $2) RETURNING c2`, p1, p2) // want `Invalid query: syntax error at or near "c2"` 81 | } 82 | -------------------------------------------------------------------------------- /testdata/src/sqlx/sqlx2.go: -------------------------------------------------------------------------------- 1 | // Test taken from https://github.com/croll/arkeogis-server/blob/master/model/database.go 2 | package sqlx 3 | 4 | import ( 5 | "bytes" 6 | 7 | "github.com/jmoiron/sqlx" 8 | ) 9 | 10 | const Database_handle_InsertStr = "\"database_id\", \"import_id\", \"identifier\", \"url\", \"declared_creation_date\", \"created_at\"" 11 | const Database_handle_InsertValuesStr = "$1, $2, $3, $4, $5, now()" 12 | 13 | type Database struct { 14 | Id int `db:"id" json:"id"` 15 | Name string `db:"name" json:"name" min:"1" max:"255" error:"DATABASE.FIELD_NAME.T_CHECK_MANDATORY"` 16 | Owner int `db:"owner" json:"owner"` // User.Id 17 | Editor string `db:"editor" json:"editor"` 18 | Contributor string `db:"contributor" json:"contributor"` 19 | Default_language string `db:"default_language" json:"default_language"` // Lang.Isocode 20 | State string `db:"state" json:"state" enum:"undefined,in-progress,finished" error:"DATABASE.FIELD_STATE.T_CHECK_INCORRECT"` 21 | License_id int `db:"license_id" json:"license_id"` // License.Id 22 | } 23 | 24 | // AnotherExistsWithSameName checks if database already exists with same name and owned by another user 25 | func (d *Database) AnotherExistsWithSameName(tx *sqlx.Tx) (exists bool, err error) { 26 | tx.QueryRowx("SELECT id FROM \"database\" WHERE name = $1 AND owner != $2", d.Name, d.Owner).Scan(&d.Id) 27 | return true, nil 28 | } 29 | 30 | // Get retrieves informations about a database stored in the main table 31 | func (d *Database) Get(tx *sqlx.Tx) (err error) { 32 | stmt, err := tx.PrepareNamed("SELECT * from \"database\" WHERE id=$1") 33 | defer stmt.Close() 34 | return stmt.Get(d, d) 35 | } 36 | 37 | // AddHandle links a handle to a database 38 | func (d *Database) AddHandle(tx *sqlx.Tx) (id int, err error) { 39 | stmt, err := tx.PrepareNamed("INSERT INTO \"database_handle\" (" + Database_handle_InsertStr + ") VALUES (" + Database_handle_InsertValuesStr + ") RETURNING id") 40 | tx.PrepareNamed("INSERT INTOdatabase_handle\" (" + Database_handle_InsertStr + ") VALUES (" + Database_handle_InsertValuesStr + ") RETURNING id") // want `Invalid query: syntax error at or near "INTOdatabase_handle"` 41 | defer stmt.Close() 42 | return 43 | } 44 | 45 | // DeleteHandles unlinks handles 46 | func (d *Database) DeleteSpecificHandle(tx *sqlx.Tx, id int) error { 47 | _, err := tx.Exec("DELETE FROM \"database_handle\" WHERE identifier = $1", id) 48 | return err 49 | } 50 | 51 | // SetContexts links users as contexts to a database 52 | func (d *Database) SetContexts(tx *sqlx.Tx, contexts []string) error { 53 | for _, cname := range contexts { 54 | tx.Exec("INSERT INTO \"database_context\" (database_id) VALUES ($1, $2)", d.Id, cname) // want `No. of columns \(1\) not equal to no. of values \(2\)` 55 | tx.Exec("INSERT INTO \"database_context\" (database_id, context) VALUES ($1)", d.Id, cname) // // want `No. of columns \(2\) not equal to no. of values \(1\)` 56 | } 57 | return nil 58 | } 59 | 60 | // DeleteContexts deletes the context linked to a database 61 | func (d *Database) DeleteContexts(tx *sqlx.Tx) error { 62 | _, err := tx.NamedExec("DELETE FROM \"\"database_context\" WHERE database_id=$1", d) // want `Invalid query: zero-length delimited identifier at or near """"` 63 | return err 64 | } 65 | 66 | func (d *Database) SetTranslations(tx *sqlx.Tx, field string, translations []struct { 67 | Lang_Isocode string 68 | Text string 69 | }) (err error) { 70 | var transID int 71 | for _, tr := range translations { 72 | err = tx.QueryRow("SELECT count(database_id) FROM database_tr WHERE database_id = $1 AND lang_isocode = $2", d.Id, tr.Lang_Isocode).Scan(&transID) 73 | if transID == 0 { 74 | _, err = tx.Exec("INSERT INTO database_tr (database_id, lang_isocode, description, geographical_limit, bibliography, context_description, source_description, source_relation, copyright, subject) VALUES ($1, $2, '', '', '', '', '', '', '', '')", d.Id, tr.Lang_Isocode) 75 | _, err = tx.Exec("INSERT INTO database_tr (database_id, lang_isocode, description, geographical_limit, bibliography, context_description, source_description, source_relation, copyright, subject) VALUES ($1, $2, '', '', '', '', '', '', '', '')", d.Id) // want `No. of args \(1\) is less than no. of params \(2\)` 76 | } 77 | } 78 | return 79 | } 80 | 81 | // UpdateFields updates "database" fields (crazy isn't it ?) 82 | func (d *Database) UpdateFields(tx *sqlx.Tx, params interface{}, fields ...string) (err error) { 83 | var upd string 84 | query := "UPDATE \"database\" SET " + upd + " WHERE id = :id" 85 | _, err = tx.NamedExec(query, params) 86 | return 87 | } 88 | 89 | // CacheGeom get database sites extend and cache enveloppe 90 | func (d *Database) CacheGeom(tx *sqlx.Tx) (err error) { 91 | var c int 92 | err = tx.Get(&c, "SELECT COUNT(*) FROM (SELECT DISTINCT geom FROM site WHERE database_id = $1) AS temp", d.Id) 93 | // Envelope 94 | if c > 2 { 95 | _, err = tx.NamedExec("UPDATE database SET geographical_extent_geom = (SELECT (ST_Envelope((SELECT ST_Multi(ST_Collect(f.geom)) as singlegeom FROM (SELECT (ST_Dump(geom::::geometry)).geom As geom FROM site WHERE database_id = $1) As f)))) WHERE id =$1", d) // want `Invalid query: syntax error at or near "::"` 96 | } else { 97 | _, err = tx.NamedExec("UPDATE database SET geographical_extent_geom = (SELECT ST_Buffer((SELECT geom FROM site WHERE database_id = $1 AND geom IS NOT NULL LIMIT 1), 1)) WHERE id = $1", d) 98 | } 99 | return 100 | } 101 | 102 | // CacheDates get database sites extend and cache enveloppe 103 | func (d *Database) CacheDates(tx *sqlx.Tx) (err error) { 104 | _, err = tx.NamedExec("UPDATE database SET start_date = (SELECT COALESCE(min(start_date1),-2147483648) FROM site_range WHERE site_id IN (SELECT id FROM site where database_id = $1) AND start_date1 != -2147483648), end_date = (SELECT COALESCE(max(end_date2),2147483647) FROM site_range WHERE site_id IN (SELECT id FROM site where database_id = $1) AND end_date2 != 2147483647) WHERE id = $1", d) 105 | return 106 | } 107 | 108 | // LinkToUserProject links database to project 109 | func (d *Database) LinkToUserProject(tx *sqlx.Tx, project_ID int) (err error) { 110 | _, err = tx.Exec("INSERT INTO project__database (project_id, database_id) VALUES ($1, $2)", project_ID, d.Id) 111 | return 112 | } 113 | 114 | // ExportCSV exports database and sites as as csv file 115 | func (d *Database) ExportCSV(tx *sqlx.Tx, siteIDs ...[]int) (outp string, err error) { 116 | var buff bytes.Buffer 117 | const q = "WITH RECURSIVE nodes_cte(id, path) AS (SELECT ca.id, cat.name::TEXT AS path FROM charac AS ca LEFT JOIN charac_tr cat ON ca.id = cat.charac_id LEFT JOIN lang ON cat.lang_isocode = lang.isocode WHERE lang.isocode = $1 AND ca.parent_id = 0 UNION ALL SELECT ca.id, (p.path || ';' cat.name) FROM nodes_cte AS p, charac AS ca LEFT JOIN charac_tr cat ON ca.id = cat.charac_id LEFT JOIN lang ON cat.lang_isocode = lang.isocode WHERE lang.isocode = $1 AND ca.parent_id = p.id) SELECT * FROM nodes_cte AS n ORDER BY n.id ASC" 118 | 119 | tx.Query(q, d.Default_language) // want `Invalid query: syntax error at or near "cat"` 120 | 121 | tx.Query(`SELECT s.code, s.name, s.city_name, s.city_geonameid, ST_X(s.geom::geometry) as longitude, ST_Y(s.geom::geometry) as latitude, ST_X(s.geom_3d::geometry) as longitude_3d, ST_Y(s.geom_3d::geometry) as latitude3d, ST_Z(s.geom_3d::geometry) as altitude, s.centroid, s.occupation, sr.start_date1, sr.start_date2, sr.end_date1, sr.end_date2, src.exceptional, src.knowledge_type, srctr.bibliography, srctr.comment, c.id as charac_id FROM site s LEFT JOIN site_range sr ON s.id = sr.site_id LEFT JOIN site_tr str ON s.id = str.site_id LEFT JOIN site_range__charac src ON sr.id = src.site_range_id LEFT JOIN site_range__charac_tr srctr ON src.id = srctr.site_range__charac_id LEFT JOIN charac c ON src.charac_id = c.id WHERE s.database_id = $1 AND str.lang_isocode IS NULL OR str.lang_isocode = $2 ORDER BY s.id, sr.id`, d.Id, d.Default_language) 122 | 123 | return buff.String(), nil 124 | } 125 | -------------------------------------------------------------------------------- /sqlargs.go: -------------------------------------------------------------------------------- 1 | package sqlargs 2 | 3 | import ( 4 | "go/ast" 5 | "go/constant" 6 | "go/types" 7 | "strconv" 8 | "strings" 9 | 10 | "golang.org/x/tools/go/analysis" 11 | "golang.org/x/tools/go/analysis/passes/inspect" 12 | "golang.org/x/tools/go/ast/inspector" 13 | ) 14 | 15 | const Doc = `check sql query strings for correctness 16 | 17 | The sqlargs analyser checks the parameters passed to sql queries 18 | and the actual number of parameters written in the query string 19 | and reports any mismatches. 20 | 21 | This is a common occurence when updating a sql query to add/remove 22 | a column.` 23 | 24 | var Analyzer = &analysis.Analyzer{ 25 | Name: "sqlargs", 26 | Doc: Doc, 27 | Run: run, 28 | Requires: []*analysis.Analyzer{inspect.Analyzer}, 29 | RunDespiteErrors: true, 30 | } 31 | 32 | // validExprs contain all the valid selector expressions to check in the code, 33 | // keyed by their package import path. 34 | var validExprs = map[string]map[string]bool{ 35 | "database/sql": { 36 | "DB.Exec": true, 37 | "DB.ExecContext": true, 38 | "DB.QueryRow": true, 39 | "DB.QueryRowContext": true, 40 | "DB.Query": true, 41 | "DB.QueryContext": true, 42 | "DB.Prepare": true, 43 | "DB.PrepareContext": true, 44 | "Tx.Exec": true, 45 | "Tx.ExecContext": true, 46 | "Tx.QueryRow": true, 47 | "Tx.QueryRowContext": true, 48 | "Tx.Query": true, 49 | "Tx.QueryContext": true, 50 | "Stmt.Exec": true, 51 | "Stmt.ExecContext": true, 52 | "Stmt.QueryRow": true, 53 | "Stmt.QueryRowContext": true, 54 | "Stmt.Query": true, 55 | "Stmt.QueryContext": true, 56 | }, 57 | "github.com/jmoiron/sqlx": { 58 | // inherited 59 | "DB.Exec": true, 60 | "DB.ExecContext": true, 61 | "DB.QueryRow": true, 62 | "DB.QueryRowContext": true, 63 | "DB.Query": true, 64 | "DB.QueryContext": true, 65 | "Tx.Exec": true, 66 | "Tx.ExecContext": true, 67 | "Tx.QueryRow": true, 68 | "Tx.QueryRowContext": true, 69 | "Tx.Query": true, 70 | "Tx.QueryContext": true, 71 | "Stmt.Exec": true, 72 | "Stmt.ExecContext": true, 73 | "Stmt.QueryRow": true, 74 | "Stmt.QueryRowContext": true, 75 | "Stmt.Query": true, 76 | "Stmt.QueryContext": true, 77 | // extensions 78 | "DB.MustExec": true, 79 | "DB.MustExecContext": true, 80 | "DB.NamedExec": true, 81 | "DB.NamedExecContext": true, 82 | "DB.QueryRowx": true, 83 | "DB.QueryRowxContext": true, 84 | "DB.Queryx": true, 85 | "DB.QueryxContext": true, 86 | "DB.PrepareNamed": true, 87 | "DB.PrepareNamedContext": true, 88 | "Tx.MustExec": true, 89 | "Tx.MustExecContext": true, 90 | "Tx.QueryRowx": true, 91 | "Tx.QueryRowxContext": true, 92 | "Tx.Queryx": true, 93 | "Tx.QueryxContext": true, 94 | "Tx.PrepareNamed": true, 95 | "Tx.PrepareNamedContext": true, 96 | "Tx.NamedExec": true, 97 | "Tx.NamedExecContext": true, 98 | "Stmt.MustExec": true, 99 | "Stmt.MustExecContext": true, 100 | "Stmt.QueryRowx": true, 101 | "Stmt.QueryRowxContext": true, 102 | "Stmt.Queryx": true, 103 | "Stmt.QueryxContext": true, 104 | }, 105 | } 106 | 107 | func run(pass *analysis.Pass) (interface{}, error) { 108 | // Getting the list of import paths. 109 | var pkgs []string 110 | for pkg := range validExprs { 111 | pkgs = append(pkgs, pkg) 112 | } 113 | 114 | // We ignore packages that do not import the required paths. 115 | if !imports(pass.Pkg, true, pkgs...) { 116 | return nil, nil 117 | } 118 | 119 | inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 120 | // We filter only function calls. 121 | nodeFilter := []ast.Node{ 122 | (*ast.CallExpr)(nil), 123 | } 124 | 125 | inspect.Preorder(nodeFilter, func(n ast.Node) { 126 | call := n.(*ast.CallExpr) 127 | // Now we need to find expressions like these in the source code. 128 | // db.Exec(`INSERT INTO <> (foo, bar) VALUES ($1, $2)`, param1, param2) 129 | 130 | // A CallExpr has 2 parts - Fun and Args. 131 | // A Fun can either be an Ident (Fun()) or a SelectorExpr (foo.Fun()). 132 | // Since we are looking for patterns like db.Exec, we need to filter only SelectorExpr 133 | // We will ignore dot imported functions. 134 | sel, ok := call.Fun.(*ast.SelectorExpr) 135 | if !ok { 136 | return 137 | } 138 | 139 | // A SelectorExpr(db.Exec) has 2 parts - X (db) and Sel (Exec/Query/QueryRow). 140 | // Now that we are inside the SelectorExpr, we need to verify 2 things - 141 | // 1. The function name is Exec, Query or QueryRow; because that is what we are interested in. 142 | // 2. The type of the selector is sql.DB, sql.Tx or sql.Stmt. 143 | if !isProperSelExpr(sel, pass.TypesInfo) { 144 | return 145 | } 146 | // Length of args has to be minimum of 1 because we only take Exec, Query or QueryRow; 147 | // all of which have atleast 1 argument. But still writing a sanity check. 148 | if len(call.Args) == 0 { 149 | return 150 | } 151 | 152 | // Check if it is a Context call, then re-slice the first item which is a context. 153 | // XXX: This is a heuristic. Most DB code which takes context always ends with "Context" 154 | // and takes the ctx as the first param. But there is no guarantee for this. 155 | if strings.HasSuffix(sel.Sel.Name, "Context") { 156 | call.Args = call.Args[1:] 157 | } 158 | 159 | // Another heuristic: if a function begins with Prepare, it usually returns 160 | // a prepared statement; in which case we don't need to check for arguments. 161 | checkArgs := !strings.HasPrefix(sel.Sel.Name, "Prepare") 162 | 163 | arg0 := call.Args[0] 164 | typ, ok := pass.TypesInfo.Types[arg0] 165 | if !ok { 166 | return 167 | } 168 | query := "" 169 | if typ.Value != nil { 170 | query = constant.StringVal(typ.Value) 171 | } else { // query is a variable. 172 | ident, ok := arg0.(*ast.Ident) 173 | if !ok { 174 | return 175 | } 176 | if ident.Obj == nil { 177 | return 178 | } 179 | assign, ok := ident.Obj.Decl.(*ast.AssignStmt) 180 | if !ok { 181 | return 182 | } 183 | basic, ok := assign.Rhs[0].(*ast.BasicLit) 184 | if !ok { 185 | return 186 | } 187 | query, _ = strconv.Unquote(basic.Value) 188 | } 189 | analyzeQuery(query, call, pass, checkArgs) 190 | }) 191 | 192 | return nil, nil 193 | } 194 | 195 | func isProperSelExpr(sel *ast.SelectorExpr, typesInfo *types.Info) bool { 196 | // Get the type info of X of the selector. 197 | typ, ok := typesInfo.Types[sel.X] 198 | if !ok { 199 | return false 200 | } 201 | 202 | t := typ.Type 203 | // If it is a pointer, get the element. 204 | if ptr, ok := t.(*types.Pointer); ok { 205 | t = ptr.Elem() 206 | } 207 | named, ok := t.(*types.Named) 208 | if !ok { 209 | return false 210 | } 211 | 212 | fnName := sel.Sel.Name 213 | objName := named.Obj().Name() 214 | 215 | // Check valid selector expressions for a match. 216 | for path, obj := range validExprs { 217 | // If the object is a direct match. 218 | if imports(named.Obj().Pkg(), false, path) && obj[objName+"."+fnName] { 219 | return true 220 | } 221 | 222 | // Otherwise, it can be a struct which embeds *sql.DB 223 | u := named.Underlying() 224 | st, ok := u.(*types.Struct) 225 | if !ok { 226 | continue 227 | } 228 | for i := 0; i < st.NumFields(); i++ { 229 | f := st.Field(i) 230 | // check if the embedded field is *sql.DB-ish or not. 231 | if f.Embedded() && imports(f.Pkg(), true, path) && obj[f.Name()+"."+fnName] { 232 | return true 233 | } 234 | } 235 | } 236 | return false 237 | } 238 | 239 | func imports(pkg *types.Package, checkImports bool, paths ...string) bool { 240 | if pkg == nil { 241 | return false 242 | } 243 | if checkImports { 244 | for _, imp := range pkg.Imports() { 245 | for _, p := range paths { 246 | if stripVendor(imp.Path()) == p { 247 | return true 248 | } 249 | } 250 | } 251 | } else { 252 | for _, p := range paths { 253 | if stripVendor(pkg.Path()) == p { 254 | return true 255 | } 256 | } 257 | } 258 | return false 259 | } 260 | 261 | // stripVendor strips out the vendor path prefix 262 | func stripVendor(pkgPath string) string { 263 | idx := strings.LastIndex(pkgPath, "vendor/") 264 | if idx < 0 { 265 | return pkgPath 266 | } 267 | // len("vendor/") == 7 268 | return pkgPath[idx+7:] 269 | } 270 | --------------------------------------------------------------------------------