├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── driver ├── client.go ├── client_test.go ├── mock.go ├── rows.go ├── rows_test.go ├── server.go ├── server_test.go ├── session.go ├── session_test.go ├── statement.go └── statement_test.go ├── examples ├── client.go └── mysqld.go ├── go.mod ├── go.sum ├── makefile ├── packet ├── error.go ├── mock.go ├── packets.go ├── packets_test.go ├── stream.go └── stream_test.go ├── proto ├── auth.go ├── auth_test.go ├── column.go ├── column_test.go ├── const.go ├── eof.go ├── eof_test.go ├── err.go ├── err_test.go ├── greeting.go ├── greeting_test.go ├── ok.go ├── ok_test.go ├── statement.go └── statement_test.go ├── sqldb ├── constants.go ├── constants_test.go ├── sql_error.go └── sql_error_test.go ├── sqlparser ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── sqlparser.iml │ ├── vcs.xml │ └── workspace.xml ├── Makefile ├── analyzer.go ├── analyzer_test.go ├── ast.go ├── ast_funcs.go ├── ast_test.go ├── checksum_test.go ├── comments.go ├── comments_test.go ├── constants.go ├── ddl_test.go ├── depends │ ├── bytes2 │ │ ├── buffer.go │ │ └── buffer_test.go │ ├── common │ │ ├── buffer.go │ │ ├── buffer_test.go │ │ ├── hash_table.go │ │ ├── hash_table_test.go │ │ ├── unsafe.go │ │ └── unsafe_test.go │ ├── query │ │ └── query.pb.go │ └── sqltypes │ │ ├── aggregation.go │ │ ├── aggregation_test.go │ │ ├── arithmetic.go │ │ ├── arithmetic_test.go │ │ ├── bind_variables.go │ │ ├── bind_variables_test.go │ │ ├── column.go │ │ ├── column_test.go │ │ ├── const.go │ │ ├── limit.go │ │ ├── limit_test.go │ │ ├── plan_value.go │ │ ├── result.go │ │ ├── result_test.go │ │ ├── row.go │ │ ├── time.go │ │ ├── time_test.go │ │ ├── type.go │ │ ├── type_test.go │ │ ├── value.go │ │ └── value_test.go ├── encodable.go ├── encodable_test.go ├── explain_test.go ├── impossible_query.go ├── impossible_query_test.go ├── kill_test.go ├── normalizer.go ├── normalizer_test.go ├── parse_test.go ├── parsed_query.go ├── parsed_query_test.go ├── parser.go ├── precedence_test.go ├── radon_test.go ├── rewriter.go ├── rewriter_api.go ├── select_test.go ├── set_test.go ├── show_test.go ├── sql.go ├── sql.y ├── token.go ├── token_test.go ├── tracked_buffer.go ├── tracked_buffer_test.go ├── txn_test.go ├── visitorgen │ ├── ast_walker.go │ ├── ast_walker_test.go │ ├── main │ │ └── main.go │ ├── sast.go │ ├── struct_producer.go │ ├── struct_producer_test.go │ ├── transformer.go │ ├── transformer_test.go │ ├── visitor_emitter.go │ ├── visitor_emitter_test.go │ └── visitorgen.go └── xa_test.go └── xlog ├── options.go ├── syslog.go ├── syslog_windows.go ├── xlog.go └── xlog_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | tags 2 | bin/* 3 | *.output 4 | coverage.* 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: required 3 | go: 4 | - 1.x 5 | 6 | before_install: 7 | - go get github.com/shopspring/decimal 8 | - go get github.com/pierrre/gotestcover 9 | - go get github.com/stretchr/testify/assert 10 | 11 | script: 12 | - make test 13 | - make coverage 14 | 15 | after_success: 16 | # send coverage reports to Codecov 17 | - bash <(curl -s https://codecov.io/bash) -f "!mock.go" 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, xelabs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/xelabs/go-mysqlstack.png)](https://travis-ci.org/xelabs/go-mysqlstack) [![Go Report Card](https://goreportcard.com/badge/github.com/xelabs/go-mysqlstack)](https://goreportcard.com/report/github.com/xelabs/go-mysqlstack) [![codecov.io](https://codecov.io/gh/xelabs/go-mysqlstack/graphs/badge.svg)](https://codecov.io/gh/xelabs/go-mysqlstack/branch/master) 2 | 3 | # go-mysqlstack 4 | 5 | ***go-mysqlstack*** is an MySQL protocol library implementing in Go (golang). 6 | 7 | Protocol is based on [mysqlproto-go](https://github.com/pubnative/mysqlproto-go) and [go-sql-driver](https://github.com/go-sql-driver/mysql) 8 | 9 | ## Running Tests 10 | 11 | ``` 12 | $ mkdir src 13 | $ export GOPATH=`pwd` 14 | $ go get -u github.com/xelabs/go-mysqlstack/driver 15 | $ cd src/github.com/xelabs/go-mysqlstack/ 16 | $ make test 17 | ``` 18 | 19 | ## Examples 20 | 21 | 1. ***examples/mysqld.go*** mocks a MySQL server by running: 22 | 23 | ``` 24 | $ go run example/mysqld.go 25 | 2018/01/26 16:02:02.304376 mysqld.go:52: [INFO] mysqld.server.start.address[:4407] 26 | ``` 27 | 28 | 2. ***examples/client.go*** mocks a client and query from the mock MySQL server: 29 | 30 | ``` 31 | $ go run example/client.go 32 | 2018/01/26 16:06:10.779340 client.go:32: [INFO] results:[[[10 nice name]]] 33 | ``` 34 | 35 | ## Status 36 | 37 | go-mysqlstack is production ready. 38 | 39 | ## License 40 | 41 | go-mysqlstack is released under the BSD-3-Clause License. See [LICENSE](https://github.com/xelabs/go-mysqlstack/blob/master/LICENSE) 42 | -------------------------------------------------------------------------------- /driver/rows.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package driver 11 | 12 | import ( 13 | "errors" 14 | "fmt" 15 | 16 | "github.com/xelabs/go-mysqlstack/proto" 17 | 18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 20 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 21 | ) 22 | 23 | var _ Rows = &TextRows{} 24 | 25 | type RowMode int 26 | 27 | const ( 28 | TextRowMode RowMode = iota 29 | BinaryRowMode 30 | ) 31 | 32 | // Rows presents row cursor interface. 33 | type Rows interface { 34 | Next() bool 35 | Close() error 36 | Datas() []byte 37 | Bytes() int 38 | RowsAffected() uint64 39 | LastInsertID() uint64 40 | LastError() error 41 | Fields() []*querypb.Field 42 | RowValues() ([]sqltypes.Value, error) 43 | } 44 | 45 | // BaseRows -- 46 | type BaseRows struct { 47 | c Conn 48 | end bool 49 | err error 50 | data []byte 51 | bytes int 52 | rowsAffected uint64 53 | insertID uint64 54 | buffer *common.Buffer 55 | fields []*querypb.Field 56 | } 57 | 58 | // TextRows presents row tuple. 59 | type TextRows struct { 60 | BaseRows 61 | } 62 | 63 | // BinaryRows presents binary row tuple. 64 | type BinaryRows struct { 65 | BaseRows 66 | } 67 | 68 | // Next implements the Rows interface. 69 | // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 70 | func (r *BaseRows) Next() bool { 71 | defer func() { 72 | if r.err != nil { 73 | r.c.Cleanup() 74 | } 75 | }() 76 | 77 | if r.end { 78 | return false 79 | } 80 | 81 | // if fields count is 0 82 | // the packet is OK-Packet without Resultset. 83 | if len(r.fields) == 0 { 84 | r.end = true 85 | return false 86 | } 87 | 88 | if r.data, r.err = r.c.NextPacket(); r.err != nil { 89 | r.end = true 90 | return false 91 | } 92 | 93 | switch r.data[0] { 94 | case proto.EOF_PACKET: 95 | // This packet may be one of two kinds: 96 | // - an EOF packet, 97 | // - an OK packet with an EOF header if 98 | // sqldb.CLIENT_DEPRECATE_EOF is set. 99 | r.end = true 100 | return false 101 | 102 | case proto.ERR_PACKET: 103 | r.err = proto.UnPackERR(r.data) 104 | r.end = true 105 | return false 106 | } 107 | r.buffer.Reset(r.data) 108 | return true 109 | } 110 | 111 | // Close drain the rest packets and check the error. 112 | func (r *BaseRows) Close() error { 113 | for r.Next() { 114 | } 115 | return r.LastError() 116 | } 117 | 118 | // RowValues implements the Rows interface. 119 | // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow 120 | func (r *BaseRows) RowValues() ([]sqltypes.Value, error) { 121 | if r.fields == nil { 122 | return nil, errors.New("rows.fields is NIL") 123 | } 124 | 125 | colNumber := len(r.fields) 126 | result := make([]sqltypes.Value, colNumber) 127 | for i := 0; i < colNumber; i++ { 128 | v, err := r.buffer.ReadLenEncodeBytes() 129 | if err != nil { 130 | r.c.Cleanup() 131 | return nil, err 132 | } 133 | 134 | if v != nil { 135 | r.bytes += len(v) 136 | result[i] = sqltypes.MakeTrusted(r.fields[i].Type, v) 137 | } 138 | } 139 | return result, nil 140 | } 141 | 142 | // Datas implements the Rows interface. 143 | func (r *BaseRows) Datas() []byte { 144 | return r.buffer.Datas() 145 | } 146 | 147 | // Fields implements the Rows interface. 148 | func (r *BaseRows) Fields() []*querypb.Field { 149 | return r.fields 150 | } 151 | 152 | // Bytes returns all the memory usage which read by this row cursor. 153 | func (r *BaseRows) Bytes() int { 154 | return r.bytes 155 | } 156 | 157 | // RowsAffected implements the Rows interface. 158 | func (r *BaseRows) RowsAffected() uint64 { 159 | return r.rowsAffected 160 | } 161 | 162 | // LastInsertID implements the Rows interface. 163 | func (r *BaseRows) LastInsertID() uint64 { 164 | return r.insertID 165 | } 166 | 167 | // LastError implements the Rows interface. 168 | func (r *BaseRows) LastError() error { 169 | return r.err 170 | } 171 | 172 | // NewTextRows creates TextRows. 173 | func NewTextRows(c Conn) *TextRows { 174 | textRows := &TextRows{} 175 | textRows.c = c 176 | textRows.buffer = common.NewBuffer(8) 177 | return textRows 178 | } 179 | 180 | // NewBinaryRows creates BinaryRows. 181 | func NewBinaryRows(c Conn) *BinaryRows { 182 | binaryRows := &BinaryRows{} 183 | binaryRows.c = c 184 | binaryRows.buffer = common.NewBuffer(8) 185 | return binaryRows 186 | } 187 | 188 | // RowValues implements the Rows interface. 189 | // https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html 190 | func (r *BinaryRows) RowValues() ([]sqltypes.Value, error) { 191 | if r.fields == nil { 192 | return nil, errors.New("rows.fields is NIL") 193 | } 194 | 195 | header, err := r.buffer.ReadU8() 196 | if err != nil { 197 | return nil, err 198 | } 199 | if header != proto.OK_PACKET { 200 | return nil, fmt.Errorf("binary.rows.header.is.not.ok[%v]", header) 201 | } 202 | 203 | colCount := len(r.fields) 204 | // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] 205 | nullMask, err := r.buffer.ReadBytes(int((colCount + 7 + 2) / 8)) 206 | if err != nil { 207 | return nil, err 208 | } 209 | 210 | result := make([]sqltypes.Value, colCount) 211 | for i := 0; i < colCount; i++ { 212 | // Field is NULL 213 | // (byte >> bit-pos) % 2 == 1 214 | if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { 215 | result[i] = sqltypes.Value{} 216 | continue 217 | } 218 | 219 | v, err := sqltypes.ParseMySQLValues(r.buffer, r.fields[i].Type) 220 | if err != nil { 221 | r.c.Cleanup() 222 | return nil, err 223 | } 224 | 225 | if v != nil { 226 | val, err := sqltypes.BuildValue(v) 227 | if err != nil { 228 | r.c.Cleanup() 229 | return nil, err 230 | } 231 | r.bytes += val.Len() 232 | result[i] = val 233 | } else { 234 | result[i] = sqltypes.Value{} 235 | } 236 | } 237 | return result, nil 238 | } 239 | -------------------------------------------------------------------------------- /driver/rows_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package driver 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | 17 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 19 | "github.com/xelabs/go-mysqlstack/xlog" 20 | ) 21 | 22 | func TestRows(t *testing.T) { 23 | result1 := &sqltypes.Result{ 24 | Fields: []*querypb.Field{ 25 | { 26 | Name: "id", 27 | Type: querypb.Type_INT32, 28 | }, 29 | { 30 | Name: "name", 31 | Type: querypb.Type_VARCHAR, 32 | }, 33 | }, 34 | Rows: [][]sqltypes.Value{ 35 | { 36 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), 37 | sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), 38 | }, 39 | { 40 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")), 41 | sqltypes.NULL, 42 | }, 43 | }, 44 | } 45 | result2 := &sqltypes.Result{ 46 | RowsAffected: 123, 47 | InsertID: 123456789, 48 | } 49 | result3 := &sqltypes.Result{ 50 | Fields: []*querypb.Field{ 51 | { 52 | Name: "name", 53 | Type: querypb.Type_VARCHAR, 54 | }, 55 | }, 56 | Rows: [][]sqltypes.Value{ 57 | { 58 | sqltypes.NULL, 59 | }, 60 | }, 61 | } 62 | 63 | log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) 64 | th := NewTestHandler(log) 65 | svr, err := MockMysqlServer(log, th) 66 | assert.Nil(t, err) 67 | defer svr.Close() 68 | address := svr.Addr() 69 | 70 | // query 71 | { 72 | client, err := NewConn("mock", "mock", address, "test", "") 73 | assert.Nil(t, err) 74 | defer client.Close() 75 | 76 | th.AddQuery("SELECT2", result2) 77 | rows, err := client.Query("SELECT2") 78 | assert.Nil(t, err) 79 | 80 | assert.Equal(t, uint64(123), rows.RowsAffected()) 81 | assert.Equal(t, uint64(123456789), rows.LastInsertID()) 82 | } 83 | 84 | // query 85 | { 86 | client, err := NewConn("mock", "mock", address, "test", "") 87 | assert.Nil(t, err) 88 | defer client.Close() 89 | 90 | th.AddQuery("SELECT1", result1) 91 | rows, err := client.Query("SELECT1") 92 | assert.Nil(t, err) 93 | assert.Equal(t, result1.Fields, rows.Fields()) 94 | for rows.Next() { 95 | _ = rows.Datas() 96 | _, _ = rows.RowValues() 97 | } 98 | 99 | want := 13 100 | got := int(rows.Bytes()) 101 | assert.Equal(t, want, got) 102 | } 103 | 104 | // query 105 | { 106 | client, err := NewConn("mock", "mock", address, "test", "") 107 | assert.Nil(t, err) 108 | defer client.Close() 109 | 110 | th.AddQuery("SELECT3", result3) 111 | rows, err := client.Query("SELECT3") 112 | assert.Nil(t, err) 113 | assert.Equal(t, result3.Fields, rows.Fields()) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /driver/session_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package driver 11 | 12 | import ( 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "github.com/xelabs/go-mysqlstack/xlog" 18 | ) 19 | 20 | func TestSession(t *testing.T) { 21 | log := xlog.NewStdLog(xlog.Level(xlog.DEBUG)) 22 | th := NewTestHandler(log) 23 | svr, err := MockMysqlServer(log, th) 24 | assert.Nil(t, err) 25 | address := svr.Addr() 26 | 27 | // create session 1 28 | client, err := NewConn("mock", "mock", address, "test", "") 29 | assert.Nil(t, err) 30 | defer client.Close() 31 | 32 | var sessions []*Session 33 | for _, s := range th.ss { 34 | sessions = append(sessions, s.session) 35 | } 36 | 37 | { 38 | session1 := sessions[0] 39 | 40 | // Session ID. 41 | { 42 | log.Debug("--id:%v", session1.ID()) 43 | log.Debug("--addr:%v", session1.Addr()) 44 | log.Debug("--salt:%v", session1.Salt()) 45 | log.Debug("--scramble:%v", session1.Scramble()) 46 | } 47 | 48 | // schema. 49 | { 50 | want := "xx" 51 | session1.SetSchema(want) 52 | got := session1.Schema() 53 | assert.Equal(t, want, got) 54 | } 55 | 56 | // charset. 57 | { 58 | want := uint8(0x21) 59 | got := session1.Charset() 60 | assert.Equal(t, want, got) 61 | } 62 | 63 | // UpdateTime. 64 | { 65 | want := time.Now() 66 | session1.updateLastQueryTime(want) 67 | got := session1.LastQueryTime() 68 | assert.Equal(t, want, got) 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /driver/statement.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package driver 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/proto" 14 | "github.com/xelabs/go-mysqlstack/sqldb" 15 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 16 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 17 | ) 18 | 19 | // Statement -- 20 | type Statement struct { 21 | conn *conn 22 | ID uint32 23 | ParamCount uint16 24 | PrepareStmt string 25 | ParamsType []int32 26 | ColumnNames []string 27 | BindVars map[string]*querypb.BindVariable 28 | } 29 | 30 | // ComStatementExecute -- statement execute write. 31 | func (s *Statement) ComStatementExecute(parameters []sqltypes.Value) error { 32 | var err error 33 | var datas []byte 34 | var iRows Rows 35 | 36 | if datas, err = proto.PackStatementExecute(s.ID, parameters); err != nil { 37 | return err 38 | } 39 | 40 | if iRows, err = s.conn.stmtQuery(sqldb.COM_STMT_EXECUTE, datas); err != nil { 41 | return err 42 | } 43 | for iRows.Next() { 44 | if _, err := iRows.RowValues(); err != nil { 45 | s.conn.Cleanup() 46 | return err 47 | } 48 | } 49 | // Drain the results and check last error. 50 | if err := iRows.Close(); err != nil { 51 | s.conn.Cleanup() 52 | return err 53 | } 54 | return nil 55 | } 56 | 57 | // ComStatementExecute -- statement execute write. 58 | func (s *Statement) ComStatementQuery(parameters []sqltypes.Value) (*sqltypes.Result, error) { 59 | var err error 60 | var datas []byte 61 | var iRows Rows 62 | var qrRow []sqltypes.Value 63 | var qrRows [][]sqltypes.Value 64 | 65 | if datas, err = proto.PackStatementExecute(s.ID, parameters); err != nil { 66 | return nil, err 67 | } 68 | 69 | if iRows, err = s.conn.stmtQuery(sqldb.COM_STMT_EXECUTE, datas); err != nil { 70 | return nil, err 71 | } 72 | for iRows.Next() { 73 | if qrRow, err = iRows.RowValues(); err != nil { 74 | s.conn.Cleanup() 75 | return nil, err 76 | } 77 | if qrRow != nil { 78 | qrRows = append(qrRows, qrRow) 79 | } 80 | } 81 | // Drain the results and check last error. 82 | if err := iRows.Close(); err != nil { 83 | s.conn.Cleanup() 84 | return nil, err 85 | } 86 | 87 | rowsAffected := iRows.RowsAffected() 88 | if rowsAffected == 0 { 89 | rowsAffected = uint64(len(qrRows)) 90 | } 91 | qr := &sqltypes.Result{ 92 | Fields: iRows.Fields(), 93 | RowsAffected: rowsAffected, 94 | InsertID: iRows.LastInsertID(), 95 | Rows: qrRows, 96 | } 97 | return qr, err 98 | } 99 | 100 | // ComStatementReset -- reset the stmt. 101 | func (s *Statement) ComStatementReset() error { 102 | var data [4]byte 103 | 104 | // Add arg [32 bit] 105 | data[0] = byte(s.ID) 106 | data[1] = byte(s.ID >> 8) 107 | data[2] = byte(s.ID >> 16) 108 | data[3] = byte(s.ID >> 24) 109 | if err := s.conn.packets.WriteCommand(sqldb.COM_STMT_RESET, data[:]); err != nil { 110 | return err 111 | } 112 | return s.conn.packets.ReadOK() 113 | } 114 | 115 | // ComStatementClose -- close the stmt. 116 | func (s *Statement) ComStatementClose() error { 117 | var data [4]byte 118 | 119 | // Add arg [32 bit] 120 | data[0] = byte(s.ID) 121 | data[1] = byte(s.ID >> 8) 122 | data[2] = byte(s.ID >> 16) 123 | data[3] = byte(s.ID >> 24) 124 | if err := s.conn.packets.WriteCommand(sqldb.COM_STMT_CLOSE, data[:]); err != nil { 125 | return err 126 | } 127 | return nil 128 | } 129 | -------------------------------------------------------------------------------- /driver/statement_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package driver 11 | 12 | import ( 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "github.com/xelabs/go-mysqlstack/xlog" 18 | 19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 20 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 21 | ) 22 | 23 | func TestStatement(t *testing.T) { 24 | log := xlog.NewStdLog(xlog.Level(xlog.DEBUG)) 25 | th := NewTestHandler(log) 26 | svr, err := MockMysqlServer(log, th) 27 | assert.Nil(t, err) 28 | defer svr.Close() 29 | address := svr.Addr() 30 | 31 | result1 := &sqltypes.Result{ 32 | Fields: []*querypb.Field{ 33 | { 34 | Name: "a", 35 | Type: sqltypes.Int32, 36 | }, 37 | { 38 | Name: "b", 39 | Type: sqltypes.VarChar, 40 | }, 41 | { 42 | Name: "c", 43 | Type: sqltypes.Datetime, 44 | }, 45 | { 46 | Name: "d", 47 | Type: sqltypes.Time, 48 | }, 49 | { 50 | Name: "e", 51 | Type: sqltypes.VarChar, 52 | }, 53 | }, 54 | Rows: [][]sqltypes.Value{ 55 | { 56 | sqltypes.MakeTrusted(sqltypes.Int32, []byte("10")), 57 | sqltypes.MakeTrusted(sqltypes.VarChar, []byte("xx10xx")), 58 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))), 59 | sqltypes.MakeTrusted(sqltypes.Time, []byte("15:04:05")), 60 | sqltypes.MakeTrusted(sqltypes.VarChar, nil), 61 | }, 62 | }, 63 | } 64 | result2 := &sqltypes.Result{} 65 | th.AddQueryPattern("drop table if .*", result2) 66 | th.AddQueryPattern("create table if .*", result2) 67 | th.AddQueryPattern("insert .*", result2) 68 | th.AddQueryPattern("select .*", result1) 69 | 70 | // query 71 | { 72 | client, err := NewConn("mock", "mock", address, "test", "") 73 | //client, err := NewConn("root", "", "127.0.0.1:3307", "test", "") 74 | assert.Nil(t, err) 75 | defer client.Close() 76 | 77 | query := "drop table if exists t1" 78 | err = client.Exec(query) 79 | assert.Nil(t, err) 80 | 81 | query = "create table if not exists t1 (a int, b varchar(20), c datetime, d time, e varchar(20))" 82 | err = client.Exec(query) 83 | assert.Nil(t, err) 84 | 85 | // Prepare Insert. 86 | { 87 | query = "insert into t1(a, b, c, d, e) values(?,?,?,?,?)" 88 | stmt, err := client.ComStatementPrepare(query) 89 | assert.Nil(t, err) 90 | log.Debug("stmt:%+v", stmt) 91 | 92 | params := []sqltypes.Value{ 93 | sqltypes.NewInt32(11), 94 | sqltypes.NewVarChar("xx10xx"), 95 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))), 96 | sqltypes.MakeTrusted(sqltypes.Time, []byte("15:04:05")), 97 | sqltypes.MakeTrusted(sqltypes.VarChar, nil), 98 | } 99 | err = stmt.ComStatementExecute(params) 100 | assert.Nil(t, err) 101 | stmt.ComStatementClose() 102 | } 103 | 104 | // Normal Select int. 105 | { 106 | query = "select * from t1 where a=10" 107 | qr, err := client.FetchAll(query, -1) 108 | assert.Nil(t, err) 109 | log.Debug("normal:%+v", qr) 110 | } 111 | 112 | { 113 | query = "select * from t1 where a=10" 114 | qr, err := client.FetchAll(query, -1) 115 | assert.Nil(t, err) 116 | log.Debug("normal:%+v", qr) 117 | } 118 | 119 | // Prepare Select int. 120 | { 121 | query = "select * from t1 where a=?" 122 | stmt, err := client.ComStatementPrepare(query) 123 | assert.Nil(t, err) 124 | assert.NotNil(t, stmt) 125 | log.Debug("stmt:%+v", stmt) 126 | 127 | params := []sqltypes.Value{ 128 | sqltypes.NewInt32(11), 129 | } 130 | qr, err := stmt.ComStatementQuery(params) 131 | assert.Nil(t, err) 132 | log.Debug("%+v", qr) 133 | stmt.ComStatementClose() 134 | } 135 | 136 | // Prepare Select int. 137 | { 138 | query = "select * from t1 where a=?" 139 | stmt, err := client.ComStatementPrepare(query) 140 | assert.Nil(t, err) 141 | log.Debug("stmt:%+v", stmt) 142 | 143 | params := []sqltypes.Value{ 144 | sqltypes.NewInt32(11), 145 | } 146 | qr, err := stmt.ComStatementQuery(params) 147 | assert.Nil(t, err) 148 | log.Debug("%+v", qr) 149 | stmt.ComStatementClose() 150 | } 151 | 152 | // Prepare Select time. 153 | { 154 | query = "select a,b,c,d,e from t1 where c=?" 155 | stmt, err := client.ComStatementPrepare(query) 156 | assert.Nil(t, err) 157 | log.Debug("stmt:%+v", stmt) 158 | 159 | params := []sqltypes.Value{ 160 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))), 161 | } 162 | qr, err := stmt.ComStatementQuery(params) 163 | assert.Nil(t, err) 164 | log.Debug("%+v", qr) 165 | stmt.ComStatementReset() 166 | stmt.ComStatementClose() 167 | } 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /examples/client.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package main 11 | 12 | import ( 13 | "fmt" 14 | 15 | "github.com/xelabs/go-mysqlstack/driver" 16 | "github.com/xelabs/go-mysqlstack/xlog" 17 | ) 18 | 19 | func main() { 20 | log := xlog.NewStdLog(xlog.Level(xlog.INFO)) 21 | address := fmt.Sprintf(":4407") 22 | client, err := driver.NewConn("mock", "mock", address, "", "") 23 | if err != nil { 24 | log.Panic("client.new.connection.error:%+v", err) 25 | } 26 | defer client.Close() 27 | 28 | qr, err := client.FetchAll("SELECT * FROM MOCK", -1) 29 | if err != nil { 30 | log.Panic("client.query.error:%+v", err) 31 | } 32 | log.Info("results:[%+v]", qr.Rows) 33 | } 34 | -------------------------------------------------------------------------------- /examples/mysqld.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package main 11 | 12 | import ( 13 | "os" 14 | "os/signal" 15 | "syscall" 16 | 17 | "github.com/xelabs/go-mysqlstack/driver" 18 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 19 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 20 | "github.com/xelabs/go-mysqlstack/xlog" 21 | ) 22 | 23 | func main() { 24 | result1 := &sqltypes.Result{ 25 | Fields: []*querypb.Field{ 26 | { 27 | Name: "id", 28 | Type: querypb.Type_INT32, 29 | }, 30 | { 31 | Name: "name", 32 | Type: querypb.Type_VARCHAR, 33 | }, 34 | }, 35 | Rows: [][]sqltypes.Value{ 36 | { 37 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), 38 | sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), 39 | }, 40 | }, 41 | } 42 | 43 | log := xlog.NewStdLog(xlog.Level(xlog.INFO)) 44 | th := driver.NewTestHandler(log) 45 | th.AddQuery("SELECT * FROM MOCK", result1) 46 | 47 | mysqld, err := driver.MockMysqlServerWithPort(log, 4407, th) 48 | if err != nil { 49 | log.Panic("mysqld.start.error:%+v", err) 50 | } 51 | defer mysqld.Close() 52 | log.Info("mysqld.server.start.address[%v]", mysqld.Addr()) 53 | 54 | // Handle SIGINT and SIGTERM. 55 | ch := make(chan os.Signal) 56 | signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) 57 | <-ch 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xelabs/go-mysqlstack 2 | 3 | require ( 4 | github.com/shopspring/decimal v1.2.0 5 | github.com/stretchr/testify v1.7.0 6 | ) 7 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= 6 | github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= 7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 8 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 9 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 14 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | export PATH := $(GOPATH)/bin:$(PATH) 2 | 3 | fmt: 4 | go fmt ./... 5 | go vet ./... 6 | 7 | test: 8 | go get github.com/stretchr/testify/assert 9 | @echo "--> Testing..." 10 | @$(MAKE) testxlog 11 | @$(MAKE) testsqlparser 12 | @$(MAKE) testsqldb 13 | @$(MAKE) testproto 14 | @$(MAKE) testpacket 15 | @$(MAKE) testdriver 16 | 17 | testxlog: 18 | go test -v ./xlog 19 | testsqlparser: 20 | go test -v ./sqlparser/... 21 | testsqldb: 22 | go test -v ./sqldb 23 | testproto: 24 | go test -v ./proto 25 | testpacket: 26 | go test -v ./packet 27 | testdriver: 28 | go test -v ./driver 29 | 30 | COVPKGS = ./sqlparser/... ./sqldb ./proto ./packet ./driver 31 | coverage: 32 | go get github.com/pierrre/gotestcover 33 | gotestcover -coverprofile=coverage.out -v $(COVPKGS) 34 | go tool cover -html=coverage.out 35 | 36 | .PHONY: fmt testcommon testproto testpacket testdriver coverage 37 | -------------------------------------------------------------------------------- /packet/error.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package packet 11 | 12 | import ( 13 | "errors" 14 | ) 15 | 16 | var ( 17 | // ErrBadConn used for the error of bad connection. 18 | ErrBadConn = errors.New("connection.was.bad") 19 | // ErrMalformPacket used for the bad packet. 20 | ErrMalformPacket = errors.New("Malform.packet.error") 21 | ) 22 | -------------------------------------------------------------------------------- /packet/mock.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. 7 | * GPL License 8 | * 9 | */ 10 | 11 | package packet 12 | 13 | import ( 14 | "io" 15 | "net" 16 | "time" 17 | ) 18 | 19 | var _ net.Conn = &MockConn{} 20 | 21 | // MockConn used to mock a net.Conn for testing purposes. 22 | type MockConn struct { 23 | laddr net.Addr 24 | raddr net.Addr 25 | data []byte 26 | closed bool 27 | read int 28 | } 29 | 30 | // NewMockConn creates new mock connection. 31 | func NewMockConn() *MockConn { 32 | return &MockConn{} 33 | } 34 | 35 | // Read implements the net.Conn interface. 36 | func (m *MockConn) Read(b []byte) (n int, err error) { 37 | // handle the EOF 38 | if len(m.data) == 0 { 39 | err = io.EOF 40 | return 41 | } 42 | 43 | n = copy(b, m.data) 44 | m.read += n 45 | m.data = m.data[n:] 46 | return 47 | } 48 | 49 | // Write implements the net.Conn interface. 50 | func (m *MockConn) Write(b []byte) (n int, err error) { 51 | m.data = append(m.data, b...) 52 | return len(b), nil 53 | } 54 | 55 | // Datas implements the net.Conn interface. 56 | func (m *MockConn) Datas() []byte { 57 | return m.data 58 | } 59 | 60 | // Close implements the net.Conn interface. 61 | func (m *MockConn) Close() error { 62 | m.closed = true 63 | return nil 64 | } 65 | 66 | // LocalAddr implements the net.Conn interface. 67 | func (m *MockConn) LocalAddr() net.Addr { 68 | return m.laddr 69 | } 70 | 71 | // RemoteAddr implements the net.Conn interface. 72 | func (m *MockConn) RemoteAddr() net.Addr { 73 | return m.raddr 74 | } 75 | 76 | // SetDeadline implements the net.Conn interface. 77 | func (m *MockConn) SetDeadline(t time.Time) error { 78 | return nil 79 | } 80 | 81 | // SetReadDeadline implements the net.Conn interface. 82 | func (m *MockConn) SetReadDeadline(t time.Time) error { 83 | return nil 84 | } 85 | 86 | // SetWriteDeadline implements the net.Conn interface. 87 | func (m *MockConn) SetWriteDeadline(t time.Time) error { 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /packet/stream.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package packet 11 | 12 | import ( 13 | "bufio" 14 | "io" 15 | "net" 16 | ) 17 | 18 | const ( 19 | // PACKET_BUFFER_SIZE is how much we buffer for reading. 20 | PACKET_BUFFER_SIZE = 32 * 1024 21 | ) 22 | 23 | // Stream represents the stream tuple. 24 | type Stream struct { 25 | pktMaxSize int 26 | header []byte 27 | reader *bufio.Reader 28 | writer *bufio.Writer 29 | } 30 | 31 | // NewStream creates a new stream. 32 | func NewStream(conn net.Conn, pktMaxSize int) *Stream { 33 | return &Stream{ 34 | pktMaxSize: pktMaxSize, 35 | header: []byte{0, 0, 0, 0}, 36 | reader: bufio.NewReaderSize(conn, PACKET_BUFFER_SIZE), 37 | writer: bufio.NewWriterSize(conn, PACKET_BUFFER_SIZE), 38 | } 39 | } 40 | 41 | // Read reads the next packet from the reader 42 | // The returned pkt.Datas is only guaranteed to be valid until the next read 43 | func (s *Stream) Read() (*Packet, error) { 44 | // Header. 45 | if _, err := io.ReadFull(s.reader, s.header); err != nil { 46 | return nil, err 47 | } 48 | 49 | // Length. 50 | pkt := &Packet{} 51 | pkt.SequenceID = s.header[3] 52 | length := int(uint32(s.header[0]) | uint32(s.header[1])<<8 | uint32(s.header[2])<<16) 53 | if length == 0 { 54 | return pkt, nil 55 | } 56 | 57 | // Datas. 58 | data := make([]byte, length) 59 | if _, err := io.ReadFull(s.reader, data); err != nil { 60 | return nil, err 61 | } 62 | pkt.Datas = data 63 | 64 | // Single packet. 65 | if length < s.pktMaxSize { 66 | return pkt, nil 67 | } 68 | 69 | // There is more than one packet, read them all. 70 | next, err := s.Read() 71 | if err != nil { 72 | return nil, err 73 | } 74 | pkt.SequenceID = next.SequenceID 75 | pkt.Datas = append(pkt.Datas, next.Datas...) 76 | return pkt, nil 77 | } 78 | 79 | // Write writes the packet to writer 80 | func (s *Stream) Write(data []byte) error { 81 | if err := s.Append(data); err != nil { 82 | return err 83 | } 84 | return s.Flush() 85 | } 86 | 87 | // Append used to append data to write buffer. 88 | func (s *Stream) Append(data []byte) error { 89 | payLen := len(data) - 4 90 | sequence := data[3] 91 | 92 | for { 93 | var size int 94 | if payLen < s.pktMaxSize { 95 | size = payLen 96 | } else { 97 | size = s.pktMaxSize 98 | } 99 | data[0] = byte(size) 100 | data[1] = byte(size >> 8) 101 | data[2] = byte(size >> 16) 102 | data[3] = sequence 103 | 104 | // append to buffer 105 | s.writer.Write(data[:4+size]) 106 | if size < s.pktMaxSize { 107 | break 108 | } 109 | 110 | payLen -= size 111 | data = data[size:] 112 | sequence++ 113 | } 114 | return nil 115 | } 116 | 117 | // Flush used to flush the writer. 118 | func (s *Stream) Flush() error { 119 | return s.writer.Flush() 120 | } 121 | -------------------------------------------------------------------------------- /packet/stream_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package packet 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | 17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 18 | ) 19 | 20 | // TEST EFFECTS: 21 | // writes normal packet 22 | // 23 | // TEST PROCESSES: 24 | // 1. write datas more than PACKET_BUFFER_SIZE 25 | // 2. write checks 26 | // 3. read checks 27 | func TestStream(t *testing.T) { 28 | rBuf := NewMockConn() 29 | defer rBuf.Close() 30 | 31 | wBuf := NewMockConn() 32 | defer wBuf.Close() 33 | 34 | rStream := NewStream(rBuf, PACKET_MAX_SIZE) 35 | wStream := NewStream(wBuf, PACKET_MAX_SIZE) 36 | 37 | packet := common.NewBuffer(PACKET_BUFFER_SIZE) 38 | payload := common.NewBuffer(PACKET_BUFFER_SIZE) 39 | 40 | for i := 0; i < 1234; i++ { 41 | payload.WriteU8(byte(i)) 42 | } 43 | 44 | packet.WriteU24(uint32(payload.Length())) 45 | packet.WriteU8(1) 46 | packet.WriteBytes(payload.Datas()) 47 | 48 | // write checks 49 | { 50 | err := wStream.Write(packet.Datas()) 51 | assert.Nil(t, err) 52 | 53 | want := packet.Datas() 54 | got := wBuf.Datas() 55 | assert.Equal(t, want, got) 56 | } 57 | 58 | // read checks 59 | { 60 | rBuf.Write(wBuf.Datas()) 61 | ptk, err := rStream.Read() 62 | assert.Nil(t, err) 63 | 64 | assert.Equal(t, byte(0x01), ptk.SequenceID) 65 | assert.Equal(t, payload.Datas(), ptk.Datas) 66 | } 67 | } 68 | 69 | // TEST EFFECTS: 70 | // write packet whoes payload length equals pktMaxSize 71 | // 72 | // TEST PROCESSES: 73 | // 1. write payload whoes length equals pktMaxSize 74 | // 2. read checks 75 | // 3. write checks 76 | func TestStreamWriteMax(t *testing.T) { 77 | rBuf := NewMockConn() 78 | defer rBuf.Close() 79 | 80 | wBuf := NewMockConn() 81 | defer wBuf.Close() 82 | 83 | pktMaxSize := 64 84 | rStream := NewStream(rBuf, pktMaxSize) 85 | wStream := NewStream(wBuf, pktMaxSize) 86 | 87 | packet := common.NewBuffer(PACKET_BUFFER_SIZE) 88 | expect := common.NewBuffer(PACKET_BUFFER_SIZE) 89 | payload := common.NewBuffer(PACKET_BUFFER_SIZE) 90 | 91 | { 92 | for i := 0; i < (pktMaxSize+1)/4; i++ { 93 | payload.WriteU32(uint32(i)) 94 | } 95 | } 96 | packet.WriteU24(uint32(payload.Length())) 97 | packet.WriteU8(1) 98 | packet.WriteBytes(payload.Datas()) 99 | 100 | // write checks 101 | { 102 | err := wStream.Write(packet.Datas()) 103 | assert.Nil(t, err) 104 | 105 | // check length 106 | { 107 | want := packet.Length() + 4 108 | got := len(wBuf.Datas()) 109 | assert.Equal(t, want, got) 110 | } 111 | 112 | // check chunks 113 | { 114 | // first chunk 115 | expect.WriteU24(uint32(pktMaxSize)) 116 | expect.WriteU8(1) 117 | expect.WriteBytes(payload.Datas()[:pktMaxSize]) 118 | 119 | // second chunk 120 | expect.WriteU24(0) 121 | expect.WriteU8(2) 122 | 123 | want := expect.Datas() 124 | got := wBuf.Datas() 125 | assert.Equal(t, want, got) 126 | } 127 | } 128 | 129 | // read checks 130 | { 131 | rBuf.Write(wBuf.Datas()) 132 | ptk, err := rStream.Read() 133 | assert.Nil(t, err) 134 | 135 | assert.Equal(t, byte(0x02), ptk.SequenceID) 136 | assert.Equal(t, payload.Datas(), ptk.Datas) 137 | } 138 | } 139 | 140 | // TEST EFFECTS: 141 | // write packet whoes payload length more than pktMaxSizie 142 | // 143 | // TEST PROCESSES: 144 | // 1. write payload whoes length (pktMaxSizie + 8) 145 | // 2. read checks 146 | // 3. write checks 147 | func TestStreamWriteOverMax(t *testing.T) { 148 | rBuf := NewMockConn() 149 | defer rBuf.Close() 150 | 151 | wBuf := NewMockConn() 152 | defer wBuf.Close() 153 | 154 | pktMaxSize := 63 155 | rStream := NewStream(rBuf, pktMaxSize) 156 | wStream := NewStream(wBuf, pktMaxSize) 157 | 158 | packet := common.NewBuffer(PACKET_BUFFER_SIZE) 159 | expect := common.NewBuffer(PACKET_BUFFER_SIZE) 160 | payload := common.NewBuffer(PACKET_BUFFER_SIZE) 161 | 162 | { 163 | for i := 0; i < pktMaxSize/4; i++ { 164 | payload.WriteU32(uint32(i)) 165 | } 166 | } 167 | // fill with 8bytes 168 | payload.WriteU32(32) 169 | payload.WriteU32(32) 170 | 171 | packet.WriteU24(uint32(payload.Length())) 172 | packet.WriteU8(1) 173 | packet.WriteBytes(payload.Datas()) 174 | 175 | // write checks 176 | { 177 | err := wStream.Write(packet.Datas()) 178 | assert.Nil(t, err) 179 | 180 | // check length 181 | { 182 | want := packet.Length() + 4 183 | got := len(wBuf.Datas()) 184 | assert.Equal(t, want, got) 185 | } 186 | 187 | // check chunks 188 | { 189 | // first chunk 190 | expect.WriteU24(uint32(pktMaxSize)) 191 | expect.WriteU8(1) 192 | expect.WriteBytes(payload.Datas()[:pktMaxSize]) 193 | 194 | // second chunk 195 | left := (packet.Length() - 4) - pktMaxSize 196 | expect.WriteU24(uint32(left)) 197 | expect.WriteU8(2) 198 | expect.WriteBytes(payload.Datas()[pktMaxSize:]) 199 | 200 | want := expect.Datas() 201 | got := wBuf.Datas() 202 | assert.Equal(t, want, got) 203 | } 204 | } 205 | 206 | // read checks 207 | { 208 | rBuf.Write(wBuf.Datas()) 209 | ptk, err := rStream.Read() 210 | assert.Nil(t, err) 211 | 212 | assert.Equal(t, byte(0x02), ptk.SequenceID) 213 | assert.Equal(t, payload.Datas(), ptk.Datas) 214 | _, err = rStream.Read() 215 | assert.NotNil(t, err) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /proto/column.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/sqldb" 14 | 15 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 16 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 18 | ) 19 | 20 | // ColumnCount returns the column count. 21 | func ColumnCount(payload []byte) (count uint64, err error) { 22 | buff := common.ReadBuffer(payload) 23 | if count, err = buff.ReadLenEncode(); err != nil { 24 | return 0, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting column count failed") 25 | } 26 | return 27 | } 28 | 29 | // UnpackColumn used to unpack the column packet. 30 | // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 31 | func UnpackColumn(payload []byte) (*querypb.Field, error) { 32 | var err error 33 | field := &querypb.Field{} 34 | buff := common.ReadBuffer(payload) 35 | // Catalog is ignored, always set to "def" 36 | if _, err = buff.ReadLenEncodeString(); err != nil { 37 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "skipping col catalog failed") 38 | } 39 | 40 | // lenenc_str Schema 41 | if field.Database, err = buff.ReadLenEncodeString(); err != nil { 42 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col schema failed") 43 | } 44 | 45 | // lenenc_str Table 46 | if field.Table, err = buff.ReadLenEncodeString(); err != nil { 47 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col table failed") 48 | } 49 | 50 | // lenenc_str Org_Table 51 | if field.OrgTable, err = buff.ReadLenEncodeString(); err != nil { 52 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col org_table failed") 53 | } 54 | 55 | // lenenc_str Name 56 | if field.Name, err = buff.ReadLenEncodeString(); err != nil { 57 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col name failed") 58 | } 59 | 60 | // lenenc_str Org_Name 61 | if field.OrgName, err = buff.ReadLenEncodeString(); err != nil { 62 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col org_name failed") 63 | } 64 | 65 | // lenenc_int length of fixed-length fields [0c], skip 66 | if _, err = buff.ReadLenEncode(); err != nil { 67 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col 0c failed") 68 | } 69 | 70 | // 2 character set 71 | charset, err := buff.ReadU16() 72 | if err != nil { 73 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col charset failed") 74 | } 75 | field.Charset = uint32(charset) 76 | 77 | // 4 column length 78 | if field.ColumnLength, err = buff.ReadU32(); err != nil { 79 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col columnlength failed") 80 | } 81 | 82 | // 1 type 83 | t, err := buff.ReadU8() 84 | if err != nil { 85 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col type failed") 86 | } 87 | 88 | // 2 flags 89 | flags, err := buff.ReadU16() 90 | if err != nil { 91 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col flags failed") 92 | } 93 | field.Flags = uint32(flags) 94 | 95 | // Convert MySQL type 96 | if field.Type, err = sqltypes.MySQLToType(int64(t), int64(field.Flags)); err != nil { 97 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "MySQLToType(%v,%v) failed: %v", t, field.Flags, err) 98 | } 99 | 100 | // 1 Decimals 101 | decimals, err := buff.ReadU8() 102 | if err != nil { 103 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col type failed") 104 | } 105 | field.Decimals = uint32(decimals) 106 | 107 | // 2 Filler and Default Values is ignored 108 | // 109 | return field, nil 110 | } 111 | 112 | // PackColumn used to pack the column packet. 113 | func PackColumn(field *querypb.Field) []byte { 114 | typ, flags := sqltypes.TypeToMySQL(field.Type) 115 | if field.Flags != 0 { 116 | flags = int64(field.Flags) 117 | } 118 | 119 | buf := common.NewBuffer(256) 120 | 121 | // lenenc_str Catalog, always 'def' 122 | buf.WriteLenEncodeString("def") 123 | 124 | // lenenc_str Schema 125 | buf.WriteLenEncodeString(field.Database) 126 | 127 | // lenenc_str Table 128 | buf.WriteLenEncodeString(field.Table) 129 | 130 | // lenenc_str Org_Table 131 | buf.WriteLenEncodeString(field.OrgTable) 132 | 133 | // lenenc_str Name 134 | buf.WriteLenEncodeString(field.Name) 135 | 136 | // lenenc_str Org_Name 137 | buf.WriteLenEncodeString(field.OrgName) 138 | 139 | // lenenc_int length of fixed-length fields [0c] 140 | buf.WriteLenEncode(uint64(0x0c)) 141 | 142 | // 2 character set 143 | buf.WriteU16(uint16(field.Charset)) 144 | 145 | // 4 column length 146 | buf.WriteU32(field.ColumnLength) 147 | 148 | // 1 type 149 | buf.WriteU8(byte(typ)) 150 | 151 | // 2 flags 152 | buf.WriteU16(uint16(flags)) 153 | 154 | //1 Decimals 155 | buf.WriteU8(uint8(field.Decimals)) 156 | 157 | // 2 filler [00] [00] 158 | buf.WriteU16(uint16(0)) 159 | return buf.Datas() 160 | } 161 | -------------------------------------------------------------------------------- /proto/column_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 17 | 18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 20 | ) 21 | 22 | func TestColumnCount(t *testing.T) { 23 | payload := []byte{ 24 | 0x02, 25 | } 26 | 27 | want := uint64(2) 28 | got, err := ColumnCount(payload) 29 | assert.Nil(t, err) 30 | assert.Equal(t, want, got) 31 | } 32 | 33 | func TestColumn(t *testing.T) { 34 | want := &querypb.Field{ 35 | Database: "test", 36 | Table: "t1", 37 | OrgTable: "t1", 38 | Name: "a", 39 | OrgName: "a", 40 | Charset: 11, 41 | ColumnLength: 11, 42 | Type: sqltypes.Int32, 43 | Flags: 11, 44 | } 45 | 46 | datas := PackColumn(want) 47 | got, err := UnpackColumn(datas) 48 | assert.Nil(t, err) 49 | assert.Equal(t, want, got) 50 | } 51 | 52 | func TestColumnUnPackError(t *testing.T) { 53 | // NULL 54 | f0 := func(buff *common.Buffer) { 55 | } 56 | 57 | // Write catalog. 58 | f1 := func(buff *common.Buffer) { 59 | buff.WriteLenEncodeString("def") 60 | } 61 | 62 | // Write schema. 63 | f2 := func(buff *common.Buffer) { 64 | buff.WriteLenEncodeString("sbtest") 65 | } 66 | 67 | // Write table. 68 | f3 := func(buff *common.Buffer) { 69 | buff.WriteLenEncodeString("table1") 70 | } 71 | 72 | // Write org table. 73 | f4 := func(buff *common.Buffer) { 74 | buff.WriteLenEncodeString("orgtable1") 75 | } 76 | 77 | // Write Name. 78 | f5 := func(buff *common.Buffer) { 79 | buff.WriteLenEncodeString("name") 80 | } 81 | 82 | // Write Org Name. 83 | f6 := func(buff *common.Buffer) { 84 | buff.WriteLenEncodeString("name") 85 | } 86 | 87 | // Write length. 88 | f7 := func(buff *common.Buffer) { 89 | buff.WriteLenEncode(0x0c) 90 | } 91 | 92 | // Write Charset. 93 | f8 := func(buff *common.Buffer) { 94 | buff.WriteU16(uint16(1)) 95 | } 96 | 97 | // Write Column length. 98 | f9 := func(buff *common.Buffer) { 99 | buff.WriteU32(uint32(1)) 100 | } 101 | 102 | // Write type. 103 | f10 := func(buff *common.Buffer) { 104 | buff.WriteU8(0x01) 105 | } 106 | 107 | // Write flags 108 | f11 := func(buff *common.Buffer) { 109 | buff.WriteU16(uint16(1)) 110 | } 111 | 112 | buff := common.NewBuffer(32) 113 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11} 114 | for i := 0; i < len(fs); i++ { 115 | _, err := UnpackColumn(buff.Datas()) 116 | assert.NotNil(t, err) 117 | fs[i](buff) 118 | } 119 | 120 | { 121 | _, err := UnpackColumn(buff.Datas()) 122 | assert.NotNil(t, err) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /proto/const.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/sqldb" 14 | ) 15 | 16 | const ( 17 | // DefaultAuthPluginName is the default plugin name. 18 | DefaultAuthPluginName = "mysql_native_password" 19 | 20 | // DefaultServerCapability is the default server capability. 21 | DefaultServerCapability = sqldb.CLIENT_LONG_PASSWORD | 22 | sqldb.CLIENT_LONG_FLAG | 23 | sqldb.CLIENT_CONNECT_WITH_DB | 24 | sqldb.CLIENT_PROTOCOL_41 | 25 | sqldb.CLIENT_TRANSACTIONS | 26 | sqldb.CLIENT_MULTI_STATEMENTS | 27 | sqldb.CLIENT_PLUGIN_AUTH | 28 | sqldb.CLIENT_DEPRECATE_EOF | 29 | sqldb.CLIENT_SECURE_CONNECTION 30 | 31 | // DefaultClientCapability is the default client capability. 32 | DefaultClientCapability = sqldb.CLIENT_LONG_PASSWORD | 33 | sqldb.CLIENT_LONG_FLAG | 34 | sqldb.CLIENT_PROTOCOL_41 | 35 | sqldb.CLIENT_TRANSACTIONS | 36 | sqldb.CLIENT_MULTI_STATEMENTS | 37 | sqldb.CLIENT_PLUGIN_AUTH | 38 | sqldb.CLIENT_DEPRECATE_EOF | 39 | sqldb.CLIENT_SECURE_CONNECTION 40 | ) 41 | 42 | var ( 43 | // DefaultSalt is the default salt bytes. 44 | DefaultSalt = []byte{ 45 | 0x77, 0x63, 0x6a, 0x6d, 0x61, 0x22, 0x23, 0x27, // first part 46 | 0x38, 0x26, 0x55, 0x58, 0x3b, 0x5d, 0x44, 0x78, 0x53, 0x73, 0x6b, 0x41} 47 | ) 48 | -------------------------------------------------------------------------------- /proto/eof.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/sqldb" 14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 15 | ) 16 | 17 | const ( 18 | // EOF_PACKET is the EOF packet. 19 | EOF_PACKET byte = 0xfe 20 | ) 21 | 22 | // EOF used for EOF packet. 23 | type EOF struct { 24 | Header byte // 0x00 25 | Warnings uint16 26 | StatusFlags uint16 27 | } 28 | 29 | // UnPackEOF used to unpack the EOF packet. 30 | // https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html 31 | // This method unsed. 32 | func UnPackEOF(data []byte) (*EOF, error) { 33 | var err error 34 | e := &EOF{} 35 | buf := common.ReadBuffer(data) 36 | 37 | // header 38 | if e.Header, err = buf.ReadU8(); err != nil { 39 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet header: %v", data) 40 | } 41 | if e.Header != EOF_PACKET { 42 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid oeof packet header: %v", e.Header) 43 | } 44 | 45 | // Warnings 46 | if e.Warnings, err = buf.ReadU16(); err != nil { 47 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet warnings: %v", data) 48 | } 49 | 50 | // Status 51 | if e.StatusFlags, err = buf.ReadU16(); err != nil { 52 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet statusflags: %v", data) 53 | } 54 | return e, nil 55 | } 56 | 57 | // PackEOF used to pack the EOF packet. 58 | func PackEOF(e *EOF) []byte { 59 | buf := common.NewBuffer(64) 60 | 61 | // EOF 62 | buf.WriteU8(EOF_PACKET) 63 | 64 | // warnings 65 | buf.WriteU16(e.Warnings) 66 | 67 | // status 68 | buf.WriteU16(e.StatusFlags) 69 | return buf.Datas() 70 | } 71 | -------------------------------------------------------------------------------- /proto/eof_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | 17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 18 | ) 19 | 20 | func TestEOF(t *testing.T) { 21 | want := &EOF{} 22 | want.Header = EOF_PACKET 23 | want.StatusFlags = 1 24 | want.Warnings = 2 25 | data := PackEOF(want) 26 | 27 | got, err := UnPackEOF(data) 28 | assert.Nil(t, err) 29 | assert.Equal(t, want, got) 30 | } 31 | 32 | func TestEOFUnPackError(t *testing.T) { 33 | // header error 34 | { 35 | buff := common.NewBuffer(32) 36 | // header 37 | buff.WriteU8(0x99) 38 | _, err := UnPackEOF(buff.Datas()) 39 | assert.NotNil(t, err) 40 | } 41 | 42 | // NULL 43 | f0 := func(buff *common.Buffer) { 44 | } 45 | 46 | // Write EOF header. 47 | f1 := func(buff *common.Buffer) { 48 | buff.WriteU8(0xfe) 49 | } 50 | 51 | // Write Status. 52 | f2 := func(buff *common.Buffer) { 53 | buff.WriteU16(0x01) 54 | } 55 | 56 | buff := common.NewBuffer(32) 57 | fs := []func(buff *common.Buffer){f0, f1, f2} 58 | for i := 0; i < len(fs); i++ { 59 | _, err := UnPackEOF(buff.Datas()) 60 | assert.NotNil(t, err) 61 | fs[i](buff) 62 | } 63 | 64 | { 65 | _, err := UnPackEOF(buff.Datas()) 66 | assert.NotNil(t, err) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /proto/err.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/sqldb" 14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 15 | ) 16 | 17 | const ( 18 | // ERR_PACKET is the error packet byte. 19 | ERR_PACKET byte = 0xff 20 | ) 21 | 22 | // ERR is the error packet. 23 | type ERR struct { 24 | Header byte // always 0xff 25 | ErrorCode uint16 26 | SQLState string 27 | ErrorMessage string 28 | } 29 | 30 | // UnPackERR parses the error packet and returns a sqldb.SQLError. 31 | // https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html 32 | func UnPackERR(data []byte) error { 33 | var err error 34 | e := &ERR{} 35 | buf := common.ReadBuffer(data) 36 | if e.Header, err = buf.ReadU8(); err != nil { 37 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet header: %v", data) 38 | } 39 | if e.Header != ERR_PACKET { 40 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet header: %v", e.Header) 41 | } 42 | if e.ErrorCode, err = buf.ReadU16(); err != nil { 43 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet code: %v", data) 44 | } 45 | 46 | // Skip SQLStateMarker 47 | if _, err = buf.ReadString(1); err != nil { 48 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet marker: %v", data) 49 | } 50 | if e.SQLState, err = buf.ReadString(5); err != nil { 51 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet sqlstate: %v", data) 52 | } 53 | msgLen := len(data) - buf.Seek() 54 | if e.ErrorMessage, err = buf.ReadString(msgLen); err != nil { 55 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet message: %v", data) 56 | } 57 | return sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage) 58 | } 59 | 60 | // PackERR used to pack the error packet. 61 | func PackERR(e *ERR) []byte { 62 | buf := common.NewBuffer(64) 63 | 64 | buf.WriteU8(ERR_PACKET) 65 | 66 | // error code 67 | buf.WriteU16(e.ErrorCode) 68 | 69 | // sql-state marker # 70 | buf.WriteU8('#') 71 | 72 | // sql-state (?) 5 ascii bytes 73 | if e.SQLState == "" { 74 | e.SQLState = "HY000" 75 | } 76 | if len(e.SQLState) != 5 { 77 | panic("sqlState has to be 5 characters long") 78 | } 79 | buf.WriteString(e.SQLState) 80 | 81 | // error msg 82 | buf.WriteString(e.ErrorMessage) 83 | return buf.Datas() 84 | } 85 | -------------------------------------------------------------------------------- /proto/err_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/xelabs/go-mysqlstack/sqldb" 17 | 18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 19 | ) 20 | 21 | func TestERR(t *testing.T) { 22 | { 23 | buff := common.NewBuffer(32) 24 | 25 | // header 26 | buff.WriteU8(0xff) 27 | // error_code 28 | buff.WriteU16(0x01) 29 | // sql_state_marker 30 | buff.WriteString("#") 31 | // sql_state 32 | buff.WriteString("ABCDE") 33 | buff.WriteString("ERROR") 34 | 35 | e := &ERR{} 36 | e.Header = 0xff 37 | e.ErrorCode = 0x1 38 | e.SQLState = "ABCDE" 39 | e.ErrorMessage = "ERROR" 40 | want := sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage) 41 | got := UnPackERR(buff.Datas()) 42 | assert.Equal(t, want, got) 43 | } 44 | 45 | { 46 | e := &ERR{} 47 | e.Header = 0xff 48 | e.ErrorCode = 0x1 49 | e.ErrorMessage = "ERROR" 50 | datas := PackERR(e) 51 | want := sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage) 52 | got := UnPackERR(datas) 53 | assert.Equal(t, want, got) 54 | } 55 | } 56 | 57 | func TestERRUnPackError(t *testing.T) { 58 | // header error 59 | { 60 | buff := common.NewBuffer(32) 61 | 62 | // header 63 | buff.WriteU8(0x01) 64 | 65 | err := UnPackERR(buff.Datas()) 66 | assert.NotNil(t, err) 67 | } 68 | 69 | // NULL 70 | f0 := func(buff *common.Buffer) { 71 | } 72 | 73 | // Write error header. 74 | f1 := func(buff *common.Buffer) { 75 | buff.WriteU8(0xff) 76 | } 77 | 78 | // Write error code. 79 | f2 := func(buff *common.Buffer) { 80 | buff.WriteU16(0x01) 81 | } 82 | 83 | // Write SQLStateMarker. 84 | f3 := func(buff *common.Buffer) { 85 | buff.WriteU8('#') 86 | } 87 | 88 | // Write SQLState. 89 | f4 := func(buff *common.Buffer) { 90 | buff.WriteString("xxxxx") 91 | } 92 | 93 | buff := common.NewBuffer(32) 94 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4} 95 | for i := 0; i < len(fs); i++ { 96 | err := UnPackERR(buff.Datas()) 97 | assert.NotNil(t, err) 98 | fs[i](buff) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /proto/greeting_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/xelabs/go-mysqlstack/sqldb" 17 | 18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 19 | ) 20 | 21 | func TestGreetingUnPack(t *testing.T) { 22 | want := NewGreeting(4, "") 23 | got := NewGreeting(4, "") 24 | 25 | // normal 26 | { 27 | want.authPluginName = "mysql_native_password" 28 | err := got.UnPack(want.Pack()) 29 | assert.Nil(t, err) 30 | assert.Equal(t, want, got) 31 | assert.Equal(t, sqldb.SERVER_STATUS_AUTOCOMMIT, int(got.Status())) 32 | } 33 | 34 | // 1. off sqldb.CLIENT_PLUGIN_AUTH 35 | { 36 | want.Capability = want.Capability &^ sqldb.CLIENT_PLUGIN_AUTH 37 | want.authPluginName = "mysql_native_password" 38 | err := got.UnPack(want.Pack()) 39 | assert.Nil(t, err) 40 | assert.Equal(t, want, got) 41 | } 42 | 43 | // 2. off sqldb.CLIENT_SECURE_CONNECTION 44 | { 45 | want.Capability &= ^sqldb.CLIENT_SECURE_CONNECTION 46 | want.authPluginName = "mysql_native_password" 47 | err := got.UnPack(want.Pack()) 48 | assert.Nil(t, err) 49 | assert.Equal(t, want, got) 50 | } 51 | 52 | // 3. off sqldb.CLIENT_PLUGIN_AUTH && sqldb.CLIENT_SECURE_CONNECTION 53 | { 54 | want.Capability &= (^sqldb.CLIENT_PLUGIN_AUTH ^ sqldb.CLIENT_SECURE_CONNECTION) 55 | want.authPluginName = "mysql_native_password" 56 | err := got.UnPack(want.Pack()) 57 | assert.Nil(t, err) 58 | assert.Equal(t, want, got) 59 | } 60 | } 61 | 62 | func TestGreetingUnPackError(t *testing.T) { 63 | // NULL 64 | f0 := func(buff *common.Buffer) { 65 | } 66 | 67 | // Write protocol version. 68 | f1 := func(buff *common.Buffer) { 69 | buff.WriteU8(0x01) 70 | } 71 | 72 | // Write server version. 73 | f2 := func(buff *common.Buffer) { 74 | buff.WriteString("5.7.17-11") 75 | buff.WriteZero(1) 76 | } 77 | 78 | // Write connection ID. 79 | f3 := func(buff *common.Buffer) { 80 | buff.WriteU32(uint32(1)) 81 | } 82 | 83 | // Write salt[8]. 84 | f4 := func(buff *common.Buffer) { 85 | salt8 := make([]byte, 8) 86 | buff.WriteBytes(salt8) 87 | } 88 | 89 | // Write filler. 90 | f5 := func(buff *common.Buffer) { 91 | buff.WriteZero(1) 92 | } 93 | 94 | capability := DefaultServerCapability 95 | capLower := uint16(capability) 96 | capUpper := uint16(uint32(capability) >> 16) 97 | 98 | // Write capability lower 2 bytes 99 | f6 := func(buff *common.Buffer) { 100 | buff.WriteU16(capLower) 101 | } 102 | 103 | // Write charset. 104 | f7 := func(buff *common.Buffer) { 105 | buff.WriteU8(0x01) 106 | } 107 | 108 | // Write statu flags 109 | f8 := func(buff *common.Buffer) { 110 | buff.WriteU16(uint16(1)) 111 | } 112 | 113 | // Write capability upper 2 bytes 114 | f9 := func(buff *common.Buffer) { 115 | buff.WriteU16(capUpper) 116 | } 117 | 118 | // Write length of auth-plugin 119 | f10 := func(buff *common.Buffer) { 120 | buff.WriteU8(0x01) 121 | } 122 | 123 | // Write reserved. 124 | f11 := func(buff *common.Buffer) { 125 | buff.WriteZero(10) 126 | } 127 | 128 | // Write auth plugin data part 2 129 | f12 := func(buff *common.Buffer) { 130 | data2 := make([]byte, 13) 131 | data2[12] = 0x01 132 | buff.WriteBytes(data2) 133 | } 134 | 135 | buff := common.NewBuffer(32) 136 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12} 137 | for i := 0; i < len(fs); i++ { 138 | greeting := NewGreeting(0, "") 139 | err := greeting.UnPack(buff.Datas()) 140 | assert.NotNil(t, err) 141 | fs[i](buff) 142 | } 143 | 144 | { 145 | greeting := NewGreeting(0, "") 146 | err := greeting.UnPack(buff.Datas()) 147 | assert.NotNil(t, err) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /proto/ok.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "github.com/xelabs/go-mysqlstack/sqldb" 14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 15 | ) 16 | 17 | const ( 18 | // OK_PACKET is the OK byte. 19 | OK_PACKET byte = 0x00 20 | ) 21 | 22 | // OK used for OK packet. 23 | type OK struct { 24 | Header byte // 0x00 25 | AffectedRows uint64 26 | LastInsertID uint64 27 | StatusFlags uint16 28 | Warnings uint16 29 | } 30 | 31 | // UnPackOK used to unpack the OK packet. 32 | // https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html 33 | func UnPackOK(data []byte) (*OK, error) { 34 | var err error 35 | o := &OK{} 36 | buf := common.ReadBuffer(data) 37 | 38 | // header 39 | if o.Header, err = buf.ReadU8(); err != nil { 40 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet header: %v", data) 41 | } 42 | if o.Header != OK_PACKET { 43 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet header: %v", o.Header) 44 | } 45 | 46 | // AffectedRows 47 | if o.AffectedRows, err = buf.ReadLenEncode(); err != nil { 48 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet affectedrows: %v", data) 49 | } 50 | 51 | // LastInsertID 52 | if o.LastInsertID, err = buf.ReadLenEncode(); err != nil { 53 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet lastinsertid: %v", data) 54 | } 55 | 56 | // Status 57 | if o.StatusFlags, err = buf.ReadU16(); err != nil { 58 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet statusflags: %v", data) 59 | } 60 | 61 | // Warnings 62 | if o.Warnings, err = buf.ReadU16(); err != nil { 63 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet warnings: %v", data) 64 | } 65 | return o, nil 66 | } 67 | 68 | // PackOK used to pack the OK packet. 69 | func PackOK(o *OK) []byte { 70 | buf := common.NewBuffer(64) 71 | 72 | // OK 73 | buf.WriteU8(OK_PACKET) 74 | 75 | // affected rows 76 | buf.WriteLenEncode(o.AffectedRows) 77 | 78 | // last insert id 79 | buf.WriteLenEncode(o.LastInsertID) 80 | 81 | // status 82 | buf.WriteU16(o.StatusFlags) 83 | 84 | // warnings 85 | buf.WriteU16(o.Warnings) 86 | return buf.Datas() 87 | } 88 | -------------------------------------------------------------------------------- /proto/ok_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | 17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 18 | ) 19 | 20 | func TestOK(t *testing.T) { 21 | { 22 | buff := common.NewBuffer(32) 23 | 24 | // header 25 | buff.WriteU8(0x00) 26 | // affected_rows 27 | buff.WriteLenEncode(uint64(3)) 28 | // last_insert_id 29 | buff.WriteLenEncode(uint64(40000000000)) 30 | 31 | // status_flags 32 | buff.WriteU16(0x01) 33 | // warnings 34 | buff.WriteU16(0x02) 35 | 36 | want := &OK{} 37 | want.AffectedRows = 3 38 | want.LastInsertID = 40000000000 39 | want.StatusFlags = 1 40 | want.Warnings = 2 41 | 42 | got, err := UnPackOK(buff.Datas()) 43 | assert.Nil(t, err) 44 | assert.Equal(t, want, got) 45 | } 46 | 47 | { 48 | want := &OK{} 49 | want.AffectedRows = 3 50 | want.LastInsertID = 40000000000 51 | want.StatusFlags = 1 52 | want.Warnings = 2 53 | datas := PackOK(want) 54 | 55 | got, err := UnPackOK(datas) 56 | assert.Nil(t, err) 57 | assert.Equal(t, want, got) 58 | } 59 | } 60 | 61 | func TestOKUnPackError(t *testing.T) { 62 | // header error 63 | { 64 | buff := common.NewBuffer(32) 65 | // header 66 | buff.WriteU8(0x99) 67 | _, err := UnPackOK(buff.Datas()) 68 | assert.NotNil(t, err) 69 | } 70 | 71 | // NULL 72 | f0 := func(buff *common.Buffer) { 73 | } 74 | 75 | // Write OK header. 76 | f1 := func(buff *common.Buffer) { 77 | buff.WriteU8(0x00) 78 | } 79 | 80 | // Write AffectedRows. 81 | f2 := func(buff *common.Buffer) { 82 | buff.WriteLenEncode(uint64(3)) 83 | } 84 | 85 | // Write LastInsertID. 86 | f3 := func(buff *common.Buffer) { 87 | buff.WriteLenEncode(uint64(3)) 88 | } 89 | 90 | // Write Status. 91 | f4 := func(buff *common.Buffer) { 92 | buff.WriteU16(0x01) 93 | } 94 | 95 | buff := common.NewBuffer(32) 96 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4} 97 | for i := 0; i < len(fs); i++ { 98 | _, err := UnPackOK(buff.Datas()) 99 | assert.NotNil(t, err) 100 | fs[i](buff) 101 | } 102 | 103 | { 104 | _, err := UnPackOK(buff.Datas()) 105 | assert.NotNil(t, err) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /proto/statement_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package proto 11 | 12 | import ( 13 | "errors" 14 | "testing" 15 | "time" 16 | 17 | "github.com/stretchr/testify/assert" 18 | 19 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common" 20 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 21 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 22 | ) 23 | 24 | func TestStatementPrepare(t *testing.T) { 25 | want := &Statement{ 26 | ID: 5, 27 | ColumnCount: 2, 28 | ParamCount: 3, 29 | Warnings: 1, 30 | } 31 | datas := PackStatementPrepare(want) 32 | got, err := UnPackStatementPrepare(datas) 33 | assert.Nil(t, err) 34 | assert.Equal(t, want, got) 35 | } 36 | 37 | func TestStatementPrepareUnPackError(t *testing.T) { 38 | // NULL 39 | f0 := func(buff *common.Buffer) { 40 | } 41 | 42 | // Write ok. 43 | f1 := func(buff *common.Buffer) { 44 | buff.WriteU8(OK_PACKET) 45 | } 46 | 47 | // Write ID. 48 | f2 := func(buff *common.Buffer) { 49 | buff.WriteU32(1) 50 | } 51 | 52 | // Write Column count. 53 | f3 := func(buff *common.Buffer) { 54 | buff.WriteU16(1) 55 | } 56 | 57 | // Write param count. 58 | f4 := func(buff *common.Buffer) { 59 | buff.WriteU16(2) 60 | } 61 | 62 | // Write reserved. 63 | f5 := func(buff *common.Buffer) { 64 | buff.WriteU8(2) 65 | } 66 | 67 | f6 := func(buff *common.Buffer) { 68 | buff.WriteU8(2) 69 | } 70 | 71 | buff := common.NewBuffer(32) 72 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6} 73 | for i := 0; i < len(fs); i++ { 74 | _, err := UnPackStatementPrepare(buff.Datas()) 75 | assert.NotNil(t, err) 76 | fs[i](buff) 77 | } 78 | } 79 | 80 | func TestStatementExecute(t *testing.T) { 81 | id := uint32(11) 82 | values := []sqltypes.Value{ 83 | sqltypes.MakeTrusted(sqltypes.Int32, []byte("10")), 84 | sqltypes.MakeTrusted(sqltypes.VarChar, []byte("xx10xx")), 85 | sqltypes.MakeTrusted(sqltypes.Null, nil), 86 | sqltypes.MakeTrusted(sqltypes.Text, []byte{}), 87 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))), 88 | } 89 | 90 | datas, err := PackStatementExecute(id, values) 91 | assert.Nil(t, err) 92 | 93 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) { 94 | return nil, nil 95 | } 96 | 97 | protoStmt := &Statement{ 98 | ID: id, 99 | ParamCount: uint16(len(values)), 100 | ParamsType: make([]int32, len(values)), 101 | BindVars: make(map[string]*querypb.BindVariable, len(values)), 102 | } 103 | err = UnPackStatementExecute(datas, protoStmt, parseFn) 104 | assert.Nil(t, err) 105 | } 106 | 107 | func TestStatementExecuteUnPackError(t *testing.T) { 108 | // NULL 109 | f0 := func(buff *common.Buffer) { 110 | } 111 | 112 | // Write ID. 113 | f1 := func(buff *common.Buffer) { 114 | buff.WriteU32(1) 115 | } 116 | 117 | // Cursor type. 118 | f2 := func(buff *common.Buffer) { 119 | buff.WriteU8(1) 120 | } 121 | 122 | // Iteration count. 123 | f3 := func(buff *common.Buffer) { 124 | buff.WriteU32(1) 125 | } 126 | 127 | // Write param count. 128 | f4 := func(buff *common.Buffer) { 129 | buff.WriteU16(2) 130 | } 131 | 132 | // Write null bits. 133 | f5 := func(buff *common.Buffer) { 134 | buff.WriteBytes([]byte{0x00}) 135 | } 136 | 137 | // newParameterBoundFlag. 138 | f6 := func(buff *common.Buffer) { 139 | buff.WriteU8(0x01) 140 | } 141 | 142 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) { 143 | return nil, errors.New("mock.error") 144 | } 145 | 146 | buff := common.NewBuffer(32) 147 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6} 148 | for i := 0; i < len(fs); i++ { 149 | 150 | protoStmt := &Statement{ 151 | ID: 1, 152 | ParamCount: 2, 153 | ParamsType: make([]int32, 2), 154 | BindVars: make(map[string]*querypb.BindVariable, 2), 155 | } 156 | 157 | err := UnPackStatementExecute(buff.Datas(), protoStmt, parseFn) 158 | assert.NotNil(t, err) 159 | fs[i](buff) 160 | } 161 | } 162 | 163 | // issue 462. 164 | // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html 165 | // test about new-params-bound-flag about 0 1 166 | func TestStatementExecuteBatchUnPackStatementExecute(t *testing.T) { 167 | data := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1} 168 | data2 := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 0, 1, 128, 1} 169 | 170 | var dataBatch [][]byte 171 | dataBatch = append(dataBatch, data) 172 | dataBatch = append(dataBatch, data2) 173 | 174 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) { 175 | return nil, nil 176 | } 177 | 178 | protoStmt := &Statement{ 179 | ID: 23, 180 | ParamCount: 1, 181 | ParamsType: make([]int32, 1), 182 | BindVars: make(map[string]*querypb.BindVariable, 1), 183 | } 184 | err := UnPackStatementExecute(dataBatch[0], protoStmt, parseFn) 185 | assert.Nil(t, err) 186 | 187 | err = UnPackStatementExecute(dataBatch[1], protoStmt, parseFn) 188 | assert.Nil(t, err) 189 | } 190 | -------------------------------------------------------------------------------- /sqldb/constants_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package sqldb 11 | 12 | import ( 13 | "testing" 14 | ) 15 | 16 | func TestConstants(t *testing.T) { 17 | var i byte 18 | for i = 0; i < COM_RESET_CONNECTION+2; i++ { 19 | CommandString(i) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /sqldb/sql_error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqldb 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "regexp" 11 | "strconv" 12 | ) 13 | 14 | const ( 15 | // SQLStateGeneral is the SQLSTATE value for "general error". 16 | SQLStateGeneral = "HY000" 17 | ) 18 | 19 | // SQLError is the error structure returned from calling a db library function 20 | type SQLError struct { 21 | Num uint16 22 | State string 23 | Message string 24 | Query string 25 | } 26 | 27 | // NewSQLError creates new sql error. 28 | func NewSQLError(number uint16, args ...interface{}) *SQLError { 29 | sqlErr := &SQLError{} 30 | err, ok := SQLErrors[number] 31 | if !ok { 32 | unknow := SQLErrors[ER_UNKNOWN_ERROR] 33 | sqlErr.Num = unknow.Num 34 | sqlErr.State = unknow.State 35 | err = unknow 36 | } else { 37 | sqlErr.Num = err.Num 38 | sqlErr.State = err.State 39 | } 40 | sqlErr.Message = fmt.Sprintf(err.Message, args...) 41 | return sqlErr 42 | } 43 | 44 | func NewSQLErrorf(number uint16, format string, args ...interface{}) *SQLError { 45 | sqlErr := &SQLError{} 46 | err, ok := SQLErrors[number] 47 | if !ok { 48 | unknow := SQLErrors[ER_UNKNOWN_ERROR] 49 | sqlErr.Num = unknow.Num 50 | sqlErr.State = unknow.State 51 | } else { 52 | sqlErr.Num = err.Num 53 | sqlErr.State = err.State 54 | } 55 | sqlErr.Message = fmt.Sprintf(format, args...) 56 | return sqlErr 57 | } 58 | 59 | // NewSQLError1 creates new sql error with state. 60 | func NewSQLError1(number uint16, state string, format string, args ...interface{}) *SQLError { 61 | return &SQLError{ 62 | Num: number, 63 | State: state, 64 | Message: fmt.Sprintf(format, args...), 65 | } 66 | } 67 | 68 | // Error implements the error interface 69 | func (se *SQLError) Error() string { 70 | buf := &bytes.Buffer{} 71 | buf.WriteString(se.Message) 72 | 73 | // Add MySQL errno and SQLSTATE in a format that we can later parse. 74 | // There's no avoiding string parsing because all errors 75 | // are converted to strings anyway at RPC boundaries. 76 | // See NewSQLErrorFromError. 77 | fmt.Fprintf(buf, " (errno %v) (sqlstate %v)", se.Num, se.State) 78 | 79 | if se.Query != "" { 80 | fmt.Fprintf(buf, " during query: %s", se.Query) 81 | } 82 | return buf.String() 83 | } 84 | 85 | var errExtract = regexp.MustCompile(`.*\(errno ([0-9]*)\) \(sqlstate ([0-9a-zA-Z]{5})\).*`) 86 | 87 | // NewSQLErrorFromError returns a *SQLError from the provided error. 88 | // If it's not the right type, it still tries to get it from a regexp. 89 | func NewSQLErrorFromError(err error) error { 90 | if err == nil { 91 | return nil 92 | } 93 | 94 | if serr, ok := err.(*SQLError); ok { 95 | return serr 96 | } 97 | 98 | msg := err.Error() 99 | match := errExtract.FindStringSubmatch(msg) 100 | if len(match) < 2 { 101 | // Not found, build a generic SQLError. 102 | // TODO(alainjobart) maybe we can also check the canonical 103 | // error code, and translate that into the right error. 104 | 105 | // FIXME(alainjobart): 1105 is unknown error. Will 106 | // merge with sqlconn later. 107 | unknow := SQLErrors[ER_UNKNOWN_ERROR] 108 | return &SQLError{ 109 | Num: unknow.Num, 110 | State: unknow.State, 111 | Message: msg, 112 | } 113 | } 114 | 115 | num, err := strconv.Atoi(match[1]) 116 | if err != nil { 117 | unknow := SQLErrors[ER_UNKNOWN_ERROR] 118 | return &SQLError{ 119 | Num: unknow.Num, 120 | State: unknow.State, 121 | Message: msg, 122 | } 123 | } 124 | 125 | serr := &SQLError{ 126 | Num: uint16(num), 127 | State: match[2], 128 | Message: msg, 129 | } 130 | return serr 131 | } 132 | -------------------------------------------------------------------------------- /sqldb/sql_error_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package sqldb 11 | 12 | import ( 13 | "testing" 14 | 15 | "errors" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestSqlError(t *testing.T) { 20 | { 21 | sqlerr := NewSQLError(1, "i.am.error.man") 22 | assert.Equal(t, "i.am.error.man (errno 1105) (sqlstate HY000)", sqlerr.Error()) 23 | } 24 | 25 | { 26 | sqlerr := NewSQLErrorf(1, "i.am.error.man%s", "xx") 27 | assert.Equal(t, "i.am.error.manxx (errno 1105) (sqlstate HY000)", sqlerr.Error()) 28 | } 29 | 30 | { 31 | sqlerr := NewSQLError(ER_NO_DB_ERROR) 32 | assert.Equal(t, "No database selected (errno 1046) (sqlstate 3D000)", sqlerr.Error()) 33 | } 34 | } 35 | 36 | func TestSqlErrorFromErr(t *testing.T) { 37 | { 38 | err := errors.New("errorman") 39 | sqlerr := NewSQLErrorFromError(err) 40 | assert.NotNil(t, sqlerr) 41 | } 42 | 43 | { 44 | err := errors.New("i.am.error.man (errno 1) (sqlstate HY000)") 45 | sqlerr := NewSQLErrorFromError(err) 46 | assert.NotNil(t, sqlerr) 47 | } 48 | 49 | { 50 | err := errors.New("No database selected (errno 1046) (sqlstate 3D000)") 51 | want := &SQLError{Num: 1046, State: "3D000", Message: "No database selected (errno 1046) (sqlstate 3D000)"} 52 | got := NewSQLErrorFromError(err) 53 | assert.Equal(t, want, got) 54 | } 55 | 56 | { 57 | err := NewSQLError1(10086, "xx", "i.am.the.error.man.%s", "xx") 58 | want := &SQLError{Num: 10086, State: "xx", Message: "i.am.the.error.man.xx"} 59 | got := NewSQLErrorFromError(err) 60 | assert.Equal(t, want, got) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /sqlparser/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /sqlparser/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /sqlparser/.idea/sqlparser.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /sqlparser/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /sqlparser/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2012, Google Inc. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can 3 | # be found in the LICENSE file. 4 | 5 | MAKEFLAGS = -s 6 | 7 | sql.go: sql.y 8 | goyacc -o sql.go sql.y 9 | 10 | visitor: 11 | go generate rewriter.go 12 | 13 | clean: 14 | rm -f y.output sql.go 15 | -------------------------------------------------------------------------------- /sqlparser/analyzer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | // analyzer.go contains utility analysis functions. 20 | 21 | import ( 22 | "errors" 23 | "fmt" 24 | "strings" 25 | "unicode" 26 | 27 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 28 | ) 29 | 30 | // These constants are used to identify the SQL statement type. 31 | const ( 32 | StmtSelect = iota 33 | StmtInsert 34 | StmtReplace 35 | StmtUpdate 36 | StmtDelete 37 | StmtDDL 38 | StmtBegin 39 | StmtCommit 40 | StmtRollback 41 | StmtSet 42 | StmtShow 43 | StmtUse 44 | StmtOther 45 | StmtUnknown 46 | ) 47 | 48 | // Preview analyzes the beginning of the query using a simpler and faster 49 | // textual comparison to identify the statement type. 50 | func Preview(sql string) int { 51 | trimmed := StripLeadingComments(sql) 52 | 53 | firstWord := trimmed 54 | if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 { 55 | firstWord = trimmed[:end] 56 | } 57 | 58 | // Comparison is done in order of priority. 59 | loweredFirstWord := strings.ToLower(firstWord) 60 | switch loweredFirstWord { 61 | case "select": 62 | return StmtSelect 63 | case "insert": 64 | return StmtInsert 65 | case "replace": 66 | return StmtReplace 67 | case "update": 68 | return StmtUpdate 69 | case "delete": 70 | return StmtDelete 71 | } 72 | switch strings.ToLower(trimmed) { 73 | case "begin", "start transaction": 74 | return StmtBegin 75 | case "commit": 76 | return StmtCommit 77 | case "rollback": 78 | return StmtRollback 79 | } 80 | switch loweredFirstWord { 81 | case "create", "alter", "rename", "drop": 82 | return StmtDDL 83 | case "set": 84 | return StmtSet 85 | case "show": 86 | return StmtShow 87 | case "use": 88 | return StmtUse 89 | case "analyze", "describe", "desc", "explain", "repair", "optimize", "truncate": 90 | return StmtOther 91 | } 92 | return StmtUnknown 93 | } 94 | 95 | // IsDML returns true if the query is an INSERT, UPDATE or DELETE statement. 96 | func IsDML(sql string) bool { 97 | switch Preview(sql) { 98 | case StmtInsert, StmtReplace, StmtUpdate, StmtDelete: 99 | return true 100 | } 101 | return false 102 | } 103 | 104 | // GetTableName returns the table name from the SimpleTableExpr 105 | // only if it's a simple expression. Otherwise, it returns "". 106 | func GetTableName(node SimpleTableExpr) TableIdent { 107 | if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() { 108 | return n.Name 109 | } 110 | // sub-select or '.' expression 111 | return NewTableIdent("") 112 | } 113 | 114 | // IsColName returns true if the Expr is a *ColName. 115 | func IsColName(node Expr) bool { 116 | _, ok := node.(*ColName) 117 | return ok 118 | } 119 | 120 | // IsValue returns true if the Expr is a string, integral or value arg. 121 | // NULL is not considered to be a value. 122 | func IsValue(node Expr) bool { 123 | switch v := node.(type) { 124 | case *SQLVal: 125 | switch v.Type { 126 | case StrVal, HexVal, IntVal, ValArg: 127 | return true 128 | } 129 | case *ValuesFuncExpr: 130 | if v.Resolved != nil { 131 | return IsValue(v.Resolved) 132 | } 133 | } 134 | return false 135 | } 136 | 137 | // IsNull returns true if the Expr is SQL NULL 138 | func IsNull(node Expr) bool { 139 | switch node.(type) { 140 | case *NullVal: 141 | return true 142 | } 143 | return false 144 | } 145 | 146 | // IsSimpleTuple returns true if the Expr is a ValTuple that 147 | // contains simple values or if it's a list arg. 148 | func IsSimpleTuple(node Expr) bool { 149 | switch vals := node.(type) { 150 | case ValTuple: 151 | for _, n := range vals { 152 | if !IsValue(n) { 153 | return false 154 | } 155 | } 156 | return true 157 | case ListArg: 158 | return true 159 | } 160 | // It's a subquery 161 | return false 162 | } 163 | 164 | // NewPlanValue builds a sqltypes.PlanValue from an Expr. 165 | func NewPlanValue(node Expr) (sqltypes.PlanValue, error) { 166 | switch node := node.(type) { 167 | case *SQLVal: 168 | switch node.Type { 169 | case ValArg: 170 | return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil 171 | case IntVal: 172 | n, err := sqltypes.NewIntegral(string(node.Val)) 173 | if err != nil { 174 | return sqltypes.PlanValue{}, err 175 | } 176 | return sqltypes.PlanValue{Value: n}, nil 177 | case StrVal: 178 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil 179 | case HexVal: 180 | v, err := node.HexDecode() 181 | if err != nil { 182 | return sqltypes.PlanValue{}, err 183 | } 184 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil 185 | } 186 | case ListArg: 187 | return sqltypes.PlanValue{ListKey: string(node[2:])}, nil 188 | case ValTuple: 189 | pv := sqltypes.PlanValue{ 190 | Values: make([]sqltypes.PlanValue, 0, len(node)), 191 | } 192 | for _, val := range node { 193 | innerpv, err := NewPlanValue(val) 194 | if err != nil { 195 | return sqltypes.PlanValue{}, err 196 | } 197 | if innerpv.ListKey != "" || innerpv.Values != nil { 198 | return sqltypes.PlanValue{}, errors.New("unsupported: nested lists") 199 | } 200 | pv.Values = append(pv.Values, innerpv) 201 | } 202 | return pv, nil 203 | case *ValuesFuncExpr: 204 | if node.Resolved != nil { 205 | return NewPlanValue(node.Resolved) 206 | } 207 | case *NullVal: 208 | return sqltypes.PlanValue{}, nil 209 | } 210 | return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node)) 211 | } 212 | 213 | // StringIn is a convenience function that returns 214 | // true if str matches any of the values. 215 | func StringIn(str string, values ...string) bool { 216 | for _, val := range values { 217 | if str == val { 218 | return true 219 | } 220 | } 221 | return false 222 | } 223 | -------------------------------------------------------------------------------- /sqlparser/checksum_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "strings" 20 | import "testing" 21 | 22 | func TestChecksumTable(t *testing.T) { 23 | validSQL := []struct { 24 | input string 25 | output string 26 | }{ 27 | { 28 | input: "checksum table test.t1", 29 | output: "checksum table test.t1", 30 | }, 31 | 32 | { 33 | input: "checksum table t1", 34 | output: "checksum table t1", 35 | }, 36 | } 37 | 38 | for _, s := range validSQL { 39 | sql := strings.TrimSpace(s.input) 40 | tree, err := Parse(sql) 41 | if err != nil { 42 | t.Errorf("input: %s, err: %v", sql, err) 43 | continue 44 | } 45 | 46 | // Walk. 47 | Walk(func(node SQLNode) (bool, error) { 48 | return true, nil 49 | }, tree) 50 | 51 | got := String(tree.(*Checksum)) 52 | if s.output != got { 53 | t.Errorf("want:\n%s\ngot:\n%s", s.output, got) 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /sqlparser/depends/bytes2/buffer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package bytes2 18 | 19 | // Buffer implements a subset of the write portion of 20 | // bytes.Buffer, but more efficiently. This is meant to 21 | // be used in very high QPS operations, especially for 22 | // WriteByte, and without abstracting it as a Writer. 23 | // Function signatures contain errors for compatibility, 24 | // but they do not return errors. 25 | type Buffer struct { 26 | bytes []byte 27 | } 28 | 29 | // NewBuffer is equivalent to bytes.NewBuffer. 30 | func NewBuffer(b []byte) *Buffer { 31 | return &Buffer{bytes: b} 32 | } 33 | 34 | // Write is equivalent to bytes.Buffer.Write. 35 | func (buf *Buffer) Write(b []byte) (int, error) { 36 | buf.bytes = append(buf.bytes, b...) 37 | return len(b), nil 38 | } 39 | 40 | // WriteString is equivalent to bytes.Buffer.WriteString. 41 | func (buf *Buffer) WriteString(s string) (int, error) { 42 | buf.bytes = append(buf.bytes, s...) 43 | return len(s), nil 44 | } 45 | 46 | // WriteByte is equivalent to bytes.Buffer.WriteByte. 47 | func (buf *Buffer) WriteByte(b byte) error { 48 | buf.bytes = append(buf.bytes, b) 49 | return nil 50 | } 51 | 52 | // Bytes is equivalent to bytes.Buffer.Bytes. 53 | func (buf *Buffer) Bytes() []byte { 54 | return buf.bytes 55 | } 56 | 57 | // Strings is equivalent to bytes.Buffer.Strings. 58 | func (buf *Buffer) String() string { 59 | return string(buf.bytes) 60 | } 61 | 62 | // Len is equivalent to bytes.Buffer.Len. 63 | func (buf *Buffer) Len() int { 64 | return len(buf.bytes) 65 | } 66 | -------------------------------------------------------------------------------- /sqlparser/depends/bytes2/buffer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package bytes2 18 | 19 | import "testing" 20 | 21 | func TestBuffer(t *testing.T) { 22 | b := NewBuffer(nil) 23 | b.Write([]byte("ab")) 24 | b.WriteString("cd") 25 | b.WriteByte('e') 26 | want := "abcde" 27 | if got := string(b.Bytes()); got != want { 28 | t.Errorf("b.Bytes(): %s, want %s", got, want) 29 | } 30 | if got := b.String(); got != want { 31 | t.Errorf("b.String(): %s, want %s", got, want) 32 | } 33 | if got := b.Len(); got != 5 { 34 | t.Errorf("b.Len(): %d, want 5", got) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /sqlparser/depends/common/hash_table.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package common 11 | 12 | import ( 13 | "bytes" 14 | "hash/fnv" 15 | ) 16 | 17 | // hash64a used to get bucket. 18 | func hash64a(data []byte) uint64 { 19 | h := fnv.New64a() 20 | h.Write(data) 21 | return h.Sum64() 22 | } 23 | 24 | type entry struct { 25 | // key slice. 26 | key []byte 27 | // value interface. 28 | value []interface{} 29 | // point to the next entry. 30 | next *entry 31 | } 32 | 33 | func (e *entry) put(key []byte, value interface{}) *entry { 34 | if e == nil { 35 | return &entry{key, []interface{}{value}, nil} 36 | } 37 | if bytes.Equal(e.key, key) { 38 | e.value = append(e.value, value) 39 | return e 40 | } 41 | 42 | e.next = e.next.put(key, value) 43 | return e 44 | } 45 | 46 | func (e *entry) get(key []byte) (bool, []interface{}) { 47 | if e == nil { 48 | return false, nil 49 | } else if bytes.Equal(e.key, key) { 50 | return true, e.value 51 | } else { 52 | return e.next.get(key) 53 | } 54 | } 55 | 56 | // HashTable the hash table. 57 | type HashTable struct { 58 | // stores value for a given key. 59 | hashEntry []*entry 60 | // k: bucket. v: index in the hashEntry. 61 | hashMap map[uint64]int 62 | // size of entrys. 63 | size int 64 | } 65 | 66 | // NewHashTable create hash table. 67 | func NewHashTable() *HashTable { 68 | return &HashTable{ 69 | hashMap: make(map[uint64]int), 70 | size: 0, 71 | } 72 | } 73 | 74 | // Size used to get the hashtable size. 75 | func (h *HashTable) Size() int { 76 | return h.size 77 | } 78 | 79 | // Put puts the key/value pairs to the HashTable. 80 | func (h *HashTable) Put(key []byte, value interface{}) { 81 | var table *entry 82 | bucket := hash64a(key) 83 | index, ok := h.hashMap[bucket] 84 | if !ok { 85 | table = &entry{key, []interface{}{value}, nil} 86 | h.hashMap[bucket] = len(h.hashEntry) 87 | h.hashEntry = append(h.hashEntry, table) 88 | } else { 89 | h.hashEntry[index] = h.hashEntry[index].put(key, value) 90 | } 91 | h.size++ 92 | } 93 | 94 | // Get gets the values of the "key". 95 | func (h *HashTable) Get(key []byte) (bool, []interface{}) { 96 | bucket := hash64a(key) 97 | index, ok := h.hashMap[bucket] 98 | if !ok { 99 | return false, nil 100 | } 101 | return h.hashEntry[index].get(key) 102 | } 103 | 104 | // Iterator used to iterate the HashTable. 105 | type Iterator func() (key []byte, value []interface{}, next Iterator) 106 | 107 | // Next used to iterate the HashTable. 108 | func (h *HashTable) Next() Iterator { 109 | var e *entry 110 | var iter Iterator 111 | table := h.hashEntry 112 | i := -1 113 | iter = func() (key []byte, val []interface{}, next Iterator) { 114 | for e == nil { 115 | i++ 116 | if i >= len(table) { 117 | return nil, nil, nil 118 | } 119 | e = table[i] 120 | } 121 | key = e.key 122 | val = e.value 123 | e = e.next 124 | return key, val, iter 125 | } 126 | return iter 127 | } 128 | -------------------------------------------------------------------------------- /sqlparser/depends/common/hash_table_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package common 11 | 12 | import ( 13 | "bytes" 14 | crand "crypto/rand" 15 | "testing" 16 | ) 17 | 18 | func TestHashKey(t *testing.T) { 19 | a := []byte("asdf") 20 | b := []byte("asdf") 21 | c := []byte("csfd") 22 | if !bytes.Equal(a, b) { 23 | t.Error("a != b") 24 | } 25 | if hash64a(a) != hash64a(b) { 26 | t.Error("hash64a(a) != hash64a(b)") 27 | } 28 | if bytes.Equal(a, c) { 29 | t.Error("a == c") 30 | } 31 | if hash64a(a) == hash64a(c) { 32 | t.Error("hash64a(a) == hash64a(c)") 33 | } 34 | } 35 | 36 | func randSlice(length int) []byte { 37 | slice := make([]byte, length) 38 | if _, err := crand.Read(slice); err != nil { 39 | panic(err) 40 | } 41 | return slice 42 | } 43 | 44 | func TestPutHasGetRemove(t *testing.T) { 45 | 46 | type record struct { 47 | key []byte 48 | val []byte 49 | } 50 | 51 | ranrec := func() *record { 52 | return &record{ 53 | randSlice(20), 54 | randSlice(20), 55 | } 56 | } 57 | 58 | table := NewHashTable() 59 | records := make([]*record, 400) 60 | var i int 61 | for i = range records { 62 | r := ranrec() 63 | records[i] = r 64 | table.Put(r.key, []byte("")) 65 | table.Put(r.key, r.val) 66 | 67 | if table.Size() != 2*(i+1) { 68 | t.Error("size was wrong", table.Size(), i+1) 69 | } 70 | } 71 | 72 | for _, r := range records { 73 | if has, val := table.Get(r.key); !has { 74 | t.Error(table, "Missing key") 75 | } else if !bytes.Equal(val[1].([]byte), r.val) { 76 | t.Error("wrong value") 77 | } 78 | if has, _ := table.Get(randSlice(12)); has { 79 | t.Error("Table has extra key") 80 | } 81 | } 82 | } 83 | 84 | func TestIterate(t *testing.T) { 85 | table := NewHashTable() 86 | t.Logf("%T", table) 87 | for k, v, next := table.Next()(); next != nil; k, v, next = next() { 88 | t.Errorf("Should never reach here %v %v %v", k, v, next) 89 | } 90 | records := make(map[string][]byte) 91 | for i := 0; i < 100; i++ { 92 | v := randSlice(8) 93 | keySlice := []byte{0x01} 94 | keySlice = append(keySlice, v...) 95 | keySlice = append(keySlice, 0x02) 96 | k := BytesToString(keySlice) 97 | records[k] = v 98 | table.Put(v, k) 99 | if table.Size() != (i + 1) { 100 | t.Error("size was wrong", table.Size(), i+1) 101 | } 102 | } 103 | count := 0 104 | for k, v, next := table.Next()(); next != nil; k, v, next = next() { 105 | if v1, has := records[v[0].(string)]; !has { 106 | t.Error("bad key in table") 107 | } else if !bytes.Equal(k, v1) { 108 | t.Error("values don't equal") 109 | } 110 | count++ 111 | } 112 | if len(records) != count { 113 | t.Error("iterate missed records") 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /sqlparser/depends/common/unsafe.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package common 11 | 12 | import ( 13 | "reflect" 14 | "unsafe" 15 | ) 16 | 17 | // BytesToString casts slice to string without copy 18 | func BytesToString(b []byte) (s string) { 19 | if len(b) == 0 { 20 | return "" 21 | } 22 | 23 | bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 24 | sh := reflect.StringHeader{Data: bh.Data, Len: bh.Len} 25 | 26 | return *(*string)(unsafe.Pointer(&sh)) 27 | } 28 | 29 | // StringToBytes casts string to slice without copy 30 | func StringToBytes(s string) []byte { 31 | if len(s) == 0 { 32 | return []byte{} 33 | } 34 | 35 | sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) 36 | bh := reflect.SliceHeader{Data: sh.Data, Len: sh.Len, Cap: sh.Len} 37 | 38 | return *(*[]byte)(unsafe.Pointer(&bh)) 39 | } 40 | -------------------------------------------------------------------------------- /sqlparser/depends/common/unsafe_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package common 11 | 12 | import ( 13 | "github.com/stretchr/testify/assert" 14 | "testing" 15 | ) 16 | 17 | func TestBytesToString(t *testing.T) { 18 | { 19 | bs := []byte{0x61, 0x62} 20 | want := "ab" 21 | got := BytesToString(bs) 22 | assert.Equal(t, want, got) 23 | } 24 | 25 | { 26 | bs := []byte{} 27 | want := "" 28 | got := BytesToString(bs) 29 | assert.Equal(t, want, got) 30 | } 31 | } 32 | 33 | func TestSting(t *testing.T) { 34 | { 35 | want := []byte{0x61, 0x62} 36 | got := StringToBytes("ab") 37 | assert.Equal(t, want, got) 38 | } 39 | 40 | { 41 | want := []byte{} 42 | got := StringToBytes("") 43 | assert.Equal(t, want, got) 44 | } 45 | } 46 | 47 | func TestStingToBytes(t *testing.T) { 48 | { 49 | want := []byte{0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x2a, 0x20, 0x46, 0x52, 0x4f, 0x4d, 0x20, 0x74, 0x32} 50 | got := StringToBytes("SELECT * FROM t2") 51 | assert.Equal(t, want, got) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/column.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // Copyright (c) XeLabs 6 | // BohuTANG 7 | 8 | package sqltypes 9 | 10 | import ( 11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 12 | ) 13 | 14 | // RemoveColumns used to remove columns who in the idxs. 15 | func (result *Result) RemoveColumns(idxs ...int) { 16 | c := len(idxs) 17 | if c == 0 { 18 | return 19 | } 20 | 21 | if result.Fields != nil { 22 | var fields []*querypb.Field 23 | for i, f := range result.Fields { 24 | in := false 25 | for _, idx := range idxs { 26 | if i == idx { 27 | in = true 28 | break 29 | } 30 | } 31 | if !in { 32 | fields = append(fields, f) 33 | } 34 | } 35 | result.Fields = fields 36 | } 37 | 38 | if result.Rows != nil { 39 | for i, r := range result.Rows { 40 | var row []Value 41 | for i, v := range r { 42 | in := false 43 | for _, idx := range idxs { 44 | if i == idx { 45 | in = true 46 | break 47 | } 48 | } 49 | if !in { 50 | row = append(row, v) 51 | } 52 | } 53 | result.Rows[i] = row 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/column_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // Copyright (c) XeLabs 6 | // BohuTANG 7 | 8 | package sqltypes 9 | 10 | import ( 11 | "fmt" 12 | "reflect" 13 | "testing" 14 | 15 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 16 | ) 17 | 18 | func TestColumnRemove(t *testing.T) { 19 | rt := &Result{ 20 | Fields: []*querypb.Field{{ 21 | Name: "a", 22 | Type: Int32, 23 | }, { 24 | Name: "b", 25 | Type: Uint24, 26 | }, { 27 | Name: "c", 28 | Type: Float32, 29 | }, 30 | }, 31 | Rows: [][]Value{ 32 | {testVal(Int32, "-5"), testVal(Uint64, "10"), testVal(Float32, "3.1415926")}, 33 | {testVal(Int32, "-4"), testVal(Uint64, "9"), testVal(Float32, "3.1415927")}, 34 | {testVal(Int32, "-3"), testVal(Uint64, "8"), testVal(Float32, "3.1415928")}, 35 | {testVal(Int32, "1"), testVal(Uint64, "1"), testVal(Float32, "3.1415926")}, 36 | {testVal(Int32, "1"), testVal(Uint64, "1"), testVal(Float32, "3.1415925")}, 37 | }, 38 | } 39 | 40 | { 41 | rs := rt.Copy() 42 | rs.RemoveColumns(0) 43 | { 44 | want := []*querypb.Field{ 45 | { 46 | Name: "b", 47 | Type: Uint24, 48 | }, { 49 | Name: "c", 50 | Type: Float32, 51 | }, 52 | } 53 | got := rs.Fields 54 | if !reflect.DeepEqual(want, got) { 55 | t.Errorf("want:%+v\n, got:%+v", want, got) 56 | } 57 | } 58 | 59 | { 60 | want := "[[10 3.1415926] [9 3.1415927] [8 3.1415928] [1 3.1415926] [1 3.1415925]]" 61 | got := fmt.Sprintf("%+v", rs.Rows) 62 | if want != got { 63 | t.Errorf("want:%s\n, got:%+s", want, got) 64 | } 65 | } 66 | } 67 | 68 | { 69 | rs := rt.Copy() 70 | rs.RemoveColumns(2) 71 | { 72 | want := []*querypb.Field{ 73 | { 74 | Name: "a", 75 | Type: Int32, 76 | }, { 77 | Name: "b", 78 | Type: Uint24, 79 | }, 80 | } 81 | got := rs.Fields 82 | if !reflect.DeepEqual(want, got) { 83 | t.Errorf("want:%+v\n, got:%+v", want, got) 84 | } 85 | } 86 | 87 | { 88 | want := "[[-5 10] [-4 9] [-3 8] [1 1] [1 1]]" 89 | got := fmt.Sprintf("%+v", rs.Rows) 90 | if want != got { 91 | t.Errorf("want:%s\n, got:%s", want, got) 92 | } 93 | } 94 | } 95 | 96 | { 97 | rs := rt.Copy() 98 | rs.RemoveColumns(0, 1) 99 | { 100 | want := []*querypb.Field{ 101 | { 102 | Name: "c", 103 | Type: Float32, 104 | }, 105 | } 106 | got := rs.Fields 107 | if !reflect.DeepEqual(want, got) { 108 | t.Errorf("want:%+v\n, got:%+v", want, got) 109 | } 110 | } 111 | 112 | { 113 | want := "[[3.1415926] [3.1415927] [3.1415928] [3.1415926] [3.1415925]]" 114 | got := fmt.Sprintf("%+v", rs.Rows) 115 | if want != got { 116 | t.Errorf("want:%s\n, got:%s", want, got) 117 | } 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/const.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Radon 3 | * 4 | * Copyright 2019 The Radon Authors. 5 | * Code is licensed under the GPLv3. 6 | * 7 | */ 8 | 9 | package sqltypes 10 | 11 | const ( 12 | // DecimalLongLongDigits decimal longlong digits. 13 | DecimalLongLongDigits = 22 14 | // FloatDigits float decimal precision. 15 | FloatDigits = 6 16 | // DoubleDigits double decimal precision. 17 | DoubleDigits = 15 18 | ) 19 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/limit.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // Copyright (c) XeLabs 6 | // BohuTANG 7 | 8 | package sqltypes 9 | 10 | // Limit used to cutoff the rows based on the MySQL LIMIT and OFFSET clauses. 11 | func (result *Result) Limit(offset, limit int) { 12 | count := len(result.Rows) 13 | start := offset 14 | end := offset + limit 15 | if start > count { 16 | start = count 17 | } 18 | if end > count { 19 | end = count 20 | } 21 | result.Rows = result.Rows[start:end] 22 | } 23 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/limit_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // Copyright (c) XeLabs 6 | // BohuTANG 7 | 8 | package sqltypes 9 | 10 | import ( 11 | "reflect" 12 | "testing" 13 | ) 14 | 15 | func TestLimit(t *testing.T) { 16 | rs := &Result{ 17 | Rows: [][]Value{ 18 | {testVal(VarChar, "1")}, {testVal(VarChar, "2")}, {testVal(VarChar, "3")}, {testVal(VarChar, "4")}, {testVal(VarChar, "5")}, 19 | }, 20 | } 21 | 22 | // normal: offset 0, limit 1. 23 | { 24 | rs1 := rs.Copy() 25 | rs1.Limit(0, 1) 26 | want := rs.Rows[0:1] 27 | got := rs1.Rows 28 | 29 | if !reflect.DeepEqual(want, got) { 30 | t.Errorf("want:\n%#v, got\n%#v", want, got) 31 | } 32 | } 33 | 34 | // normal: offset 0, limit 5. 35 | { 36 | rs1 := rs.Copy() 37 | rs1.Limit(0, 5) 38 | want := rs.Rows 39 | got := rs1.Rows 40 | 41 | if !reflect.DeepEqual(want, got) { 42 | t.Errorf("want:\n%#v, got\n%#v", want, got) 43 | } 44 | } 45 | 46 | // normal: offset 1, limit 4. 47 | { 48 | rs1 := rs.Copy() 49 | rs1.Limit(1, 4) 50 | want := rs.Rows[1:5] 51 | got := rs1.Rows 52 | 53 | if !reflect.DeepEqual(want, got) { 54 | t.Errorf("want:\n%#v, got\n%#v", want, got) 55 | } 56 | } 57 | 58 | // limit overflow: offset 0, limit 6. 59 | { 60 | rs1 := rs.Copy() 61 | rs1.Limit(0, 6) 62 | want := rs.Rows 63 | got := rs1.Rows 64 | 65 | if !reflect.DeepEqual(want, got) { 66 | t.Errorf("want:\n%#v, got\n%#v", want, got) 67 | } 68 | } 69 | 70 | // offset overflow: offset 5, limit 0. 71 | { 72 | rs1 := rs.Copy() 73 | rs1.Limit(5, 0) 74 | want := rs.Rows[5:5] 75 | got := rs1.Rows 76 | 77 | if !reflect.DeepEqual(want, got) { 78 | t.Errorf("want:\n%#v, got\n%#v", want, got) 79 | } 80 | } 81 | 82 | // (offset+limit) overflow: offset 3, limit 6. 83 | { 84 | rs1 := rs.Copy() 85 | rs1.Limit(3, 6) 86 | want := rs.Rows[3:5] 87 | got := rs1.Rows 88 | 89 | if !reflect.DeepEqual(want, got) { 90 | t.Errorf("want:\n%#v, got\n%#v", want, got) 91 | } 92 | } 93 | 94 | // Empty test. 95 | { 96 | rs1 := &Result{ 97 | Rows: [][]Value{ 98 | {}, 99 | }, 100 | } 101 | 102 | rs1.Limit(3, 6) 103 | want := rs.Rows[0:0] 104 | got := rs1.Rows 105 | 106 | if !reflect.DeepEqual(want, got) { 107 | t.Errorf("want:\n%#v, got\n%#v", want, got) 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/plan_value.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqltypes 18 | 19 | import () 20 | 21 | // PlanValue represents a value or a list of values for 22 | // a column that will later be resolved using bind vars and used 23 | // to perform plan actions like generating the final query or 24 | // deciding on a route. 25 | // 26 | // Plan values are typically used as a slice ([]planValue) 27 | // where each entry is for one column. For situations where 28 | // the required output is a list of rows (like in the case 29 | // of multi-value inserts), the representation is pivoted. 30 | // For example, a statement like this: 31 | // INSERT INTO t VALUES (1, 2), (3, 4) 32 | // will be represented as follows: 33 | // []PlanValue{ 34 | // Values: {1, 3}, 35 | // Values: {2, 4}, 36 | // } 37 | // 38 | // For WHERE clause items that contain a combination of 39 | // equality expressions and IN clauses like this: 40 | // WHERE pk1 = 1 AND pk2 IN (2, 3, 4) 41 | // The plan values will be represented as follows: 42 | // []PlanValue{ 43 | // Value: 1, 44 | // Values: {2, 3, 4}, 45 | // } 46 | // When converted into rows, columns with single values 47 | // are replicated as the same for all rows: 48 | // [][]Value{ 49 | // {1, 2}, 50 | // {1, 3}, 51 | // {1, 4}, 52 | // } 53 | type PlanValue struct { 54 | Key string 55 | Value Value 56 | ListKey string 57 | Values []PlanValue 58 | } 59 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/result.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 9 | ) 10 | 11 | // ResultState enum. 12 | type ResultState int 13 | 14 | const ( 15 | // RStateNone enum. 16 | RStateNone ResultState = iota 17 | // RStateFields enum. 18 | RStateFields 19 | // RStateRows enum. 20 | RStateRows 21 | // RStateFinished enum. 22 | RStateFinished 23 | ) 24 | 25 | // Result represents a query result. 26 | type Result struct { 27 | Fields []*querypb.Field `json:"fields"` 28 | RowsAffected uint64 `json:"rows_affected"` 29 | InsertID uint64 `json:"insert_id"` 30 | Warnings uint16 `json:"warnings"` 31 | Rows [][]Value `json:"rows"` 32 | Extras *querypb.ResultExtras `json:"extras"` 33 | State ResultState 34 | } 35 | 36 | // ResultStream is an interface for receiving Result. It is used for 37 | // RPC interfaces. 38 | type ResultStream interface { 39 | // Recv returns the next result on the stream. 40 | // It will return io.EOF if the stream ended. 41 | Recv() (*Result, error) 42 | } 43 | 44 | // Repair fixes the type info in the rows 45 | // to conform to the supplied field types. 46 | func (result *Result) Repair(fields []*querypb.Field) { 47 | // Usage of j is intentional. 48 | for j, f := range fields { 49 | for _, r := range result.Rows { 50 | if r[j].typ != Null { 51 | r[j].typ = f.Type 52 | } 53 | } 54 | } 55 | } 56 | 57 | // Copy creates a deep copy of Result. 58 | func (result *Result) Copy() *Result { 59 | out := &Result{ 60 | InsertID: result.InsertID, 61 | RowsAffected: result.RowsAffected, 62 | } 63 | if result.Fields != nil { 64 | fieldsp := make([]*querypb.Field, len(result.Fields)) 65 | fields := make([]querypb.Field, len(result.Fields)) 66 | for i, f := range result.Fields { 67 | fields[i] = *f 68 | fieldsp[i] = &fields[i] 69 | } 70 | out.Fields = fieldsp 71 | } 72 | if result.Rows != nil { 73 | rows := make([][]Value, len(result.Rows)) 74 | for i, r := range result.Rows { 75 | rows[i] = make([]Value, len(r)) 76 | totalLen := 0 77 | for _, c := range r { 78 | totalLen += len(c.val) 79 | } 80 | arena := make([]byte, 0, totalLen) 81 | for j, c := range r { 82 | start := len(arena) 83 | arena = append(arena, c.val...) 84 | rows[i][j] = MakeTrusted(c.typ, arena[start:start+len(c.val)]) 85 | } 86 | } 87 | out.Rows = rows 88 | } 89 | return out 90 | } 91 | 92 | // StripFieldNames will return a new Result that has the same Rows, 93 | // but the Field objects will have their Name emptied. Note we don't 94 | // proto.Copy each Field for performance reasons, but we only copy the 95 | // individual fields. 96 | func (result *Result) StripFieldNames() *Result { 97 | if len(result.Fields) == 0 { 98 | return result 99 | } 100 | r := *result 101 | r.Fields = make([]*querypb.Field, len(result.Fields)) 102 | newFieldsArray := make([]querypb.Field, len(result.Fields)) 103 | for i, f := range result.Fields { 104 | r.Fields[i] = &newFieldsArray[i] 105 | newFieldsArray[i].Type = f.Type 106 | } 107 | return &r 108 | } 109 | 110 | // AppendResult will combine the Results Objects of one result 111 | // to another result.Note currently it doesn't handle cases like 112 | // if two results have different fields.We will enhance this function. 113 | func (result *Result) AppendResult(src *Result) { 114 | if src.RowsAffected == 0 && len(src.Fields) == 0 { 115 | return 116 | } 117 | if result.Fields == nil { 118 | result.Fields = src.Fields 119 | } 120 | result.RowsAffected += src.RowsAffected 121 | if src.InsertID != 0 { 122 | result.InsertID = src.InsertID 123 | } 124 | if len(src.Rows) != 0 { 125 | result.Rows = append(result.Rows, src.Rows...) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/result_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 12 | ) 13 | 14 | func TestRepair(t *testing.T) { 15 | fields := []*querypb.Field{{ 16 | Type: Int64, 17 | }, { 18 | Type: VarChar, 19 | }} 20 | in := Result{ 21 | Rows: [][]Value{ 22 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")}, 23 | {testVal(VarBinary, "2"), testVal(VarBinary, "bb")}, 24 | }, 25 | } 26 | want := Result{ 27 | Rows: [][]Value{ 28 | {testVal(Int64, "1"), testVal(VarChar, "aa")}, 29 | {testVal(Int64, "2"), testVal(VarChar, "bb")}, 30 | }, 31 | } 32 | in.Repair(fields) 33 | if !reflect.DeepEqual(in, want) { 34 | t.Errorf("Repair:\n%#v, want\n%#v", in, want) 35 | } 36 | } 37 | 38 | func TestCopy(t *testing.T) { 39 | in := &Result{ 40 | Fields: []*querypb.Field{{ 41 | Type: Int64, 42 | }, { 43 | Type: VarChar, 44 | }}, 45 | InsertID: 1, 46 | RowsAffected: 2, 47 | Rows: [][]Value{ 48 | {testVal(Int64, "1"), NULL}, 49 | {testVal(Int64, "2"), MakeTrusted(VarChar, nil)}, 50 | {testVal(Int64, "3"), testVal(VarChar, "")}, 51 | }, 52 | } 53 | want := &Result{ 54 | Fields: []*querypb.Field{{ 55 | Type: Int64, 56 | }, { 57 | Type: VarChar, 58 | }}, 59 | InsertID: 1, 60 | RowsAffected: 2, 61 | Rows: [][]Value{ 62 | {testVal(Int64, "1"), NULL}, 63 | {testVal(Int64, "2"), testVal(VarChar, "")}, 64 | {testVal(Int64, "3"), testVal(VarChar, "")}, 65 | }, 66 | } 67 | out := in.Copy() 68 | // Change in so we're sure out got actually copied 69 | in.Fields[0].Type = VarChar 70 | in.Rows[0][0] = testVal(VarChar, "aa") 71 | if !reflect.DeepEqual(out, want) { 72 | t.Errorf("Copy:\n%#v, want\n%#v", out, want) 73 | } 74 | } 75 | 76 | func TestStripFieldNames(t *testing.T) { 77 | testcases := []struct { 78 | name string 79 | in *Result 80 | expected *Result 81 | }{{ 82 | name: "no fields", 83 | in: &Result{}, 84 | expected: &Result{}, 85 | }, { 86 | name: "empty fields", 87 | in: &Result{ 88 | Fields: []*querypb.Field{}, 89 | }, 90 | expected: &Result{ 91 | Fields: []*querypb.Field{}, 92 | }, 93 | }, { 94 | name: "no name", 95 | in: &Result{ 96 | Fields: []*querypb.Field{{ 97 | Type: Int64, 98 | }, { 99 | Type: VarChar, 100 | }}, 101 | }, 102 | expected: &Result{ 103 | Fields: []*querypb.Field{{ 104 | Type: Int64, 105 | }, { 106 | Type: VarChar, 107 | }}, 108 | }, 109 | }, { 110 | name: "names", 111 | in: &Result{ 112 | Fields: []*querypb.Field{{ 113 | Name: "field1", 114 | Type: Int64, 115 | }, { 116 | Name: "field2", 117 | Type: VarChar, 118 | }}, 119 | }, 120 | expected: &Result{ 121 | Fields: []*querypb.Field{{ 122 | Type: Int64, 123 | }, { 124 | Type: VarChar, 125 | }}, 126 | }, 127 | }} 128 | for _, tcase := range testcases { 129 | inCopy := tcase.in.Copy() 130 | out := inCopy.StripFieldNames() 131 | if !reflect.DeepEqual(out, tcase.expected) { 132 | t.Errorf("StripFieldNames unexpected result for %v: %v", tcase.name, out) 133 | } 134 | if len(tcase.in.Fields) > 0 { 135 | // check the out array is different than the in array. 136 | if out.Fields[0] == inCopy.Fields[0] { 137 | t.Errorf("StripFieldNames modified original Field for %v", tcase.name) 138 | } 139 | } 140 | // check we didn't change the original result. 141 | if !reflect.DeepEqual(tcase.in, inCopy) { 142 | t.Errorf("StripFieldNames modified original result") 143 | } 144 | } 145 | } 146 | 147 | func TestAppendResult(t *testing.T) { 148 | r1 := &Result{ 149 | RowsAffected: 3, 150 | Rows: [][]Value{ 151 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")}, 152 | }, 153 | } 154 | r2 := &Result{ 155 | RowsAffected: 5, 156 | Rows: [][]Value{ 157 | {testVal(VarBinary, "2"), testVal(VarBinary, "aa2")}, 158 | }, 159 | } 160 | r1.AppendResult(r2) 161 | 162 | got := r1 163 | want := &Result{ 164 | RowsAffected: 8, 165 | Rows: [][]Value{ 166 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")}, 167 | {testVal(VarBinary, "2"), testVal(VarBinary, "aa2")}, 168 | }, 169 | } 170 | if !reflect.DeepEqual(got, want) { 171 | t.Errorf("Append:\n%#v, want\n%#v", got, want) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/row.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // Copyright (c) XeLabs 6 | // BohuTANG 7 | 8 | package sqltypes 9 | 10 | // Row operations. 11 | type Row []Value 12 | 13 | // Copy used to clone the new value. 14 | func (r Row) Copy() []Value { 15 | ret := make([]Value, len(r)) 16 | for i, v := range r { 17 | ret[i] = v 18 | } 19 | return ret 20 | } 21 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/time.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | */ 8 | 9 | package sqltypes 10 | 11 | import ( 12 | "errors" 13 | "strconv" 14 | "strings" 15 | 16 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 17 | ) 18 | 19 | // timeToNumeric used to cast time type to numeric. 20 | func timeToNumeric(v Value) (numeric, error) { 21 | switch v.Type() { 22 | case querypb.Type_TIMESTAMP, querypb.Type_DATETIME: 23 | var i int64 24 | year, err := strconv.ParseInt(string(v.val[0:4]), 10, 16) 25 | if err != nil { 26 | return numeric{}, err 27 | } 28 | month, err := strconv.ParseInt(string(v.val[5:7]), 10, 8) 29 | if err != nil { 30 | return numeric{}, err 31 | } 32 | day, err := strconv.ParseInt(string(v.val[8:10]), 10, 8) 33 | if err != nil { 34 | return numeric{}, err 35 | } 36 | hour, err := strconv.ParseInt(string(v.val[11:13]), 10, 8) 37 | if err != nil { 38 | return numeric{}, err 39 | } 40 | minute, err := strconv.ParseInt(string(v.val[14:16]), 10, 8) 41 | if err != nil { 42 | return numeric{}, err 43 | } 44 | second, err := strconv.ParseInt(string(v.val[17:19]), 10, 8) 45 | if err != nil { 46 | return numeric{}, err 47 | } 48 | 49 | i = (year*10000+month*100+day)*1000000 + (hour*10000 + minute*100 + second) 50 | if len(v.val) > 19 { 51 | var f float64 52 | microSecond, err := strconv.ParseUint(string(v.val[20:]), 10, 32) 53 | if err != nil { 54 | return numeric{}, err 55 | } 56 | 57 | microSec := float64(microSecond) 58 | n := len(v.val[20:]) 59 | for n != 0 { 60 | microSec *= 0.1 61 | n-- 62 | } 63 | f = float64(i) + microSec 64 | return numeric{fval: f, typ: Float64}, nil 65 | } 66 | return numeric{ival: i, typ: Int64}, nil 67 | case querypb.Type_DATE: 68 | var i int64 69 | year, err := strconv.ParseInt(string(v.val[0:4]), 10, 16) 70 | if err != nil { 71 | return numeric{}, err 72 | } 73 | month, err := strconv.ParseInt(string(v.val[5:7]), 10, 8) 74 | if err != nil { 75 | return numeric{}, err 76 | } 77 | day, err := strconv.ParseInt(string(v.val[8:]), 10, 8) 78 | if err != nil { 79 | return numeric{}, err 80 | } 81 | i = year*10000 + month*100 + day 82 | return numeric{ival: i, typ: Int64}, nil 83 | case querypb.Type_TIME: 84 | var i int64 85 | sub := strings.Split(string(v.val), ":") 86 | if len(sub) != 3 { 87 | return numeric{}, errors.New("incorrect.time.value,':'.is.not.found") 88 | } 89 | 90 | pre := int64(1) 91 | if strings.HasPrefix(sub[0], "-") { 92 | pre = -1 93 | sub[0] = sub[0][1:] 94 | } 95 | 96 | hour, err := strconv.ParseInt(string(sub[0]), 10, 32) 97 | if err != nil { 98 | return numeric{}, err 99 | } 100 | minute, err := strconv.ParseInt(string(sub[1]), 10, 8) 101 | if err != nil { 102 | return numeric{}, err 103 | } 104 | 105 | if strings.Contains(sub[2], ".") { 106 | second, err := strconv.ParseFloat(sub[2], 64) 107 | if err != nil { 108 | return numeric{}, err 109 | } 110 | f := float64(pre) * (float64(hour)*10000 + float64(minute)*100 + second) 111 | return numeric{fval: f, typ: Float64}, nil 112 | } 113 | 114 | second, err := strconv.ParseInt(sub[2], 10, 8) 115 | if err != nil { 116 | return numeric{}, err 117 | } 118 | i = pre * (hour*10000 + minute*100 + second) 119 | return numeric{ival: i, typ: Int64}, nil 120 | case querypb.Type_YEAR: 121 | val, err := strconv.ParseUint(v.ToString(), 10, 16) 122 | if err != nil { 123 | return numeric{}, err 124 | } 125 | return numeric{uval: val, typ: Uint64}, nil 126 | } 127 | return numeric{}, errors.New("unsupport: unknown.type") 128 | } 129 | -------------------------------------------------------------------------------- /sqlparser/depends/sqltypes/time_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | */ 8 | 9 | package sqltypes 10 | 11 | import ( 12 | "reflect" 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TesttimeToNumeric(t *testing.T) { 19 | testcases := []struct { 20 | in Value 21 | out interface{} 22 | }{ 23 | { 24 | in: testVal(Timestamp, "2012-02-24 23:19:43"), 25 | out: int64(20120224231943), 26 | }, 27 | { 28 | in: testVal(Timestamp, "2012-02-24 23:19:43.120"), 29 | out: float64(20120224231943.120), 30 | }, 31 | { 32 | in: testVal(Time, "-23:19:43.120"), 33 | out: float64(-231943.120), 34 | }, 35 | { 36 | in: testVal(Time, "-63:19:43"), 37 | out: int64(-631943), 38 | }, 39 | { 40 | in: testVal(Datetime, "0000-00-00 00:00:00"), 41 | out: int64(0), 42 | }, 43 | { 44 | in: testVal(Datetime, "2012-02-24 23:19:43.000012"), 45 | out: float64(20120224231943.000012), 46 | }, 47 | { 48 | in: testVal(Date, "0000-00-00"), 49 | out: int64(0), 50 | }, 51 | { 52 | in: testVal(Date, "2012-02-24"), 53 | out: int64(20120224), 54 | }, 55 | { 56 | in: testVal(Year, "2012"), 57 | out: uint64(2012), 58 | }, 59 | { 60 | in: testVal(Year, "12"), 61 | out: uint64(12), 62 | }, 63 | } 64 | 65 | for _, tcase := range testcases { 66 | got, err := timeToNumeric(tcase.in) 67 | assert.Nil(t, err) 68 | 69 | var v interface{} 70 | switch got.typ { 71 | case Uint64: 72 | v = got.uval 73 | case Float64: 74 | v = got.fval 75 | case Int64: 76 | v = got.ival 77 | } 78 | 79 | if !reflect.DeepEqual(v, tcase.out) { 80 | t.Errorf("%v.ToNative = %#v, want %#v", makePretty(tcase.in), v, tcase.out) 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /sqlparser/encodable.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strings" 21 | 22 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 23 | ) 24 | 25 | // This file contains types that are 'Encodable'. 26 | 27 | // Encodable defines the interface for types that can 28 | // be custom-encoded into SQL. 29 | type Encodable interface { 30 | EncodeSQL(buf *strings.Builder) 31 | } 32 | 33 | // InsertValues is a custom SQL encoder for the values of 34 | // an insert statement. 35 | type InsertValues [][]sqltypes.Value 36 | 37 | // EncodeSQL performs the SQL encoding for InsertValues. 38 | func (iv InsertValues) EncodeSQL(buf *strings.Builder) { 39 | for i, rows := range iv { 40 | if i != 0 { 41 | buf.WriteString(", ") 42 | } 43 | buf.WriteByte('(') 44 | for j, bv := range rows { 45 | if j != 0 { 46 | buf.WriteString(", ") 47 | } 48 | bv.EncodeSQL(buf) 49 | } 50 | buf.WriteByte(')') 51 | } 52 | } 53 | 54 | // TupleEqualityList is for generating equality constraints 55 | // for tables that have composite primary keys. 56 | type TupleEqualityList struct { 57 | Columns []ColIdent 58 | Rows [][]sqltypes.Value 59 | } 60 | 61 | // EncodeSQL generates the where clause constraints for the tuple 62 | // equality. 63 | func (tpl *TupleEqualityList) EncodeSQL(buf *strings.Builder) { 64 | if len(tpl.Columns) == 1 { 65 | tpl.encodeAsIn(buf) 66 | return 67 | } 68 | tpl.encodeAsEquality(buf) 69 | } 70 | 71 | func (tpl *TupleEqualityList) encodeAsIn(buf *strings.Builder) { 72 | Append(buf, tpl.Columns[0]) 73 | buf.WriteString(" in (") 74 | for i, r := range tpl.Rows { 75 | if i != 0 { 76 | buf.WriteString(", ") 77 | } 78 | r[0].EncodeSQL(buf) 79 | } 80 | buf.WriteByte(')') 81 | } 82 | 83 | func (tpl *TupleEqualityList) encodeAsEquality(buf *strings.Builder) { 84 | for i, r := range tpl.Rows { 85 | if i != 0 { 86 | buf.WriteString(" or ") 87 | } 88 | buf.WriteString("(") 89 | for j, c := range tpl.Columns { 90 | if j != 0 { 91 | buf.WriteString(" and ") 92 | } 93 | Append(buf, c) 94 | buf.WriteString(" = ") 95 | r[j].EncodeSQL(buf) 96 | } 97 | buf.WriteByte(')') 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /sqlparser/encodable_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strings" 21 | "testing" 22 | 23 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 24 | ) 25 | 26 | func TestEncodable(t *testing.T) { 27 | tcases := []struct { 28 | in Encodable 29 | out string 30 | }{{ 31 | in: InsertValues{{ 32 | sqltypes.NewInt64(1), 33 | sqltypes.NewVarBinary("foo('a')"), 34 | }, { 35 | sqltypes.NewInt64(2), 36 | sqltypes.NewVarBinary("bar(`b`)"), 37 | }}, 38 | out: "(1, 'foo(\\'a\\')'), (2, 'bar(`b`)')", 39 | }, { 40 | // Single column. 41 | in: &TupleEqualityList{ 42 | Columns: []ColIdent{NewColIdent("pk")}, 43 | Rows: [][]sqltypes.Value{ 44 | {sqltypes.NewInt64(1)}, 45 | {sqltypes.NewVarBinary("aa")}, 46 | }, 47 | }, 48 | out: "pk in (1, 'aa')", 49 | }, { 50 | // Multiple columns. 51 | in: &TupleEqualityList{ 52 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")}, 53 | Rows: [][]sqltypes.Value{ 54 | { 55 | sqltypes.NewInt64(1), 56 | sqltypes.NewVarBinary("aa"), 57 | }, 58 | { 59 | sqltypes.NewInt64(2), 60 | sqltypes.NewVarBinary("bb"), 61 | }, 62 | }, 63 | }, 64 | out: "(pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 65 | }} 66 | for _, tcase := range tcases { 67 | buf := new(strings.Builder) 68 | tcase.in.EncodeSQL(buf) 69 | if out := buf.String(); out != tcase.out { 70 | t.Errorf("EncodeSQL(%v): %s, want %s", tcase.in, out, tcase.out) 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /sqlparser/explain_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "strings" 20 | import "testing" 21 | 22 | func TestExplain(t *testing.T) { 23 | validSQL := []struct { 24 | input string 25 | output string 26 | }{ 27 | { 28 | input: "explain select * from 1", 29 | output: "explain", 30 | }, 31 | } 32 | 33 | for _, exp := range validSQL { 34 | sql := strings.TrimSpace(exp.input) 35 | tree, err := Parse(sql) 36 | if err != nil { 37 | t.Errorf("input: %s, err: %v", sql, err) 38 | continue 39 | } 40 | 41 | // Walk. 42 | Walk(func(node SQLNode) (bool, error) { 43 | return true, nil 44 | }, tree) 45 | 46 | got := String(tree.(*Explain)) 47 | if exp.output != got { 48 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /sqlparser/impossible_query.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreedto in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | // FormatImpossibleQuery creates an impossible query in a TrackedBuffer. 20 | // An impossible query is a modified version of a query where all selects have where clauses that are 21 | // impossible for mysql to resolve. This is used in the vtgate and vttablet: 22 | // 23 | // - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible 24 | // query just to fetch field info from vttablet 25 | // - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused 26 | // for subsequent queries 27 | func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) { 28 | switch node := node.(type) { 29 | case *Select: 30 | buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From) 31 | if node.GroupBy != nil { 32 | node.GroupBy.Format(buf) 33 | } 34 | case *Union: 35 | buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right) 36 | default: 37 | node.Format(buf) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /sqlparser/impossible_query_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Radon 3 | * 4 | * Copyright 2019 The Radon Authors. 5 | * Code is licensed under the GPLv3. 6 | * 7 | */ 8 | 9 | package sqlparser 10 | 11 | import ( 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestFormatImpossibleQuery(t *testing.T) { 18 | querys := []string{"select a,b from A where A.id>1 group by a order by a limit 1", 19 | "select id,a from A union select name,a from B order by a", 20 | "insert into A(a,b) values(1,'a')"} 21 | wants := []string{"select a, b from A where 1 != 1 group by a", 22 | "select id, a from A union select name, a from B", 23 | "insert into A(a, b) values (1, 'a')"} 24 | for i, query := range querys { 25 | node, err := Parse(query) 26 | assert.Nil(t, err) 27 | buf := NewTrackedBuffer(nil) 28 | FormatImpossibleQuery(buf, node) 29 | got := buf.String() 30 | assert.Equal(t, wants[i], got) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /sqlparser/kill_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strings" 21 | "testing" 22 | ) 23 | 24 | func TestKill(t *testing.T) { 25 | validSQL := []struct { 26 | input string 27 | output string 28 | }{ 29 | { 30 | input: "kill 1", 31 | output: "kill 1", 32 | }, 33 | 34 | { 35 | input: "kill 10000000000000000000000000000000", 36 | output: "kill 10000000000000000000000000000000", 37 | }, 38 | 39 | { 40 | input: "kill query 1", 41 | output: "kill 1", 42 | }, 43 | } 44 | 45 | for _, exp := range validSQL { 46 | sql := strings.TrimSpace(exp.input) 47 | tree, err := Parse(sql) 48 | if err != nil { 49 | t.Errorf("input: %s, err: %v", sql, err) 50 | continue 51 | } 52 | 53 | // Walk. 54 | Walk(func(node SQLNode) (bool, error) { 55 | return true, nil 56 | }, tree) 57 | 58 | node := tree.(*Kill) 59 | node.QueryID.AsUint64() 60 | 61 | // Format. 62 | got := String(node) 63 | if exp.output != got { 64 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /sqlparser/normalizer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "fmt" 9 | 10 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 11 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 12 | ) 13 | 14 | // Normalize changes the statement to use bind values, and 15 | // updates the bind vars to those values. The supplied prefix 16 | // is used to generate the bind var names. The function ensures 17 | // that there are no collisions with existing bind vars. 18 | func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) { 19 | reserved := GetBindvars(stmt) 20 | // vals allows us to reuse bindvars for 21 | // identical values. 22 | counter := 1 23 | vals := make(map[string]string) 24 | _ = Walk(func(node SQLNode) (kontinue bool, err error) { 25 | switch node := node.(type) { 26 | case *SQLVal: 27 | // Make the bindvar 28 | bval := sqlToBindvar(node) 29 | if bval == nil { 30 | // If unsuccessful continue. 31 | return true, nil 32 | } 33 | // Check if there's a bindvar for that value already. 34 | var key string 35 | if bval.Type == sqltypes.VarBinary { 36 | // Prefixing strings with "'" ensures that a string 37 | // and number that have the same representation don't 38 | // collide. 39 | key = "'" + string(node.Val) 40 | } else { 41 | key = string(node.Val) 42 | } 43 | bvname, ok := vals[key] 44 | if !ok { 45 | // If there's no such bindvar, make a new one. 46 | bvname, counter = newName(prefix, counter, reserved) 47 | vals[key] = bvname 48 | bindVars[bvname] = bval 49 | } 50 | // Modify the AST node to a bindvar. 51 | node.Type = ValArg 52 | node.Val = append([]byte(":"), bvname...) 53 | case *ComparisonExpr: 54 | switch node.Operator { 55 | case InStr, NotInStr: 56 | default: 57 | return true, nil 58 | } 59 | // It's either IN or NOT IN. 60 | tupleVals, ok := node.Right.(ValTuple) 61 | if !ok { 62 | return true, nil 63 | } 64 | // The RHS is a tuple of values. 65 | // Make a list bindvar. 66 | bvals := &querypb.BindVariable{ 67 | Type: sqltypes.Tuple, 68 | } 69 | for _, val := range tupleVals { 70 | bval := sqlToBindvar(val) 71 | if bval == nil { 72 | return true, nil 73 | } 74 | bvals.Values = append(bvals.Values, &querypb.Value{ 75 | Type: bval.Type, 76 | Value: bval.Value, 77 | }) 78 | } 79 | var bvname string 80 | bvname, counter = newName(prefix, counter, reserved) 81 | bindVars[bvname] = bvals 82 | // Modify RHS to be a list bindvar. 83 | node.Right = ListArg(append([]byte("::"), bvname...)) 84 | } 85 | return true, nil 86 | }, stmt) 87 | } 88 | 89 | func sqlToBindvar(node SQLNode) *querypb.BindVariable { 90 | if node, ok := node.(*SQLVal); ok { 91 | switch node.Type { 92 | case StrVal: 93 | return &querypb.BindVariable{Type: sqltypes.VarBinary, Value: node.Val} 94 | case IntVal: 95 | return &querypb.BindVariable{Type: sqltypes.Int64, Value: node.Val} 96 | case FloatVal: 97 | return &querypb.BindVariable{Type: sqltypes.Float64, Value: node.Val} 98 | } 99 | } 100 | return nil 101 | } 102 | 103 | func newName(prefix string, counter int, reserved map[string]struct{}) (string, int) { 104 | for { 105 | newName := fmt.Sprintf("%s%d", prefix, counter) 106 | if _, ok := reserved[newName]; !ok { 107 | reserved[newName] = struct{}{} 108 | return newName, counter + 1 109 | } 110 | counter++ 111 | } 112 | } 113 | 114 | // GetBindvars returns a map of the bind vars referenced in the statement. 115 | // TODO(sougou); This function gets called again from vtgate/planbuilder. 116 | // Ideally, this should be done only once. 117 | func GetBindvars(stmt Statement) map[string]struct{} { 118 | bindvars := make(map[string]struct{}) 119 | _ = Walk(func(node SQLNode) (kontinue bool, err error) { 120 | switch node := node.(type) { 121 | case *SQLVal: 122 | if node.Type == ValArg { 123 | bindvars[string(node.Val[1:])] = struct{}{} 124 | } 125 | case ListArg: 126 | bindvars[string(node[2:])] = struct{}{} 127 | } 128 | return true, nil 129 | }, stmt) 130 | return bindvars 131 | } 132 | -------------------------------------------------------------------------------- /sqlparser/normalizer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 12 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 13 | ) 14 | 15 | func TestNormalize(t *testing.T) { 16 | prefix := "bv" 17 | testcases := []struct { 18 | in string 19 | outstmt string 20 | outbv map[string]*querypb.BindVariable 21 | }{{ 22 | // str val 23 | in: "select * from t where v1 = 'aa'", 24 | outstmt: "select * from t where v1 = :bv1", 25 | outbv: map[string]*querypb.BindVariable{ 26 | "bv1": &querypb.BindVariable{ 27 | Type: sqltypes.VarBinary, 28 | Value: []byte("aa"), 29 | }, 30 | }, 31 | }, { 32 | // int val 33 | in: "select * from t where v1 = 1", 34 | outstmt: "select * from t where v1 = :bv1", 35 | outbv: map[string]*querypb.BindVariable{ 36 | "bv1": &querypb.BindVariable{ 37 | Type: sqltypes.Int64, 38 | Value: []byte("1"), 39 | }, 40 | }, 41 | }, { 42 | // float val 43 | in: "select * from t where v1 = 1.2", 44 | outstmt: "select * from t where v1 = :bv1", 45 | outbv: map[string]*querypb.BindVariable{ 46 | "bv1": &querypb.BindVariable{ 47 | Type: sqltypes.Float64, 48 | Value: []byte("1.2"), 49 | }, 50 | }, 51 | }, { 52 | // multiple vals 53 | in: "select * from t where v1 = 1.2 and v2 = 2", 54 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 55 | outbv: map[string]*querypb.BindVariable{ 56 | "bv1": &querypb.BindVariable{ 57 | Type: sqltypes.Float64, 58 | Value: []byte("1.2"), 59 | }, 60 | "bv2": &querypb.BindVariable{ 61 | Type: sqltypes.Int64, 62 | Value: []byte("2"), 63 | }, 64 | }, 65 | }, { 66 | // bv collision 67 | in: "select * from t where v1 = :bv1 and v2 = 1", 68 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 69 | outbv: map[string]*querypb.BindVariable{ 70 | "bv2": &querypb.BindVariable{ 71 | Type: sqltypes.Int64, 72 | Value: []byte("1"), 73 | }, 74 | }, 75 | }, { 76 | // val reuse 77 | in: "select * from t where v1 = 1 and v2 = 1", 78 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv1", 79 | outbv: map[string]*querypb.BindVariable{ 80 | "bv1": &querypb.BindVariable{ 81 | Type: sqltypes.Int64, 82 | Value: []byte("1"), 83 | }, 84 | }, 85 | }, { 86 | // ints and strings are different 87 | in: "select * from t where v1 = 1 and v2 = '1'", 88 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", 89 | outbv: map[string]*querypb.BindVariable{ 90 | "bv1": &querypb.BindVariable{ 91 | Type: sqltypes.Int64, 92 | Value: []byte("1"), 93 | }, 94 | "bv2": &querypb.BindVariable{ 95 | Type: sqltypes.VarBinary, 96 | Value: []byte("1"), 97 | }, 98 | }, 99 | }, { 100 | // comparison with no vals 101 | in: "select * from t where v1 = v2", 102 | outstmt: "select * from t where v1 = v2", 103 | outbv: map[string]*querypb.BindVariable{}, 104 | }, { 105 | // IN clause with existing bv 106 | in: "select * from t where v1 in ::list", 107 | outstmt: "select * from t where v1 in ::list", 108 | outbv: map[string]*querypb.BindVariable{}, 109 | }, { 110 | // IN clause with non-val values 111 | in: "select * from t where v1 in (1, a)", 112 | outstmt: "select * from t where v1 in (:bv1, a)", 113 | outbv: map[string]*querypb.BindVariable{ 114 | "bv1": &querypb.BindVariable{ 115 | Type: sqltypes.Int64, 116 | Value: []byte("1"), 117 | }, 118 | }, 119 | }, { 120 | // IN clause with vals 121 | in: "select * from t where v1 in (1, '2')", 122 | outstmt: "select * from t where v1 in ::bv1", 123 | outbv: map[string]*querypb.BindVariable{ 124 | "bv1": &querypb.BindVariable{ 125 | Type: sqltypes.Tuple, 126 | Values: []*querypb.Value{{ 127 | Type: sqltypes.Int64, 128 | Value: []byte("1"), 129 | }, { 130 | Type: sqltypes.VarBinary, 131 | Value: []byte("2"), 132 | }}, 133 | }, 134 | }, 135 | }, { 136 | // NOT IN clause 137 | in: "select * from t where v1 not in (1, '2')", 138 | outstmt: "select * from t where v1 not in ::bv1", 139 | outbv: map[string]*querypb.BindVariable{ 140 | "bv1": &querypb.BindVariable{ 141 | Type: sqltypes.Tuple, 142 | Values: []*querypb.Value{{ 143 | Type: sqltypes.Int64, 144 | Value: []byte("1"), 145 | }, { 146 | Type: sqltypes.VarBinary, 147 | Value: []byte("2"), 148 | }}, 149 | }, 150 | }, 151 | }} 152 | for _, tc := range testcases { 153 | stmt, err := Parse(tc.in) 154 | if err != nil { 155 | t.Error(err) 156 | continue 157 | } 158 | bv := make(map[string]*querypb.BindVariable) 159 | Normalize(stmt, bv, prefix) 160 | outstmt := String(stmt) 161 | if outstmt != tc.outstmt { 162 | t.Errorf("Query:\n%s:\n%s, want\n%s", tc.in, outstmt, tc.outstmt) 163 | } 164 | if !reflect.DeepEqual(tc.outbv, bv) { 165 | t.Errorf("Query:\n%s:\n%v, want\n%v", tc.in, bv, tc.outbv) 166 | } 167 | } 168 | } 169 | 170 | func TestGetBindVars(t *testing.T) { 171 | stmt, err := Parse("select * from t where :v1 = :v2 and :v2 = :v3 and :v4 in ::v5") 172 | if err != nil { 173 | t.Fatal(err) 174 | } 175 | got := GetBindvars(stmt) 176 | want := map[string]struct{}{ 177 | "v1": {}, 178 | "v2": {}, 179 | "v3": {}, 180 | "v4": {}, 181 | "v5": {}, 182 | } 183 | if !reflect.DeepEqual(got, want) { 184 | t.Errorf("GetBindVars: %v, want %v", got, want) 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /sqlparser/parsed_query.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "encoding/json" 21 | "fmt" 22 | "strings" 23 | 24 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 25 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 26 | ) 27 | 28 | // ParsedQuery represents a parsed query where 29 | // bind locations are precompued for fast substitutions. 30 | type ParsedQuery struct { 31 | Query string 32 | bindLocations []bindLocation 33 | } 34 | 35 | type bindLocation struct { 36 | offset, length int 37 | } 38 | 39 | // NewParsedQuery returns a ParsedQuery of the ast. 40 | func NewParsedQuery(node SQLNode) *ParsedQuery { 41 | buf := NewTrackedBuffer(nil) 42 | buf.Myprintf("%v", node) 43 | return buf.ParsedQuery() 44 | } 45 | 46 | // GenerateQuery generates a query by substituting the specified 47 | // bindVariables. The extras parameter specifies special parameters 48 | // that can perform custom encoding. 49 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) (string, error) { 50 | if len(pq.bindLocations) == 0 { 51 | return pq.Query, nil 52 | } 53 | var buf strings.Builder 54 | buf.Grow(len(pq.Query)) 55 | current := 0 56 | for _, loc := range pq.bindLocations { 57 | buf.WriteString(pq.Query[current:loc.offset]) 58 | name := pq.Query[loc.offset : loc.offset+loc.length] 59 | if encodable, ok := extras[name[1:]]; ok { 60 | encodable.EncodeSQL(&buf) 61 | } else { 62 | supplied, _, err := FetchBindVar(name, bindVariables) 63 | if err != nil { 64 | return "", err 65 | } 66 | EncodeValue(&buf, supplied) 67 | } 68 | current = loc.offset + loc.length 69 | } 70 | buf.WriteString(pq.Query[current:]) 71 | return buf.String(), nil 72 | } 73 | 74 | // MarshalJSON is a custom JSON marshaler for ParsedQuery. 75 | // Note that any queries longer that 512 bytes will be truncated. 76 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) { 77 | return json.Marshal(pq.Query) 78 | } 79 | 80 | // EncodeValue encodes one bind variable value into the query. 81 | func EncodeValue(buf *strings.Builder, value *querypb.BindVariable) { 82 | if value.Type != querypb.Type_TUPLE { 83 | // Since we already check for TUPLE, we don't expect an error. 84 | v, _ := sqltypes.BindVariableToValue(value) 85 | v.EncodeSQL(buf) 86 | return 87 | } 88 | 89 | // It's a TUPLE. 90 | buf.WriteByte('(') 91 | for i, bv := range value.Values { 92 | if i != 0 { 93 | buf.WriteString(", ") 94 | } 95 | sqltypes.ProtoToValue(bv).EncodeSQL(buf) 96 | } 97 | buf.WriteByte(')') 98 | } 99 | 100 | // FetchBindVar resolves the bind variable by fetching it from bindVariables. 101 | func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) { 102 | name = name[1:] 103 | if name[0] == ':' { 104 | name = name[1:] 105 | isList = true 106 | } 107 | supplied, ok := bindVariables[name] 108 | if !ok { 109 | return nil, false, fmt.Errorf("missing bind var %s", name) 110 | } 111 | 112 | if isList { 113 | if supplied.Type != querypb.Type_TUPLE { 114 | return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name) 115 | } 116 | if len(supplied.Values) == 0 { 117 | return nil, false, fmt.Errorf("empty list supplied for %s", name) 118 | } 119 | return supplied, true, nil 120 | } 121 | 122 | if supplied.Type == querypb.Type_TUPLE { 123 | return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name) 124 | } 125 | 126 | return supplied, false, nil 127 | } 128 | -------------------------------------------------------------------------------- /sqlparser/parsed_query_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "reflect" 21 | "testing" 22 | 23 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" 24 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" 25 | ) 26 | 27 | func TestNewParsedQuery(t *testing.T) { 28 | stmt, err := Parse("select * from a where id =:id") 29 | if err != nil { 30 | t.Error(err) 31 | return 32 | } 33 | pq := NewParsedQuery(stmt) 34 | want := &ParsedQuery{ 35 | Query: "select * from a where id = :id", 36 | bindLocations: []bindLocation{{offset: 27, length: 3}}, 37 | } 38 | if !reflect.DeepEqual(pq, want) { 39 | t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want) 40 | } 41 | } 42 | 43 | func TestGenerateQuery(t *testing.T) { 44 | tcases := []struct { 45 | desc string 46 | query string 47 | bindVars map[string]*querypb.BindVariable 48 | extras map[string]Encodable 49 | output string 50 | }{ 51 | { 52 | desc: "no substitutions", 53 | query: "select * from a where id = 2", 54 | bindVars: map[string]*querypb.BindVariable{ 55 | "id": sqltypes.Int64BindVariable(1), 56 | }, 57 | output: "select * from a where id = 2", 58 | }, { 59 | desc: "missing bind var", 60 | query: "select * from a where id1 = :id1 and id2 = :id2", 61 | bindVars: map[string]*querypb.BindVariable{ 62 | "id1": sqltypes.Int64BindVariable(1), 63 | }, 64 | output: "missing bind var id2", 65 | }, { 66 | desc: "simple bindvar substitution", 67 | query: "select * from a where id1 = :id1 and id2 = :id2", 68 | bindVars: map[string]*querypb.BindVariable{ 69 | "id1": sqltypes.Int64BindVariable(1), 70 | "id2": sqltypes.NullBindVariable, 71 | }, 72 | output: "select * from a where id1 = 1 and id2 = null", 73 | }, { 74 | desc: "non-list bind var supplied", 75 | query: "select * from a where id in ::vals", 76 | bindVars: map[string]*querypb.BindVariable{ 77 | "vals": sqltypes.Int64BindVariable(1), 78 | }, 79 | output: "unexpected list arg type (INT64) for key vals", 80 | }, { 81 | desc: "single column tuple equality", 82 | query: "select * from a where b = :equality", 83 | extras: map[string]Encodable{ 84 | "equality": &TupleEqualityList{ 85 | Columns: []ColIdent{NewColIdent("pk")}, 86 | Rows: [][]sqltypes.Value{ 87 | {sqltypes.NewInt64(1)}, 88 | {sqltypes.NewVarBinary("aa")}, 89 | }, 90 | }, 91 | }, 92 | output: "select * from a where b = pk in (1, 'aa')", 93 | }, { 94 | desc: "multi column tuple equality", 95 | query: "select * from a where b = :equality", 96 | extras: map[string]Encodable{ 97 | "equality": &TupleEqualityList{ 98 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")}, 99 | Rows: [][]sqltypes.Value{ 100 | { 101 | sqltypes.NewInt64(1), 102 | sqltypes.NewVarBinary("aa"), 103 | }, 104 | { 105 | sqltypes.NewInt64(2), 106 | sqltypes.NewVarBinary("bb"), 107 | }, 108 | }, 109 | }, 110 | }, 111 | output: "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 112 | }, 113 | } 114 | 115 | for _, tcase := range tcases { 116 | tree, err := Parse(tcase.query) 117 | if err != nil { 118 | t.Errorf("parse failed for %s: %v", tcase.desc, err) 119 | continue 120 | } 121 | buf := NewTrackedBuffer(nil) 122 | buf.Myprintf("%v", tree) 123 | pq := buf.ParsedQuery() 124 | bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.extras) 125 | var got string 126 | if err != nil { 127 | got = err.Error() 128 | } else { 129 | got = string(bytes) 130 | } 131 | if got != tcase.output { 132 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output) 133 | } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /sqlparser/parser.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "errors" 21 | "sync" 22 | ) 23 | 24 | // parserPool is a pool for parser objects. 25 | var parserPool = sync.Pool{} 26 | 27 | // zeroParser is a zero-initialized parser to help reinitialize the parser for pooling. 28 | var zeroParser = *(yyNewParser().(*yyParserImpl)) 29 | 30 | // yyParsePooled is a wrapper around yyParse that pools the parser objects. There isn't a 31 | // particularly good reason to use yyParse directly, since it immediately discards its parser. What 32 | // would be ideal down the line is to actually pool the stacks themselves rather than the parser 33 | // objects, as per https://github.com/cznic/goyacc/blob/master/main.go. However, absent an upstream 34 | // change to goyacc, this is the next best option. 35 | // 36 | // N.B: Parser pooling means that you CANNOT take references directly to parse stack variables (e.g. 37 | // $$ = &$4) in sql.y rules. You must instead add an intermediate reference like so: 38 | // showCollationFilterOpt := $4 39 | // $$ = &Show{Type: string($2), ShowCollationFilterOpt: &showCollationFilterOpt} 40 | func yyParsePooled(yylex yyLexer) int { 41 | // Being very particular about using the base type and not an interface type b/c we depend on 42 | // the implementation to know how to reinitialize the parser. 43 | var parser *yyParserImpl 44 | 45 | i := parserPool.Get() 46 | if i != nil { 47 | parser = i.(*yyParserImpl) 48 | } else { 49 | parser = yyNewParser().(*yyParserImpl) 50 | } 51 | 52 | defer func() { 53 | *parser = zeroParser 54 | parserPool.Put(parser) 55 | }() 56 | return parser.Parse(yylex) 57 | } 58 | 59 | // Instructions for creating new types: If a type 60 | // needs to satisfy an interface, declare that function 61 | // along with that interface. This will help users 62 | // identify the list of types to which they can assert 63 | // those interfaces. 64 | // If the member of a type has a string with a predefined 65 | // list of values, declare those values as const following 66 | // the type. 67 | // For interfaces that define dummy functions to consolidate 68 | // a set of types, define the function as iTypeName. 69 | // This will help avoid name collisions. 70 | 71 | // Parse parses the sql and returns a Statement, which 72 | // is the AST representation of the query. If a DDL statement 73 | // is partially parsed but still contains a syntax error, the 74 | // error is ignored and the DDL is returned anyway. 75 | func Parse(sql string) (Statement, error) { 76 | tokenizer := NewStringTokenizer(sql) 77 | if yyParse(tokenizer) != 0 { 78 | return nil, errors.New(tokenizer.LastError) 79 | } 80 | return tokenizer.ParseTree, nil 81 | } 82 | 83 | // ParseStrictDDL is the same as Parse except it errors on 84 | // partially parsed DDL statements. 85 | func ParseStrictDDL(sql string) (Statement, error) { 86 | tokenizer := NewStringTokenizer(sql) 87 | if yyParsePooled(tokenizer) != 0 { 88 | return nil, errors.New(tokenizer.LastError) 89 | } 90 | if tokenizer.ParseTree == nil { 91 | return nil, nil 92 | } 93 | return tokenizer.ParseTree, nil 94 | } 95 | 96 | // String returns a string representation of an SQLNode. 97 | func String(node SQLNode) string { 98 | buf := NewTrackedBuffer(nil) 99 | buf.Myprintf("%v", node) 100 | return buf.String() 101 | } 102 | -------------------------------------------------------------------------------- /sqlparser/precedence_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | "testing" 22 | ) 23 | 24 | func readable(node Expr) string { 25 | switch node := node.(type) { 26 | case *OrExpr: 27 | return fmt.Sprintf("(%s or %s)", readable(node.Left), readable(node.Right)) 28 | case *AndExpr: 29 | return fmt.Sprintf("(%s and %s)", readable(node.Left), readable(node.Right)) 30 | case *BinaryExpr: 31 | return fmt.Sprintf("(%s %s %s)", readable(node.Left), node.Operator, readable(node.Right)) 32 | case *IsExpr: 33 | return fmt.Sprintf("(%s %s)", readable(node.Expr), node.Operator) 34 | default: 35 | return String(node) 36 | } 37 | } 38 | 39 | func TestAndOrPrecedence(t *testing.T) { 40 | validSQL := []struct { 41 | input string 42 | output string 43 | }{{ 44 | input: "select * from a where a=b and c=d or e=f", 45 | output: "((a = b and c = d) or e = f)", 46 | }, { 47 | input: "select * from a where a=b or c=d and e=f", 48 | output: "(a = b or (c = d and e = f))", 49 | }} 50 | for _, tcase := range validSQL { 51 | tree, err := Parse(tcase.input) 52 | if err != nil { 53 | t.Error(err) 54 | continue 55 | } 56 | expr := readable(tree.(*Select).Where.Expr) 57 | if expr != tcase.output { 58 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 59 | } 60 | } 61 | } 62 | 63 | func TestPlusStarPrecedence(t *testing.T) { 64 | validSQL := []struct { 65 | input string 66 | output string 67 | }{{ 68 | input: "select 1+2*3 from a", 69 | output: "(1 + (2 * 3))", 70 | }, { 71 | input: "select 1*2+3 from a", 72 | output: "((1 * 2) + 3)", 73 | }} 74 | for _, tcase := range validSQL { 75 | tree, err := Parse(tcase.input) 76 | if err != nil { 77 | t.Error(err) 78 | continue 79 | } 80 | expr := readable(tree.(*Select).SelectExprs[0].(*AliasedExpr).Expr) 81 | if expr != tcase.output { 82 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 83 | } 84 | } 85 | } 86 | 87 | func TestIsPrecedence(t *testing.T) { 88 | validSQL := []struct { 89 | input string 90 | output string 91 | }{{ 92 | input: "select * from a where a+b is true", 93 | output: "((a + b) is true)", 94 | }, { 95 | input: "select * from a where a=1 and b=2 is true", 96 | output: "(a = 1 and (b = 2 is true))", 97 | }, { 98 | input: "select * from a where (a=1 and b=2) is true", 99 | output: "((a = 1 and b = 2) is true)", 100 | }} 101 | for _, tcase := range validSQL { 102 | tree, err := Parse(tcase.input) 103 | if err != nil { 104 | t.Error(err) 105 | continue 106 | } 107 | expr := readable(tree.(*Select).Where.Expr) 108 | if expr != tcase.output { 109 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /sqlparser/radon_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestRadon(t *testing.T) { 9 | validSQL := []struct { 10 | input string 11 | output string 12 | }{ 13 | // name, address, user, password. 14 | { 15 | input: "radon attach ('attach1', '127.0.0.1:6000', 'root', '123456')", 16 | output: "radon attach ('attach1', '127.0.0.1:6000', 'root', '123456')", 17 | }, 18 | { 19 | input: "radon attachlist", 20 | output: "radon attachlist", 21 | }, 22 | { 23 | input: "radon detach('attach1')", 24 | output: "radon detach ('attach1')", 25 | }, 26 | { 27 | input: "radon reshard db.t db.tt", 28 | output: "radon reshard db.t to db.tt", 29 | }, 30 | { 31 | input: "radon reshard db.t to a.tt", 32 | output: "radon reshard db.t to a.tt", 33 | }, 34 | { 35 | input: "radon reshard db.t as b.tt", 36 | output: "radon reshard db.t to b.tt", 37 | }, 38 | { 39 | input: "radon cleanup", 40 | output: "radon cleanup", 41 | }, 42 | } 43 | 44 | for _, exp := range validSQL { 45 | sql := strings.TrimSpace(exp.input) 46 | tree, err := Parse(sql) 47 | if err != nil { 48 | t.Errorf("input: %s, err: %v", sql, err) 49 | continue 50 | } 51 | 52 | // Walk. 53 | Walk(func(node SQLNode) (bool, error) { 54 | return true, nil 55 | }, tree) 56 | 57 | got := String(tree.(*Radon)) 58 | if exp.output != got { 59 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /sqlparser/rewriter_api.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | // The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go 20 | 21 | // Rewrite traverses a syntax tree recursively, starting with root, 22 | // and calling pre and post for each node as described below. 23 | // Rewrite returns the syntax tree, possibly modified. 24 | // 25 | // If pre is not nil, it is called for each node before the node's 26 | // children are traversed (pre-order). If pre returns false, no 27 | // children are traversed, and post is not called for that node. 28 | // 29 | // If post is not nil, and a prior call of pre didn't return false, 30 | // post is called for each node after its children are traversed 31 | // (post-order). If post returns false, traversal is terminated and 32 | // Apply returns immediately. 33 | // 34 | // Only fields that refer to AST nodes are considered children; 35 | // i.e., fields of basic types (strings, []byte, etc.) are ignored. 36 | // 37 | func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { 38 | parent := &struct{ SQLNode }{node} 39 | defer func() { 40 | if r := recover(); r != nil && r != abort { 41 | panic(r) 42 | } 43 | result = parent.SQLNode 44 | }() 45 | 46 | a := &application{ 47 | pre: pre, 48 | post: post, 49 | cursor: Cursor{}, 50 | } 51 | 52 | // this is the root-replacer, used when the user replaces the root of the ast 53 | replacer := func(newNode SQLNode, _ SQLNode) { 54 | parent.SQLNode = newNode 55 | } 56 | 57 | a.apply(parent, node, replacer) 58 | 59 | return parent.SQLNode 60 | } 61 | 62 | // An ApplyFunc is invoked by Rewrite for each node n, even if n is nil, 63 | // before and/or after the node's children, using a Cursor describing 64 | // the current node and providing operations on it. 65 | // 66 | // The return value of ApplyFunc controls the syntax tree traversal. 67 | // See Rewrite for details. 68 | type ApplyFunc func(*Cursor) bool 69 | 70 | var abort = new(int) // singleton, to signal termination of Apply 71 | 72 | // A Cursor describes a node encountered during Apply. 73 | // Information about the node and its parent is available 74 | // from the Node and Parent methods. 75 | type Cursor struct { 76 | parent SQLNode 77 | replacer replacerFunc 78 | node SQLNode 79 | } 80 | 81 | // Node returns the current Node. 82 | func (c *Cursor) Node() SQLNode { return c.node } 83 | 84 | // Parent returns the parent of the current Node. 85 | func (c *Cursor) Parent() SQLNode { return c.parent } 86 | 87 | // Replace replaces the current node in the parent field with this new object. The use needs to make sure to not 88 | // replace the object with something of the wrong type, or the visitor will panic. 89 | func (c *Cursor) Replace(newNode SQLNode) { 90 | c.replacer(newNode, c.parent) 91 | } 92 | -------------------------------------------------------------------------------- /sqlparser/select_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "strings" 20 | import "testing" 21 | 22 | func TestSelect1(t *testing.T) { 23 | validSQL := []struct { 24 | input string 25 | output string 26 | }{ 27 | { 28 | input: "select * from xx", 29 | output: "select * from xx", 30 | }, 31 | { 32 | input: "select * from xx where id=1", 33 | output: "select * from xx where id = 1", 34 | }, 35 | } 36 | 37 | for _, sel := range validSQL { 38 | sql := strings.TrimSpace(sel.input) 39 | tree, err := Parse(sql) 40 | if err != nil { 41 | t.Errorf("input: %s, err: %v", sql, err) 42 | continue 43 | } 44 | got := String(tree.(*Select)) 45 | if sel.output != got { 46 | t.Errorf("want:\n%s\ngot:\n%s", sel.output, got) 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /sqlparser/set_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strings" 21 | "testing" 22 | ) 23 | 24 | func TestSet(t *testing.T) { 25 | validSQL := []struct { 26 | input string 27 | output string 28 | }{ 29 | { 30 | input: "SET @@session.s1= 'ON', @@session.s2='OFF'", 31 | output: "set @@session.s1 = 'ON', @@session.s2 = 'OFF'", 32 | }, 33 | 34 | { 35 | input: "SET @@session.radon_stream_fetching= 'OFF'", 36 | output: "set @@session.radon_stream_fetching = 'OFF'", 37 | }, 38 | { 39 | input: "SET radon_stream_fetching= false", 40 | output: "set radon_stream_fetching = false", 41 | }, 42 | { 43 | input: "SET SESSION wait_timeout = 2147483", 44 | output: "set session wait_timeout = 2147483", 45 | }, 46 | { 47 | input: "SET NAMES utf8", 48 | output: "set names 'utf8'", 49 | }, 50 | { 51 | input: "SET NAMES latin1 COLLATE latin1_german2_ci", 52 | output: "set names 'latin1' collate latin1_german2_ci", 53 | }, 54 | { 55 | input: "set session autocommit = ON, global wait_timeout = 2147483", 56 | output: "set session autocommit = 'on', global wait_timeout = 2147483", 57 | }, 58 | { 59 | input: "SET LOCAL TRANSACTION ISOLATION LEVEL READ COMMITTED", 60 | output: "set session transaction isolation level read committed", 61 | }, 62 | { 63 | input: "SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ WRITE", 64 | output: "set global transaction isolation level serializable, read write", 65 | }, 66 | { 67 | input: "SET SESSION TRANSACTION READ ONLY", 68 | output: "set session transaction read only", 69 | }, 70 | { 71 | input: "SET TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", 72 | output: "set transaction isolation level serializable, read write", 73 | }, 74 | } 75 | 76 | for _, exp := range validSQL { 77 | sql := strings.TrimSpace(exp.input) 78 | tree, err := Parse(sql) 79 | if err != nil { 80 | t.Errorf("input: %s, err: %v", sql, err) 81 | continue 82 | } 83 | 84 | // Walk. 85 | Walk(func(node SQLNode) (bool, error) { 86 | return true, nil 87 | }, tree) 88 | 89 | got := String(tree.(*Set)) 90 | if exp.output != got { 91 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /sqlparser/show_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "strings" 20 | import "testing" 21 | 22 | func TestShow1(t *testing.T) { 23 | validSQL := []struct { 24 | input string 25 | output string 26 | }{ 27 | { 28 | input: "show table status", 29 | output: "show table status", 30 | }, 31 | { 32 | input: "show table status from sbtest", 33 | output: "show table status from sbtest", 34 | }, 35 | { 36 | input: "show create table t1", 37 | output: "show create table t1", 38 | }, 39 | { 40 | input: "show tables", 41 | output: "show tables", 42 | }, 43 | { 44 | input: "show full tables", 45 | output: "show full tables", 46 | }, 47 | { 48 | input: "show full tables from t1", 49 | output: "show full tables from t1", 50 | }, 51 | { 52 | input: "show full tables from t1 like '%mysql%'", 53 | output: "show full tables from t1 like '%mysql%'", 54 | }, 55 | { 56 | input: "show full tables where Table_type != 'VIEW'", 57 | output: "show full tables where Table_type != 'VIEW'", 58 | }, 59 | { 60 | input: "show tables from t1", 61 | output: "show tables from t1", 62 | }, 63 | { 64 | input: "show tables from t1 like '%mysql%'", 65 | output: "show tables from t1 like '%mysql%'", 66 | }, 67 | { 68 | input: "show databases", 69 | output: "show databases", 70 | }, 71 | { 72 | input: "show create database sbtest", 73 | output: "show create database sbtest", 74 | }, 75 | { 76 | input: "show engines", 77 | output: "show engines", 78 | }, 79 | { 80 | input: "show status", 81 | output: "show status", 82 | }, 83 | { 84 | input: "show versions", 85 | output: "show versions", 86 | }, 87 | { 88 | input: "show processlist", 89 | output: "show processlist", 90 | }, 91 | { 92 | input: "show queryz", 93 | output: "show queryz", 94 | }, 95 | { 96 | input: "show txnz", 97 | output: "show txnz", 98 | }, 99 | { 100 | input: "show warnings", 101 | output: "show warnings", 102 | }, 103 | { 104 | input: "show variables", 105 | output: "show variables", 106 | }, 107 | { 108 | input: "show binlog events", 109 | output: "show binlog events", 110 | }, 111 | { 112 | input: "show binlog events limit 10", 113 | output: "show binlog events limit 10", 114 | }, 115 | { 116 | input: "show binlog events from gtid '20171225083823'", 117 | output: "show binlog events from gtid '20171225083823'", 118 | }, 119 | { 120 | input: "show binlog events from gtid '20171225083823' limit 1", 121 | output: "show binlog events from gtid '20171225083823' limit 1", 122 | }, 123 | { 124 | input: "show columns from t1", 125 | output: "show columns from t1", 126 | }, 127 | { 128 | input: "show columns from t1 like '%'", 129 | output: "show columns from t1 like '%'", 130 | }, 131 | { 132 | input: "show columns from t1 where `Key` = 'PRI'", 133 | output: "show columns from t1 where `Key` = 'PRI'", 134 | }, 135 | { 136 | input: "show full columns from t1", 137 | output: "show full columns from t1", 138 | }, 139 | { 140 | input: "show full columns from t1 like '%'", 141 | output: "show full columns from t1 like '%'", 142 | }, 143 | { 144 | input: "show full columns from t1 where `Key` = 'PRI'", 145 | output: "show full columns from t1 where `Key` = 'PRI'", 146 | }, 147 | { 148 | input: "show fields from t1", 149 | output: "show columns from t1", 150 | }, 151 | { 152 | input: "show fields from t1 like '%'", 153 | output: "show columns from t1 like '%'", 154 | }, 155 | { 156 | input: "show fields from t1 where `Key` = 'PRI'", 157 | output: "show columns from t1 where `Key` = 'PRI'", 158 | }, 159 | { 160 | input: "show full fields from t1", 161 | output: "show full columns from t1", 162 | }, 163 | { 164 | input: "show full fields from t1 like '%'", 165 | output: "show full columns from t1 like '%'", 166 | }, 167 | { 168 | input: "show full fields from t1 where `Key` = 'PRI'", 169 | output: "show full columns from t1 where `Key` = 'PRI'", 170 | }, 171 | } 172 | 173 | for _, show := range validSQL { 174 | sql := strings.TrimSpace(show.input) 175 | tree, err := Parse(sql) 176 | if err != nil { 177 | t.Errorf("input: %s, err: %v", sql, err) 178 | continue 179 | } 180 | 181 | // Walk. 182 | Walk(func(node SQLNode) (bool, error) { 183 | return true, nil 184 | }, tree) 185 | 186 | got := String(tree.(*Show)) 187 | if show.output != got { 188 | t.Errorf("want:\n%s\ngot:\n%s", show.output, got) 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /sqlparser/token_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "testing" 20 | 21 | func TestLiteralID(t *testing.T) { 22 | testcases := []struct { 23 | in string 24 | id int 25 | out string 26 | }{{ 27 | in: "`aa`", 28 | id: ID, 29 | out: "aa", 30 | }, { 31 | in: "```a```", 32 | id: ID, 33 | out: "`a`", 34 | }, { 35 | in: "`a``b`", 36 | id: ID, 37 | out: "a`b", 38 | }, { 39 | in: "`a``b`c", 40 | id: ID, 41 | out: "a`b", 42 | }, { 43 | in: "`a``b", 44 | id: LEX_ERROR, 45 | out: "a`b", 46 | }, { 47 | in: "`a``b``", 48 | id: LEX_ERROR, 49 | out: "a`b`", 50 | }, { 51 | in: "``", 52 | id: LEX_ERROR, 53 | out: "", 54 | }} 55 | 56 | for _, tcase := range testcases { 57 | tkn := NewStringTokenizer(tcase.in) 58 | id, out := tkn.Scan() 59 | if tcase.id != id || string(out) != tcase.out { 60 | t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out) 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /sqlparser/tracked_buffer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "fmt" 21 | "strings" 22 | ) 23 | 24 | // NodeFormatter defines the signature of a custom node formatter 25 | // function that can be given to TrackedBuffer for code generation. 26 | type NodeFormatter func(buf *TrackedBuffer, node SQLNode) 27 | 28 | // TrackedBuffer is used to rebuild a query from the ast. 29 | // bindLocations keeps track of locations in the buffer that 30 | // use bind variables for efficient future substitutions. 31 | // nodeFormatter is the formatting function the buffer will 32 | // use to format a node. By default(nil), it's FormatNode. 33 | // But you can supply a different formatting function if you 34 | // want to generate a query that's different from the default. 35 | type TrackedBuffer struct { 36 | *strings.Builder 37 | bindLocations []bindLocation 38 | nodeFormatter NodeFormatter 39 | } 40 | 41 | // NewTrackedBuffer creates a new TrackedBuffer. 42 | func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer { 43 | return &TrackedBuffer{ 44 | Builder: new(strings.Builder), 45 | nodeFormatter: nodeFormatter, 46 | } 47 | } 48 | 49 | // WriteNode function, initiates the writing of a single SQLNode tree by passing 50 | // through to Myprintf with a default format string 51 | func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer { 52 | buf.Myprintf("%v", node) 53 | return buf 54 | } 55 | 56 | // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v), 57 | // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in 58 | // which case it adds tracking info for future substitutions. 59 | // 60 | // The name must be something other than the usual Printf() to avoid "go vet" 61 | // warnings due to our custom format specifiers. 62 | func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { 63 | end := len(format) 64 | fieldnum := 0 65 | for i := 0; i < end; { 66 | lasti := i 67 | for i < end && format[i] != '%' { 68 | i++ 69 | } 70 | if i > lasti { 71 | buf.WriteString(format[lasti:i]) 72 | } 73 | if i >= end { 74 | break 75 | } 76 | i++ // '%' 77 | switch format[i] { 78 | case 'c': 79 | switch v := values[fieldnum].(type) { 80 | case byte: 81 | buf.WriteByte(v) 82 | case rune: 83 | buf.WriteRune(v) 84 | default: 85 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 86 | } 87 | case 's': 88 | switch v := values[fieldnum].(type) { 89 | case []byte: 90 | buf.Write(v) 91 | case string: 92 | buf.WriteString(v) 93 | default: 94 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 95 | } 96 | case 'v': 97 | node := values[fieldnum].(SQLNode) 98 | if buf.nodeFormatter == nil { 99 | node.Format(buf) 100 | } else { 101 | buf.nodeFormatter(buf, node) 102 | } 103 | case 'a': 104 | buf.WriteArg(values[fieldnum].(string)) 105 | default: 106 | panic("unexpected") 107 | } 108 | fieldnum++ 109 | i++ 110 | } 111 | } 112 | 113 | // WriteArg writes a value argument into the buffer along with 114 | // tracking information for future substitutions. arg must contain 115 | // the ":" or "::" prefix. 116 | func (buf *TrackedBuffer) WriteArg(arg string) { 117 | buf.bindLocations = append(buf.bindLocations, bindLocation{ 118 | offset: buf.Len(), 119 | length: len(arg), 120 | }) 121 | buf.WriteString(arg) 122 | } 123 | 124 | // ParsedQuery returns a ParsedQuery that contains bind 125 | // locations for easy substitution. 126 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 127 | return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations} 128 | } 129 | 130 | // HasBindVars returns true if the parsed query uses bind vars. 131 | func (buf *TrackedBuffer) HasBindVars() bool { 132 | return len(buf.bindLocations) != 0 133 | } 134 | 135 | // BuildParsedQuery builds a ParsedQuery from the input. 136 | func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery { 137 | buf := NewTrackedBuffer(nil) 138 | buf.Myprintf(in, vars...) 139 | return buf.ParsedQuery() 140 | } 141 | -------------------------------------------------------------------------------- /sqlparser/tracked_buffer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "testing" 20 | 21 | func TestTrackedBuffer(t *testing.T) { 22 | buf := NewTrackedBuffer(nil) 23 | buf.Myprintf("%c,%s,%a", 'a', "a", "a") 24 | } 25 | -------------------------------------------------------------------------------- /sqlparser/txn_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import "strings" 20 | import "testing" 21 | 22 | func TestTxn(t *testing.T) { 23 | validSQL := []struct { 24 | input string 25 | output string 26 | }{ 27 | { 28 | input: "start transaction", 29 | output: "start transaction", 30 | }, 31 | { 32 | input: "begin", 33 | output: "begin", 34 | }, 35 | { 36 | input: "rollback", 37 | output: "rollback", 38 | }, 39 | { 40 | input: "commit", 41 | output: "commit", 42 | }, 43 | } 44 | 45 | for _, exp := range validSQL { 46 | sql := strings.TrimSpace(exp.input) 47 | tree, err := Parse(sql) 48 | if err != nil { 49 | t.Errorf("input: %s, err: %v", sql, err) 50 | continue 51 | } 52 | 53 | // Walk. 54 | Walk(func(node SQLNode) (bool, error) { 55 | return true, nil 56 | }, tree) 57 | 58 | got := String(tree.(*Transaction)) 59 | if exp.output != got { 60 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/ast_walker.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import ( 20 | "go/ast" 21 | "reflect" 22 | ) 23 | 24 | var _ ast.Visitor = (*walker)(nil) 25 | 26 | type walker struct { 27 | result SourceFile 28 | } 29 | 30 | // Walk walks the given AST and translates it to the simplified AST used by the next steps 31 | func Walk(node ast.Node) *SourceFile { 32 | var w walker 33 | ast.Walk(&w, node) 34 | return &w.result 35 | } 36 | 37 | // Visit implements the ast.Visitor interface 38 | func (w *walker) Visit(node ast.Node) ast.Visitor { 39 | switch n := node.(type) { 40 | case *ast.TypeSpec: 41 | switch t2 := n.Type.(type) { 42 | case *ast.InterfaceType: 43 | w.append(&InterfaceDeclaration{ 44 | name: n.Name.Name, 45 | block: "", 46 | }) 47 | case *ast.StructType: 48 | var fields []*Field 49 | for _, f := range t2.Fields.List { 50 | for _, name := range f.Names { 51 | fields = append(fields, &Field{ 52 | name: name.Name, 53 | typ: sastType(f.Type), 54 | }) 55 | } 56 | 57 | } 58 | w.append(&StructDeclaration{ 59 | name: n.Name.Name, 60 | fields: fields, 61 | }) 62 | case *ast.ArrayType: 63 | w.append(&TypeAlias{ 64 | name: n.Name.Name, 65 | typ: &Array{inner: sastType(t2.Elt)}, 66 | }) 67 | case *ast.Ident: 68 | w.append(&TypeAlias{ 69 | name: n.Name.Name, 70 | typ: &TypeString{t2.Name}, 71 | }) 72 | 73 | default: 74 | panic(reflect.TypeOf(t2)) 75 | } 76 | case *ast.FuncDecl: 77 | if len(n.Recv.List) > 1 || len(n.Recv.List[0].Names) > 1 { 78 | panic("don't know what to do!") 79 | } 80 | var f *Field 81 | if len(n.Recv.List) == 1 { 82 | r := n.Recv.List[0] 83 | t := sastType(r.Type) 84 | if len(r.Names) > 1 { 85 | panic("don't know what to do!") 86 | } 87 | if len(r.Names) == 1 { 88 | f = &Field{ 89 | name: r.Names[0].Name, 90 | typ: t, 91 | } 92 | } else { 93 | f = &Field{ 94 | name: "", 95 | typ: t, 96 | } 97 | } 98 | } 99 | 100 | w.append(&FuncDeclaration{ 101 | receiver: f, 102 | name: n.Name.Name, 103 | block: "", 104 | arguments: nil, 105 | }) 106 | } 107 | 108 | return w 109 | } 110 | 111 | func (w *walker) append(line Sast) { 112 | w.result.lines = append(w.result.lines, line) 113 | } 114 | 115 | func sastType(e ast.Expr) Type { 116 | switch n := e.(type) { 117 | case *ast.StarExpr: 118 | return &Ref{sastType(n.X)} 119 | case *ast.Ident: 120 | return &TypeString{n.Name} 121 | case *ast.ArrayType: 122 | return &Array{inner: sastType(n.Elt)} 123 | case *ast.InterfaceType: 124 | return &TypeString{"interface{}"} 125 | case *ast.StructType: 126 | return &TypeString{"struct{}"} 127 | } 128 | 129 | panic(reflect.TypeOf(e)) 130 | } 131 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/ast_walker_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import ( 20 | "go/parser" 21 | "go/token" 22 | "testing" 23 | 24 | "github.com/stretchr/testify/assert" 25 | 26 | "github.com/stretchr/testify/require" 27 | ) 28 | 29 | func TestSingleInterface(t *testing.T) { 30 | input := ` 31 | package sqlparser 32 | 33 | type Nodeiface interface { 34 | iNode() 35 | } 36 | ` 37 | 38 | fset := token.NewFileSet() 39 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 40 | require.NoError(t, err) 41 | 42 | result := Walk(ast) 43 | expected := SourceFile{ 44 | lines: []Sast{&InterfaceDeclaration{ 45 | name: "Nodeiface", 46 | block: "", 47 | }}, 48 | } 49 | assert.Equal(t, expected.String(), result.String()) 50 | } 51 | 52 | func TestEmptyStruct(t *testing.T) { 53 | input := ` 54 | package sqlparser 55 | 56 | type Empty struct {} 57 | ` 58 | 59 | fset := token.NewFileSet() 60 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 61 | require.NoError(t, err) 62 | 63 | result := Walk(ast) 64 | expected := SourceFile{ 65 | lines: []Sast{&StructDeclaration{ 66 | name: "Empty", 67 | fields: []*Field{}, 68 | }}, 69 | } 70 | assert.Equal(t, expected.String(), result.String()) 71 | } 72 | 73 | func TestStructWithStringField(t *testing.T) { 74 | input := ` 75 | package sqlparser 76 | 77 | type Struct struct { 78 | field string 79 | } 80 | ` 81 | 82 | fset := token.NewFileSet() 83 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 84 | require.NoError(t, err) 85 | 86 | result := Walk(ast) 87 | expected := SourceFile{ 88 | lines: []Sast{&StructDeclaration{ 89 | name: "Struct", 90 | fields: []*Field{{ 91 | name: "field", 92 | typ: &TypeString{typName: "string"}, 93 | }}, 94 | }}, 95 | } 96 | assert.Equal(t, expected.String(), result.String()) 97 | } 98 | 99 | func TestStructWithDifferentTypes(t *testing.T) { 100 | input := ` 101 | package sqlparser 102 | 103 | type Struct struct { 104 | field string 105 | reference *string 106 | array []string 107 | arrayOfRef []*string 108 | } 109 | ` 110 | 111 | fset := token.NewFileSet() 112 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 113 | require.NoError(t, err) 114 | 115 | result := Walk(ast) 116 | expected := SourceFile{ 117 | lines: []Sast{&StructDeclaration{ 118 | name: "Struct", 119 | fields: []*Field{{ 120 | name: "field", 121 | typ: &TypeString{typName: "string"}, 122 | }, { 123 | name: "reference", 124 | typ: &Ref{&TypeString{typName: "string"}}, 125 | }, { 126 | name: "array", 127 | typ: &Array{&TypeString{typName: "string"}}, 128 | }, { 129 | name: "arrayOfRef", 130 | typ: &Array{&Ref{&TypeString{typName: "string"}}}, 131 | }}, 132 | }}, 133 | } 134 | assert.Equal(t, expected.String(), result.String()) 135 | } 136 | 137 | func TestStructWithTwoStringFieldInOneLine(t *testing.T) { 138 | input := ` 139 | package sqlparser 140 | 141 | type Struct struct { 142 | left, right string 143 | } 144 | ` 145 | 146 | fset := token.NewFileSet() 147 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 148 | require.NoError(t, err) 149 | 150 | result := Walk(ast) 151 | expected := SourceFile{ 152 | lines: []Sast{&StructDeclaration{ 153 | name: "Struct", 154 | fields: []*Field{{ 155 | name: "left", 156 | typ: &TypeString{typName: "string"}, 157 | }, { 158 | name: "right", 159 | typ: &TypeString{typName: "string"}, 160 | }}, 161 | }}, 162 | } 163 | assert.Equal(t, expected.String(), result.String()) 164 | } 165 | 166 | func TestStructWithSingleMethod(t *testing.T) { 167 | input := ` 168 | package sqlparser 169 | 170 | type Empty struct {} 171 | 172 | func (*Empty) method() {} 173 | ` 174 | 175 | fset := token.NewFileSet() 176 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 177 | require.NoError(t, err) 178 | 179 | result := Walk(ast) 180 | expected := SourceFile{ 181 | lines: []Sast{ 182 | &StructDeclaration{ 183 | name: "Empty", 184 | fields: []*Field{}}, 185 | &FuncDeclaration{ 186 | receiver: &Field{ 187 | name: "", 188 | typ: &Ref{&TypeString{"Empty"}}, 189 | }, 190 | name: "method", 191 | block: "", 192 | arguments: []*Field{}, 193 | }, 194 | }, 195 | } 196 | assert.Equal(t, expected.String(), result.String()) 197 | } 198 | 199 | func TestSingleArrayType(t *testing.T) { 200 | input := ` 201 | package sqlparser 202 | 203 | type Strings []string 204 | ` 205 | 206 | fset := token.NewFileSet() 207 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 208 | require.NoError(t, err) 209 | 210 | result := Walk(ast) 211 | expected := SourceFile{ 212 | lines: []Sast{&TypeAlias{ 213 | name: "Strings", 214 | typ: &Array{&TypeString{"string"}}, 215 | }}, 216 | } 217 | assert.Equal(t, expected.String(), result.String()) 218 | } 219 | 220 | func TestSingleTypeAlias(t *testing.T) { 221 | input := ` 222 | package sqlparser 223 | 224 | type String string 225 | ` 226 | 227 | fset := token.NewFileSet() 228 | ast, err := parser.ParseFile(fset, "ast.go", input, 0) 229 | require.NoError(t, err) 230 | 231 | result := Walk(ast) 232 | expected := SourceFile{ 233 | lines: []Sast{&TypeAlias{ 234 | name: "String", 235 | typ: &TypeString{"string"}, 236 | }}, 237 | } 238 | assert.Equal(t, expected.String(), result.String()) 239 | } 240 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/main/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "bytes" 21 | "flag" 22 | "fmt" 23 | "go/parser" 24 | "go/token" 25 | "io/ioutil" 26 | "os" 27 | 28 | "github.com/xelabs/go-mysqlstack/sqlparser/visitorgen" 29 | "github.com/xelabs/go-mysqlstack/xlog" 30 | ) 31 | 32 | var ( 33 | inputFile = flag.String("input", "", "input file to use") 34 | outputFile = flag.String("output", "", "output file") 35 | compare = flag.Bool("compareOnly", false, "instead of writing to the output file, compare if the generated visitor is still valid for this ast.go") 36 | ) 37 | 38 | const usage = `Usage of visitorgen: 39 | 40 | go run /path/to/visitorgen/main -input=/path/to/ast.go -output=/path/to/rewriter.go 41 | ` 42 | 43 | func main() { 44 | log := xlog.NewStdLog(xlog.Level(xlog.INFO)) 45 | defer func() { 46 | if x := recover(); x != nil { 47 | log.Error("sqlparser.visitorgen.failed: %v", x) 48 | } 49 | }() 50 | 51 | flag.Usage = printUsage 52 | flag.Parse() 53 | if *inputFile == "" || *outputFile == "" { 54 | printUsage() 55 | os.Exit(0) 56 | } 57 | log.Info("sqlparser.visitorgen.inputFile[%s].outputFile[%s].compare[%t].start...\n", *inputFile, *outputFile, *compare) 58 | 59 | fs := token.NewFileSet() 60 | file, err := parser.ParseFile(fs, *inputFile, nil, parser.DeclarationErrors) 61 | if err != nil { 62 | panic(fmt.Sprintf("parse.file[%s].error[%v]", *inputFile, err)) 63 | } 64 | 65 | astWalkResult := visitorgen.Walk(file) 66 | vp := visitorgen.Transform(astWalkResult) 67 | vd := visitorgen.ToVisitorPlan(vp) 68 | 69 | replacementMethods := visitorgen.EmitReplacementMethods(vd) 70 | typeSwitch := visitorgen.EmitTypeSwitches(vd) 71 | 72 | b := &bytes.Buffer{} 73 | fmt.Fprint(b, fileHeader) 74 | fmt.Fprintln(b) 75 | fmt.Fprintln(b, replacementMethods) 76 | fmt.Fprint(b, applyHeader) 77 | fmt.Fprintln(b, typeSwitch) 78 | fmt.Fprintln(b, fileFooter) 79 | 80 | if *compare { 81 | currentFile, err := ioutil.ReadFile(*outputFile) 82 | if err != nil { 83 | panic(fmt.Sprintf("read.file[%s].error[%v]", *outputFile, err)) 84 | } 85 | if !bytes.Equal(b.Bytes(), currentFile) { 86 | fmt.Println("rewriter needs to be re-generated: go generate " + *outputFile) 87 | os.Exit(1) 88 | } 89 | log.Info("sqlparser.visitorgen.compare.success") 90 | } else { 91 | err = ioutil.WriteFile(*outputFile, b.Bytes(), 0644) 92 | if err != nil { 93 | panic(fmt.Sprintf("write.file[%s].error[%v]", *outputFile, err)) 94 | } 95 | } 96 | log.Info("sqlparser.visitorgen.finish...") 97 | } 98 | 99 | func printUsage() { 100 | os.Stderr.WriteString(usage) 101 | os.Stderr.WriteString("\nOptions:\n") 102 | flag.PrintDefaults() 103 | } 104 | 105 | const fileHeader = `// Code generated by visitorgen/main/main.go. DO NOT EDIT. 106 | 107 | package sqlparser 108 | 109 | //go:generate go run ./visitorgen/main -input=ast.go -output=rewriter.go 110 | 111 | import ( 112 | "reflect" 113 | ) 114 | 115 | type replacerFunc func(newNode, parent SQLNode) 116 | 117 | // application carries all the shared data so we can pass it around cheaply. 118 | type application struct { 119 | pre, post ApplyFunc 120 | cursor Cursor 121 | } 122 | ` 123 | 124 | const applyHeader = ` 125 | // apply is where the visiting happens. Here is where we keep the big switch-case that will be used 126 | // to do the actual visiting of SQLNodes 127 | func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { 128 | if node == nil || isNilValue(node) { 129 | return 130 | } 131 | 132 | // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead 133 | saved := a.cursor 134 | a.cursor.replacer = replacer 135 | a.cursor.node = node 136 | a.cursor.parent = parent 137 | 138 | if a.pre != nil && !a.pre(&a.cursor) { 139 | a.cursor = saved 140 | return 141 | } 142 | 143 | // walk children 144 | // (the order of the cases is alphabetical) 145 | switch n := node.(type) { 146 | case nil: 147 | ` 148 | 149 | const fileFooter = ` 150 | default: 151 | panic("unknown ast type " + reflect.TypeOf(node).String()) 152 | } 153 | 154 | if a.post != nil && !a.post(&a.cursor) { 155 | panic(abort) 156 | } 157 | 158 | a.cursor = saved 159 | } 160 | 161 | func isNilValue(i interface{}) bool { 162 | valueOf := reflect.ValueOf(i) 163 | kind := valueOf.Kind() 164 | isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice 165 | return isNullable && valueOf.IsNil() 166 | }` 167 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/sast.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | // simplified ast - when reading the golang ast of the ast.go file, we translate the golang ast objects 20 | // to this much simpler format, that contains only the necessary information and no more 21 | type ( 22 | // SourceFile contains all important lines from an ast.go file 23 | SourceFile struct { 24 | lines []Sast 25 | } 26 | 27 | // Sast or simplified AST, is a representation of the ast.go lines we are interested in 28 | Sast interface { 29 | toSastString() string 30 | } 31 | 32 | // InterfaceDeclaration represents a declaration of an interface. This is used to keep track of which types 33 | // need to be handled by the visitor framework 34 | InterfaceDeclaration struct { 35 | name, block string 36 | } 37 | 38 | // TypeAlias is used whenever we see a `type XXX YYY` - XXX is the new name for YYY. 39 | // Note that YYY could be an array or a reference 40 | TypeAlias struct { 41 | name string 42 | typ Type 43 | } 44 | 45 | // FuncDeclaration represents a function declaration. These are tracked to know which types implement interfaces. 46 | FuncDeclaration struct { 47 | receiver *Field 48 | name, block string 49 | arguments []*Field 50 | } 51 | 52 | // StructDeclaration represents a struct. It contains the fields and their types 53 | StructDeclaration struct { 54 | name string 55 | fields []*Field 56 | } 57 | 58 | // Field is a field in a struct - a name with a type tuple 59 | Field struct { 60 | name string 61 | typ Type 62 | } 63 | 64 | // Type represents a type in the golang type system. Used to keep track of type we need to handle, 65 | // and the types of fields. 66 | Type interface { 67 | toTypString() string 68 | rawTypeName() string 69 | } 70 | 71 | // TypeString is a raw type name, such as `string` 72 | TypeString struct { 73 | typName string 74 | } 75 | 76 | // Ref is a reference to something, such as `*string` 77 | Ref struct { 78 | inner Type 79 | } 80 | 81 | // Array is an array of things, such as `[]string` 82 | Array struct { 83 | inner Type 84 | } 85 | ) 86 | 87 | var _ Sast = (*InterfaceDeclaration)(nil) 88 | var _ Sast = (*StructDeclaration)(nil) 89 | var _ Sast = (*FuncDeclaration)(nil) 90 | var _ Sast = (*TypeAlias)(nil) 91 | 92 | var _ Type = (*TypeString)(nil) 93 | var _ Type = (*Ref)(nil) 94 | var _ Type = (*Array)(nil) 95 | 96 | // String returns a textual representation of the SourceFile. This is for testing purposed 97 | func (t *SourceFile) String() string { 98 | var result string 99 | for _, l := range t.lines { 100 | result += l.toSastString() 101 | result += "\n" 102 | } 103 | 104 | return result 105 | } 106 | 107 | func (t *Ref) toTypString() string { 108 | return "*" + t.inner.toTypString() 109 | } 110 | 111 | func (t *Array) toTypString() string { 112 | return "[]" + t.inner.toTypString() 113 | } 114 | 115 | func (t *TypeString) toTypString() string { 116 | return t.typName 117 | } 118 | 119 | func (f *FuncDeclaration) toSastString() string { 120 | var receiver string 121 | if f.receiver != nil { 122 | receiver = "(" + f.receiver.String() + ") " 123 | } 124 | var args string 125 | for i, arg := range f.arguments { 126 | if i > 0 { 127 | args += ", " 128 | } 129 | args += arg.String() 130 | } 131 | 132 | return "func " + receiver + f.name + "(" + args + ") {" + blockInNewLines(f.block) + "}" 133 | } 134 | 135 | func (i *InterfaceDeclaration) toSastString() string { 136 | return "type " + i.name + " interface {" + blockInNewLines(i.block) + "}" 137 | } 138 | 139 | func (a *TypeAlias) toSastString() string { 140 | return "type " + a.name + " " + a.typ.toTypString() 141 | } 142 | 143 | func (s *StructDeclaration) toSastString() string { 144 | var block string 145 | for _, f := range s.fields { 146 | block += "\t" + f.String() + "\n" 147 | } 148 | 149 | return "type " + s.name + " struct {" + blockInNewLines(block) + "}" 150 | } 151 | 152 | func blockInNewLines(block string) string { 153 | if block == "" { 154 | return "" 155 | } 156 | return "\n" + block + "\n" 157 | } 158 | 159 | // String returns a string representation of a field 160 | func (f *Field) String() string { 161 | if f.name != "" { 162 | return f.name + " " + f.typ.toTypString() 163 | } 164 | 165 | return f.typ.toTypString() 166 | } 167 | 168 | func (t *TypeString) rawTypeName() string { 169 | return t.typName 170 | } 171 | 172 | func (t *Ref) rawTypeName() string { 173 | return t.inner.rawTypeName() 174 | } 175 | 176 | func (t *Array) rawTypeName() string { 177 | return t.inner.rawTypeName() 178 | } 179 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/transformer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import "fmt" 20 | 21 | // Transform takes an input file and collects the information into an easier to consume format 22 | func Transform(input *SourceFile) *SourceInformation { 23 | interestingTypes := make(map[string]Type) 24 | interfaces := make(map[string]bool) 25 | structs := make(map[string]*StructDeclaration) 26 | typeAliases := make(map[string]*TypeAlias) 27 | 28 | for _, l := range input.lines { 29 | switch line := l.(type) { 30 | case *FuncDeclaration: 31 | interestingTypes[line.receiver.typ.toTypString()] = line.receiver.typ 32 | case *StructDeclaration: 33 | structs[line.name] = line 34 | case *TypeAlias: 35 | typeAliases[line.name] = line 36 | case *InterfaceDeclaration: 37 | interfaces[line.name] = true 38 | } 39 | } 40 | 41 | return &SourceInformation{ 42 | interfaces: interfaces, 43 | interestingTypes: interestingTypes, 44 | structs: structs, 45 | typeAliases: typeAliases, 46 | } 47 | } 48 | 49 | // SourceInformation contains the information from the ast.go file, but in a format that is easier to consume 50 | type SourceInformation struct { 51 | interestingTypes map[string]Type 52 | interfaces map[string]bool 53 | structs map[string]*StructDeclaration 54 | typeAliases map[string]*TypeAlias 55 | } 56 | 57 | func (v *SourceInformation) String() string { 58 | var types string 59 | for _, k := range v.interestingTypes { 60 | types += k.toTypString() + "\n" 61 | } 62 | var structs string 63 | for _, k := range v.structs { 64 | structs += k.toSastString() + "\n" 65 | } 66 | var typeAliases string 67 | for _, k := range v.typeAliases { 68 | typeAliases += k.toSastString() + "\n" 69 | } 70 | 71 | return fmt.Sprintf("Types to build visitor for:\n%s\nStructs with fields: \n%s\nTypeAliases with type: \n%s\n", types, structs, typeAliases) 72 | } 73 | 74 | // getItemTypeOfArray will return nil if the given type is not pointing to a array type. 75 | // If it is an array type, the type of it's items will be returned 76 | func (v *SourceInformation) getItemTypeOfArray(typ Type) Type { 77 | alias := v.typeAliases[typ.rawTypeName()] 78 | if alias == nil { 79 | return nil 80 | } 81 | arrTyp, isArray := alias.typ.(*Array) 82 | if !isArray { 83 | return v.getItemTypeOfArray(alias.typ) 84 | } 85 | return arrTyp.inner 86 | } 87 | 88 | func (v *SourceInformation) isSQLNode(typ Type) bool { 89 | _, isInteresting := v.interestingTypes[typ.toTypString()] 90 | if isInteresting { 91 | return true 92 | } 93 | _, isInterface := v.interfaces[typ.toTypString()] 94 | return isInterface 95 | } 96 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/transformer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestSimplestAst(t *testing.T) { 26 | /* 27 | type NodeInterface interface { 28 | iNode() 29 | } 30 | 31 | type NodeStruct struct {} 32 | 33 | func (*NodeStruct) iNode{} 34 | */ 35 | input := &SourceFile{ 36 | lines: []Sast{ 37 | &InterfaceDeclaration{ 38 | name: "NodeInterface", 39 | block: "// an interface lives here"}, 40 | &StructDeclaration{ 41 | name: "NodeStruct", 42 | fields: []*Field{}}, 43 | &FuncDeclaration{ 44 | receiver: &Field{ 45 | name: "", 46 | typ: &Ref{&TypeString{"NodeStruct"}}, 47 | }, 48 | name: "iNode", 49 | block: "", 50 | arguments: []*Field{}}, 51 | }, 52 | } 53 | 54 | expected := &SourceInformation{ 55 | interestingTypes: map[string]Type{ 56 | "*NodeStruct": &Ref{&TypeString{"NodeStruct"}}}, 57 | structs: map[string]*StructDeclaration{ 58 | "NodeStruct": { 59 | name: "NodeStruct", 60 | fields: []*Field{}}}, 61 | } 62 | 63 | assert.Equal(t, expected.String(), Transform(input).String()) 64 | } 65 | 66 | func TestAstWithArray(t *testing.T) { 67 | /* 68 | type NodeInterface interface { 69 | iNode() 70 | } 71 | 72 | func (*NodeArray) iNode{} 73 | 74 | type NodeArray []NodeInterface 75 | */ 76 | input := &SourceFile{ 77 | lines: []Sast{ 78 | &InterfaceDeclaration{ 79 | name: "NodeInterface"}, 80 | &TypeAlias{ 81 | name: "NodeArray", 82 | typ: &Array{&TypeString{"NodeInterface"}}, 83 | }, 84 | &FuncDeclaration{ 85 | receiver: &Field{ 86 | name: "", 87 | typ: &Ref{&TypeString{"NodeArray"}}, 88 | }, 89 | name: "iNode", 90 | block: "", 91 | arguments: []*Field{}}, 92 | }, 93 | } 94 | 95 | expected := &SourceInformation{ 96 | interestingTypes: map[string]Type{ 97 | "*NodeArray": &Ref{&TypeString{"NodeArray"}}}, 98 | structs: map[string]*StructDeclaration{}, 99 | typeAliases: map[string]*TypeAlias{ 100 | "NodeArray": { 101 | name: "NodeArray", 102 | typ: &Array{&TypeString{"NodeInterface"}}, 103 | }, 104 | }, 105 | } 106 | 107 | result := Transform(input) 108 | 109 | assert.Equal(t, expected.String(), result.String()) 110 | } 111 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/visitor_emitter.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import ( 20 | "fmt" 21 | "strings" 22 | ) 23 | 24 | // EmitReplacementMethods is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, 25 | // and produces a string from it. This method will produce the replacement methods that make it possible to 26 | // replace objects in fields or in slices. 27 | func EmitReplacementMethods(vd *VisitorPlan) string { 28 | var sb builder 29 | for _, s := range vd.Switches { 30 | for _, k := range s.Fields { 31 | sb.appendF(k.asReplMethod()) 32 | sb.newLine() 33 | } 34 | } 35 | 36 | return sb.String() 37 | } 38 | 39 | // EmitTypeSwitches is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, 40 | // and produces a string from it. This method will produce the switch cases needed to cover the Vitess AST. 41 | func EmitTypeSwitches(vd *VisitorPlan) string { 42 | var sb builder 43 | for _, s := range vd.Switches { 44 | sb.newLine() 45 | sb.appendF(" case %s:", s.Type.toTypString()) 46 | for _, k := range s.Fields { 47 | sb.appendF(k.asSwitchCase()) 48 | } 49 | } 50 | 51 | return sb.String() 52 | } 53 | 54 | func (b *builder) String() string { 55 | return strings.TrimSpace(b.sb.String()) 56 | } 57 | 58 | type builder struct { 59 | sb strings.Builder 60 | } 61 | 62 | func (b *builder) appendF(format string, data ...interface{}) *builder { 63 | _, err := b.sb.WriteString(fmt.Sprintf(format, data...)) 64 | if err != nil { 65 | panic(err) 66 | } 67 | b.newLine() 68 | return b 69 | } 70 | 71 | func (b *builder) newLine() { 72 | _, err := b.sb.WriteString("\n") 73 | if err != nil { 74 | panic(err) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/visitor_emitter_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package visitorgen 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/stretchr/testify/require" 23 | ) 24 | 25 | func TestSingleItem(t *testing.T) { 26 | sfi := SingleFieldItem{ 27 | StructType: &Ref{&TypeString{"Struct"}}, 28 | FieldType: &TypeString{"string"}, 29 | FieldName: "Field", 30 | } 31 | 32 | expectedReplacer := `func replaceStructField(newNode, parent SQLNode) { 33 | parent.(*Struct).Field = newNode.(string) 34 | }` 35 | 36 | expectedSwitch := ` a.apply(node, n.Field, replaceStructField)` 37 | require.Equal(t, expectedReplacer, sfi.asReplMethod()) 38 | require.Equal(t, expectedSwitch, sfi.asSwitchCase()) 39 | } 40 | 41 | func TestArrayFieldItem(t *testing.T) { 42 | sfi := ArrayFieldItem{ 43 | StructType: &Ref{&TypeString{"Struct"}}, 44 | ItemType: &TypeString{"string"}, 45 | FieldName: "Field", 46 | } 47 | 48 | expectedReplacer := `type replaceStructField int 49 | 50 | func (r *replaceStructField) replace(newNode, container SQLNode) { 51 | container.(*Struct).Field[int(*r)] = newNode.(string) 52 | } 53 | 54 | func (r *replaceStructField) inc() { 55 | *r++ 56 | }` 57 | 58 | expectedSwitch := ` replacerField := replaceStructField(0) 59 | replacerFieldB := &replacerField 60 | for _, item := range n.Field { 61 | a.apply(node, item, replacerFieldB.replace) 62 | replacerFieldB.inc() 63 | }` 64 | require.Equal(t, expectedReplacer, sfi.asReplMethod()) 65 | require.Equal(t, expectedSwitch, sfi.asSwitchCase()) 66 | } 67 | 68 | func TestArrayItem(t *testing.T) { 69 | sfi := ArrayItem{ 70 | StructType: &Ref{&TypeString{"Struct"}}, 71 | ItemType: &TypeString{"string"}, 72 | } 73 | 74 | expectedReplacer := `type replaceStructItems int 75 | 76 | func (r *replaceStructItems) replace(newNode, container SQLNode) { 77 | container.(*Struct)[int(*r)] = newNode.(string) 78 | } 79 | 80 | func (r *replaceStructItems) inc() { 81 | *r++ 82 | }` 83 | 84 | expectedSwitch := ` replacer := replaceStructItems(0) 85 | replacerRef := &replacer 86 | for _, item := range n { 87 | a.apply(node, item, replacerRef.replace) 88 | replacerRef.inc() 89 | }` 90 | require.Equal(t, expectedReplacer, sfi.asReplMethod()) 91 | require.Equal(t, expectedSwitch, sfi.asSwitchCase()) 92 | } 93 | -------------------------------------------------------------------------------- /sqlparser/visitorgen/visitorgen.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | //Package visitorgen is responsible for taking the ast.go of Vitess 18 | //and producing visitor infrastructure for it. 19 | // 20 | //This is accomplished in a few steps. 21 | //Step 1: Walk the AST and collect the interesting information into a format that is 22 | // easy to consume for the next step. The output format is a *SourceFile, that 23 | // contains the needed information in a format that is pretty close to the golang ast, 24 | // but simplified 25 | //Step 2: A SourceFile is packaged into a SourceInformation. SourceInformation is still 26 | // concerned with the input ast - it's just an even more distilled and easy to 27 | // consume format for the last step. This step is performed by the code in transformer.go. 28 | //Step 3: Using the SourceInformation, the struct_producer.go code produces the final data structure 29 | // used, a VisitorPlan. This is focused on the output - it contains a list of all fields or 30 | // arrays that need to be handled by the visitor produced. 31 | //Step 4: The VisitorPlan is lastly turned into a string that is written as the output of 32 | // this whole process. 33 | package visitorgen 34 | -------------------------------------------------------------------------------- /sqlparser/xa_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sqlparser 18 | 19 | import ( 20 | "strings" 21 | "testing" 22 | ) 23 | 24 | func TestXA(t *testing.T) { 25 | validSQL := []struct { 26 | input string 27 | output string 28 | }{ 29 | { 30 | input: "xa begin 'x1'", 31 | output: "XA", 32 | }, 33 | 34 | { 35 | input: "xa prepare 'x1'", 36 | output: "XA", 37 | }, 38 | 39 | { 40 | input: "xa end 'x1'", 41 | output: "XA", 42 | }, 43 | 44 | { 45 | input: "xa commit 'x1'", 46 | output: "XA", 47 | }, 48 | 49 | { 50 | input: "xa rollback 'x1'", 51 | output: "XA", 52 | }, 53 | } 54 | 55 | for _, exp := range validSQL { 56 | sql := strings.TrimSpace(exp.input) 57 | tree, err := Parse(sql) 58 | if err != nil { 59 | t.Errorf("input: %s, err: %v", sql, err) 60 | continue 61 | } 62 | 63 | // Walk. 64 | Walk(func(node SQLNode) (bool, error) { 65 | return true, nil 66 | }, tree) 67 | 68 | node := tree.(*Xa) 69 | 70 | // Format. 71 | got := String(node) 72 | if exp.output != got { 73 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got) 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /xlog/options.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package xlog 11 | 12 | var ( 13 | defaultName = " " 14 | defaultLevel = DEBUG 15 | ) 16 | 17 | // Options used for the options of the xlog. 18 | type Options struct { 19 | Name string 20 | Level LogLevel 21 | } 22 | 23 | // Option func. 24 | type Option func(*Options) 25 | 26 | func newOptions(opts ...Option) *Options { 27 | opt := &Options{} 28 | for _, o := range opts { 29 | o(opt) 30 | } 31 | 32 | if len(opt.Name) == 0 { 33 | opt.Name = defaultName 34 | } 35 | 36 | if opt.Level == 0 { 37 | opt.Level = defaultLevel 38 | } 39 | return opt 40 | } 41 | 42 | // Name used to set the name. 43 | func Name(v string) Option { 44 | return func(o *Options) { 45 | o.Name = v 46 | } 47 | } 48 | 49 | // Level used to set the log level. 50 | func Level(v LogLevel) Option { 51 | return func(o *Options) { 52 | o.Level = v 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /xlog/syslog.go: -------------------------------------------------------------------------------- 1 | // +build linux darwin dragonfly freebsd netbsd openbsd solaris 2 | 3 | package xlog 4 | 5 | import "log/syslog" 6 | 7 | // NewSysLog creates a new sys log. 8 | func NewSysLog(opts ...Option) *Log { 9 | w, err := syslog.New(syslog.LOG_DEBUG, "") 10 | if err != nil { 11 | panic(err) 12 | } 13 | return NewXLog(w, opts...) 14 | } 15 | -------------------------------------------------------------------------------- /xlog/syslog_windows.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package xlog 4 | 5 | import ( 6 | "os" 7 | ) 8 | 9 | // NewSysLog creates a new sys log. Because there is no syslog support for 10 | // Windows, we output to os.Stdout. 11 | func NewSysLog(opts ...Option) *Log { 12 | return NewXLog(os.Stdout, opts...) 13 | } 14 | -------------------------------------------------------------------------------- /xlog/xlog.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package xlog 11 | 12 | import ( 13 | "fmt" 14 | "io" 15 | "log" 16 | "os" 17 | "strings" 18 | ) 19 | 20 | var ( 21 | defaultlog *Log 22 | ) 23 | 24 | // LogLevel used for log level. 25 | type LogLevel int 26 | 27 | const ( 28 | // DEBUG enum. 29 | DEBUG LogLevel = 1 << iota 30 | // INFO enum. 31 | INFO 32 | // WARNING enum. 33 | WARNING 34 | // ERROR enum. 35 | ERROR 36 | // FATAL enum. 37 | FATAL 38 | // PANIC enum. 39 | PANIC 40 | ) 41 | 42 | // LevelNames represents the string name of all levels. 43 | var LevelNames = [...]string{ 44 | DEBUG: "DEBUG", 45 | INFO: "INFO", 46 | WARNING: "WARNING", 47 | ERROR: "ERROR", 48 | FATAL: "FATAL", 49 | PANIC: "PANIC", 50 | } 51 | 52 | const ( 53 | // D_LOG_FLAGS is the default log flags. 54 | D_LOG_FLAGS int = log.LstdFlags | log.Lmicroseconds | log.Lshortfile 55 | ) 56 | 57 | // Log struct. 58 | type Log struct { 59 | opts *Options 60 | *log.Logger 61 | } 62 | 63 | // NewStdLog creates a new std log. 64 | func NewStdLog(opts ...Option) *Log { 65 | return NewXLog(os.Stdout, opts...) 66 | } 67 | 68 | // NewXLog creates a new xlog. 69 | func NewXLog(w io.Writer, opts ...Option) *Log { 70 | options := newOptions(opts...) 71 | 72 | l := &Log{ 73 | opts: options, 74 | } 75 | l.Logger = log.New(w, l.opts.Name, D_LOG_FLAGS) 76 | defaultlog = l 77 | return l 78 | } 79 | 80 | // NewLog creates the new log. 81 | func NewLog(w io.Writer, prefix string, flag int) *Log { 82 | l := &Log{} 83 | l.Logger = log.New(w, prefix, flag) 84 | return l 85 | } 86 | 87 | // GetLog returns Log. 88 | func GetLog() *Log { 89 | if defaultlog == nil { 90 | log := NewStdLog(Level(INFO)) 91 | defaultlog = log 92 | } 93 | return defaultlog 94 | } 95 | 96 | // SetLevel used to set the log level. 97 | func (t *Log) SetLevel(level string) { 98 | for i, v := range LevelNames { 99 | if level == v { 100 | t.opts.Level = LogLevel(i) 101 | return 102 | } 103 | } 104 | } 105 | 106 | // Debug used to log debug msg. 107 | func (t *Log) Debug(format string, v ...interface{}) { 108 | if DEBUG < t.opts.Level { 109 | return 110 | } 111 | t.log("\t [DEBUG] \t%s", fmt.Sprintf(format, v...)) 112 | } 113 | 114 | // Info used to log info msg. 115 | func (t *Log) Info(format string, v ...interface{}) { 116 | if INFO < t.opts.Level { 117 | return 118 | } 119 | t.log("\t [INFO] \t%s", fmt.Sprintf(format, v...)) 120 | } 121 | 122 | // Warning used to log warning msg. 123 | func (t *Log) Warning(format string, v ...interface{}) { 124 | if WARNING < t.opts.Level { 125 | return 126 | } 127 | t.log("\t [WARNING] \t%s", fmt.Sprintf(format, v...)) 128 | } 129 | 130 | // Error used to log error msg. 131 | func (t *Log) Error(format string, v ...interface{}) { 132 | if ERROR < t.opts.Level { 133 | return 134 | } 135 | t.log("\t [ERROR] \t%s", fmt.Sprintf(format, v...)) 136 | } 137 | 138 | // Fatal used to log faltal msg. 139 | func (t *Log) Fatal(format string, v ...interface{}) { 140 | if FATAL < t.opts.Level { 141 | return 142 | } 143 | t.log("\t [FATAL+EXIT] \t%s", fmt.Sprintf(format, v...)) 144 | os.Exit(1) 145 | } 146 | 147 | // Panic used to log panic msg. 148 | func (t *Log) Panic(format string, v ...interface{}) { 149 | if PANIC < t.opts.Level { 150 | return 151 | } 152 | msg := fmt.Sprintf("\t [PANIC] \t%s", fmt.Sprintf(format, v...)) 153 | t.log(msg) 154 | panic(msg) 155 | } 156 | 157 | // Close used to close the log. 158 | func (t *Log) Close() { 159 | // nothing 160 | } 161 | 162 | func (t *Log) log(format string, v ...interface{}) { 163 | t.Output(3, strings.Repeat(" ", 3)+fmt.Sprintf(format, v...)+"\n") 164 | } 165 | -------------------------------------------------------------------------------- /xlog/xlog_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * go-mysqlstack 3 | * xelabs.org 4 | * 5 | * Copyright (c) XeLabs 6 | * GPL License 7 | * 8 | */ 9 | 10 | package xlog 11 | 12 | import ( 13 | "testing" 14 | ) 15 | 16 | // assert fails the test if the condition is false. 17 | func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) { 18 | if !condition { 19 | tb.FailNow() 20 | } 21 | } 22 | 23 | func TestGetLog(t *testing.T) { 24 | GetLog().Debug("DEBUG") 25 | log := NewStdLog() 26 | log.SetLevel("INFO") 27 | GetLog().Debug("DEBUG") 28 | GetLog().Info("INFO") 29 | } 30 | 31 | func TestSysLog(t *testing.T) { 32 | log := NewSysLog() 33 | 34 | log.Debug("DEBUG") 35 | log.Info("INFO") 36 | log.Warning("WARNING") 37 | log.Error("ERROR") 38 | 39 | log.SetLevel("DEBUG") 40 | log.Debug("DEBUG") 41 | log.Info("INFO") 42 | log.Warning("WARNING") 43 | log.Error("ERROR") 44 | 45 | log.SetLevel("INFO") 46 | log.Debug("DEBUG") 47 | log.Info("INFO") 48 | log.Warning("WARNING") 49 | log.Error("ERROR") 50 | 51 | log.SetLevel("WARNING") 52 | log.Debug("DEBUG") 53 | log.Info("INFO") 54 | log.Warning("WARNING") 55 | log.Error("ERROR") 56 | 57 | log.SetLevel("ERROR") 58 | log.Debug("DEBUG") 59 | log.Info("INFO") 60 | log.Warning("WARNING") 61 | log.Error("ERROR") 62 | } 63 | 64 | func TestStdLog(t *testing.T) { 65 | log := NewStdLog() 66 | 67 | log.Println("........DEFAULT........") 68 | log.Debug("DEBUG") 69 | log.Info("INFO") 70 | log.Warning("WARNING") 71 | log.Error("ERROR") 72 | 73 | log.Println("........DEBUG........") 74 | log.SetLevel("DEBUG") 75 | log.Debug("DEBUG") 76 | log.Info("INFO") 77 | log.Warning("WARNING") 78 | log.Error("ERROR") 79 | 80 | log.Println("........INFO........") 81 | log.SetLevel("INFO") 82 | log.Debug("DEBUG") 83 | log.Info("INFO") 84 | log.Warning("WARNING") 85 | log.Error("ERROR") 86 | 87 | log.Println("........WARNING........") 88 | log.SetLevel("WARNING") 89 | log.Debug("DEBUG") 90 | log.Info("INFO") 91 | log.Warning("WARNING") 92 | log.Error("ERROR") 93 | 94 | log.Println("........ERROR........") 95 | log.SetLevel("ERROR") 96 | log.Debug("DEBUG") 97 | log.Info("INFO") 98 | log.Warning("WARNING") 99 | log.Error("ERROR") 100 | } 101 | 102 | func TestLogLevel(t *testing.T) { 103 | log := NewStdLog() 104 | { 105 | log.SetLevel("DEBUG") 106 | want := DEBUG 107 | got := log.opts.Level 108 | Assert(t, want == got, "want[%v]!=got[%v]", want, got) 109 | } 110 | 111 | { 112 | log.SetLevel("DEBUGX") 113 | want := DEBUG 114 | got := log.opts.Level 115 | Assert(t, want == got, "want[%v]!=got[%v]", want, got) 116 | } 117 | 118 | { 119 | log.SetLevel("PANIC") 120 | want := PANIC 121 | got := log.opts.Level 122 | Assert(t, want == got, "want[%v]!=got[%v]", want, got) 123 | } 124 | 125 | { 126 | log.SetLevel("WARNING") 127 | want := WARNING 128 | got := log.opts.Level 129 | Assert(t, want == got, "want[%v]!=got[%v]", want, got) 130 | } 131 | } 132 | --------------------------------------------------------------------------------