├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── cmd └── astprinter │ └── main.go ├── dialect ├── dialect.go ├── keywords.go ├── mysql.go └── postgres.go ├── e2e ├── astutil_test.go ├── commentmap_test.go ├── e2e_test.go └── testdata │ ├── alter │ ├── add_column.sql │ ├── drop_column.sql │ ├── drop_default.sql │ ├── drop_not_null.sql │ ├── edit_type.sql │ ├── set_default.sql │ └── set_not_null.sql │ ├── create_index │ ├── partial_index.sql │ ├── unique_index.sql │ └── using_index.sql │ ├── create_table │ ├── create.sql │ ├── my_create.sql │ └── pg_dump.sql │ ├── drop_index │ ├── drop.sql │ └── single.sql │ ├── drop_table │ └── drop_table.sql │ ├── insert │ ├── multi_values.sql │ ├── on_duplicate_key.sql │ ├── simple.sql │ └── with_select.sql │ └── select │ ├── between_in.sql │ ├── distinct.sql │ ├── exists.sql │ ├── group_by_join.sql │ ├── having_orderby.sql │ ├── inner_join.sql │ ├── lateral_join.sql │ ├── left_join.sql │ ├── like.sql │ ├── limit_and_offset.sql │ ├── moving_average.sql │ ├── multi_join.sql │ ├── multi_table.sql │ ├── notin.sql │ ├── right_join.sql │ ├── row_number.sql │ ├── running_sum.sql │ ├── sql_between.sql │ ├── union_all_where.sql │ ├── where_and_or.sql │ └── window.sql ├── example └── main.go ├── go.mod ├── go.sum ├── parser.go ├── parser_test.go ├── sqlast ├── alter_column_action_gen.go ├── alter_table_action_gen.go ├── ast.go ├── comment.go ├── commentmap.go ├── insert_source_gen.go ├── join_element_gen.go ├── join_spec_gen.go ├── my_data_type_decoration_gen.go ├── operator.go ├── query.go ├── query_test.go ├── sql_select_item_gen.go ├── sql_set_expr_gen.go ├── sql_set_operator_gen.go ├── sql_window_frame_bound_gen.go ├── stmt.go ├── stmt_gen.go ├── stmt_test.go ├── table_constraint_spec_gen.go ├── table_element_gen.go ├── table_factor_gen.go ├── table_option.go ├── table_option_gen.go ├── table_reference_gen.go ├── type.go ├── value.go ├── walk.go └── writer.go ├── sqlastutil ├── rewrite.go └── rewrite_test.go ├── sqltoken ├── kind.go ├── kind_string.go ├── tokenizer.go └── tokenizer_test.go ├── testhelper.go └── tools └── genmark ├── main.go └── main_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | 8 | jobs: 9 | 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | go: [1.16, 1.17, 1.18] 16 | steps: 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v1 20 | with: 21 | go-version: ${{ matrix.go }} 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@v2 26 | 27 | - name: test 28 | run: go test -race -coverprofile=coverage.txt -covermode=atomic -coverpkg=.,./sqlast/...,./sqlastutil/...,./sqltoken/...,./dialect/... ./... 29 | 30 | - name: Upload Coverage report to CodeCov 31 | uses: codecov/codecov-action@v1.0.0 32 | with: 33 | token: ${{secrets.CODECOV_TOKEN}} 34 | file: ./coverage.txt 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/go 3 | # Edit at https://www.gitignore.io/?templates=go 4 | 5 | ### Go ### 6 | # Binaries for programs and plugins 7 | *.exe 8 | *.exe~ 9 | *.dll 10 | *.so 11 | *.dylib 12 | 13 | # Test binary, built with `go test -c` 14 | *.test 15 | 16 | # Output of the go coverage tool, specifically when used with LiteIDE 17 | *.out 18 | 19 | ### Go Patch ### 20 | /vendor/ 21 | /Godeps/ 22 | 23 | bin 24 | tools/bin/ 25 | # End of https://www.gitignore.io/api/go 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2020] Akito Ito 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := PATH="$(PWD)/tools/bin:$(PATH)" $(SHELL) 2 | 3 | .PHONY: build 4 | build: bin/astprinter 5 | 6 | .PHONY: bin/astprinter 7 | bin/astprinter: generate 8 | go build -o bin/astprinter cmd/astprinter/main.go 9 | 10 | .PHONY: tools/bin/genmark 11 | tools/bin/genmark: 12 | go build -o tools/bin/genmark tools/genmark/main.go 13 | 14 | .PHONY: generate 15 | generate: tools/bin/genmark 16 | go generate ./... 17 | 18 | .PHONY: test 19 | test: 20 | go test ./... -cover -count=1 -v 21 | 22 | .PHONY: install 23 | install: vendor 24 | go install ./cmd/... 25 | 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xsqlparser 2 | 3 | [![GoDoc](https://godoc.org/github.com/akito0107/xsqlparser?status.svg)](https://godoc.org/github.com/akito0107/xsqlparser) 4 | [![Actions Status](https://github.com/akito0107/xsqlparser/workflows/Go/badge.svg)](https://github.com/akito0107/xsqlparser/actions) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/akito0107/xsqlparser)](https://goreportcard.com/report/github.com/akito0107/xsqlparser) 6 | [![codecov](https://codecov.io/gh/akito0107/xsqlparser/branch/master/graph/badge.svg)](https://codecov.io/gh/akito0107/xsqlparser) 7 | 8 | sql parser for golang. 9 | 10 | This repo is ported of [sqlparser-rs](https://github.com/andygrove/sqlparser-rs) in Go. 11 | 12 | 13 | ## Getting Started 14 | 15 | ### Prerequisites 16 | - Go 1.16+ 17 | 18 | ### Installing 19 | ``` 20 | $ go get -u github.com/akito0107/xsqlparser/... 21 | ``` 22 | 23 | ### How to use 24 | 25 | #### Parser 26 | 27 | __Currently supports `SELECT`,`CREATE TABLE`, `DROP TABLE`, `CREATE VIEW`,`INSERT`,`UPDATE`,`DELETE`, `ALTER TABLE`, `CREATE INDEX`, `DROP INDEX`, `EXPLAIN`.__ 28 | 29 | - simple case 30 | ```go 31 | package main 32 | 33 | import ( 34 | "bytes" 35 | "log" 36 | 37 | "github.com/k0kubun/pp" 38 | 39 | "github.com/akito0107/xsqlparser" 40 | "github.com/akito0107/xsqlparser/dialect" 41 | ) 42 | 43 | ... 44 | str := "SELECT * from test_table" 45 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | 50 | stmt, err := parser.ParseStatement() 51 | if err != nil { 52 | log.Fatal(err) 53 | } 54 | pp.Println(stmt) 55 | ``` 56 | 57 | got: 58 | ``` 59 | &sqlast.Query{ 60 | stmt: sqlast.stmt{}, 61 | CTEs: []*sqlast.CTE{}, 62 | Body: &sqlast.SQLSelect{ 63 | sqlSetExpr: sqlast.sqlSetExpr{}, 64 | Distinct: false, 65 | Projection: []sqlast.SQLSelectItem{ 66 | &sqlast.UnnamedSelectItem{ 67 | sqlSelectItem: sqlast.sqlSelectItem{}, 68 | Node: &sqlast.Wildcard{}, 69 | }, 70 | }, 71 | FromClause: []sqlast.TableReference{ 72 | &sqlast.Table{ 73 | tableFactor: sqlast.tableFactor{}, 74 | tableReference: sqlast.tableReference{}, 75 | Name: &sqlast.ObjectName{ 76 | Idents: []*sqlast.Ident{ 77 | &"test_table", 78 | }, 79 | }, 80 | Alias: (*sqlast.Ident)(nil), 81 | Args: []sqlast.Node{}, 82 | WithHints: []sqlast.Node{}, 83 | }, 84 | }, 85 | WhereClause: nil, 86 | GroupByClause: []sqlast.Node{}, 87 | HavingClause: nil, 88 | }, 89 | OrderBy: []*sqlast.OrderByExpr{}, 90 | Limit: (*sqlast.LimitExpr)(nil), 91 | } 92 | ``` 93 | 94 | You can also create `sql` from ast via `ToSQLString()`. 95 | ```go 96 | log.Println(stmt.ToSQLString()) 97 | ``` 98 | 99 | got: 100 | ``` 101 | 2019/05/07 11:59:36 SELECT * FROM test_table 102 | ``` 103 | 104 | - complicated select 105 | ```go 106 | str := "SELECT orders.product, SUM(orders.quantity) AS product_units, accounts.* " + 107 | "FROM orders LEFT JOIN accounts ON orders.account_id = accounts.id " + 108 | "WHERE orders.region IN (SELECT region FROM top_regions) " + 109 | "ORDER BY product_units LIMIT 100" 110 | 111 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 112 | if err != nil { 113 | log.Fatal(err) 114 | } 115 | 116 | stmt, err := parser.ParseStatement() 117 | if err != nil { 118 | log.Fatal(err) 119 | } 120 | pp.Println(stmt) 121 | ``` 122 | 123 | got: 124 | ``` 125 | &sqlast.Query{ 126 | stmt: sqlast.stmt{}, 127 | CTEs: []*sqlast.CTE{}, 128 | Body: &sqlast.SQLSelect{ 129 | sqlSetExpr: sqlast.sqlSetExpr{}, 130 | Distinct: false, 131 | Projection: []sqlast.SQLSelectItem{ 132 | &sqlast.UnnamedSelectItem{ 133 | sqlSelectItem: sqlast.sqlSelectItem{}, 134 | Node: &sqlast.CompoundIdent{ 135 | Idents: []*sqlast.Ident{ 136 | &"orders", 137 | &"product", 138 | }, 139 | }, 140 | }, 141 | &sqlast.AliasSelectItem{ 142 | sqlSelectItem: sqlast.sqlSelectItem{}, 143 | Expr: &sqlast.Function{ 144 | Name: &sqlast.ObjectName{ 145 | Idents: []*sqlast.Ident{ 146 | &"SUM", 147 | }, 148 | }, 149 | Args: []sqlast.Node{ 150 | &sqlast.CompoundIdent{ 151 | Idents: []*sqlast.Ident{ 152 | &"orders", 153 | &"quantity", 154 | }, 155 | }, 156 | }, 157 | Over: (*sqlast.WindowSpec)(nil), 158 | }, 159 | Alias: &"product_units", 160 | }, 161 | &sqlast.QualifiedWildcardSelectItem{ 162 | sqlSelectItem: sqlast.sqlSelectItem{}, 163 | Prefix: &sqlast.ObjectName{ 164 | Idents: []*sqlast.Ident{ 165 | &"accounts", 166 | }, 167 | }, 168 | }, 169 | }, 170 | FromClause: []sqlast.TableReference{ 171 | &sqlast.QualifiedJoin{ 172 | tableReference: sqlast.tableReference{}, 173 | LeftElement: &sqlast.TableJoinElement{ 174 | joinElement: sqlast.joinElement{}, 175 | Ref: &sqlast.Table{ 176 | tableFactor: sqlast.tableFactor{}, 177 | tableReference: sqlast.tableReference{}, 178 | Name: &sqlast.ObjectName{ 179 | Idents: []*sqlast.Ident{ 180 | &"orders", 181 | }, 182 | }, 183 | Alias: (*sqlast.Ident)(nil), 184 | Args: []sqlast.Node{}, 185 | WithHints: []sqlast.Node{}, 186 | }, 187 | }, 188 | Type: 1, 189 | RightElement: &sqlast.TableJoinElement{ 190 | joinElement: sqlast.joinElement{}, 191 | Ref: &sqlast.Table{ 192 | tableFactor: sqlast.tableFactor{}, 193 | tableReference: sqlast.tableReference{}, 194 | Name: &sqlast.ObjectName{ 195 | Idents: []*sqlast.Ident{ 196 | &"accounts", 197 | }, 198 | }, 199 | Alias: (*sqlast.Ident)(nil), 200 | Args: []sqlast.Node{}, 201 | WithHints: []sqlast.Node{}, 202 | }, 203 | }, 204 | Spec: &sqlast.JoinCondition{ 205 | joinSpec: sqlast.joinSpec{}, 206 | SearchCondition: &sqlast.BinaryExpr{ 207 | Left: &sqlast.CompoundIdent{ 208 | Idents: []*sqlast.Ident{ 209 | &"orders", 210 | &"account_id", 211 | }, 212 | }, 213 | Op: 9, 214 | Right: &sqlast.CompoundIdent{ 215 | Idents: []*sqlast.Ident{ 216 | &"accounts", 217 | &"id", 218 | }, 219 | }, 220 | }, 221 | }, 222 | }, 223 | }, 224 | WhereClause: &sqlast.InSubQuery{ 225 | Expr: &sqlast.CompoundIdent{ 226 | Idents: []*sqlast.Ident{ 227 | &"orders", 228 | &"region", 229 | }, 230 | }, 231 | SubQuery: &sqlast.Query{ 232 | stmt: sqlast.stmt{}, 233 | CTEs: []*sqlast.CTE{}, 234 | Body: &sqlast.SQLSelect{ 235 | sqlSetExpr: sqlast.sqlSetExpr{}, 236 | Distinct: false, 237 | Projection: []sqlast.SQLSelectItem{ 238 | &sqlast.UnnamedSelectItem{ 239 | sqlSelectItem: sqlast.sqlSelectItem{}, 240 | Node: &"region", 241 | }, 242 | }, 243 | FromClause: []sqlast.TableReference{ 244 | &sqlast.Table{ 245 | tableFactor: sqlast.tableFactor{}, 246 | tableReference: sqlast.tableReference{}, 247 | Name: &sqlast.ObjectName{ 248 | Idents: []*sqlast.Ident{ 249 | &"top_regions", 250 | }, 251 | }, 252 | Alias: (*sqlast.Ident)(nil), 253 | Args: []sqlast.Node{}, 254 | WithHints: []sqlast.Node{}, 255 | }, 256 | }, 257 | WhereClause: nil, 258 | GroupByClause: []sqlast.Node{}, 259 | HavingClause: nil, 260 | }, 261 | OrderBy: []*sqlast.OrderByExpr{}, 262 | Limit: (*sqlast.LimitExpr)(nil), 263 | }, 264 | Negated: false, 265 | }, 266 | GroupByClause: []sqlast.Node{}, 267 | HavingClause: nil, 268 | }, 269 | OrderBy: []*sqlast.OrderByExpr{ 270 | &sqlast.OrderByExpr{ 271 | Expr: &"product_units", 272 | ASC: (*bool)(nil), 273 | }, 274 | }, 275 | Limit: &sqlast.LimitExpr{ 276 | All: false, 277 | LimitValue: &100, 278 | OffsetValue: (*sqlast.LongValue)(nil), 279 | }, 280 | } 281 | ``` 282 | 283 | - with CTE 284 | ```go 285 | str := "WITH regional_sales AS (" + 286 | "SELECT region, SUM(amount) AS total_sales " + 287 | "FROM orders GROUP BY region) " + 288 | "SELECT product, SUM(quantity) AS product_units " + 289 | "FROM orders " + 290 | "WHERE region IN (SELECT region FROM top_regions) " + 291 | "GROUP BY region, product" 292 | 293 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 294 | if err != nil { 295 | log.Fatal(err) 296 | } 297 | 298 | stmt, err := parser.ParseStatement() 299 | if err != nil { 300 | log.Fatal(err) 301 | } 302 | pp.Println(stmt) 303 | ``` 304 | 305 | got: 306 | ``` 307 | &sqlast.Query{ 308 | stmt: sqlast.stmt{}, 309 | CTEs: []*sqlast.CTE{ 310 | &sqlast.CTE{ 311 | Alias: &"regional_sales", 312 | Query: &sqlast.Query{ 313 | stmt: sqlast.stmt{}, 314 | CTEs: []*sqlast.CTE{}, 315 | Body: &sqlast.SQLSelect{ 316 | sqlSetExpr: sqlast.sqlSetExpr{}, 317 | Distinct: false, 318 | Projection: []sqlast.SQLSelectItem{ 319 | &sqlast.UnnamedSelectItem{ 320 | sqlSelectItem: sqlast.sqlSelectItem{}, 321 | Node: &"region", 322 | }, 323 | &sqlast.AliasSelectItem{ 324 | sqlSelectItem: sqlast.sqlSelectItem{}, 325 | Expr: &sqlast.Function{ 326 | Name: &sqlast.ObjectName{ 327 | Idents: []*sqlast.Ident{ 328 | &"SUM", 329 | }, 330 | }, 331 | Args: []sqlast.Node{ 332 | &"amount", 333 | }, 334 | Over: (*sqlast.WindowSpec)(nil), 335 | }, 336 | Alias: &"total_sales", 337 | }, 338 | }, 339 | FromClause: []sqlast.TableReference{ 340 | &sqlast.Table{ 341 | tableFactor: sqlast.tableFactor{}, 342 | tableReference: sqlast.tableReference{}, 343 | Name: &sqlast.ObjectName{ 344 | Idents: []*sqlast.Ident{ 345 | &"orders", 346 | }, 347 | }, 348 | Alias: (*sqlast.Ident)(nil), 349 | Args: []sqlast.Node{}, 350 | WithHints: []sqlast.Node{}, 351 | }, 352 | }, 353 | WhereClause: nil, 354 | GroupByClause: []sqlast.Node{ 355 | &"region", 356 | }, 357 | HavingClause: nil, 358 | }, 359 | OrderBy: []*sqlast.OrderByExpr{}, 360 | Limit: (*sqlast.LimitExpr)(nil), 361 | }, 362 | }, 363 | }, 364 | Body: &sqlast.SQLSelect{ 365 | sqlSetExpr: sqlast.sqlSetExpr{}, 366 | Distinct: false, 367 | Projection: []sqlast.SQLSelectItem{ 368 | &sqlast.UnnamedSelectItem{ 369 | sqlSelectItem: sqlast.sqlSelectItem{}, 370 | Node: &"product", 371 | }, 372 | &sqlast.AliasSelectItem{ 373 | sqlSelectItem: sqlast.sqlSelectItem{}, 374 | Expr: &sqlast.Function{ 375 | Name: &sqlast.ObjectName{ 376 | Idents: []*sqlast.Ident{ 377 | &"SUM", 378 | }, 379 | }, 380 | Args: []sqlast.Node{ 381 | &"quantity", 382 | }, 383 | Over: (*sqlast.WindowSpec)(nil), 384 | }, 385 | Alias: &"product_units", 386 | }, 387 | }, 388 | FromClause: []sqlast.TableReference{ 389 | &sqlast.Table{ 390 | tableFactor: sqlast.tableFactor{}, 391 | tableReference: sqlast.tableReference{}, 392 | Name: &sqlast.ObjectName{ 393 | Idents: []*sqlast.Ident{ 394 | &"orders", 395 | }, 396 | }, 397 | Alias: (*sqlast.Ident)(nil), 398 | Args: []sqlast.Node{}, 399 | WithHints: []sqlast.Node{}, 400 | }, 401 | }, 402 | WhereClause: &sqlast.InSubQuery{ 403 | Expr: &"region", 404 | SubQuery: &sqlast.Query{ 405 | stmt: sqlast.stmt{}, 406 | CTEs: []*sqlast.CTE{}, 407 | Body: &sqlast.SQLSelect{ 408 | sqlSetExpr: sqlast.sqlSetExpr{}, 409 | Distinct: false, 410 | Projection: []sqlast.SQLSelectItem{ 411 | &sqlast.UnnamedSelectItem{ 412 | sqlSelectItem: sqlast.sqlSelectItem{}, 413 | Node: &"region", 414 | }, 415 | }, 416 | FromClause: []sqlast.TableReference{ 417 | &sqlast.Table{ 418 | tableFactor: sqlast.tableFactor{}, 419 | tableReference: sqlast.tableReference{}, 420 | Name: &sqlast.ObjectName{ 421 | Idents: []*sqlast.Ident{ 422 | &"top_regions", 423 | }, 424 | }, 425 | Alias: (*sqlast.Ident)(nil), 426 | Args: []sqlast.Node{}, 427 | WithHints: []sqlast.Node{}, 428 | }, 429 | }, 430 | WhereClause: nil, 431 | GroupByClause: []sqlast.Node{}, 432 | HavingClause: nil, 433 | }, 434 | OrderBy: []*sqlast.OrderByExpr{}, 435 | Limit: (*sqlast.LimitExpr)(nil), 436 | }, 437 | Negated: false, 438 | }, 439 | GroupByClause: []sqlast.Node{ 440 | &"region", 441 | &"product", 442 | }, 443 | HavingClause: nil, 444 | }, 445 | OrderBy: []*sqlast.OrderByExpr{}, 446 | Limit: (*sqlast.LimitExpr)(nil), 447 | } 448 | ``` 449 | 450 | #### Visitor(s) 451 | 452 | - Using `Inspect` 453 | 454 | create AST List 455 | ```go 456 | package main 457 | 458 | 459 | import ( 460 | "bytes" 461 | "log" 462 | 463 | "github.com/k0kubun/pp" 464 | 465 | "github.com/akito0107/xsqlparser" 466 | "github.com/akito0107/xsqlparser/sqlast" 467 | "github.com/akito0107/xsqlparser/dialect" 468 | ) 469 | 470 | func main() { 471 | src := `WITH regional_sales AS ( 472 | SELECT region, SUM(amount) AS total_sales 473 | FROM orders GROUP BY region) 474 | SELECT product, SUM(quantity) AS product_units 475 | FROM orders 476 | WHERE region IN (SELECT region FROM top_regions) 477 | GROUP BY region, product;` 478 | 479 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(src), &dialect.GenericSQLDialect{}) 480 | if err != nil { 481 | log.Fatal(err) 482 | } 483 | 484 | stmt, err := parser.ParseStatement() 485 | if err != nil { 486 | log.Fatal(err) 487 | } 488 | var list []sqlast.Node 489 | 490 | sqlast.Inspect(stmt, func(node sqlast.Node) bool { 491 | switch node.(type) { 492 | case nil: 493 | return false 494 | default: 495 | list = append(list, node) 496 | return true 497 | } 498 | }) 499 | pp.Println(list) 500 | } 501 | ``` 502 | 503 | also available `Walk()`. 504 | 505 | #### CommentMap 506 | 507 | __Experimental Feature__ 508 | 509 | ```go 510 | package main 511 | 512 | import ( 513 | "bytes" 514 | "log" 515 | 516 | "github.com/k0kubun/pp" 517 | 518 | "github.com/akito0107/xsqlparser" 519 | "github.com/akito0107/xsqlparser/sqlast" 520 | "github.com/akito0107/xsqlparser/dialect" 521 | ) 522 | 523 | func main() { 524 | src := ` 525 | /*associate with stmts1*/ 526 | CREATE TABLE test ( 527 | /*associate with columndef*/ 528 | col0 int primary key, --columndef 529 | /*with constraints*/ 530 | col1 integer constraint test_constraint check (10 < col1 and col1 < 100), 531 | foreign key (col0, col1) references test2(col1, col2), --table constraints1 532 | --table constraints2 533 | CONSTRAINT test_constraint check(col1 > 10) 534 | ); --associate with stmts2 535 | ` 536 | 537 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(src), &dialect.GenericSQLDialect{}, xsqlparser.ParseComment) 538 | if err != nil { 539 | log.Fatal(err) 540 | } 541 | 542 | file, err := parser.ParseFile() 543 | if err != nil { 544 | log.Fatal(err) 545 | } 546 | 547 | m := sqlast.NewCommentMap(file) 548 | 549 | createTable := file.Stmts[0].(*sqlast.CreateTableStmt) 550 | 551 | pp.Println(m[createTable.Elements[0]]) // you can show `associate with columndef` and `columndef` comments 552 | } 553 | 554 | ``` 555 | 556 | ## License 557 | This project is licensed under the Apache License 2.0 License - see the [LICENSE](LICENSE) file for details 558 | -------------------------------------------------------------------------------- /cmd/astprinter/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "io" 6 | "log" 7 | "os" 8 | 9 | "github.com/k0kubun/pp" 10 | 11 | "github.com/akito0107/xsqlparser" 12 | "github.com/akito0107/xsqlparser/dialect" 13 | ) 14 | 15 | var f = flag.String("f", "stdin", "input sql file (default stdin)") 16 | 17 | func main() { 18 | flag.Parse() 19 | 20 | var src io.Reader 21 | if *f == "stdin" { 22 | src = os.Stdin 23 | } else { 24 | file, err := os.Open(*f) 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | defer file.Close() 29 | src = file 30 | } 31 | 32 | parser, _ := xsqlparser.NewParser(src, &dialect.GenericSQLDialect{}) 33 | stmt, err := parser.ParseStatement() 34 | if err != nil { 35 | log.Fatal(err) 36 | } 37 | 38 | pp.Println(stmt) 39 | 40 | log.Println(stmt.ToSQLString()) 41 | } 42 | -------------------------------------------------------------------------------- /dialect/dialect.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | type Dialect interface { 4 | IsIdentifierStart(r rune) bool 5 | IsIdentifierPart(r rune) bool 6 | IsDelimitedIdentifierStart(r rune) bool 7 | } 8 | 9 | type GenericSQLDialect struct { 10 | } 11 | 12 | func (*GenericSQLDialect) IsIdentifierStart(r rune) bool { 13 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || r == '@' 14 | } 15 | 16 | func (*GenericSQLDialect) IsIdentifierPart(r rune) bool { 17 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '@' || r == '_' 18 | } 19 | 20 | func (*GenericSQLDialect) IsDelimitedIdentifierStart(r rune) bool { 21 | return r == '"' 22 | } 23 | 24 | var _ Dialect = &GenericSQLDialect{} 25 | -------------------------------------------------------------------------------- /dialect/mysql.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | type MySQLDialect struct { 4 | GenericSQLDialect 5 | } 6 | 7 | func (*MySQLDialect) IsDelimitedIdentifierStart(r rune) bool { 8 | return r == '"' || r == '`' 9 | } 10 | 11 | var _ Dialect = &MySQLDialect{} 12 | -------------------------------------------------------------------------------- /dialect/postgres.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | type PostgresqlDialect struct { 4 | } 5 | 6 | func (*PostgresqlDialect) IsIdentifierStart(r rune) bool { 7 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || r == '_' 8 | } 9 | 10 | func (*PostgresqlDialect) IsIdentifierPart(r rune) bool { 11 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '$' || r == '_' 12 | } 13 | 14 | func (*PostgresqlDialect) IsDelimitedIdentifierStart(r rune) bool { 15 | return r == '"' || r == '`' 16 | } 17 | 18 | var _ Dialect = &PostgresqlDialect{} 19 | -------------------------------------------------------------------------------- /e2e/astutil_test.go: -------------------------------------------------------------------------------- 1 | package e2e_test 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/akito0107/xsqlparser" 11 | "github.com/akito0107/xsqlparser/sqlastutil" 12 | "github.com/akito0107/xsqlparser/dialect" 13 | "github.com/akito0107/xsqlparser/sqlast" 14 | ) 15 | 16 | func TestInspect(t *testing.T) { 17 | cases := []struct { 18 | name string 19 | dir string 20 | }{ 21 | { 22 | name: "SELECT", 23 | dir: "select", 24 | }, 25 | { 26 | name: "CREATE TABLE", 27 | dir: "create_table", 28 | }, 29 | { 30 | name: "ALTER TABLE", 31 | dir: "alter", 32 | }, 33 | { 34 | name: "DROP TABLE", 35 | dir: "drop_table", 36 | }, 37 | { 38 | name: "CREATE INDEX", 39 | dir: "create_index", 40 | }, 41 | { 42 | name: "DROP INDEX", 43 | dir: "drop_index", 44 | }, 45 | { 46 | name: "INSERT", 47 | dir: "insert", 48 | }, 49 | } 50 | 51 | for _, c := range cases { 52 | t.Run(c.name, func(t *testing.T) { 53 | fname := fmt.Sprintf("./testdata/%s/", c.dir) 54 | files, err := ioutil.ReadDir(fname) 55 | if err != nil { 56 | t.Fatalf("%+v", err) 57 | } 58 | 59 | for _, f := range files { 60 | if !strings.HasSuffix(f.Name(), ".sql") { 61 | continue 62 | } 63 | t.Run(f.Name(), func(t *testing.T) { 64 | fi, err := os.Open(fname + f.Name()) 65 | if err != nil { 66 | t.Fatalf("%+v", err) 67 | } 68 | defer fi.Close() 69 | parser, err := xsqlparser.NewParser(fi, &dialect.GenericSQLDialect{}) 70 | if err != nil { 71 | t.Fatalf("%+v", err) 72 | } 73 | 74 | stmt, err := parser.ParseStatement() 75 | if err != nil { 76 | t.Fatalf("%+v", err) 77 | } 78 | sqlast.Inspect(stmt, func(node sqlast.Node) bool { 79 | // fmt.Printf("%T\n", node) 80 | return true 81 | }) 82 | }) 83 | } 84 | }) 85 | } 86 | } 87 | 88 | func TestApply(t *testing.T) { 89 | cases := []struct { 90 | name string 91 | dir string 92 | }{ 93 | { 94 | name: "SELECT", 95 | dir: "select", 96 | }, 97 | { 98 | name: "CREATE TABLE", 99 | dir: "create_table", 100 | }, 101 | { 102 | name: "ALTER TABLE", 103 | dir: "alter", 104 | }, 105 | { 106 | name: "DROP TABLE", 107 | dir: "drop_table", 108 | }, 109 | { 110 | name: "CREATE INDEX", 111 | dir: "create_index", 112 | }, 113 | { 114 | name: "DROP INDEX", 115 | dir: "drop_index", 116 | }, 117 | { 118 | name: "INSERT", 119 | dir: "insert", 120 | }, 121 | } 122 | 123 | for _, c := range cases { 124 | t.Run(c.name, func(t *testing.T) { 125 | fname := fmt.Sprintf("./testdata/%s/", c.dir) 126 | files, err := ioutil.ReadDir(fname) 127 | if err != nil { 128 | t.Fatalf("%+v", err) 129 | } 130 | 131 | for _, f := range files { 132 | if !strings.HasSuffix(f.Name(), ".sql") { 133 | continue 134 | } 135 | t.Run(f.Name(), func(t *testing.T) { 136 | fi, err := os.Open(fname + f.Name()) 137 | if err != nil { 138 | t.Fatalf("%+v", err) 139 | } 140 | defer fi.Close() 141 | parser, err := xsqlparser.NewParser(fi, &dialect.GenericSQLDialect{}) 142 | if err != nil { 143 | t.Fatalf("%+v", err) 144 | } 145 | 146 | stmt, err := parser.ParseStatement() 147 | if err != nil { 148 | t.Fatalf("%+v", err) 149 | } 150 | sqlastutil.Apply(stmt, func(c *sqlastutil.Cursor) bool { 151 | // fmt.Printf("%T\n", node) 152 | return true 153 | }, nil) 154 | }) 155 | } 156 | }) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /e2e/commentmap_test.go: -------------------------------------------------------------------------------- 1 | package e2e_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | 9 | "github.com/akito0107/xsqlparser" 10 | "github.com/akito0107/xsqlparser/dialect" 11 | "github.com/akito0107/xsqlparser/sqlast" 12 | "github.com/akito0107/xsqlparser/sqltoken" 13 | ) 14 | 15 | func parseFile(t *testing.T, src string) *sqlast.File { 16 | t.Helper() 17 | parser, err := xsqlparser.NewParser(strings.NewReader(src), &dialect.GenericSQLDialect{}, xsqlparser.ParseComment()) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | 22 | f, err := parser.ParseFile() 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | return f 27 | } 28 | 29 | func compareComment(t *testing.T, expect, actual []*sqlast.CommentGroup) { 30 | t.Helper() 31 | if diff := cmp.Diff(expect, actual); diff != "" { 32 | t.Error(diff) 33 | } 34 | } 35 | 36 | func TestNewCommentMap(t *testing.T) { 37 | 38 | t.Run("associate with single statement", func(t *testing.T) { 39 | f := parseFile(t, ` 40 | --test 41 | SELECT * from test; 42 | `) 43 | 44 | m := sqlast.NewCommentMap(f) 45 | compareComment(t, m[f.Stmts[0]], []*sqlast.CommentGroup{ 46 | { 47 | List: []*sqlast.Comment{ 48 | { 49 | Text: "test", 50 | From: sqltoken.NewPos(2, 1), 51 | To: sqltoken.NewPos(2, 7), 52 | }, 53 | }, 54 | }, 55 | }) 56 | }) 57 | 58 | t.Run("associate with multi statements", func(t *testing.T) { 59 | 60 | f := parseFile(t, ` 61 | --select 62 | SELECT * from test; 63 | 64 | /* 65 | insert 66 | */ 67 | INSERT INTO tbl_name (col1,col2) VALUES(15,col1*2); 68 | `) 69 | m := sqlast.NewCommentMap(f) 70 | 71 | compareComment(t, m[f.Stmts[0]], []*sqlast.CommentGroup{ 72 | { 73 | List: []*sqlast.Comment{ 74 | { 75 | Text: "select", 76 | From: sqltoken.NewPos(2, 1), 77 | To: sqltoken.NewPos(2, 9), 78 | }, 79 | }, 80 | }, 81 | }) 82 | 83 | compareComment(t, m[f.Stmts[1]], []*sqlast.CommentGroup{ 84 | { 85 | List: []*sqlast.Comment{ 86 | { 87 | Text: "\ninsert\n", 88 | From: sqltoken.NewPos(5, 1), 89 | To: sqltoken.NewPos(7, 3), 90 | }, 91 | }, 92 | }, 93 | }) 94 | }) 95 | 96 | t.Run("create table", func(t *testing.T) { 97 | 98 | f := parseFile(t, ` 99 | /*associate with stmts1*/ 100 | CREATE TABLE test ( 101 | /*associate with columndef*/ 102 | col0 int primary key, --columndef 103 | /*with constraints*/ 104 | col1 integer constraint test_constraint check (10 < col1 and col1 < 100), 105 | foreign key (col0, col1) references test2(col1, col2), --table constraints1 106 | --table constraints2 107 | CONSTRAINT test_constraint check(col1 > 10) 108 | ); --associate with stmts2 109 | `) 110 | 111 | m := sqlast.NewCommentMap(f) 112 | ct := f.Stmts[0].(*sqlast.CreateTableStmt) 113 | compareComment(t, m[ct], []*sqlast.CommentGroup{ 114 | { 115 | List: []*sqlast.Comment{ 116 | { 117 | Text: "associate with stmts1", 118 | From: sqltoken.NewPos(2, 1), 119 | To: sqltoken.NewPos(2, 26), 120 | }, 121 | }, 122 | }, 123 | { 124 | List: []*sqlast.Comment{ 125 | { 126 | Text: "associate with stmts2", 127 | From: sqltoken.NewPos(11, 4), 128 | To: sqltoken.NewPos(11, 27), 129 | }, 130 | }, 131 | }, 132 | }) 133 | 134 | compareComment(t, m[ct.Elements[0]], []*sqlast.CommentGroup{ 135 | { 136 | List: []*sqlast.Comment{ 137 | { 138 | Text: "associate with columndef", 139 | From: sqltoken.NewPos(4, 5), 140 | To: sqltoken.NewPos(4, 33), 141 | }, 142 | }, 143 | }, 144 | { 145 | List: []*sqlast.Comment{ 146 | { 147 | Text: "columndef", 148 | From: sqltoken.NewPos(5, 27), 149 | To: sqltoken.NewPos(5, 38), 150 | }, 151 | }, 152 | }, 153 | }) 154 | 155 | compareComment(t, m[ct.Elements[1]], []*sqlast.CommentGroup{ 156 | { 157 | List: []*sqlast.Comment{ 158 | { 159 | Text: "with constraints", 160 | From: sqltoken.NewPos(6, 5), 161 | To: sqltoken.NewPos(6, 25), 162 | }, 163 | }, 164 | }, 165 | }) 166 | 167 | compareComment(t, m[ct.Elements[2]], []*sqlast.CommentGroup{ 168 | { 169 | List: []*sqlast.Comment{ 170 | { 171 | Text: "table constraints1", 172 | From: sqltoken.NewPos(8, 60), 173 | To: sqltoken.NewPos(8, 80), 174 | }, 175 | }, 176 | }, 177 | }) 178 | 179 | compareComment(t, m[ct.Elements[3]], []*sqlast.CommentGroup{ 180 | { 181 | List: []*sqlast.Comment{ 182 | { 183 | Text: "table constraints2", 184 | From: sqltoken.NewPos(9, 5), 185 | To: sqltoken.NewPos(9, 25), 186 | }, 187 | }, 188 | }, 189 | }) 190 | }) 191 | } 192 | -------------------------------------------------------------------------------- /e2e/e2e_test.go: -------------------------------------------------------------------------------- 1 | package e2e_test 2 | 3 | // All queries are from https://www.w3schools.com/sql/sql_examples.asp 4 | 5 | import ( 6 | "bytes" 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/akito0107/xsqlparser" 14 | "github.com/akito0107/xsqlparser/dialect" 15 | ) 16 | 17 | func TestParseQuery(t *testing.T) { 18 | 19 | cases := []struct { 20 | name string 21 | dir string 22 | }{ 23 | { 24 | name: "SELECT", 25 | dir: "select", 26 | }, 27 | { 28 | name: "CREATE TABLE", 29 | dir: "create_table", 30 | }, 31 | { 32 | name: "ALTER TABLE", 33 | dir: "alter", 34 | }, 35 | { 36 | name: "DROP TABLE", 37 | dir: "drop_table", 38 | }, 39 | { 40 | name: "CREATE INDEX", 41 | dir: "create_index", 42 | }, 43 | { 44 | name: "DROP INDEX", 45 | dir: "drop_index", 46 | }, 47 | { 48 | name: "INSERT", 49 | dir: "insert", 50 | }, 51 | } 52 | 53 | for _, c := range cases { 54 | t.Run(c.name, func(t *testing.T) { 55 | fname := fmt.Sprintf("testdata/%s/", c.dir) 56 | files, err := ioutil.ReadDir(fname) 57 | if err != nil { 58 | t.Fatalf("%+v", err) 59 | } 60 | 61 | for _, f := range files { 62 | if !strings.HasSuffix(f.Name(), ".sql") { 63 | continue 64 | } 65 | t.Run(f.Name(), func(t *testing.T) { 66 | fi, err := os.Open(fname + f.Name()) 67 | if err != nil { 68 | t.Fatalf("%+v", err) 69 | } 70 | defer fi.Close() 71 | parser, err := xsqlparser.NewParser(fi, &dialect.GenericSQLDialect{}) 72 | if err != nil { 73 | t.Fatalf("%+v", err) 74 | } 75 | 76 | orig, err := parser.ParseStatement() 77 | if err != nil { 78 | t.Fatalf("%+v", err) 79 | } 80 | recovered := orig.ToSQLString() 81 | 82 | parser, err = xsqlparser.NewParser(bytes.NewBufferString(recovered), &dialect.GenericSQLDialect{}) 83 | if err != nil { 84 | t.Log(recovered) 85 | t.Fatalf("%+v", err) 86 | } 87 | 88 | stmt2, err := parser.ParseStatement() 89 | if err != nil { 90 | t.Fatalf("%+v", err) 91 | } 92 | 93 | recovered2 := stmt2.ToSQLString() 94 | 95 | parser, err = xsqlparser.NewParser(bytes.NewBufferString(recovered2), &dialect.GenericSQLDialect{}) 96 | if err != nil { 97 | t.Log(recovered) 98 | t.Fatalf("%+v", err) 99 | } 100 | 101 | stmt3, err := parser.ParseStatement() 102 | if err != nil { 103 | t.Fatalf("%+v", err) 104 | } 105 | 106 | if astdiff := xsqlparser.CompareWithoutMarker(stmt2, stmt3); astdiff != "" { 107 | t.Logf(recovered) 108 | t.Errorf("should be same ast but diff:\n %s", astdiff) 109 | } 110 | }) 111 | } 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /e2e/testdata/alter/add_column.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ADD COLUMN NAME VARCHAR NOT NULL; -------------------------------------------------------------------------------- /e2e/testdata/alter/drop_column.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE TEST1 DROP COLUMN name; -------------------------------------------------------------------------------- /e2e/testdata/alter/drop_default.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ALTER COLUMN NAME DROP DEFAULT; -------------------------------------------------------------------------------- /e2e/testdata/alter/drop_not_null.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ALTER COLUMN NAME DROP NOT NULL; -------------------------------------------------------------------------------- /e2e/testdata/alter/edit_type.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ALTER COLUMN NAME TYPE TEXT; -------------------------------------------------------------------------------- /e2e/testdata/alter/set_default.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ALTER COLUMN NAME SET DEFAULT 10; -------------------------------------------------------------------------------- /e2e/testdata/alter/set_not_null.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE test1 ALTER COLUMN NAME SET NOT NULL; -------------------------------------------------------------------------------- /e2e/testdata/create_index/partial_index.sql: -------------------------------------------------------------------------------- 1 | CREATE UNIQUE INDEX customers_idx ON customers USING gist (name) WHERE name = 'test'; -------------------------------------------------------------------------------- /e2e/testdata/create_index/unique_index.sql: -------------------------------------------------------------------------------- 1 | CREATE UNIQUE INDEX customers_idx ON customers USING gist (name); -------------------------------------------------------------------------------- /e2e/testdata/create_index/using_index.sql: -------------------------------------------------------------------------------- 1 | CREATE UNIQUE INDEX customers_idx ON customers (name, email); -------------------------------------------------------------------------------- /e2e/testdata/create_table/create.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE test ( 2 | col1 int primary key, 3 | col2 char(10), 4 | col3 VARCHAR, 5 | col4 VARCHAR(255), 6 | col5 uuid NOT NULl, 7 | col6 smallint check(col6 < 10), 8 | col7 bigint UNIQUE, 9 | col8 integer constraint test_constraint check (10 < col8 and col8 < 100), 10 | col9 serial, 11 | col10 character varying, 12 | col11 real references test2(col1), 13 | col12 double precision, 14 | col13 date, 15 | col14 time, 16 | col15 timestamp default current_timestamp, 17 | col16 boolean default false, 18 | col17 numeric(10, 10), 19 | col18 text, 20 | foreign key (col1, col2) references test2(col1, col2), 21 | unique key (col1, col2), 22 | CONSTRAINT test_constraint check(col1 > 10) 23 | ) -------------------------------------------------------------------------------- /e2e/testdata/create_table/my_create.sql: -------------------------------------------------------------------------------- 1 | -- from https://github.com/isucon/isucon8-qualify/blob/master/db/schema.sql 2 | 3 | CREATE TABLE IF NOT EXISTS users ( 4 | id INTEGER UNSIGNED PRIMARY KEY AUTO_INCREMENT, 5 | nickname VARCHAR(128) NOT NULL, 6 | login_name VARCHAR(128) NOT NULL, 7 | pass_hash VARCHAR(128) NOT NULL 8 | ) ENGINE=InnoDB DEFAULT CHARSET utf8mb4; 9 | -------------------------------------------------------------------------------- /e2e/testdata/create_table/pg_dump.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE public.version ( 2 | version_id uuid DEFAULT public.uuid_generate_v4() NOT NULL, 3 | project_id uuid NOT NULL, 4 | name character varying(256) NOT NULL, 5 | created_at timestamp without time zone DEFAULT now() NOT NULL 6 | ); -------------------------------------------------------------------------------- /e2e/testdata/drop_index/drop.sql: -------------------------------------------------------------------------------- 1 | DROP INDEX title_idx,title_idx2,title_idx3; -------------------------------------------------------------------------------- /e2e/testdata/drop_index/single.sql: -------------------------------------------------------------------------------- 1 | DROP INDEX title_idx; -------------------------------------------------------------------------------- /e2e/testdata/drop_table/drop_table.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS TEST CASCADE; -------------------------------------------------------------------------------- /e2e/testdata/insert/multi_values.sql: -------------------------------------------------------------------------------- 1 | -- from https://dev.mysql.com/doc/refman/8.0/en/insert.html 2 | INSERT INTO tbl_name (a,b,c) VALUES(1,2,3),(4,5,6),(7,8,9); -------------------------------------------------------------------------------- /e2e/testdata/insert/on_duplicate_key.sql: -------------------------------------------------------------------------------- 1 | -- from: https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html 2 | INSERT INTO t1 (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE c=c+1; 3 | -------------------------------------------------------------------------------- /e2e/testdata/insert/simple.sql: -------------------------------------------------------------------------------- 1 | -- from https://dev.mysql.com/doc/refman/8.0/en/insert.html 2 | INSERT INTO tbl_name (col1,col2) VALUES(15,col1*2); -------------------------------------------------------------------------------- /e2e/testdata/insert/with_select.sql: -------------------------------------------------------------------------------- 1 | INSERT INTO tbl_name (a,b,c) SELECT * from tbl_name2; -------------------------------------------------------------------------------- /e2e/testdata/select/between_in.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM Products 2 | WHERE Price BETWEEN 10 AND 20 3 | AND NOT CategoryID IN (1,2,3); -------------------------------------------------------------------------------- /e2e/testdata/select/distinct.sql: -------------------------------------------------------------------------------- 1 | SELECT Count(*) AS DistinctCountries 2 | FROM (SELECT DISTINCT Country FROM Customers); 3 | -------------------------------------------------------------------------------- /e2e/testdata/select/exists.sql: -------------------------------------------------------------------------------- 1 | SELECT column_name(s) 2 | FROM table_name 3 | WHERE EXISTS 4 | (SELECT column_name FROM table_name WHERE condition); -------------------------------------------------------------------------------- /e2e/testdata/select/group_by_join.sql: -------------------------------------------------------------------------------- 1 | SELECT Shippers.ShipperName,COUNT(Orders.OrderID) AS NumberOfOrders FROM Orders 2 | LEFT JOIN Shippers ON Orders.ShipperID = Shippers.ShipperID 3 | GROUP BY ShipperName; 4 | -------------------------------------------------------------------------------- /e2e/testdata/select/having_orderby.sql: -------------------------------------------------------------------------------- 1 | SELECT COUNT(CustomerID), Country 2 | FROM Customers 3 | GROUP BY Country 4 | HAVING COUNT(CustomerID) > 5 5 | ORDER BY COUNT(CustomerID) DESC; -------------------------------------------------------------------------------- /e2e/testdata/select/inner_join.sql: -------------------------------------------------------------------------------- 1 | SELECT Orders.OrderID, Customers.CustomerName 2 | FROM Orders 3 | INNER JOIN Customers ON Orders.CustomerID = Customers.CustomerID; 4 | -------------------------------------------------------------------------------- /e2e/testdata/select/lateral_join.sql: -------------------------------------------------------------------------------- 1 | SELECT n.id, n.name, t.max, t.min FROM node n, 2 | LATERAL ( 3 | SELECT max(usage) as max, min(usage) as min 4 | FROM node_mon m 5 | WHERE m.id = n.id 6 | ) t; 7 | -------------------------------------------------------------------------------- /e2e/testdata/select/left_join.sql: -------------------------------------------------------------------------------- 1 | SELECT Customers.CustomerName, Orders.OrderID 2 | FROM Customers 3 | LEFT JOIN Orders 4 | ON Customers.CustomerID=Orders.CustomerID 5 | ORDER BY Customers.CustomerName; 6 | -------------------------------------------------------------------------------- /e2e/testdata/select/like.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM Customers 2 | WHERE ContactName LIKE 'a%o'; -------------------------------------------------------------------------------- /e2e/testdata/select/limit_and_offset.sql: -------------------------------------------------------------------------------- 1 | SELECT product, SUM(quantity) AS product_units FROM orders 2 | WHERE region IN (SELECT region FROM top_regions) 3 | ORDER BY product_units LIMIT 100 OFFSET 20; -------------------------------------------------------------------------------- /e2e/testdata/select/moving_average.sql: -------------------------------------------------------------------------------- 1 | -- from https://support.treasuredata.com/hc/ja/articles/216392117-Window-%E9%96%A2%E6%95%B0%E3%82%92%E4%BD%BF%E3%81%84%E3%81%93%E3%81%AA%E3%81%99-%E9%9B%86%E7%B4%84%E9%96%A2%E6%95%B0%E7%B3%BB- 2 | SELECT m, d, goods_id, sales, AVG(sales) OVER (PARTITION BY goods_id,m ORDER BY d ASC ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) as sales_moving_avg 3 | FROM 4 | ( 5 | SELECT 6 | TD_TIME_FORMAT(time,'yyyy-MM-dd','JST') AS d, TD_TIME_FORMAT(time,'yyyy-MM','JST') AS m, goods_id, SUM(price*amount) AS sales 7 | FROM sales_slip 8 | GROUP BY TD_TIME_FORMAT(time,'yyyy-MM-dd','JST'), TD_TIME_FORMAT(time,'yyyy-MM','JST'), goods_id 9 | ) t 10 | ORDER BY goods_id, m, d -------------------------------------------------------------------------------- /e2e/testdata/select/multi_join.sql: -------------------------------------------------------------------------------- 1 | select count(*) from project 2 | join scenario ON scenario.project_id = project.project_id 3 | right outer join scenario_version ON scenario_version.scenario_id = scenario.scenario_id 4 | left outer join compare_log ON compare_log.left_scenario_version_id = scenario_version.scenario_version_id 5 | full outer join scenario_version scenario_version2 ON scenario_version2.scenario_version_id = compare_log.right_scenario_version_id 6 | inner join snapshot ON snapshot.scenario_version_id = scenario_version2.scenario_version_id 7 | natural join scenario_version as s3; -------------------------------------------------------------------------------- /e2e/testdata/select/multi_table.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM Customers, Items, Blogs 2 | WHERE ContactName LIKE 'a%o'; -------------------------------------------------------------------------------- /e2e/testdata/select/notin.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM Customers 2 | WHERE Country NOT IN ('Germany', 'France', 'UK'); -------------------------------------------------------------------------------- /e2e/testdata/select/right_join.sql: -------------------------------------------------------------------------------- 1 | SELECT Orders.OrderID, Employees.LastName, Employees.FirstName 2 | FROM Orders 3 | RIGHT JOIN Employees 4 | ON Orders.EmployeeID = Employees.EmployeeID 5 | ORDER BY Orders.OrderID; 6 | -------------------------------------------------------------------------------- /e2e/testdata/select/row_number.sql: -------------------------------------------------------------------------------- 1 | -- from https://mode.com/sql-tutorial/sql-window-functions/ 2 | SELECT start_terminal, 3 | start_time, 4 | duration_seconds, 5 | ROW_NUMBER() OVER (ORDER BY start_time) 6 | AS row_number 7 | FROM tutorial.dc_bikeshare_q1_2012 8 | WHERE start_time < '2012-01-08' -------------------------------------------------------------------------------- /e2e/testdata/select/running_sum.sql: -------------------------------------------------------------------------------- 1 | -- from https://support.treasuredata.com/hc/ja/articles/216392117-Window-%E9%96%A2%E6%95%B0%E3%82%92%E4%BD%BF%E3%81%84%E3%81%93%E3%81%AA%E3%81%99-%E9%9B%86%E7%B4%84%E9%96%A2%E6%95%B0%E7%B3%BB- 2 | SELECT m, d, goods_id, sales, SUM(sales) OVER (PARTITION BY goods_id,m ORDER BY d ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sales_running_sum 3 | FROM 4 | ( 5 | SELECT 6 | TD_TIME_FORMAT(time,'yyyy-MM-dd','JST') AS d, TD_TIME_FORMAT(time,'yyyy-MM','JST') AS m, goods_id, SUM(price*amount) AS sales 7 | FROM sales_slip 8 | GROUP BY TD_TIME_FORMAT(time,'yyyy-MM-dd','JST'), TD_TIME_FORMAT(time,'yyyy-MM','JST'), goods_id 9 | ) t 10 | ORDER BY goods_id, m, d -------------------------------------------------------------------------------- /e2e/testdata/select/sql_between.sql: -------------------------------------------------------------------------------- 1 | SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) 2 | FROM T; -------------------------------------------------------------------------------- /e2e/testdata/select/union_all_where.sql: -------------------------------------------------------------------------------- 1 | SELECT City, Country FROM Customers 2 | WHERE Country='Germany' 3 | UNION ALL 4 | SELECT City, Country FROM Suppliers 5 | WHERE Country='Germany' 6 | ORDER BY City; -------------------------------------------------------------------------------- /e2e/testdata/select/where_and_or.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM Customers 2 | WHERE Country='Germany' AND (City='Berlin' OR City='München'); 3 | -------------------------------------------------------------------------------- /e2e/testdata/select/window.sql: -------------------------------------------------------------------------------- 1 | -- from https://mode.com/sql-tutorial/sql-window-functions/ 2 | SELECT start_terminal, 3 | duration_seconds, 4 | SUM(duration_seconds) OVER 5 | (PARTITION BY start_terminal ORDER BY start_time) 6 | AS running_total 7 | FROM tutorial.dc_bikeshare_q1_2012 8 | WHERE start_time < '2012-01-08' -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | 7 | "github.com/k0kubun/pp" 8 | 9 | "github.com/akito0107/xsqlparser" 10 | "github.com/akito0107/xsqlparser/dialect" 11 | "github.com/akito0107/xsqlparser/sqlast" 12 | ) 13 | 14 | func main() { 15 | simpleSelect() 16 | complicatedSelect() 17 | withCTE() 18 | createASTList() 19 | commentMap() 20 | } 21 | 22 | func simpleSelect() { 23 | str := "SELECT * from test_table" 24 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | 29 | stmt, err := parser.ParseStatement() 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | pp.Println(stmt) 34 | 35 | log.Println(stmt.ToSQLString()) 36 | } 37 | 38 | func complicatedSelect() { 39 | str := "SELECT orders.product, SUM(orders.quantity) AS product_units, accounts.* " + 40 | "FROM orders LEFT JOIN accounts ON orders.account_id = accounts.id " + 41 | "WHERE orders.region IN (SELECT region FROM top_regions) " + 42 | "ORDER BY product_units LIMIT 100" 43 | 44 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | 49 | stmt, err := parser.ParseStatement() 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | pp.Println(stmt) 54 | 55 | log.Println(stmt.ToSQLString()) 56 | } 57 | 58 | func withCTE() { 59 | str := "WITH regional_sales AS (" + 60 | "SELECT region, SUM(amount) AS total_sales " + 61 | "FROM orders GROUP BY region) " + 62 | "SELECT product, SUM(quantity) AS product_units " + 63 | "FROM orders " + 64 | "WHERE region IN (SELECT region FROM top_regions) " + 65 | "GROUP BY region, product" 66 | 67 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(str), &dialect.GenericSQLDialect{}) 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | 72 | stmt, err := parser.ParseStatement() 73 | if err != nil { 74 | log.Fatal(err) 75 | } 76 | pp.Println(stmt) 77 | 78 | log.Println(stmt.ToSQLString()) 79 | } 80 | 81 | func createASTList() { 82 | src := `WITH regional_sales AS ( 83 | SELECT region, SUM(amount) AS total_sales 84 | FROM orders GROUP BY region) 85 | SELECT product, SUM(quantity) AS product_units 86 | FROM orders 87 | WHERE region IN (SELECT region FROM top_regions) 88 | GROUP BY region, product;` 89 | 90 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(src), &dialect.GenericSQLDialect{}) 91 | if err != nil { 92 | log.Fatal(err) 93 | } 94 | 95 | stmt, err := parser.ParseStatement() 96 | if err != nil { 97 | log.Fatal(err) 98 | } 99 | var list []sqlast.Node 100 | 101 | sqlast.Inspect(stmt, func(node sqlast.Node) bool { 102 | switch node.(type) { 103 | case nil: 104 | return false 105 | default: 106 | list = append(list, node) 107 | return true 108 | } 109 | }) 110 | 111 | pp.Println(list) 112 | } 113 | 114 | func commentMap() { 115 | 116 | src := ` 117 | /*associate with stmts1*/ 118 | CREATE TABLE test ( 119 | /*associate with columndef*/ 120 | col0 int primary key, --columndef 121 | /*with constraints*/ 122 | col1 integer constraint test_constraint check (10 < col1 and col1 < 100), 123 | foreign key (col0, col1) references test2(col1, col2), --table constraints1 124 | --table constraints2 125 | CONSTRAINT test_constraint check(col1 > 10) 126 | ); --associate with stmts2 127 | ` 128 | 129 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(src), &dialect.GenericSQLDialect{}, xsqlparser.ParseComment()) 130 | if err != nil { 131 | log.Fatal(err) 132 | } 133 | 134 | file, err := parser.ParseFile() 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | 139 | m := sqlast.NewCommentMap(file) 140 | 141 | createTable := file.Stmts[0].(*sqlast.CreateTableStmt) 142 | 143 | pp.Println(m[createTable.Elements[0]]) // you can show `associate with columndef` and `columndef` comments 144 | } 145 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/akito0107/xsqlparser 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 7 | github.com/google/go-cmp v0.3.1 8 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect 9 | github.com/k0kubun/pp v3.0.1+incompatible 10 | github.com/mattn/go-colorable v0.1.2 // indirect 11 | github.com/mattn/go-isatty v0.0.9 // indirect 12 | github.com/sergi/go-diff v1.0.0 // indirect 13 | github.com/stretchr/testify v1.7.1 // indirect 14 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= 2 | github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= 6 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 7 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= 8 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= 9 | github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40= 10 | github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg= 11 | github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= 12 | github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 13 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 14 | github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= 15 | github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= 19 | github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= 20 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 21 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 22 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 23 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 24 | golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= 25 | golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 26 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= 27 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 28 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 29 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | -------------------------------------------------------------------------------- /sqlast/alter_column_action_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type AlterColumnAction interface { 6 | alterColumnActionMarker() 7 | Node 8 | } 9 | type alterColumnAction struct{} 10 | 11 | func (alterColumnAction) alterColumnActionMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/alter_table_action_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type AlterTableAction interface { 6 | alterTableActionMarker() 7 | Node 8 | } 9 | type alterTableAction struct{} 10 | 11 | func (alterTableAction) alterTableActionMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/comment.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/akito0107/xsqlparser/sqltoken" 7 | ) 8 | 9 | type CommentGroup struct { 10 | List []*Comment 11 | } 12 | 13 | func (c *CommentGroup) ToSQLString() string { 14 | return toSQLString(c) 15 | } 16 | 17 | func (c *CommentGroup) WriteTo(w io.Writer) (n int64, err error) { 18 | sw := newSQLWriter(w) 19 | for i, comment := range c.List { 20 | sw.JoinNewLine(i, comment) 21 | } 22 | return sw.End() 23 | } 24 | 25 | func (c *CommentGroup) Pos() sqltoken.Pos { 26 | return c.List[0].Pos() 27 | } 28 | 29 | func (c *CommentGroup) End() sqltoken.Pos { 30 | return c.List[len(c.List)-1].End() 31 | } 32 | 33 | type Comment struct { 34 | Text string 35 | From, To sqltoken.Pos 36 | } 37 | 38 | func (c *Comment) ToSQLString() string { 39 | return c.Text 40 | } 41 | 42 | func (c *Comment) WriteTo(w io.Writer) (int64, error) { 43 | return writeSingleString(w, c.Text) 44 | } 45 | 46 | func (c *Comment) Pos() sqltoken.Pos { 47 | return c.From 48 | } 49 | 50 | func (c *Comment) End() sqltoken.Pos { 51 | return c.To 52 | } 53 | -------------------------------------------------------------------------------- /sqlast/commentmap.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/akito0107/xsqlparser/sqltoken" 7 | ) 8 | 9 | type CommentMap map[Node][]*CommentGroup 10 | 11 | func (cmap CommentMap) addComment(n Node, c *CommentGroup) { 12 | list := cmap[n] 13 | 14 | if len(list) == 0 { 15 | list = []*CommentGroup{c} 16 | } else { 17 | list = append(list, c) 18 | } 19 | 20 | cmap[n] = list 21 | } 22 | 23 | func nodeList(file *File) []Node { 24 | var list []Node 25 | 26 | Inspect(file, func(node Node) bool { 27 | switch node.(type) { 28 | case nil: 29 | return false 30 | default: 31 | list = append(list, node) 32 | return true 33 | } 34 | }) 35 | return list 36 | } 37 | 38 | type commentListReader struct { 39 | list []*CommentGroup 40 | comment *CommentGroup 41 | idx int 42 | pos, end sqltoken.Pos 43 | } 44 | 45 | func (r *commentListReader) eol() bool { 46 | return len(r.list) <= r.idx 47 | } 48 | 49 | func (r *commentListReader) next() { 50 | if !r.eol() { 51 | r.comment = r.list[r.idx] 52 | r.pos = r.comment.Pos() 53 | r.end = r.comment.End() 54 | r.idx++ 55 | } 56 | } 57 | 58 | type nodeStack []Node 59 | 60 | func (s *nodeStack) push(n Node) { 61 | s.pop(n.Pos()) 62 | *s = append(*s, n) 63 | } 64 | 65 | func (s *nodeStack) pop(pos sqltoken.Pos) (top Node) { 66 | i := len(*s) 67 | 68 | for i > 0 && sqltoken.ComparePos((*s)[i-1].End(), pos) != 1 { 69 | top = (*s)[i-1] 70 | i-- 71 | } 72 | *s = (*s)[0:i] 73 | 74 | return top 75 | } 76 | 77 | func NewCommentMap(file *File) CommentMap { 78 | if len(file.Comments) == 0 { 79 | return nil 80 | } 81 | 82 | cmap := make(CommentMap) 83 | 84 | nodes := nodeList(file) 85 | nodes = append(nodes, nil) 86 | 87 | tmp := make([]*CommentGroup, len(file.Comments)) 88 | copy(tmp, file.Comments) 89 | r := commentListReader{list: tmp} 90 | r.next() 91 | 92 | var ( 93 | p Node 94 | pend sqltoken.Pos 95 | pg Node 96 | pgend sqltoken.Pos 97 | stack nodeStack 98 | ) 99 | 100 | for _, q := range nodes { 101 | var qpos sqltoken.Pos 102 | if q != nil { 103 | qpos = q.Pos() 104 | } else { 105 | const infinity = 1 << 30 106 | qpos = sqltoken.NewPos(infinity, infinity) 107 | } 108 | 109 | for sqltoken.ComparePos(qpos, r.end) != -1 { 110 | if top := stack.pop(r.comment.Pos()); top != nil { 111 | pg = top 112 | pgend = pg.End() 113 | } 114 | 115 | var assoc Node 116 | switch { 117 | case pg != nil && 118 | (pgend.Line == r.pos.Line || 119 | pgend.Line+1 == r.pos.Line && r.end.Line+1 < qpos.Line): 120 | assoc = pg 121 | case p != nil && 122 | (pend.Line == r.pos.Line || 123 | pend.Line+1 == r.pos.Line && r.end.Line+1 < qpos.Line || 124 | q == nil): 125 | assoc = p 126 | default: 127 | if q == nil { 128 | log.Panic("internal error") 129 | } 130 | assoc = q 131 | } 132 | cmap.addComment(assoc, r.comment) 133 | 134 | if r.eol() { 135 | return cmap 136 | } 137 | r.next() 138 | } 139 | 140 | p = q 141 | pend = p.End() 142 | 143 | switch q.(type) { 144 | // Stmts 145 | case *QueryStmt, *InsertStmt, *UpdateStmt, *DeleteStmt, *CreateViewStmt, *CreateTableStmt, *AlterTableStmt, *DropTableStmt, *CreateIndexStmt, *DropIndexStmt, *ExplainStmt: 146 | stack.push(q) 147 | // table element 148 | case *ColumnDef, *TableConstraint: 149 | stack.push(q) 150 | } 151 | } 152 | 153 | return cmap 154 | } 155 | -------------------------------------------------------------------------------- /sqlast/insert_source_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type InsertSource interface { 6 | insertSourceMarker() 7 | Node 8 | } 9 | type insertSource struct{} 10 | 11 | func (insertSource) insertSourceMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/join_element_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type JoinElement interface { 6 | joinElementMarker() 7 | Node 8 | } 9 | type joinElement struct{} 10 | 11 | func (joinElement) joinElementMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/join_spec_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type JoinSpec interface { 6 | joinSpecMarker() 7 | Node 8 | } 9 | type joinSpec struct{} 10 | 11 | func (joinSpec) joinSpecMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/my_data_type_decoration_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type MyDataTypeDecoration interface { 6 | myDataTypeDecorationMarker() 7 | Node 8 | } 9 | type myDataTypeDecoration struct{} 10 | 11 | func (myDataTypeDecoration) myDataTypeDecorationMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/operator.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/akito0107/xsqlparser/sqltoken" 7 | ) 8 | 9 | type Operator struct { 10 | Type OperatorType 11 | From, To sqltoken.Pos 12 | } 13 | 14 | func (o *Operator) Pos() sqltoken.Pos { 15 | return o.From 16 | } 17 | 18 | func (o *Operator) End() sqltoken.Pos { 19 | return o.To 20 | } 21 | 22 | type OperatorType int 23 | 24 | const ( 25 | Plus OperatorType = iota 26 | Minus 27 | Multiply 28 | Divide 29 | Modulus 30 | Gt 31 | Lt 32 | GtEq 33 | LtEq 34 | Eq 35 | NotEq 36 | And 37 | Or 38 | Not 39 | Like 40 | NotLike 41 | None 42 | ) 43 | 44 | func (o *Operator) ToSQLString() string { 45 | switch o.Type { 46 | case Plus: 47 | return "+" 48 | case Minus: 49 | return "-" 50 | case Multiply: 51 | return "*" 52 | case Divide: 53 | return "/" 54 | case Modulus: 55 | return "%" 56 | case Gt: 57 | return ">" 58 | case Lt: 59 | return "<" 60 | case GtEq: 61 | return ">=" 62 | case LtEq: 63 | return "<=" 64 | case Eq: 65 | return "=" 66 | case NotEq: 67 | return "!=" 68 | case And: 69 | return "AND" 70 | case Or: 71 | return "OR" 72 | case Not: 73 | return "NOT" 74 | case Like: 75 | return "LIKE" 76 | case NotLike: 77 | return "NOT LIKE" 78 | } 79 | return "" 80 | } 81 | 82 | func (o *Operator) WriteTo(w io.Writer) (int64, error) { 83 | switch o.Type { 84 | case Plus: 85 | return writeSingleBytes(w, []byte("+")) 86 | case Minus: 87 | return writeSingleBytes(w, []byte("-")) 88 | case Multiply: 89 | return writeSingleBytes(w, []byte("*")) 90 | case Divide: 91 | return writeSingleBytes(w, []byte("/")) 92 | case Modulus: 93 | return writeSingleBytes(w, []byte("%")) 94 | case Gt: 95 | return writeSingleBytes(w, []byte(">")) 96 | case Lt: 97 | return writeSingleBytes(w, []byte("<")) 98 | case GtEq: 99 | return writeSingleBytes(w, []byte(">=")) 100 | case LtEq: 101 | return writeSingleBytes(w, []byte("<=")) 102 | case Eq: 103 | return writeSingleBytes(w, []byte("=")) 104 | case NotEq: 105 | return writeSingleBytes(w, []byte("!=")) 106 | case And: 107 | return writeSingleBytes(w, []byte("AND")) 108 | case Or: 109 | return writeSingleBytes(w, []byte("OR")) 110 | case Not: 111 | return writeSingleBytes(w, []byte("NOT")) 112 | case Like: 113 | return writeSingleBytes(w, []byte("LIKE")) 114 | case NotLike: 115 | return writeSingleBytes(w, []byte("NOT LIKE")) 116 | } 117 | return 0, nil 118 | } -------------------------------------------------------------------------------- /sqlast/query.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | "log" 6 | 7 | "github.com/akito0107/xsqlparser/sqltoken" 8 | ) 9 | 10 | // QueryStmt stmt 11 | type QueryStmt struct { 12 | stmt 13 | With sqltoken.Pos // first char position of WITH if CTEs is not blank 14 | CTEs []*CTE 15 | Body SQLSetExpr 16 | OrderBy []*OrderByExpr 17 | Limit *LimitExpr 18 | } 19 | 20 | func (q *QueryStmt) Pos() sqltoken.Pos { 21 | if len(q.CTEs) != 0 { 22 | return q.With 23 | } 24 | 25 | return q.Body.Pos() 26 | } 27 | 28 | func (q *QueryStmt) End() sqltoken.Pos { 29 | if q.Limit != nil { 30 | return q.Limit.End() 31 | } 32 | 33 | if len(q.OrderBy) != 0 { 34 | return q.OrderBy[len(q.OrderBy)-1].End() 35 | } 36 | 37 | return q.Body.End() 38 | } 39 | 40 | func (q *QueryStmt) ToSQLString() string { 41 | return toSQLString(q) 42 | } 43 | 44 | func (q *QueryStmt) WriteTo(w io.Writer) (int64, error) { 45 | sw := newSQLWriter(w) 46 | if len(q.CTEs) != 0 { 47 | sw.Bytes([]byte("WITH ")) 48 | for i, cte := range q.CTEs { 49 | sw.JoinComma(i, cte) 50 | } 51 | sw.Space() 52 | } 53 | if sw.Err() == nil { 54 | sw.Direct(q.Body.WriteTo(w)) 55 | } 56 | if len(q.OrderBy) != 0 { 57 | sw.Bytes([]byte(" ORDER BY ")) 58 | for i, col := range q.OrderBy { 59 | sw.JoinComma(i, col) 60 | } 61 | } 62 | if q.Limit != nil { 63 | sw.Space().Node(q.Limit) 64 | } 65 | return sw.End() 66 | } 67 | 68 | // CTE 69 | type CTE struct { 70 | Alias *Ident 71 | Query *QueryStmt 72 | RParen sqltoken.Pos 73 | } 74 | 75 | func (c *CTE) Pos() sqltoken.Pos { 76 | return c.Alias.Pos() 77 | } 78 | 79 | func (c *CTE) End() sqltoken.Pos { 80 | return c.RParen 81 | } 82 | 83 | func (c *CTE) ToSQLString() string { 84 | return toSQLString(c) 85 | } 86 | 87 | func (c *CTE) WriteTo(w io.Writer) (int64, error) { 88 | return newSQLWriter(w). 89 | Node(c.Alias).As().LParen().Node(c.Query).RParen(). 90 | End() 91 | } 92 | 93 | //go:generate genmark -t SQLSetExpr -e Node 94 | 95 | // Select 96 | type SelectExpr struct { 97 | sqlSetExpr 98 | Select *SQLSelect 99 | } 100 | 101 | func (s *SelectExpr) Pos() sqltoken.Pos { 102 | return s.Select.Pos() 103 | } 104 | 105 | func (s *SelectExpr) End() sqltoken.Pos { 106 | return s.Select.End() 107 | } 108 | 109 | func (s *SelectExpr) ToSQLString() string { 110 | return toSQLString(s) 111 | } 112 | 113 | func (s *SelectExpr) WriteTo(w io.Writer) (int64, error) { 114 | return s.Select.WriteTo(w) 115 | } 116 | 117 | // (QueryStmt) 118 | type QueryExpr struct { 119 | sqlSetExpr 120 | LParen, RParen sqltoken.Pos 121 | Query *QueryStmt 122 | } 123 | 124 | func (q *QueryExpr) Pos() sqltoken.Pos { 125 | return q.LParen 126 | } 127 | 128 | func (q *QueryExpr) End() sqltoken.Pos { 129 | return q.RParen 130 | } 131 | 132 | func (q *QueryExpr) ToSQLString() string { 133 | return toSQLString(q) 134 | } 135 | 136 | func (q *QueryExpr) WriteTo(w io.Writer) (int64, error) { 137 | return newSQLWriter(w).LParen().Node(q.Query).RParen().End() 138 | } 139 | 140 | type SetOperationExpr struct { 141 | sqlSetExpr 142 | Op SQLSetOperator 143 | All bool 144 | Left SQLSetExpr 145 | Right SQLSetExpr 146 | } 147 | 148 | func (s *SetOperationExpr) Pos() sqltoken.Pos { 149 | return s.Left.Pos() 150 | } 151 | 152 | func (s *SetOperationExpr) End() sqltoken.Pos { 153 | return s.Right.End() 154 | } 155 | 156 | func (s *SetOperationExpr) ToSQLString() string { 157 | return toSQLString(s) 158 | } 159 | 160 | func (s *SetOperationExpr) WriteTo(w io.Writer) (n int64, err error) { 161 | return newSQLWriter(w). 162 | Node(s.Left).Space().Node(s.Op).If(s.All, []byte(" ALL")).Space().Node(s.Right). 163 | End() 164 | } 165 | 166 | //go:generate genmark -t SQLSetOperator -e Node 167 | 168 | type UnionOperator struct { 169 | sqlSetOperator 170 | From, To sqltoken.Pos 171 | } 172 | 173 | func (u *UnionOperator) Pos() sqltoken.Pos { 174 | return u.From 175 | } 176 | 177 | func (u *UnionOperator) End() sqltoken.Pos { 178 | return u.To 179 | } 180 | 181 | func (u *UnionOperator) ToSQLString() string { 182 | return "UNION" 183 | } 184 | 185 | func (u *UnionOperator) WriteTo(w io.Writer) (int64, error) { 186 | return writeSingleBytes(w, []byte("UNION")) 187 | } 188 | 189 | type ExceptOperator struct { 190 | sqlSetOperator 191 | From, To sqltoken.Pos 192 | } 193 | 194 | func (e *ExceptOperator) Pos() sqltoken.Pos { 195 | return e.From 196 | } 197 | 198 | func (e *ExceptOperator) End() sqltoken.Pos { 199 | return e.To 200 | } 201 | 202 | func (*ExceptOperator) ToSQLString() string { 203 | return "EXCEPT" 204 | } 205 | 206 | func (e *ExceptOperator) WriteTo(w io.Writer) (n int64, err error) { 207 | return writeSingleBytes(w, []byte("EXCEPT")) 208 | } 209 | 210 | type IntersectOperator struct { 211 | sqlSetOperator 212 | From, To sqltoken.Pos 213 | } 214 | 215 | func (i *IntersectOperator) Pos() sqltoken.Pos { 216 | return i.From 217 | } 218 | 219 | func (i *IntersectOperator) End() sqltoken.Pos { 220 | return i.To 221 | } 222 | 223 | func (IntersectOperator) ToSQLString() string { 224 | return "INTERSECT" 225 | } 226 | 227 | func (i *IntersectOperator) WriteTo(w io.Writer) (n int64, err error) { 228 | return writeSingleBytes(w, []byte("INTERSECT")) 229 | } 230 | 231 | type SQLSelect struct { 232 | sqlSetExpr 233 | Distinct bool 234 | Projection []SQLSelectItem 235 | FromClause []TableReference 236 | WhereClause Node 237 | GroupByClause []Node 238 | HavingClause Node 239 | Select sqltoken.Pos // first position of SELECT 240 | } 241 | 242 | func (s *SQLSelect) Pos() sqltoken.Pos { 243 | return s.Select 244 | } 245 | 246 | func (s *SQLSelect) End() sqltoken.Pos { 247 | if s.HavingClause != nil { 248 | return s.HavingClause.End() 249 | } 250 | 251 | if len(s.GroupByClause) != 0 { 252 | return s.GroupByClause[len(s.GroupByClause)-1].End() 253 | } 254 | 255 | if s.WhereClause != nil { 256 | return s.WhereClause.End() 257 | } 258 | 259 | if len(s.FromClause) != 0 { 260 | return s.FromClause[len(s.FromClause)-1].End() 261 | } 262 | 263 | return s.Projection[len(s.Projection)-1].End() 264 | } 265 | 266 | func (s *SQLSelect) ToSQLString() string { 267 | return toSQLString(s) 268 | } 269 | 270 | func (s *SQLSelect) WriteTo(w io.Writer) (int64, error) { 271 | sw := newSQLWriter(w) 272 | sw.Bytes(selectBytes) 273 | if s.Distinct { 274 | sw.Bytes([]byte("DISTINCT ")) 275 | } 276 | for i, projection := range s.Projection { 277 | sw.JoinComma(i, projection) 278 | } 279 | if len(s.FromClause) != 0 { 280 | sw.Bytes(fromBytes) 281 | for i, from := range s.FromClause { 282 | sw.JoinComma(i, from) 283 | } 284 | } 285 | if s.WhereClause != nil { 286 | sw.Bytes(whereBytes) 287 | if sw.Err() == nil { 288 | sw.Direct(s.WhereClause.WriteTo(w)) 289 | } 290 | } 291 | if len(s.GroupByClause) != 0 { 292 | sw.Bytes([]byte(" GROUP BY ")).Nodes(s.GroupByClause) 293 | } 294 | if s.HavingClause != nil { 295 | sw.Bytes([]byte(" HAVING ")).Node(s.HavingClause) 296 | } 297 | return sw.End() 298 | } 299 | 300 | //go:generate genmark -t TableReference -e Node 301 | 302 | //go:generate genmark -t TableFactor -e TableReference 303 | 304 | // Table 305 | type Table struct { 306 | tableFactor 307 | tableReference 308 | Name *ObjectName 309 | Alias *Ident 310 | Args []Node 311 | ArgsRParen sqltoken.Pos 312 | WithHints []Node 313 | WithHintsRParen sqltoken.Pos 314 | } 315 | 316 | func (t *Table) Pos() sqltoken.Pos { 317 | return t.Name.Pos() 318 | } 319 | 320 | func (t *Table) End() sqltoken.Pos { 321 | if len(t.WithHints) != 0 { 322 | return t.WithHintsRParen 323 | } 324 | 325 | if t.Alias != nil { 326 | return t.Alias.End() 327 | } 328 | 329 | if len(t.Args) != 0 { 330 | return t.ArgsRParen 331 | } 332 | 333 | return t.Name.End() 334 | } 335 | 336 | func (t *Table) ToSQLString() string { 337 | return toSQLString(t) 338 | } 339 | 340 | func (t *Table) WriteTo(w io.Writer) (int64, error) { 341 | sw := newSQLWriter(w) 342 | sw.Node(t.Name) 343 | if len(t.Args) != 0 { 344 | sw.LParen().Nodes(t.Args).RParen() 345 | } 346 | if t.Alias != nil { 347 | sw.As().Node(t.Alias) 348 | } 349 | if len(t.WithHints) != 0 { 350 | sw.Bytes([]byte(" WITH ")).LParen().Nodes(t.WithHints).RParen() 351 | } 352 | return sw.End() 353 | } 354 | 355 | type Derived struct { 356 | tableFactor 357 | tableReference 358 | Lateral bool 359 | LateralPos sqltoken.Pos // last position of LATERAL keyword if Lateral is true 360 | LParen sqltoken.Pos 361 | RParen sqltoken.Pos 362 | SubQuery *QueryStmt 363 | Alias *Ident 364 | } 365 | 366 | func (d *Derived) Pos() sqltoken.Pos { 367 | if d.Lateral { 368 | return d.LateralPos 369 | } 370 | return d.LParen 371 | } 372 | 373 | func (d *Derived) End() sqltoken.Pos { 374 | if d.Alias != nil { 375 | return d.Alias.End() 376 | } 377 | 378 | return d.LParen 379 | } 380 | 381 | func (d *Derived) ToSQLString() string { 382 | return toSQLString(d) 383 | } 384 | 385 | func (d *Derived) WriteTo(w io.Writer) (int64, error) { 386 | sw := newSQLWriter(w) 387 | sw.If(d.Lateral, []byte("LATERAL ")) 388 | sw.LParen().Node(d.SubQuery).RParen() 389 | if d.Alias != nil { 390 | sw.As().Node(d.Alias) 391 | } 392 | return sw.End() 393 | } 394 | 395 | //go:generate genmark -t SQLSelectItem -e Node 396 | 397 | type UnnamedSelectItem struct { 398 | sqlSelectItem 399 | Node Node 400 | } 401 | 402 | func (u *UnnamedSelectItem) Pos() sqltoken.Pos { 403 | return u.Node.Pos() 404 | } 405 | 406 | func (u *UnnamedSelectItem) End() sqltoken.Pos { 407 | return u.Node.End() 408 | } 409 | 410 | func (u *UnnamedSelectItem) ToSQLString() string { 411 | return toSQLString(u) 412 | } 413 | 414 | func (u *UnnamedSelectItem) WriteTo(w io.Writer) (int64, error) { 415 | return u.Node.WriteTo(w) 416 | } 417 | 418 | type AliasSelectItem struct { 419 | sqlSelectItem 420 | Expr Node 421 | Alias *Ident 422 | } 423 | 424 | func (a *AliasSelectItem) Pos() sqltoken.Pos { 425 | return a.Expr.Pos() 426 | } 427 | 428 | func (a *AliasSelectItem) End() sqltoken.Pos { 429 | return a.Alias.End() 430 | } 431 | 432 | func (a *AliasSelectItem) ToSQLString() string { 433 | return toSQLString(a) 434 | } 435 | 436 | func (a *AliasSelectItem) WriteTo(w io.Writer) (int64, error) { 437 | return newSQLWriter(w).Node(a.Expr).As().Node(a.Alias).End() 438 | } 439 | 440 | // schema.* 441 | type QualifiedWildcardSelectItem struct { 442 | sqlSelectItem 443 | Prefix *ObjectName 444 | } 445 | 446 | func (q *QualifiedWildcardSelectItem) Pos() sqltoken.Pos { 447 | return q.Prefix.Pos() 448 | } 449 | 450 | func (q *QualifiedWildcardSelectItem) End() sqltoken.Pos { 451 | return sqltoken.Pos{ 452 | Line: q.Prefix.End().Line, 453 | Col: q.Prefix.End().Col + 2, 454 | } 455 | } 456 | 457 | func (q *QualifiedWildcardSelectItem) ToSQLString() string { 458 | return toSQLString(q) 459 | } 460 | 461 | func (q *QualifiedWildcardSelectItem) WriteTo(w io.Writer) (int64, error) { 462 | return newSQLWriter(w).Node(q.Prefix).Bytes([]byte(".*")).End() 463 | } 464 | 465 | type WildcardSelectItem struct { 466 | sqlSelectItem 467 | From, To sqltoken.Pos 468 | } 469 | 470 | func (w *WildcardSelectItem) Pos() sqltoken.Pos { 471 | return w.From 472 | } 473 | 474 | func (w *WildcardSelectItem) End() sqltoken.Pos { 475 | return w.To 476 | } 477 | 478 | func (w *WildcardSelectItem) ToSQLString() string { 479 | return "*" 480 | } 481 | 482 | func (*WildcardSelectItem) WriteTo(w io.Writer) (int64, error) { 483 | return writeSingleBytes(w, []byte("*")) 484 | } 485 | 486 | type CrossJoin struct { 487 | tableReference 488 | Reference TableReference 489 | Factor TableFactor 490 | } 491 | 492 | func (c *CrossJoin) Pos() sqltoken.Pos { 493 | return c.Reference.Pos() 494 | } 495 | 496 | func (c *CrossJoin) End() sqltoken.Pos { 497 | return c.Factor.End() 498 | } 499 | 500 | func (c *CrossJoin) ToSQLString() string { 501 | return toSQLString(c) 502 | } 503 | 504 | func (c *CrossJoin) WriteTo(w io.Writer) (int64, error) { 505 | return newSQLWriter(w). 506 | Node(c.Reference).Bytes([]byte(" CROSS JOIN ")).Node(c.Factor). 507 | End() 508 | } 509 | 510 | //go:generate genmark -t JoinElement -e Node 511 | 512 | type TableJoinElement struct { 513 | joinElement 514 | Ref TableReference 515 | } 516 | 517 | func (t *TableJoinElement) Pos() sqltoken.Pos { 518 | return t.Ref.Pos() 519 | } 520 | 521 | func (t *TableJoinElement) End() sqltoken.Pos { 522 | return t.Ref.End() 523 | } 524 | 525 | func (t *TableJoinElement) ToSQLString() string { 526 | return toSQLString(t) 527 | } 528 | 529 | func (t *TableJoinElement) WriteTo(w io.Writer) (int64, error) { 530 | return t.Ref.WriteTo(w) 531 | } 532 | 533 | type PartitionedJoinTable struct { 534 | joinElement 535 | tableReference 536 | Factor TableFactor 537 | ColumnList []*Ident 538 | RParen sqltoken.Pos 539 | } 540 | 541 | func (p *PartitionedJoinTable) Pos() sqltoken.Pos { 542 | return p.Factor.Pos() 543 | } 544 | 545 | func (p *PartitionedJoinTable) End() sqltoken.Pos { 546 | return p.RParen 547 | } 548 | 549 | func (p *PartitionedJoinTable) ToSQLString() string { 550 | return toSQLString(p) 551 | } 552 | 553 | func (p *PartitionedJoinTable) WriteTo(w io.Writer) (int64, error) { 554 | return newSQLWriter(w). 555 | Node(p.Factor).Bytes([]byte(" PARTITION BY ")). 556 | LParen().Idents(p.ColumnList, []byte(", ")).RParen(). 557 | End() 558 | } 559 | 560 | type QualifiedJoin struct { 561 | tableReference 562 | LeftElement *TableJoinElement 563 | Type *JoinType 564 | RightElement *TableJoinElement 565 | Spec JoinSpec 566 | } 567 | 568 | func (q *QualifiedJoin) Pos() sqltoken.Pos { 569 | return q.LeftElement.Pos() 570 | } 571 | 572 | func (q *QualifiedJoin) End() sqltoken.Pos { 573 | return q.Spec.End() 574 | } 575 | 576 | func (q *QualifiedJoin) ToSQLString() string { 577 | return toSQLString(q) 578 | } 579 | 580 | func (q *QualifiedJoin) WriteTo(w io.Writer) (int64, error) { 581 | return newSQLWriter(w). 582 | Node(q.LeftElement).Space(). 583 | Node(q.Type).Bytes([]byte("JOIN ")). 584 | Node(q.RightElement).Space().Node(q.Spec). 585 | End() 586 | } 587 | 588 | type NaturalJoin struct { 589 | tableReference 590 | LeftElement *TableJoinElement 591 | Type *JoinType 592 | RightElement *TableJoinElement 593 | } 594 | 595 | func (n *NaturalJoin) Pos() sqltoken.Pos { 596 | return n.LeftElement.Pos() 597 | } 598 | 599 | func (n *NaturalJoin) End() sqltoken.Pos { 600 | return n.RightElement.End() 601 | } 602 | 603 | func (n *NaturalJoin) ToSQLString() string { 604 | return toSQLString(n) 605 | } 606 | 607 | func (n *NaturalJoin) WriteTo(w io.Writer) (int64, error) { 608 | return newSQLWriter(w). 609 | Node(n.LeftElement). 610 | Bytes([]byte(" NATURAL ")).Node(n.Type).Bytes([]byte("JOIN ")). 611 | Node(n.RightElement). 612 | End() 613 | } 614 | 615 | //go:generate genmark -t JoinSpec -e Node 616 | 617 | type NamedColumnsJoin struct { 618 | joinSpec 619 | ColumnList []*Ident 620 | Using sqltoken.Pos 621 | RParen sqltoken.Pos 622 | } 623 | 624 | func (n *NamedColumnsJoin) Pos() sqltoken.Pos { 625 | return n.Using 626 | } 627 | 628 | func (n *NamedColumnsJoin) End() sqltoken.Pos { 629 | return n.RParen 630 | } 631 | 632 | func (n *NamedColumnsJoin) ToSQLString() string { 633 | return toSQLString(n) 634 | } 635 | 636 | func (n *NamedColumnsJoin) WriteTo(w io.Writer) (int64, error) { 637 | return newSQLWriter(w). 638 | Bytes([]byte("USING ")). 639 | LParen().Idents(n.ColumnList, []byte(", ")).RParen(). 640 | End() 641 | } 642 | 643 | type JoinCondition struct { 644 | joinSpec 645 | SearchCondition Node 646 | On sqltoken.Pos 647 | } 648 | 649 | func (j *JoinCondition) Pos() sqltoken.Pos { 650 | return j.On 651 | } 652 | 653 | func (j *JoinCondition) End() sqltoken.Pos { 654 | return j.SearchCondition.End() 655 | } 656 | 657 | func (j *JoinCondition) ToSQLString() string { 658 | return toSQLString(j) 659 | } 660 | 661 | func (j *JoinCondition) WriteTo(w io.Writer) (int64, error) { 662 | return newSQLWriter(w).Bytes([]byte("ON ")).Node(j.SearchCondition).End() 663 | } 664 | 665 | type JoinType struct { 666 | Condition JoinTypeCondition 667 | From, To sqltoken.Pos 668 | } 669 | 670 | func (j *JoinType) Pos() sqltoken.Pos { 671 | return j.From 672 | } 673 | 674 | func (j *JoinType) End() sqltoken.Pos { 675 | return j.To 676 | } 677 | 678 | type JoinTypeCondition int 679 | 680 | const ( 681 | INNER JoinTypeCondition = iota 682 | LEFT 683 | RIGHT 684 | FULL 685 | LEFTOUTER 686 | RIGHTOUTER 687 | FULLOUTER 688 | IMPLICIT 689 | ) 690 | 691 | func (j *JoinType) ToSQLString() string { 692 | switch j.Condition { 693 | case INNER: 694 | return "INNER " 695 | case LEFT: 696 | return "LEFT " 697 | case RIGHT: 698 | return "RIGHT " 699 | case FULL: 700 | return "FULL " 701 | case LEFTOUTER: 702 | return "LEFT OUTER " 703 | case RIGHTOUTER: 704 | return "RIGHT OUTER " 705 | case FULLOUTER: 706 | return "FULL OUTER " 707 | case IMPLICIT: 708 | return "" 709 | default: 710 | log.Panicf("unknown join type %d", j) 711 | } 712 | return "" 713 | } 714 | 715 | func (j *JoinType) WriteTo(w io.Writer) (int64, error) { 716 | return writeSingleBytes(w, []byte(j.ToSQLString())) 717 | } 718 | 719 | // ORDER BY Expr [ASC | DESC] 720 | type OrderByExpr struct { 721 | Expr Node 722 | OrderingPos sqltoken.Pos // ASC / DESC keyword position if ASC != nil 723 | ASC *bool 724 | } 725 | 726 | func (o *OrderByExpr) Pos() sqltoken.Pos { 727 | return o.Expr.Pos() 728 | } 729 | 730 | func (o *OrderByExpr) End() sqltoken.Pos { 731 | if o.ASC != nil { 732 | return o.OrderingPos 733 | } 734 | 735 | return o.Expr.End() 736 | } 737 | 738 | func (o *OrderByExpr) ToSQLString() string { 739 | return toSQLString(o) 740 | } 741 | 742 | func (o *OrderByExpr) WriteTo(w io.Writer) (int64, error) { 743 | sw := newSQLWriter(w) 744 | sw.Node(o.Expr) 745 | if o.ASC != nil { 746 | if *o.ASC { 747 | sw.Bytes([]byte(" ASC")) 748 | } else { 749 | sw.Bytes([]byte(" DESC")) 750 | } 751 | } 752 | return sw.End() 753 | } 754 | 755 | // LIMIT [ALL | LimitValue ] [ OFFSET OffsetValue] 756 | type LimitExpr struct { 757 | All bool 758 | AllPos sqltoken.Pos // ALL keyword position if All is true 759 | Limit sqltoken.Pos // Limit keyword position 760 | LimitValue *LongValue 761 | OffsetValue *LongValue 762 | } 763 | 764 | func (l *LimitExpr) Pos() sqltoken.Pos { 765 | return l.Limit 766 | } 767 | 768 | func (l *LimitExpr) End() sqltoken.Pos { 769 | if l.All { 770 | return l.AllPos 771 | } 772 | 773 | if l.OffsetValue != nil { 774 | return l.OffsetValue.To 775 | } 776 | return l.LimitValue.To 777 | } 778 | 779 | func (l *LimitExpr) ToSQLString() string { 780 | return toSQLString(l) 781 | } 782 | 783 | func (l *LimitExpr) WriteTo(w io.Writer) (int64, error) { 784 | sw := newSQLWriter(w) 785 | sw.Bytes([]byte("LIMIT ")) 786 | if l.All { 787 | sw.Bytes([]byte("ALL")) 788 | } else { 789 | sw.Node(l.LimitValue) 790 | } 791 | if l.OffsetValue != nil { 792 | sw.Bytes([]byte(" OFFSET ")).Node(l.OffsetValue) 793 | } 794 | return sw.End() 795 | } 796 | -------------------------------------------------------------------------------- /sqlast/query_test.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/andreyvit/diff" 7 | ) 8 | 9 | func TestSQLSelect_ToSQLString(t *testing.T) { 10 | cases := []struct { 11 | name string 12 | in *SQLSelect 13 | out string 14 | }{ 15 | { 16 | name: "simple select", 17 | in: &SQLSelect{ 18 | Projection: []SQLSelectItem{ 19 | &UnnamedSelectItem{ 20 | Node: NewIdent("test"), 21 | }, 22 | }, 23 | FromClause: []TableReference{ 24 | &Table{ 25 | Name: NewObjectName("test_table"), 26 | }, 27 | }, 28 | }, 29 | out: "SELECT test FROM test_table", 30 | }, 31 | { 32 | name: "join", 33 | in: &SQLSelect{ 34 | Projection: []SQLSelectItem{ 35 | &UnnamedSelectItem{ 36 | Node: NewObjectName("test"), 37 | }, 38 | }, 39 | FromClause: []TableReference{ 40 | &NaturalJoin{ 41 | LeftElement: &TableJoinElement{ 42 | Ref: &Table{ 43 | Name: NewObjectName("test_table"), 44 | }, 45 | }, 46 | Type: &JoinType{ 47 | Condition: IMPLICIT, 48 | }, 49 | RightElement: &TableJoinElement{ 50 | Ref: &Table{ 51 | Name: NewObjectName("test_table2"), 52 | }, 53 | }, 54 | }, 55 | }, 56 | }, 57 | out: "SELECT test FROM test_table NATURAL JOIN test_table2", 58 | }, 59 | { 60 | name: "where", 61 | in: &SQLSelect{ 62 | Projection: []SQLSelectItem{ 63 | &UnnamedSelectItem{ 64 | Node: NewIdent("test"), 65 | }, 66 | }, 67 | FromClause: []TableReference{ 68 | &Table{ 69 | Name: NewObjectName("test_table"), 70 | }, 71 | }, 72 | WhereClause: &BinaryExpr{ 73 | Left: &CompoundIdent{ 74 | Idents: []*Ident{NewIdent("test_table"), NewIdent("column1")}, 75 | }, 76 | Op: &Operator{ 77 | Type: Eq, 78 | }, 79 | Right: NewSingleQuotedString("test"), 80 | }, 81 | }, 82 | out: "SELECT test FROM test_table WHERE test_table.column1 = 'test'", 83 | }, 84 | { 85 | name: "count and join", 86 | in: &SQLSelect{ 87 | Projection: []SQLSelectItem{ 88 | &AliasSelectItem{ 89 | Expr: &Function{ 90 | Name: NewObjectName("COUNT"), 91 | Args: []Node{&CompoundIdent{ 92 | Idents: []*Ident{NewIdent("t1"), NewIdent("id")}, 93 | }}, 94 | }, 95 | Alias: NewIdent("c"), 96 | }, 97 | }, 98 | FromClause: []TableReference{ 99 | &QualifiedJoin{ 100 | LeftElement: &TableJoinElement{ 101 | Ref: &Table{ 102 | Name: NewObjectName("test_table"), 103 | Alias: NewIdent("t1"), 104 | }, 105 | }, 106 | Type: &JoinType{ 107 | Condition: LEFT, 108 | }, 109 | RightElement: &TableJoinElement{ 110 | Ref: &Table{ 111 | Name: NewObjectName("test_table2"), 112 | Alias: NewIdent("t2"), 113 | }, 114 | }, 115 | Spec: &JoinCondition{ 116 | SearchCondition: &BinaryExpr{ 117 | Left: &CompoundIdent{ 118 | Idents: []*Ident{NewIdent("t1"), NewIdent("id")}, 119 | }, 120 | Op: &Operator{ 121 | Type: Eq, 122 | }, 123 | Right: &CompoundIdent{ 124 | Idents: []*Ident{NewIdent("t2"), NewIdent("test_table_id")}, 125 | }, 126 | }, 127 | }, 128 | }, 129 | }, 130 | }, 131 | out: "SELECT COUNT(t1.id) AS c FROM test_table AS t1 LEFT JOIN test_table2 AS t2 ON t1.id = t2.test_table_id", 132 | }, 133 | { 134 | name: "group by", 135 | in: &SQLSelect{ 136 | Projection: []SQLSelectItem{ 137 | &UnnamedSelectItem{ 138 | Node: &Function{ 139 | Name: NewObjectName("COUNT"), 140 | Args: []Node{NewIdent("customer_id")}, 141 | }, 142 | }, 143 | &QualifiedWildcardSelectItem{ 144 | Prefix: NewObjectName("country"), 145 | }, 146 | }, 147 | FromClause: []TableReference{ 148 | &Table{ 149 | Name: NewObjectName("customers"), 150 | }, 151 | }, 152 | GroupByClause: []Node{NewIdent("country")}, 153 | }, 154 | out: "SELECT COUNT(customer_id), country.* FROM customers GROUP BY country", 155 | }, 156 | { 157 | name: "having", 158 | in: &SQLSelect{ 159 | Projection: []SQLSelectItem{ 160 | &UnnamedSelectItem{ 161 | Node: &Function{ 162 | Name: NewObjectName("COUNT"), 163 | Args: []Node{NewIdent("customer_id")}, 164 | }, 165 | }, 166 | &UnnamedSelectItem{ 167 | Node: NewIdent("country"), 168 | }, 169 | }, 170 | FromClause: []TableReference{ 171 | &Table{ 172 | Name: NewObjectName("customers"), 173 | }, 174 | }, 175 | GroupByClause: []Node{NewIdent("country")}, 176 | HavingClause: &BinaryExpr{ 177 | Op: &Operator{Type: Gt}, 178 | Left: &Function{ 179 | Name: NewObjectName("COUNT"), 180 | Args: []Node{NewIdent("customer_id")}, 181 | }, 182 | Right: NewLongValue(3), 183 | }, 184 | }, 185 | out: "SELECT COUNT(customer_id), country FROM customers GROUP BY country HAVING COUNT(customer_id) > 3", 186 | }, 187 | } 188 | 189 | for _, c := range cases { 190 | t.Run(c.name, func(t *testing.T) { 191 | act := c.in.ToSQLString() 192 | 193 | if act != c.out { 194 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 195 | } 196 | }) 197 | } 198 | 199 | } 200 | 201 | func TestSQLQuery_ToSQLString(t *testing.T) { 202 | cases := []struct { 203 | name string 204 | in *QueryStmt 205 | out string 206 | }{ 207 | { 208 | // from https://www.postgresql.jp/document/9.3/html/queries-with.html 209 | name: "with cte", 210 | in: &QueryStmt{ 211 | CTEs: []*CTE{ 212 | { 213 | Alias: NewIdent("regional_sales"), 214 | Query: &QueryStmt{ 215 | Body: &SQLSelect{ 216 | Projection: []SQLSelectItem{ 217 | &UnnamedSelectItem{Node: NewIdent("region")}, 218 | &AliasSelectItem{ 219 | Alias: NewIdent("total_sales"), 220 | Expr: &Function{ 221 | Name: NewObjectName("SUM"), 222 | Args: []Node{NewIdent("amount")}, 223 | }, 224 | }, 225 | }, 226 | FromClause: []TableReference{ 227 | &Table{ 228 | Name: NewObjectName("orders"), 229 | }, 230 | }, 231 | GroupByClause: []Node{NewIdent("region")}, 232 | }, 233 | }, 234 | }, 235 | }, 236 | Body: &SQLSelect{ 237 | Projection: []SQLSelectItem{ 238 | &UnnamedSelectItem{Node: NewIdent("product")}, 239 | &AliasSelectItem{ 240 | Alias: NewIdent("product_units"), 241 | Expr: &Function{ 242 | Name: NewObjectName("SUM"), 243 | Args: []Node{NewIdent("quantity")}, 244 | }, 245 | }, 246 | }, 247 | FromClause: []TableReference{ 248 | &Table{ 249 | Name: NewObjectName("orders"), 250 | }, 251 | }, 252 | WhereClause: &InSubQuery{ 253 | Expr: NewIdent("region"), 254 | SubQuery: &QueryStmt{ 255 | Body: &SQLSelect{ 256 | Projection: []SQLSelectItem{ 257 | &UnnamedSelectItem{Node: NewIdent("region")}, 258 | }, 259 | FromClause: []TableReference{ 260 | &Table{ 261 | Name: NewObjectName("top_regions"), 262 | }, 263 | }, 264 | }, 265 | }, 266 | }, 267 | GroupByClause: []Node{NewIdent("region"), NewIdent("product")}, 268 | }, 269 | }, 270 | out: "WITH regional_sales AS (" + 271 | "SELECT region, SUM(amount) AS total_sales " + 272 | "FROM orders GROUP BY region) " + 273 | "SELECT product, SUM(quantity) AS product_units " + 274 | "FROM orders " + 275 | "WHERE region IN (SELECT region FROM top_regions) " + 276 | "GROUP BY region, product", 277 | }, 278 | { 279 | name: "order by and limit", 280 | in: &QueryStmt{ 281 | Body: &SQLSelect{ 282 | Projection: []SQLSelectItem{ 283 | &UnnamedSelectItem{Node: NewIdent("product")}, 284 | &AliasSelectItem{ 285 | Alias: NewIdent("product_units"), 286 | Expr: &Function{ 287 | Name: NewObjectName("SUM"), 288 | Args: []Node{NewIdent("quantity")}, 289 | }, 290 | }, 291 | }, 292 | FromClause: []TableReference{ 293 | &Table{ 294 | Name: NewObjectName("orders"), 295 | }, 296 | }, 297 | WhereClause: &InSubQuery{ 298 | Expr: NewIdent("region"), 299 | SubQuery: &QueryStmt{ 300 | Body: &SQLSelect{ 301 | Projection: []SQLSelectItem{ 302 | &UnnamedSelectItem{Node: NewIdent("region")}, 303 | }, 304 | FromClause: []TableReference{ 305 | &Table{ 306 | Name: NewObjectName("top_regions"), 307 | }, 308 | }, 309 | }, 310 | }, 311 | }, 312 | }, 313 | OrderBy: []*OrderByExpr{ 314 | {Expr: NewIdent("product_units")}, 315 | }, 316 | Limit: &LimitExpr{LimitValue: NewLongValue(100)}, 317 | }, 318 | out: "SELECT product, SUM(quantity) AS product_units " + 319 | "FROM orders " + 320 | "WHERE region IN (SELECT region FROM top_regions) " + 321 | "ORDER BY product_units LIMIT 100", 322 | }, 323 | { 324 | name: "exists", 325 | in: &QueryStmt{ 326 | Body: &SQLSelect{ 327 | Projection: []SQLSelectItem{ 328 | &UnnamedSelectItem{ 329 | Node: &Wildcard{}, 330 | }, 331 | }, 332 | FromClause: []TableReference{ 333 | &Table{ 334 | Name: NewObjectName("user"), 335 | }, 336 | }, 337 | WhereClause: &Exists{ 338 | Negated: true, 339 | Query: &QueryStmt{ 340 | Body: &SQLSelect{ 341 | Projection: []SQLSelectItem{ 342 | &UnnamedSelectItem{ 343 | Node: &Wildcard{}, 344 | }, 345 | }, 346 | FromClause: []TableReference{ 347 | &Table{ 348 | Name: NewObjectName("user_sub"), 349 | }, 350 | }, 351 | WhereClause: &BinaryExpr{ 352 | Op: &Operator{Type: And}, 353 | Left: &BinaryExpr{ 354 | Op: &Operator{Type: Eq}, 355 | Left: &CompoundIdent{ 356 | Idents: []*Ident{ 357 | NewIdent("user"), 358 | NewIdent("id"), 359 | }, 360 | }, 361 | Right: &CompoundIdent{ 362 | Idents: []*Ident{ 363 | NewIdent("user_sub"), 364 | NewIdent("id"), 365 | }, 366 | }, 367 | }, 368 | Right: &BinaryExpr{ 369 | Op: &Operator{Type: Eq}, 370 | Left: &CompoundIdent{ 371 | Idents: []*Ident{ 372 | NewIdent("user_sub"), 373 | NewIdent("job"), 374 | }, 375 | }, 376 | Right: NewSingleQuotedString("job"), 377 | }, 378 | }, 379 | }, 380 | }, 381 | }, 382 | }, 383 | }, 384 | out: "SELECT * FROM user WHERE NOT EXISTS (" + 385 | "SELECT * FROM user_sub WHERE user.id = user_sub.id AND user_sub.job = 'job'" + 386 | ")", 387 | }, 388 | { 389 | name: "between / case", 390 | in: &QueryStmt{ 391 | Body: &SQLSelect{ 392 | Projection: []SQLSelectItem{ 393 | &AliasSelectItem{ 394 | Expr: &CaseExpr{ 395 | Conditions: []Node{ 396 | &BinaryExpr{ 397 | Op: &Operator{Type: Eq}, 398 | Left: NewIdent("expr1"), 399 | Right: NewSingleQuotedString("1"), 400 | }, 401 | &BinaryExpr{ 402 | Op: &Operator{Type: Eq}, 403 | Left: NewIdent("expr2"), 404 | Right: NewSingleQuotedString("2"), 405 | }, 406 | }, 407 | Results: []Node{ 408 | NewSingleQuotedString("test1"), 409 | NewSingleQuotedString("test2"), 410 | }, 411 | ElseResult: NewSingleQuotedString("other"), 412 | }, 413 | Alias: NewIdent("alias"), 414 | }, 415 | }, 416 | FromClause: []TableReference{ 417 | &Table{ 418 | Name: NewObjectName("user"), 419 | }, 420 | }, 421 | WhereClause: &Between{ 422 | Expr: NewIdent("id"), 423 | High: NewLongValue(2), 424 | Low: NewLongValue(1), 425 | }, 426 | }, 427 | }, 428 | out: "SELECT CASE WHEN expr1 = '1' THEN 'test1' WHEN expr2 = '2' THEN 'test2' ELSE 'other' END AS alias " + 429 | "FROM user WHERE id BETWEEN 1 AND 2", 430 | }, 431 | } 432 | 433 | for _, c := range cases { 434 | t.Run(c.name, func(t *testing.T) { 435 | act := c.in.ToSQLString() 436 | 437 | if act != c.out { 438 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 439 | } 440 | }) 441 | } 442 | 443 | } 444 | 445 | func BenchmarkSQLQuery_ToSQLString(b *testing.B) { 446 | cases := []struct { 447 | name string 448 | in *QueryStmt 449 | }{ 450 | { 451 | // from https://www.postgresql.jp/document/9.3/html/queries-with.html 452 | name: "with cte", 453 | in: &QueryStmt{ 454 | CTEs: []*CTE{ 455 | { 456 | Alias: NewIdent("regional_sales"), 457 | Query: &QueryStmt{ 458 | Body: &SQLSelect{ 459 | Projection: []SQLSelectItem{ 460 | &UnnamedSelectItem{Node: NewIdent("region")}, 461 | &AliasSelectItem{ 462 | Alias: NewIdent("total_sales"), 463 | Expr: &Function{ 464 | Name: NewObjectName("SUM"), 465 | Args: []Node{NewIdent("amount")}, 466 | }, 467 | }, 468 | }, 469 | FromClause: []TableReference{ 470 | &Table{ 471 | Name: NewObjectName("orders"), 472 | }, 473 | }, 474 | GroupByClause: []Node{NewIdent("region")}, 475 | }, 476 | }, 477 | }, 478 | }, 479 | Body: &SQLSelect{ 480 | Projection: []SQLSelectItem{ 481 | &UnnamedSelectItem{Node: NewIdent("product")}, 482 | &AliasSelectItem{ 483 | Alias: NewIdent("product_units"), 484 | Expr: &Function{ 485 | Name: NewObjectName("SUM"), 486 | Args: []Node{NewIdent("quantity")}, 487 | }, 488 | }, 489 | }, 490 | FromClause: []TableReference{ 491 | &Table{ 492 | Name: NewObjectName("orders"), 493 | }, 494 | }, 495 | WhereClause: &InSubQuery{ 496 | Expr: NewIdent("region"), 497 | SubQuery: &QueryStmt{ 498 | Body: &SQLSelect{ 499 | Projection: []SQLSelectItem{ 500 | &UnnamedSelectItem{Node: NewIdent("region")}, 501 | }, 502 | FromClause: []TableReference{ 503 | &Table{ 504 | Name: NewObjectName("top_regions"), 505 | }, 506 | }, 507 | }, 508 | }, 509 | }, 510 | GroupByClause: []Node{NewIdent("region"), NewIdent("product")}, 511 | }, 512 | }, 513 | }, 514 | { 515 | name: "order by and limit", 516 | in: &QueryStmt{ 517 | Body: &SQLSelect{ 518 | Projection: []SQLSelectItem{ 519 | &UnnamedSelectItem{Node: NewIdent("product")}, 520 | &AliasSelectItem{ 521 | Alias: NewIdent("product_units"), 522 | Expr: &Function{ 523 | Name: NewObjectName("SUM"), 524 | Args: []Node{NewIdent("quantity")}, 525 | }, 526 | }, 527 | }, 528 | FromClause: []TableReference{ 529 | &Table{ 530 | Name: NewObjectName("orders"), 531 | }, 532 | }, 533 | WhereClause: &InSubQuery{ 534 | Expr: NewIdent("region"), 535 | SubQuery: &QueryStmt{ 536 | Body: &SQLSelect{ 537 | Projection: []SQLSelectItem{ 538 | &UnnamedSelectItem{Node: NewIdent("region")}, 539 | }, 540 | FromClause: []TableReference{ 541 | &Table{ 542 | Name: NewObjectName("top_regions"), 543 | }, 544 | }, 545 | }, 546 | }, 547 | }, 548 | }, 549 | OrderBy: []*OrderByExpr{ 550 | {Expr: NewIdent("product_units")}, 551 | }, 552 | Limit: &LimitExpr{LimitValue: NewLongValue(100)}, 553 | }, 554 | }, 555 | { 556 | name: "exists", 557 | in: &QueryStmt{ 558 | Body: &SQLSelect{ 559 | Projection: []SQLSelectItem{ 560 | &UnnamedSelectItem{ 561 | Node: &Wildcard{}, 562 | }, 563 | }, 564 | FromClause: []TableReference{ 565 | &Table{ 566 | Name: NewObjectName("user"), 567 | }, 568 | }, 569 | WhereClause: &Exists{ 570 | Negated: true, 571 | Query: &QueryStmt{ 572 | Body: &SQLSelect{ 573 | Projection: []SQLSelectItem{ 574 | &UnnamedSelectItem{ 575 | Node: &Wildcard{}, 576 | }, 577 | }, 578 | FromClause: []TableReference{ 579 | &Table{ 580 | Name: NewObjectName("user_sub"), 581 | }, 582 | }, 583 | WhereClause: &BinaryExpr{ 584 | Op: &Operator{Type: And}, 585 | Left: &BinaryExpr{ 586 | Op: &Operator{Type: Eq}, 587 | Left: &CompoundIdent{ 588 | Idents: []*Ident{ 589 | NewIdent("user"), 590 | NewIdent("id"), 591 | }, 592 | }, 593 | Right: &CompoundIdent{ 594 | Idents: []*Ident{ 595 | NewIdent("user_sub"), 596 | NewIdent("id"), 597 | }, 598 | }, 599 | }, 600 | Right: &BinaryExpr{ 601 | Op: &Operator{Type: Eq}, 602 | Left: &CompoundIdent{ 603 | Idents: []*Ident{ 604 | NewIdent("user_sub"), 605 | NewIdent("job"), 606 | }, 607 | }, 608 | Right: NewSingleQuotedString("job"), 609 | }, 610 | }, 611 | }, 612 | }, 613 | }, 614 | }, 615 | }, 616 | }, 617 | { 618 | name: "between / case", 619 | in: &QueryStmt{ 620 | Body: &SQLSelect{ 621 | Projection: []SQLSelectItem{ 622 | &AliasSelectItem{ 623 | Expr: &CaseExpr{ 624 | Conditions: []Node{ 625 | &BinaryExpr{ 626 | Op: &Operator{Type: Eq}, 627 | Left: NewIdent("expr1"), 628 | Right: NewSingleQuotedString("1"), 629 | }, 630 | &BinaryExpr{ 631 | Op: &Operator{Type: Eq}, 632 | Left: NewIdent("expr2"), 633 | Right: NewSingleQuotedString("2"), 634 | }, 635 | }, 636 | Results: []Node{ 637 | NewSingleQuotedString("test1"), 638 | NewSingleQuotedString("test2"), 639 | }, 640 | ElseResult: NewSingleQuotedString("other"), 641 | }, 642 | Alias: NewIdent("alias"), 643 | }, 644 | }, 645 | FromClause: []TableReference{ 646 | &Table{ 647 | Name: NewObjectName("user"), 648 | }, 649 | }, 650 | WhereClause: &Between{ 651 | Expr: NewIdent("id"), 652 | High: NewLongValue(2), 653 | Low: NewLongValue(1), 654 | }, 655 | }, 656 | }, 657 | }, 658 | } 659 | 660 | for _, c := range cases { 661 | b.Run(c.name, func(b *testing.B) { 662 | b.ReportAllocs() 663 | for i := 0; i < b.N; i++ { 664 | c.in.ToSQLString() 665 | } 666 | }) 667 | } 668 | 669 | } 670 | -------------------------------------------------------------------------------- /sqlast/sql_select_item_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type SQLSelectItem interface { 6 | sqlSelectItemMarker() 7 | Node 8 | } 9 | type sqlSelectItem struct{} 10 | 11 | func (sqlSelectItem) sqlSelectItemMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/sql_set_expr_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type SQLSetExpr interface { 6 | sqlSetExprMarker() 7 | Node 8 | } 9 | type sqlSetExpr struct{} 10 | 11 | func (sqlSetExpr) sqlSetExprMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/sql_set_operator_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type SQLSetOperator interface { 6 | sqlSetOperatorMarker() 7 | Node 8 | } 9 | type sqlSetOperator struct{} 10 | 11 | func (sqlSetOperator) sqlSetOperatorMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/sql_window_frame_bound_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type SQLWindowFrameBound interface { 6 | sqlWindowFrameBoundMarker() 7 | Node 8 | } 9 | type sqlWindowFrameBound struct{} 10 | 11 | func (sqlWindowFrameBound) sqlWindowFrameBoundMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/stmt_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type Stmt interface { 6 | stmtMarker() 7 | Node 8 | } 9 | type stmt struct{} 10 | 11 | func (stmt) stmtMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/stmt_test.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/andreyvit/diff" 7 | ) 8 | 9 | // from https://www.w3schools.com/sql/sql_insert.asp 10 | 11 | func TestSQLInsert_ToSQLString(t *testing.T) { 12 | cases := []struct { 13 | name string 14 | in *InsertStmt 15 | out string 16 | }{ 17 | { 18 | name: "simple case", 19 | in: &InsertStmt{ 20 | TableName: NewObjectName("customers"), 21 | Columns: []*Ident{ 22 | NewIdent("customer_name"), 23 | NewIdent("contract_name"), 24 | }, 25 | Source: &ConstructorSource{ 26 | Rows: []*RowValueExpr{ 27 | { 28 | Values: []Node{ 29 | NewSingleQuotedString("Cardinal"), 30 | NewSingleQuotedString("Tom B. Erichsen"), 31 | }, 32 | }, 33 | }, 34 | }, 35 | }, 36 | out: "INSERT INTO customers (customer_name, contract_name) VALUES ('Cardinal', 'Tom B. Erichsen')", 37 | }, 38 | { 39 | name: "multi row case", 40 | in: &InsertStmt{ 41 | TableName: NewObjectName("customers"), 42 | Columns: []*Ident{ 43 | NewIdent("customer_name"), 44 | NewIdent("contract_name"), 45 | }, 46 | Source: &ConstructorSource{ 47 | Rows: []*RowValueExpr{ 48 | { 49 | Values: []Node{ 50 | NewSingleQuotedString("Cardinal"), 51 | NewSingleQuotedString("Tom B. Erichsen"), 52 | }, 53 | }, 54 | { 55 | Values: []Node{ 56 | NewSingleQuotedString("Cardinal2"), 57 | NewSingleQuotedString("Tom B. Erichsen2"), 58 | }, 59 | }, 60 | { 61 | Values: []Node{ 62 | NewSingleQuotedString("Cardinal3"), 63 | NewSingleQuotedString("Tom B. Erichsen3"), 64 | }, 65 | }, 66 | }, 67 | }, 68 | }, 69 | out: "INSERT INTO customers (customer_name, contract_name) VALUES ('Cardinal', 'Tom B. Erichsen'), ('Cardinal2', 'Tom B. Erichsen2'), ('Cardinal3', 'Tom B. Erichsen3')", 70 | }, 71 | { 72 | name: "insert sub query", 73 | in: &InsertStmt{ 74 | TableName: NewObjectName("customers"), 75 | Columns: []*Ident{ 76 | NewIdent("customer_name"), 77 | NewIdent("contract_name"), 78 | }, 79 | Source: &SubQuerySource{ 80 | SubQuery: &QueryStmt{ 81 | Body: &SelectExpr{ 82 | Select: &SQLSelect{ 83 | Projection: []SQLSelectItem{ 84 | &WildcardSelectItem{}, 85 | }, 86 | FromClause: []TableReference{ 87 | &Table{ 88 | Name: NewObjectName("customers2"), 89 | }, 90 | }, 91 | }, 92 | }, 93 | }, 94 | }, 95 | }, 96 | out: "INSERT INTO customers (customer_name, contract_name) SELECT * FROM customers2", 97 | }, 98 | } 99 | for _, c := range cases { 100 | t.Run(c.name, func(t *testing.T) { 101 | act := c.in.ToSQLString() 102 | 103 | if act != c.out { 104 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 105 | } 106 | }) 107 | } 108 | } 109 | 110 | func TestSQLUpdate_ToSQLString(t *testing.T) { 111 | cases := []struct { 112 | name string 113 | in *UpdateStmt 114 | out string 115 | }{ 116 | { 117 | name: "simple case", 118 | in: &UpdateStmt{ 119 | TableName: NewObjectName("customers"), 120 | Assignments: []*Assignment{ 121 | { 122 | ID: NewIdent("contract_name"), 123 | Value: NewSingleQuotedString("Alfred Schmidt"), 124 | }, 125 | { 126 | ID: NewIdent("city"), 127 | Value: NewSingleQuotedString("Frankfurt"), 128 | }, 129 | }, 130 | Selection: &BinaryExpr{ 131 | Op: &Operator{Type: Eq}, 132 | Left: NewIdent("customer_id"), 133 | Right: NewLongValue(1), 134 | }, 135 | }, 136 | out: "UPDATE customers SET contract_name = 'Alfred Schmidt', city = 'Frankfurt' WHERE customer_id = 1", 137 | }, 138 | } 139 | for _, c := range cases { 140 | t.Run(c.name, func(t *testing.T) { 141 | act := c.in.ToSQLString() 142 | 143 | if act != c.out { 144 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 145 | } 146 | }) 147 | } 148 | } 149 | 150 | func TestSQLDelete_ToSQLString(t *testing.T) { 151 | cases := []struct { 152 | name string 153 | in *DeleteStmt 154 | out string 155 | }{ 156 | { 157 | name: "simple case", 158 | in: &DeleteStmt{ 159 | TableName: NewObjectName("customers"), 160 | Selection: &BinaryExpr{ 161 | Op: &Operator{Type: Eq}, 162 | Left: NewIdent("customer_id"), 163 | Right: NewLongValue(1), 164 | }, 165 | }, 166 | out: "DELETE FROM customers WHERE customer_id = 1", 167 | }, 168 | } 169 | for _, c := range cases { 170 | t.Run(c.name, func(t *testing.T) { 171 | act := c.in.ToSQLString() 172 | 173 | if act != c.out { 174 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 175 | } 176 | }) 177 | } 178 | } 179 | 180 | func TestSQLCreateView_ToSQLString(t *testing.T) { 181 | cases := []struct { 182 | name string 183 | in *CreateViewStmt 184 | out string 185 | }{ 186 | { 187 | name: "simple case", 188 | in: &CreateViewStmt{ 189 | Name: NewObjectName("customers_view"), 190 | Query: &QueryStmt{ 191 | Body: &SelectExpr{ 192 | Select: &SQLSelect{ 193 | Projection: []SQLSelectItem{ 194 | &UnnamedSelectItem{ 195 | Node: NewIdent("customer_name"), 196 | }, 197 | &UnnamedSelectItem{ 198 | Node: NewIdent("contract_name"), 199 | }, 200 | }, 201 | FromClause: []TableReference{ 202 | &Table{ 203 | Name: &ObjectName{ 204 | Idents: []*Ident{ 205 | NewIdent("customers"), 206 | }, 207 | }, 208 | }, 209 | }, 210 | WhereClause: &BinaryExpr{ 211 | Op: &Operator{Type: Eq}, 212 | Left: NewIdent("country"), 213 | Right: NewSingleQuotedString("Brazil"), 214 | }, 215 | }, 216 | }, 217 | }, 218 | }, 219 | out: "CREATE VIEW customers_view AS " + 220 | "SELECT customer_name, contract_name " + 221 | "FROM customers " + 222 | "WHERE country = 'Brazil'", 223 | }, 224 | } 225 | for _, c := range cases { 226 | t.Run(c.name, func(t *testing.T) { 227 | act := c.in.ToSQLString() 228 | 229 | if act != c.out { 230 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 231 | } 232 | }) 233 | } 234 | } 235 | 236 | func TestSQLCreateTable_ToSQLString(t *testing.T) { 237 | cases := []struct { 238 | name string 239 | in *CreateTableStmt 240 | out string 241 | }{ 242 | { 243 | name: "simple case", 244 | in: &CreateTableStmt{ 245 | Name: NewObjectName("persons"), 246 | Elements: []TableElement{ 247 | &ColumnDef{ 248 | Name: NewIdent("person_id"), 249 | DataType: &Int{}, 250 | Constraints: []*ColumnConstraint{ 251 | { 252 | Spec: &UniqueColumnSpec{ 253 | IsPrimaryKey: true, 254 | }, 255 | }, 256 | { 257 | Spec: &NotNullColumnSpec{}, 258 | }, 259 | }, 260 | }, 261 | &ColumnDef{ 262 | Name: NewIdent("last_name"), 263 | DataType: &VarcharType{ 264 | Size: NewSize(255), 265 | }, 266 | Constraints: []*ColumnConstraint{ 267 | { 268 | Spec: &NotNullColumnSpec{}, 269 | }, 270 | }, 271 | }, 272 | &ColumnDef{ 273 | Name: NewIdent("test_id"), 274 | DataType: &Int{}, 275 | Constraints: []*ColumnConstraint{ 276 | { 277 | Spec: &NotNullColumnSpec{}, 278 | }, 279 | { 280 | Spec: &ReferencesColumnSpec{ 281 | TableName: NewObjectName("test"), 282 | Columns: []*Ident{NewIdent("id1"), NewIdent("id2")}, 283 | }, 284 | }, 285 | }, 286 | }, 287 | &ColumnDef{ 288 | Name: NewIdent("email"), 289 | DataType: &VarcharType{ 290 | Size: NewSize(255), 291 | }, 292 | Constraints: []*ColumnConstraint{ 293 | { 294 | Spec: &UniqueColumnSpec{}, 295 | }, 296 | { 297 | Spec: &NotNullColumnSpec{}, 298 | }, 299 | }, 300 | }, 301 | &ColumnDef{ 302 | Name: NewIdent("age"), 303 | DataType: &Int{}, 304 | Constraints: []*ColumnConstraint{ 305 | { 306 | Spec: &NotNullColumnSpec{}, 307 | }, 308 | { 309 | Spec: &CheckColumnSpec{ 310 | Expr: &BinaryExpr{ 311 | Op: &Operator{Type: And}, 312 | Left: &BinaryExpr{ 313 | Op: &Operator{Type: Gt}, 314 | Left: NewIdent("age"), 315 | Right: NewLongValue(0), 316 | }, 317 | Right: &BinaryExpr{ 318 | Op: &Operator{Type: Lt}, 319 | Left: NewIdent("age"), 320 | Right: NewLongValue(100), 321 | }, 322 | }, 323 | }, 324 | }, 325 | }, 326 | }, 327 | &ColumnDef{ 328 | Name: NewIdent("created_at"), 329 | DataType: &Timestamp{}, 330 | Default: NewIdent("CURRENT_TIMESTAMP"), 331 | Constraints: []*ColumnConstraint{ 332 | { 333 | Spec: &NotNullColumnSpec{}, 334 | }, 335 | }, 336 | }, 337 | }, 338 | }, 339 | out: "CREATE TABLE persons (" + 340 | "person_id int PRIMARY KEY NOT NULL, " + 341 | "last_name character varying(255) NOT NULL, " + 342 | "test_id int NOT NULL REFERENCES test(id1, id2), " + 343 | "email character varying(255) UNIQUE NOT NULL, " + 344 | "age int NOT NULL CHECK(age > 0 AND age < 100), " + 345 | "created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL)", 346 | }, 347 | { 348 | name: "with table constraint", 349 | in: &CreateTableStmt{ 350 | Name: NewObjectName("persons"), 351 | Elements: []TableElement{ 352 | &ColumnDef{ 353 | Name: NewIdent("person_id"), 354 | DataType: &Int{}, 355 | }, 356 | &TableConstraint{ 357 | Name: NewIdent("production"), 358 | Spec: &UniqueTableConstraint{ 359 | Columns: []*Ident{NewIdent("test_column")}, 360 | }, 361 | }, 362 | &TableConstraint{ 363 | Spec: &UniqueTableConstraint{ 364 | Columns: []*Ident{NewIdent("person_id")}, 365 | IsPrimary: true, 366 | }, 367 | }, 368 | &TableConstraint{ 369 | Spec: &CheckTableConstraint{ 370 | Expr: &BinaryExpr{ 371 | Left: NewIdent("id"), 372 | Op: &Operator{Type: Gt}, 373 | Right: NewLongValue(100), 374 | }, 375 | }, 376 | }, 377 | &TableConstraint{ 378 | Spec: &ReferentialTableConstraint{ 379 | Columns: []*Ident{NewIdent("test_id")}, 380 | KeyExpr: &ReferenceKeyExpr{ 381 | TableName: NewIdent("other_table"), 382 | Columns: []*Ident{NewIdent("col1"), NewIdent("col2")}, 383 | }, 384 | }, 385 | }, 386 | }, 387 | }, 388 | out: "CREATE TABLE persons (" + 389 | "person_id int, " + 390 | "CONSTRAINT production UNIQUE(test_column), " + 391 | "PRIMARY KEY(person_id), " + 392 | "CHECK(id > 100), " + 393 | "FOREIGN KEY(test_id) REFERENCES other_table(col1, col2)" + 394 | ")", 395 | }, 396 | { 397 | name: "NotExists", 398 | in: &CreateTableStmt{ 399 | Name: NewObjectName("persons"), 400 | NotExists: true, 401 | Elements: []TableElement{ 402 | &ColumnDef{ 403 | Name: NewIdent("person_id"), 404 | DataType: &Int{}, 405 | Constraints: []*ColumnConstraint{ 406 | { 407 | Spec: &UniqueColumnSpec{ 408 | IsPrimaryKey: true, 409 | }, 410 | }, 411 | { 412 | Spec: &NotNullColumnSpec{}, 413 | }, 414 | }, 415 | }, 416 | &ColumnDef{ 417 | Name: NewIdent("last_name"), 418 | DataType: &VarcharType{ 419 | Size: NewSize(255), 420 | }, 421 | Constraints: []*ColumnConstraint{ 422 | { 423 | Spec: &NotNullColumnSpec{}, 424 | }, 425 | }, 426 | }, 427 | &ColumnDef{ 428 | Name: NewIdent("created_at"), 429 | DataType: &Timestamp{}, 430 | Default: NewIdent("CURRENT_TIMESTAMP"), 431 | Constraints: []*ColumnConstraint{ 432 | { 433 | Spec: &NotNullColumnSpec{}, 434 | }, 435 | }, 436 | }, 437 | }, 438 | }, 439 | out: "CREATE TABLE IF NOT EXISTS persons (" + 440 | "person_id int PRIMARY KEY NOT NULL, " + 441 | "last_name character varying(255) NOT NULL, " + 442 | "created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL)", 443 | }, 444 | } 445 | for _, c := range cases { 446 | t.Run(c.name, func(t *testing.T) { 447 | act := c.in.ToSQLString() 448 | 449 | if act != c.out { 450 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 451 | } 452 | }) 453 | } 454 | } 455 | 456 | func TestSQLAlterTable_ToSQLString(t *testing.T) { 457 | cases := []struct { 458 | name string 459 | in *AlterTableStmt 460 | out string 461 | }{ 462 | { 463 | name: "add column", 464 | in: &AlterTableStmt{ 465 | TableName: NewObjectName("customers"), 466 | Action: &AddColumnTableAction{ 467 | Column: &ColumnDef{ 468 | Name: NewIdent("email"), 469 | DataType: &VarcharType{ 470 | Size: NewSize(255), 471 | }, 472 | }, 473 | }, 474 | }, 475 | out: "ALTER TABLE customers " + 476 | "ADD COLUMN email character varying(255)", 477 | }, 478 | { 479 | name: "add column over uint8", 480 | in: &AlterTableStmt{ 481 | TableName: NewObjectName("customers"), 482 | Action: &AddColumnTableAction{ 483 | Column: &ColumnDef{ 484 | Name: NewIdent("email"), 485 | DataType: &VarcharType{ 486 | Size: NewSize(256), 487 | }, 488 | }, 489 | }, 490 | }, 491 | out: "ALTER TABLE customers " + 492 | "ADD COLUMN email character varying(256)", 493 | }, 494 | { 495 | name: "remove column", 496 | in: &AlterTableStmt{ 497 | TableName: NewObjectName("products"), 498 | Action: &RemoveColumnTableAction{ 499 | Name: NewIdent("description"), 500 | Cascade: true, 501 | }, 502 | }, 503 | out: "ALTER TABLE products " + 504 | "DROP COLUMN description CASCADE", 505 | }, 506 | { 507 | name: "add constraint", 508 | in: &AlterTableStmt{ 509 | TableName: NewObjectName("products"), 510 | Action: &AddConstraintTableAction{ 511 | Constraint: &TableConstraint{ 512 | Spec: &ReferentialTableConstraint{ 513 | Columns: []*Ident{NewIdent("test_id")}, 514 | KeyExpr: &ReferenceKeyExpr{ 515 | TableName: NewIdent("other_table"), 516 | Columns: []*Ident{NewIdent("col1"), NewIdent("col2")}, 517 | }, 518 | }, 519 | }, 520 | }, 521 | }, 522 | out: "ALTER TABLE products " + 523 | "ADD FOREIGN KEY(test_id) REFERENCES other_table(col1, col2)", 524 | }, 525 | { 526 | name: "alter column", 527 | in: &AlterTableStmt{ 528 | TableName: NewObjectName("products"), 529 | Action: &AlterColumnTableAction{ 530 | ColumnName: NewIdent("created_at"), 531 | Action: &SetDefaultColumnAction{ 532 | Default: NewIdent("current_timestamp"), 533 | }, 534 | }, 535 | }, 536 | out: "ALTER TABLE products " + 537 | "ALTER COLUMN created_at SET DEFAULT current_timestamp", 538 | }, 539 | { 540 | name: "pg change type", 541 | in: &AlterTableStmt{ 542 | TableName: NewObjectName("products"), 543 | Action: &AlterColumnTableAction{ 544 | ColumnName: NewIdent("number"), 545 | Action: &PGAlterDataTypeColumnAction{ 546 | DataType: &Decimal{ 547 | Scale: NewSize(10), 548 | Precision: NewSize(255), 549 | }, 550 | }, 551 | }, 552 | }, 553 | out: "ALTER TABLE products " + 554 | "ALTER COLUMN number TYPE numeric(255,10)", 555 | }, 556 | } 557 | for _, c := range cases { 558 | t.Run(c.name, func(t *testing.T) { 559 | act := c.in.ToSQLString() 560 | 561 | if act != c.out { 562 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 563 | } 564 | }) 565 | } 566 | } 567 | 568 | func TestSQLCreateIndex_ToSQLString(t *testing.T) { 569 | cases := []struct { 570 | name string 571 | in *CreateIndexStmt 572 | out string 573 | }{ 574 | { 575 | name: "create index", 576 | in: &CreateIndexStmt{ 577 | TableName: NewObjectName("customers"), 578 | ColumnNames: []*Ident{NewIdent("name")}, 579 | }, 580 | out: "CREATE INDEX ON customers (name)", 581 | }, 582 | { 583 | name: "create unique index", 584 | in: &CreateIndexStmt{ 585 | TableName: NewObjectName("customers"), 586 | IsUnique: true, 587 | ColumnNames: []*Ident{NewIdent("name")}, 588 | }, 589 | out: "CREATE UNIQUE INDEX ON customers (name)", 590 | }, 591 | { 592 | name: "create index with name", 593 | in: &CreateIndexStmt{ 594 | TableName: NewObjectName("customers"), 595 | IndexName: NewIdent("customers_idx"), 596 | IsUnique: true, 597 | ColumnNames: []*Ident{NewIdent("name"), NewIdent("email")}, 598 | }, 599 | out: "CREATE UNIQUE INDEX customers_idx ON customers (name, email)", 600 | }, 601 | { 602 | name: "create index with name", 603 | in: &CreateIndexStmt{ 604 | TableName: NewObjectName("customers"), 605 | IndexName: NewIdent("customers_idx"), 606 | IsUnique: true, 607 | MethodName: NewIdent("gist"), 608 | ColumnNames: []*Ident{NewIdent("name")}, 609 | }, 610 | out: "CREATE UNIQUE INDEX customers_idx ON customers USING gist (name)", 611 | }, 612 | { 613 | name: "create partial index with name", 614 | in: &CreateIndexStmt{ 615 | TableName: NewObjectName("customers"), 616 | IndexName: NewIdent("customers_idx"), 617 | IsUnique: true, 618 | MethodName: NewIdent("gist"), 619 | ColumnNames: []*Ident{NewIdent("name")}, 620 | Selection: &BinaryExpr{ 621 | Left: NewIdent("name"), 622 | Op: &Operator{Type: Eq}, 623 | Right: NewSingleQuotedString("test"), 624 | }, 625 | }, 626 | out: "CREATE UNIQUE INDEX customers_idx ON customers USING gist (name) WHERE name = 'test'", 627 | }, 628 | } 629 | for _, c := range cases { 630 | t.Run(c.name, func(t *testing.T) { 631 | act := c.in.ToSQLString() 632 | 633 | if act != c.out { 634 | t.Errorf("must be \n%s but \n%s \n diff: %s", c.out, act, diff.CharacterDiff(c.out, act)) 635 | } 636 | }) 637 | } 638 | } 639 | -------------------------------------------------------------------------------- /sqlast/table_constraint_spec_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type TableConstraintSpec interface { 6 | tableConstraintSpecMarker() 7 | Node 8 | } 9 | type tableConstraintSpec struct{} 10 | 11 | func (tableConstraintSpec) tableConstraintSpecMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/table_element_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type TableElement interface { 6 | tableElementMarker() 7 | Node 8 | } 9 | type tableElement struct{} 10 | 11 | func (tableElement) tableElementMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/table_factor_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type TableFactor interface { 6 | tableFactorMarker() 7 | TableReference 8 | } 9 | type tableFactor struct{} 10 | 11 | func (tableFactor) tableFactorMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/table_option.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/akito0107/xsqlparser/sqltoken" 7 | ) 8 | 9 | //go:generate genmark -t TableOption -e Node 10 | 11 | //ENGINE option ( = InnoDB, MyISAM ...) 12 | type MyEngine struct { 13 | tableOption 14 | Engine sqltoken.Pos 15 | Equal bool 16 | Name *Ident 17 | } 18 | 19 | func (m *MyEngine) ToSQLString() string { 20 | return toSQLString(m) 21 | } 22 | 23 | func (m *MyEngine) WriteTo(w io.Writer) (int64, error) { 24 | sw := newSQLWriter(w) 25 | sw.Bytes([]byte("ENGINE ")).If(m.Equal, []byte("= ")).Node(m.Name) 26 | return sw.End() 27 | } 28 | 29 | func (m *MyEngine) Pos() sqltoken.Pos { 30 | return m.Engine 31 | } 32 | 33 | func (m *MyEngine) End() sqltoken.Pos { 34 | return m.Name.To 35 | } 36 | 37 | type MyCharset struct { 38 | tableOption 39 | IsDefault bool 40 | Default sqltoken.Pos 41 | Charset sqltoken.Pos 42 | Equal bool 43 | Name *Ident 44 | } 45 | 46 | func (m *MyCharset) ToSQLString() string { 47 | return toSQLString(m) 48 | } 49 | 50 | func (m *MyCharset) WriteTo(w io.Writer) (int64, error) { 51 | sw := newSQLWriter(w) 52 | sw.If(m.IsDefault, []byte("DEFAULT ")).Bytes([]byte("CHARSET ")) 53 | sw.If(m.Equal, []byte("= ")).Node(m.Name) 54 | return sw.End() 55 | } 56 | 57 | func (m *MyCharset) Pos() sqltoken.Pos { 58 | if m.IsDefault { 59 | return m.Default 60 | } 61 | return m.Charset 62 | } 63 | 64 | func (m *MyCharset) End() sqltoken.Pos { 65 | return m.Name.To 66 | } 67 | -------------------------------------------------------------------------------- /sqlast/table_option_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type TableOption interface { 6 | tableOptionMarker() 7 | Node 8 | } 9 | type tableOption struct{} 10 | 11 | func (tableOption) tableOptionMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/table_reference_gen.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | // Code generated by genmark. DO NOT EDIT. 4 | 5 | type TableReference interface { 6 | tableReferenceMarker() 7 | Node 8 | } 9 | type tableReference struct{} 10 | 11 | func (tableReference) tableReferenceMarker() {} 12 | -------------------------------------------------------------------------------- /sqlast/type.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/akito0107/xsqlparser/sqltoken" 7 | ) 8 | 9 | type Type interface { 10 | Node 11 | } 12 | 13 | type CharType struct { 14 | Size *uint 15 | From, To, RParen sqltoken.Pos 16 | } 17 | 18 | func (c *CharType) Pos() sqltoken.Pos { 19 | return c.From 20 | } 21 | 22 | func (c *CharType) End() sqltoken.Pos { 23 | if c.Size != nil { 24 | return c.RParen 25 | } 26 | return c.To 27 | } 28 | 29 | func (c *CharType) ToSQLString() string { 30 | return toSQLString(c) 31 | } 32 | 33 | func (c *CharType) WriteTo(w io.Writer) (int64, error) { 34 | return newSQLWriter(w).TypeWithOptionalLength([]byte("char"), c.Size).End() 35 | } 36 | 37 | type VarcharType struct { 38 | Size *uint 39 | Character, Varying, RParen sqltoken.Pos 40 | } 41 | 42 | func (v *VarcharType) Pos() sqltoken.Pos { 43 | return v.Character 44 | } 45 | 46 | func (v *VarcharType) End() sqltoken.Pos { 47 | if v.Size != nil { 48 | return v.RParen 49 | } 50 | return v.Varying 51 | } 52 | 53 | func (v *VarcharType) ToSQLString() string { 54 | return toSQLString(v) 55 | } 56 | 57 | func (v *VarcharType) WriteTo(w io.Writer) (int64, error) { 58 | return newSQLWriter(w).TypeWithOptionalLength([]byte("character varying"), v.Size).End() 59 | } 60 | 61 | type UUID struct { 62 | From, To sqltoken.Pos 63 | } 64 | 65 | func (u *UUID) Pos() sqltoken.Pos { 66 | return u.From 67 | } 68 | 69 | func (u *UUID) End() sqltoken.Pos { 70 | return u.To 71 | } 72 | 73 | func (*UUID) ToSQLString() string { 74 | return "uuid" 75 | } 76 | 77 | func (u *UUID) WriteTo(w io.Writer) (int64, error) { 78 | return writeSingleBytes(w, []byte("uuid")) 79 | } 80 | 81 | type Clob struct { 82 | Size uint 83 | Clob, RParen sqltoken.Pos 84 | } 85 | 86 | func (c *Clob) Pos() sqltoken.Pos { 87 | return c.Clob 88 | } 89 | 90 | func (c *Clob) End() sqltoken.Pos { 91 | return c.RParen 92 | } 93 | 94 | func (c *Clob) ToSQLString() string { 95 | return toSQLString(c) 96 | } 97 | 98 | func (c *Clob) WriteTo(w io.Writer) (int64, error) { 99 | return newSQLWriter(w).TypeWithOptionalLength([]byte("clob"), &c.Size).End() 100 | } 101 | 102 | type Binary struct { 103 | Size uint 104 | Binary, RParen sqltoken.Pos 105 | } 106 | 107 | func (b *Binary) Pos() sqltoken.Pos { 108 | return b.Binary 109 | } 110 | 111 | func (b *Binary) End() sqltoken.Pos { 112 | return b.RParen 113 | } 114 | 115 | func (b *Binary) ToSQLString() string { 116 | return toSQLString(b) 117 | } 118 | 119 | func (b *Binary) WriteTo(w io.Writer) (int64, error) { 120 | return newSQLWriter(w).TypeWithOptionalLength([]byte("binary"), &b.Size).End() 121 | } 122 | 123 | type Varbinary struct { 124 | Size uint 125 | Varbinary, RParen sqltoken.Pos 126 | } 127 | 128 | func (v *Varbinary) Pos() sqltoken.Pos { 129 | return v.Varbinary 130 | } 131 | 132 | func (v *Varbinary) End() sqltoken.Pos { 133 | return v.RParen 134 | } 135 | 136 | func (v *Varbinary) ToSQLString() string { 137 | return toSQLString(v) 138 | } 139 | 140 | func (v *Varbinary) WriteTo(w io.Writer) (int64, error) { 141 | return newSQLWriter(w).TypeWithOptionalLength([]byte("varbinary"), &v.Size).End() 142 | } 143 | 144 | type Blob struct { 145 | Size uint 146 | Blob, RParen sqltoken.Pos 147 | } 148 | 149 | func (b *Blob) Pos() sqltoken.Pos { 150 | return b.Blob 151 | } 152 | 153 | func (b *Blob) End() sqltoken.Pos { 154 | return b.RParen 155 | } 156 | 157 | func (b *Blob) ToSQLString() string { 158 | return toSQLString(b) 159 | } 160 | 161 | func (b *Blob) WriteTo(w io.Writer) (int64, error) { 162 | return newSQLWriter(w).TypeWithOptionalLength([]byte("blob"), &b.Size).End() 163 | } 164 | 165 | // All unsigned props are only available on MySQL 166 | 167 | type Decimal struct { 168 | Precision *uint 169 | Scale *uint 170 | Numeric, RParen sqltoken.Pos 171 | IsUnsigned bool 172 | Unsigned sqltoken.Pos 173 | } 174 | 175 | func (d *Decimal) Pos() sqltoken.Pos { 176 | return d.Numeric 177 | } 178 | 179 | func (d *Decimal) End() sqltoken.Pos { 180 | if d.IsUnsigned { 181 | return d.Unsigned 182 | } 183 | return d.RParen 184 | } 185 | 186 | func (d *Decimal) ToSQLString() string { 187 | return toSQLString(d) 188 | } 189 | 190 | func (d *Decimal) WriteTo(w io.Writer) (int64, error) { 191 | sw := newSQLWriter(w) 192 | sw.Bytes([]byte("numeric")) 193 | if d.Precision != nil { 194 | sw.LParen() 195 | sw.Int(int(*d.Precision)) 196 | if d.Scale != nil { 197 | sw.Bytes([]byte(",")) 198 | sw.Int(int(*d.Scale)) 199 | } 200 | sw.RParen() 201 | } 202 | sw.If(d.IsUnsigned, []byte(" unsigned")) 203 | return sw.End() 204 | } 205 | 206 | type Float struct { 207 | Size *uint 208 | From, To, RParen sqltoken.Pos 209 | IsUnsigned bool 210 | Unsigned sqltoken.Pos 211 | } 212 | 213 | func (f *Float) Pos() sqltoken.Pos { 214 | return f.From 215 | } 216 | 217 | func (f *Float) End() sqltoken.Pos { 218 | if f.IsUnsigned { 219 | return f.Unsigned 220 | } 221 | if f.Size != nil { 222 | return f.RParen 223 | } 224 | return f.To 225 | } 226 | 227 | func (f *Float) ToSQLString() string { 228 | return toSQLString(f) 229 | } 230 | 231 | func (f *Float) WriteTo(w io.Writer) (int64, error) { 232 | sw := newSQLWriter(w) 233 | sw.TypeWithOptionalLength([]byte("float"), f.Size).If(f.IsUnsigned, []byte(" unsigned")) 234 | return sw.End() 235 | } 236 | 237 | type SmallInt struct { 238 | From, To sqltoken.Pos 239 | IsUnsigned bool 240 | Unsigned sqltoken.Pos 241 | } 242 | 243 | func (s *SmallInt) Pos() sqltoken.Pos { 244 | return s.From 245 | } 246 | 247 | func (s *SmallInt) End() sqltoken.Pos { 248 | if s.IsUnsigned { 249 | return s.Unsigned 250 | } 251 | return s.To 252 | } 253 | 254 | func (s *SmallInt) ToSQLString() string { 255 | return toSQLString(s) 256 | } 257 | 258 | func (s *SmallInt) WriteTo(w io.Writer) (int64, error) { 259 | sw := newSQLWriter(w) 260 | sw.Bytes([]byte("smallint")).If(s.IsUnsigned, []byte(" unsigned")) 261 | return sw.End() 262 | } 263 | 264 | type Int struct { 265 | From, To sqltoken.Pos 266 | IsUnsigned bool 267 | Unsigned sqltoken.Pos 268 | } 269 | 270 | func (i *Int) Pos() sqltoken.Pos { 271 | return i.From 272 | } 273 | 274 | func (i *Int) End() sqltoken.Pos { 275 | if i.IsUnsigned { 276 | return i.Unsigned 277 | } 278 | return i.To 279 | } 280 | 281 | func (i *Int) ToSQLString() string { 282 | return toSQLString(i) 283 | } 284 | 285 | func (i *Int) WriteTo(w io.Writer) (int64, error) { 286 | sw := newSQLWriter(w) 287 | sw.Bytes([]byte("int")).If(i.IsUnsigned, []byte(" unsigned")) 288 | return sw.End() 289 | } 290 | 291 | type BigInt struct { 292 | From, To sqltoken.Pos 293 | IsUnsigned bool 294 | Unsigned sqltoken.Pos 295 | } 296 | 297 | func (b *BigInt) Pos() sqltoken.Pos { 298 | return b.From 299 | } 300 | 301 | func (b *BigInt) End() sqltoken.Pos { 302 | if b.IsUnsigned { 303 | return b.Unsigned 304 | } 305 | return b.To 306 | } 307 | 308 | func (b *BigInt) ToSQLString() string { 309 | return toSQLString(b) 310 | } 311 | 312 | func (b *BigInt) WriteTo(w io.Writer) (int64, error) { 313 | sw := newSQLWriter(w) 314 | sw.Bytes([]byte("bigint")).If(b.IsUnsigned, []byte(" unsigned")) 315 | return sw.End() 316 | } 317 | 318 | type Real struct { 319 | From, To sqltoken.Pos 320 | IsUnsigned bool 321 | Unsigned sqltoken.Pos 322 | } 323 | 324 | func (r *Real) Pos() sqltoken.Pos { 325 | return r.From 326 | } 327 | 328 | func (r *Real) End() sqltoken.Pos { 329 | if r.IsUnsigned { 330 | return r.Unsigned 331 | } 332 | return r.To 333 | } 334 | 335 | func (r *Real) ToSQLString() string { 336 | return toSQLString(r) 337 | } 338 | 339 | func (r *Real) WriteTo(w io.Writer) (int64, error) { 340 | sw := newSQLWriter(w) 341 | sw.Bytes([]byte("real")).If(r.IsUnsigned, []byte(" unsigned")) 342 | return sw.End() 343 | } 344 | 345 | type Double struct { 346 | From, To sqltoken.Pos 347 | } 348 | 349 | func (d *Double) Pos() sqltoken.Pos { 350 | return d.From 351 | } 352 | 353 | func (d *Double) End() sqltoken.Pos { 354 | return d.To 355 | } 356 | 357 | func (*Double) ToSQLString() string { 358 | return "double precision" 359 | } 360 | 361 | func (*Double) WriteTo(w io.Writer) (int64, error) { 362 | return writeSingleBytes(w, []byte("double precision")) 363 | } 364 | 365 | type Boolean struct { 366 | From, To sqltoken.Pos 367 | } 368 | 369 | func (b *Boolean) Pos() sqltoken.Pos { 370 | return b.From 371 | } 372 | 373 | func (b *Boolean) End() sqltoken.Pos { 374 | return b.To 375 | } 376 | 377 | func (*Boolean) ToSQLString() string { 378 | return "boolean" 379 | } 380 | 381 | func (*Boolean) WriteTo(w io.Writer) (int64, error) { 382 | return writeSingleBytes(w, []byte("boolean")) 383 | } 384 | 385 | type Date struct { 386 | From, To sqltoken.Pos 387 | } 388 | 389 | func (d *Date) Pos() sqltoken.Pos { 390 | return d.From 391 | } 392 | 393 | func (d *Date) End() sqltoken.Pos { 394 | return d.To 395 | } 396 | 397 | func (*Date) ToSQLString() string { 398 | return "date" 399 | } 400 | 401 | func (*Date) WriteTo(w io.Writer) (int64, error) { 402 | return writeSingleBytes(w, []byte("date")) 403 | } 404 | 405 | type Time struct { 406 | From, To sqltoken.Pos 407 | } 408 | 409 | func (t *Time) Pos() sqltoken.Pos { 410 | return t.From 411 | } 412 | 413 | func (t *Time) End() sqltoken.Pos { 414 | return t.To 415 | } 416 | 417 | func (*Time) ToSQLString() string { 418 | return "time" 419 | } 420 | 421 | func (*Time) WriteTo(w io.Writer) (int64, error) { 422 | return writeSingleBytes(w, []byte("time")) 423 | } 424 | 425 | type Timestamp struct { 426 | WithTimeZone bool 427 | Timestamp sqltoken.Pos 428 | Zone sqltoken.Pos 429 | } 430 | 431 | func (t *Timestamp) Pos() sqltoken.Pos { 432 | return t.Timestamp 433 | } 434 | 435 | func (t *Timestamp) End() sqltoken.Pos { 436 | if t.WithTimeZone { 437 | return t.Zone 438 | } 439 | 440 | return sqltoken.Pos{ 441 | Line: t.Timestamp.Line, 442 | Col: t.Timestamp.Col + 9, 443 | } 444 | } 445 | 446 | func (t *Timestamp) ToSQLString() string { 447 | return toSQLString(t) 448 | } 449 | 450 | func (t *Timestamp) WriteTo(w io.Writer) (int64, error) { 451 | sw := newSQLWriter(w) 452 | sw.Bytes([]byte("timestamp")).If(t.WithTimeZone, []byte(" with time zone")) 453 | return sw.End() 454 | } 455 | 456 | type Regclass struct { 457 | From, To sqltoken.Pos 458 | } 459 | 460 | func (r *Regclass) Pos() sqltoken.Pos { 461 | return r.From 462 | } 463 | 464 | func (r *Regclass) End() sqltoken.Pos { 465 | return r.To 466 | } 467 | 468 | func (*Regclass) ToSQLString() string { 469 | return "regclass" 470 | } 471 | 472 | func (*Regclass) WriteTo(w io.Writer) (int64, error) { 473 | return writeSingleBytes(w, []byte("regclass")) 474 | } 475 | 476 | type Text struct { 477 | From, To sqltoken.Pos 478 | } 479 | 480 | func (t *Text) Pos() sqltoken.Pos { 481 | return t.From 482 | } 483 | 484 | func (t *Text) End() sqltoken.Pos { 485 | return t.To 486 | } 487 | 488 | func (*Text) ToSQLString() string { 489 | return "text" 490 | } 491 | 492 | func (*Text) WriteTo(w io.Writer) (int64, error) { 493 | return writeSingleBytes(w, []byte("text")) 494 | } 495 | 496 | type Bytea struct { 497 | From, To sqltoken.Pos 498 | } 499 | 500 | func (b *Bytea) Pos() sqltoken.Pos { 501 | return b.From 502 | } 503 | 504 | func (b *Bytea) End() sqltoken.Pos { 505 | return b.To 506 | } 507 | 508 | func (*Bytea) ToSQLString() string { 509 | return "bytea" 510 | } 511 | 512 | func (*Bytea) WriteTo(w io.Writer) (int64, error) { 513 | return writeSingleBytes(w, []byte("bytea")) 514 | } 515 | 516 | type Array struct { 517 | Ty Type 518 | RParen sqltoken.Pos 519 | } 520 | 521 | func (a *Array) Pos() sqltoken.Pos { 522 | return a.Ty.Pos() 523 | } 524 | 525 | func (a *Array) End() sqltoken.Pos { 526 | return a.RParen 527 | } 528 | 529 | func (a *Array) ToSQLString() string { 530 | return toSQLString(a) 531 | } 532 | 533 | func (a *Array) WriteTo(w io.Writer) (int64, error) { 534 | return newSQLWriter(w).Node(a.Ty).Bytes([]byte("[]")).End() 535 | } 536 | 537 | type Custom struct { 538 | Ty *ObjectName 539 | } 540 | 541 | func (c *Custom) Pos() sqltoken.Pos { 542 | return c.Ty.Pos() 543 | } 544 | 545 | func (c *Custom) End() sqltoken.Pos { 546 | return c.Ty.End() 547 | } 548 | 549 | func (c *Custom) ToSQLString() string { 550 | return c.Ty.ToSQLString() 551 | } 552 | 553 | func (c *Custom) WriteTo(w io.Writer) (int64, error) { 554 | return c.Ty.WriteTo(w) 555 | } 556 | 557 | func NewSize(s uint) *uint { 558 | return &s 559 | } 560 | -------------------------------------------------------------------------------- /sqlast/value.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/akito0107/xsqlparser/sqltoken" 10 | ) 11 | 12 | type Value interface { 13 | Value() interface{} 14 | Node 15 | } 16 | 17 | type LongValue struct { 18 | From, To sqltoken.Pos 19 | Long int64 20 | } 21 | 22 | func NewLongValue(i int64) *LongValue { 23 | return &LongValue{ 24 | Long: i, 25 | } 26 | } 27 | 28 | func (l *LongValue) Pos() sqltoken.Pos { 29 | return l.From 30 | } 31 | 32 | func (l *LongValue) End() sqltoken.Pos { 33 | return l.To 34 | } 35 | 36 | func (l *LongValue) Value() interface{} { 37 | return l 38 | } 39 | 40 | func (l *LongValue) ToSQLString() string { 41 | return toSQLString(l) 42 | } 43 | 44 | func (l *LongValue) WriteTo(w io.Writer) (int64, error) { 45 | n, err := io.WriteString(w, strconv.FormatInt(l.Long, 10)) 46 | return int64(n), err 47 | } 48 | 49 | type DoubleValue struct { 50 | From, To sqltoken.Pos 51 | Double float64 52 | } 53 | 54 | func NewDoubleValue(f float64) *DoubleValue { 55 | return &DoubleValue{ 56 | Double: f, 57 | } 58 | } 59 | 60 | func (d *DoubleValue) Pos() sqltoken.Pos { 61 | return d.From 62 | } 63 | 64 | func (d *DoubleValue) End() sqltoken.Pos { 65 | return d.To 66 | } 67 | 68 | func (d *DoubleValue) Value() interface{} { 69 | return d.Double 70 | } 71 | 72 | func (d *DoubleValue) ToSQLString() string { 73 | return toSQLString(d) 74 | } 75 | 76 | func (d *DoubleValue) WriteTo(w io.Writer) (int64, error) { 77 | var b [32] byte 78 | buf := strconv.AppendFloat(b[:0], d.Double, 'f', -1, 64) 79 | n, err := w.Write(buf) 80 | return int64(n), err 81 | } 82 | 83 | type SingleQuotedString struct { 84 | From, To sqltoken.Pos 85 | String string 86 | } 87 | 88 | func NewSingleQuotedString(str string) *SingleQuotedString { 89 | return &SingleQuotedString{ 90 | String: str, 91 | } 92 | } 93 | 94 | func (s *SingleQuotedString) Pos() sqltoken.Pos { 95 | return s.From 96 | } 97 | 98 | func (s *SingleQuotedString) End() sqltoken.Pos { 99 | return s.To 100 | } 101 | 102 | func (s *SingleQuotedString) Value() interface{} { 103 | return s.String 104 | } 105 | 106 | func (s *SingleQuotedString) ToSQLString() string { 107 | return toSQLString(s) 108 | } 109 | 110 | func (s *SingleQuotedString) WriteTo(w io.Writer) (int64, error) { 111 | n, err := w.Write([]byte("'")) 112 | if err != nil { 113 | return int64(n), err 114 | } 115 | n1, err := io.WriteString(w, s.String) 116 | if err != nil { 117 | return int64(n + n1), err 118 | } 119 | n2, err := w.Write([]byte("'")) 120 | return int64(n + n1 + n2), err 121 | } 122 | 123 | type NationalStringLiteral struct { 124 | From, To sqltoken.Pos 125 | String string 126 | } 127 | 128 | func NewNationalStringLiteral(str string) *NationalStringLiteral { 129 | return &NationalStringLiteral{ 130 | String: str, 131 | } 132 | } 133 | 134 | func (n *NationalStringLiteral) Pos() sqltoken.Pos { 135 | return n.From 136 | } 137 | 138 | func (n *NationalStringLiteral) End() sqltoken.Pos { 139 | return n.To 140 | } 141 | 142 | func (n *NationalStringLiteral) Value() interface{} { 143 | return n.String 144 | } 145 | 146 | func (n *NationalStringLiteral) ToSQLString() string { 147 | return fmt.Sprintf("N'%s'", n.String) 148 | } 149 | 150 | func (n *NationalStringLiteral) WriteTo(w io.Writer) (int64, error) { 151 | n0, err := w.Write([]byte("N'")) 152 | if err != nil { 153 | return int64(n0), err 154 | } 155 | n1, err := io.WriteString(w, n.String) 156 | if err != nil { 157 | return int64(n0 + n1), err 158 | } 159 | n2, err := w.Write([]byte("'")) 160 | return int64(n0 + n1 + n2), err 161 | } 162 | 163 | type BooleanValue struct { 164 | From, To sqltoken.Pos 165 | Boolean bool 166 | } 167 | 168 | func NewBooleanValue(b bool) *BooleanValue { 169 | return &BooleanValue{ 170 | Boolean: b, 171 | } 172 | } 173 | 174 | func (b *BooleanValue) Pos() sqltoken.Pos { 175 | return b.From 176 | } 177 | 178 | func (b *BooleanValue) End() sqltoken.Pos { 179 | return b.To 180 | } 181 | 182 | func (b *BooleanValue) Value() interface{} { 183 | return b.Boolean 184 | } 185 | 186 | func (b *BooleanValue) ToSQLString() string { 187 | return toSQLString(b) 188 | } 189 | 190 | func (b *BooleanValue) WriteTo(w io.Writer) (int64, error) { 191 | if b.Boolean { 192 | return writeSingleBytes(w, []byte("true")) 193 | } else { 194 | return writeSingleBytes(w, []byte("false")) 195 | } 196 | } 197 | 198 | type DateValue struct { 199 | From, To sqltoken.Pos 200 | Date time.Time 201 | } 202 | 203 | func (d *DateValue) Pos() sqltoken.Pos { 204 | return d.From 205 | } 206 | 207 | func (d *DateValue) End() sqltoken.Pos { 208 | return d.To 209 | } 210 | 211 | func (d *DateValue) Value() interface{} { 212 | return d.Date 213 | } 214 | 215 | func (d *DateValue) ToSQLString() string { 216 | return toSQLString(d) 217 | } 218 | 219 | func (d *DateValue) WriteTo(w io.Writer) (int64, error) { 220 | var b [16]byte 221 | buf := d.Date.AppendFormat(b[:0], "2006-01-02") 222 | n, err := w.Write(buf) 223 | return int64(n), err 224 | } 225 | 226 | type TimeValue struct { 227 | From, To sqltoken.Pos 228 | Time time.Time 229 | } 230 | 231 | func NewTimeValue(t time.Time) *TimeValue { 232 | return &TimeValue{ 233 | Time: t, 234 | } 235 | } 236 | 237 | func (t *TimeValue) Pos() sqltoken.Pos { 238 | return t.From 239 | } 240 | 241 | func (t *TimeValue) End() sqltoken.Pos { 242 | return t.To 243 | } 244 | 245 | func (t *TimeValue) Value() interface{} { 246 | return t.Time 247 | } 248 | 249 | func (t *TimeValue) ToSQLString() string { 250 | return toSQLString(t) 251 | } 252 | 253 | func (t *TimeValue) WriteTo(w io.Writer) (int64, error) { 254 | var b [16]byte 255 | buf := t.Time.AppendFormat(b[:0], "15:04:05") 256 | n, err := w.Write(buf) 257 | return int64(n), err 258 | } 259 | 260 | type DateTimeValue struct { 261 | From, To sqltoken.Pos 262 | DateTime time.Time 263 | } 264 | 265 | func NewDateTimeValue(t time.Time) *DateTimeValue { 266 | return &DateTimeValue{ 267 | DateTime: t, 268 | } 269 | } 270 | 271 | func (d *DateTimeValue) Pos() sqltoken.Pos { 272 | return d.From 273 | } 274 | 275 | func (d *DateTimeValue) End() sqltoken.Pos { 276 | return d.To 277 | } 278 | 279 | func (d *DateTimeValue) Value() interface{} { 280 | return d.DateTime 281 | } 282 | 283 | func (d *DateTimeValue) ToSQLString() string { 284 | return d.DateTime.Format("2006-01-02 15:04:05") 285 | } 286 | 287 | func (d *DateTimeValue) WriteTo(w io.Writer) (int64, error) { 288 | var b [32]byte 289 | buf := d.DateTime.AppendFormat(b[:0], "2006-01-02 15:04:05") 290 | n, err := w.Write(buf) 291 | return int64(n), err 292 | } 293 | 294 | type TimestampValue struct { 295 | From, To sqltoken.Pos 296 | Timestamp time.Time 297 | } 298 | 299 | func NewTimestampValue(t time.Time) *TimestampValue { 300 | return &TimestampValue{Timestamp: t} 301 | } 302 | 303 | func (t *TimestampValue) Pos() sqltoken.Pos { 304 | return t.From 305 | } 306 | 307 | func (t *TimestampValue) End() sqltoken.Pos { 308 | return t.To 309 | } 310 | 311 | func (t *TimestampValue) Value() interface{} { 312 | return t.Timestamp 313 | } 314 | 315 | func (t *TimestampValue) ToSQLString() string { 316 | return toSQLString(t) 317 | } 318 | 319 | func (t *TimestampValue) WriteTo(w io.Writer) (int64, error) { 320 | var b [32]byte 321 | buf := t.Timestamp.AppendFormat(b[:0], "2006-01-02 15:04:05") 322 | n, err := w.Write(buf) 323 | return int64(n), err 324 | } 325 | 326 | type NullValue struct { 327 | From, To sqltoken.Pos 328 | } 329 | 330 | func NewNullValue() *NullValue { 331 | return &NullValue{} 332 | } 333 | 334 | func (n *NullValue) Pos() sqltoken.Pos { 335 | return n.From 336 | } 337 | 338 | func (n *NullValue) End() sqltoken.Pos { 339 | return n.To 340 | } 341 | 342 | func (n *NullValue) Value() interface{} { 343 | return nil 344 | } 345 | 346 | func (n *NullValue) ToSQLString() string { 347 | return "NULL" 348 | } 349 | 350 | func (*NullValue) WriteTo(w io.Writer) (int64, error) { 351 | return writeSingleBytes(w, []byte("NULL")) 352 | } 353 | -------------------------------------------------------------------------------- /sqlast/walk.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | type Visitor interface { 8 | Visit(node Node) Visitor 9 | } 10 | 11 | func walkIdentLists(v Visitor, list []*Ident) { 12 | for _, i := range list { 13 | Walk(v, i) 14 | } 15 | } 16 | 17 | func walkASTNodeLists(v Visitor, list []Node) { 18 | for _, l := range list { 19 | Walk(v, l) 20 | } 21 | } 22 | 23 | func Walk(v Visitor, node Node) { 24 | if v := v.Visit(node); v == nil { 25 | return 26 | } 27 | 28 | switch n := node.(type) { 29 | case *File: 30 | for _, stmt := range n.Stmts { 31 | Walk(v, stmt) 32 | } 33 | case *Ident: 34 | // nothing to do 35 | case *Wildcard: 36 | // nothing to do 37 | case *QualifiedWildcard: 38 | walkIdentLists(v, n.Idents) 39 | case *CompoundIdent: 40 | walkIdentLists(v, n.Idents) 41 | case *IsNull: 42 | Walk(v, n.X) 43 | case *IsNotNull: 44 | Walk(v, n.X) 45 | case *InList: 46 | Walk(v, n.Expr) 47 | walkASTNodeLists(v, n.List) 48 | case *InSubQuery: 49 | Walk(v, n.Expr) 50 | Walk(v, n.SubQuery) 51 | case *Between: 52 | Walk(v, n.Expr) 53 | Walk(v, n.Low) 54 | Walk(v, n.High) 55 | case *BinaryExpr: 56 | Walk(v, n.Left) 57 | Walk(v, n.Op) 58 | Walk(v, n.Right) 59 | case *Cast: 60 | Walk(v, n.Expr) 61 | Walk(v, n.DataType) 62 | case *Nested: 63 | Walk(v, n.AST) 64 | case *UnaryExpr: 65 | Walk(v, n.Op) 66 | Walk(v, n.Expr) 67 | case *Function: 68 | Walk(v, n.Name) 69 | walkASTNodeLists(v, n.Args) 70 | if n.Over != nil { 71 | Walk(v, n.Over) 72 | } 73 | case *CaseExpr: 74 | Walk(v, n.Operand) 75 | case *Exists: 76 | Walk(v, n.Query) 77 | case *SubQuery: 78 | Walk(v, n.Query) 79 | case *ObjectName: 80 | walkIdentLists(v, n.Idents) 81 | case *WindowSpec: 82 | walkASTNodeLists(v, n.PartitionBy) 83 | for _, o := range n.OrderBy { 84 | Walk(v, o) 85 | } 86 | if n.WindowsFrame != nil { 87 | Walk(v, n.WindowsFrame) 88 | } 89 | case *WindowFrame: 90 | Walk(v, n.Units) 91 | Walk(v, n.StartBound) 92 | if n.EndBound != nil { 93 | Walk(v, n.EndBound) 94 | } 95 | case *WindowFrameUnit: 96 | // nothing to do 97 | case *CurrentRow: 98 | // nothing to do 99 | case *UnboundedPreceding: 100 | // nothing to do 101 | case *UnboundedFollowing: 102 | // nothing to do 103 | case *Preceding: 104 | // nothing to do 105 | case *Following: 106 | // nothing to do 107 | case *QueryStmt: 108 | for _, c := range n.CTEs { 109 | Walk(v, c) 110 | } 111 | Walk(v, n.Body) 112 | for _, o := range n.OrderBy { 113 | Walk(v, o) 114 | } 115 | if n.Limit != nil { 116 | Walk(v, n.Limit) 117 | } 118 | case *CTE: 119 | Walk(v, n.Query) 120 | Walk(v, n.Alias) 121 | case *SelectExpr: 122 | Walk(v, n.Select) 123 | case *QueryExpr: 124 | Walk(v, n.Query) 125 | case *SetOperationExpr: 126 | Walk(v, n.Op) 127 | Walk(v, n.Left) 128 | Walk(v, n.Right) 129 | case *UnionOperator: 130 | // nothing to do 131 | case *ExceptOperator: 132 | // nothing to do 133 | case *IntersectOperator: 134 | // nothing to do 135 | case *SQLSelect: 136 | for _, p := range n.Projection { 137 | Walk(v, p) 138 | } 139 | if len(n.FromClause) != 0 { 140 | for _, f := range n.FromClause { 141 | Walk(v, f) 142 | } 143 | } 144 | if n.WhereClause != nil { 145 | Walk(v, n.WhereClause) 146 | } 147 | walkASTNodeLists(v, n.GroupByClause) 148 | if n.HavingClause != nil { 149 | Walk(v, n.HavingClause) 150 | } 151 | case *QualifiedJoin: 152 | Walk(v, n.LeftElement) 153 | Walk(v, n.Type) 154 | Walk(v, n.RightElement) 155 | Walk(v, n.Spec) 156 | case *TableJoinElement: 157 | Walk(v, n.Ref) 158 | case *JoinType: 159 | // nothing to do 160 | case *NamedColumnsJoin: 161 | // nothing to do 162 | case *JoinCondition: 163 | Walk(v, n.SearchCondition) 164 | case *NaturalJoin: 165 | Walk(v, n.LeftElement) 166 | Walk(v, n.Type) 167 | Walk(v, n.RightElement) 168 | case *CrossJoin: 169 | Walk(v, n.Factor) 170 | Walk(v, n.Reference) 171 | case *Table: 172 | Walk(v, n.Name) 173 | if n.Alias != nil { 174 | Walk(v, n.Alias) 175 | } 176 | walkASTNodeLists(v, n.Args) 177 | walkASTNodeLists(v, n.WithHints) 178 | case *Derived: 179 | Walk(v, n.SubQuery) 180 | if n.Alias != nil { 181 | Walk(v, n.Alias) 182 | } 183 | case *UnnamedSelectItem: 184 | Walk(v, n.Node) 185 | case *AliasSelectItem: 186 | Walk(v, n.Expr) 187 | Walk(v, n.Alias) 188 | case *QualifiedWildcardSelectItem: 189 | Walk(v, n.Prefix) 190 | case *WildcardSelectItem: 191 | // nothing to do 192 | case *OrderByExpr: 193 | Walk(v, n.Expr) 194 | case *LimitExpr: 195 | if !n.All { 196 | Walk(v, n.LimitValue) 197 | } 198 | if n.OffsetValue != nil { 199 | Walk(v, n.OffsetValue) 200 | } 201 | case *CharType: 202 | // nothing to do 203 | case *VarcharType: 204 | // nothing to do 205 | case *UUID: 206 | // nothing to do 207 | case *Clob: 208 | // nothing to do 209 | case *Binary: 210 | // nothing to do 211 | case *Varbinary: 212 | // nothing to do 213 | case *Blob: 214 | // nothing to do 215 | case *Decimal: 216 | // nothing to do 217 | case *Float: 218 | // nothing to do 219 | case *SmallInt: 220 | // nothing to do 221 | case *Int: 222 | // nothing to do 223 | case *BigInt: 224 | // nothing to do 225 | case *Real: 226 | // nothing to do 227 | case *Double: 228 | // nothing to do 229 | case *Boolean: 230 | // nothing to do 231 | case *Date: 232 | // nothing to do 233 | case *Time: 234 | // nothing to do 235 | case *Timestamp: 236 | // nothing to do 237 | case *Regclass: 238 | // nothing to do 239 | case *Text: 240 | // nothing to do 241 | case *Bytea: 242 | // nothing to do 243 | case *Array: 244 | // nothing to do 245 | case *Custom: 246 | // nothing to do 247 | case *InsertStmt: 248 | Walk(v, n.TableName) 249 | walkIdentLists(v, n.Columns) 250 | Walk(v, n.Source) 251 | 252 | for _, a := range n.UpdateAssignments { 253 | Walk(v, a) 254 | } 255 | 256 | case *ConstructorSource: 257 | for _, r := range n.Rows { 258 | Walk(v, r) 259 | } 260 | case *RowValueExpr: 261 | for _, r := range n.Values { 262 | Walk(v, r) 263 | } 264 | case *SubQuerySource: 265 | Walk(v, n.SubQuery) 266 | case *CopyStmt: 267 | Walk(v, n.TableName) 268 | walkIdentLists(v, n.Columns) 269 | case *UpdateStmt: 270 | Walk(v, n.TableName) 271 | for _, a := range n.Assignments { 272 | Walk(v, a) 273 | } 274 | Walk(v, n.Selection) 275 | case *DeleteStmt: 276 | Walk(v, n.TableName) 277 | if n.Selection != nil { 278 | Walk(v, n.Selection) 279 | } 280 | case *CreateViewStmt: 281 | Walk(v, n.Name) 282 | Walk(v, n.Query) 283 | case *CreateTableStmt: 284 | Walk(v, n.Name) 285 | for _, e := range n.Elements { 286 | Walk(v, e) 287 | } 288 | case *Assignment: 289 | Walk(v, n.ID) 290 | Walk(v, n.Value) 291 | case *TableConstraint: 292 | if n.Name != nil { 293 | Walk(v, n.Name) 294 | } 295 | Walk(v, n.Spec) 296 | case *UniqueTableConstraint: 297 | walkIdentLists(v, n.Columns) 298 | case *ReferentialTableConstraint: 299 | walkIdentLists(v, n.Columns) 300 | Walk(v, n.KeyExpr) 301 | case *ReferenceKeyExpr: 302 | Walk(v, n.TableName) 303 | walkIdentLists(v, n.Columns) 304 | case *CheckTableConstraint: 305 | Walk(v, n.Expr) 306 | case *ColumnDef: 307 | Walk(v, n.Name) 308 | Walk(v, n.DataType) 309 | if n.Default != nil { 310 | Walk(v, n.Default) 311 | } 312 | for _, c := range n.Constraints { 313 | Walk(v, c) 314 | } 315 | case *ColumnConstraint: 316 | if n.Name != nil { 317 | Walk(v, n.Name) 318 | } 319 | Walk(v, n.Spec) 320 | case *NotNullColumnSpec: 321 | // nothing to do 322 | case *UniqueColumnSpec: 323 | // nothing to do 324 | case *ReferencesColumnSpec: 325 | Walk(v, n.TableName) 326 | walkIdentLists(v, n.Columns) 327 | case *CheckColumnSpec: 328 | Walk(v, n.Expr) 329 | case *AlterTableStmt: 330 | Walk(v, n.TableName) 331 | Walk(v, n.Action) 332 | case *AddColumnTableAction: 333 | Walk(v, n.Column) 334 | case *AlterColumnTableAction: 335 | Walk(v, n.ColumnName) 336 | Walk(v, n.Action) 337 | case *SetDefaultColumnAction: 338 | Walk(v, n.Default) 339 | case *DropDefaultColumnAction: 340 | // nothing to do 341 | case *PGAlterDataTypeColumnAction: 342 | Walk(v, n.DataType) 343 | case *PGSetNotNullColumnAction: 344 | // nothing to do 345 | case *PGDropNotNullColumnAction: 346 | // nothing to do 347 | case *RemoveColumnTableAction: 348 | Walk(v, n.Name) 349 | case *AddConstraintTableAction: 350 | Walk(v, n.Constraint) 351 | case *DropConstraintTableAction: 352 | Walk(v, n.Name) 353 | case *DropTableStmt: 354 | for _, t := range n.TableNames { 355 | Walk(v, t) 356 | } 357 | case *CreateIndexStmt: 358 | Walk(v, n.TableName) 359 | if n.IndexName != nil { 360 | Walk(v, n.IndexName) 361 | } 362 | if n.MethodName != nil { 363 | Walk(v, n.MethodName) 364 | } 365 | walkIdentLists(v, n.ColumnNames) 366 | if n.Selection != nil { 367 | Walk(v, n.Selection) 368 | } 369 | case *DropIndexStmt: 370 | walkIdentLists(v, n.IndexNames) 371 | case *ExplainStmt: 372 | Walk(v, n.Stmt) 373 | case *Operator: 374 | // nothing to do 375 | case *NullValue, 376 | *LongValue, 377 | *DoubleValue, 378 | *SingleQuotedString, 379 | *NationalStringLiteral, 380 | *BooleanValue, 381 | *DateValue, 382 | *TimeValue, 383 | *DateTimeValue, 384 | *TimestampValue: 385 | // nothing to do 386 | default: 387 | log.Panicf("not implemented type %T: %+v", node, node) 388 | } 389 | 390 | v.Visit(nil) 391 | } 392 | 393 | type inspector func(node Node) bool 394 | 395 | func (f inspector) Visit(node Node) Visitor { 396 | if f(node) { 397 | return f 398 | } 399 | return nil 400 | } 401 | 402 | func Inspect(node Node, f func(node Node) bool) { 403 | Walk(inspector(f), node) 404 | } 405 | -------------------------------------------------------------------------------- /sqlast/writer.go: -------------------------------------------------------------------------------- 1 | package sqlast 2 | 3 | import ( 4 | "io" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | type sqlWriter struct { 10 | w io.Writer 11 | n int64 12 | err error 13 | } 14 | 15 | func newSQLWriter(w io.Writer) *sqlWriter { 16 | return &sqlWriter{w: w} 17 | } 18 | 19 | var selectBytes = []byte("SELECT ") 20 | var fromBytes = []byte(" FROM ") 21 | var whereBytes = []byte(" WHERE ") 22 | var wildcardBytes = []byte("*") 23 | var dotBytes = []byte(".") 24 | var spaceBytes = []byte(" ") 25 | 26 | func (w *sqlWriter) Bytes(b []byte) *sqlWriter { 27 | if w.err != nil { 28 | return w 29 | } 30 | n, err := w.w.Write(b) 31 | w.n += int64(n) 32 | if err != nil { 33 | w.err = err 34 | } 35 | return w 36 | } 37 | 38 | func (w *sqlWriter) Space() *sqlWriter { 39 | return w.Bytes(spaceBytes) 40 | } 41 | 42 | func (w *sqlWriter) LParen() *sqlWriter { 43 | return w.Bytes([]byte("(")) 44 | } 45 | 46 | func (w *sqlWriter) RParen() *sqlWriter { 47 | return w.Bytes([]byte(")")) 48 | } 49 | 50 | func (w *sqlWriter) Int(i int) *sqlWriter { 51 | if w.err != nil { 52 | return w 53 | } 54 | var buf [32]byte 55 | b := buf[:0] 56 | b = strconv.AppendInt(b, int64(i), 10) 57 | n, err := w.w.Write(b) 58 | w.n += int64(n) 59 | if err != nil { 60 | w.err = err 61 | } 62 | return w 63 | } 64 | 65 | func (w *sqlWriter) Node(wt io.WriterTo) *sqlWriter { 66 | if w.err != nil { 67 | return w 68 | } 69 | n, err := wt.WriteTo(w.w) 70 | w.n += n 71 | if err != nil { 72 | w.err = err 73 | } 74 | return w 75 | } 76 | 77 | func (w *sqlWriter) Join(i int, wt io.WriterTo, sep []byte) *sqlWriter { 78 | if i > 0 { 79 | w.Bytes(sep) 80 | } 81 | return w.Node(wt) 82 | } 83 | 84 | func (w *sqlWriter) JoinComma(i int, wt io.WriterTo) *sqlWriter { 85 | if i > 0 { 86 | w.Bytes([]byte(", ")) 87 | } 88 | return w.Node(wt) 89 | } 90 | 91 | func (w *sqlWriter) JoinNewLine(i int, wt io.WriterTo) *sqlWriter { 92 | if i > 0 { 93 | w.Bytes([]byte("\n")) 94 | } 95 | return w.Node(wt) 96 | } 97 | 98 | func (w *sqlWriter) Idents(idents []*Ident, sep []byte) *sqlWriter { 99 | if w.err != nil { 100 | return w 101 | } 102 | sw, ok := w.w.(io.StringWriter) 103 | if ok { 104 | for i, ident := range idents { 105 | if i > 0 { 106 | w.Bytes(sep) 107 | } 108 | if w.err == nil { 109 | w.Direct(ident.WriteStringTo(sw)) 110 | } 111 | } 112 | return w 113 | } 114 | for i, ident := range idents { 115 | w.Join(i, ident, sep) 116 | } 117 | return w 118 | } 119 | 120 | func (w *sqlWriter) Nodes(nodes []Node) *sqlWriter { 121 | if w.err != nil { 122 | return w 123 | } 124 | for i, node := range nodes { 125 | w.Join(i, node, []byte(", ")) 126 | } 127 | return w 128 | } 129 | 130 | func (w *sqlWriter) TypeWithOptionalLength(sqltype []byte, size *uint) *sqlWriter { 131 | w.Bytes(sqltype) 132 | if size != nil { 133 | w.Bytes([]byte("(")).Int(int(*size)).Bytes([]byte(")")) 134 | } 135 | return w 136 | } 137 | 138 | func (w *sqlWriter) Negated(negated bool) *sqlWriter { 139 | return w.If(negated, []byte("NOT ")) 140 | } 141 | 142 | func (w *sqlWriter) If(ok bool, b []byte) *sqlWriter { 143 | if ok { 144 | w.Bytes(b) 145 | } 146 | return w 147 | } 148 | 149 | func (w *sqlWriter) As() *sqlWriter { 150 | return w.Bytes([]byte(" AS ")) 151 | } 152 | 153 | func (w *sqlWriter) End() (int64, error) { 154 | return w.n, w.err 155 | } 156 | 157 | func (w *sqlWriter) Err() error { 158 | return w.err 159 | } 160 | 161 | func (w *sqlWriter) Direct(n int64, err error) *sqlWriter { 162 | w.n += n 163 | if err != nil { 164 | w.err = err 165 | } 166 | return w 167 | } 168 | 169 | func writeSingleBytes(w io.Writer, b []byte) (int64, error) { 170 | n, err := w.Write(b) 171 | return int64(n), err 172 | } 173 | 174 | func writeSingleString(w io.Writer, s string) (int64, error) { 175 | n, err := io.WriteString(w, s) 176 | return int64(n), err 177 | } 178 | 179 | func toSQLString(n Node) string { 180 | var b strings.Builder 181 | _, _ = n.WriteTo(&b) 182 | return b.String() 183 | } 184 | -------------------------------------------------------------------------------- /sqlastutil/rewrite.go: -------------------------------------------------------------------------------- 1 | package sqlastutil 2 | 3 | import ( 4 | "log" 5 | "reflect" 6 | 7 | "github.com/akito0107/xsqlparser/sqlast" 8 | ) 9 | 10 | type ApplyFunc func(*Cursor) bool 11 | 12 | var abort = new(int) 13 | 14 | func Apply(root sqlast.Node, pre, post ApplyFunc) (result sqlast.Node) { 15 | parent := &struct { 16 | sqlast.Node 17 | }{root} 18 | 19 | defer func() { 20 | if r := recover(); r != nil && r != abort { 21 | panic(r) 22 | } 23 | result = parent.Node 24 | }() 25 | a := &application{pre: pre, post: post} 26 | a.apply(parent, "Node", nil, root) 27 | return 28 | } 29 | 30 | type Cursor struct { 31 | parent sqlast.Node 32 | name string 33 | iter *iterator 34 | node sqlast.Node 35 | } 36 | 37 | func (c *Cursor) Node() sqlast.Node { return c.node } 38 | 39 | func (c *Cursor) Parent() sqlast.Node { return c.parent } 40 | 41 | func (c *Cursor) Name() string { return c.name } 42 | 43 | func (c *Cursor) Index() int { 44 | if c.iter != nil { 45 | return c.iter.index 46 | } 47 | return -1 48 | } 49 | 50 | func (c *Cursor) field() reflect.Value { 51 | return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) 52 | } 53 | 54 | func (c *Cursor) Replace(n sqlast.Node) { 55 | v := c.field() 56 | if i := c.Index(); i >= 0 { 57 | v = v.Index(i) 58 | } 59 | v.Set(reflect.ValueOf(n)) 60 | } 61 | 62 | func (c *Cursor) Delete() { 63 | i := c.Index() 64 | if i < 0 { 65 | log.Panicln("delete node not contained in slice") 66 | } 67 | v := c.field() 68 | l := v.Len() 69 | 70 | reflect.Copy(v.Slice(i, l), v.Slice(i+1, l)) 71 | v.Index(l - 1).Set(reflect.Zero(v.Type().Elem())) 72 | v.SetLen(l - 1) 73 | c.iter.step-- 74 | } 75 | 76 | func (c *Cursor) InsertAfter(n sqlast.Node) { 77 | i := c.Index() 78 | if i < 0 { 79 | log.Panicln("InsertAfter node not contained in slice") 80 | } 81 | v := c.field() 82 | v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) 83 | l := v.Len() 84 | reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l)) 85 | v.Index(i + 1).Set(reflect.ValueOf(n)) 86 | c.iter.step++ 87 | } 88 | 89 | func (c *Cursor) InsertBefore(n sqlast.Node) { 90 | i := c.Index() 91 | if i < 0 { 92 | log.Panicln("InsertBefore node not contained in slice") 93 | } 94 | v := c.field() 95 | v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) 96 | l := v.Len() 97 | reflect.Copy(v.Slice(i+1, l), v.Slice(i, l)) 98 | v.Index(i).Set(reflect.ValueOf(n)) 99 | c.iter.index++ 100 | } 101 | 102 | type iterator struct { 103 | index, step int 104 | } 105 | 106 | type application struct { 107 | pre, post ApplyFunc 108 | cursor Cursor 109 | iter iterator 110 | } 111 | 112 | func (a *application) apply(parent sqlast.Node, name string, iter *iterator, n sqlast.Node) { 113 | if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() { 114 | n = nil 115 | } 116 | 117 | saved := a.cursor 118 | a.cursor.parent = parent 119 | a.cursor.name = name 120 | a.cursor.iter = iter 121 | a.cursor.node = n 122 | 123 | if a.pre != nil && !a.pre(&a.cursor) { 124 | a.cursor = saved 125 | return 126 | } 127 | 128 | switch n := n.(type) { 129 | case *sqlast.File: 130 | a.applyList(n, "Stmts") 131 | case *sqlast.Ident: 132 | // nothing to do 133 | case *sqlast.Wildcard: 134 | // nothing to do 135 | case *sqlast.QualifiedWildcard: 136 | a.applyList(n, "Idents") 137 | case *sqlast.CompoundIdent: 138 | a.applyList(n, "Idents") 139 | case *sqlast.IsNull: 140 | a.apply(n, "X", nil, n.X) 141 | case *sqlast.IsNotNull: 142 | a.apply(n, "X", nil, n.X) 143 | case *sqlast.InList: 144 | a.apply(n, "Expr", nil, n.Expr) 145 | a.applyList(n, "List") 146 | case *sqlast.InSubQuery: 147 | a.apply(n, "Expr", nil, n.Expr) 148 | a.apply(n, "SubQuery", nil, n.SubQuery) 149 | case *sqlast.Between: 150 | a.apply(n, "Expr", nil, n.Expr) 151 | a.apply(n, "Low", nil, n.Low) 152 | a.apply(n, "High", nil, n.High) 153 | case *sqlast.BinaryExpr: 154 | a.apply(n, "Left", nil, n.Left) 155 | a.apply(n, "Op", nil, n.Op) 156 | a.apply(n, "Right", nil, n.Right) 157 | case *sqlast.Cast: 158 | a.apply(n, "Expr", nil, n.Expr) 159 | a.apply(n, "DataType", nil, n.DataType) 160 | case *sqlast.Nested: 161 | a.apply(n, "AST", nil, n.AST) 162 | case *sqlast.UnaryExpr: 163 | a.apply(n, "Op", nil, n.Op) 164 | a.apply(n, "Expr", nil, n.Expr) 165 | case *sqlast.Function: 166 | a.apply(n, "Name", nil, n.Name) 167 | a.applyList(n, "Args") 168 | if n.Over != nil { 169 | a.apply(n, "Over", nil, n.Over) 170 | } 171 | case *sqlast.CaseExpr: 172 | a.apply(n, "Operand", nil, n.Operand) 173 | case *sqlast.Exists: 174 | a.apply(n, "QueryStmt", nil, n.Query) 175 | case *sqlast.SubQuery: 176 | a.apply(n, "QueryStmt", nil, n.Query) 177 | case *sqlast.ObjectName: 178 | a.applyList(n, "Idents") 179 | case *sqlast.WindowSpec: 180 | a.applyList(n, "PartitionBy") 181 | a.applyList(n, "OrderBy") 182 | if n.WindowsFrame != nil { 183 | a.apply(n, "WindowsFrame", nil, n.WindowsFrame) 184 | } 185 | case *sqlast.WindowFrame: 186 | a.apply(n, "Units", nil, n.Units) 187 | a.apply(n, "StartBound", nil, n.StartBound) 188 | if n.EndBound != nil { 189 | a.apply(n, "EndBound", nil, n.EndBound) 190 | } 191 | case *sqlast.WindowFrameUnit, 192 | *sqlast.CurrentRow, 193 | *sqlast.UnboundedPreceding, 194 | *sqlast.UnboundedFollowing, 195 | *sqlast.Preceding, 196 | *sqlast.Following: 197 | // nothing to do 198 | case *sqlast.QueryStmt: 199 | a.applyList(n, "CTEs") 200 | a.apply(n, "Body", nil, n.Body) 201 | a.applyList(n, "OrderBy") 202 | if n.Limit != nil { 203 | a.apply(n, "Limit", nil, n.Limit) 204 | } 205 | case *sqlast.CTE: 206 | a.apply(n, "QueryStmt", nil, n.Query) 207 | a.apply(n, "Alias", nil, n.Alias) 208 | case *sqlast.SelectExpr: 209 | a.apply(n, "Select", nil, n.Select) 210 | case *sqlast.QueryExpr: 211 | a.apply(n, "QueryStmt", nil, n.Query) 212 | case *sqlast.SetOperationExpr: 213 | a.apply(n, "Op", nil, n.Op) 214 | a.apply(n, "Left", nil, n.Left) 215 | a.apply(n, "Right", nil, n.Right) 216 | case *sqlast.UnionOperator: 217 | // nothing to do 218 | case *sqlast.ExceptOperator: 219 | // nothing to do 220 | case *sqlast.IntersectOperator: 221 | // nothing to do 222 | case *sqlast.SQLSelect: 223 | a.applyList(n, "Projection") 224 | a.applyList(n, "FromClause") 225 | if n.WhereClause != nil { 226 | a.apply(n, "WhereClause", nil, n.WhereClause) 227 | } 228 | a.applyList(n, "GroupByClause") 229 | if n.HavingClause != nil { 230 | a.apply(n, "HavingClause", nil, n.HavingClause) 231 | } 232 | case *sqlast.QualifiedJoin: 233 | a.apply(n, "LeftElement", nil, n.LeftElement) 234 | a.apply(n, "Type", nil, n.Type) 235 | a.apply(n, "RightElement", nil, n.RightElement) 236 | a.apply(n, "Spec", nil, n.Spec) 237 | case *sqlast.TableJoinElement: 238 | a.apply(n, "Ref", nil, n.Ref) 239 | case *sqlast.JoinType: 240 | // nothing to do 241 | case *sqlast.JoinCondition: 242 | a.apply(n, "SearchCondition", nil, n.SearchCondition) 243 | case *sqlast.NaturalJoin: 244 | a.apply(n, "LeftElement", nil, n.LeftElement) 245 | a.apply(n, "Type", nil, n.Type) 246 | a.apply(n, "RightElement", nil, n.RightElement) 247 | case *sqlast.CrossJoin: 248 | a.apply(n, "Factor", nil, n.Factor) 249 | a.apply(n, "Reference", nil, n.Reference) 250 | case *sqlast.Table: 251 | a.apply(n, "Name", nil, n.Name) 252 | if n.Alias != nil { 253 | a.apply(n, "Alias", nil, n.Alias) 254 | } 255 | a.applyList(n, "Args") 256 | a.applyList(n, "WithHints") 257 | case *sqlast.Derived: 258 | a.apply(n, "SubQuery", nil, n.SubQuery) 259 | if n.Alias != nil { 260 | a.apply(n, "Alias", nil, n.Alias) 261 | } 262 | case *sqlast.UnnamedSelectItem: 263 | a.apply(n, "Node", nil, n.Node) 264 | case *sqlast.AliasSelectItem: 265 | a.apply(n, "Expr", nil, n.Expr) 266 | a.apply(n, "Alias", nil, n.Alias) 267 | case *sqlast.QualifiedWildcardSelectItem: 268 | a.apply(n, "Prefix", nil, n.Prefix) 269 | case *sqlast.WildcardSelectItem: 270 | // nothing to do 271 | case *sqlast.OrderByExpr: 272 | a.apply(n, "Expr", nil, n.Expr) 273 | case *sqlast.LimitExpr: 274 | if !n.All { 275 | a.apply(n, "LimitValue", nil, n.LimitValue) 276 | } 277 | if n.OffsetValue != nil { 278 | a.apply(n, "OffsetValue", nil, n.OffsetValue) 279 | } 280 | case *sqlast.CharType: 281 | // nothing to do 282 | case *sqlast.VarcharType: 283 | // nothing to do 284 | case *sqlast.UUID: 285 | // nothing to do 286 | case *sqlast.Clob: 287 | // nothing to do 288 | case *sqlast.Binary: 289 | // nothing to do 290 | case *sqlast.Varbinary: 291 | // nothing to do 292 | case *sqlast.Blob: 293 | // nothing to do 294 | case *sqlast.Decimal: 295 | // nothing to do 296 | case *sqlast.Float: 297 | // nothing to do 298 | case *sqlast.SmallInt: 299 | // nothing to do 300 | case *sqlast.Int: 301 | // nothing to do 302 | case *sqlast.BigInt: 303 | // nothing to do 304 | case *sqlast.Real: 305 | // nothing to do 306 | case *sqlast.Double: 307 | // nothing to do 308 | case *sqlast.Boolean: 309 | // nothing to do 310 | case *sqlast.Date: 311 | // nothing to do 312 | case *sqlast.Time: 313 | // nothing to do 314 | case *sqlast.Timestamp: 315 | // nothing to do 316 | case *sqlast.Regclass: 317 | // nothing to do 318 | case *sqlast.Text: 319 | // nothing to do 320 | case *sqlast.Bytea: 321 | // nothing to do 322 | case *sqlast.Array: 323 | // nothing to do 324 | case *sqlast.Custom: 325 | // nothing to do 326 | case *sqlast.InsertStmt: 327 | a.apply(n, "TableName", nil, n.TableName) 328 | a.applyList(n, "Columns") 329 | a.apply(n, "Source", nil, n.Source) 330 | a.applyList(n, "UpdateAssignments") 331 | case *sqlast.ConstructorSource: 332 | a.applyList(n, "Rows") 333 | case *sqlast.RowValueExpr: 334 | a.applyList(n, "Values") 335 | case *sqlast.SubQuerySource: 336 | a.apply(n, "SubQuery", nil, n.SubQuery) 337 | case *sqlast.CopyStmt: 338 | a.apply(n, "TableName", nil, n.TableName) 339 | a.applyList(n, "Columns") 340 | case *sqlast.UpdateStmt: 341 | a.apply(n, "TableName", nil, n.TableName) 342 | a.applyList(n, "Assignments") 343 | a.apply(n, "Selection", nil, n.Selection) 344 | case *sqlast.DeleteStmt: 345 | a.apply(n, "TableName", nil, n.TableName) 346 | if n.Selection != nil { 347 | a.apply(n, "Selection", nil, n.Selection) 348 | } 349 | case *sqlast.CreateViewStmt: 350 | a.apply(n, "Name", nil, n.Name) 351 | a.apply(n, "QueryStmt", nil, n.Query) 352 | case *sqlast.CreateTableStmt: 353 | a.apply(n, "Name", nil, n.Name) 354 | a.applyList(n, "Elements") 355 | case *sqlast.Assignment: 356 | a.apply(n, "ID", nil, n.ID) 357 | a.apply(n, "Value", nil, n.Value) 358 | case *sqlast.TableConstraint: 359 | if n.Name != nil { 360 | a.apply(n, "Name", nil, n.Name) 361 | } 362 | a.apply(n, "Spec", nil, n.Spec) 363 | case *sqlast.UniqueTableConstraint: 364 | a.applyList(n, "Columns") 365 | case *sqlast.ReferentialTableConstraint: 366 | a.applyList(n, "Columns") 367 | a.apply(n, "KeyExpr", nil, n.KeyExpr) 368 | case *sqlast.ReferenceKeyExpr: 369 | a.apply(n, "TableName", nil, n.TableName) 370 | a.applyList(n, "Columns") 371 | case *sqlast.CheckTableConstraint: 372 | a.apply(n, "Expr", nil, n.Expr) 373 | case *sqlast.ColumnDef: 374 | a.apply(n, "Name", nil, n.Name) 375 | a.apply(n, "DataType", nil, n.DataType) 376 | if n.Default != nil { 377 | a.apply(n, "Default", nil, n.Default) 378 | } 379 | a.applyList(n, "Constraints") 380 | case *sqlast.ColumnConstraint: 381 | if n.Name != nil { 382 | a.apply(n, "Name", nil, n.Name) 383 | } 384 | a.apply(n, "Spec", nil, n.Spec) 385 | case *sqlast.NotNullColumnSpec: 386 | // nothing to do 387 | case *sqlast.UniqueColumnSpec: 388 | // nothing to do 389 | case *sqlast.ReferencesColumnSpec: 390 | a.apply(n, "TableName", nil, n.TableName) 391 | a.applyList(n, "Columns") 392 | case *sqlast.CheckColumnSpec: 393 | a.apply(n, "Expr", nil, n.Expr) 394 | case *sqlast.AlterTableStmt: 395 | a.apply(n, "TableName", nil, n.TableName) 396 | a.apply(n, "Action", nil, n.Action) 397 | case *sqlast.AddColumnTableAction: 398 | a.apply(n, "Column", nil, n.Column) 399 | case *sqlast.AlterColumnTableAction: 400 | a.apply(n, "ColumnName", nil, n.ColumnName) 401 | a.apply(n, "Action", nil, n.Action) 402 | case *sqlast.SetDefaultColumnAction: 403 | a.apply(n, "Default", nil, n.Default) 404 | case *sqlast.DropDefaultColumnAction: 405 | // nothing to do 406 | case *sqlast.PGAlterDataTypeColumnAction: 407 | a.apply(n, "DataType", nil, n.DataType) 408 | case *sqlast.PGSetNotNullColumnAction: 409 | // nothing to do 410 | case *sqlast.PGDropNotNullColumnAction: 411 | // nothing to do 412 | case *sqlast.RemoveColumnTableAction: 413 | a.apply(n, "Name", nil, n.Name) 414 | case *sqlast.AddConstraintTableAction: 415 | a.apply(n, "Constraint", nil, n.Constraint) 416 | case *sqlast.DropConstraintTableAction: 417 | a.apply(n, "Name", nil, n.Name) 418 | case *sqlast.DropTableStmt: 419 | a.applyList(n, "TableNames") 420 | case *sqlast.CreateIndexStmt: 421 | a.apply(n, "TableName", nil, n.TableName) 422 | if n.IndexName != nil { 423 | a.apply(n, "IndexName", nil, n.IndexName) 424 | } 425 | if n.MethodName != nil { 426 | a.apply(n, "MethodName", nil, n.MethodName) 427 | } 428 | a.applyList(n, "ColumnNames") 429 | if n.Selection != nil { 430 | a.apply(n, "Selection", nil, n.Selection) 431 | } 432 | case *sqlast.DropIndexStmt: 433 | a.applyList(n, "IndexNames") 434 | case *sqlast.ExplainStmt: 435 | a.apply(n, "Stmt", nil, n.Stmt) 436 | case *sqlast.Operator: 437 | // nothing to do 438 | case *sqlast.NullValue, 439 | *sqlast.LongValue, 440 | *sqlast.DoubleValue, 441 | *sqlast.SingleQuotedString, 442 | *sqlast.NationalStringLiteral, 443 | *sqlast.BooleanValue, 444 | *sqlast.DateValue, 445 | *sqlast.TimeValue, 446 | *sqlast.DateTimeValue, 447 | *sqlast.TimestampValue: 448 | // nothing to do 449 | default: 450 | log.Panicf("not implemented type %T: %+v", n, n) 451 | } 452 | 453 | if a.post != nil && !a.post(&a.cursor) { 454 | panic(abort) 455 | } 456 | a.cursor = saved 457 | } 458 | 459 | func (a *application) applyList(parent sqlast.Node, name string) { 460 | saved := a.iter 461 | a.iter.index = 0 462 | for { 463 | v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) 464 | if a.iter.index >= v.Len() { 465 | break 466 | } 467 | 468 | var x sqlast.Node 469 | if e := v.Index(a.iter.index); e.IsValid() { 470 | x = e.Interface().(sqlast.Node) 471 | } 472 | 473 | a.iter.step = 1 474 | a.apply(parent, name, &a.iter, x) 475 | a.iter.index += a.iter.step 476 | } 477 | a.iter = saved 478 | } 479 | -------------------------------------------------------------------------------- /sqlastutil/rewrite_test.go: -------------------------------------------------------------------------------- 1 | package sqlastutil 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/akito0107/xsqlparser" 8 | "github.com/akito0107/xsqlparser/dialect" 9 | "github.com/akito0107/xsqlparser/sqlast" 10 | ) 11 | 12 | func TestApply(t *testing.T) { 13 | 14 | cases := []struct { 15 | name string 16 | src string 17 | expect string 18 | preFunc ApplyFunc 19 | postFunc ApplyFunc 20 | }{ 21 | { 22 | name: "replace long value", 23 | src: `SELECT * FROM table_a WHERE id = 1`, 24 | expect: `SELECT * FROM table_a WHERE id = 2`, 25 | preFunc: func(cursor *Cursor) bool { 26 | switch cursor.node.(type) { 27 | case *sqlast.LongValue: 28 | cursor.Replace(sqlast.NewLongValue(2)) 29 | } 30 | return true 31 | }, 32 | }, 33 | { 34 | name: "remove select item", 35 | src: "SELECT a, b, c FROM table_a", 36 | expect: "SELECT a, b FROM table_a", 37 | preFunc: func(cursor *Cursor) bool { 38 | switch cursor.node.(type) { 39 | case *sqlast.UnnamedSelectItem: 40 | if cursor.Index() == 2 { 41 | cursor.Delete() 42 | } 43 | } 44 | return true 45 | }, 46 | }, 47 | { 48 | name: "insert after", 49 | src: "SELECT a, b FROM table_a", 50 | expect: "SELECT a, b, c FROM table_a", 51 | preFunc: func(cursor *Cursor) bool { 52 | switch cursor.node.(type) { 53 | case *sqlast.UnnamedSelectItem: 54 | if cursor.Index() == 1 { 55 | cursor.InsertAfter(&sqlast.UnnamedSelectItem{ 56 | Node: sqlast.NewIdent("c"), 57 | }) 58 | } 59 | } 60 | return true 61 | }, 62 | }, 63 | { 64 | name: "insert before", 65 | src: "SELECT a, b FROM table_a", 66 | expect: "SELECT c, a, b FROM table_a", 67 | preFunc: func(cursor *Cursor) bool { 68 | switch cursor.node.(type) { 69 | case *sqlast.UnnamedSelectItem: 70 | if cursor.Index() == 0 { 71 | cursor.InsertBefore(&sqlast.UnnamedSelectItem{ 72 | Node: sqlast.NewIdent("c"), 73 | }) 74 | } 75 | } 76 | return true 77 | }, 78 | }, 79 | } 80 | 81 | for _, c := range cases { 82 | t.Run(c.name, func(t *testing.T) { 83 | parser, err := xsqlparser.NewParser(bytes.NewBufferString(c.src), &dialect.GenericSQLDialect{}) 84 | if err != nil { 85 | t.Fatalf("%+v", err) 86 | } 87 | ast, err := parser.ParseStatement() 88 | if err != nil { 89 | t.Fatalf("%+v", err) 90 | } 91 | 92 | res := Apply(ast, c.preFunc, c.postFunc) 93 | if c.expect != res.ToSQLString() { 94 | t.Errorf("should be \n %s but \n %s", c.expect, res.ToSQLString()) 95 | } 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /sqltoken/kind.go: -------------------------------------------------------------------------------- 1 | package sqltoken 2 | 3 | type Kind int 4 | 5 | //go:generate stringer -type Kind kind.go 6 | const ( 7 | // A keyword (like SELECT) 8 | SQLKeyword Kind = iota 9 | // Numeric literal 10 | Number 11 | // A character that cloud not be tokenized 12 | Char 13 | // Single quoted string i.e: 'string' 14 | SingleQuotedString 15 | // National string i.e: N'string' 16 | NationalStringLiteral 17 | // Comma 18 | Comma 19 | // Whitespace 20 | Whitespace 21 | // comment node 22 | Comment 23 | // = operator 24 | Eq 25 | // != or <> operator 26 | Neq 27 | // < operator 28 | Lt 29 | // > operator 30 | Gt 31 | // <= operator 32 | LtEq 33 | // >= operator 34 | GtEq 35 | // + operator 36 | Plus 37 | // - operator 38 | Minus 39 | // * operator 40 | Mult 41 | // / operator 42 | Div 43 | // % operator 44 | Mod 45 | // Left parenthesis `(` 46 | LParen 47 | // Right parenthesis `)` 48 | RParen 49 | // Period 50 | Period 51 | // Colon 52 | Colon 53 | // DoubleColon 54 | DoubleColon 55 | // Semicolon 56 | Semicolon 57 | // Backslash 58 | Backslash 59 | // Left bracket `]` 60 | LBracket 61 | // Right bracket `[` 62 | RBracket 63 | // & 64 | Ampersand 65 | // Left brace `{` 66 | LBrace 67 | // Right brace `}` 68 | RBrace 69 | // ILLEGAL sqltoken 70 | ILLEGAL 71 | ) 72 | -------------------------------------------------------------------------------- /sqltoken/kind_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type Kind kind.go"; DO NOT EDIT. 2 | 3 | package sqltoken 4 | 5 | import "strconv" 6 | 7 | func _() { 8 | // An "invalid array index" compiler error signifies that the constant values have changed. 9 | // Re-run the stringer command to generate them again. 10 | var x [1]struct{} 11 | _ = x[SQLKeyword-0] 12 | _ = x[Number-1] 13 | _ = x[Char-2] 14 | _ = x[SingleQuotedString-3] 15 | _ = x[NationalStringLiteral-4] 16 | _ = x[Comma-5] 17 | _ = x[Whitespace-6] 18 | _ = x[Comment-7] 19 | _ = x[Eq-8] 20 | _ = x[Neq-9] 21 | _ = x[Lt-10] 22 | _ = x[Gt-11] 23 | _ = x[LtEq-12] 24 | _ = x[GtEq-13] 25 | _ = x[Plus-14] 26 | _ = x[Minus-15] 27 | _ = x[Mult-16] 28 | _ = x[Div-17] 29 | _ = x[Mod-18] 30 | _ = x[LParen-19] 31 | _ = x[RParen-20] 32 | _ = x[Period-21] 33 | _ = x[Colon-22] 34 | _ = x[DoubleColon-23] 35 | _ = x[Semicolon-24] 36 | _ = x[Backslash-25] 37 | _ = x[LBracket-26] 38 | _ = x[RBracket-27] 39 | _ = x[Ampersand-28] 40 | _ = x[LBrace-29] 41 | _ = x[RBrace-30] 42 | _ = x[ILLEGAL-31] 43 | } 44 | 45 | const _Kind_name = "SQLKeywordNumberCharSingleQuotedStringNationalStringLiteralCommaWhitespaceCommentEqNeqLtGtLtEqGtEqPlusMinusMultDivModLParenRParenPeriodColonDoubleColonSemicolonBackslashLBracketRBracketAmpersandLBraceRBraceILLEGAL" 46 | 47 | var _Kind_index = [...]uint8{0, 10, 16, 20, 38, 59, 64, 74, 81, 83, 86, 88, 90, 94, 98, 102, 107, 111, 114, 117, 123, 129, 135, 140, 151, 160, 169, 177, 185, 194, 200, 206, 213} 48 | 49 | func (i Kind) String() string { 50 | if i < 0 || i >= Kind(len(_Kind_index)-1) { 51 | return "Kind(" + strconv.FormatInt(int64(i), 10) + ")" 52 | } 53 | return _Kind_name[_Kind_index[i]:_Kind_index[i+1]] 54 | } 55 | -------------------------------------------------------------------------------- /sqltoken/tokenizer.go: -------------------------------------------------------------------------------- 1 | package sqltoken 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | "text/scanner" 8 | 9 | errors "golang.org/x/xerrors" 10 | 11 | "github.com/akito0107/xsqlparser/dialect" 12 | ) 13 | 14 | type SQLWord struct { 15 | Value string 16 | QuoteStyle rune 17 | Keyword string 18 | } 19 | 20 | func (s *SQLWord) String() string { 21 | if s.QuoteStyle == '"' || s.QuoteStyle == '[' || s.QuoteStyle == '`' { 22 | return string(s.QuoteStyle) + s.Value + string(matchingEndQuote(s.QuoteStyle)) 23 | } else if s.QuoteStyle == 0 { 24 | return s.Value 25 | } 26 | return "" 27 | } 28 | 29 | func matchingEndQuote(quoteStyle rune) rune { 30 | switch quoteStyle { 31 | case '"': 32 | return '"' 33 | case '[': 34 | return ']' 35 | case '`': 36 | return '`' 37 | } 38 | return 0 39 | } 40 | 41 | var keywordCache = map[string]*SQLWord{} 42 | 43 | func init() { 44 | for keyword := range dialect.Keywords { 45 | keywordCache[keyword] = &SQLWord{ 46 | Value: keyword, 47 | Keyword: keyword, 48 | } 49 | lower := strings.ToLower(keyword) 50 | keywordCache[lower] = &SQLWord{ 51 | Value: lower, 52 | Keyword: keyword, 53 | } 54 | } 55 | } 56 | 57 | func MakeKeyword(word string, quoteStyle rune) *SQLWord { 58 | if quoteStyle == 0 { 59 | if w, ok := keywordCache[word]; ok { 60 | return w 61 | } 62 | } 63 | w := strings.ToUpper(word) 64 | _, ok := dialect.Keywords[w] 65 | 66 | if quoteStyle == 0 && ok { 67 | return &SQLWord{ 68 | Value: word, 69 | Keyword: w, 70 | } 71 | } else { 72 | return &SQLWord{ 73 | Value: word, 74 | Keyword: w, 75 | QuoteStyle: quoteStyle, 76 | } 77 | } 78 | } 79 | 80 | type Token struct { 81 | Kind Kind 82 | Value interface{} 83 | From Pos 84 | To Pos 85 | } 86 | 87 | func NewPos(line, col int) Pos { 88 | return Pos{ 89 | Line: line, 90 | Col: col, 91 | } 92 | } 93 | 94 | type Pos struct { 95 | Line int 96 | Col int 97 | } 98 | 99 | func (p *Pos) String() string { 100 | return fmt.Sprintf("{Line: %d Col: %d}", p.Line, p.Col) 101 | } 102 | 103 | func ComparePos(x, y Pos) int { 104 | if x.Line == y.Line && x.Col == y.Col { 105 | return 0 106 | } 107 | 108 | if x.Line > y.Line { 109 | return 1 110 | } else if x.Line < y.Line { 111 | return -1 112 | } 113 | 114 | if x.Col > y.Col { 115 | return 1 116 | } 117 | 118 | return -1 119 | } 120 | 121 | type Tokenizer struct { 122 | Dialect dialect.Dialect 123 | Scanner *scanner.Scanner 124 | Line int 125 | Col int 126 | parseComment bool 127 | } 128 | 129 | func NewTokenizer(src io.Reader, dialect dialect.Dialect) *Tokenizer { 130 | var scan scanner.Scanner 131 | return &Tokenizer{ 132 | Dialect: dialect, 133 | Scanner: scan.Init(src), 134 | Line: 1, 135 | Col: 1, 136 | parseComment: true, 137 | } 138 | } 139 | 140 | type TokenizerOption func(*Tokenizer) 141 | 142 | func Dialect(dialect dialect.Dialect) TokenizerOption { 143 | return func(tokenizer *Tokenizer) { 144 | tokenizer.Dialect = dialect 145 | } 146 | } 147 | 148 | func DisableParseComment() TokenizerOption { 149 | return func(tokenizer *Tokenizer) { 150 | tokenizer.parseComment = false 151 | } 152 | } 153 | 154 | func NewTokenizerWithOptions(src io.Reader, options ...TokenizerOption) *Tokenizer { 155 | tokenizer := NewTokenizer(src, &dialect.GenericSQLDialect{}) 156 | for _, o := range options { 157 | o(tokenizer) 158 | } 159 | return tokenizer 160 | } 161 | 162 | func (t *Tokenizer) Tokenize() ([]*Token, error) { 163 | var tokenset []*Token 164 | 165 | for { 166 | t, err := t.NextToken() 167 | if err == io.EOF { 168 | break 169 | } 170 | if err != nil { 171 | return nil, err 172 | } 173 | 174 | if t == nil { 175 | continue 176 | } 177 | tokenset = append(tokenset, t) 178 | } 179 | 180 | return tokenset, nil 181 | } 182 | 183 | func (t *Tokenizer) NextToken() (*Token, error) { 184 | var tok Token 185 | return t.Scan(&tok) 186 | } 187 | 188 | func (t *Tokenizer) Scan(token *Token) (*Token, error) { 189 | pos := t.Pos() 190 | tok, str, err := t.next() 191 | if err == io.EOF { 192 | return nil, io.EOF 193 | } 194 | if err != nil { 195 | token.Kind = ILLEGAL 196 | token.Value = "" 197 | token.From = pos 198 | token.To = t.Pos() 199 | return token, errors.Errorf("tokenize failed: %w", err) 200 | } 201 | 202 | if !t.parseComment && (tok == Whitespace || tok == Comment) { 203 | return nil, nil 204 | } 205 | 206 | token.Kind = tok 207 | token.Value = str 208 | token.From = pos 209 | token.To = t.Pos() 210 | return token, nil 211 | } 212 | 213 | func (t *Tokenizer) Pos() Pos { 214 | return Pos{ 215 | Line: t.Line, 216 | Col: t.Col, 217 | } 218 | } 219 | 220 | func (t *Tokenizer) next() (Kind, interface{}, error) { 221 | r := t.Scanner.Peek() 222 | switch { 223 | case ' ' == r: 224 | t.Scanner.Next() 225 | t.Col += 1 226 | return Whitespace, " ", nil 227 | 228 | case '\t' == r: 229 | t.Scanner.Next() 230 | t.Col += 4 231 | return Whitespace, "\t", nil 232 | 233 | case '\n' == r: 234 | t.Scanner.Next() 235 | t.Line += 1 236 | t.Col = 1 237 | return Whitespace, "\n", nil 238 | 239 | case '\r' == r: 240 | t.Scanner.Next() 241 | n := t.Scanner.Peek() 242 | if n == '\n' { 243 | t.Scanner.Next() 244 | } 245 | t.Line += 1 246 | t.Col = 1 247 | return Whitespace, "\n", nil 248 | 249 | case 'N' == r: 250 | t.Scanner.Next() 251 | n := t.Scanner.Peek() 252 | if n == '\'' { 253 | t.Col += 1 254 | str, err := t.tokenizeSingleQuotedString() 255 | if err != nil { 256 | return ILLEGAL, "", err 257 | } 258 | return NationalStringLiteral, str, nil 259 | } 260 | s := t.tokenizeWord('N') 261 | v := MakeKeyword(s, 0) 262 | return SQLKeyword, v, nil 263 | 264 | case t.Dialect.IsIdentifierStart(r): 265 | t.Scanner.Next() 266 | s := t.tokenizeWord(r) 267 | return SQLKeyword, MakeKeyword(s, 0), nil 268 | 269 | case '\'' == r: 270 | s, err := t.tokenizeSingleQuotedString() 271 | if err != nil { 272 | return ILLEGAL, "", err 273 | } 274 | return SingleQuotedString, s, nil 275 | 276 | case t.Dialect.IsDelimitedIdentifierStart(r): 277 | t.Scanner.Next() 278 | end := matchingEndQuote(r) 279 | 280 | var s []rune 281 | for { 282 | n := t.Scanner.Next() 283 | if n == end { 284 | break 285 | } 286 | s = append(s, n) 287 | } 288 | t.Col += 2 + len(s) 289 | 290 | return SQLKeyword, MakeKeyword(string(s), r), nil 291 | 292 | case '0' <= r && r <= '9': 293 | var s []rune 294 | for { 295 | n := t.Scanner.Peek() 296 | if ('0' <= n && n <= '9') || n == '.' { 297 | s = append(s, n) 298 | t.Scanner.Next() 299 | } else { 300 | break 301 | } 302 | } 303 | t.Col += len(s) 304 | return Number, string(s), nil 305 | 306 | case '(' == r: 307 | t.Scanner.Next() 308 | t.Col += 1 309 | return LParen, "(", nil 310 | 311 | case ')' == r: 312 | t.Scanner.Next() 313 | t.Col += 1 314 | return RParen, ")", nil 315 | 316 | case ',' == r: 317 | t.Scanner.Next() 318 | t.Col += 1 319 | return Comma, ",", nil 320 | 321 | case '-' == r: 322 | t.Scanner.Next() 323 | 324 | if '-' == t.Scanner.Peek() { 325 | t.Scanner.Next() 326 | 327 | var s []rune 328 | for { 329 | ch := t.Scanner.Peek() 330 | if ch != scanner.EOF && ch != '\n' { 331 | t.Scanner.Next() 332 | s = append(s, ch) 333 | } else { 334 | t.Col += len(s) + 2 335 | return Comment, string(s), nil // Comment Node 336 | } 337 | } 338 | } 339 | t.Col += 1 340 | return Minus, "-", nil 341 | 342 | case '/' == r: 343 | t.Scanner.Next() 344 | 345 | if '*' == t.Scanner.Peek() { 346 | t.Scanner.Next() 347 | str, err := t.tokenizeMultilineComment() 348 | if err != nil { 349 | return ILLEGAL, str, err 350 | } 351 | return Comment, str, nil 352 | } 353 | t.Col += 1 354 | return Div, "/", nil 355 | 356 | case '+' == r: 357 | t.Scanner.Next() 358 | t.Col += 1 359 | return Plus, "+", nil 360 | case '*' == r: 361 | t.Scanner.Next() 362 | t.Col += 1 363 | return Mult, "*", nil 364 | case '%' == r: 365 | t.Scanner.Next() 366 | t.Col += 1 367 | return Mod, "%", nil 368 | case '=' == r: 369 | t.Scanner.Next() 370 | t.Col += 1 371 | return Eq, "=", nil 372 | case '.' == r: 373 | t.Scanner.Next() 374 | t.Col += 1 375 | return Period, ".", nil 376 | 377 | case '!' == r: 378 | t.Scanner.Next() 379 | n := t.Scanner.Peek() 380 | if n == '=' { 381 | t.Scanner.Next() 382 | t.Col += 2 383 | return Neq, "!=", nil 384 | } 385 | return ILLEGAL, "", errors.Errorf("tokenizer error: illegal sequence %s%s", string(r), string(n)) 386 | 387 | case '<' == r: 388 | t.Scanner.Next() 389 | switch t.Scanner.Peek() { 390 | case '=': 391 | t.Scanner.Next() 392 | t.Col += 2 393 | return LtEq, "<=", nil 394 | case '>': 395 | t.Scanner.Next() 396 | t.Col += 2 397 | return Neq, "<>", nil 398 | default: 399 | t.Col += 1 400 | return Lt, "<", nil 401 | } 402 | case '>' == r: 403 | t.Scanner.Next() 404 | switch t.Scanner.Peek() { 405 | case '=': 406 | t.Scanner.Next() 407 | t.Col += 2 408 | return GtEq, ">=", nil 409 | default: 410 | t.Col += 1 411 | return Gt, ">", nil 412 | } 413 | case ':' == r: 414 | t.Scanner.Next() 415 | n := t.Scanner.Peek() 416 | if n == ':' { 417 | t.Scanner.Next() 418 | t.Col += 2 419 | return DoubleColon, "::", nil 420 | } 421 | t.Col += 1 422 | return Colon, ":", nil 423 | case ';' == r: 424 | t.Scanner.Next() 425 | t.Col += 1 426 | return Semicolon, ";", nil 427 | case '\\' == r: 428 | t.Scanner.Next() 429 | t.Col += 1 430 | return Backslash, "\\", nil 431 | case '[' == r: 432 | t.Scanner.Next() 433 | t.Col += 1 434 | return LBracket, "[", nil 435 | case ']' == r: 436 | t.Scanner.Next() 437 | t.Col += 1 438 | return RBracket, "]", nil 439 | case '&' == r: 440 | t.Scanner.Next() 441 | t.Col += 1 442 | return Ampersand, "&", nil 443 | case '{' == r: 444 | t.Scanner.Next() 445 | t.Col += 1 446 | return LBrace, "{", nil 447 | case '}' == r: 448 | t.Scanner.Next() 449 | t.Col += 1 450 | return RBrace, "}", nil 451 | case scanner.EOF == r: 452 | return ILLEGAL, "", io.EOF 453 | default: 454 | t.Scanner.Next() 455 | t.Col += 1 456 | return Char, string(r), nil 457 | } 458 | } 459 | 460 | func (t *Tokenizer) tokenizeWord(f rune) string { 461 | var builder strings.Builder 462 | builder.WriteRune(f) 463 | for { 464 | r := t.Scanner.Peek() 465 | if t.Dialect.IsIdentifierPart(r) { 466 | t.Scanner.Next() 467 | builder.WriteRune(r) 468 | } else { 469 | break 470 | } 471 | } 472 | 473 | str := builder.String() 474 | t.Col += len(str) 475 | return str 476 | } 477 | 478 | func (t *Tokenizer) tokenizeSingleQuotedString() (string, error) { 479 | var builder strings.Builder 480 | t.Scanner.Next() 481 | for { 482 | n := t.Scanner.Peek() 483 | if n == '\'' { 484 | t.Scanner.Next() 485 | if t.Scanner.Peek() == '\'' { 486 | // str = append(str, '\'') 487 | builder.WriteRune('\'') 488 | t.Scanner.Next() 489 | } else { 490 | break 491 | } 492 | continue 493 | } 494 | if n == scanner.EOF { 495 | return "", errors.Errorf("unclosed single quoted string: %s at %+v", builder.String(), t.Pos()) 496 | } 497 | 498 | t.Scanner.Next() 499 | builder.WriteRune(n) 500 | // str = append(str, n) 501 | } 502 | str := builder.String() 503 | t.Col += 2 + len(str) 504 | 505 | return str, nil 506 | } 507 | 508 | func (t *Tokenizer) tokenizeMultilineComment() (string, error) { 509 | var str []rune 510 | var mayBeClosingComment bool 511 | t.Col += 2 512 | for { 513 | n := t.Scanner.Next() 514 | 515 | if n == '\r' { 516 | if t.Scanner.Peek() == '\n' { 517 | t.Scanner.Next() 518 | } 519 | t.Col = 1 520 | t.Line += 1 521 | } else if n == '\n' { 522 | t.Col = 1 523 | t.Line += 1 524 | } else if n == scanner.EOF { 525 | return "", errors.Errorf("unclosed multiline comment: %s at %+v", string(str), t.Pos()) 526 | } else { 527 | t.Col += 1 528 | } 529 | 530 | if mayBeClosingComment { 531 | if n == '/' { 532 | break 533 | } else { 534 | str = append(str, n) 535 | } 536 | } 537 | mayBeClosingComment = n == '*' 538 | if !mayBeClosingComment { 539 | str = append(str, n) 540 | } 541 | } 542 | 543 | return string(str), nil 544 | } 545 | -------------------------------------------------------------------------------- /sqltoken/tokenizer_test.go: -------------------------------------------------------------------------------- 1 | package sqltoken 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | 12 | "github.com/akito0107/xsqlparser/dialect" 13 | ) 14 | 15 | func TestTokenizer_Tokenize(t *testing.T) { 16 | cases := []struct { 17 | name string 18 | in string 19 | out []*Token 20 | }{ 21 | { 22 | name: "whitespace", 23 | in: " ", 24 | out: []*Token{ 25 | { 26 | Kind: Whitespace, 27 | Value: " ", 28 | From: Pos{Line: 1, Col: 1}, 29 | To: Pos{Line: 1, Col: 2}, 30 | }, 31 | }, 32 | }, 33 | { 34 | name: "whitespace and new line", 35 | in: ` 36 | `, 37 | out: []*Token{ 38 | { 39 | Kind: Whitespace, 40 | Value: "\n", 41 | From: Pos{Line: 1, Col: 1}, 42 | To: Pos{Line: 2, Col: 1}, 43 | }, 44 | { 45 | Kind: Whitespace, 46 | Value: " ", 47 | From: Pos{Line: 2, Col: 1}, 48 | To: Pos{Line: 2, Col: 2}, 49 | }, 50 | }, 51 | }, 52 | { 53 | name: "whitespace and tab", 54 | in: "\r\n ", 55 | out: []*Token{ 56 | { 57 | Kind: Whitespace, 58 | Value: "\n", 59 | From: Pos{Line: 1, Col: 1}, 60 | To: Pos{Line: 2, Col: 1}, 61 | }, 62 | { 63 | Kind: Whitespace, 64 | Value: "\t", 65 | From: Pos{Line: 2, Col: 1}, 66 | To: Pos{Line: 2, Col: 5}, 67 | }, 68 | }, 69 | }, 70 | { 71 | name: "N string", 72 | in: "N'string'", 73 | out: []*Token{ 74 | { 75 | Kind: NationalStringLiteral, 76 | Value: "string", 77 | From: Pos{Line: 1, Col: 1}, 78 | To: Pos{Line: 1, Col: 10}, 79 | }, 80 | }, 81 | }, 82 | { 83 | name: "N string with keyword", 84 | in: "N'string' NOT", 85 | out: []*Token{ 86 | { 87 | Kind: NationalStringLiteral, 88 | Value: "string", 89 | From: Pos{Line: 1, Col: 1}, 90 | To: Pos{Line: 1, Col: 10}, 91 | }, 92 | { 93 | Kind: Whitespace, 94 | Value: " ", 95 | From: Pos{Line: 1, Col: 10}, 96 | To: Pos{Line: 1, Col: 11}, 97 | }, 98 | { 99 | Kind: SQLKeyword, 100 | Value: &SQLWord{ 101 | Value: "NOT", 102 | Keyword: "NOT", 103 | }, 104 | From: Pos{Line: 1, Col: 11}, 105 | To: Pos{Line: 1, Col: 14}, 106 | }, 107 | }, 108 | }, 109 | { 110 | name: "Ident", 111 | in: "select", 112 | out: []*Token{ 113 | { 114 | Kind: SQLKeyword, 115 | Value: &SQLWord{ 116 | Value: "select", 117 | Keyword: "SELECT", 118 | }, 119 | From: Pos{Line: 1, Col: 1}, 120 | To: Pos{Line: 1, Col: 7}, 121 | }, 122 | }, 123 | }, 124 | { 125 | name: "single quote string", 126 | in: "'test'", 127 | out: []*Token{ 128 | { 129 | Kind: SingleQuotedString, 130 | Value: "test", 131 | From: Pos{Line: 1, Col: 1}, 132 | To: Pos{Line: 1, Col: 7}, 133 | }, 134 | }, 135 | }, 136 | { 137 | name: "quoted string", 138 | in: "\"SELECT\"", 139 | out: []*Token{ 140 | { 141 | Kind: SQLKeyword, 142 | Value: &SQLWord{ 143 | Value: "SELECT", 144 | Keyword: "SELECT", 145 | QuoteStyle: '"', 146 | }, 147 | From: Pos{Line: 1, Col: 1}, 148 | To: Pos{Line: 1, Col: 9}, 149 | }, 150 | }, 151 | }, 152 | { 153 | name: "parents with number", 154 | in: "(123),", 155 | out: []*Token{ 156 | { 157 | Kind: LParen, 158 | Value: "(", 159 | From: Pos{Line: 1, Col: 1}, 160 | To: Pos{Line: 1, Col: 2}, 161 | }, 162 | { 163 | Kind: Number, 164 | Value: "123", 165 | From: Pos{Line: 1, Col: 2}, 166 | To: Pos{Line: 1, Col: 5}, 167 | }, 168 | { 169 | Kind: RParen, 170 | Value: ")", 171 | From: Pos{Line: 1, Col: 5}, 172 | To: Pos{Line: 1, Col: 6}, 173 | }, 174 | { 175 | Kind: Comma, 176 | Value: ",", 177 | From: Pos{Line: 1, Col: 6}, 178 | To: Pos{Line: 1, Col: 7}, 179 | }, 180 | }, 181 | }, 182 | { 183 | name: "minus comment", 184 | in: "-- test", 185 | out: []*Token{ 186 | { 187 | Kind: Comment, 188 | Value: " test", 189 | From: Pos{Line: 1, Col: 1}, 190 | To: Pos{Line: 1, Col: 8}, 191 | }, 192 | }, 193 | }, 194 | { 195 | name: "minus operator", 196 | in: "1-3", 197 | out: []*Token{ 198 | { 199 | Kind: Number, 200 | Value: "1", 201 | From: Pos{Line: 1, Col: 1}, 202 | To: Pos{Line: 1, Col: 2}, 203 | }, 204 | { 205 | Kind: Minus, 206 | Value: "-", 207 | From: Pos{Line: 1, Col: 2}, 208 | To: Pos{Line: 1, Col: 3}, 209 | }, 210 | { 211 | Kind: Number, 212 | Value: "3", 213 | From: Pos{Line: 1, Col: 3}, 214 | To: Pos{Line: 1, Col: 4}, 215 | }, 216 | }, 217 | }, 218 | { 219 | name: "/* comment", 220 | in: `/* test 221 | multiline 222 | comment */`, 223 | out: []*Token{ 224 | { 225 | Kind: Comment, 226 | Value: " test\nmultiline\ncomment ", 227 | From: Pos{Line: 1, Col: 1}, 228 | To: Pos{Line: 3, Col: 11}, 229 | }, 230 | }, 231 | }, 232 | { 233 | name: "operators", 234 | in: "1/1*1+1%1=1.1-.", 235 | out: []*Token{ 236 | { 237 | Kind: Number, 238 | Value: "1", 239 | From: Pos{Line: 1, Col: 1}, 240 | To: Pos{Line: 1, Col: 2}, 241 | }, 242 | { 243 | Kind: Div, 244 | Value: "/", 245 | From: Pos{Line: 1, Col: 2}, 246 | To: Pos{Line: 1, Col: 3}, 247 | }, 248 | { 249 | Kind: Number, 250 | Value: "1", 251 | From: Pos{Line: 1, Col: 3}, 252 | To: Pos{Line: 1, Col: 4}, 253 | }, 254 | { 255 | Kind: Mult, 256 | Value: "*", 257 | From: Pos{Line: 1, Col: 4}, 258 | To: Pos{Line: 1, Col: 5}, 259 | }, 260 | { 261 | Kind: Number, 262 | Value: "1", 263 | From: Pos{Line: 1, Col: 5}, 264 | To: Pos{Line: 1, Col: 6}, 265 | }, 266 | { 267 | Kind: Plus, 268 | Value: "+", 269 | From: Pos{Line: 1, Col: 6}, 270 | To: Pos{Line: 1, Col: 7}, 271 | }, 272 | { 273 | Kind: Number, 274 | Value: "1", 275 | From: Pos{Line: 1, Col: 7}, 276 | To: Pos{Line: 1, Col: 8}, 277 | }, 278 | { 279 | Kind: Mod, 280 | Value: "%", 281 | From: Pos{Line: 1, Col: 8}, 282 | To: Pos{Line: 1, Col: 9}, 283 | }, 284 | { 285 | Kind: Number, 286 | Value: "1", 287 | From: Pos{Line: 1, Col: 9}, 288 | To: Pos{Line: 1, Col: 10}, 289 | }, 290 | { 291 | Kind: Eq, 292 | Value: "=", 293 | From: Pos{Line: 1, Col: 10}, 294 | To: Pos{Line: 1, Col: 11}, 295 | }, 296 | { 297 | Kind: Number, 298 | Value: "1.1", 299 | From: Pos{Line: 1, Col: 11}, 300 | To: Pos{Line: 1, Col: 14}, 301 | }, 302 | { 303 | Kind: Minus, 304 | Value: "-", 305 | From: Pos{Line: 1, Col: 14}, 306 | To: Pos{Line: 1, Col: 15}, 307 | }, 308 | { 309 | Kind: Period, 310 | Value: ".", 311 | From: Pos{Line: 1, Col: 15}, 312 | To: Pos{Line: 1, Col: 16}, 313 | }, 314 | }, 315 | }, 316 | { 317 | name: "Neq", 318 | in: "1!=2", 319 | out: []*Token{ 320 | { 321 | Kind: Number, 322 | Value: "1", 323 | From: Pos{Line: 1, Col: 1}, 324 | To: Pos{Line: 1, Col: 2}, 325 | }, 326 | { 327 | Kind: Neq, 328 | Value: "!=", 329 | From: Pos{Line: 1, Col: 2}, 330 | To: Pos{Line: 1, Col: 4}, 331 | }, 332 | { 333 | Kind: Number, 334 | Value: "2", 335 | From: Pos{Line: 1, Col: 4}, 336 | To: Pos{Line: 1, Col: 5}, 337 | }, 338 | }, 339 | }, 340 | { 341 | name: "Lts", 342 | in: "<<=<>", 343 | out: []*Token{ 344 | { 345 | Kind: Lt, 346 | Value: "<", 347 | From: Pos{Line: 1, Col: 1}, 348 | To: Pos{Line: 1, Col: 2}, 349 | }, 350 | { 351 | Kind: LtEq, 352 | Value: "<=", 353 | From: Pos{Line: 1, Col: 2}, 354 | To: Pos{Line: 1, Col: 4}, 355 | }, 356 | { 357 | Kind: Neq, 358 | Value: "<>", 359 | From: Pos{Line: 1, Col: 4}, 360 | To: Pos{Line: 1, Col: 6}, 361 | }, 362 | }, 363 | }, 364 | { 365 | name: "Gts", 366 | in: ">>=", 367 | out: []*Token{ 368 | { 369 | Kind: Gt, 370 | Value: ">", 371 | From: Pos{Line: 1, Col: 1}, 372 | To: Pos{Line: 1, Col: 2}, 373 | }, 374 | { 375 | Kind: GtEq, 376 | Value: ">=", 377 | From: Pos{Line: 1, Col: 2}, 378 | To: Pos{Line: 1, Col: 4}, 379 | }, 380 | }, 381 | }, 382 | { 383 | name: "colons", 384 | in: ":1::1;", 385 | out: []*Token{ 386 | { 387 | Kind: Colon, 388 | Value: ":", 389 | From: Pos{Line: 1, Col: 1}, 390 | To: Pos{Line: 1, Col: 2}, 391 | }, 392 | { 393 | Kind: Number, 394 | Value: "1", 395 | From: Pos{Line: 1, Col: 2}, 396 | To: Pos{Line: 1, Col: 3}, 397 | }, 398 | { 399 | Kind: DoubleColon, 400 | Value: "::", 401 | From: Pos{Line: 1, Col: 3}, 402 | To: Pos{Line: 1, Col: 5}, 403 | }, 404 | { 405 | Kind: Number, 406 | Value: "1", 407 | From: Pos{Line: 1, Col: 5}, 408 | To: Pos{Line: 1, Col: 6}, 409 | }, 410 | { 411 | Kind: Semicolon, 412 | Value: ";", 413 | From: Pos{Line: 1, Col: 6}, 414 | To: Pos{Line: 1, Col: 7}, 415 | }, 416 | }, 417 | }, 418 | { 419 | name: "others", 420 | in: "\\[{&}]", 421 | out: []*Token{ 422 | { 423 | Kind: Backslash, 424 | Value: "\\", 425 | From: Pos{Line: 1, Col: 1}, 426 | To: Pos{Line: 1, Col: 2}, 427 | }, 428 | { 429 | Kind: LBracket, 430 | Value: "[", 431 | From: Pos{Line: 1, Col: 2}, 432 | To: Pos{Line: 1, Col: 3}, 433 | }, 434 | { 435 | Kind: LBrace, 436 | Value: "{", 437 | From: Pos{Line: 1, Col: 3}, 438 | To: Pos{Line: 1, Col: 4}, 439 | }, 440 | { 441 | Kind: Ampersand, 442 | Value: "&", 443 | From: Pos{Line: 1, Col: 4}, 444 | To: Pos{Line: 1, Col: 5}, 445 | }, 446 | { 447 | Kind: RBrace, 448 | Value: "}", 449 | From: Pos{Line: 1, Col: 5}, 450 | To: Pos{Line: 1, Col: 6}, 451 | }, 452 | { 453 | Kind: RBracket, 454 | Value: "]", 455 | From: Pos{Line: 1, Col: 6}, 456 | To: Pos{Line: 1, Col: 7}, 457 | }, 458 | }, 459 | }, 460 | } 461 | 462 | for _, c := range cases { 463 | t.Run(c.name, func(t *testing.T) { 464 | src := strings.NewReader(c.in) 465 | tokenizer := NewTokenizer(src, &dialect.GenericSQLDialect{}) 466 | 467 | tok, err := tokenizer.Tokenize() 468 | if err != nil { 469 | t.Errorf("should be no error %v", err) 470 | } 471 | 472 | if len(tok) != len(c.out) { 473 | t.Fatalf("should be same length but %d, %d", len(tok), len(c.out)) 474 | } 475 | 476 | for i := 0; i < len(tok); i++ { 477 | if tok[i].Kind != c.out[i].Kind { 478 | t.Errorf("%d, expected sqltoken: %d, but got %d", i, c.out[i].Kind, tok[i].Kind) 479 | } 480 | if !reflect.DeepEqual(tok[i].Value, c.out[i].Value) { 481 | t.Errorf("%d, expected value: %+v, but got %+v", i, c.out[i].Value, tok[i].Value) 482 | } 483 | if !reflect.DeepEqual(tok[i].From, c.out[i].From) { 484 | t.Errorf("%d, expected value: %+v, but got %+v", i, c.out[i].From, tok[i].From) 485 | } 486 | if !reflect.DeepEqual(tok[i].To, c.out[i].To) { 487 | t.Errorf("%d, expected value: %+v, but got %+v", i, c.out[i].To, tok[i].To) 488 | } 489 | } 490 | }) 491 | } 492 | } 493 | 494 | func TestTokenizer_Pos(t *testing.T) { 495 | t.Run("operators", func(t *testing.T) { 496 | cases := []struct { 497 | operator string 498 | add int 499 | }{ 500 | { 501 | operator: "+", 502 | }, 503 | { 504 | operator: "-", 505 | }, 506 | { 507 | operator: "%", 508 | }, 509 | { 510 | operator: "*", 511 | }, 512 | { 513 | operator: "/", 514 | }, 515 | { 516 | operator: ">", 517 | }, 518 | { 519 | operator: "=", 520 | }, 521 | { 522 | operator: "<", 523 | }, 524 | { 525 | operator: "<=", 526 | add: 1, 527 | }, 528 | { 529 | operator: "<>", 530 | add: 1, 531 | }, 532 | { 533 | operator: ">=", 534 | add: 1, 535 | }, 536 | } 537 | 538 | for _, c := range cases { 539 | t.Run(c.operator, func(t *testing.T) { 540 | src := fmt.Sprintf("1 %s 1", c.operator) 541 | tokenizer := NewTokenizer(bytes.NewBufferString(src), &dialect.GenericSQLDialect{}) 542 | 543 | if _, err := tokenizer.Tokenize(); err != nil { 544 | t.Fatal(err) 545 | } 546 | 547 | if d := cmp.Diff(tokenizer.Pos(), Pos{Line: 1, Col: 6 + c.add}); d != "" { 548 | t.Errorf("must be same but diff: %s", d) 549 | } 550 | }) 551 | } 552 | }) 553 | t.Run("other expressions", func(t *testing.T) { 554 | cases := []struct { 555 | name string 556 | src string 557 | expect Pos 558 | }{ 559 | { 560 | name: "multiline ", 561 | src: `1+1 562 | asdf`, 563 | expect: Pos{Line: 2, Col: 5}, 564 | }, 565 | { 566 | name: "single line comment", 567 | src: `-- comments`, 568 | expect: Pos{Line: 1, Col: 12}, 569 | }, 570 | { 571 | name: "statements", 572 | src: `select count(id) from account`, 573 | expect: Pos{Line: 1, Col: 30}, 574 | }, 575 | { 576 | name: "multiline statements", 577 | src: `select count(id) 578 | from account 579 | where name like '%test%'`, 580 | expect: Pos{Line: 3, Col: 25}, 581 | }, 582 | { 583 | name: "multiline comment", 584 | src: `/* 585 | test comment 586 | test comment 587 | */`, 588 | expect: Pos{Line: 4, Col: 3}, 589 | }, 590 | { 591 | name: "single line comment", 592 | src: "/* asdf */", 593 | expect: Pos{Line: 1, Col: 11}, 594 | }, 595 | { 596 | name: "comment inside sql", 597 | src: "select * from /* test table */ test_table where id != 123", 598 | expect: Pos{Line: 1, Col: 58}, 599 | }, 600 | } 601 | 602 | for _, c := range cases { 603 | t.Run(c.name, func(t *testing.T) { 604 | tokenizer := NewTokenizer(bytes.NewBufferString(c.src), &dialect.GenericSQLDialect{}) 605 | 606 | if _, err := tokenizer.Tokenize(); err != nil { 607 | t.Fatal(err) 608 | } 609 | 610 | if d := cmp.Diff(tokenizer.Pos(), c.expect); d != "" { 611 | t.Errorf("must be same but diff: %s", d) 612 | } 613 | }) 614 | } 615 | }) 616 | 617 | t.Run("illegal cases", func(t *testing.T) { 618 | cases := []struct { 619 | name string 620 | src string 621 | }{ 622 | { 623 | name: "incomplete quoted string", 624 | src: "'test", 625 | }, 626 | { 627 | name: "unclosed multiline comment", 628 | src: ` 629 | /* test 630 | test 631 | `, 632 | }, 633 | } 634 | 635 | for _, c := range cases { 636 | t.Run(c.name, func(t *testing.T) { 637 | tokenizer := NewTokenizer(bytes.NewBufferString(c.src), &dialect.GenericSQLDialect{}) 638 | 639 | _, err := tokenizer.Tokenize() 640 | if err == nil { 641 | t.Errorf("must be error but blank") 642 | } 643 | t.Logf("%+v", err) 644 | 645 | }) 646 | } 647 | }) 648 | } 649 | 650 | func BenchmarkTokenizer_Tokenize(b *testing.B) { 651 | cases := []struct { 652 | name string 653 | src string 654 | }{ 655 | { 656 | name: "select", 657 | src: `SELECT COUNT(customer_id), country 658 | FROM customers 659 | GROUP BY country 660 | HAVING COUNT(customer_id) > 3`, 661 | }, 662 | { 663 | name: "complex select", 664 | src: `SELECT start_terminal, 665 | start_time, 666 | duration_seconds, 667 | ROW_NUMBER() OVER (ORDER BY start_time) 668 | AS row_number 669 | FROM tutorial.dc_bikeshare_q1_2012 670 | WHERE start_time < '2012-01-08'`, 671 | }, 672 | { 673 | name: "insert", 674 | src: `INSERT INTO tbl_name (a,b,c) VALUES(1,2,3),(4,5,6),(7,8,9);`, 675 | }, 676 | 677 | { 678 | 679 | name: "multi line comment", 680 | src: ` 681 | create table account ( 682 | account_id serial primary key, --aaa 683 | /*bbb*/ 684 | name varchar(255) not null, 685 | email /*ccc*/ varchar(255) unique not null --ddd 686 | ); 687 | 688 | --eee 689 | 690 | /*fff 691 | ggg 692 | */ 693 | select 1 from test; --hhh 694 | /*jjj*/ --kkk 695 | select 1 from test; /*lll*/ --mmm 696 | --nnn 697 | `, 698 | }, 699 | } 700 | 701 | for _, c := range cases { 702 | b.Run(c.name, func(b *testing.B) { 703 | b.ResetTimer() 704 | 705 | for i := 0; i < b.N; i++ { 706 | in := bytes.NewBufferString(c.src) 707 | tokenizer := NewTokenizer(in, &dialect.GenericSQLDialect{}) 708 | 709 | if _, err := tokenizer.Tokenize(); err != nil { 710 | b.Fatal(err) 711 | } 712 | } 713 | }) 714 | } 715 | } 716 | 717 | func BenchmarkTokenizer_Tokenize_WithoutComment(b *testing.B) { 718 | cases := []struct { 719 | name string 720 | src string 721 | }{ 722 | { 723 | name: "select", 724 | src: `SELECT COUNT(customer_id), country 725 | FROM customers 726 | GROUP BY country 727 | HAVING COUNT(customer_id) > 3`, 728 | }, 729 | { 730 | name: "complex select", 731 | src: `SELECT start_terminal, 732 | start_time, 733 | duration_seconds, 734 | ROW_NUMBER() OVER (ORDER BY start_time) 735 | AS row_number 736 | FROM tutorial.dc_bikeshare_q1_2012 737 | WHERE start_time < '2012-01-08'`, 738 | }, 739 | { 740 | name: "insert", 741 | src: `INSERT INTO tbl_name (a,b,c) VALUES(1,2,3),(4,5,6),(7,8,9);`, 742 | }, 743 | } 744 | 745 | for _, c := range cases { 746 | b.Run(c.name, func(b *testing.B) { 747 | b.ResetTimer() 748 | 749 | for i := 0; i < b.N; i++ { 750 | in := bytes.NewBufferString(c.src) 751 | tokenizer := NewTokenizerWithOptions(in, Dialect(&dialect.GenericSQLDialect{}), DisableParseComment()) 752 | 753 | if _, err := tokenizer.Tokenize(); err != nil { 754 | b.Fatal(err) 755 | } 756 | } 757 | }) 758 | } 759 | } 760 | -------------------------------------------------------------------------------- /testhelper.go: -------------------------------------------------------------------------------- 1 | package xsqlparser 2 | 3 | import ( 4 | "reflect" 5 | "unicode" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | ) 9 | 10 | var IgnoreMarker = cmp.FilterPath(func(paths cmp.Path) bool { 11 | s := paths.Last().Type() 12 | name := s.Name() 13 | r := []rune(name) 14 | return s.Kind() == reflect.Struct && len(r) > 0 && unicode.IsLower(r[0]) 15 | }, cmp.Ignore()) 16 | 17 | func CompareWithoutMarker(a, b interface{}) string { 18 | return cmp.Diff(a, b, IgnoreMarker) 19 | } -------------------------------------------------------------------------------- /tools/genmark/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "fmt" 7 | "go/format" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "regexp" 12 | "strings" 13 | "unicode" 14 | ) 15 | 16 | func main() { 17 | log.SetFlags(0) 18 | log.SetPrefix("genmark: ") 19 | 20 | if err := run(); err != nil { 21 | if _, ok := err.(*flagError); ok { 22 | flag.Usage() 23 | } 24 | log.Fatal(err) 25 | } 26 | } 27 | 28 | type flagError struct { 29 | message string 30 | } 31 | 32 | func (e *flagError) Error() string { 33 | return e.message 34 | } 35 | 36 | func FlagError(format string, a ...interface{}) *flagError { 37 | return &flagError{message: fmt.Sprintf(format, a...)} 38 | } 39 | 40 | var snakeRegex = regexp.MustCompile("(^|[a-z])([A-Z])") 41 | 42 | func toSnake(s string) string { 43 | s = strings.Replace(s, "SQL", "Sql", -1) 44 | return snakeRegex.ReplaceAllStringFunc(s, func(s string) string { 45 | r := []rune(s) 46 | if len(r) == 1 { 47 | r[0] = unicode.ToLower(r[0]) 48 | } else { 49 | r = append(r[:1], '_', unicode.ToLower(r[1])) 50 | } 51 | return string(r) 52 | }) 53 | } 54 | 55 | func toPrivate(s string) string { 56 | if strings.HasPrefix(s, "SQL") { 57 | s = strings.Replace(s, "SQL", "sql", 1) 58 | } 59 | r := []rune(s) 60 | r[0] = unicode.ToLower(r[0]) 61 | return string(r) 62 | } 63 | 64 | func run() error { 65 | var flags struct { 66 | MarkerTypeName string 67 | OutputName string 68 | Embedded string 69 | Package string 70 | } 71 | 72 | flag.StringVar(&flags.MarkerTypeName, "t", "", "marker interface type name (required)") 73 | flag.StringVar(&flags.OutputName, "o", "", "output filename (automatically add '_gen.go')") 74 | flag.StringVar(&flags.Embedded, "e", "", "embedded struct list (comma separated)") 75 | flag.StringVar(&flags.Package, "pkg", os.Getenv("GOPACKAGE"), "package name") 76 | flag.Parse() 77 | 78 | markerTypeName := flags.MarkerTypeName 79 | if markerTypeName == "" { 80 | return FlagError("-t is must be required") 81 | } 82 | if !unicode.IsUpper([]rune(markerTypeName)[0]) { 83 | return FlagError("-t is must be public") 84 | } 85 | implTypeName := toPrivate(markerTypeName) 86 | 87 | outputName := flags.OutputName 88 | if outputName == "" { 89 | outputName = toSnake(markerTypeName) 90 | } 91 | outputName += "_gen.go" 92 | 93 | var embedded []string 94 | if len(flags.Embedded) != 0 { 95 | embedded = strings.Split(flags.Embedded, ",") 96 | } 97 | 98 | buf := &bytes.Buffer{} 99 | fmt.Fprintf(buf, "package %s\n", flags.Package) 100 | fmt.Fprintf(buf, "// Code generated by genmark. DO NOT EDIT.\n\n") 101 | fmt.Fprintf(buf, "type %s interface {\n", markerTypeName) 102 | fmt.Fprintf(buf, "%sMarker()\n", implTypeName) 103 | for _, e := range embedded { 104 | fmt.Fprintln(buf, e) 105 | } 106 | fmt.Fprintf(buf, "}\n") 107 | fmt.Fprintf(buf, "type %s struct {}\n", implTypeName) 108 | fmt.Fprintf(buf, "func (%s) %sMarker() {}\n", implTypeName, implTypeName) 109 | 110 | src, err := format.Source(buf.Bytes()) 111 | if err != nil { 112 | return fmt.Errorf("failed to format source code: %s", err.Error()) 113 | } 114 | 115 | err = ioutil.WriteFile(outputName, src, 0666) 116 | if err != nil { 117 | return fmt.Errorf("failed to write generate code: %s", err.Error()) 118 | } 119 | return nil 120 | } 121 | -------------------------------------------------------------------------------- /tools/genmark/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestToSnake(t *testing.T) { 8 | ts := []struct { 9 | in string 10 | expected string 11 | } { 12 | {in: "Foo", expected: "foo"}, 13 | {in: "FooBar", expected: "foo_bar"}, 14 | {in: "FooB", expected: "foo_b"}, 15 | {in: "SQLBar", expected: "sql_bar"}, 16 | {in: "BarSQL", expected: "bar_sql"}, 17 | } 18 | for _, tc := range ts { 19 | got := toSnake(tc.in) 20 | if got != tc.expected { 21 | t.Errorf("unexpected snake case. expected: %v, but got: %v", tc.expected, got) 22 | } 23 | } 24 | } 25 | --------------------------------------------------------------------------------