├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── driver
├── client.go
├── client_test.go
├── mock.go
├── rows.go
├── rows_test.go
├── server.go
├── server_test.go
├── session.go
├── session_test.go
├── statement.go
└── statement_test.go
├── examples
├── client.go
└── mysqld.go
├── go.mod
├── go.sum
├── makefile
├── packet
├── error.go
├── mock.go
├── packets.go
├── packets_test.go
├── stream.go
└── stream_test.go
├── proto
├── auth.go
├── auth_test.go
├── column.go
├── column_test.go
├── const.go
├── eof.go
├── eof_test.go
├── err.go
├── err_test.go
├── greeting.go
├── greeting_test.go
├── ok.go
├── ok_test.go
├── statement.go
└── statement_test.go
├── sqldb
├── constants.go
├── constants_test.go
├── sql_error.go
└── sql_error_test.go
├── sqlparser
├── .idea
│ ├── misc.xml
│ ├── modules.xml
│ ├── sqlparser.iml
│ ├── vcs.xml
│ └── workspace.xml
├── Makefile
├── analyzer.go
├── analyzer_test.go
├── ast.go
├── ast_funcs.go
├── ast_test.go
├── checksum_test.go
├── comments.go
├── comments_test.go
├── constants.go
├── ddl_test.go
├── depends
│ ├── bytes2
│ │ ├── buffer.go
│ │ └── buffer_test.go
│ ├── common
│ │ ├── buffer.go
│ │ ├── buffer_test.go
│ │ ├── hash_table.go
│ │ ├── hash_table_test.go
│ │ ├── unsafe.go
│ │ └── unsafe_test.go
│ ├── query
│ │ └── query.pb.go
│ └── sqltypes
│ │ ├── aggregation.go
│ │ ├── aggregation_test.go
│ │ ├── arithmetic.go
│ │ ├── arithmetic_test.go
│ │ ├── bind_variables.go
│ │ ├── bind_variables_test.go
│ │ ├── column.go
│ │ ├── column_test.go
│ │ ├── const.go
│ │ ├── limit.go
│ │ ├── limit_test.go
│ │ ├── plan_value.go
│ │ ├── result.go
│ │ ├── result_test.go
│ │ ├── row.go
│ │ ├── time.go
│ │ ├── time_test.go
│ │ ├── type.go
│ │ ├── type_test.go
│ │ ├── value.go
│ │ └── value_test.go
├── encodable.go
├── encodable_test.go
├── explain_test.go
├── impossible_query.go
├── impossible_query_test.go
├── kill_test.go
├── normalizer.go
├── normalizer_test.go
├── parse_test.go
├── parsed_query.go
├── parsed_query_test.go
├── parser.go
├── precedence_test.go
├── radon_test.go
├── rewriter.go
├── rewriter_api.go
├── select_test.go
├── set_test.go
├── show_test.go
├── sql.go
├── sql.y
├── token.go
├── token_test.go
├── tracked_buffer.go
├── tracked_buffer_test.go
├── txn_test.go
├── visitorgen
│ ├── ast_walker.go
│ ├── ast_walker_test.go
│ ├── main
│ │ └── main.go
│ ├── sast.go
│ ├── struct_producer.go
│ ├── struct_producer_test.go
│ ├── transformer.go
│ ├── transformer_test.go
│ ├── visitor_emitter.go
│ ├── visitor_emitter_test.go
│ └── visitorgen.go
└── xa_test.go
└── xlog
├── options.go
├── syslog.go
├── syslog_windows.go
├── xlog.go
└── xlog_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | tags
2 | bin/*
3 | *.output
4 | coverage.*
5 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | sudo: required
3 | go:
4 | - 1.x
5 |
6 | before_install:
7 | - go get github.com/shopspring/decimal
8 | - go get github.com/pierrre/gotestcover
9 | - go get github.com/stretchr/testify/assert
10 |
11 | script:
12 | - make test
13 | - make coverage
14 |
15 | after_success:
16 | # send coverage reports to Codecov
17 | - bash <(curl -s https://codecov.io/bash) -f "!mock.go"
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, xelabs
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/xelabs/go-mysqlstack) [](https://goreportcard.com/report/github.com/xelabs/go-mysqlstack) [](https://codecov.io/gh/xelabs/go-mysqlstack/branch/master)
2 |
3 | # go-mysqlstack
4 |
5 | ***go-mysqlstack*** is an MySQL protocol library implementing in Go (golang).
6 |
7 | Protocol is based on [mysqlproto-go](https://github.com/pubnative/mysqlproto-go) and [go-sql-driver](https://github.com/go-sql-driver/mysql)
8 |
9 | ## Running Tests
10 |
11 | ```
12 | $ mkdir src
13 | $ export GOPATH=`pwd`
14 | $ go get -u github.com/xelabs/go-mysqlstack/driver
15 | $ cd src/github.com/xelabs/go-mysqlstack/
16 | $ make test
17 | ```
18 |
19 | ## Examples
20 |
21 | 1. ***examples/mysqld.go*** mocks a MySQL server by running:
22 |
23 | ```
24 | $ go run example/mysqld.go
25 | 2018/01/26 16:02:02.304376 mysqld.go:52: [INFO] mysqld.server.start.address[:4407]
26 | ```
27 |
28 | 2. ***examples/client.go*** mocks a client and query from the mock MySQL server:
29 |
30 | ```
31 | $ go run example/client.go
32 | 2018/01/26 16:06:10.779340 client.go:32: [INFO] results:[[[10 nice name]]]
33 | ```
34 |
35 | ## Status
36 |
37 | go-mysqlstack is production ready.
38 |
39 | ## License
40 |
41 | go-mysqlstack is released under the BSD-3-Clause License. See [LICENSE](https://github.com/xelabs/go-mysqlstack/blob/master/LICENSE)
42 |
--------------------------------------------------------------------------------
/driver/rows.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package driver
11 |
12 | import (
13 | "errors"
14 | "fmt"
15 |
16 | "github.com/xelabs/go-mysqlstack/proto"
17 |
18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
20 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
21 | )
22 |
23 | var _ Rows = &TextRows{}
24 |
25 | type RowMode int
26 |
27 | const (
28 | TextRowMode RowMode = iota
29 | BinaryRowMode
30 | )
31 |
32 | // Rows presents row cursor interface.
33 | type Rows interface {
34 | Next() bool
35 | Close() error
36 | Datas() []byte
37 | Bytes() int
38 | RowsAffected() uint64
39 | LastInsertID() uint64
40 | LastError() error
41 | Fields() []*querypb.Field
42 | RowValues() ([]sqltypes.Value, error)
43 | }
44 |
45 | // BaseRows --
46 | type BaseRows struct {
47 | c Conn
48 | end bool
49 | err error
50 | data []byte
51 | bytes int
52 | rowsAffected uint64
53 | insertID uint64
54 | buffer *common.Buffer
55 | fields []*querypb.Field
56 | }
57 |
58 | // TextRows presents row tuple.
59 | type TextRows struct {
60 | BaseRows
61 | }
62 |
63 | // BinaryRows presents binary row tuple.
64 | type BinaryRows struct {
65 | BaseRows
66 | }
67 |
68 | // Next implements the Rows interface.
69 | // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
70 | func (r *BaseRows) Next() bool {
71 | defer func() {
72 | if r.err != nil {
73 | r.c.Cleanup()
74 | }
75 | }()
76 |
77 | if r.end {
78 | return false
79 | }
80 |
81 | // if fields count is 0
82 | // the packet is OK-Packet without Resultset.
83 | if len(r.fields) == 0 {
84 | r.end = true
85 | return false
86 | }
87 |
88 | if r.data, r.err = r.c.NextPacket(); r.err != nil {
89 | r.end = true
90 | return false
91 | }
92 |
93 | switch r.data[0] {
94 | case proto.EOF_PACKET:
95 | // This packet may be one of two kinds:
96 | // - an EOF packet,
97 | // - an OK packet with an EOF header if
98 | // sqldb.CLIENT_DEPRECATE_EOF is set.
99 | r.end = true
100 | return false
101 |
102 | case proto.ERR_PACKET:
103 | r.err = proto.UnPackERR(r.data)
104 | r.end = true
105 | return false
106 | }
107 | r.buffer.Reset(r.data)
108 | return true
109 | }
110 |
111 | // Close drain the rest packets and check the error.
112 | func (r *BaseRows) Close() error {
113 | for r.Next() {
114 | }
115 | return r.LastError()
116 | }
117 |
118 | // RowValues implements the Rows interface.
119 | // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
120 | func (r *BaseRows) RowValues() ([]sqltypes.Value, error) {
121 | if r.fields == nil {
122 | return nil, errors.New("rows.fields is NIL")
123 | }
124 |
125 | colNumber := len(r.fields)
126 | result := make([]sqltypes.Value, colNumber)
127 | for i := 0; i < colNumber; i++ {
128 | v, err := r.buffer.ReadLenEncodeBytes()
129 | if err != nil {
130 | r.c.Cleanup()
131 | return nil, err
132 | }
133 |
134 | if v != nil {
135 | r.bytes += len(v)
136 | result[i] = sqltypes.MakeTrusted(r.fields[i].Type, v)
137 | }
138 | }
139 | return result, nil
140 | }
141 |
142 | // Datas implements the Rows interface.
143 | func (r *BaseRows) Datas() []byte {
144 | return r.buffer.Datas()
145 | }
146 |
147 | // Fields implements the Rows interface.
148 | func (r *BaseRows) Fields() []*querypb.Field {
149 | return r.fields
150 | }
151 |
152 | // Bytes returns all the memory usage which read by this row cursor.
153 | func (r *BaseRows) Bytes() int {
154 | return r.bytes
155 | }
156 |
157 | // RowsAffected implements the Rows interface.
158 | func (r *BaseRows) RowsAffected() uint64 {
159 | return r.rowsAffected
160 | }
161 |
162 | // LastInsertID implements the Rows interface.
163 | func (r *BaseRows) LastInsertID() uint64 {
164 | return r.insertID
165 | }
166 |
167 | // LastError implements the Rows interface.
168 | func (r *BaseRows) LastError() error {
169 | return r.err
170 | }
171 |
172 | // NewTextRows creates TextRows.
173 | func NewTextRows(c Conn) *TextRows {
174 | textRows := &TextRows{}
175 | textRows.c = c
176 | textRows.buffer = common.NewBuffer(8)
177 | return textRows
178 | }
179 |
180 | // NewBinaryRows creates BinaryRows.
181 | func NewBinaryRows(c Conn) *BinaryRows {
182 | binaryRows := &BinaryRows{}
183 | binaryRows.c = c
184 | binaryRows.buffer = common.NewBuffer(8)
185 | return binaryRows
186 | }
187 |
188 | // RowValues implements the Rows interface.
189 | // https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
190 | func (r *BinaryRows) RowValues() ([]sqltypes.Value, error) {
191 | if r.fields == nil {
192 | return nil, errors.New("rows.fields is NIL")
193 | }
194 |
195 | header, err := r.buffer.ReadU8()
196 | if err != nil {
197 | return nil, err
198 | }
199 | if header != proto.OK_PACKET {
200 | return nil, fmt.Errorf("binary.rows.header.is.not.ok[%v]", header)
201 | }
202 |
203 | colCount := len(r.fields)
204 | // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
205 | nullMask, err := r.buffer.ReadBytes(int((colCount + 7 + 2) / 8))
206 | if err != nil {
207 | return nil, err
208 | }
209 |
210 | result := make([]sqltypes.Value, colCount)
211 | for i := 0; i < colCount; i++ {
212 | // Field is NULL
213 | // (byte >> bit-pos) % 2 == 1
214 | if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
215 | result[i] = sqltypes.Value{}
216 | continue
217 | }
218 |
219 | v, err := sqltypes.ParseMySQLValues(r.buffer, r.fields[i].Type)
220 | if err != nil {
221 | r.c.Cleanup()
222 | return nil, err
223 | }
224 |
225 | if v != nil {
226 | val, err := sqltypes.BuildValue(v)
227 | if err != nil {
228 | r.c.Cleanup()
229 | return nil, err
230 | }
231 | r.bytes += val.Len()
232 | result[i] = val
233 | } else {
234 | result[i] = sqltypes.Value{}
235 | }
236 | }
237 | return result, nil
238 | }
239 |
--------------------------------------------------------------------------------
/driver/rows_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package driver
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 |
17 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
19 | "github.com/xelabs/go-mysqlstack/xlog"
20 | )
21 |
22 | func TestRows(t *testing.T) {
23 | result1 := &sqltypes.Result{
24 | Fields: []*querypb.Field{
25 | {
26 | Name: "id",
27 | Type: querypb.Type_INT32,
28 | },
29 | {
30 | Name: "name",
31 | Type: querypb.Type_VARCHAR,
32 | },
33 | },
34 | Rows: [][]sqltypes.Value{
35 | {
36 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")),
37 | sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
38 | },
39 | {
40 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")),
41 | sqltypes.NULL,
42 | },
43 | },
44 | }
45 | result2 := &sqltypes.Result{
46 | RowsAffected: 123,
47 | InsertID: 123456789,
48 | }
49 | result3 := &sqltypes.Result{
50 | Fields: []*querypb.Field{
51 | {
52 | Name: "name",
53 | Type: querypb.Type_VARCHAR,
54 | },
55 | },
56 | Rows: [][]sqltypes.Value{
57 | {
58 | sqltypes.NULL,
59 | },
60 | },
61 | }
62 |
63 | log := xlog.NewStdLog(xlog.Level(xlog.ERROR))
64 | th := NewTestHandler(log)
65 | svr, err := MockMysqlServer(log, th)
66 | assert.Nil(t, err)
67 | defer svr.Close()
68 | address := svr.Addr()
69 |
70 | // query
71 | {
72 | client, err := NewConn("mock", "mock", address, "test", "")
73 | assert.Nil(t, err)
74 | defer client.Close()
75 |
76 | th.AddQuery("SELECT2", result2)
77 | rows, err := client.Query("SELECT2")
78 | assert.Nil(t, err)
79 |
80 | assert.Equal(t, uint64(123), rows.RowsAffected())
81 | assert.Equal(t, uint64(123456789), rows.LastInsertID())
82 | }
83 |
84 | // query
85 | {
86 | client, err := NewConn("mock", "mock", address, "test", "")
87 | assert.Nil(t, err)
88 | defer client.Close()
89 |
90 | th.AddQuery("SELECT1", result1)
91 | rows, err := client.Query("SELECT1")
92 | assert.Nil(t, err)
93 | assert.Equal(t, result1.Fields, rows.Fields())
94 | for rows.Next() {
95 | _ = rows.Datas()
96 | _, _ = rows.RowValues()
97 | }
98 |
99 | want := 13
100 | got := int(rows.Bytes())
101 | assert.Equal(t, want, got)
102 | }
103 |
104 | // query
105 | {
106 | client, err := NewConn("mock", "mock", address, "test", "")
107 | assert.Nil(t, err)
108 | defer client.Close()
109 |
110 | th.AddQuery("SELECT3", result3)
111 | rows, err := client.Query("SELECT3")
112 | assert.Nil(t, err)
113 | assert.Equal(t, result3.Fields, rows.Fields())
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/driver/session_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package driver
11 |
12 | import (
13 | "testing"
14 | "time"
15 |
16 | "github.com/stretchr/testify/assert"
17 | "github.com/xelabs/go-mysqlstack/xlog"
18 | )
19 |
20 | func TestSession(t *testing.T) {
21 | log := xlog.NewStdLog(xlog.Level(xlog.DEBUG))
22 | th := NewTestHandler(log)
23 | svr, err := MockMysqlServer(log, th)
24 | assert.Nil(t, err)
25 | address := svr.Addr()
26 |
27 | // create session 1
28 | client, err := NewConn("mock", "mock", address, "test", "")
29 | assert.Nil(t, err)
30 | defer client.Close()
31 |
32 | var sessions []*Session
33 | for _, s := range th.ss {
34 | sessions = append(sessions, s.session)
35 | }
36 |
37 | {
38 | session1 := sessions[0]
39 |
40 | // Session ID.
41 | {
42 | log.Debug("--id:%v", session1.ID())
43 | log.Debug("--addr:%v", session1.Addr())
44 | log.Debug("--salt:%v", session1.Salt())
45 | log.Debug("--scramble:%v", session1.Scramble())
46 | }
47 |
48 | // schema.
49 | {
50 | want := "xx"
51 | session1.SetSchema(want)
52 | got := session1.Schema()
53 | assert.Equal(t, want, got)
54 | }
55 |
56 | // charset.
57 | {
58 | want := uint8(0x21)
59 | got := session1.Charset()
60 | assert.Equal(t, want, got)
61 | }
62 |
63 | // UpdateTime.
64 | {
65 | want := time.Now()
66 | session1.updateLastQueryTime(want)
67 | got := session1.LastQueryTime()
68 | assert.Equal(t, want, got)
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/driver/statement.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package driver
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/proto"
14 | "github.com/xelabs/go-mysqlstack/sqldb"
15 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
16 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
17 | )
18 |
19 | // Statement --
20 | type Statement struct {
21 | conn *conn
22 | ID uint32
23 | ParamCount uint16
24 | PrepareStmt string
25 | ParamsType []int32
26 | ColumnNames []string
27 | BindVars map[string]*querypb.BindVariable
28 | }
29 |
30 | // ComStatementExecute -- statement execute write.
31 | func (s *Statement) ComStatementExecute(parameters []sqltypes.Value) error {
32 | var err error
33 | var datas []byte
34 | var iRows Rows
35 |
36 | if datas, err = proto.PackStatementExecute(s.ID, parameters); err != nil {
37 | return err
38 | }
39 |
40 | if iRows, err = s.conn.stmtQuery(sqldb.COM_STMT_EXECUTE, datas); err != nil {
41 | return err
42 | }
43 | for iRows.Next() {
44 | if _, err := iRows.RowValues(); err != nil {
45 | s.conn.Cleanup()
46 | return err
47 | }
48 | }
49 | // Drain the results and check last error.
50 | if err := iRows.Close(); err != nil {
51 | s.conn.Cleanup()
52 | return err
53 | }
54 | return nil
55 | }
56 |
57 | // ComStatementExecute -- statement execute write.
58 | func (s *Statement) ComStatementQuery(parameters []sqltypes.Value) (*sqltypes.Result, error) {
59 | var err error
60 | var datas []byte
61 | var iRows Rows
62 | var qrRow []sqltypes.Value
63 | var qrRows [][]sqltypes.Value
64 |
65 | if datas, err = proto.PackStatementExecute(s.ID, parameters); err != nil {
66 | return nil, err
67 | }
68 |
69 | if iRows, err = s.conn.stmtQuery(sqldb.COM_STMT_EXECUTE, datas); err != nil {
70 | return nil, err
71 | }
72 | for iRows.Next() {
73 | if qrRow, err = iRows.RowValues(); err != nil {
74 | s.conn.Cleanup()
75 | return nil, err
76 | }
77 | if qrRow != nil {
78 | qrRows = append(qrRows, qrRow)
79 | }
80 | }
81 | // Drain the results and check last error.
82 | if err := iRows.Close(); err != nil {
83 | s.conn.Cleanup()
84 | return nil, err
85 | }
86 |
87 | rowsAffected := iRows.RowsAffected()
88 | if rowsAffected == 0 {
89 | rowsAffected = uint64(len(qrRows))
90 | }
91 | qr := &sqltypes.Result{
92 | Fields: iRows.Fields(),
93 | RowsAffected: rowsAffected,
94 | InsertID: iRows.LastInsertID(),
95 | Rows: qrRows,
96 | }
97 | return qr, err
98 | }
99 |
100 | // ComStatementReset -- reset the stmt.
101 | func (s *Statement) ComStatementReset() error {
102 | var data [4]byte
103 |
104 | // Add arg [32 bit]
105 | data[0] = byte(s.ID)
106 | data[1] = byte(s.ID >> 8)
107 | data[2] = byte(s.ID >> 16)
108 | data[3] = byte(s.ID >> 24)
109 | if err := s.conn.packets.WriteCommand(sqldb.COM_STMT_RESET, data[:]); err != nil {
110 | return err
111 | }
112 | return s.conn.packets.ReadOK()
113 | }
114 |
115 | // ComStatementClose -- close the stmt.
116 | func (s *Statement) ComStatementClose() error {
117 | var data [4]byte
118 |
119 | // Add arg [32 bit]
120 | data[0] = byte(s.ID)
121 | data[1] = byte(s.ID >> 8)
122 | data[2] = byte(s.ID >> 16)
123 | data[3] = byte(s.ID >> 24)
124 | if err := s.conn.packets.WriteCommand(sqldb.COM_STMT_CLOSE, data[:]); err != nil {
125 | return err
126 | }
127 | return nil
128 | }
129 |
--------------------------------------------------------------------------------
/driver/statement_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package driver
11 |
12 | import (
13 | "testing"
14 | "time"
15 |
16 | "github.com/stretchr/testify/assert"
17 | "github.com/xelabs/go-mysqlstack/xlog"
18 |
19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
20 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
21 | )
22 |
23 | func TestStatement(t *testing.T) {
24 | log := xlog.NewStdLog(xlog.Level(xlog.DEBUG))
25 | th := NewTestHandler(log)
26 | svr, err := MockMysqlServer(log, th)
27 | assert.Nil(t, err)
28 | defer svr.Close()
29 | address := svr.Addr()
30 |
31 | result1 := &sqltypes.Result{
32 | Fields: []*querypb.Field{
33 | {
34 | Name: "a",
35 | Type: sqltypes.Int32,
36 | },
37 | {
38 | Name: "b",
39 | Type: sqltypes.VarChar,
40 | },
41 | {
42 | Name: "c",
43 | Type: sqltypes.Datetime,
44 | },
45 | {
46 | Name: "d",
47 | Type: sqltypes.Time,
48 | },
49 | {
50 | Name: "e",
51 | Type: sqltypes.VarChar,
52 | },
53 | },
54 | Rows: [][]sqltypes.Value{
55 | {
56 | sqltypes.MakeTrusted(sqltypes.Int32, []byte("10")),
57 | sqltypes.MakeTrusted(sqltypes.VarChar, []byte("xx10xx")),
58 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))),
59 | sqltypes.MakeTrusted(sqltypes.Time, []byte("15:04:05")),
60 | sqltypes.MakeTrusted(sqltypes.VarChar, nil),
61 | },
62 | },
63 | }
64 | result2 := &sqltypes.Result{}
65 | th.AddQueryPattern("drop table if .*", result2)
66 | th.AddQueryPattern("create table if .*", result2)
67 | th.AddQueryPattern("insert .*", result2)
68 | th.AddQueryPattern("select .*", result1)
69 |
70 | // query
71 | {
72 | client, err := NewConn("mock", "mock", address, "test", "")
73 | //client, err := NewConn("root", "", "127.0.0.1:3307", "test", "")
74 | assert.Nil(t, err)
75 | defer client.Close()
76 |
77 | query := "drop table if exists t1"
78 | err = client.Exec(query)
79 | assert.Nil(t, err)
80 |
81 | query = "create table if not exists t1 (a int, b varchar(20), c datetime, d time, e varchar(20))"
82 | err = client.Exec(query)
83 | assert.Nil(t, err)
84 |
85 | // Prepare Insert.
86 | {
87 | query = "insert into t1(a, b, c, d, e) values(?,?,?,?,?)"
88 | stmt, err := client.ComStatementPrepare(query)
89 | assert.Nil(t, err)
90 | log.Debug("stmt:%+v", stmt)
91 |
92 | params := []sqltypes.Value{
93 | sqltypes.NewInt32(11),
94 | sqltypes.NewVarChar("xx10xx"),
95 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))),
96 | sqltypes.MakeTrusted(sqltypes.Time, []byte("15:04:05")),
97 | sqltypes.MakeTrusted(sqltypes.VarChar, nil),
98 | }
99 | err = stmt.ComStatementExecute(params)
100 | assert.Nil(t, err)
101 | stmt.ComStatementClose()
102 | }
103 |
104 | // Normal Select int.
105 | {
106 | query = "select * from t1 where a=10"
107 | qr, err := client.FetchAll(query, -1)
108 | assert.Nil(t, err)
109 | log.Debug("normal:%+v", qr)
110 | }
111 |
112 | {
113 | query = "select * from t1 where a=10"
114 | qr, err := client.FetchAll(query, -1)
115 | assert.Nil(t, err)
116 | log.Debug("normal:%+v", qr)
117 | }
118 |
119 | // Prepare Select int.
120 | {
121 | query = "select * from t1 where a=?"
122 | stmt, err := client.ComStatementPrepare(query)
123 | assert.Nil(t, err)
124 | assert.NotNil(t, stmt)
125 | log.Debug("stmt:%+v", stmt)
126 |
127 | params := []sqltypes.Value{
128 | sqltypes.NewInt32(11),
129 | }
130 | qr, err := stmt.ComStatementQuery(params)
131 | assert.Nil(t, err)
132 | log.Debug("%+v", qr)
133 | stmt.ComStatementClose()
134 | }
135 |
136 | // Prepare Select int.
137 | {
138 | query = "select * from t1 where a=?"
139 | stmt, err := client.ComStatementPrepare(query)
140 | assert.Nil(t, err)
141 | log.Debug("stmt:%+v", stmt)
142 |
143 | params := []sqltypes.Value{
144 | sqltypes.NewInt32(11),
145 | }
146 | qr, err := stmt.ComStatementQuery(params)
147 | assert.Nil(t, err)
148 | log.Debug("%+v", qr)
149 | stmt.ComStatementClose()
150 | }
151 |
152 | // Prepare Select time.
153 | {
154 | query = "select a,b,c,d,e from t1 where c=?"
155 | stmt, err := client.ComStatementPrepare(query)
156 | assert.Nil(t, err)
157 | log.Debug("stmt:%+v", stmt)
158 |
159 | params := []sqltypes.Value{
160 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))),
161 | }
162 | qr, err := stmt.ComStatementQuery(params)
163 | assert.Nil(t, err)
164 | log.Debug("%+v", qr)
165 | stmt.ComStatementReset()
166 | stmt.ComStatementClose()
167 | }
168 | }
169 | }
170 |
--------------------------------------------------------------------------------
/examples/client.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package main
11 |
12 | import (
13 | "fmt"
14 |
15 | "github.com/xelabs/go-mysqlstack/driver"
16 | "github.com/xelabs/go-mysqlstack/xlog"
17 | )
18 |
19 | func main() {
20 | log := xlog.NewStdLog(xlog.Level(xlog.INFO))
21 | address := fmt.Sprintf(":4407")
22 | client, err := driver.NewConn("mock", "mock", address, "", "")
23 | if err != nil {
24 | log.Panic("client.new.connection.error:%+v", err)
25 | }
26 | defer client.Close()
27 |
28 | qr, err := client.FetchAll("SELECT * FROM MOCK", -1)
29 | if err != nil {
30 | log.Panic("client.query.error:%+v", err)
31 | }
32 | log.Info("results:[%+v]", qr.Rows)
33 | }
34 |
--------------------------------------------------------------------------------
/examples/mysqld.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package main
11 |
12 | import (
13 | "os"
14 | "os/signal"
15 | "syscall"
16 |
17 | "github.com/xelabs/go-mysqlstack/driver"
18 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
19 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
20 | "github.com/xelabs/go-mysqlstack/xlog"
21 | )
22 |
23 | func main() {
24 | result1 := &sqltypes.Result{
25 | Fields: []*querypb.Field{
26 | {
27 | Name: "id",
28 | Type: querypb.Type_INT32,
29 | },
30 | {
31 | Name: "name",
32 | Type: querypb.Type_VARCHAR,
33 | },
34 | },
35 | Rows: [][]sqltypes.Value{
36 | {
37 | sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")),
38 | sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
39 | },
40 | },
41 | }
42 |
43 | log := xlog.NewStdLog(xlog.Level(xlog.INFO))
44 | th := driver.NewTestHandler(log)
45 | th.AddQuery("SELECT * FROM MOCK", result1)
46 |
47 | mysqld, err := driver.MockMysqlServerWithPort(log, 4407, th)
48 | if err != nil {
49 | log.Panic("mysqld.start.error:%+v", err)
50 | }
51 | defer mysqld.Close()
52 | log.Info("mysqld.server.start.address[%v]", mysqld.Addr())
53 |
54 | // Handle SIGINT and SIGTERM.
55 | ch := make(chan os.Signal)
56 | signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
57 | <-ch
58 | }
59 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xelabs/go-mysqlstack
2 |
3 | require (
4 | github.com/shopspring/decimal v1.2.0
5 | github.com/stretchr/testify v1.7.0
6 | )
7 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5 | github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
6 | github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
8 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
9 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
12 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
14 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | export PATH := $(GOPATH)/bin:$(PATH)
2 |
3 | fmt:
4 | go fmt ./...
5 | go vet ./...
6 |
7 | test:
8 | go get github.com/stretchr/testify/assert
9 | @echo "--> Testing..."
10 | @$(MAKE) testxlog
11 | @$(MAKE) testsqlparser
12 | @$(MAKE) testsqldb
13 | @$(MAKE) testproto
14 | @$(MAKE) testpacket
15 | @$(MAKE) testdriver
16 |
17 | testxlog:
18 | go test -v ./xlog
19 | testsqlparser:
20 | go test -v ./sqlparser/...
21 | testsqldb:
22 | go test -v ./sqldb
23 | testproto:
24 | go test -v ./proto
25 | testpacket:
26 | go test -v ./packet
27 | testdriver:
28 | go test -v ./driver
29 |
30 | COVPKGS = ./sqlparser/... ./sqldb ./proto ./packet ./driver
31 | coverage:
32 | go get github.com/pierrre/gotestcover
33 | gotestcover -coverprofile=coverage.out -v $(COVPKGS)
34 | go tool cover -html=coverage.out
35 |
36 | .PHONY: fmt testcommon testproto testpacket testdriver coverage
37 |
--------------------------------------------------------------------------------
/packet/error.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package packet
11 |
12 | import (
13 | "errors"
14 | )
15 |
16 | var (
17 | // ErrBadConn used for the error of bad connection.
18 | ErrBadConn = errors.New("connection.was.bad")
19 | // ErrMalformPacket used for the bad packet.
20 | ErrMalformPacket = errors.New("Malform.packet.error")
21 | )
22 |
--------------------------------------------------------------------------------
/packet/mock.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
7 | * GPL License
8 | *
9 | */
10 |
11 | package packet
12 |
13 | import (
14 | "io"
15 | "net"
16 | "time"
17 | )
18 |
19 | var _ net.Conn = &MockConn{}
20 |
21 | // MockConn used to mock a net.Conn for testing purposes.
22 | type MockConn struct {
23 | laddr net.Addr
24 | raddr net.Addr
25 | data []byte
26 | closed bool
27 | read int
28 | }
29 |
30 | // NewMockConn creates new mock connection.
31 | func NewMockConn() *MockConn {
32 | return &MockConn{}
33 | }
34 |
35 | // Read implements the net.Conn interface.
36 | func (m *MockConn) Read(b []byte) (n int, err error) {
37 | // handle the EOF
38 | if len(m.data) == 0 {
39 | err = io.EOF
40 | return
41 | }
42 |
43 | n = copy(b, m.data)
44 | m.read += n
45 | m.data = m.data[n:]
46 | return
47 | }
48 |
49 | // Write implements the net.Conn interface.
50 | func (m *MockConn) Write(b []byte) (n int, err error) {
51 | m.data = append(m.data, b...)
52 | return len(b), nil
53 | }
54 |
55 | // Datas implements the net.Conn interface.
56 | func (m *MockConn) Datas() []byte {
57 | return m.data
58 | }
59 |
60 | // Close implements the net.Conn interface.
61 | func (m *MockConn) Close() error {
62 | m.closed = true
63 | return nil
64 | }
65 |
66 | // LocalAddr implements the net.Conn interface.
67 | func (m *MockConn) LocalAddr() net.Addr {
68 | return m.laddr
69 | }
70 |
71 | // RemoteAddr implements the net.Conn interface.
72 | func (m *MockConn) RemoteAddr() net.Addr {
73 | return m.raddr
74 | }
75 |
76 | // SetDeadline implements the net.Conn interface.
77 | func (m *MockConn) SetDeadline(t time.Time) error {
78 | return nil
79 | }
80 |
81 | // SetReadDeadline implements the net.Conn interface.
82 | func (m *MockConn) SetReadDeadline(t time.Time) error {
83 | return nil
84 | }
85 |
86 | // SetWriteDeadline implements the net.Conn interface.
87 | func (m *MockConn) SetWriteDeadline(t time.Time) error {
88 | return nil
89 | }
90 |
--------------------------------------------------------------------------------
/packet/stream.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package packet
11 |
12 | import (
13 | "bufio"
14 | "io"
15 | "net"
16 | )
17 |
18 | const (
19 | // PACKET_BUFFER_SIZE is how much we buffer for reading.
20 | PACKET_BUFFER_SIZE = 32 * 1024
21 | )
22 |
23 | // Stream represents the stream tuple.
24 | type Stream struct {
25 | pktMaxSize int
26 | header []byte
27 | reader *bufio.Reader
28 | writer *bufio.Writer
29 | }
30 |
31 | // NewStream creates a new stream.
32 | func NewStream(conn net.Conn, pktMaxSize int) *Stream {
33 | return &Stream{
34 | pktMaxSize: pktMaxSize,
35 | header: []byte{0, 0, 0, 0},
36 | reader: bufio.NewReaderSize(conn, PACKET_BUFFER_SIZE),
37 | writer: bufio.NewWriterSize(conn, PACKET_BUFFER_SIZE),
38 | }
39 | }
40 |
41 | // Read reads the next packet from the reader
42 | // The returned pkt.Datas is only guaranteed to be valid until the next read
43 | func (s *Stream) Read() (*Packet, error) {
44 | // Header.
45 | if _, err := io.ReadFull(s.reader, s.header); err != nil {
46 | return nil, err
47 | }
48 |
49 | // Length.
50 | pkt := &Packet{}
51 | pkt.SequenceID = s.header[3]
52 | length := int(uint32(s.header[0]) | uint32(s.header[1])<<8 | uint32(s.header[2])<<16)
53 | if length == 0 {
54 | return pkt, nil
55 | }
56 |
57 | // Datas.
58 | data := make([]byte, length)
59 | if _, err := io.ReadFull(s.reader, data); err != nil {
60 | return nil, err
61 | }
62 | pkt.Datas = data
63 |
64 | // Single packet.
65 | if length < s.pktMaxSize {
66 | return pkt, nil
67 | }
68 |
69 | // There is more than one packet, read them all.
70 | next, err := s.Read()
71 | if err != nil {
72 | return nil, err
73 | }
74 | pkt.SequenceID = next.SequenceID
75 | pkt.Datas = append(pkt.Datas, next.Datas...)
76 | return pkt, nil
77 | }
78 |
79 | // Write writes the packet to writer
80 | func (s *Stream) Write(data []byte) error {
81 | if err := s.Append(data); err != nil {
82 | return err
83 | }
84 | return s.Flush()
85 | }
86 |
87 | // Append used to append data to write buffer.
88 | func (s *Stream) Append(data []byte) error {
89 | payLen := len(data) - 4
90 | sequence := data[3]
91 |
92 | for {
93 | var size int
94 | if payLen < s.pktMaxSize {
95 | size = payLen
96 | } else {
97 | size = s.pktMaxSize
98 | }
99 | data[0] = byte(size)
100 | data[1] = byte(size >> 8)
101 | data[2] = byte(size >> 16)
102 | data[3] = sequence
103 |
104 | // append to buffer
105 | s.writer.Write(data[:4+size])
106 | if size < s.pktMaxSize {
107 | break
108 | }
109 |
110 | payLen -= size
111 | data = data[size:]
112 | sequence++
113 | }
114 | return nil
115 | }
116 |
117 | // Flush used to flush the writer.
118 | func (s *Stream) Flush() error {
119 | return s.writer.Flush()
120 | }
121 |
--------------------------------------------------------------------------------
/packet/stream_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package packet
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 |
17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
18 | )
19 |
20 | // TEST EFFECTS:
21 | // writes normal packet
22 | //
23 | // TEST PROCESSES:
24 | // 1. write datas more than PACKET_BUFFER_SIZE
25 | // 2. write checks
26 | // 3. read checks
27 | func TestStream(t *testing.T) {
28 | rBuf := NewMockConn()
29 | defer rBuf.Close()
30 |
31 | wBuf := NewMockConn()
32 | defer wBuf.Close()
33 |
34 | rStream := NewStream(rBuf, PACKET_MAX_SIZE)
35 | wStream := NewStream(wBuf, PACKET_MAX_SIZE)
36 |
37 | packet := common.NewBuffer(PACKET_BUFFER_SIZE)
38 | payload := common.NewBuffer(PACKET_BUFFER_SIZE)
39 |
40 | for i := 0; i < 1234; i++ {
41 | payload.WriteU8(byte(i))
42 | }
43 |
44 | packet.WriteU24(uint32(payload.Length()))
45 | packet.WriteU8(1)
46 | packet.WriteBytes(payload.Datas())
47 |
48 | // write checks
49 | {
50 | err := wStream.Write(packet.Datas())
51 | assert.Nil(t, err)
52 |
53 | want := packet.Datas()
54 | got := wBuf.Datas()
55 | assert.Equal(t, want, got)
56 | }
57 |
58 | // read checks
59 | {
60 | rBuf.Write(wBuf.Datas())
61 | ptk, err := rStream.Read()
62 | assert.Nil(t, err)
63 |
64 | assert.Equal(t, byte(0x01), ptk.SequenceID)
65 | assert.Equal(t, payload.Datas(), ptk.Datas)
66 | }
67 | }
68 |
69 | // TEST EFFECTS:
70 | // write packet whoes payload length equals pktMaxSize
71 | //
72 | // TEST PROCESSES:
73 | // 1. write payload whoes length equals pktMaxSize
74 | // 2. read checks
75 | // 3. write checks
76 | func TestStreamWriteMax(t *testing.T) {
77 | rBuf := NewMockConn()
78 | defer rBuf.Close()
79 |
80 | wBuf := NewMockConn()
81 | defer wBuf.Close()
82 |
83 | pktMaxSize := 64
84 | rStream := NewStream(rBuf, pktMaxSize)
85 | wStream := NewStream(wBuf, pktMaxSize)
86 |
87 | packet := common.NewBuffer(PACKET_BUFFER_SIZE)
88 | expect := common.NewBuffer(PACKET_BUFFER_SIZE)
89 | payload := common.NewBuffer(PACKET_BUFFER_SIZE)
90 |
91 | {
92 | for i := 0; i < (pktMaxSize+1)/4; i++ {
93 | payload.WriteU32(uint32(i))
94 | }
95 | }
96 | packet.WriteU24(uint32(payload.Length()))
97 | packet.WriteU8(1)
98 | packet.WriteBytes(payload.Datas())
99 |
100 | // write checks
101 | {
102 | err := wStream.Write(packet.Datas())
103 | assert.Nil(t, err)
104 |
105 | // check length
106 | {
107 | want := packet.Length() + 4
108 | got := len(wBuf.Datas())
109 | assert.Equal(t, want, got)
110 | }
111 |
112 | // check chunks
113 | {
114 | // first chunk
115 | expect.WriteU24(uint32(pktMaxSize))
116 | expect.WriteU8(1)
117 | expect.WriteBytes(payload.Datas()[:pktMaxSize])
118 |
119 | // second chunk
120 | expect.WriteU24(0)
121 | expect.WriteU8(2)
122 |
123 | want := expect.Datas()
124 | got := wBuf.Datas()
125 | assert.Equal(t, want, got)
126 | }
127 | }
128 |
129 | // read checks
130 | {
131 | rBuf.Write(wBuf.Datas())
132 | ptk, err := rStream.Read()
133 | assert.Nil(t, err)
134 |
135 | assert.Equal(t, byte(0x02), ptk.SequenceID)
136 | assert.Equal(t, payload.Datas(), ptk.Datas)
137 | }
138 | }
139 |
140 | // TEST EFFECTS:
141 | // write packet whoes payload length more than pktMaxSizie
142 | //
143 | // TEST PROCESSES:
144 | // 1. write payload whoes length (pktMaxSizie + 8)
145 | // 2. read checks
146 | // 3. write checks
147 | func TestStreamWriteOverMax(t *testing.T) {
148 | rBuf := NewMockConn()
149 | defer rBuf.Close()
150 |
151 | wBuf := NewMockConn()
152 | defer wBuf.Close()
153 |
154 | pktMaxSize := 63
155 | rStream := NewStream(rBuf, pktMaxSize)
156 | wStream := NewStream(wBuf, pktMaxSize)
157 |
158 | packet := common.NewBuffer(PACKET_BUFFER_SIZE)
159 | expect := common.NewBuffer(PACKET_BUFFER_SIZE)
160 | payload := common.NewBuffer(PACKET_BUFFER_SIZE)
161 |
162 | {
163 | for i := 0; i < pktMaxSize/4; i++ {
164 | payload.WriteU32(uint32(i))
165 | }
166 | }
167 | // fill with 8bytes
168 | payload.WriteU32(32)
169 | payload.WriteU32(32)
170 |
171 | packet.WriteU24(uint32(payload.Length()))
172 | packet.WriteU8(1)
173 | packet.WriteBytes(payload.Datas())
174 |
175 | // write checks
176 | {
177 | err := wStream.Write(packet.Datas())
178 | assert.Nil(t, err)
179 |
180 | // check length
181 | {
182 | want := packet.Length() + 4
183 | got := len(wBuf.Datas())
184 | assert.Equal(t, want, got)
185 | }
186 |
187 | // check chunks
188 | {
189 | // first chunk
190 | expect.WriteU24(uint32(pktMaxSize))
191 | expect.WriteU8(1)
192 | expect.WriteBytes(payload.Datas()[:pktMaxSize])
193 |
194 | // second chunk
195 | left := (packet.Length() - 4) - pktMaxSize
196 | expect.WriteU24(uint32(left))
197 | expect.WriteU8(2)
198 | expect.WriteBytes(payload.Datas()[pktMaxSize:])
199 |
200 | want := expect.Datas()
201 | got := wBuf.Datas()
202 | assert.Equal(t, want, got)
203 | }
204 | }
205 |
206 | // read checks
207 | {
208 | rBuf.Write(wBuf.Datas())
209 | ptk, err := rStream.Read()
210 | assert.Nil(t, err)
211 |
212 | assert.Equal(t, byte(0x02), ptk.SequenceID)
213 | assert.Equal(t, payload.Datas(), ptk.Datas)
214 | _, err = rStream.Read()
215 | assert.NotNil(t, err)
216 | }
217 | }
218 |
--------------------------------------------------------------------------------
/proto/column.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/sqldb"
14 |
15 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
16 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
18 | )
19 |
20 | // ColumnCount returns the column count.
21 | func ColumnCount(payload []byte) (count uint64, err error) {
22 | buff := common.ReadBuffer(payload)
23 | if count, err = buff.ReadLenEncode(); err != nil {
24 | return 0, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting column count failed")
25 | }
26 | return
27 | }
28 |
29 | // UnpackColumn used to unpack the column packet.
30 | // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
31 | func UnpackColumn(payload []byte) (*querypb.Field, error) {
32 | var err error
33 | field := &querypb.Field{}
34 | buff := common.ReadBuffer(payload)
35 | // Catalog is ignored, always set to "def"
36 | if _, err = buff.ReadLenEncodeString(); err != nil {
37 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "skipping col catalog failed")
38 | }
39 |
40 | // lenenc_str Schema
41 | if field.Database, err = buff.ReadLenEncodeString(); err != nil {
42 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col schema failed")
43 | }
44 |
45 | // lenenc_str Table
46 | if field.Table, err = buff.ReadLenEncodeString(); err != nil {
47 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col table failed")
48 | }
49 |
50 | // lenenc_str Org_Table
51 | if field.OrgTable, err = buff.ReadLenEncodeString(); err != nil {
52 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col org_table failed")
53 | }
54 |
55 | // lenenc_str Name
56 | if field.Name, err = buff.ReadLenEncodeString(); err != nil {
57 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col name failed")
58 | }
59 |
60 | // lenenc_str Org_Name
61 | if field.OrgName, err = buff.ReadLenEncodeString(); err != nil {
62 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col org_name failed")
63 | }
64 |
65 | // lenenc_int length of fixed-length fields [0c], skip
66 | if _, err = buff.ReadLenEncode(); err != nil {
67 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col 0c failed")
68 | }
69 |
70 | // 2 character set
71 | charset, err := buff.ReadU16()
72 | if err != nil {
73 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col charset failed")
74 | }
75 | field.Charset = uint32(charset)
76 |
77 | // 4 column length
78 | if field.ColumnLength, err = buff.ReadU32(); err != nil {
79 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col columnlength failed")
80 | }
81 |
82 | // 1 type
83 | t, err := buff.ReadU8()
84 | if err != nil {
85 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col type failed")
86 | }
87 |
88 | // 2 flags
89 | flags, err := buff.ReadU16()
90 | if err != nil {
91 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col flags failed")
92 | }
93 | field.Flags = uint32(flags)
94 |
95 | // Convert MySQL type
96 | if field.Type, err = sqltypes.MySQLToType(int64(t), int64(field.Flags)); err != nil {
97 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "MySQLToType(%v,%v) failed: %v", t, field.Flags, err)
98 | }
99 |
100 | // 1 Decimals
101 | decimals, err := buff.ReadU8()
102 | if err != nil {
103 | return nil, sqldb.NewSQLError(sqldb.ER_MALFORMED_PACKET, "extracting col type failed")
104 | }
105 | field.Decimals = uint32(decimals)
106 |
107 | // 2 Filler and Default Values is ignored
108 | //
109 | return field, nil
110 | }
111 |
112 | // PackColumn used to pack the column packet.
113 | func PackColumn(field *querypb.Field) []byte {
114 | typ, flags := sqltypes.TypeToMySQL(field.Type)
115 | if field.Flags != 0 {
116 | flags = int64(field.Flags)
117 | }
118 |
119 | buf := common.NewBuffer(256)
120 |
121 | // lenenc_str Catalog, always 'def'
122 | buf.WriteLenEncodeString("def")
123 |
124 | // lenenc_str Schema
125 | buf.WriteLenEncodeString(field.Database)
126 |
127 | // lenenc_str Table
128 | buf.WriteLenEncodeString(field.Table)
129 |
130 | // lenenc_str Org_Table
131 | buf.WriteLenEncodeString(field.OrgTable)
132 |
133 | // lenenc_str Name
134 | buf.WriteLenEncodeString(field.Name)
135 |
136 | // lenenc_str Org_Name
137 | buf.WriteLenEncodeString(field.OrgName)
138 |
139 | // lenenc_int length of fixed-length fields [0c]
140 | buf.WriteLenEncode(uint64(0x0c))
141 |
142 | // 2 character set
143 | buf.WriteU16(uint16(field.Charset))
144 |
145 | // 4 column length
146 | buf.WriteU32(field.ColumnLength)
147 |
148 | // 1 type
149 | buf.WriteU8(byte(typ))
150 |
151 | // 2 flags
152 | buf.WriteU16(uint16(flags))
153 |
154 | //1 Decimals
155 | buf.WriteU8(uint8(field.Decimals))
156 |
157 | // 2 filler [00] [00]
158 | buf.WriteU16(uint16(0))
159 | return buf.Datas()
160 | }
161 |
--------------------------------------------------------------------------------
/proto/column_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
17 |
18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
19 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
20 | )
21 |
22 | func TestColumnCount(t *testing.T) {
23 | payload := []byte{
24 | 0x02,
25 | }
26 |
27 | want := uint64(2)
28 | got, err := ColumnCount(payload)
29 | assert.Nil(t, err)
30 | assert.Equal(t, want, got)
31 | }
32 |
33 | func TestColumn(t *testing.T) {
34 | want := &querypb.Field{
35 | Database: "test",
36 | Table: "t1",
37 | OrgTable: "t1",
38 | Name: "a",
39 | OrgName: "a",
40 | Charset: 11,
41 | ColumnLength: 11,
42 | Type: sqltypes.Int32,
43 | Flags: 11,
44 | }
45 |
46 | datas := PackColumn(want)
47 | got, err := UnpackColumn(datas)
48 | assert.Nil(t, err)
49 | assert.Equal(t, want, got)
50 | }
51 |
52 | func TestColumnUnPackError(t *testing.T) {
53 | // NULL
54 | f0 := func(buff *common.Buffer) {
55 | }
56 |
57 | // Write catalog.
58 | f1 := func(buff *common.Buffer) {
59 | buff.WriteLenEncodeString("def")
60 | }
61 |
62 | // Write schema.
63 | f2 := func(buff *common.Buffer) {
64 | buff.WriteLenEncodeString("sbtest")
65 | }
66 |
67 | // Write table.
68 | f3 := func(buff *common.Buffer) {
69 | buff.WriteLenEncodeString("table1")
70 | }
71 |
72 | // Write org table.
73 | f4 := func(buff *common.Buffer) {
74 | buff.WriteLenEncodeString("orgtable1")
75 | }
76 |
77 | // Write Name.
78 | f5 := func(buff *common.Buffer) {
79 | buff.WriteLenEncodeString("name")
80 | }
81 |
82 | // Write Org Name.
83 | f6 := func(buff *common.Buffer) {
84 | buff.WriteLenEncodeString("name")
85 | }
86 |
87 | // Write length.
88 | f7 := func(buff *common.Buffer) {
89 | buff.WriteLenEncode(0x0c)
90 | }
91 |
92 | // Write Charset.
93 | f8 := func(buff *common.Buffer) {
94 | buff.WriteU16(uint16(1))
95 | }
96 |
97 | // Write Column length.
98 | f9 := func(buff *common.Buffer) {
99 | buff.WriteU32(uint32(1))
100 | }
101 |
102 | // Write type.
103 | f10 := func(buff *common.Buffer) {
104 | buff.WriteU8(0x01)
105 | }
106 |
107 | // Write flags
108 | f11 := func(buff *common.Buffer) {
109 | buff.WriteU16(uint16(1))
110 | }
111 |
112 | buff := common.NewBuffer(32)
113 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11}
114 | for i := 0; i < len(fs); i++ {
115 | _, err := UnpackColumn(buff.Datas())
116 | assert.NotNil(t, err)
117 | fs[i](buff)
118 | }
119 |
120 | {
121 | _, err := UnpackColumn(buff.Datas())
122 | assert.NotNil(t, err)
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/proto/const.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/sqldb"
14 | )
15 |
16 | const (
17 | // DefaultAuthPluginName is the default plugin name.
18 | DefaultAuthPluginName = "mysql_native_password"
19 |
20 | // DefaultServerCapability is the default server capability.
21 | DefaultServerCapability = sqldb.CLIENT_LONG_PASSWORD |
22 | sqldb.CLIENT_LONG_FLAG |
23 | sqldb.CLIENT_CONNECT_WITH_DB |
24 | sqldb.CLIENT_PROTOCOL_41 |
25 | sqldb.CLIENT_TRANSACTIONS |
26 | sqldb.CLIENT_MULTI_STATEMENTS |
27 | sqldb.CLIENT_PLUGIN_AUTH |
28 | sqldb.CLIENT_DEPRECATE_EOF |
29 | sqldb.CLIENT_SECURE_CONNECTION
30 |
31 | // DefaultClientCapability is the default client capability.
32 | DefaultClientCapability = sqldb.CLIENT_LONG_PASSWORD |
33 | sqldb.CLIENT_LONG_FLAG |
34 | sqldb.CLIENT_PROTOCOL_41 |
35 | sqldb.CLIENT_TRANSACTIONS |
36 | sqldb.CLIENT_MULTI_STATEMENTS |
37 | sqldb.CLIENT_PLUGIN_AUTH |
38 | sqldb.CLIENT_DEPRECATE_EOF |
39 | sqldb.CLIENT_SECURE_CONNECTION
40 | )
41 |
42 | var (
43 | // DefaultSalt is the default salt bytes.
44 | DefaultSalt = []byte{
45 | 0x77, 0x63, 0x6a, 0x6d, 0x61, 0x22, 0x23, 0x27, // first part
46 | 0x38, 0x26, 0x55, 0x58, 0x3b, 0x5d, 0x44, 0x78, 0x53, 0x73, 0x6b, 0x41}
47 | )
48 |
--------------------------------------------------------------------------------
/proto/eof.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/sqldb"
14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
15 | )
16 |
17 | const (
18 | // EOF_PACKET is the EOF packet.
19 | EOF_PACKET byte = 0xfe
20 | )
21 |
22 | // EOF used for EOF packet.
23 | type EOF struct {
24 | Header byte // 0x00
25 | Warnings uint16
26 | StatusFlags uint16
27 | }
28 |
29 | // UnPackEOF used to unpack the EOF packet.
30 | // https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html
31 | // This method unsed.
32 | func UnPackEOF(data []byte) (*EOF, error) {
33 | var err error
34 | e := &EOF{}
35 | buf := common.ReadBuffer(data)
36 |
37 | // header
38 | if e.Header, err = buf.ReadU8(); err != nil {
39 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet header: %v", data)
40 | }
41 | if e.Header != EOF_PACKET {
42 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid oeof packet header: %v", e.Header)
43 | }
44 |
45 | // Warnings
46 | if e.Warnings, err = buf.ReadU16(); err != nil {
47 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet warnings: %v", data)
48 | }
49 |
50 | // Status
51 | if e.StatusFlags, err = buf.ReadU16(); err != nil {
52 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid eof packet statusflags: %v", data)
53 | }
54 | return e, nil
55 | }
56 |
57 | // PackEOF used to pack the EOF packet.
58 | func PackEOF(e *EOF) []byte {
59 | buf := common.NewBuffer(64)
60 |
61 | // EOF
62 | buf.WriteU8(EOF_PACKET)
63 |
64 | // warnings
65 | buf.WriteU16(e.Warnings)
66 |
67 | // status
68 | buf.WriteU16(e.StatusFlags)
69 | return buf.Datas()
70 | }
71 |
--------------------------------------------------------------------------------
/proto/eof_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 |
17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
18 | )
19 |
20 | func TestEOF(t *testing.T) {
21 | want := &EOF{}
22 | want.Header = EOF_PACKET
23 | want.StatusFlags = 1
24 | want.Warnings = 2
25 | data := PackEOF(want)
26 |
27 | got, err := UnPackEOF(data)
28 | assert.Nil(t, err)
29 | assert.Equal(t, want, got)
30 | }
31 |
32 | func TestEOFUnPackError(t *testing.T) {
33 | // header error
34 | {
35 | buff := common.NewBuffer(32)
36 | // header
37 | buff.WriteU8(0x99)
38 | _, err := UnPackEOF(buff.Datas())
39 | assert.NotNil(t, err)
40 | }
41 |
42 | // NULL
43 | f0 := func(buff *common.Buffer) {
44 | }
45 |
46 | // Write EOF header.
47 | f1 := func(buff *common.Buffer) {
48 | buff.WriteU8(0xfe)
49 | }
50 |
51 | // Write Status.
52 | f2 := func(buff *common.Buffer) {
53 | buff.WriteU16(0x01)
54 | }
55 |
56 | buff := common.NewBuffer(32)
57 | fs := []func(buff *common.Buffer){f0, f1, f2}
58 | for i := 0; i < len(fs); i++ {
59 | _, err := UnPackEOF(buff.Datas())
60 | assert.NotNil(t, err)
61 | fs[i](buff)
62 | }
63 |
64 | {
65 | _, err := UnPackEOF(buff.Datas())
66 | assert.NotNil(t, err)
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/proto/err.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/sqldb"
14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
15 | )
16 |
17 | const (
18 | // ERR_PACKET is the error packet byte.
19 | ERR_PACKET byte = 0xff
20 | )
21 |
22 | // ERR is the error packet.
23 | type ERR struct {
24 | Header byte // always 0xff
25 | ErrorCode uint16
26 | SQLState string
27 | ErrorMessage string
28 | }
29 |
30 | // UnPackERR parses the error packet and returns a sqldb.SQLError.
31 | // https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
32 | func UnPackERR(data []byte) error {
33 | var err error
34 | e := &ERR{}
35 | buf := common.ReadBuffer(data)
36 | if e.Header, err = buf.ReadU8(); err != nil {
37 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet header: %v", data)
38 | }
39 | if e.Header != ERR_PACKET {
40 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet header: %v", e.Header)
41 | }
42 | if e.ErrorCode, err = buf.ReadU16(); err != nil {
43 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet code: %v", data)
44 | }
45 |
46 | // Skip SQLStateMarker
47 | if _, err = buf.ReadString(1); err != nil {
48 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet marker: %v", data)
49 | }
50 | if e.SQLState, err = buf.ReadString(5); err != nil {
51 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet sqlstate: %v", data)
52 | }
53 | msgLen := len(data) - buf.Seek()
54 | if e.ErrorMessage, err = buf.ReadString(msgLen); err != nil {
55 | return sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid error packet message: %v", data)
56 | }
57 | return sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage)
58 | }
59 |
60 | // PackERR used to pack the error packet.
61 | func PackERR(e *ERR) []byte {
62 | buf := common.NewBuffer(64)
63 |
64 | buf.WriteU8(ERR_PACKET)
65 |
66 | // error code
67 | buf.WriteU16(e.ErrorCode)
68 |
69 | // sql-state marker #
70 | buf.WriteU8('#')
71 |
72 | // sql-state (?) 5 ascii bytes
73 | if e.SQLState == "" {
74 | e.SQLState = "HY000"
75 | }
76 | if len(e.SQLState) != 5 {
77 | panic("sqlState has to be 5 characters long")
78 | }
79 | buf.WriteString(e.SQLState)
80 |
81 | // error msg
82 | buf.WriteString(e.ErrorMessage)
83 | return buf.Datas()
84 | }
85 |
--------------------------------------------------------------------------------
/proto/err_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 | "github.com/xelabs/go-mysqlstack/sqldb"
17 |
18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
19 | )
20 |
21 | func TestERR(t *testing.T) {
22 | {
23 | buff := common.NewBuffer(32)
24 |
25 | // header
26 | buff.WriteU8(0xff)
27 | // error_code
28 | buff.WriteU16(0x01)
29 | // sql_state_marker
30 | buff.WriteString("#")
31 | // sql_state
32 | buff.WriteString("ABCDE")
33 | buff.WriteString("ERROR")
34 |
35 | e := &ERR{}
36 | e.Header = 0xff
37 | e.ErrorCode = 0x1
38 | e.SQLState = "ABCDE"
39 | e.ErrorMessage = "ERROR"
40 | want := sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage)
41 | got := UnPackERR(buff.Datas())
42 | assert.Equal(t, want, got)
43 | }
44 |
45 | {
46 | e := &ERR{}
47 | e.Header = 0xff
48 | e.ErrorCode = 0x1
49 | e.ErrorMessage = "ERROR"
50 | datas := PackERR(e)
51 | want := sqldb.NewSQLError1(e.ErrorCode, e.SQLState, "%s", e.ErrorMessage)
52 | got := UnPackERR(datas)
53 | assert.Equal(t, want, got)
54 | }
55 | }
56 |
57 | func TestERRUnPackError(t *testing.T) {
58 | // header error
59 | {
60 | buff := common.NewBuffer(32)
61 |
62 | // header
63 | buff.WriteU8(0x01)
64 |
65 | err := UnPackERR(buff.Datas())
66 | assert.NotNil(t, err)
67 | }
68 |
69 | // NULL
70 | f0 := func(buff *common.Buffer) {
71 | }
72 |
73 | // Write error header.
74 | f1 := func(buff *common.Buffer) {
75 | buff.WriteU8(0xff)
76 | }
77 |
78 | // Write error code.
79 | f2 := func(buff *common.Buffer) {
80 | buff.WriteU16(0x01)
81 | }
82 |
83 | // Write SQLStateMarker.
84 | f3 := func(buff *common.Buffer) {
85 | buff.WriteU8('#')
86 | }
87 |
88 | // Write SQLState.
89 | f4 := func(buff *common.Buffer) {
90 | buff.WriteString("xxxxx")
91 | }
92 |
93 | buff := common.NewBuffer(32)
94 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4}
95 | for i := 0; i < len(fs); i++ {
96 | err := UnPackERR(buff.Datas())
97 | assert.NotNil(t, err)
98 | fs[i](buff)
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/proto/greeting_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 | "github.com/xelabs/go-mysqlstack/sqldb"
17 |
18 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
19 | )
20 |
21 | func TestGreetingUnPack(t *testing.T) {
22 | want := NewGreeting(4, "")
23 | got := NewGreeting(4, "")
24 |
25 | // normal
26 | {
27 | want.authPluginName = "mysql_native_password"
28 | err := got.UnPack(want.Pack())
29 | assert.Nil(t, err)
30 | assert.Equal(t, want, got)
31 | assert.Equal(t, sqldb.SERVER_STATUS_AUTOCOMMIT, int(got.Status()))
32 | }
33 |
34 | // 1. off sqldb.CLIENT_PLUGIN_AUTH
35 | {
36 | want.Capability = want.Capability &^ sqldb.CLIENT_PLUGIN_AUTH
37 | want.authPluginName = "mysql_native_password"
38 | err := got.UnPack(want.Pack())
39 | assert.Nil(t, err)
40 | assert.Equal(t, want, got)
41 | }
42 |
43 | // 2. off sqldb.CLIENT_SECURE_CONNECTION
44 | {
45 | want.Capability &= ^sqldb.CLIENT_SECURE_CONNECTION
46 | want.authPluginName = "mysql_native_password"
47 | err := got.UnPack(want.Pack())
48 | assert.Nil(t, err)
49 | assert.Equal(t, want, got)
50 | }
51 |
52 | // 3. off sqldb.CLIENT_PLUGIN_AUTH && sqldb.CLIENT_SECURE_CONNECTION
53 | {
54 | want.Capability &= (^sqldb.CLIENT_PLUGIN_AUTH ^ sqldb.CLIENT_SECURE_CONNECTION)
55 | want.authPluginName = "mysql_native_password"
56 | err := got.UnPack(want.Pack())
57 | assert.Nil(t, err)
58 | assert.Equal(t, want, got)
59 | }
60 | }
61 |
62 | func TestGreetingUnPackError(t *testing.T) {
63 | // NULL
64 | f0 := func(buff *common.Buffer) {
65 | }
66 |
67 | // Write protocol version.
68 | f1 := func(buff *common.Buffer) {
69 | buff.WriteU8(0x01)
70 | }
71 |
72 | // Write server version.
73 | f2 := func(buff *common.Buffer) {
74 | buff.WriteString("5.7.17-11")
75 | buff.WriteZero(1)
76 | }
77 |
78 | // Write connection ID.
79 | f3 := func(buff *common.Buffer) {
80 | buff.WriteU32(uint32(1))
81 | }
82 |
83 | // Write salt[8].
84 | f4 := func(buff *common.Buffer) {
85 | salt8 := make([]byte, 8)
86 | buff.WriteBytes(salt8)
87 | }
88 |
89 | // Write filler.
90 | f5 := func(buff *common.Buffer) {
91 | buff.WriteZero(1)
92 | }
93 |
94 | capability := DefaultServerCapability
95 | capLower := uint16(capability)
96 | capUpper := uint16(uint32(capability) >> 16)
97 |
98 | // Write capability lower 2 bytes
99 | f6 := func(buff *common.Buffer) {
100 | buff.WriteU16(capLower)
101 | }
102 |
103 | // Write charset.
104 | f7 := func(buff *common.Buffer) {
105 | buff.WriteU8(0x01)
106 | }
107 |
108 | // Write statu flags
109 | f8 := func(buff *common.Buffer) {
110 | buff.WriteU16(uint16(1))
111 | }
112 |
113 | // Write capability upper 2 bytes
114 | f9 := func(buff *common.Buffer) {
115 | buff.WriteU16(capUpper)
116 | }
117 |
118 | // Write length of auth-plugin
119 | f10 := func(buff *common.Buffer) {
120 | buff.WriteU8(0x01)
121 | }
122 |
123 | // Write reserved.
124 | f11 := func(buff *common.Buffer) {
125 | buff.WriteZero(10)
126 | }
127 |
128 | // Write auth plugin data part 2
129 | f12 := func(buff *common.Buffer) {
130 | data2 := make([]byte, 13)
131 | data2[12] = 0x01
132 | buff.WriteBytes(data2)
133 | }
134 |
135 | buff := common.NewBuffer(32)
136 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12}
137 | for i := 0; i < len(fs); i++ {
138 | greeting := NewGreeting(0, "")
139 | err := greeting.UnPack(buff.Datas())
140 | assert.NotNil(t, err)
141 | fs[i](buff)
142 | }
143 |
144 | {
145 | greeting := NewGreeting(0, "")
146 | err := greeting.UnPack(buff.Datas())
147 | assert.NotNil(t, err)
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/proto/ok.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "github.com/xelabs/go-mysqlstack/sqldb"
14 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
15 | )
16 |
17 | const (
18 | // OK_PACKET is the OK byte.
19 | OK_PACKET byte = 0x00
20 | )
21 |
22 | // OK used for OK packet.
23 | type OK struct {
24 | Header byte // 0x00
25 | AffectedRows uint64
26 | LastInsertID uint64
27 | StatusFlags uint16
28 | Warnings uint16
29 | }
30 |
31 | // UnPackOK used to unpack the OK packet.
32 | // https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
33 | func UnPackOK(data []byte) (*OK, error) {
34 | var err error
35 | o := &OK{}
36 | buf := common.ReadBuffer(data)
37 |
38 | // header
39 | if o.Header, err = buf.ReadU8(); err != nil {
40 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet header: %v", data)
41 | }
42 | if o.Header != OK_PACKET {
43 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet header: %v", o.Header)
44 | }
45 |
46 | // AffectedRows
47 | if o.AffectedRows, err = buf.ReadLenEncode(); err != nil {
48 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet affectedrows: %v", data)
49 | }
50 |
51 | // LastInsertID
52 | if o.LastInsertID, err = buf.ReadLenEncode(); err != nil {
53 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet lastinsertid: %v", data)
54 | }
55 |
56 | // Status
57 | if o.StatusFlags, err = buf.ReadU16(); err != nil {
58 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet statusflags: %v", data)
59 | }
60 |
61 | // Warnings
62 | if o.Warnings, err = buf.ReadU16(); err != nil {
63 | return nil, sqldb.NewSQLErrorf(sqldb.ER_MALFORMED_PACKET, "invalid ok packet warnings: %v", data)
64 | }
65 | return o, nil
66 | }
67 |
68 | // PackOK used to pack the OK packet.
69 | func PackOK(o *OK) []byte {
70 | buf := common.NewBuffer(64)
71 |
72 | // OK
73 | buf.WriteU8(OK_PACKET)
74 |
75 | // affected rows
76 | buf.WriteLenEncode(o.AffectedRows)
77 |
78 | // last insert id
79 | buf.WriteLenEncode(o.LastInsertID)
80 |
81 | // status
82 | buf.WriteU16(o.StatusFlags)
83 |
84 | // warnings
85 | buf.WriteU16(o.Warnings)
86 | return buf.Datas()
87 | }
88 |
--------------------------------------------------------------------------------
/proto/ok_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 |
17 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
18 | )
19 |
20 | func TestOK(t *testing.T) {
21 | {
22 | buff := common.NewBuffer(32)
23 |
24 | // header
25 | buff.WriteU8(0x00)
26 | // affected_rows
27 | buff.WriteLenEncode(uint64(3))
28 | // last_insert_id
29 | buff.WriteLenEncode(uint64(40000000000))
30 |
31 | // status_flags
32 | buff.WriteU16(0x01)
33 | // warnings
34 | buff.WriteU16(0x02)
35 |
36 | want := &OK{}
37 | want.AffectedRows = 3
38 | want.LastInsertID = 40000000000
39 | want.StatusFlags = 1
40 | want.Warnings = 2
41 |
42 | got, err := UnPackOK(buff.Datas())
43 | assert.Nil(t, err)
44 | assert.Equal(t, want, got)
45 | }
46 |
47 | {
48 | want := &OK{}
49 | want.AffectedRows = 3
50 | want.LastInsertID = 40000000000
51 | want.StatusFlags = 1
52 | want.Warnings = 2
53 | datas := PackOK(want)
54 |
55 | got, err := UnPackOK(datas)
56 | assert.Nil(t, err)
57 | assert.Equal(t, want, got)
58 | }
59 | }
60 |
61 | func TestOKUnPackError(t *testing.T) {
62 | // header error
63 | {
64 | buff := common.NewBuffer(32)
65 | // header
66 | buff.WriteU8(0x99)
67 | _, err := UnPackOK(buff.Datas())
68 | assert.NotNil(t, err)
69 | }
70 |
71 | // NULL
72 | f0 := func(buff *common.Buffer) {
73 | }
74 |
75 | // Write OK header.
76 | f1 := func(buff *common.Buffer) {
77 | buff.WriteU8(0x00)
78 | }
79 |
80 | // Write AffectedRows.
81 | f2 := func(buff *common.Buffer) {
82 | buff.WriteLenEncode(uint64(3))
83 | }
84 |
85 | // Write LastInsertID.
86 | f3 := func(buff *common.Buffer) {
87 | buff.WriteLenEncode(uint64(3))
88 | }
89 |
90 | // Write Status.
91 | f4 := func(buff *common.Buffer) {
92 | buff.WriteU16(0x01)
93 | }
94 |
95 | buff := common.NewBuffer(32)
96 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4}
97 | for i := 0; i < len(fs); i++ {
98 | _, err := UnPackOK(buff.Datas())
99 | assert.NotNil(t, err)
100 | fs[i](buff)
101 | }
102 |
103 | {
104 | _, err := UnPackOK(buff.Datas())
105 | assert.NotNil(t, err)
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/proto/statement_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package proto
11 |
12 | import (
13 | "errors"
14 | "testing"
15 | "time"
16 |
17 | "github.com/stretchr/testify/assert"
18 |
19 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
20 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
21 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
22 | )
23 |
24 | func TestStatementPrepare(t *testing.T) {
25 | want := &Statement{
26 | ID: 5,
27 | ColumnCount: 2,
28 | ParamCount: 3,
29 | Warnings: 1,
30 | }
31 | datas := PackStatementPrepare(want)
32 | got, err := UnPackStatementPrepare(datas)
33 | assert.Nil(t, err)
34 | assert.Equal(t, want, got)
35 | }
36 |
37 | func TestStatementPrepareUnPackError(t *testing.T) {
38 | // NULL
39 | f0 := func(buff *common.Buffer) {
40 | }
41 |
42 | // Write ok.
43 | f1 := func(buff *common.Buffer) {
44 | buff.WriteU8(OK_PACKET)
45 | }
46 |
47 | // Write ID.
48 | f2 := func(buff *common.Buffer) {
49 | buff.WriteU32(1)
50 | }
51 |
52 | // Write Column count.
53 | f3 := func(buff *common.Buffer) {
54 | buff.WriteU16(1)
55 | }
56 |
57 | // Write param count.
58 | f4 := func(buff *common.Buffer) {
59 | buff.WriteU16(2)
60 | }
61 |
62 | // Write reserved.
63 | f5 := func(buff *common.Buffer) {
64 | buff.WriteU8(2)
65 | }
66 |
67 | f6 := func(buff *common.Buffer) {
68 | buff.WriteU8(2)
69 | }
70 |
71 | buff := common.NewBuffer(32)
72 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6}
73 | for i := 0; i < len(fs); i++ {
74 | _, err := UnPackStatementPrepare(buff.Datas())
75 | assert.NotNil(t, err)
76 | fs[i](buff)
77 | }
78 | }
79 |
80 | func TestStatementExecute(t *testing.T) {
81 | id := uint32(11)
82 | values := []sqltypes.Value{
83 | sqltypes.MakeTrusted(sqltypes.Int32, []byte("10")),
84 | sqltypes.MakeTrusted(sqltypes.VarChar, []byte("xx10xx")),
85 | sqltypes.MakeTrusted(sqltypes.Null, nil),
86 | sqltypes.MakeTrusted(sqltypes.Text, []byte{}),
87 | sqltypes.MakeTrusted(sqltypes.Datetime, []byte(time.Now().Format("2006-01-02 15:04:05"))),
88 | }
89 |
90 | datas, err := PackStatementExecute(id, values)
91 | assert.Nil(t, err)
92 |
93 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) {
94 | return nil, nil
95 | }
96 |
97 | protoStmt := &Statement{
98 | ID: id,
99 | ParamCount: uint16(len(values)),
100 | ParamsType: make([]int32, len(values)),
101 | BindVars: make(map[string]*querypb.BindVariable, len(values)),
102 | }
103 | err = UnPackStatementExecute(datas, protoStmt, parseFn)
104 | assert.Nil(t, err)
105 | }
106 |
107 | func TestStatementExecuteUnPackError(t *testing.T) {
108 | // NULL
109 | f0 := func(buff *common.Buffer) {
110 | }
111 |
112 | // Write ID.
113 | f1 := func(buff *common.Buffer) {
114 | buff.WriteU32(1)
115 | }
116 |
117 | // Cursor type.
118 | f2 := func(buff *common.Buffer) {
119 | buff.WriteU8(1)
120 | }
121 |
122 | // Iteration count.
123 | f3 := func(buff *common.Buffer) {
124 | buff.WriteU32(1)
125 | }
126 |
127 | // Write param count.
128 | f4 := func(buff *common.Buffer) {
129 | buff.WriteU16(2)
130 | }
131 |
132 | // Write null bits.
133 | f5 := func(buff *common.Buffer) {
134 | buff.WriteBytes([]byte{0x00})
135 | }
136 |
137 | // newParameterBoundFlag.
138 | f6 := func(buff *common.Buffer) {
139 | buff.WriteU8(0x01)
140 | }
141 |
142 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) {
143 | return nil, errors.New("mock.error")
144 | }
145 |
146 | buff := common.NewBuffer(32)
147 | fs := []func(buff *common.Buffer){f0, f1, f2, f3, f4, f5, f6}
148 | for i := 0; i < len(fs); i++ {
149 |
150 | protoStmt := &Statement{
151 | ID: 1,
152 | ParamCount: 2,
153 | ParamsType: make([]int32, 2),
154 | BindVars: make(map[string]*querypb.BindVariable, 2),
155 | }
156 |
157 | err := UnPackStatementExecute(buff.Datas(), protoStmt, parseFn)
158 | assert.NotNil(t, err)
159 | fs[i](buff)
160 | }
161 | }
162 |
163 | // issue 462.
164 | // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
165 | // test about new-params-bound-flag about 0 1
166 | func TestStatementExecuteBatchUnPackStatementExecute(t *testing.T) {
167 | data := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1}
168 | data2 := []byte{ /*23,*/ 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 0, 1, 128, 1}
169 |
170 | var dataBatch [][]byte
171 | dataBatch = append(dataBatch, data)
172 | dataBatch = append(dataBatch, data2)
173 |
174 | parseFn := func(*common.Buffer, querypb.Type) (interface{}, error) {
175 | return nil, nil
176 | }
177 |
178 | protoStmt := &Statement{
179 | ID: 23,
180 | ParamCount: 1,
181 | ParamsType: make([]int32, 1),
182 | BindVars: make(map[string]*querypb.BindVariable, 1),
183 | }
184 | err := UnPackStatementExecute(dataBatch[0], protoStmt, parseFn)
185 | assert.Nil(t, err)
186 |
187 | err = UnPackStatementExecute(dataBatch[1], protoStmt, parseFn)
188 | assert.Nil(t, err)
189 | }
190 |
--------------------------------------------------------------------------------
/sqldb/constants_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package sqldb
11 |
12 | import (
13 | "testing"
14 | )
15 |
16 | func TestConstants(t *testing.T) {
17 | var i byte
18 | for i = 0; i < COM_RESET_CONNECTION+2; i++ {
19 | CommandString(i)
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/sqldb/sql_error.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqldb
6 |
7 | import (
8 | "bytes"
9 | "fmt"
10 | "regexp"
11 | "strconv"
12 | )
13 |
14 | const (
15 | // SQLStateGeneral is the SQLSTATE value for "general error".
16 | SQLStateGeneral = "HY000"
17 | )
18 |
19 | // SQLError is the error structure returned from calling a db library function
20 | type SQLError struct {
21 | Num uint16
22 | State string
23 | Message string
24 | Query string
25 | }
26 |
27 | // NewSQLError creates new sql error.
28 | func NewSQLError(number uint16, args ...interface{}) *SQLError {
29 | sqlErr := &SQLError{}
30 | err, ok := SQLErrors[number]
31 | if !ok {
32 | unknow := SQLErrors[ER_UNKNOWN_ERROR]
33 | sqlErr.Num = unknow.Num
34 | sqlErr.State = unknow.State
35 | err = unknow
36 | } else {
37 | sqlErr.Num = err.Num
38 | sqlErr.State = err.State
39 | }
40 | sqlErr.Message = fmt.Sprintf(err.Message, args...)
41 | return sqlErr
42 | }
43 |
44 | func NewSQLErrorf(number uint16, format string, args ...interface{}) *SQLError {
45 | sqlErr := &SQLError{}
46 | err, ok := SQLErrors[number]
47 | if !ok {
48 | unknow := SQLErrors[ER_UNKNOWN_ERROR]
49 | sqlErr.Num = unknow.Num
50 | sqlErr.State = unknow.State
51 | } else {
52 | sqlErr.Num = err.Num
53 | sqlErr.State = err.State
54 | }
55 | sqlErr.Message = fmt.Sprintf(format, args...)
56 | return sqlErr
57 | }
58 |
59 | // NewSQLError1 creates new sql error with state.
60 | func NewSQLError1(number uint16, state string, format string, args ...interface{}) *SQLError {
61 | return &SQLError{
62 | Num: number,
63 | State: state,
64 | Message: fmt.Sprintf(format, args...),
65 | }
66 | }
67 |
68 | // Error implements the error interface
69 | func (se *SQLError) Error() string {
70 | buf := &bytes.Buffer{}
71 | buf.WriteString(se.Message)
72 |
73 | // Add MySQL errno and SQLSTATE in a format that we can later parse.
74 | // There's no avoiding string parsing because all errors
75 | // are converted to strings anyway at RPC boundaries.
76 | // See NewSQLErrorFromError.
77 | fmt.Fprintf(buf, " (errno %v) (sqlstate %v)", se.Num, se.State)
78 |
79 | if se.Query != "" {
80 | fmt.Fprintf(buf, " during query: %s", se.Query)
81 | }
82 | return buf.String()
83 | }
84 |
85 | var errExtract = regexp.MustCompile(`.*\(errno ([0-9]*)\) \(sqlstate ([0-9a-zA-Z]{5})\).*`)
86 |
87 | // NewSQLErrorFromError returns a *SQLError from the provided error.
88 | // If it's not the right type, it still tries to get it from a regexp.
89 | func NewSQLErrorFromError(err error) error {
90 | if err == nil {
91 | return nil
92 | }
93 |
94 | if serr, ok := err.(*SQLError); ok {
95 | return serr
96 | }
97 |
98 | msg := err.Error()
99 | match := errExtract.FindStringSubmatch(msg)
100 | if len(match) < 2 {
101 | // Not found, build a generic SQLError.
102 | // TODO(alainjobart) maybe we can also check the canonical
103 | // error code, and translate that into the right error.
104 |
105 | // FIXME(alainjobart): 1105 is unknown error. Will
106 | // merge with sqlconn later.
107 | unknow := SQLErrors[ER_UNKNOWN_ERROR]
108 | return &SQLError{
109 | Num: unknow.Num,
110 | State: unknow.State,
111 | Message: msg,
112 | }
113 | }
114 |
115 | num, err := strconv.Atoi(match[1])
116 | if err != nil {
117 | unknow := SQLErrors[ER_UNKNOWN_ERROR]
118 | return &SQLError{
119 | Num: unknow.Num,
120 | State: unknow.State,
121 | Message: msg,
122 | }
123 | }
124 |
125 | serr := &SQLError{
126 | Num: uint16(num),
127 | State: match[2],
128 | Message: msg,
129 | }
130 | return serr
131 | }
132 |
--------------------------------------------------------------------------------
/sqldb/sql_error_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package sqldb
11 |
12 | import (
13 | "testing"
14 |
15 | "errors"
16 | "github.com/stretchr/testify/assert"
17 | )
18 |
19 | func TestSqlError(t *testing.T) {
20 | {
21 | sqlerr := NewSQLError(1, "i.am.error.man")
22 | assert.Equal(t, "i.am.error.man (errno 1105) (sqlstate HY000)", sqlerr.Error())
23 | }
24 |
25 | {
26 | sqlerr := NewSQLErrorf(1, "i.am.error.man%s", "xx")
27 | assert.Equal(t, "i.am.error.manxx (errno 1105) (sqlstate HY000)", sqlerr.Error())
28 | }
29 |
30 | {
31 | sqlerr := NewSQLError(ER_NO_DB_ERROR)
32 | assert.Equal(t, "No database selected (errno 1046) (sqlstate 3D000)", sqlerr.Error())
33 | }
34 | }
35 |
36 | func TestSqlErrorFromErr(t *testing.T) {
37 | {
38 | err := errors.New("errorman")
39 | sqlerr := NewSQLErrorFromError(err)
40 | assert.NotNil(t, sqlerr)
41 | }
42 |
43 | {
44 | err := errors.New("i.am.error.man (errno 1) (sqlstate HY000)")
45 | sqlerr := NewSQLErrorFromError(err)
46 | assert.NotNil(t, sqlerr)
47 | }
48 |
49 | {
50 | err := errors.New("No database selected (errno 1046) (sqlstate 3D000)")
51 | want := &SQLError{Num: 1046, State: "3D000", Message: "No database selected (errno 1046) (sqlstate 3D000)"}
52 | got := NewSQLErrorFromError(err)
53 | assert.Equal(t, want, got)
54 | }
55 |
56 | {
57 | err := NewSQLError1(10086, "xx", "i.am.the.error.man.%s", "xx")
58 | want := &SQLError{Num: 10086, State: "xx", Message: "i.am.the.error.man.xx"}
59 | got := NewSQLErrorFromError(err)
60 | assert.Equal(t, want, got)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/sqlparser/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/sqlparser/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/sqlparser/.idea/sqlparser.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/sqlparser/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/sqlparser/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright 2012, Google Inc. All rights reserved.
2 | # Use of this source code is governed by a BSD-style license that can
3 | # be found in the LICENSE file.
4 |
5 | MAKEFLAGS = -s
6 |
7 | sql.go: sql.y
8 | goyacc -o sql.go sql.y
9 |
10 | visitor:
11 | go generate rewriter.go
12 |
13 | clean:
14 | rm -f y.output sql.go
15 |
--------------------------------------------------------------------------------
/sqlparser/analyzer.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | // analyzer.go contains utility analysis functions.
20 |
21 | import (
22 | "errors"
23 | "fmt"
24 | "strings"
25 | "unicode"
26 |
27 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
28 | )
29 |
30 | // These constants are used to identify the SQL statement type.
31 | const (
32 | StmtSelect = iota
33 | StmtInsert
34 | StmtReplace
35 | StmtUpdate
36 | StmtDelete
37 | StmtDDL
38 | StmtBegin
39 | StmtCommit
40 | StmtRollback
41 | StmtSet
42 | StmtShow
43 | StmtUse
44 | StmtOther
45 | StmtUnknown
46 | )
47 |
48 | // Preview analyzes the beginning of the query using a simpler and faster
49 | // textual comparison to identify the statement type.
50 | func Preview(sql string) int {
51 | trimmed := StripLeadingComments(sql)
52 |
53 | firstWord := trimmed
54 | if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
55 | firstWord = trimmed[:end]
56 | }
57 |
58 | // Comparison is done in order of priority.
59 | loweredFirstWord := strings.ToLower(firstWord)
60 | switch loweredFirstWord {
61 | case "select":
62 | return StmtSelect
63 | case "insert":
64 | return StmtInsert
65 | case "replace":
66 | return StmtReplace
67 | case "update":
68 | return StmtUpdate
69 | case "delete":
70 | return StmtDelete
71 | }
72 | switch strings.ToLower(trimmed) {
73 | case "begin", "start transaction":
74 | return StmtBegin
75 | case "commit":
76 | return StmtCommit
77 | case "rollback":
78 | return StmtRollback
79 | }
80 | switch loweredFirstWord {
81 | case "create", "alter", "rename", "drop":
82 | return StmtDDL
83 | case "set":
84 | return StmtSet
85 | case "show":
86 | return StmtShow
87 | case "use":
88 | return StmtUse
89 | case "analyze", "describe", "desc", "explain", "repair", "optimize", "truncate":
90 | return StmtOther
91 | }
92 | return StmtUnknown
93 | }
94 |
95 | // IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
96 | func IsDML(sql string) bool {
97 | switch Preview(sql) {
98 | case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
99 | return true
100 | }
101 | return false
102 | }
103 |
104 | // GetTableName returns the table name from the SimpleTableExpr
105 | // only if it's a simple expression. Otherwise, it returns "".
106 | func GetTableName(node SimpleTableExpr) TableIdent {
107 | if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
108 | return n.Name
109 | }
110 | // sub-select or '.' expression
111 | return NewTableIdent("")
112 | }
113 |
114 | // IsColName returns true if the Expr is a *ColName.
115 | func IsColName(node Expr) bool {
116 | _, ok := node.(*ColName)
117 | return ok
118 | }
119 |
120 | // IsValue returns true if the Expr is a string, integral or value arg.
121 | // NULL is not considered to be a value.
122 | func IsValue(node Expr) bool {
123 | switch v := node.(type) {
124 | case *SQLVal:
125 | switch v.Type {
126 | case StrVal, HexVal, IntVal, ValArg:
127 | return true
128 | }
129 | case *ValuesFuncExpr:
130 | if v.Resolved != nil {
131 | return IsValue(v.Resolved)
132 | }
133 | }
134 | return false
135 | }
136 |
137 | // IsNull returns true if the Expr is SQL NULL
138 | func IsNull(node Expr) bool {
139 | switch node.(type) {
140 | case *NullVal:
141 | return true
142 | }
143 | return false
144 | }
145 |
146 | // IsSimpleTuple returns true if the Expr is a ValTuple that
147 | // contains simple values or if it's a list arg.
148 | func IsSimpleTuple(node Expr) bool {
149 | switch vals := node.(type) {
150 | case ValTuple:
151 | for _, n := range vals {
152 | if !IsValue(n) {
153 | return false
154 | }
155 | }
156 | return true
157 | case ListArg:
158 | return true
159 | }
160 | // It's a subquery
161 | return false
162 | }
163 |
164 | // NewPlanValue builds a sqltypes.PlanValue from an Expr.
165 | func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
166 | switch node := node.(type) {
167 | case *SQLVal:
168 | switch node.Type {
169 | case ValArg:
170 | return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
171 | case IntVal:
172 | n, err := sqltypes.NewIntegral(string(node.Val))
173 | if err != nil {
174 | return sqltypes.PlanValue{}, err
175 | }
176 | return sqltypes.PlanValue{Value: n}, nil
177 | case StrVal:
178 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
179 | case HexVal:
180 | v, err := node.HexDecode()
181 | if err != nil {
182 | return sqltypes.PlanValue{}, err
183 | }
184 | return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
185 | }
186 | case ListArg:
187 | return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
188 | case ValTuple:
189 | pv := sqltypes.PlanValue{
190 | Values: make([]sqltypes.PlanValue, 0, len(node)),
191 | }
192 | for _, val := range node {
193 | innerpv, err := NewPlanValue(val)
194 | if err != nil {
195 | return sqltypes.PlanValue{}, err
196 | }
197 | if innerpv.ListKey != "" || innerpv.Values != nil {
198 | return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
199 | }
200 | pv.Values = append(pv.Values, innerpv)
201 | }
202 | return pv, nil
203 | case *ValuesFuncExpr:
204 | if node.Resolved != nil {
205 | return NewPlanValue(node.Resolved)
206 | }
207 | case *NullVal:
208 | return sqltypes.PlanValue{}, nil
209 | }
210 | return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
211 | }
212 |
213 | // StringIn is a convenience function that returns
214 | // true if str matches any of the values.
215 | func StringIn(str string, values ...string) bool {
216 | for _, val := range values {
217 | if str == val {
218 | return true
219 | }
220 | }
221 | return false
222 | }
223 |
--------------------------------------------------------------------------------
/sqlparser/checksum_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "strings"
20 | import "testing"
21 |
22 | func TestChecksumTable(t *testing.T) {
23 | validSQL := []struct {
24 | input string
25 | output string
26 | }{
27 | {
28 | input: "checksum table test.t1",
29 | output: "checksum table test.t1",
30 | },
31 |
32 | {
33 | input: "checksum table t1",
34 | output: "checksum table t1",
35 | },
36 | }
37 |
38 | for _, s := range validSQL {
39 | sql := strings.TrimSpace(s.input)
40 | tree, err := Parse(sql)
41 | if err != nil {
42 | t.Errorf("input: %s, err: %v", sql, err)
43 | continue
44 | }
45 |
46 | // Walk.
47 | Walk(func(node SQLNode) (bool, error) {
48 | return true, nil
49 | }, tree)
50 |
51 | got := String(tree.(*Checksum))
52 | if s.output != got {
53 | t.Errorf("want:\n%s\ngot:\n%s", s.output, got)
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/sqlparser/depends/bytes2/buffer.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package bytes2
18 |
19 | // Buffer implements a subset of the write portion of
20 | // bytes.Buffer, but more efficiently. This is meant to
21 | // be used in very high QPS operations, especially for
22 | // WriteByte, and without abstracting it as a Writer.
23 | // Function signatures contain errors for compatibility,
24 | // but they do not return errors.
25 | type Buffer struct {
26 | bytes []byte
27 | }
28 |
29 | // NewBuffer is equivalent to bytes.NewBuffer.
30 | func NewBuffer(b []byte) *Buffer {
31 | return &Buffer{bytes: b}
32 | }
33 |
34 | // Write is equivalent to bytes.Buffer.Write.
35 | func (buf *Buffer) Write(b []byte) (int, error) {
36 | buf.bytes = append(buf.bytes, b...)
37 | return len(b), nil
38 | }
39 |
40 | // WriteString is equivalent to bytes.Buffer.WriteString.
41 | func (buf *Buffer) WriteString(s string) (int, error) {
42 | buf.bytes = append(buf.bytes, s...)
43 | return len(s), nil
44 | }
45 |
46 | // WriteByte is equivalent to bytes.Buffer.WriteByte.
47 | func (buf *Buffer) WriteByte(b byte) error {
48 | buf.bytes = append(buf.bytes, b)
49 | return nil
50 | }
51 |
52 | // Bytes is equivalent to bytes.Buffer.Bytes.
53 | func (buf *Buffer) Bytes() []byte {
54 | return buf.bytes
55 | }
56 |
57 | // Strings is equivalent to bytes.Buffer.Strings.
58 | func (buf *Buffer) String() string {
59 | return string(buf.bytes)
60 | }
61 |
62 | // Len is equivalent to bytes.Buffer.Len.
63 | func (buf *Buffer) Len() int {
64 | return len(buf.bytes)
65 | }
66 |
--------------------------------------------------------------------------------
/sqlparser/depends/bytes2/buffer_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package bytes2
18 |
19 | import "testing"
20 |
21 | func TestBuffer(t *testing.T) {
22 | b := NewBuffer(nil)
23 | b.Write([]byte("ab"))
24 | b.WriteString("cd")
25 | b.WriteByte('e')
26 | want := "abcde"
27 | if got := string(b.Bytes()); got != want {
28 | t.Errorf("b.Bytes(): %s, want %s", got, want)
29 | }
30 | if got := b.String(); got != want {
31 | t.Errorf("b.String(): %s, want %s", got, want)
32 | }
33 | if got := b.Len(); got != 5 {
34 | t.Errorf("b.Len(): %d, want 5", got)
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/sqlparser/depends/common/hash_table.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package common
11 |
12 | import (
13 | "bytes"
14 | "hash/fnv"
15 | )
16 |
17 | // hash64a used to get bucket.
18 | func hash64a(data []byte) uint64 {
19 | h := fnv.New64a()
20 | h.Write(data)
21 | return h.Sum64()
22 | }
23 |
24 | type entry struct {
25 | // key slice.
26 | key []byte
27 | // value interface.
28 | value []interface{}
29 | // point to the next entry.
30 | next *entry
31 | }
32 |
33 | func (e *entry) put(key []byte, value interface{}) *entry {
34 | if e == nil {
35 | return &entry{key, []interface{}{value}, nil}
36 | }
37 | if bytes.Equal(e.key, key) {
38 | e.value = append(e.value, value)
39 | return e
40 | }
41 |
42 | e.next = e.next.put(key, value)
43 | return e
44 | }
45 |
46 | func (e *entry) get(key []byte) (bool, []interface{}) {
47 | if e == nil {
48 | return false, nil
49 | } else if bytes.Equal(e.key, key) {
50 | return true, e.value
51 | } else {
52 | return e.next.get(key)
53 | }
54 | }
55 |
56 | // HashTable the hash table.
57 | type HashTable struct {
58 | // stores value for a given key.
59 | hashEntry []*entry
60 | // k: bucket. v: index in the hashEntry.
61 | hashMap map[uint64]int
62 | // size of entrys.
63 | size int
64 | }
65 |
66 | // NewHashTable create hash table.
67 | func NewHashTable() *HashTable {
68 | return &HashTable{
69 | hashMap: make(map[uint64]int),
70 | size: 0,
71 | }
72 | }
73 |
74 | // Size used to get the hashtable size.
75 | func (h *HashTable) Size() int {
76 | return h.size
77 | }
78 |
79 | // Put puts the key/value pairs to the HashTable.
80 | func (h *HashTable) Put(key []byte, value interface{}) {
81 | var table *entry
82 | bucket := hash64a(key)
83 | index, ok := h.hashMap[bucket]
84 | if !ok {
85 | table = &entry{key, []interface{}{value}, nil}
86 | h.hashMap[bucket] = len(h.hashEntry)
87 | h.hashEntry = append(h.hashEntry, table)
88 | } else {
89 | h.hashEntry[index] = h.hashEntry[index].put(key, value)
90 | }
91 | h.size++
92 | }
93 |
94 | // Get gets the values of the "key".
95 | func (h *HashTable) Get(key []byte) (bool, []interface{}) {
96 | bucket := hash64a(key)
97 | index, ok := h.hashMap[bucket]
98 | if !ok {
99 | return false, nil
100 | }
101 | return h.hashEntry[index].get(key)
102 | }
103 |
104 | // Iterator used to iterate the HashTable.
105 | type Iterator func() (key []byte, value []interface{}, next Iterator)
106 |
107 | // Next used to iterate the HashTable.
108 | func (h *HashTable) Next() Iterator {
109 | var e *entry
110 | var iter Iterator
111 | table := h.hashEntry
112 | i := -1
113 | iter = func() (key []byte, val []interface{}, next Iterator) {
114 | for e == nil {
115 | i++
116 | if i >= len(table) {
117 | return nil, nil, nil
118 | }
119 | e = table[i]
120 | }
121 | key = e.key
122 | val = e.value
123 | e = e.next
124 | return key, val, iter
125 | }
126 | return iter
127 | }
128 |
--------------------------------------------------------------------------------
/sqlparser/depends/common/hash_table_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package common
11 |
12 | import (
13 | "bytes"
14 | crand "crypto/rand"
15 | "testing"
16 | )
17 |
18 | func TestHashKey(t *testing.T) {
19 | a := []byte("asdf")
20 | b := []byte("asdf")
21 | c := []byte("csfd")
22 | if !bytes.Equal(a, b) {
23 | t.Error("a != b")
24 | }
25 | if hash64a(a) != hash64a(b) {
26 | t.Error("hash64a(a) != hash64a(b)")
27 | }
28 | if bytes.Equal(a, c) {
29 | t.Error("a == c")
30 | }
31 | if hash64a(a) == hash64a(c) {
32 | t.Error("hash64a(a) == hash64a(c)")
33 | }
34 | }
35 |
36 | func randSlice(length int) []byte {
37 | slice := make([]byte, length)
38 | if _, err := crand.Read(slice); err != nil {
39 | panic(err)
40 | }
41 | return slice
42 | }
43 |
44 | func TestPutHasGetRemove(t *testing.T) {
45 |
46 | type record struct {
47 | key []byte
48 | val []byte
49 | }
50 |
51 | ranrec := func() *record {
52 | return &record{
53 | randSlice(20),
54 | randSlice(20),
55 | }
56 | }
57 |
58 | table := NewHashTable()
59 | records := make([]*record, 400)
60 | var i int
61 | for i = range records {
62 | r := ranrec()
63 | records[i] = r
64 | table.Put(r.key, []byte(""))
65 | table.Put(r.key, r.val)
66 |
67 | if table.Size() != 2*(i+1) {
68 | t.Error("size was wrong", table.Size(), i+1)
69 | }
70 | }
71 |
72 | for _, r := range records {
73 | if has, val := table.Get(r.key); !has {
74 | t.Error(table, "Missing key")
75 | } else if !bytes.Equal(val[1].([]byte), r.val) {
76 | t.Error("wrong value")
77 | }
78 | if has, _ := table.Get(randSlice(12)); has {
79 | t.Error("Table has extra key")
80 | }
81 | }
82 | }
83 |
84 | func TestIterate(t *testing.T) {
85 | table := NewHashTable()
86 | t.Logf("%T", table)
87 | for k, v, next := table.Next()(); next != nil; k, v, next = next() {
88 | t.Errorf("Should never reach here %v %v %v", k, v, next)
89 | }
90 | records := make(map[string][]byte)
91 | for i := 0; i < 100; i++ {
92 | v := randSlice(8)
93 | keySlice := []byte{0x01}
94 | keySlice = append(keySlice, v...)
95 | keySlice = append(keySlice, 0x02)
96 | k := BytesToString(keySlice)
97 | records[k] = v
98 | table.Put(v, k)
99 | if table.Size() != (i + 1) {
100 | t.Error("size was wrong", table.Size(), i+1)
101 | }
102 | }
103 | count := 0
104 | for k, v, next := table.Next()(); next != nil; k, v, next = next() {
105 | if v1, has := records[v[0].(string)]; !has {
106 | t.Error("bad key in table")
107 | } else if !bytes.Equal(k, v1) {
108 | t.Error("values don't equal")
109 | }
110 | count++
111 | }
112 | if len(records) != count {
113 | t.Error("iterate missed records")
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/sqlparser/depends/common/unsafe.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package common
11 |
12 | import (
13 | "reflect"
14 | "unsafe"
15 | )
16 |
17 | // BytesToString casts slice to string without copy
18 | func BytesToString(b []byte) (s string) {
19 | if len(b) == 0 {
20 | return ""
21 | }
22 |
23 | bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
24 | sh := reflect.StringHeader{Data: bh.Data, Len: bh.Len}
25 |
26 | return *(*string)(unsafe.Pointer(&sh))
27 | }
28 |
29 | // StringToBytes casts string to slice without copy
30 | func StringToBytes(s string) []byte {
31 | if len(s) == 0 {
32 | return []byte{}
33 | }
34 |
35 | sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
36 | bh := reflect.SliceHeader{Data: sh.Data, Len: sh.Len, Cap: sh.Len}
37 |
38 | return *(*[]byte)(unsafe.Pointer(&bh))
39 | }
40 |
--------------------------------------------------------------------------------
/sqlparser/depends/common/unsafe_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package common
11 |
12 | import (
13 | "github.com/stretchr/testify/assert"
14 | "testing"
15 | )
16 |
17 | func TestBytesToString(t *testing.T) {
18 | {
19 | bs := []byte{0x61, 0x62}
20 | want := "ab"
21 | got := BytesToString(bs)
22 | assert.Equal(t, want, got)
23 | }
24 |
25 | {
26 | bs := []byte{}
27 | want := ""
28 | got := BytesToString(bs)
29 | assert.Equal(t, want, got)
30 | }
31 | }
32 |
33 | func TestSting(t *testing.T) {
34 | {
35 | want := []byte{0x61, 0x62}
36 | got := StringToBytes("ab")
37 | assert.Equal(t, want, got)
38 | }
39 |
40 | {
41 | want := []byte{}
42 | got := StringToBytes("")
43 | assert.Equal(t, want, got)
44 | }
45 | }
46 |
47 | func TestStingToBytes(t *testing.T) {
48 | {
49 | want := []byte{0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x2a, 0x20, 0x46, 0x52, 0x4f, 0x4d, 0x20, 0x74, 0x32}
50 | got := StringToBytes("SELECT * FROM t2")
51 | assert.Equal(t, want, got)
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/column.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 | //
5 | // Copyright (c) XeLabs
6 | // BohuTANG
7 |
8 | package sqltypes
9 |
10 | import (
11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
12 | )
13 |
14 | // RemoveColumns used to remove columns who in the idxs.
15 | func (result *Result) RemoveColumns(idxs ...int) {
16 | c := len(idxs)
17 | if c == 0 {
18 | return
19 | }
20 |
21 | if result.Fields != nil {
22 | var fields []*querypb.Field
23 | for i, f := range result.Fields {
24 | in := false
25 | for _, idx := range idxs {
26 | if i == idx {
27 | in = true
28 | break
29 | }
30 | }
31 | if !in {
32 | fields = append(fields, f)
33 | }
34 | }
35 | result.Fields = fields
36 | }
37 |
38 | if result.Rows != nil {
39 | for i, r := range result.Rows {
40 | var row []Value
41 | for i, v := range r {
42 | in := false
43 | for _, idx := range idxs {
44 | if i == idx {
45 | in = true
46 | break
47 | }
48 | }
49 | if !in {
50 | row = append(row, v)
51 | }
52 | }
53 | result.Rows[i] = row
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/column_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 | //
5 | // Copyright (c) XeLabs
6 | // BohuTANG
7 |
8 | package sqltypes
9 |
10 | import (
11 | "fmt"
12 | "reflect"
13 | "testing"
14 |
15 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
16 | )
17 |
18 | func TestColumnRemove(t *testing.T) {
19 | rt := &Result{
20 | Fields: []*querypb.Field{{
21 | Name: "a",
22 | Type: Int32,
23 | }, {
24 | Name: "b",
25 | Type: Uint24,
26 | }, {
27 | Name: "c",
28 | Type: Float32,
29 | },
30 | },
31 | Rows: [][]Value{
32 | {testVal(Int32, "-5"), testVal(Uint64, "10"), testVal(Float32, "3.1415926")},
33 | {testVal(Int32, "-4"), testVal(Uint64, "9"), testVal(Float32, "3.1415927")},
34 | {testVal(Int32, "-3"), testVal(Uint64, "8"), testVal(Float32, "3.1415928")},
35 | {testVal(Int32, "1"), testVal(Uint64, "1"), testVal(Float32, "3.1415926")},
36 | {testVal(Int32, "1"), testVal(Uint64, "1"), testVal(Float32, "3.1415925")},
37 | },
38 | }
39 |
40 | {
41 | rs := rt.Copy()
42 | rs.RemoveColumns(0)
43 | {
44 | want := []*querypb.Field{
45 | {
46 | Name: "b",
47 | Type: Uint24,
48 | }, {
49 | Name: "c",
50 | Type: Float32,
51 | },
52 | }
53 | got := rs.Fields
54 | if !reflect.DeepEqual(want, got) {
55 | t.Errorf("want:%+v\n, got:%+v", want, got)
56 | }
57 | }
58 |
59 | {
60 | want := "[[10 3.1415926] [9 3.1415927] [8 3.1415928] [1 3.1415926] [1 3.1415925]]"
61 | got := fmt.Sprintf("%+v", rs.Rows)
62 | if want != got {
63 | t.Errorf("want:%s\n, got:%+s", want, got)
64 | }
65 | }
66 | }
67 |
68 | {
69 | rs := rt.Copy()
70 | rs.RemoveColumns(2)
71 | {
72 | want := []*querypb.Field{
73 | {
74 | Name: "a",
75 | Type: Int32,
76 | }, {
77 | Name: "b",
78 | Type: Uint24,
79 | },
80 | }
81 | got := rs.Fields
82 | if !reflect.DeepEqual(want, got) {
83 | t.Errorf("want:%+v\n, got:%+v", want, got)
84 | }
85 | }
86 |
87 | {
88 | want := "[[-5 10] [-4 9] [-3 8] [1 1] [1 1]]"
89 | got := fmt.Sprintf("%+v", rs.Rows)
90 | if want != got {
91 | t.Errorf("want:%s\n, got:%s", want, got)
92 | }
93 | }
94 | }
95 |
96 | {
97 | rs := rt.Copy()
98 | rs.RemoveColumns(0, 1)
99 | {
100 | want := []*querypb.Field{
101 | {
102 | Name: "c",
103 | Type: Float32,
104 | },
105 | }
106 | got := rs.Fields
107 | if !reflect.DeepEqual(want, got) {
108 | t.Errorf("want:%+v\n, got:%+v", want, got)
109 | }
110 | }
111 |
112 | {
113 | want := "[[3.1415926] [3.1415927] [3.1415928] [3.1415926] [3.1415925]]"
114 | got := fmt.Sprintf("%+v", rs.Rows)
115 | if want != got {
116 | t.Errorf("want:%s\n, got:%s", want, got)
117 | }
118 | }
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/const.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Radon
3 | *
4 | * Copyright 2019 The Radon Authors.
5 | * Code is licensed under the GPLv3.
6 | *
7 | */
8 |
9 | package sqltypes
10 |
11 | const (
12 | // DecimalLongLongDigits decimal longlong digits.
13 | DecimalLongLongDigits = 22
14 | // FloatDigits float decimal precision.
15 | FloatDigits = 6
16 | // DoubleDigits double decimal precision.
17 | DoubleDigits = 15
18 | )
19 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/limit.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 | //
5 | // Copyright (c) XeLabs
6 | // BohuTANG
7 |
8 | package sqltypes
9 |
10 | // Limit used to cutoff the rows based on the MySQL LIMIT and OFFSET clauses.
11 | func (result *Result) Limit(offset, limit int) {
12 | count := len(result.Rows)
13 | start := offset
14 | end := offset + limit
15 | if start > count {
16 | start = count
17 | }
18 | if end > count {
19 | end = count
20 | }
21 | result.Rows = result.Rows[start:end]
22 | }
23 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/limit_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 | //
5 | // Copyright (c) XeLabs
6 | // BohuTANG
7 |
8 | package sqltypes
9 |
10 | import (
11 | "reflect"
12 | "testing"
13 | )
14 |
15 | func TestLimit(t *testing.T) {
16 | rs := &Result{
17 | Rows: [][]Value{
18 | {testVal(VarChar, "1")}, {testVal(VarChar, "2")}, {testVal(VarChar, "3")}, {testVal(VarChar, "4")}, {testVal(VarChar, "5")},
19 | },
20 | }
21 |
22 | // normal: offset 0, limit 1.
23 | {
24 | rs1 := rs.Copy()
25 | rs1.Limit(0, 1)
26 | want := rs.Rows[0:1]
27 | got := rs1.Rows
28 |
29 | if !reflect.DeepEqual(want, got) {
30 | t.Errorf("want:\n%#v, got\n%#v", want, got)
31 | }
32 | }
33 |
34 | // normal: offset 0, limit 5.
35 | {
36 | rs1 := rs.Copy()
37 | rs1.Limit(0, 5)
38 | want := rs.Rows
39 | got := rs1.Rows
40 |
41 | if !reflect.DeepEqual(want, got) {
42 | t.Errorf("want:\n%#v, got\n%#v", want, got)
43 | }
44 | }
45 |
46 | // normal: offset 1, limit 4.
47 | {
48 | rs1 := rs.Copy()
49 | rs1.Limit(1, 4)
50 | want := rs.Rows[1:5]
51 | got := rs1.Rows
52 |
53 | if !reflect.DeepEqual(want, got) {
54 | t.Errorf("want:\n%#v, got\n%#v", want, got)
55 | }
56 | }
57 |
58 | // limit overflow: offset 0, limit 6.
59 | {
60 | rs1 := rs.Copy()
61 | rs1.Limit(0, 6)
62 | want := rs.Rows
63 | got := rs1.Rows
64 |
65 | if !reflect.DeepEqual(want, got) {
66 | t.Errorf("want:\n%#v, got\n%#v", want, got)
67 | }
68 | }
69 |
70 | // offset overflow: offset 5, limit 0.
71 | {
72 | rs1 := rs.Copy()
73 | rs1.Limit(5, 0)
74 | want := rs.Rows[5:5]
75 | got := rs1.Rows
76 |
77 | if !reflect.DeepEqual(want, got) {
78 | t.Errorf("want:\n%#v, got\n%#v", want, got)
79 | }
80 | }
81 |
82 | // (offset+limit) overflow: offset 3, limit 6.
83 | {
84 | rs1 := rs.Copy()
85 | rs1.Limit(3, 6)
86 | want := rs.Rows[3:5]
87 | got := rs1.Rows
88 |
89 | if !reflect.DeepEqual(want, got) {
90 | t.Errorf("want:\n%#v, got\n%#v", want, got)
91 | }
92 | }
93 |
94 | // Empty test.
95 | {
96 | rs1 := &Result{
97 | Rows: [][]Value{
98 | {},
99 | },
100 | }
101 |
102 | rs1.Limit(3, 6)
103 | want := rs.Rows[0:0]
104 | got := rs1.Rows
105 |
106 | if !reflect.DeepEqual(want, got) {
107 | t.Errorf("want:\n%#v, got\n%#v", want, got)
108 | }
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/plan_value.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqltypes
18 |
19 | import ()
20 |
21 | // PlanValue represents a value or a list of values for
22 | // a column that will later be resolved using bind vars and used
23 | // to perform plan actions like generating the final query or
24 | // deciding on a route.
25 | //
26 | // Plan values are typically used as a slice ([]planValue)
27 | // where each entry is for one column. For situations where
28 | // the required output is a list of rows (like in the case
29 | // of multi-value inserts), the representation is pivoted.
30 | // For example, a statement like this:
31 | // INSERT INTO t VALUES (1, 2), (3, 4)
32 | // will be represented as follows:
33 | // []PlanValue{
34 | // Values: {1, 3},
35 | // Values: {2, 4},
36 | // }
37 | //
38 | // For WHERE clause items that contain a combination of
39 | // equality expressions and IN clauses like this:
40 | // WHERE pk1 = 1 AND pk2 IN (2, 3, 4)
41 | // The plan values will be represented as follows:
42 | // []PlanValue{
43 | // Value: 1,
44 | // Values: {2, 3, 4},
45 | // }
46 | // When converted into rows, columns with single values
47 | // are replicated as the same for all rows:
48 | // [][]Value{
49 | // {1, 2},
50 | // {1, 3},
51 | // {1, 4},
52 | // }
53 | type PlanValue struct {
54 | Key string
55 | Value Value
56 | ListKey string
57 | Values []PlanValue
58 | }
59 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/result.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqltypes
6 |
7 | import (
8 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
9 | )
10 |
11 | // ResultState enum.
12 | type ResultState int
13 |
14 | const (
15 | // RStateNone enum.
16 | RStateNone ResultState = iota
17 | // RStateFields enum.
18 | RStateFields
19 | // RStateRows enum.
20 | RStateRows
21 | // RStateFinished enum.
22 | RStateFinished
23 | )
24 |
25 | // Result represents a query result.
26 | type Result struct {
27 | Fields []*querypb.Field `json:"fields"`
28 | RowsAffected uint64 `json:"rows_affected"`
29 | InsertID uint64 `json:"insert_id"`
30 | Warnings uint16 `json:"warnings"`
31 | Rows [][]Value `json:"rows"`
32 | Extras *querypb.ResultExtras `json:"extras"`
33 | State ResultState
34 | }
35 |
36 | // ResultStream is an interface for receiving Result. It is used for
37 | // RPC interfaces.
38 | type ResultStream interface {
39 | // Recv returns the next result on the stream.
40 | // It will return io.EOF if the stream ended.
41 | Recv() (*Result, error)
42 | }
43 |
44 | // Repair fixes the type info in the rows
45 | // to conform to the supplied field types.
46 | func (result *Result) Repair(fields []*querypb.Field) {
47 | // Usage of j is intentional.
48 | for j, f := range fields {
49 | for _, r := range result.Rows {
50 | if r[j].typ != Null {
51 | r[j].typ = f.Type
52 | }
53 | }
54 | }
55 | }
56 |
57 | // Copy creates a deep copy of Result.
58 | func (result *Result) Copy() *Result {
59 | out := &Result{
60 | InsertID: result.InsertID,
61 | RowsAffected: result.RowsAffected,
62 | }
63 | if result.Fields != nil {
64 | fieldsp := make([]*querypb.Field, len(result.Fields))
65 | fields := make([]querypb.Field, len(result.Fields))
66 | for i, f := range result.Fields {
67 | fields[i] = *f
68 | fieldsp[i] = &fields[i]
69 | }
70 | out.Fields = fieldsp
71 | }
72 | if result.Rows != nil {
73 | rows := make([][]Value, len(result.Rows))
74 | for i, r := range result.Rows {
75 | rows[i] = make([]Value, len(r))
76 | totalLen := 0
77 | for _, c := range r {
78 | totalLen += len(c.val)
79 | }
80 | arena := make([]byte, 0, totalLen)
81 | for j, c := range r {
82 | start := len(arena)
83 | arena = append(arena, c.val...)
84 | rows[i][j] = MakeTrusted(c.typ, arena[start:start+len(c.val)])
85 | }
86 | }
87 | out.Rows = rows
88 | }
89 | return out
90 | }
91 |
92 | // StripFieldNames will return a new Result that has the same Rows,
93 | // but the Field objects will have their Name emptied. Note we don't
94 | // proto.Copy each Field for performance reasons, but we only copy the
95 | // individual fields.
96 | func (result *Result) StripFieldNames() *Result {
97 | if len(result.Fields) == 0 {
98 | return result
99 | }
100 | r := *result
101 | r.Fields = make([]*querypb.Field, len(result.Fields))
102 | newFieldsArray := make([]querypb.Field, len(result.Fields))
103 | for i, f := range result.Fields {
104 | r.Fields[i] = &newFieldsArray[i]
105 | newFieldsArray[i].Type = f.Type
106 | }
107 | return &r
108 | }
109 |
110 | // AppendResult will combine the Results Objects of one result
111 | // to another result.Note currently it doesn't handle cases like
112 | // if two results have different fields.We will enhance this function.
113 | func (result *Result) AppendResult(src *Result) {
114 | if src.RowsAffected == 0 && len(src.Fields) == 0 {
115 | return
116 | }
117 | if result.Fields == nil {
118 | result.Fields = src.Fields
119 | }
120 | result.RowsAffected += src.RowsAffected
121 | if src.InsertID != 0 {
122 | result.InsertID = src.InsertID
123 | }
124 | if len(src.Rows) != 0 {
125 | result.Rows = append(result.Rows, src.Rows...)
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/result_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqltypes
6 |
7 | import (
8 | "reflect"
9 | "testing"
10 |
11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
12 | )
13 |
14 | func TestRepair(t *testing.T) {
15 | fields := []*querypb.Field{{
16 | Type: Int64,
17 | }, {
18 | Type: VarChar,
19 | }}
20 | in := Result{
21 | Rows: [][]Value{
22 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")},
23 | {testVal(VarBinary, "2"), testVal(VarBinary, "bb")},
24 | },
25 | }
26 | want := Result{
27 | Rows: [][]Value{
28 | {testVal(Int64, "1"), testVal(VarChar, "aa")},
29 | {testVal(Int64, "2"), testVal(VarChar, "bb")},
30 | },
31 | }
32 | in.Repair(fields)
33 | if !reflect.DeepEqual(in, want) {
34 | t.Errorf("Repair:\n%#v, want\n%#v", in, want)
35 | }
36 | }
37 |
38 | func TestCopy(t *testing.T) {
39 | in := &Result{
40 | Fields: []*querypb.Field{{
41 | Type: Int64,
42 | }, {
43 | Type: VarChar,
44 | }},
45 | InsertID: 1,
46 | RowsAffected: 2,
47 | Rows: [][]Value{
48 | {testVal(Int64, "1"), NULL},
49 | {testVal(Int64, "2"), MakeTrusted(VarChar, nil)},
50 | {testVal(Int64, "3"), testVal(VarChar, "")},
51 | },
52 | }
53 | want := &Result{
54 | Fields: []*querypb.Field{{
55 | Type: Int64,
56 | }, {
57 | Type: VarChar,
58 | }},
59 | InsertID: 1,
60 | RowsAffected: 2,
61 | Rows: [][]Value{
62 | {testVal(Int64, "1"), NULL},
63 | {testVal(Int64, "2"), testVal(VarChar, "")},
64 | {testVal(Int64, "3"), testVal(VarChar, "")},
65 | },
66 | }
67 | out := in.Copy()
68 | // Change in so we're sure out got actually copied
69 | in.Fields[0].Type = VarChar
70 | in.Rows[0][0] = testVal(VarChar, "aa")
71 | if !reflect.DeepEqual(out, want) {
72 | t.Errorf("Copy:\n%#v, want\n%#v", out, want)
73 | }
74 | }
75 |
76 | func TestStripFieldNames(t *testing.T) {
77 | testcases := []struct {
78 | name string
79 | in *Result
80 | expected *Result
81 | }{{
82 | name: "no fields",
83 | in: &Result{},
84 | expected: &Result{},
85 | }, {
86 | name: "empty fields",
87 | in: &Result{
88 | Fields: []*querypb.Field{},
89 | },
90 | expected: &Result{
91 | Fields: []*querypb.Field{},
92 | },
93 | }, {
94 | name: "no name",
95 | in: &Result{
96 | Fields: []*querypb.Field{{
97 | Type: Int64,
98 | }, {
99 | Type: VarChar,
100 | }},
101 | },
102 | expected: &Result{
103 | Fields: []*querypb.Field{{
104 | Type: Int64,
105 | }, {
106 | Type: VarChar,
107 | }},
108 | },
109 | }, {
110 | name: "names",
111 | in: &Result{
112 | Fields: []*querypb.Field{{
113 | Name: "field1",
114 | Type: Int64,
115 | }, {
116 | Name: "field2",
117 | Type: VarChar,
118 | }},
119 | },
120 | expected: &Result{
121 | Fields: []*querypb.Field{{
122 | Type: Int64,
123 | }, {
124 | Type: VarChar,
125 | }},
126 | },
127 | }}
128 | for _, tcase := range testcases {
129 | inCopy := tcase.in.Copy()
130 | out := inCopy.StripFieldNames()
131 | if !reflect.DeepEqual(out, tcase.expected) {
132 | t.Errorf("StripFieldNames unexpected result for %v: %v", tcase.name, out)
133 | }
134 | if len(tcase.in.Fields) > 0 {
135 | // check the out array is different than the in array.
136 | if out.Fields[0] == inCopy.Fields[0] {
137 | t.Errorf("StripFieldNames modified original Field for %v", tcase.name)
138 | }
139 | }
140 | // check we didn't change the original result.
141 | if !reflect.DeepEqual(tcase.in, inCopy) {
142 | t.Errorf("StripFieldNames modified original result")
143 | }
144 | }
145 | }
146 |
147 | func TestAppendResult(t *testing.T) {
148 | r1 := &Result{
149 | RowsAffected: 3,
150 | Rows: [][]Value{
151 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")},
152 | },
153 | }
154 | r2 := &Result{
155 | RowsAffected: 5,
156 | Rows: [][]Value{
157 | {testVal(VarBinary, "2"), testVal(VarBinary, "aa2")},
158 | },
159 | }
160 | r1.AppendResult(r2)
161 |
162 | got := r1
163 | want := &Result{
164 | RowsAffected: 8,
165 | Rows: [][]Value{
166 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")},
167 | {testVal(VarBinary, "2"), testVal(VarBinary, "aa2")},
168 | },
169 | }
170 | if !reflect.DeepEqual(got, want) {
171 | t.Errorf("Append:\n%#v, want\n%#v", got, want)
172 | }
173 | }
174 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/row.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 | //
5 | // Copyright (c) XeLabs
6 | // BohuTANG
7 |
8 | package sqltypes
9 |
10 | // Row operations.
11 | type Row []Value
12 |
13 | // Copy used to clone the new value.
14 | func (r Row) Copy() []Value {
15 | ret := make([]Value, len(r))
16 | for i, v := range r {
17 | ret[i] = v
18 | }
19 | return ret
20 | }
21 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/time.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | */
8 |
9 | package sqltypes
10 |
11 | import (
12 | "errors"
13 | "strconv"
14 | "strings"
15 |
16 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
17 | )
18 |
19 | // timeToNumeric used to cast time type to numeric.
20 | func timeToNumeric(v Value) (numeric, error) {
21 | switch v.Type() {
22 | case querypb.Type_TIMESTAMP, querypb.Type_DATETIME:
23 | var i int64
24 | year, err := strconv.ParseInt(string(v.val[0:4]), 10, 16)
25 | if err != nil {
26 | return numeric{}, err
27 | }
28 | month, err := strconv.ParseInt(string(v.val[5:7]), 10, 8)
29 | if err != nil {
30 | return numeric{}, err
31 | }
32 | day, err := strconv.ParseInt(string(v.val[8:10]), 10, 8)
33 | if err != nil {
34 | return numeric{}, err
35 | }
36 | hour, err := strconv.ParseInt(string(v.val[11:13]), 10, 8)
37 | if err != nil {
38 | return numeric{}, err
39 | }
40 | minute, err := strconv.ParseInt(string(v.val[14:16]), 10, 8)
41 | if err != nil {
42 | return numeric{}, err
43 | }
44 | second, err := strconv.ParseInt(string(v.val[17:19]), 10, 8)
45 | if err != nil {
46 | return numeric{}, err
47 | }
48 |
49 | i = (year*10000+month*100+day)*1000000 + (hour*10000 + minute*100 + second)
50 | if len(v.val) > 19 {
51 | var f float64
52 | microSecond, err := strconv.ParseUint(string(v.val[20:]), 10, 32)
53 | if err != nil {
54 | return numeric{}, err
55 | }
56 |
57 | microSec := float64(microSecond)
58 | n := len(v.val[20:])
59 | for n != 0 {
60 | microSec *= 0.1
61 | n--
62 | }
63 | f = float64(i) + microSec
64 | return numeric{fval: f, typ: Float64}, nil
65 | }
66 | return numeric{ival: i, typ: Int64}, nil
67 | case querypb.Type_DATE:
68 | var i int64
69 | year, err := strconv.ParseInt(string(v.val[0:4]), 10, 16)
70 | if err != nil {
71 | return numeric{}, err
72 | }
73 | month, err := strconv.ParseInt(string(v.val[5:7]), 10, 8)
74 | if err != nil {
75 | return numeric{}, err
76 | }
77 | day, err := strconv.ParseInt(string(v.val[8:]), 10, 8)
78 | if err != nil {
79 | return numeric{}, err
80 | }
81 | i = year*10000 + month*100 + day
82 | return numeric{ival: i, typ: Int64}, nil
83 | case querypb.Type_TIME:
84 | var i int64
85 | sub := strings.Split(string(v.val), ":")
86 | if len(sub) != 3 {
87 | return numeric{}, errors.New("incorrect.time.value,':'.is.not.found")
88 | }
89 |
90 | pre := int64(1)
91 | if strings.HasPrefix(sub[0], "-") {
92 | pre = -1
93 | sub[0] = sub[0][1:]
94 | }
95 |
96 | hour, err := strconv.ParseInt(string(sub[0]), 10, 32)
97 | if err != nil {
98 | return numeric{}, err
99 | }
100 | minute, err := strconv.ParseInt(string(sub[1]), 10, 8)
101 | if err != nil {
102 | return numeric{}, err
103 | }
104 |
105 | if strings.Contains(sub[2], ".") {
106 | second, err := strconv.ParseFloat(sub[2], 64)
107 | if err != nil {
108 | return numeric{}, err
109 | }
110 | f := float64(pre) * (float64(hour)*10000 + float64(minute)*100 + second)
111 | return numeric{fval: f, typ: Float64}, nil
112 | }
113 |
114 | second, err := strconv.ParseInt(sub[2], 10, 8)
115 | if err != nil {
116 | return numeric{}, err
117 | }
118 | i = pre * (hour*10000 + minute*100 + second)
119 | return numeric{ival: i, typ: Int64}, nil
120 | case querypb.Type_YEAR:
121 | val, err := strconv.ParseUint(v.ToString(), 10, 16)
122 | if err != nil {
123 | return numeric{}, err
124 | }
125 | return numeric{uval: val, typ: Uint64}, nil
126 | }
127 | return numeric{}, errors.New("unsupport: unknown.type")
128 | }
129 |
--------------------------------------------------------------------------------
/sqlparser/depends/sqltypes/time_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | */
8 |
9 | package sqltypes
10 |
11 | import (
12 | "reflect"
13 | "testing"
14 |
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func TesttimeToNumeric(t *testing.T) {
19 | testcases := []struct {
20 | in Value
21 | out interface{}
22 | }{
23 | {
24 | in: testVal(Timestamp, "2012-02-24 23:19:43"),
25 | out: int64(20120224231943),
26 | },
27 | {
28 | in: testVal(Timestamp, "2012-02-24 23:19:43.120"),
29 | out: float64(20120224231943.120),
30 | },
31 | {
32 | in: testVal(Time, "-23:19:43.120"),
33 | out: float64(-231943.120),
34 | },
35 | {
36 | in: testVal(Time, "-63:19:43"),
37 | out: int64(-631943),
38 | },
39 | {
40 | in: testVal(Datetime, "0000-00-00 00:00:00"),
41 | out: int64(0),
42 | },
43 | {
44 | in: testVal(Datetime, "2012-02-24 23:19:43.000012"),
45 | out: float64(20120224231943.000012),
46 | },
47 | {
48 | in: testVal(Date, "0000-00-00"),
49 | out: int64(0),
50 | },
51 | {
52 | in: testVal(Date, "2012-02-24"),
53 | out: int64(20120224),
54 | },
55 | {
56 | in: testVal(Year, "2012"),
57 | out: uint64(2012),
58 | },
59 | {
60 | in: testVal(Year, "12"),
61 | out: uint64(12),
62 | },
63 | }
64 |
65 | for _, tcase := range testcases {
66 | got, err := timeToNumeric(tcase.in)
67 | assert.Nil(t, err)
68 |
69 | var v interface{}
70 | switch got.typ {
71 | case Uint64:
72 | v = got.uval
73 | case Float64:
74 | v = got.fval
75 | case Int64:
76 | v = got.ival
77 | }
78 |
79 | if !reflect.DeepEqual(v, tcase.out) {
80 | t.Errorf("%v.ToNative = %#v, want %#v", makePretty(tcase.in), v, tcase.out)
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/sqlparser/encodable.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "strings"
21 |
22 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
23 | )
24 |
25 | // This file contains types that are 'Encodable'.
26 |
27 | // Encodable defines the interface for types that can
28 | // be custom-encoded into SQL.
29 | type Encodable interface {
30 | EncodeSQL(buf *strings.Builder)
31 | }
32 |
33 | // InsertValues is a custom SQL encoder for the values of
34 | // an insert statement.
35 | type InsertValues [][]sqltypes.Value
36 |
37 | // EncodeSQL performs the SQL encoding for InsertValues.
38 | func (iv InsertValues) EncodeSQL(buf *strings.Builder) {
39 | for i, rows := range iv {
40 | if i != 0 {
41 | buf.WriteString(", ")
42 | }
43 | buf.WriteByte('(')
44 | for j, bv := range rows {
45 | if j != 0 {
46 | buf.WriteString(", ")
47 | }
48 | bv.EncodeSQL(buf)
49 | }
50 | buf.WriteByte(')')
51 | }
52 | }
53 |
54 | // TupleEqualityList is for generating equality constraints
55 | // for tables that have composite primary keys.
56 | type TupleEqualityList struct {
57 | Columns []ColIdent
58 | Rows [][]sqltypes.Value
59 | }
60 |
61 | // EncodeSQL generates the where clause constraints for the tuple
62 | // equality.
63 | func (tpl *TupleEqualityList) EncodeSQL(buf *strings.Builder) {
64 | if len(tpl.Columns) == 1 {
65 | tpl.encodeAsIn(buf)
66 | return
67 | }
68 | tpl.encodeAsEquality(buf)
69 | }
70 |
71 | func (tpl *TupleEqualityList) encodeAsIn(buf *strings.Builder) {
72 | Append(buf, tpl.Columns[0])
73 | buf.WriteString(" in (")
74 | for i, r := range tpl.Rows {
75 | if i != 0 {
76 | buf.WriteString(", ")
77 | }
78 | r[0].EncodeSQL(buf)
79 | }
80 | buf.WriteByte(')')
81 | }
82 |
83 | func (tpl *TupleEqualityList) encodeAsEquality(buf *strings.Builder) {
84 | for i, r := range tpl.Rows {
85 | if i != 0 {
86 | buf.WriteString(" or ")
87 | }
88 | buf.WriteString("(")
89 | for j, c := range tpl.Columns {
90 | if j != 0 {
91 | buf.WriteString(" and ")
92 | }
93 | Append(buf, c)
94 | buf.WriteString(" = ")
95 | r[j].EncodeSQL(buf)
96 | }
97 | buf.WriteByte(')')
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/sqlparser/encodable_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "strings"
21 | "testing"
22 |
23 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
24 | )
25 |
26 | func TestEncodable(t *testing.T) {
27 | tcases := []struct {
28 | in Encodable
29 | out string
30 | }{{
31 | in: InsertValues{{
32 | sqltypes.NewInt64(1),
33 | sqltypes.NewVarBinary("foo('a')"),
34 | }, {
35 | sqltypes.NewInt64(2),
36 | sqltypes.NewVarBinary("bar(`b`)"),
37 | }},
38 | out: "(1, 'foo(\\'a\\')'), (2, 'bar(`b`)')",
39 | }, {
40 | // Single column.
41 | in: &TupleEqualityList{
42 | Columns: []ColIdent{NewColIdent("pk")},
43 | Rows: [][]sqltypes.Value{
44 | {sqltypes.NewInt64(1)},
45 | {sqltypes.NewVarBinary("aa")},
46 | },
47 | },
48 | out: "pk in (1, 'aa')",
49 | }, {
50 | // Multiple columns.
51 | in: &TupleEqualityList{
52 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")},
53 | Rows: [][]sqltypes.Value{
54 | {
55 | sqltypes.NewInt64(1),
56 | sqltypes.NewVarBinary("aa"),
57 | },
58 | {
59 | sqltypes.NewInt64(2),
60 | sqltypes.NewVarBinary("bb"),
61 | },
62 | },
63 | },
64 | out: "(pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')",
65 | }}
66 | for _, tcase := range tcases {
67 | buf := new(strings.Builder)
68 | tcase.in.EncodeSQL(buf)
69 | if out := buf.String(); out != tcase.out {
70 | t.Errorf("EncodeSQL(%v): %s, want %s", tcase.in, out, tcase.out)
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/sqlparser/explain_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "strings"
20 | import "testing"
21 |
22 | func TestExplain(t *testing.T) {
23 | validSQL := []struct {
24 | input string
25 | output string
26 | }{
27 | {
28 | input: "explain select * from 1",
29 | output: "explain",
30 | },
31 | }
32 |
33 | for _, exp := range validSQL {
34 | sql := strings.TrimSpace(exp.input)
35 | tree, err := Parse(sql)
36 | if err != nil {
37 | t.Errorf("input: %s, err: %v", sql, err)
38 | continue
39 | }
40 |
41 | // Walk.
42 | Walk(func(node SQLNode) (bool, error) {
43 | return true, nil
44 | }, tree)
45 |
46 | got := String(tree.(*Explain))
47 | if exp.output != got {
48 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/sqlparser/impossible_query.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreedto in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | // FormatImpossibleQuery creates an impossible query in a TrackedBuffer.
20 | // An impossible query is a modified version of a query where all selects have where clauses that are
21 | // impossible for mysql to resolve. This is used in the vtgate and vttablet:
22 | //
23 | // - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible
24 | // query just to fetch field info from vttablet
25 | // - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused
26 | // for subsequent queries
27 | func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
28 | switch node := node.(type) {
29 | case *Select:
30 | buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
31 | if node.GroupBy != nil {
32 | node.GroupBy.Format(buf)
33 | }
34 | case *Union:
35 | buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
36 | default:
37 | node.Format(buf)
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/sqlparser/impossible_query_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Radon
3 | *
4 | * Copyright 2019 The Radon Authors.
5 | * Code is licensed under the GPLv3.
6 | *
7 | */
8 |
9 | package sqlparser
10 |
11 | import (
12 | "testing"
13 |
14 | "github.com/stretchr/testify/assert"
15 | )
16 |
17 | func TestFormatImpossibleQuery(t *testing.T) {
18 | querys := []string{"select a,b from A where A.id>1 group by a order by a limit 1",
19 | "select id,a from A union select name,a from B order by a",
20 | "insert into A(a,b) values(1,'a')"}
21 | wants := []string{"select a, b from A where 1 != 1 group by a",
22 | "select id, a from A union select name, a from B",
23 | "insert into A(a, b) values (1, 'a')"}
24 | for i, query := range querys {
25 | node, err := Parse(query)
26 | assert.Nil(t, err)
27 | buf := NewTrackedBuffer(nil)
28 | FormatImpossibleQuery(buf, node)
29 | got := buf.String()
30 | assert.Equal(t, wants[i], got)
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/sqlparser/kill_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "strings"
21 | "testing"
22 | )
23 |
24 | func TestKill(t *testing.T) {
25 | validSQL := []struct {
26 | input string
27 | output string
28 | }{
29 | {
30 | input: "kill 1",
31 | output: "kill 1",
32 | },
33 |
34 | {
35 | input: "kill 10000000000000000000000000000000",
36 | output: "kill 10000000000000000000000000000000",
37 | },
38 |
39 | {
40 | input: "kill query 1",
41 | output: "kill 1",
42 | },
43 | }
44 |
45 | for _, exp := range validSQL {
46 | sql := strings.TrimSpace(exp.input)
47 | tree, err := Parse(sql)
48 | if err != nil {
49 | t.Errorf("input: %s, err: %v", sql, err)
50 | continue
51 | }
52 |
53 | // Walk.
54 | Walk(func(node SQLNode) (bool, error) {
55 | return true, nil
56 | }, tree)
57 |
58 | node := tree.(*Kill)
59 | node.QueryID.AsUint64()
60 |
61 | // Format.
62 | got := String(node)
63 | if exp.output != got {
64 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/sqlparser/normalizer.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import (
8 | "fmt"
9 |
10 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
11 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
12 | )
13 |
14 | // Normalize changes the statement to use bind values, and
15 | // updates the bind vars to those values. The supplied prefix
16 | // is used to generate the bind var names. The function ensures
17 | // that there are no collisions with existing bind vars.
18 | func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) {
19 | reserved := GetBindvars(stmt)
20 | // vals allows us to reuse bindvars for
21 | // identical values.
22 | counter := 1
23 | vals := make(map[string]string)
24 | _ = Walk(func(node SQLNode) (kontinue bool, err error) {
25 | switch node := node.(type) {
26 | case *SQLVal:
27 | // Make the bindvar
28 | bval := sqlToBindvar(node)
29 | if bval == nil {
30 | // If unsuccessful continue.
31 | return true, nil
32 | }
33 | // Check if there's a bindvar for that value already.
34 | var key string
35 | if bval.Type == sqltypes.VarBinary {
36 | // Prefixing strings with "'" ensures that a string
37 | // and number that have the same representation don't
38 | // collide.
39 | key = "'" + string(node.Val)
40 | } else {
41 | key = string(node.Val)
42 | }
43 | bvname, ok := vals[key]
44 | if !ok {
45 | // If there's no such bindvar, make a new one.
46 | bvname, counter = newName(prefix, counter, reserved)
47 | vals[key] = bvname
48 | bindVars[bvname] = bval
49 | }
50 | // Modify the AST node to a bindvar.
51 | node.Type = ValArg
52 | node.Val = append([]byte(":"), bvname...)
53 | case *ComparisonExpr:
54 | switch node.Operator {
55 | case InStr, NotInStr:
56 | default:
57 | return true, nil
58 | }
59 | // It's either IN or NOT IN.
60 | tupleVals, ok := node.Right.(ValTuple)
61 | if !ok {
62 | return true, nil
63 | }
64 | // The RHS is a tuple of values.
65 | // Make a list bindvar.
66 | bvals := &querypb.BindVariable{
67 | Type: sqltypes.Tuple,
68 | }
69 | for _, val := range tupleVals {
70 | bval := sqlToBindvar(val)
71 | if bval == nil {
72 | return true, nil
73 | }
74 | bvals.Values = append(bvals.Values, &querypb.Value{
75 | Type: bval.Type,
76 | Value: bval.Value,
77 | })
78 | }
79 | var bvname string
80 | bvname, counter = newName(prefix, counter, reserved)
81 | bindVars[bvname] = bvals
82 | // Modify RHS to be a list bindvar.
83 | node.Right = ListArg(append([]byte("::"), bvname...))
84 | }
85 | return true, nil
86 | }, stmt)
87 | }
88 |
89 | func sqlToBindvar(node SQLNode) *querypb.BindVariable {
90 | if node, ok := node.(*SQLVal); ok {
91 | switch node.Type {
92 | case StrVal:
93 | return &querypb.BindVariable{Type: sqltypes.VarBinary, Value: node.Val}
94 | case IntVal:
95 | return &querypb.BindVariable{Type: sqltypes.Int64, Value: node.Val}
96 | case FloatVal:
97 | return &querypb.BindVariable{Type: sqltypes.Float64, Value: node.Val}
98 | }
99 | }
100 | return nil
101 | }
102 |
103 | func newName(prefix string, counter int, reserved map[string]struct{}) (string, int) {
104 | for {
105 | newName := fmt.Sprintf("%s%d", prefix, counter)
106 | if _, ok := reserved[newName]; !ok {
107 | reserved[newName] = struct{}{}
108 | return newName, counter + 1
109 | }
110 | counter++
111 | }
112 | }
113 |
114 | // GetBindvars returns a map of the bind vars referenced in the statement.
115 | // TODO(sougou); This function gets called again from vtgate/planbuilder.
116 | // Ideally, this should be done only once.
117 | func GetBindvars(stmt Statement) map[string]struct{} {
118 | bindvars := make(map[string]struct{})
119 | _ = Walk(func(node SQLNode) (kontinue bool, err error) {
120 | switch node := node.(type) {
121 | case *SQLVal:
122 | if node.Type == ValArg {
123 | bindvars[string(node.Val[1:])] = struct{}{}
124 | }
125 | case ListArg:
126 | bindvars[string(node[2:])] = struct{}{}
127 | }
128 | return true, nil
129 | }, stmt)
130 | return bindvars
131 | }
132 |
--------------------------------------------------------------------------------
/sqlparser/normalizer_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import (
8 | "reflect"
9 | "testing"
10 |
11 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
12 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
13 | )
14 |
15 | func TestNormalize(t *testing.T) {
16 | prefix := "bv"
17 | testcases := []struct {
18 | in string
19 | outstmt string
20 | outbv map[string]*querypb.BindVariable
21 | }{{
22 | // str val
23 | in: "select * from t where v1 = 'aa'",
24 | outstmt: "select * from t where v1 = :bv1",
25 | outbv: map[string]*querypb.BindVariable{
26 | "bv1": &querypb.BindVariable{
27 | Type: sqltypes.VarBinary,
28 | Value: []byte("aa"),
29 | },
30 | },
31 | }, {
32 | // int val
33 | in: "select * from t where v1 = 1",
34 | outstmt: "select * from t where v1 = :bv1",
35 | outbv: map[string]*querypb.BindVariable{
36 | "bv1": &querypb.BindVariable{
37 | Type: sqltypes.Int64,
38 | Value: []byte("1"),
39 | },
40 | },
41 | }, {
42 | // float val
43 | in: "select * from t where v1 = 1.2",
44 | outstmt: "select * from t where v1 = :bv1",
45 | outbv: map[string]*querypb.BindVariable{
46 | "bv1": &querypb.BindVariable{
47 | Type: sqltypes.Float64,
48 | Value: []byte("1.2"),
49 | },
50 | },
51 | }, {
52 | // multiple vals
53 | in: "select * from t where v1 = 1.2 and v2 = 2",
54 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2",
55 | outbv: map[string]*querypb.BindVariable{
56 | "bv1": &querypb.BindVariable{
57 | Type: sqltypes.Float64,
58 | Value: []byte("1.2"),
59 | },
60 | "bv2": &querypb.BindVariable{
61 | Type: sqltypes.Int64,
62 | Value: []byte("2"),
63 | },
64 | },
65 | }, {
66 | // bv collision
67 | in: "select * from t where v1 = :bv1 and v2 = 1",
68 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2",
69 | outbv: map[string]*querypb.BindVariable{
70 | "bv2": &querypb.BindVariable{
71 | Type: sqltypes.Int64,
72 | Value: []byte("1"),
73 | },
74 | },
75 | }, {
76 | // val reuse
77 | in: "select * from t where v1 = 1 and v2 = 1",
78 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv1",
79 | outbv: map[string]*querypb.BindVariable{
80 | "bv1": &querypb.BindVariable{
81 | Type: sqltypes.Int64,
82 | Value: []byte("1"),
83 | },
84 | },
85 | }, {
86 | // ints and strings are different
87 | in: "select * from t where v1 = 1 and v2 = '1'",
88 | outstmt: "select * from t where v1 = :bv1 and v2 = :bv2",
89 | outbv: map[string]*querypb.BindVariable{
90 | "bv1": &querypb.BindVariable{
91 | Type: sqltypes.Int64,
92 | Value: []byte("1"),
93 | },
94 | "bv2": &querypb.BindVariable{
95 | Type: sqltypes.VarBinary,
96 | Value: []byte("1"),
97 | },
98 | },
99 | }, {
100 | // comparison with no vals
101 | in: "select * from t where v1 = v2",
102 | outstmt: "select * from t where v1 = v2",
103 | outbv: map[string]*querypb.BindVariable{},
104 | }, {
105 | // IN clause with existing bv
106 | in: "select * from t where v1 in ::list",
107 | outstmt: "select * from t where v1 in ::list",
108 | outbv: map[string]*querypb.BindVariable{},
109 | }, {
110 | // IN clause with non-val values
111 | in: "select * from t where v1 in (1, a)",
112 | outstmt: "select * from t where v1 in (:bv1, a)",
113 | outbv: map[string]*querypb.BindVariable{
114 | "bv1": &querypb.BindVariable{
115 | Type: sqltypes.Int64,
116 | Value: []byte("1"),
117 | },
118 | },
119 | }, {
120 | // IN clause with vals
121 | in: "select * from t where v1 in (1, '2')",
122 | outstmt: "select * from t where v1 in ::bv1",
123 | outbv: map[string]*querypb.BindVariable{
124 | "bv1": &querypb.BindVariable{
125 | Type: sqltypes.Tuple,
126 | Values: []*querypb.Value{{
127 | Type: sqltypes.Int64,
128 | Value: []byte("1"),
129 | }, {
130 | Type: sqltypes.VarBinary,
131 | Value: []byte("2"),
132 | }},
133 | },
134 | },
135 | }, {
136 | // NOT IN clause
137 | in: "select * from t where v1 not in (1, '2')",
138 | outstmt: "select * from t where v1 not in ::bv1",
139 | outbv: map[string]*querypb.BindVariable{
140 | "bv1": &querypb.BindVariable{
141 | Type: sqltypes.Tuple,
142 | Values: []*querypb.Value{{
143 | Type: sqltypes.Int64,
144 | Value: []byte("1"),
145 | }, {
146 | Type: sqltypes.VarBinary,
147 | Value: []byte("2"),
148 | }},
149 | },
150 | },
151 | }}
152 | for _, tc := range testcases {
153 | stmt, err := Parse(tc.in)
154 | if err != nil {
155 | t.Error(err)
156 | continue
157 | }
158 | bv := make(map[string]*querypb.BindVariable)
159 | Normalize(stmt, bv, prefix)
160 | outstmt := String(stmt)
161 | if outstmt != tc.outstmt {
162 | t.Errorf("Query:\n%s:\n%s, want\n%s", tc.in, outstmt, tc.outstmt)
163 | }
164 | if !reflect.DeepEqual(tc.outbv, bv) {
165 | t.Errorf("Query:\n%s:\n%v, want\n%v", tc.in, bv, tc.outbv)
166 | }
167 | }
168 | }
169 |
170 | func TestGetBindVars(t *testing.T) {
171 | stmt, err := Parse("select * from t where :v1 = :v2 and :v2 = :v3 and :v4 in ::v5")
172 | if err != nil {
173 | t.Fatal(err)
174 | }
175 | got := GetBindvars(stmt)
176 | want := map[string]struct{}{
177 | "v1": {},
178 | "v2": {},
179 | "v3": {},
180 | "v4": {},
181 | "v5": {},
182 | }
183 | if !reflect.DeepEqual(got, want) {
184 | t.Errorf("GetBindVars: %v, want %v", got, want)
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/sqlparser/parsed_query.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "encoding/json"
21 | "fmt"
22 | "strings"
23 |
24 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
25 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
26 | )
27 |
28 | // ParsedQuery represents a parsed query where
29 | // bind locations are precompued for fast substitutions.
30 | type ParsedQuery struct {
31 | Query string
32 | bindLocations []bindLocation
33 | }
34 |
35 | type bindLocation struct {
36 | offset, length int
37 | }
38 |
39 | // NewParsedQuery returns a ParsedQuery of the ast.
40 | func NewParsedQuery(node SQLNode) *ParsedQuery {
41 | buf := NewTrackedBuffer(nil)
42 | buf.Myprintf("%v", node)
43 | return buf.ParsedQuery()
44 | }
45 |
46 | // GenerateQuery generates a query by substituting the specified
47 | // bindVariables. The extras parameter specifies special parameters
48 | // that can perform custom encoding.
49 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) (string, error) {
50 | if len(pq.bindLocations) == 0 {
51 | return pq.Query, nil
52 | }
53 | var buf strings.Builder
54 | buf.Grow(len(pq.Query))
55 | current := 0
56 | for _, loc := range pq.bindLocations {
57 | buf.WriteString(pq.Query[current:loc.offset])
58 | name := pq.Query[loc.offset : loc.offset+loc.length]
59 | if encodable, ok := extras[name[1:]]; ok {
60 | encodable.EncodeSQL(&buf)
61 | } else {
62 | supplied, _, err := FetchBindVar(name, bindVariables)
63 | if err != nil {
64 | return "", err
65 | }
66 | EncodeValue(&buf, supplied)
67 | }
68 | current = loc.offset + loc.length
69 | }
70 | buf.WriteString(pq.Query[current:])
71 | return buf.String(), nil
72 | }
73 |
74 | // MarshalJSON is a custom JSON marshaler for ParsedQuery.
75 | // Note that any queries longer that 512 bytes will be truncated.
76 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) {
77 | return json.Marshal(pq.Query)
78 | }
79 |
80 | // EncodeValue encodes one bind variable value into the query.
81 | func EncodeValue(buf *strings.Builder, value *querypb.BindVariable) {
82 | if value.Type != querypb.Type_TUPLE {
83 | // Since we already check for TUPLE, we don't expect an error.
84 | v, _ := sqltypes.BindVariableToValue(value)
85 | v.EncodeSQL(buf)
86 | return
87 | }
88 |
89 | // It's a TUPLE.
90 | buf.WriteByte('(')
91 | for i, bv := range value.Values {
92 | if i != 0 {
93 | buf.WriteString(", ")
94 | }
95 | sqltypes.ProtoToValue(bv).EncodeSQL(buf)
96 | }
97 | buf.WriteByte(')')
98 | }
99 |
100 | // FetchBindVar resolves the bind variable by fetching it from bindVariables.
101 | func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) {
102 | name = name[1:]
103 | if name[0] == ':' {
104 | name = name[1:]
105 | isList = true
106 | }
107 | supplied, ok := bindVariables[name]
108 | if !ok {
109 | return nil, false, fmt.Errorf("missing bind var %s", name)
110 | }
111 |
112 | if isList {
113 | if supplied.Type != querypb.Type_TUPLE {
114 | return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name)
115 | }
116 | if len(supplied.Values) == 0 {
117 | return nil, false, fmt.Errorf("empty list supplied for %s", name)
118 | }
119 | return supplied, true, nil
120 | }
121 |
122 | if supplied.Type == querypb.Type_TUPLE {
123 | return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name)
124 | }
125 |
126 | return supplied, false, nil
127 | }
128 |
--------------------------------------------------------------------------------
/sqlparser/parsed_query_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "reflect"
21 | "testing"
22 |
23 | querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
24 | "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
25 | )
26 |
27 | func TestNewParsedQuery(t *testing.T) {
28 | stmt, err := Parse("select * from a where id =:id")
29 | if err != nil {
30 | t.Error(err)
31 | return
32 | }
33 | pq := NewParsedQuery(stmt)
34 | want := &ParsedQuery{
35 | Query: "select * from a where id = :id",
36 | bindLocations: []bindLocation{{offset: 27, length: 3}},
37 | }
38 | if !reflect.DeepEqual(pq, want) {
39 | t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want)
40 | }
41 | }
42 |
43 | func TestGenerateQuery(t *testing.T) {
44 | tcases := []struct {
45 | desc string
46 | query string
47 | bindVars map[string]*querypb.BindVariable
48 | extras map[string]Encodable
49 | output string
50 | }{
51 | {
52 | desc: "no substitutions",
53 | query: "select * from a where id = 2",
54 | bindVars: map[string]*querypb.BindVariable{
55 | "id": sqltypes.Int64BindVariable(1),
56 | },
57 | output: "select * from a where id = 2",
58 | }, {
59 | desc: "missing bind var",
60 | query: "select * from a where id1 = :id1 and id2 = :id2",
61 | bindVars: map[string]*querypb.BindVariable{
62 | "id1": sqltypes.Int64BindVariable(1),
63 | },
64 | output: "missing bind var id2",
65 | }, {
66 | desc: "simple bindvar substitution",
67 | query: "select * from a where id1 = :id1 and id2 = :id2",
68 | bindVars: map[string]*querypb.BindVariable{
69 | "id1": sqltypes.Int64BindVariable(1),
70 | "id2": sqltypes.NullBindVariable,
71 | },
72 | output: "select * from a where id1 = 1 and id2 = null",
73 | }, {
74 | desc: "non-list bind var supplied",
75 | query: "select * from a where id in ::vals",
76 | bindVars: map[string]*querypb.BindVariable{
77 | "vals": sqltypes.Int64BindVariable(1),
78 | },
79 | output: "unexpected list arg type (INT64) for key vals",
80 | }, {
81 | desc: "single column tuple equality",
82 | query: "select * from a where b = :equality",
83 | extras: map[string]Encodable{
84 | "equality": &TupleEqualityList{
85 | Columns: []ColIdent{NewColIdent("pk")},
86 | Rows: [][]sqltypes.Value{
87 | {sqltypes.NewInt64(1)},
88 | {sqltypes.NewVarBinary("aa")},
89 | },
90 | },
91 | },
92 | output: "select * from a where b = pk in (1, 'aa')",
93 | }, {
94 | desc: "multi column tuple equality",
95 | query: "select * from a where b = :equality",
96 | extras: map[string]Encodable{
97 | "equality": &TupleEqualityList{
98 | Columns: []ColIdent{NewColIdent("pk1"), NewColIdent("pk2")},
99 | Rows: [][]sqltypes.Value{
100 | {
101 | sqltypes.NewInt64(1),
102 | sqltypes.NewVarBinary("aa"),
103 | },
104 | {
105 | sqltypes.NewInt64(2),
106 | sqltypes.NewVarBinary("bb"),
107 | },
108 | },
109 | },
110 | },
111 | output: "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')",
112 | },
113 | }
114 |
115 | for _, tcase := range tcases {
116 | tree, err := Parse(tcase.query)
117 | if err != nil {
118 | t.Errorf("parse failed for %s: %v", tcase.desc, err)
119 | continue
120 | }
121 | buf := NewTrackedBuffer(nil)
122 | buf.Myprintf("%v", tree)
123 | pq := buf.ParsedQuery()
124 | bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.extras)
125 | var got string
126 | if err != nil {
127 | got = err.Error()
128 | } else {
129 | got = string(bytes)
130 | }
131 | if got != tcase.output {
132 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output)
133 | }
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/sqlparser/parser.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "errors"
21 | "sync"
22 | )
23 |
24 | // parserPool is a pool for parser objects.
25 | var parserPool = sync.Pool{}
26 |
27 | // zeroParser is a zero-initialized parser to help reinitialize the parser for pooling.
28 | var zeroParser = *(yyNewParser().(*yyParserImpl))
29 |
30 | // yyParsePooled is a wrapper around yyParse that pools the parser objects. There isn't a
31 | // particularly good reason to use yyParse directly, since it immediately discards its parser. What
32 | // would be ideal down the line is to actually pool the stacks themselves rather than the parser
33 | // objects, as per https://github.com/cznic/goyacc/blob/master/main.go. However, absent an upstream
34 | // change to goyacc, this is the next best option.
35 | //
36 | // N.B: Parser pooling means that you CANNOT take references directly to parse stack variables (e.g.
37 | // $$ = &$4) in sql.y rules. You must instead add an intermediate reference like so:
38 | // showCollationFilterOpt := $4
39 | // $$ = &Show{Type: string($2), ShowCollationFilterOpt: &showCollationFilterOpt}
40 | func yyParsePooled(yylex yyLexer) int {
41 | // Being very particular about using the base type and not an interface type b/c we depend on
42 | // the implementation to know how to reinitialize the parser.
43 | var parser *yyParserImpl
44 |
45 | i := parserPool.Get()
46 | if i != nil {
47 | parser = i.(*yyParserImpl)
48 | } else {
49 | parser = yyNewParser().(*yyParserImpl)
50 | }
51 |
52 | defer func() {
53 | *parser = zeroParser
54 | parserPool.Put(parser)
55 | }()
56 | return parser.Parse(yylex)
57 | }
58 |
59 | // Instructions for creating new types: If a type
60 | // needs to satisfy an interface, declare that function
61 | // along with that interface. This will help users
62 | // identify the list of types to which they can assert
63 | // those interfaces.
64 | // If the member of a type has a string with a predefined
65 | // list of values, declare those values as const following
66 | // the type.
67 | // For interfaces that define dummy functions to consolidate
68 | // a set of types, define the function as iTypeName.
69 | // This will help avoid name collisions.
70 |
71 | // Parse parses the sql and returns a Statement, which
72 | // is the AST representation of the query. If a DDL statement
73 | // is partially parsed but still contains a syntax error, the
74 | // error is ignored and the DDL is returned anyway.
75 | func Parse(sql string) (Statement, error) {
76 | tokenizer := NewStringTokenizer(sql)
77 | if yyParse(tokenizer) != 0 {
78 | return nil, errors.New(tokenizer.LastError)
79 | }
80 | return tokenizer.ParseTree, nil
81 | }
82 |
83 | // ParseStrictDDL is the same as Parse except it errors on
84 | // partially parsed DDL statements.
85 | func ParseStrictDDL(sql string) (Statement, error) {
86 | tokenizer := NewStringTokenizer(sql)
87 | if yyParsePooled(tokenizer) != 0 {
88 | return nil, errors.New(tokenizer.LastError)
89 | }
90 | if tokenizer.ParseTree == nil {
91 | return nil, nil
92 | }
93 | return tokenizer.ParseTree, nil
94 | }
95 |
96 | // String returns a string representation of an SQLNode.
97 | func String(node SQLNode) string {
98 | buf := NewTrackedBuffer(nil)
99 | buf.Myprintf("%v", node)
100 | return buf.String()
101 | }
102 |
--------------------------------------------------------------------------------
/sqlparser/precedence_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "fmt"
21 | "testing"
22 | )
23 |
24 | func readable(node Expr) string {
25 | switch node := node.(type) {
26 | case *OrExpr:
27 | return fmt.Sprintf("(%s or %s)", readable(node.Left), readable(node.Right))
28 | case *AndExpr:
29 | return fmt.Sprintf("(%s and %s)", readable(node.Left), readable(node.Right))
30 | case *BinaryExpr:
31 | return fmt.Sprintf("(%s %s %s)", readable(node.Left), node.Operator, readable(node.Right))
32 | case *IsExpr:
33 | return fmt.Sprintf("(%s %s)", readable(node.Expr), node.Operator)
34 | default:
35 | return String(node)
36 | }
37 | }
38 |
39 | func TestAndOrPrecedence(t *testing.T) {
40 | validSQL := []struct {
41 | input string
42 | output string
43 | }{{
44 | input: "select * from a where a=b and c=d or e=f",
45 | output: "((a = b and c = d) or e = f)",
46 | }, {
47 | input: "select * from a where a=b or c=d and e=f",
48 | output: "(a = b or (c = d and e = f))",
49 | }}
50 | for _, tcase := range validSQL {
51 | tree, err := Parse(tcase.input)
52 | if err != nil {
53 | t.Error(err)
54 | continue
55 | }
56 | expr := readable(tree.(*Select).Where.Expr)
57 | if expr != tcase.output {
58 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output)
59 | }
60 | }
61 | }
62 |
63 | func TestPlusStarPrecedence(t *testing.T) {
64 | validSQL := []struct {
65 | input string
66 | output string
67 | }{{
68 | input: "select 1+2*3 from a",
69 | output: "(1 + (2 * 3))",
70 | }, {
71 | input: "select 1*2+3 from a",
72 | output: "((1 * 2) + 3)",
73 | }}
74 | for _, tcase := range validSQL {
75 | tree, err := Parse(tcase.input)
76 | if err != nil {
77 | t.Error(err)
78 | continue
79 | }
80 | expr := readable(tree.(*Select).SelectExprs[0].(*AliasedExpr).Expr)
81 | if expr != tcase.output {
82 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output)
83 | }
84 | }
85 | }
86 |
87 | func TestIsPrecedence(t *testing.T) {
88 | validSQL := []struct {
89 | input string
90 | output string
91 | }{{
92 | input: "select * from a where a+b is true",
93 | output: "((a + b) is true)",
94 | }, {
95 | input: "select * from a where a=1 and b=2 is true",
96 | output: "(a = 1 and (b = 2 is true))",
97 | }, {
98 | input: "select * from a where (a=1 and b=2) is true",
99 | output: "((a = 1 and b = 2) is true)",
100 | }}
101 | for _, tcase := range validSQL {
102 | tree, err := Parse(tcase.input)
103 | if err != nil {
104 | t.Error(err)
105 | continue
106 | }
107 | expr := readable(tree.(*Select).Where.Expr)
108 | if expr != tcase.output {
109 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output)
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/sqlparser/radon_test.go:
--------------------------------------------------------------------------------
1 | package sqlparser
2 |
3 | import (
4 | "strings"
5 | "testing"
6 | )
7 |
8 | func TestRadon(t *testing.T) {
9 | validSQL := []struct {
10 | input string
11 | output string
12 | }{
13 | // name, address, user, password.
14 | {
15 | input: "radon attach ('attach1', '127.0.0.1:6000', 'root', '123456')",
16 | output: "radon attach ('attach1', '127.0.0.1:6000', 'root', '123456')",
17 | },
18 | {
19 | input: "radon attachlist",
20 | output: "radon attachlist",
21 | },
22 | {
23 | input: "radon detach('attach1')",
24 | output: "radon detach ('attach1')",
25 | },
26 | {
27 | input: "radon reshard db.t db.tt",
28 | output: "radon reshard db.t to db.tt",
29 | },
30 | {
31 | input: "radon reshard db.t to a.tt",
32 | output: "radon reshard db.t to a.tt",
33 | },
34 | {
35 | input: "radon reshard db.t as b.tt",
36 | output: "radon reshard db.t to b.tt",
37 | },
38 | {
39 | input: "radon cleanup",
40 | output: "radon cleanup",
41 | },
42 | }
43 |
44 | for _, exp := range validSQL {
45 | sql := strings.TrimSpace(exp.input)
46 | tree, err := Parse(sql)
47 | if err != nil {
48 | t.Errorf("input: %s, err: %v", sql, err)
49 | continue
50 | }
51 |
52 | // Walk.
53 | Walk(func(node SQLNode) (bool, error) {
54 | return true, nil
55 | }, tree)
56 |
57 | got := String(tree.(*Radon))
58 | if exp.output != got {
59 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
60 | }
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/sqlparser/rewriter_api.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | // The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go
20 |
21 | // Rewrite traverses a syntax tree recursively, starting with root,
22 | // and calling pre and post for each node as described below.
23 | // Rewrite returns the syntax tree, possibly modified.
24 | //
25 | // If pre is not nil, it is called for each node before the node's
26 | // children are traversed (pre-order). If pre returns false, no
27 | // children are traversed, and post is not called for that node.
28 | //
29 | // If post is not nil, and a prior call of pre didn't return false,
30 | // post is called for each node after its children are traversed
31 | // (post-order). If post returns false, traversal is terminated and
32 | // Apply returns immediately.
33 | //
34 | // Only fields that refer to AST nodes are considered children;
35 | // i.e., fields of basic types (strings, []byte, etc.) are ignored.
36 | //
37 | func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) {
38 | parent := &struct{ SQLNode }{node}
39 | defer func() {
40 | if r := recover(); r != nil && r != abort {
41 | panic(r)
42 | }
43 | result = parent.SQLNode
44 | }()
45 |
46 | a := &application{
47 | pre: pre,
48 | post: post,
49 | cursor: Cursor{},
50 | }
51 |
52 | // this is the root-replacer, used when the user replaces the root of the ast
53 | replacer := func(newNode SQLNode, _ SQLNode) {
54 | parent.SQLNode = newNode
55 | }
56 |
57 | a.apply(parent, node, replacer)
58 |
59 | return parent.SQLNode
60 | }
61 |
62 | // An ApplyFunc is invoked by Rewrite for each node n, even if n is nil,
63 | // before and/or after the node's children, using a Cursor describing
64 | // the current node and providing operations on it.
65 | //
66 | // The return value of ApplyFunc controls the syntax tree traversal.
67 | // See Rewrite for details.
68 | type ApplyFunc func(*Cursor) bool
69 |
70 | var abort = new(int) // singleton, to signal termination of Apply
71 |
72 | // A Cursor describes a node encountered during Apply.
73 | // Information about the node and its parent is available
74 | // from the Node and Parent methods.
75 | type Cursor struct {
76 | parent SQLNode
77 | replacer replacerFunc
78 | node SQLNode
79 | }
80 |
81 | // Node returns the current Node.
82 | func (c *Cursor) Node() SQLNode { return c.node }
83 |
84 | // Parent returns the parent of the current Node.
85 | func (c *Cursor) Parent() SQLNode { return c.parent }
86 |
87 | // Replace replaces the current node in the parent field with this new object. The use needs to make sure to not
88 | // replace the object with something of the wrong type, or the visitor will panic.
89 | func (c *Cursor) Replace(newNode SQLNode) {
90 | c.replacer(newNode, c.parent)
91 | }
92 |
--------------------------------------------------------------------------------
/sqlparser/select_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "strings"
20 | import "testing"
21 |
22 | func TestSelect1(t *testing.T) {
23 | validSQL := []struct {
24 | input string
25 | output string
26 | }{
27 | {
28 | input: "select * from xx",
29 | output: "select * from xx",
30 | },
31 | {
32 | input: "select * from xx where id=1",
33 | output: "select * from xx where id = 1",
34 | },
35 | }
36 |
37 | for _, sel := range validSQL {
38 | sql := strings.TrimSpace(sel.input)
39 | tree, err := Parse(sql)
40 | if err != nil {
41 | t.Errorf("input: %s, err: %v", sql, err)
42 | continue
43 | }
44 | got := String(tree.(*Select))
45 | if sel.output != got {
46 | t.Errorf("want:\n%s\ngot:\n%s", sel.output, got)
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/sqlparser/set_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "strings"
21 | "testing"
22 | )
23 |
24 | func TestSet(t *testing.T) {
25 | validSQL := []struct {
26 | input string
27 | output string
28 | }{
29 | {
30 | input: "SET @@session.s1= 'ON', @@session.s2='OFF'",
31 | output: "set @@session.s1 = 'ON', @@session.s2 = 'OFF'",
32 | },
33 |
34 | {
35 | input: "SET @@session.radon_stream_fetching= 'OFF'",
36 | output: "set @@session.radon_stream_fetching = 'OFF'",
37 | },
38 | {
39 | input: "SET radon_stream_fetching= false",
40 | output: "set radon_stream_fetching = false",
41 | },
42 | {
43 | input: "SET SESSION wait_timeout = 2147483",
44 | output: "set session wait_timeout = 2147483",
45 | },
46 | {
47 | input: "SET NAMES utf8",
48 | output: "set names 'utf8'",
49 | },
50 | {
51 | input: "SET NAMES latin1 COLLATE latin1_german2_ci",
52 | output: "set names 'latin1' collate latin1_german2_ci",
53 | },
54 | {
55 | input: "set session autocommit = ON, global wait_timeout = 2147483",
56 | output: "set session autocommit = 'on', global wait_timeout = 2147483",
57 | },
58 | {
59 | input: "SET LOCAL TRANSACTION ISOLATION LEVEL READ COMMITTED",
60 | output: "set session transaction isolation level read committed",
61 | },
62 | {
63 | input: "SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ WRITE",
64 | output: "set global transaction isolation level serializable, read write",
65 | },
66 | {
67 | input: "SET SESSION TRANSACTION READ ONLY",
68 | output: "set session transaction read only",
69 | },
70 | {
71 | input: "SET TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
72 | output: "set transaction isolation level serializable, read write",
73 | },
74 | }
75 |
76 | for _, exp := range validSQL {
77 | sql := strings.TrimSpace(exp.input)
78 | tree, err := Parse(sql)
79 | if err != nil {
80 | t.Errorf("input: %s, err: %v", sql, err)
81 | continue
82 | }
83 |
84 | // Walk.
85 | Walk(func(node SQLNode) (bool, error) {
86 | return true, nil
87 | }, tree)
88 |
89 | got := String(tree.(*Set))
90 | if exp.output != got {
91 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
92 | }
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/sqlparser/show_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "strings"
20 | import "testing"
21 |
22 | func TestShow1(t *testing.T) {
23 | validSQL := []struct {
24 | input string
25 | output string
26 | }{
27 | {
28 | input: "show table status",
29 | output: "show table status",
30 | },
31 | {
32 | input: "show table status from sbtest",
33 | output: "show table status from sbtest",
34 | },
35 | {
36 | input: "show create table t1",
37 | output: "show create table t1",
38 | },
39 | {
40 | input: "show tables",
41 | output: "show tables",
42 | },
43 | {
44 | input: "show full tables",
45 | output: "show full tables",
46 | },
47 | {
48 | input: "show full tables from t1",
49 | output: "show full tables from t1",
50 | },
51 | {
52 | input: "show full tables from t1 like '%mysql%'",
53 | output: "show full tables from t1 like '%mysql%'",
54 | },
55 | {
56 | input: "show full tables where Table_type != 'VIEW'",
57 | output: "show full tables where Table_type != 'VIEW'",
58 | },
59 | {
60 | input: "show tables from t1",
61 | output: "show tables from t1",
62 | },
63 | {
64 | input: "show tables from t1 like '%mysql%'",
65 | output: "show tables from t1 like '%mysql%'",
66 | },
67 | {
68 | input: "show databases",
69 | output: "show databases",
70 | },
71 | {
72 | input: "show create database sbtest",
73 | output: "show create database sbtest",
74 | },
75 | {
76 | input: "show engines",
77 | output: "show engines",
78 | },
79 | {
80 | input: "show status",
81 | output: "show status",
82 | },
83 | {
84 | input: "show versions",
85 | output: "show versions",
86 | },
87 | {
88 | input: "show processlist",
89 | output: "show processlist",
90 | },
91 | {
92 | input: "show queryz",
93 | output: "show queryz",
94 | },
95 | {
96 | input: "show txnz",
97 | output: "show txnz",
98 | },
99 | {
100 | input: "show warnings",
101 | output: "show warnings",
102 | },
103 | {
104 | input: "show variables",
105 | output: "show variables",
106 | },
107 | {
108 | input: "show binlog events",
109 | output: "show binlog events",
110 | },
111 | {
112 | input: "show binlog events limit 10",
113 | output: "show binlog events limit 10",
114 | },
115 | {
116 | input: "show binlog events from gtid '20171225083823'",
117 | output: "show binlog events from gtid '20171225083823'",
118 | },
119 | {
120 | input: "show binlog events from gtid '20171225083823' limit 1",
121 | output: "show binlog events from gtid '20171225083823' limit 1",
122 | },
123 | {
124 | input: "show columns from t1",
125 | output: "show columns from t1",
126 | },
127 | {
128 | input: "show columns from t1 like '%'",
129 | output: "show columns from t1 like '%'",
130 | },
131 | {
132 | input: "show columns from t1 where `Key` = 'PRI'",
133 | output: "show columns from t1 where `Key` = 'PRI'",
134 | },
135 | {
136 | input: "show full columns from t1",
137 | output: "show full columns from t1",
138 | },
139 | {
140 | input: "show full columns from t1 like '%'",
141 | output: "show full columns from t1 like '%'",
142 | },
143 | {
144 | input: "show full columns from t1 where `Key` = 'PRI'",
145 | output: "show full columns from t1 where `Key` = 'PRI'",
146 | },
147 | {
148 | input: "show fields from t1",
149 | output: "show columns from t1",
150 | },
151 | {
152 | input: "show fields from t1 like '%'",
153 | output: "show columns from t1 like '%'",
154 | },
155 | {
156 | input: "show fields from t1 where `Key` = 'PRI'",
157 | output: "show columns from t1 where `Key` = 'PRI'",
158 | },
159 | {
160 | input: "show full fields from t1",
161 | output: "show full columns from t1",
162 | },
163 | {
164 | input: "show full fields from t1 like '%'",
165 | output: "show full columns from t1 like '%'",
166 | },
167 | {
168 | input: "show full fields from t1 where `Key` = 'PRI'",
169 | output: "show full columns from t1 where `Key` = 'PRI'",
170 | },
171 | }
172 |
173 | for _, show := range validSQL {
174 | sql := strings.TrimSpace(show.input)
175 | tree, err := Parse(sql)
176 | if err != nil {
177 | t.Errorf("input: %s, err: %v", sql, err)
178 | continue
179 | }
180 |
181 | // Walk.
182 | Walk(func(node SQLNode) (bool, error) {
183 | return true, nil
184 | }, tree)
185 |
186 | got := String(tree.(*Show))
187 | if show.output != got {
188 | t.Errorf("want:\n%s\ngot:\n%s", show.output, got)
189 | }
190 | }
191 | }
192 |
--------------------------------------------------------------------------------
/sqlparser/token_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "testing"
20 |
21 | func TestLiteralID(t *testing.T) {
22 | testcases := []struct {
23 | in string
24 | id int
25 | out string
26 | }{{
27 | in: "`aa`",
28 | id: ID,
29 | out: "aa",
30 | }, {
31 | in: "```a```",
32 | id: ID,
33 | out: "`a`",
34 | }, {
35 | in: "`a``b`",
36 | id: ID,
37 | out: "a`b",
38 | }, {
39 | in: "`a``b`c",
40 | id: ID,
41 | out: "a`b",
42 | }, {
43 | in: "`a``b",
44 | id: LEX_ERROR,
45 | out: "a`b",
46 | }, {
47 | in: "`a``b``",
48 | id: LEX_ERROR,
49 | out: "a`b`",
50 | }, {
51 | in: "``",
52 | id: LEX_ERROR,
53 | out: "",
54 | }}
55 |
56 | for _, tcase := range testcases {
57 | tkn := NewStringTokenizer(tcase.in)
58 | id, out := tkn.Scan()
59 | if tcase.id != id || string(out) != tcase.out {
60 | t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out)
61 | }
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/sqlparser/tracked_buffer.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "fmt"
21 | "strings"
22 | )
23 |
24 | // NodeFormatter defines the signature of a custom node formatter
25 | // function that can be given to TrackedBuffer for code generation.
26 | type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
27 |
28 | // TrackedBuffer is used to rebuild a query from the ast.
29 | // bindLocations keeps track of locations in the buffer that
30 | // use bind variables for efficient future substitutions.
31 | // nodeFormatter is the formatting function the buffer will
32 | // use to format a node. By default(nil), it's FormatNode.
33 | // But you can supply a different formatting function if you
34 | // want to generate a query that's different from the default.
35 | type TrackedBuffer struct {
36 | *strings.Builder
37 | bindLocations []bindLocation
38 | nodeFormatter NodeFormatter
39 | }
40 |
41 | // NewTrackedBuffer creates a new TrackedBuffer.
42 | func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
43 | return &TrackedBuffer{
44 | Builder: new(strings.Builder),
45 | nodeFormatter: nodeFormatter,
46 | }
47 | }
48 |
49 | // WriteNode function, initiates the writing of a single SQLNode tree by passing
50 | // through to Myprintf with a default format string
51 | func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
52 | buf.Myprintf("%v", node)
53 | return buf
54 | }
55 |
56 | // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
57 | // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
58 | // which case it adds tracking info for future substitutions.
59 | //
60 | // The name must be something other than the usual Printf() to avoid "go vet"
61 | // warnings due to our custom format specifiers.
62 | func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
63 | end := len(format)
64 | fieldnum := 0
65 | for i := 0; i < end; {
66 | lasti := i
67 | for i < end && format[i] != '%' {
68 | i++
69 | }
70 | if i > lasti {
71 | buf.WriteString(format[lasti:i])
72 | }
73 | if i >= end {
74 | break
75 | }
76 | i++ // '%'
77 | switch format[i] {
78 | case 'c':
79 | switch v := values[fieldnum].(type) {
80 | case byte:
81 | buf.WriteByte(v)
82 | case rune:
83 | buf.WriteRune(v)
84 | default:
85 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
86 | }
87 | case 's':
88 | switch v := values[fieldnum].(type) {
89 | case []byte:
90 | buf.Write(v)
91 | case string:
92 | buf.WriteString(v)
93 | default:
94 | panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
95 | }
96 | case 'v':
97 | node := values[fieldnum].(SQLNode)
98 | if buf.nodeFormatter == nil {
99 | node.Format(buf)
100 | } else {
101 | buf.nodeFormatter(buf, node)
102 | }
103 | case 'a':
104 | buf.WriteArg(values[fieldnum].(string))
105 | default:
106 | panic("unexpected")
107 | }
108 | fieldnum++
109 | i++
110 | }
111 | }
112 |
113 | // WriteArg writes a value argument into the buffer along with
114 | // tracking information for future substitutions. arg must contain
115 | // the ":" or "::" prefix.
116 | func (buf *TrackedBuffer) WriteArg(arg string) {
117 | buf.bindLocations = append(buf.bindLocations, bindLocation{
118 | offset: buf.Len(),
119 | length: len(arg),
120 | })
121 | buf.WriteString(arg)
122 | }
123 |
124 | // ParsedQuery returns a ParsedQuery that contains bind
125 | // locations for easy substitution.
126 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
127 | return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
128 | }
129 |
130 | // HasBindVars returns true if the parsed query uses bind vars.
131 | func (buf *TrackedBuffer) HasBindVars() bool {
132 | return len(buf.bindLocations) != 0
133 | }
134 |
135 | // BuildParsedQuery builds a ParsedQuery from the input.
136 | func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
137 | buf := NewTrackedBuffer(nil)
138 | buf.Myprintf(in, vars...)
139 | return buf.ParsedQuery()
140 | }
141 |
--------------------------------------------------------------------------------
/sqlparser/tracked_buffer_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "testing"
20 |
21 | func TestTrackedBuffer(t *testing.T) {
22 | buf := NewTrackedBuffer(nil)
23 | buf.Myprintf("%c,%s,%a", 'a', "a", "a")
24 | }
25 |
--------------------------------------------------------------------------------
/sqlparser/txn_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import "strings"
20 | import "testing"
21 |
22 | func TestTxn(t *testing.T) {
23 | validSQL := []struct {
24 | input string
25 | output string
26 | }{
27 | {
28 | input: "start transaction",
29 | output: "start transaction",
30 | },
31 | {
32 | input: "begin",
33 | output: "begin",
34 | },
35 | {
36 | input: "rollback",
37 | output: "rollback",
38 | },
39 | {
40 | input: "commit",
41 | output: "commit",
42 | },
43 | }
44 |
45 | for _, exp := range validSQL {
46 | sql := strings.TrimSpace(exp.input)
47 | tree, err := Parse(sql)
48 | if err != nil {
49 | t.Errorf("input: %s, err: %v", sql, err)
50 | continue
51 | }
52 |
53 | // Walk.
54 | Walk(func(node SQLNode) (bool, error) {
55 | return true, nil
56 | }, tree)
57 |
58 | got := String(tree.(*Transaction))
59 | if exp.output != got {
60 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
61 | }
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/ast_walker.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import (
20 | "go/ast"
21 | "reflect"
22 | )
23 |
24 | var _ ast.Visitor = (*walker)(nil)
25 |
26 | type walker struct {
27 | result SourceFile
28 | }
29 |
30 | // Walk walks the given AST and translates it to the simplified AST used by the next steps
31 | func Walk(node ast.Node) *SourceFile {
32 | var w walker
33 | ast.Walk(&w, node)
34 | return &w.result
35 | }
36 |
37 | // Visit implements the ast.Visitor interface
38 | func (w *walker) Visit(node ast.Node) ast.Visitor {
39 | switch n := node.(type) {
40 | case *ast.TypeSpec:
41 | switch t2 := n.Type.(type) {
42 | case *ast.InterfaceType:
43 | w.append(&InterfaceDeclaration{
44 | name: n.Name.Name,
45 | block: "",
46 | })
47 | case *ast.StructType:
48 | var fields []*Field
49 | for _, f := range t2.Fields.List {
50 | for _, name := range f.Names {
51 | fields = append(fields, &Field{
52 | name: name.Name,
53 | typ: sastType(f.Type),
54 | })
55 | }
56 |
57 | }
58 | w.append(&StructDeclaration{
59 | name: n.Name.Name,
60 | fields: fields,
61 | })
62 | case *ast.ArrayType:
63 | w.append(&TypeAlias{
64 | name: n.Name.Name,
65 | typ: &Array{inner: sastType(t2.Elt)},
66 | })
67 | case *ast.Ident:
68 | w.append(&TypeAlias{
69 | name: n.Name.Name,
70 | typ: &TypeString{t2.Name},
71 | })
72 |
73 | default:
74 | panic(reflect.TypeOf(t2))
75 | }
76 | case *ast.FuncDecl:
77 | if len(n.Recv.List) > 1 || len(n.Recv.List[0].Names) > 1 {
78 | panic("don't know what to do!")
79 | }
80 | var f *Field
81 | if len(n.Recv.List) == 1 {
82 | r := n.Recv.List[0]
83 | t := sastType(r.Type)
84 | if len(r.Names) > 1 {
85 | panic("don't know what to do!")
86 | }
87 | if len(r.Names) == 1 {
88 | f = &Field{
89 | name: r.Names[0].Name,
90 | typ: t,
91 | }
92 | } else {
93 | f = &Field{
94 | name: "",
95 | typ: t,
96 | }
97 | }
98 | }
99 |
100 | w.append(&FuncDeclaration{
101 | receiver: f,
102 | name: n.Name.Name,
103 | block: "",
104 | arguments: nil,
105 | })
106 | }
107 |
108 | return w
109 | }
110 |
111 | func (w *walker) append(line Sast) {
112 | w.result.lines = append(w.result.lines, line)
113 | }
114 |
115 | func sastType(e ast.Expr) Type {
116 | switch n := e.(type) {
117 | case *ast.StarExpr:
118 | return &Ref{sastType(n.X)}
119 | case *ast.Ident:
120 | return &TypeString{n.Name}
121 | case *ast.ArrayType:
122 | return &Array{inner: sastType(n.Elt)}
123 | case *ast.InterfaceType:
124 | return &TypeString{"interface{}"}
125 | case *ast.StructType:
126 | return &TypeString{"struct{}"}
127 | }
128 |
129 | panic(reflect.TypeOf(e))
130 | }
131 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/ast_walker_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import (
20 | "go/parser"
21 | "go/token"
22 | "testing"
23 |
24 | "github.com/stretchr/testify/assert"
25 |
26 | "github.com/stretchr/testify/require"
27 | )
28 |
29 | func TestSingleInterface(t *testing.T) {
30 | input := `
31 | package sqlparser
32 |
33 | type Nodeiface interface {
34 | iNode()
35 | }
36 | `
37 |
38 | fset := token.NewFileSet()
39 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
40 | require.NoError(t, err)
41 |
42 | result := Walk(ast)
43 | expected := SourceFile{
44 | lines: []Sast{&InterfaceDeclaration{
45 | name: "Nodeiface",
46 | block: "",
47 | }},
48 | }
49 | assert.Equal(t, expected.String(), result.String())
50 | }
51 |
52 | func TestEmptyStruct(t *testing.T) {
53 | input := `
54 | package sqlparser
55 |
56 | type Empty struct {}
57 | `
58 |
59 | fset := token.NewFileSet()
60 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
61 | require.NoError(t, err)
62 |
63 | result := Walk(ast)
64 | expected := SourceFile{
65 | lines: []Sast{&StructDeclaration{
66 | name: "Empty",
67 | fields: []*Field{},
68 | }},
69 | }
70 | assert.Equal(t, expected.String(), result.String())
71 | }
72 |
73 | func TestStructWithStringField(t *testing.T) {
74 | input := `
75 | package sqlparser
76 |
77 | type Struct struct {
78 | field string
79 | }
80 | `
81 |
82 | fset := token.NewFileSet()
83 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
84 | require.NoError(t, err)
85 |
86 | result := Walk(ast)
87 | expected := SourceFile{
88 | lines: []Sast{&StructDeclaration{
89 | name: "Struct",
90 | fields: []*Field{{
91 | name: "field",
92 | typ: &TypeString{typName: "string"},
93 | }},
94 | }},
95 | }
96 | assert.Equal(t, expected.String(), result.String())
97 | }
98 |
99 | func TestStructWithDifferentTypes(t *testing.T) {
100 | input := `
101 | package sqlparser
102 |
103 | type Struct struct {
104 | field string
105 | reference *string
106 | array []string
107 | arrayOfRef []*string
108 | }
109 | `
110 |
111 | fset := token.NewFileSet()
112 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
113 | require.NoError(t, err)
114 |
115 | result := Walk(ast)
116 | expected := SourceFile{
117 | lines: []Sast{&StructDeclaration{
118 | name: "Struct",
119 | fields: []*Field{{
120 | name: "field",
121 | typ: &TypeString{typName: "string"},
122 | }, {
123 | name: "reference",
124 | typ: &Ref{&TypeString{typName: "string"}},
125 | }, {
126 | name: "array",
127 | typ: &Array{&TypeString{typName: "string"}},
128 | }, {
129 | name: "arrayOfRef",
130 | typ: &Array{&Ref{&TypeString{typName: "string"}}},
131 | }},
132 | }},
133 | }
134 | assert.Equal(t, expected.String(), result.String())
135 | }
136 |
137 | func TestStructWithTwoStringFieldInOneLine(t *testing.T) {
138 | input := `
139 | package sqlparser
140 |
141 | type Struct struct {
142 | left, right string
143 | }
144 | `
145 |
146 | fset := token.NewFileSet()
147 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
148 | require.NoError(t, err)
149 |
150 | result := Walk(ast)
151 | expected := SourceFile{
152 | lines: []Sast{&StructDeclaration{
153 | name: "Struct",
154 | fields: []*Field{{
155 | name: "left",
156 | typ: &TypeString{typName: "string"},
157 | }, {
158 | name: "right",
159 | typ: &TypeString{typName: "string"},
160 | }},
161 | }},
162 | }
163 | assert.Equal(t, expected.String(), result.String())
164 | }
165 |
166 | func TestStructWithSingleMethod(t *testing.T) {
167 | input := `
168 | package sqlparser
169 |
170 | type Empty struct {}
171 |
172 | func (*Empty) method() {}
173 | `
174 |
175 | fset := token.NewFileSet()
176 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
177 | require.NoError(t, err)
178 |
179 | result := Walk(ast)
180 | expected := SourceFile{
181 | lines: []Sast{
182 | &StructDeclaration{
183 | name: "Empty",
184 | fields: []*Field{}},
185 | &FuncDeclaration{
186 | receiver: &Field{
187 | name: "",
188 | typ: &Ref{&TypeString{"Empty"}},
189 | },
190 | name: "method",
191 | block: "",
192 | arguments: []*Field{},
193 | },
194 | },
195 | }
196 | assert.Equal(t, expected.String(), result.String())
197 | }
198 |
199 | func TestSingleArrayType(t *testing.T) {
200 | input := `
201 | package sqlparser
202 |
203 | type Strings []string
204 | `
205 |
206 | fset := token.NewFileSet()
207 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
208 | require.NoError(t, err)
209 |
210 | result := Walk(ast)
211 | expected := SourceFile{
212 | lines: []Sast{&TypeAlias{
213 | name: "Strings",
214 | typ: &Array{&TypeString{"string"}},
215 | }},
216 | }
217 | assert.Equal(t, expected.String(), result.String())
218 | }
219 |
220 | func TestSingleTypeAlias(t *testing.T) {
221 | input := `
222 | package sqlparser
223 |
224 | type String string
225 | `
226 |
227 | fset := token.NewFileSet()
228 | ast, err := parser.ParseFile(fset, "ast.go", input, 0)
229 | require.NoError(t, err)
230 |
231 | result := Walk(ast)
232 | expected := SourceFile{
233 | lines: []Sast{&TypeAlias{
234 | name: "String",
235 | typ: &TypeString{"string"},
236 | }},
237 | }
238 | assert.Equal(t, expected.String(), result.String())
239 | }
240 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/main/main.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package main
18 |
19 | import (
20 | "bytes"
21 | "flag"
22 | "fmt"
23 | "go/parser"
24 | "go/token"
25 | "io/ioutil"
26 | "os"
27 |
28 | "github.com/xelabs/go-mysqlstack/sqlparser/visitorgen"
29 | "github.com/xelabs/go-mysqlstack/xlog"
30 | )
31 |
32 | var (
33 | inputFile = flag.String("input", "", "input file to use")
34 | outputFile = flag.String("output", "", "output file")
35 | compare = flag.Bool("compareOnly", false, "instead of writing to the output file, compare if the generated visitor is still valid for this ast.go")
36 | )
37 |
38 | const usage = `Usage of visitorgen:
39 |
40 | go run /path/to/visitorgen/main -input=/path/to/ast.go -output=/path/to/rewriter.go
41 | `
42 |
43 | func main() {
44 | log := xlog.NewStdLog(xlog.Level(xlog.INFO))
45 | defer func() {
46 | if x := recover(); x != nil {
47 | log.Error("sqlparser.visitorgen.failed: %v", x)
48 | }
49 | }()
50 |
51 | flag.Usage = printUsage
52 | flag.Parse()
53 | if *inputFile == "" || *outputFile == "" {
54 | printUsage()
55 | os.Exit(0)
56 | }
57 | log.Info("sqlparser.visitorgen.inputFile[%s].outputFile[%s].compare[%t].start...\n", *inputFile, *outputFile, *compare)
58 |
59 | fs := token.NewFileSet()
60 | file, err := parser.ParseFile(fs, *inputFile, nil, parser.DeclarationErrors)
61 | if err != nil {
62 | panic(fmt.Sprintf("parse.file[%s].error[%v]", *inputFile, err))
63 | }
64 |
65 | astWalkResult := visitorgen.Walk(file)
66 | vp := visitorgen.Transform(astWalkResult)
67 | vd := visitorgen.ToVisitorPlan(vp)
68 |
69 | replacementMethods := visitorgen.EmitReplacementMethods(vd)
70 | typeSwitch := visitorgen.EmitTypeSwitches(vd)
71 |
72 | b := &bytes.Buffer{}
73 | fmt.Fprint(b, fileHeader)
74 | fmt.Fprintln(b)
75 | fmt.Fprintln(b, replacementMethods)
76 | fmt.Fprint(b, applyHeader)
77 | fmt.Fprintln(b, typeSwitch)
78 | fmt.Fprintln(b, fileFooter)
79 |
80 | if *compare {
81 | currentFile, err := ioutil.ReadFile(*outputFile)
82 | if err != nil {
83 | panic(fmt.Sprintf("read.file[%s].error[%v]", *outputFile, err))
84 | }
85 | if !bytes.Equal(b.Bytes(), currentFile) {
86 | fmt.Println("rewriter needs to be re-generated: go generate " + *outputFile)
87 | os.Exit(1)
88 | }
89 | log.Info("sqlparser.visitorgen.compare.success")
90 | } else {
91 | err = ioutil.WriteFile(*outputFile, b.Bytes(), 0644)
92 | if err != nil {
93 | panic(fmt.Sprintf("write.file[%s].error[%v]", *outputFile, err))
94 | }
95 | }
96 | log.Info("sqlparser.visitorgen.finish...")
97 | }
98 |
99 | func printUsage() {
100 | os.Stderr.WriteString(usage)
101 | os.Stderr.WriteString("\nOptions:\n")
102 | flag.PrintDefaults()
103 | }
104 |
105 | const fileHeader = `// Code generated by visitorgen/main/main.go. DO NOT EDIT.
106 |
107 | package sqlparser
108 |
109 | //go:generate go run ./visitorgen/main -input=ast.go -output=rewriter.go
110 |
111 | import (
112 | "reflect"
113 | )
114 |
115 | type replacerFunc func(newNode, parent SQLNode)
116 |
117 | // application carries all the shared data so we can pass it around cheaply.
118 | type application struct {
119 | pre, post ApplyFunc
120 | cursor Cursor
121 | }
122 | `
123 |
124 | const applyHeader = `
125 | // apply is where the visiting happens. Here is where we keep the big switch-case that will be used
126 | // to do the actual visiting of SQLNodes
127 | func (a *application) apply(parent, node SQLNode, replacer replacerFunc) {
128 | if node == nil || isNilValue(node) {
129 | return
130 | }
131 |
132 | // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
133 | saved := a.cursor
134 | a.cursor.replacer = replacer
135 | a.cursor.node = node
136 | a.cursor.parent = parent
137 |
138 | if a.pre != nil && !a.pre(&a.cursor) {
139 | a.cursor = saved
140 | return
141 | }
142 |
143 | // walk children
144 | // (the order of the cases is alphabetical)
145 | switch n := node.(type) {
146 | case nil:
147 | `
148 |
149 | const fileFooter = `
150 | default:
151 | panic("unknown ast type " + reflect.TypeOf(node).String())
152 | }
153 |
154 | if a.post != nil && !a.post(&a.cursor) {
155 | panic(abort)
156 | }
157 |
158 | a.cursor = saved
159 | }
160 |
161 | func isNilValue(i interface{}) bool {
162 | valueOf := reflect.ValueOf(i)
163 | kind := valueOf.Kind()
164 | isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice
165 | return isNullable && valueOf.IsNil()
166 | }`
167 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/sast.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | // simplified ast - when reading the golang ast of the ast.go file, we translate the golang ast objects
20 | // to this much simpler format, that contains only the necessary information and no more
21 | type (
22 | // SourceFile contains all important lines from an ast.go file
23 | SourceFile struct {
24 | lines []Sast
25 | }
26 |
27 | // Sast or simplified AST, is a representation of the ast.go lines we are interested in
28 | Sast interface {
29 | toSastString() string
30 | }
31 |
32 | // InterfaceDeclaration represents a declaration of an interface. This is used to keep track of which types
33 | // need to be handled by the visitor framework
34 | InterfaceDeclaration struct {
35 | name, block string
36 | }
37 |
38 | // TypeAlias is used whenever we see a `type XXX YYY` - XXX is the new name for YYY.
39 | // Note that YYY could be an array or a reference
40 | TypeAlias struct {
41 | name string
42 | typ Type
43 | }
44 |
45 | // FuncDeclaration represents a function declaration. These are tracked to know which types implement interfaces.
46 | FuncDeclaration struct {
47 | receiver *Field
48 | name, block string
49 | arguments []*Field
50 | }
51 |
52 | // StructDeclaration represents a struct. It contains the fields and their types
53 | StructDeclaration struct {
54 | name string
55 | fields []*Field
56 | }
57 |
58 | // Field is a field in a struct - a name with a type tuple
59 | Field struct {
60 | name string
61 | typ Type
62 | }
63 |
64 | // Type represents a type in the golang type system. Used to keep track of type we need to handle,
65 | // and the types of fields.
66 | Type interface {
67 | toTypString() string
68 | rawTypeName() string
69 | }
70 |
71 | // TypeString is a raw type name, such as `string`
72 | TypeString struct {
73 | typName string
74 | }
75 |
76 | // Ref is a reference to something, such as `*string`
77 | Ref struct {
78 | inner Type
79 | }
80 |
81 | // Array is an array of things, such as `[]string`
82 | Array struct {
83 | inner Type
84 | }
85 | )
86 |
87 | var _ Sast = (*InterfaceDeclaration)(nil)
88 | var _ Sast = (*StructDeclaration)(nil)
89 | var _ Sast = (*FuncDeclaration)(nil)
90 | var _ Sast = (*TypeAlias)(nil)
91 |
92 | var _ Type = (*TypeString)(nil)
93 | var _ Type = (*Ref)(nil)
94 | var _ Type = (*Array)(nil)
95 |
96 | // String returns a textual representation of the SourceFile. This is for testing purposed
97 | func (t *SourceFile) String() string {
98 | var result string
99 | for _, l := range t.lines {
100 | result += l.toSastString()
101 | result += "\n"
102 | }
103 |
104 | return result
105 | }
106 |
107 | func (t *Ref) toTypString() string {
108 | return "*" + t.inner.toTypString()
109 | }
110 |
111 | func (t *Array) toTypString() string {
112 | return "[]" + t.inner.toTypString()
113 | }
114 |
115 | func (t *TypeString) toTypString() string {
116 | return t.typName
117 | }
118 |
119 | func (f *FuncDeclaration) toSastString() string {
120 | var receiver string
121 | if f.receiver != nil {
122 | receiver = "(" + f.receiver.String() + ") "
123 | }
124 | var args string
125 | for i, arg := range f.arguments {
126 | if i > 0 {
127 | args += ", "
128 | }
129 | args += arg.String()
130 | }
131 |
132 | return "func " + receiver + f.name + "(" + args + ") {" + blockInNewLines(f.block) + "}"
133 | }
134 |
135 | func (i *InterfaceDeclaration) toSastString() string {
136 | return "type " + i.name + " interface {" + blockInNewLines(i.block) + "}"
137 | }
138 |
139 | func (a *TypeAlias) toSastString() string {
140 | return "type " + a.name + " " + a.typ.toTypString()
141 | }
142 |
143 | func (s *StructDeclaration) toSastString() string {
144 | var block string
145 | for _, f := range s.fields {
146 | block += "\t" + f.String() + "\n"
147 | }
148 |
149 | return "type " + s.name + " struct {" + blockInNewLines(block) + "}"
150 | }
151 |
152 | func blockInNewLines(block string) string {
153 | if block == "" {
154 | return ""
155 | }
156 | return "\n" + block + "\n"
157 | }
158 |
159 | // String returns a string representation of a field
160 | func (f *Field) String() string {
161 | if f.name != "" {
162 | return f.name + " " + f.typ.toTypString()
163 | }
164 |
165 | return f.typ.toTypString()
166 | }
167 |
168 | func (t *TypeString) rawTypeName() string {
169 | return t.typName
170 | }
171 |
172 | func (t *Ref) rawTypeName() string {
173 | return t.inner.rawTypeName()
174 | }
175 |
176 | func (t *Array) rawTypeName() string {
177 | return t.inner.rawTypeName()
178 | }
179 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/transformer.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import "fmt"
20 |
21 | // Transform takes an input file and collects the information into an easier to consume format
22 | func Transform(input *SourceFile) *SourceInformation {
23 | interestingTypes := make(map[string]Type)
24 | interfaces := make(map[string]bool)
25 | structs := make(map[string]*StructDeclaration)
26 | typeAliases := make(map[string]*TypeAlias)
27 |
28 | for _, l := range input.lines {
29 | switch line := l.(type) {
30 | case *FuncDeclaration:
31 | interestingTypes[line.receiver.typ.toTypString()] = line.receiver.typ
32 | case *StructDeclaration:
33 | structs[line.name] = line
34 | case *TypeAlias:
35 | typeAliases[line.name] = line
36 | case *InterfaceDeclaration:
37 | interfaces[line.name] = true
38 | }
39 | }
40 |
41 | return &SourceInformation{
42 | interfaces: interfaces,
43 | interestingTypes: interestingTypes,
44 | structs: structs,
45 | typeAliases: typeAliases,
46 | }
47 | }
48 |
49 | // SourceInformation contains the information from the ast.go file, but in a format that is easier to consume
50 | type SourceInformation struct {
51 | interestingTypes map[string]Type
52 | interfaces map[string]bool
53 | structs map[string]*StructDeclaration
54 | typeAliases map[string]*TypeAlias
55 | }
56 |
57 | func (v *SourceInformation) String() string {
58 | var types string
59 | for _, k := range v.interestingTypes {
60 | types += k.toTypString() + "\n"
61 | }
62 | var structs string
63 | for _, k := range v.structs {
64 | structs += k.toSastString() + "\n"
65 | }
66 | var typeAliases string
67 | for _, k := range v.typeAliases {
68 | typeAliases += k.toSastString() + "\n"
69 | }
70 |
71 | return fmt.Sprintf("Types to build visitor for:\n%s\nStructs with fields: \n%s\nTypeAliases with type: \n%s\n", types, structs, typeAliases)
72 | }
73 |
74 | // getItemTypeOfArray will return nil if the given type is not pointing to a array type.
75 | // If it is an array type, the type of it's items will be returned
76 | func (v *SourceInformation) getItemTypeOfArray(typ Type) Type {
77 | alias := v.typeAliases[typ.rawTypeName()]
78 | if alias == nil {
79 | return nil
80 | }
81 | arrTyp, isArray := alias.typ.(*Array)
82 | if !isArray {
83 | return v.getItemTypeOfArray(alias.typ)
84 | }
85 | return arrTyp.inner
86 | }
87 |
88 | func (v *SourceInformation) isSQLNode(typ Type) bool {
89 | _, isInteresting := v.interestingTypes[typ.toTypString()]
90 | if isInteresting {
91 | return true
92 | }
93 | _, isInterface := v.interfaces[typ.toTypString()]
94 | return isInterface
95 | }
96 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/transformer_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import (
20 | "testing"
21 |
22 | "github.com/stretchr/testify/assert"
23 | )
24 |
25 | func TestSimplestAst(t *testing.T) {
26 | /*
27 | type NodeInterface interface {
28 | iNode()
29 | }
30 |
31 | type NodeStruct struct {}
32 |
33 | func (*NodeStruct) iNode{}
34 | */
35 | input := &SourceFile{
36 | lines: []Sast{
37 | &InterfaceDeclaration{
38 | name: "NodeInterface",
39 | block: "// an interface lives here"},
40 | &StructDeclaration{
41 | name: "NodeStruct",
42 | fields: []*Field{}},
43 | &FuncDeclaration{
44 | receiver: &Field{
45 | name: "",
46 | typ: &Ref{&TypeString{"NodeStruct"}},
47 | },
48 | name: "iNode",
49 | block: "",
50 | arguments: []*Field{}},
51 | },
52 | }
53 |
54 | expected := &SourceInformation{
55 | interestingTypes: map[string]Type{
56 | "*NodeStruct": &Ref{&TypeString{"NodeStruct"}}},
57 | structs: map[string]*StructDeclaration{
58 | "NodeStruct": {
59 | name: "NodeStruct",
60 | fields: []*Field{}}},
61 | }
62 |
63 | assert.Equal(t, expected.String(), Transform(input).String())
64 | }
65 |
66 | func TestAstWithArray(t *testing.T) {
67 | /*
68 | type NodeInterface interface {
69 | iNode()
70 | }
71 |
72 | func (*NodeArray) iNode{}
73 |
74 | type NodeArray []NodeInterface
75 | */
76 | input := &SourceFile{
77 | lines: []Sast{
78 | &InterfaceDeclaration{
79 | name: "NodeInterface"},
80 | &TypeAlias{
81 | name: "NodeArray",
82 | typ: &Array{&TypeString{"NodeInterface"}},
83 | },
84 | &FuncDeclaration{
85 | receiver: &Field{
86 | name: "",
87 | typ: &Ref{&TypeString{"NodeArray"}},
88 | },
89 | name: "iNode",
90 | block: "",
91 | arguments: []*Field{}},
92 | },
93 | }
94 |
95 | expected := &SourceInformation{
96 | interestingTypes: map[string]Type{
97 | "*NodeArray": &Ref{&TypeString{"NodeArray"}}},
98 | structs: map[string]*StructDeclaration{},
99 | typeAliases: map[string]*TypeAlias{
100 | "NodeArray": {
101 | name: "NodeArray",
102 | typ: &Array{&TypeString{"NodeInterface"}},
103 | },
104 | },
105 | }
106 |
107 | result := Transform(input)
108 |
109 | assert.Equal(t, expected.String(), result.String())
110 | }
111 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/visitor_emitter.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import (
20 | "fmt"
21 | "strings"
22 | )
23 |
24 | // EmitReplacementMethods is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST,
25 | // and produces a string from it. This method will produce the replacement methods that make it possible to
26 | // replace objects in fields or in slices.
27 | func EmitReplacementMethods(vd *VisitorPlan) string {
28 | var sb builder
29 | for _, s := range vd.Switches {
30 | for _, k := range s.Fields {
31 | sb.appendF(k.asReplMethod())
32 | sb.newLine()
33 | }
34 | }
35 |
36 | return sb.String()
37 | }
38 |
39 | // EmitTypeSwitches is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST,
40 | // and produces a string from it. This method will produce the switch cases needed to cover the Vitess AST.
41 | func EmitTypeSwitches(vd *VisitorPlan) string {
42 | var sb builder
43 | for _, s := range vd.Switches {
44 | sb.newLine()
45 | sb.appendF(" case %s:", s.Type.toTypString())
46 | for _, k := range s.Fields {
47 | sb.appendF(k.asSwitchCase())
48 | }
49 | }
50 |
51 | return sb.String()
52 | }
53 |
54 | func (b *builder) String() string {
55 | return strings.TrimSpace(b.sb.String())
56 | }
57 |
58 | type builder struct {
59 | sb strings.Builder
60 | }
61 |
62 | func (b *builder) appendF(format string, data ...interface{}) *builder {
63 | _, err := b.sb.WriteString(fmt.Sprintf(format, data...))
64 | if err != nil {
65 | panic(err)
66 | }
67 | b.newLine()
68 | return b
69 | }
70 |
71 | func (b *builder) newLine() {
72 | _, err := b.sb.WriteString("\n")
73 | if err != nil {
74 | panic(err)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/visitor_emitter_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package visitorgen
18 |
19 | import (
20 | "testing"
21 |
22 | "github.com/stretchr/testify/require"
23 | )
24 |
25 | func TestSingleItem(t *testing.T) {
26 | sfi := SingleFieldItem{
27 | StructType: &Ref{&TypeString{"Struct"}},
28 | FieldType: &TypeString{"string"},
29 | FieldName: "Field",
30 | }
31 |
32 | expectedReplacer := `func replaceStructField(newNode, parent SQLNode) {
33 | parent.(*Struct).Field = newNode.(string)
34 | }`
35 |
36 | expectedSwitch := ` a.apply(node, n.Field, replaceStructField)`
37 | require.Equal(t, expectedReplacer, sfi.asReplMethod())
38 | require.Equal(t, expectedSwitch, sfi.asSwitchCase())
39 | }
40 |
41 | func TestArrayFieldItem(t *testing.T) {
42 | sfi := ArrayFieldItem{
43 | StructType: &Ref{&TypeString{"Struct"}},
44 | ItemType: &TypeString{"string"},
45 | FieldName: "Field",
46 | }
47 |
48 | expectedReplacer := `type replaceStructField int
49 |
50 | func (r *replaceStructField) replace(newNode, container SQLNode) {
51 | container.(*Struct).Field[int(*r)] = newNode.(string)
52 | }
53 |
54 | func (r *replaceStructField) inc() {
55 | *r++
56 | }`
57 |
58 | expectedSwitch := ` replacerField := replaceStructField(0)
59 | replacerFieldB := &replacerField
60 | for _, item := range n.Field {
61 | a.apply(node, item, replacerFieldB.replace)
62 | replacerFieldB.inc()
63 | }`
64 | require.Equal(t, expectedReplacer, sfi.asReplMethod())
65 | require.Equal(t, expectedSwitch, sfi.asSwitchCase())
66 | }
67 |
68 | func TestArrayItem(t *testing.T) {
69 | sfi := ArrayItem{
70 | StructType: &Ref{&TypeString{"Struct"}},
71 | ItemType: &TypeString{"string"},
72 | }
73 |
74 | expectedReplacer := `type replaceStructItems int
75 |
76 | func (r *replaceStructItems) replace(newNode, container SQLNode) {
77 | container.(*Struct)[int(*r)] = newNode.(string)
78 | }
79 |
80 | func (r *replaceStructItems) inc() {
81 | *r++
82 | }`
83 |
84 | expectedSwitch := ` replacer := replaceStructItems(0)
85 | replacerRef := &replacer
86 | for _, item := range n {
87 | a.apply(node, item, replacerRef.replace)
88 | replacerRef.inc()
89 | }`
90 | require.Equal(t, expectedReplacer, sfi.asReplMethod())
91 | require.Equal(t, expectedSwitch, sfi.asSwitchCase())
92 | }
93 |
--------------------------------------------------------------------------------
/sqlparser/visitorgen/visitorgen.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | //Package visitorgen is responsible for taking the ast.go of Vitess
18 | //and producing visitor infrastructure for it.
19 | //
20 | //This is accomplished in a few steps.
21 | //Step 1: Walk the AST and collect the interesting information into a format that is
22 | // easy to consume for the next step. The output format is a *SourceFile, that
23 | // contains the needed information in a format that is pretty close to the golang ast,
24 | // but simplified
25 | //Step 2: A SourceFile is packaged into a SourceInformation. SourceInformation is still
26 | // concerned with the input ast - it's just an even more distilled and easy to
27 | // consume format for the last step. This step is performed by the code in transformer.go.
28 | //Step 3: Using the SourceInformation, the struct_producer.go code produces the final data structure
29 | // used, a VisitorPlan. This is focused on the output - it contains a list of all fields or
30 | // arrays that need to be handled by the visitor produced.
31 | //Step 4: The VisitorPlan is lastly turned into a string that is written as the output of
32 | // this whole process.
33 | package visitorgen
34 |
--------------------------------------------------------------------------------
/sqlparser/xa_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2017 Google Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sqlparser
18 |
19 | import (
20 | "strings"
21 | "testing"
22 | )
23 |
24 | func TestXA(t *testing.T) {
25 | validSQL := []struct {
26 | input string
27 | output string
28 | }{
29 | {
30 | input: "xa begin 'x1'",
31 | output: "XA",
32 | },
33 |
34 | {
35 | input: "xa prepare 'x1'",
36 | output: "XA",
37 | },
38 |
39 | {
40 | input: "xa end 'x1'",
41 | output: "XA",
42 | },
43 |
44 | {
45 | input: "xa commit 'x1'",
46 | output: "XA",
47 | },
48 |
49 | {
50 | input: "xa rollback 'x1'",
51 | output: "XA",
52 | },
53 | }
54 |
55 | for _, exp := range validSQL {
56 | sql := strings.TrimSpace(exp.input)
57 | tree, err := Parse(sql)
58 | if err != nil {
59 | t.Errorf("input: %s, err: %v", sql, err)
60 | continue
61 | }
62 |
63 | // Walk.
64 | Walk(func(node SQLNode) (bool, error) {
65 | return true, nil
66 | }, tree)
67 |
68 | node := tree.(*Xa)
69 |
70 | // Format.
71 | got := String(node)
72 | if exp.output != got {
73 | t.Errorf("want:\n%s\ngot:\n%s", exp.output, got)
74 | }
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/xlog/options.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package xlog
11 |
12 | var (
13 | defaultName = " "
14 | defaultLevel = DEBUG
15 | )
16 |
17 | // Options used for the options of the xlog.
18 | type Options struct {
19 | Name string
20 | Level LogLevel
21 | }
22 |
23 | // Option func.
24 | type Option func(*Options)
25 |
26 | func newOptions(opts ...Option) *Options {
27 | opt := &Options{}
28 | for _, o := range opts {
29 | o(opt)
30 | }
31 |
32 | if len(opt.Name) == 0 {
33 | opt.Name = defaultName
34 | }
35 |
36 | if opt.Level == 0 {
37 | opt.Level = defaultLevel
38 | }
39 | return opt
40 | }
41 |
42 | // Name used to set the name.
43 | func Name(v string) Option {
44 | return func(o *Options) {
45 | o.Name = v
46 | }
47 | }
48 |
49 | // Level used to set the log level.
50 | func Level(v LogLevel) Option {
51 | return func(o *Options) {
52 | o.Level = v
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/xlog/syslog.go:
--------------------------------------------------------------------------------
1 | // +build linux darwin dragonfly freebsd netbsd openbsd solaris
2 |
3 | package xlog
4 |
5 | import "log/syslog"
6 |
7 | // NewSysLog creates a new sys log.
8 | func NewSysLog(opts ...Option) *Log {
9 | w, err := syslog.New(syslog.LOG_DEBUG, "")
10 | if err != nil {
11 | panic(err)
12 | }
13 | return NewXLog(w, opts...)
14 | }
15 |
--------------------------------------------------------------------------------
/xlog/syslog_windows.go:
--------------------------------------------------------------------------------
1 | // +build windows
2 |
3 | package xlog
4 |
5 | import (
6 | "os"
7 | )
8 |
9 | // NewSysLog creates a new sys log. Because there is no syslog support for
10 | // Windows, we output to os.Stdout.
11 | func NewSysLog(opts ...Option) *Log {
12 | return NewXLog(os.Stdout, opts...)
13 | }
14 |
--------------------------------------------------------------------------------
/xlog/xlog.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package xlog
11 |
12 | import (
13 | "fmt"
14 | "io"
15 | "log"
16 | "os"
17 | "strings"
18 | )
19 |
20 | var (
21 | defaultlog *Log
22 | )
23 |
24 | // LogLevel used for log level.
25 | type LogLevel int
26 |
27 | const (
28 | // DEBUG enum.
29 | DEBUG LogLevel = 1 << iota
30 | // INFO enum.
31 | INFO
32 | // WARNING enum.
33 | WARNING
34 | // ERROR enum.
35 | ERROR
36 | // FATAL enum.
37 | FATAL
38 | // PANIC enum.
39 | PANIC
40 | )
41 |
42 | // LevelNames represents the string name of all levels.
43 | var LevelNames = [...]string{
44 | DEBUG: "DEBUG",
45 | INFO: "INFO",
46 | WARNING: "WARNING",
47 | ERROR: "ERROR",
48 | FATAL: "FATAL",
49 | PANIC: "PANIC",
50 | }
51 |
52 | const (
53 | // D_LOG_FLAGS is the default log flags.
54 | D_LOG_FLAGS int = log.LstdFlags | log.Lmicroseconds | log.Lshortfile
55 | )
56 |
57 | // Log struct.
58 | type Log struct {
59 | opts *Options
60 | *log.Logger
61 | }
62 |
63 | // NewStdLog creates a new std log.
64 | func NewStdLog(opts ...Option) *Log {
65 | return NewXLog(os.Stdout, opts...)
66 | }
67 |
68 | // NewXLog creates a new xlog.
69 | func NewXLog(w io.Writer, opts ...Option) *Log {
70 | options := newOptions(opts...)
71 |
72 | l := &Log{
73 | opts: options,
74 | }
75 | l.Logger = log.New(w, l.opts.Name, D_LOG_FLAGS)
76 | defaultlog = l
77 | return l
78 | }
79 |
80 | // NewLog creates the new log.
81 | func NewLog(w io.Writer, prefix string, flag int) *Log {
82 | l := &Log{}
83 | l.Logger = log.New(w, prefix, flag)
84 | return l
85 | }
86 |
87 | // GetLog returns Log.
88 | func GetLog() *Log {
89 | if defaultlog == nil {
90 | log := NewStdLog(Level(INFO))
91 | defaultlog = log
92 | }
93 | return defaultlog
94 | }
95 |
96 | // SetLevel used to set the log level.
97 | func (t *Log) SetLevel(level string) {
98 | for i, v := range LevelNames {
99 | if level == v {
100 | t.opts.Level = LogLevel(i)
101 | return
102 | }
103 | }
104 | }
105 |
106 | // Debug used to log debug msg.
107 | func (t *Log) Debug(format string, v ...interface{}) {
108 | if DEBUG < t.opts.Level {
109 | return
110 | }
111 | t.log("\t [DEBUG] \t%s", fmt.Sprintf(format, v...))
112 | }
113 |
114 | // Info used to log info msg.
115 | func (t *Log) Info(format string, v ...interface{}) {
116 | if INFO < t.opts.Level {
117 | return
118 | }
119 | t.log("\t [INFO] \t%s", fmt.Sprintf(format, v...))
120 | }
121 |
122 | // Warning used to log warning msg.
123 | func (t *Log) Warning(format string, v ...interface{}) {
124 | if WARNING < t.opts.Level {
125 | return
126 | }
127 | t.log("\t [WARNING] \t%s", fmt.Sprintf(format, v...))
128 | }
129 |
130 | // Error used to log error msg.
131 | func (t *Log) Error(format string, v ...interface{}) {
132 | if ERROR < t.opts.Level {
133 | return
134 | }
135 | t.log("\t [ERROR] \t%s", fmt.Sprintf(format, v...))
136 | }
137 |
138 | // Fatal used to log faltal msg.
139 | func (t *Log) Fatal(format string, v ...interface{}) {
140 | if FATAL < t.opts.Level {
141 | return
142 | }
143 | t.log("\t [FATAL+EXIT] \t%s", fmt.Sprintf(format, v...))
144 | os.Exit(1)
145 | }
146 |
147 | // Panic used to log panic msg.
148 | func (t *Log) Panic(format string, v ...interface{}) {
149 | if PANIC < t.opts.Level {
150 | return
151 | }
152 | msg := fmt.Sprintf("\t [PANIC] \t%s", fmt.Sprintf(format, v...))
153 | t.log(msg)
154 | panic(msg)
155 | }
156 |
157 | // Close used to close the log.
158 | func (t *Log) Close() {
159 | // nothing
160 | }
161 |
162 | func (t *Log) log(format string, v ...interface{}) {
163 | t.Output(3, strings.Repeat(" ", 3)+fmt.Sprintf(format, v...)+"\n")
164 | }
165 |
--------------------------------------------------------------------------------
/xlog/xlog_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * go-mysqlstack
3 | * xelabs.org
4 | *
5 | * Copyright (c) XeLabs
6 | * GPL License
7 | *
8 | */
9 |
10 | package xlog
11 |
12 | import (
13 | "testing"
14 | )
15 |
16 | // assert fails the test if the condition is false.
17 | func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
18 | if !condition {
19 | tb.FailNow()
20 | }
21 | }
22 |
23 | func TestGetLog(t *testing.T) {
24 | GetLog().Debug("DEBUG")
25 | log := NewStdLog()
26 | log.SetLevel("INFO")
27 | GetLog().Debug("DEBUG")
28 | GetLog().Info("INFO")
29 | }
30 |
31 | func TestSysLog(t *testing.T) {
32 | log := NewSysLog()
33 |
34 | log.Debug("DEBUG")
35 | log.Info("INFO")
36 | log.Warning("WARNING")
37 | log.Error("ERROR")
38 |
39 | log.SetLevel("DEBUG")
40 | log.Debug("DEBUG")
41 | log.Info("INFO")
42 | log.Warning("WARNING")
43 | log.Error("ERROR")
44 |
45 | log.SetLevel("INFO")
46 | log.Debug("DEBUG")
47 | log.Info("INFO")
48 | log.Warning("WARNING")
49 | log.Error("ERROR")
50 |
51 | log.SetLevel("WARNING")
52 | log.Debug("DEBUG")
53 | log.Info("INFO")
54 | log.Warning("WARNING")
55 | log.Error("ERROR")
56 |
57 | log.SetLevel("ERROR")
58 | log.Debug("DEBUG")
59 | log.Info("INFO")
60 | log.Warning("WARNING")
61 | log.Error("ERROR")
62 | }
63 |
64 | func TestStdLog(t *testing.T) {
65 | log := NewStdLog()
66 |
67 | log.Println("........DEFAULT........")
68 | log.Debug("DEBUG")
69 | log.Info("INFO")
70 | log.Warning("WARNING")
71 | log.Error("ERROR")
72 |
73 | log.Println("........DEBUG........")
74 | log.SetLevel("DEBUG")
75 | log.Debug("DEBUG")
76 | log.Info("INFO")
77 | log.Warning("WARNING")
78 | log.Error("ERROR")
79 |
80 | log.Println("........INFO........")
81 | log.SetLevel("INFO")
82 | log.Debug("DEBUG")
83 | log.Info("INFO")
84 | log.Warning("WARNING")
85 | log.Error("ERROR")
86 |
87 | log.Println("........WARNING........")
88 | log.SetLevel("WARNING")
89 | log.Debug("DEBUG")
90 | log.Info("INFO")
91 | log.Warning("WARNING")
92 | log.Error("ERROR")
93 |
94 | log.Println("........ERROR........")
95 | log.SetLevel("ERROR")
96 | log.Debug("DEBUG")
97 | log.Info("INFO")
98 | log.Warning("WARNING")
99 | log.Error("ERROR")
100 | }
101 |
102 | func TestLogLevel(t *testing.T) {
103 | log := NewStdLog()
104 | {
105 | log.SetLevel("DEBUG")
106 | want := DEBUG
107 | got := log.opts.Level
108 | Assert(t, want == got, "want[%v]!=got[%v]", want, got)
109 | }
110 |
111 | {
112 | log.SetLevel("DEBUGX")
113 | want := DEBUG
114 | got := log.opts.Level
115 | Assert(t, want == got, "want[%v]!=got[%v]", want, got)
116 | }
117 |
118 | {
119 | log.SetLevel("PANIC")
120 | want := PANIC
121 | got := log.opts.Level
122 | Assert(t, want == got, "want[%v]!=got[%v]", want, got)
123 | }
124 |
125 | {
126 | log.SetLevel("WARNING")
127 | want := WARNING
128 | got := log.opts.Level
129 | Assert(t, want == got, "want[%v]!=got[%v]", want, got)
130 | }
131 | }
132 |
--------------------------------------------------------------------------------