├── .gitignore ├── .travis.yml ├── CONTRIBUTORS.md ├── LICENSE.md ├── Makefile ├── README.md ├── analyzer.go ├── analyzer_test.go ├── ast.go ├── ast_test.go ├── comments.go ├── comments_test.go ├── dependency ├── bytes2 │ ├── buffer.go │ └── buffer_test.go ├── hack │ ├── hack.go │ └── hack_test.go ├── querypb │ └── query.pb.go └── sqltypes │ ├── bind_variables.go │ ├── bind_variables_test.go │ ├── plan_value.go │ ├── plan_value_test.go │ ├── testing.go │ ├── type.go │ ├── type_test.go │ ├── value.go │ └── value_test.go ├── encodable.go ├── encodable_test.go ├── github_test.go ├── impossible_query.go ├── normalizer.go ├── normalizer_test.go ├── parse_next_test.go ├── parse_test.go ├── parsed_query.go ├── parsed_query_test.go ├── patches ├── bytes2.patch ├── querypb.patch ├── sqlparser.patch └── sqltypes.patch ├── precedence_test.go ├── redact_query.go ├── redact_query_test.go ├── sql.go ├── sql.y ├── token.go ├── token_test.go └── tracked_buffer.go /.gitignore: -------------------------------------------------------------------------------- 1 | y.output 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | sudo: false 3 | 4 | language: go 5 | go: 6 | - 1.6.x 7 | - 1.7.x 8 | - 1.8.x 9 | - 1.9.x 10 | 11 | before_install: 12 | - go get github.com/mattn/goveralls 13 | - go get golang.org/x/tools/cmd/cover 14 | - go install golang.org/x/tools/cmd/goyacc 15 | 16 | script: 17 | - travis_retry $HOME/gopath/bin/goveralls -service=travis-ci 18 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | This project is originally a fork of [https://github.com/youtube/vitess](https://github.com/youtube/vitess) 2 | Copyright Google Inc 3 | 4 | # Contributors 5 | Wenbin Xiao 2015 6 | Started this project and maintained it. 7 | 8 | Andrew Brampton 2017 9 | Merged in multiple upstream fixes/changes. -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | MAKEFLAGS = -s 16 | 17 | sql.go: sql.y 18 | goyacc -o sql.go sql.y 19 | gofmt -w sql.go 20 | 21 | clean: 22 | rm -f y.output sql.go 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlparser [![Build Status](https://img.shields.io/travis/xwb1989/sqlparser.svg)](https://travis-ci.org/xwb1989/sqlparser) [![Coverage](https://img.shields.io/coveralls/xwb1989/sqlparser.svg)](https://coveralls.io/github/xwb1989/sqlparser) [![Report card](https://goreportcard.com/badge/github.com/xwb1989/sqlparser)](https://goreportcard.com/report/github.com/xwb1989/sqlparser) [![GoDoc](https://godoc.org/github.com/xwb1989/sqlparser?status.svg)](https://godoc.org/github.com/xwb1989/sqlparser) 2 | 3 | Go package for parsing MySQL SQL queries. 4 | 5 | ## Notice 6 | 7 | The backbone of this repo is extracted from [vitessio/vitess](https://github.com/vitessio/vitess). 8 | 9 | Inside vitessio/vitess there is a very nicely written sql parser. However as it's not a self-contained application, I created this one. 10 | It applies the same LICENSE as vitessio/vitess. 11 | 12 | ## Usage 13 | 14 | ```go 15 | import ( 16 | "github.com/xwb1989/sqlparser" 17 | ) 18 | ``` 19 | 20 | Then use: 21 | 22 | ```go 23 | sql := "SELECT * FROM table WHERE a = 'abc'" 24 | stmt, err := sqlparser.Parse(sql) 25 | if err != nil { 26 | // Do something with the err 27 | } 28 | 29 | // Otherwise do something with stmt 30 | switch stmt := stmt.(type) { 31 | case *sqlparser.Select: 32 | _ = stmt 33 | case *sqlparser.Insert: 34 | } 35 | ``` 36 | 37 | Alternative to read many queries from a io.Reader: 38 | 39 | ```go 40 | r := strings.NewReader("INSERT INTO table1 VALUES (1, 'a'); INSERT INTO table2 VALUES (3, 4);") 41 | 42 | tokens := sqlparser.NewTokenizer(r) 43 | for { 44 | stmt, err := sqlparser.ParseNext(tokens) 45 | if err == io.EOF { 46 | break 47 | } 48 | // Do something with stmt or err. 49 | } 50 | ``` 51 | 52 | See [parse_test.go](https://github.com/xwb1989/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/xwb1989/sqlparser). 53 | 54 | 55 | ## Porting Instructions 56 | 57 | You only need the below if you plan to try and keep this library up to date with [vitessio/vitess](https://github.com/vitessio/vitess). 58 | 59 | ### Keeping up to date 60 | 61 | ```bash 62 | shopt -s nullglob 63 | VITESS=${GOPATH?}/src/vitess.io/vitess/go/ 64 | XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/ 65 | 66 | # Create patches for everything that changed 67 | LASTIMPORT=1b7879cb91f1dfe1a2dfa06fea96e951e3a7aec5 68 | for path in ${VITESS?}/{vt/sqlparser,sqltypes,bytes2,hack}; do 69 | cd ${path} 70 | git format-patch ${LASTIMPORT?} . 71 | done; 72 | 73 | # Apply patches to the dependencies 74 | cd ${XWB1989?} 75 | git am --directory dependency -p2 ${VITESS?}/{sqltypes,bytes2,hack}/*.patch 76 | 77 | # Apply the main patches to the repo 78 | cd ${XWB1989?} 79 | git am -p4 ${VITESS?}/vt/sqlparser/*.patch 80 | 81 | # If you encounter diff failures, manually fix them with 82 | patch -p4 < .git/rebase-apply/patch 83 | ... 84 | git add name_of_files 85 | git am --continue 86 | 87 | # Cleanup 88 | rm ${VITESS?}/{sqltypes,bytes2,hack}/*.patch ${VITESS?}/*.patch 89 | 90 | # and Finally update the LASTIMPORT in this README. 91 | ``` 92 | 93 | ### Fresh install 94 | 95 | TODO: Change these instructions to use git to copy the files, that'll make later patching easier. 96 | 97 | ```bash 98 | VITESS=${GOPATH?}/src/vitess.io/vitess/go/ 99 | XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/ 100 | 101 | cd ${XWB1989?} 102 | 103 | # Copy all the code 104 | cp -pr ${VITESS?}/vt/sqlparser/ . 105 | cp -pr ${VITESS?}/sqltypes dependency 106 | cp -pr ${VITESS?}/bytes2 dependency 107 | cp -pr ${VITESS?}/hack dependency 108 | 109 | # Delete some code we haven't ported 110 | rm dependency/sqltypes/arithmetic.go dependency/sqltypes/arithmetic_test.go dependency/sqltypes/event_token.go dependency/sqltypes/event_token_test.go dependency/sqltypes/proto3.go dependency/sqltypes/proto3_test.go dependency/sqltypes/query_response.go dependency/sqltypes/result.go dependency/sqltypes/result_test.go 111 | 112 | # Some automated fixes 113 | 114 | # Fix imports 115 | sed -i '.bak' 's_vitess.io/vitess/go/vt/proto/query_github.com/xwb1989/sqlparser/dependency/querypb_g' *.go dependency/sqltypes/*.go 116 | sed -i '.bak' 's_vitess.io/vitess/go/_github.com/xwb1989/sqlparser/dependency/_g' *.go dependency/sqltypes/*.go 117 | 118 | # Copy the proto, but basically drop everything we don't want 119 | cp -pr ${VITESS?}/vt/proto/query dependency/querypb 120 | 121 | sed -i '.bak' 's_.*Descriptor.*__g' dependency/querypb/*.go 122 | sed -i '.bak' 's_.*ProtoMessage.*__g' dependency/querypb/*.go 123 | 124 | sed -i '.bak' 's/proto.CompactTextString(m)/"TODO"/g' dependency/querypb/*.go 125 | sed -i '.bak' 's/proto.EnumName/EnumName/g' dependency/querypb/*.go 126 | 127 | sed -i '.bak' 's/proto.Equal/reflect.DeepEqual/g' dependency/sqltypes/*.go 128 | 129 | # Remove the error library 130 | sed -i '.bak' 's/vterrors.Errorf([^,]*, /fmt.Errorf(/g' *.go dependency/sqltypes/*.go 131 | sed -i '.bak' 's/vterrors.New([^,]*, /errors.New(/g' *.go dependency/sqltypes/*.go 132 | ``` 133 | 134 | ### Testing 135 | 136 | ```bash 137 | VITESS=${GOPATH?}/src/vitess.io/vitess/go/ 138 | XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/ 139 | 140 | cd ${XWB1989?} 141 | 142 | # Test, fix and repeat 143 | go test ./... 144 | 145 | # Finally make some diffs (for later reference) 146 | diff -u ${VITESS?}/sqltypes/ ${XWB1989?}/dependency/sqltypes/ > ${XWB1989?}/patches/sqltypes.patch 147 | diff -u ${VITESS?}/bytes2/ ${XWB1989?}/dependency/bytes2/ > ${XWB1989?}/patches/bytes2.patch 148 | diff -u ${VITESS?}/vt/proto/query/ ${XWB1989?}/dependency/querypb/ > ${XWB1989?}/patches/querypb.patch 149 | diff -u ${VITESS?}/vt/sqlparser/ ${XWB1989?}/ > ${XWB1989?}/patches/sqlparser.patch 150 | ``` -------------------------------------------------------------------------------- /analyzer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | // analyzer.go contains utility analysis functions. 20 | 21 | import ( 22 | "errors" 23 | "fmt" 24 | "strconv" 25 | "strings" 26 | "unicode" 27 | 28 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 29 | ) 30 | 31 | // These constants are used to identify the SQL statement type. 32 | const ( 33 | StmtSelect = iota 34 | StmtStream 35 | StmtInsert 36 | StmtReplace 37 | StmtUpdate 38 | StmtDelete 39 | StmtDDL 40 | StmtBegin 41 | StmtCommit 42 | StmtRollback 43 | StmtSet 44 | StmtShow 45 | StmtUse 46 | StmtOther 47 | StmtUnknown 48 | StmtComment 49 | ) 50 | 51 | // Preview analyzes the beginning of the query using a simpler and faster 52 | // textual comparison to identify the statement type. 53 | func Preview(sql string) int { 54 | trimmed := StripLeadingComments(sql) 55 | 56 | firstWord := trimmed 57 | if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 { 58 | firstWord = trimmed[:end] 59 | } 60 | firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) }) 61 | // Comparison is done in order of priority. 62 | loweredFirstWord := strings.ToLower(firstWord) 63 | switch loweredFirstWord { 64 | case "select": 65 | return StmtSelect 66 | case "stream": 67 | return StmtStream 68 | case "insert": 69 | return StmtInsert 70 | case "replace": 71 | return StmtReplace 72 | case "update": 73 | return StmtUpdate 74 | case "delete": 75 | return StmtDelete 76 | } 77 | // For the following statements it is not sufficient to rely 78 | // on loweredFirstWord. This is because they are not statements 79 | // in the grammar and we are relying on Preview to parse them. 80 | // For instance, we don't want: "BEGIN JUNK" to be parsed 81 | // as StmtBegin. 82 | trimmedNoComments, _ := SplitMarginComments(trimmed) 83 | switch strings.ToLower(trimmedNoComments) { 84 | case "begin", "start transaction": 85 | return StmtBegin 86 | case "commit": 87 | return StmtCommit 88 | case "rollback": 89 | return StmtRollback 90 | } 91 | switch loweredFirstWord { 92 | case "create", "alter", "rename", "drop", "truncate": 93 | return StmtDDL 94 | case "set": 95 | return StmtSet 96 | case "show": 97 | return StmtShow 98 | case "use": 99 | return StmtUse 100 | case "analyze", "describe", "desc", "explain", "repair", "optimize": 101 | return StmtOther 102 | } 103 | if strings.Index(trimmed, "/*!") == 0 { 104 | return StmtComment 105 | } 106 | return StmtUnknown 107 | } 108 | 109 | // StmtType returns the statement type as a string 110 | func StmtType(stmtType int) string { 111 | switch stmtType { 112 | case StmtSelect: 113 | return "SELECT" 114 | case StmtStream: 115 | return "STREAM" 116 | case StmtInsert: 117 | return "INSERT" 118 | case StmtReplace: 119 | return "REPLACE" 120 | case StmtUpdate: 121 | return "UPDATE" 122 | case StmtDelete: 123 | return "DELETE" 124 | case StmtDDL: 125 | return "DDL" 126 | case StmtBegin: 127 | return "BEGIN" 128 | case StmtCommit: 129 | return "COMMIT" 130 | case StmtRollback: 131 | return "ROLLBACK" 132 | case StmtSet: 133 | return "SET" 134 | case StmtShow: 135 | return "SHOW" 136 | case StmtUse: 137 | return "USE" 138 | case StmtOther: 139 | return "OTHER" 140 | default: 141 | return "UNKNOWN" 142 | } 143 | } 144 | 145 | // IsDML returns true if the query is an INSERT, UPDATE or DELETE statement. 146 | func IsDML(sql string) bool { 147 | switch Preview(sql) { 148 | case StmtInsert, StmtReplace, StmtUpdate, StmtDelete: 149 | return true 150 | } 151 | return false 152 | } 153 | 154 | // GetTableName returns the table name from the SimpleTableExpr 155 | // only if it's a simple expression. Otherwise, it returns "". 156 | func GetTableName(node SimpleTableExpr) TableIdent { 157 | if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() { 158 | return n.Name 159 | } 160 | // sub-select or '.' expression 161 | return NewTableIdent("") 162 | } 163 | 164 | // IsColName returns true if the Expr is a *ColName. 165 | func IsColName(node Expr) bool { 166 | _, ok := node.(*ColName) 167 | return ok 168 | } 169 | 170 | // IsValue returns true if the Expr is a string, integral or value arg. 171 | // NULL is not considered to be a value. 172 | func IsValue(node Expr) bool { 173 | switch v := node.(type) { 174 | case *SQLVal: 175 | switch v.Type { 176 | case StrVal, HexVal, IntVal, ValArg: 177 | return true 178 | } 179 | } 180 | return false 181 | } 182 | 183 | // IsNull returns true if the Expr is SQL NULL 184 | func IsNull(node Expr) bool { 185 | switch node.(type) { 186 | case *NullVal: 187 | return true 188 | } 189 | return false 190 | } 191 | 192 | // IsSimpleTuple returns true if the Expr is a ValTuple that 193 | // contains simple values or if it's a list arg. 194 | func IsSimpleTuple(node Expr) bool { 195 | switch vals := node.(type) { 196 | case ValTuple: 197 | for _, n := range vals { 198 | if !IsValue(n) { 199 | return false 200 | } 201 | } 202 | return true 203 | case ListArg: 204 | return true 205 | } 206 | // It's a subquery 207 | return false 208 | } 209 | 210 | // NewPlanValue builds a sqltypes.PlanValue from an Expr. 211 | func NewPlanValue(node Expr) (sqltypes.PlanValue, error) { 212 | switch node := node.(type) { 213 | case *SQLVal: 214 | switch node.Type { 215 | case ValArg: 216 | return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil 217 | case IntVal: 218 | n, err := sqltypes.NewIntegral(string(node.Val)) 219 | if err != nil { 220 | return sqltypes.PlanValue{}, fmt.Errorf("%v", err) 221 | } 222 | return sqltypes.PlanValue{Value: n}, nil 223 | case StrVal: 224 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil 225 | case HexVal: 226 | v, err := node.HexDecode() 227 | if err != nil { 228 | return sqltypes.PlanValue{}, fmt.Errorf("%v", err) 229 | } 230 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil 231 | } 232 | case ListArg: 233 | return sqltypes.PlanValue{ListKey: string(node[2:])}, nil 234 | case ValTuple: 235 | pv := sqltypes.PlanValue{ 236 | Values: make([]sqltypes.PlanValue, 0, len(node)), 237 | } 238 | for _, val := range node { 239 | innerpv, err := NewPlanValue(val) 240 | if err != nil { 241 | return sqltypes.PlanValue{}, err 242 | } 243 | if innerpv.ListKey != "" || innerpv.Values != nil { 244 | return sqltypes.PlanValue{}, errors.New("unsupported: nested lists") 245 | } 246 | pv.Values = append(pv.Values, innerpv) 247 | } 248 | return pv, nil 249 | case *NullVal: 250 | return sqltypes.PlanValue{}, nil 251 | } 252 | return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node)) 253 | } 254 | 255 | // StringIn is a convenience function that returns 256 | // true if str matches any of the values. 257 | func StringIn(str string, values ...string) bool { 258 | for _, val := range values { 259 | if str == val { 260 | return true 261 | } 262 | } 263 | return false 264 | } 265 | 266 | // SetKey is the extracted key from one SetExpr 267 | type SetKey struct { 268 | Key string 269 | Scope string 270 | } 271 | 272 | // ExtractSetValues returns a map of key-value pairs 273 | // if the query is a SET statement. Values can be bool, int64 or string. 274 | // Since set variable names are case insensitive, all keys are returned 275 | // as lower case. 276 | func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) { 277 | stmt, err := Parse(sql) 278 | if err != nil { 279 | return nil, "", err 280 | } 281 | setStmt, ok := stmt.(*Set) 282 | if !ok { 283 | return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt) 284 | } 285 | result := make(map[SetKey]interface{}) 286 | for _, expr := range setStmt.Exprs { 287 | scope := SessionStr 288 | key := expr.Name.Lowered() 289 | switch { 290 | case strings.HasPrefix(key, "@@global."): 291 | scope = GlobalStr 292 | key = strings.TrimPrefix(key, "@@global.") 293 | case strings.HasPrefix(key, "@@session."): 294 | key = strings.TrimPrefix(key, "@@session.") 295 | case strings.HasPrefix(key, "@@"): 296 | key = strings.TrimPrefix(key, "@@") 297 | } 298 | 299 | if strings.HasPrefix(expr.Name.Lowered(), "@@") { 300 | if setStmt.Scope != "" && scope != "" { 301 | return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope") 302 | } 303 | _, out := NewStringTokenizer(key).Scan() 304 | key = string(out) 305 | } 306 | 307 | setKey := SetKey{ 308 | Key: key, 309 | Scope: scope, 310 | } 311 | 312 | switch expr := expr.Expr.(type) { 313 | case *SQLVal: 314 | switch expr.Type { 315 | case StrVal: 316 | result[setKey] = strings.ToLower(string(expr.Val)) 317 | case IntVal: 318 | num, err := strconv.ParseInt(string(expr.Val), 0, 64) 319 | if err != nil { 320 | return nil, "", err 321 | } 322 | result[setKey] = num 323 | default: 324 | return nil, "", fmt.Errorf("invalid value type: %v", String(expr)) 325 | } 326 | case BoolVal: 327 | var val int64 328 | if expr { 329 | val = 1 330 | } 331 | result[setKey] = val 332 | case *ColName: 333 | result[setKey] = expr.Name.String() 334 | case *NullVal: 335 | result[setKey] = nil 336 | case *Default: 337 | result[setKey] = "default" 338 | default: 339 | return nil, "", fmt.Errorf("invalid syntax: %s", String(expr)) 340 | } 341 | } 342 | return result, strings.ToLower(setStmt.Scope), nil 343 | } 344 | -------------------------------------------------------------------------------- /comments.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strconv" 21 | "strings" 22 | "unicode" 23 | ) 24 | 25 | const ( 26 | // DirectiveMultiShardAutocommit is the query comment directive to allow 27 | // single round trip autocommit with a multi-shard statement. 28 | DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT" 29 | // DirectiveSkipQueryPlanCache skips query plan cache when set. 30 | DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE" 31 | // DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS. 32 | DirectiveQueryTimeout = "QUERY_TIMEOUT_MS" 33 | ) 34 | 35 | func isNonSpace(r rune) bool { 36 | return !unicode.IsSpace(r) 37 | } 38 | 39 | // leadingCommentEnd returns the first index after all leading comments, or 40 | // 0 if there are no leading comments. 41 | func leadingCommentEnd(text string) (end int) { 42 | hasComment := false 43 | pos := 0 44 | for pos < len(text) { 45 | // Eat up any whitespace. Trailing whitespace will be considered part of 46 | // the leading comments. 47 | nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace) 48 | if nextVisibleOffset < 0 { 49 | break 50 | } 51 | pos += nextVisibleOffset 52 | remainingText := text[pos:] 53 | 54 | // Found visible characters. Look for '/*' at the beginning 55 | // and '*/' somewhere after that. 56 | if len(remainingText) < 4 || remainingText[:2] != "/*" { 57 | break 58 | } 59 | commentLength := 4 + strings.Index(remainingText[2:], "*/") 60 | if commentLength < 4 { 61 | // Missing end comment :/ 62 | break 63 | } 64 | 65 | hasComment = true 66 | pos += commentLength 67 | } 68 | 69 | if hasComment { 70 | return pos 71 | } 72 | return 0 73 | } 74 | 75 | // trailingCommentStart returns the first index of trailing comments. 76 | // If there are no trailing comments, returns the length of the input string. 77 | func trailingCommentStart(text string) (start int) { 78 | hasComment := false 79 | reducedLen := len(text) 80 | for reducedLen > 0 { 81 | // Eat up any whitespace. Leading whitespace will be considered part of 82 | // the trailing comments. 83 | nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1 84 | if nextReducedLen == 0 { 85 | break 86 | } 87 | reducedLen = nextReducedLen 88 | if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" { 89 | break 90 | } 91 | 92 | // Find the beginning of the comment 93 | startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*") 94 | if startCommentPos < 0 { 95 | // Badly formatted sql :/ 96 | break 97 | } 98 | 99 | hasComment = true 100 | reducedLen = startCommentPos 101 | } 102 | 103 | if hasComment { 104 | return reducedLen 105 | } 106 | return len(text) 107 | } 108 | 109 | // MarginComments holds the leading and trailing comments that surround a query. 110 | type MarginComments struct { 111 | Leading string 112 | Trailing string 113 | } 114 | 115 | // SplitMarginComments pulls out any leading or trailing comments from a raw sql query. 116 | // This function also trims leading (if there's a comment) and trailing whitespace. 117 | func SplitMarginComments(sql string) (query string, comments MarginComments) { 118 | trailingStart := trailingCommentStart(sql) 119 | leadingEnd := leadingCommentEnd(sql[:trailingStart]) 120 | comments = MarginComments{ 121 | Leading: strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace), 122 | Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace), 123 | } 124 | return strings.TrimFunc(sql[leadingEnd:trailingStart], unicode.IsSpace), comments 125 | } 126 | 127 | // StripLeadingComments trims the SQL string and removes any leading comments 128 | func StripLeadingComments(sql string) string { 129 | sql = strings.TrimFunc(sql, unicode.IsSpace) 130 | 131 | for hasCommentPrefix(sql) { 132 | switch sql[0] { 133 | case '/': 134 | // Multi line comment 135 | index := strings.Index(sql, "*/") 136 | if index <= 1 { 137 | return sql 138 | } 139 | // don't strip /*! ... */ or /*!50700 ... */ 140 | if len(sql) > 2 && sql[2] == '!' { 141 | return sql 142 | } 143 | sql = sql[index+2:] 144 | case '-': 145 | // Single line comment 146 | index := strings.Index(sql, "\n") 147 | if index == -1 { 148 | return sql 149 | } 150 | sql = sql[index+1:] 151 | } 152 | 153 | sql = strings.TrimFunc(sql, unicode.IsSpace) 154 | } 155 | 156 | return sql 157 | } 158 | 159 | func hasCommentPrefix(sql string) bool { 160 | return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-')) 161 | } 162 | 163 | // ExtractMysqlComment extracts the version and SQL from a comment-only query 164 | // such as /*!50708 sql here */ 165 | func ExtractMysqlComment(sql string) (version string, innerSQL string) { 166 | sql = sql[3 : len(sql)-2] 167 | 168 | digitCount := 0 169 | endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { 170 | digitCount++ 171 | return !unicode.IsDigit(c) || digitCount == 6 172 | }) 173 | version = sql[0:endOfVersionIndex] 174 | innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace) 175 | 176 | return version, innerSQL 177 | } 178 | 179 | const commentDirectivePreamble = "/*vt+" 180 | 181 | // CommentDirectives is the parsed representation for execution directives 182 | // conveyed in query comments 183 | type CommentDirectives map[string]interface{} 184 | 185 | // ExtractCommentDirectives parses the comment list for any execution directives 186 | // of the form: 187 | // 188 | // /*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */ 189 | // 190 | // It returns the map of the directive values or nil if there aren't any. 191 | func ExtractCommentDirectives(comments Comments) CommentDirectives { 192 | if comments == nil { 193 | return nil 194 | } 195 | 196 | var vals map[string]interface{} 197 | 198 | for _, comment := range comments { 199 | commentStr := string(comment) 200 | if commentStr[0:5] != commentDirectivePreamble { 201 | continue 202 | } 203 | 204 | if vals == nil { 205 | vals = make(map[string]interface{}) 206 | } 207 | 208 | // Split on whitespace and ignore the first and last directive 209 | // since they contain the comment start/end 210 | directives := strings.Fields(commentStr) 211 | for i := 1; i < len(directives)-1; i++ { 212 | directive := directives[i] 213 | sep := strings.IndexByte(directive, '=') 214 | 215 | // No value is equivalent to a true boolean 216 | if sep == -1 { 217 | vals[directive] = true 218 | continue 219 | } 220 | 221 | strVal := directive[sep+1:] 222 | directive = directive[:sep] 223 | 224 | intVal, err := strconv.Atoi(strVal) 225 | if err == nil { 226 | vals[directive] = intVal 227 | continue 228 | } 229 | 230 | boolVal, err := strconv.ParseBool(strVal) 231 | if err == nil { 232 | vals[directive] = boolVal 233 | continue 234 | } 235 | 236 | vals[directive] = strVal 237 | } 238 | } 239 | return vals 240 | } 241 | 242 | // IsSet checks the directive map for the named directive and returns 243 | // true if the directive is set and has a true/false or 0/1 value 244 | func (d CommentDirectives) IsSet(key string) bool { 245 | if d == nil { 246 | return false 247 | } 248 | 249 | val, ok := d[key] 250 | if !ok { 251 | return false 252 | } 253 | 254 | boolVal, ok := val.(bool) 255 | if ok { 256 | return boolVal 257 | } 258 | 259 | intVal, ok := val.(int) 260 | if ok { 261 | return intVal == 1 262 | } 263 | return false 264 | } 265 | 266 | // SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query. 267 | func SkipQueryPlanCacheDirective(stmt Statement) bool { 268 | switch stmt := stmt.(type) { 269 | case *Select: 270 | directives := ExtractCommentDirectives(stmt.Comments) 271 | if directives.IsSet(DirectiveSkipQueryPlanCache) { 272 | return true 273 | } 274 | case *Insert: 275 | directives := ExtractCommentDirectives(stmt.Comments) 276 | if directives.IsSet(DirectiveSkipQueryPlanCache) { 277 | return true 278 | } 279 | case *Update: 280 | directives := ExtractCommentDirectives(stmt.Comments) 281 | if directives.IsSet(DirectiveSkipQueryPlanCache) { 282 | return true 283 | } 284 | case *Delete: 285 | directives := ExtractCommentDirectives(stmt.Comments) 286 | if directives.IsSet(DirectiveSkipQueryPlanCache) { 287 | return true 288 | } 289 | default: 290 | return false 291 | } 292 | return false 293 | } 294 | -------------------------------------------------------------------------------- /comments_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "reflect" 21 | "testing" 22 | ) 23 | 24 | func TestSplitComments(t *testing.T) { 25 | var testCases = []struct { 26 | input, outSQL, outLeadingComments, outTrailingComments string 27 | }{{ 28 | input: "/", 29 | outSQL: "/", 30 | outLeadingComments: "", 31 | outTrailingComments: "", 32 | }, { 33 | input: "*/", 34 | outSQL: "*/", 35 | outLeadingComments: "", 36 | outTrailingComments: "", 37 | }, { 38 | input: "/*/", 39 | outSQL: "/*/", 40 | outLeadingComments: "", 41 | outTrailingComments: "", 42 | }, { 43 | input: "a*/", 44 | outSQL: "a*/", 45 | outLeadingComments: "", 46 | outTrailingComments: "", 47 | }, { 48 | input: "*a*/", 49 | outSQL: "*a*/", 50 | outLeadingComments: "", 51 | outTrailingComments: "", 52 | }, { 53 | input: "**a*/", 54 | outSQL: "**a*/", 55 | outLeadingComments: "", 56 | outTrailingComments: "", 57 | }, { 58 | input: "/*b**a*/", 59 | outSQL: "", 60 | outLeadingComments: "", 61 | outTrailingComments: "/*b**a*/", 62 | }, { 63 | input: "/*a*/", 64 | outSQL: "", 65 | outLeadingComments: "", 66 | outTrailingComments: "/*a*/", 67 | }, { 68 | input: "/**/", 69 | outSQL: "", 70 | outLeadingComments: "", 71 | outTrailingComments: "/**/", 72 | }, { 73 | input: "/*b*/ /*a*/", 74 | outSQL: "", 75 | outLeadingComments: "", 76 | outTrailingComments: "/*b*/ /*a*/", 77 | }, { 78 | input: "/* before */ foo /* bar */", 79 | outSQL: "foo", 80 | outLeadingComments: "/* before */ ", 81 | outTrailingComments: " /* bar */", 82 | }, { 83 | input: "/* before1 */ /* before2 */ foo /* after1 */ /* after2 */", 84 | outSQL: "foo", 85 | outLeadingComments: "/* before1 */ /* before2 */ ", 86 | outTrailingComments: " /* after1 */ /* after2 */", 87 | }, { 88 | input: "/** before */ foo /** bar */", 89 | outSQL: "foo", 90 | outLeadingComments: "/** before */ ", 91 | outTrailingComments: " /** bar */", 92 | }, { 93 | input: "/*** before */ foo /*** bar */", 94 | outSQL: "foo", 95 | outLeadingComments: "/*** before */ ", 96 | outTrailingComments: " /*** bar */", 97 | }, { 98 | input: "/** before **/ foo /** bar **/", 99 | outSQL: "foo", 100 | outLeadingComments: "/** before **/ ", 101 | outTrailingComments: " /** bar **/", 102 | }, { 103 | input: "/*** before ***/ foo /*** bar ***/", 104 | outSQL: "foo", 105 | outLeadingComments: "/*** before ***/ ", 106 | outTrailingComments: " /*** bar ***/", 107 | }, { 108 | input: " /*** before ***/ foo /*** bar ***/ ", 109 | outSQL: "foo", 110 | outLeadingComments: "/*** before ***/ ", 111 | outTrailingComments: " /*** bar ***/", 112 | }, { 113 | input: "*** bar ***/", 114 | outSQL: "*** bar ***/", 115 | outLeadingComments: "", 116 | outTrailingComments: "", 117 | }, { 118 | input: " foo ", 119 | outSQL: "foo", 120 | outLeadingComments: "", 121 | outTrailingComments: "", 122 | }} 123 | for _, testCase := range testCases { 124 | gotSQL, gotComments := SplitMarginComments(testCase.input) 125 | gotLeadingComments, gotTrailingComments := gotComments.Leading, gotComments.Trailing 126 | 127 | if gotSQL != testCase.outSQL { 128 | t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) 129 | } 130 | if gotLeadingComments != testCase.outLeadingComments { 131 | t.Errorf("test input: '%s', got LeadingComments\n%+v, want\n%+v", testCase.input, gotLeadingComments, testCase.outLeadingComments) 132 | } 133 | if gotTrailingComments != testCase.outTrailingComments { 134 | t.Errorf("test input: '%s', got TrailingComments\n%+v, want\n%+v", testCase.input, gotTrailingComments, testCase.outTrailingComments) 135 | } 136 | } 137 | } 138 | 139 | func TestStripLeadingComments(t *testing.T) { 140 | var testCases = []struct { 141 | input, outSQL string 142 | }{{ 143 | input: "/", 144 | outSQL: "/", 145 | }, { 146 | input: "*/", 147 | outSQL: "*/", 148 | }, { 149 | input: "/*/", 150 | outSQL: "/*/", 151 | }, { 152 | input: "/*a", 153 | outSQL: "/*a", 154 | }, { 155 | input: "/*a*", 156 | outSQL: "/*a*", 157 | }, { 158 | input: "/*a**", 159 | outSQL: "/*a**", 160 | }, { 161 | input: "/*b**a*/", 162 | outSQL: "", 163 | }, { 164 | input: "/*a*/", 165 | outSQL: "", 166 | }, { 167 | input: "/**/", 168 | outSQL: "", 169 | }, { 170 | input: "/*!*/", 171 | outSQL: "/*!*/", 172 | }, { 173 | input: "/*!a*/", 174 | outSQL: "/*!a*/", 175 | }, { 176 | input: "/*b*/ /*a*/", 177 | outSQL: "", 178 | }, { 179 | input: `/*b*/ --foo 180 | bar`, 181 | outSQL: "bar", 182 | }, { 183 | input: "foo /* bar */", 184 | outSQL: "foo /* bar */", 185 | }, { 186 | input: "/* foo */ bar", 187 | outSQL: "bar", 188 | }, { 189 | input: "-- /* foo */ bar", 190 | outSQL: "-- /* foo */ bar", 191 | }, { 192 | input: "foo -- bar */", 193 | outSQL: "foo -- bar */", 194 | }, { 195 | input: `/* 196 | foo */ bar`, 197 | outSQL: "bar", 198 | }, { 199 | input: `-- foo bar 200 | a`, 201 | outSQL: "a", 202 | }, { 203 | input: `-- foo bar`, 204 | outSQL: "-- foo bar", 205 | }} 206 | for _, testCase := range testCases { 207 | gotSQL := StripLeadingComments(testCase.input) 208 | 209 | if gotSQL != testCase.outSQL { 210 | t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) 211 | } 212 | } 213 | } 214 | 215 | func TestExtractMysqlComment(t *testing.T) { 216 | var testCases = []struct { 217 | input, outSQL, outVersion string 218 | }{{ 219 | input: "/*!50708SET max_execution_time=5000 */", 220 | outSQL: "SET max_execution_time=5000", 221 | outVersion: "50708", 222 | }, { 223 | input: "/*!50708 SET max_execution_time=5000*/", 224 | outSQL: "SET max_execution_time=5000", 225 | outVersion: "50708", 226 | }, { 227 | input: "/*!50708* from*/", 228 | outSQL: "* from", 229 | outVersion: "50708", 230 | }, { 231 | input: "/*! SET max_execution_time=5000*/", 232 | outSQL: "SET max_execution_time=5000", 233 | outVersion: "", 234 | }} 235 | for _, testCase := range testCases { 236 | gotVersion, gotSQL := ExtractMysqlComment(testCase.input) 237 | 238 | if gotVersion != testCase.outVersion { 239 | t.Errorf("test input: '%s', got version\n%+v, want\n%+v", testCase.input, gotVersion, testCase.outVersion) 240 | } 241 | if gotSQL != testCase.outSQL { 242 | t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) 243 | } 244 | } 245 | } 246 | 247 | func TestExtractCommentDirectives(t *testing.T) { 248 | var testCases = []struct { 249 | input string 250 | vals CommentDirectives 251 | }{{ 252 | input: "", 253 | vals: nil, 254 | }, { 255 | input: "/* not a vt comment */", 256 | vals: nil, 257 | }, { 258 | input: "/*vt+ */", 259 | vals: CommentDirectives{}, 260 | }, { 261 | input: "/*vt+ SINGLE_OPTION */", 262 | vals: CommentDirectives{ 263 | "SINGLE_OPTION": true, 264 | }, 265 | }, { 266 | input: "/*vt+ ONE_OPT TWO_OPT */", 267 | vals: CommentDirectives{ 268 | "ONE_OPT": true, 269 | "TWO_OPT": true, 270 | }, 271 | }, { 272 | input: "/*vt+ ONE_OPT */ /* other comment */ /*vt+ TWO_OPT */", 273 | vals: CommentDirectives{ 274 | "ONE_OPT": true, 275 | "TWO_OPT": true, 276 | }, 277 | }, { 278 | input: "/*vt+ ONE_OPT=abc TWO_OPT=def */", 279 | vals: CommentDirectives{ 280 | "ONE_OPT": "abc", 281 | "TWO_OPT": "def", 282 | }, 283 | }, { 284 | input: "/*vt+ ONE_OPT=true TWO_OPT=false */", 285 | vals: CommentDirectives{ 286 | "ONE_OPT": true, 287 | "TWO_OPT": false, 288 | }, 289 | }, { 290 | input: "/*vt+ ONE_OPT=true TWO_OPT=\"false\" */", 291 | vals: CommentDirectives{ 292 | "ONE_OPT": true, 293 | "TWO_OPT": "\"false\"", 294 | }, 295 | }, { 296 | input: "/*vt+ RANGE_OPT=[a:b] ANOTHER ANOTHER_WITH_VALEQ=val= AND_ONE_WITH_EQ== */", 297 | vals: CommentDirectives{ 298 | "RANGE_OPT": "[a:b]", 299 | "ANOTHER": true, 300 | "ANOTHER_WITH_VALEQ": "val=", 301 | "AND_ONE_WITH_EQ": "=", 302 | }, 303 | }} 304 | 305 | for _, testCase := range testCases { 306 | sql := "select " + testCase.input + " 1 from dual" 307 | stmt, _ := Parse(sql) 308 | comments := stmt.(*Select).Comments 309 | vals := ExtractCommentDirectives(comments) 310 | 311 | if !reflect.DeepEqual(vals, testCase.vals) { 312 | t.Errorf("test input: '%v', got vals:\n%+v, want\n%+v", testCase.input, vals, testCase.vals) 313 | } 314 | } 315 | 316 | d := CommentDirectives{ 317 | "ONE_OPT": true, 318 | "TWO_OPT": false, 319 | "three": 1, 320 | "four": 2, 321 | "five": 0, 322 | "six": "true", 323 | } 324 | 325 | if !d.IsSet("ONE_OPT") { 326 | t.Errorf("d.IsSet(ONE_OPT) should be true") 327 | } 328 | 329 | if d.IsSet("TWO_OPT") { 330 | t.Errorf("d.IsSet(TWO_OPT) should be false") 331 | } 332 | 333 | if !d.IsSet("three") { 334 | t.Errorf("d.IsSet(three) should be true") 335 | } 336 | 337 | if d.IsSet("four") { 338 | t.Errorf("d.IsSet(four) should be false") 339 | } 340 | 341 | if d.IsSet("five") { 342 | t.Errorf("d.IsSet(five) should be false") 343 | } 344 | 345 | if d.IsSet("six") { 346 | t.Errorf("d.IsSet(six) should be false") 347 | } 348 | } 349 | 350 | func TestSkipQueryPlanCacheDirective(t *testing.T) { 351 | stmt, _ := Parse("insert /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ into user(id) values (1), (2)") 352 | if !SkipQueryPlanCacheDirective(stmt) { 353 | t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true") 354 | } 355 | 356 | stmt, _ = Parse("insert into user(id) values (1), (2)") 357 | if SkipQueryPlanCacheDirective(stmt) { 358 | t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be false") 359 | } 360 | 361 | stmt, _ = Parse("update /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ users set name=1") 362 | if !SkipQueryPlanCacheDirective(stmt) { 363 | t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true") 364 | } 365 | 366 | stmt, _ = Parse("select /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ * from users") 367 | if !SkipQueryPlanCacheDirective(stmt) { 368 | t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true") 369 | } 370 | 371 | stmt, _ = Parse("delete /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ from users") 372 | if !SkipQueryPlanCacheDirective(stmt) { 373 | t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true") 374 | } 375 | } 376 | -------------------------------------------------------------------------------- /dependency/bytes2/buffer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package bytes2 18 | 19 | // Buffer implements a subset of the write portion of 20 | // bytes.Buffer, but more efficiently. This is meant to 21 | // be used in very high QPS operations, especially for 22 | // WriteByte, and without abstracting it as a Writer. 23 | // Function signatures contain errors for compatibility, 24 | // but they do not return errors. 25 | type Buffer struct { 26 | bytes []byte 27 | } 28 | 29 | // NewBuffer is equivalent to bytes.NewBuffer. 30 | func NewBuffer(b []byte) *Buffer { 31 | return &Buffer{bytes: b} 32 | } 33 | 34 | // Write is equivalent to bytes.Buffer.Write. 35 | func (buf *Buffer) Write(b []byte) (int, error) { 36 | buf.bytes = append(buf.bytes, b...) 37 | return len(b), nil 38 | } 39 | 40 | // WriteString is equivalent to bytes.Buffer.WriteString. 41 | func (buf *Buffer) WriteString(s string) (int, error) { 42 | buf.bytes = append(buf.bytes, s...) 43 | return len(s), nil 44 | } 45 | 46 | // WriteByte is equivalent to bytes.Buffer.WriteByte. 47 | func (buf *Buffer) WriteByte(b byte) error { 48 | buf.bytes = append(buf.bytes, b) 49 | return nil 50 | } 51 | 52 | // Bytes is equivalent to bytes.Buffer.Bytes. 53 | func (buf *Buffer) Bytes() []byte { 54 | return buf.bytes 55 | } 56 | 57 | // Strings is equivalent to bytes.Buffer.Strings. 58 | func (buf *Buffer) String() string { 59 | return string(buf.bytes) 60 | } 61 | 62 | // Len is equivalent to bytes.Buffer.Len. 63 | func (buf *Buffer) Len() int { 64 | return len(buf.bytes) 65 | } 66 | -------------------------------------------------------------------------------- /dependency/bytes2/buffer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package bytes2 18 | 19 | import "testing" 20 | 21 | func TestBuffer(t *testing.T) { 22 | b := NewBuffer(nil) 23 | b.Write([]byte("ab")) 24 | b.WriteString("cd") 25 | b.WriteByte('e') 26 | want := "abcde" 27 | if got := string(b.Bytes()); got != want { 28 | t.Errorf("b.Bytes(): %s, want %s", got, want) 29 | } 30 | if got := b.String(); got != want { 31 | t.Errorf("b.String(): %s, want %s", got, want) 32 | } 33 | if got := b.Len(); got != 5 { 34 | t.Errorf("b.Len(): %d, want 5", got) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /dependency/hack/hack.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package hack gives you some efficient functionality at the cost of 18 | // breaking some Go rules. 19 | package hack 20 | 21 | import ( 22 | "reflect" 23 | "unsafe" 24 | ) 25 | 26 | // StringArena lets you consolidate allocations for a group of strings 27 | // that have similar life length 28 | type StringArena struct { 29 | buf []byte 30 | str string 31 | } 32 | 33 | // NewStringArena creates an arena of the specified size. 34 | func NewStringArena(size int) *StringArena { 35 | sa := &StringArena{buf: make([]byte, 0, size)} 36 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&sa.buf)) 37 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&sa.str)) 38 | pstring.Data = pbytes.Data 39 | pstring.Len = pbytes.Cap 40 | return sa 41 | } 42 | 43 | // NewString copies a byte slice into the arena and returns it as a string. 44 | // If the arena is full, it returns a traditional go string. 45 | func (sa *StringArena) NewString(b []byte) string { 46 | if len(b) == 0 { 47 | return "" 48 | } 49 | if len(sa.buf)+len(b) > cap(sa.buf) { 50 | return string(b) 51 | } 52 | start := len(sa.buf) 53 | sa.buf = append(sa.buf, b...) 54 | return sa.str[start : start+len(b)] 55 | } 56 | 57 | // SpaceLeft returns the amount of space left in the arena. 58 | func (sa *StringArena) SpaceLeft() int { 59 | return cap(sa.buf) - len(sa.buf) 60 | } 61 | 62 | // String force casts a []byte to a string. 63 | // USE AT YOUR OWN RISK 64 | func String(b []byte) (s string) { 65 | if len(b) == 0 { 66 | return "" 67 | } 68 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 69 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 70 | pstring.Data = pbytes.Data 71 | pstring.Len = pbytes.Len 72 | return 73 | } 74 | 75 | // StringPointer returns &s[0], which is not allowed in go 76 | func StringPointer(s string) unsafe.Pointer { 77 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 78 | return unsafe.Pointer(pstring.Data) 79 | } 80 | -------------------------------------------------------------------------------- /dependency/hack/hack_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package hack 18 | 19 | import "testing" 20 | 21 | func TestStringArena(t *testing.T) { 22 | sarena := NewStringArena(10) 23 | 24 | s0 := sarena.NewString(nil) 25 | checkint(t, len(sarena.buf), 0) 26 | checkint(t, sarena.SpaceLeft(), 10) 27 | checkstring(t, s0, "") 28 | 29 | s1 := sarena.NewString([]byte("01234")) 30 | checkint(t, len(sarena.buf), 5) 31 | checkint(t, sarena.SpaceLeft(), 5) 32 | checkstring(t, s1, "01234") 33 | 34 | s2 := sarena.NewString([]byte("5678")) 35 | checkint(t, len(sarena.buf), 9) 36 | checkint(t, sarena.SpaceLeft(), 1) 37 | checkstring(t, s2, "5678") 38 | 39 | // s3 will be allocated outside of sarena 40 | s3 := sarena.NewString([]byte("ab")) 41 | checkint(t, len(sarena.buf), 9) 42 | checkint(t, sarena.SpaceLeft(), 1) 43 | checkstring(t, s3, "ab") 44 | 45 | // s4 should still fit in sarena 46 | s4 := sarena.NewString([]byte("9")) 47 | checkint(t, len(sarena.buf), 10) 48 | checkint(t, sarena.SpaceLeft(), 0) 49 | checkstring(t, s4, "9") 50 | 51 | sarena.buf[0] = 'A' 52 | checkstring(t, s1, "A1234") 53 | 54 | sarena.buf[5] = 'B' 55 | checkstring(t, s2, "B678") 56 | 57 | sarena.buf[9] = 'C' 58 | // s3 will not change 59 | checkstring(t, s3, "ab") 60 | checkstring(t, s4, "C") 61 | checkstring(t, sarena.str, "A1234B678C") 62 | } 63 | 64 | func checkstring(t *testing.T, actual, expected string) { 65 | if actual != expected { 66 | t.Errorf("received %s, expecting %s", actual, expected) 67 | } 68 | } 69 | 70 | func checkint(t *testing.T, actual, expected int) { 71 | if actual != expected { 72 | t.Errorf("received %d, expecting %d", actual, expected) 73 | } 74 | } 75 | 76 | func TestByteToString(t *testing.T) { 77 | v1 := []byte("1234") 78 | if s := String(v1); s != "1234" { 79 | t.Errorf("String(\"1234\"): %q, want 1234", s) 80 | } 81 | 82 | v1 = []byte("") 83 | if s := String(v1); s != "" { 84 | t.Errorf("String(\"\"): %q, want empty", s) 85 | } 86 | 87 | v1 = nil 88 | if s := String(v1); s != "" { 89 | t.Errorf("String(\"\"): %q, want empty", s) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /dependency/sqltypes/bind_variables.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "errors" 21 | "fmt" 22 | "reflect" 23 | "strconv" 24 | 25 | "github.com/xwb1989/sqlparser/dependency/querypb" 26 | ) 27 | 28 | // NullBindVariable is a bindvar with NULL value. 29 | var NullBindVariable = &querypb.BindVariable{Type: querypb.Type_NULL_TYPE} 30 | 31 | // ValueToProto converts Value to a *querypb.Value. 32 | func ValueToProto(v Value) *querypb.Value { 33 | return &querypb.Value{Type: v.typ, Value: v.val} 34 | } 35 | 36 | // ProtoToValue converts a *querypb.Value to a Value. 37 | func ProtoToValue(v *querypb.Value) Value { 38 | return MakeTrusted(v.Type, v.Value) 39 | } 40 | 41 | // BuildBindVariables builds a map[string]*querypb.BindVariable from a map[string]interface{}. 42 | func BuildBindVariables(in map[string]interface{}) (map[string]*querypb.BindVariable, error) { 43 | if len(in) == 0 { 44 | return nil, nil 45 | } 46 | 47 | out := make(map[string]*querypb.BindVariable, len(in)) 48 | for k, v := range in { 49 | bv, err := BuildBindVariable(v) 50 | if err != nil { 51 | return nil, fmt.Errorf("%s: %v", k, err) 52 | } 53 | out[k] = bv 54 | } 55 | return out, nil 56 | } 57 | 58 | // Int32BindVariable converts an int32 to a bind var. 59 | func Int32BindVariable(v int32) *querypb.BindVariable { 60 | return ValueBindVariable(NewInt32(v)) 61 | } 62 | 63 | // Int64BindVariable converts an int64 to a bind var. 64 | func Int64BindVariable(v int64) *querypb.BindVariable { 65 | return ValueBindVariable(NewInt64(v)) 66 | } 67 | 68 | // Uint64BindVariable converts a uint64 to a bind var. 69 | func Uint64BindVariable(v uint64) *querypb.BindVariable { 70 | return ValueBindVariable(NewUint64(v)) 71 | } 72 | 73 | // Float64BindVariable converts a float64 to a bind var. 74 | func Float64BindVariable(v float64) *querypb.BindVariable { 75 | return ValueBindVariable(NewFloat64(v)) 76 | } 77 | 78 | // StringBindVariable converts a string to a bind var. 79 | func StringBindVariable(v string) *querypb.BindVariable { 80 | return ValueBindVariable(NewVarChar(v)) 81 | } 82 | 83 | // BytesBindVariable converts a []byte to a bind var. 84 | func BytesBindVariable(v []byte) *querypb.BindVariable { 85 | return &querypb.BindVariable{Type: VarBinary, Value: v} 86 | } 87 | 88 | // ValueBindVariable converts a Value to a bind var. 89 | func ValueBindVariable(v Value) *querypb.BindVariable { 90 | return &querypb.BindVariable{Type: v.typ, Value: v.val} 91 | } 92 | 93 | // BuildBindVariable builds a *querypb.BindVariable from a valid input type. 94 | func BuildBindVariable(v interface{}) (*querypb.BindVariable, error) { 95 | switch v := v.(type) { 96 | case string: 97 | return StringBindVariable(v), nil 98 | case []byte: 99 | return BytesBindVariable(v), nil 100 | case int: 101 | return &querypb.BindVariable{ 102 | Type: querypb.Type_INT64, 103 | Value: strconv.AppendInt(nil, int64(v), 10), 104 | }, nil 105 | case int64: 106 | return Int64BindVariable(v), nil 107 | case uint64: 108 | return Uint64BindVariable(v), nil 109 | case float64: 110 | return Float64BindVariable(v), nil 111 | case nil: 112 | return NullBindVariable, nil 113 | case Value: 114 | return ValueBindVariable(v), nil 115 | case *querypb.BindVariable: 116 | return v, nil 117 | case []interface{}: 118 | bv := &querypb.BindVariable{ 119 | Type: querypb.Type_TUPLE, 120 | Values: make([]*querypb.Value, len(v)), 121 | } 122 | values := make([]querypb.Value, len(v)) 123 | for i, lv := range v { 124 | lbv, err := BuildBindVariable(lv) 125 | if err != nil { 126 | return nil, err 127 | } 128 | values[i].Type = lbv.Type 129 | values[i].Value = lbv.Value 130 | bv.Values[i] = &values[i] 131 | } 132 | return bv, nil 133 | case []string: 134 | bv := &querypb.BindVariable{ 135 | Type: querypb.Type_TUPLE, 136 | Values: make([]*querypb.Value, len(v)), 137 | } 138 | values := make([]querypb.Value, len(v)) 139 | for i, lv := range v { 140 | values[i].Type = querypb.Type_VARCHAR 141 | values[i].Value = []byte(lv) 142 | bv.Values[i] = &values[i] 143 | } 144 | return bv, nil 145 | case [][]byte: 146 | bv := &querypb.BindVariable{ 147 | Type: querypb.Type_TUPLE, 148 | Values: make([]*querypb.Value, len(v)), 149 | } 150 | values := make([]querypb.Value, len(v)) 151 | for i, lv := range v { 152 | values[i].Type = querypb.Type_VARBINARY 153 | values[i].Value = lv 154 | bv.Values[i] = &values[i] 155 | } 156 | return bv, nil 157 | case []int: 158 | bv := &querypb.BindVariable{ 159 | Type: querypb.Type_TUPLE, 160 | Values: make([]*querypb.Value, len(v)), 161 | } 162 | values := make([]querypb.Value, len(v)) 163 | for i, lv := range v { 164 | values[i].Type = querypb.Type_INT64 165 | values[i].Value = strconv.AppendInt(nil, int64(lv), 10) 166 | bv.Values[i] = &values[i] 167 | } 168 | return bv, nil 169 | case []int64: 170 | bv := &querypb.BindVariable{ 171 | Type: querypb.Type_TUPLE, 172 | Values: make([]*querypb.Value, len(v)), 173 | } 174 | values := make([]querypb.Value, len(v)) 175 | for i, lv := range v { 176 | values[i].Type = querypb.Type_INT64 177 | values[i].Value = strconv.AppendInt(nil, lv, 10) 178 | bv.Values[i] = &values[i] 179 | } 180 | return bv, nil 181 | case []uint64: 182 | bv := &querypb.BindVariable{ 183 | Type: querypb.Type_TUPLE, 184 | Values: make([]*querypb.Value, len(v)), 185 | } 186 | values := make([]querypb.Value, len(v)) 187 | for i, lv := range v { 188 | values[i].Type = querypb.Type_UINT64 189 | values[i].Value = strconv.AppendUint(nil, lv, 10) 190 | bv.Values[i] = &values[i] 191 | } 192 | return bv, nil 193 | case []float64: 194 | bv := &querypb.BindVariable{ 195 | Type: querypb.Type_TUPLE, 196 | Values: make([]*querypb.Value, len(v)), 197 | } 198 | values := make([]querypb.Value, len(v)) 199 | for i, lv := range v { 200 | values[i].Type = querypb.Type_FLOAT64 201 | values[i].Value = strconv.AppendFloat(nil, lv, 'g', -1, 64) 202 | bv.Values[i] = &values[i] 203 | } 204 | return bv, nil 205 | } 206 | return nil, fmt.Errorf("type %T not supported as bind var: %v", v, v) 207 | } 208 | 209 | // ValidateBindVariables validates a map[string]*querypb.BindVariable. 210 | func ValidateBindVariables(bv map[string]*querypb.BindVariable) error { 211 | for k, v := range bv { 212 | if err := ValidateBindVariable(v); err != nil { 213 | return fmt.Errorf("%s: %v", k, err) 214 | } 215 | } 216 | return nil 217 | } 218 | 219 | // ValidateBindVariable returns an error if the bind variable has inconsistent 220 | // fields. 221 | func ValidateBindVariable(bv *querypb.BindVariable) error { 222 | if bv == nil { 223 | return errors.New("bind variable is nil") 224 | } 225 | 226 | if bv.Type == querypb.Type_TUPLE { 227 | if len(bv.Values) == 0 { 228 | return errors.New("empty tuple is not allowed") 229 | } 230 | for _, val := range bv.Values { 231 | if val.Type == querypb.Type_TUPLE { 232 | return errors.New("tuple not allowed inside another tuple") 233 | } 234 | if err := ValidateBindVariable(&querypb.BindVariable{Type: val.Type, Value: val.Value}); err != nil { 235 | return err 236 | } 237 | } 238 | return nil 239 | } 240 | 241 | // If NewValue succeeds, the value is valid. 242 | _, err := NewValue(bv.Type, bv.Value) 243 | return err 244 | } 245 | 246 | // BindVariableToValue converts a bind var into a Value. 247 | func BindVariableToValue(bv *querypb.BindVariable) (Value, error) { 248 | if bv.Type == querypb.Type_TUPLE { 249 | return NULL, errors.New("cannot convert a TUPLE bind var into a value") 250 | } 251 | return MakeTrusted(bv.Type, bv.Value), nil 252 | } 253 | 254 | // BindVariablesEqual compares two maps of bind variables. 255 | func BindVariablesEqual(x, y map[string]*querypb.BindVariable) bool { 256 | return reflect.DeepEqual(&querypb.BoundQuery{BindVariables: x}, &querypb.BoundQuery{BindVariables: y}) 257 | } 258 | 259 | // CopyBindVariables returns a shallow-copy of the given bindVariables map. 260 | func CopyBindVariables(bindVariables map[string]*querypb.BindVariable) map[string]*querypb.BindVariable { 261 | result := make(map[string]*querypb.BindVariable, len(bindVariables)) 262 | for key, value := range bindVariables { 263 | result[key] = value 264 | } 265 | return result 266 | } 267 | -------------------------------------------------------------------------------- /dependency/sqltypes/bind_variables_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "reflect" 21 | "strings" 22 | "testing" 23 | 24 | "github.com/xwb1989/sqlparser/dependency/querypb" 25 | ) 26 | 27 | func TestProtoConversions(t *testing.T) { 28 | v := TestValue(Int64, "1") 29 | got := ValueToProto(v) 30 | want := &querypb.Value{Type: Int64, Value: []byte("1")} 31 | if !reflect.DeepEqual(got, want) { 32 | t.Errorf("ValueToProto: %v, want %v", got, want) 33 | } 34 | gotback := ProtoToValue(got) 35 | if !reflect.DeepEqual(gotback, v) { 36 | t.Errorf("ProtoToValue: %v, want %v", gotback, v) 37 | } 38 | } 39 | 40 | func TestBuildBindVariables(t *testing.T) { 41 | tcases := []struct { 42 | in map[string]interface{} 43 | out map[string]*querypb.BindVariable 44 | err string 45 | }{{ 46 | in: nil, 47 | out: nil, 48 | }, { 49 | in: map[string]interface{}{ 50 | "k": int64(1), 51 | }, 52 | out: map[string]*querypb.BindVariable{ 53 | "k": Int64BindVariable(1), 54 | }, 55 | }, { 56 | in: map[string]interface{}{ 57 | "k": byte(1), 58 | }, 59 | err: "k: type uint8 not supported as bind var: 1", 60 | }} 61 | for _, tcase := range tcases { 62 | bindVars, err := BuildBindVariables(tcase.in) 63 | if err != nil { 64 | if err.Error() != tcase.err { 65 | t.Errorf("MapToBindVars(%v) error: %v, want %s", tcase.in, err, tcase.err) 66 | } 67 | continue 68 | } 69 | if tcase.err != "" { 70 | t.Errorf("MapToBindVars(%v) error: nil, want %s", tcase.in, tcase.err) 71 | continue 72 | } 73 | if !BindVariablesEqual(bindVars, tcase.out) { 74 | t.Errorf("MapToBindVars(%v): %v, want %s", tcase.in, bindVars, tcase.out) 75 | } 76 | } 77 | } 78 | 79 | func TestBuildBindVariable(t *testing.T) { 80 | tcases := []struct { 81 | in interface{} 82 | out *querypb.BindVariable 83 | err string 84 | }{{ 85 | in: "aa", 86 | out: &querypb.BindVariable{ 87 | Type: querypb.Type_VARCHAR, 88 | Value: []byte("aa"), 89 | }, 90 | }, { 91 | in: []byte("aa"), 92 | out: &querypb.BindVariable{ 93 | Type: querypb.Type_VARBINARY, 94 | Value: []byte("aa"), 95 | }, 96 | }, { 97 | in: int(1), 98 | out: &querypb.BindVariable{ 99 | Type: querypb.Type_INT64, 100 | Value: []byte("1"), 101 | }, 102 | }, { 103 | in: int64(1), 104 | out: &querypb.BindVariable{ 105 | Type: querypb.Type_INT64, 106 | Value: []byte("1"), 107 | }, 108 | }, { 109 | in: uint64(1), 110 | out: &querypb.BindVariable{ 111 | Type: querypb.Type_UINT64, 112 | Value: []byte("1"), 113 | }, 114 | }, { 115 | in: float64(1), 116 | out: &querypb.BindVariable{ 117 | Type: querypb.Type_FLOAT64, 118 | Value: []byte("1"), 119 | }, 120 | }, { 121 | in: nil, 122 | out: NullBindVariable, 123 | }, { 124 | in: MakeTrusted(Int64, []byte("1")), 125 | out: &querypb.BindVariable{ 126 | Type: querypb.Type_INT64, 127 | Value: []byte("1"), 128 | }, 129 | }, { 130 | in: &querypb.BindVariable{ 131 | Type: querypb.Type_INT64, 132 | Value: []byte("1"), 133 | }, 134 | out: &querypb.BindVariable{ 135 | Type: querypb.Type_INT64, 136 | Value: []byte("1"), 137 | }, 138 | }, { 139 | in: []interface{}{"aa", int64(1)}, 140 | out: &querypb.BindVariable{ 141 | Type: querypb.Type_TUPLE, 142 | Values: []*querypb.Value{{ 143 | Type: querypb.Type_VARCHAR, 144 | Value: []byte("aa"), 145 | }, { 146 | Type: querypb.Type_INT64, 147 | Value: []byte("1"), 148 | }}, 149 | }, 150 | }, { 151 | in: []string{"aa", "bb"}, 152 | out: &querypb.BindVariable{ 153 | Type: querypb.Type_TUPLE, 154 | Values: []*querypb.Value{{ 155 | Type: querypb.Type_VARCHAR, 156 | Value: []byte("aa"), 157 | }, { 158 | Type: querypb.Type_VARCHAR, 159 | Value: []byte("bb"), 160 | }}, 161 | }, 162 | }, { 163 | in: [][]byte{[]byte("aa"), []byte("bb")}, 164 | out: &querypb.BindVariable{ 165 | Type: querypb.Type_TUPLE, 166 | Values: []*querypb.Value{{ 167 | Type: querypb.Type_VARBINARY, 168 | Value: []byte("aa"), 169 | }, { 170 | Type: querypb.Type_VARBINARY, 171 | Value: []byte("bb"), 172 | }}, 173 | }, 174 | }, { 175 | in: []int{1, 2}, 176 | out: &querypb.BindVariable{ 177 | Type: querypb.Type_TUPLE, 178 | Values: []*querypb.Value{{ 179 | Type: querypb.Type_INT64, 180 | Value: []byte("1"), 181 | }, { 182 | Type: querypb.Type_INT64, 183 | Value: []byte("2"), 184 | }}, 185 | }, 186 | }, { 187 | in: []int64{1, 2}, 188 | out: &querypb.BindVariable{ 189 | Type: querypb.Type_TUPLE, 190 | Values: []*querypb.Value{{ 191 | Type: querypb.Type_INT64, 192 | Value: []byte("1"), 193 | }, { 194 | Type: querypb.Type_INT64, 195 | Value: []byte("2"), 196 | }}, 197 | }, 198 | }, { 199 | in: []uint64{1, 2}, 200 | out: &querypb.BindVariable{ 201 | Type: querypb.Type_TUPLE, 202 | Values: []*querypb.Value{{ 203 | Type: querypb.Type_UINT64, 204 | Value: []byte("1"), 205 | }, { 206 | Type: querypb.Type_UINT64, 207 | Value: []byte("2"), 208 | }}, 209 | }, 210 | }, { 211 | in: []float64{1, 2}, 212 | out: &querypb.BindVariable{ 213 | Type: querypb.Type_TUPLE, 214 | Values: []*querypb.Value{{ 215 | Type: querypb.Type_FLOAT64, 216 | Value: []byte("1"), 217 | }, { 218 | Type: querypb.Type_FLOAT64, 219 | Value: []byte("2"), 220 | }}, 221 | }, 222 | }, { 223 | in: byte(1), 224 | err: "type uint8 not supported as bind var: 1", 225 | }, { 226 | in: []interface{}{1, byte(1)}, 227 | err: "type uint8 not supported as bind var: 1", 228 | }} 229 | for _, tcase := range tcases { 230 | bv, err := BuildBindVariable(tcase.in) 231 | if err != nil { 232 | if err.Error() != tcase.err { 233 | t.Errorf("ToBindVar(%T(%v)) error: %v, want %s", tcase.in, tcase.in, err, tcase.err) 234 | } 235 | continue 236 | } 237 | if tcase.err != "" { 238 | t.Errorf("ToBindVar(%T(%v)) error: nil, want %s", tcase.in, tcase.in, tcase.err) 239 | continue 240 | } 241 | if !reflect.DeepEqual(bv, tcase.out) { 242 | t.Errorf("ToBindVar(%T(%v)): %v, want %s", tcase.in, tcase.in, bv, tcase.out) 243 | } 244 | } 245 | } 246 | 247 | func TestValidateBindVarables(t *testing.T) { 248 | tcases := []struct { 249 | in map[string]*querypb.BindVariable 250 | err string 251 | }{{ 252 | in: map[string]*querypb.BindVariable{ 253 | "v": { 254 | Type: querypb.Type_INT64, 255 | Value: []byte("1"), 256 | }, 257 | }, 258 | err: "", 259 | }, { 260 | in: map[string]*querypb.BindVariable{ 261 | "v": { 262 | Type: querypb.Type_INT64, 263 | Value: []byte("a"), 264 | }, 265 | }, 266 | err: `v: strconv.ParseInt: parsing "a": invalid syntax`, 267 | }, { 268 | in: map[string]*querypb.BindVariable{ 269 | "v": { 270 | Type: querypb.Type_TUPLE, 271 | Values: []*querypb.Value{{ 272 | Type: Int64, 273 | Value: []byte("a"), 274 | }}, 275 | }, 276 | }, 277 | err: `v: strconv.ParseInt: parsing "a": invalid syntax`, 278 | }} 279 | for _, tcase := range tcases { 280 | err := ValidateBindVariables(tcase.in) 281 | if tcase.err != "" { 282 | if err == nil || err.Error() != tcase.err { 283 | t.Errorf("ValidateBindVars(%v): %v, want %s", tcase.in, err, tcase.err) 284 | } 285 | continue 286 | } 287 | if err != nil { 288 | t.Errorf("ValidateBindVars(%v): %v, want nil", tcase.in, err) 289 | } 290 | } 291 | } 292 | 293 | func TestValidateBindVariable(t *testing.T) { 294 | testcases := []struct { 295 | in *querypb.BindVariable 296 | err string 297 | }{{ 298 | in: &querypb.BindVariable{ 299 | Type: querypb.Type_INT8, 300 | Value: []byte("1"), 301 | }, 302 | }, { 303 | in: &querypb.BindVariable{ 304 | Type: querypb.Type_INT16, 305 | Value: []byte("1"), 306 | }, 307 | }, { 308 | in: &querypb.BindVariable{ 309 | Type: querypb.Type_INT24, 310 | Value: []byte("1"), 311 | }, 312 | }, { 313 | in: &querypb.BindVariable{ 314 | Type: querypb.Type_INT32, 315 | Value: []byte("1"), 316 | }, 317 | }, { 318 | in: &querypb.BindVariable{ 319 | Type: querypb.Type_INT64, 320 | Value: []byte("1"), 321 | }, 322 | }, { 323 | in: &querypb.BindVariable{ 324 | Type: querypb.Type_UINT8, 325 | Value: []byte("1"), 326 | }, 327 | }, { 328 | in: &querypb.BindVariable{ 329 | Type: querypb.Type_UINT16, 330 | Value: []byte("1"), 331 | }, 332 | }, { 333 | in: &querypb.BindVariable{ 334 | Type: querypb.Type_UINT24, 335 | Value: []byte("1"), 336 | }, 337 | }, { 338 | in: &querypb.BindVariable{ 339 | Type: querypb.Type_UINT32, 340 | Value: []byte("1"), 341 | }, 342 | }, { 343 | in: &querypb.BindVariable{ 344 | Type: querypb.Type_UINT64, 345 | Value: []byte("1"), 346 | }, 347 | }, { 348 | in: &querypb.BindVariable{ 349 | Type: querypb.Type_FLOAT32, 350 | Value: []byte("1.00"), 351 | }, 352 | }, { 353 | in: &querypb.BindVariable{ 354 | Type: querypb.Type_FLOAT64, 355 | Value: []byte("1.00"), 356 | }, 357 | }, { 358 | in: &querypb.BindVariable{ 359 | Type: querypb.Type_DECIMAL, 360 | Value: []byte("1.00"), 361 | }, 362 | }, { 363 | in: &querypb.BindVariable{ 364 | Type: querypb.Type_TIMESTAMP, 365 | Value: []byte("2012-02-24 23:19:43"), 366 | }, 367 | }, { 368 | in: &querypb.BindVariable{ 369 | Type: querypb.Type_DATE, 370 | Value: []byte("2012-02-24"), 371 | }, 372 | }, { 373 | in: &querypb.BindVariable{ 374 | Type: querypb.Type_TIME, 375 | Value: []byte("23:19:43"), 376 | }, 377 | }, { 378 | in: &querypb.BindVariable{ 379 | Type: querypb.Type_DATETIME, 380 | Value: []byte("2012-02-24 23:19:43"), 381 | }, 382 | }, { 383 | in: &querypb.BindVariable{ 384 | Type: querypb.Type_YEAR, 385 | Value: []byte("1"), 386 | }, 387 | }, { 388 | in: &querypb.BindVariable{ 389 | Type: querypb.Type_TEXT, 390 | Value: []byte("a"), 391 | }, 392 | }, { 393 | in: &querypb.BindVariable{ 394 | Type: querypb.Type_BLOB, 395 | Value: []byte("a"), 396 | }, 397 | }, { 398 | in: &querypb.BindVariable{ 399 | Type: querypb.Type_VARCHAR, 400 | Value: []byte("a"), 401 | }, 402 | }, { 403 | in: &querypb.BindVariable{ 404 | Type: querypb.Type_BINARY, 405 | Value: []byte("a"), 406 | }, 407 | }, { 408 | in: &querypb.BindVariable{ 409 | Type: querypb.Type_CHAR, 410 | Value: []byte("a"), 411 | }, 412 | }, { 413 | in: &querypb.BindVariable{ 414 | Type: querypb.Type_BIT, 415 | Value: []byte("1"), 416 | }, 417 | }, { 418 | in: &querypb.BindVariable{ 419 | Type: querypb.Type_ENUM, 420 | Value: []byte("a"), 421 | }, 422 | }, { 423 | in: &querypb.BindVariable{ 424 | Type: querypb.Type_SET, 425 | Value: []byte("a"), 426 | }, 427 | }, { 428 | in: &querypb.BindVariable{ 429 | Type: querypb.Type_VARBINARY, 430 | Value: []byte("a"), 431 | }, 432 | }, { 433 | in: &querypb.BindVariable{ 434 | Type: querypb.Type_INT64, 435 | Value: []byte(InvalidNeg), 436 | }, 437 | err: "out of range", 438 | }, { 439 | in: &querypb.BindVariable{ 440 | Type: querypb.Type_INT64, 441 | Value: []byte(InvalidPos), 442 | }, 443 | err: "out of range", 444 | }, { 445 | in: &querypb.BindVariable{ 446 | Type: querypb.Type_UINT64, 447 | Value: []byte("-1"), 448 | }, 449 | err: "invalid syntax", 450 | }, { 451 | in: &querypb.BindVariable{ 452 | Type: querypb.Type_UINT64, 453 | Value: []byte(InvalidPos), 454 | }, 455 | err: "out of range", 456 | }, { 457 | in: &querypb.BindVariable{ 458 | Type: querypb.Type_FLOAT64, 459 | Value: []byte("a"), 460 | }, 461 | err: "invalid syntax", 462 | }, { 463 | in: &querypb.BindVariable{ 464 | Type: querypb.Type_EXPRESSION, 465 | Value: []byte("a"), 466 | }, 467 | err: "invalid type specified for MakeValue: EXPRESSION", 468 | }, { 469 | in: &querypb.BindVariable{ 470 | Type: querypb.Type_TUPLE, 471 | Values: []*querypb.Value{{ 472 | Type: querypb.Type_INT64, 473 | Value: []byte("1"), 474 | }}, 475 | }, 476 | }, { 477 | in: &querypb.BindVariable{ 478 | Type: querypb.Type_TUPLE, 479 | }, 480 | err: "empty tuple is not allowed", 481 | }, { 482 | in: &querypb.BindVariable{ 483 | Type: querypb.Type_TUPLE, 484 | Values: []*querypb.Value{{ 485 | Type: querypb.Type_TUPLE, 486 | }}, 487 | }, 488 | err: "tuple not allowed inside another tuple", 489 | }} 490 | for _, tcase := range testcases { 491 | err := ValidateBindVariable(tcase.in) 492 | if tcase.err != "" { 493 | if err == nil || !strings.Contains(err.Error(), tcase.err) { 494 | t.Errorf("ValidateBindVar(%v) error: %v, must contain %v", tcase.in, err, tcase.err) 495 | } 496 | continue 497 | } 498 | if err != nil { 499 | t.Errorf("ValidateBindVar(%v) error: %v", tcase.in, err) 500 | } 501 | } 502 | 503 | // Special case: nil bind var. 504 | err := ValidateBindVariable(nil) 505 | want := "bind variable is nil" 506 | if err == nil || err.Error() != want { 507 | t.Errorf("ValidateBindVar(nil) error: %v, want %s", err, want) 508 | } 509 | } 510 | 511 | func TestBindVariableToValue(t *testing.T) { 512 | v, err := BindVariableToValue(Int64BindVariable(1)) 513 | if err != nil { 514 | t.Error(err) 515 | } 516 | want := MakeTrusted(querypb.Type_INT64, []byte("1")) 517 | if !reflect.DeepEqual(v, want) { 518 | t.Errorf("BindVarToValue(1): %v, want %v", v, want) 519 | } 520 | 521 | v, err = BindVariableToValue(&querypb.BindVariable{Type: querypb.Type_TUPLE}) 522 | wantErr := "cannot convert a TUPLE bind var into a value" 523 | if err == nil || err.Error() != wantErr { 524 | t.Errorf(" BindVarToValue(TUPLE): %v, want %s", err, wantErr) 525 | } 526 | } 527 | 528 | func TestBindVariablesEqual(t *testing.T) { 529 | bv1 := map[string]*querypb.BindVariable{ 530 | "k": { 531 | Type: querypb.Type_INT64, 532 | Value: []byte("1"), 533 | }, 534 | } 535 | bv2 := map[string]*querypb.BindVariable{ 536 | "k": { 537 | Type: querypb.Type_INT64, 538 | Value: []byte("1"), 539 | }, 540 | } 541 | bv3 := map[string]*querypb.BindVariable{ 542 | "k": { 543 | Type: querypb.Type_INT64, 544 | Value: []byte("1"), 545 | }, 546 | } 547 | if !BindVariablesEqual(bv1, bv2) { 548 | t.Errorf("%v != %v, want equal", bv1, bv2) 549 | } 550 | if !BindVariablesEqual(bv1, bv3) { 551 | t.Errorf("%v = %v, want not equal", bv1, bv3) 552 | } 553 | } 554 | -------------------------------------------------------------------------------- /dependency/sqltypes/plan_value.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "encoding/json" 21 | "errors" 22 | "fmt" 23 | 24 | "github.com/xwb1989/sqlparser/dependency/querypb" 25 | ) 26 | 27 | // PlanValue represents a value or a list of values for 28 | // a column that will later be resolved using bind vars and used 29 | // to perform plan actions like generating the final query or 30 | // deciding on a route. 31 | // 32 | // Plan values are typically used as a slice ([]planValue) 33 | // where each entry is for one column. For situations where 34 | // the required output is a list of rows (like in the case 35 | // of multi-value inserts), the representation is pivoted. 36 | // For example, a statement like this: 37 | // INSERT INTO t VALUES (1, 2), (3, 4) 38 | // will be represented as follows: 39 | // []PlanValue{ 40 | // Values: {1, 3}, 41 | // Values: {2, 4}, 42 | // } 43 | // 44 | // For WHERE clause items that contain a combination of 45 | // equality expressions and IN clauses like this: 46 | // WHERE pk1 = 1 AND pk2 IN (2, 3, 4) 47 | // The plan values will be represented as follows: 48 | // []PlanValue{ 49 | // Value: 1, 50 | // Values: {2, 3, 4}, 51 | // } 52 | // When converted into rows, columns with single values 53 | // are replicated as the same for all rows: 54 | // [][]Value{ 55 | // {1, 2}, 56 | // {1, 3}, 57 | // {1, 4}, 58 | // } 59 | type PlanValue struct { 60 | Key string 61 | Value Value 62 | ListKey string 63 | Values []PlanValue 64 | } 65 | 66 | // IsNull returns true if the PlanValue is NULL. 67 | func (pv PlanValue) IsNull() bool { 68 | return pv.Key == "" && pv.Value.IsNull() && pv.ListKey == "" && pv.Values == nil 69 | } 70 | 71 | // IsList returns true if the PlanValue is a list. 72 | func (pv PlanValue) IsList() bool { 73 | return pv.ListKey != "" || pv.Values != nil 74 | } 75 | 76 | // ResolveValue resolves a PlanValue as a single value based on the supplied bindvars. 77 | func (pv PlanValue) ResolveValue(bindVars map[string]*querypb.BindVariable) (Value, error) { 78 | switch { 79 | case pv.Key != "": 80 | bv, err := pv.lookupValue(bindVars) 81 | if err != nil { 82 | return NULL, err 83 | } 84 | return MakeTrusted(bv.Type, bv.Value), nil 85 | case !pv.Value.IsNull(): 86 | return pv.Value, nil 87 | case pv.ListKey != "" || pv.Values != nil: 88 | // This code is unreachable because the parser does not allow 89 | // multi-value constructs where a single value is expected. 90 | return NULL, errors.New("a list was supplied where a single value was expected") 91 | } 92 | return NULL, nil 93 | } 94 | 95 | func (pv PlanValue) lookupValue(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) { 96 | bv, ok := bindVars[pv.Key] 97 | if !ok { 98 | return nil, fmt.Errorf("missing bind var %s", pv.Key) 99 | } 100 | if bv.Type == querypb.Type_TUPLE { 101 | return nil, fmt.Errorf("TUPLE was supplied for single value bind var %s", pv.ListKey) 102 | } 103 | return bv, nil 104 | } 105 | 106 | // ResolveList resolves a PlanValue as a list of values based on the supplied bindvars. 107 | func (pv PlanValue) ResolveList(bindVars map[string]*querypb.BindVariable) ([]Value, error) { 108 | switch { 109 | case pv.ListKey != "": 110 | bv, err := pv.lookupList(bindVars) 111 | if err != nil { 112 | return nil, err 113 | } 114 | values := make([]Value, 0, len(bv.Values)) 115 | for _, val := range bv.Values { 116 | values = append(values, MakeTrusted(val.Type, val.Value)) 117 | } 118 | return values, nil 119 | case pv.Values != nil: 120 | values := make([]Value, 0, len(pv.Values)) 121 | for _, val := range pv.Values { 122 | v, err := val.ResolveValue(bindVars) 123 | if err != nil { 124 | return nil, err 125 | } 126 | values = append(values, v) 127 | } 128 | return values, nil 129 | } 130 | // This code is unreachable because the parser does not allow 131 | // single value constructs where multiple values are expected. 132 | return nil, errors.New("a single value was supplied where a list was expected") 133 | } 134 | 135 | func (pv PlanValue) lookupList(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) { 136 | bv, ok := bindVars[pv.ListKey] 137 | if !ok { 138 | return nil, fmt.Errorf("missing bind var %s", pv.ListKey) 139 | } 140 | if bv.Type != querypb.Type_TUPLE { 141 | return nil, fmt.Errorf("single value was supplied for TUPLE bind var %s", pv.ListKey) 142 | } 143 | return bv, nil 144 | } 145 | 146 | // MarshalJSON should be used only for testing. 147 | func (pv PlanValue) MarshalJSON() ([]byte, error) { 148 | switch { 149 | case pv.Key != "": 150 | return json.Marshal(":" + pv.Key) 151 | case !pv.Value.IsNull(): 152 | if pv.Value.IsIntegral() { 153 | return pv.Value.ToBytes(), nil 154 | } 155 | return json.Marshal(pv.Value.ToString()) 156 | case pv.ListKey != "": 157 | return json.Marshal("::" + pv.ListKey) 158 | case pv.Values != nil: 159 | return json.Marshal(pv.Values) 160 | } 161 | return []byte("null"), nil 162 | } 163 | 164 | func rowCount(pvs []PlanValue, bindVars map[string]*querypb.BindVariable) (int, error) { 165 | count := -1 166 | setCount := func(l int) error { 167 | switch count { 168 | case -1: 169 | count = l 170 | return nil 171 | case l: 172 | return nil 173 | default: 174 | return errors.New("mismatch in number of column values") 175 | } 176 | } 177 | 178 | for _, pv := range pvs { 179 | switch { 180 | case pv.Key != "" || !pv.Value.IsNull(): 181 | continue 182 | case pv.Values != nil: 183 | if err := setCount(len(pv.Values)); err != nil { 184 | return 0, err 185 | } 186 | case pv.ListKey != "": 187 | bv, err := pv.lookupList(bindVars) 188 | if err != nil { 189 | return 0, err 190 | } 191 | if err := setCount(len(bv.Values)); err != nil { 192 | return 0, err 193 | } 194 | } 195 | } 196 | 197 | if count == -1 { 198 | // If there were no lists inside, it was a single row. 199 | // Note that count can never be 0 because there is enough 200 | // protection at the top level: list bind vars must have 201 | // at least one value (enforced by vtgate), and AST lists 202 | // must have at least one value (enforced by the parser). 203 | // Also lists created internally after vtgate validation 204 | // ensure at least one value. 205 | // TODO(sougou): verify and change API to enforce this. 206 | return 1, nil 207 | } 208 | return count, nil 209 | } 210 | 211 | // ResolveRows resolves a []PlanValue as rows based on the supplied bindvars. 212 | func ResolveRows(pvs []PlanValue, bindVars map[string]*querypb.BindVariable) ([][]Value, error) { 213 | count, err := rowCount(pvs, bindVars) 214 | if err != nil { 215 | return nil, err 216 | } 217 | 218 | // Allocate the rows. 219 | rows := make([][]Value, count) 220 | for i := range rows { 221 | rows[i] = make([]Value, len(pvs)) 222 | } 223 | 224 | // Using j becasue we're resolving by columns. 225 | for j, pv := range pvs { 226 | switch { 227 | case pv.Key != "": 228 | bv, err := pv.lookupValue(bindVars) 229 | if err != nil { 230 | return nil, err 231 | } 232 | for i := range rows { 233 | rows[i][j] = MakeTrusted(bv.Type, bv.Value) 234 | } 235 | case !pv.Value.IsNull(): 236 | for i := range rows { 237 | rows[i][j] = pv.Value 238 | } 239 | case pv.ListKey != "": 240 | bv, err := pv.lookupList(bindVars) 241 | if err != nil { 242 | // This code is unreachable because pvRowCount already checks this. 243 | return nil, err 244 | } 245 | for i := range rows { 246 | rows[i][j] = MakeTrusted(bv.Values[i].Type, bv.Values[i].Value) 247 | } 248 | case pv.Values != nil: 249 | for i := range rows { 250 | rows[i][j], err = pv.Values[i].ResolveValue(bindVars) 251 | if err != nil { 252 | return nil, err 253 | } 254 | } 255 | // default case is a NULL value, which the row values are already initialized to. 256 | } 257 | } 258 | return rows, nil 259 | } 260 | -------------------------------------------------------------------------------- /dependency/sqltypes/plan_value_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "reflect" 21 | "strings" 22 | "testing" 23 | 24 | "github.com/xwb1989/sqlparser/dependency/querypb" 25 | ) 26 | 27 | func TestPlanValueIsNull(t *testing.T) { 28 | tcases := []struct { 29 | in PlanValue 30 | out bool 31 | }{{ 32 | in: PlanValue{}, 33 | out: true, 34 | }, { 35 | in: PlanValue{Key: "aa"}, 36 | out: false, 37 | }, { 38 | in: PlanValue{Value: NewVarBinary("aa")}, 39 | out: false, 40 | }, { 41 | in: PlanValue{ListKey: "aa"}, 42 | out: false, 43 | }, { 44 | in: PlanValue{Values: []PlanValue{}}, 45 | out: false, 46 | }} 47 | for _, tc := range tcases { 48 | got := tc.in.IsNull() 49 | if got != tc.out { 50 | t.Errorf("IsNull(%v): %v, want %v", tc.in, got, tc.out) 51 | } 52 | } 53 | } 54 | 55 | func TestPlanValueIsList(t *testing.T) { 56 | tcases := []struct { 57 | in PlanValue 58 | out bool 59 | }{{ 60 | in: PlanValue{}, 61 | out: false, 62 | }, { 63 | in: PlanValue{Key: "aa"}, 64 | out: false, 65 | }, { 66 | in: PlanValue{Value: NewVarBinary("aa")}, 67 | out: false, 68 | }, { 69 | in: PlanValue{ListKey: "aa"}, 70 | out: true, 71 | }, { 72 | in: PlanValue{Values: []PlanValue{}}, 73 | out: true, 74 | }} 75 | for _, tc := range tcases { 76 | got := tc.in.IsList() 77 | if got != tc.out { 78 | t.Errorf("IsList(%v): %v, want %v", tc.in, got, tc.out) 79 | } 80 | } 81 | } 82 | 83 | func TestResolveRows(t *testing.T) { 84 | testBindVars := map[string]*querypb.BindVariable{ 85 | "int": Int64BindVariable(10), 86 | "intstr": TestBindVariable([]interface{}{10, "aa"}), 87 | } 88 | intValue := MakeTrusted(Int64, []byte("10")) 89 | strValue := MakeTrusted(VarChar, []byte("aa")) 90 | tcases := []struct { 91 | in []PlanValue 92 | out [][]Value 93 | err string 94 | }{{ 95 | // Simple cases. 96 | in: []PlanValue{ 97 | {Key: "int"}, 98 | }, 99 | out: [][]Value{ 100 | {intValue}, 101 | }, 102 | }, { 103 | in: []PlanValue{ 104 | {Value: intValue}, 105 | }, 106 | out: [][]Value{ 107 | {intValue}, 108 | }, 109 | }, { 110 | in: []PlanValue{ 111 | {ListKey: "intstr"}, 112 | }, 113 | out: [][]Value{ 114 | {intValue}, 115 | {strValue}, 116 | }, 117 | }, { 118 | in: []PlanValue{ 119 | {Values: []PlanValue{{Value: intValue}, {Value: strValue}}}, 120 | }, 121 | out: [][]Value{ 122 | {intValue}, 123 | {strValue}, 124 | }, 125 | }, { 126 | in: []PlanValue{ 127 | {Values: []PlanValue{{Key: "int"}, {Value: strValue}}}, 128 | }, 129 | out: [][]Value{ 130 | {intValue}, 131 | {strValue}, 132 | }, 133 | }, { 134 | in: []PlanValue{{}}, 135 | out: [][]Value{ 136 | {NULL}, 137 | }, 138 | }, { 139 | // Cases with varying rowcounts. 140 | // All types of input.. 141 | in: []PlanValue{ 142 | {Key: "int"}, 143 | {Value: strValue}, 144 | {ListKey: "intstr"}, 145 | {Values: []PlanValue{{Value: strValue}, {Value: intValue}}}, 146 | }, 147 | out: [][]Value{ 148 | {intValue, strValue, intValue, strValue}, 149 | {intValue, strValue, strValue, intValue}, 150 | }, 151 | }, { 152 | // list, val, list. 153 | in: []PlanValue{ 154 | {Value: strValue}, 155 | {Key: "int"}, 156 | {Values: []PlanValue{{Value: strValue}, {Value: intValue}}}, 157 | }, 158 | out: [][]Value{ 159 | {strValue, intValue, strValue}, 160 | {strValue, intValue, intValue}, 161 | }, 162 | }, { 163 | // list, list 164 | in: []PlanValue{ 165 | {ListKey: "intstr"}, 166 | {Values: []PlanValue{{Value: strValue}, {Value: intValue}}}, 167 | }, 168 | out: [][]Value{ 169 | {intValue, strValue}, 170 | {strValue, intValue}, 171 | }, 172 | }, { 173 | // Error cases 174 | in: []PlanValue{ 175 | {ListKey: "intstr"}, 176 | {Values: []PlanValue{{Value: strValue}}}, 177 | }, 178 | err: "mismatch in number of column values", 179 | }, { 180 | // This is a different code path for a similar validation. 181 | in: []PlanValue{ 182 | {Values: []PlanValue{{Value: strValue}}}, 183 | {ListKey: "intstr"}, 184 | }, 185 | err: "mismatch in number of column values", 186 | }, { 187 | in: []PlanValue{ 188 | {Key: "absent"}, 189 | }, 190 | err: "missing bind var", 191 | }, { 192 | in: []PlanValue{ 193 | {ListKey: "absent"}, 194 | }, 195 | err: "missing bind var", 196 | }, { 197 | in: []PlanValue{ 198 | {Values: []PlanValue{{Key: "absent"}}}, 199 | }, 200 | err: "missing bind var", 201 | }} 202 | 203 | for _, tc := range tcases { 204 | got, err := ResolveRows(tc.in, testBindVars) 205 | if err != nil { 206 | if !strings.Contains(err.Error(), tc.err) { 207 | t.Errorf("ResolveRows(%v) error: %v, want '%s'", tc.in, err, tc.err) 208 | } 209 | continue 210 | } 211 | if tc.err != "" { 212 | t.Errorf("ResolveRows(%v) error: nil, want '%s'", tc.in, tc.err) 213 | continue 214 | } 215 | if !reflect.DeepEqual(got, tc.out) { 216 | t.Errorf("ResolveRows(%v): %v, want %v", tc.in, got, tc.out) 217 | } 218 | } 219 | } 220 | 221 | func TestResolveList(t *testing.T) { 222 | testBindVars := map[string]*querypb.BindVariable{ 223 | "int": Int64BindVariable(10), 224 | "intstr": TestBindVariable([]interface{}{10, "aa"}), 225 | } 226 | intValue := MakeTrusted(Int64, []byte("10")) 227 | strValue := MakeTrusted(VarChar, []byte("aa")) 228 | tcases := []struct { 229 | in PlanValue 230 | out []Value 231 | err string 232 | }{{ 233 | in: PlanValue{ListKey: "intstr"}, 234 | out: []Value{intValue, strValue}, 235 | }, { 236 | in: PlanValue{Values: []PlanValue{{Value: intValue}, {Value: strValue}}}, 237 | out: []Value{intValue, strValue}, 238 | }, { 239 | in: PlanValue{Values: []PlanValue{{Key: "int"}, {Value: strValue}}}, 240 | out: []Value{intValue, strValue}, 241 | }, { 242 | in: PlanValue{ListKey: "absent"}, 243 | err: "missing bind var", 244 | }, { 245 | in: PlanValue{Values: []PlanValue{{Key: "absent"}, {Value: strValue}}}, 246 | err: "missing bind var", 247 | }, { 248 | in: PlanValue{ListKey: "int"}, 249 | err: "single value was supplied for TUPLE bind var", 250 | }, { 251 | in: PlanValue{Key: "int"}, 252 | err: "a single value was supplied where a list was expected", 253 | }} 254 | 255 | for _, tc := range tcases { 256 | got, err := tc.in.ResolveList(testBindVars) 257 | if err != nil { 258 | if !strings.Contains(err.Error(), tc.err) { 259 | t.Errorf("ResolveList(%v) error: %v, want '%s'", tc.in, err, tc.err) 260 | } 261 | continue 262 | } 263 | if tc.err != "" { 264 | t.Errorf("ResolveList(%v) error: nil, want '%s'", tc.in, tc.err) 265 | continue 266 | } 267 | if !reflect.DeepEqual(got, tc.out) { 268 | t.Errorf("ResolveList(%v): %v, want %v", tc.in, got, tc.out) 269 | } 270 | } 271 | } 272 | 273 | func TestResolveValue(t *testing.T) { 274 | testBindVars := map[string]*querypb.BindVariable{ 275 | "int": Int64BindVariable(10), 276 | "intstr": TestBindVariable([]interface{}{10, "aa"}), 277 | } 278 | intValue := MakeTrusted(Int64, []byte("10")) 279 | tcases := []struct { 280 | in PlanValue 281 | out Value 282 | err string 283 | }{{ 284 | in: PlanValue{Key: "int"}, 285 | out: intValue, 286 | }, { 287 | in: PlanValue{Value: intValue}, 288 | out: intValue, 289 | }, { 290 | in: PlanValue{}, 291 | out: NULL, 292 | }, { 293 | in: PlanValue{Key: "absent"}, 294 | err: "missing bind var", 295 | }, { 296 | in: PlanValue{Key: "intstr"}, 297 | err: "TUPLE was supplied for single value bind var", 298 | }, { 299 | in: PlanValue{ListKey: "intstr"}, 300 | err: "a list was supplied where a single value was expected", 301 | }} 302 | 303 | for _, tc := range tcases { 304 | got, err := tc.in.ResolveValue(testBindVars) 305 | if err != nil { 306 | if !strings.Contains(err.Error(), tc.err) { 307 | t.Errorf("ResolveValue(%v) error: %v, want '%s'", tc.in, err, tc.err) 308 | } 309 | continue 310 | } 311 | if tc.err != "" { 312 | t.Errorf("ResolveValue(%v) error: nil, want '%s'", tc.in, tc.err) 313 | continue 314 | } 315 | if !reflect.DeepEqual(got, tc.out) { 316 | t.Errorf("ResolveValue(%v): %v, want %v", tc.in, got, tc.out) 317 | } 318 | } 319 | } 320 | -------------------------------------------------------------------------------- /dependency/sqltypes/testing.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | querypb "github.com/xwb1989/sqlparser/dependency/querypb" 21 | ) 22 | 23 | // Functions in this file should only be used for testing. 24 | // This is an experiment to see if test code bloat can be 25 | // reduced and readability improved. 26 | 27 | /* 28 | // MakeTestFields builds a []*querypb.Field for testing. 29 | // fields := sqltypes.MakeTestFields( 30 | // "a|b", 31 | // "int64|varchar", 32 | // ) 33 | // The field types are as defined in querypb and are case 34 | // insensitive. Column delimiters must be used only to sepearate 35 | // strings and not at the beginning or the end. 36 | func MakeTestFields(names, types string) []*querypb.Field { 37 | n := split(names) 38 | t := split(types) 39 | var fields []*querypb.Field 40 | for i := range n { 41 | fields = append(fields, &querypb.Field{ 42 | Name: n[i], 43 | Type: querypb.Type(querypb.Type_value[strings.ToUpper(t[i])]), 44 | }) 45 | } 46 | return fields 47 | } 48 | 49 | // MakeTestResult builds a *sqltypes.Result object for testing. 50 | // result := sqltypes.MakeTestResult( 51 | // fields, 52 | // " 1|a", 53 | // "10|abcd", 54 | // ) 55 | // The field type values are set as the types for the rows built. 56 | // Spaces are trimmed from row values. "null" is treated as NULL. 57 | func MakeTestResult(fields []*querypb.Field, rows ...string) *Result { 58 | result := &Result{ 59 | Fields: fields, 60 | } 61 | if len(rows) > 0 { 62 | result.Rows = make([][]Value, len(rows)) 63 | } 64 | for i, row := range rows { 65 | result.Rows[i] = make([]Value, len(fields)) 66 | for j, col := range split(row) { 67 | if col == "null" { 68 | continue 69 | } 70 | result.Rows[i][j] = MakeTrusted(fields[j].Type, []byte(col)) 71 | } 72 | } 73 | result.RowsAffected = uint64(len(result.Rows)) 74 | return result 75 | } 76 | 77 | // MakeTestStreamingResults builds a list of results for streaming. 78 | // results := sqltypes.MakeStreamingResults( 79 | // fields, 80 | // "1|a", 81 | // "2|b", 82 | // "---", 83 | // "c|c", 84 | // ) 85 | // The first result contains only the fields. Subsequent results 86 | // are built using the field types. Every input that starts with a "-" 87 | // is treated as streaming delimiter for one result. A final 88 | // delimiter must not be supplied. 89 | func MakeTestStreamingResults(fields []*querypb.Field, rows ...string) []*Result { 90 | var results []*Result 91 | results = append(results, &Result{Fields: fields}) 92 | start := 0 93 | cur := 0 94 | // Add a final streaming delimiter to simplify the loop below. 95 | rows = append(rows, "-") 96 | for cur < len(rows) { 97 | if rows[cur][0] != '-' { 98 | cur++ 99 | continue 100 | } 101 | result := MakeTestResult(fields, rows[start:cur]...) 102 | result.Fields = nil 103 | result.RowsAffected = 0 104 | results = append(results, result) 105 | start = cur + 1 106 | cur = start 107 | } 108 | return results 109 | } 110 | */ 111 | 112 | // TestBindVariable makes a *querypb.BindVariable from 113 | // an interface{}.It panics on invalid input. 114 | // This function should only be used for testing. 115 | func TestBindVariable(v interface{}) *querypb.BindVariable { 116 | if v == nil { 117 | return NullBindVariable 118 | } 119 | bv, err := BuildBindVariable(v) 120 | if err != nil { 121 | panic(err) 122 | } 123 | return bv 124 | } 125 | 126 | // TestValue builds a Value from typ and val. 127 | // This function should only be used for testing. 128 | func TestValue(typ querypb.Type, val string) Value { 129 | return MakeTrusted(typ, []byte(val)) 130 | } 131 | 132 | /* 133 | // PrintResults prints []*Results into a string. 134 | // This function should only be used for testing. 135 | func PrintResults(results []*Result) string { 136 | b := new(bytes.Buffer) 137 | for i, r := range results { 138 | if i == 0 { 139 | fmt.Fprintf(b, "%v", r) 140 | continue 141 | } 142 | fmt.Fprintf(b, ", %v", r) 143 | } 144 | return b.String() 145 | } 146 | 147 | func split(str string) []string { 148 | splits := strings.Split(str, "|") 149 | for i, v := range splits { 150 | splits[i] = strings.TrimSpace(v) 151 | } 152 | return splits 153 | } 154 | */ 155 | -------------------------------------------------------------------------------- /dependency/sqltypes/type.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "fmt" 21 | 22 | "github.com/xwb1989/sqlparser/dependency/querypb" 23 | ) 24 | 25 | // This file provides wrappers and support 26 | // functions for querypb.Type. 27 | 28 | // These bit flags can be used to query on the 29 | // common properties of types. 30 | const ( 31 | flagIsIntegral = int(querypb.Flag_ISINTEGRAL) 32 | flagIsUnsigned = int(querypb.Flag_ISUNSIGNED) 33 | flagIsFloat = int(querypb.Flag_ISFLOAT) 34 | flagIsQuoted = int(querypb.Flag_ISQUOTED) 35 | flagIsText = int(querypb.Flag_ISTEXT) 36 | flagIsBinary = int(querypb.Flag_ISBINARY) 37 | ) 38 | 39 | // IsIntegral returns true if querypb.Type is an integral 40 | // (signed/unsigned) that can be represented using 41 | // up to 64 binary bits. 42 | // If you have a Value object, use its member function. 43 | func IsIntegral(t querypb.Type) bool { 44 | return int(t)&flagIsIntegral == flagIsIntegral 45 | } 46 | 47 | // IsSigned returns true if querypb.Type is a signed integral. 48 | // If you have a Value object, use its member function. 49 | func IsSigned(t querypb.Type) bool { 50 | return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral 51 | } 52 | 53 | // IsUnsigned returns true if querypb.Type is an unsigned integral. 54 | // Caution: this is not the same as !IsSigned. 55 | // If you have a Value object, use its member function. 56 | func IsUnsigned(t querypb.Type) bool { 57 | return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral|flagIsUnsigned 58 | } 59 | 60 | // IsFloat returns true is querypb.Type is a floating point. 61 | // If you have a Value object, use its member function. 62 | func IsFloat(t querypb.Type) bool { 63 | return int(t)&flagIsFloat == flagIsFloat 64 | } 65 | 66 | // IsQuoted returns true if querypb.Type is a quoted text or binary. 67 | // If you have a Value object, use its member function. 68 | func IsQuoted(t querypb.Type) bool { 69 | return int(t)&flagIsQuoted == flagIsQuoted 70 | } 71 | 72 | // IsText returns true if querypb.Type is a text. 73 | // If you have a Value object, use its member function. 74 | func IsText(t querypb.Type) bool { 75 | return int(t)&flagIsText == flagIsText 76 | } 77 | 78 | // IsBinary returns true if querypb.Type is a binary. 79 | // If you have a Value object, use its member function. 80 | func IsBinary(t querypb.Type) bool { 81 | return int(t)&flagIsBinary == flagIsBinary 82 | } 83 | 84 | // isNumber returns true if the type is any type of number. 85 | func isNumber(t querypb.Type) bool { 86 | return IsIntegral(t) || IsFloat(t) || t == Decimal 87 | } 88 | 89 | // Vitess data types. These are idiomatically 90 | // named synonyms for the querypb.Type values. 91 | // Although these constants are interchangeable, 92 | // they should be treated as different from querypb.Type. 93 | // Use the synonyms only to refer to the type in Value. 94 | // For proto variables, use the querypb.Type constants 95 | // instead. 96 | // The following conditions are non-overlapping 97 | // and cover all types: IsSigned(), IsUnsigned(), 98 | // IsFloat(), IsQuoted(), Null, Decimal, Expression. 99 | // Also, IsIntegral() == (IsSigned()||IsUnsigned()). 100 | // TestCategory needs to be updated accordingly if 101 | // you add a new type. 102 | // If IsBinary or IsText is true, then IsQuoted is 103 | // also true. But there are IsQuoted types that are 104 | // neither binary or text. 105 | // querypb.Type_TUPLE is not included in this list 106 | // because it's not a valid Value type. 107 | // TODO(sougou): provide a categorization function 108 | // that returns enums, which will allow for cleaner 109 | // switch statements for those who want to cover types 110 | // by their category. 111 | const ( 112 | Null = querypb.Type_NULL_TYPE 113 | Int8 = querypb.Type_INT8 114 | Uint8 = querypb.Type_UINT8 115 | Int16 = querypb.Type_INT16 116 | Uint16 = querypb.Type_UINT16 117 | Int24 = querypb.Type_INT24 118 | Uint24 = querypb.Type_UINT24 119 | Int32 = querypb.Type_INT32 120 | Uint32 = querypb.Type_UINT32 121 | Int64 = querypb.Type_INT64 122 | Uint64 = querypb.Type_UINT64 123 | Float32 = querypb.Type_FLOAT32 124 | Float64 = querypb.Type_FLOAT64 125 | Timestamp = querypb.Type_TIMESTAMP 126 | Date = querypb.Type_DATE 127 | Time = querypb.Type_TIME 128 | Datetime = querypb.Type_DATETIME 129 | Year = querypb.Type_YEAR 130 | Decimal = querypb.Type_DECIMAL 131 | Text = querypb.Type_TEXT 132 | Blob = querypb.Type_BLOB 133 | VarChar = querypb.Type_VARCHAR 134 | VarBinary = querypb.Type_VARBINARY 135 | Char = querypb.Type_CHAR 136 | Binary = querypb.Type_BINARY 137 | Bit = querypb.Type_BIT 138 | Enum = querypb.Type_ENUM 139 | Set = querypb.Type_SET 140 | Geometry = querypb.Type_GEOMETRY 141 | TypeJSON = querypb.Type_JSON 142 | Expression = querypb.Type_EXPRESSION 143 | ) 144 | 145 | // bit-shift the mysql flags by two byte so we 146 | // can merge them with the mysql or vitess types. 147 | const ( 148 | mysqlUnsigned = 32 149 | mysqlBinary = 128 150 | mysqlEnum = 256 151 | mysqlSet = 2048 152 | ) 153 | 154 | // If you add to this map, make sure you add a test case 155 | // in tabletserver/endtoend. 156 | var mysqlToType = map[int64]querypb.Type{ 157 | 1: Int8, 158 | 2: Int16, 159 | 3: Int32, 160 | 4: Float32, 161 | 5: Float64, 162 | 6: Null, 163 | 7: Timestamp, 164 | 8: Int64, 165 | 9: Int24, 166 | 10: Date, 167 | 11: Time, 168 | 12: Datetime, 169 | 13: Year, 170 | 16: Bit, 171 | 245: TypeJSON, 172 | 246: Decimal, 173 | 249: Text, 174 | 250: Text, 175 | 251: Text, 176 | 252: Text, 177 | 253: VarChar, 178 | 254: Char, 179 | 255: Geometry, 180 | } 181 | 182 | // modifyType modifies the vitess type based on the 183 | // mysql flag. The function checks specific flags based 184 | // on the type. This allows us to ignore stray flags 185 | // that MySQL occasionally sets. 186 | func modifyType(typ querypb.Type, flags int64) querypb.Type { 187 | switch typ { 188 | case Int8: 189 | if flags&mysqlUnsigned != 0 { 190 | return Uint8 191 | } 192 | return Int8 193 | case Int16: 194 | if flags&mysqlUnsigned != 0 { 195 | return Uint16 196 | } 197 | return Int16 198 | case Int32: 199 | if flags&mysqlUnsigned != 0 { 200 | return Uint32 201 | } 202 | return Int32 203 | case Int64: 204 | if flags&mysqlUnsigned != 0 { 205 | return Uint64 206 | } 207 | return Int64 208 | case Int24: 209 | if flags&mysqlUnsigned != 0 { 210 | return Uint24 211 | } 212 | return Int24 213 | case Text: 214 | if flags&mysqlBinary != 0 { 215 | return Blob 216 | } 217 | return Text 218 | case VarChar: 219 | if flags&mysqlBinary != 0 { 220 | return VarBinary 221 | } 222 | return VarChar 223 | case Char: 224 | if flags&mysqlBinary != 0 { 225 | return Binary 226 | } 227 | if flags&mysqlEnum != 0 { 228 | return Enum 229 | } 230 | if flags&mysqlSet != 0 { 231 | return Set 232 | } 233 | return Char 234 | } 235 | return typ 236 | } 237 | 238 | // MySQLToType computes the vitess type from mysql type and flags. 239 | func MySQLToType(mysqlType, flags int64) (typ querypb.Type, err error) { 240 | result, ok := mysqlToType[mysqlType] 241 | if !ok { 242 | return 0, fmt.Errorf("unsupported type: %d", mysqlType) 243 | } 244 | return modifyType(result, flags), nil 245 | } 246 | 247 | // typeToMySQL is the reverse of mysqlToType. 248 | var typeToMySQL = map[querypb.Type]struct { 249 | typ int64 250 | flags int64 251 | }{ 252 | Int8: {typ: 1}, 253 | Uint8: {typ: 1, flags: mysqlUnsigned}, 254 | Int16: {typ: 2}, 255 | Uint16: {typ: 2, flags: mysqlUnsigned}, 256 | Int32: {typ: 3}, 257 | Uint32: {typ: 3, flags: mysqlUnsigned}, 258 | Float32: {typ: 4}, 259 | Float64: {typ: 5}, 260 | Null: {typ: 6, flags: mysqlBinary}, 261 | Timestamp: {typ: 7}, 262 | Int64: {typ: 8}, 263 | Uint64: {typ: 8, flags: mysqlUnsigned}, 264 | Int24: {typ: 9}, 265 | Uint24: {typ: 9, flags: mysqlUnsigned}, 266 | Date: {typ: 10, flags: mysqlBinary}, 267 | Time: {typ: 11, flags: mysqlBinary}, 268 | Datetime: {typ: 12, flags: mysqlBinary}, 269 | Year: {typ: 13, flags: mysqlUnsigned}, 270 | Bit: {typ: 16, flags: mysqlUnsigned}, 271 | TypeJSON: {typ: 245}, 272 | Decimal: {typ: 246}, 273 | Text: {typ: 252}, 274 | Blob: {typ: 252, flags: mysqlBinary}, 275 | VarChar: {typ: 253}, 276 | VarBinary: {typ: 253, flags: mysqlBinary}, 277 | Char: {typ: 254}, 278 | Binary: {typ: 254, flags: mysqlBinary}, 279 | Enum: {typ: 254, flags: mysqlEnum}, 280 | Set: {typ: 254, flags: mysqlSet}, 281 | Geometry: {typ: 255}, 282 | } 283 | 284 | // TypeToMySQL returns the equivalent mysql type and flag for a vitess type. 285 | func TypeToMySQL(typ querypb.Type) (mysqlType, flags int64) { 286 | val := typeToMySQL[typ] 287 | return val.typ, val.flags 288 | } 289 | -------------------------------------------------------------------------------- /dependency/sqltypes/type_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/xwb1989/sqlparser/dependency/querypb" 23 | ) 24 | 25 | func TestTypeValues(t *testing.T) { 26 | testcases := []struct { 27 | defined querypb.Type 28 | expected int 29 | }{{ 30 | defined: Null, 31 | expected: 0, 32 | }, { 33 | defined: Int8, 34 | expected: 1 | flagIsIntegral, 35 | }, { 36 | defined: Uint8, 37 | expected: 2 | flagIsIntegral | flagIsUnsigned, 38 | }, { 39 | defined: Int16, 40 | expected: 3 | flagIsIntegral, 41 | }, { 42 | defined: Uint16, 43 | expected: 4 | flagIsIntegral | flagIsUnsigned, 44 | }, { 45 | defined: Int24, 46 | expected: 5 | flagIsIntegral, 47 | }, { 48 | defined: Uint24, 49 | expected: 6 | flagIsIntegral | flagIsUnsigned, 50 | }, { 51 | defined: Int32, 52 | expected: 7 | flagIsIntegral, 53 | }, { 54 | defined: Uint32, 55 | expected: 8 | flagIsIntegral | flagIsUnsigned, 56 | }, { 57 | defined: Int64, 58 | expected: 9 | flagIsIntegral, 59 | }, { 60 | defined: Uint64, 61 | expected: 10 | flagIsIntegral | flagIsUnsigned, 62 | }, { 63 | defined: Float32, 64 | expected: 11 | flagIsFloat, 65 | }, { 66 | defined: Float64, 67 | expected: 12 | flagIsFloat, 68 | }, { 69 | defined: Timestamp, 70 | expected: 13 | flagIsQuoted, 71 | }, { 72 | defined: Date, 73 | expected: 14 | flagIsQuoted, 74 | }, { 75 | defined: Time, 76 | expected: 15 | flagIsQuoted, 77 | }, { 78 | defined: Datetime, 79 | expected: 16 | flagIsQuoted, 80 | }, { 81 | defined: Year, 82 | expected: 17 | flagIsIntegral | flagIsUnsigned, 83 | }, { 84 | defined: Decimal, 85 | expected: 18, 86 | }, { 87 | defined: Text, 88 | expected: 19 | flagIsQuoted | flagIsText, 89 | }, { 90 | defined: Blob, 91 | expected: 20 | flagIsQuoted | flagIsBinary, 92 | }, { 93 | defined: VarChar, 94 | expected: 21 | flagIsQuoted | flagIsText, 95 | }, { 96 | defined: VarBinary, 97 | expected: 22 | flagIsQuoted | flagIsBinary, 98 | }, { 99 | defined: Char, 100 | expected: 23 | flagIsQuoted | flagIsText, 101 | }, { 102 | defined: Binary, 103 | expected: 24 | flagIsQuoted | flagIsBinary, 104 | }, { 105 | defined: Bit, 106 | expected: 25 | flagIsQuoted, 107 | }, { 108 | defined: Enum, 109 | expected: 26 | flagIsQuoted, 110 | }, { 111 | defined: Set, 112 | expected: 27 | flagIsQuoted, 113 | }, { 114 | defined: Geometry, 115 | expected: 29 | flagIsQuoted, 116 | }, { 117 | defined: TypeJSON, 118 | expected: 30 | flagIsQuoted, 119 | }, { 120 | defined: Expression, 121 | expected: 31, 122 | }} 123 | for _, tcase := range testcases { 124 | if int(tcase.defined) != tcase.expected { 125 | t.Errorf("Type %s: %d, want: %d", tcase.defined, int(tcase.defined), tcase.expected) 126 | } 127 | } 128 | } 129 | 130 | // TestCategory verifies that the type categorizations 131 | // are non-overlapping and complete. 132 | func TestCategory(t *testing.T) { 133 | alltypes := []querypb.Type{ 134 | Null, 135 | Int8, 136 | Uint8, 137 | Int16, 138 | Uint16, 139 | Int24, 140 | Uint24, 141 | Int32, 142 | Uint32, 143 | Int64, 144 | Uint64, 145 | Float32, 146 | Float64, 147 | Timestamp, 148 | Date, 149 | Time, 150 | Datetime, 151 | Year, 152 | Decimal, 153 | Text, 154 | Blob, 155 | VarChar, 156 | VarBinary, 157 | Char, 158 | Binary, 159 | Bit, 160 | Enum, 161 | Set, 162 | Geometry, 163 | TypeJSON, 164 | Expression, 165 | } 166 | for _, typ := range alltypes { 167 | matched := false 168 | if IsSigned(typ) { 169 | if !IsIntegral(typ) { 170 | t.Errorf("Signed type %v is not an integral", typ) 171 | } 172 | matched = true 173 | } 174 | if IsUnsigned(typ) { 175 | if !IsIntegral(typ) { 176 | t.Errorf("Unsigned type %v is not an integral", typ) 177 | } 178 | if matched { 179 | t.Errorf("%v matched more than one category", typ) 180 | } 181 | matched = true 182 | } 183 | if IsFloat(typ) { 184 | if matched { 185 | t.Errorf("%v matched more than one category", typ) 186 | } 187 | matched = true 188 | } 189 | if IsQuoted(typ) { 190 | if matched { 191 | t.Errorf("%v matched more than one category", typ) 192 | } 193 | matched = true 194 | } 195 | if typ == Null || typ == Decimal || typ == Expression { 196 | if matched { 197 | t.Errorf("%v matched more than one category", typ) 198 | } 199 | matched = true 200 | } 201 | if !matched { 202 | t.Errorf("%v matched no category", typ) 203 | } 204 | } 205 | } 206 | 207 | func TestIsFunctions(t *testing.T) { 208 | if IsIntegral(Null) { 209 | t.Error("Null: IsIntegral, must be false") 210 | } 211 | if !IsIntegral(Int64) { 212 | t.Error("Int64: !IsIntegral, must be true") 213 | } 214 | if IsSigned(Uint64) { 215 | t.Error("Uint64: IsSigned, must be false") 216 | } 217 | if !IsSigned(Int64) { 218 | t.Error("Int64: !IsSigned, must be true") 219 | } 220 | if IsUnsigned(Int64) { 221 | t.Error("Int64: IsUnsigned, must be false") 222 | } 223 | if !IsUnsigned(Uint64) { 224 | t.Error("Uint64: !IsUnsigned, must be true") 225 | } 226 | if IsFloat(Int64) { 227 | t.Error("Int64: IsFloat, must be false") 228 | } 229 | if !IsFloat(Float64) { 230 | t.Error("Uint64: !IsFloat, must be true") 231 | } 232 | if IsQuoted(Int64) { 233 | t.Error("Int64: IsQuoted, must be false") 234 | } 235 | if !IsQuoted(Binary) { 236 | t.Error("Binary: !IsQuoted, must be true") 237 | } 238 | if IsText(Int64) { 239 | t.Error("Int64: IsText, must be false") 240 | } 241 | if !IsText(Char) { 242 | t.Error("Char: !IsText, must be true") 243 | } 244 | if IsBinary(Int64) { 245 | t.Error("Int64: IsBinary, must be false") 246 | } 247 | if !IsBinary(Binary) { 248 | t.Error("Char: !IsBinary, must be true") 249 | } 250 | if !isNumber(Int64) { 251 | t.Error("Int64: !isNumber, must be true") 252 | } 253 | } 254 | 255 | func TestTypeToMySQL(t *testing.T) { 256 | v, f := TypeToMySQL(Bit) 257 | if v != 16 { 258 | t.Errorf("Bit: %d, want 16", v) 259 | } 260 | if f != mysqlUnsigned { 261 | t.Errorf("Bit flag: %x, want %x", f, mysqlUnsigned) 262 | } 263 | v, f = TypeToMySQL(Date) 264 | if v != 10 { 265 | t.Errorf("Bit: %d, want 10", v) 266 | } 267 | if f != mysqlBinary { 268 | t.Errorf("Bit flag: %x, want %x", f, mysqlBinary) 269 | } 270 | } 271 | 272 | func TestMySQLToType(t *testing.T) { 273 | testcases := []struct { 274 | intype int64 275 | inflags int64 276 | outtype querypb.Type 277 | }{{ 278 | intype: 1, 279 | outtype: Int8, 280 | }, { 281 | intype: 1, 282 | inflags: mysqlUnsigned, 283 | outtype: Uint8, 284 | }, { 285 | intype: 2, 286 | outtype: Int16, 287 | }, { 288 | intype: 2, 289 | inflags: mysqlUnsigned, 290 | outtype: Uint16, 291 | }, { 292 | intype: 3, 293 | outtype: Int32, 294 | }, { 295 | intype: 3, 296 | inflags: mysqlUnsigned, 297 | outtype: Uint32, 298 | }, { 299 | intype: 4, 300 | outtype: Float32, 301 | }, { 302 | intype: 5, 303 | outtype: Float64, 304 | }, { 305 | intype: 6, 306 | outtype: Null, 307 | }, { 308 | intype: 7, 309 | outtype: Timestamp, 310 | }, { 311 | intype: 8, 312 | outtype: Int64, 313 | }, { 314 | intype: 8, 315 | inflags: mysqlUnsigned, 316 | outtype: Uint64, 317 | }, { 318 | intype: 9, 319 | outtype: Int24, 320 | }, { 321 | intype: 9, 322 | inflags: mysqlUnsigned, 323 | outtype: Uint24, 324 | }, { 325 | intype: 10, 326 | outtype: Date, 327 | }, { 328 | intype: 11, 329 | outtype: Time, 330 | }, { 331 | intype: 12, 332 | outtype: Datetime, 333 | }, { 334 | intype: 13, 335 | outtype: Year, 336 | }, { 337 | intype: 16, 338 | outtype: Bit, 339 | }, { 340 | intype: 245, 341 | outtype: TypeJSON, 342 | }, { 343 | intype: 246, 344 | outtype: Decimal, 345 | }, { 346 | intype: 249, 347 | outtype: Text, 348 | }, { 349 | intype: 250, 350 | outtype: Text, 351 | }, { 352 | intype: 251, 353 | outtype: Text, 354 | }, { 355 | intype: 252, 356 | outtype: Text, 357 | }, { 358 | intype: 252, 359 | inflags: mysqlBinary, 360 | outtype: Blob, 361 | }, { 362 | intype: 253, 363 | outtype: VarChar, 364 | }, { 365 | intype: 253, 366 | inflags: mysqlBinary, 367 | outtype: VarBinary, 368 | }, { 369 | intype: 254, 370 | outtype: Char, 371 | }, { 372 | intype: 254, 373 | inflags: mysqlBinary, 374 | outtype: Binary, 375 | }, { 376 | intype: 254, 377 | inflags: mysqlEnum, 378 | outtype: Enum, 379 | }, { 380 | intype: 254, 381 | inflags: mysqlSet, 382 | outtype: Set, 383 | }, { 384 | intype: 255, 385 | outtype: Geometry, 386 | }, { 387 | // Binary flag must be ignored. 388 | intype: 8, 389 | inflags: mysqlUnsigned | mysqlBinary, 390 | outtype: Uint64, 391 | }, { 392 | // Unsigned flag must be ignored 393 | intype: 252, 394 | inflags: mysqlUnsigned | mysqlBinary, 395 | outtype: Blob, 396 | }} 397 | for _, tcase := range testcases { 398 | got, err := MySQLToType(tcase.intype, tcase.inflags) 399 | if err != nil { 400 | t.Error(err) 401 | } 402 | if got != tcase.outtype { 403 | t.Errorf("MySQLToType(%d, %x): %v, want %v", tcase.intype, tcase.inflags, got, tcase.outtype) 404 | } 405 | } 406 | } 407 | 408 | func TestTypeError(t *testing.T) { 409 | _, err := MySQLToType(15, 0) 410 | want := "unsupported type: 15" 411 | if err == nil || err.Error() != want { 412 | t.Errorf("MySQLToType: %v, want %s", err, want) 413 | } 414 | } 415 | -------------------------------------------------------------------------------- /dependency/sqltypes/value.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package sqltypes implements interfaces and types that represent SQL values. 18 | package sqltypes 19 | 20 | import ( 21 | "encoding/base64" 22 | "encoding/json" 23 | "fmt" 24 | "strconv" 25 | 26 | "github.com/xwb1989/sqlparser/dependency/bytes2" 27 | "github.com/xwb1989/sqlparser/dependency/hack" 28 | 29 | "github.com/xwb1989/sqlparser/dependency/querypb" 30 | ) 31 | 32 | var ( 33 | // NULL represents the NULL value. 34 | NULL = Value{} 35 | 36 | // DontEscape tells you if a character should not be escaped. 37 | DontEscape = byte(255) 38 | 39 | nullstr = []byte("null") 40 | ) 41 | 42 | // BinWriter interface is used for encoding values. 43 | // Types like bytes.Buffer conform to this interface. 44 | // We expect the writer objects to be in-memory buffers. 45 | // So, we don't expect the write operations to fail. 46 | type BinWriter interface { 47 | Write([]byte) (int, error) 48 | } 49 | 50 | // Value can store any SQL value. If the value represents 51 | // an integral type, the bytes are always stored as a cannonical 52 | // representation that matches how MySQL returns such values. 53 | type Value struct { 54 | typ querypb.Type 55 | val []byte 56 | } 57 | 58 | // NewValue builds a Value using typ and val. If the value and typ 59 | // don't match, it returns an error. 60 | func NewValue(typ querypb.Type, val []byte) (v Value, err error) { 61 | switch { 62 | case IsSigned(typ): 63 | if _, err := strconv.ParseInt(string(val), 0, 64); err != nil { 64 | return NULL, err 65 | } 66 | return MakeTrusted(typ, val), nil 67 | case IsUnsigned(typ): 68 | if _, err := strconv.ParseUint(string(val), 0, 64); err != nil { 69 | return NULL, err 70 | } 71 | return MakeTrusted(typ, val), nil 72 | case IsFloat(typ) || typ == Decimal: 73 | if _, err := strconv.ParseFloat(string(val), 64); err != nil { 74 | return NULL, err 75 | } 76 | return MakeTrusted(typ, val), nil 77 | case IsQuoted(typ) || typ == Null: 78 | return MakeTrusted(typ, val), nil 79 | } 80 | // All other types are unsafe or invalid. 81 | return NULL, fmt.Errorf("invalid type specified for MakeValue: %v", typ) 82 | } 83 | 84 | // MakeTrusted makes a new Value based on the type. 85 | // This function should only be used if you know the value 86 | // and type conform to the rules. Every place this function is 87 | // called, a comment is needed that explains why it's justified. 88 | // Exceptions: The current package and mysql package do not need 89 | // comments. Other packages can also use the function to create 90 | // VarBinary or VarChar values. 91 | func MakeTrusted(typ querypb.Type, val []byte) Value { 92 | if typ == Null { 93 | return NULL 94 | } 95 | return Value{typ: typ, val: val} 96 | } 97 | 98 | // NewInt64 builds an Int64 Value. 99 | func NewInt64(v int64) Value { 100 | return MakeTrusted(Int64, strconv.AppendInt(nil, v, 10)) 101 | } 102 | 103 | // NewInt32 builds an Int64 Value. 104 | func NewInt32(v int32) Value { 105 | return MakeTrusted(Int32, strconv.AppendInt(nil, int64(v), 10)) 106 | } 107 | 108 | // NewUint64 builds an Uint64 Value. 109 | func NewUint64(v uint64) Value { 110 | return MakeTrusted(Uint64, strconv.AppendUint(nil, v, 10)) 111 | } 112 | 113 | // NewFloat64 builds an Float64 Value. 114 | func NewFloat64(v float64) Value { 115 | return MakeTrusted(Float64, strconv.AppendFloat(nil, v, 'g', -1, 64)) 116 | } 117 | 118 | // NewVarChar builds a VarChar Value. 119 | func NewVarChar(v string) Value { 120 | return MakeTrusted(VarChar, []byte(v)) 121 | } 122 | 123 | // NewVarBinary builds a VarBinary Value. 124 | // The input is a string because it's the most common use case. 125 | func NewVarBinary(v string) Value { 126 | return MakeTrusted(VarBinary, []byte(v)) 127 | } 128 | 129 | // NewIntegral builds an integral type from a string representaion. 130 | // The type will be Int64 or Uint64. Int64 will be preferred where possible. 131 | func NewIntegral(val string) (n Value, err error) { 132 | signed, err := strconv.ParseInt(val, 0, 64) 133 | if err == nil { 134 | return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil 135 | } 136 | unsigned, err := strconv.ParseUint(val, 0, 64) 137 | if err != nil { 138 | return Value{}, err 139 | } 140 | return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil 141 | } 142 | 143 | // InterfaceToValue builds a value from a go type. 144 | // Supported types are nil, int64, uint64, float64, 145 | // string and []byte. 146 | // This function is deprecated. Use the type-specific 147 | // functions instead. 148 | func InterfaceToValue(goval interface{}) (Value, error) { 149 | switch goval := goval.(type) { 150 | case nil: 151 | return NULL, nil 152 | case []byte: 153 | return MakeTrusted(VarBinary, goval), nil 154 | case int64: 155 | return NewInt64(goval), nil 156 | case uint64: 157 | return NewUint64(goval), nil 158 | case float64: 159 | return NewFloat64(goval), nil 160 | case string: 161 | return NewVarChar(goval), nil 162 | default: 163 | return NULL, fmt.Errorf("unexpected type %T: %v", goval, goval) 164 | } 165 | } 166 | 167 | // Type returns the type of Value. 168 | func (v Value) Type() querypb.Type { 169 | return v.typ 170 | } 171 | 172 | // Raw returns the internal represenation of the value. For newer types, 173 | // this may not match MySQL's representation. 174 | func (v Value) Raw() []byte { 175 | return v.val 176 | } 177 | 178 | // ToBytes returns the value as MySQL would return it as []byte. 179 | // In contrast, Raw returns the internal representation of the Value, which may not 180 | // match MySQL's representation for newer types. 181 | // If the value is not convertible like in the case of Expression, it returns nil. 182 | func (v Value) ToBytes() []byte { 183 | if v.typ == Expression { 184 | return nil 185 | } 186 | return v.val 187 | } 188 | 189 | // Len returns the length. 190 | func (v Value) Len() int { 191 | return len(v.val) 192 | } 193 | 194 | // ToString returns the value as MySQL would return it as string. 195 | // If the value is not convertible like in the case of Expression, it returns nil. 196 | func (v Value) ToString() string { 197 | if v.typ == Expression { 198 | return "" 199 | } 200 | return hack.String(v.val) 201 | } 202 | 203 | // String returns a printable version of the value. 204 | func (v Value) String() string { 205 | if v.typ == Null { 206 | return "NULL" 207 | } 208 | if v.IsQuoted() { 209 | return fmt.Sprintf("%v(%q)", v.typ, v.val) 210 | } 211 | return fmt.Sprintf("%v(%s)", v.typ, v.val) 212 | } 213 | 214 | // EncodeSQL encodes the value into an SQL statement. Can be binary. 215 | func (v Value) EncodeSQL(b BinWriter) { 216 | switch { 217 | case v.typ == Null: 218 | b.Write(nullstr) 219 | case v.IsQuoted(): 220 | encodeBytesSQL(v.val, b) 221 | default: 222 | b.Write(v.val) 223 | } 224 | } 225 | 226 | // EncodeASCII encodes the value using 7-bit clean ascii bytes. 227 | func (v Value) EncodeASCII(b BinWriter) { 228 | switch { 229 | case v.typ == Null: 230 | b.Write(nullstr) 231 | case v.IsQuoted(): 232 | encodeBytesASCII(v.val, b) 233 | default: 234 | b.Write(v.val) 235 | } 236 | } 237 | 238 | // IsNull returns true if Value is null. 239 | func (v Value) IsNull() bool { 240 | return v.typ == Null 241 | } 242 | 243 | // IsIntegral returns true if Value is an integral. 244 | func (v Value) IsIntegral() bool { 245 | return IsIntegral(v.typ) 246 | } 247 | 248 | // IsSigned returns true if Value is a signed integral. 249 | func (v Value) IsSigned() bool { 250 | return IsSigned(v.typ) 251 | } 252 | 253 | // IsUnsigned returns true if Value is an unsigned integral. 254 | func (v Value) IsUnsigned() bool { 255 | return IsUnsigned(v.typ) 256 | } 257 | 258 | // IsFloat returns true if Value is a float. 259 | func (v Value) IsFloat() bool { 260 | return IsFloat(v.typ) 261 | } 262 | 263 | // IsQuoted returns true if Value must be SQL-quoted. 264 | func (v Value) IsQuoted() bool { 265 | return IsQuoted(v.typ) 266 | } 267 | 268 | // IsText returns true if Value is a collatable text. 269 | func (v Value) IsText() bool { 270 | return IsText(v.typ) 271 | } 272 | 273 | // IsBinary returns true if Value is binary. 274 | func (v Value) IsBinary() bool { 275 | return IsBinary(v.typ) 276 | } 277 | 278 | // MarshalJSON should only be used for testing. 279 | // It's not a complete implementation. 280 | func (v Value) MarshalJSON() ([]byte, error) { 281 | switch { 282 | case v.IsQuoted(): 283 | return json.Marshal(v.ToString()) 284 | case v.typ == Null: 285 | return nullstr, nil 286 | } 287 | return v.val, nil 288 | } 289 | 290 | // UnmarshalJSON should only be used for testing. 291 | // It's not a complete implementation. 292 | func (v *Value) UnmarshalJSON(b []byte) error { 293 | if len(b) == 0 { 294 | return fmt.Errorf("error unmarshaling empty bytes") 295 | } 296 | var val interface{} 297 | var err error 298 | switch b[0] { 299 | case '-': 300 | var ival int64 301 | err = json.Unmarshal(b, &ival) 302 | val = ival 303 | case '"': 304 | var bval []byte 305 | err = json.Unmarshal(b, &bval) 306 | val = bval 307 | case 'n': // null 308 | err = json.Unmarshal(b, &val) 309 | default: 310 | var uval uint64 311 | err = json.Unmarshal(b, &uval) 312 | val = uval 313 | } 314 | if err != nil { 315 | return err 316 | } 317 | *v, err = InterfaceToValue(val) 318 | return err 319 | } 320 | 321 | func encodeBytesSQL(val []byte, b BinWriter) { 322 | buf := &bytes2.Buffer{} 323 | buf.WriteByte('\'') 324 | for _, ch := range val { 325 | if encodedChar := SQLEncodeMap[ch]; encodedChar == DontEscape { 326 | buf.WriteByte(ch) 327 | } else { 328 | buf.WriteByte('\\') 329 | buf.WriteByte(encodedChar) 330 | } 331 | } 332 | buf.WriteByte('\'') 333 | b.Write(buf.Bytes()) 334 | } 335 | 336 | func encodeBytesASCII(val []byte, b BinWriter) { 337 | buf := &bytes2.Buffer{} 338 | buf.WriteByte('\'') 339 | encoder := base64.NewEncoder(base64.StdEncoding, buf) 340 | encoder.Write(val) 341 | encoder.Close() 342 | buf.WriteByte('\'') 343 | b.Write(buf.Bytes()) 344 | } 345 | 346 | // SQLEncodeMap specifies how to escape binary data with '\'. 347 | // Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html 348 | var SQLEncodeMap [256]byte 349 | 350 | // SQLDecodeMap is the reverse of SQLEncodeMap 351 | var SQLDecodeMap [256]byte 352 | 353 | var encodeRef = map[byte]byte{ 354 | '\x00': '0', 355 | '\'': '\'', 356 | '"': '"', 357 | '\b': 'b', 358 | '\n': 'n', 359 | '\r': 'r', 360 | '\t': 't', 361 | 26: 'Z', // ctl-Z 362 | '\\': '\\', 363 | } 364 | 365 | func init() { 366 | for i := range SQLEncodeMap { 367 | SQLEncodeMap[i] = DontEscape 368 | SQLDecodeMap[i] = DontEscape 369 | } 370 | for i := range SQLEncodeMap { 371 | if to, ok := encodeRef[byte(i)]; ok { 372 | SQLEncodeMap[byte(i)] = to 373 | SQLDecodeMap[to] = byte(i) 374 | } 375 | } 376 | } 377 | -------------------------------------------------------------------------------- /dependency/sqltypes/value_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import ( 20 | "bytes" 21 | "reflect" 22 | "strings" 23 | "testing" 24 | 25 | "github.com/xwb1989/sqlparser/dependency/querypb" 26 | ) 27 | 28 | const ( 29 | InvalidNeg = "-9223372036854775809" 30 | MinNeg = "-9223372036854775808" 31 | MinPos = "18446744073709551615" 32 | InvalidPos = "18446744073709551616" 33 | ) 34 | 35 | func TestNewValue(t *testing.T) { 36 | testcases := []struct { 37 | inType querypb.Type 38 | inVal string 39 | outVal Value 40 | outErr string 41 | }{{ 42 | inType: Null, 43 | inVal: "", 44 | outVal: NULL, 45 | }, { 46 | inType: Int8, 47 | inVal: "1", 48 | outVal: TestValue(Int8, "1"), 49 | }, { 50 | inType: Int16, 51 | inVal: "1", 52 | outVal: TestValue(Int16, "1"), 53 | }, { 54 | inType: Int24, 55 | inVal: "1", 56 | outVal: TestValue(Int24, "1"), 57 | }, { 58 | inType: Int32, 59 | inVal: "1", 60 | outVal: TestValue(Int32, "1"), 61 | }, { 62 | inType: Int64, 63 | inVal: "1", 64 | outVal: TestValue(Int64, "1"), 65 | }, { 66 | inType: Uint8, 67 | inVal: "1", 68 | outVal: TestValue(Uint8, "1"), 69 | }, { 70 | inType: Uint16, 71 | inVal: "1", 72 | outVal: TestValue(Uint16, "1"), 73 | }, { 74 | inType: Uint24, 75 | inVal: "1", 76 | outVal: TestValue(Uint24, "1"), 77 | }, { 78 | inType: Uint32, 79 | inVal: "1", 80 | outVal: TestValue(Uint32, "1"), 81 | }, { 82 | inType: Uint64, 83 | inVal: "1", 84 | outVal: TestValue(Uint64, "1"), 85 | }, { 86 | inType: Float32, 87 | inVal: "1.00", 88 | outVal: TestValue(Float32, "1.00"), 89 | }, { 90 | inType: Float64, 91 | inVal: "1.00", 92 | outVal: TestValue(Float64, "1.00"), 93 | }, { 94 | inType: Decimal, 95 | inVal: "1.00", 96 | outVal: TestValue(Decimal, "1.00"), 97 | }, { 98 | inType: Timestamp, 99 | inVal: "2012-02-24 23:19:43", 100 | outVal: TestValue(Timestamp, "2012-02-24 23:19:43"), 101 | }, { 102 | inType: Date, 103 | inVal: "2012-02-24", 104 | outVal: TestValue(Date, "2012-02-24"), 105 | }, { 106 | inType: Time, 107 | inVal: "23:19:43", 108 | outVal: TestValue(Time, "23:19:43"), 109 | }, { 110 | inType: Datetime, 111 | inVal: "2012-02-24 23:19:43", 112 | outVal: TestValue(Datetime, "2012-02-24 23:19:43"), 113 | }, { 114 | inType: Year, 115 | inVal: "1", 116 | outVal: TestValue(Year, "1"), 117 | }, { 118 | inType: Text, 119 | inVal: "a", 120 | outVal: TestValue(Text, "a"), 121 | }, { 122 | inType: Blob, 123 | inVal: "a", 124 | outVal: TestValue(Blob, "a"), 125 | }, { 126 | inType: VarChar, 127 | inVal: "a", 128 | outVal: TestValue(VarChar, "a"), 129 | }, { 130 | inType: Binary, 131 | inVal: "a", 132 | outVal: TestValue(Binary, "a"), 133 | }, { 134 | inType: Char, 135 | inVal: "a", 136 | outVal: TestValue(Char, "a"), 137 | }, { 138 | inType: Bit, 139 | inVal: "1", 140 | outVal: TestValue(Bit, "1"), 141 | }, { 142 | inType: Enum, 143 | inVal: "a", 144 | outVal: TestValue(Enum, "a"), 145 | }, { 146 | inType: Set, 147 | inVal: "a", 148 | outVal: TestValue(Set, "a"), 149 | }, { 150 | inType: VarBinary, 151 | inVal: "a", 152 | outVal: TestValue(VarBinary, "a"), 153 | }, { 154 | inType: Int64, 155 | inVal: InvalidNeg, 156 | outErr: "out of range", 157 | }, { 158 | inType: Int64, 159 | inVal: InvalidPos, 160 | outErr: "out of range", 161 | }, { 162 | inType: Uint64, 163 | inVal: "-1", 164 | outErr: "invalid syntax", 165 | }, { 166 | inType: Uint64, 167 | inVal: InvalidPos, 168 | outErr: "out of range", 169 | }, { 170 | inType: Float64, 171 | inVal: "a", 172 | outErr: "invalid syntax", 173 | }, { 174 | inType: Expression, 175 | inVal: "a", 176 | outErr: "invalid type specified for MakeValue: EXPRESSION", 177 | }} 178 | for _, tcase := range testcases { 179 | v, err := NewValue(tcase.inType, []byte(tcase.inVal)) 180 | if tcase.outErr != "" { 181 | if err == nil || !strings.Contains(err.Error(), tcase.outErr) { 182 | t.Errorf("ValueFromBytes(%v, %v) error: %v, must contain %v", tcase.inType, tcase.inVal, err, tcase.outErr) 183 | } 184 | continue 185 | } 186 | if err != nil { 187 | t.Errorf("ValueFromBytes(%v, %v) error: %v", tcase.inType, tcase.inVal, err) 188 | continue 189 | } 190 | if !reflect.DeepEqual(v, tcase.outVal) { 191 | t.Errorf("ValueFromBytes(%v, %v) = %v, want %v", tcase.inType, tcase.inVal, v, tcase.outVal) 192 | } 193 | } 194 | } 195 | 196 | // TestNew tests 'New' functions that are not tested 197 | // through other code paths. 198 | func TestNew(t *testing.T) { 199 | got := NewInt32(1) 200 | want := MakeTrusted(Int32, []byte("1")) 201 | if !reflect.DeepEqual(got, want) { 202 | t.Errorf("NewInt32(aa): %v, want %v", got, want) 203 | } 204 | 205 | got = NewVarBinary("aa") 206 | want = MakeTrusted(VarBinary, []byte("aa")) 207 | if !reflect.DeepEqual(got, want) { 208 | t.Errorf("NewVarBinary(aa): %v, want %v", got, want) 209 | } 210 | } 211 | 212 | func TestMakeTrusted(t *testing.T) { 213 | v := MakeTrusted(Null, []byte("abcd")) 214 | if !reflect.DeepEqual(v, NULL) { 215 | t.Errorf("MakeTrusted(Null...) = %v, want null", v) 216 | } 217 | v = MakeTrusted(Int64, []byte("1")) 218 | want := TestValue(Int64, "1") 219 | if !reflect.DeepEqual(v, want) { 220 | t.Errorf("MakeTrusted(Int64, \"1\") = %v, want %v", v, want) 221 | } 222 | } 223 | 224 | func TestIntegralValue(t *testing.T) { 225 | testcases := []struct { 226 | in string 227 | outVal Value 228 | outErr string 229 | }{{ 230 | in: MinNeg, 231 | outVal: TestValue(Int64, MinNeg), 232 | }, { 233 | in: "1", 234 | outVal: TestValue(Int64, "1"), 235 | }, { 236 | in: MinPos, 237 | outVal: TestValue(Uint64, MinPos), 238 | }, { 239 | in: InvalidPos, 240 | outErr: "out of range", 241 | }} 242 | for _, tcase := range testcases { 243 | v, err := NewIntegral(tcase.in) 244 | if tcase.outErr != "" { 245 | if err == nil || !strings.Contains(err.Error(), tcase.outErr) { 246 | t.Errorf("BuildIntegral(%v) error: %v, must contain %v", tcase.in, err, tcase.outErr) 247 | } 248 | continue 249 | } 250 | if err != nil { 251 | t.Errorf("BuildIntegral(%v) error: %v", tcase.in, err) 252 | continue 253 | } 254 | if !reflect.DeepEqual(v, tcase.outVal) { 255 | t.Errorf("BuildIntegral(%v) = %v, want %v", tcase.in, v, tcase.outVal) 256 | } 257 | } 258 | } 259 | 260 | func TestInerfaceValue(t *testing.T) { 261 | testcases := []struct { 262 | in interface{} 263 | out Value 264 | }{{ 265 | in: nil, 266 | out: NULL, 267 | }, { 268 | in: []byte("a"), 269 | out: TestValue(VarBinary, "a"), 270 | }, { 271 | in: int64(1), 272 | out: TestValue(Int64, "1"), 273 | }, { 274 | in: uint64(1), 275 | out: TestValue(Uint64, "1"), 276 | }, { 277 | in: float64(1.2), 278 | out: TestValue(Float64, "1.2"), 279 | }, { 280 | in: "a", 281 | out: TestValue(VarChar, "a"), 282 | }} 283 | for _, tcase := range testcases { 284 | v, err := InterfaceToValue(tcase.in) 285 | if err != nil { 286 | t.Errorf("BuildValue(%#v) error: %v", tcase.in, err) 287 | continue 288 | } 289 | if !reflect.DeepEqual(v, tcase.out) { 290 | t.Errorf("BuildValue(%#v) = %v, want %v", tcase.in, v, tcase.out) 291 | } 292 | } 293 | 294 | _, err := InterfaceToValue(make(chan bool)) 295 | want := "unexpected" 296 | if err == nil || !strings.Contains(err.Error(), want) { 297 | t.Errorf("BuildValue(chan): %v, want %v", err, want) 298 | } 299 | } 300 | 301 | func TestAccessors(t *testing.T) { 302 | v := TestValue(Int64, "1") 303 | if v.Type() != Int64 { 304 | t.Errorf("v.Type=%v, want Int64", v.Type()) 305 | } 306 | if !bytes.Equal(v.Raw(), []byte("1")) { 307 | t.Errorf("v.Raw=%s, want 1", v.Raw()) 308 | } 309 | if v.Len() != 1 { 310 | t.Errorf("v.Len=%d, want 1", v.Len()) 311 | } 312 | if v.ToString() != "1" { 313 | t.Errorf("v.String=%s, want 1", v.ToString()) 314 | } 315 | if v.IsNull() { 316 | t.Error("v.IsNull: true, want false") 317 | } 318 | if !v.IsIntegral() { 319 | t.Error("v.IsIntegral: false, want true") 320 | } 321 | if !v.IsSigned() { 322 | t.Error("v.IsSigned: false, want true") 323 | } 324 | if v.IsUnsigned() { 325 | t.Error("v.IsUnsigned: true, want false") 326 | } 327 | if v.IsFloat() { 328 | t.Error("v.IsFloat: true, want false") 329 | } 330 | if v.IsQuoted() { 331 | t.Error("v.IsQuoted: true, want false") 332 | } 333 | if v.IsText() { 334 | t.Error("v.IsText: true, want false") 335 | } 336 | if v.IsBinary() { 337 | t.Error("v.IsBinary: true, want false") 338 | } 339 | } 340 | 341 | func TestToBytesAndString(t *testing.T) { 342 | for _, v := range []Value{ 343 | NULL, 344 | TestValue(Int64, "1"), 345 | TestValue(Int64, "12"), 346 | } { 347 | if b := v.ToBytes(); bytes.Compare(b, v.Raw()) != 0 { 348 | t.Errorf("%v.ToBytes: %s, want %s", v, b, v.Raw()) 349 | } 350 | if s := v.ToString(); s != string(v.Raw()) { 351 | t.Errorf("%v.ToString: %s, want %s", v, s, v.Raw()) 352 | } 353 | } 354 | 355 | tv := TestValue(Expression, "aa") 356 | if b := tv.ToBytes(); b != nil { 357 | t.Errorf("%v.ToBytes: %s, want nil", tv, b) 358 | } 359 | if s := tv.ToString(); s != "" { 360 | t.Errorf("%v.ToString: %s, want \"\"", tv, s) 361 | } 362 | } 363 | 364 | func TestEncode(t *testing.T) { 365 | testcases := []struct { 366 | in Value 367 | outSQL string 368 | outASCII string 369 | }{{ 370 | in: NULL, 371 | outSQL: "null", 372 | outASCII: "null", 373 | }, { 374 | in: TestValue(Int64, "1"), 375 | outSQL: "1", 376 | outASCII: "1", 377 | }, { 378 | in: TestValue(VarChar, "foo"), 379 | outSQL: "'foo'", 380 | outASCII: "'Zm9v'", 381 | }, { 382 | in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"), 383 | outSQL: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", 384 | outASCII: "'ACciCAoNCRpc'", 385 | }} 386 | for _, tcase := range testcases { 387 | buf := &bytes.Buffer{} 388 | tcase.in.EncodeSQL(buf) 389 | if tcase.outSQL != buf.String() { 390 | t.Errorf("%v.EncodeSQL = %q, want %q", tcase.in, buf.String(), tcase.outSQL) 391 | } 392 | buf = &bytes.Buffer{} 393 | tcase.in.EncodeASCII(buf) 394 | if tcase.outASCII != buf.String() { 395 | t.Errorf("%v.EncodeASCII = %q, want %q", tcase.in, buf.String(), tcase.outASCII) 396 | } 397 | } 398 | } 399 | 400 | // TestEncodeMap ensures DontEscape is not escaped 401 | func TestEncodeMap(t *testing.T) { 402 | if SQLEncodeMap[DontEscape] != DontEscape { 403 | t.Errorf("SQLEncodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) 404 | } 405 | if SQLDecodeMap[DontEscape] != DontEscape { 406 | t.Errorf("SQLDecodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) 407 | } 408 | } 409 | -------------------------------------------------------------------------------- /encodable.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "bytes" 21 | 22 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 23 | ) 24 | 25 | // This file contains types that are 'Encodable'. 26 | 27 | // Encodable defines the interface for types that can 28 | // be custom-encoded into SQL. 29 | type Encodable interface { 30 | EncodeSQL(buf *bytes.Buffer) 31 | } 32 | 33 | // InsertValues is a custom SQL encoder for the values of 34 | // an insert statement. 35 | type InsertValues [][]sqltypes.Value 36 | 37 | // EncodeSQL performs the SQL encoding for InsertValues. 38 | func (iv InsertValues) EncodeSQL(buf *bytes.Buffer) { 39 | for i, rows := range iv { 40 | if i != 0 { 41 | buf.WriteString(", ") 42 | } 43 | buf.WriteByte('(') 44 | for j, bv := range rows { 45 | if j != 0 { 46 | buf.WriteString(", ") 47 | } 48 | bv.EncodeSQL(buf) 49 | } 50 | buf.WriteByte(')') 51 | } 52 | } 53 | 54 | // TupleEqualityList is for generating equality constraints 55 | // for tables that have composite primary keys. 56 | type TupleEqualityList struct { 57 | Columns []ColIdent 58 | Rows [][]sqltypes.Value 59 | } 60 | 61 | // EncodeSQL generates the where clause constraints for the tuple 62 | // equality. 63 | func (tpl *TupleEqualityList) EncodeSQL(buf *bytes.Buffer) { 64 | if len(tpl.Columns) == 1 { 65 | tpl.encodeAsIn(buf) 66 | return 67 | } 68 | tpl.encodeAsEquality(buf) 69 | } 70 | 71 | func (tpl *TupleEqualityList) encodeAsIn(buf *bytes.Buffer) { 72 | Append(buf, tpl.Columns[0]) 73 | buf.WriteString(" in (") 74 | for i, r := range tpl.Rows { 75 | if i != 0 { 76 | buf.WriteString(", ") 77 | } 78 | r[0].EncodeSQL(buf) 79 | } 80 | buf.WriteByte(')') 81 | } 82 | 83 | func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) { 84 | for i, r := range tpl.Rows { 85 | if i != 0 { 86 | buf.WriteString(" or ") 87 | } 88 | buf.WriteString("(") 89 | for j, c := range tpl.Columns { 90 | if j != 0 { 91 | buf.WriteString(" and ") 92 | } 93 | Append(buf, c) 94 | buf.WriteString(" = ") 95 | r[j].EncodeSQL(buf) 96 | } 97 | buf.WriteByte(')') 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /encodable_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "bytes" 21 | "testing" 22 | 23 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 24 | ) 25 | 26 | func TestEncodable(t *testing.T) { 27 | tcases := []struct { 28 | in Encodable 29 | out string 30 | }{{ 31 | in: InsertValues{{ 32 | sqltypes.NewInt64(1), 33 | sqltypes.NewVarBinary("foo('a')"), 34 | }, { 35 | sqltypes.NewInt64(2), 36 | sqltypes.NewVarBinary("bar(`b`)"), 37 | }}, 38 | out: "(1, 'foo(\\'a\\')'), (2, 'bar(`b`)')", 39 | }, { 40 | // Single column. 41 | in: &TupleEqualityList{ 42 | Columns: []ColIdent{NewColIdent("pk")}, 43 | Rows: [][]sqltypes.Value{ 44 | {sqltypes.NewInt64(1)}, 45 | {sqltypes.NewVarBinary("aa")}, 46 | }, 47 | }, 48 | out: "pk in (1, 'aa')", 49 | }, { 50 | // Multiple columns. 51 | in: &TupleEqualityList{ 52 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")}, 53 | Rows: [][]sqltypes.Value{ 54 | { 55 | sqltypes.NewInt64(1), 56 | sqltypes.NewVarBinary("aa"), 57 | }, 58 | { 59 | sqltypes.NewInt64(2), 60 | sqltypes.NewVarBinary("bb"), 61 | }, 62 | }, 63 | }, 64 | out: "(pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 65 | }} 66 | for _, tcase := range tcases { 67 | buf := new(bytes.Buffer) 68 | tcase.in.EncodeSQL(buf) 69 | if out := buf.String(); out != tcase.out { 70 | t.Errorf("EncodeSQL(%v): %s, want %s", tcase.in, out, tcase.out) 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /github_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreedto in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | package sqlparser 17 | 18 | // Additional tests to address the GitHub issues for this fork. 19 | 20 | import ( 21 | "io" 22 | "strings" 23 | "testing" 24 | ) 25 | 26 | func TestParsing(t *testing.T) { 27 | tests := []struct { 28 | id int // Github issue ID 29 | sql string 30 | skip string 31 | }{ 32 | {id: 9, sql: "select 1 as 测试 from dual", skip: "Broken due to ReadByte()"}, 33 | {id: 12, sql: "SELECT * FROM AccessToken LIMIT 10 OFFSET 13"}, 34 | {id: 14, sql: "SELECT DATE_SUB(NOW(), INTERVAL 1 MONTH)"}, 35 | {id: 15, sql: "select STRAIGHT_JOIN t1.* FROM t1 INNER JOIN t2 ON t1.CommonID = t2.CommonID WHERE t1.FilterID = 1"}, 36 | {id: 16, sql: "SELECT a FROM t WHERE FUNC(a) = 1"}, // Doesn't seem broken, need better example 37 | {id: 21, sql: `CREATE TABLE t (UpdateDatetime TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP)`}, 38 | {id: 21, sql: `CREATE TABLE t (UpdateDatetime TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间')`}, 39 | {id: 24, sql: `select * from t1 join t2 using(id)`}, 40 | } 41 | 42 | for _, test := range tests { 43 | if test.skip != "" { 44 | continue 45 | } 46 | 47 | if _, err := Parse(test.sql); err != nil { 48 | t.Errorf("https://github.com/xwb1989/sqlparser/issues/%d:\nParse(%q) err = %s, want nil", test.id, test.sql, err) 49 | } 50 | } 51 | } 52 | 53 | // ExampleParse is the first example from the README.md. 54 | func ExampleParse() { 55 | sql := "SELECT * FROM table WHERE a = 'abc'" 56 | stmt, err := Parse(sql) 57 | if err != nil { 58 | // Do something with the err 59 | } 60 | 61 | // Otherwise do something with stmt 62 | switch stmt := stmt.(type) { 63 | case *Select: 64 | _ = stmt 65 | case *Insert: 66 | } 67 | } 68 | 69 | // ExampleParseNext is the second example from the README.md. 70 | func ExampleParseNext() { 71 | r := strings.NewReader("INSERT INTO table1 VALUES (1, 'a'); INSERT INTO table2 VALUES (3, 4);") 72 | 73 | tokens := NewTokenizer(r) 74 | for { 75 | stmt, err := ParseNext(tokens) 76 | if err == io.EOF { 77 | break 78 | } 79 | // Do something with stmt or err. 80 | _ = stmt 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /impossible_query.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreedto in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | // FormatImpossibleQuery creates an impossible query in a TrackedBuffer. 20 | // An impossible query is a modified version of a query where all selects have where clauses that are 21 | // impossible for mysql to resolve. This is used in the vtgate and vttablet: 22 | // 23 | // - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible 24 | // query just to fetch field info from vttablet 25 | // - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused 26 | // for subsequent queries 27 | func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) { 28 | switch node := node.(type) { 29 | case *Select: 30 | buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From) 31 | if node.GroupBy != nil { 32 | node.GroupBy.Format(buf) 33 | } 34 | case *Union: 35 | buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right) 36 | default: 37 | node.Format(buf) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /normalizer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | 22 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 23 | 24 | "github.com/xwb1989/sqlparser/dependency/querypb" 25 | ) 26 | 27 | // Normalize changes the statement to use bind values, and 28 | // updates the bind vars to those values. The supplied prefix 29 | // is used to generate the bind var names. The function ensures 30 | // that there are no collisions with existing bind vars. 31 | // Within Select constructs, bind vars are deduped. This allows 32 | // us to identify vindex equality. Otherwise, every value is 33 | // treated as distinct. 34 | func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) { 35 | nz := newNormalizer(stmt, bindVars, prefix) 36 | _ = Walk(nz.WalkStatement, stmt) 37 | } 38 | 39 | type normalizer struct { 40 | stmt Statement 41 | bindVars map[string]*querypb.BindVariable 42 | prefix string 43 | reserved map[string]struct{} 44 | counter int 45 | vals map[string]string 46 | } 47 | 48 | func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer { 49 | return &normalizer{ 50 | stmt: stmt, 51 | bindVars: bindVars, 52 | prefix: prefix, 53 | reserved: GetBindvars(stmt), 54 | counter: 1, 55 | vals: make(map[string]string), 56 | } 57 | } 58 | 59 | // WalkStatement is the top level walk function. 60 | // If it encounters a Select, it switches to a mode 61 | // where variables are deduped. 62 | func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) { 63 | switch node := node.(type) { 64 | case *Select: 65 | _ = Walk(nz.WalkSelect, node) 66 | // Don't continue 67 | return false, nil 68 | case *SQLVal: 69 | nz.convertSQLVal(node) 70 | case *ComparisonExpr: 71 | nz.convertComparison(node) 72 | } 73 | return true, nil 74 | } 75 | 76 | // WalkSelect normalizes the AST in Select mode. 77 | func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) { 78 | switch node := node.(type) { 79 | case *SQLVal: 80 | nz.convertSQLValDedup(node) 81 | case *ComparisonExpr: 82 | nz.convertComparison(node) 83 | } 84 | return true, nil 85 | } 86 | 87 | func (nz *normalizer) convertSQLValDedup(node *SQLVal) { 88 | // If value is too long, don't dedup. 89 | // Such values are most likely not for vindexes. 90 | // We save a lot of CPU because we avoid building 91 | // the key for them. 92 | if len(node.Val) > 256 { 93 | nz.convertSQLVal(node) 94 | return 95 | } 96 | 97 | // Make the bindvar 98 | bval := nz.sqlToBindvar(node) 99 | if bval == nil { 100 | return 101 | } 102 | 103 | // Check if there's a bindvar for that value already. 104 | var key string 105 | if bval.Type == sqltypes.VarBinary { 106 | // Prefixing strings with "'" ensures that a string 107 | // and number that have the same representation don't 108 | // collide. 109 | key = "'" + string(node.Val) 110 | } else { 111 | key = string(node.Val) 112 | } 113 | bvname, ok := nz.vals[key] 114 | if !ok { 115 | // If there's no such bindvar, make a new one. 116 | bvname = nz.newName() 117 | nz.vals[key] = bvname 118 | nz.bindVars[bvname] = bval 119 | } 120 | 121 | // Modify the AST node to a bindvar. 122 | node.Type = ValArg 123 | node.Val = append([]byte(":"), bvname...) 124 | } 125 | 126 | // convertSQLVal converts an SQLVal without the dedup. 127 | func (nz *normalizer) convertSQLVal(node *SQLVal) { 128 | bval := nz.sqlToBindvar(node) 129 | if bval == nil { 130 | return 131 | } 132 | 133 | bvname := nz.newName() 134 | nz.bindVars[bvname] = bval 135 | 136 | node.Type = ValArg 137 | node.Val = append([]byte(":"), bvname...) 138 | } 139 | 140 | // convertComparison attempts to convert IN clauses to 141 | // use the list bind var construct. If it fails, it returns 142 | // with no change made. The walk function will then continue 143 | // and iterate on converting each individual value into separate 144 | // bind vars. 145 | func (nz *normalizer) convertComparison(node *ComparisonExpr) { 146 | if node.Operator != InStr && node.Operator != NotInStr { 147 | return 148 | } 149 | tupleVals, ok := node.Right.(ValTuple) 150 | if !ok { 151 | return 152 | } 153 | // The RHS is a tuple of values. 154 | // Make a list bindvar. 155 | bvals := &querypb.BindVariable{ 156 | Type: querypb.Type_TUPLE, 157 | } 158 | for _, val := range tupleVals { 159 | bval := nz.sqlToBindvar(val) 160 | if bval == nil { 161 | return 162 | } 163 | bvals.Values = append(bvals.Values, &querypb.Value{ 164 | Type: bval.Type, 165 | Value: bval.Value, 166 | }) 167 | } 168 | bvname := nz.newName() 169 | nz.bindVars[bvname] = bvals 170 | // Modify RHS to be a list bindvar. 171 | node.Right = ListArg(append([]byte("::"), bvname...)) 172 | } 173 | 174 | func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable { 175 | if node, ok := node.(*SQLVal); ok { 176 | var v sqltypes.Value 177 | var err error 178 | switch node.Type { 179 | case StrVal: 180 | v, err = sqltypes.NewValue(sqltypes.VarBinary, node.Val) 181 | case IntVal: 182 | v, err = sqltypes.NewValue(sqltypes.Int64, node.Val) 183 | case FloatVal: 184 | v, err = sqltypes.NewValue(sqltypes.Float64, node.Val) 185 | default: 186 | return nil 187 | } 188 | if err != nil { 189 | return nil 190 | } 191 | return sqltypes.ValueBindVariable(v) 192 | } 193 | return nil 194 | } 195 | 196 | func (nz *normalizer) newName() string { 197 | for { 198 | newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter) 199 | if _, ok := nz.reserved[newName]; !ok { 200 | nz.reserved[newName] = struct{}{} 201 | return newName 202 | } 203 | nz.counter++ 204 | } 205 | } 206 | 207 | // GetBindvars returns a map of the bind vars referenced in the statement. 208 | // TODO(sougou); This function gets called again from vtgate/planbuilder. 209 | // Ideally, this should be done only once. 210 | func GetBindvars(stmt Statement) map[string]struct{} { 211 | bindvars := make(map[string]struct{}) 212 | _ = Walk(func(node SQLNode) (kontinue bool, err error) { 213 | switch node := node.(type) { 214 | case *SQLVal: 215 | if node.Type == ValArg { 216 | bindvars[string(node.Val[1:])] = struct{}{} 217 | } 218 | case ListArg: 219 | bindvars[string(node[2:])] = struct{}{} 220 | } 221 | return true, nil 222 | }, stmt) 223 | return bindvars 224 | } 225 | -------------------------------------------------------------------------------- /normalizer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | "reflect" 22 | "testing" 23 | 24 | "github.com/xwb1989/sqlparser/dependency/querypb" 25 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 26 | ) 27 | 28 | func TestNormalize(t *testing.T) { 29 | prefix := "bv" 30 | testcases := []struct { 31 | in string 32 | outstmt string 33 | outbv map[string]*querypb.BindVariable 34 | }{{ 35 | // str val 36 | in: "select * from t where v1 = 'aa'", 37 | outstmt: "select * from t where v1 = :bv1", 38 | outbv: map[string]*querypb.BindVariable{ 39 | "bv1": sqltypes.BytesBindVariable([]byte("aa")), 40 | }, 41 | }, { 42 | // int val 43 | in: "select * from t where v1 = 1", 44 | outstmt: "select * from t where v1 = :bv1", 45 | outbv: map[string]*querypb.BindVariable{ 46 | "bv1": sqltypes.Int64BindVariable(1), 47 | }, 48 | }, { 49 | // float val 50 | in: "select * from t where v1 = 1.2", 51 | outstmt: "select * from t where v1 = :bv1", 52 | outbv: map[string]*querypb.BindVariable{ 53 | "bv1": sqltypes.Float64BindVariable(1.2), 54 | }, 55 | }, { 56 | // multiple vals 57 | in: "select * from t where v1 = 1.2 and v2 = 2", 58 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 59 | outbv: map[string]*querypb.BindVariable{ 60 | "bv1": sqltypes.Float64BindVariable(1.2), 61 | "bv2": sqltypes.Int64BindVariable(2), 62 | }, 63 | }, { 64 | // bv collision 65 | in: "select * from t where v1 = :bv1 and v2 = 1", 66 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 67 | outbv: map[string]*querypb.BindVariable{ 68 | "bv2": sqltypes.Int64BindVariable(1), 69 | }, 70 | }, { 71 | // val reuse 72 | in: "select * from t where v1 = 1 and v2 = 1", 73 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv1", 74 | outbv: map[string]*querypb.BindVariable{ 75 | "bv1": sqltypes.Int64BindVariable(1), 76 | }, 77 | }, { 78 | // ints and strings are different 79 | in: "select * from t where v1 = 1 and v2 = '1'", 80 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 81 | outbv: map[string]*querypb.BindVariable{ 82 | "bv1": sqltypes.Int64BindVariable(1), 83 | "bv2": sqltypes.BytesBindVariable([]byte("1")), 84 | }, 85 | }, { 86 | // val should not be reused for non-select statements 87 | in: "insert into a values(1, 1)", 88 | outstmt: "insert into a values (:bv1, :bv2)", 89 | outbv: map[string]*querypb.BindVariable{ 90 | "bv1": sqltypes.Int64BindVariable(1), 91 | "bv2": sqltypes.Int64BindVariable(1), 92 | }, 93 | }, { 94 | // val should be reused only in subqueries of DMLs 95 | in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5", 96 | outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv2, v3 = (select :bv1 from t), v4 = :bv3", 97 | outbv: map[string]*querypb.BindVariable{ 98 | "bv1": sqltypes.Int64BindVariable(5), 99 | "bv2": sqltypes.Int64BindVariable(5), 100 | "bv3": sqltypes.Int64BindVariable(5), 101 | }, 102 | }, { 103 | // list vars should work for DMLs also 104 | in: "update a set v1=5 where v2 in (1, 4, 5)", 105 | outstmt: "update a set v1 = :bv1 where v2 in ::bv2", 106 | outbv: map[string]*querypb.BindVariable{ 107 | "bv1": sqltypes.Int64BindVariable(5), 108 | "bv2": sqltypes.TestBindVariable([]interface{}{1, 4, 5}), 109 | }, 110 | }, { 111 | // Hex value does not convert 112 | in: "select * from t where v1 = 0x1234", 113 | outstmt: "select * from t where v1 = 0x1234", 114 | outbv: map[string]*querypb.BindVariable{}, 115 | }, { 116 | // Hex value does not convert for DMLs 117 | in: "update a set v1 = 0x1234", 118 | outstmt: "update a set v1 = 0x1234", 119 | outbv: map[string]*querypb.BindVariable{}, 120 | }, { 121 | // Values up to len 256 will reuse. 122 | in: fmt.Sprintf("select * from t where v1 = '%256s' and v2 = '%256s'", "a", "a"), 123 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv1", 124 | outbv: map[string]*querypb.BindVariable{ 125 | "bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%256s", "a"))), 126 | }, 127 | }, { 128 | // Values greater than len 256 will not reuse. 129 | in: fmt.Sprintf("select * from t where v1 = '%257s' and v2 = '%257s'", "b", "b"), 130 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 131 | outbv: map[string]*querypb.BindVariable{ 132 | "bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))), 133 | "bv2": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))), 134 | }, 135 | }, { 136 | // bad int 137 | in: "select * from t where v1 = 12345678901234567890", 138 | outstmt: "select * from t where v1 = 12345678901234567890", 139 | outbv: map[string]*querypb.BindVariable{}, 140 | }, { 141 | // comparison with no vals 142 | in: "select * from t where v1 = v2", 143 | outstmt: "select * from t where v1 = v2", 144 | outbv: map[string]*querypb.BindVariable{}, 145 | }, { 146 | // IN clause with existing bv 147 | in: "select * from t where v1 in ::list", 148 | outstmt: "select * from t where v1 in ::list", 149 | outbv: map[string]*querypb.BindVariable{}, 150 | }, { 151 | // IN clause with non-val values 152 | in: "select * from t where v1 in (1, a)", 153 | outstmt: "select * from t where v1 in (:bv1, a)", 154 | outbv: map[string]*querypb.BindVariable{ 155 | "bv1": sqltypes.Int64BindVariable(1), 156 | }, 157 | }, { 158 | // IN clause with vals 159 | in: "select * from t where v1 in (1, '2')", 160 | outstmt: "select * from t where v1 in ::bv1", 161 | outbv: map[string]*querypb.BindVariable{ 162 | "bv1": sqltypes.TestBindVariable([]interface{}{1, []byte("2")}), 163 | }, 164 | }, { 165 | // NOT IN clause 166 | in: "select * from t where v1 not in (1, '2')", 167 | outstmt: "select * from t where v1 not in ::bv1", 168 | outbv: map[string]*querypb.BindVariable{ 169 | "bv1": sqltypes.TestBindVariable([]interface{}{1, []byte("2")}), 170 | }, 171 | }} 172 | for _, tc := range testcases { 173 | stmt, err := Parse(tc.in) 174 | if err != nil { 175 | t.Error(err) 176 | continue 177 | } 178 | bv := make(map[string]*querypb.BindVariable) 179 | Normalize(stmt, bv, prefix) 180 | outstmt := String(stmt) 181 | if outstmt != tc.outstmt { 182 | t.Errorf("Query:\n%s:\n%s, want\n%s", tc.in, outstmt, tc.outstmt) 183 | } 184 | if !reflect.DeepEqual(tc.outbv, bv) { 185 | t.Errorf("Query:\n%s:\n%v, want\n%v", tc.in, bv, tc.outbv) 186 | } 187 | } 188 | } 189 | 190 | func TestGetBindVars(t *testing.T) { 191 | stmt, err := Parse("select * from t where :v1 = :v2 and :v2 = :v3 and :v4 in ::v5") 192 | if err != nil { 193 | t.Fatal(err) 194 | } 195 | got := GetBindvars(stmt) 196 | want := map[string]struct{}{ 197 | "v1": {}, 198 | "v2": {}, 199 | "v3": {}, 200 | "v4": {}, 201 | "v5": {}, 202 | } 203 | if !reflect.DeepEqual(got, want) { 204 | t.Errorf("GetBindVars: %v, want: %v", got, want) 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /parse_next_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "bytes" 21 | "io" 22 | "strings" 23 | "testing" 24 | ) 25 | 26 | // TestParseNextValid concatenates all the valid SQL test cases and check it can read 27 | // them as one long string. 28 | func TestParseNextValid(t *testing.T) { 29 | var sql bytes.Buffer 30 | for _, tcase := range validSQL { 31 | sql.WriteString(strings.TrimSuffix(tcase.input, ";")) 32 | sql.WriteRune(';') 33 | } 34 | 35 | tokens := NewTokenizer(&sql) 36 | for i, tcase := range validSQL { 37 | input := tcase.input + ";" 38 | want := tcase.output 39 | if want == "" { 40 | want = tcase.input 41 | } 42 | 43 | tree, err := ParseNext(tokens) 44 | if err != nil { 45 | t.Fatalf("[%d] ParseNext(%q) err: %q, want nil", i, input, err) 46 | continue 47 | } 48 | 49 | if got := String(tree); got != want { 50 | t.Fatalf("[%d] ParseNext(%q) = %q, want %q", i, input, got, want) 51 | } 52 | } 53 | 54 | // Read once more and it should be EOF. 55 | if tree, err := ParseNext(tokens); err != io.EOF { 56 | t.Errorf("ParseNext(tokens) = (%q, %v) want io.EOF", String(tree), err) 57 | } 58 | } 59 | 60 | // TestParseNextErrors tests all the error cases, and ensures a valid 61 | // SQL statement can be passed afterwards. 62 | func TestParseNextErrors(t *testing.T) { 63 | for _, tcase := range invalidSQL { 64 | if tcase.excludeMulti { 65 | // Skip tests which leave unclosed strings, or comments. 66 | continue 67 | } 68 | 69 | sql := tcase.input + "; select 1 from t" 70 | tokens := NewStringTokenizer(sql) 71 | 72 | // The first statement should be an error 73 | _, err := ParseNext(tokens) 74 | if err == nil || err.Error() != tcase.output { 75 | t.Fatalf("[0] ParseNext(%q) err: %q, want %q", sql, err, tcase.output) 76 | continue 77 | } 78 | 79 | // The second should be valid 80 | tree, err := ParseNext(tokens) 81 | if err != nil { 82 | t.Fatalf("[1] ParseNext(%q) err: %q, want nil", sql, err) 83 | continue 84 | } 85 | 86 | want := "select 1 from t" 87 | if got := String(tree); got != want { 88 | t.Fatalf("[1] ParseNext(%q) = %q, want %q", sql, got, want) 89 | } 90 | 91 | // Read once more and it should be EOF. 92 | if tree, err := ParseNext(tokens); err != io.EOF { 93 | t.Errorf("ParseNext(tokens) = (%q, %v) want io.EOF", String(tree), err) 94 | } 95 | } 96 | } 97 | 98 | // TestParseNextEdgeCases tests various ParseNext edge cases. 99 | func TestParseNextEdgeCases(t *testing.T) { 100 | tests := []struct { 101 | name string 102 | input string 103 | want []string 104 | }{{ 105 | name: "Trailing ;", 106 | input: "select 1 from a; update a set b = 2;", 107 | want: []string{"select 1 from a", "update a set b = 2"}, 108 | }, { 109 | name: "No trailing ;", 110 | input: "select 1 from a; update a set b = 2", 111 | want: []string{"select 1 from a", "update a set b = 2"}, 112 | }, { 113 | name: "Trailing whitespace", 114 | input: "select 1 from a; update a set b = 2 ", 115 | want: []string{"select 1 from a", "update a set b = 2"}, 116 | }, { 117 | name: "Trailing whitespace and ;", 118 | input: "select 1 from a; update a set b = 2 ; ", 119 | want: []string{"select 1 from a", "update a set b = 2"}, 120 | }, { 121 | name: "Handle ForceEOF statements", 122 | input: "set character set utf8; select 1 from a", 123 | want: []string{"set charset 'utf8'", "select 1 from a"}, 124 | }, { 125 | name: "Semicolin inside a string", 126 | input: "set character set ';'; select 1 from a", 127 | want: []string{"set charset ';'", "select 1 from a"}, 128 | }, { 129 | name: "Partial DDL", 130 | input: "create table a; select 1 from a", 131 | want: []string{"create table a", "select 1 from a"}, 132 | }, { 133 | name: "Partial DDL", 134 | input: "create table a ignore me this is garbage; select 1 from a", 135 | want: []string{"create table a", "select 1 from a"}, 136 | }} 137 | 138 | for _, test := range tests { 139 | tokens := NewStringTokenizer(test.input) 140 | 141 | for i, want := range test.want { 142 | tree, err := ParseNext(tokens) 143 | if err != nil { 144 | t.Fatalf("[%d] ParseNext(%q) err = %q, want nil", i, test.input, err) 145 | continue 146 | } 147 | 148 | if got := String(tree); got != want { 149 | t.Fatalf("[%d] ParseNext(%q) = %q, want %q", i, test.input, got, want) 150 | } 151 | } 152 | 153 | // Read once more and it should be EOF. 154 | if tree, err := ParseNext(tokens); err != io.EOF { 155 | t.Errorf("ParseNext(%q) = (%q, %v) want io.EOF", test.input, String(tree), err) 156 | } 157 | 158 | // And again, once more should be EOF. 159 | if tree, err := ParseNext(tokens); err != io.EOF { 160 | t.Errorf("ParseNext(%q) = (%q, %v) want io.EOF", test.input, String(tree), err) 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /parsed_query.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "bytes" 21 | "fmt" 22 | 23 | "github.com/xwb1989/sqlparser/dependency/querypb" 24 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 25 | ) 26 | 27 | // ParsedQuery represents a parsed query where 28 | // bind locations are precompued for fast substitutions. 29 | type ParsedQuery struct { 30 | Query string 31 | bindLocations []bindLocation 32 | } 33 | 34 | type bindLocation struct { 35 | offset, length int 36 | } 37 | 38 | // NewParsedQuery returns a ParsedQuery of the ast. 39 | func NewParsedQuery(node SQLNode) *ParsedQuery { 40 | buf := NewTrackedBuffer(nil) 41 | buf.Myprintf("%v", node) 42 | return buf.ParsedQuery() 43 | } 44 | 45 | // GenerateQuery generates a query by substituting the specified 46 | // bindVariables. The extras parameter specifies special parameters 47 | // that can perform custom encoding. 48 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) ([]byte, error) { 49 | if len(pq.bindLocations) == 0 { 50 | return []byte(pq.Query), nil 51 | } 52 | buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query))) 53 | current := 0 54 | for _, loc := range pq.bindLocations { 55 | buf.WriteString(pq.Query[current:loc.offset]) 56 | name := pq.Query[loc.offset : loc.offset+loc.length] 57 | if encodable, ok := extras[name[1:]]; ok { 58 | encodable.EncodeSQL(buf) 59 | } else { 60 | supplied, _, err := FetchBindVar(name, bindVariables) 61 | if err != nil { 62 | return nil, err 63 | } 64 | EncodeValue(buf, supplied) 65 | } 66 | current = loc.offset + loc.length 67 | } 68 | buf.WriteString(pq.Query[current:]) 69 | return buf.Bytes(), nil 70 | } 71 | 72 | // EncodeValue encodes one bind variable value into the query. 73 | func EncodeValue(buf *bytes.Buffer, value *querypb.BindVariable) { 74 | if value.Type != querypb.Type_TUPLE { 75 | // Since we already check for TUPLE, we don't expect an error. 76 | v, _ := sqltypes.BindVariableToValue(value) 77 | v.EncodeSQL(buf) 78 | return 79 | } 80 | 81 | // It's a TUPLE. 82 | buf.WriteByte('(') 83 | for i, bv := range value.Values { 84 | if i != 0 { 85 | buf.WriteString(", ") 86 | } 87 | sqltypes.ProtoToValue(bv).EncodeSQL(buf) 88 | } 89 | buf.WriteByte(')') 90 | } 91 | 92 | // FetchBindVar resolves the bind variable by fetching it from bindVariables. 93 | func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) { 94 | name = name[1:] 95 | if name[0] == ':' { 96 | name = name[1:] 97 | isList = true 98 | } 99 | supplied, ok := bindVariables[name] 100 | if !ok { 101 | return nil, false, fmt.Errorf("missing bind var %s", name) 102 | } 103 | 104 | if isList { 105 | if supplied.Type != querypb.Type_TUPLE { 106 | return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name) 107 | } 108 | if len(supplied.Values) == 0 { 109 | return nil, false, fmt.Errorf("empty list supplied for %s", name) 110 | } 111 | return supplied, true, nil 112 | } 113 | 114 | if supplied.Type == querypb.Type_TUPLE { 115 | return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name) 116 | } 117 | 118 | return supplied, false, nil 119 | } 120 | -------------------------------------------------------------------------------- /parsed_query_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "reflect" 21 | "testing" 22 | 23 | "github.com/xwb1989/sqlparser/dependency/sqltypes" 24 | 25 | "github.com/xwb1989/sqlparser/dependency/querypb" 26 | ) 27 | 28 | func TestNewParsedQuery(t *testing.T) { 29 | stmt, err := Parse("select * from a where id =:id") 30 | if err != nil { 31 | t.Error(err) 32 | return 33 | } 34 | pq := NewParsedQuery(stmt) 35 | want := &ParsedQuery{ 36 | Query: "select * from a where id = :id", 37 | bindLocations: []bindLocation{{offset: 27, length: 3}}, 38 | } 39 | if !reflect.DeepEqual(pq, want) { 40 | t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want) 41 | } 42 | } 43 | 44 | func TestGenerateQuery(t *testing.T) { 45 | tcases := []struct { 46 | desc string 47 | query string 48 | bindVars map[string]*querypb.BindVariable 49 | extras map[string]Encodable 50 | output string 51 | }{ 52 | { 53 | desc: "no substitutions", 54 | query: "select * from a where id = 2", 55 | bindVars: map[string]*querypb.BindVariable{ 56 | "id": sqltypes.Int64BindVariable(1), 57 | }, 58 | output: "select * from a where id = 2", 59 | }, { 60 | desc: "missing bind var", 61 | query: "select * from a where id1 = :id1 and id2 = :id2", 62 | bindVars: map[string]*querypb.BindVariable{ 63 | "id1": sqltypes.Int64BindVariable(1), 64 | }, 65 | output: "missing bind var id2", 66 | }, { 67 | desc: "simple bindvar substitution", 68 | query: "select * from a where id1 = :id1 and id2 = :id2", 69 | bindVars: map[string]*querypb.BindVariable{ 70 | "id1": sqltypes.Int64BindVariable(1), 71 | "id2": sqltypes.NullBindVariable, 72 | }, 73 | output: "select * from a where id1 = 1 and id2 = null", 74 | }, { 75 | desc: "tuple *querypb.BindVariable", 76 | query: "select * from a where id in ::vals", 77 | bindVars: map[string]*querypb.BindVariable{ 78 | "vals": sqltypes.TestBindVariable([]interface{}{1, "aa"}), 79 | }, 80 | output: "select * from a where id in (1, 'aa')", 81 | }, { 82 | desc: "list bind vars 0 arguments", 83 | query: "select * from a where id in ::vals", 84 | bindVars: map[string]*querypb.BindVariable{ 85 | "vals": sqltypes.TestBindVariable([]interface{}{}), 86 | }, 87 | output: "empty list supplied for vals", 88 | }, { 89 | desc: "non-list bind var supplied", 90 | query: "select * from a where id in ::vals", 91 | bindVars: map[string]*querypb.BindVariable{ 92 | "vals": sqltypes.Int64BindVariable(1), 93 | }, 94 | output: "unexpected list arg type (INT64) for key vals", 95 | }, { 96 | desc: "list bind var for non-list", 97 | query: "select * from a where id = :vals", 98 | bindVars: map[string]*querypb.BindVariable{ 99 | "vals": sqltypes.TestBindVariable([]interface{}{1}), 100 | }, 101 | output: "unexpected arg type (TUPLE) for non-list key vals", 102 | }, { 103 | desc: "single column tuple equality", 104 | query: "select * from a where b = :equality", 105 | extras: map[string]Encodable{ 106 | "equality": &TupleEqualityList{ 107 | Columns: []ColIdent{NewColIdent("pk")}, 108 | Rows: [][]sqltypes.Value{ 109 | {sqltypes.NewInt64(1)}, 110 | {sqltypes.NewVarBinary("aa")}, 111 | }, 112 | }, 113 | }, 114 | output: "select * from a where b = pk in (1, 'aa')", 115 | }, { 116 | desc: "multi column tuple equality", 117 | query: "select * from a where b = :equality", 118 | extras: map[string]Encodable{ 119 | "equality": &TupleEqualityList{ 120 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")}, 121 | Rows: [][]sqltypes.Value{ 122 | { 123 | sqltypes.NewInt64(1), 124 | sqltypes.NewVarBinary("aa"), 125 | }, 126 | { 127 | sqltypes.NewInt64(2), 128 | sqltypes.NewVarBinary("bb"), 129 | }, 130 | }, 131 | }, 132 | }, 133 | output: "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 134 | }, 135 | } 136 | 137 | for _, tcase := range tcases { 138 | tree, err := Parse(tcase.query) 139 | if err != nil { 140 | t.Errorf("parse failed for %s: %v", tcase.desc, err) 141 | continue 142 | } 143 | buf := NewTrackedBuffer(nil) 144 | buf.Myprintf("%v", tree) 145 | pq := buf.ParsedQuery() 146 | bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.extras) 147 | var got string 148 | if err != nil { 149 | got = err.Error() 150 | } else { 151 | got = string(bytes) 152 | } 153 | if got != tcase.output { 154 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output) 155 | } 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /patches/bytes2.patch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwb1989/sqlparser/120387863bf27d04bc07db8015110a6e96d0146c/patches/bytes2.patch -------------------------------------------------------------------------------- /patches/sqlparser.patch: -------------------------------------------------------------------------------- 1 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: .git 2 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: .gitignore 3 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: .travis.yml 4 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: CONTRIBUTORS.md 5 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: LICENSE.md 6 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: README.md 7 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/analyzer.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//analyzer.go 8 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/analyzer.go 2018-06-05 08:45:47.000000000 -0700 9 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//analyzer.go 2018-06-06 07:45:09.000000000 -0700 10 | @@ -19,15 +19,13 @@ 11 | // analyzer.go contains utility analysis functions. 12 | 13 | import ( 14 | + "errors" 15 | "fmt" 16 | "strconv" 17 | "strings" 18 | "unicode" 19 | 20 | - "vitess.io/vitess/go/sqltypes" 21 | - "vitess.io/vitess/go/vt/vterrors" 22 | - 23 | - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 24 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 25 | ) 26 | 27 | // These constants are used to identify the SQL statement type. 28 | @@ -219,7 +217,7 @@ 29 | case IntVal: 30 | n, err := sqltypes.NewIntegral(string(node.Val)) 31 | if err != nil { 32 | - return sqltypes.PlanValue{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) 33 | + return sqltypes.PlanValue{}, fmt.Errorf("%v", err) 34 | } 35 | return sqltypes.PlanValue{Value: n}, nil 36 | case StrVal: 37 | @@ -227,7 +225,7 @@ 38 | case HexVal: 39 | v, err := node.HexDecode() 40 | if err != nil { 41 | - return sqltypes.PlanValue{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) 42 | + return sqltypes.PlanValue{}, fmt.Errorf("%v", err) 43 | } 44 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil 45 | } 46 | @@ -243,7 +241,7 @@ 47 | return sqltypes.PlanValue{}, err 48 | } 49 | if innerpv.ListKey != "" || innerpv.Values != nil { 50 | - return sqltypes.PlanValue{}, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: nested lists") 51 | + return sqltypes.PlanValue{}, errors.New("unsupported: nested lists") 52 | } 53 | pv.Values = append(pv.Values, innerpv) 54 | } 55 | @@ -251,7 +249,7 @@ 56 | case *NullVal: 57 | return sqltypes.PlanValue{}, nil 58 | } 59 | - return sqltypes.PlanValue{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "expression is too complex '%v'", String(node)) 60 | + return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node)) 61 | } 62 | 63 | // StringIn is a convenience function that returns 64 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/analyzer_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//analyzer_test.go 65 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/analyzer_test.go 2018-06-05 08:45:47.000000000 -0700 66 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//analyzer_test.go 2018-06-06 07:45:09.000000000 -0700 67 | @@ -21,7 +21,7 @@ 68 | "strings" 69 | "testing" 70 | 71 | - "vitess.io/vitess/go/sqltypes" 72 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 73 | ) 74 | 75 | func TestPreview(t *testing.T) { 76 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/ast.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//ast.go 77 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/ast.go 2018-06-05 08:45:47.000000000 -0700 78 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//ast.go 2018-06-06 07:45:09.000000000 -0700 79 | @@ -22,14 +22,11 @@ 80 | "encoding/json" 81 | "fmt" 82 | "io" 83 | + "log" 84 | "strings" 85 | 86 | - "vitess.io/vitess/go/sqltypes" 87 | - "vitess.io/vitess/go/vt/log" 88 | - "vitess.io/vitess/go/vt/vterrors" 89 | - 90 | - querypb "vitess.io/vitess/go/vt/proto/query" 91 | - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 92 | + "github.com/xwb1989/sqlparser/dependency/querypb" 93 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 94 | ) 95 | 96 | // Instructions for creating new types: If a type 97 | @@ -52,11 +49,11 @@ 98 | tokenizer := NewStringTokenizer(sql) 99 | if yyParse(tokenizer) != 0 { 100 | if tokenizer.partialDDL != nil { 101 | - log.Warningf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError) 102 | + log.Printf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError) 103 | tokenizer.ParseTree = tokenizer.partialDDL 104 | return tokenizer.ParseTree, nil 105 | } 106 | - return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, tokenizer.LastError.Error()) 107 | + return nil, tokenizer.LastError 108 | } 109 | return tokenizer.ParseTree, nil 110 | } 111 | @@ -2249,7 +2246,7 @@ 112 | return NewStrVal(value.ToBytes()), nil 113 | default: 114 | // We cannot support sqltypes.Expression, or any other invalid type. 115 | - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot convert value %v to AST", value) 116 | + return nil, fmt.Errorf("cannot convert value %v to AST", value) 117 | } 118 | } 119 | 120 | @@ -3394,6 +3391,20 @@ 121 | return nil 122 | } 123 | 124 | +// Backtick produces a backticked literal given an input string. 125 | +func Backtick(in string) string { 126 | + var buf bytes.Buffer 127 | + buf.WriteByte('`') 128 | + for _, c := range in { 129 | + buf.WriteRune(c) 130 | + if c == '`' { 131 | + buf.WriteByte('`') 132 | + } 133 | + } 134 | + buf.WriteByte('`') 135 | + return buf.String() 136 | +} 137 | + 138 | func formatID(buf *TrackedBuffer, original, lowered string) { 139 | isDbSystemVariable := false 140 | if len(original) > 1 && original[:2] == "@@" { 141 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/ast_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//ast_test.go 142 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/ast_test.go 2018-06-05 08:45:47.000000000 -0700 143 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//ast_test.go 2018-06-05 07:41:09.000000000 -0700 144 | @@ -24,7 +24,7 @@ 145 | "testing" 146 | "unsafe" 147 | 148 | - "vitess.io/vitess/go/sqltypes" 149 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 150 | ) 151 | 152 | func TestAppend(t *testing.T) { 153 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/comments.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//comments.go 154 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/comments.go 2018-06-05 08:45:47.000000000 -0700 155 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//comments.go 2018-06-06 07:45:09.000000000 -0700 156 | @@ -145,7 +145,7 @@ 157 | // Single line comment 158 | index := strings.Index(sql, "\n") 159 | if index == -1 { 160 | - return "" 161 | + return sql 162 | } 163 | sql = sql[index+1:] 164 | } 165 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/comments_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//comments_test.go 166 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/comments_test.go 2018-06-05 08:45:47.000000000 -0700 167 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//comments_test.go 2018-06-06 07:45:08.000000000 -0700 168 | @@ -187,7 +187,7 @@ 169 | outSQL: "bar", 170 | }, { 171 | input: "-- /* foo */ bar", 172 | - outSQL: "", 173 | + outSQL: "-- /* foo */ bar", 174 | }, { 175 | input: "foo -- bar */", 176 | outSQL: "foo -- bar */", 177 | @@ -201,7 +201,7 @@ 178 | outSQL: "a", 179 | }, { 180 | input: `-- foo bar`, 181 | - outSQL: "", 182 | + outSQL: "-- foo bar", 183 | }} 184 | for _, testCase := range testCases { 185 | gotSQL := StripLeadingComments(testCase.input) 186 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: dependency 187 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/encodable.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//encodable.go 188 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/encodable.go 2018-06-05 08:45:47.000000000 -0700 189 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//encodable.go 2017-10-18 18:06:33.000000000 -0700 190 | @@ -19,7 +19,7 @@ 191 | import ( 192 | "bytes" 193 | 194 | - "vitess.io/vitess/go/sqltypes" 195 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 196 | ) 197 | 198 | // This file contains types that are 'Encodable'. 199 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/encodable_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//encodable_test.go 200 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/encodable_test.go 2018-06-05 08:45:47.000000000 -0700 201 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//encodable_test.go 2017-10-18 18:06:33.000000000 -0700 202 | @@ -20,7 +20,7 @@ 203 | "bytes" 204 | "testing" 205 | 206 | - "vitess.io/vitess/go/sqltypes" 207 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 208 | ) 209 | 210 | func TestEncodable(t *testing.T) { 211 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: github_test.go 212 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/normalizer.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//normalizer.go 213 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/normalizer.go 2018-06-05 08:45:47.000000000 -0700 214 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//normalizer.go 2017-10-18 18:06:33.000000000 -0700 215 | @@ -19,9 +19,9 @@ 216 | import ( 217 | "fmt" 218 | 219 | - "vitess.io/vitess/go/sqltypes" 220 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 221 | 222 | - querypb "vitess.io/vitess/go/vt/proto/query" 223 | + "github.com/xwb1989/sqlparser/dependency/querypb" 224 | ) 225 | 226 | // Normalize changes the statement to use bind values, and 227 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/normalizer_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//normalizer_test.go 228 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/normalizer_test.go 2018-06-05 08:45:47.000000000 -0700 229 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//normalizer_test.go 2017-11-27 22:10:51.000000000 -0800 230 | @@ -21,8 +21,8 @@ 231 | "reflect" 232 | "testing" 233 | 234 | - "vitess.io/vitess/go/sqltypes" 235 | - querypb "vitess.io/vitess/go/vt/proto/query" 236 | + "github.com/xwb1989/sqlparser/dependency/querypb" 237 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 238 | ) 239 | 240 | func TestNormalize(t *testing.T) { 241 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/parsed_query.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//parsed_query.go 242 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/parsed_query.go 2018-06-05 08:45:47.000000000 -0700 243 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//parsed_query.go 2017-10-22 13:30:37.000000000 -0700 244 | @@ -18,12 +18,10 @@ 245 | 246 | import ( 247 | "bytes" 248 | - "encoding/json" 249 | "fmt" 250 | 251 | - "vitess.io/vitess/go/sqltypes" 252 | - 253 | - querypb "vitess.io/vitess/go/vt/proto/query" 254 | + "github.com/xwb1989/sqlparser/dependency/querypb" 255 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 256 | ) 257 | 258 | // ParsedQuery represents a parsed query where 259 | @@ -71,12 +69,6 @@ 260 | return buf.Bytes(), nil 261 | } 262 | 263 | -// MarshalJSON is a custom JSON marshaler for ParsedQuery. 264 | -// Note that any queries longer that 512 bytes will be truncated. 265 | -func (pq *ParsedQuery) MarshalJSON() ([]byte, error) { 266 | - return json.Marshal(TruncateForUI(pq.Query)) 267 | -} 268 | - 269 | // EncodeValue encodes one bind variable value into the query. 270 | func EncodeValue(buf *bytes.Buffer, value *querypb.BindVariable) { 271 | if value.Type != querypb.Type_TUPLE { 272 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/parsed_query_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//parsed_query_test.go 273 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/parsed_query_test.go 2018-06-05 08:45:47.000000000 -0700 274 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//parsed_query_test.go 2017-10-18 18:06:33.000000000 -0700 275 | @@ -20,9 +20,9 @@ 276 | "reflect" 277 | "testing" 278 | 279 | - "vitess.io/vitess/go/sqltypes" 280 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 281 | 282 | - querypb "vitess.io/vitess/go/vt/proto/query" 283 | + "github.com/xwb1989/sqlparser/dependency/querypb" 284 | ) 285 | 286 | func TestNewParsedQuery(t *testing.T) { 287 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: patches 288 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: quick 289 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/redact_query.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//redact_query.go 290 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/redact_query.go 2018-06-05 08:45:47.000000000 -0700 291 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//redact_query.go 2018-06-06 07:42:56.000000000 -0700 292 | @@ -1,6 +1,6 @@ 293 | package sqlparser 294 | 295 | -import querypb "vitess.io/vitess/go/vt/proto/query" 296 | +import querypb "github.com/xwb1989/sqlparser/dependency/querypb" 297 | 298 | // RedactSQLQuery returns a sql string with the params stripped out for display 299 | func RedactSQLQuery(sql string) (string, error) { 300 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: tests 301 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/token.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//token.go 302 | --- /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/token.go 2018-06-05 08:45:47.000000000 -0700 303 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//token.go 2018-06-06 07:45:09.000000000 -0700 304 | @@ -22,8 +22,8 @@ 305 | "fmt" 306 | "io" 307 | 308 | - "vitess.io/vitess/go/bytes2" 309 | - "vitess.io/vitess/go/sqltypes" 310 | + "github.com/xwb1989/sqlparser/dependency/bytes2" 311 | + "github.com/xwb1989/sqlparser/dependency/sqltypes" 312 | ) 313 | 314 | const ( 315 | Only in /Users/bramp/go/src/vitess.io/vitess/go//vt/sqlparser/: truncate_query.go 316 | Only in /Users/bramp/go/src/github.com/xwb1989/sqlparser//: y.output 317 | -------------------------------------------------------------------------------- /patches/sqltypes.patch: -------------------------------------------------------------------------------- 1 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: arithmetic.go 2 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: arithmetic_test.go 3 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/bind_variables.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/bind_variables.go 4 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/bind_variables.go 2018-06-05 08:45:47.000000000 -0700 5 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/bind_variables.go 2018-06-04 08:05:24.000000000 -0700 6 | @@ -19,11 +19,10 @@ 7 | import ( 8 | "errors" 9 | "fmt" 10 | + "reflect" 11 | "strconv" 12 | 13 | - "github.com/golang/protobuf/proto" 14 | - 15 | - querypb "vitess.io/vitess/go/vt/proto/query" 16 | + "github.com/xwb1989/sqlparser/dependency/querypb" 17 | ) 18 | 19 | // NullBindVariable is a bindvar with NULL value. 20 | @@ -253,9 +252,8 @@ 21 | } 22 | 23 | // BindVariablesEqual compares two maps of bind variables. 24 | -// For protobuf messages we have to use "proto.Equal". 25 | func BindVariablesEqual(x, y map[string]*querypb.BindVariable) bool { 26 | - return proto.Equal(&querypb.BoundQuery{BindVariables: x}, &querypb.BoundQuery{BindVariables: y}) 27 | + return reflect.DeepEqual(&querypb.BoundQuery{BindVariables: x}, &querypb.BoundQuery{BindVariables: y}) 28 | } 29 | 30 | // CopyBindVariables returns a shallow-copy of the given bindVariables map. 31 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/bind_variables_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/bind_variables_test.go 32 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/bind_variables_test.go 2018-06-05 08:45:47.000000000 -0700 33 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/bind_variables_test.go 2018-06-04 08:05:24.000000000 -0700 34 | @@ -21,16 +21,14 @@ 35 | "strings" 36 | "testing" 37 | 38 | - "github.com/golang/protobuf/proto" 39 | - 40 | - querypb "vitess.io/vitess/go/vt/proto/query" 41 | + "github.com/xwb1989/sqlparser/dependency/querypb" 42 | ) 43 | 44 | func TestProtoConversions(t *testing.T) { 45 | v := TestValue(Int64, "1") 46 | got := ValueToProto(v) 47 | want := &querypb.Value{Type: Int64, Value: []byte("1")} 48 | - if !proto.Equal(got, want) { 49 | + if !reflect.DeepEqual(got, want) { 50 | t.Errorf("ValueToProto: %v, want %v", got, want) 51 | } 52 | gotback := ProtoToValue(got) 53 | @@ -240,7 +238,7 @@ 54 | t.Errorf("ToBindVar(%T(%v)) error: nil, want %s", tcase.in, tcase.in, tcase.err) 55 | continue 56 | } 57 | - if !proto.Equal(bv, tcase.out) { 58 | + if !reflect.DeepEqual(bv, tcase.out) { 59 | t.Errorf("ToBindVar(%T(%v)): %v, want %s", tcase.in, tcase.in, bv, tcase.out) 60 | } 61 | } 62 | @@ -523,7 +521,7 @@ 63 | v, err = BindVariableToValue(&querypb.BindVariable{Type: querypb.Type_TUPLE}) 64 | wantErr := "cannot convert a TUPLE bind var into a value" 65 | if err == nil || err.Error() != wantErr { 66 | - t.Errorf(" BindVarToValue(TUPLE): (%v, %v), want %s", v, err, wantErr) 67 | + t.Errorf(" BindVarToValue(TUPLE): %v, want %s", err, wantErr) 68 | } 69 | } 70 | 71 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: event_token.go 72 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: event_token_test.go 73 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/plan_value.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/plan_value.go 74 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/plan_value.go 2018-06-05 08:45:47.000000000 -0700 75 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/plan_value.go 2018-06-04 08:05:24.000000000 -0700 76 | @@ -18,10 +18,10 @@ 77 | 78 | import ( 79 | "encoding/json" 80 | + "errors" 81 | + "fmt" 82 | 83 | - querypb "vitess.io/vitess/go/vt/proto/query" 84 | - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 85 | - "vitess.io/vitess/go/vt/vterrors" 86 | + "github.com/xwb1989/sqlparser/dependency/querypb" 87 | ) 88 | 89 | // PlanValue represents a value or a list of values for 90 | @@ -87,7 +87,7 @@ 91 | case pv.ListKey != "" || pv.Values != nil: 92 | // This code is unreachable because the parser does not allow 93 | // multi-value constructs where a single value is expected. 94 | - return NULL, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "a list was supplied where a single value was expected") 95 | + return NULL, errors.New("a list was supplied where a single value was expected") 96 | } 97 | return NULL, nil 98 | } 99 | @@ -95,10 +95,10 @@ 100 | func (pv PlanValue) lookupValue(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) { 101 | bv, ok := bindVars[pv.Key] 102 | if !ok { 103 | - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "missing bind var %s", pv.Key) 104 | + return nil, fmt.Errorf("missing bind var %s", pv.Key) 105 | } 106 | if bv.Type == querypb.Type_TUPLE { 107 | - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "TUPLE was supplied for single value bind var %s", pv.ListKey) 108 | + return nil, fmt.Errorf("TUPLE was supplied for single value bind var %s", pv.ListKey) 109 | } 110 | return bv, nil 111 | } 112 | @@ -129,16 +129,16 @@ 113 | } 114 | // This code is unreachable because the parser does not allow 115 | // single value constructs where multiple values are expected. 116 | - return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "a single value was supplied where a list was expected") 117 | + return nil, errors.New("a single value was supplied where a list was expected") 118 | } 119 | 120 | func (pv PlanValue) lookupList(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) { 121 | bv, ok := bindVars[pv.ListKey] 122 | if !ok { 123 | - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "missing bind var %s", pv.ListKey) 124 | + return nil, fmt.Errorf("missing bind var %s", pv.ListKey) 125 | } 126 | if bv.Type != querypb.Type_TUPLE { 127 | - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "single value was supplied for TUPLE bind var %s", pv.ListKey) 128 | + return nil, fmt.Errorf("single value was supplied for TUPLE bind var %s", pv.ListKey) 129 | } 130 | return bv, nil 131 | } 132 | @@ -171,7 +171,7 @@ 133 | case l: 134 | return nil 135 | default: 136 | - return vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "mismatch in number of column values") 137 | + return errors.New("mismatch in number of column values") 138 | } 139 | } 140 | 141 | @@ -221,7 +221,7 @@ 142 | rows[i] = make([]Value, len(pvs)) 143 | } 144 | 145 | - // Using j because we're resolving by columns. 146 | + // Using j becasue we're resolving by columns. 147 | for j, pv := range pvs { 148 | switch { 149 | case pv.Key != "": 150 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/plan_value_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/plan_value_test.go 151 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/plan_value_test.go 2018-06-05 08:45:47.000000000 -0700 152 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/plan_value_test.go 2018-06-04 08:05:24.000000000 -0700 153 | @@ -21,7 +21,7 @@ 154 | "strings" 155 | "testing" 156 | 157 | - querypb "vitess.io/vitess/go/vt/proto/query" 158 | + "github.com/xwb1989/sqlparser/dependency/querypb" 159 | ) 160 | 161 | func TestPlanValueIsNull(t *testing.T) { 162 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: proto3.go 163 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: proto3_test.go 164 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: query_response.go 165 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: result.go 166 | Only in /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/: result_test.go 167 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/testing.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/testing.go 168 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/testing.go 2018-06-05 08:45:47.000000000 -0700 169 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/testing.go 2018-06-04 08:06:27.000000000 -0700 170 | @@ -17,17 +17,14 @@ 171 | package sqltypes 172 | 173 | import ( 174 | - "bytes" 175 | - "fmt" 176 | - "strings" 177 | - 178 | - querypb "vitess.io/vitess/go/vt/proto/query" 179 | + querypb "github.com/xwb1989/sqlparser/dependency/querypb" 180 | ) 181 | 182 | // Functions in this file should only be used for testing. 183 | // This is an experiment to see if test code bloat can be 184 | // reduced and readability improved. 185 | 186 | +/* 187 | // MakeTestFields builds a []*querypb.Field for testing. 188 | // fields := sqltypes.MakeTestFields( 189 | // "a|b", 190 | @@ -110,6 +107,7 @@ 191 | } 192 | return results 193 | } 194 | +*/ 195 | 196 | // TestBindVariable makes a *querypb.BindVariable from 197 | // an interface{}.It panics on invalid input. 198 | @@ -131,6 +129,7 @@ 199 | return MakeTrusted(typ, []byte(val)) 200 | } 201 | 202 | +/* 203 | // PrintResults prints []*Results into a string. 204 | // This function should only be used for testing. 205 | func PrintResults(results []*Result) string { 206 | @@ -152,3 +151,4 @@ 207 | } 208 | return splits 209 | } 210 | +*/ 211 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/type.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/type.go 212 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/type.go 2018-06-05 08:45:47.000000000 -0700 213 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/type.go 2018-06-04 08:05:24.000000000 -0700 214 | @@ -19,7 +19,7 @@ 215 | import ( 216 | "fmt" 217 | 218 | - querypb "vitess.io/vitess/go/vt/proto/query" 219 | + "github.com/xwb1989/sqlparser/dependency/querypb" 220 | ) 221 | 222 | // This file provides wrappers and support 223 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/type_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/type_test.go 224 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/type_test.go 2018-06-05 08:45:47.000000000 -0700 225 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/type_test.go 2018-06-04 08:05:24.000000000 -0700 226 | @@ -19,7 +19,7 @@ 227 | import ( 228 | "testing" 229 | 230 | - querypb "vitess.io/vitess/go/vt/proto/query" 231 | + "github.com/xwb1989/sqlparser/dependency/querypb" 232 | ) 233 | 234 | func TestTypeValues(t *testing.T) { 235 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/value.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/value.go 236 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/value.go 2018-06-05 08:45:47.000000000 -0700 237 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/value.go 2018-06-04 08:05:24.000000000 -0700 238 | @@ -23,10 +23,10 @@ 239 | "fmt" 240 | "strconv" 241 | 242 | - "vitess.io/vitess/go/bytes2" 243 | - "vitess.io/vitess/go/hack" 244 | + "github.com/xwb1989/sqlparser/dependency/bytes2" 245 | + "github.com/xwb1989/sqlparser/dependency/hack" 246 | 247 | - querypb "vitess.io/vitess/go/vt/proto/query" 248 | + "github.com/xwb1989/sqlparser/dependency/querypb" 249 | ) 250 | 251 | var ( 252 | @@ -48,7 +48,7 @@ 253 | } 254 | 255 | // Value can store any SQL value. If the value represents 256 | -// an integral type, the bytes are always stored as a canonical 257 | +// an integral type, the bytes are always stored as a cannonical 258 | // representation that matches how MySQL returns such values. 259 | type Value struct { 260 | typ querypb.Type 261 | @@ -126,7 +126,7 @@ 262 | return MakeTrusted(VarBinary, []byte(v)) 263 | } 264 | 265 | -// NewIntegral builds an integral type from a string representation. 266 | +// NewIntegral builds an integral type from a string representaion. 267 | // The type will be Int64 or Uint64. Int64 will be preferred where possible. 268 | func NewIntegral(val string) (n Value, err error) { 269 | signed, err := strconv.ParseInt(val, 0, 64) 270 | @@ -169,7 +169,7 @@ 271 | return v.typ 272 | } 273 | 274 | -// Raw returns the internal representation of the value. For newer types, 275 | +// Raw returns the internal represenation of the value. For newer types, 276 | // this may not match MySQL's representation. 277 | func (v Value) Raw() []byte { 278 | return v.val 279 | diff -u /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/value_test.go /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/value_test.go 280 | --- /Users/bramp/go/src/vitess.io/vitess/go//sqltypes/value_test.go 2018-06-05 08:45:47.000000000 -0700 281 | +++ /Users/bramp/go/src/github.com/xwb1989/sqlparser//dependency/sqltypes/value_test.go 2018-06-04 08:05:24.000000000 -0700 282 | @@ -22,7 +22,7 @@ 283 | "strings" 284 | "testing" 285 | 286 | - querypb "vitess.io/vitess/go/vt/proto/query" 287 | + "github.com/xwb1989/sqlparser/dependency/querypb" 288 | ) 289 | 290 | const ( 291 | -------------------------------------------------------------------------------- /precedence_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | "testing" 22 | ) 23 | 24 | func readable(node Expr) string { 25 | switch node := node.(type) { 26 | case *OrExpr: 27 | return fmt.Sprintf("(%s or %s)", readable(node.Left), readable(node.Right)) 28 | case *AndExpr: 29 | return fmt.Sprintf("(%s and %s)", readable(node.Left), readable(node.Right)) 30 | case *BinaryExpr: 31 | return fmt.Sprintf("(%s %s %s)", readable(node.Left), node.Operator, readable(node.Right)) 32 | case *IsExpr: 33 | return fmt.Sprintf("(%s %s)", readable(node.Expr), node.Operator) 34 | default: 35 | return String(node) 36 | } 37 | } 38 | 39 | func TestAndOrPrecedence(t *testing.T) { 40 | validSQL := []struct { 41 | input string 42 | output string 43 | }{{ 44 | input: "select * from a where a=b and c=d or e=f", 45 | output: "((a = b and c = d) or e = f)", 46 | }, { 47 | input: "select * from a where a=b or c=d and e=f", 48 | output: "(a = b or (c = d and e = f))", 49 | }} 50 | for _, tcase := range validSQL { 51 | tree, err := Parse(tcase.input) 52 | if err != nil { 53 | t.Error(err) 54 | continue 55 | } 56 | expr := readable(tree.(*Select).Where.Expr) 57 | if expr != tcase.output { 58 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 59 | } 60 | } 61 | } 62 | 63 | func TestPlusStarPrecedence(t *testing.T) { 64 | validSQL := []struct { 65 | input string 66 | output string 67 | }{{ 68 | input: "select 1+2*3 from a", 69 | output: "(1 + (2 * 3))", 70 | }, { 71 | input: "select 1*2+3 from a", 72 | output: "((1 * 2) + 3)", 73 | }} 74 | for _, tcase := range validSQL { 75 | tree, err := Parse(tcase.input) 76 | if err != nil { 77 | t.Error(err) 78 | continue 79 | } 80 | expr := readable(tree.(*Select).SelectExprs[0].(*AliasedExpr).Expr) 81 | if expr != tcase.output { 82 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 83 | } 84 | } 85 | } 86 | 87 | func TestIsPrecedence(t *testing.T) { 88 | validSQL := []struct { 89 | input string 90 | output string 91 | }{{ 92 | input: "select * from a where a+b is true", 93 | output: "((a + b) is true)", 94 | }, { 95 | input: "select * from a where a=1 and b=2 is true", 96 | output: "(a = 1 and (b = 2 is true))", 97 | }, { 98 | input: "select * from a where (a=1 and b=2) is true", 99 | output: "((a = 1 and b = 2) is true)", 100 | }} 101 | for _, tcase := range validSQL { 102 | tree, err := Parse(tcase.input) 103 | if err != nil { 104 | t.Error(err) 105 | continue 106 | } 107 | expr := readable(tree.(*Select).Where.Expr) 108 | if expr != tcase.output { 109 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /redact_query.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import querypb "github.com/xwb1989/sqlparser/dependency/querypb" 4 | 5 | // RedactSQLQuery returns a sql string with the params stripped out for display 6 | func RedactSQLQuery(sql string) (string, error) { 7 | bv := map[string]*querypb.BindVariable{} 8 | sqlStripped, comments := SplitMarginComments(sql) 9 | 10 | stmt, err := Parse(sqlStripped) 11 | if err != nil { 12 | return "", err 13 | } 14 | 15 | prefix := "redacted" 16 | Normalize(stmt, bv, prefix) 17 | 18 | return comments.Leading + String(stmt) + comments.Trailing, nil 19 | } 20 | -------------------------------------------------------------------------------- /redact_query_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestRedactSQLStatements(t *testing.T) { 8 | sql := "select a,b,c from t where x = 1234 and y = 1234 and z = 'apple'" 9 | redactedSQL, err := RedactSQLQuery(sql) 10 | if err != nil { 11 | t.Fatalf("redacting sql failed: %v", err) 12 | } 13 | 14 | if redactedSQL != "select a, b, c from t where x = :redacted1 and y = :redacted1 and z = :redacted2" { 15 | t.Fatalf("Unknown sql redaction: %v", redactedSQL) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /token_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | "testing" 22 | ) 23 | 24 | func TestLiteralID(t *testing.T) { 25 | testcases := []struct { 26 | in string 27 | id int 28 | out string 29 | }{{ 30 | in: "`aa`", 31 | id: ID, 32 | out: "aa", 33 | }, { 34 | in: "```a```", 35 | id: ID, 36 | out: "`a`", 37 | }, { 38 | in: "`a``b`", 39 | id: ID, 40 | out: "a`b", 41 | }, { 42 | in: "`a``b`c", 43 | id: ID, 44 | out: "a`b", 45 | }, { 46 | in: "`a``b", 47 | id: LEX_ERROR, 48 | out: "a`b", 49 | }, { 50 | in: "`a``b``", 51 | id: LEX_ERROR, 52 | out: "a`b`", 53 | }, { 54 | in: "``", 55 | id: LEX_ERROR, 56 | out: "", 57 | }} 58 | 59 | for _, tcase := range testcases { 60 | tkn := NewStringTokenizer(tcase.in) 61 | id, out := tkn.Scan() 62 | if tcase.id != id || string(out) != tcase.out { 63 | t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out) 64 | } 65 | } 66 | } 67 | 68 | func tokenName(id int) string { 69 | if id == STRING { 70 | return "STRING" 71 | } else if id == LEX_ERROR { 72 | return "LEX_ERROR" 73 | } 74 | return fmt.Sprintf("%d", id) 75 | } 76 | 77 | func TestString(t *testing.T) { 78 | testcases := []struct { 79 | in string 80 | id int 81 | want string 82 | }{{ 83 | in: "''", 84 | id: STRING, 85 | want: "", 86 | }, { 87 | in: "''''", 88 | id: STRING, 89 | want: "'", 90 | }, { 91 | in: "'hello'", 92 | id: STRING, 93 | want: "hello", 94 | }, { 95 | in: "'\\n'", 96 | id: STRING, 97 | want: "\n", 98 | }, { 99 | in: "'\\nhello\\n'", 100 | id: STRING, 101 | want: "\nhello\n", 102 | }, { 103 | in: "'a''b'", 104 | id: STRING, 105 | want: "a'b", 106 | }, { 107 | in: "'a\\'b'", 108 | id: STRING, 109 | want: "a'b", 110 | }, { 111 | in: "'\\'", 112 | id: LEX_ERROR, 113 | want: "'", 114 | }, { 115 | in: "'", 116 | id: LEX_ERROR, 117 | want: "", 118 | }, { 119 | in: "'hello\\'", 120 | id: LEX_ERROR, 121 | want: "hello'", 122 | }, { 123 | in: "'hello", 124 | id: LEX_ERROR, 125 | want: "hello", 126 | }, { 127 | in: "'hello\\", 128 | id: LEX_ERROR, 129 | want: "hello", 130 | }} 131 | 132 | for _, tcase := range testcases { 133 | id, got := NewStringTokenizer(tcase.in).Scan() 134 | if tcase.id != id || string(got) != tcase.want { 135 | t.Errorf("Scan(%q) = (%s, %q), want (%s, %q)", tcase.in, tokenName(id), got, tokenName(tcase.id), tcase.want) 136 | } 137 | } 138 | } 139 | 140 | func TestSplitStatement(t *testing.T) { 141 | testcases := []struct { 142 | in string 143 | sql string 144 | rem string 145 | }{{ 146 | in: "select * from table", 147 | sql: "select * from table", 148 | }, { 149 | in: "select * from table; ", 150 | sql: "select * from table", 151 | rem: " ", 152 | }, { 153 | in: "select * from table; select * from table2;", 154 | sql: "select * from table", 155 | rem: " select * from table2;", 156 | }, { 157 | in: "select * from /* comment */ table;", 158 | sql: "select * from /* comment */ table", 159 | }, { 160 | in: "select * from /* comment ; */ table;", 161 | sql: "select * from /* comment ; */ table", 162 | }, { 163 | in: "select * from table where semi = ';';", 164 | sql: "select * from table where semi = ';'", 165 | }, { 166 | in: "-- select * from table", 167 | sql: "-- select * from table", 168 | }, { 169 | in: " ", 170 | sql: " ", 171 | }, { 172 | in: "", 173 | sql: "", 174 | }} 175 | 176 | for _, tcase := range testcases { 177 | sql, rem, err := SplitStatement(tcase.in) 178 | if err != nil { 179 | t.Errorf("EndOfStatementPosition(%s): ERROR: %v", tcase.in, err) 180 | continue 181 | } 182 | 183 | if tcase.sql != sql { 184 | t.Errorf("EndOfStatementPosition(%s) got sql \"%s\" want \"%s\"", tcase.in, sql, tcase.sql) 185 | } 186 | 187 | if tcase.rem != rem { 188 | t.Errorf("EndOfStatementPosition(%s) got remainder \"%s\" want \"%s\"", tcase.in, rem, tcase.rem) 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /tracked_buffer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "bytes" 21 | "fmt" 22 | ) 23 | 24 | // NodeFormatter defines the signature of a custom node formatter 25 | // function that can be given to TrackedBuffer for code generation. 26 | type NodeFormatter func(buf *TrackedBuffer, node SQLNode) 27 | 28 | // TrackedBuffer is used to rebuild a query from the ast. 29 | // bindLocations keeps track of locations in the buffer that 30 | // use bind variables for efficient future substitutions. 31 | // nodeFormatter is the formatting function the buffer will 32 | // use to format a node. By default(nil), it's FormatNode. 33 | // But you can supply a different formatting function if you 34 | // want to generate a query that's different from the default. 35 | type TrackedBuffer struct { 36 | *bytes.Buffer 37 | bindLocations []bindLocation 38 | nodeFormatter NodeFormatter 39 | } 40 | 41 | // NewTrackedBuffer creates a new TrackedBuffer. 42 | func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer { 43 | return &TrackedBuffer{ 44 | Buffer: new(bytes.Buffer), 45 | nodeFormatter: nodeFormatter, 46 | } 47 | } 48 | 49 | // WriteNode function, initiates the writing of a single SQLNode tree by passing 50 | // through to Myprintf with a default format string 51 | func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer { 52 | buf.Myprintf("%v", node) 53 | return buf 54 | } 55 | 56 | // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v), 57 | // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in 58 | // which case it adds tracking info for future substitutions. 59 | // 60 | // The name must be something other than the usual Printf() to avoid "go vet" 61 | // warnings due to our custom format specifiers. 62 | func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { 63 | end := len(format) 64 | fieldnum := 0 65 | for i := 0; i < end; { 66 | lasti := i 67 | for i < end && format[i] != '%' { 68 | i++ 69 | } 70 | if i > lasti { 71 | buf.WriteString(format[lasti:i]) 72 | } 73 | if i >= end { 74 | break 75 | } 76 | i++ // '%' 77 | switch format[i] { 78 | case 'c': 79 | switch v := values[fieldnum].(type) { 80 | case byte: 81 | buf.WriteByte(v) 82 | case rune: 83 | buf.WriteRune(v) 84 | default: 85 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 86 | } 87 | case 's': 88 | switch v := values[fieldnum].(type) { 89 | case []byte: 90 | buf.Write(v) 91 | case string: 92 | buf.WriteString(v) 93 | default: 94 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 95 | } 96 | case 'v': 97 | node := values[fieldnum].(SQLNode) 98 | if buf.nodeFormatter == nil { 99 | node.Format(buf) 100 | } else { 101 | buf.nodeFormatter(buf, node) 102 | } 103 | case 'a': 104 | buf.WriteArg(values[fieldnum].(string)) 105 | default: 106 | panic("unexpected") 107 | } 108 | fieldnum++ 109 | i++ 110 | } 111 | } 112 | 113 | // WriteArg writes a value argument into the buffer along with 114 | // tracking information for future substitutions. arg must contain 115 | // the ":" or "::" prefix. 116 | func (buf *TrackedBuffer) WriteArg(arg string) { 117 | buf.bindLocations = append(buf.bindLocations, bindLocation{ 118 | offset: buf.Len(), 119 | length: len(arg), 120 | }) 121 | buf.WriteString(arg) 122 | } 123 | 124 | // ParsedQuery returns a ParsedQuery that contains bind 125 | // locations for easy substitution. 126 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 127 | return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations} 128 | } 129 | 130 | // HasBindVars returns true if the parsed query uses bind vars. 131 | func (buf *TrackedBuffer) HasBindVars() bool { 132 | return len(buf.bindLocations) != 0 133 | } 134 | 135 | // BuildParsedQuery builds a ParsedQuery from the input. 136 | func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery { 137 | buf := NewTrackedBuffer(nil) 138 | buf.Myprintf(in, vars...) 139 | return buf.ParsedQuery() 140 | } 141 | --------------------------------------------------------------------------------