├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── question.md └── pull_request_template.md ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── ast ├── advisor.go ├── ast.go ├── base.go ├── ddl.go ├── ddl_test.go ├── dml.go ├── dml_test.go ├── expressions.go ├── expressions_test.go ├── flag.go ├── flag_test.go ├── format_test.go ├── functions.go ├── functions_test.go ├── misc.go ├── misc_test.go ├── stats.go ├── util.go └── util_test.go ├── auth ├── auth.go ├── auth_test.go ├── caching_sha2.go ├── caching_sha2_test.go ├── mysql_native_password.go └── mysql_native_password_test.go ├── bench_test.go ├── charset ├── charset.go ├── charset_test.go ├── encoding.go ├── encoding_table.go └── encoding_test.go ├── checkout-pr-branch.sh ├── circle.yml ├── codecov.yml ├── consistent_test.go ├── digester.go ├── digester_test.go ├── docs ├── quickstart.md └── update-parser-for-tidb.md ├── export_test.go ├── format ├── format.go └── format_test.go ├── go.mod ├── go.sum ├── goyacc ├── format_yacc.go └── main.go ├── hintparser.go ├── hintparser.y ├── hintparser_test.go ├── hintparserimpl.go ├── lexer.go ├── lexer_test.go ├── misc.go ├── model ├── ddl.go ├── flags.go ├── model.go └── model_test.go ├── mysql ├── charset.go ├── const.go ├── const_test.go ├── errcode.go ├── errname.go ├── error.go ├── error_test.go ├── locale_format.go ├── privs.go ├── privs_test.go ├── state.go ├── type.go ├── type_test.go └── util.go ├── opcode ├── opcode.go └── opcode_test.go ├── parser.go ├── parser.y ├── parser_test.go ├── reserved_words_test.go ├── terror ├── terror.go └── terror_test.go ├── test.sh ├── test_driver ├── test_driver.go ├── test_driver_datum.go ├── test_driver_helper.go └── test_driver_mydecimal.go ├── tidb └── features.go ├── types ├── etc.go ├── eval_type.go ├── field_type.go └── field_type_test.go └── yy_parser.go /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | charset = utf-8 7 | 8 | # tab_size = 4 spaces 9 | [{*.go,*.y}] 10 | indent_style = tab 11 | indent_size = 4 12 | trim_trailing_whitespace = true 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Something isn't working as expected 4 | 5 | --- 6 | 7 | ## Bug Report 8 | 9 | Please answer these questions before submitting your issue. Thanks! 10 | 11 | 1. What did you do? 12 | If possible, provide a recipe for reproducing the error. 13 | 14 | 15 | 2. What did you expect to see? 16 | 17 | 18 | 19 | 3. What did you see instead? 20 | 21 | 22 | 23 | 4. What version of TiDB SQL Parser are you using? 24 | 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: I have a suggestion 4 | 5 | --- 6 | 7 | ## Feature Request 8 | 9 | **Is your feature request related to a problem? Please describe:** 10 | 11 | 12 | **Describe the feature you'd like:** 13 | 14 | 15 | **Describe alternatives you've considered:** 16 | 17 | 18 | **Teachability, Documentation, Adoption, Migration Strategy:** 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F914 Question" 3 | about: Usage question that isn't answered in docs or discussion 4 | 5 | --- 6 | 7 | ## Question 8 | 9 | Before asking a question, make sure you have: 10 | 11 | - Searched existing Stack Overflow questions. 12 | - Googled your question. 13 | - Searched open and closed [GitHub issues](https://github.com/pingcap/parser/issues?utf8=%E2%9C%93&q=is%3Aissue) 14 | - Read the documentation: 15 | - [TiDB SQL Parser Readme](https://github.com/pingcap/parser) 16 | - [TiDB Doc](https://github.com/pingcap/docs) -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 4 | 5 | ### What problem does this PR solve? 6 | 7 | 8 | ### What is changed and how it works? 9 | 10 | 11 | ### Check List 12 | 13 | Tests 14 | 15 | - Unit test 16 | - Integration test 17 | - Manual test (add detailed scripts or steps below) 18 | - No code 19 | 20 | Code changes 21 | 22 | - Has exported function/method change 23 | - Has exported variable/fields change 24 | - Has interface methods change 25 | 26 | Side effects 27 | 28 | - Possible performance regression 29 | - Increased code complexity 30 | - Breaking backward compatibility 31 | 32 | Related changes 33 | 34 | - Need to cherry-pick to the release branch 35 | - Need to update the documentation 36 | - Need to be included in the release note 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | y.go 3 | *.output 4 | .idea/ 5 | .vscode/ 6 | coverage.txt 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all parser clean 2 | 3 | all: fmt parser 4 | 5 | test: fmt parser 6 | sh test.sh 7 | 8 | parser: parser.go hintparser.go 9 | 10 | %arser.go: prefix = $(@:parser.go=) 11 | %arser.go: %arser.y bin/goyacc 12 | @echo "bin/goyacc -o $@ -p yy$(prefix) -t $(prefix)Parser $<" 13 | @bin/goyacc -o $@ -p yy$(prefix) -t $(prefix)Parser $< || ( rm -f $@ && echo 'Please check y.output for more information' && exit 1 ) 14 | @rm -f y.output 15 | 16 | %arser_golden.y: %arser.y 17 | @bin/goyacc -fmt -fmtout $@ $< 18 | @(git diff --no-index --exit-code $< $@ && rm $@) || (mv $@ $< && >&2 echo "formatted $<" && exit 1) 19 | 20 | bin/goyacc: goyacc/main.go goyacc/format_yacc.go 21 | GO111MODULE=on go build -o bin/goyacc goyacc/main.go goyacc/format_yacc.go 22 | 23 | fmt: bin/goyacc parser_golden.y hintparser_golden.y 24 | @echo "gofmt (simplify)" 25 | @gofmt -s -l -w . 2>&1 | awk '{print} END{if(NR>0) {exit 1}}' 26 | 27 | clean: 28 | go clean -i ./... 29 | rm -rf *.out 30 | rm -f parser.go hintparser.go 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository has been moved to https://github.com/pingcap/tidb/tree/master/pkg/parser. 2 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Vulnerability Disclosure and Response Process 2 | 3 | The primary goal of this process is to reduce the total exposure time of users to publicly known vulnerabilities. TiDB security team is responsible for the entire vulnerability management process, including internal communication and external disclosure. 4 | 5 | If you find a vulnerability or encounter a security incident involving vulnerabilities of this repository, please report it as soon as possible to the TiDB security team (security@tidb.io). 6 | 7 | Please kindly help provide as much vulnerability information as possible in the following format: 8 | 9 | - Issue title*: 10 | 11 | - Overview*: 12 | 13 | - Affected components and version number*: 14 | 15 | - CVE number (if any): 16 | 17 | - Vulnerability verification process*: 18 | 19 | - Contact information*: 20 | 21 | The asterisk (*) indicates the required field. 22 | 23 | # Response Time 24 | 25 | The TiDB security team will confirm the vulnerabilities and contact you within 2 working days after your submission. 26 | 27 | We will publicly thank you after fixing the security vulnerability. To avoid negative impact, please keep the vulnerability confidential until we fix it. We would appreciate it if you could obey the following code of conduct: 28 | 29 | The vulnerability will not be disclosed until a patch is released for it. 30 | 31 | The details of the vulnerability, for example, exploits code, will not be disclosed. 32 | -------------------------------------------------------------------------------- /ast/advisor.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast 15 | 16 | import ( 17 | "github.com/pingcap/parser/format" 18 | ) 19 | 20 | var _ StmtNode = &IndexAdviseStmt{} 21 | 22 | // IndexAdviseStmt is used to advise indexes 23 | type IndexAdviseStmt struct { 24 | stmtNode 25 | 26 | IsLocal bool 27 | Path string 28 | MaxMinutes uint64 29 | MaxIndexNum *MaxIndexNumClause 30 | LinesInfo *LinesClause 31 | } 32 | 33 | // Restore implements Node Accept interface. 34 | func (n *IndexAdviseStmt) Restore(ctx *format.RestoreCtx) error { 35 | ctx.WriteKeyWord("INDEX ADVISE ") 36 | if n.IsLocal { 37 | ctx.WriteKeyWord("LOCAL ") 38 | } 39 | ctx.WriteKeyWord("INFILE ") 40 | ctx.WriteString(n.Path) 41 | if n.MaxMinutes != UnspecifiedSize { 42 | ctx.WriteKeyWord(" MAX_MINUTES ") 43 | ctx.WritePlainf("%d", n.MaxMinutes) 44 | } 45 | if n.MaxIndexNum != nil { 46 | n.MaxIndexNum.Restore(ctx) 47 | } 48 | n.LinesInfo.Restore(ctx) 49 | return nil 50 | } 51 | 52 | // Accept implements Node Accept interface. 53 | func (n *IndexAdviseStmt) Accept(v Visitor) (Node, bool) { 54 | newNode, skipChildren := v.Enter(n) 55 | if skipChildren { 56 | return v.Leave(newNode) 57 | } 58 | n = newNode.(*IndexAdviseStmt) 59 | return v.Leave(n) 60 | } 61 | 62 | // MaxIndexNumClause represents 'maximum number of indexes' clause in index advise statement. 63 | type MaxIndexNumClause struct { 64 | PerTable uint64 65 | PerDB uint64 66 | } 67 | 68 | // Restore for max index num clause 69 | func (n *MaxIndexNumClause) Restore(ctx *format.RestoreCtx) error { 70 | ctx.WriteKeyWord(" MAX_IDXNUM") 71 | if n.PerTable != UnspecifiedSize { 72 | ctx.WriteKeyWord(" PER_TABLE ") 73 | ctx.WritePlainf("%d", n.PerTable) 74 | } 75 | if n.PerDB != UnspecifiedSize { 76 | ctx.WriteKeyWord(" PER_DB ") 77 | ctx.WritePlainf("%d", n.PerDB) 78 | } 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /ast/ast.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | // Package ast is the abstract syntax tree parsed from a SQL statement by parser. 15 | // It can be analysed and transformed by optimizer. 16 | package ast 17 | 18 | import ( 19 | "io" 20 | 21 | "github.com/pingcap/parser/format" 22 | "github.com/pingcap/parser/model" 23 | "github.com/pingcap/parser/types" 24 | ) 25 | 26 | // Node is the basic element of the AST. 27 | // Interfaces embed Node should have 'Node' name suffix. 28 | type Node interface { 29 | // Restore returns the sql text from ast tree 30 | Restore(ctx *format.RestoreCtx) error 31 | // Accept accepts Visitor to visit itself. 32 | // The returned node should replace original node. 33 | // ok returns false to stop visiting. 34 | // 35 | // Implementation of this method should first call visitor.Enter, 36 | // assign the returned node to its method receiver, if skipChildren returns true, 37 | // children should be skipped. Otherwise, call its children in particular order that 38 | // later elements depends on former elements. Finally, return visitor.Leave. 39 | Accept(v Visitor) (node Node, ok bool) 40 | // Text returns the original text of the element. 41 | Text() string 42 | // SetText sets original text to the Node. 43 | SetText(text string) 44 | // SetOriginTextPosition set the start offset of this node in the origin text. 45 | SetOriginTextPosition(offset int) 46 | // OriginTextPosition get the start offset of this node in the origin text. 47 | OriginTextPosition() int 48 | } 49 | 50 | // Flags indicates whether an expression contains certain types of expression. 51 | const ( 52 | FlagConstant uint64 = 0 53 | FlagHasParamMarker uint64 = 1 << iota 54 | FlagHasFunc 55 | FlagHasReference 56 | FlagHasAggregateFunc 57 | FlagHasSubquery 58 | FlagHasVariable 59 | FlagHasDefault 60 | FlagPreEvaluated 61 | FlagHasWindowFunc 62 | ) 63 | 64 | // ExprNode is a node that can be evaluated. 65 | // Name of implementations should have 'Expr' suffix. 66 | type ExprNode interface { 67 | // Node is embedded in ExprNode. 68 | Node 69 | // SetType sets evaluation type to the expression. 70 | SetType(tp *types.FieldType) 71 | // GetType gets the evaluation type of the expression. 72 | GetType() *types.FieldType 73 | // SetFlag sets flag to the expression. 74 | // Flag indicates whether the expression contains 75 | // parameter marker, reference, aggregate function... 76 | SetFlag(flag uint64) 77 | // GetFlag returns the flag of the expression. 78 | GetFlag() uint64 79 | 80 | // Format formats the AST into a writer. 81 | Format(w io.Writer) 82 | } 83 | 84 | // OptBinary is used for parser. 85 | type OptBinary struct { 86 | IsBinary bool 87 | Charset string 88 | } 89 | 90 | // FuncNode represents function call expression node. 91 | type FuncNode interface { 92 | ExprNode 93 | functionExpression() 94 | } 95 | 96 | // StmtNode represents statement node. 97 | // Name of implementations should have 'Stmt' suffix. 98 | type StmtNode interface { 99 | Node 100 | statement() 101 | } 102 | 103 | // DDLNode represents DDL statement node. 104 | type DDLNode interface { 105 | StmtNode 106 | ddlStatement() 107 | } 108 | 109 | // DMLNode represents DML statement node. 110 | type DMLNode interface { 111 | StmtNode 112 | dmlStatement() 113 | } 114 | 115 | // ResultField represents a result field which can be a column from a table, 116 | // or an expression in select field. It is a generated property during 117 | // binding process. ResultField is the key element to evaluate a ColumnNameExpr. 118 | // After resolving process, every ColumnNameExpr will be resolved to a ResultField. 119 | // During execution, every row retrieved from table will set the row value to 120 | // ResultFields of that table, so ColumnNameExpr resolved to that ResultField can be 121 | // easily evaluated. 122 | type ResultField struct { 123 | Column *model.ColumnInfo 124 | ColumnAsName model.CIStr 125 | Table *model.TableInfo 126 | TableAsName model.CIStr 127 | DBName model.CIStr 128 | 129 | // Expr represents the expression for the result field. If it is generated from a select field, it would 130 | // be the expression of that select field, otherwise the type would be ValueExpr and value 131 | // will be set for every retrieved row. 132 | Expr ExprNode 133 | TableName *TableName 134 | // Referenced indicates the result field has been referenced or not. 135 | // If not, we don't need to get the values. 136 | Referenced bool 137 | } 138 | 139 | // ResultSetNode interface has a ResultFields property, represents a Node that returns result set. 140 | // Implementations include SelectStmt, SubqueryExpr, TableSource, TableName, Join and SetOprStmt. 141 | type ResultSetNode interface { 142 | Node 143 | 144 | resultSet() 145 | } 146 | 147 | // SensitiveStmtNode overloads StmtNode and provides a SecureText method. 148 | type SensitiveStmtNode interface { 149 | StmtNode 150 | // SecureText is different from Text that it hide password information. 151 | SecureText() string 152 | } 153 | 154 | // Visitor visits a Node. 155 | type Visitor interface { 156 | // Enter is called before children nodes are visited. 157 | // The returned node must be the same type as the input node n. 158 | // skipChildren returns true means children nodes should be skipped, 159 | // this is useful when work is done in Enter and there is no need to visit children. 160 | Enter(n Node) (node Node, skipChildren bool) 161 | // Leave is called after children nodes have been visited. 162 | // The returned node's type can be different from the input node if it is a ExprNode, 163 | // Non-expression node must be the same type as the input node n. 164 | // ok returns false to stop visiting. 165 | Leave(n Node) (node Node, ok bool) 166 | } 167 | -------------------------------------------------------------------------------- /ast/base.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast 15 | 16 | import ( 17 | "github.com/pingcap/parser/types" 18 | ) 19 | 20 | // node is the struct implements Node interface except for Accept method. 21 | // Node implementations should embed it in. 22 | type node struct { 23 | text string 24 | offset int 25 | } 26 | 27 | // SetOriginTextPosition implements Node interface. 28 | func (n *node) SetOriginTextPosition(offset int) { 29 | n.offset = offset 30 | } 31 | 32 | // OriginTextPosition implements Node interface. 33 | func (n *node) OriginTextPosition() int { 34 | return n.offset 35 | } 36 | 37 | // SetText implements Node interface. 38 | func (n *node) SetText(text string) { 39 | n.text = text 40 | } 41 | 42 | // Text implements Node interface. 43 | func (n *node) Text() string { 44 | return n.text 45 | } 46 | 47 | // stmtNode implements StmtNode interface. 48 | // Statement implementations should embed it in. 49 | type stmtNode struct { 50 | node 51 | } 52 | 53 | // statement implements StmtNode interface. 54 | func (sn *stmtNode) statement() {} 55 | 56 | // ddlNode implements DDLNode interface. 57 | // DDL implementations should embed it in. 58 | type ddlNode struct { 59 | stmtNode 60 | } 61 | 62 | // ddlStatement implements DDLNode interface. 63 | func (dn *ddlNode) ddlStatement() {} 64 | 65 | // dmlNode is the struct implements DMLNode interface. 66 | // DML implementations should embed it in. 67 | type dmlNode struct { 68 | stmtNode 69 | } 70 | 71 | // dmlStatement implements DMLNode interface. 72 | func (dn *dmlNode) dmlStatement() {} 73 | 74 | // exprNode is the struct implements Expression interface. 75 | // Expression implementations should embed it in. 76 | type exprNode struct { 77 | node 78 | Type types.FieldType 79 | flag uint64 80 | } 81 | 82 | // TexprNode is exported for parser driver. 83 | type TexprNode = exprNode 84 | 85 | // SetType implements ExprNode interface. 86 | func (en *exprNode) SetType(tp *types.FieldType) { 87 | en.Type = *tp 88 | } 89 | 90 | // GetType implements ExprNode interface. 91 | func (en *exprNode) GetType() *types.FieldType { 92 | return &en.Type 93 | } 94 | 95 | // SetFlag implements ExprNode interface. 96 | func (en *exprNode) SetFlag(flag uint64) { 97 | en.flag = flag 98 | } 99 | 100 | // GetFlag implements ExprNode interface. 101 | func (en *exprNode) GetFlag() uint64 { 102 | return en.flag 103 | } 104 | 105 | type funcNode struct { 106 | exprNode 107 | } 108 | 109 | // functionExpression implements FunctionNode interface. 110 | func (fn *funcNode) functionExpression() {} 111 | -------------------------------------------------------------------------------- /ast/flag.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast 15 | 16 | // HasAggFlag checks if the expr contains FlagHasAggregateFunc. 17 | func HasAggFlag(expr ExprNode) bool { 18 | return expr.GetFlag()&FlagHasAggregateFunc > 0 19 | } 20 | 21 | func HasWindowFlag(expr ExprNode) bool { 22 | return expr.GetFlag()&FlagHasWindowFunc > 0 23 | } 24 | 25 | // SetFlag sets flag for expression. 26 | func SetFlag(n Node) { 27 | var setter flagSetter 28 | n.Accept(&setter) 29 | } 30 | 31 | type flagSetter struct { 32 | } 33 | 34 | func (f *flagSetter) Enter(in Node) (Node, bool) { 35 | return in, false 36 | } 37 | 38 | func (f *flagSetter) Leave(in Node) (Node, bool) { 39 | if x, ok := in.(ParamMarkerExpr); ok { 40 | x.SetFlag(FlagHasParamMarker) 41 | } 42 | switch x := in.(type) { 43 | case *AggregateFuncExpr: 44 | f.aggregateFunc(x) 45 | case *WindowFuncExpr: 46 | f.windowFunc(x) 47 | case *BetweenExpr: 48 | x.SetFlag(x.Expr.GetFlag() | x.Left.GetFlag() | x.Right.GetFlag()) 49 | case *BinaryOperationExpr: 50 | x.SetFlag(x.L.GetFlag() | x.R.GetFlag()) 51 | case *CaseExpr: 52 | f.caseExpr(x) 53 | case *ColumnNameExpr: 54 | x.SetFlag(FlagHasReference) 55 | case *CompareSubqueryExpr: 56 | x.SetFlag(x.L.GetFlag() | x.R.GetFlag()) 57 | case *DefaultExpr: 58 | x.SetFlag(FlagHasDefault) 59 | case *ExistsSubqueryExpr: 60 | x.SetFlag(x.Sel.GetFlag()) 61 | case *FuncCallExpr: 62 | f.funcCall(x) 63 | case *FuncCastExpr: 64 | x.SetFlag(FlagHasFunc | x.Expr.GetFlag()) 65 | case *IsNullExpr: 66 | x.SetFlag(x.Expr.GetFlag()) 67 | case *IsTruthExpr: 68 | x.SetFlag(x.Expr.GetFlag()) 69 | case *ParenthesesExpr: 70 | x.SetFlag(x.Expr.GetFlag()) 71 | case *PatternInExpr: 72 | f.patternIn(x) 73 | case *PatternLikeExpr: 74 | f.patternLike(x) 75 | case *PatternRegexpExpr: 76 | f.patternRegexp(x) 77 | case *PositionExpr: 78 | x.SetFlag(FlagHasReference) 79 | case *RowExpr: 80 | f.row(x) 81 | case *SubqueryExpr: 82 | x.SetFlag(FlagHasSubquery) 83 | case *UnaryOperationExpr: 84 | x.SetFlag(x.V.GetFlag()) 85 | case *ValuesExpr: 86 | x.SetFlag(FlagHasReference) 87 | case *VariableExpr: 88 | if x.Value == nil { 89 | x.SetFlag(FlagHasVariable) 90 | } else { 91 | x.SetFlag(FlagHasVariable | x.Value.GetFlag()) 92 | } 93 | } 94 | 95 | return in, true 96 | } 97 | 98 | func (f *flagSetter) caseExpr(x *CaseExpr) { 99 | var flag uint64 100 | if x.Value != nil { 101 | flag |= x.Value.GetFlag() 102 | } 103 | for _, val := range x.WhenClauses { 104 | flag |= val.Expr.GetFlag() 105 | flag |= val.Result.GetFlag() 106 | } 107 | if x.ElseClause != nil { 108 | flag |= x.ElseClause.GetFlag() 109 | } 110 | x.SetFlag(flag) 111 | } 112 | 113 | func (f *flagSetter) patternIn(x *PatternInExpr) { 114 | flag := x.Expr.GetFlag() 115 | for _, val := range x.List { 116 | flag |= val.GetFlag() 117 | } 118 | if x.Sel != nil { 119 | flag |= x.Sel.GetFlag() 120 | } 121 | x.SetFlag(flag) 122 | } 123 | 124 | func (f *flagSetter) patternLike(x *PatternLikeExpr) { 125 | flag := x.Pattern.GetFlag() 126 | if x.Expr != nil { 127 | flag |= x.Expr.GetFlag() 128 | } 129 | x.SetFlag(flag) 130 | } 131 | 132 | func (f *flagSetter) patternRegexp(x *PatternRegexpExpr) { 133 | flag := x.Pattern.GetFlag() 134 | if x.Expr != nil { 135 | flag |= x.Expr.GetFlag() 136 | } 137 | x.SetFlag(flag) 138 | } 139 | 140 | func (f *flagSetter) row(x *RowExpr) { 141 | var flag uint64 142 | for _, val := range x.Values { 143 | flag |= val.GetFlag() 144 | } 145 | x.SetFlag(flag) 146 | } 147 | 148 | func (f *flagSetter) funcCall(x *FuncCallExpr) { 149 | flag := FlagHasFunc 150 | for _, val := range x.Args { 151 | flag |= val.GetFlag() 152 | } 153 | x.SetFlag(flag) 154 | } 155 | 156 | func (f *flagSetter) aggregateFunc(x *AggregateFuncExpr) { 157 | flag := FlagHasAggregateFunc 158 | for _, val := range x.Args { 159 | flag |= val.GetFlag() 160 | } 161 | x.SetFlag(flag) 162 | } 163 | 164 | func (f *flagSetter) windowFunc(x *WindowFuncExpr) { 165 | flag := FlagHasWindowFunc 166 | for _, val := range x.Args { 167 | flag |= val.GetFlag() 168 | } 169 | x.SetFlag(flag) 170 | } 171 | -------------------------------------------------------------------------------- /ast/flag_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast_test 15 | 16 | import ( 17 | "testing" 18 | 19 | . "github.com/pingcap/check" 20 | "github.com/pingcap/parser" 21 | "github.com/pingcap/parser/ast" 22 | ) 23 | 24 | func TestT(t *testing.T) { 25 | CustomVerboseFlag = true 26 | TestingT(t) 27 | } 28 | 29 | var _ = Suite(&testFlagSuite{}) 30 | 31 | type testFlagSuite struct { 32 | *parser.Parser 33 | } 34 | 35 | func (ts *testFlagSuite) SetUpSuite(c *C) { 36 | ts.Parser = parser.New() 37 | } 38 | 39 | func (ts *testFlagSuite) TestHasAggFlag(c *C) { 40 | expr := &ast.BetweenExpr{} 41 | flagTests := []struct { 42 | flag uint64 43 | hasAgg bool 44 | }{ 45 | {ast.FlagHasAggregateFunc, true}, 46 | {ast.FlagHasAggregateFunc | ast.FlagHasVariable, true}, 47 | {ast.FlagHasVariable, false}, 48 | } 49 | for _, tt := range flagTests { 50 | expr.SetFlag(tt.flag) 51 | c.Assert(ast.HasAggFlag(expr), Equals, tt.hasAgg) 52 | } 53 | } 54 | 55 | func (ts *testFlagSuite) TestFlag(c *C) { 56 | flagTests := []struct { 57 | expr string 58 | flag uint64 59 | }{ 60 | { 61 | "1 between 0 and 2", 62 | ast.FlagConstant, 63 | }, 64 | { 65 | "case 1 when 1 then 1 else 0 end", 66 | ast.FlagConstant, 67 | }, 68 | { 69 | "case 1 when 1 then 1 else 0 end", 70 | ast.FlagConstant, 71 | }, 72 | { 73 | "case 1 when a > 1 then 1 else 0 end", 74 | ast.FlagConstant | ast.FlagHasReference, 75 | }, 76 | { 77 | "1 = ANY (select 1) OR exists (select 1)", 78 | ast.FlagHasSubquery, 79 | }, 80 | { 81 | "1 in (1) or 1 is true or null is null or 'abc' like 'abc' or 'abc' rlike 'abc'", 82 | ast.FlagConstant, 83 | }, 84 | { 85 | "row (1, 1) = row (1, 1)", 86 | ast.FlagConstant, 87 | }, 88 | { 89 | "(1 + a) > ?", 90 | ast.FlagHasReference | ast.FlagHasParamMarker, 91 | }, 92 | { 93 | "trim('abc ')", 94 | ast.FlagHasFunc, 95 | }, 96 | { 97 | "now() + EXTRACT(YEAR FROM '2009-07-02') + CAST(1 AS UNSIGNED)", 98 | ast.FlagHasFunc, 99 | }, 100 | { 101 | "substring('abc', 1)", 102 | ast.FlagHasFunc, 103 | }, 104 | { 105 | "sum(a)", 106 | ast.FlagHasAggregateFunc | ast.FlagHasReference, 107 | }, 108 | { 109 | "(select 1) as a", 110 | ast.FlagHasSubquery, 111 | }, 112 | { 113 | "@auto_commit", 114 | ast.FlagHasVariable, 115 | }, 116 | { 117 | "default(a)", 118 | ast.FlagHasDefault, 119 | }, 120 | { 121 | "a is null", 122 | ast.FlagHasReference, 123 | }, 124 | { 125 | "1 is true", 126 | ast.FlagConstant, 127 | }, 128 | { 129 | "a in (1, count(*), 3)", 130 | ast.FlagConstant | ast.FlagHasReference | ast.FlagHasAggregateFunc, 131 | }, 132 | { 133 | "'Michael!' REGEXP '.*'", 134 | ast.FlagConstant, 135 | }, 136 | { 137 | "a REGEXP '.*'", 138 | ast.FlagHasReference, 139 | }, 140 | { 141 | "-a", 142 | ast.FlagHasReference, 143 | }, 144 | } 145 | for _, tt := range flagTests { 146 | stmt, err := ts.ParseOneStmt("select "+tt.expr, "", "") 147 | c.Assert(err, IsNil) 148 | selectStmt := stmt.(*ast.SelectStmt) 149 | ast.SetFlag(selectStmt) 150 | expr := selectStmt.Fields.Fields[0].Expr 151 | c.Assert(expr.GetFlag(), Equals, tt.flag, Commentf("For %s", tt.expr)) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /ast/format_test.go: -------------------------------------------------------------------------------- 1 | package ast_test 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | 7 | . "github.com/pingcap/check" 8 | "github.com/pingcap/parser" 9 | "github.com/pingcap/parser/ast" 10 | ) 11 | 12 | var _ = Suite(&testAstFormatSuite{}) 13 | 14 | type testAstFormatSuite struct { 15 | } 16 | 17 | func getDefaultCharsetAndCollate() (string, string) { 18 | return "utf8", "utf8_bin" 19 | } 20 | 21 | func (ts *testAstFormatSuite) TestAstFormat(c *C) { 22 | var testcases = []struct { 23 | input string 24 | output string 25 | }{ 26 | // Literals. 27 | {`null`, `NULL`}, 28 | {`true`, `TRUE`}, 29 | {`350`, `350`}, 30 | {`001e-12`, `1e-12`}, // Float. 31 | {`345.678`, `345.678`}, 32 | {`00.0001000`, `0.0001000`}, // Decimal. 33 | {`null`, `NULL`}, 34 | {`"Hello, world"`, `"Hello, world"`}, 35 | {`'Hello, world'`, `"Hello, world"`}, 36 | {`'Hello, "world"'`, `"Hello, \"world\""`}, 37 | {`_utf8'你好'`, `"你好"`}, 38 | {`x'bcde'`, "x'bcde'"}, 39 | {`x''`, "x''"}, 40 | {`x'0035'`, "x'0035'"}, // Shouldn't trim leading zero. 41 | {`b'00111111'`, `b'111111'`}, 42 | {`time'10:10:10.123'`, ast.TimeLiteral + `("10:10:10.123")`}, 43 | {`timestamp'1999-01-01 10:0:0.123'`, ast.TimestampLiteral + `("1999-01-01 10:0:0.123")`}, 44 | {`date '1700-01-01'`, ast.DateLiteral + `("1700-01-01")`}, 45 | 46 | // Expressions. 47 | {`f between 30 and 50`, "`f` BETWEEN 30 AND 50"}, 48 | {`f not between 30 and 50`, "`f` NOT BETWEEN 30 AND 50"}, 49 | {`345 + " hello "`, `345 + " hello "`}, 50 | {`"hello world" >= 'hello world'`, `"hello world" >= "hello world"`}, 51 | {`case 3 when 1 then false else true end`, `CASE 3 WHEN 1 THEN FALSE ELSE TRUE END`}, 52 | {`database.table.column`, "`database`.`table`.`column`"}, // ColumnNameExpr 53 | {`3 is null`, `3 IS NULL`}, 54 | {`3 is not null`, `3 IS NOT NULL`}, 55 | {`3 is true`, `3 IS TRUE`}, 56 | {`3 is not true`, `3 IS NOT TRUE`}, 57 | {`3 is false`, `3 IS FALSE`}, 58 | {` ( x is false )`, "(`x` IS FALSE)"}, 59 | {`3 in ( a,b,"h",6 )`, "3 IN (`a`,`b`,\"h\",6)"}, 60 | {`3 not in ( a,b,"h",6 )`, "3 NOT IN (`a`,`b`,\"h\",6)"}, 61 | {`"abc" like '%b%'`, `"abc" LIKE "%b%"`}, 62 | {`"abc" not like '%b%'`, `"abc" NOT LIKE "%b%"`}, 63 | {`"abc" like '%b%' escape '_'`, `"abc" LIKE "%b%" ESCAPE '_'`}, 64 | {`"abc" regexp '.*bc?'`, `"abc" REGEXP ".*bc?"`}, 65 | {`"abc" not regexp '.*bc?'`, `"abc" NOT REGEXP ".*bc?"`}, 66 | {`- 4`, `-4`}, 67 | {`- ( - 4 ) `, `-(-4)`}, 68 | {`a%b`, "`a` % `b`"}, 69 | {`a%b+6`, "`a` % `b` + 6"}, 70 | {`a%(b+6)`, "`a` % (`b` + 6)"}, 71 | // Functions. 72 | {` json_extract ( a,'$.b',"$.\"c d\"" ) `, "json_extract(`a`, \"$.b\", \"$.\\\"c d\\\"\")"}, 73 | {` length ( a )`, "length(`a`)"}, 74 | {`a -> '$.a'`, "json_extract(`a`, \"$.a\")"}, 75 | {`a.b ->> '$.a'`, "json_unquote(json_extract(`a`.`b`, \"$.a\"))"}, 76 | {`DATE_ADD('1970-01-01', interval 3 second)`, `date_add("1970-01-01", INTERVAL 3 SECOND)`}, 77 | {`TIMESTAMPDIFF(month, '2001-01-01', '2001-02-02 12:03:05.123')`, `timestampdiff(MONTH, "2001-01-01", "2001-02-02 12:03:05.123")`}, 78 | // Cast, Convert and Binary. 79 | // There should not be spaces between 'cast' and '(' unless 'IGNORE_SPACE' mode is set. 80 | // see: https://dev.mysql.com/doc/refman/5.7/en/function-resolution.html 81 | {` cast( a as signed ) `, "CAST(`a` AS SIGNED)"}, 82 | {` cast( a as unsigned integer) `, "CAST(`a` AS UNSIGNED)"}, 83 | {` cast( a as char(3) binary) `, "CAST(`a` AS BINARY(3))"}, 84 | {` cast( a as decimal ) `, "CAST(`a` AS DECIMAL(10))"}, 85 | {` cast( a as decimal (3) ) `, "CAST(`a` AS DECIMAL(3))"}, 86 | {` cast( a as decimal (3,3) ) `, "CAST(`a` AS DECIMAL(3, 3))"}, 87 | {` ((case when (c0 = 0) then 0 when (c0 > 0) then (c1 / c0) end)) `, "((CASE WHEN (`c0` = 0) THEN 0 WHEN (`c0` > 0) THEN (`c1` / `c0`) END))"}, 88 | {` convert (a, signed) `, "CONVERT(`a`, SIGNED)"}, 89 | {` binary "hello"`, `BINARY "hello"`}, 90 | } 91 | for _, tt := range testcases { 92 | expr := fmt.Sprintf("select %s", tt.input) 93 | charset, collation := getDefaultCharsetAndCollate() 94 | stmts, _, err := parser.New().Parse(expr, charset, collation) 95 | node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr 96 | c.Assert(err, IsNil) 97 | 98 | writer := bytes.NewBufferString("") 99 | node.Format(writer) 100 | c.Assert(writer.String(), Equals, tt.output) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /ast/stats.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast 15 | 16 | import ( 17 | "github.com/pingcap/errors" 18 | "github.com/pingcap/parser/format" 19 | "github.com/pingcap/parser/model" 20 | ) 21 | 22 | var ( 23 | _ StmtNode = &AnalyzeTableStmt{} 24 | _ StmtNode = &DropStatsStmt{} 25 | _ StmtNode = &LoadStatsStmt{} 26 | ) 27 | 28 | // AnalyzeTableStmt is used to create table statistics. 29 | type AnalyzeTableStmt struct { 30 | stmtNode 31 | 32 | TableNames []*TableName 33 | PartitionNames []model.CIStr 34 | IndexNames []model.CIStr 35 | AnalyzeOpts []AnalyzeOpt 36 | 37 | // IndexFlag is true when we only analyze indices for a table. 38 | IndexFlag bool 39 | Incremental bool 40 | // HistogramOperation is set in "ANALYZE TABLE ... UPDATE/DROP HISTOGRAM ..." statement. 41 | HistogramOperation HistogramOperationType 42 | ColumnNames []*ColumnName 43 | } 44 | 45 | // AnalyzeOptType is the type for analyze options. 46 | type AnalyzeOptionType int 47 | 48 | // Analyze option types. 49 | const ( 50 | AnalyzeOptNumBuckets = iota 51 | AnalyzeOptNumTopN 52 | AnalyzeOptCMSketchDepth 53 | AnalyzeOptCMSketchWidth 54 | AnalyzeOptNumSamples 55 | ) 56 | 57 | // AnalyzeOptionString stores the string form of analyze options. 58 | var AnalyzeOptionString = map[AnalyzeOptionType]string{ 59 | AnalyzeOptNumBuckets: "BUCKETS", 60 | AnalyzeOptNumTopN: "TOPN", 61 | AnalyzeOptCMSketchWidth: "CMSKETCH WIDTH", 62 | AnalyzeOptCMSketchDepth: "CMSKETCH DEPTH", 63 | AnalyzeOptNumSamples: "SAMPLES", 64 | } 65 | 66 | // HistogramOperationType is the type for histogram operation. 67 | type HistogramOperationType int 68 | 69 | // Histogram operation types. 70 | const ( 71 | // HistogramOperationNop shows no operation in histogram. Default value. 72 | HistogramOperationNop HistogramOperationType = iota 73 | HistogramOperationUpdate 74 | HistogramOperationDrop 75 | ) 76 | 77 | // String implements fmt.Stringer for HistogramOperationType. 78 | func (hot HistogramOperationType) String() string { 79 | switch hot { 80 | case HistogramOperationUpdate: 81 | return "UPDATE HISTOGRAM" 82 | case HistogramOperationDrop: 83 | return "DROP HISTOGRAM" 84 | } 85 | return "" 86 | } 87 | 88 | // AnalyzeOpt stores the analyze option type and value. 89 | type AnalyzeOpt struct { 90 | Type AnalyzeOptionType 91 | Value uint64 92 | } 93 | 94 | // Restore implements Node interface. 95 | func (n *AnalyzeTableStmt) Restore(ctx *format.RestoreCtx) error { 96 | if n.Incremental { 97 | ctx.WriteKeyWord("ANALYZE INCREMENTAL TABLE ") 98 | } else { 99 | ctx.WriteKeyWord("ANALYZE TABLE ") 100 | } 101 | for i, table := range n.TableNames { 102 | if i != 0 { 103 | ctx.WritePlain(",") 104 | } 105 | if err := table.Restore(ctx); err != nil { 106 | return errors.Annotatef(err, "An error occurred while restore AnalyzeTableStmt.TableNames[%d]", i) 107 | } 108 | } 109 | if len(n.PartitionNames) != 0 { 110 | ctx.WriteKeyWord(" PARTITION ") 111 | } 112 | for i, partition := range n.PartitionNames { 113 | if i != 0 { 114 | ctx.WritePlain(",") 115 | } 116 | ctx.WriteName(partition.O) 117 | } 118 | if n.HistogramOperation != HistogramOperationNop { 119 | ctx.WritePlain(" ") 120 | ctx.WriteKeyWord(n.HistogramOperation.String()) 121 | ctx.WritePlain(" ") 122 | } 123 | if len(n.ColumnNames) > 0 { 124 | ctx.WriteKeyWord("ON ") 125 | for i, columnName := range n.ColumnNames { 126 | if i != 0 { 127 | ctx.WritePlain(",") 128 | } 129 | ctx.WriteName(columnName.Name.O) 130 | } 131 | } 132 | if n.IndexFlag { 133 | ctx.WriteKeyWord(" INDEX") 134 | } 135 | for i, index := range n.IndexNames { 136 | if i != 0 { 137 | ctx.WritePlain(",") 138 | } else { 139 | ctx.WritePlain(" ") 140 | } 141 | ctx.WriteName(index.O) 142 | } 143 | if len(n.AnalyzeOpts) != 0 { 144 | ctx.WriteKeyWord(" WITH") 145 | for i, opt := range n.AnalyzeOpts { 146 | if i != 0 { 147 | ctx.WritePlain(",") 148 | } 149 | ctx.WritePlainf(" %d ", opt.Value) 150 | ctx.WritePlain(AnalyzeOptionString[opt.Type]) 151 | } 152 | } 153 | return nil 154 | } 155 | 156 | // Accept implements Node Accept interface. 157 | func (n *AnalyzeTableStmt) Accept(v Visitor) (Node, bool) { 158 | newNode, skipChildren := v.Enter(n) 159 | if skipChildren { 160 | return v.Leave(newNode) 161 | } 162 | n = newNode.(*AnalyzeTableStmt) 163 | for i, val := range n.TableNames { 164 | node, ok := val.Accept(v) 165 | if !ok { 166 | return n, false 167 | } 168 | n.TableNames[i] = node.(*TableName) 169 | } 170 | return v.Leave(n) 171 | } 172 | 173 | // DropStatsStmt is used to drop table statistics. 174 | type DropStatsStmt struct { 175 | stmtNode 176 | 177 | Table *TableName 178 | PartitionNames []model.CIStr 179 | IsGlobalStats bool 180 | } 181 | 182 | // Restore implements Node interface. 183 | func (n *DropStatsStmt) Restore(ctx *format.RestoreCtx) error { 184 | ctx.WriteKeyWord("DROP STATS ") 185 | if err := n.Table.Restore(ctx); err != nil { 186 | return errors.Annotate(err, "An error occurred while add table") 187 | } 188 | 189 | if n.IsGlobalStats { 190 | ctx.WriteKeyWord(" GLOBAL") 191 | return nil 192 | } 193 | 194 | if len(n.PartitionNames) != 0 { 195 | ctx.WriteKeyWord(" PARTITION ") 196 | } 197 | for i, partition := range n.PartitionNames { 198 | if i != 0 { 199 | ctx.WritePlain(",") 200 | } 201 | ctx.WriteName(partition.O) 202 | } 203 | return nil 204 | } 205 | 206 | // Accept implements Node Accept interface. 207 | func (n *DropStatsStmt) Accept(v Visitor) (Node, bool) { 208 | newNode, skipChildren := v.Enter(n) 209 | if skipChildren { 210 | return v.Leave(newNode) 211 | } 212 | n = newNode.(*DropStatsStmt) 213 | node, ok := n.Table.Accept(v) 214 | if !ok { 215 | return n, false 216 | } 217 | n.Table = node.(*TableName) 218 | return v.Leave(n) 219 | } 220 | 221 | // LoadStatsStmt is the statement node for loading statistic. 222 | type LoadStatsStmt struct { 223 | stmtNode 224 | 225 | Path string 226 | } 227 | 228 | // Restore implements Node interface. 229 | func (n *LoadStatsStmt) Restore(ctx *format.RestoreCtx) error { 230 | ctx.WriteKeyWord("LOAD STATS ") 231 | ctx.WriteString(n.Path) 232 | return nil 233 | } 234 | 235 | // Accept implements Node Accept interface. 236 | func (n *LoadStatsStmt) Accept(v Visitor) (Node, bool) { 237 | newNode, skipChildren := v.Enter(n) 238 | if skipChildren { 239 | return v.Leave(newNode) 240 | } 241 | n = newNode.(*LoadStatsStmt) 242 | return v.Leave(n) 243 | } 244 | -------------------------------------------------------------------------------- /ast/util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast 15 | 16 | import "math" 17 | 18 | // UnspecifiedSize is unspecified size. 19 | const ( 20 | UnspecifiedSize = math.MaxUint64 21 | ) 22 | 23 | // IsReadOnly checks whether the input ast is readOnly. 24 | func IsReadOnly(node Node) bool { 25 | switch st := node.(type) { 26 | case *SelectStmt: 27 | if st.LockInfo != nil { 28 | switch st.LockInfo.LockType { 29 | case SelectLockForUpdate, SelectLockForUpdateNoWait, SelectLockForUpdateWaitN: 30 | return false 31 | } 32 | } 33 | 34 | checker := readOnlyChecker{ 35 | readOnly: true, 36 | } 37 | 38 | node.Accept(&checker) 39 | return checker.readOnly 40 | case *ExplainStmt: 41 | return !st.Analyze || IsReadOnly(st.Stmt) 42 | case *DoStmt, *ShowStmt: 43 | return true 44 | case *SetOprStmt: 45 | for _, sel := range node.(*SetOprStmt).SelectList.Selects { 46 | if !IsReadOnly(sel) { 47 | return false 48 | } 49 | } 50 | return true 51 | case *SetOprSelectList: 52 | for _, sel := range node.(*SetOprSelectList).Selects { 53 | if !IsReadOnly(sel) { 54 | return false 55 | } 56 | } 57 | return true 58 | default: 59 | return false 60 | } 61 | } 62 | 63 | // readOnlyChecker checks whether a query's ast is readonly, if it satisfied 64 | // 1. selectstmt; 65 | // 2. need not to set var; 66 | // it is readonly statement. 67 | type readOnlyChecker struct { 68 | readOnly bool 69 | } 70 | 71 | // Enter implements Visitor interface. 72 | func (checker *readOnlyChecker) Enter(in Node) (out Node, skipChildren bool) { 73 | switch node := in.(type) { 74 | case *VariableExpr: 75 | // like func rewriteVariable(), this stands for SetVar. 76 | if !node.IsSystem && node.Value != nil { 77 | checker.readOnly = false 78 | return in, true 79 | } 80 | } 81 | return in, false 82 | } 83 | 84 | // Leave implements Visitor interface. 85 | func (checker *readOnlyChecker) Leave(in Node) (out Node, ok bool) { 86 | return in, checker.readOnly 87 | } 88 | -------------------------------------------------------------------------------- /ast/util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package ast_test 15 | 16 | import ( 17 | "fmt" 18 | "strings" 19 | 20 | . "github.com/pingcap/check" 21 | "github.com/pingcap/parser" 22 | . "github.com/pingcap/parser/ast" 23 | . "github.com/pingcap/parser/format" 24 | "github.com/pingcap/parser/test_driver" 25 | ) 26 | 27 | var _ = Suite(&testCacheableSuite{}) 28 | 29 | type testCacheableSuite struct { 30 | } 31 | 32 | func (s *testCacheableSuite) TestCacheable(c *C) { 33 | // test non-SelectStmt 34 | var stmt Node = &DeleteStmt{} 35 | c.Assert(IsReadOnly(stmt), IsFalse) 36 | 37 | stmt = &InsertStmt{} 38 | c.Assert(IsReadOnly(stmt), IsFalse) 39 | 40 | stmt = &UpdateStmt{} 41 | c.Assert(IsReadOnly(stmt), IsFalse) 42 | 43 | stmt = &ExplainStmt{} 44 | c.Assert(IsReadOnly(stmt), IsTrue) 45 | 46 | stmt = &ExplainStmt{} 47 | c.Assert(IsReadOnly(stmt), IsTrue) 48 | 49 | stmt = &DoStmt{} 50 | c.Assert(IsReadOnly(stmt), IsTrue) 51 | 52 | stmt = &ExplainStmt{ 53 | Stmt: &InsertStmt{}, 54 | } 55 | c.Assert(IsReadOnly(stmt), IsTrue) 56 | 57 | stmt = &ExplainStmt{ 58 | Analyze: true, 59 | Stmt: &InsertStmt{}, 60 | } 61 | c.Assert(IsReadOnly(stmt), IsFalse) 62 | 63 | stmt = &ExplainStmt{ 64 | Stmt: &SelectStmt{}, 65 | } 66 | c.Assert(IsReadOnly(stmt), IsTrue) 67 | 68 | stmt = &ExplainStmt{ 69 | Analyze: true, 70 | Stmt: &SelectStmt{}, 71 | } 72 | c.Assert(IsReadOnly(stmt), IsTrue) 73 | 74 | stmt = &ShowStmt{} 75 | c.Assert(IsReadOnly(stmt), IsTrue) 76 | 77 | stmt = &ShowStmt{} 78 | c.Assert(IsReadOnly(stmt), IsTrue) 79 | } 80 | 81 | func (s *testCacheableSuite) TestUnionReadOnly(c *C) { 82 | selectReadOnly := &SelectStmt{} 83 | selectForUpdate := &SelectStmt{ 84 | LockInfo: &SelectLockInfo{LockType: SelectLockForUpdate}, 85 | } 86 | selectForUpdateNoWait := &SelectStmt{ 87 | LockInfo: &SelectLockInfo{LockType: SelectLockForUpdateNoWait}, 88 | } 89 | 90 | setOprStmt := &SetOprStmt{ 91 | SelectList: &SetOprSelectList{ 92 | Selects: []Node{selectReadOnly, selectReadOnly}, 93 | }, 94 | } 95 | c.Assert(IsReadOnly(setOprStmt), IsTrue) 96 | 97 | setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectReadOnly, selectReadOnly} 98 | c.Assert(IsReadOnly(setOprStmt), IsTrue) 99 | 100 | setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate} 101 | c.Assert(IsReadOnly(setOprStmt), IsFalse) 102 | 103 | setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdateNoWait} 104 | c.Assert(IsReadOnly(setOprStmt), IsFalse) 105 | 106 | setOprStmt.SelectList.Selects = []Node{selectForUpdate, selectForUpdateNoWait} 107 | c.Assert(IsReadOnly(setOprStmt), IsFalse) 108 | 109 | setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate, selectForUpdateNoWait} 110 | c.Assert(IsReadOnly(setOprStmt), IsFalse) 111 | } 112 | 113 | // CleanNodeText set the text of node and all child node empty. 114 | // For test only. 115 | func CleanNodeText(node Node) { 116 | var cleaner nodeTextCleaner 117 | node.Accept(&cleaner) 118 | } 119 | 120 | // nodeTextCleaner clean the text of a node and it's child node. 121 | // For test only. 122 | type nodeTextCleaner struct { 123 | } 124 | 125 | // Enter implements Visitor interface. 126 | func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) { 127 | in.SetText("") 128 | in.SetOriginTextPosition(0) 129 | switch node := in.(type) { 130 | case *Constraint: 131 | if node.Option != nil { 132 | if node.Option.KeyBlockSize == 0x0 && node.Option.Tp == 0 && node.Option.Comment == "" { 133 | node.Option = nil 134 | } 135 | } 136 | case *FuncCallExpr: 137 | node.FnName.O = strings.ToLower(node.FnName.O) 138 | switch node.FnName.L { 139 | case "convert": 140 | node.Args[1].(*test_driver.ValueExpr).Datum.SetBytes(nil) 141 | } 142 | case *AggregateFuncExpr: 143 | node.F = strings.ToLower(node.F) 144 | case *FieldList: 145 | for _, f := range node.Fields { 146 | f.Offset = 0 147 | } 148 | case *AlterTableSpec: 149 | for _, opt := range node.Options { 150 | opt.StrValue = strings.ToLower(opt.StrValue) 151 | } 152 | case *Join: 153 | node.ExplicitParens = false 154 | } 155 | return in, false 156 | } 157 | 158 | // Leave implements Visitor interface. 159 | func (checker *nodeTextCleaner) Leave(in Node) (out Node, ok bool) { 160 | return in, true 161 | } 162 | 163 | type NodeRestoreTestCase struct { 164 | sourceSQL string 165 | expectSQL string 166 | } 167 | 168 | func RunNodeRestoreTest(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) { 169 | RunNodeRestoreTestWithFlags(c, nodeTestCases, template, extractNodeFunc, DefaultRestoreFlags) 170 | } 171 | 172 | func RunNodeRestoreTestWithFlags(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) { 173 | parser := parser.New() 174 | parser.EnableWindowFunc(true) 175 | for _, testCase := range nodeTestCases { 176 | sourceSQL := fmt.Sprintf(template, testCase.sourceSQL) 177 | expectSQL := fmt.Sprintf(template, testCase.expectSQL) 178 | stmt, err := parser.ParseOneStmt(sourceSQL, "", "") 179 | comment := Commentf("source %#v", testCase) 180 | c.Assert(err, IsNil, comment) 181 | var sb strings.Builder 182 | err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) 183 | c.Assert(err, IsNil, comment) 184 | restoreSql := fmt.Sprintf(template, sb.String()) 185 | comment = Commentf("source %#v; restore %v", testCase, restoreSql) 186 | c.Assert(restoreSql, Equals, expectSQL, comment) 187 | stmt2, err := parser.ParseOneStmt(restoreSql, "", "") 188 | c.Assert(err, IsNil, comment) 189 | CleanNodeText(stmt) 190 | CleanNodeText(stmt2) 191 | c.Assert(stmt2, DeepEquals, stmt, comment) 192 | } 193 | } 194 | 195 | // RunNodeRestoreTestWithFlagsStmtChange likes RunNodeRestoreTestWithFlags but not check if the ASTs are same. 196 | // Sometimes the AST are different and it's expected. 197 | func RunNodeRestoreTestWithFlagsStmtChange(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) { 198 | par := parser.New() 199 | par.EnableWindowFunc(true) 200 | for _, testCase := range nodeTestCases { 201 | sourceSQL := fmt.Sprintf(template, testCase.sourceSQL) 202 | expectSQL := fmt.Sprintf(template, testCase.expectSQL) 203 | stmt, err := par.ParseOneStmt(sourceSQL, "", "") 204 | comment := Commentf("source %#v", testCase) 205 | c.Assert(err, IsNil, comment) 206 | var sb strings.Builder 207 | err = extractNodeFunc(stmt).Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) 208 | c.Assert(err, IsNil, comment) 209 | restoreSql := fmt.Sprintf(template, sb.String()) 210 | comment = Commentf("source %#v; restore %v", testCase, restoreSql) 211 | c.Assert(restoreSql, Equals, expectSQL, comment) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /auth/auth.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | import ( 17 | "fmt" 18 | 19 | "github.com/pingcap/parser/format" 20 | ) 21 | 22 | // UserIdentity represents username and hostname. 23 | type UserIdentity struct { 24 | Username string 25 | Hostname string 26 | CurrentUser bool 27 | AuthUsername string // Username matched in privileges system 28 | AuthHostname string // Match in privs system (i.e. could be a wildcard) 29 | } 30 | 31 | // Restore implements Node interface. 32 | func (user *UserIdentity) Restore(ctx *format.RestoreCtx) error { 33 | if user.CurrentUser { 34 | ctx.WriteKeyWord("CURRENT_USER") 35 | } else { 36 | ctx.WriteName(user.Username) 37 | ctx.WritePlain("@") 38 | ctx.WriteName(user.Hostname) 39 | } 40 | return nil 41 | } 42 | 43 | // String converts UserIdentity to the format user@host. 44 | func (user *UserIdentity) String() string { 45 | // TODO: Escape username and hostname. 46 | if user == nil { 47 | return "" 48 | } 49 | return fmt.Sprintf("%s@%s", user.Username, user.Hostname) 50 | } 51 | 52 | // AuthIdentityString returns matched identity in user@host format 53 | func (user *UserIdentity) AuthIdentityString() string { 54 | // TODO: Escape username and hostname. 55 | return fmt.Sprintf("%s@%s", user.AuthUsername, user.AuthHostname) 56 | } 57 | 58 | type RoleIdentity struct { 59 | Username string 60 | Hostname string 61 | } 62 | 63 | func (role *RoleIdentity) Restore(ctx *format.RestoreCtx) error { 64 | ctx.WriteName(role.Username) 65 | if role.Hostname != "" { 66 | ctx.WritePlain("@") 67 | ctx.WriteName(role.Hostname) 68 | } 69 | return nil 70 | } 71 | 72 | // String converts UserIdentity to the format user@host. 73 | func (role *RoleIdentity) String() string { 74 | // TODO: Escape username and hostname. 75 | return fmt.Sprintf("`%s`@`%s`", role.Username, role.Hostname) 76 | } 77 | -------------------------------------------------------------------------------- /auth/auth_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | import ( 17 | "testing" 18 | 19 | . "github.com/pingcap/check" 20 | ) 21 | 22 | var _ = Suite(&testAuthSuite{}) 23 | 24 | type testAuthSuite struct { 25 | } 26 | 27 | func TestT(t *testing.T) { 28 | TestingT(t) 29 | } 30 | -------------------------------------------------------------------------------- /auth/caching_sha2.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | // Resources: 17 | // - https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html 18 | // - https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html 19 | // - https://dev.mysql.com/doc/dev/mysql-server/latest/namespacesha2__password.html 20 | // - https://www.akkadia.org/drepper/SHA-crypt.txt 21 | // - https://dev.mysql.com/worklog/task/?id=9591 22 | // 23 | // CREATE USER 'foo'@'%' IDENTIFIED BY 'foobar'; 24 | // SELECT HEX(authentication_string) FROM mysql.user WHERE user='foo'; 25 | // 24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537 26 | // 27 | // Format: 28 | // Split on '$': 29 | // - digest type ("A") 30 | // - iterations (divided by ITERATION_MULTIPLIER) 31 | // - salt+hash 32 | // 33 | 34 | import ( 35 | "bytes" 36 | "crypto/rand" 37 | "crypto/sha256" 38 | "errors" 39 | "fmt" 40 | "strconv" 41 | ) 42 | 43 | const ( 44 | MIXCHARS = 32 45 | SALT_LENGTH = 20 46 | ITERATION_MULTIPLIER = 1000 47 | ) 48 | 49 | func b64From24bit(b []byte, n int) []byte { 50 | b64t := []byte("./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") 51 | 52 | w := (int64(b[0]) << 16) | (int64(b[1]) << 8) | int64(b[2]) 53 | ret := make([]byte, 0, n) 54 | for n > 0 { 55 | n-- 56 | ret = append(ret, b64t[w&0x3f]) 57 | w >>= 6 58 | } 59 | 60 | return ret 61 | } 62 | 63 | func sha256crypt(plaintext string, salt []byte, iterations int) string { 64 | // Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt 65 | 66 | // 1, 2, 3 67 | tmpA := sha256.New() 68 | tmpA.Write([]byte(plaintext)) 69 | tmpA.Write(salt) 70 | 71 | // 4, 5, 6, 7, 8 72 | tmpB := sha256.New() 73 | tmpB.Write([]byte(plaintext)) 74 | tmpB.Write(salt) 75 | tmpB.Write([]byte(plaintext)) 76 | sumB := tmpB.Sum(nil) 77 | 78 | // 9, 10 79 | var i int 80 | for i = len(plaintext); i > MIXCHARS; i -= MIXCHARS { 81 | tmpA.Write(sumB[:MIXCHARS]) 82 | } 83 | tmpA.Write(sumB[:i]) 84 | 85 | // 11 86 | for i = len(plaintext); i > 0; i >>= 1 { 87 | if i%2 == 0 { 88 | tmpA.Write([]byte(plaintext)) 89 | } else { 90 | tmpA.Write(sumB) 91 | } 92 | } 93 | 94 | // 12 95 | sumA := tmpA.Sum(nil) 96 | 97 | // 13, 14, 15 98 | tmpDP := sha256.New() 99 | for range []byte(plaintext) { 100 | tmpDP.Write([]byte(plaintext)) 101 | } 102 | sumDP := tmpDP.Sum(nil) 103 | 104 | // 16 105 | p := make([]byte, 0, sha256.Size) 106 | for i = len(plaintext); i > 0; i -= MIXCHARS { 107 | if i > MIXCHARS { 108 | p = append(p, sumDP...) 109 | } else { 110 | p = append(p, sumDP[0:i]...) 111 | } 112 | } 113 | 114 | // 17, 18, 19 115 | tmpDS := sha256.New() 116 | for i = 0; i < 16+int(sumA[0]); i++ { 117 | tmpDS.Write(salt) 118 | } 119 | sumDS := tmpDS.Sum(nil) 120 | 121 | // 20 122 | s := []byte{} 123 | for i = len(salt); i > 0; i -= MIXCHARS { 124 | if i > MIXCHARS { 125 | s = append(s, sumDS...) 126 | } else { 127 | s = append(s, sumDS[0:i]...) 128 | } 129 | } 130 | 131 | // 21 132 | tmpC := sha256.New() 133 | var sumC []byte 134 | for i = 0; i < iterations; i++ { 135 | tmpC.Reset() 136 | 137 | if i&1 != 0 { 138 | tmpC.Write(p) 139 | } else { 140 | tmpC.Write(sumA) 141 | } 142 | if i%3 != 0 { 143 | tmpC.Write(s) 144 | } 145 | if i%7 != 0 { 146 | tmpC.Write(p) 147 | } 148 | if i&1 != 0 { 149 | tmpC.Write(sumA) 150 | } else { 151 | tmpC.Write(p) 152 | } 153 | sumC = tmpC.Sum(nil) 154 | copy(sumA, tmpC.Sum(nil)) 155 | } 156 | 157 | // 22 158 | buf := bytes.Buffer{} 159 | buf.Grow(100) // FIXME 160 | buf.Write([]byte{'$', 'A', '$'}) 161 | rounds := fmt.Sprintf("%03d", iterations/ITERATION_MULTIPLIER) 162 | buf.Write([]byte(rounds)) 163 | buf.Write([]byte{'$'}) 164 | buf.Write(salt) 165 | 166 | buf.Write(b64From24bit([]byte{sumC[0], sumC[10], sumC[20]}, 4)) 167 | buf.Write(b64From24bit([]byte{sumC[21], sumC[1], sumC[11]}, 4)) 168 | buf.Write(b64From24bit([]byte{sumC[12], sumC[22], sumC[2]}, 4)) 169 | buf.Write(b64From24bit([]byte{sumC[3], sumC[13], sumC[23]}, 4)) 170 | buf.Write(b64From24bit([]byte{sumC[24], sumC[4], sumC[14]}, 4)) 171 | buf.Write(b64From24bit([]byte{sumC[15], sumC[25], sumC[5]}, 4)) 172 | buf.Write(b64From24bit([]byte{sumC[6], sumC[16], sumC[26]}, 4)) 173 | buf.Write(b64From24bit([]byte{sumC[27], sumC[7], sumC[17]}, 4)) 174 | buf.Write(b64From24bit([]byte{sumC[18], sumC[28], sumC[8]}, 4)) 175 | buf.Write(b64From24bit([]byte{sumC[9], sumC[19], sumC[29]}, 4)) 176 | buf.Write(b64From24bit([]byte{0, sumC[31], sumC[30]}, 3)) 177 | 178 | return buf.String() 179 | } 180 | 181 | // Checks if a MySQL style caching_sha2 authentication string matches a password 182 | func CheckShaPassword(pwhash []byte, password string) (bool, error) { 183 | pwhash_parts := bytes.Split(pwhash, []byte("$")) 184 | if len(pwhash_parts) != 4 { 185 | return false, errors.New("failed to decode hash parts") 186 | } 187 | 188 | hash_type := string(pwhash_parts[1]) 189 | if hash_type != "A" { 190 | return false, errors.New("digest type is incompatible") 191 | } 192 | 193 | iterations, err := strconv.Atoi(string(pwhash_parts[2])) 194 | if err != nil { 195 | return false, errors.New("failed to decode iterations") 196 | } 197 | iterations = iterations * ITERATION_MULTIPLIER 198 | salt := pwhash_parts[3][:SALT_LENGTH] 199 | 200 | newHash := sha256crypt(password, salt, iterations) 201 | 202 | return bytes.Equal(pwhash, []byte(newHash)), nil 203 | } 204 | 205 | func NewSha2Password(pwd string) string { 206 | salt := make([]byte, SALT_LENGTH) 207 | rand.Read(salt) 208 | 209 | // Restrict to 7-bit to avoid multi-byte UTF-8 210 | for i := range salt { 211 | salt[i] = salt[i] &^ 128 212 | for salt[i] == 36 || salt[i] == 0 { // '$' or NUL 213 | newval := make([]byte, 1) 214 | rand.Read(newval) 215 | salt[i] = newval[0] &^ 128 216 | } 217 | } 218 | 219 | return sha256crypt(pwd, salt, 5*ITERATION_MULTIPLIER) 220 | } 221 | -------------------------------------------------------------------------------- /auth/caching_sha2_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | import ( 17 | "encoding/hex" 18 | 19 | . "github.com/pingcap/check" 20 | ) 21 | 22 | func (s *testAuthSuite) TestCheckShaPasswordGood(c *C) { 23 | pwd := "foobar" 24 | pwhash, _ := hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") 25 | r, err := CheckShaPassword(pwhash, pwd) 26 | c.Assert(err, IsNil) 27 | c.Assert(r, IsTrue) 28 | } 29 | 30 | func (s *testAuthSuite) TestCheckShaPasswordBad(c *C) { 31 | pwd := "not_foobar" 32 | pwhash, _ := hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") 33 | r, err := CheckShaPassword(pwhash, pwd) 34 | c.Assert(err, IsNil) 35 | c.Assert(r, IsFalse) 36 | } 37 | 38 | func (s *testAuthSuite) TestCheckShaPasswordShort(c *C) { 39 | pwd := "not_foobar" 40 | pwhash, _ := hex.DecodeString("aaaaaaaa") 41 | _, err := CheckShaPassword(pwhash, pwd) 42 | c.Assert(err, NotNil) 43 | } 44 | 45 | func (s *testAuthSuite) TestCheckShaPasswordDigetTypeIncompatible(c *C) { 46 | pwd := "not_foobar" 47 | pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") 48 | _, err := CheckShaPassword(pwhash, pwd) 49 | c.Assert(err, NotNil) 50 | } 51 | 52 | func (s *testAuthSuite) TestCheckShaPasswordIterationsInvalid(c *C) { 53 | pwd := "not_foobar" 54 | pwhash, _ := hex.DecodeString("24412430304124031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") 55 | _, err := CheckShaPassword(pwhash, pwd) 56 | c.Assert(err, NotNil) 57 | } 58 | 59 | // The output from NewSha2Password is not stable as the hash is based on the genrated salt. 60 | // This is why CheckShaPassword is used here. 61 | func (s *testAuthSuite) TestNewSha2Password(c *C) { 62 | pwd := "testpwd" 63 | pwhash := NewSha2Password(pwd) 64 | r, err := CheckShaPassword([]byte(pwhash), pwd) 65 | c.Assert(err, IsNil) 66 | c.Assert(r, IsTrue) 67 | 68 | for r := range pwhash { 69 | c.Assert(pwhash[r], Less, uint8(128)) 70 | c.Assert(pwhash[r], Not(Equals), 0) // NUL 71 | c.Assert(pwhash[r], Not(Equals), 36) // '$' 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /auth/mysql_native_password.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | import ( 17 | "bytes" 18 | "crypto/sha1" 19 | "encoding/hex" 20 | "fmt" 21 | 22 | "github.com/pingcap/errors" 23 | "github.com/pingcap/parser/terror" 24 | ) 25 | 26 | // CheckScrambledPassword check scrambled password received from client. 27 | // The new authentication is performed in following manner: 28 | // SERVER: public_seed=create_random_string() 29 | // send(public_seed) 30 | // CLIENT: recv(public_seed) 31 | // hash_stage1=sha1("password") 32 | // hash_stage2=sha1(hash_stage1) 33 | // reply=xor(hash_stage1, sha1(public_seed,hash_stage2) 34 | // // this three steps are done in scramble() 35 | // send(reply) 36 | // SERVER: recv(reply) 37 | // hash_stage1=xor(reply, sha1(public_seed,hash_stage2)) 38 | // candidate_hash2=sha1(hash_stage1) 39 | // check(candidate_hash2==hash_stage2) 40 | // // this three steps are done in check_scramble() 41 | func CheckScrambledPassword(salt, hpwd, auth []byte) bool { 42 | crypt := sha1.New() 43 | _, err := crypt.Write(salt) 44 | terror.Log(errors.Trace(err)) 45 | _, err = crypt.Write(hpwd) 46 | terror.Log(errors.Trace(err)) 47 | hash := crypt.Sum(nil) 48 | // token = scrambleHash XOR stage1Hash 49 | if len(auth) != len(hash) { 50 | return false 51 | } 52 | for i := range hash { 53 | hash[i] ^= auth[i] 54 | } 55 | 56 | return bytes.Equal(hpwd, Sha1Hash(hash)) 57 | } 58 | 59 | // Sha1Hash is an util function to calculate sha1 hash. 60 | func Sha1Hash(bs []byte) []byte { 61 | crypt := sha1.New() 62 | _, err := crypt.Write(bs) 63 | terror.Log(errors.Trace(err)) 64 | return crypt.Sum(nil) 65 | } 66 | 67 | // EncodePassword converts plaintext password to hashed hex string. 68 | func EncodePassword(pwd string) string { 69 | if len(pwd) == 0 { 70 | return "" 71 | } 72 | hash1 := Sha1Hash([]byte(pwd)) 73 | hash2 := Sha1Hash(hash1) 74 | 75 | return fmt.Sprintf("*%X", hash2) 76 | } 77 | 78 | // DecodePassword converts hex string password without prefix '*' to byte array. 79 | func DecodePassword(pwd string) ([]byte, error) { 80 | x, err := hex.DecodeString(pwd[1:]) 81 | if err != nil { 82 | return nil, errors.Trace(err) 83 | } 84 | return x, nil 85 | } 86 | -------------------------------------------------------------------------------- /auth/mysql_native_password_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package auth 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | ) 19 | 20 | func (s *testAuthSuite) TestEncodePassword(c *C) { 21 | pwd := "123" 22 | c.Assert(EncodePassword(pwd), Equals, "*23AE809DDACAF96AF0FD78ED04B6A265E05AA257") 23 | } 24 | 25 | func (s *testAuthSuite) TestDecodePassword(c *C) { 26 | x, err := DecodePassword(EncodePassword("123")) 27 | c.Assert(err, IsNil) 28 | c.Assert(x, DeepEquals, Sha1Hash(Sha1Hash([]byte("123")))) 29 | } 30 | 31 | func (s *testAuthSuite) TestCheckScramble(c *C) { 32 | pwd := "abc" 33 | salt := []byte{85, 92, 45, 22, 58, 79, 107, 6, 122, 125, 58, 80, 12, 90, 103, 32, 90, 10, 74, 82} 34 | auth := []byte{24, 180, 183, 225, 166, 6, 81, 102, 70, 248, 199, 143, 91, 204, 169, 9, 161, 171, 203, 33} 35 | encodepwd := EncodePassword(pwd) 36 | hpwd, err := DecodePassword(encodepwd) 37 | c.Assert(err, IsNil) 38 | 39 | res := CheckScrambledPassword(salt, hpwd, auth) 40 | c.Assert(res, IsTrue) 41 | 42 | // Do not panic for invalid input. 43 | res = CheckScrambledPassword(salt, hpwd, []byte("xxyyzz")) 44 | c.Assert(res, IsFalse) 45 | } 46 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser 15 | 16 | import ( 17 | "testing" 18 | ) 19 | 20 | func BenchmarkSysbenchSelect(b *testing.B) { 21 | parser := New() 22 | sql := "SELECT pad FROM sbtest1 WHERE id=1;" 23 | b.ResetTimer() 24 | for i := 0; i < b.N; i++ { 25 | _, _, err := parser.Parse(sql, "", "") 26 | if err != nil { 27 | b.Fatal(err) 28 | } 29 | } 30 | b.ReportAllocs() 31 | } 32 | 33 | func BenchmarkParseComplex(b *testing.B) { 34 | var table = []string{ 35 | `SELECT DISTINCT ca.l9_convergence_code AS atb2, cu.cust_sub_type AS account_type, cst.description AS account_type_desc, ss.prim_resource_val AS msisdn, ca.ban AS ban_key, To_char(mo.memo_date, 'YYYYMMDD') AS memo_date, cu.l9_identification AS thai_id, ss.subscriber_no AS subs_key, ss.dealer_code AS shop_code, cd.description AS shop_name, mot.short_desc, Regexp_substr(mo.attr1value, '[^ ;]+', 1, 3) staff_id, mo.operator_id AS user_id, mo.memo_system_text, co2.soc_name AS first_socname, co3.soc_name AS previous_socname, co.soc_name AS current_socname, Regexp_substr(mo.attr1value, '[^ ; ]+', 1, 1) NAME, co.soc_description AS current_pp_desc, co3.soc_description AS prev_pp_desc, co.soc_cd AS soc_cd, ( SELECT Sum(br.amount) FROM bl1_rc_rates BR, customer CU, subscriber SS WHERE br.service_receiver_id = ss.subscriber_no AND br.receiver_customer = ss.customer_id AND br.effective_date <= br.expiration_date AND (( ss. sub_status <> 'C' AND ss. sub_status <> 'T' AND br.expiration_date IS NULL) OR ( ss. sub_status = 'C' AND br.expiration_date LIKE ss.effective_date)) AND br.pp_ind = 'Y' AND br.cycle_code = cu.bill_cycle) AS pp_rate, cu.bill_cycle AS cycle_code, To_char(Nvl(ss.l9_tmv_act_date, ss.init_act_date),'YYYYMMDD') AS activated_date, To_char(cd.effective_date, 'YYYYMMDD') AS shop_effective_date, cd.expiration_date AS shop_expired_date, ca.l9_company_code AS company_code FROM service_details S, product CO, csm_pay_channel CPC, account CA, subscriber SS, customer CU, customer_sub_type CST, csm_dealer CD, service_details S2, product CO2, service_details S3, product CO3, memo MO , memo_type MOT, logical_date LO, charge_details CHD WHERE ss.subscriber_no = chd.agreement_no AND cpc.pym_channel_no = chd.target_pcn AND chd.chg_split_type = 'DR' AND chd.expiration_date IS NULL AND s.soc = co.soc_cd AND co.soc_type = 'P' AND s.agreement_no = ss.subscriber_no AND ss.prim_resource_tp = 'C' AND cpc.payment_category = 'POST' AND ca.ban = cpc.ban AND ( ca.l9_company_code = 'RF' OR ca.l9_company_code = 'RM' OR ca.l9_company_code = 'TM') AND ss.customer_id = cu.customer_id AND cu.cust_sub_type = cst.cust_sub_type AND cu.customer_type = cst.customer_type AND ss.dealer_code = cd.dealer AND s2.effective_date= ( SELECT Max(sa1.effective_date) FROM service_details SA1, product o1 WHERE sa1.agreement_no = ss.subscriber_no AND co.soc_cd = sa1.soc AND co.soc_type = 'P' ) AND s2.agreement_no = s.agreement_no AND s2.soc = co2.soc_cd AND co2.soc_type = 'P' AND s2.effective_date = ( SELECT Min(sa1.effective_date) FROM service_details SA1, product o1 WHERE sa1.agreement_no = ss.subscriber_no AND co2.soc_cd = sa1.soc AND co.soc_type = 'P' ) AND s3.agreement_no = s.agreement_no AND s3.soc = co3.soc_cd AND co3.soc_type = 'P' AND s3.effective_date = ( SELECT Max(sa1.effective_date) FROM service_details SA1, a product o1 WHERE sa1.agreement_no = ss.subscriber_no AND sa1.effective_date < ( SELECT Max(sa1.effective_date) FROM service_details SA1, product o1 WHERE sa1.agreement_no = ss.subscriber_no AND co3.soc_cd = sa1.soc AND co3.soc_type = 'P' ) AND co3.soc_cd = sa1.soc AND o1.soc_type = 'P' ) AND mo.entity_id = ss.subscriber_no AND mo.entity_type_id = 6 AND mo.memo_type_id = mot.memo_type_id AND Trunc(mo.sys_creation_date) = ( SELECT Trunc(lo.logical_date - 1) FROM lo) trunc(lo.logical_date - 1) AND lo.expiration_date IS NULL AND lo.logical_date_type = 'B' AND lo.expiration_date IS NULL AND ( mot.short_desc = 'BCN' OR mot.short_desc = 'BCNM' )`} 36 | parser := New() 37 | b.ResetTimer() 38 | for i := 0; i < b.N; i++ { 39 | for _, v := range table { 40 | _, _, err := parser.Parse(v, "", "") 41 | if err != nil { 42 | b.Failed() 43 | } 44 | } 45 | } 46 | b.ReportAllocs() 47 | } 48 | 49 | func BenchmarkParseSimple(b *testing.B) { 50 | var table = []string{ 51 | "insert into t values (1), (2), (3)", 52 | "insert into t values (4), (5), (6), (7)", 53 | "select c from t where c > 2", 54 | } 55 | parser := New() 56 | b.ResetTimer() 57 | for i := 0; i < b.N; i++ { 58 | for _, v := range table { 59 | _, _, err := parser.Parse(v, "", "") 60 | if err != nil { 61 | b.Failed() 62 | } 63 | } 64 | } 65 | b.ReportAllocs() 66 | } 67 | -------------------------------------------------------------------------------- /charset/charset_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package charset 15 | 16 | import ( 17 | "math/rand" 18 | "testing" 19 | 20 | . "github.com/pingcap/check" 21 | ) 22 | 23 | func TestT(t *testing.T) { 24 | CustomVerboseFlag = true 25 | TestingT(t) 26 | } 27 | 28 | var _ = Suite(&testCharsetSuite{}) 29 | 30 | type testCharsetSuite struct { 31 | } 32 | 33 | func testValidCharset(c *C, charset string, collation string, expect bool) { 34 | b := ValidCharsetAndCollation(charset, collation) 35 | c.Assert(b, Equals, expect) 36 | } 37 | 38 | func (s *testCharsetSuite) TestValidCharset(c *C) { 39 | tests := []struct { 40 | cs string 41 | co string 42 | succ bool 43 | }{ 44 | {"utf8", "utf8_general_ci", true}, 45 | {"", "utf8_general_ci", true}, 46 | {"utf8mb4", "utf8mb4_bin", true}, 47 | {"latin1", "latin1_bin", true}, 48 | {"utf8", "utf8_invalid_ci", false}, 49 | {"utf16", "utf16_bin", false}, 50 | {"gb2312", "gb2312_chinese_ci", false}, 51 | {"UTF8", "UTF8_BIN", true}, 52 | {"UTF8", "utf8_bin", true}, 53 | {"UTF8MB4", "utf8mb4_bin", true}, 54 | {"UTF8MB4", "UTF8MB4_bin", true}, 55 | {"UTF8MB4", "UTF8MB4_general_ci", true}, 56 | {"Utf8", "uTf8_bIN", true}, 57 | } 58 | for _, tt := range tests { 59 | testValidCharset(c, tt.cs, tt.co, tt.succ) 60 | } 61 | } 62 | 63 | func (s *testCharsetSuite) TestValidCustomCharset(c *C) { 64 | AddCharset(&Charset{"custom", "custom_collation", make(map[string]*Collation), "Custom", 4}) 65 | AddCollation(&Collation{99999, "custom", "custom_collation", true}) 66 | 67 | tests := []struct { 68 | cs string 69 | co string 70 | succ bool 71 | }{ 72 | {"custom", "custom_collation", true}, 73 | {"utf8", "utf8_invalid_ci", false}, 74 | } 75 | for _, tt := range tests { 76 | testValidCharset(c, tt.cs, tt.co, tt.succ) 77 | } 78 | } 79 | 80 | func testGetDefaultCollation(c *C, charset string, expectCollation string, succ bool) { 81 | b, err := GetDefaultCollation(charset) 82 | if !succ { 83 | c.Assert(err, NotNil) 84 | return 85 | } 86 | c.Assert(b, Equals, expectCollation) 87 | } 88 | 89 | func (s *testCharsetSuite) TestGetDefaultCollation(c *C) { 90 | tests := []struct { 91 | cs string 92 | co string 93 | succ bool 94 | }{ 95 | {"utf8", "utf8_bin", true}, 96 | {"UTF8", "utf8_bin", true}, 97 | {"utf8mb4", "utf8mb4_bin", true}, 98 | {"ascii", "ascii_bin", true}, 99 | {"binary", "binary", true}, 100 | {"latin1", "latin1_bin", true}, 101 | {"invalid_cs", "", false}, 102 | {"", "utf8_bin", false}, 103 | } 104 | for _, tt := range tests { 105 | testGetDefaultCollation(c, tt.cs, tt.co, tt.succ) 106 | } 107 | 108 | // Test the consistency of collations table and charset desc table 109 | charset_num := 0 110 | for _, collate := range collations { 111 | if collate.IsDefault { 112 | if desc, ok := charsetInfos[collate.CharsetName]; ok { 113 | c.Assert(collate.Name, Equals, desc.DefaultCollation) 114 | charset_num += 1 115 | } 116 | } 117 | } 118 | c.Assert(charset_num, Equals, len(charsetInfos)) 119 | } 120 | 121 | func (s *testCharsetSuite) TestSupportedCollations(c *C) { 122 | // All supportedCollation are defined from their names 123 | c.Assert(len(supportedCollationNames), Equals, len(supportedCollationNames)) 124 | 125 | // The default collations of supported charsets is the subset of supported collations 126 | errMsg := "Charset [%v] is supported but its default collation [%v] is not." 127 | for _, desc := range GetSupportedCharsets() { 128 | found := false 129 | for _, c := range GetSupportedCollations() { 130 | if desc.DefaultCollation == c.Name { 131 | found = true 132 | break 133 | } 134 | } 135 | c.Assert(found, IsTrue, Commentf(errMsg, desc.Name, desc.DefaultCollation)) 136 | } 137 | } 138 | 139 | func (s *testCharsetSuite) TestGetCharsetDesc(c *C) { 140 | tests := []struct { 141 | cs string 142 | result string 143 | succ bool 144 | }{ 145 | {"utf8", "utf8", true}, 146 | {"UTF8", "utf8", true}, 147 | {"utf8mb4", "utf8mb4", true}, 148 | {"ascii", "ascii", true}, 149 | {"binary", "binary", true}, 150 | {"latin1", "latin1", true}, 151 | {"invalid_cs", "", false}, 152 | {"", "utf8_bin", false}, 153 | } 154 | for _, tt := range tests { 155 | desc, err := GetCharsetInfo(tt.cs) 156 | if !tt.succ { 157 | c.Assert(err, NotNil) 158 | } else { 159 | c.Assert(desc.Name, Equals, tt.result) 160 | } 161 | } 162 | } 163 | 164 | func (s *testCharsetSuite) TestGetCollationByName(c *C) { 165 | 166 | for _, collation := range collations { 167 | coll, err := GetCollationByName(collation.Name) 168 | c.Assert(err, IsNil) 169 | c.Assert(coll, Equals, collation) 170 | } 171 | 172 | _, err := GetCollationByName("non_exist") 173 | c.Assert(err, ErrorMatches, "\\[ddl:1273\\]Unknown collation: 'non_exist'") 174 | } 175 | 176 | func BenchmarkGetCharsetDesc(b *testing.B) { 177 | b.ResetTimer() 178 | charsets := []string{CharsetUTF8, CharsetUTF8MB4, CharsetASCII, CharsetLatin1, CharsetBin} 179 | index := rand.Intn(len(charsets)) 180 | cs := charsets[index] 181 | 182 | for i := 0; i < b.N; i++ { 183 | GetCharsetInfo(cs) 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /charset/encoding.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package charset 15 | 16 | import ( 17 | "bytes" 18 | "fmt" 19 | "strings" 20 | 21 | "github.com/cznic/mathutil" 22 | "github.com/pingcap/parser/mysql" 23 | "github.com/pingcap/parser/terror" 24 | "golang.org/x/text/encoding" 25 | "golang.org/x/text/transform" 26 | ) 27 | 28 | const encodingLegacy = "utf-8" // utf-8 encoding is compatible with old default behavior. 29 | 30 | var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString) 31 | 32 | type EncodingLabel string 33 | 34 | // Format trim and change the label to lowercase. 35 | func Format(label string) EncodingLabel { 36 | return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f "))) 37 | } 38 | 39 | // Formatted is used when the label is already trimmed and it is lowercase. 40 | func Formatted(label string) EncodingLabel { 41 | return EncodingLabel(label) 42 | } 43 | 44 | // Encoding provide a interface to encode/decode a string with specific encoding. 45 | type Encoding struct { 46 | enc encoding.Encoding 47 | name string 48 | charLength func([]byte) int 49 | } 50 | 51 | // Enabled indicates whether the non-utf8 encoding is used. 52 | func (e *Encoding) Enabled() bool { 53 | return e.enc != nil && e.charLength != nil 54 | } 55 | 56 | // Name returns the name of the current encoding. 57 | func (e *Encoding) Name() string { 58 | return e.name 59 | } 60 | 61 | // NewEncoding creates a new Encoding. 62 | func NewEncoding(label string) *Encoding { 63 | if len(label) == 0 { 64 | return &Encoding{} 65 | } 66 | e, name := Lookup(label) 67 | if e != nil && name != encodingLegacy { 68 | return &Encoding{ 69 | enc: e, 70 | name: name, 71 | charLength: FindNextCharacterLength(name), 72 | } 73 | } 74 | return &Encoding{name: name} 75 | } 76 | 77 | // UpdateEncoding updates to a new Encoding. 78 | func (e *Encoding) UpdateEncoding(label EncodingLabel) { 79 | enc, name := lookup(label) 80 | e.name = name 81 | if enc != nil && name != encodingLegacy { 82 | e.enc = enc 83 | e.charLength = FindNextCharacterLength(name) 84 | } else { 85 | e.enc = nil 86 | e.charLength = nil 87 | } 88 | } 89 | 90 | // Encode convert bytes from utf-8 charset to a specific charset. 91 | func (e *Encoding) Encode(dest, src []byte) ([]byte, error) { 92 | return e.transform(e.enc.NewEncoder(), dest, src, false) 93 | } 94 | 95 | // Decode convert bytes from a specific charset to utf-8 charset. 96 | func (e *Encoding) Decode(dest, src []byte) ([]byte, error) { 97 | return e.transform(e.enc.NewDecoder(), dest, src, true) 98 | } 99 | 100 | func (e *Encoding) transform(transformer transform.Transformer, dest, src []byte, isDecoding bool) ([]byte, error) { 101 | if len(dest) < len(src) { 102 | dest = make([]byte, len(src)*2) 103 | } 104 | var destOffset, srcOffset int 105 | var encodingErr error 106 | for { 107 | srcNextLen := e.nextCharLenInSrc(src[srcOffset:], isDecoding) 108 | srcEnd := mathutil.Min(srcOffset+srcNextLen, len(src)) 109 | nDest, nSrc, err := transformer.Transform(dest[destOffset:], src[srcOffset:srcEnd], false) 110 | if err == transform.ErrShortDst { 111 | dest = enlargeCapacity(dest) 112 | } else if err != nil || isDecoding && beginWithReplacementChar(dest[destOffset:destOffset+nDest]) { 113 | if encodingErr == nil { 114 | encodingErr = e.generateErr(src[srcOffset:], srcNextLen) 115 | } 116 | dest[destOffset] = byte('?') 117 | nDest, nSrc = 1, srcNextLen // skip the source bytes that cannot be decoded normally. 118 | } 119 | destOffset += nDest 120 | srcOffset += nSrc 121 | // The source bytes are exhausted. 122 | if srcOffset >= len(src) { 123 | return dest[:destOffset], encodingErr 124 | } 125 | } 126 | } 127 | 128 | func (e *Encoding) nextCharLenInSrc(srcRest []byte, isDecoding bool) int { 129 | if isDecoding && e.charLength != nil { 130 | return e.charLength(srcRest) 131 | } 132 | return len(srcRest) 133 | } 134 | 135 | func enlargeCapacity(dest []byte) []byte { 136 | newDest := make([]byte, len(dest)*2) 137 | copy(newDest, dest) 138 | return newDest 139 | } 140 | 141 | func (e *Encoding) generateErr(srcRest []byte, srcNextLen int) error { 142 | cutEnd := mathutil.Min(srcNextLen, len(srcRest)) 143 | invalidBytes := fmt.Sprintf("%X", string(srcRest[:cutEnd])) 144 | return errInvalidCharacterString.GenWithStackByArgs(e.name, invalidBytes) 145 | } 146 | 147 | // replacementBytes are bytes for the replacement rune 0xfffd. 148 | var replacementBytes = []byte{0xEF, 0xBF, 0xBD} 149 | 150 | // beginWithReplacementChar check if dst has the prefix '0xEFBFBD'. 151 | func beginWithReplacementChar(dst []byte) bool { 152 | return bytes.HasPrefix(dst, replacementBytes) 153 | } 154 | -------------------------------------------------------------------------------- /charset/encoding_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package charset_test 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | "github.com/pingcap/parser/charset" 19 | "golang.org/x/text/transform" 20 | ) 21 | 22 | var _ = Suite(&testEncodingSuite{}) 23 | 24 | type testEncodingSuite struct { 25 | } 26 | 27 | func (s *testEncodingSuite) TestEncoding(c *C) { 28 | enc := charset.NewEncoding("gbk") 29 | c.Assert(enc.Name(), Equals, "gbk") 30 | c.Assert(enc.Enabled(), IsTrue) 31 | enc.UpdateEncoding("utf-8") 32 | c.Assert(enc.Name(), Equals, "utf-8") 33 | enc.UpdateEncoding("gbk") 34 | c.Assert(enc.Name(), Equals, "gbk") 35 | c.Assert(enc.Enabled(), IsTrue) 36 | 37 | txt := []byte("一二三四") 38 | e, _ := charset.Lookup("gbk") 39 | gbkEncodedTxt, _, err := transform.Bytes(e.NewEncoder(), txt) 40 | c.Assert(err, IsNil) 41 | result, err := enc.Decode(nil, gbkEncodedTxt) 42 | c.Assert(err, IsNil) 43 | c.Assert(result, DeepEquals, txt) 44 | 45 | gbkEncodedTxt2, err := enc.Encode(nil, txt) 46 | c.Assert(err, IsNil) 47 | c.Assert(gbkEncodedTxt, DeepEquals, gbkEncodedTxt2) 48 | result, err = enc.Decode(nil, gbkEncodedTxt2) 49 | c.Assert(err, IsNil) 50 | c.Assert(result, DeepEquals, txt) 51 | 52 | GBKCases := []struct { 53 | utf8Str string 54 | result string 55 | isValid bool 56 | }{ 57 | {"一二三", "涓?簩涓?", false}, // MySQL reports '涓?簩涓'. 58 | {"一二三123", "涓?簩涓?23", false}, 59 | {"案1案2", "妗?妗?", false}, 60 | {"焊䏷菡釬", "鐒婁彿鑿¢嚞", true}, 61 | {"鞍杏以伊位依", "闉嶆潖浠ヤ紛浣嶄緷", true}, 62 | {"移維緯胃萎衣謂違", "绉荤董绶?儍钀庤。璎傞仌", false}, 63 | {"仆仂仗仞仭仟价伉佚估", "浠嗕粋浠椾粸浠?粺浠蜂級浣氫及", false}, 64 | {"佝佗佇佶侈侏侘佻佩佰侑佯", "浣濅綏浣囦蕉渚堜緩渚樹交浣╀桨渚戜蒋", true}, 65 | } 66 | for _, tc := range GBKCases { 67 | cmt := Commentf("%v", tc) 68 | result, err = enc.Decode(nil, []byte(tc.utf8Str)) 69 | if tc.isValid { 70 | c.Assert(err, IsNil, cmt) 71 | } else { 72 | c.Assert(err, NotNil, cmt) 73 | } 74 | c.Assert(string(result), Equals, tc.result, Commentf("%v", tc)) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /checkout-pr-branch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is used to checkout a Parser PR branch in a forked repo. 4 | if test -z $1; then 5 | echo -e "Usage:\n" 6 | echo -e "\tcheckout-pr-branch.sh [github-username]:[pr-branch]\n" 7 | echo -e "The argument can be copied directly from github PR page." 8 | echo -e "The local branch name would be [github-username]/[pr-branch]." 9 | exit 0; 10 | fi 11 | 12 | username=$(echo $1 | cut -d':' -f1) 13 | branch=$(echo $1 | cut -d':' -f2) 14 | local_branch=$username/$branch 15 | fork="https://github.com/$username/parser" 16 | 17 | exists=`git show-ref refs/heads/$local_branch` 18 | if [ -n "$exists" ]; then 19 | git checkout $local_branch 20 | git pull $fork $branch:$local_branch 21 | else 22 | git fetch $fork $branch:$local_branch 23 | git checkout $local_branch 24 | fi 25 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build-ut: 5 | docker: 6 | - image: golang:1.16 7 | working_directory: /go/src/github.com/pingcap/parser 8 | steps: 9 | - checkout 10 | - run: echo skip 11 | build-integration: 12 | docker: 13 | - image: golang:1.16 14 | working_directory: /go/src/github.com/pingcap/parser 15 | steps: 16 | - checkout 17 | - run: echo skip 18 | 19 | workflows: 20 | version: 2 21 | build_and_test: 22 | jobs: 23 | - build-ut 24 | - build-integration 25 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: no 3 | notify: 4 | wait_for_ci: no 5 | 6 | coverage: 7 | status: 8 | project: 9 | default: 10 | threshold: 0.2 11 | patch: 12 | default: 13 | target: 0% # trial operation 14 | changes: no 15 | 16 | comment: 17 | layout: "header, diff" 18 | behavior: default 19 | require_changes: no 20 | -------------------------------------------------------------------------------- /consistent_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser 15 | 16 | import ( 17 | "io/ioutil" 18 | "os" 19 | "path" 20 | "runtime" 21 | "sort" 22 | "strings" 23 | 24 | . "github.com/pingcap/check" 25 | ) 26 | 27 | var _ = Suite(&testConsistentSuite{}) 28 | 29 | type testConsistentSuite struct { 30 | content string 31 | 32 | reservedKeywords []string 33 | unreservedKeywords []string 34 | notKeywordTokens []string 35 | tidbKeywords []string 36 | } 37 | 38 | func (s *testConsistentSuite) SetUpSuite(c *C) { 39 | _, filename, _, _ := runtime.Caller(0) 40 | parserFilename := path.Join(path.Dir(filename), "parser.y") 41 | parserFile, err := os.Open(parserFilename) 42 | c.Assert(err, IsNil) 43 | data, err := ioutil.ReadAll(parserFile) 44 | c.Assert(err, IsNil) 45 | s.content = string(data) 46 | 47 | reservedKeywordStartMarker := "\t/* The following tokens belong to ReservedKeyword. Notice: make sure these tokens are contained in ReservedKeyword. */" 48 | unreservedKeywordStartMarker := "\t/* The following tokens belong to UnReservedKeyword. Notice: make sure these tokens are contained in UnReservedKeyword. */" 49 | notKeywordTokenStartMarker := "\t/* The following tokens belong to NotKeywordToken. Notice: make sure these tokens are contained in NotKeywordToken. */" 50 | tidbKeywordStartMarker := "\t/* The following tokens belong to TiDBKeyword. Notice: make sure these tokens are contained in TiDBKeyword. */" 51 | identTokenEndMarker := "%token\t" 52 | 53 | s.reservedKeywords = extractKeywords(s.content, reservedKeywordStartMarker, unreservedKeywordStartMarker) 54 | s.unreservedKeywords = extractKeywords(s.content, unreservedKeywordStartMarker, notKeywordTokenStartMarker) 55 | s.notKeywordTokens = extractKeywords(s.content, notKeywordTokenStartMarker, tidbKeywordStartMarker) 56 | s.tidbKeywords = extractKeywords(s.content, tidbKeywordStartMarker, identTokenEndMarker) 57 | } 58 | 59 | func (s *testConsistentSuite) TestKeywordConsistent(c *C) { 60 | for k, v := range aliases { 61 | c.Assert(k, Not(Equals), v) 62 | c.Assert(tokenMap[k], Equals, tokenMap[v]) 63 | } 64 | keywordCount := len(s.reservedKeywords) + len(s.unreservedKeywords) + len(s.notKeywordTokens) + len(s.tidbKeywords) 65 | c.Assert(len(tokenMap)-len(aliases), Equals, keywordCount-len(windowFuncTokenMap)) 66 | 67 | unreservedCollectionDef := extractKeywordsFromCollectionDef(s.content, "\nUnReservedKeyword:") 68 | c.Assert(s.unreservedKeywords, DeepEquals, unreservedCollectionDef) 69 | 70 | notKeywordTokensCollectionDef := extractKeywordsFromCollectionDef(s.content, "\nNotKeywordToken:") 71 | c.Assert(s.notKeywordTokens, DeepEquals, notKeywordTokensCollectionDef) 72 | 73 | tidbKeywordsCollectionDef := extractKeywordsFromCollectionDef(s.content, "\nTiDBKeyword:") 74 | c.Assert(s.tidbKeywords, DeepEquals, tidbKeywordsCollectionDef) 75 | } 76 | 77 | func extractMiddle(str, startMarker, endMarker string) string { 78 | startIdx := strings.Index(str, startMarker) 79 | if startIdx == -1 { 80 | return "" 81 | } 82 | str = str[startIdx+len(startMarker):] 83 | endIdx := strings.Index(str, endMarker) 84 | if endIdx == -1 { 85 | return "" 86 | } 87 | return str[:endIdx] 88 | } 89 | 90 | func extractQuotedWords(strs []string) []string { 91 | var words []string 92 | for _, str := range strs { 93 | word := extractMiddle(str, "\"", "\"") 94 | if word == "" { 95 | continue 96 | } 97 | words = append(words, word) 98 | } 99 | sort.Strings(words) 100 | return words 101 | } 102 | 103 | func extractKeywords(content, startMarker, endMarker string) []string { 104 | keywordSection := extractMiddle(content, startMarker, endMarker) 105 | lines := strings.Split(keywordSection, "\n") 106 | return extractQuotedWords(lines) 107 | } 108 | 109 | func extractKeywordsFromCollectionDef(content, startMarker string) []string { 110 | keywordSection := extractMiddle(content, startMarker, "\n\n") 111 | words := strings.Split(keywordSection, "|") 112 | return extractQuotedWords(words) 113 | } 114 | -------------------------------------------------------------------------------- /digester_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser_test 15 | 16 | import ( 17 | "crypto/sha256" 18 | "encoding/hex" 19 | "fmt" 20 | "testing" 21 | 22 | . "github.com/pingcap/check" 23 | "github.com/pingcap/parser" 24 | ) 25 | 26 | var _ = Suite(&testSQLDigestSuite{}) 27 | 28 | type testSQLDigestSuite struct { 29 | } 30 | 31 | func (s *testSQLDigestSuite) TestNormalize(c *C) { 32 | tests := []struct { 33 | input string 34 | expect string 35 | }{ 36 | {"select _utf8mb4'123'", "select (_charset) ?"}, 37 | {"SELECT 1", "select ?"}, 38 | {"select null", "select ?"}, 39 | {"select \\N", "select ?"}, 40 | {"SELECT `null`", "select `null`"}, 41 | {"select * from b where id = 1", "select * from `b` where `id` = ?"}, 42 | {"select 1 from b where id in (1, 3, '3', 1, 2, 3, 4)", "select ? from `b` where `id` in ( ... )"}, 43 | {"select 1 from b where id in (1, a, 4)", "select ? from `b` where `id` in ( ? , `a` , ? )"}, 44 | {"select 1 from b order by 2", "select ? from `b` order by 2"}, 45 | {"select /*+ a hint */ 1", "select ?"}, 46 | {"select /* a hint */ 1", "select ?"}, 47 | {"select truncate(1, 2)", "select truncate ( ... )"}, 48 | {"select -1 + - 2 + b - c + 0.2 + (-2) from c where d in (1, -2, +3)", "select ? + ? + `b` - `c` + ? + ( ? ) from `c` where `d` in ( ... )"}, 49 | {"select * from t where a <= -1 and b < -2 and c = -3 and c > -4 and c >= -5 and e is 1", "select * from `t` where `a` <= ? and `b` < ? and `c` = ? and `c` > ? and `c` >= ? and `e` is ?"}, 50 | {"select count(a), b from t group by 2", "select count ( `a` ) , `b` from `t` group by 2"}, 51 | {"select count(a), b, c from t group by 2, 3", "select count ( `a` ) , `b` , `c` from `t` group by 2 , 3"}, 52 | {"select count(a), b, c from t group by (2, 3)", "select count ( `a` ) , `b` , `c` from `t` group by ( 2 , 3 )"}, 53 | {"select a, b from t order by 1, 2", "select `a` , `b` from `t` order by 1 , 2"}, 54 | {"select count(*) from t", "select count ( ? ) from `t`"}, 55 | {"select * from t Force Index(kk)", "select * from `t`"}, 56 | {"select * from t USE Index(kk)", "select * from `t`"}, 57 | {"select * from t Ignore Index(kk)", "select * from `t`"}, 58 | {"select * from t1 straight_join t2 on t1.id=t2.id", "select * from `t1` join `t2` on `t1` . `id` = `t2` . `id`"}, 59 | {"select * from `table`", "select * from `table`"}, 60 | {"select * from `30`", "select * from `30`"}, 61 | {"select * from `select`", "select * from `select`"}, 62 | // test syntax error, it will be checked by parser, but it should not make normalize dead loop. 63 | {"select * from t ignore index(", "select * from `t` ignore index"}, 64 | {"select /*+ ", "select "}, 65 | {"select * from 🥳", "select * from"}, 66 | {"select 1 / 2", "select ? / ?"}, 67 | {"select * from t where a = 40 limit ?, ?", "select * from `t` where `a` = ? limit ..."}, 68 | {"select * from t where a > ?", "select * from `t` where `a` > ?"}, 69 | {"select @a=b from t", "select @a = `b` from `t`"}, 70 | {"select * from `table", "select * from"}, 71 | } 72 | for _, test := range tests { 73 | normalized := parser.Normalize(test.input) 74 | digest := parser.DigestNormalized(normalized) 75 | c.Assert(normalized, Equals, test.expect) 76 | 77 | normalized2, digest2 := parser.NormalizeDigest(test.input) 78 | c.Assert(normalized2, Equals, normalized) 79 | c.Assert(digest2.String(), Equals, digest.String(), Commentf("%+v", test)) 80 | } 81 | } 82 | 83 | func (s *testSQLDigestSuite) TestNormalizeDigest(c *C) { 84 | tests := []struct { 85 | sql string 86 | normalized string 87 | digest string 88 | }{ 89 | {"select 1 from b where id in (1, 3, '3', 1, 2, 3, 4)", "select ? from `b` where `id` in ( ... )", "e1c8cc2738f596dc24f15ef8eb55e0d902910d7298983496362a7b46dbc0b310"}, 90 | } 91 | for _, test := range tests { 92 | normalized, digest := parser.NormalizeDigest(test.sql) 93 | c.Assert(normalized, Equals, test.normalized) 94 | c.Assert(digest.String(), Equals, test.digest) 95 | 96 | normalized = parser.Normalize(test.sql) 97 | digest = parser.DigestNormalized(normalized) 98 | c.Assert(normalized, Equals, test.normalized) 99 | c.Assert(digest.String(), Equals, test.digest) 100 | } 101 | } 102 | 103 | func (s *testSQLDigestSuite) TestDigestHashEqForSimpleSQL(c *C) { 104 | sqlGroups := [][]string{ 105 | {"select * from b where id = 1", "select * from b where id = '1'", "select * from b where id =2"}, 106 | {"select 2 from b, c where c.id > 1", "select 4 from b, c where c.id > 23"}, 107 | {"Select 3", "select 1"}, 108 | } 109 | for _, sqlGroup := range sqlGroups { 110 | var d string 111 | for _, sql := range sqlGroup { 112 | dig := parser.DigestHash(sql) 113 | if d == "" { 114 | d = dig.String() 115 | continue 116 | } 117 | c.Assert(d, Equals, dig.String()) 118 | } 119 | } 120 | } 121 | 122 | func (s *testSQLDigestSuite) TestDigestHashNotEqForSimpleSQL(c *C) { 123 | sqlGroups := [][]string{ 124 | {"select * from b where id = 1", "select a from b where id = 1", "select * from d where bid =1"}, 125 | } 126 | for _, sqlGroup := range sqlGroups { 127 | var d string 128 | for _, sql := range sqlGroup { 129 | dig := parser.DigestHash(sql) 130 | if d == "" { 131 | d = dig.String() 132 | continue 133 | } 134 | c.Assert(d, Not(Equals), dig.String()) 135 | } 136 | } 137 | } 138 | 139 | func (s *testSQLDigestSuite) TestGenDigest(c *C) { 140 | hash := genRandDigest("abc") 141 | digest := parser.NewDigest(hash) 142 | c.Assert(digest.String(), Equals, fmt.Sprintf("%x", hash)) 143 | c.Assert(digest.Bytes(), DeepEquals, hash) 144 | digest = parser.NewDigest(nil) 145 | c.Assert(digest.String(), Equals, "") 146 | c.Assert(digest.Bytes(), IsNil) 147 | } 148 | 149 | func genRandDigest(str string) []byte { 150 | hasher := sha256.New() 151 | hasher.Write([]byte(str)) 152 | return hasher.Sum(nil) 153 | } 154 | 155 | func BenchmarkDigestHexEncode(b *testing.B) { 156 | digest1 := genRandDigest("abc") 157 | b.ResetTimer() 158 | for i := 0; i < b.N; i++ { 159 | hex.EncodeToString(digest1) 160 | } 161 | } 162 | 163 | func BenchmarkDigestSprintf(b *testing.B) { 164 | digest1 := genRandDigest("abc") 165 | b.ResetTimer() 166 | for i := 0; i < b.N; i++ { 167 | fmt.Sprintf("%x", digest1) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | This parser is highly compatible with MySQL syntax. You can use it as a library, parse a text SQL into an AST tree, and traverse the AST nodes. 4 | 5 | In this example, you will build a project, which can extract all the column names from a text SQL. 6 | 7 | ## Prerequisites 8 | 9 | - [Golang](https://golang.org/dl/) version 1.13 or above. You can follow the instructions in the official [installation page](https://golang.org/doc/install) (check it by `go version`) 10 | 11 | ## Create a Project 12 | 13 | ```bash 14 | mkdir colx && cd colx 15 | go mod init colx && touch main.go 16 | ``` 17 | 18 | ## Import Dependencies 19 | 20 | First of all, you need to use `go get` to fetch the dependencies through git hash. The git hashes are available in [release page](https://github.com/pingcap/parser/releases). Take `v4.0.2` as an example: 21 | 22 | ```bash 23 | go get -v github.com/pingcap/parser@3a18f1e 24 | ``` 25 | 26 | > **NOTE** 27 | > 28 | > You may want to use advanced API on expressions (a kind of AST node), such as numbers, string literals, booleans, nulls, etc. It is strongly recommended to use the `types` package in TiDB repo with the following command: 29 | > 30 | > ```bash 31 | > go get -v github.com/pingcap/tidb/types/parser_driver@328b6d0 32 | > ``` 33 | > and import it in your golang source code: 34 | > ```go 35 | > import _ "github.com/pingcap/tidb/types/parser_driver" 36 | > ``` 37 | 38 | Your directory should contain the following three files: 39 | ``` 40 | . 41 | ├── go.mod 42 | ├── go.sum 43 | └── main.go 44 | ``` 45 | 46 | Now, open `main.go` with your favorite editor, and start coding! 47 | 48 | ## Parse SQL text 49 | 50 | To convert a SQL text to an AST tree, you need to: 51 | 1. Use the [`parser.New()`](https://pkg.go.dev/github.com/pingcap/parser?tab=doc#New) function to instantiate a parser, and 52 | 2. Invoke the method [`Parse(sql, charset, collation)`](https://pkg.go.dev/github.com/pingcap/parser?tab=doc#Parser.Parse) on the parser. 53 | 54 | ```go 55 | package main 56 | 57 | import ( 58 | "fmt" 59 | "github.com/pingcap/parser" 60 | "github.com/pingcap/parser/ast" 61 | _ "github.com/pingcap/parser/test_driver" 62 | ) 63 | 64 | func parse(sql string) (*ast.StmtNode, error) { 65 | p := parser.New() 66 | 67 | stmtNodes, _, err := p.Parse(sql, "", "") 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return &stmtNodes[0], nil 73 | } 74 | 75 | func main() { 76 | astNode, err := parse("SELECT a, b FROM t") 77 | if err != nil { 78 | fmt.Printf("parse error: %v\n", err.Error()) 79 | return 80 | } 81 | fmt.Printf("%v\n", *astNode) 82 | } 83 | ``` 84 | 85 | Test the parser by running the following command: 86 | 87 | ```bash 88 | go run main.go 89 | ``` 90 | 91 | If the parser runs properly, you should get a result like this: 92 | 93 | ``` 94 | &{{{{SELECT a, b FROM t}}} {[]} 0xc0000a1980 false 0xc00000e7a0 0xc0000a19b0 [] none [] false false 0 } 95 | ``` 96 | 97 | > **NOTE** 98 | > 99 | > Here are a few things you might want to know: 100 | > - To use a parser, a `parser_driver` is required. It decides how to parse the basic data types in SQL. 101 | > 102 | > You can use [`github.com/pingcap/parser/test_driver`](https://pkg.go.dev/github.com/pingcap/parser/test_driver) as the `parser_driver` for test. Again, if you need advanced features, please use the `parser_driver` in TiDB (run `go get -v github.com/pingcap/tidb/types/parser_driver@328b6d0` and import it). 103 | > - The instantiated parser object is not goroutine safe. It is better to keep it in a single goroutine. 104 | > - The instantiated parser object is not lightweight. It is better to reuse it if possible. 105 | > - The 2nd and 3rd arguments of [`parser.Parse()`](https://pkg.go.dev/github.com/pingcap/parser?tab=doc#Parser.Parse) are charset and collation respectively. If you pass an empty string into it, a default value is chosen. 106 | 107 | 108 | ## Traverse AST Nodes 109 | 110 | Now you get the AST tree root of a SQL statement. It is time to extract the column names by traverse. 111 | 112 | Parser implements the interface [`ast.Node`](https://pkg.go.dev/github.com/pingcap/parser/ast?tab=doc#Node) for each kind of AST node, such as SelectStmt, TableName, ColumnName. [`ast.Node`](https://pkg.go.dev/github.com/pingcap/parser/ast?tab=doc#Node) provides a method `Accept(v Visitor) (node Node, ok bool)` to allow any struct that has implemented [`ast.Visitor`](https://pkg.go.dev/github.com/pingcap/parser/ast?tab=doc#Visitor) to traverse itself. 113 | 114 | [`ast.Visitor`](https://pkg.go.dev/github.com/pingcap/parser/ast?tab=doc#Visitor) is defined as follows: 115 | ```go 116 | type Visitor interface { 117 | Enter(n Node) (node Node, skipChildren bool) 118 | Leave(n Node) (node Node, ok bool) 119 | } 120 | ``` 121 | 122 | Now you can define your own visitor, `colX`(columnExtractor): 123 | 124 | ```go 125 | type colX struct{ 126 | colNames []string 127 | } 128 | 129 | func (v *colX) Enter(in ast.Node) (ast.Node, bool) { 130 | if name, ok := in.(*ast.ColumnName); ok { 131 | v.colNames = append(v.colNames, name.Name.O) 132 | } 133 | return in, false 134 | } 135 | 136 | func (v *colX) Leave(in ast.Node) (ast.Node, bool) { 137 | return in, true 138 | } 139 | ``` 140 | 141 | Finally, wrap `colX` in a simple function: 142 | 143 | ```go 144 | func extract(rootNode *ast.StmtNode) []string { 145 | v := &colX{} 146 | (*rootNode).Accept(v) 147 | return v.colNames 148 | } 149 | ``` 150 | 151 | And slightly modify the main function: 152 | 153 | ```go 154 | func main() { 155 | if len(os.Args) != 2 { 156 | fmt.Println("usage: colx 'SQL statement'") 157 | return 158 | } 159 | sql := os.Args[1] 160 | astNode, err := parse(sql) 161 | if err != nil { 162 | fmt.Printf("parse error: %v\n", err.Error()) 163 | return 164 | } 165 | fmt.Printf("%v\n", extract(astNode)) 166 | } 167 | ``` 168 | 169 | Test your program: 170 | 171 | ```bash 172 | go build && ./colx 'select a, b from t' 173 | ``` 174 | 175 | ``` 176 | [a b] 177 | ``` 178 | 179 | You can also try a different SQL statement as an input. For example: 180 | 181 | ```console 182 | $ ./colx 'SELECT a, b FROM t GROUP BY (a, b) HAVING a > c ORDER BY b' 183 | [a b a b a c b] 184 | 185 | If necessary, you can deduplicate by yourself. 186 | 187 | $ ./colx 'SELECT a, b FROM t/invalid_str' 188 | parse error: line 1 column 19 near "/invalid_str" 189 | ``` 190 | 191 | Enjoy! 192 | -------------------------------------------------------------------------------- /docs/update-parser-for-tidb.md: -------------------------------------------------------------------------------- 1 | # How to update parser for TiDB 2 | 3 | Assuming that you want to file a PR (pull request) to TiDB, and your PR includes a change in the parser, follow these steps to update the parser in TiDB. 4 | 5 | ## Step 1: Make changes in your parser repository 6 | 7 | Fork this repository to your own account and commit the changes to your repository. 8 | 9 | > **Note:** 10 | > 11 | > - Don't forget to run `make test` before you commit! 12 | > - Make sure `parser.go` is updated. 13 | 14 | Suppose the forked repository is `https://github.com/your-repo/parser`. 15 | 16 | ## Step 2: Make your parser changes take effect in TiDB and run CI 17 | 18 | 1. In your TiDB repository, execute the `replace` instruction to make your parser changes take effect: 19 | 20 | ``` 21 | GO111MODULE=on go mod edit -replace github.com/pingcap/parser=github.com/your-repo/parser@your-branch 22 | ``` 23 | 24 | 2. `make dev` to run CI in TiDB. 25 | 26 | 3. File a PR to TiDB. 27 | 28 | ## Step 3: Merge the PR about the parser to this repository 29 | 30 | File a PR to this repository. **Link the related PR in TiDB in your PR description or comment.** 31 | 32 | This PR will be reviewed, and if everything goes well, it will be merged. 33 | 34 | ## Step 4: Update TiDB to use the latest parser 35 | 36 | In your TiDB pull request, modify the `go.mod` file manually or use this command: 37 | 38 | ``` 39 | GO111MODULE=on go get -u github.com/pingcap/parser@master 40 | ``` 41 | 42 | Make sure the `replace` instruction is changed back to the `require` instruction and the version is the latest. 43 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser 15 | 16 | // WindowFuncTokenMapForTest exports windowFuncTokenMap in test-case 17 | var WindowFuncTokenMapForTest = windowFuncTokenMap 18 | -------------------------------------------------------------------------------- /format/format_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package format 15 | 16 | import ( 17 | "bytes" 18 | "io/ioutil" 19 | "strings" 20 | "testing" 21 | 22 | . "github.com/pingcap/check" 23 | ) 24 | 25 | func TestT(t *testing.T) { 26 | CustomVerboseFlag = true 27 | TestingT(t) 28 | } 29 | 30 | var _ = Suite(&testFormatSuite{}) 31 | var _ = Suite(&testRestoreCtxSuite{}) 32 | 33 | type testFormatSuite struct { 34 | } 35 | 36 | func checkFormat(c *C, f Formatter, buf *bytes.Buffer, str, expect string) { 37 | _, err := f.Format(str, 3) 38 | c.Assert(err, IsNil) 39 | b, err := ioutil.ReadAll(buf) 40 | c.Assert(err, IsNil) 41 | c.Assert(string(b), Equals, expect) 42 | } 43 | 44 | func (s *testFormatSuite) TestFormat(c *C) { 45 | str := "abc%d%%e%i\nx\ny\n%uz\n" 46 | buf := &bytes.Buffer{} 47 | f := IndentFormatter(buf, "\t") 48 | expect := `abc3%e 49 | x 50 | y 51 | z 52 | ` 53 | checkFormat(c, f, buf, str, expect) 54 | 55 | str = "abc%d%%e%i\nx\ny\n%uz\n%i\n" 56 | buf = &bytes.Buffer{} 57 | f = FlatFormatter(buf) 58 | expect = "abc3%e x y z\n " 59 | checkFormat(c, f, buf, str, expect) 60 | } 61 | 62 | type testRestoreCtxSuite struct { 63 | } 64 | 65 | func (s *testRestoreCtxSuite) TestRestoreCtx(c *C) { 66 | testCases := []struct { 67 | flag RestoreFlags 68 | expect string 69 | }{ 70 | {0, "key`.'\"Word\\ str`.'\"ing\\ na`.'\"Me\\"}, 71 | {RestoreStringSingleQuotes, "key`.'\"Word\\ 'str`.''\"ing\\' na`.'\"Me\\"}, 72 | {RestoreStringDoubleQuotes, "key`.'\"Word\\ \"str`.'\"\"ing\\\" na`.'\"Me\\"}, 73 | {RestoreStringEscapeBackslash, "key`.'\"Word\\ str`.'\"ing\\\\ na`.'\"Me\\"}, 74 | {RestoreKeyWordUppercase, "KEY`.'\"WORD\\ str`.'\"ing\\ na`.'\"Me\\"}, 75 | {RestoreKeyWordLowercase, "key`.'\"word\\ str`.'\"ing\\ na`.'\"Me\\"}, 76 | {RestoreNameUppercase, "key`.'\"Word\\ str`.'\"ing\\ NA`.'\"ME\\"}, 77 | {RestoreNameLowercase, "key`.'\"Word\\ str`.'\"ing\\ na`.'\"me\\"}, 78 | {RestoreNameDoubleQuotes, "key`.'\"Word\\ str`.'\"ing\\ \"na`.'\"\"Me\\\""}, 79 | {RestoreNameBackQuotes, "key`.'\"Word\\ str`.'\"ing\\ `na``.'\"Me\\`"}, 80 | {DefaultRestoreFlags, "KEY`.'\"WORD\\ 'str`.''\"ing\\' `na``.'\"Me\\`"}, 81 | {RestoreStringSingleQuotes | RestoreStringDoubleQuotes, "key`.'\"Word\\ 'str`.''\"ing\\' na`.'\"Me\\"}, 82 | {RestoreKeyWordUppercase | RestoreKeyWordLowercase, "KEY`.'\"WORD\\ str`.'\"ing\\ na`.'\"Me\\"}, 83 | {RestoreNameUppercase | RestoreNameLowercase, "key`.'\"Word\\ str`.'\"ing\\ NA`.'\"ME\\"}, 84 | {RestoreNameDoubleQuotes | RestoreNameBackQuotes, "key`.'\"Word\\ str`.'\"ing\\ \"na`.'\"\"Me\\\""}, 85 | } 86 | var sb strings.Builder 87 | for _, testCase := range testCases { 88 | sb.Reset() 89 | ctx := NewRestoreCtx(testCase.flag, &sb) 90 | ctx.WriteKeyWord("key`.'\"Word\\") 91 | ctx.WritePlain(" ") 92 | ctx.WriteString("str`.'\"ing\\") 93 | ctx.WritePlain(" ") 94 | ctx.WriteName("na`.'\"Me\\") 95 | c.Assert(sb.String(), Equals, testCase.expect, Commentf("case: %#v", testCase)) 96 | } 97 | } 98 | 99 | func (s *testRestoreCtxSuite) TestRestoreSpecialComment(c *C) { 100 | var sb strings.Builder 101 | sb.Reset() 102 | ctx := NewRestoreCtx(RestoreTiDBSpecialComment, &sb) 103 | ctx.WriteWithSpecialComments("fea_id", func() { 104 | ctx.WritePlain("content") 105 | }) 106 | c.Assert(sb.String(), Equals, "/*T![fea_id] content */") 107 | 108 | sb.Reset() 109 | ctx.WriteWithSpecialComments("", func() { 110 | ctx.WritePlain("shard_row_id_bits") 111 | }) 112 | c.Assert(sb.String(), Equals, "/*T! shard_row_id_bits */") 113 | } 114 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pingcap/parser 2 | 3 | require ( 4 | github.com/cznic/golex v0.0.0-20181122101858-9c343928389c // indirect 5 | github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 6 | github.com/cznic/parser v0.0.0-20160622100904-31edd927e5b1 7 | github.com/cznic/sortutil v0.0.0-20181122101858-f5f958428db8 8 | github.com/cznic/strutil v0.0.0-20171016134553-529a34b1c186 9 | github.com/cznic/y v0.0.0-20170802143616-045f81c6662a 10 | github.com/go-sql-driver/mysql v1.3.0 11 | github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 12 | github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 13 | github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 14 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect 15 | github.com/stretchr/testify v1.7.0 16 | go.uber.org/zap v1.18.1 17 | golang.org/x/text v0.3.6 18 | ) 19 | 20 | go 1.13 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 4 | github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 5 | github.com/cznic/golex v0.0.0-20181122101858-9c343928389c h1:G8zTsaqyVfIHpgMFcGgdbhHSFhlNc77rAKkhVbQ9kQg= 6 | github.com/cznic/golex v0.0.0-20181122101858-9c343928389c/go.mod h1:+bmmJDNmKlhWNG+gwWCkaBoTy39Fs+bzRxVBzoTQbIc= 7 | github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= 8 | github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= 9 | github.com/cznic/parser v0.0.0-20160622100904-31edd927e5b1 h1:uWcWCkSP+E1w1z8r082miT+c+9vzg+5UdrgGCo15lMo= 10 | github.com/cznic/parser v0.0.0-20160622100904-31edd927e5b1/go.mod h1:2B43mz36vGZNZEwkWi8ayRSSUXLfjL8OkbzwW4NcPMM= 11 | github.com/cznic/sortutil v0.0.0-20181122101858-f5f958428db8 h1:LpMLYGyy67BoAFGda1NeOBQwqlv7nUXpm+rIVHGxZZ4= 12 | github.com/cznic/sortutil v0.0.0-20181122101858-f5f958428db8/go.mod h1:q2w6Bg5jeox1B+QkJ6Wp/+Vn0G/bo3f1uY7Fn3vivIQ= 13 | github.com/cznic/strutil v0.0.0-20171016134553-529a34b1c186 h1:0rkFMAbn5KBKNpJyHQ6Prb95vIKanmAe62KxsrN+sqA= 14 | github.com/cznic/strutil v0.0.0-20171016134553-529a34b1c186/go.mod h1:AHHPPPXTw0h6pVabbcbyGRK1DckRn7r/STdZEeIDzZc= 15 | github.com/cznic/y v0.0.0-20170802143616-045f81c6662a h1:N2rDAvHuM46OGscJkGX4Dw4BBqZgg6mGNGLYs5utVVo= 16 | github.com/cznic/y v0.0.0-20170802143616-045f81c6662a/go.mod h1:1rk5VM7oSnA4vjp+hrLQ3HWHa+Y4yPCa3/CsJrcNnvs= 17 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 19 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 20 | github.com/go-sql-driver/mysql v1.3.0 h1:pgwjLi/dvffoP9aabwkT3AKpXQM93QARkjFhDDqC1UE= 21 | github.com/go-sql-driver/mysql v1.3.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 22 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 23 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 24 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 25 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 26 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 27 | github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 h1:USx2/E1bX46VG32FIw034Au6seQ2fY9NEILmNh/UlQg= 28 | github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= 29 | github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= 30 | github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 h1:+FZIDR/D97YOPik4N4lPDaUcLDF/EQPogxtlHB2ZZRM= 31 | github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= 32 | github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 h1:k2BbABz9+TNpYRwsCCFS8pEEnFVOdbgEjL/kTlLuzZQ= 33 | github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7/go.mod h1:8AanEdAHATuRurdGxZXBz0At+9avep+ub7U1AGYLIMM= 34 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 35 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 36 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 37 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 38 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= 39 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= 40 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 41 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 42 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 43 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 44 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 45 | go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= 46 | go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= 47 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 48 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 49 | go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0= 50 | go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= 51 | go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= 52 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 53 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 54 | go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= 55 | go.uber.org/zap v1.18.1 h1:CSUJ2mjFszzEWt4CdKISEuChVIXGBn3lAPwkRGyVrc4= 56 | go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= 57 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 58 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= 59 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 60 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 61 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 62 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 63 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 64 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 65 | golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= 66 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 67 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 68 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 69 | golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 70 | golang.org/x/tools v0.0.0-20191108193012-7d206e10da11 h1:Yq9t9jnGoR+dBuitxdo9l6Q7xh/zOyNnYUtDKaQ3x0E= 71 | golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 72 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 73 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 74 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 75 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 76 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= 77 | gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= 78 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 79 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 80 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 81 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 82 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 83 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 84 | -------------------------------------------------------------------------------- /hintparser_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser_test 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | 19 | "github.com/pingcap/parser" 20 | "github.com/pingcap/parser/ast" 21 | "github.com/pingcap/parser/model" 22 | "github.com/pingcap/parser/mysql" 23 | ) 24 | 25 | var _ = Suite(&testHintParserSuite{}) 26 | 27 | type testHintParserSuite struct{} 28 | 29 | func (s *testHintParserSuite) TestParseHint(c *C) { 30 | testCases := []struct { 31 | input string 32 | mode mysql.SQLMode 33 | output []*ast.TableOptimizerHint 34 | errs []string 35 | }{ 36 | { 37 | input: "", 38 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 39 | }, 40 | { 41 | input: "MEMORY_QUOTA(8 MB) MEMORY_QUOTA(6 GB)", 42 | output: []*ast.TableOptimizerHint{ 43 | { 44 | HintName: model.NewCIStr("MEMORY_QUOTA"), 45 | HintData: int64(8 * 1024 * 1024), 46 | }, 47 | { 48 | HintName: model.NewCIStr("MEMORY_QUOTA"), 49 | HintData: int64(6 * 1024 * 1024 * 1024), 50 | }, 51 | }, 52 | }, 53 | { 54 | input: "QB_NAME(qb1) QB_NAME(`qb2`), QB_NAME(TRUE) QB_NAME(\"ANSI quoted\") QB_NAME(_utf8), QB_NAME(0b10) QB_NAME(0x1a)", 55 | mode: mysql.ModeANSIQuotes, 56 | output: []*ast.TableOptimizerHint{ 57 | { 58 | HintName: model.NewCIStr("QB_NAME"), 59 | QBName: model.NewCIStr("qb1"), 60 | }, 61 | { 62 | HintName: model.NewCIStr("QB_NAME"), 63 | QBName: model.NewCIStr("qb2"), 64 | }, 65 | { 66 | HintName: model.NewCIStr("QB_NAME"), 67 | QBName: model.NewCIStr("TRUE"), 68 | }, 69 | { 70 | HintName: model.NewCIStr("QB_NAME"), 71 | QBName: model.NewCIStr("ANSI quoted"), 72 | }, 73 | { 74 | HintName: model.NewCIStr("QB_NAME"), 75 | QBName: model.NewCIStr("_utf8"), 76 | }, 77 | { 78 | HintName: model.NewCIStr("QB_NAME"), 79 | QBName: model.NewCIStr("0b10"), 80 | }, 81 | { 82 | HintName: model.NewCIStr("QB_NAME"), 83 | QBName: model.NewCIStr("0x1a"), 84 | }, 85 | }, 86 | }, 87 | { 88 | input: "QB_NAME(1)", 89 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 90 | }, 91 | { 92 | input: "QB_NAME('string literal')", 93 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 94 | }, 95 | { 96 | input: "QB_NAME(many identifiers)", 97 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 98 | }, 99 | { 100 | input: "QB_NAME(@qb1)", 101 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 102 | }, 103 | { 104 | input: "QB_NAME(b'10')", 105 | errs: []string{ 106 | `.*Cannot use bit-value literal.*`, 107 | `.*Optimizer hint syntax error at line 1 .*`, 108 | }, 109 | }, 110 | { 111 | input: "QB_NAME(x'1a')", 112 | errs: []string{ 113 | `.*Cannot use hexadecimal literal.*`, 114 | `.*Optimizer hint syntax error at line 1 .*`, 115 | }, 116 | }, 117 | { 118 | input: "JOIN_FIXED_ORDER() BKA()", 119 | errs: []string{ 120 | `.*Optimizer hint JOIN_FIXED_ORDER is not supported.*`, 121 | `.*Optimizer hint BKA is not supported.*`, 122 | }, 123 | }, 124 | { 125 | input: "HASH_JOIN() TIDB_HJ(@qb1) INL_JOIN(x, `y y`.z) MERGE_JOIN(w@`First QB`)", 126 | output: []*ast.TableOptimizerHint{ 127 | { 128 | HintName: model.NewCIStr("HASH_JOIN"), 129 | }, 130 | { 131 | HintName: model.NewCIStr("TIDB_HJ"), 132 | QBName: model.NewCIStr("qb1"), 133 | }, 134 | { 135 | HintName: model.NewCIStr("INL_JOIN"), 136 | Tables: []ast.HintTable{ 137 | {TableName: model.NewCIStr("x")}, 138 | {DBName: model.NewCIStr("y y"), TableName: model.NewCIStr("z")}, 139 | }, 140 | }, 141 | { 142 | HintName: model.NewCIStr("MERGE_JOIN"), 143 | Tables: []ast.HintTable{ 144 | {TableName: model.NewCIStr("w"), QBName: model.NewCIStr("First QB")}, 145 | }, 146 | }, 147 | }, 148 | }, 149 | { 150 | input: "USE_INDEX_MERGE(@qb1 tbl1 x, y, z) IGNORE_INDEX(tbl2@qb2) USE_INDEX(tbl3 PRIMARY) FORCE_INDEX(tbl4@qb3 c1)", 151 | output: []*ast.TableOptimizerHint{ 152 | { 153 | HintName: model.NewCIStr("USE_INDEX_MERGE"), 154 | Tables: []ast.HintTable{{TableName: model.NewCIStr("tbl1")}}, 155 | QBName: model.NewCIStr("qb1"), 156 | Indexes: []model.CIStr{model.NewCIStr("x"), model.NewCIStr("y"), model.NewCIStr("z")}, 157 | }, 158 | { 159 | HintName: model.NewCIStr("IGNORE_INDEX"), 160 | Tables: []ast.HintTable{{TableName: model.NewCIStr("tbl2"), QBName: model.NewCIStr("qb2")}}, 161 | }, 162 | { 163 | HintName: model.NewCIStr("USE_INDEX"), 164 | Tables: []ast.HintTable{{TableName: model.NewCIStr("tbl3")}}, 165 | Indexes: []model.CIStr{model.NewCIStr("PRIMARY")}, 166 | }, 167 | { 168 | HintName: model.NewCIStr("FORCE_INDEX"), 169 | Tables: []ast.HintTable{{TableName: model.NewCIStr("tbl4"), QBName: model.NewCIStr("qb3")}}, 170 | Indexes: []model.CIStr{model.NewCIStr("c1")}, 171 | }, 172 | }, 173 | }, 174 | { 175 | input: "USE_INDEX(@qb1 tbl1 partition(p0) x) USE_INDEX_MERGE(@qb2 tbl2@qb2 partition(p0, p1) x, y, z)", 176 | output: []*ast.TableOptimizerHint{ 177 | { 178 | HintName: model.NewCIStr("USE_INDEX"), 179 | Tables: []ast.HintTable{{ 180 | TableName: model.NewCIStr("tbl1"), 181 | PartitionList: []model.CIStr{model.NewCIStr("p0")}, 182 | }}, 183 | QBName: model.NewCIStr("qb1"), 184 | Indexes: []model.CIStr{model.NewCIStr("x")}, 185 | }, 186 | { 187 | HintName: model.NewCIStr("USE_INDEX_MERGE"), 188 | Tables: []ast.HintTable{{ 189 | TableName: model.NewCIStr("tbl2"), 190 | QBName: model.NewCIStr("qb2"), 191 | PartitionList: []model.CIStr{model.NewCIStr("p0"), model.NewCIStr("p1")}, 192 | }}, 193 | QBName: model.NewCIStr("qb2"), 194 | Indexes: []model.CIStr{model.NewCIStr("x"), model.NewCIStr("y"), model.NewCIStr("z")}, 195 | }, 196 | }, 197 | }, 198 | { 199 | input: `SET_VAR(sbs = 16M) SET_VAR(fkc=OFF) SET_VAR(os="mcb=off") set_var(abc=1) set_var(os2='mcb2=off')`, 200 | output: []*ast.TableOptimizerHint{ 201 | { 202 | HintName: model.NewCIStr("SET_VAR"), 203 | HintData: ast.HintSetVar{ 204 | VarName: "sbs", 205 | Value: "16M", 206 | }, 207 | }, 208 | { 209 | HintName: model.NewCIStr("SET_VAR"), 210 | HintData: ast.HintSetVar{ 211 | VarName: "fkc", 212 | Value: "OFF", 213 | }, 214 | }, 215 | { 216 | HintName: model.NewCIStr("SET_VAR"), 217 | HintData: ast.HintSetVar{ 218 | VarName: "os", 219 | Value: "mcb=off", 220 | }, 221 | }, 222 | { 223 | HintName: model.NewCIStr("set_var"), 224 | HintData: ast.HintSetVar{ 225 | VarName: "abc", 226 | Value: "1", 227 | }, 228 | }, 229 | { 230 | HintName: model.NewCIStr("set_var"), 231 | HintData: ast.HintSetVar{ 232 | VarName: "os2", 233 | Value: "mcb2=off", 234 | }, 235 | }, 236 | }, 237 | }, 238 | { 239 | input: "USE_TOJA(TRUE) IGNORE_PLAN_CACHE() USE_CASCADES(TRUE) QUERY_TYPE(@qb1 OLAP) QUERY_TYPE(OLTP) NO_INDEX_MERGE()", 240 | output: []*ast.TableOptimizerHint{ 241 | { 242 | HintName: model.NewCIStr("USE_TOJA"), 243 | HintData: true, 244 | }, 245 | { 246 | HintName: model.NewCIStr("IGNORE_PLAN_CACHE"), 247 | }, 248 | { 249 | HintName: model.NewCIStr("USE_CASCADES"), 250 | HintData: true, 251 | }, 252 | { 253 | HintName: model.NewCIStr("QUERY_TYPE"), 254 | QBName: model.NewCIStr("qb1"), 255 | HintData: model.NewCIStr("OLAP"), 256 | }, 257 | { 258 | HintName: model.NewCIStr("QUERY_TYPE"), 259 | HintData: model.NewCIStr("OLTP"), 260 | }, 261 | { 262 | HintName: model.NewCIStr("NO_INDEX_MERGE"), 263 | }, 264 | }, 265 | }, 266 | { 267 | input: "READ_FROM_STORAGE(@foo TIKV[a, b], TIFLASH[c, d]) HASH_AGG() READ_FROM_STORAGE(TIKV[e])", 268 | output: []*ast.TableOptimizerHint{ 269 | { 270 | HintName: model.NewCIStr("READ_FROM_STORAGE"), 271 | HintData: model.NewCIStr("TIKV"), 272 | QBName: model.NewCIStr("foo"), 273 | Tables: []ast.HintTable{ 274 | {TableName: model.NewCIStr("a")}, 275 | {TableName: model.NewCIStr("b")}, 276 | }, 277 | }, 278 | { 279 | HintName: model.NewCIStr("READ_FROM_STORAGE"), 280 | HintData: model.NewCIStr("TIFLASH"), 281 | QBName: model.NewCIStr("foo"), 282 | Tables: []ast.HintTable{ 283 | {TableName: model.NewCIStr("c")}, 284 | {TableName: model.NewCIStr("d")}, 285 | }, 286 | }, 287 | { 288 | HintName: model.NewCIStr("HASH_AGG"), 289 | }, 290 | { 291 | HintName: model.NewCIStr("READ_FROM_STORAGE"), 292 | HintData: model.NewCIStr("TIKV"), 293 | Tables: []ast.HintTable{ 294 | {TableName: model.NewCIStr("e")}, 295 | }, 296 | }, 297 | }, 298 | }, 299 | { 300 | input: "unknown_hint()", 301 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 302 | }, 303 | { 304 | input: "set_var(timestamp = 1.5)", 305 | errs: []string{ 306 | `.*Cannot use decimal number.*`, 307 | `.*Optimizer hint syntax error at line 1 .*`, 308 | }, 309 | }, 310 | { 311 | input: "set_var(timestamp = _utf8mb4'1234')", // Optimizer hint doesn't recognize _charset'strings'. 312 | errs: []string{`.*Optimizer hint syntax error at line 1 .*`}, 313 | }, 314 | { 315 | input: "set_var(timestamp = 9999999999999999999999999999999999999)", 316 | errs: []string{ 317 | `.*integer value is out of range.*`, 318 | `.*Optimizer hint syntax error at line 1 .*`, 319 | }, 320 | }, 321 | { 322 | input: "time_range('2020-02-20 12:12:12',456)", 323 | errs: []string{ 324 | `.*Optimizer hint syntax error at line 1 .*`, 325 | }, 326 | }, 327 | { 328 | input: "time_range(456,'2020-02-20 12:12:12')", 329 | errs: []string{ 330 | `.*Optimizer hint syntax error at line 1 .*`, 331 | }, 332 | }, 333 | { 334 | input: "TIME_RANGE('2020-02-20 12:12:12','2020-02-20 13:12:12')", 335 | output: []*ast.TableOptimizerHint{ 336 | { 337 | HintName: model.NewCIStr("TIME_RANGE"), 338 | HintData: ast.HintTimeRange{ 339 | From: "2020-02-20 12:12:12", 340 | To: "2020-02-20 13:12:12", 341 | }, 342 | }, 343 | }, 344 | }, 345 | } 346 | 347 | for _, tc := range testCases { 348 | output, errs := parser.ParseHint("/*+"+tc.input+"*/", tc.mode, parser.Pos{Line: 1}) 349 | c.Assert(errs, HasLen, len(tc.errs), Commentf("input = %s,\n... errs = %q", tc.input, errs)) 350 | for i, err := range errs { 351 | c.Assert(err, ErrorMatches, tc.errs[i], Commentf("input = %s, i = %d", tc.input, i)) 352 | } 353 | c.Assert(output, DeepEquals, tc.output, Commentf("input = %s,\n... output = %q", tc.input, output)) 354 | } 355 | } 356 | -------------------------------------------------------------------------------- /hintparserimpl.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package parser 15 | 16 | import ( 17 | "strconv" 18 | "strings" 19 | "unicode" 20 | 21 | "github.com/pingcap/parser/ast" 22 | "github.com/pingcap/parser/mysql" 23 | "github.com/pingcap/parser/terror" 24 | ) 25 | 26 | var ( 27 | ErrWarnOptimizerHintUnsupportedHint = terror.ClassParser.NewStd(mysql.ErrWarnOptimizerHintUnsupportedHint) 28 | ErrWarnOptimizerHintInvalidToken = terror.ClassParser.NewStd(mysql.ErrWarnOptimizerHintInvalidToken) 29 | ErrWarnMemoryQuotaOverflow = terror.ClassParser.NewStd(mysql.ErrWarnMemoryQuotaOverflow) 30 | ErrWarnOptimizerHintParseError = terror.ClassParser.NewStd(mysql.ErrWarnOptimizerHintParseError) 31 | ErrWarnOptimizerHintInvalidInteger = terror.ClassParser.NewStd(mysql.ErrWarnOptimizerHintInvalidInteger) 32 | ) 33 | 34 | // hintScanner implements the yyhintLexer interface 35 | type hintScanner struct { 36 | Scanner 37 | } 38 | 39 | func (hs *hintScanner) Errorf(format string, args ...interface{}) error { 40 | inner := hs.Scanner.Errorf(format, args...) 41 | return ErrWarnOptimizerHintParseError.GenWithStackByArgs(inner) 42 | } 43 | 44 | func (hs *hintScanner) Lex(lval *yyhintSymType) int { 45 | tok, pos, lit := hs.scan() 46 | hs.lastScanOffset = pos.Offset 47 | var errorTokenType string 48 | 49 | switch tok { 50 | case intLit: 51 | n, e := strconv.ParseUint(lit, 10, 64) 52 | if e != nil { 53 | hs.AppendError(ErrWarnOptimizerHintInvalidInteger.GenWithStackByArgs(lit)) 54 | return int(unicode.ReplacementChar) 55 | } 56 | lval.number = n 57 | return hintIntLit 58 | 59 | case singleAtIdentifier: 60 | lval.ident = lit 61 | return hintSingleAtIdentifier 62 | 63 | case identifier: 64 | lval.ident = lit 65 | if tok1, ok := hintTokenMap[strings.ToUpper(lit)]; ok { 66 | return tok1 67 | } 68 | return hintIdentifier 69 | 70 | case stringLit: 71 | lval.ident = lit 72 | if hs.sqlMode.HasANSIQuotesMode() && hs.r.s[pos.Offset] == '"' { 73 | return hintIdentifier 74 | } 75 | return hintStringLit 76 | 77 | case bitLit: 78 | if strings.HasPrefix(lit, "0b") { 79 | lval.ident = lit 80 | return hintIdentifier 81 | } 82 | errorTokenType = "bit-value literal" 83 | 84 | case hexLit: 85 | if strings.HasPrefix(lit, "0x") { 86 | lval.ident = lit 87 | return hintIdentifier 88 | } 89 | errorTokenType = "hexadecimal literal" 90 | 91 | case quotedIdentifier: 92 | lval.ident = lit 93 | return hintIdentifier 94 | 95 | case eq: 96 | return '=' 97 | 98 | case floatLit: 99 | errorTokenType = "floating point number" 100 | case decLit: 101 | errorTokenType = "decimal number" 102 | 103 | default: 104 | if tok <= 0x7f { 105 | return tok 106 | } 107 | errorTokenType = "unknown token" 108 | } 109 | 110 | hs.AppendError(ErrWarnOptimizerHintInvalidToken.GenWithStackByArgs(errorTokenType, lit, tok)) 111 | return int(unicode.ReplacementChar) 112 | } 113 | 114 | type hintParser struct { 115 | lexer hintScanner 116 | result []*ast.TableOptimizerHint 117 | 118 | // the following fields are used by yyParse to reduce allocation. 119 | cache []yyhintSymType 120 | yylval yyhintSymType 121 | yyVAL *yyhintSymType 122 | } 123 | 124 | func newHintParser() *hintParser { 125 | return &hintParser{cache: make([]yyhintSymType, 50)} 126 | } 127 | 128 | func (hp *hintParser) parse(input string, sqlMode mysql.SQLMode, initPos Pos) ([]*ast.TableOptimizerHint, []error) { 129 | hp.result = nil 130 | hp.lexer.reset(input[3:]) 131 | hp.lexer.SetSQLMode(sqlMode) 132 | hp.lexer.r.updatePos(Pos{ 133 | Line: initPos.Line, 134 | Col: initPos.Col + 3, // skipped the initial '/*+' 135 | Offset: 0, 136 | }) 137 | hp.lexer.inBangComment = true // skip the final '*/' (we need the '*/' for reporting warnings) 138 | 139 | yyhintParse(&hp.lexer, hp) 140 | 141 | warns, errs := hp.lexer.Errors() 142 | if len(errs) == 0 { 143 | errs = warns 144 | } 145 | return hp.result, errs 146 | } 147 | 148 | // ParseHint parses an optimizer hint (the interior of `/*+ ... */`). 149 | func ParseHint(input string, sqlMode mysql.SQLMode, initPos Pos) ([]*ast.TableOptimizerHint, []error) { 150 | hp := newHintParser() 151 | return hp.parse(input, sqlMode, initPos) 152 | } 153 | 154 | func (hp *hintParser) warnUnsupportedHint(name string) { 155 | warn := ErrWarnOptimizerHintUnsupportedHint.GenWithStackByArgs(name) 156 | hp.lexer.warns = append(hp.lexer.warns, warn) 157 | } 158 | 159 | func (hp *hintParser) lastErrorAsWarn() { 160 | hp.lexer.lastErrorAsWarn() 161 | } 162 | -------------------------------------------------------------------------------- /model/flags.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package model 15 | 16 | // Flags are used by tipb.SelectRequest.Flags to handle execution mode, like how to handle truncate error. 17 | const ( 18 | // FlagIgnoreTruncate indicates if truncate error should be ignored. 19 | // Read-only statements should ignore truncate error, write statements should not ignore truncate error. 20 | FlagIgnoreTruncate uint64 = 1 21 | // FlagTruncateAsWarning indicates if truncate error should be returned as warning. 22 | // This flag only matters if FlagIgnoreTruncate is not set, in strict sql mode, truncate error should 23 | // be returned as error, in non-strict sql mode, truncate error should be saved as warning. 24 | FlagTruncateAsWarning = 1 << 1 25 | // FlagPadCharToFullLength indicates if sql_mode 'PAD_CHAR_TO_FULL_LENGTH' is set. 26 | FlagPadCharToFullLength = 1 << 2 27 | // FlagInInsertStmt indicates if this is a INSERT statement. 28 | FlagInInsertStmt = 1 << 3 29 | // FlagInUpdateOrDeleteStmt indicates if this is a UPDATE statement or a DELETE statement. 30 | FlagInUpdateOrDeleteStmt = 1 << 4 31 | // FlagInSelectStmt indicates if this is a SELECT statement. 32 | FlagInSelectStmt = 1 << 5 33 | // FlagOverflowAsWarning indicates if overflow error should be returned as warning. 34 | // In strict sql mode, overflow error should be returned as error, 35 | // in non-strict sql mode, overflow error should be saved as warning. 36 | FlagOverflowAsWarning = 1 << 6 37 | // FlagIgnoreZeroInDate indicates if ZeroInDate error should be ignored. 38 | // Read-only statements should ignore ZeroInDate error. 39 | // Write statements should not ignore ZeroInDate error in strict sql mode. 40 | FlagIgnoreZeroInDate = 1 << 7 41 | // FlagDividedByZeroAsWarning indicates if DividedByZero should be returned as warning. 42 | FlagDividedByZeroAsWarning = 1 << 8 43 | // FlagInSetOprStmt indicates if this is a UNION/EXCEPT/INTERSECT statement. 44 | FlagInSetOprStmt = 1 << 9 45 | // FlagInLoadDataStmt indicates if this is a LOAD DATA statement. 46 | FlagInLoadDataStmt = 1 << 10 47 | ) 48 | -------------------------------------------------------------------------------- /mysql/const_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | import ( 17 | "testing" 18 | 19 | . "github.com/pingcap/check" 20 | ) 21 | 22 | var _ = Suite(&testConstSuite{}) 23 | 24 | type testConstSuite struct{} 25 | 26 | func TestT(t *testing.T) { 27 | TestingT(t) 28 | } 29 | 30 | func (s *testConstSuite) TestSQLMode(c *C) { 31 | // ref https://dev.mysql.com/doc/internals/en/query-event.html#q-sql-mode-code, 32 | hardCode := []struct { 33 | code SQLMode 34 | value int 35 | }{{ 36 | ModeRealAsFloat, 0x00000001, 37 | }, { 38 | ModePipesAsConcat, 0x00000002, 39 | }, { 40 | ModeANSIQuotes, 0x00000004, 41 | }, { 42 | ModeIgnoreSpace, 0x00000008, 43 | }, { 44 | ModeNotUsed, 0x00000010, 45 | }, { 46 | ModeOnlyFullGroupBy, 0x00000020, 47 | }, { 48 | ModeNoUnsignedSubtraction, 0x00000040, 49 | }, { 50 | ModeNoDirInCreate, 0x00000080, 51 | }, { 52 | ModePostgreSQL, 0x00000100, 53 | }, { 54 | ModeOracle, 0x00000200, 55 | }, { 56 | ModeMsSQL, 0x00000400, 57 | }, { 58 | ModeDb2, 0x00000800, 59 | }, { 60 | ModeMaxdb, 0x00001000, 61 | }, { 62 | ModeNoKeyOptions, 0x00002000, 63 | }, { 64 | ModeNoTableOptions, 0x00004000, 65 | }, { 66 | ModeNoFieldOptions, 0x00008000, 67 | }, { 68 | ModeMySQL323, 0x00010000, 69 | }, { 70 | ModeMySQL40, 0x00020000, 71 | }, { 72 | ModeANSI, 0x00040000, 73 | }, { 74 | ModeNoAutoValueOnZero, 0x00080000, 75 | }, { 76 | ModeNoBackslashEscapes, 0x00100000, 77 | }, { 78 | ModeStrictTransTables, 0x00200000, 79 | }, { 80 | ModeStrictAllTables, 0x00400000, 81 | }, { 82 | ModeNoZeroInDate, 0x00800000, 83 | }, { 84 | ModeNoZeroDate, 0x01000000, 85 | }, { 86 | ModeInvalidDates, 0x02000000, 87 | }, { 88 | ModeErrorForDivisionByZero, 0x04000000, 89 | }, { 90 | ModeTraditional, 0x08000000, 91 | }, { 92 | ModeNoAutoCreateUser, 0x10000000, 93 | }, { 94 | ModeHighNotPrecedence, 0x20000000, 95 | }, { 96 | ModeNoEngineSubstitution, 0x40000000, 97 | }, { 98 | ModePadCharToFullLength, 0x80000000, 99 | }} 100 | 101 | for _, ca := range hardCode { 102 | c.Assert(int(ca.code), Equals, ca.value) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /mysql/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | import ( 17 | "fmt" 18 | 19 | "github.com/pingcap/errors" 20 | ) 21 | 22 | // Portable analogs of some common call errors. 23 | var ( 24 | ErrBadConn = errors.New("connection was bad") 25 | ErrMalformPacket = errors.New("malform packet error") 26 | ) 27 | 28 | // SQLError records an error information, from executing SQL. 29 | type SQLError struct { 30 | Code uint16 31 | Message string 32 | State string 33 | } 34 | 35 | // Error prints errors, with a formatted string. 36 | func (e *SQLError) Error() string { 37 | return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.State, e.Message) 38 | } 39 | 40 | // NewErr generates a SQL error, with an error code and default format specifier defined in MySQLErrName. 41 | func NewErr(errCode uint16, args ...interface{}) *SQLError { 42 | e := &SQLError{Code: errCode} 43 | 44 | if s, ok := MySQLState[errCode]; ok { 45 | e.State = s 46 | } else { 47 | e.State = DefaultMySQLState 48 | } 49 | 50 | if sqlErr, ok := MySQLErrName[errCode]; ok { 51 | errors.RedactErrorArg(args, sqlErr.RedactArgPos) 52 | e.Message = fmt.Sprintf(sqlErr.Raw, args...) 53 | } else { 54 | e.Message = fmt.Sprint(args...) 55 | } 56 | 57 | return e 58 | } 59 | 60 | // NewErrf creates a SQL error, with an error code and a format specifier. 61 | func NewErrf(errCode uint16, format string, redactArgPos []int, args ...interface{}) *SQLError { 62 | e := &SQLError{Code: errCode} 63 | 64 | if s, ok := MySQLState[errCode]; ok { 65 | e.State = s 66 | } else { 67 | e.State = DefaultMySQLState 68 | } 69 | 70 | errors.RedactErrorArg(args, redactArgPos) 71 | e.Message = fmt.Sprintf(format, args...) 72 | 73 | return e 74 | } 75 | -------------------------------------------------------------------------------- /mysql/error_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | ) 19 | 20 | var _ = Suite(&testSQLErrorSuite{}) 21 | 22 | type testSQLErrorSuite struct { 23 | } 24 | 25 | func (s *testSQLErrorSuite) TestSQLError(c *C) { 26 | e := NewErrf(ErrNoDB, "no db error", nil) 27 | c.Assert(len(e.Error()), Greater, 0) 28 | 29 | e = NewErrf(0, "customized error", nil) 30 | c.Assert(len(e.Error()), Greater, 0) 31 | 32 | e = NewErr(ErrNoDB) 33 | c.Assert(len(e.Error()), Greater, 0) 34 | 35 | e = NewErr(0, "customized error", nil) 36 | c.Assert(len(e.Error()), Greater, 0) 37 | } 38 | -------------------------------------------------------------------------------- /mysql/locale_format.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "strconv" 6 | "strings" 7 | "unicode" 8 | 9 | "github.com/pingcap/errors" 10 | ) 11 | 12 | func formatENUS(number string, precision string) (string, error) { 13 | var buffer bytes.Buffer 14 | if unicode.IsDigit(rune(precision[0])) { 15 | for i, v := range precision { 16 | if unicode.IsDigit(v) { 17 | continue 18 | } 19 | precision = precision[:i] 20 | break 21 | } 22 | } else { 23 | precision = "0" 24 | } 25 | if number[0] == '-' && number[1] == '.' { 26 | number = strings.Replace(number, "-", "-0", 1) 27 | } else if number[0] == '.' { 28 | number = strings.Replace(number, ".", "0.", 1) 29 | } 30 | 31 | if (number[:1] == "-" && !unicode.IsDigit(rune(number[1]))) || 32 | (!unicode.IsDigit(rune(number[0])) && number[:1] != "-") { 33 | buffer.Write([]byte{'0'}) 34 | position, err := strconv.ParseUint(precision, 10, 64) 35 | if err == nil && position > 0 { 36 | buffer.Write([]byte{'.'}) 37 | buffer.WriteString(strings.Repeat("0", int(position))) 38 | } 39 | return buffer.String(), nil 40 | } else if number[:1] == "-" { 41 | buffer.Write([]byte{'-'}) 42 | number = number[1:] 43 | } 44 | 45 | for i, v := range number { 46 | if unicode.IsDigit(v) { 47 | continue 48 | } else if i == 1 && number[1] == '.' { 49 | continue 50 | } else if v == '.' && number[1] != '.' { 51 | continue 52 | } else { 53 | number = number[:i] 54 | break 55 | } 56 | } 57 | 58 | comma := []byte{','} 59 | parts := strings.Split(number, ".") 60 | pos := 0 61 | if len(parts[0])%3 != 0 { 62 | pos += len(parts[0]) % 3 63 | buffer.WriteString(parts[0][:pos]) 64 | buffer.Write(comma) 65 | } 66 | for ; pos < len(parts[0]); pos += 3 { 67 | buffer.WriteString(parts[0][pos : pos+3]) 68 | buffer.Write(comma) 69 | } 70 | buffer.Truncate(buffer.Len() - 1) 71 | 72 | position, err := strconv.ParseUint(precision, 10, 64) 73 | if err == nil { 74 | if position > 0 { 75 | buffer.Write([]byte{'.'}) 76 | if len(parts) == 2 { 77 | if uint64(len(parts[1])) >= position { 78 | buffer.WriteString(parts[1][:position]) 79 | } else { 80 | buffer.WriteString(parts[1]) 81 | buffer.WriteString(strings.Repeat("0", int(position)-len(parts[1]))) 82 | } 83 | } else { 84 | buffer.WriteString(strings.Repeat("0", int(position))) 85 | } 86 | } 87 | } 88 | 89 | return buffer.String(), nil 90 | } 91 | 92 | func formatZHCN(number string, precision string) (string, error) { 93 | return "", errors.New("not implemented") 94 | } 95 | 96 | func formatNotSupport(number string, precision string) (string, error) { 97 | return "", errors.New("not support for the specific locale") 98 | } 99 | -------------------------------------------------------------------------------- /mysql/privs_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | ) 19 | 20 | var _ = Suite(&testPrivsSuite{}) 21 | 22 | type testPrivsSuite struct{} 23 | 24 | func (s *testPrivsSuite) TestPrivString(c *C) { 25 | for i := 0; ; i++ { 26 | p := PrivilegeType(1 << i) 27 | if p > AllPriv { 28 | break 29 | } 30 | c.Assert(p.String(), Not(Equals), "", Commentf("%d-th", i)) 31 | } 32 | } 33 | 34 | func (s *testPrivsSuite) TestPrivColumn(c *C) { 35 | for _, p := range AllGlobalPrivs { 36 | c.Assert(p.ColumnString(), Not(Equals), "", Commentf("%s", p)) 37 | np, ok := NewPrivFromColumn(p.ColumnString()) 38 | c.Assert(ok, IsTrue, Commentf("%s", p)) 39 | c.Assert(np, Equals, p) 40 | } 41 | for _, p := range StaticGlobalOnlyPrivs { 42 | c.Assert(p.ColumnString(), Not(Equals), "", Commentf("%s", p)) 43 | np, ok := NewPrivFromColumn(p.ColumnString()) 44 | c.Assert(ok, IsTrue, Commentf("%s", p)) 45 | c.Assert(np, Equals, p) 46 | } 47 | for _, p := range AllDBPrivs { 48 | c.Assert(p.ColumnString(), Not(Equals), "", Commentf("%s", p)) 49 | np, ok := NewPrivFromColumn(p.ColumnString()) 50 | c.Assert(ok, IsTrue, Commentf("%s", p)) 51 | c.Assert(np, Equals, p) 52 | } 53 | } 54 | 55 | func (s *testPrivsSuite) TestPrivSetString(c *C) { 56 | for _, p := range AllTablePrivs { 57 | c.Assert(p.SetString(), Not(Equals), "", Commentf("%s", p)) 58 | np, ok := NewPrivFromSetEnum(p.SetString()) 59 | c.Assert(ok, IsTrue, Commentf("%s", p)) 60 | c.Assert(np, Equals, p) 61 | } 62 | for _, p := range AllColumnPrivs { 63 | c.Assert(p.SetString(), Not(Equals), "", Commentf("%s", p)) 64 | np, ok := NewPrivFromSetEnum(p.SetString()) 65 | c.Assert(ok, IsTrue, Commentf("%s", p)) 66 | c.Assert(np, Equals, p) 67 | } 68 | } 69 | 70 | func (s *testPrivsSuite) TestPrivsHas(c *C) { 71 | // it is a simple helper, does not handle all&dynamic privs 72 | privs := Privileges{AllPriv} 73 | c.Assert(privs.Has(AllPriv), IsTrue) 74 | c.Assert(privs.Has(InsertPriv), IsFalse) 75 | 76 | // multiple privs 77 | privs = Privileges{InsertPriv, SelectPriv} 78 | c.Assert(privs.Has(SelectPriv), IsTrue) 79 | c.Assert(privs.Has(InsertPriv), IsTrue) 80 | c.Assert(privs.Has(DropPriv), IsFalse) 81 | } 82 | 83 | func (s *testPrivsSuite) TestPrivAllConsistency(c *C) { 84 | // AllPriv in mysql.user columns. 85 | for priv := PrivilegeType(CreatePriv); priv != AllPriv; priv = priv << 1 { 86 | _, ok := Priv2UserCol[priv] 87 | c.Assert(ok, IsTrue, Commentf("priv fail %d", priv)) 88 | } 89 | 90 | c.Assert(len(Priv2UserCol), Equals, len(AllGlobalPrivs)+1) 91 | 92 | // USAGE privilege doesn't have a column in Priv2UserCol 93 | // ALL privilege doesn't have a column in Priv2UserCol 94 | // so it's +2 95 | c.Assert(len(Priv2Str), Equals, len(Priv2UserCol)+2) 96 | } 97 | -------------------------------------------------------------------------------- /mysql/type.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | // MySQL type information. 17 | const ( 18 | TypeUnspecified byte = 0 19 | TypeTiny byte = 1 20 | TypeShort byte = 2 21 | TypeLong byte = 3 22 | TypeFloat byte = 4 23 | TypeDouble byte = 5 24 | TypeNull byte = 6 25 | TypeTimestamp byte = 7 26 | TypeLonglong byte = 8 27 | TypeInt24 byte = 9 28 | TypeDate byte = 10 29 | /* TypeDuration original name was TypeTime, renamed to TypeDuration to resolve the conflict with Go type Time.*/ 30 | TypeDuration byte = 11 31 | TypeDatetime byte = 12 32 | TypeYear byte = 13 33 | TypeNewDate byte = 14 34 | TypeVarchar byte = 15 35 | TypeBit byte = 16 36 | 37 | TypeJSON byte = 0xf5 38 | TypeNewDecimal byte = 0xf6 39 | TypeEnum byte = 0xf7 40 | TypeSet byte = 0xf8 41 | TypeTinyBlob byte = 0xf9 42 | TypeMediumBlob byte = 0xfa 43 | TypeLongBlob byte = 0xfb 44 | TypeBlob byte = 0xfc 45 | TypeVarString byte = 0xfd 46 | TypeString byte = 0xfe 47 | TypeGeometry byte = 0xff 48 | ) 49 | 50 | // Flag information. 51 | const ( 52 | NotNullFlag uint = 1 << 0 /* Field can't be NULL */ 53 | PriKeyFlag uint = 1 << 1 /* Field is part of a primary key */ 54 | UniqueKeyFlag uint = 1 << 2 /* Field is part of a unique key */ 55 | MultipleKeyFlag uint = 1 << 3 /* Field is part of a key */ 56 | BlobFlag uint = 1 << 4 /* Field is a blob */ 57 | UnsignedFlag uint = 1 << 5 /* Field is unsigned */ 58 | ZerofillFlag uint = 1 << 6 /* Field is zerofill */ 59 | BinaryFlag uint = 1 << 7 /* Field is binary */ 60 | EnumFlag uint = 1 << 8 /* Field is an enum */ 61 | AutoIncrementFlag uint = 1 << 9 /* Field is an auto increment field */ 62 | TimestampFlag uint = 1 << 10 /* Field is a timestamp */ 63 | SetFlag uint = 1 << 11 /* Field is a set */ 64 | NoDefaultValueFlag uint = 1 << 12 /* Field doesn't have a default value */ 65 | OnUpdateNowFlag uint = 1 << 13 /* Field is set to NOW on UPDATE */ 66 | PartKeyFlag uint = 1 << 14 /* Intern: Part of some keys */ 67 | NumFlag uint = 1 << 15 /* Field is a num (for clients) */ 68 | 69 | GroupFlag uint = 1 << 15 /* Internal: Group field */ 70 | UniqueFlag uint = 1 << 16 /* Internal: Used by sql_yacc */ 71 | BinCmpFlag uint = 1 << 17 /* Internal: Used by sql_yacc */ 72 | ParseToJSONFlag uint = 1 << 18 /* Internal: Used when we want to parse string to JSON in CAST */ 73 | IsBooleanFlag uint = 1 << 19 /* Internal: Used for telling boolean literal from integer */ 74 | PreventNullInsertFlag uint = 1 << 20 /* Prevent this Field from inserting NULL values */ 75 | EnumSetAsIntFlag uint = 1 << 21 /* Internal: Used for inferring enum eval type. */ 76 | DropColumnIndexFlag uint = 1 << 22 /* Internal: Used for indicate the column is being dropped with index */ 77 | ) 78 | 79 | // TypeInt24 bounds. 80 | const ( 81 | MaxUint24 = 1<<24 - 1 82 | MaxInt24 = 1<<23 - 1 83 | MinInt24 = -1 << 23 84 | ) 85 | 86 | // HasDropColumnWithIndexFlag checks if DropColumnIndexFlag is set. 87 | func HasDropColumnWithIndexFlag(flag uint) bool { 88 | return (flag & DropColumnIndexFlag) > 0 89 | } 90 | 91 | // HasNotNullFlag checks if NotNullFlag is set. 92 | func HasNotNullFlag(flag uint) bool { 93 | return (flag & NotNullFlag) > 0 94 | } 95 | 96 | // HasNoDefaultValueFlag checks if NoDefaultValueFlag is set. 97 | func HasNoDefaultValueFlag(flag uint) bool { 98 | return (flag & NoDefaultValueFlag) > 0 99 | } 100 | 101 | // HasAutoIncrementFlag checks if AutoIncrementFlag is set. 102 | func HasAutoIncrementFlag(flag uint) bool { 103 | return (flag & AutoIncrementFlag) > 0 104 | } 105 | 106 | // HasUnsignedFlag checks if UnsignedFlag is set. 107 | func HasUnsignedFlag(flag uint) bool { 108 | return (flag & UnsignedFlag) > 0 109 | } 110 | 111 | // HasZerofillFlag checks if ZerofillFlag is set. 112 | func HasZerofillFlag(flag uint) bool { 113 | return (flag & ZerofillFlag) > 0 114 | } 115 | 116 | // HasBinaryFlag checks if BinaryFlag is set. 117 | func HasBinaryFlag(flag uint) bool { 118 | return (flag & BinaryFlag) > 0 119 | } 120 | 121 | // HasPriKeyFlag checks if PriKeyFlag is set. 122 | func HasPriKeyFlag(flag uint) bool { 123 | return (flag & PriKeyFlag) > 0 124 | } 125 | 126 | // HasUniKeyFlag checks if UniqueKeyFlag is set. 127 | func HasUniKeyFlag(flag uint) bool { 128 | return (flag & UniqueKeyFlag) > 0 129 | } 130 | 131 | // HasMultipleKeyFlag checks if MultipleKeyFlag is set. 132 | func HasMultipleKeyFlag(flag uint) bool { 133 | return (flag & MultipleKeyFlag) > 0 134 | } 135 | 136 | // HasTimestampFlag checks if HasTimestampFlag is set. 137 | func HasTimestampFlag(flag uint) bool { 138 | return (flag & TimestampFlag) > 0 139 | } 140 | 141 | // HasOnUpdateNowFlag checks if OnUpdateNowFlag is set. 142 | func HasOnUpdateNowFlag(flag uint) bool { 143 | return (flag & OnUpdateNowFlag) > 0 144 | } 145 | 146 | // HasParseToJSONFlag checks if ParseToJSONFlag is set. 147 | func HasParseToJSONFlag(flag uint) bool { 148 | return (flag & ParseToJSONFlag) > 0 149 | } 150 | 151 | // HasIsBooleanFlag checks if IsBooleanFlag is set. 152 | func HasIsBooleanFlag(flag uint) bool { 153 | return (flag & IsBooleanFlag) > 0 154 | } 155 | 156 | // HasPreventNullInsertFlag checks if PreventNullInsertFlag is set. 157 | func HasPreventNullInsertFlag(flag uint) bool { 158 | return (flag & PreventNullInsertFlag) > 0 159 | } 160 | 161 | // HasEnumSetAsIntFlag checks if EnumSetAsIntFlag is set. 162 | func HasEnumSetAsIntFlag(flag uint) bool { 163 | return (flag & EnumSetAsIntFlag) > 0 164 | } 165 | -------------------------------------------------------------------------------- /mysql/type_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | ) 19 | 20 | var _ = Suite(&testTypeSuite{}) 21 | 22 | type testTypeSuite struct{} 23 | 24 | func (s *testTypeSuite) TestFlags(c *C) { 25 | c.Assert(HasNotNullFlag(NotNullFlag), IsTrue) 26 | c.Assert(HasUniKeyFlag(UniqueKeyFlag), IsTrue) 27 | c.Assert(HasNotNullFlag(NotNullFlag), IsTrue) 28 | c.Assert(HasNoDefaultValueFlag(NoDefaultValueFlag), IsTrue) 29 | c.Assert(HasAutoIncrementFlag(AutoIncrementFlag), IsTrue) 30 | c.Assert(HasUnsignedFlag(UnsignedFlag), IsTrue) 31 | c.Assert(HasZerofillFlag(ZerofillFlag), IsTrue) 32 | c.Assert(HasBinaryFlag(BinaryFlag), IsTrue) 33 | c.Assert(HasPriKeyFlag(PriKeyFlag), IsTrue) 34 | c.Assert(HasMultipleKeyFlag(MultipleKeyFlag), IsTrue) 35 | c.Assert(HasTimestampFlag(TimestampFlag), IsTrue) 36 | c.Assert(HasOnUpdateNowFlag(OnUpdateNowFlag), IsTrue) 37 | } 38 | -------------------------------------------------------------------------------- /mysql/util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package mysql 15 | 16 | type lengthAndDecimal struct { 17 | length int 18 | decimal int 19 | } 20 | 21 | // defaultLengthAndDecimal provides default Flen and Decimal for fields 22 | // from CREATE TABLE when they are unspecified. 23 | var defaultLengthAndDecimal = map[byte]lengthAndDecimal{ 24 | TypeBit: {1, 0}, 25 | TypeTiny: {4, 0}, 26 | TypeShort: {6, 0}, 27 | TypeInt24: {9, 0}, 28 | TypeLong: {11, 0}, 29 | TypeLonglong: {20, 0}, 30 | TypeDouble: {22, -1}, 31 | TypeFloat: {12, -1}, 32 | TypeNewDecimal: {10, 0}, 33 | TypeDuration: {10, 0}, 34 | TypeDate: {10, 0}, 35 | TypeTimestamp: {19, 0}, 36 | TypeDatetime: {19, 0}, 37 | TypeYear: {4, 0}, 38 | TypeString: {1, 0}, 39 | TypeVarchar: {5, 0}, 40 | TypeVarString: {5, 0}, 41 | TypeTinyBlob: {255, 0}, 42 | TypeBlob: {65535, 0}, 43 | TypeMediumBlob: {16777215, 0}, 44 | TypeLongBlob: {4294967295, 0}, 45 | TypeJSON: {4294967295, 0}, 46 | TypeNull: {0, 0}, 47 | TypeSet: {-1, 0}, 48 | TypeEnum: {-1, 0}, 49 | } 50 | 51 | // IsIntegerType indicate whether tp is an integer type. 52 | func IsIntegerType(tp byte) bool { 53 | switch tp { 54 | case TypeTiny, TypeShort, TypeInt24, TypeLong, TypeLonglong: 55 | return true 56 | } 57 | return false 58 | } 59 | 60 | // GetDefaultFieldLengthAndDecimal returns the default display length (flen) and decimal length for column. 61 | // Call this when no Flen assigned in ddl. 62 | // or column value is calculated from an expression. 63 | // For example: "select count(*) from t;", the column type is int64 and Flen in ResultField will be 21. 64 | // See https://dev.mysql.com/doc/refman/5.7/en/storage-requirements.html 65 | func GetDefaultFieldLengthAndDecimal(tp byte) (flen int, decimal int) { 66 | val, ok := defaultLengthAndDecimal[tp] 67 | if ok { 68 | return val.length, val.decimal 69 | } 70 | return -1, -1 71 | } 72 | 73 | // defaultLengthAndDecimal provides default Flen and Decimal for fields 74 | // from CAST when they are unspecified. 75 | var defaultLengthAndDecimalForCast = map[byte]lengthAndDecimal{ 76 | TypeString: {0, -1}, // Flen & Decimal differs. 77 | TypeDate: {10, 0}, 78 | TypeDatetime: {19, 0}, 79 | TypeNewDecimal: {10, 0}, 80 | TypeDuration: {10, 0}, 81 | TypeLonglong: {22, 0}, 82 | TypeDouble: {22, -1}, 83 | TypeFloat: {12, -1}, 84 | TypeJSON: {4194304, 0}, // Flen differs. 85 | } 86 | 87 | // GetDefaultFieldLengthAndDecimalForCast returns the default display length (flen) and decimal length for casted column 88 | // when flen or decimal is not specified. 89 | func GetDefaultFieldLengthAndDecimalForCast(tp byte) (flen int, decimal int) { 90 | val, ok := defaultLengthAndDecimalForCast[tp] 91 | if ok { 92 | return val.length, val.decimal 93 | } 94 | return -1, -1 95 | } 96 | -------------------------------------------------------------------------------- /opcode/opcode.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package opcode 15 | 16 | import ( 17 | "io" 18 | 19 | "github.com/pingcap/parser/format" 20 | ) 21 | 22 | // Op is opcode type. 23 | type Op int 24 | 25 | // List operators. 26 | const ( 27 | LogicAnd Op = iota + 1 28 | LeftShift 29 | RightShift 30 | LogicOr 31 | GE 32 | LE 33 | EQ 34 | NE 35 | LT 36 | GT 37 | Plus 38 | Minus 39 | And 40 | Or 41 | Mod 42 | Xor 43 | Div 44 | Mul 45 | Not 46 | Not2 47 | BitNeg 48 | IntDiv 49 | LogicXor 50 | NullEQ 51 | In 52 | Like 53 | Case 54 | Regexp 55 | IsNull 56 | IsTruth 57 | IsFalsity 58 | ) 59 | 60 | var ops = [...]struct { 61 | name string 62 | literal string 63 | isKeyword bool 64 | }{ 65 | LogicAnd: { 66 | name: "and", 67 | literal: "AND", 68 | isKeyword: true, 69 | }, 70 | LogicOr: { 71 | name: "or", 72 | literal: "OR", 73 | isKeyword: true, 74 | }, 75 | LogicXor: { 76 | name: "xor", 77 | literal: "XOR", 78 | isKeyword: true, 79 | }, 80 | LeftShift: { 81 | name: "leftshift", 82 | literal: "<<", 83 | isKeyword: false, 84 | }, 85 | RightShift: { 86 | name: "rightshift", 87 | literal: ">>", 88 | isKeyword: false, 89 | }, 90 | GE: { 91 | name: "ge", 92 | literal: ">=", 93 | isKeyword: false, 94 | }, 95 | LE: { 96 | name: "le", 97 | literal: "<=", 98 | isKeyword: false, 99 | }, 100 | EQ: { 101 | name: "eq", 102 | literal: "=", 103 | isKeyword: false, 104 | }, 105 | NE: { 106 | name: "ne", 107 | literal: "!=", // perhaps should use `<>` here 108 | isKeyword: false, 109 | }, 110 | LT: { 111 | name: "lt", 112 | literal: "<", 113 | isKeyword: false, 114 | }, 115 | GT: { 116 | name: "gt", 117 | literal: ">", 118 | isKeyword: false, 119 | }, 120 | Plus: { 121 | name: "plus", 122 | literal: "+", 123 | isKeyword: false, 124 | }, 125 | Minus: { 126 | name: "minus", 127 | literal: "-", 128 | isKeyword: false, 129 | }, 130 | And: { 131 | name: "bitand", 132 | literal: "&", 133 | isKeyword: false, 134 | }, 135 | Or: { 136 | name: "bitor", 137 | literal: "|", 138 | isKeyword: false, 139 | }, 140 | Mod: { 141 | name: "mod", 142 | literal: "%", 143 | isKeyword: false, 144 | }, 145 | Xor: { 146 | name: "bitxor", 147 | literal: "^", 148 | isKeyword: false, 149 | }, 150 | Div: { 151 | name: "div", 152 | literal: "/", 153 | isKeyword: false, 154 | }, 155 | Mul: { 156 | name: "mul", 157 | literal: "*", 158 | isKeyword: false, 159 | }, 160 | Not: { 161 | name: "not", 162 | literal: "not ", 163 | isKeyword: true, 164 | }, 165 | Not2: { 166 | name: "!", 167 | literal: "!", 168 | isKeyword: false, 169 | }, 170 | BitNeg: { 171 | name: "bitneg", 172 | literal: "~", 173 | isKeyword: false, 174 | }, 175 | IntDiv: { 176 | name: "intdiv", 177 | literal: "DIV", 178 | isKeyword: true, 179 | }, 180 | NullEQ: { 181 | name: "nulleq", 182 | literal: "<=>", 183 | isKeyword: false, 184 | }, 185 | In: { 186 | name: "in", 187 | literal: "IN", 188 | isKeyword: true, 189 | }, 190 | Like: { 191 | name: "like", 192 | literal: "LIKE", 193 | isKeyword: true, 194 | }, 195 | Case: { 196 | name: "case", 197 | literal: "CASE", 198 | isKeyword: true, 199 | }, 200 | Regexp: { 201 | name: "regexp", 202 | literal: "REGEXP", 203 | isKeyword: true, 204 | }, 205 | IsNull: { 206 | name: "isnull", 207 | literal: "IS NULL", 208 | isKeyword: true, 209 | }, 210 | IsTruth: { 211 | name: "istrue", 212 | literal: "IS TRUE", 213 | isKeyword: true, 214 | }, 215 | IsFalsity: { 216 | name: "isfalse", 217 | literal: "IS FALSE", 218 | isKeyword: true, 219 | }, 220 | } 221 | 222 | // String implements Stringer interface. 223 | func (o Op) String() string { 224 | return ops[o].name 225 | } 226 | 227 | // Format the ExprNode into a Writer. 228 | func (o Op) Format(w io.Writer) { 229 | io.WriteString(w, ops[o].literal) 230 | } 231 | 232 | // IsKeyword returns whether the operator is a keyword. 233 | func (o Op) IsKeyword() bool { 234 | return ops[o].isKeyword 235 | } 236 | 237 | // Restore the Op into a Writer 238 | func (o Op) Restore(ctx *format.RestoreCtx) error { 239 | info := &ops[o] 240 | if info.isKeyword { 241 | ctx.WriteKeyWord(info.literal) 242 | } else { 243 | ctx.WritePlain(info.literal) 244 | } 245 | return nil 246 | } 247 | -------------------------------------------------------------------------------- /opcode/opcode_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package opcode 15 | 16 | import ( 17 | "bytes" 18 | "testing" 19 | ) 20 | 21 | func TestT(t *testing.T) { 22 | op := Plus 23 | if op.String() != "plus" { 24 | t.Fatalf("invalid op code") 25 | } 26 | 27 | var buf bytes.Buffer 28 | for i := range ops { 29 | op := Op(i) 30 | op.Format(&buf) 31 | if buf.String() != ops[op].literal { 32 | t.Error("format op fail", op) 33 | } 34 | buf.Reset() 35 | } 36 | 37 | // Test invalid opcode 38 | defer func() { 39 | recover() 40 | }() 41 | 42 | op = 0 43 | s := op.String() 44 | if len(s) > 0 { 45 | t.Fail() 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /reserved_words_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | //+build reserved_words_test 15 | 16 | // This file ensures that the set of reserved keywords is the same as that of 17 | // MySQL. To run: 18 | // 19 | // 1. Set up a MySQL server listening at 127.0.0.1:3306 using root and no password 20 | // 2. Run this test with: 21 | // 22 | // go test -tags reserved_words_test -check.f TestReservedWords 23 | package parser 24 | 25 | import ( 26 | dbsql "database/sql" 27 | 28 | // needed to connect to MySQL 29 | _ "github.com/go-sql-driver/mysql" 30 | . "github.com/pingcap/check" 31 | 32 | "github.com/pingcap/parser/ast" 33 | ) 34 | 35 | func (s *testConsistentSuite) TestCompareReservedWordsWithMySQL(c *C) { 36 | p := New() 37 | db, err := dbsql.Open("mysql", "root@tcp(127.0.0.1:3306)/") 38 | c.Assert(err, IsNil) 39 | defer db.Close() 40 | 41 | for _, kw := range s.reservedKeywords { 42 | switch kw { 43 | case "CURRENT_ROLE": 44 | // special case: we do reserve CURRENT_ROLE but MySQL didn't, 45 | // and unreservering it causes legit parser conflict. 46 | continue 47 | } 48 | 49 | query := "do (select 1 as " + kw + ")" 50 | errRegexp := ".*" + kw + ".*" 51 | 52 | var err error 53 | 54 | if _, ok := windowFuncTokenMap[kw]; !ok { 55 | // for some reason the query does parse even then the keyword is reserved in TiDB. 56 | _, _, err = p.Parse(query, "", "") 57 | c.Assert(err, ErrorMatches, errRegexp) 58 | } 59 | _, err = db.Exec(query) 60 | c.Assert(err, ErrorMatches, errRegexp, Commentf("MySQL suggests that '%s' should *not* be reserved!", kw)) 61 | } 62 | 63 | for _, kws := range [][]string{s.unreservedKeywords, s.notKeywordTokens, s.tidbKeywords} { 64 | for _, kw := range kws { 65 | switch kw { 66 | case "FUNCTION", // reserved in 8.0.1 67 | "SEPARATOR": // ? 68 | continue 69 | } 70 | 71 | query := "do (select 1 as " + kw + ")" 72 | 73 | stmts, _, err := p.Parse(query, "", "") 74 | c.Assert(err, IsNil) 75 | c.Assert(stmts, HasLen, 1) 76 | c.Assert(stmts[0], FitsTypeOf, &ast.DoStmt{}) 77 | 78 | _, err = db.Exec(query) 79 | c.Assert(err, IsNil, Commentf("MySQL suggests that '%s' should be reserved!", kw)) 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /terror/terror.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package terror 15 | 16 | import ( 17 | "fmt" 18 | "strconv" 19 | "strings" 20 | "sync" 21 | "sync/atomic" 22 | 23 | "github.com/pingcap/errors" 24 | "github.com/pingcap/log" 25 | "github.com/pingcap/parser/mysql" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | // ErrCode represents a specific error type in a error class. 30 | // Same error code can be used in different error classes. 31 | type ErrCode int 32 | 33 | const ( 34 | // Executor error codes. 35 | 36 | // CodeUnknown is for errors of unknown reason. 37 | CodeUnknown ErrCode = -1 38 | // CodeExecResultIsEmpty indicates execution result is empty. 39 | CodeExecResultIsEmpty ErrCode = 3 40 | 41 | // Expression error codes. 42 | 43 | // CodeMissConnectionID indicates connection id is missing. 44 | CodeMissConnectionID ErrCode = 1 45 | 46 | // Special error codes. 47 | 48 | // CodeResultUndetermined indicates the sql execution result is undetermined. 49 | CodeResultUndetermined ErrCode = 2 50 | ) 51 | 52 | // ErrClass represents a class of errors. 53 | type ErrClass int 54 | 55 | type Error = errors.Error 56 | 57 | // Error classes. 58 | var ( 59 | ClassAutoid = RegisterErrorClass(1, "autoid") 60 | ClassDDL = RegisterErrorClass(2, "ddl") 61 | ClassDomain = RegisterErrorClass(3, "domain") 62 | ClassEvaluator = RegisterErrorClass(4, "evaluator") 63 | ClassExecutor = RegisterErrorClass(5, "executor") 64 | ClassExpression = RegisterErrorClass(6, "expression") 65 | ClassAdmin = RegisterErrorClass(7, "admin") 66 | ClassKV = RegisterErrorClass(8, "kv") 67 | ClassMeta = RegisterErrorClass(9, "meta") 68 | ClassOptimizer = RegisterErrorClass(10, "planner") 69 | ClassParser = RegisterErrorClass(11, "parser") 70 | ClassPerfSchema = RegisterErrorClass(12, "perfschema") 71 | ClassPrivilege = RegisterErrorClass(13, "privilege") 72 | ClassSchema = RegisterErrorClass(14, "schema") 73 | ClassServer = RegisterErrorClass(15, "server") 74 | ClassStructure = RegisterErrorClass(16, "structure") 75 | ClassVariable = RegisterErrorClass(17, "variable") 76 | ClassXEval = RegisterErrorClass(18, "xeval") 77 | ClassTable = RegisterErrorClass(19, "table") 78 | ClassTypes = RegisterErrorClass(20, "types") 79 | ClassGlobal = RegisterErrorClass(21, "global") 80 | ClassMockTikv = RegisterErrorClass(22, "mocktikv") 81 | ClassJSON = RegisterErrorClass(23, "json") 82 | ClassTiKV = RegisterErrorClass(24, "tikv") 83 | ClassSession = RegisterErrorClass(25, "session") 84 | ClassPlugin = RegisterErrorClass(26, "plugin") 85 | ClassUtil = RegisterErrorClass(27, "util") 86 | // Add more as needed. 87 | ) 88 | 89 | var errClass2Desc = make(map[ErrClass]string) 90 | var rfcCode2errClass = newCode2ErrClassMap() 91 | 92 | type code2ErrClassMap struct { 93 | data sync.Map 94 | } 95 | 96 | func newCode2ErrClassMap() *code2ErrClassMap { 97 | return &code2ErrClassMap{ 98 | data: sync.Map{}, 99 | } 100 | } 101 | 102 | func (m *code2ErrClassMap) Get(key string) (ErrClass, bool) { 103 | ret, have := m.data.Load(key) 104 | return ret.(ErrClass), have 105 | } 106 | 107 | func (m *code2ErrClassMap) Put(key string, err ErrClass) { 108 | m.data.Store(key, err) 109 | } 110 | 111 | var registerFinish uint32 112 | 113 | // RegisterFinish makes the register of new error panic. 114 | // The use pattern should be register all the errors during initialization, and then call RegisterFinish. 115 | func RegisterFinish() { 116 | atomic.StoreUint32(®isterFinish, 1) 117 | } 118 | 119 | func frozen() bool { 120 | return atomic.LoadUint32(®isterFinish) != 0 121 | } 122 | 123 | // RegisterErrorClass registers new error class for terror. 124 | func RegisterErrorClass(classCode int, desc string) ErrClass { 125 | errClass := ErrClass(classCode) 126 | if _, exists := errClass2Desc[errClass]; exists { 127 | panic(fmt.Sprintf("duplicate register ClassCode %d - %s", classCode, desc)) 128 | } 129 | errClass2Desc[errClass] = desc 130 | return errClass 131 | } 132 | 133 | // String implements fmt.Stringer interface. 134 | func (ec ErrClass) String() string { 135 | if s, exists := errClass2Desc[ec]; exists { 136 | return s 137 | } 138 | return strconv.Itoa(int(ec)) 139 | } 140 | 141 | // EqualClass returns true if err is *Error with the same class. 142 | func (ec ErrClass) EqualClass(err error) bool { 143 | e := errors.Cause(err) 144 | if e == nil { 145 | return false 146 | } 147 | if te, ok := e.(*Error); ok { 148 | rfcCode := te.RFCCode() 149 | if index := strings.Index(string(rfcCode), ":"); index > 0 { 150 | if class, has := rfcCode2errClass.Get(string(rfcCode)[:index]); has { 151 | return class == ec 152 | } 153 | } 154 | } 155 | return false 156 | } 157 | 158 | // NotEqualClass returns true if err is not *Error with the same class. 159 | func (ec ErrClass) NotEqualClass(err error) bool { 160 | return !ec.EqualClass(err) 161 | } 162 | 163 | func (ec ErrClass) initError(code ErrCode) string { 164 | if frozen() { 165 | panic("register error after initialized is prohibited") 166 | } 167 | clsMap, ok := ErrClassToMySQLCodes[ec] 168 | if !ok { 169 | clsMap = make(map[ErrCode]struct{}) 170 | ErrClassToMySQLCodes[ec] = clsMap 171 | } 172 | clsMap[code] = struct{}{} 173 | class := errClass2Desc[ec] 174 | rfcCode := fmt.Sprintf("%s:%d", class, code) 175 | rfcCode2errClass.Put(class, ec) 176 | return rfcCode 177 | } 178 | 179 | // New defines an *Error with an error code and an error message. 180 | // Usually used to create base *Error. 181 | // Attention: 182 | // this method is not goroutine-safe and 183 | // usually be used in global variable initializer 184 | // 185 | // Deprecated: use NewStd or NewStdErr instead. 186 | func (ec ErrClass) New(code ErrCode, message string) *Error { 187 | rfcCode := ec.initError(code) 188 | err := errors.Normalize(message, errors.MySQLErrorCode(int(code)), errors.RFCCodeText(rfcCode)) 189 | return err 190 | } 191 | 192 | // NewStdErr defines an *Error with an error code, an error 193 | // message and workaround to create standard error. 194 | func (ec ErrClass) NewStdErr(code ErrCode, message *mysql.ErrMessage) *Error { 195 | rfcCode := ec.initError(code) 196 | err := errors.Normalize(message.Raw, errors.RedactArgs(message.RedactArgPos), errors.MySQLErrorCode(int(code)), errors.RFCCodeText(rfcCode)) 197 | return err 198 | } 199 | 200 | // NewStd calls New using the standard message for the error code 201 | // Attention: 202 | // this method is not goroutine-safe and 203 | // usually be used in global variable initializer 204 | func (ec ErrClass) NewStd(code ErrCode) *Error { 205 | return ec.NewStdErr(code, mysql.MySQLErrName[uint16(code)]) 206 | } 207 | 208 | // Synthesize synthesizes an *Error in the air 209 | // it didn't register error into ErrClassToMySQLCodes 210 | // so it's goroutine-safe 211 | // and often be used to create Error came from other systems like TiKV. 212 | func (ec ErrClass) Synthesize(code ErrCode, message string) *Error { 213 | return errors.Normalize(message, errors.MySQLErrorCode(int(code)), errors.RFCCodeText(fmt.Sprintf("%s:%d", errClass2Desc[ec], code))) 214 | } 215 | 216 | // ToSQLError convert Error to mysql.SQLError. 217 | func ToSQLError(e *Error) *mysql.SQLError { 218 | code := getMySQLErrorCode(e) 219 | return mysql.NewErrf(code, "%s", nil, e.GetMsg()) 220 | } 221 | 222 | var defaultMySQLErrorCode uint16 223 | 224 | func getMySQLErrorCode(e *Error) uint16 { 225 | rfcCode := e.RFCCode() 226 | var class ErrClass 227 | if index := strings.Index(string(rfcCode), ":"); index > 0 { 228 | if ec, has := rfcCode2errClass.Get(string(rfcCode)[:index]); has { 229 | class = ec 230 | } else { 231 | log.Warn("Unknown error class", zap.String("class", string(rfcCode)[:index])) 232 | return defaultMySQLErrorCode 233 | } 234 | } 235 | codeMap, ok := ErrClassToMySQLCodes[class] 236 | if !ok { 237 | log.Warn("Unknown error class", zap.Int("class", int(class))) 238 | return defaultMySQLErrorCode 239 | } 240 | _, ok = codeMap[ErrCode(e.Code())] 241 | if !ok { 242 | log.Debug("Unknown error code", zap.Int("class", int(class)), zap.Int("code", int(e.Code()))) 243 | return defaultMySQLErrorCode 244 | } 245 | return uint16(e.Code()) 246 | } 247 | 248 | var ( 249 | // ErrClassToMySQLCodes is the map of ErrClass to code-set. 250 | ErrClassToMySQLCodes = make(map[ErrClass]map[ErrCode]struct{}) 251 | ErrCritical = ClassGlobal.NewStdErr(CodeExecResultIsEmpty, mysql.Message("critical error %v", nil)) 252 | ErrResultUndetermined = ClassGlobal.NewStdErr(CodeResultUndetermined, mysql.Message("execution result undetermined", nil)) 253 | ) 254 | 255 | func init() { 256 | defaultMySQLErrorCode = mysql.ErrUnknown 257 | } 258 | 259 | // ErrorEqual returns a boolean indicating whether err1 is equal to err2. 260 | func ErrorEqual(err1, err2 error) bool { 261 | e1 := errors.Cause(err1) 262 | e2 := errors.Cause(err2) 263 | 264 | if e1 == e2 { 265 | return true 266 | } 267 | 268 | if e1 == nil || e2 == nil { 269 | return e1 == e2 270 | } 271 | 272 | te1, ok1 := e1.(*Error) 273 | te2, ok2 := e2.(*Error) 274 | if ok1 && ok2 { 275 | return te1.RFCCode() == te2.RFCCode() 276 | } 277 | 278 | return e1.Error() == e2.Error() 279 | } 280 | 281 | // ErrorNotEqual returns a boolean indicating whether err1 isn't equal to err2. 282 | func ErrorNotEqual(err1, err2 error) bool { 283 | return !ErrorEqual(err1, err2) 284 | } 285 | 286 | // MustNil cleans up and fatals if err is not nil. 287 | func MustNil(err error, closeFuns ...func()) { 288 | if err != nil { 289 | for _, f := range closeFuns { 290 | f() 291 | } 292 | log.Fatal("unexpected error", zap.Error(err), zap.Stack("stack")) 293 | } 294 | } 295 | 296 | // Call executes a function and checks the returned err. 297 | func Call(fn func() error) { 298 | err := fn() 299 | if err != nil { 300 | log.Error("function call errored", zap.Error(err), zap.Stack("stack")) 301 | } 302 | } 303 | 304 | // Log logs the error if it is not nil. 305 | func Log(err error) { 306 | if err != nil { 307 | log.Error("encountered error", zap.Error(err), zap.Stack("stack")) 308 | } 309 | } 310 | 311 | func GetErrClass(e *Error) ErrClass { 312 | rfcCode := e.RFCCode() 313 | if index := strings.Index(string(rfcCode), ":"); index > 0 { 314 | if class, has := rfcCode2errClass.Get(string(rfcCode)[:index]); has { 315 | return class 316 | } 317 | } 318 | return ErrClass(-1) 319 | } 320 | -------------------------------------------------------------------------------- /terror/terror_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package terror 15 | 16 | import ( 17 | "encoding/json" 18 | "fmt" 19 | "os" 20 | "runtime" 21 | "strings" 22 | "testing" 23 | 24 | . "github.com/pingcap/check" 25 | "github.com/pingcap/errors" 26 | ) 27 | 28 | func TestT(t *testing.T) { 29 | CustomVerboseFlag = true 30 | TestingT(t) 31 | } 32 | 33 | var _ = Suite(&testTErrorSuite{}) 34 | 35 | type testTErrorSuite struct { 36 | } 37 | 38 | func (s *testTErrorSuite) TestErrCode(c *C) { 39 | c.Assert(CodeMissConnectionID, Equals, ErrCode(1)) 40 | c.Assert(CodeResultUndetermined, Equals, ErrCode(2)) 41 | } 42 | 43 | func (s *testTErrorSuite) TestTError(c *C) { 44 | c.Assert(ClassParser.String(), Not(Equals), "") 45 | c.Assert(ClassOptimizer.String(), Not(Equals), "") 46 | c.Assert(ClassKV.String(), Not(Equals), "") 47 | c.Assert(ClassServer.String(), Not(Equals), "") 48 | 49 | parserErr := ClassParser.New(ErrCode(100), "error 100") 50 | c.Assert(parserErr.Error(), Not(Equals), "") 51 | c.Assert(ClassParser.EqualClass(parserErr), IsTrue) 52 | c.Assert(ClassParser.NotEqualClass(parserErr), IsFalse) 53 | 54 | c.Assert(ClassOptimizer.EqualClass(parserErr), IsFalse) 55 | optimizerErr := ClassOptimizer.New(ErrCode(2), "abc") 56 | c.Assert(ClassOptimizer.EqualClass(errors.New("abc")), IsFalse) 57 | c.Assert(ClassOptimizer.EqualClass(nil), IsFalse) 58 | c.Assert(optimizerErr.Equal(optimizerErr.GenWithStack("def")), IsTrue) 59 | c.Assert(optimizerErr.Equal(nil), IsFalse) 60 | c.Assert(optimizerErr.Equal(errors.New("abc")), IsFalse) 61 | 62 | // Test case for FastGen. 63 | c.Assert(optimizerErr.Equal(optimizerErr.FastGen("def")), IsTrue) 64 | c.Assert(optimizerErr.Equal(optimizerErr.FastGen("def: %s", "def")), IsTrue) 65 | kvErr := ClassKV.New(1062, "key already exist") 66 | e := kvErr.FastGen("Duplicate entry '%d' for key 'PRIMARY'", 1) 67 | c.Assert(e.Error(), Equals, "[kv:1062]Duplicate entry '1' for key 'PRIMARY'") 68 | sqlErr := ToSQLError(errors.Cause(e).(*Error)) 69 | c.Assert(sqlErr.Message, Equals, "Duplicate entry '1' for key 'PRIMARY'") 70 | c.Assert(sqlErr.Code, Equals, uint16(1062)) 71 | 72 | err := errors.Trace(ErrCritical.GenWithStackByArgs("test")) 73 | c.Assert(ErrCritical.Equal(err), IsTrue) 74 | 75 | err = errors.Trace(ErrCritical) 76 | c.Assert(ErrCritical.Equal(err), IsTrue) 77 | } 78 | 79 | func (s *testTErrorSuite) TestJson(c *C) { 80 | prevTErr := errors.Normalize("json test", errors.MySQLErrorCode(int(CodeExecResultIsEmpty))) 81 | buf, err := json.Marshal(prevTErr) 82 | c.Assert(err, IsNil) 83 | var curTErr errors.Error 84 | err = json.Unmarshal(buf, &curTErr) 85 | c.Assert(err, IsNil) 86 | isEqual := prevTErr.Equal(&curTErr) 87 | c.Assert(isEqual, IsTrue) 88 | } 89 | 90 | var predefinedErr = ClassExecutor.New(ErrCode(123), "predefiend error") 91 | 92 | func example() error { 93 | err := call() 94 | return errors.Trace(err) 95 | } 96 | 97 | func call() error { 98 | return predefinedErr.GenWithStack("error message:%s", "abc") 99 | } 100 | 101 | func (s *testTErrorSuite) TestTraceAndLocation(c *C) { 102 | err := example() 103 | stack := errors.ErrorStack(err) 104 | lines := strings.Split(stack, "\n") 105 | goroot := strings.ReplaceAll(runtime.GOROOT(), string(os.PathSeparator), "/") 106 | var sysStack = 0 107 | for _, line := range lines { 108 | if strings.Contains(line, goroot) { 109 | sysStack++ 110 | } 111 | } 112 | c.Assert(len(lines)-(2*sysStack), Equals, 15, Commentf("stack =\n%s", stack)) 113 | var containTerr bool 114 | for _, v := range lines { 115 | if strings.Contains(v, "terror_test.go") { 116 | containTerr = true 117 | break 118 | } 119 | } 120 | c.Assert(containTerr, IsTrue) 121 | } 122 | 123 | func (s *testTErrorSuite) TestErrorEqual(c *C) { 124 | e1 := errors.New("test error") 125 | c.Assert(e1, NotNil) 126 | 127 | e2 := errors.Trace(e1) 128 | c.Assert(e2, NotNil) 129 | 130 | e3 := errors.Trace(e2) 131 | c.Assert(e3, NotNil) 132 | 133 | c.Assert(errors.Cause(e2), Equals, e1) 134 | c.Assert(errors.Cause(e3), Equals, e1) 135 | c.Assert(errors.Cause(e2), Equals, errors.Cause(e3)) 136 | 137 | e4 := errors.New("test error") 138 | c.Assert(errors.Cause(e4), Not(Equals), e1) 139 | 140 | e5 := errors.Errorf("test error") 141 | c.Assert(errors.Cause(e5), Not(Equals), e1) 142 | 143 | c.Assert(ErrorEqual(e1, e2), IsTrue) 144 | c.Assert(ErrorEqual(e1, e3), IsTrue) 145 | c.Assert(ErrorEqual(e1, e4), IsTrue) 146 | c.Assert(ErrorEqual(e1, e5), IsTrue) 147 | 148 | var e6 error 149 | 150 | c.Assert(ErrorEqual(nil, nil), IsTrue) 151 | c.Assert(ErrorNotEqual(e1, e6), IsTrue) 152 | code1 := ErrCode(9001) 153 | code2 := ErrCode(9002) 154 | te1 := ClassParser.Synthesize(code1, "abc") 155 | te3 := ClassKV.New(code1, "abc") 156 | te4 := ClassKV.New(code2, "abc") 157 | c.Assert(ErrorEqual(te1, te3), IsFalse) 158 | c.Assert(ErrorEqual(te3, te4), IsFalse) 159 | } 160 | 161 | func (s *testTErrorSuite) TestLog(c *C) { 162 | err := fmt.Errorf("xxx") 163 | Log(err) 164 | } 165 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # If 'check.TestingT' is not used in any of the *_test.go files in a subdir no tests will run. 4 | 5 | for f in $(git grep -l 'github.com/pingcap/check' | grep '/' | cut -d/ -f1 | uniq) 6 | do 7 | if ! grep -r TestingT "$f" > /dev/null 8 | then 9 | echo "check.TestingT missing from $f" 10 | exit 1 11 | fi 12 | done 13 | 14 | GO111MODULE=on go test -p 1 -race -covermode=atomic -coverprofile=coverage.txt -coverpkg=./... ./... 15 | -------------------------------------------------------------------------------- /test_driver/test_driver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | //+build !codes 15 | 16 | package test_driver 17 | 18 | import ( 19 | "fmt" 20 | "io" 21 | "strconv" 22 | 23 | "github.com/pingcap/parser/ast" 24 | "github.com/pingcap/parser/charset" 25 | "github.com/pingcap/parser/format" 26 | "github.com/pingcap/parser/mysql" 27 | ) 28 | 29 | func init() { 30 | ast.NewValueExpr = newValueExpr 31 | ast.NewParamMarkerExpr = newParamMarkerExpr 32 | ast.NewDecimal = func(str string) (interface{}, error) { 33 | dec := new(MyDecimal) 34 | err := dec.FromString([]byte(str)) 35 | return dec, err 36 | } 37 | ast.NewHexLiteral = func(str string) (interface{}, error) { 38 | h, err := NewHexLiteral(str) 39 | return h, err 40 | } 41 | ast.NewBitLiteral = func(str string) (interface{}, error) { 42 | b, err := NewBitLiteral(str) 43 | return b, err 44 | } 45 | } 46 | 47 | var ( 48 | _ ast.ParamMarkerExpr = &ParamMarkerExpr{} 49 | _ ast.ValueExpr = &ValueExpr{} 50 | ) 51 | 52 | // ValueExpr is the simple value expression. 53 | type ValueExpr struct { 54 | ast.TexprNode 55 | Datum 56 | projectionOffset int 57 | } 58 | 59 | // Restore implements Node interface. 60 | func (n *ValueExpr) Restore(ctx *format.RestoreCtx) error { 61 | switch n.Kind() { 62 | case KindNull: 63 | ctx.WriteKeyWord("NULL") 64 | case KindInt64: 65 | if n.Type.Flag&mysql.IsBooleanFlag != 0 { 66 | if n.GetInt64() > 0 { 67 | ctx.WriteKeyWord("TRUE") 68 | } else { 69 | ctx.WriteKeyWord("FALSE") 70 | } 71 | } else { 72 | ctx.WritePlain(strconv.FormatInt(n.GetInt64(), 10)) 73 | } 74 | case KindUint64: 75 | ctx.WritePlain(strconv.FormatUint(n.GetUint64(), 10)) 76 | case KindFloat32: 77 | ctx.WritePlain(strconv.FormatFloat(n.GetFloat64(), 'e', -1, 32)) 78 | case KindFloat64: 79 | ctx.WritePlain(strconv.FormatFloat(n.GetFloat64(), 'e', -1, 64)) 80 | case KindString: 81 | if n.Type.Charset != "" { 82 | ctx.WritePlain("_") 83 | ctx.WriteKeyWord(n.Type.Charset) 84 | } 85 | ctx.WriteString(n.GetString()) 86 | case KindBytes: 87 | ctx.WriteString(n.GetString()) 88 | case KindMysqlDecimal: 89 | ctx.WritePlain(n.GetMysqlDecimal().String()) 90 | case KindBinaryLiteral: 91 | if n.Type.Charset != "" && n.Type.Charset != mysql.DefaultCharset && 92 | n.Type.Charset != charset.CharsetBin { 93 | ctx.WritePlain("_") 94 | ctx.WriteKeyWord(n.Type.Charset + " ") 95 | } 96 | if n.Type.Flag&mysql.UnsignedFlag != 0 { 97 | ctx.WritePlainf("x'%x'", n.GetBytes()) 98 | } else { 99 | ctx.WritePlain(n.GetBinaryLiteral().ToBitLiteralString(true)) 100 | } 101 | case KindMysqlDuration, KindMysqlEnum, 102 | KindMysqlBit, KindMysqlSet, KindMysqlTime, 103 | KindInterface, KindMinNotNull, KindMaxValue, 104 | KindRaw, KindMysqlJSON: 105 | // TODO implement Restore function 106 | return fmt.Errorf("not implemented") 107 | default: 108 | return fmt.Errorf("can't format to string") 109 | } 110 | return nil 111 | } 112 | 113 | // GetDatumString implements the ValueExpr interface. 114 | func (n *ValueExpr) GetDatumString() string { 115 | return n.GetString() 116 | } 117 | 118 | // Format the ExprNode into a Writer. 119 | func (n *ValueExpr) Format(w io.Writer) { 120 | var s string 121 | switch n.Kind() { 122 | case KindNull: 123 | s = "NULL" 124 | case KindInt64: 125 | if n.Type.Flag&mysql.IsBooleanFlag != 0 { 126 | if n.GetInt64() > 0 { 127 | s = "TRUE" 128 | } else { 129 | s = "FALSE" 130 | } 131 | } else { 132 | s = strconv.FormatInt(n.GetInt64(), 10) 133 | } 134 | case KindUint64: 135 | s = strconv.FormatUint(n.GetUint64(), 10) 136 | case KindFloat32: 137 | s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 32) 138 | case KindFloat64: 139 | s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 64) 140 | case KindString, KindBytes: 141 | s = strconv.Quote(n.GetString()) 142 | case KindMysqlDecimal: 143 | s = n.GetMysqlDecimal().String() 144 | case KindBinaryLiteral: 145 | if n.Type.Flag&mysql.UnsignedFlag != 0 { 146 | s = fmt.Sprintf("x'%x'", n.GetBytes()) 147 | } else { 148 | s = n.GetBinaryLiteral().ToBitLiteralString(true) 149 | } 150 | default: 151 | panic("Can't format to string") 152 | } 153 | _, _ = fmt.Fprint(w, s) 154 | } 155 | 156 | // newValueExpr creates a ValueExpr with value, and sets default field type. 157 | func newValueExpr(value interface{}, charset string, collate string) ast.ValueExpr { 158 | if ve, ok := value.(*ValueExpr); ok { 159 | return ve 160 | } 161 | ve := &ValueExpr{} 162 | ve.SetValue(value) 163 | DefaultTypeForValue(value, &ve.Type, charset, collate) 164 | ve.projectionOffset = -1 165 | return ve 166 | } 167 | 168 | // SetProjectionOffset sets ValueExpr.projectionOffset for logical plan builder. 169 | func (n *ValueExpr) SetProjectionOffset(offset int) { 170 | n.projectionOffset = offset 171 | } 172 | 173 | // GetProjectionOffset returns ValueExpr.projectionOffset. 174 | func (n *ValueExpr) GetProjectionOffset() int { 175 | return n.projectionOffset 176 | } 177 | 178 | // Accept implements Node interface. 179 | func (n *ValueExpr) Accept(v ast.Visitor) (ast.Node, bool) { 180 | newNode, skipChildren := v.Enter(n) 181 | if skipChildren { 182 | return v.Leave(newNode) 183 | } 184 | n = newNode.(*ValueExpr) 185 | return v.Leave(n) 186 | } 187 | 188 | // ParamMarkerExpr expression holds a place for another expression. 189 | // Used in parsing prepare statement. 190 | type ParamMarkerExpr struct { 191 | ValueExpr 192 | Offset int 193 | Order int 194 | InExecute bool 195 | } 196 | 197 | // Restore implements Node interface. 198 | func (n *ParamMarkerExpr) Restore(ctx *format.RestoreCtx) error { 199 | ctx.WritePlain("?") 200 | return nil 201 | } 202 | 203 | func newParamMarkerExpr(offset int) ast.ParamMarkerExpr { 204 | return &ParamMarkerExpr{ 205 | Offset: offset, 206 | } 207 | } 208 | 209 | // Format the ExprNode into a Writer. 210 | func (n *ParamMarkerExpr) Format(w io.Writer) { 211 | panic("Not implemented") 212 | } 213 | 214 | // Accept implements Node Accept interface. 215 | func (n *ParamMarkerExpr) Accept(v ast.Visitor) (ast.Node, bool) { 216 | newNode, skipChildren := v.Enter(n) 217 | if skipChildren { 218 | return v.Leave(newNode) 219 | } 220 | n = newNode.(*ParamMarkerExpr) 221 | return v.Leave(n) 222 | } 223 | 224 | // SetOrder implements the ParamMarkerExpr interface. 225 | func (n *ParamMarkerExpr) SetOrder(order int) { 226 | n.Order = order 227 | } 228 | -------------------------------------------------------------------------------- /test_driver/test_driver_helper.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | //+build !codes 15 | 16 | package test_driver 17 | 18 | import ( 19 | "math" 20 | ) 21 | 22 | func isSpace(c byte) bool { 23 | return c == ' ' || c == '\t' 24 | } 25 | 26 | func isDigit(c byte) bool { 27 | return c >= '0' && c <= '9' 28 | } 29 | 30 | func myMin(a, b int) int { 31 | if a < b { 32 | return a 33 | } 34 | return b 35 | } 36 | 37 | func pow10(x int) int32 { 38 | return int32(math.Pow10(x)) 39 | } 40 | 41 | func Abs(n int64) int64 { 42 | y := n >> 63 43 | return (n ^ y) - y 44 | } 45 | 46 | // uintSizeTable is used as a table to do comparison to get uint length is faster than doing loop on division with 10 47 | var uintSizeTable = [21]uint64{ 48 | 0, // redundant 0 here, so to make function StrLenOfUint64Fast to count from 1 and return i directly 49 | 9, 99, 999, 9999, 99999, 50 | 999999, 9999999, 99999999, 999999999, 9999999999, 51 | 99999999999, 999999999999, 9999999999999, 99999999999999, 999999999999999, 52 | 9999999999999999, 99999999999999999, 999999999999999999, 9999999999999999999, 53 | math.MaxUint64, 54 | } // math.MaxUint64 is 18446744073709551615 and it has 20 digits 55 | 56 | // StrLenOfUint64Fast efficiently calculate the string character lengths of an uint64 as input 57 | func StrLenOfUint64Fast(x uint64) int { 58 | for i := 1; ; i++ { 59 | if x <= uintSizeTable[i] { 60 | return i 61 | } 62 | } 63 | } 64 | 65 | // StrLenOfInt64Fast efficiently calculate the string character lengths of an int64 as input 66 | func StrLenOfInt64Fast(x int64) int { 67 | size := 0 68 | if x < 0 { 69 | size = 1 // add "-" sign on the length count 70 | } 71 | return size + StrLenOfUint64Fast(uint64(Abs(x))) 72 | } 73 | -------------------------------------------------------------------------------- /test_driver/test_driver_mydecimal.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | //+build !codes 15 | 16 | package test_driver 17 | 18 | const panicInfo = "This branch is not implemented. " + 19 | "This is because you are trying to test something specific to TiDB's MyDecimal implementation. " + 20 | "It is recommended to do this in TiDB repository." 21 | 22 | // constant values. 23 | const ( 24 | maxWordBufLen = 9 // A MyDecimal holds 9 words. 25 | digitsPerWord = 9 // A word holds 9 digits. 26 | digMask = 100000000 27 | ) 28 | 29 | var ( 30 | wordBufLen = 9 31 | ) 32 | 33 | // fixWordCntError limits word count in wordBufLen, and returns overflow or truncate error. 34 | func fixWordCntError(wordsInt, wordsFrac int) (newWordsInt int, newWordsFrac int, err error) { 35 | if wordsInt+wordsFrac > wordBufLen { 36 | panic(panicInfo) 37 | } 38 | return wordsInt, wordsFrac, nil 39 | } 40 | 41 | /* 42 | countLeadingZeroes returns the number of leading zeroes that can be removed from fraction. 43 | 44 | @param i start index 45 | @param word value to compare against list of powers of 10 46 | */ 47 | func countLeadingZeroes(i int, word int32) int { 48 | leading := 0 49 | for word < pow10(i) { 50 | i-- 51 | leading++ 52 | } 53 | return leading 54 | } 55 | 56 | func digitsToWords(digits int) int { 57 | return (digits + digitsPerWord - 1) / digitsPerWord 58 | } 59 | 60 | // MyDecimal represents a decimal value. 61 | type MyDecimal struct { 62 | digitsInt int8 // the number of *decimal* digits before the point. 63 | 64 | digitsFrac int8 // the number of decimal digits after the point. 65 | 66 | resultFrac int8 // result fraction digits. 67 | 68 | negative bool 69 | 70 | // wordBuf is an array of int32 words. 71 | // A word is an int32 value can hold 9 digits.(0 <= word < wordBase) 72 | wordBuf [maxWordBufLen]int32 73 | } 74 | 75 | // String returns the decimal string representation rounded to resultFrac. 76 | func (d *MyDecimal) String() string { 77 | tmp := *d 78 | return string(tmp.ToString()) 79 | } 80 | 81 | func (d *MyDecimal) stringSize() int { 82 | // sign, zero integer and dot. 83 | return int(d.digitsInt + d.digitsFrac + 3) 84 | } 85 | 86 | func (d *MyDecimal) removeLeadingZeros() (wordIdx int, digitsInt int) { 87 | digitsInt = int(d.digitsInt) 88 | i := ((digitsInt - 1) % digitsPerWord) + 1 89 | for digitsInt > 0 && d.wordBuf[wordIdx] == 0 { 90 | digitsInt -= i 91 | i = digitsPerWord 92 | wordIdx++ 93 | } 94 | if digitsInt > 0 { 95 | digitsInt -= countLeadingZeroes((digitsInt-1)%digitsPerWord, d.wordBuf[wordIdx]) 96 | } else { 97 | digitsInt = 0 98 | } 99 | return 100 | } 101 | 102 | // ToString converts decimal to its printable string representation without rounding. 103 | // 104 | // RETURN VALUE 105 | // 106 | // str - result string 107 | // errCode - eDecOK/eDecTruncate/eDecOverflow 108 | // 109 | func (d *MyDecimal) ToString() (str []byte) { 110 | str = make([]byte, d.stringSize()) 111 | digitsFrac := int(d.digitsFrac) 112 | wordStartIdx, digitsInt := d.removeLeadingZeros() 113 | if digitsInt+digitsFrac == 0 { 114 | digitsInt = 1 115 | wordStartIdx = 0 116 | } 117 | 118 | digitsIntLen := digitsInt 119 | if digitsIntLen == 0 { 120 | digitsIntLen = 1 121 | } 122 | digitsFracLen := digitsFrac 123 | length := digitsIntLen + digitsFracLen 124 | if d.negative { 125 | length++ 126 | } 127 | if digitsFrac > 0 { 128 | length++ 129 | } 130 | str = str[:length] 131 | strIdx := 0 132 | if d.negative { 133 | str[strIdx] = '-' 134 | strIdx++ 135 | } 136 | var fill int 137 | if digitsFrac > 0 { 138 | fracIdx := strIdx + digitsIntLen 139 | fill = digitsFracLen - digitsFrac 140 | wordIdx := wordStartIdx + digitsToWords(digitsInt) 141 | str[fracIdx] = '.' 142 | fracIdx++ 143 | for ; digitsFrac > 0; digitsFrac -= digitsPerWord { 144 | x := d.wordBuf[wordIdx] 145 | wordIdx++ 146 | for i := myMin(digitsFrac, digitsPerWord); i > 0; i-- { 147 | y := x / digMask 148 | str[fracIdx] = byte(y) + '0' 149 | fracIdx++ 150 | x -= y * digMask 151 | x *= 10 152 | } 153 | } 154 | for ; fill > 0; fill-- { 155 | str[fracIdx] = '0' 156 | fracIdx++ 157 | } 158 | } 159 | fill = digitsIntLen - digitsInt 160 | if digitsInt == 0 { 161 | fill-- /* symbol 0 before digital point */ 162 | } 163 | for ; fill > 0; fill-- { 164 | str[strIdx] = '0' 165 | strIdx++ 166 | } 167 | if digitsInt > 0 { 168 | strIdx += digitsInt 169 | wordIdx := wordStartIdx + digitsToWords(digitsInt) 170 | for ; digitsInt > 0; digitsInt -= digitsPerWord { 171 | wordIdx-- 172 | x := d.wordBuf[wordIdx] 173 | for i := myMin(digitsInt, digitsPerWord); i > 0; i-- { 174 | y := x / 10 175 | strIdx-- 176 | str[strIdx] = '0' + byte(x-y*10) 177 | x = y 178 | } 179 | } 180 | } else { 181 | str[strIdx] = '0' 182 | } 183 | return 184 | } 185 | 186 | // FromString parses decimal from string. 187 | func (d *MyDecimal) FromString(str []byte) error { 188 | for i := 0; i < len(str); i++ { 189 | if !isSpace(str[i]) { 190 | str = str[i:] 191 | break 192 | } 193 | } 194 | if len(str) == 0 { 195 | panic(panicInfo) 196 | } 197 | switch str[0] { 198 | case '-': 199 | d.negative = true 200 | fallthrough 201 | case '+': 202 | str = str[1:] 203 | } 204 | var strIdx int 205 | for strIdx < len(str) && isDigit(str[strIdx]) { 206 | strIdx++ 207 | } 208 | digitsInt := strIdx 209 | var digitsFrac int 210 | var endIdx int 211 | if strIdx < len(str) && str[strIdx] == '.' { 212 | endIdx = strIdx + 1 213 | for endIdx < len(str) && isDigit(str[endIdx]) { 214 | endIdx++ 215 | } 216 | digitsFrac = endIdx - strIdx - 1 217 | } else { 218 | digitsFrac = 0 219 | endIdx = strIdx 220 | } 221 | if digitsInt+digitsFrac == 0 { 222 | panic(panicInfo) 223 | } 224 | wordsInt := digitsToWords(digitsInt) 225 | wordsFrac := digitsToWords(digitsFrac) 226 | wordsInt, _, err := fixWordCntError(wordsInt, wordsFrac) 227 | if err != nil { 228 | panic(panicInfo) 229 | } 230 | d.digitsInt = int8(digitsInt) 231 | d.digitsFrac = int8(digitsFrac) 232 | wordIdx := wordsInt 233 | strIdxTmp := strIdx 234 | var word int32 235 | var innerIdx int 236 | for digitsInt > 0 { 237 | digitsInt-- 238 | strIdx-- 239 | word += int32(str[strIdx]-'0') * pow10(innerIdx) 240 | innerIdx++ 241 | if innerIdx == digitsPerWord { 242 | wordIdx-- 243 | d.wordBuf[wordIdx] = word 244 | word = 0 245 | innerIdx = 0 246 | } 247 | } 248 | if innerIdx != 0 { 249 | wordIdx-- 250 | d.wordBuf[wordIdx] = word 251 | } 252 | 253 | wordIdx = wordsInt 254 | strIdx = strIdxTmp 255 | word = 0 256 | innerIdx = 0 257 | for digitsFrac > 0 { 258 | digitsFrac-- 259 | strIdx++ 260 | word = int32(str[strIdx]-'0') + word*10 261 | innerIdx++ 262 | if innerIdx == digitsPerWord { 263 | d.wordBuf[wordIdx] = word 264 | wordIdx++ 265 | word = 0 266 | innerIdx = 0 267 | } 268 | } 269 | if innerIdx != 0 { 270 | d.wordBuf[wordIdx] = word * pow10(digitsPerWord-innerIdx) 271 | } 272 | if endIdx+1 <= len(str) && (str[endIdx] == 'e' || str[endIdx] == 'E') { 273 | panic(panicInfo) 274 | } 275 | allZero := true 276 | for i := 0; i < wordBufLen; i++ { 277 | if d.wordBuf[i] != 0 { 278 | allZero = false 279 | break 280 | } 281 | } 282 | if allZero { 283 | d.negative = false 284 | } 285 | d.resultFrac = d.digitsFrac 286 | return err 287 | } 288 | -------------------------------------------------------------------------------- /tidb/features.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package tidb 15 | 16 | const ( 17 | // FeatureIDTiDB represents the general TiDB-specific features. 18 | FeatureIDTiDB = "" 19 | // FeatureIDAutoRandom is the `auto_random` feature. 20 | FeatureIDAutoRandom = "auto_rand" 21 | // FeatureIDAutoIDCache is the `auto_id_cache` feature. 22 | FeatureIDAutoIDCache = "auto_id_cache" 23 | // FeatureIDAutoRandomBase is the `auto_random_base` feature. 24 | FeatureIDAutoRandomBase = "auto_rand_base" 25 | // FeatureIDClusteredIndex is the `clustered_index` feature. 26 | FeatureIDClusteredIndex = "clustered_index" 27 | // FeatureIDForceAutoInc is the `force auto_increment` feature. 28 | FeatureIDForceAutoInc = "force_inc" 29 | // FeatureIDPlacement is the `placement rule` feature. 30 | FeatureIDPlacement = "placement" 31 | ) 32 | 33 | var featureIDs = map[string]struct{}{ 34 | FeatureIDAutoRandom: {}, 35 | FeatureIDAutoIDCache: {}, 36 | FeatureIDAutoRandomBase: {}, 37 | FeatureIDClusteredIndex: {}, 38 | FeatureIDForceAutoInc: {}, 39 | FeatureIDPlacement: {}, 40 | } 41 | 42 | func CanParseFeature(fs ...string) bool { 43 | for _, f := range fs { 44 | if _, ok := featureIDs[f]; !ok { 45 | return false 46 | } 47 | } 48 | return true 49 | } 50 | -------------------------------------------------------------------------------- /types/etc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The ql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSES/QL-LICENSE file. 4 | 5 | // Copyright 2015 PingCAP, Inc. 6 | // 7 | // Licensed under the Apache License, Version 2.0 (the "License"); 8 | // you may not use this file except in compliance with the License. 9 | // You may obtain a copy of the License at 10 | // 11 | // http://www.apache.org/licenses/LICENSE-2.0 12 | // 13 | // Unless required by applicable law or agreed to in writing, software 14 | // distributed under the License is distributed on an "AS IS" BASIS, 15 | // See the License for the specific language governing permissions and 16 | // limitations under the License. 17 | 18 | package types 19 | 20 | import ( 21 | "strings" 22 | 23 | "github.com/pingcap/parser/mysql" 24 | "github.com/pingcap/parser/terror" 25 | ) 26 | 27 | // IsTypeBlob returns a boolean indicating whether the tp is a blob type. 28 | func IsTypeBlob(tp byte) bool { 29 | switch tp { 30 | case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: 31 | return true 32 | default: 33 | return false 34 | } 35 | } 36 | 37 | // IsTypeChar returns a boolean indicating 38 | // whether the tp is the char type like a string type or a varchar type. 39 | func IsTypeChar(tp byte) bool { 40 | return tp == mysql.TypeString || tp == mysql.TypeVarchar 41 | } 42 | 43 | var type2Str = map[byte]string{ 44 | mysql.TypeBit: "bit", 45 | mysql.TypeBlob: "text", 46 | mysql.TypeDate: "date", 47 | mysql.TypeDatetime: "datetime", 48 | mysql.TypeUnspecified: "unspecified", 49 | mysql.TypeNewDecimal: "decimal", 50 | mysql.TypeDouble: "double", 51 | mysql.TypeEnum: "enum", 52 | mysql.TypeFloat: "float", 53 | mysql.TypeGeometry: "geometry", 54 | mysql.TypeInt24: "mediumint", 55 | mysql.TypeJSON: "json", 56 | mysql.TypeLong: "int", 57 | mysql.TypeLonglong: "bigint", 58 | mysql.TypeLongBlob: "longtext", 59 | mysql.TypeMediumBlob: "mediumtext", 60 | mysql.TypeNull: "null", 61 | mysql.TypeSet: "set", 62 | mysql.TypeShort: "smallint", 63 | mysql.TypeString: "char", 64 | mysql.TypeDuration: "time", 65 | mysql.TypeTimestamp: "timestamp", 66 | mysql.TypeTiny: "tinyint", 67 | mysql.TypeTinyBlob: "tinytext", 68 | mysql.TypeVarchar: "varchar", 69 | mysql.TypeVarString: "var_string", 70 | mysql.TypeYear: "year", 71 | } 72 | 73 | // TypeStr converts tp to a string. 74 | func TypeStr(tp byte) (r string) { 75 | return type2Str[tp] 76 | } 77 | 78 | // TypeToStr converts a field to a string. 79 | // It is used for converting Text to Blob, 80 | // or converting Char to Binary. 81 | // Args: 82 | // tp: type enum 83 | // cs: charset 84 | func TypeToStr(tp byte, cs string) (r string) { 85 | ts := type2Str[tp] 86 | if cs != "binary" { 87 | return ts 88 | } 89 | if IsTypeBlob(tp) { 90 | ts = strings.Replace(ts, "text", "blob", 1) 91 | } else if IsTypeChar(tp) { 92 | ts = strings.Replace(ts, "char", "binary", 1) 93 | } 94 | return ts 95 | } 96 | 97 | var ( 98 | dig2bytes = [10]int{0, 1, 1, 2, 2, 3, 3, 4, 4, 4} 99 | ) 100 | 101 | // constant values. 102 | const ( 103 | digitsPerWord = 9 // A word holds 9 digits. 104 | wordSize = 4 // A word is 4 bytes int32. 105 | ) 106 | 107 | // ErrInvalidDefault is returned when meet a invalid default value. 108 | var ErrInvalidDefault = terror.ClassTypes.NewStd(mysql.ErrInvalidDefault) 109 | -------------------------------------------------------------------------------- /types/eval_type.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package types 15 | 16 | // EvalType indicates the specified types that arguments and result of a built-in function should be. 17 | type EvalType byte 18 | 19 | const ( 20 | // ETInt represents type INT in evaluation. 21 | ETInt EvalType = iota 22 | // ETReal represents type REAL in evaluation. 23 | ETReal 24 | // ETDecimal represents type DECIMAL in evaluation. 25 | ETDecimal 26 | // ETString represents type STRING in evaluation. 27 | ETString 28 | // ETDatetime represents type DATETIME in evaluation. 29 | ETDatetime 30 | // ETTimestamp represents type TIMESTAMP in evaluation. 31 | ETTimestamp 32 | // ETDuration represents type DURATION in evaluation. 33 | ETDuration 34 | // ETJson represents type JSON in evaluation. 35 | ETJson 36 | ) 37 | 38 | // IsStringKind returns true for ETString, ETDatetime, ETTimestamp, ETDuration, ETJson EvalTypes. 39 | func (et EvalType) IsStringKind() bool { 40 | return et == ETString || et == ETDatetime || 41 | et == ETTimestamp || et == ETDuration || et == ETJson 42 | } 43 | -------------------------------------------------------------------------------- /types/field_type_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package types_test 15 | 16 | import ( 17 | "fmt" 18 | "testing" 19 | 20 | . "github.com/pingcap/check" 21 | "github.com/pingcap/parser" 22 | "github.com/pingcap/parser/ast" 23 | "github.com/pingcap/parser/charset" 24 | "github.com/pingcap/parser/mysql" 25 | . "github.com/pingcap/parser/types" 26 | 27 | // import parser_driver 28 | _ "github.com/pingcap/parser/test_driver" 29 | ) 30 | 31 | func TestT(t *testing.T) { 32 | CustomVerboseFlag = true 33 | TestingT(t) 34 | } 35 | 36 | var _ = Suite(&testFieldTypeSuite{}) 37 | 38 | type testFieldTypeSuite struct { 39 | } 40 | 41 | func (s *testFieldTypeSuite) TestFieldType(c *C) { 42 | ft := NewFieldType(mysql.TypeDuration) 43 | c.Assert(ft.Flen, Equals, UnspecifiedLength) 44 | c.Assert(ft.Decimal, Equals, UnspecifiedLength) 45 | ft.Decimal = 5 46 | c.Assert(ft.String(), Equals, "time(5)") 47 | c.Assert(HasCharset(ft), IsFalse) 48 | 49 | ft = NewFieldType(mysql.TypeLong) 50 | ft.Flen = 5 51 | ft.Flag = mysql.UnsignedFlag | mysql.ZerofillFlag 52 | c.Assert(ft.String(), Equals, "int(5) UNSIGNED ZEROFILL") 53 | c.Assert(ft.InfoSchemaStr(), Equals, "int(5) unsigned") 54 | c.Assert(HasCharset(ft), IsFalse) 55 | 56 | ft = NewFieldType(mysql.TypeFloat) 57 | ft.Flen = 12 // Default 58 | ft.Decimal = 3 // Not Default 59 | c.Assert(ft.String(), Equals, "float(12,3)") 60 | ft = NewFieldType(mysql.TypeFloat) 61 | ft.Flen = 12 // Default 62 | ft.Decimal = -1 // Default 63 | c.Assert(ft.String(), Equals, "float") 64 | ft = NewFieldType(mysql.TypeFloat) 65 | ft.Flen = 5 // Not Default 66 | ft.Decimal = -1 // Default 67 | c.Assert(ft.String(), Equals, "float") 68 | ft = NewFieldType(mysql.TypeFloat) 69 | ft.Flen = 7 // Not Default 70 | ft.Decimal = 3 // Not Default 71 | c.Assert(ft.String(), Equals, "float(7,3)") 72 | c.Assert(HasCharset(ft), IsFalse) 73 | 74 | ft = NewFieldType(mysql.TypeDouble) 75 | ft.Flen = 22 // Default 76 | ft.Decimal = 3 // Not Default 77 | c.Assert(ft.String(), Equals, "double(22,3)") 78 | ft = NewFieldType(mysql.TypeDouble) 79 | ft.Flen = 22 // Default 80 | ft.Decimal = -1 // Default 81 | c.Assert(ft.String(), Equals, "double") 82 | ft = NewFieldType(mysql.TypeDouble) 83 | ft.Flen = 5 // Not Default 84 | ft.Decimal = -1 // Default 85 | c.Assert(ft.String(), Equals, "double") 86 | ft = NewFieldType(mysql.TypeDouble) 87 | ft.Flen = 7 // Not Default 88 | ft.Decimal = 3 // Not Default 89 | c.Assert(ft.String(), Equals, "double(7,3)") 90 | c.Assert(HasCharset(ft), IsFalse) 91 | 92 | ft = NewFieldType(mysql.TypeBlob) 93 | ft.Flen = 10 94 | ft.Charset = "UTF8" 95 | ft.Collate = "UTF8_UNICODE_GI" 96 | c.Assert(ft.String(), Equals, "text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI") 97 | c.Assert(HasCharset(ft), IsTrue) 98 | 99 | ft = NewFieldType(mysql.TypeVarchar) 100 | ft.Flen = 10 101 | ft.Flag |= mysql.BinaryFlag 102 | c.Assert(ft.String(), Equals, "varchar(10) BINARY") 103 | c.Assert(HasCharset(ft), IsFalse) 104 | 105 | ft = NewFieldType(mysql.TypeString) 106 | ft.Charset = charset.CollationBin 107 | ft.Flag |= mysql.BinaryFlag 108 | c.Assert(ft.String(), Equals, "binary(1)") 109 | c.Assert(HasCharset(ft), IsFalse) 110 | 111 | ft = NewFieldType(mysql.TypeEnum) 112 | ft.Elems = []string{"a", "b"} 113 | c.Assert(ft.String(), Equals, "enum('a','b')") 114 | c.Assert(HasCharset(ft), IsTrue) 115 | 116 | ft = NewFieldType(mysql.TypeEnum) 117 | ft.Elems = []string{"'a'", "'b'"} 118 | c.Assert(ft.String(), Equals, "enum('''a''','''b''')") 119 | c.Assert(HasCharset(ft), IsTrue) 120 | 121 | ft = NewFieldType(mysql.TypeEnum) 122 | ft.Elems = []string{"a\nb", "a\tb", "a\rb"} 123 | c.Assert(ft.String(), Equals, "enum('a\\nb','a\tb','a\\rb')") 124 | c.Assert(HasCharset(ft), IsTrue) 125 | 126 | ft = NewFieldType(mysql.TypeEnum) 127 | ft.Elems = []string{"a\nb", "a'\t\r\nb", "a\rb"} 128 | c.Assert(ft.String(), Equals, "enum('a\\nb','a'' \\r\\nb','a\\rb')") 129 | c.Assert(HasCharset(ft), IsTrue) 130 | 131 | ft = NewFieldType(mysql.TypeSet) 132 | ft.Elems = []string{"a", "b"} 133 | c.Assert(ft.String(), Equals, "set('a','b')") 134 | c.Assert(HasCharset(ft), IsTrue) 135 | 136 | ft = NewFieldType(mysql.TypeSet) 137 | ft.Elems = []string{"'a'", "'b'"} 138 | c.Assert(ft.String(), Equals, "set('''a''','''b''')") 139 | c.Assert(HasCharset(ft), IsTrue) 140 | 141 | ft = NewFieldType(mysql.TypeSet) 142 | ft.Elems = []string{"a\nb", "a'\t\r\nb", "a\rb"} 143 | c.Assert(ft.String(), Equals, "set('a\\nb','a'' \\r\\nb','a\\rb')") 144 | c.Assert(HasCharset(ft), IsTrue) 145 | 146 | ft = NewFieldType(mysql.TypeSet) 147 | ft.Elems = []string{"a'\nb", "a'b\tc"} 148 | c.Assert(ft.String(), Equals, "set('a''\\nb','a''b c')") 149 | c.Assert(HasCharset(ft), IsTrue) 150 | 151 | ft = NewFieldType(mysql.TypeTimestamp) 152 | ft.Flen = 8 153 | ft.Decimal = 2 154 | c.Assert(ft.String(), Equals, "timestamp(2)") 155 | c.Assert(HasCharset(ft), IsFalse) 156 | ft = NewFieldType(mysql.TypeTimestamp) 157 | ft.Flen = 8 158 | ft.Decimal = 0 159 | c.Assert(ft.String(), Equals, "timestamp") 160 | c.Assert(HasCharset(ft), IsFalse) 161 | 162 | ft = NewFieldType(mysql.TypeDatetime) 163 | ft.Flen = 8 164 | ft.Decimal = 2 165 | c.Assert(ft.String(), Equals, "datetime(2)") 166 | c.Assert(HasCharset(ft), IsFalse) 167 | ft = NewFieldType(mysql.TypeDatetime) 168 | ft.Flen = 8 169 | ft.Decimal = 0 170 | c.Assert(ft.String(), Equals, "datetime") 171 | c.Assert(HasCharset(ft), IsFalse) 172 | 173 | ft = NewFieldType(mysql.TypeDate) 174 | ft.Flen = 8 175 | ft.Decimal = 2 176 | c.Assert(ft.String(), Equals, "date") 177 | c.Assert(HasCharset(ft), IsFalse) 178 | ft = NewFieldType(mysql.TypeDate) 179 | ft.Flen = 8 180 | ft.Decimal = 0 181 | c.Assert(ft.String(), Equals, "date") 182 | c.Assert(HasCharset(ft), IsFalse) 183 | 184 | ft = NewFieldType(mysql.TypeYear) 185 | ft.Flen = 4 186 | ft.Decimal = 0 187 | c.Assert(ft.String(), Equals, "year(4)") 188 | c.Assert(HasCharset(ft), IsFalse) 189 | ft = NewFieldType(mysql.TypeYear) 190 | ft.Flen = 2 191 | ft.Decimal = 2 192 | c.Assert(ft.String(), Equals, "year(2)") // Note: Invalid year. 193 | c.Assert(HasCharset(ft), IsFalse) 194 | 195 | ft = NewFieldType(mysql.TypeVarchar) 196 | ft.Flen = 0 197 | ft.Decimal = 0 198 | c.Assert(ft.String(), Equals, "varchar(0)") 199 | c.Assert(HasCharset(ft), IsTrue) 200 | 201 | ft = NewFieldType(mysql.TypeString) 202 | ft.Flen = 0 203 | ft.Decimal = 0 204 | c.Assert(ft.String(), Equals, "char(0)") 205 | c.Assert(HasCharset(ft), IsTrue) 206 | } 207 | 208 | func (s *testFieldTypeSuite) TestHasCharsetFromStmt(c *C) { 209 | template := "CREATE TABLE t(a %s)" 210 | 211 | types := []struct { 212 | strType string 213 | hasCharset bool 214 | }{ 215 | {"int", false}, 216 | {"real", false}, 217 | {"float", false}, 218 | {"bit", false}, 219 | {"bool", false}, 220 | {"char(1)", true}, 221 | {"national char(1)", true}, 222 | {"binary", false}, 223 | {"varchar(1)", true}, 224 | {"national varchar(1)", true}, 225 | {"varbinary(1)", false}, 226 | {"year", false}, 227 | {"date", false}, 228 | {"time", false}, 229 | {"datetime", false}, 230 | {"timestamp", false}, 231 | {"blob", false}, 232 | {"tinyblob", false}, 233 | {"mediumblob", false}, 234 | {"longblob", false}, 235 | {"bit", false}, 236 | {"text", true}, 237 | {"tinytext", true}, 238 | {"mediumtext", true}, 239 | {"longtext", true}, 240 | {"json", false}, 241 | {"enum('1')", true}, 242 | {"set('1')", true}, 243 | } 244 | 245 | p := parser.New() 246 | for _, t := range types { 247 | sql := fmt.Sprintf(template, t.strType) 248 | stmt, err := p.ParseOneStmt(sql, "", "") 249 | c.Assert(err, IsNil) 250 | 251 | col := stmt.(*ast.CreateTableStmt).Cols[0] 252 | c.Assert(HasCharset(col.Tp), Equals, t.hasCharset) 253 | } 254 | } 255 | 256 | func (s *testFieldTypeSuite) TestEnumSetFlen(c *C) { 257 | p := parser.New() 258 | cases := []struct { 259 | sql string 260 | ex int 261 | }{ 262 | {"enum('a')", 1}, 263 | {"enum('a', 'b')", 1}, 264 | {"enum('a', 'bb')", 2}, 265 | {"enum('a', 'b', 'c')", 1}, 266 | {"enum('a', 'bb', 'c')", 2}, 267 | {"enum('a', 'bb', 'c')", 2}, 268 | {"enum('')", 0}, 269 | {"enum('a', '')", 1}, 270 | {"set('a')", 1}, 271 | {"set('a', 'b')", 3}, 272 | {"set('a', 'bb')", 4}, 273 | {"set('a', 'b', 'c')", 5}, 274 | {"set('a', 'bb', 'c')", 6}, 275 | {"set('')", 0}, 276 | {"set('a', '')", 2}, 277 | } 278 | 279 | for _, ca := range cases { 280 | stmt, err := p.ParseOneStmt(fmt.Sprintf("create table t (e %v)", ca.sql), "", "") 281 | c.Assert(err, IsNil) 282 | col := stmt.(*ast.CreateTableStmt).Cols[0] 283 | c.Assert(col.Tp.Flen, Equals, ca.ex) 284 | 285 | } 286 | } 287 | 288 | func (s *testFieldTypeSuite) TestFieldTypeEqual(c *C) { 289 | 290 | // Tp not equal 291 | ft1 := NewFieldType(mysql.TypeDouble) 292 | ft2 := NewFieldType(mysql.TypeFloat) 293 | c.Assert(ft1.Equal(ft2), Equals, false) 294 | 295 | // Decimal not equal 296 | ft2 = NewFieldType(mysql.TypeDouble) 297 | ft2.Decimal = 5 298 | c.Assert(ft1.Equal(ft2), Equals, false) 299 | 300 | // Flen not equal and decimal not -1 301 | ft1.Decimal = 5 302 | ft1.Flen = 22 303 | c.Assert(ft1.Equal(ft2), Equals, false) 304 | 305 | // Flen equal 306 | ft2.Flen = 22 307 | c.Assert(ft1.Equal(ft2), Equals, true) 308 | 309 | // Decimal is -1 310 | ft1.Decimal = -1 311 | ft2.Decimal = -1 312 | ft1.Flen = 23 313 | c.Assert(ft1.Equal(ft2), Equals, true) 314 | } 315 | --------------------------------------------------------------------------------