├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── gif └── sql.gif ├── go.mod ├── go.sum ├── sqlfmt.go └── sqlfmt ├── ast.go ├── errors.go ├── format.go ├── format_test.go ├── lexer ├── token.go ├── tokenizer.go └── tokenizer_test.go ├── parser ├── group │ ├── and_group.go │ ├── and_group_test.go │ ├── case_group.go │ ├── case_group_test.go │ ├── delete_group.go │ ├── delete_group_test.go │ ├── from_group.go │ ├── from_group_test.go │ ├── function_group.go │ ├── function_group_test.go │ ├── group_by_group.go │ ├── group_by_group_test.go │ ├── having_group.go │ ├── having_group_test.go │ ├── insert_group.go │ ├── insert_group_test.go │ ├── join_group.go │ ├── join_group_test.go │ ├── limit_clause.go │ ├── limit_clause_test.go │ ├── lock_group.go │ ├── lock_group_test.go │ ├── or_group.go │ ├── or_group_test.go │ ├── order_by_group.go │ ├── order_by_group_test.go │ ├── parenthesis_group.go │ ├── reindenter.go │ ├── returning_group.go │ ├── returning_group_test.go │ ├── select_group.go │ ├── select_group_test.go │ ├── set_group.go │ ├── set_group_test.go │ ├── subquery_group _test.go │ ├── subquery_group.go │ ├── tie_clause_group.go │ ├── tie_clause_test.go │ ├── type_cast_group.go │ ├── update_group.go │ ├── update_group_test.go │ ├── util.go │ ├── values_group.go │ ├── values_group_test.go │ ├── where_group.go │ ├── where_group_test.go │ └── with_group.go ├── parser.go ├── parser_test.go ├── retriever.go └── retriever_test.go ├── sqlfmt.go └── testdata ├── testing_gofile.go └── testing_gofile_url_query.go /.gitignore: -------------------------------------------------------------------------------- 1 | dummy.go 2 | vendor 3 | .DS_Store 4 | 5 | 6 | # for GoLand user 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xmlq 9 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.11" 5 | - "1.12" 6 | - tip 7 | 8 | script: 9 | - go fmt ./... 10 | - go test -v ./sqlfmt 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Yu Tanaka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlfmt 2 | 3 | [![Build Status](https://travis-ci.org/kanmu/go-sqlfmt.svg?branch=master)](https://travis-ci.org/kanmu/go-sqlfmt) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/kanmu/go-sqlfmt)](https://goreportcard.com/report/github.com/kanmu/go-sqlfmt) 5 | 6 | ## Description 7 | 8 | The sqlfmt formats PostgreSQL statements in `.go` files into a consistent style. 9 | 10 | ## Example 11 | 12 | _Unformatted SQL in a `.go` file_ 13 | 14 | ```go 15 | package main 16 | 17 | import ( 18 | "database/sql" 19 | ) 20 | 21 | 22 | func sendSQL() int { 23 | var id int 24 | var db *sql.DB 25 | db.QueryRow(` 26 | select xxx ,xxx ,xxx 27 | , case 28 | when xxx is null then xxx 29 | else true 30 | end as xxx 31 | from xxx as xxx join xxx on xxx = xxx join xxx as xxx on xxx = xxx 32 | left outer join xxx as xxx 33 | on xxx = xxx 34 | where xxx in ( select xxx from ( select xxx from xxx ) as xxx where xxx = xxx ) 35 | and xxx in ($2, $3) order by xxx`).Scan(&id) 36 | return id 37 | } 38 | ``` 39 | 40 | The above will be formatted into the following: 41 | 42 | ```go 43 | package main 44 | 45 | import ( 46 | "database/sql" 47 | ) 48 | 49 | func sendSQL() int { 50 | var id int 51 | var db *sql.DB 52 | db.QueryRow(` 53 | SELECT 54 | xxx 55 | , xxx 56 | , xxx 57 | , CASE 58 | WHEN xxx IS NULL THEN xxx 59 | ELSE true 60 | END AS xxx 61 | FROM xxx AS xxx 62 | JOIN xxx 63 | ON xxx = xxx 64 | JOIN xxx AS xxx 65 | ON xxx = xxx 66 | LEFT OUTER JOIN xxx AS xxx 67 | ON xxx = xxx 68 | WHERE xxx IN ( 69 | SELECT 70 | xxx 71 | FROM ( 72 | SELECT 73 | xxx 74 | FROM xxx 75 | ) AS xxx 76 | WHERE xxx = xxx 77 | ) 78 | AND xxx IN ($2, $3) 79 | ORDER BY 80 | xxx`).Scan(&id) 81 | return id 82 | } 83 | ``` 84 | 85 | ## Installation 86 | 87 | ```bash 88 | run git clone and go build -o sqlfmt 89 | ``` 90 | ## Usage 91 | 92 | - Provide flags and input files or directory 93 | ```bash 94 | $ sqlfmt -w input_file.go 95 | ``` 96 | 97 | ## Flags 98 | ``` 99 | -l 100 | Do not print reformatted sources to standard output. 101 | If a file's formatting is different from src, print its name 102 | to standard output. 103 | -d 104 | Do not print reformatted sources to standard output. 105 | If a file's formatting is different than src, print diffs 106 | to standard output. 107 | -w 108 | Do not print reformatted sources to standard output. 109 | If a file's formatting is different from src, overwrite it 110 | with gofmt style. 111 | -distance 112 | Write the distance from the edge to the begin of SQL statements 113 | ``` 114 | 115 | ## Limitations 116 | 117 | - The `sqlfmt` is only able to format SQL statements that are surrounded with **back quotes** and values in **`QueryRow`**, **`Query`**, **`Exec`** functions from the `"database/sql"` package. 118 | 119 | The following SQL statements will be formatted: 120 | 121 | ```go 122 | func sendSQL() int { 123 | var id int 124 | var db *sql.DB 125 | db.QueryRow(`select xxx from xxx`).Scan(&id) 126 | return id 127 | } 128 | ``` 129 | 130 | The following SQL statements will NOT be formatted: 131 | 132 | ```go 133 | // values in fmt.Println() are not formatting targets 134 | func sendSQL() int { 135 | fmt.Println(`select * from xxx`) 136 | } 137 | 138 | // nor are statements surrounded with double quotes 139 | func sendSQL() int { 140 | var id int 141 | var db *sql.DB 142 | db.QueryRow("select xxx from xxx").Scan(&id) 143 | return id 144 | } 145 | ``` 146 | 147 | ## Not Supported 148 | 149 | - `IS DISTINCT FROM` 150 | - `WITHIN GROUP` 151 | - `DISTINCT ON(xxx)` 152 | - `select(array)` 153 | - Comments after commna such as 154 | `select xxxx, --comment 155 | xxxx 156 | ` 157 | - Nested square brackets or braces such as `[[xx], xx]` 158 | - Currently being formatted into this: `[[ xx], xx]` 159 | - Ideally, it should be formatted into this: `[[xx], xx]` 160 | 161 | - Nested functions such as `sum(average(xxx))` 162 | - Currently being formatted into this: `SUM( AVERAGE(xxx))` 163 | - Ideally, it should be formatted into this: `SUM(AVERAGE(xxx))` 164 | 165 | 166 | 167 | ## Future Work 168 | 169 | - [ ] Refactor 170 | - [ ] Turn it into a plug-in or an extension for editors 171 | 172 | ## Contribution 173 | 174 | Thank you for thinking of contributing to the sqlfmt! 175 | Pull Requests are welcome! 176 | 177 | 1. Fork ([https://github.com/kanmu/go-sqlfmt)) 178 | 2. Create a feature branch 179 | 3. Commit your changes 180 | 4. Rebase your local changes against the master branch 181 | 5. Create new Pull Request 182 | 183 | ## License 184 | 185 | MIT 186 | -------------------------------------------------------------------------------- /gif/sql.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanmu/go-sqlfmt/d1e63e2ee5eb36cbbc28c9d9471ab05786b5dae7/gif/sql.gif -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kanmu/go-sqlfmt 2 | 3 | go 1.13 4 | 5 | require github.com/pkg/errors v0.8.1 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 2 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 3 | -------------------------------------------------------------------------------- /sqlfmt.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "runtime" 14 | "strings" 15 | 16 | "github.com/pkg/errors" 17 | 18 | "github.com/kanmu/go-sqlfmt/sqlfmt" 19 | ) 20 | 21 | var ( 22 | // main operation modes 23 | list = flag.Bool("l", false, "list files whose formatting differs from goreturns's") 24 | write = flag.Bool("w", false, "write result to (source) file instead of stdout") 25 | doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") 26 | options = &sqlfmt.Options{} 27 | ) 28 | 29 | func init() { 30 | flag.IntVar(&options.Distance, "distance", 0, "write the distance from the edge to the begin of SQL statements") 31 | } 32 | 33 | func usage() { 34 | fmt.Fprintf(os.Stderr, "usage: sqlfmt [flags] [path ...]\n") 35 | flag.PrintDefaults() 36 | } 37 | 38 | func isGoFile(info os.FileInfo) bool { 39 | name := info.Name() 40 | return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") 41 | } 42 | 43 | func visitFile(path string, info os.FileInfo, err error) error { 44 | if err == nil && isGoFile(info) { 45 | err = processFile(path, nil, os.Stdout) 46 | } 47 | if err != nil { 48 | processError(errors.Wrap(err, "visit file failed")) 49 | 50 | } 51 | return nil 52 | } 53 | 54 | func walkDir(path string) { 55 | filepath.Walk(path, visitFile) 56 | } 57 | 58 | func processFile(filename string, in io.Reader, out io.Writer) error { 59 | if in == nil { 60 | f, err := os.Open(filename) 61 | if err != nil { 62 | return errors.Wrap(err, "os.Open failed") 63 | } 64 | in = f 65 | } 66 | 67 | src, err := ioutil.ReadAll(in) 68 | if err != nil { 69 | return errors.Wrap(err, "ioutil.ReadAll failed") 70 | } 71 | 72 | res, err := sqlfmt.Process(filename, src, options) 73 | if err != nil { 74 | return errors.Wrap(err, "sqlfmt.Process failed") 75 | } 76 | 77 | if !bytes.Equal(src, res) { 78 | if *list { 79 | fmt.Fprintln(out, filename) 80 | } 81 | if *write { 82 | if err = ioutil.WriteFile(filename, res, 0); err != nil { 83 | return errors.Wrap(err, "ioutil.WriteFile failed") 84 | } 85 | } 86 | if *doDiff { 87 | data, err := diff(src, res) 88 | if err != nil { 89 | return errors.Wrap(err, "diff failed") 90 | } 91 | fmt.Printf("diff %s gofmt/%s\n", filename, filename) 92 | out.Write(data) 93 | } 94 | if !*list && !*write && !*doDiff { 95 | _, err = out.Write(res) 96 | if err != nil { 97 | return errors.Wrap(err, "out.Write failed") 98 | } 99 | } 100 | } 101 | return nil 102 | } 103 | 104 | func sqlfmtMain() { 105 | flag.Usage = usage 106 | flag.Parse() 107 | 108 | // the user is piping their source into go-sqlfmt 109 | if flag.NArg() == 0 { 110 | if *write { 111 | log.Fatal("can not use -w while using pipeline") 112 | } 113 | if err := processFile("", os.Stdin, os.Stdout); err != nil { 114 | processError(errors.Wrap(err, "processFile failed")) 115 | } 116 | return 117 | } 118 | 119 | for i := 0; i < flag.NArg(); i++ { 120 | path := flag.Arg(i) 121 | switch dir, err := os.Stat(path); { 122 | case err != nil: 123 | processError(err) 124 | case dir.IsDir(): 125 | walkDir(path) 126 | default: 127 | info, err := os.Stat(path) 128 | if err != nil { 129 | processError(err) 130 | } 131 | if isGoFile(info) { 132 | err = processFile(path, nil, os.Stdout) 133 | if err != nil { 134 | processError(err) 135 | } 136 | } 137 | } 138 | } 139 | } 140 | 141 | func main() { 142 | runtime.GOMAXPROCS(runtime.NumCPU()) 143 | sqlfmtMain() 144 | } 145 | 146 | func diff(b1, b2 []byte) (data []byte, err error) { 147 | f1, err := ioutil.TempFile("", "sqlfmt") 148 | if err != nil { 149 | return 150 | } 151 | defer os.Remove(f1.Name()) 152 | defer f1.Close() 153 | 154 | f2, err := ioutil.TempFile("", "sqlfmt") 155 | if err != nil { 156 | return 157 | } 158 | defer os.Remove(f2.Name()) 159 | defer f2.Close() 160 | 161 | f1.Write(b1) 162 | f2.Write(b2) 163 | 164 | data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput() 165 | if len(data) > 0 { 166 | // diff exits with a non-zero status when the files don't match. 167 | // Ignore that failure as long as we get output. 168 | err = nil 169 | } 170 | return 171 | } 172 | 173 | func processError(err error) { 174 | switch err.(type) { 175 | case *sqlfmt.FormatError: 176 | log.Println(err) 177 | default: 178 | log.Fatal(err) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /sqlfmt/ast.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "log" 8 | "strings" 9 | 10 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser/group" 11 | ) 12 | 13 | // sqlfmt retrieves all strings from "Query" and "QueryRow" and "Exec" functions in .go file 14 | const ( 15 | QUERY = "Query" 16 | QUERYROW = "QueryRow" 17 | EXEC = "Exec" 18 | ) 19 | 20 | // replaceAst replace ast node with formatted SQL statement 21 | func replaceAst(f *ast.File, fset *token.FileSet, options *Options) { 22 | ast.Inspect(f, func(n ast.Node) bool { 23 | if x, ok := n.(*ast.CallExpr); ok { 24 | if fun, ok := x.Fun.(*ast.SelectorExpr); ok { 25 | funcName := fun.Sel.Name 26 | if funcName == QUERY || funcName == QUERYROW || funcName == EXEC { 27 | // not for parsing url.Query 28 | if len(x.Args) > 0 { 29 | if arg, ok := x.Args[0].(*ast.BasicLit); ok { 30 | sqlStmt := arg.Value 31 | if !strings.HasPrefix(sqlStmt, "`") { 32 | return true 33 | } 34 | src := strings.Trim(sqlStmt, "`") 35 | res, err := Format(src, options) 36 | if err != nil { 37 | log.Println(fmt.Sprintf("Format failed at %s: %v", fset.Position(arg.Pos()), err)) 38 | return true 39 | } 40 | // FIXME 41 | // more elegant 42 | arg.Value = "`" + res + strings.Repeat(group.WhiteSpace, options.Distance) + "`" 43 | } 44 | } 45 | } 46 | } 47 | } 48 | return true 49 | }) 50 | } 51 | -------------------------------------------------------------------------------- /sqlfmt/errors.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // FormatError is an error that occurred while sqlfmt.Process 8 | type FormatError struct { 9 | msg string 10 | } 11 | 12 | func (e *FormatError) Error() string { 13 | return fmt.Sprint(e.msg) 14 | } 15 | -------------------------------------------------------------------------------- /sqlfmt/format.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 10 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser" 11 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser/group" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | // Format formats src in 3 steps 16 | // 1: tokenize src 17 | // 2: parse tokens by SQL clause group 18 | // 3: for each clause group (Reindenter), add indentation or new line in the correct position 19 | func Format(src string, options *Options) (string, error) { 20 | t := lexer.NewTokenizer(src) 21 | tokens, err := t.GetTokens() 22 | if err != nil { 23 | return src, errors.Wrap(err, "Tokenize failed") 24 | } 25 | 26 | rs, err := parser.ParseTokens(tokens) 27 | if err != nil { 28 | return src, errors.Wrap(err, "ParseTokens failed") 29 | } 30 | 31 | res, err := getFormattedStmt(rs, options.Distance) 32 | if err != nil { 33 | return src, errors.Wrap(err, "getFormattedStmt failed") 34 | } 35 | 36 | if !compare(src, res) { 37 | return src, fmt.Errorf("the formatted statement has diffed from the source") 38 | } 39 | return res, nil 40 | } 41 | 42 | func getFormattedStmt(rs []group.Reindenter, distance int) (string, error) { 43 | var buf bytes.Buffer 44 | 45 | for _, r := range rs { 46 | if err := r.Reindent(&buf); err != nil { 47 | return "", errors.Wrap(err, "Reindent failed") 48 | } 49 | } 50 | 51 | if distance != 0 { 52 | return putDistance(buf.String(), distance), nil 53 | } 54 | return buf.String(), nil 55 | } 56 | 57 | func putDistance(src string, distance int) string { 58 | scanner := bufio.NewScanner(strings.NewReader(src)) 59 | 60 | var result string 61 | for scanner.Scan() { 62 | result += fmt.Sprintf("%s%s%s", strings.Repeat(group.WhiteSpace, distance), scanner.Text(), "\n") 63 | } 64 | return result 65 | } 66 | 67 | // returns false if the value of formatted statement (without any space) differs from source statement 68 | func compare(src string, res string) bool { 69 | before := removeSpace(src) 70 | after := removeSpace(res) 71 | 72 | if v := strings.Compare(before, after); v != 0 { 73 | return false 74 | } 75 | return true 76 | } 77 | 78 | // removes whitespaces and new lines from src 79 | func removeSpace(src string) string { 80 | var result []rune 81 | for _, r := range src { 82 | if string(r) == "\n" || string(r) == " " || string(r) == "\t" || string(r) == " " { 83 | continue 84 | } 85 | result = append(result, r) 86 | } 87 | return strings.ToLower(string(result)) 88 | } 89 | -------------------------------------------------------------------------------- /sqlfmt/format_test.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCompare(t *testing.T) { 8 | test := struct { 9 | before string 10 | after string 11 | want bool 12 | }{ 13 | before: "select * from xxx", 14 | after: "select\n *\nFROM xxx", 15 | want: true, 16 | } 17 | if got := compare(test.before, test.after); got != test.want { 18 | t.Errorf("want %v#v got %#v", test.want, got) 19 | } 20 | } 21 | 22 | func TestRemove(t *testing.T) { 23 | got := removeSpace("select xxx from xxx") 24 | want := "selectxxxfromxxx" 25 | if got != want { 26 | t.Errorf("want %#v, got %#v", want, got) 27 | } 28 | } 29 | 30 | func TestFormat(t *testing.T) { 31 | for _, tt := range formatTestingData { 32 | opt := &Options{} 33 | t.Run(tt.src, func(t *testing.T) { 34 | got, err := Format(tt.src, opt) 35 | if err != nil { 36 | t.Errorf("should be nil, got %v", err) 37 | } 38 | if tt.want != got { 39 | t.Errorf("\nwant %#v, \ngot %#v", tt.want, got) 40 | } 41 | }) 42 | } 43 | } 44 | 45 | var formatTestingData = []struct { 46 | src string 47 | want string 48 | }{ 49 | { 50 | src: `select name, age, id from user join transaction on a = b`, 51 | want: ` 52 | SELECT 53 | name 54 | , age 55 | , id 56 | FROM user 57 | JOIN transaction 58 | ON a = b`, 59 | }, 60 | { 61 | src: `select 62 | xxx 63 | , xxx 64 | , xxx, 65 | xxx 66 | 67 | from xxx a 68 | join xxx b 69 | on xxx = xxx 70 | join xxx xxx 71 | on xxx = xxx 72 | left outer join xxx 73 | on xxx = xxx 74 | where xxx = xxx 75 | and xxx = true 76 | and xxx is null 77 | `, 78 | want: ` 79 | SELECT 80 | xxx 81 | , xxx 82 | , xxx 83 | , xxx 84 | FROM xxx a 85 | JOIN xxx b 86 | ON xxx = xxx 87 | JOIN xxx xxx 88 | ON xxx = xxx 89 | LEFT OUTER JOIN xxx 90 | ON xxx = xxx 91 | WHERE xxx = xxx 92 | AND xxx = true 93 | AND xxx IS NULL`, 94 | }, 95 | { 96 | src: `select 97 | count(xxx) 98 | from xxx as xxx 99 | join xxx on xxx = xxx 100 | where xxx = $1 101 | and xxx = $2`, 102 | want: ` 103 | SELECT 104 | COUNT(xxx) 105 | FROM xxx AS xxx 106 | JOIN xxx 107 | ON xxx = xxx 108 | WHERE xxx = $1 109 | AND xxx = $2`, 110 | }, { 111 | src: `select 112 | xxx 113 | from xxx 114 | and xxx in ($3, $4, $5, $6)`, 115 | want: ` 116 | SELECT 117 | xxx 118 | FROM xxx 119 | AND xxx IN ($3, $4, $5, $6)`, 120 | }, 121 | { 122 | src: `select 123 | xxx 124 | , xxx 125 | , xxx 126 | , case 127 | when xxx is null THEN $1 128 | ELSE $2 129 | end as xxx 130 | , case 131 | when xxx is null THEN 0 132 | ELSE xxx 133 | end as xxx 134 | from xxx`, 135 | want: ` 136 | SELECT 137 | xxx 138 | , xxx 139 | , xxx 140 | , CASE 141 | WHEN xxx IS NULL THEN $1 142 | ELSE $2 143 | END AS xxx 144 | , CASE 145 | WHEN xxx IS NULL THEN 0 146 | ELSE xxx 147 | END AS xxx 148 | FROM xxx`, 149 | }, 150 | { 151 | src: `select 152 | xxx 153 | from xxx 154 | order by xxx`, 155 | want: ` 156 | SELECT 157 | xxx 158 | FROM xxx 159 | ORDER BY 160 | xxx`, 161 | }, 162 | { 163 | src: `select 164 | xxx 165 | from xxx 166 | where xxx in ($1, $2) 167 | order by id`, 168 | want: ` 169 | SELECT 170 | xxx 171 | FROM xxx 172 | WHERE xxx IN ($1, $2) 173 | ORDER BY 174 | id`, 175 | }, 176 | { 177 | src: `UPDATE xxx 178 | SET 179 | first_name = $1 180 | , last_name = $2 181 | WHERE id = $8`, 182 | want: ` 183 | UPDATE 184 | xxx 185 | SET 186 | first_name = $1 187 | , last_name = $2 188 | WHERE id = $8`, 189 | }, 190 | { 191 | src: `UPDATE xxx 192 | SET 193 | xxx = $1 194 | , xxx = $2 195 | WHERE xxx = $3 196 | RETURNING 197 | xxx 198 | , xxx`, 199 | want: ` 200 | UPDATE 201 | xxx 202 | SET 203 | xxx = $1 204 | , xxx = $2 205 | WHERE xxx = $3 206 | RETURNING 207 | xxx 208 | , xxx`, 209 | }, 210 | { 211 | src: `SELECT 212 | xxx 213 | from xxx 214 | where xxx = $1 215 | and xxx != $3`, 216 | want: ` 217 | SELECT 218 | xxx 219 | FROM xxx 220 | WHERE xxx = $1 221 | AND xxx != $3`, 222 | }, 223 | { 224 | src: `SELECT 225 | xxx 226 | , xxx 227 | , CASE 228 | WHEN xxx IS NULL and xxx IS NULL THEN $1 229 | WHEN xxx IS NULL and xxx IS NOT NULL THEN xxx 230 | ELSE $4 231 | END 232 | , CASE 233 | WHEN xxx IS NULL xxx IS NULL THEN xxx 234 | WHEN xxx IS NULL THEN xxx 235 | ELSE xxx 236 | END AS xxx 237 | FROM xxx 238 | LEFT OUTER JOIN xxx 239 | ON xxx = xxx 240 | LEFT OUTER JOIN ( 241 | select 242 | xxx, xxx, xxx ,xxx 243 | from xxx 244 | join xxx 245 | on xxx = xxx 246 | join xxx 247 | on xxx = xxx 248 | where xxx = $5 249 | ) as xxx 250 | ON xxx = xxx 251 | where xxx`, 252 | want: ` 253 | SELECT 254 | xxx 255 | , xxx 256 | , CASE 257 | WHEN xxx IS NULL AND xxx IS NULL THEN $1 258 | WHEN xxx IS NULL AND xxx IS NOT NULL THEN xxx 259 | ELSE $4 260 | END 261 | , CASE 262 | WHEN xxx IS NULL xxx IS NULL THEN xxx 263 | WHEN xxx IS NULL THEN xxx 264 | ELSE xxx 265 | END AS xxx 266 | FROM xxx 267 | LEFT OUTER JOIN xxx 268 | ON xxx = xxx 269 | LEFT OUTER JOIN ( 270 | SELECT 271 | xxx 272 | , xxx 273 | , xxx 274 | , xxx 275 | FROM xxx 276 | JOIN xxx 277 | ON xxx = xxx 278 | JOIN xxx 279 | ON xxx = xxx 280 | WHERE xxx = $5 281 | ) AS xxx 282 | ON xxx = xxx 283 | WHERE xxx`, 284 | }, 285 | { 286 | src: `SELECT 287 | xxx 288 | , xxx 289 | , CASE 290 | WHEN xxx IS NULL AND xxx IS NULL THEN xxx 291 | WHEN xxx > xxx THEN xxx 292 | WHEN xxx <= xxx THEN xxx 293 | END as xxx 294 | , xxx 295 | , xxx 296 | , xxx 297 | , xxx 298 | , xxx 299 | , sum(xxx, 0) as xxx 300 | , SUM(xxx, 0) as xxx 301 | , avg(xxx, 0) as xxx 302 | , AVG(xxx, 0) as xxx 303 | , max(xxx, 0) as xxx 304 | , min(xxx, 0) as xxx 305 | , extract(xxx, 0) as xxx 306 | , EXTRACT(xxx, 0) as xxx 307 | , cast(xxx, 0) as xxx 308 | , TRIM(xxx, 0) as xxx 309 | , xmlforest(xxx, 0) as xxx 310 | FROM xxx 311 | LEFT OUTER JOIN ( 312 | SELECT 313 | xxx 314 | , xxx 315 | , xxx 316 | FROM ( 317 | select xxx from xxx 318 | ) 319 | WHERE xxx = $1 320 | ) xxx 321 | ON xxx 322 | WHERE xxx = xxx`, 323 | want: ` 324 | SELECT 325 | xxx 326 | , xxx 327 | , CASE 328 | WHEN xxx IS NULL AND xxx IS NULL THEN xxx 329 | WHEN xxx > xxx THEN xxx 330 | WHEN xxx <= xxx THEN xxx 331 | END AS xxx 332 | , xxx 333 | , xxx 334 | , xxx 335 | , xxx 336 | , xxx 337 | , SUM(xxx, 0) AS xxx 338 | , SUM(xxx, 0) AS xxx 339 | , AVG(xxx, 0) AS xxx 340 | , AVG(xxx, 0) AS xxx 341 | , MAX(xxx, 0) AS xxx 342 | , MIN(xxx, 0) AS xxx 343 | , EXTRACT(xxx, 0) AS xxx 344 | , EXTRACT(xxx, 0) AS xxx 345 | , CAST(xxx, 0) AS xxx 346 | , TRIM(xxx, 0) AS xxx 347 | , XMLFOREST(xxx, 0) AS xxx 348 | FROM xxx 349 | LEFT OUTER JOIN ( 350 | SELECT 351 | xxx 352 | , xxx 353 | , xxx 354 | FROM ( 355 | SELECT 356 | xxx 357 | FROM xxx 358 | ) 359 | WHERE xxx = $1 360 | ) xxx 361 | ON xxx 362 | WHERE xxx = xxx`, 363 | }, 364 | { 365 | src: `select 1 + 1, 2 - 1, 3 * 2, 8 / 2, 366 | 1 + 1 * 3, 3 + 8 / 7, 367 | 1+1*3, 312+8/7, 368 | 4%3, 7^5 from xxx`, 369 | want: ` 370 | SELECT 371 | 1 + 1 372 | , 2 - 1 373 | , 3 * 2 374 | , 8 / 2 375 | , 1 + 1 * 3 376 | , 3 + 8 / 7 377 | , 1+1*3 378 | , 312+8/7 379 | , 4%3 380 | , 7^5 381 | FROM xxx`, 382 | }, 383 | { 384 | src: `select 385 | array[], 386 | array[1] 387 | from 388 | baz`, 389 | want: ` 390 | SELECT 391 | array [] 392 | , array [1] 393 | FROM baz`, 394 | }, 395 | 396 | { 397 | src: `select 398 | foo, 399 | array ( select 400 | bar 401 | from 402 | quz 403 | where 404 | baz.foo = quz.foo 405 | ) 406 | from 407 | baz`, 408 | want: ` 409 | SELECT 410 | foo 411 | , array ( 412 | SELECT 413 | bar 414 | FROM quz 415 | WHERE baz.foo = quz.foo 416 | ) 417 | FROM baz`, 418 | }, 419 | { 420 | src: ` 421 | select 422 | '{1,2,3}'::int[], 423 | '{{1,2}, {3,4}}'::int[][], 424 | '{{1,2}, {3,4}}'::int[][2] 425 | from xxx 426 | `, 427 | want: ` 428 | SELECT 429 | '{1,2,3}'::int [] 430 | , '{{1,2}, {3,4}}'::int [] [] 431 | , '{{1,2}, {3,4}}'::int [] [2] 432 | FROM xxx`, 433 | }, 434 | { 435 | src: ` 436 | select 437 | '2015-01-01 00:00:00-09'::timestamptz at time zone 'America/Chicago' 438 | from xxx`, 439 | want: ` 440 | SELECT 441 | '2015-01-01 00:00:00-09'::timestamptz AT TIME ZONE 'America/Chicago' 442 | FROM xxx`, 443 | }, 444 | { 445 | src: `select 446 | foo between bexpr::text and bar, 447 | foo between -42 and bar, 448 | foo between +3 and bar, 449 | foo between 1 + 1 and bar, 450 | foo between 1 - 1 and bar, 451 | foo between 1 * 1 and bar, 452 | foo between 1 / 1 and bar, 453 | foo between 1 % 1 and bar, 454 | foo between 1 ^ 1 and bar, 455 | foo between 1 < 1 and bar, 456 | foo between 1 > 1 and bar, 457 | foo between 1 = 1 and bar, 458 | foo between 1 <= 1 and bar, 459 | foo between 1 >= 1 and bar, 460 | foo between 1 != 1 and bar, 461 | foo between 1 @> 1 and bar, 462 | foo between @1 and bar, 463 | foo between 5 ! and bar, 464 | false between foo is document and bar, 465 | false between foo is not document and bar 466 | from 467 | baz`, 468 | want: ` 469 | SELECT 470 | foo BETWEEN bexpr::text AND bar 471 | , foo BETWEEN -42 AND bar 472 | , foo BETWEEN +3 AND bar 473 | , foo BETWEEN 1 + 1 AND bar 474 | , foo BETWEEN 1 - 1 AND bar 475 | , foo BETWEEN 1 * 1 AND bar 476 | , foo BETWEEN 1 / 1 AND bar 477 | , foo BETWEEN 1 % 1 AND bar 478 | , foo BETWEEN 1 ^ 1 AND bar 479 | , foo BETWEEN 1 < 1 AND bar 480 | , foo BETWEEN 1 > 1 AND bar 481 | , foo BETWEEN 1 = 1 AND bar 482 | , foo BETWEEN 1 <= 1 AND bar 483 | , foo BETWEEN 1 >= 1 AND bar 484 | , foo BETWEEN 1 != 1 AND bar 485 | , foo BETWEEN 1 @> 1 AND bar 486 | , foo BETWEEN @1 AND bar 487 | , foo BETWEEN 5 ! AND bar 488 | , false BETWEEN foo IS document AND bar 489 | , false BETWEEN foo IS NOT document AND bar 490 | FROM baz`, 491 | }, 492 | { 493 | src: ` 494 | select 495 | b'10101', 496 | x'0123456789abcdefABCDEF' from xxx 497 | `, 498 | want: ` 499 | SELECT 500 | b '10101' 501 | , x '0123456789abcdefABCDEF' 502 | FROM xxx`, 503 | }, 504 | { 505 | src: ` 506 | select 507 | foo and bar, 508 | baz or quz 509 | from 510 | t 511 | `, 512 | want: ` 513 | SELECT 514 | foo AND bar 515 | , baz OR quz 516 | FROM t`, 517 | }, 518 | { 519 | src: ` 520 | select 521 | not foo, 522 | not true, 523 | not false, 524 | case 525 | when foo = bar then 526 | 7 527 | when foo > bar then 528 | 42 529 | else 530 | 1 531 | end 532 | from 533 | t 534 | `, 535 | want: ` 536 | SELECT 537 | NOT foo 538 | , NOT true 539 | , NOT false 540 | , CASE 541 | WHEN foo = bar THEN 7 542 | WHEN foo > bar THEN 42 543 | ELSE 1 544 | END 545 | FROM t`, 546 | }, 547 | { 548 | src: ` 549 | select 550 | case foo 551 | when 4 then 552 | 'A' 553 | when 3 then 554 | 'B' 555 | else 556 | 'C' 557 | end 558 | from 559 | baz 560 | `, 561 | want: ` 562 | SELECT 563 | CASE foo 564 | WHEN 4 THEN 'A' 565 | WHEN 3 THEN 'B' 566 | ELSE 'C' 567 | END 568 | FROM baz`, 569 | }, 570 | { 571 | src: ` 572 | select 573 | CAST('{1,2,3}' as int[]), 574 | 'Foo' collate "C", 575 | 'Bar' collate "en_US" 576 | from xxx`, 577 | want: ` 578 | SELECT 579 | CAST('{1,2,3}' AS int []) 580 | , 'Foo' COLLATE "C" 581 | , 'Bar' COLLATE "en_US" 582 | FROM xxx`, 583 | }, 584 | { 585 | src: `select 586 | 1 = 1, 587 | 2 > 1, 588 | 2 < 8, 589 | 1 != 2, 590 | 1 != 2, 591 | 3 >= 2, 592 | 2 <= 7 593 | from xxx 594 | `, 595 | want: ` 596 | SELECT 597 | 1 = 1 598 | , 2 > 1 599 | , 2 < 8 600 | , 1 != 2 601 | , 1 != 2 602 | , 3 >= 2 603 | , 2 <= 7 604 | FROM xxx`, 605 | }, 606 | { 607 | src: ` 608 | SELECT 609 | CHAR 'hi' 610 | , CHAR(2) 'hi' 611 | , VARCHAR 'hi' 612 | , VARCHAR(2) 'hi' 613 | , TIMESTAMP(4) '2000-01-01 00:00:00'`, 614 | want: ` 615 | SELECT 616 | CHAR 'hi' 617 | , CHAR(2) 'hi' 618 | , VARCHAR 'hi' 619 | , VARCHAR(2) 'hi' 620 | , TIMESTAMP(4) '2000-01-01 00:00:00'`, 621 | }, 622 | { 623 | src: ` 624 | select 625 | foo @> bar, 626 | @foo, 627 | 'foo' || 'bar' 628 | `, 629 | want: ` 630 | SELECT 631 | foo @> bar 632 | , @foo 633 | , 'foo' || 'bar'`, 634 | }, 635 | { 636 | src: ` 637 | select distinct 638 | foo, 639 | bar 640 | from 641 | baz 642 | `, 643 | want: ` 644 | SELECT DISTINCT 645 | foo 646 | , bar 647 | FROM baz`, 648 | }, 649 | { 650 | src: `select 651 | foo, 652 | bar 653 | from 654 | baz 655 | except 656 | select 657 | a, 658 | b 659 | from 660 | quz 661 | `, 662 | want: ` 663 | SELECT 664 | foo 665 | , bar 666 | FROM baz 667 | EXCEPT 668 | SELECT 669 | a 670 | , b 671 | FROM quz`, 672 | }, 673 | { 674 | src: `select 675 | foo, 676 | bar 677 | from 678 | baz 679 | where 680 | exists ( select 681 | 1 682 | from 683 | quz 684 | ) 685 | `, 686 | want: ` 687 | SELECT 688 | foo 689 | , bar 690 | FROM baz 691 | WHERE EXISTS ( 692 | SELECT 693 | 1 694 | FROM quz 695 | )`, 696 | }, 697 | { 698 | src: `select 699 | extract(year from '2000-01-01 12:34:56'::timestamptz), 700 | extract(month from '2000-01-01 12:34:56'::timestamptz) 701 | `, 702 | want: ` 703 | SELECT 704 | EXTRACT(year FROM '2000-01-01 12:34:56'::timestamptz) 705 | , EXTRACT(month FROM '2000-01-01 12:34:56'::timestamptz)`, 706 | }, 707 | { 708 | src: `select 709 | coalesce(a, b, c), 710 | greatest(d, e, f), 711 | least(g, h, i), 712 | xmlconcat(j, k, l) 713 | from 714 | foo 715 | `, 716 | want: ` 717 | SELECT 718 | COALESCE(a, b, c) 719 | , GREATEST(d, e, f) 720 | , LEAST(g, h, i) 721 | , XMLCONCAT(j, k, l) 722 | FROM foo`, 723 | }, 724 | { 725 | src: `select 726 | foo, 727 | bar 728 | from 729 | baz 730 | intersect 731 | select 732 | a, 733 | b 734 | from 735 | quz 736 | `, 737 | want: ` 738 | SELECT 739 | foo 740 | , bar 741 | FROM baz 742 | INTERSECT 743 | SELECT 744 | a 745 | , b 746 | FROM quz`, 747 | }, 748 | { 749 | src: `select 750 | interval '5', 751 | interval '5' hour, 752 | interval '5' hour to minute, 753 | interval '5' second(5), 754 | interval(2) '10.324' 755 | `, 756 | want: ` 757 | SELECT 758 | INTERVAL '5' 759 | , INTERVAL '5' hour 760 | , INTERVAL '5' hour to minute 761 | , INTERVAL '5' SECOND(5) 762 | , INTERVAL(2) '10.324'`, 763 | }, 764 | { 765 | src: `select 766 | foo, 767 | bar 768 | from 769 | baz 770 | where 771 | foo like 'abd%' 772 | or foo like 'ada%' escape '!' 773 | `, 774 | want: ` 775 | SELECT 776 | foo 777 | , bar 778 | FROM baz 779 | WHERE foo LIKE 'abd%' 780 | OR foo LIKE 'ada%' escape '!'`, 781 | }, 782 | { 783 | src: `select 784 | foo, 785 | bar 786 | from 787 | baz 788 | limit 7 789 | offset 42`, 790 | want: ` 791 | SELECT 792 | foo 793 | , bar 794 | FROM baz 795 | LIMIT 7 796 | OFFSET 42`, 797 | }, 798 | { 799 | src: `select foo, bar from baz offset 42 rows fetch next 7 rows only 800 | `, 801 | want: ` 802 | SELECT 803 | foo 804 | , bar 805 | FROM baz 806 | OFFSET 42 ROWS 807 | FETCH next 7 ROWS only`, 808 | }, 809 | { 810 | src: `select 811 | foo, 812 | bar 813 | from 814 | baz 815 | order by 816 | foo desc nulls first, 817 | quz asc nulls last, 818 | abc nulls last 819 | `, 820 | want: ` 821 | SELECT 822 | foo 823 | , bar 824 | FROM baz 825 | ORDER BY 826 | foo DESC NULLS FIRST 827 | , quz ASC NULLS LAST 828 | , abc NULLS LAST`, 829 | }, 830 | { 831 | src: `select 832 | (date '2000-01-01', date '2000-01-31') overlaps (date '2000-01-15', date '2000-02-15') 833 | `, 834 | want: ` 835 | SELECT 836 | (date '2000-01-01', date '2000-01-31') OVERLAPS (date '2000-01-15', date '2000-02-15')`, 837 | }, 838 | { 839 | src: `select 840 | foo, 841 | row_number() over(range unbounded preceding) 842 | from 843 | baz 844 | `, 845 | want: ` 846 | SELECT 847 | foo 848 | , ROW_NUMBER() OVER(range unbounded preceding) 849 | FROM baz`, 850 | }, 851 | { 852 | src: `select xxx from xxx union all select xxx from xxx`, 853 | want: ` 854 | SELECT 855 | xxx 856 | FROM xxx 857 | UNION ALL 858 | SELECT 859 | xxx 860 | FROM xxx`, 861 | }, 862 | { 863 | src: `lock table in xxx`, 864 | want: ` 865 | LOCK table 866 | IN xxx`, 867 | }, 868 | } 869 | -------------------------------------------------------------------------------- /sqlfmt/lexer/token.go: -------------------------------------------------------------------------------- 1 | package lexer 2 | 3 | import ( 4 | "bytes" 5 | ) 6 | 7 | // Token types 8 | const ( 9 | EOF TokenType = 1 + iota // eof 10 | WS // white space 11 | NEWLINE 12 | FUNCTION 13 | COMMA 14 | STARTPARENTHESIS 15 | ENDPARENTHESIS 16 | STARTBRACKET 17 | ENDBRACKET 18 | STARTBRACE 19 | ENDBRACE 20 | TYPE 21 | IDENT // field or table name 22 | STRING // values surrounded with single quotes 23 | SELECT 24 | FROM 25 | WHERE 26 | CASE 27 | ORDER 28 | BY 29 | AS 30 | JOIN 31 | LEFT 32 | RIGHT 33 | INNER 34 | OUTER 35 | ON 36 | WHEN 37 | END 38 | GROUP 39 | DESC 40 | ASC 41 | LIMIT 42 | AND 43 | ANDGROUP 44 | OR 45 | ORGROUP 46 | IN 47 | IS 48 | NOT 49 | NULL 50 | DISTINCT 51 | LIKE 52 | BETWEEN 53 | UNION 54 | ALL 55 | HAVING 56 | OVER 57 | EXISTS 58 | UPDATE 59 | SET 60 | RETURNING 61 | DELETE 62 | INSERT 63 | INTO 64 | DO 65 | VALUES 66 | FOR 67 | THEN 68 | ELSE 69 | DISTINCTROW 70 | FILTER 71 | WITHIN 72 | COLLATE 73 | INTERVAL 74 | INTERSECT 75 | EXCEPT 76 | OFFSET 77 | FETCH 78 | FIRST 79 | ROWS 80 | USING 81 | OVERLAPS 82 | NATURAL 83 | CROSS 84 | ZONE 85 | NULLS 86 | LAST 87 | AT 88 | LOCK 89 | WITH 90 | 91 | QUOTEAREA 92 | SURROUNDING 93 | ) 94 | 95 | // TokenType is an alias type that represents a kind of token 96 | type TokenType int 97 | 98 | // Token is a token struct 99 | type Token struct { 100 | Type TokenType 101 | Value string 102 | } 103 | 104 | // Reindent is a placeholder for implementing Reindenter interface 105 | func (t Token) Reindent(buf *bytes.Buffer) error { return nil } 106 | 107 | // IncrementIndentLevel is a placeholder implementing Reindenter interface 108 | func (t Token) IncrementIndentLevel(lev int) {} 109 | 110 | // end keywords of each clause 111 | var ( 112 | EndOfSelect = []TokenType{FROM, UNION, EOF} 113 | EndOfCase = []TokenType{END} 114 | EndOfFrom = []TokenType{WHERE, INNER, OUTER, LEFT, RIGHT, JOIN, NATURAL, CROSS, ORDER, GROUP, UNION, OFFSET, LIMIT, FETCH, EXCEPT, INTERSECT, EOF, ENDPARENTHESIS} 115 | EndOfJoin = []TokenType{WHERE, ORDER, GROUP, LIMIT, OFFSET, FETCH, ANDGROUP, ORGROUP, LEFT, RIGHT, INNER, OUTER, NATURAL, CROSS, UNION, EXCEPT, INTERSECT, EOF, ENDPARENTHESIS} 116 | EndOfWhere = []TokenType{GROUP, ORDER, LIMIT, OFFSET, FETCH, ANDGROUP, OR, UNION, EXCEPT, INTERSECT, RETURNING, EOF, ENDPARENTHESIS} 117 | EndOfAndGroup = []TokenType{GROUP, ORDER, LIMIT, OFFSET, FETCH, UNION, EXCEPT, INTERSECT, ANDGROUP, ORGROUP, EOF, ENDPARENTHESIS} 118 | EndOfOrGroup = []TokenType{GROUP, ORDER, LIMIT, OFFSET, FETCH, UNION, EXCEPT, INTERSECT, ANDGROUP, ORGROUP, EOF, ENDPARENTHESIS} 119 | EndOfGroupBy = []TokenType{ORDER, LIMIT, FETCH, OFFSET, UNION, EXCEPT, INTERSECT, HAVING, EOF, ENDPARENTHESIS} 120 | EndOfHaving = []TokenType{LIMIT, OFFSET, FETCH, ORDER, UNION, EXCEPT, INTERSECT, EOF, ENDPARENTHESIS} 121 | EndOfOrderBy = []TokenType{LIMIT, FETCH, OFFSET, UNION, EXCEPT, INTERSECT, EOF, ENDPARENTHESIS} 122 | EndOfLimitClause = []TokenType{UNION, EXCEPT, INTERSECT, EOF, ENDPARENTHESIS} 123 | EndOfParenthesis = []TokenType{ENDPARENTHESIS} 124 | EndOfTieClause = []TokenType{SELECT} 125 | EndOfUpdate = []TokenType{WHERE, SET, RETURNING, EOF} 126 | EndOfSet = []TokenType{WHERE, RETURNING, EOF} 127 | EndOfReturning = []TokenType{EOF} 128 | EndOfDelete = []TokenType{WHERE, FROM, EOF} 129 | EndOfInsert = []TokenType{VALUES, EOF} 130 | EndOfValues = []TokenType{UPDATE, RETURNING, EOF} 131 | EndOfFunction = []TokenType{ENDPARENTHESIS} 132 | EndOfTypeCast = []TokenType{ENDPARENTHESIS} 133 | EndOfLock = []TokenType{EOF} 134 | EndOfWith = []TokenType{EOF} 135 | ) 136 | 137 | // token types that contain the keyword to make subGroup 138 | var ( 139 | TokenTypesOfGroupMaker = []TokenType{SELECT, CASE, FROM, WHERE, ORDER, GROUP, LIMIT, ANDGROUP, ORGROUP, HAVING, UNION, EXCEPT, INTERSECT, FUNCTION, STARTPARENTHESIS, TYPE} 140 | TokenTypesOfJoinMaker = []TokenType{JOIN, INNER, OUTER, LEFT, RIGHT, NATURAL, CROSS} 141 | TokenTypeOfTieClause = []TokenType{UNION, INTERSECT, EXCEPT} 142 | TokenTypeOfLimitClause = []TokenType{LIMIT, FETCH, OFFSET} 143 | ) 144 | 145 | // IsJoinStart determines if ttype is included in TokenTypesOfJoinMaker 146 | func (t Token) IsJoinStart() bool { 147 | for _, v := range TokenTypesOfJoinMaker { 148 | if t.Type == v { 149 | return true 150 | } 151 | } 152 | return false 153 | } 154 | 155 | // IsTieClauseStart determines if ttype is included in TokenTypesOfTieClause 156 | func (t Token) IsTieClauseStart() bool { 157 | for _, v := range TokenTypeOfTieClause { 158 | if t.Type == v { 159 | return true 160 | } 161 | } 162 | return false 163 | } 164 | 165 | // IsLimitClauseStart determines ttype is included in TokenTypesOfLimitClause 166 | func (t Token) IsLimitClauseStart() bool { 167 | for _, v := range TokenTypeOfLimitClause { 168 | if t.Type == v { 169 | return true 170 | } 171 | } 172 | return false 173 | } 174 | 175 | // IsNeedNewLineBefore returns true if token needs new line before written in buffer 176 | func (t Token) IsNeedNewLineBefore() bool { 177 | var ttypes = []TokenType{SELECT, UPDATE, INSERT, DELETE, ANDGROUP, FROM, GROUP, ORGROUP, ORDER, HAVING, LIMIT, OFFSET, FETCH, RETURNING, SET, UNION, INTERSECT, EXCEPT, VALUES, WHERE, ON, USING, UNION, EXCEPT, INTERSECT} 178 | for _, v := range ttypes { 179 | if t.Type == v { 180 | return true 181 | } 182 | } 183 | return false 184 | } 185 | 186 | // IsKeyWordInSelect returns true if token is a keyword in select group 187 | func (t Token) IsKeyWordInSelect() bool { 188 | return t.Type == SELECT || t.Type == EXISTS || t.Type == DISTINCT || t.Type == DISTINCTROW || t.Type == INTO || t.Type == AS || t.Type == GROUP || t.Type == ORDER || t.Type == BY || t.Type == ON || t.Type == RETURNING || t.Type == SET || t.Type == UPDATE 189 | } 190 | -------------------------------------------------------------------------------- /sqlfmt/lexer/tokenizer.go: -------------------------------------------------------------------------------- 1 | package lexer 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "strings" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Tokenizer tokenizes SQL statements 12 | type Tokenizer struct { 13 | r *bufio.Reader 14 | w *bytes.Buffer // w writes token value. It resets its value when the end of token appears 15 | result []Token 16 | } 17 | 18 | // rune that can't be contained in SQL statement 19 | // TODO: I have to make better solution of making rune of eof in stead of using '∂' 20 | var eof = '∂' 21 | 22 | // value of literal 23 | const ( 24 | Comma = "," 25 | StartParenthesis = "(" 26 | EndParenthesis = ")" 27 | StartBracket = "[" 28 | EndBracket = "]" 29 | StartBrace = "{" 30 | EndBrace = "}" 31 | SingleQuote = "'" 32 | NewLine = "\n" 33 | ) 34 | 35 | // NewTokenizer creates Tokenizer 36 | func NewTokenizer(src string) *Tokenizer { 37 | return &Tokenizer{ 38 | r: bufio.NewReader(strings.NewReader(src)), 39 | w: &bytes.Buffer{}, 40 | } 41 | } 42 | 43 | // GetTokens returns tokens for parsing 44 | func (t *Tokenizer) GetTokens() ([]Token, error) { 45 | var result []Token 46 | 47 | tokens, err := t.Tokenize() 48 | if err != nil { 49 | return nil, errors.Wrap(err, "Tokenize failed") 50 | } 51 | // replace all tokens without whitespaces and new lines 52 | // if "AND" or "OR" appears after new line, token value will be ANDGROUP, ORGROUP 53 | for i, tok := range tokens { 54 | if tok.Type == AND && tokens[i-1].Type == NEWLINE { 55 | andGroupToken := Token{Type: ANDGROUP, Value: tok.Value} 56 | result = append(result, andGroupToken) 57 | continue 58 | } 59 | if tok.Type == OR && tokens[i-1].Type == NEWLINE { 60 | orGroupToken := Token{Type: ORGROUP, Value: tok.Value} 61 | result = append(result, orGroupToken) 62 | continue 63 | } 64 | if tok.Type == WS || tok.Type == NEWLINE { 65 | continue 66 | } 67 | result = append(result, tok) 68 | } 69 | return result, nil 70 | } 71 | 72 | // Tokenize analyses every rune in SQL statement 73 | // every token is identified when whitespace appears 74 | func (t *Tokenizer) Tokenize() ([]Token, error) { 75 | for { 76 | isEOF, err := t.scan() 77 | 78 | if isEOF { 79 | break 80 | } 81 | if err != nil { 82 | return nil, err 83 | } 84 | } 85 | return t.result, nil 86 | } 87 | 88 | // unread undoes t.r.readRune method to get last character 89 | func (t *Tokenizer) unread() { t.r.UnreadRune() } 90 | 91 | func isWhiteSpace(ch rune) bool { 92 | return ch == ' ' || ch == '\t' || ch == '\n' || ch == ' ' 93 | } 94 | 95 | func isComma(ch rune) bool { 96 | return ch == ',' 97 | } 98 | 99 | func isStartParenthesis(ch rune) bool { 100 | return ch == '(' 101 | } 102 | 103 | func isEndParenthesis(ch rune) bool { 104 | return ch == ')' 105 | } 106 | 107 | func isSingleQuote(ch rune) bool { 108 | return ch == '\'' 109 | } 110 | 111 | func isStartBracket(ch rune) bool { 112 | return ch == '[' 113 | } 114 | 115 | func isEndBracket(ch rune) bool { 116 | return ch == ']' 117 | } 118 | 119 | func isStartBrace(ch rune) bool { 120 | return ch == '{' 121 | } 122 | 123 | func isEndBrace(ch rune) bool { 124 | return ch == '}' 125 | } 126 | 127 | // scan scans each character and appends to result until "eof" appears 128 | // when it finishes scanning all characters, it returns true 129 | func (t *Tokenizer) scan() (bool, error) { 130 | ch, _, err := t.r.ReadRune() 131 | if err != nil { 132 | if err.Error() == "EOF" { 133 | ch = eof 134 | } else { 135 | return false, errors.Wrap(err, "read rune failed") 136 | } 137 | } 138 | 139 | switch { 140 | case ch == eof: 141 | tok := Token{Type: EOF, Value: "EOF"} 142 | t.result = append(t.result, tok) 143 | return true, nil 144 | case isWhiteSpace(ch): 145 | if err := t.scanWhiteSpace(); err != nil { 146 | return false, err 147 | } 148 | return false, nil 149 | // extract string 150 | case isSingleQuote(ch): 151 | if err := t.scanString(); err != nil { 152 | return false, err 153 | } 154 | return false, nil 155 | case isComma(ch): 156 | token := Token{Type: COMMA, Value: Comma} 157 | t.result = append(t.result, token) 158 | return false, nil 159 | case isStartParenthesis(ch): 160 | token := Token{Type: STARTPARENTHESIS, Value: StartParenthesis} 161 | t.result = append(t.result, token) 162 | return false, nil 163 | case isEndParenthesis(ch): 164 | token := Token{Type: ENDPARENTHESIS, Value: EndParenthesis} 165 | t.result = append(t.result, token) 166 | return false, nil 167 | case isStartBracket(ch): 168 | token := Token{Type: STARTBRACKET, Value: StartBracket} 169 | t.result = append(t.result, token) 170 | return false, nil 171 | case isEndBracket(ch): 172 | token := Token{Type: ENDBRACKET, Value: EndBracket} 173 | t.result = append(t.result, token) 174 | return false, nil 175 | case isStartBrace(ch): 176 | token := Token{Type: STARTBRACE, Value: StartBrace} 177 | t.result = append(t.result, token) 178 | return false, nil 179 | case isEndBrace(ch): 180 | token := Token{Type: ENDBRACE, Value: EndBrace} 181 | t.result = append(t.result, token) 182 | return false, nil 183 | default: 184 | if err := t.scanIdent(); err != nil { 185 | return false, err 186 | } 187 | return false, nil 188 | } 189 | } 190 | 191 | func (t *Tokenizer) scanWhiteSpace() error { 192 | t.unread() 193 | 194 | for { 195 | ch, _, err := t.r.ReadRune() 196 | if err != nil { 197 | if err.Error() == "EOF" { 198 | break 199 | } else { 200 | return err 201 | } 202 | } 203 | if !isWhiteSpace(ch) { 204 | t.unread() 205 | break 206 | } else { 207 | t.w.WriteRune(ch) 208 | } 209 | } 210 | 211 | if strings.Contains(t.w.String(), "\n") { 212 | tok := Token{Type: NEWLINE, Value: "\n"} 213 | t.result = append(t.result, tok) 214 | } else { 215 | tok := Token{Type: WS, Value: t.w.String()} 216 | t.result = append(t.result, tok) 217 | } 218 | t.w.Reset() 219 | return nil 220 | } 221 | 222 | // scan string token including single quotes 223 | func (t *Tokenizer) scanString() error { 224 | var counter int 225 | t.unread() 226 | 227 | for { 228 | ch, _, err := t.r.ReadRune() 229 | if err != nil { 230 | if err.Error() == "EOF" { 231 | break 232 | } else { 233 | return err 234 | } 235 | } 236 | // ignore the first single quote 237 | if counter != 0 && isSingleQuote(ch) { 238 | t.w.WriteRune(ch) 239 | break 240 | } else { 241 | t.w.WriteRune(ch) 242 | } 243 | counter++ 244 | } 245 | tok := Token{Type: STRING, Value: t.w.String()} 246 | t.result = append(t.result, tok) 247 | t.w.Reset() 248 | return nil 249 | } 250 | 251 | // append all ch to result until ch is a white space 252 | // if ident is keyword, Type will be the keyword and value will be the uppercase keyword 253 | func (t *Tokenizer) scanIdent() error { 254 | t.unread() 255 | 256 | for { 257 | ch, _, err := t.r.ReadRune() 258 | if err != nil { 259 | if err.Error() == "EOF" { 260 | break 261 | } else { 262 | return err 263 | } 264 | } 265 | if isWhiteSpace(ch) { 266 | t.unread() 267 | break 268 | } else if isComma(ch) { 269 | t.unread() 270 | break 271 | } else if isStartParenthesis(ch) { 272 | t.unread() 273 | break 274 | } else if isEndParenthesis(ch) { 275 | t.unread() 276 | break 277 | } else if isSingleQuote(ch) { 278 | t.unread() 279 | break 280 | } else if isStartBracket(ch) { 281 | t.unread() 282 | break 283 | } else if isEndBracket(ch) { 284 | t.unread() 285 | break 286 | } else if isStartBrace(ch) { 287 | t.unread() 288 | break 289 | } else if isEndBrace(ch) { 290 | t.unread() 291 | break 292 | } else { 293 | t.w.WriteRune(ch) 294 | } 295 | } 296 | t.append(t.w.String()) 297 | return nil 298 | } 299 | 300 | func (t *Tokenizer) append(v string) { 301 | upperValue := strings.ToUpper(v) 302 | 303 | if ttype, ok := t.isSQLKeyWord(upperValue); ok { 304 | t.result = append(t.result, Token{ 305 | Type: ttype, 306 | Value: upperValue, 307 | }) 308 | } else { 309 | t.result = append(t.result, Token{ 310 | Type: ttype, 311 | Value: v, 312 | }) 313 | } 314 | t.w.Reset() 315 | } 316 | 317 | func (t *Tokenizer) isSQLKeyWord(v string) (TokenType, bool) { 318 | if ttype, ok := sqlKeywordMap[v]; ok { 319 | return ttype, ok 320 | } else if ttype, ok := typeWithParenMap[v]; ok { 321 | if r, _, err := t.r.ReadRune(); err == nil && string(r) == StartParenthesis { 322 | t.unread() 323 | return ttype, ok 324 | } 325 | t.unread() 326 | return IDENT, ok 327 | } 328 | return IDENT, false 329 | } 330 | 331 | var sqlKeywordMap = map[string]TokenType{ 332 | "SELECT": SELECT, 333 | "FROM": FROM, 334 | "WHERE": WHERE, 335 | "CASE": CASE, 336 | "ORDER": ORDER, 337 | "BY": BY, 338 | "AS": AS, 339 | "JOIN": JOIN, 340 | "LEFT": LEFT, 341 | "RIGHT": RIGHT, 342 | "INNER": INNER, 343 | "OUTER": OUTER, 344 | "ON": ON, 345 | "WHEN": WHEN, 346 | "END": END, 347 | "GROUP": GROUP, 348 | "DESC": DESC, 349 | "ASC": ASC, 350 | "LIMIT": LIMIT, 351 | "AND": AND, 352 | "OR": OR, 353 | "IN": IN, 354 | "IS": IS, 355 | "NOT": NOT, 356 | "NULL": NULL, 357 | "DISTINCT": DISTINCT, 358 | "LIKE": LIKE, 359 | "BETWEEN": BETWEEN, 360 | "UNION": UNION, 361 | "ALL": ALL, 362 | "HAVING": HAVING, 363 | "EXISTS": EXISTS, 364 | "UPDATE": UPDATE, 365 | "SET": SET, 366 | "RETURNING": RETURNING, 367 | "DELETE": DELETE, 368 | "INSERT": INSERT, 369 | "INTO": INTO, 370 | "DO": DO, 371 | "VALUES": VALUES, 372 | "FOR": FOR, 373 | "THEN": THEN, 374 | "ELSE": ELSE, 375 | "DISTINCTROW": DISTINCTROW, 376 | "FILTER": FILTER, 377 | "WITHIN": WITHIN, 378 | "COLLATE": COLLATE, 379 | "INTERSECT": INTERSECT, 380 | "EXCEPT": EXCEPT, 381 | "OFFSET": OFFSET, 382 | "FETCH": FETCH, 383 | "FIRST": FIRST, 384 | "ROWS": ROWS, 385 | "USING": USING, 386 | "OVERLAPS": OVERLAPS, 387 | "NATURAL": NATURAL, 388 | "CROSS": CROSS, 389 | "ZONE": ZONE, 390 | "NULLS": NULLS, 391 | "LAST": LAST, 392 | "AT": AT, 393 | "LOCK": LOCK, 394 | "WITH": WITH, 395 | } 396 | 397 | var typeWithParenMap = map[string]TokenType{ 398 | "SUM": FUNCTION, 399 | "AVG": FUNCTION, 400 | "MAX": FUNCTION, 401 | "MIN": FUNCTION, 402 | "COUNT": FUNCTION, 403 | "COALESCE": FUNCTION, 404 | "EXTRACT": FUNCTION, 405 | "OVERLAY": FUNCTION, 406 | "POSITION": FUNCTION, 407 | "CAST": FUNCTION, 408 | "SUBSTRING": FUNCTION, 409 | "TRIM": FUNCTION, 410 | "XMLELEMENT": FUNCTION, 411 | "XMLFOREST": FUNCTION, 412 | "XMLCONCAT": FUNCTION, 413 | "RANDOM": FUNCTION, 414 | "DATE_PART": FUNCTION, 415 | "DATE_TRUNC": FUNCTION, 416 | "ARRAY_AGG": FUNCTION, 417 | "PERCENTILE_DISC": FUNCTION, 418 | "GREATEST": FUNCTION, 419 | "LEAST": FUNCTION, 420 | "OVER": FUNCTION, 421 | "ROW_NUMBER": FUNCTION, 422 | "BIG": TYPE, 423 | "BIGSERIAL": TYPE, 424 | "BOOLEAN": TYPE, 425 | "CHAR": TYPE, 426 | "BIT": TYPE, 427 | "TEXT": TYPE, 428 | "INTEGER": TYPE, 429 | "NUMERIC": TYPE, 430 | "DECIMAL": TYPE, 431 | "DEC": TYPE, 432 | "FLOAT": TYPE, 433 | "CUSTOMTYPE": TYPE, 434 | "VARCHAR": TYPE, 435 | "VARBIT": TYPE, 436 | "TIMESTAMP": TYPE, 437 | "TIME": TYPE, 438 | "SECOND": TYPE, 439 | "INTERVAL": TYPE, 440 | } 441 | -------------------------------------------------------------------------------- /sqlfmt/lexer/tokenizer_test.go: -------------------------------------------------------------------------------- 1 | package lexer 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestGetTokens(t *testing.T) { 10 | var testingSQLStatement = strings.Trim(`select name, age,sum, sum(case xxx) from user where name xxx and age = 'xxx' limit 100 except 100`, "`") 11 | want := []Token{ 12 | {Type: SELECT, Value: "SELECT"}, 13 | {Type: IDENT, Value: "name"}, 14 | {Type: COMMA, Value: ","}, 15 | {Type: IDENT, Value: "age"}, 16 | {Type: COMMA, Value: ","}, 17 | {Type: IDENT, Value: "SUM"}, 18 | {Type: COMMA, Value: ","}, 19 | {Type: FUNCTION, Value: "SUM"}, 20 | {Type: STARTPARENTHESIS, Value: "("}, 21 | {Type: CASE, Value: "CASE"}, 22 | {Type: IDENT, Value: "xxx"}, 23 | {Type: ENDPARENTHESIS, Value: ")"}, 24 | 25 | {Type: FROM, Value: "FROM"}, 26 | {Type: IDENT, Value: "user"}, 27 | {Type: WHERE, Value: "WHERE"}, 28 | {Type: IDENT, Value: "name"}, 29 | {Type: IDENT, Value: "xxx"}, 30 | {Type: AND, Value: "AND"}, 31 | {Type: IDENT, Value: "age"}, 32 | {Type: IDENT, Value: "="}, 33 | {Type: STRING, Value: "'xxx'"}, 34 | {Type: LIMIT, Value: "LIMIT"}, 35 | {Type: IDENT, Value: "100"}, 36 | {Type: EXCEPT, Value: "EXCEPT"}, 37 | {Type: IDENT, Value: "100"}, 38 | 39 | {Type: EOF, Value: "EOF"}, 40 | } 41 | tnz := NewTokenizer(testingSQLStatement) 42 | got, err := tnz.GetTokens() 43 | if err != nil { 44 | t.Fatalf("\nERROR: %#v", err) 45 | } else if !reflect.DeepEqual(want, got) { 46 | t.Errorf("\nwant %#v, \ngot %#v", want, got) 47 | } 48 | } 49 | 50 | func TestIsWhiteSpace(t *testing.T) { 51 | tests := []struct { 52 | name string 53 | src rune 54 | want bool 55 | }{ 56 | { 57 | name: "normal test case 1", 58 | src: '\n', 59 | want: true, 60 | }, 61 | { 62 | name: "normal test case 2", 63 | src: '\t', 64 | want: true, 65 | }, 66 | { 67 | name: "normal test case 3", 68 | src: ' ', 69 | want: true, 70 | }, 71 | { 72 | name: "abnormal case", 73 | src: 'a', 74 | want: false, 75 | }, 76 | } 77 | for _, tt := range tests { 78 | t.Run(tt.name, func(t *testing.T) { 79 | if got := isWhiteSpace(tt.src); got != tt.want { 80 | t.Errorf("\nwant %v, \ngot %v", tt.want, got) 81 | } 82 | }) 83 | } 84 | } 85 | 86 | func TestScan(t *testing.T) { 87 | tests := []struct { 88 | name string 89 | src string 90 | want bool 91 | }{ 92 | { 93 | name: "normal test case 1", 94 | src: `select`, 95 | want: false, 96 | }, 97 | { 98 | name: "normal test case 2", 99 | src: `table`, 100 | want: false, 101 | }, 102 | { 103 | name: "normal test case 3", 104 | src: ` `, 105 | want: false, 106 | }, 107 | } 108 | for _, tt := range tests { 109 | t.Run(tt.name, func(t *testing.T) { 110 | tnz := NewTokenizer(tt.src) 111 | 112 | got, err := tnz.scan() 113 | if err != nil { 114 | t.Errorf("\nERROR: %#v", err) 115 | } 116 | if got != tt.want { 117 | t.Errorf("\nwant %v, \ngot %v", tt.want, got) 118 | } 119 | }) 120 | } 121 | } 122 | 123 | func TestScanWhiteSpace(t *testing.T) { 124 | tests := []struct { 125 | name string 126 | src string 127 | want Token 128 | }{ 129 | { 130 | name: "normal test case 1", 131 | src: ` `, 132 | want: Token{Type: WS, Value: " "}, 133 | }, 134 | { 135 | name: "normal test case 2", 136 | src: "\n", 137 | want: Token{Type: NEWLINE, Value: "\n"}, 138 | }, 139 | } 140 | for _, tt := range tests { 141 | t.Run(tt.name, func(t *testing.T) { 142 | tnz := NewTokenizer(tt.src) 143 | tnz.scanWhiteSpace() 144 | 145 | if got := tnz.result[0]; got != tt.want { 146 | t.Errorf("\nwant %v, \ngot %v", tt.want, got) 147 | } 148 | }) 149 | } 150 | } 151 | 152 | func TestScanIdent(t *testing.T) { 153 | tests := []struct { 154 | name string 155 | src string 156 | want Token 157 | }{ 158 | { 159 | name: "normal test case 1", 160 | src: `select`, 161 | want: Token{Type: SELECT, Value: "SELECT"}, 162 | }, 163 | { 164 | name: "normal test case 2", 165 | src: "table", 166 | want: Token{Type: IDENT, Value: "table"}, 167 | }, 168 | } 169 | for _, tt := range tests { 170 | t.Run(tt.name, func(t *testing.T) { 171 | tnz := NewTokenizer(tt.src) 172 | tnz.scanIdent() 173 | 174 | if got := tnz.result[0]; got != tt.want { 175 | t.Errorf("\nwant %v, \ngot %v", tt.want, got) 176 | } 177 | }) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/and_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // AndGroup is AND clause not AND operator 10 | // AndGroup is made after new line 11 | //// select xxx and xxx <= this is not AndGroup 12 | //// select xxx from xxx where xxx 13 | //// and xxx <= this is AndGroup 14 | type AndGroup struct { 15 | Element []Reindenter 16 | IndentLevel int 17 | } 18 | 19 | // Reindent reindents its elements 20 | func (a *AndGroup) Reindent(buf *bytes.Buffer) error { 21 | elements, err := processPunctuation(a.Element) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | for _, el := range elements { 27 | if token, ok := el.(lexer.Token); ok { 28 | write(buf, token, a.IndentLevel) 29 | } else { 30 | el.Reindent(buf) 31 | } 32 | } 33 | return nil 34 | } 35 | 36 | // IncrementIndentLevel increments by its specified indent level 37 | func (a *AndGroup) IncrementIndentLevel(lev int) { 38 | a.IndentLevel += lev 39 | } 40 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/and_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentAndGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal test", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.ANDGROUP, Value: "AND"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "something2"}, 22 | }, 23 | want: "\nAND something1 something2", 24 | }, 25 | } 26 | for _, tt := range tests { 27 | buf := &bytes.Buffer{} 28 | andGroup := &AndGroup{Element: tt.tokenSource} 29 | 30 | if err := andGroup.Reindent(buf); err != nil { 31 | t.Errorf("error %#v", err) 32 | } 33 | got := buf.String() 34 | if tt.want != got { 35 | t.Errorf("want%#v, got %#v", tt.want, got) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/case_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Case Clause 10 | type Case struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | hasCommaBefore bool 14 | } 15 | 16 | // Reindent reindents its elements 17 | func (c *Case) Reindent(buf *bytes.Buffer) error { 18 | elements, err := processPunctuation(c.Element) 19 | if err != nil { 20 | return err 21 | } 22 | for _, v := range elements { 23 | if token, ok := v.(lexer.Token); ok { 24 | writeCase(buf, token, c.IndentLevel, c.hasCommaBefore) 25 | } else { 26 | v.Reindent(buf) 27 | } 28 | } 29 | return nil 30 | } 31 | 32 | // IncrementIndentLevel increments by its specified increment level 33 | func (c *Case) IncrementIndentLevel(lev int) { 34 | c.IndentLevel += lev 35 | } 36 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/case_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentCaseGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.CASE, Value: "CASE"}, 20 | lexer.Token{Type: lexer.WHEN, Value: "WHEN"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "something"}, 22 | lexer.Token{Type: lexer.IDENT, Value: "something"}, 23 | lexer.Token{Type: lexer.END, Value: "END"}, 24 | }, 25 | want: "\n CASE\n WHEN something something\n END", 26 | }, 27 | } 28 | for _, tt := range tests { 29 | buf := &bytes.Buffer{} 30 | caseGroup := &Case{Element: tt.tokenSource} 31 | 32 | caseGroup.Reindent(buf) 33 | got := buf.String() 34 | if tt.want != got { 35 | t.Errorf("want%#v, got %#v", tt.want, got) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/delete_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Delete clause 10 | type Delete struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (d *Delete) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(d.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, d.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (d *Delete) IncrementIndentLevel(lev int) { 33 | d.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/delete_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentDeleteGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.DELETE, Value: "DELETE"}, 20 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 22 | }, 23 | want: "\nDELETE\nFROM xxxxxx", 24 | }, 25 | } 26 | for _, tt := range tests { 27 | buf := &bytes.Buffer{} 28 | deleteGroup := &Delete{Element: tt.tokenSource} 29 | 30 | deleteGroup.Reindent(buf) 31 | got := buf.String() 32 | if tt.want != got { 33 | t.Errorf("want%#v, got %#v", tt.want, got) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/from_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // From clause 10 | type From struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (f *From) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(f.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, f.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel indents by its specified indent level 32 | func (f *From) IncrementIndentLevel(lev int) { 33 | f.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/from_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentFromGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "sometable"}, 21 | }, 22 | want: "\nFROM sometable", 23 | }, 24 | } 25 | for _, tt := range tests { 26 | buf := &bytes.Buffer{} 27 | fromGroup := &From{Element: tt.tokenSource} 28 | 29 | fromGroup.Reindent(buf) 30 | got := buf.String() 31 | if tt.want != got { 32 | t.Errorf("want%#v, got %#v", tt.want, got) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/function_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Function clause 10 | type Function struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | InColumnArea bool 14 | ColumnCount int 15 | } 16 | 17 | // Reindent reindents its elements 18 | func (f *Function) Reindent(buf *bytes.Buffer) error { 19 | elements, err := processPunctuation(f.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for i, el := range elements { 25 | if token, ok := el.(lexer.Token); ok { 26 | var prev lexer.Token 27 | 28 | if i > 0 { 29 | if preToken, ok := elements[i-1].(lexer.Token); ok { 30 | prev = preToken 31 | } 32 | } 33 | writeFunction(buf, token, prev, f.IndentLevel, f.ColumnCount, f.InColumnArea) 34 | } else { 35 | el.Reindent(buf) 36 | } 37 | } 38 | return nil 39 | } 40 | 41 | // IncrementIndentLevel increments by its specified indent level 42 | func (f *Function) IncrementIndentLevel(lev int) { 43 | f.IndentLevel += lev 44 | } 45 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/function_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentFunctionGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.FUNCTION, Value: "SUM"}, 20 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 21 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 22 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 23 | }, 24 | want: " SUM(xxx)", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | functionGroup := &Function{Element: tt.tokenSource} 30 | 31 | functionGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/group_by_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // GroupBy clause 10 | type GroupBy struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (g *GroupBy) Reindent(buf *bytes.Buffer) error { 17 | columnCount = 0 18 | 19 | elements, err := processPunctuation(g.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for _, el := range separate(elements) { 25 | switch v := el.(type) { 26 | case lexer.Token, string: 27 | if err := writeWithComma(buf, v, g.IndentLevel); err != nil { 28 | return err 29 | } 30 | case Reindenter: 31 | v.Reindent(buf) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel increments by its specified indent level 38 | func (g *GroupBy) IncrementIndentLevel(lev int) { 39 | g.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/group_by_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentGroupByGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.GROUP, Value: "GROUP"}, 20 | lexer.Token{Type: lexer.BY, Value: "BY"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 22 | }, 23 | want: "\nGROUP BY\n xxxxxx", 24 | }, 25 | } 26 | for _, tt := range tests { 27 | buf := &bytes.Buffer{} 28 | groupByGroup := &GroupBy{Element: tt.tokenSource} 29 | 30 | groupByGroup.Reindent(buf) 31 | got := buf.String() 32 | if tt.want != got { 33 | t.Errorf("want%#v, got %#v", tt.want, got) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/having_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Having clause 10 | type Having struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (h *Having) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(h.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, h.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (h *Having) IncrementIndentLevel(lev int) { 33 | h.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/having_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentHavingGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.HAVING, Value: "HAVING"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxxxx"}, 21 | }, 22 | want: "\nHAVING xxxxxxxx", 23 | }, 24 | } 25 | for _, tt := range tests { 26 | buf := &bytes.Buffer{} 27 | havingGroup := &Having{Element: tt.tokenSource} 28 | 29 | havingGroup.Reindent(buf) 30 | got := buf.String() 31 | if tt.want != got { 32 | t.Errorf("want%#v, got %#v", tt.want, got) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/insert_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Insert clause 10 | type Insert struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (insert *Insert) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(insert.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, insert.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (insert *Insert) IncrementIndentLevel(lev int) { 33 | insert.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/insert_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentInsertGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.INSERT, Value: "INSERT"}, 20 | lexer.Token{Type: lexer.INTO, Value: "INTO"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 22 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 23 | }, 24 | want: "\nINSERT INTO xxxxxx xxxxxx", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | insertGroup := &Insert{Element: tt.tokenSource} 30 | 31 | insertGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/join_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Join clause 10 | type Join struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindent its elements 16 | func (j *Join) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(j.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for i, v := range elements { 22 | if token, ok := v.(lexer.Token); ok { 23 | writeJoin(buf, token, j.IndentLevel, i == 0) 24 | } else { 25 | v.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified increment level 32 | func (j *Join) IncrementIndentLevel(lev int) { 33 | j.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/join_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentJoinGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.LEFT, Value: "LEFT"}, 20 | lexer.Token{Type: lexer.OUTER, Value: "OUTER"}, 21 | lexer.Token{Type: lexer.JOIN, Value: "JOIN"}, 22 | lexer.Token{Type: lexer.IDENT, Value: "sometable"}, 23 | lexer.Token{Type: lexer.ON, Value: "ON"}, 24 | lexer.Token{Type: lexer.IDENT, Value: "status1"}, 25 | lexer.Token{Type: lexer.IDENT, Value: "="}, 26 | lexer.Token{Type: lexer.IDENT, Value: "status2"}, 27 | }, 28 | 29 | want: "\nLEFT OUTER JOIN sometable\nON status1 = status2", 30 | }, 31 | } 32 | for _, tt := range tests { 33 | buf := &bytes.Buffer{} 34 | joinGroup := &Join{Element: tt.tokenSource} 35 | 36 | joinGroup.Reindent(buf) 37 | got := buf.String() 38 | if tt.want != got { 39 | t.Errorf("want%#v, got %#v", tt.want, got) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/limit_clause.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // LimitClause such as LIMIT, OFFSET, FETCH FIRST 10 | type LimitClause struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (l *LimitClause) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(l.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, l.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (l *LimitClause) IncrementIndentLevel(lev int) { 33 | l.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/limit_clause_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentLimitGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.LIMIT, Value: "LIMIT"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "123"}, 21 | }, 22 | want: "\nLIMIT 123", 23 | }, 24 | { 25 | name: "normalcase", 26 | tokenSource: []Reindenter{ 27 | lexer.Token{Type: lexer.OFFSET, Value: "OFFSET"}, 28 | }, 29 | want: "\nOFFSET", 30 | }, 31 | { 32 | name: "normalcase", 33 | tokenSource: []Reindenter{ 34 | lexer.Token{Type: lexer.FETCH, Value: "FETCH"}, 35 | lexer.Token{Type: lexer.FIRST, Value: "FIRST"}, 36 | }, 37 | want: "\nFETCH FIRST", 38 | }, 39 | } 40 | for _, tt := range tests { 41 | buf := &bytes.Buffer{} 42 | limitGroup := &LimitClause{Element: tt.tokenSource} 43 | 44 | limitGroup.Reindent(buf) 45 | got := buf.String() 46 | if tt.want != got { 47 | t.Errorf("want%#v, got %#v", tt.want, got) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/lock_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Lock clause 10 | type Lock struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindent its elements 16 | func (l *Lock) Reindent(buf *bytes.Buffer) error { 17 | for _, v := range l.Element { 18 | if token, ok := v.(lexer.Token); ok { 19 | writeLock(buf, token) 20 | } else { 21 | v.Reindent(buf) 22 | } 23 | } 24 | return nil 25 | } 26 | 27 | // IncrementIndentLevel increments by its specified increment level 28 | func (l *Lock) IncrementIndentLevel(lev int) { 29 | l.IndentLevel += lev 30 | } 31 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/lock_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentLockGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.LOCK, Value: "LOCK"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "table"}, 21 | lexer.Token{Type: lexer.IN, Value: "IN"}, 22 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 23 | }, 24 | want: "\nLOCK table\nIN xxx", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | lock := &Lock{Element: tt.tokenSource} 30 | 31 | lock.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/or_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // OrGroup clause 10 | type OrGroup struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (o *OrGroup) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(o.Element) 18 | if err != nil { 19 | return err 20 | } 21 | 22 | for _, el := range elements { 23 | if token, ok := el.(lexer.Token); ok { 24 | write(buf, token, o.IndentLevel) 25 | } else { 26 | el.Reindent(buf) 27 | } 28 | } 29 | return nil 30 | } 31 | 32 | // IncrementIndentLevel increments by its specified increment level 33 | func (o *OrGroup) IncrementIndentLevel(lev int) { 34 | o.IndentLevel += lev 35 | } 36 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/or_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentOrGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.ORGROUP, Value: "OR"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "something2"}, 22 | }, 23 | want: "\nOR something1 something2", 24 | }, 25 | } 26 | for _, tt := range tests { 27 | buf := &bytes.Buffer{} 28 | orGroup := &OrGroup{Element: tt.tokenSource} 29 | 30 | orGroup.Reindent(buf) 31 | got := buf.String() 32 | if tt.want != got { 33 | t.Errorf("want%#v, got %#v", tt.want, got) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/order_by_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // OrderBy clause 10 | type OrderBy struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (o *OrderBy) Reindent(buf *bytes.Buffer) error { 17 | columnCount = 0 18 | 19 | src, err := processPunctuation(o.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for _, el := range separate(src) { 25 | switch v := el.(type) { 26 | case lexer.Token, string: 27 | if err := writeWithComma(buf, v, o.IndentLevel); err != nil { 28 | return err 29 | } 30 | case Reindenter: 31 | v.Reindent(buf) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel increments by its specified indent level 38 | func (o *OrderBy) IncrementIndentLevel(lev int) { 39 | o.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/order_by_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentOrderByGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.ORDER, Value: "ORDER"}, 20 | lexer.Token{Type: lexer.BY, Value: "BY"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 22 | }, 23 | want: "\nORDER BY\n xxxxxx", 24 | }, 25 | } 26 | for _, tt := range tests { 27 | buf := &bytes.Buffer{} 28 | orderByGroup := &OrderBy{Element: tt.tokenSource} 29 | 30 | orderByGroup.Reindent(buf) 31 | got := buf.String() 32 | if tt.want != got { 33 | t.Errorf("want%#v, got %#v", tt.want, got) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/parenthesis_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Parenthesis clause 10 | type Parenthesis struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | InColumnArea bool 14 | ColumnCount int 15 | } 16 | 17 | // Reindent reindents its elements 18 | func (p *Parenthesis) Reindent(buf *bytes.Buffer) error { 19 | var hasStartBefore bool 20 | 21 | elements, err := processPunctuation(p.Element) 22 | if err != nil { 23 | return err 24 | } 25 | for i, el := range elements { 26 | if token, ok := el.(lexer.Token); ok { 27 | hasStartBefore = (i == 1) 28 | writeParenthesis(buf, token, p.IndentLevel, p.ColumnCount, p.InColumnArea, hasStartBefore) 29 | } else { 30 | el.Reindent(buf) 31 | } 32 | } 33 | 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel indents by its specified indent level 38 | func (p *Parenthesis) IncrementIndentLevel(lev int) { 39 | p.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/reindenter.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 9 | ) 10 | 11 | // Reindenter interface 12 | // specific values of Reindenter would be clause group or token 13 | type Reindenter interface { 14 | Reindent(buf *bytes.Buffer) error 15 | IncrementIndentLevel(lev int) 16 | } 17 | 18 | // count of ident appearing in column area 19 | var columnCount int 20 | 21 | // to reindent 22 | const ( 23 | NewLine = "\n" 24 | WhiteSpace = " " 25 | DoubleWhiteSpace = " " 26 | ) 27 | 28 | func write(buf *bytes.Buffer, token lexer.Token, indent int) { 29 | switch { 30 | case token.IsNeedNewLineBefore(): 31 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 32 | case token.Type == lexer.COMMA: 33 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 34 | case token.Type == lexer.DO: 35 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, token.Value, WhiteSpace)) 36 | case strings.HasPrefix(token.Value, "::"): 37 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 38 | case token.Type == lexer.WITH: 39 | buf.WriteString(fmt.Sprintf("%s%s", NewLine, token.Value)) 40 | default: 41 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 42 | } 43 | } 44 | 45 | func writeWithComma(buf *bytes.Buffer, v interface{}, indent int) error { 46 | if token, ok := v.(lexer.Token); ok { 47 | switch { 48 | case token.IsNeedNewLineBefore(): 49 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 50 | case token.Type == lexer.BY: 51 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 52 | case token.Type == lexer.COMMA: 53 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, token.Value)) 54 | default: 55 | return fmt.Errorf("can not reindent %#v", token.Value) 56 | } 57 | } else if str, ok := v.(string); ok { 58 | str = strings.TrimRight(str, " ") 59 | if columnCount == 0 { 60 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, str)) 61 | } else if strings.HasPrefix(token.Value, "::") { 62 | buf.WriteString(fmt.Sprintf("%s", str)) 63 | } else { 64 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, str)) 65 | } 66 | columnCount++ 67 | } 68 | return nil 69 | } 70 | 71 | func writeSelect(buf *bytes.Buffer, el interface{}, indent int) error { 72 | if token, ok := el.(lexer.Token); ok { 73 | switch token.Type { 74 | case lexer.SELECT, lexer.INTO: 75 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 76 | case lexer.AS, lexer.DISTINCT, lexer.DISTINCTROW, lexer.GROUP, lexer.ON: 77 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 78 | case lexer.EXISTS: 79 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 80 | columnCount++ 81 | case lexer.COMMA: 82 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, token.Value)) 83 | default: 84 | return fmt.Errorf("can not reindent %#v", token.Value) 85 | } 86 | } else if str, ok := el.(string); ok { 87 | str = strings.Trim(str, WhiteSpace) 88 | if columnCount == 0 { 89 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, str)) 90 | } else { 91 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, str)) 92 | } 93 | columnCount++ 94 | } 95 | return nil 96 | } 97 | 98 | func writeCase(buf *bytes.Buffer, token lexer.Token, indent int, hasCommaBefore bool) { 99 | if hasCommaBefore { 100 | switch token.Type { 101 | case lexer.CASE: 102 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 103 | case lexer.WHEN, lexer.ELSE: 104 | buf.WriteString(fmt.Sprintf("%s%s%s%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, WhiteSpace, WhiteSpace, DoubleWhiteSpace, token.Value)) 105 | case lexer.END: 106 | buf.WriteString(fmt.Sprintf("%s%s%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, WhiteSpace, WhiteSpace, token.Value)) 107 | case lexer.COMMA: 108 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 109 | default: 110 | if strings.HasPrefix(token.Value, "::") { 111 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 112 | } else { 113 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 114 | } 115 | } 116 | } else { 117 | switch token.Type { 118 | case lexer.CASE, lexer.END: 119 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, token.Value)) 120 | case lexer.WHEN, lexer.ELSE: 121 | buf.WriteString(fmt.Sprintf("%s%s%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, WhiteSpace, DoubleWhiteSpace, token.Value)) 122 | case lexer.COMMA: 123 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 124 | default: 125 | if strings.HasPrefix(token.Value, "::") { 126 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 127 | } else { 128 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 129 | } 130 | } 131 | } 132 | } 133 | 134 | func writeJoin(buf *bytes.Buffer, token lexer.Token, indent int, isFirst bool) { 135 | switch { 136 | case isFirst && token.IsJoinStart(): 137 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 138 | case token.Type == lexer.ON || token.Type == lexer.USING: 139 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 140 | case strings.HasPrefix(token.Value, "::"): 141 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 142 | default: 143 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 144 | } 145 | } 146 | 147 | func writeFunction(buf *bytes.Buffer, token, prev lexer.Token, indent, columnCount int, inColumnArea bool) { 148 | switch { 149 | case prev.Type == lexer.STARTPARENTHESIS || token.Type == lexer.STARTPARENTHESIS || token.Type == lexer.ENDPARENTHESIS: 150 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 151 | case token.Type == lexer.FUNCTION && columnCount == 0 && inColumnArea: 152 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, token.Value)) 153 | case token.Type == lexer.FUNCTION: 154 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 155 | case token.Type == lexer.COMMA: 156 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 157 | case strings.HasPrefix(token.Value, "::"): 158 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 159 | default: 160 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 161 | } 162 | } 163 | 164 | func writeParenthesis(buf *bytes.Buffer, token lexer.Token, indent, columnCount int, inColumnArea, hasStartBefore bool) { 165 | switch { 166 | case token.Type == lexer.STARTPARENTHESIS && columnCount == 0 && inColumnArea: 167 | buf.WriteString(fmt.Sprintf("%s%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), DoubleWhiteSpace, token.Value)) 168 | case token.Type == lexer.STARTPARENTHESIS: 169 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 170 | case token.Type == lexer.ENDPARENTHESIS: 171 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 172 | case token.Type == lexer.COMMA: 173 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 174 | case hasStartBefore: 175 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 176 | case strings.HasPrefix(token.Value, "::"): 177 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 178 | default: 179 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 180 | } 181 | } 182 | 183 | func writeSubquery(buf *bytes.Buffer, token lexer.Token, indent, columnCount int, inColumnArea bool) { 184 | switch { 185 | case token.Type == lexer.STARTPARENTHESIS && columnCount == 0 && inColumnArea: 186 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 187 | case token.Type == lexer.STARTPARENTHESIS: 188 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 189 | case token.Type == lexer.ENDPARENTHESIS && columnCount > 0: 190 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent), token.Value)) 191 | case token.Type == lexer.ENDPARENTHESIS: 192 | buf.WriteString(fmt.Sprintf("%s%s%s", NewLine, strings.Repeat(DoubleWhiteSpace, indent-1), token.Value)) 193 | case strings.HasPrefix(token.Value, "::"): 194 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 195 | default: 196 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 197 | } 198 | } 199 | 200 | func writeTypeCast(buf *bytes.Buffer, token lexer.Token) { 201 | switch token.Type { 202 | case lexer.TYPE: 203 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 204 | case lexer.COMMA: 205 | buf.WriteString(fmt.Sprintf("%s%s", token.Value, WhiteSpace)) 206 | default: 207 | buf.WriteString(fmt.Sprintf("%s", token.Value)) 208 | } 209 | } 210 | 211 | func writeLock(buf *bytes.Buffer, token lexer.Token) { 212 | switch token.Type { 213 | case lexer.LOCK: 214 | buf.WriteString(fmt.Sprintf("%s%s", NewLine, token.Value)) 215 | case lexer.IN: 216 | buf.WriteString(fmt.Sprintf("%s%s", NewLine, token.Value)) 217 | default: 218 | buf.WriteString(fmt.Sprintf("%s%s", WhiteSpace, token.Value)) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/returning_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Returning clause 10 | type Returning struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (r *Returning) Reindent(buf *bytes.Buffer) error { 17 | columnCount = 0 18 | 19 | src, err := processPunctuation(r.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for _, el := range separate(src) { 25 | switch v := el.(type) { 26 | case lexer.Token, string: 27 | if err := writeWithComma(buf, v, r.IndentLevel); err != nil { 28 | return err 29 | } 30 | case Reindenter: 31 | v.Reindent(buf) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel increments by its specified indent level 38 | func (r *Returning) IncrementIndentLevel(lev int) { 39 | r.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/returning_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentReturningGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.RETURNING, Value: "RETURNING"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | lexer.Token{Type: lexer.COMMA, Value: ","}, 22 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 23 | }, 24 | want: "\nRETURNING\n something1\n , something1", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | returningGroup := &Returning{Element: tt.tokenSource} 30 | 31 | returningGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/select_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Select clause 12 | type Select struct { 13 | Element []Reindenter 14 | IndentLevel int 15 | } 16 | 17 | // Reindent reindens its elements 18 | func (s *Select) Reindent(buf *bytes.Buffer) error { 19 | columnCount = 0 20 | 21 | src, err := processPunctuation(s.Element) 22 | if err != nil { 23 | return err 24 | } 25 | elements := separate(src) 26 | 27 | for i, element := range elements { 28 | switch v := element.(type) { 29 | case lexer.Token, string: 30 | if err := writeSelect(buf, element, s.IndentLevel); err != nil { 31 | return errors.Wrap(err, "writeSelect failed") 32 | } 33 | case *Case: 34 | if tok, ok := elements[i-1].(lexer.Token); ok { 35 | if tok.Type == lexer.COMMA { 36 | v.hasCommaBefore = true 37 | } 38 | } 39 | v.Reindent(buf) 40 | // Case group in Select clause must be in column area 41 | columnCount++ 42 | case *Parenthesis: 43 | v.InColumnArea = true 44 | v.ColumnCount = columnCount 45 | v.Reindent(buf) 46 | columnCount++ 47 | case *Subquery: 48 | if token, ok := elements[i-1].(lexer.Token); ok { 49 | if token.Type == lexer.EXISTS { 50 | v.Reindent(buf) 51 | continue 52 | } 53 | } 54 | v.InColumnArea = true 55 | v.ColumnCount = columnCount 56 | v.Reindent(buf) 57 | case *Function: 58 | v.InColumnArea = true 59 | v.ColumnCount = columnCount 60 | v.Reindent(buf) 61 | columnCount++ 62 | case Reindenter: 63 | v.Reindent(buf) 64 | columnCount++ 65 | default: 66 | return fmt.Errorf("can not reindent %#v", v) 67 | } 68 | } 69 | return nil 70 | } 71 | 72 | // IncrementIndentLevel increments by its specified indent level 73 | func (s *Select) IncrementIndentLevel(lev int) { 74 | s.IndentLevel += lev 75 | } 76 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/select_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentSelectGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "name"}, 21 | lexer.Token{Type: lexer.COMMA, Value: ","}, 22 | lexer.Token{Type: lexer.IDENT, Value: "age"}, 23 | }, 24 | want: "\nSELECT\n name\n , age", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | selectGroup := &Select{Element: tt.tokenSource} 30 | 31 | selectGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | 39 | func TestIncrementIndentLevel(t *testing.T) { 40 | s := &Select{} 41 | s.IncrementIndentLevel(1) 42 | got := s.IndentLevel 43 | want := 1 44 | if got != want { 45 | t.Errorf("want %#v got %#v", want, got) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/set_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Set clause 10 | type Set struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (s *Set) Reindent(buf *bytes.Buffer) error { 17 | columnCount = 0 18 | 19 | src, err := processPunctuation(s.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for _, el := range separate(src) { 25 | switch v := el.(type) { 26 | case lexer.Token, string: 27 | if err := writeWithComma(buf, v, s.IndentLevel); err != nil { 28 | return err 29 | } 30 | case Reindenter: 31 | v.Reindent(buf) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel increments by its specified indent level 38 | func (s *Set) IncrementIndentLevel(lev int) { 39 | s.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/set_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentSetGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.SET, Value: "SET"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "="}, 22 | lexer.Token{Type: lexer.IDENT, Value: "$1"}, 23 | }, 24 | want: "\nSET\n something1 = $1", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | setGroup := &Set{Element: tt.tokenSource} 30 | 31 | setGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/subquery_group _test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentSubqueryGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | src []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normalcase", 18 | src: []Reindenter{ 19 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 20 | &Select{ 21 | Element: []Reindenter{ 22 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 23 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 24 | }, 25 | IndentLevel: 1, 26 | }, 27 | &From{ 28 | Element: []Reindenter{ 29 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 30 | lexer.Token{Type: lexer.IDENT, Value: "xxxxxx"}, 31 | }, 32 | IndentLevel: 1, 33 | }, 34 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 35 | }, 36 | want: " (\n SELECT\n xxxxxx\n FROM xxxxxx)", 37 | }, 38 | } 39 | for _, tt := range tests { 40 | buf := &bytes.Buffer{} 41 | parenGroup := &Parenthesis{Element: tt.src, IndentLevel: 1} 42 | 43 | parenGroup.Reindent(buf) 44 | got := buf.String() 45 | if tt.want != got { 46 | t.Errorf("want%#v, got %#v", tt.want, got) 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/subquery_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Subquery group 10 | type Subquery struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | InColumnArea bool 14 | ColumnCount int 15 | } 16 | 17 | // Reindent reindents its elements 18 | func (s *Subquery) Reindent(buf *bytes.Buffer) error { 19 | elements, err := processPunctuation(s.Element) 20 | if err != nil { 21 | return err 22 | } 23 | for _, el := range elements { 24 | if token, ok := el.(lexer.Token); ok { 25 | writeSubquery(buf, token, s.IndentLevel, s.ColumnCount, s.InColumnArea) 26 | } else { 27 | if s.InColumnArea { 28 | el.IncrementIndentLevel(1) 29 | el.Reindent(buf) 30 | } else { 31 | el.Reindent(buf) 32 | } 33 | } 34 | } 35 | return nil 36 | } 37 | 38 | // IncrementIndentLevel increments by its specified indent level 39 | func (s *Subquery) IncrementIndentLevel(lev int) { 40 | s.IndentLevel += lev 41 | } 42 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/tie_clause_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // TieClause such as UNION, EXCEPT, INTERSECT 10 | type TieClause struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (tie *TieClause) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(tie.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, tie.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (tie *TieClause) IncrementIndentLevel(lev int) { 33 | tie.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/tie_clause_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentUnionGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case1", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.UNION, Value: "UNION"}, 20 | lexer.Token{Type: lexer.ALL, Value: "ALL"}, 21 | }, 22 | want: "\nUNION ALL", 23 | }, 24 | { 25 | name: "normal case2", 26 | tokenSource: []Reindenter{ 27 | lexer.Token{Type: lexer.INTERSECT, Value: "INTERSECT"}, 28 | }, 29 | want: "\nINTERSECT", 30 | }, 31 | { 32 | name: "normal case3", 33 | tokenSource: []Reindenter{ 34 | lexer.Token{Type: lexer.EXCEPT, Value: "EXCEPT"}, 35 | }, 36 | want: "\nEXCEPT", 37 | }, 38 | } 39 | for _, tt := range tests { 40 | buf := &bytes.Buffer{} 41 | unionGroup := &TieClause{Element: tt.tokenSource} 42 | 43 | unionGroup.Reindent(buf) 44 | got := buf.String() 45 | if tt.want != got { 46 | t.Errorf("want%#v, got %#v", tt.want, got) 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/type_cast_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // TypeCast group 10 | type TypeCast struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (t *TypeCast) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(t.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | writeTypeCast(buf, token) 24 | } 25 | } 26 | return nil 27 | } 28 | 29 | // IncrementIndentLevel increments by its specified indent level 30 | func (t *TypeCast) IncrementIndentLevel(lev int) { 31 | t.IndentLevel += lev 32 | } 33 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/update_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Update clause 10 | type Update struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (u *Update) Reindent(buf *bytes.Buffer) error { 17 | columnCount = 0 18 | 19 | src, err := processPunctuation(u.Element) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | for _, el := range separate(src) { 25 | switch v := el.(type) { 26 | case lexer.Token, string: 27 | if err := writeWithComma(buf, v, u.IndentLevel); err != nil { 28 | return err 29 | } 30 | case Reindenter: 31 | v.Reindent(buf) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // IncrementIndentLevel increments by its specified indent level 38 | func (u *Update) IncrementIndentLevel(lev int) { 39 | u.IndentLevel += lev 40 | } 41 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/update_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentUpdateGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.UPDATE, Value: "UPDATE"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | }, 22 | want: "\nUPDATE\n something1", 23 | }, 24 | } 25 | for _, tt := range tests { 26 | buf := &bytes.Buffer{} 27 | updateGroup := &Update{Element: tt.tokenSource} 28 | 29 | updateGroup.Reindent(buf) 30 | got := buf.String() 31 | if tt.want != got { 32 | t.Errorf("want%#v, got %#v", tt.want, got) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/util.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/pkg/errors" 9 | 10 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 11 | ) 12 | 13 | // separate elements by comma and the reserved word in select clause 14 | func separate(rs []Reindenter) []interface{} { 15 | var ( 16 | result []interface{} 17 | skipRange, count int 18 | ) 19 | buf := &bytes.Buffer{} 20 | 21 | for _, r := range rs { 22 | if token, ok := r.(lexer.Token); !ok { 23 | if buf.String() != "" { 24 | result = append(result, buf.String()) 25 | buf.Reset() 26 | } 27 | result = append(result, r) 28 | } else { 29 | switch { 30 | case skipRange > 0: 31 | skipRange-- 32 | // TODO: more elegant 33 | case token.IsKeyWordInSelect(): 34 | if buf.String() != "" { 35 | result = append(result, buf.String()) 36 | buf.Reset() 37 | } 38 | result = append(result, token) 39 | case token.Type == lexer.COMMA: 40 | if buf.String() != "" { 41 | result = append(result, buf.String()) 42 | } 43 | result = append(result, token) 44 | buf.Reset() 45 | count = 0 46 | case strings.HasPrefix(token.Value, "::"): 47 | buf.WriteString(token.Value) 48 | default: 49 | if count == 0 { 50 | buf.WriteString(token.Value) 51 | } else { 52 | buf.WriteString(WhiteSpace + token.Value) 53 | } 54 | count++ 55 | } 56 | } 57 | } 58 | // append the last element in buf 59 | if buf.String() != "" { 60 | result = append(result, buf.String()) 61 | } 62 | return result 63 | } 64 | 65 | // process bracket, singlequote and brace 66 | // TODO: more elegant 67 | func processPunctuation(rs []Reindenter) ([]Reindenter, error) { 68 | var ( 69 | result []Reindenter 70 | skipRange int 71 | ) 72 | 73 | for i, v := range rs { 74 | if token, ok := v.(lexer.Token); ok { 75 | switch { 76 | case skipRange > 0: 77 | skipRange-- 78 | case token.Type == lexer.STARTBRACE || token.Type == lexer.STARTBRACKET: 79 | surrounding, sr, err := extractSurroundingArea(rs[i:]) 80 | if err != nil { 81 | return nil, err 82 | } 83 | result = append(result, lexer.Token{ 84 | Type: lexer.SURROUNDING, 85 | Value: surrounding, 86 | }) 87 | skipRange += sr 88 | default: 89 | result = append(result, token) 90 | } 91 | } else { 92 | result = append(result, v) 93 | } 94 | } 95 | return result, nil 96 | } 97 | 98 | // returns surrounding area including punctuation such as {xxx, xxx} 99 | func extractSurroundingArea(rs []Reindenter) (string, int, error) { 100 | var ( 101 | countOfStart int 102 | countOfEnd int 103 | result string 104 | skipRange int 105 | ) 106 | for i, r := range rs { 107 | if token, ok := r.(lexer.Token); ok { 108 | switch { 109 | case token.Type == lexer.COMMA || token.Type == lexer.STARTBRACKET || token.Type == lexer.STARTBRACE || token.Type == lexer.ENDBRACKET || token.Type == lexer.ENDBRACE: 110 | result += fmt.Sprint(token.Value) 111 | // for next token of StartToken 112 | case i == 1: 113 | result += fmt.Sprint(token.Value) 114 | default: 115 | result += fmt.Sprint(WhiteSpace + token.Value) 116 | } 117 | 118 | if token.Type == lexer.STARTBRACKET || token.Type == lexer.STARTBRACE || token.Type == lexer.STARTPARENTHESIS { 119 | countOfStart++ 120 | } 121 | if token.Type == lexer.ENDBRACKET || token.Type == lexer.ENDBRACE || token.Type == lexer.ENDPARENTHESIS { 122 | countOfEnd++ 123 | } 124 | if countOfStart == countOfEnd { 125 | break 126 | } 127 | skipRange++ 128 | } else { 129 | // TODO: should support group type in surrounding area? 130 | // I have not encountered any groups in surrounding area so far 131 | return "", -1, errors.New("group type is not supposed be here") 132 | } 133 | } 134 | return result, skipRange, nil 135 | } 136 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/values_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Values clause 10 | type Values struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (val *Values) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(val.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, val.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (val *Values) IncrementIndentLevel(lev int) { 33 | val.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/values_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentValuesGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.VALUES, Value: "VALUES"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "xxxxx"}, 21 | lexer.Token{Type: lexer.ON, Value: "ON"}, 22 | lexer.Token{Type: lexer.IDENT, Value: "xxxxx"}, 23 | lexer.Token{Type: lexer.DO, Value: "DO"}, 24 | }, 25 | want: "\nVALUES xxxxx\nON xxxxx\nDO ", 26 | }, 27 | } 28 | for _, tt := range tests { 29 | buf := &bytes.Buffer{} 30 | valuesGroup := &Values{Element: tt.tokenSource} 31 | 32 | valuesGroup.Reindent(buf) 33 | got := buf.String() 34 | if tt.want != got { 35 | t.Errorf("want%#v, got %#v", tt.want, got) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/where_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // Where clause 10 | type Where struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (w *Where) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(w.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, w.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (w *Where) IncrementIndentLevel(lev int) { 33 | w.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/where_group_test.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestReindentWhereGroup(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | tokenSource []Reindenter 14 | want string 15 | }{ 16 | { 17 | name: "normal case", 18 | tokenSource: []Reindenter{ 19 | lexer.Token{Type: lexer.WHERE, Value: "WHERE"}, 20 | lexer.Token{Type: lexer.IDENT, Value: "something1"}, 21 | lexer.Token{Type: lexer.IDENT, Value: "="}, 22 | lexer.Token{Type: lexer.IDENT, Value: "something2"}, 23 | }, 24 | want: "\nWHERE something1 = something2", 25 | }, 26 | } 27 | for _, tt := range tests { 28 | buf := &bytes.Buffer{} 29 | whereGroup := &Where{Element: tt.tokenSource} 30 | 31 | whereGroup.Reindent(buf) 32 | got := buf.String() 33 | if tt.want != got { 34 | t.Errorf("want%#v, got %#v", tt.want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /sqlfmt/parser/group/with_group.go: -------------------------------------------------------------------------------- 1 | package group 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | ) 8 | 9 | // With clause 10 | type With struct { 11 | Element []Reindenter 12 | IndentLevel int 13 | } 14 | 15 | // Reindent reindents its elements 16 | func (w *With) Reindent(buf *bytes.Buffer) error { 17 | elements, err := processPunctuation(w.Element) 18 | if err != nil { 19 | return err 20 | } 21 | for _, el := range elements { 22 | if token, ok := el.(lexer.Token); ok { 23 | write(buf, token, w.IndentLevel) 24 | } else { 25 | el.Reindent(buf) 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | // IncrementIndentLevel increments by its specified indent level 32 | func (w *With) IncrementIndentLevel(lev int) { 33 | w.IndentLevel += lev 34 | } 35 | -------------------------------------------------------------------------------- /sqlfmt/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 5 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser/group" 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | // TODO: calling each Retrieve function is not smart, so should be refactored 10 | 11 | // Parser parses Token Source 12 | type parser struct { 13 | offset int 14 | result []group.Reindenter 15 | err error 16 | } 17 | 18 | // ParseTokens parses Tokens, creating slice of Reindenter 19 | // each Reindenter is group of SQL Clause such as SelectGroup, FromGroup ...etc 20 | func ParseTokens(tokens []lexer.Token) ([]group.Reindenter, error) { 21 | if !isSQL(tokens[0].Type) { 22 | return nil, errors.New("can not parse no sql statement") 23 | } 24 | 25 | var ( 26 | offset int 27 | result []group.Reindenter 28 | ) 29 | 30 | for { 31 | if tokens[offset].Type == lexer.EOF { 32 | break 33 | } 34 | 35 | r := NewRetriever(tokens[offset:]) 36 | element, endIdx, err := r.Retrieve() 37 | if err != nil { 38 | return nil, errors.Wrap(err, "ParseTokens failed") 39 | } 40 | 41 | group := createGroup(element) 42 | result = append(result, group) 43 | 44 | offset += endIdx 45 | } 46 | return result, nil 47 | } 48 | 49 | func isSQL(ttype lexer.TokenType) bool { 50 | return ttype == lexer.SELECT || ttype == lexer.UPDATE || ttype == lexer.DELETE || ttype == lexer.INSERT || ttype == lexer.LOCK || ttype == lexer.WITH 51 | } 52 | -------------------------------------------------------------------------------- /sqlfmt/parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser/group" 9 | ) 10 | 11 | func TestParseTokens(t *testing.T) { 12 | testingData := []lexer.Token{ 13 | {Type: lexer.SELECT, Value: "SELECT"}, 14 | {Type: lexer.IDENT, Value: "name"}, 15 | {Type: lexer.COMMA, Value: ","}, 16 | {Type: lexer.IDENT, Value: "age"}, 17 | {Type: lexer.COMMA, Value: ","}, 18 | 19 | {Type: lexer.FUNCTION, Value: "SUM"}, 20 | {Type: lexer.STARTPARENTHESIS, Value: "("}, 21 | {Type: lexer.IDENT, Value: "xxx"}, 22 | {Type: lexer.ENDPARENTHESIS, Value: ")"}, 23 | 24 | {Type: lexer.STARTPARENTHESIS, Value: "("}, 25 | {Type: lexer.IDENT, Value: "xxx"}, 26 | {Type: lexer.ENDPARENTHESIS, Value: ")"}, 27 | 28 | {Type: lexer.TYPE, Value: "TEXT"}, 29 | {Type: lexer.STARTPARENTHESIS, Value: "("}, 30 | {Type: lexer.IDENT, Value: "xxx"}, 31 | {Type: lexer.ENDPARENTHESIS, Value: ")"}, 32 | 33 | {Type: lexer.FROM, Value: "FROM"}, 34 | {Type: lexer.IDENT, Value: "user"}, 35 | {Type: lexer.WHERE, Value: "WHERE"}, 36 | {Type: lexer.IDENT, Value: "name"}, 37 | {Type: lexer.IDENT, Value: "="}, 38 | {Type: lexer.STRING, Value: "'xxx'"}, 39 | {Type: lexer.EOF, Value: "EOF"}, 40 | } 41 | testingData2 := []lexer.Token{ 42 | {Type: lexer.SELECT, Value: "SELECT"}, 43 | {Type: lexer.IDENT, Value: "xxx"}, 44 | {Type: lexer.FROM, Value: "FROM"}, 45 | {Type: lexer.IDENT, Value: "xxx"}, 46 | {Type: lexer.WHERE, Value: "WHERE"}, 47 | {Type: lexer.IDENT, Value: "xxx"}, 48 | {Type: lexer.IN, Value: "IN"}, 49 | {Type: lexer.STARTPARENTHESIS, Value: "("}, 50 | {Type: lexer.SELECT, Value: "SELECT"}, 51 | {Type: lexer.IDENT, Value: "xxx"}, 52 | {Type: lexer.FROM, Value: "FROM"}, 53 | {Type: lexer.IDENT, Value: "xxx"}, 54 | {Type: lexer.JOIN, Value: "JOIN"}, 55 | {Type: lexer.IDENT, Value: "xxx"}, 56 | {Type: lexer.ON, Value: "ON"}, 57 | {Type: lexer.IDENT, Value: "xxx"}, 58 | {Type: lexer.IDENT, Value: "="}, 59 | {Type: lexer.IDENT, Value: "xxx"}, 60 | {Type: lexer.ENDPARENTHESIS, Value: ")"}, 61 | {Type: lexer.GROUP, Value: "GROUP"}, 62 | {Type: lexer.BY, Value: "BY"}, 63 | {Type: lexer.IDENT, Value: "xxx"}, 64 | {Type: lexer.ORDER, Value: "ORDER"}, 65 | {Type: lexer.BY, Value: "BY"}, 66 | {Type: lexer.IDENT, Value: "xxx"}, 67 | {Type: lexer.LIMIT, Value: "LIMIT"}, 68 | {Type: lexer.IDENT, Value: "xxx"}, 69 | {Type: lexer.UNION, Value: "UNION"}, 70 | {Type: lexer.ALL, Value: "ALL"}, 71 | {Type: lexer.SELECT, Value: "SELECT"}, 72 | {Type: lexer.IDENT, Value: "xxx"}, 73 | {Type: lexer.FROM, Value: "FROM"}, 74 | {Type: lexer.IDENT, Value: "xxx"}, 75 | {Type: lexer.EOF, Value: "EOF"}, 76 | } 77 | testingData3 := []lexer.Token{ 78 | {Type: lexer.UPDATE, Value: "UPDATE"}, 79 | {Type: lexer.IDENT, Value: "user"}, 80 | {Type: lexer.SET, Value: "SET"}, 81 | {Type: lexer.IDENT, Value: "point"}, 82 | {Type: lexer.IDENT, Value: "="}, 83 | {Type: lexer.IDENT, Value: "0"}, 84 | {Type: lexer.EOF, Value: "EOF"}, 85 | } 86 | 87 | tests := []struct { 88 | name string 89 | tokenSource []lexer.Token 90 | want []group.Reindenter 91 | }{ 92 | { 93 | name: "normal test case 1", 94 | tokenSource: testingData, 95 | want: []group.Reindenter{ 96 | &group.Select{ 97 | Element: []group.Reindenter{ 98 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 99 | lexer.Token{Type: lexer.IDENT, Value: "name"}, 100 | lexer.Token{Type: lexer.COMMA, Value: ","}, 101 | lexer.Token{Type: lexer.IDENT, Value: "age"}, 102 | lexer.Token{Type: lexer.COMMA, Value: ","}, 103 | &group.Function{ 104 | Element: []group.Reindenter{ 105 | lexer.Token{Type: lexer.FUNCTION, Value: "SUM"}, 106 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 107 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 108 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 109 | }, 110 | }, 111 | &group.Parenthesis{ 112 | Element: []group.Reindenter{ 113 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 114 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 115 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 116 | }, 117 | }, 118 | &group.TypeCast{ 119 | Element: []group.Reindenter{ 120 | lexer.Token{Type: lexer.TYPE, Value: "TEXT"}, 121 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 122 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 123 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 124 | }, 125 | }, 126 | }, 127 | }, 128 | &group.From{ 129 | Element: []group.Reindenter{ 130 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 131 | lexer.Token{Type: lexer.IDENT, Value: "user"}, 132 | }, 133 | }, 134 | &group.Where{ 135 | Element: []group.Reindenter{ 136 | lexer.Token{Type: lexer.WHERE, Value: "WHERE"}, 137 | lexer.Token{Type: lexer.IDENT, Value: "name"}, 138 | lexer.Token{Type: lexer.IDENT, Value: "="}, 139 | lexer.Token{Type: lexer.STRING, Value: "'xxx'"}, 140 | }, 141 | }, 142 | }, 143 | }, 144 | { 145 | name: "normal test case 2", 146 | tokenSource: testingData2, 147 | want: []group.Reindenter{ 148 | &group.Select{ 149 | Element: []group.Reindenter{ 150 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 151 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 152 | }, 153 | }, 154 | &group.From{ 155 | Element: []group.Reindenter{ 156 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 157 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 158 | }, 159 | }, 160 | &group.Where{ 161 | Element: []group.Reindenter{ 162 | lexer.Token{Type: lexer.WHERE, Value: "WHERE"}, 163 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 164 | lexer.Token{Type: lexer.IN, Value: "IN"}, 165 | &group.Subquery{ 166 | Element: []group.Reindenter{ 167 | lexer.Token{Type: lexer.STARTPARENTHESIS, Value: "("}, 168 | &group.Select{ 169 | Element: []group.Reindenter{ 170 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 171 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 172 | }, 173 | IndentLevel: 1, 174 | }, 175 | &group.From{ 176 | Element: []group.Reindenter{ 177 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 178 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 179 | }, 180 | IndentLevel: 1, 181 | }, 182 | &group.Join{ 183 | Element: []group.Reindenter{ 184 | lexer.Token{Type: lexer.JOIN, Value: "JOIN"}, 185 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 186 | lexer.Token{Type: lexer.ON, Value: "ON"}, 187 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 188 | lexer.Token{Type: lexer.IDENT, Value: "="}, 189 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 190 | }, 191 | IndentLevel: 1, 192 | }, 193 | lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"}, 194 | }, 195 | IndentLevel: 1, 196 | }, 197 | }, 198 | }, 199 | &group.GroupBy{ 200 | Element: []group.Reindenter{ 201 | lexer.Token{Type: lexer.GROUP, Value: "GROUP"}, 202 | lexer.Token{Type: lexer.BY, Value: "BY"}, 203 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 204 | }, 205 | }, 206 | &group.OrderBy{ 207 | Element: []group.Reindenter{ 208 | lexer.Token{Type: lexer.ORDER, Value: "ORDER"}, 209 | lexer.Token{Type: lexer.BY, Value: "BY"}, 210 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 211 | }, 212 | }, 213 | &group.LimitClause{ 214 | Element: []group.Reindenter{ 215 | lexer.Token{Type: lexer.LIMIT, Value: "LIMIT"}, 216 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 217 | }, 218 | }, 219 | &group.TieClause{ 220 | Element: []group.Reindenter{ 221 | lexer.Token{Type: lexer.UNION, Value: "UNION"}, 222 | lexer.Token{Type: lexer.ALL, Value: "ALL"}, 223 | }, 224 | }, 225 | &group.Select{ 226 | Element: []group.Reindenter{ 227 | lexer.Token{Type: lexer.SELECT, Value: "SELECT"}, 228 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 229 | }, 230 | }, 231 | &group.From{ 232 | Element: []group.Reindenter{ 233 | lexer.Token{Type: lexer.FROM, Value: "FROM"}, 234 | lexer.Token{Type: lexer.IDENT, Value: "xxx"}, 235 | }, 236 | }, 237 | }, 238 | }, 239 | { 240 | name: "normal test case 3", 241 | tokenSource: testingData3, 242 | want: []group.Reindenter{ 243 | &group.Update{ 244 | Element: []group.Reindenter{ 245 | lexer.Token{Type: lexer.UPDATE, Value: "UPDATE"}, 246 | lexer.Token{Type: lexer.IDENT, Value: "user"}, 247 | }, 248 | }, 249 | &group.Set{ 250 | Element: []group.Reindenter{ 251 | lexer.Token{Type: lexer.SET, Value: "SET"}, 252 | lexer.Token{Type: lexer.IDENT, Value: "point"}, 253 | lexer.Token{Type: lexer.IDENT, Value: "="}, 254 | lexer.Token{Type: lexer.IDENT, Value: "0"}, 255 | }, 256 | }, 257 | }, 258 | }, 259 | } 260 | for _, tt := range tests { 261 | got, err := ParseTokens(tt.tokenSource) 262 | if err != nil { 263 | t.Errorf("ERROR: %#v", err) 264 | } 265 | if !reflect.DeepEqual(got, tt.want) { 266 | t.Errorf("\nwant %#v, \ngot %#v", tt.want, got) 267 | } 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /sqlfmt/parser/retriever.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/parser/group" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Retriever retrieves target SQL clause group from TokenSource 12 | type Retriever struct { 13 | TokenSource []lexer.Token 14 | result []group.Reindenter 15 | indentLevel int 16 | endTokenTypes []lexer.TokenType 17 | endIdx int 18 | } 19 | 20 | // NewRetriever Creates Retriever that retrieves each target SQL clause 21 | // Each Retriever have endKeywords in order to stop retrieving 22 | func NewRetriever(tokenSource []lexer.Token) *Retriever { 23 | firstTokenType := tokenSource[0].Type 24 | switch firstTokenType { 25 | case lexer.SELECT: 26 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfSelect} 27 | case lexer.FROM: 28 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfFrom} 29 | case lexer.CASE: 30 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfCase} 31 | case lexer.JOIN, lexer.INNER, lexer.OUTER, lexer.LEFT, lexer.RIGHT, lexer.NATURAL, lexer.CROSS: 32 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfJoin} 33 | case lexer.WHERE: 34 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfWhere} 35 | case lexer.ANDGROUP: 36 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfAndGroup} 37 | case lexer.ORGROUP: 38 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfOrGroup} 39 | case lexer.GROUP: 40 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfGroupBy} 41 | case lexer.HAVING: 42 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfHaving} 43 | case lexer.ORDER: 44 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfOrderBy} 45 | case lexer.LIMIT, lexer.FETCH, lexer.OFFSET: 46 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfLimitClause} 47 | case lexer.STARTPARENTHESIS: 48 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfParenthesis} 49 | case lexer.UNION, lexer.INTERSECT, lexer.EXCEPT: 50 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfTieClause} 51 | case lexer.UPDATE: 52 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfUpdate} 53 | case lexer.SET: 54 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfSet} 55 | case lexer.RETURNING: 56 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfReturning} 57 | case lexer.DELETE: 58 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfDelete} 59 | case lexer.INSERT: 60 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfInsert} 61 | case lexer.VALUES: 62 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfValues} 63 | case lexer.FUNCTION: 64 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfFunction} 65 | case lexer.TYPE: 66 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfTypeCast} 67 | case lexer.LOCK: 68 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfLock} 69 | case lexer.WITH: 70 | return &Retriever{TokenSource: tokenSource, endTokenTypes: lexer.EndOfWith} 71 | default: 72 | return nil 73 | } 74 | } 75 | 76 | // Retrieve Retrieves group of SQL clauses 77 | // It returns clause group as slice of Reintenter interface and endIdx for setting offset 78 | func (r *Retriever) Retrieve() ([]group.Reindenter, int, error) { 79 | if err := r.appendGroupsToResult(); err != nil { 80 | return nil, -1, errors.Wrap(err, "appendGroupsToResult failed") 81 | } 82 | return r.result, r.endIdx, nil 83 | } 84 | 85 | // appendGroupsToResult appends token to result as Reindenter until endTokenType appears 86 | // if subGroup is found in the target group, subGroup will be appended to result as a Reindenter, calling itself recursive 87 | // it returns error if it cannot find any endTokenTypes 88 | func (r *Retriever) appendGroupsToResult() error { 89 | var ( 90 | idx int 91 | token lexer.Token 92 | ) 93 | for { 94 | if idx > len(r.TokenSource) { 95 | return fmt.Errorf("the retriever may have not found the endToken") 96 | } 97 | 98 | token = r.TokenSource[idx] 99 | 100 | if r.isEndGroup(token, r.endTokenTypes, idx) { 101 | r.endIdx = idx 102 | return nil 103 | } 104 | if subGroupRetriever := r.getSubGroupRetriever(idx); subGroupRetriever != nil { 105 | if !containsEndToken(subGroupRetriever.TokenSource, subGroupRetriever.endTokenTypes) { 106 | return fmt.Errorf("sub group %s has no end key word", subGroupRetriever.TokenSource[0].Value) 107 | } 108 | subGroupRetriever.appendGroupsToResult() 109 | if err := r.appendSubGroupToResult(subGroupRetriever.result, subGroupRetriever.indentLevel); err != nil { 110 | return err 111 | } 112 | idx = subGroupRetriever.getNextTokenIdx(token.Type, idx) 113 | continue 114 | } 115 | r.result = append(r.result, token) 116 | idx++ 117 | } 118 | 119 | } 120 | 121 | // check tokens contain endTokenType 122 | func containsEndToken(tokenSource []lexer.Token, endTokenTypes []lexer.TokenType) bool { 123 | for _, tok := range tokenSource { 124 | for _, endttype := range endTokenTypes { 125 | if tok.Type == endttype { 126 | return true 127 | } 128 | } 129 | } 130 | return false 131 | } 132 | 133 | // isEndGroup determines if token is the end token 134 | func (r *Retriever) isEndGroup(token lexer.Token, endTokenTypes []lexer.TokenType, idx int) bool { 135 | for _, endTokenType := range r.endTokenTypes { 136 | // ignore endTokens when first token type is equal to endTokenType because first token type might be a endTokenType. For example "AND","OR" 137 | // isRangeOfJoinStart ignores if endTokenType appears in start of Join clause such as LEFT OUTER JOIN, INNER JOIN etc ... 138 | if idx == 0 || r.isRangeOfJoinStart(idx) { 139 | return false 140 | } else if token.Type == endTokenType || token.Type == lexer.EOF { 141 | return true 142 | } 143 | } 144 | return false 145 | } 146 | 147 | // getSubGroupRetriever creates Retriever to retrieve sub group in the target group starting from tokens sliced from idx 148 | func (r *Retriever) getSubGroupRetriever(idx int) *Retriever { 149 | // when idx is equal to 0, target group itself will be Subgroup, which causes an error 150 | if idx == 0 { 151 | return nil 152 | } 153 | 154 | token := r.TokenSource[idx] 155 | nextToken := r.TokenSource[idx+1] 156 | 157 | if r.containIrregularGroupMaker(token.Type, idx) { 158 | return nil 159 | } 160 | if token.Type == lexer.STARTPARENTHESIS && nextToken.Type == lexer.SELECT { 161 | subR := NewRetriever(r.TokenSource[idx:]) 162 | subR.indentLevel = r.indentLevel 163 | 164 | // if subquery is found, indentLevel of all tokens until ")" will be incremented 165 | subR.indentLevel++ 166 | return subR 167 | } 168 | if token.IsJoinStart() { 169 | // if group keywords appears in start of join group such as LEFT INNER JOIN, those keywords will be ignored 170 | // In this case, "INNER" and "JOIN" are group keyword, but should not make subGroup 171 | rangeOfJoinGroupStart := 3 172 | if idx < rangeOfJoinGroupStart { 173 | return nil 174 | } 175 | subR := NewRetriever(r.TokenSource[idx:]) 176 | subR.indentLevel = r.indentLevel 177 | return subR 178 | } 179 | for _, v := range lexer.TokenTypesOfGroupMaker { 180 | if token.Type == v { 181 | subR := NewRetriever(r.TokenSource[idx:]) 182 | subR.indentLevel = r.indentLevel 183 | return subR 184 | } 185 | } 186 | return nil 187 | } 188 | 189 | func (r *Retriever) containIrregularGroupMaker(ttype lexer.TokenType, idx int) bool { 190 | firstTokenOfCurrentGroup := r.TokenSource[0] 191 | 192 | // in order not to make ORDER BY subGroup in Function group 193 | // this is a solution of window function 194 | if firstTokenOfCurrentGroup.Type == lexer.FUNCTION && ttype == lexer.ORDER { 195 | return true 196 | } 197 | // in order to ignore "(" in TypeCast group 198 | if firstTokenOfCurrentGroup.Type == lexer.TYPE && ttype == lexer.STARTPARENTHESIS { 199 | return true 200 | } 201 | 202 | // in order to ignore ORDER BY in window function 203 | if firstTokenOfCurrentGroup.Type == lexer.STARTPARENTHESIS && ttype == lexer.ORDER { 204 | return true 205 | } 206 | 207 | if firstTokenOfCurrentGroup.Type == lexer.FUNCTION && (ttype == lexer.STARTPARENTHESIS || ttype == lexer.FROM) { 208 | return true 209 | } 210 | 211 | if ttype == lexer.TYPE && !(r.TokenSource[idx+1].Type == lexer.STARTPARENTHESIS) { 212 | return true 213 | } 214 | 215 | return false 216 | } 217 | 218 | // if group key words to make join group such as "LEFT" or "OUTER" appear within idx is in range of join group, any keyword must be ignored not be made into a sub group 219 | func (r *Retriever) isRangeOfJoinStart(idx int) bool { 220 | firstTokenType := r.TokenSource[0].Type 221 | for _, v := range lexer.TokenTypesOfJoinMaker { 222 | joinStartRange := 3 223 | if v == firstTokenType && idx < joinStartRange { 224 | return true 225 | } 226 | } 227 | return false 228 | } 229 | 230 | // appendSubGroupToResult makes Reindenter from subGroup result and append it to result 231 | func (r *Retriever) appendSubGroupToResult(result []group.Reindenter, lev int) error { 232 | if subGroup := createGroup(result); subGroup != nil { 233 | subGroup.IncrementIndentLevel(lev) 234 | r.result = append(r.result, subGroup) 235 | } else { 236 | return fmt.Errorf("can not make sub group result :%#v", result) 237 | } 238 | return nil 239 | } 240 | 241 | // getNextTokenIdx prepares idx for next token value 242 | func (r *Retriever) getNextTokenIdx(ttype lexer.TokenType, idx int) int { 243 | // if subGroup is PARENTHESIS group or CASE group, endIdx will be index of "END" or ")" 244 | // In this case, next token must start after those end keyword, so it adds 1 to idx 245 | 246 | switch ttype { 247 | case lexer.STARTPARENTHESIS, lexer.CASE, lexer.FUNCTION, lexer.TYPE: 248 | idx += r.endIdx + 1 249 | default: 250 | idx += r.endIdx 251 | } 252 | return idx 253 | } 254 | 255 | // createGroup creates each clause group from slice of tokens, returning it as Reindenter interface 256 | func createGroup(tokenSource []group.Reindenter) group.Reindenter { 257 | firstToken, _ := tokenSource[0].(lexer.Token) 258 | 259 | switch firstToken.Type { 260 | case lexer.SELECT: 261 | return &group.Select{Element: tokenSource} 262 | case lexer.FROM: 263 | return &group.From{Element: tokenSource} 264 | case lexer.JOIN, lexer.INNER, lexer.OUTER, lexer.LEFT, lexer.RIGHT, lexer.NATURAL, lexer.CROSS: 265 | return &group.Join{Element: tokenSource} 266 | case lexer.WHERE: 267 | return &group.Where{Element: tokenSource} 268 | case lexer.ANDGROUP: 269 | return &group.AndGroup{Element: tokenSource} 270 | case lexer.ORGROUP: 271 | return &group.OrGroup{Element: tokenSource} 272 | case lexer.GROUP: 273 | return &group.GroupBy{Element: tokenSource} 274 | case lexer.ORDER: 275 | return &group.OrderBy{Element: tokenSource} 276 | case lexer.HAVING: 277 | return &group.Having{Element: tokenSource} 278 | case lexer.LIMIT, lexer.OFFSET, lexer.FETCH: 279 | return &group.LimitClause{Element: tokenSource} 280 | case lexer.UNION, lexer.INTERSECT, lexer.EXCEPT: 281 | return &group.TieClause{Element: tokenSource} 282 | case lexer.UPDATE: 283 | return &group.Update{Element: tokenSource} 284 | case lexer.SET: 285 | return &group.Set{Element: tokenSource} 286 | case lexer.RETURNING: 287 | return &group.Returning{Element: tokenSource} 288 | case lexer.LOCK: 289 | return &group.Lock{Element: tokenSource} 290 | case lexer.INSERT: 291 | return &group.Insert{Element: tokenSource} 292 | case lexer.VALUES: 293 | return &group.Values{Element: tokenSource} 294 | case lexer.DELETE: 295 | return &group.Delete{Element: tokenSource} 296 | case lexer.WITH: 297 | return &group.With{Element: tokenSource} 298 | // endKeyWord of CASE group("END") has to be included in the group, so it is appended to result 299 | case lexer.CASE: 300 | endToken := lexer.Token{Type: lexer.END, Value: "END"} 301 | tokenSource = append(tokenSource, endToken) 302 | 303 | return &group.Case{Element: tokenSource} 304 | // endKeyWord of subQuery group (")") has to be included in the group, so it is appended to result 305 | case lexer.STARTPARENTHESIS: 306 | endToken := lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"} 307 | tokenSource = append(tokenSource, endToken) 308 | 309 | if _, isSubQuery := tokenSource[1].(*group.Select); isSubQuery { 310 | return &group.Subquery{Element: tokenSource} 311 | } 312 | return &group.Parenthesis{Element: tokenSource} 313 | case lexer.FUNCTION: 314 | endToken := lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"} 315 | tokenSource = append(tokenSource, endToken) 316 | 317 | return &group.Function{Element: tokenSource} 318 | case lexer.TYPE: 319 | endToken := lexer.Token{Type: lexer.ENDPARENTHESIS, Value: ")"} 320 | tokenSource = append(tokenSource, endToken) 321 | 322 | return &group.TypeCast{Element: tokenSource} 323 | } 324 | return nil 325 | } 326 | -------------------------------------------------------------------------------- /sqlfmt/parser/retriever_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/kanmu/go-sqlfmt/sqlfmt/lexer" 8 | ) 9 | 10 | func TestNewRetriever(t *testing.T) { 11 | testingData := []lexer.Token{ 12 | {Type: lexer.SELECT, Value: "SELECT"}, 13 | {Type: lexer.IDENT, Value: "name"}, 14 | {Type: lexer.COMMA, Value: ","}, 15 | {Type: lexer.IDENT, Value: "age"}, 16 | {Type: lexer.FROM, Value: "FROM"}, 17 | {Type: lexer.IDENT, Value: "user"}, 18 | {Type: lexer.EOF, Value: "EOF"}, 19 | } 20 | r := NewRetriever(testingData) 21 | want := []lexer.Token{ 22 | {Type: lexer.SELECT, Value: "SELECT"}, 23 | {Type: lexer.IDENT, Value: "name"}, 24 | {Type: lexer.COMMA, Value: ","}, 25 | {Type: lexer.IDENT, Value: "age"}, 26 | {Type: lexer.FROM, Value: "FROM"}, 27 | {Type: lexer.IDENT, Value: "user"}, 28 | {Type: lexer.EOF, Value: "EOF"}, 29 | } 30 | got := r.TokenSource 31 | 32 | if !reflect.DeepEqual(want, got) { 33 | t.Fatalf("initialize retriever failed: want %#v got %#v", want, got) 34 | } 35 | } 36 | 37 | func TestRetrieve(t *testing.T) { 38 | type want struct { 39 | stmt []string 40 | lastIdx int 41 | } 42 | 43 | tests := []struct { 44 | name string 45 | source []lexer.Token 46 | endTokenTypes []lexer.TokenType 47 | want *want 48 | }{ 49 | { 50 | name: "normal_test", 51 | source: []lexer.Token{ 52 | {Type: lexer.SELECT, Value: "SELECT"}, 53 | {Type: lexer.IDENT, Value: "name"}, 54 | {Type: lexer.COMMA, Value: ","}, 55 | {Type: lexer.IDENT, Value: "age"}, 56 | {Type: lexer.FROM, Value: "FROM"}, 57 | {Type: lexer.IDENT, Value: "user"}, 58 | {Type: lexer.EOF, Value: "EOF"}, 59 | }, 60 | endTokenTypes: []lexer.TokenType{lexer.FROM}, 61 | want: &want{ 62 | stmt: []string{"SELECT", "name", ",", "age"}, 63 | lastIdx: 4, 64 | }, 65 | }, 66 | { 67 | name: "normal_test3", 68 | source: []lexer.Token{ 69 | {Type: lexer.LEFT, Value: "LEFT"}, 70 | {Type: lexer.JOIN, Value: "JOIN"}, 71 | {Type: lexer.IDENT, Value: "xxx"}, 72 | {Type: lexer.ON, Value: "ON"}, 73 | {Type: lexer.IDENT, Value: "xxx"}, 74 | {Type: lexer.IDENT, Value: "="}, 75 | {Type: lexer.IDENT, Value: "xxx"}, 76 | {Type: lexer.WHERE, Value: "WHERE"}, 77 | }, 78 | endTokenTypes: []lexer.TokenType{lexer.WHERE}, 79 | want: &want{ 80 | stmt: []string{"LEFT", "JOIN", "xxx", "ON", "xxx", "=", "xxx"}, 81 | lastIdx: 7, 82 | }, 83 | }, 84 | { 85 | name: "normal_test4", 86 | source: []lexer.Token{ 87 | {Type: lexer.UPDATE, Value: "UPDATE"}, 88 | {Type: lexer.IDENT, Value: "xxx"}, 89 | {Type: lexer.SET, Value: "SET"}, 90 | }, 91 | endTokenTypes: []lexer.TokenType{lexer.SET}, 92 | want: &want{ 93 | stmt: []string{"UPDATE", "xxx"}, 94 | lastIdx: 2, 95 | }, 96 | }, 97 | { 98 | name: "normal_test5", 99 | source: []lexer.Token{ 100 | {Type: lexer.INSERT, Value: "INSERT"}, 101 | {Type: lexer.INTO, Value: "INTO"}, 102 | {Type: lexer.IDENT, Value: "xxx"}, 103 | {Type: lexer.VALUES, Value: "VALUES"}, 104 | }, 105 | endTokenTypes: []lexer.TokenType{lexer.VALUES}, 106 | want: &want{ 107 | stmt: []string{"INSERT", "INTO", "xxx"}, 108 | lastIdx: 3, 109 | }, 110 | }, 111 | } 112 | for _, tt := range tests { 113 | t.Run(tt.name, func(t *testing.T) { 114 | var ( 115 | gotStmt []string 116 | gotLastIdx int 117 | ) 118 | r := &Retriever{TokenSource: tt.source, endTokenTypes: tt.endTokenTypes} 119 | reindenters, gotLastIdx, err := r.Retrieve() 120 | if err != nil { 121 | t.Errorf("ERROR:%#v", err) 122 | } 123 | 124 | for _, v := range reindenters { 125 | if tok, ok := v.(lexer.Token); ok { 126 | gotStmt = append(gotStmt, tok.Value) 127 | } 128 | } 129 | 130 | if !reflect.DeepEqual(gotStmt, tt.want.stmt) { 131 | t.Errorf("want %v, got %v", tt.want.stmt, gotStmt) 132 | } else if gotLastIdx != tt.want.lastIdx { 133 | t.Errorf("want %v, got %v", tt.want.lastIdx, gotLastIdx) 134 | } 135 | }) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /sqlfmt/sqlfmt.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "bytes" 5 | "go/format" 6 | "go/parser" 7 | "go/printer" 8 | "go/token" 9 | 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | // Options for go-sqlfmt 14 | type Options struct { 15 | Distance int 16 | } 17 | 18 | // Process formats SQL statement in .go file 19 | func Process(filename string, src []byte, options *Options) ([]byte, error) { 20 | fset := token.NewFileSet() 21 | parserMode := parser.ParseComments 22 | 23 | astFile, err := parser.ParseFile(fset, filename, src, parserMode) 24 | if err != nil { 25 | return nil, formatErr(errors.Wrap(err, "parser.ParseFile failed")) 26 | } 27 | 28 | replaceAst(astFile, fset, options) 29 | 30 | var buf bytes.Buffer 31 | 32 | if err = printer.Fprint(&buf, fset, astFile); err != nil { 33 | return nil, formatErr(errors.Wrap(err, "printer.Fprint failed")) 34 | } 35 | 36 | out, err := format.Source(buf.Bytes()) 37 | if err != nil { 38 | return nil, formatErr(errors.Wrap(err, "format.Source failed")) 39 | } 40 | return out, nil 41 | } 42 | 43 | func formatErr(err error) error { 44 | return &FormatError{msg: err.Error()} 45 | } 46 | -------------------------------------------------------------------------------- /sqlfmt/testdata/testing_gofile.go: -------------------------------------------------------------------------------- 1 | package sqlfmt 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | func sendSQL() int { 8 | var id int 9 | var db *sql.DB 10 | db.QueryRow(` 11 | select any ( select xxx from xxx ) from xxx where xxx limit xxx `).Scan(&id) 12 | return id 13 | } 14 | -------------------------------------------------------------------------------- /sqlfmt/testdata/testing_gofile_url_query.go: -------------------------------------------------------------------------------- 1 | package sqlformatter 2 | 3 | import ( 4 | "net/url" 5 | ) 6 | 7 | func parseQuery() int { 8 | u := url.Parse("https://example.org/?a=1&a=2&b=&=3&&&&") 9 | u.Query() 10 | } 11 | --------------------------------------------------------------------------------