├── .gitignore ├── LICENSE ├── README.md ├── backend ├── conn.go ├── conn_test.go ├── db.go ├── stmt.go └── stmt_test.go ├── bootstrap.sh ├── cmd └── proxy │ ├── .gitignore │ └── proxy.go ├── config ├── config.go └── config_test.go ├── docs ├── internal │ └── Design.md ├── mysql-proxy │ ├── protocol.txt │ └── scripting.txt └── protocol.txt ├── hack ├── hack.go └── hack_test.go ├── log ├── logger.go ├── milog.go └── milog_test.go ├── makefile ├── mysql ├── charset.go ├── const.go ├── debug.go ├── errcode.go ├── errname.go ├── error.go ├── field.go ├── packet.go ├── packetio.go ├── result.go ├── resultset.go ├── resultset_sort.go ├── resultset_sort_test.go ├── state.go └── util.go ├── pool ├── .gitignore ├── slice.go ├── slice1.go ├── slice1_test.go ├── slice_test.go └── utils.go ├── proxy ├── auth.go ├── conn.go ├── conn_auth.go ├── conn_query.go ├── conn_resultset.go ├── conn_select.go ├── conn_set.go ├── conn_show.go ├── conn_stmt.go ├── conn_stmt_test.go ├── conn_test.go ├── conn_tx.go ├── node.go ├── schema.go ├── server.go ├── server_test.go └── signal.go ├── run.sh ├── sql ├── .gitignore ├── Makefile ├── ast.go ├── ast_alter.go ├── ast_compound.go ├── ast_create.go ├── ast_dal.go ├── ast_ddl.go ├── ast_dml.go ├── ast_drop.go ├── ast_expr.go ├── ast_prepare.go ├── ast_replication.go ├── ast_show.go ├── ast_table.go ├── ast_trans.go ├── ast_util.go ├── bin │ └── yacc ├── charset │ ├── charset.go │ ├── charset_test.go │ └── utf8_general_cli.go ├── debug.go ├── lex.go ├── lex_ident.go ├── lex_ident_test.go ├── lex_keywords.go ├── lex_keywords_test.go ├── lex_nchar.go ├── lex_number.go ├── lex_number_test.go ├── lex_test.go ├── lex_text.go ├── lex_text_test.go ├── lex_var_test.go ├── parser.go ├── parser_dal_test.go ├── parser_ddl_test.go ├── parser_dml_test.go ├── parser_test.go ├── parser_token.go ├── parser_trans_test.go ├── parser_util_test.go ├── sql_yacc.go ├── sql_yacc.prf ├── sql_yacc.yy ├── state │ └── state.go └── test.sh ├── sqltypes ├── sqltypes.go └── type_test.go └── wercker.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | *.swp 27 | 28 | *.out 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Wang Jing wangjild@gmail.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-mysql-proxy 2 | ============== 3 | -------------------------------------------------------------------------------- /backend/conn_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/wangjild/go-mysql-proxy/mysql" 6 | "testing" 7 | ) 8 | 9 | func newTestConn() *Conn { 10 | c := new(Conn) 11 | 12 | if err := c.Connect("127.0.0.1:4306", "root", "", "go_proxy"); err != nil { 13 | panic(err) 14 | } 15 | 16 | return c 17 | } 18 | 19 | func TestConn_Connect(t *testing.T) { 20 | c := newTestConn() 21 | defer c.Close() 22 | } 23 | 24 | func TestConn_Ping(t *testing.T) { 25 | c := newTestConn() 26 | defer c.Close() 27 | 28 | if err := c.Ping(); err != nil { 29 | t.Fatal(err) 30 | } 31 | } 32 | 33 | func TestConn_DeleteTable(t *testing.T) { 34 | c := newTestConn() 35 | defer c.Close() 36 | 37 | if _, err := c.Execute("drop table if exists go_proxy_test_conn"); err != nil { 38 | t.Fatal(err) 39 | } 40 | } 41 | 42 | func TestConn_CreateTable(t *testing.T) { 43 | s := `CREATE TABLE IF NOT EXISTS go_proxy_test_conn ( 44 | id BIGINT(64) UNSIGNED NOT NULL, 45 | str VARCHAR(256), 46 | f DOUBLE, 47 | e enum("test1", "test2"), 48 | u tinyint unsigned, 49 | i tinyint, 50 | PRIMARY KEY (id) 51 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 52 | 53 | c := newTestConn() 54 | defer c.Close() 55 | 56 | if _, err := c.Execute(s); err != nil { 57 | t.Fatal(err) 58 | } 59 | } 60 | 61 | func TestConn_Insert(t *testing.T) { 62 | s := `insert into go_proxy_test_conn (id, str, f, e) values(1, "a", 3.14, "test1")` 63 | 64 | c := newTestConn() 65 | defer c.Close() 66 | 67 | if pkg, err := c.Execute(s); err != nil { 68 | t.Fatal(err) 69 | } else { 70 | if pkg.AffectedRows != 1 { 71 | t.Fatal(pkg.AffectedRows) 72 | } 73 | } 74 | } 75 | 76 | func TestConn_Select(t *testing.T) { 77 | s := `select str, f, e from go_proxy_test_conn where id = 1` 78 | 79 | c := newTestConn() 80 | defer c.Close() 81 | 82 | if result, err := c.Execute(s); err != nil { 83 | t.Fatal(err) 84 | } else { 85 | if len(result.Fields) != 3 { 86 | t.Fatal(len(result.Fields)) 87 | } 88 | 89 | if len(result.Values) != 1 { 90 | t.Fatal(len(result.Values)) 91 | } 92 | 93 | if str, _ := result.GetString(0, 0); str != "a" { 94 | t.Fatal("invalid str", str) 95 | } 96 | 97 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 98 | t.Fatal("invalid f", f) 99 | } 100 | 101 | if e, _ := result.GetString(0, 2); e != "test1" { 102 | t.Fatal("invalid e", e) 103 | } 104 | 105 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 106 | t.Fatal("invalid str", str) 107 | } 108 | 109 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 110 | t.Fatal("invalid f", f) 111 | } 112 | 113 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 114 | t.Fatal("invalid e", e) 115 | } 116 | 117 | } 118 | } 119 | 120 | func TestConn_Escape(t *testing.T) { 121 | c := newTestConn() 122 | defer c.Close() 123 | 124 | e := `""''\abc` 125 | s := fmt.Sprintf(`insert into go_proxy_test_conn (id, str) values(5, "%s")`, 126 | Escape(e)) 127 | 128 | if _, err := c.Execute(s); err != nil { 129 | t.Fatal(err) 130 | } 131 | 132 | s = `select str from go_proxy_test_conn where id = ?` 133 | 134 | if r, err := c.Execute(s, 5); err != nil { 135 | t.Fatal(err) 136 | } else { 137 | str, _ := r.GetString(0, 0) 138 | if str != e { 139 | t.Fatal(str) 140 | } 141 | } 142 | } 143 | 144 | func TestConn_SetCharset(t *testing.T) { 145 | c := newTestConn() 146 | defer c.Close() 147 | 148 | if err := c.SetCharset("gb2312"); err != nil { 149 | t.Fatal(err) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /backend/db.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "container/list" 5 | "fmt" 6 | "math/rand" 7 | . "github.com/wangjild/go-mysql-proxy/mysql" 8 | "runtime" 9 | "sync" 10 | "sync/atomic" 11 | ) 12 | 13 | type DB struct { 14 | sync.Mutex 15 | 16 | addr string 17 | user string 18 | password string 19 | db string 20 | maxIdleConns int 21 | 22 | idleConns []*list.List 23 | 24 | connNum int32 25 | barrel int 26 | } 27 | 28 | func Open(addr string, user string, password string, dbName string) (*DB, error) { 29 | db := new(DB) 30 | 31 | db.addr = addr 32 | db.user = user 33 | db.password = password 34 | db.db = dbName 35 | 36 | db.barrel = runtime.NumCPU() 37 | db.idleConns = make([]*list.List, db.barrel, db.barrel) 38 | for i := range db.idleConns { 39 | db.idleConns[i] = list.New() 40 | } 41 | 42 | db.connNum = 0 43 | return db, nil 44 | } 45 | 46 | func (db *DB) Addr() string { 47 | return db.addr 48 | } 49 | 50 | func (db *DB) String() string { 51 | return fmt.Sprintf("%s:%s@%s/%s?maxIdleConns=%v", 52 | db.user, db.password, db.addr, db.db, db.maxIdleConns) 53 | } 54 | 55 | func (db *DB) Close() error { 56 | db.Lock() 57 | defer db.Unlock() 58 | 59 | for i := range db.idleConns { 60 | if db.idleConns[i].Len() > 0 { 61 | v := db.idleConns[i].Back() 62 | co := v.Value.(*Conn) 63 | db.idleConns[i].Remove(v) 64 | co.Close() 65 | } else { 66 | break 67 | } 68 | } 69 | 70 | db.connNum = 0 71 | return nil 72 | } 73 | 74 | func (db *DB) Ping() error { 75 | c, err := db.PopConn() 76 | if err != nil { 77 | return err 78 | } 79 | 80 | err = c.Ping() 81 | db.PushConn(c, err) 82 | return err 83 | } 84 | 85 | func (db *DB) SetMaxIdleConnNum(num int) { 86 | db.maxIdleConns = num 87 | } 88 | 89 | func (db *DB) GetConnNum() int { 90 | return int(db.connNum) 91 | } 92 | 93 | func (db *DB) newConn() (*Conn, error) { 94 | co := new(Conn) 95 | 96 | if err := co.Connect(db.addr, db.user, db.password, db.db); err != nil { 97 | return nil, err 98 | } 99 | 100 | return co, nil 101 | } 102 | 103 | func (db *DB) tryReuse(co *Conn) error { 104 | if co.IsInTransaction() { 105 | //we can not reuse a connection in transaction status 106 | if err := co.Rollback(); err != nil { 107 | return err 108 | } 109 | } 110 | 111 | if !co.IsAutoCommit() { 112 | //we can not reuse a connection not in autocomit 113 | if _, err := co.exec("set autocommit = 1"); err != nil { 114 | return err 115 | } 116 | } 117 | 118 | //connection may be set names early 119 | //we must use default utf8 120 | if co.GetCharset() != DEFAULT_CHARSET { 121 | if err := co.SetCharset(DEFAULT_CHARSET); err != nil { 122 | return err 123 | } 124 | } 125 | 126 | return nil 127 | } 128 | 129 | func (db *DB) PopConn() (co *Conn, err error) { 130 | idx := rand.Intn(db.barrel) 131 | 132 | db.Lock() 133 | if db.idleConns[idx].Len() > 0 { 134 | v := db.idleConns[idx].Front() 135 | co = v.Value.(*Conn) 136 | db.idleConns[idx].Remove(v) 137 | } 138 | db.Unlock() 139 | 140 | if co != nil { 141 | if err := co.Ping(); err == nil { 142 | if err := db.tryReuse(co); err == nil { 143 | //connection may alive 144 | return co, nil 145 | } 146 | } 147 | co.Close() 148 | } 149 | 150 | co, err = db.newConn() 151 | if err == nil { 152 | atomic.AddInt32(&db.connNum, 1) 153 | } 154 | return 155 | } 156 | 157 | func (db *DB) PushConn(co *Conn, err error) { 158 | var closeConn *Conn = nil 159 | 160 | if err != nil { 161 | closeConn = co 162 | } else { 163 | if db.maxIdleConns > 0 { 164 | idx := rand.Intn(db.barrel) 165 | db.Lock() 166 | if db.idleConns[idx].Len() >= db.maxIdleConns { 167 | v := db.idleConns[idx].Front() 168 | closeConn = v.Value.(*Conn) 169 | db.idleConns[idx].Remove(v) 170 | } 171 | db.idleConns[idx].PushBack(co) 172 | db.Unlock() 173 | 174 | } else { 175 | closeConn = co 176 | } 177 | 178 | } 179 | 180 | if closeConn != nil { 181 | atomic.AddInt32(&db.connNum, -1) 182 | 183 | closeConn.Close() 184 | } 185 | } 186 | 187 | type SqlConn struct { 188 | *Conn 189 | 190 | db *DB 191 | } 192 | 193 | func (p *SqlConn) Close() { 194 | if p.Conn != nil { 195 | p.db.PushConn(p.Conn, p.Conn.pkgErr) 196 | p.Conn = nil 197 | } 198 | } 199 | 200 | func (db *DB) GetConn() (*SqlConn, error) { 201 | c, err := db.PopConn() 202 | return &SqlConn{c, db}, err 203 | } 204 | -------------------------------------------------------------------------------- /backend/stmt.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/binary" 5 | . "github.com/wangjild/go-mysql-proxy/mysql" 6 | ) 7 | 8 | type Stmt struct { 9 | conn *Conn 10 | id uint32 11 | query string 12 | 13 | params int 14 | ParamDefs [][]byte 15 | 16 | columns int 17 | ColDefs [][]byte 18 | 19 | flag byte 20 | } 21 | 22 | func (s *Stmt) ID() uint32 { 23 | return s.id 24 | } 25 | 26 | func (s *Stmt) ParamNum() int { 27 | return s.params 28 | } 29 | 30 | func (s *Stmt) ColumnNum() int { 31 | return s.columns 32 | } 33 | 34 | func (s *Stmt) Execute(data []byte) (*Result, error) { 35 | if err := s.write(data); err != nil { 36 | return nil, err 37 | } 38 | 39 | return s.conn.readResult(true) 40 | } 41 | 42 | func (s *Stmt) Close(closeConn bool) error { 43 | if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil { 44 | return err 45 | } 46 | 47 | if closeConn { 48 | s.conn.Close() 49 | s.conn = nil 50 | } 51 | 52 | return nil 53 | } 54 | 55 | func (s *Stmt) write(param []byte) error { 56 | 57 | data := make([]byte, 4, 4+9+len(param)) 58 | 59 | data = append(data, COM_STMT_EXECUTE) 60 | 61 | data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24)) 62 | 63 | //flag: CURSOR_TYPE_NO_CURSOR 64 | data = append(data, s.flag) 65 | 66 | data = append(data, 1, 0, 0, 0) 67 | 68 | data = append(data, param...) 69 | 70 | s.conn.pkg.Sequence = 0 71 | return s.conn.writePacket(data) 72 | } 73 | 74 | func (s *Stmt) SendLongData(pid uint16, payload []byte) error { 75 | 76 | data := make([]byte, 4, 4+7+len(payload)) 77 | data = append(data, COM_STMT_SEND_LONG_DATA) 78 | data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24)) 79 | data = append(data, byte(pid), byte(pid>>8)) 80 | 81 | data = append(data, payload...) 82 | 83 | s.conn.pkg.Sequence = 0 84 | return s.conn.writePacket(data) 85 | } 86 | 87 | func (s *Stmt) Reset() (*Result, error) { 88 | if err := s.conn.writeCommandUint32(COM_STMT_RESET, s.id); err != nil { 89 | return nil, err 90 | } 91 | 92 | s.flag = CURSOR_TYPE_NO_CURSOR 93 | return s.conn.readOK() 94 | } 95 | 96 | func (c *Conn) Prepare(query string) (*Stmt, error) { 97 | if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil { 98 | return nil, err 99 | } 100 | 101 | data, err := c.readPacket() 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | if data[0] == ERR_HEADER { 107 | return nil, c.handleErrorPacket(data) 108 | } else if data[0] != OK_HEADER { 109 | return nil, ErrMalformPacket 110 | } 111 | 112 | s := new(Stmt) 113 | s.conn = c 114 | 115 | pos := 1 116 | 117 | //for statement id 118 | s.id = binary.LittleEndian.Uint32(data[pos:]) 119 | pos += 4 120 | 121 | //number columns 122 | s.columns = int(binary.LittleEndian.Uint16(data[pos:])) 123 | pos += 2 124 | 125 | //number params 126 | s.params = int(binary.LittleEndian.Uint16(data[pos:])) 127 | pos += 2 128 | 129 | //warnings 130 | //warnings = binary.LittleEndian.Uint16(data[pos:]) 131 | 132 | if s.params > 0 { 133 | if ps, err := s.conn.readUntilEOF(s.params); err != nil { 134 | return nil, err 135 | } else { 136 | s.ParamDefs = ps 137 | } 138 | } 139 | 140 | if s.columns > 0 { 141 | if cs, err := s.conn.readUntilEOF(s.columns); err != nil { 142 | return nil, err 143 | } else { 144 | s.ColDefs = cs 145 | } 146 | } 147 | 148 | s.query = query 149 | return s, nil 150 | } 151 | 152 | func (s *Stmt) SetAttr(f byte) { 153 | s.flag = f 154 | } 155 | -------------------------------------------------------------------------------- /backend/stmt_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestStmt_DropTable(t *testing.T) { 8 | str := `drop table if exists go_proxy_test_stmt` 9 | 10 | c := newTestConn() 11 | 12 | s, err := c.Prepare(str) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | 17 | if _, err := s.Execute(); err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | s.Close() 22 | } 23 | 24 | func TestStmt_CreateTable(t *testing.T) { 25 | str := `CREATE TABLE IF NOT EXISTS go_proxy_test_stmt ( 26 | id BIGINT(64) UNSIGNED NOT NULL, 27 | str VARCHAR(256), 28 | f DOUBLE, 29 | e enum("test1", "test2"), 30 | u tinyint unsigned, 31 | i tinyint, 32 | PRIMARY KEY (id) 33 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 34 | 35 | c := newTestConn() 36 | defer c.Close() 37 | 38 | s, err := c.Prepare(str) 39 | 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | if _, err = s.Execute(); err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | s.Close() 49 | } 50 | 51 | func TestStmt_Delete(t *testing.T) { 52 | str := `delete from go_proxy_test_stmt` 53 | 54 | c := newTestConn() 55 | defer c.Close() 56 | 57 | s, err := c.Prepare(str) 58 | 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | if _, err := s.Execute(); err != nil { 64 | t.Fatal(err) 65 | } 66 | 67 | s.Close() 68 | } 69 | 70 | func TestStmt_Insert(t *testing.T) { 71 | str := `insert into go_proxy_test_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)` 72 | 73 | c := newTestConn() 74 | defer c.Close() 75 | 76 | s, err := c.Prepare(str) 77 | 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil { 83 | t.Fatal(err) 84 | } else { 85 | if pkg.AffectedRows != 1 { 86 | t.Fatal(pkg.AffectedRows) 87 | } 88 | } 89 | 90 | s.Close() 91 | } 92 | 93 | func TestStmt_Select(t *testing.T) { 94 | str := `select str, f, e from go_proxy_test_stmt where id = ?` 95 | 96 | c := newTestConn() 97 | defer c.Close() 98 | 99 | s, err := c.Prepare(str) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | if result, err := s.Execute(1); err != nil { 105 | t.Fatal(err) 106 | } else { 107 | if len(result.Values) != 1 { 108 | t.Fatal(len(result.Values)) 109 | } 110 | 111 | if len(result.Fields) != 3 { 112 | t.Fatal(len(result.Fields)) 113 | } 114 | 115 | if str, _ := result.GetString(0, 0); str != "a" { 116 | t.Fatal("invalid str", str) 117 | } 118 | 119 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 120 | t.Fatal("invalid f", f) 121 | } 122 | 123 | if e, _ := result.GetString(0, 2); e != "test1" { 124 | t.Fatal("invalid e", e) 125 | } 126 | 127 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 128 | t.Fatal("invalid str", str) 129 | } 130 | 131 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 132 | t.Fatal("invalid f", f) 133 | } 134 | 135 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 136 | t.Fatal("invalid e", e) 137 | } 138 | 139 | } 140 | 141 | s.Close() 142 | } 143 | 144 | func TestStmt_NULL(t *testing.T) { 145 | str := `insert into go_proxy_test_stmt (id, str, f, e) values (?, ?, ?, ?)` 146 | 147 | c := newTestConn() 148 | defer c.Close() 149 | 150 | s, err := c.Prepare(str) 151 | 152 | if err != nil { 153 | t.Fatal(err) 154 | } 155 | 156 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil { 157 | t.Fatal(err) 158 | } else { 159 | if pkg.AffectedRows != 1 { 160 | t.Fatal(pkg.AffectedRows) 161 | } 162 | } 163 | 164 | s.Close() 165 | 166 | str = `select * from go_proxy_test_stmt where id = ?` 167 | s, err = c.Prepare(str) 168 | 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | if r, err := s.Execute(2); err != nil { 174 | t.Fatal(err) 175 | } else { 176 | if b, err := r.IsNullByName(0, "id"); err != nil { 177 | t.Fatal(err) 178 | } else if b == true { 179 | t.Fatal(b) 180 | } 181 | 182 | if b, err := r.IsNullByName(0, "str"); err != nil { 183 | t.Fatal(err) 184 | } else if b == false { 185 | t.Fatal(b) 186 | } 187 | 188 | if b, err := r.IsNullByName(0, "f"); err != nil { 189 | t.Fatal(err) 190 | } else if b == true { 191 | t.Fatal(b) 192 | } 193 | 194 | if b, err := r.IsNullByName(0, "e"); err != nil { 195 | t.Fatal(err) 196 | } else if b == false { 197 | t.Fatal(b) 198 | } 199 | } 200 | 201 | s.Close() 202 | } 203 | 204 | func TestStmt_Unsigned(t *testing.T) { 205 | str := `insert into go_proxy_test_stmt (id, u) values (?, ?)` 206 | 207 | c := newTestConn() 208 | defer c.Close() 209 | 210 | s, err := c.Prepare(str) 211 | 212 | if err != nil { 213 | t.Fatal(err) 214 | } 215 | 216 | if pkg, err := s.Execute(3, uint8(255)); err != nil { 217 | t.Fatal(err) 218 | } else { 219 | if pkg.AffectedRows != 1 { 220 | t.Fatal(pkg.AffectedRows) 221 | } 222 | } 223 | 224 | s.Close() 225 | 226 | str = `select u from go_proxy_test_stmt where id = ?` 227 | 228 | s, err = c.Prepare(str) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | 233 | if r, err := s.Execute(3); err != nil { 234 | t.Fatal(err) 235 | } else { 236 | if u, err := r.GetUint(0, 0); err != nil { 237 | t.Fatal(err) 238 | } else if u != uint64(255) { 239 | t.Fatal(u) 240 | } 241 | } 242 | 243 | s.Close() 244 | } 245 | 246 | func TestStmt_Signed(t *testing.T) { 247 | str := `insert into go_proxy_test_stmt (id, i) values (?, ?)` 248 | 249 | c := newTestConn() 250 | defer c.Close() 251 | 252 | s, err := c.Prepare(str) 253 | 254 | if err != nil { 255 | t.Fatal(err) 256 | } 257 | 258 | if _, err := s.Execute(4, 127); err != nil { 259 | t.Fatal(err) 260 | } 261 | 262 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil { 263 | t.Fatal(err) 264 | } 265 | 266 | s.Close() 267 | 268 | } 269 | 270 | func TestStmt_Trans(t *testing.T) { 271 | c := newTestConn() 272 | defer c.Close() 273 | 274 | if _, err := c.Execute(`insert into go_proxy_test_stmt (id, str) values (1002, "abc")`); err != nil { 275 | t.Fatal(err) 276 | } 277 | 278 | if err := c.Begin(); err != nil { 279 | t.Fatal(err) 280 | } 281 | 282 | str := `select str from go_proxy_test_stmt where id = ?` 283 | 284 | s, err := c.Prepare(str) 285 | if err != nil { 286 | t.Fatal(err) 287 | } 288 | 289 | if _, err := s.Execute(1002); err != nil { 290 | t.Fatal(err) 291 | } 292 | 293 | if err := c.Commit(); err != nil { 294 | t.Fatal(err) 295 | } 296 | 297 | if r, err := s.Execute(1002); err != nil { 298 | t.Fatal(err) 299 | } else { 300 | if str, _ := r.GetString(0, 0); str != `abc` { 301 | t.Fatal(str) 302 | } 303 | } 304 | 305 | if err := s.Close(); err != nil { 306 | t.Fatal(err) 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -f bootstrap.sh ]; then 4 | echo "bootstrap.sh must be run from its current directory" 1>&2 5 | exit 1 6 | fi 7 | 8 | source ./dev.env 9 | 10 | go get gopkg.in/yaml.v2 11 | -------------------------------------------------------------------------------- /cmd/proxy/.gitignore: -------------------------------------------------------------------------------- 1 | proxy 2 | -------------------------------------------------------------------------------- /cmd/proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "github.com/wangjild/go-mysql-proxy/config" 6 | . "github.com/wangjild/go-mysql-proxy/log" 7 | "github.com/wangjild/go-mysql-proxy/proxy" 8 | "net/http" 9 | _ "net/http/pprof" 10 | "os" 11 | "os/signal" 12 | "runtime" 13 | "syscall" 14 | ) 15 | 16 | var configFile *string = flag.String("config", "etc/proxy.yaml", "go mysql proxy config file") 17 | var logLevel *int = flag.Int("loglevel", 0, "0-debug| 1-notice|2-warn|3-fatal") 18 | var logFile *string = flag.String("logfile", "log/proxy.log", "go mysql proxy logfile") 19 | 20 | func main() { 21 | runtime.GOMAXPROCS(runtime.NumCPU()) 22 | 23 | flag.Parse() 24 | 25 | if len(*configFile) == 0 { 26 | SysLog.Fatal("must use a config file") 27 | return 28 | } 29 | 30 | cfg, err := config.ParseConfigFile(*configFile) 31 | if err != nil { 32 | SysLog.Fatal(err.Error()) 33 | return 34 | } 35 | 36 | //Init(&Config{FilePath: *logFile, LogLevel: *logLevel}, 37 | // &Config{FilePath: *logFile, LogLevel: *logLevel}) 38 | 39 | sc := make(chan os.Signal, 1) 40 | signal.Notify(sc, 41 | syscall.SIGHUP, 42 | syscall.SIGINT, 43 | syscall.SIGTERM, 44 | syscall.SIGQUIT) 45 | 46 | var svr *proxy.Server 47 | svr, err = proxy.NewServer(cfg) 48 | if err != nil { 49 | SysLog.Fatal(err.Error()) 50 | return 51 | } 52 | 53 | go func() { 54 | http.ListenAndServe(":11888", nil) 55 | }() 56 | 57 | go func() { 58 | sig := <-sc 59 | SysLog.Notice("Got signal [%d] to exit.", sig) 60 | svr.Close() 61 | }() 62 | 63 | svr.Run() 64 | } 65 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "gopkg.in/yaml.v2" 5 | "io/ioutil" 6 | ) 7 | 8 | type NodeConfig struct { 9 | Name string `yaml:"name"` 10 | DownAfterNoAlive int `yaml:"down_after_noalive"` 11 | IdleConns int `yaml:"idle_conns"` 12 | RWSplit bool `yaml:"rw_split"` 13 | 14 | User string `yaml:"user"` 15 | Password string `yaml:"password"` 16 | 17 | Master string `yaml:"master"` 18 | Slave string `yaml:"slave"` 19 | } 20 | 21 | type SchemaConfig struct { 22 | DB string `yaml:"db"` 23 | Node string `yaml:"node"` 24 | Auths []Auth `yaml:"auths"` 25 | } 26 | 27 | type Auth struct { 28 | User string `yaml:"user"` 29 | Passwd string `yaml:"passwd"` 30 | } 31 | 32 | type Config struct { 33 | Addr string `yaml:"addr"` 34 | User string `yaml:"user"` 35 | Password string `yaml:"password"` 36 | LogLevel string `yaml:"log_level"` 37 | 38 | Nodes []NodeConfig `yaml:"nodes"` 39 | 40 | Schemas []SchemaConfig `yaml:"schemas"` 41 | } 42 | 43 | func ParseConfigData(data []byte) (*Config, error) { 44 | var cfg Config 45 | if err := yaml.Unmarshal([]byte(data), &cfg); err != nil { 46 | return nil, err 47 | } 48 | return &cfg, nil 49 | } 50 | 51 | func ParseConfigFile(fileName string) (*Config, error) { 52 | data, err := ioutil.ReadFile(fileName) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return ParseConfigData(data) 58 | } 59 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestConfig(t *testing.T) { 10 | var testConfigData = []byte( 11 | ` 12 | addr : 127.0.0.1:4000 13 | user : root 14 | password : 15 | log_level : error 16 | 17 | nodes : 18 | - 19 | name : node1 20 | down_after_noalive : 300 21 | idle_conns : 16 22 | rw_split: true 23 | user: root 24 | password: 25 | master : 127.0.0.1:3306 26 | slave : 127.0.0.1:4306 27 | - 28 | name : node2 29 | user: root 30 | master : 127.0.0.1:3307 31 | 32 | - 33 | name : node3 34 | down_after_noalive : 300 35 | idle_conns : 16 36 | rw_split: false 37 | user: root 38 | password: 39 | master : 127.0.0.1:3308 40 | 41 | schemas : 42 | - 43 | db : go_proxy 44 | node: node1 45 | auths : 46 | - 47 | user: xm_test 48 | passwd: xiaomi 49 | - 50 | user: xm_test1 51 | passwd: xiaomi 52 | 53 | `) 54 | 55 | cfg, err := ParseConfigData(testConfigData) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | 60 | if len(cfg.Nodes) != 3 { 61 | t.Fatal(len(cfg.Nodes)) 62 | } 63 | 64 | if len(cfg.Schemas) != 1 { 65 | t.Fatal(len(cfg.Schemas)) 66 | } 67 | 68 | testNode := NodeConfig{ 69 | Name: "node1", 70 | DownAfterNoAlive: 300, 71 | IdleConns: 16, 72 | RWSplit: true, 73 | 74 | User: "root", 75 | Password: "", 76 | 77 | Master: "127.0.0.1:3306", 78 | Slave: "127.0.0.1:4306", 79 | } 80 | 81 | if !reflect.DeepEqual(cfg.Nodes[0], testNode) { 82 | fmt.Printf("%v\n", cfg.Nodes[0]) 83 | t.Fatal("node1 must equal") 84 | } 85 | 86 | testNode_2 := NodeConfig{ 87 | Name: "node2", 88 | User: "root", 89 | Master: "127.0.0.1:3307", 90 | } 91 | 92 | if !reflect.DeepEqual(cfg.Nodes[1], testNode_2) { 93 | t.Fatal("node2 must equal") 94 | } 95 | 96 | testSchema := SchemaConfig{ 97 | DB: "go_proxy", 98 | Node: "node1", 99 | Auths: []Auth{Auth{User: "xm_test", Passwd: "xiaomi"}, Auth{User: "xm_test1", Passwd: "xiaomi"}}, 100 | } 101 | 102 | fmt.Println(cfg.Schemas[0]) 103 | if !reflect.DeepEqual(cfg.Schemas[0], testSchema) { 104 | t.Fatal("schema must equal") 105 | } 106 | 107 | if cfg.LogLevel != "error" || cfg.User != "root" || cfg.Password != "" || cfg.Addr != "127.0.0.1:4000" { 108 | t.Fatal("Top Config not equal.") 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /docs/internal/Design.md: -------------------------------------------------------------------------------- 1 | ## 1. 背景概述 2 | 3 | MySQL服务是一种基于多线程模型的服务,每个客户端连接需要一个对应的线程处理请求。 4 | 5 | 在互联网的应用场景下,MySQL调用方一般为php-fpm或者uwsgi等采用多进程模型的服务。这种场景下,不论是短连接还是长连接,均会对MySQL服务产生不良影响,影响服务稳定性。 6 | 7 | 再者,由于上述基于多进程模型的服务,在较高的并发和流量情况下,都会将实例部署到多台物理服务器上面,因此直接暴露MySQL的配置信息,会导致多个物理机器上面的配置需要同时变更,存在运维上面的一些不便。 8 | 9 | 又有,MySQL本身没有流量控制的功能,在服务因各种内外因素而请求突增的情况下,很容易导致MySQL服务器被压垮,从而引起雪崩效应,影响其他服务。 10 | 11 | 因此,在Web服务和最终的MySQL服务之间,我们需要有一层抽象中间层,来处理上诉技术问题。抽象中间层、应用服务、MySQL服务三者的层次结构如下: 12 | 13 | +----------------------+ +----------------------+ 14 | | | | | 15 | | PHP-FPM/uwsgi | | PHP-FPM/uwsgi | 16 | | | | | 17 | +----------+-----------+ +-----------+----------+ 18 | | | 19 | +----------------+ +---------------+ 20 | | | 21 | | | 22 | +---+----------+---+ 23 | | | 24 | | Abstract Layer | 25 | | | 26 | +--------+---------+ 27 | | 28 | | 29 | +--------+---------+ 30 | | | 31 | | MySQL DB | 32 | | | 33 | +------------------+ 34 | 35 | 36 | ## 2. 名词解释 37 | SQL指纹: 38 | 39 | 针对参数做规约化之后的SQL语句; 40 | 形如 select id, col1, col2 from table where cond1 = ? and cond2 = ?; 41 | 通常由同一组业务服务发出; 42 | 43 | ## 3. 设计目标 44 | 45 | ### 3.1 实现的功能 46 | 47 | 1. MySQL主要协议兼容 48 | 2. MySQL服务 High Available & Failover 49 | 3. LB(Load Balance) & 资源分配 & 流量控制 50 | 4. 内部信息统计 51 | 5. 运维管理支持 52 | 53 | ### 3.2 需求详细描述 54 | 55 | #### 3.2.1 MySQL主要协议兼容 56 | 57 | * 完全支持 [MySQL Internals](https://dev.mysql.com/doc/internals/en/client-server-protocol.html) 中Chapter 14 中 1-4,6-7中的用户认证,包格式,Text Protocol和Prepare Statement协议 58 | * 完全支持 [SQL Syntax](http://dev.mysql.com/doc/refman/5.6/en/sql-syntax.html) 规范中所有SQL语句,支持大部分CUID操作的SQL指纹识别和聚类 59 | * 支持各语言常用的Client Library,保证客户端库的兼容性 60 | 61 | #### 3.2.2 MySQL HA & Failover 62 | * 主库宕机不影响读 63 | * 多从库其一宕机不影响服务 64 | * 主从库平滑切换,平滑上下线 65 | * 数据库故障自动检测 66 | 67 | #### 3.2.3 Load Balance & 资源分配 & 流量控制 68 | * 读写分离 69 | * 注释式主读 70 | * IP白名单/黑名单做资源分配 71 | * 基于SQL指纹的动态流量控制 72 | 73 | #### 3.2.4 内部信息统计 74 | * 库表级别的QPS统计 75 | * SQL指纹分类统计 76 | * 慢查询统计 77 | 78 | #### 3.2.5 运维管理支持 79 | * 管理端口,直接支持SQL语言进行管理 80 | 81 | 82 | ## 4. 系统方案 83 | 84 | ### 4.1 相关调研 85 | 86 | #### 4.1.1 相关开源项目 87 | 88 | 1. [MySQL Proxy](https://dev.mysql.com/doc/mysql-proxy/en/) 89 | 90 | Oracle公司官方出品。目前已经停止维护。没有GA版本,因此不推荐在生产环境中使用。 91 | 92 | 2. [Qihoo360 Atlas](https://github.com/Qihoo360/Atlas) 93 | 94 | 基于 MySQL Proxy 0.8.2 版本进行二次开发的Proxy。具备读写分离,连接池,负载均衡,IP白名单等功能 95 | 96 | 3. [SOHU-DBProxy](https://github.com/SOHUDBA/SOHU-DBProxy) 97 | 98 | 基于 MySQL Proxy 0.8.3 版本进行二次开发的Proxy。除了具备Atlas同等的基本功能外,还添加了基于配置文件的SQL语句的过滤和审核功能 99 | 100 | 4. [Youtube Vitess](http://vitess.io/) 101 | 102 | Youtube开源的一个MySQL机器集群方案。不兼容MySQL Client-Server Protocol,必须使用官方提供的Client Library 103 | 104 | 5. [MyCAT](https://github.com/MyCATApache/Mycat-Server) 105 | 106 | 开源社区维护的,支持分库分表的MySQL中间层代理。但是由于本身是分布式的,有一定事务及跨库查询上面的限制 107 | 108 | #### 4.1.2 开源可选方案的功能完备性调研 109 | 110 | 111 | | 需求 | Atlas | SOHU-DBProxy | Vitess | MyCAT | 112 | | ------------- |:--:| :-----:|:----:|:---:| 113 | | 协议兼容性 | ✔︎ | ✔︎ | ✖︎ | ✔︎ | 114 | | Load Balance | ✔︎ | ✔︎ | ✔︎ | ✔︎ | 115 | | High Available | ✔︎ | ✔︎ | ✔︎ | ✔︎ | 116 | | Failover | ✔︎ | ✔︎ | ✔︎ | ✔︎ | 117 | | 资源分配 | O | O | ✔︎ | ✖︎ | 118 | | 流量控制 | ✖︎ | O | ✖︎ | ✖︎ | 119 | | 内部信息统计 | ✖︎ | O | ✖︎ | O | 120 | | 运维管理 | O | O | ✔︎ | O | 121 | 122 | 注: 123 | 124 | ✔︎ : 完全满足需求 125 | ✖︎ : 完全不满足需求 126 | O : 部分满足需求 127 | 128 | ### 4.2 总体思路 129 | 130 | 以架构的角度来看,中间层实现可以分为2大类: 131 | 132 | 1. 客户端 133 | 即实现为语言相关的 Library. 在Library内部去提供相关能力。淘宝的TTDL即为这一类 134 | 2. 服务端 135 | 即以一个独立的服务,提供完备的MySQL Server同等能力的服务 136 | 137 | 由于客户端的实现是语言相关的,因此我们不考虑以客户端的方式来实现。 138 | 139 | ### 4.3 技术选型 140 | 141 | * 语言: Golang 1.5 + 142 | * 平台: 支持Linux, Mac 10.9+ 143 | * 框架: 无 144 | 145 | 技术选型的理由: 研发技术储备能够cover + 主要研发者语言方面的兴趣 146 | 147 | ### 4.4 运维及容灾考虑 148 | 待补充 149 | 150 | ## 5. 系统设计 151 | 152 | ### 5.1 总体架构 153 | 154 | 总体架构图如[第一节]()所列 155 | 156 | ### 5.2 模块拆分 157 | 158 | 根据分层模型来看,系统从前到后主要分为如下几个模块: 159 | 160 | +--------------------------------+ 161 | | Network Module | 162 | +--------------------------------+ 163 | | Package Module | 164 | +--------------------+-----------+ 165 | | Auth Module | | 166 | +--------------------+ | 167 | | Protocol Layer | 168 | +--------------------------------+ 169 | | Load Balance | 170 | +--------------------------------+ 171 | | Backend Manager Module | 172 | +--------------+-----------------+ 173 | | Failover | Connection Pool | 174 | +--------------+-----------------+ 175 | 176 | 177 | 根据包的处理流程,我们有如下流程图: 178 | 179 | 180 | 181 | ### 5.3 容灾及降级方案 182 | 待补充 183 | 184 | ## 6. 详细设计 185 | ### 6.1 模块与层次划分 186 | 187 | #### 网络模型 188 | 网络IO模型,基于IO事件触发多路IO复用非阻塞。 189 | 多Routine模型,区分管理Routine和工作Routine。 190 | 191 | #### MySQL协议交互模块 192 | 大部分时候PROXY并不解包,只进行数据的透传。 193 | 准入控制时候需要解包以及构造授权包。 194 | 管理进程返回PROXY状态构造MySQL结果包。 195 | 196 | #### 权限与准入模块 197 | PROXY独立的授权用户名和密码。 198 | IP白名单和黑名单。 199 | 前端连接数限制。 200 | 服务授权限制。 201 | 202 | #### 配置管理模块 203 | 配置文件结构设计。 204 | 配置文件的热加载。 205 | 206 | #### 连接处理模块 207 | 后端连接池。 208 | 负载均衡策略。 209 | 主动屏蔽不可用Server。 210 | 连接超时检测处理。 211 | 连接失败重连。 212 | 213 | #### 读写分离模块 214 | SQL解析。 215 | 读写分离策略。 216 | 217 | #### 日志模块 218 | 请求的整个链路记录。 219 | 支持构建日志平台,开放给RD自己查错。 220 | 221 | 222 | ## 7. 部署及运维设计 223 | ### 7.1 服务架构 224 | #### 7.1.1 服务分布 225 | 包括新上线服务的机器数量,服务的模块以及模块如何分布在服务器上;对其他服务是否有关联或数据交互 226 | 227 | #### 7.1.2 数据流向 228 | 说明线上服务的数据流向,即简要说明线上服务模块是如何工作的 229 | 230 | #### 7.1.3 服务类型 231 | 标明服务各模块的类型是cpu消耗型、io消耗型、大容量磁盘空间消耗型 232 | 233 | #### 7.1.4 资源情况 234 | 包括预计流量、数据量、第三方软件使用情况 235 | 236 | ### 7.2 运维设计 237 | #### 7.2.1 服务冗余性 238 | 服务各模块的冗余考虑,如果某个模块所在服务器出现问题如何实现冗余来保证服务不受影响 239 | 240 | #### 7.2.2 服务可维护性 241 | 包括数据损坏后如何修复,减少单点服务,如果无法避免单点尽量减少单点功能,动态数据正确性的检查 242 | 243 | #### 7.2.3 服务可扩展性 244 | 当服务访问量或数据量达到上限,服务如何继续保持可扩展性 245 | 246 | #### 7.2.4 服务监控 247 | 服务需要的监控方式及dashboard地址等 248 | 249 | ### 7.3 上线及回滚方案 250 | 给出上线方案分析,并分析哪些环节可能导致回滚,给出回滚的风险评估,给出如何避免、处理回滚的具体措施 251 | 252 | ## 8. FAQ -------------------------------------------------------------------------------- /docs/mysql-proxy/scripting.txt: -------------------------------------------------------------------------------- 1 | Hooks 2 | ===== 3 | 4 | connect_server 5 | -------------- 6 | 7 | read_auth 8 | --------- 9 | 10 | read_auth_result 11 | ---------------- 12 | 13 | read_query 14 | ---------- 15 | 16 | read_query_result 17 | ----------------- 18 | 19 | disconnect_client 20 | ----------------- 21 | 22 | Modules 23 | ======= 24 | 25 | mysql.proto 26 | ----------- 27 | 28 | The ``mysql.proto`` module provides encoders and decoders for the packets exchanged between client and server 29 | 30 | 31 | from_err_packet 32 | ............... 33 | 34 | Decodes a ERR-packet into a table. 35 | 36 | Parameters: 37 | 38 | ``packet`` 39 | (string) mysql packet 40 | 41 | 42 | On success it returns a table containing: 43 | 44 | ``errmsg`` 45 | (string) 46 | 47 | ``sqlstate`` 48 | (string) 49 | 50 | ``errcode`` 51 | (int) 52 | 53 | Otherwise it raises an error. 54 | 55 | to_err_packet 56 | ............. 57 | 58 | Encode a table containing a ERR packet into a MySQL packet. 59 | 60 | Parameters: 61 | 62 | ``err`` 63 | (table) 64 | 65 | ``errmsg`` 66 | (string) 67 | 68 | ``sqlstate`` 69 | (string) 70 | 71 | ``errcode`` 72 | (int) 73 | 74 | into a MySQL packet. 75 | 76 | Returns a string. 77 | 78 | from_ok_packet 79 | .............. 80 | 81 | Decodes a OK-packet 82 | 83 | ``packet`` 84 | (string) mysql packet 85 | 86 | 87 | On success it returns a table containing: 88 | 89 | ``server_status`` 90 | (int) bit-mask of the connection status 91 | 92 | ``insert_id`` 93 | (int) last used insert id 94 | 95 | ``warnings`` 96 | (int) number of warnings for the last executed statement 97 | 98 | ``affected_rows`` 99 | (int) rows affected by the last statement 100 | 101 | Otherwise it raises an error. 102 | 103 | 104 | to_ok_packet 105 | ............ 106 | 107 | Encode a OK packet 108 | 109 | from_eof_packet 110 | ............... 111 | 112 | Decodes a EOF-packet 113 | 114 | Parameters: 115 | 116 | ``packet`` 117 | (string) mysql packet 118 | 119 | 120 | On success it returns a table containing: 121 | 122 | ``server_status`` 123 | (int) bit-mask of the connection status 124 | 125 | ``warnings`` 126 | (int) 127 | 128 | Otherwise it raises an error. 129 | 130 | 131 | to_eof_packet 132 | ............. 133 | 134 | from_challenge_packet 135 | ..................... 136 | 137 | Decodes a auth-challenge-packet 138 | 139 | Parameters: 140 | 141 | ``packet`` 142 | (string) mysql packet 143 | 144 | On success it returns a table containing: 145 | 146 | ``protocol_version`` 147 | (int) version of the mysql protocol, usually 10 148 | 149 | ``server_version`` 150 | (int) version of the server as integer: 50506 is MySQL 5.5.6 151 | 152 | ``thread_id`` 153 | (int) connection id 154 | 155 | ``capabilities`` 156 | (int) bit-mask of the server capabilities 157 | 158 | ``charset`` 159 | (int) server default character-set 160 | 161 | ``server_status`` 162 | (int) bit-mask of the connection-status 163 | 164 | ``challenge`` 165 | (string) password challenge 166 | 167 | 168 | to_challenge_packet 169 | ................... 170 | 171 | Encode a auth-response-packet 172 | 173 | from_response_packet 174 | .................... 175 | 176 | Decodes a auth-response-packet 177 | 178 | Parameters: 179 | 180 | ``packet`` 181 | (string) mysql packet 182 | 183 | 184 | to_response_packet 185 | .................. 186 | 187 | from_masterinfo_string 188 | ...................... 189 | 190 | Decodes the content of the ``master.info`` file. 191 | 192 | 193 | to_masterinfo_string 194 | .................... 195 | 196 | from_stmt_prepare_packet 197 | ........................ 198 | 199 | Decodes a COM_STMT_PREPARE-packet 200 | 201 | Parameters: 202 | 203 | ``packet`` 204 | (string) mysql packet 205 | 206 | 207 | On success it returns a table containing: 208 | 209 | ``stmt_text`` 210 | (string) 211 | text of the prepared statement 212 | 213 | Otherwise it raises an error. 214 | 215 | from_stmt_prepare_ok_packet 216 | ........................... 217 | 218 | Decodes a COM_STMT_PACKET OK-packet 219 | 220 | Parameters: 221 | 222 | ``packet`` 223 | (string) mysql packet 224 | 225 | 226 | On success it returns a table containing: 227 | 228 | ``stmt_id`` 229 | (int) statement-id 230 | 231 | ``num_columns`` 232 | (int) number of columns in the resultset 233 | 234 | ``num_params`` 235 | (int) number of parameters 236 | 237 | ``warnings`` 238 | (int) warnings generated by the prepare statement 239 | 240 | Otherwise it raises an error. 241 | 242 | 243 | from_stmt_execute_packet 244 | ........................ 245 | 246 | Decodes a COM_STMT_EXECUTE-packet 247 | 248 | Parameters: 249 | 250 | ``packet`` 251 | (string) mysql packet 252 | 253 | ``num_params`` 254 | (int) number of parameters of the corresponding prepared statement 255 | 256 | On success it returns a table containing: 257 | 258 | ``stmt_id`` 259 | (int) statemend-id 260 | 261 | ``flags`` 262 | (int) flags describing the kind of cursor used 263 | 264 | ``iteration_count`` 265 | (int) iteration count: always 1 266 | 267 | ``new_params_bound`` 268 | (bool) 269 | 270 | ``params`` 271 | (nil, table) 272 | number-index array of parameters if ``new_params_bound`` is ``true`` 273 | 274 | Each param is a table of: 275 | 276 | ``type`` 277 | (int) 278 | MYSQL_TYPE_INT, MYSQL_TYPE_STRING ... and so on 279 | 280 | ``value`` 281 | (nil, number, string) 282 | if the value is a NULL, it ``nil`` 283 | if it is a number (_INT, _DOUBLE, ...) it is a ``number`` 284 | otherwise it is a ``string`` 285 | 286 | If decoding fails it raises an error. 287 | 288 | To get the ``num_params`` for this function, you have to track the track the number of parameters as returned 289 | by the `from_stmt_prepare_ok_packet`_. Use `stmt_id_from_stmt_execute_packet`_ to get the ``statement-id`` from 290 | the COM_STMT_EXECUTE packet and lookup your tracked information. 291 | 292 | stmt_id_from_stmt_execute_packet 293 | ................................ 294 | 295 | Decodes statement-id from a COM_STMT_EXECUTE-packet 296 | 297 | Parameters: 298 | 299 | ``packet`` 300 | (string) mysql packet 301 | 302 | 303 | On success it returns the ``statement-id`` as ``int``. 304 | 305 | Otherwise it raises an error. 306 | 307 | from_stmt_close_packet 308 | ...................... 309 | 310 | Decodes a COM_STMT_CLOSE-packet 311 | 312 | Parameters: 313 | 314 | ``packet`` 315 | (string) mysql packet 316 | 317 | 318 | On success it returns a table containing: 319 | 320 | ``stmt_id`` 321 | (int) 322 | statement-id that shall be closed 323 | 324 | Otherwise it raises an error. 325 | 326 | 327 | -------------------------------------------------------------------------------- /hack/hack.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "reflect" 5 | "unsafe" 6 | ) 7 | 8 | // String provides no copy to change slice to string 9 | // use your own risk 10 | func String(b []byte) (s string) { 11 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 12 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 13 | pstring.Data = pbytes.Data 14 | pstring.Len = pbytes.Len 15 | return 16 | } 17 | 18 | // Slice provides no copy to change string to slice 19 | // use your own risk 20 | func Slice(s string) (b []byte) { 21 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 22 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 23 | pbytes.Data = pstring.Data 24 | pbytes.Len = pstring.Len 25 | pbytes.Cap = pstring.Len 26 | return 27 | } 28 | -------------------------------------------------------------------------------- /hack/hack_test.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestString(t *testing.T) { 9 | b := []byte("hello world") 10 | a := String(b) 11 | 12 | if a != "hello world" { 13 | t.Fatal(a) 14 | } 15 | 16 | b[0] = 'a' 17 | 18 | if a != "aello world" { 19 | t.Fatal(a) 20 | } 21 | 22 | b = append(b, "abc"...) 23 | if a != "aello world" { 24 | t.Fatal(a) 25 | } 26 | } 27 | 28 | func TestByte(t *testing.T) { 29 | a := "hello world" 30 | 31 | b := Slice(a) 32 | 33 | if !bytes.Equal(b, []byte("hello world")) { 34 | t.Fatal(string(b)) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /log/logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) All Rights Reserved 2 | // @file logger.go 3 | // @author 王靖 (wangjild@gmail.com) 4 | // @date 14-11-25 20:02:50 5 | // @version $Revision: 1.0 $ 6 | // @brief 7 | 8 | package log 9 | 10 | import ( 11 | "crypto/rand" 12 | "fmt" 13 | "math/big" 14 | "path/filepath" 15 | "runtime" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | ) 20 | 21 | // SysLog 系统Log 22 | var SysLog *ProxyLogger = nil 23 | 24 | // AppLog 应用Log 25 | var AppLog *ProxyLogger = nil 26 | 27 | // Logger the log.Logger wrapper 28 | type ProxyLogger struct { 29 | l *Logger 30 | } 31 | 32 | func logidGenerator() string { 33 | if i, err := rand.Int(rand.Reader, big.NewInt(1<<30-1)); err != nil { 34 | return "0" 35 | } else { 36 | return i.String() 37 | } 38 | } 39 | 40 | func comMessage(strfmt string, args ...interface{}) map[string]string { 41 | pc, file, line, ok := runtime.Caller(2) 42 | if !ok { 43 | file = "?" 44 | line = 0 45 | } 46 | fn := runtime.FuncForPC(pc) 47 | var fnName string 48 | if fn == nil { 49 | fnName = "?()" 50 | } else { 51 | dotName := filepath.Ext(fn.Name()) 52 | fnName = strings.TrimLeft(dotName, ".") + "()" 53 | } 54 | ret := map[string]string{ 55 | "file": filepath.Base(file) + ":" + strconv.Itoa(line), 56 | "func": fnName, 57 | "msg": fmt.Sprintf(strfmt, args...), 58 | } 59 | 60 | return ret 61 | } 62 | 63 | // Notice print notice message to logfile 64 | func (lg *ProxyLogger) Notice(strfmt string, args ...interface{}) { 65 | lg.l.Notice(comMessage(strfmt, args...), logidGenerator()) 66 | } 67 | 68 | // Debug print debug message to logfile 69 | func (lg *ProxyLogger) Debug(strfmt string, args ...interface{}) { 70 | lg.l.Debug(comMessage(strfmt, args...), logidGenerator()) 71 | } 72 | 73 | // Warn print warning message to logfile 74 | func (lg *ProxyLogger) Warn(strfmt string, args ...interface{}) { 75 | lg.l.Warn(comMessage(strfmt, args...), logidGenerator()) 76 | } 77 | 78 | // Fatal print fatal message to logfile 79 | func (lg *ProxyLogger) Fatal(strfmt string, args ...interface{}) { 80 | lg.l.Fatal(comMessage(strfmt, args...), logidGenerator()) 81 | } 82 | 83 | // Config Config of One Log Instance 84 | type Config struct { 85 | FilePath string 86 | LogLevel int 87 | AppTag string 88 | } 89 | 90 | func init() { 91 | realInit(&Config{FilePath: "/dev/stdout", LogLevel: 0}, 92 | &Config{FilePath: "/dev/stdout", LogLevel: 0}) 93 | } 94 | 95 | var once sync.Once 96 | 97 | func Init(syslog, applog *Config) { 98 | f := func() { 99 | realInit(syslog, applog) 100 | } 101 | once.Do(f) 102 | } 103 | 104 | func realInit(syslog, applog *Config) { 105 | SysLog = &ProxyLogger{ 106 | l: NewLogger(syslog.FilePath), 107 | } 108 | SysLog.l.SetLevel(syslog.LogLevel) 109 | SysLog.l.SetAppTag(defaultAppTag()) 110 | 111 | AppLog = &ProxyLogger{ 112 | l: NewLogger(applog.FilePath), 113 | } 114 | AppLog.l.SetLevel(applog.LogLevel) 115 | AppLog.l.SetAppTag(defaultAppTag()) 116 | } 117 | 118 | func defaultAppTag() string { 119 | return "mysql-proxy" 120 | } 121 | 122 | /* vim: set expandtab ts=4 sw=4 */ 123 | -------------------------------------------------------------------------------- /log/milog.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "encoding/json" 5 | "os" 6 | "strings" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | const ( 12 | LevelTrace = iota 13 | LevelDebug 14 | LevelNotice 15 | LevelWarn 16 | LevelFatal 17 | ) 18 | 19 | const ( 20 | DefaultKey = "DefaultKey" 21 | DefaultLogId = "000000000000" 22 | DefaultAppTag = "DefaultAppTag" 23 | ) 24 | 25 | // map TRACE, DEBUG, NOTICE, WARN, FATAL to 0, 1, 2, 3, 4 26 | var ( 27 | Level = []string{"TRACE", "DEBUG", "NOTICE", "WARN", "FATAL"} 28 | ) 29 | 30 | type Logger struct { 31 | logfd *os.File 32 | level int 33 | apptag string 34 | hostname string 35 | lock *sync.Mutex 36 | } 37 | 38 | // NewLogger return a Logger instance, 39 | // the params is filename and apptag is optional 40 | func NewLogger(filename string, apptag ...string) *Logger { 41 | var realAppTag string 42 | if len(apptag) == 0 { 43 | realAppTag = DefaultAppTag 44 | } else { 45 | realAppTag = apptag[0] 46 | } 47 | 48 | // panic when can not get hostname 49 | hostname, err := os.Hostname() 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | // panic when can not open log file 55 | logfd, err := os.OpenFile( 56 | filename, 57 | os.O_CREATE|os.O_APPEND|os.O_WRONLY, 58 | 0644, 59 | ) 60 | if err != nil { 61 | panic(err) 62 | } 63 | 64 | return &Logger{ 65 | logfd: logfd, 66 | level: LevelNotice, 67 | apptag: realAppTag, 68 | hostname: hostname, 69 | lock: new(sync.Mutex), 70 | } 71 | } 72 | 73 | // SetLevel set the log level, default is LevelNotice 74 | func (this *Logger) SetLevel(level int) { 75 | this.level = level 76 | } 77 | 78 | // SetAppTag set log's apptag 79 | func (this *Logger) SetAppTag(apptag string) { 80 | this.apptag = apptag 81 | } 82 | 83 | // Level return this logger's level 84 | func (this *Logger) Level() int { 85 | return this.level 86 | } 87 | 88 | // AppTag return this logger's apptag 89 | func (this *Logger) AppTag() string { 90 | return this.apptag 91 | } 92 | 93 | // It should be locked while calling write method. 94 | func (this *Logger) write(msg string) error { 95 | if !strings.HasSuffix(msg, "\n") { 96 | msg = msg + "\n" 97 | } 98 | 99 | this.lock.Lock() 100 | defer this.lock.Unlock() 101 | _, err := this.logfd.WriteString(msg) 102 | return err 103 | } 104 | 105 | // Flush will write all logs from os's buffer to disk 106 | func (this *Logger) Flush() { 107 | this.logfd.Sync() 108 | this.logfd.Close() 109 | } 110 | 111 | func (this *Logger) Trace(v interface{}, logid ...string) error { 112 | if !this.suitLevel(LevelTrace) { 113 | return nil 114 | } 115 | return this.write(this.format(LevelTrace, v, logid)) 116 | } 117 | 118 | func (this *Logger) Debug(v interface{}, logid ...string) error { 119 | if !this.suitLevel(LevelDebug) { 120 | return nil 121 | } 122 | return this.write(this.format(LevelDebug, v, logid)) 123 | } 124 | 125 | func (this *Logger) Notice(v interface{}, logid ...string) error { 126 | if !this.suitLevel(LevelNotice) { 127 | return nil 128 | } 129 | return this.write(this.format(LevelNotice, v, logid)) 130 | } 131 | 132 | func (this *Logger) Warn(v interface{}, logid ...string) error { 133 | if !this.suitLevel(LevelWarn) { 134 | return nil 135 | } 136 | return this.write(this.format(LevelWarn, v, logid)) 137 | } 138 | 139 | func (this *Logger) Fatal(v interface{}, logid ...string) error { 140 | if !this.suitLevel(LevelFatal) { 141 | return nil 142 | } 143 | return this.write(this.format(LevelFatal, v, logid)) 144 | } 145 | 146 | func (this *Logger) suitLevel(level int) bool { 147 | if level < this.level { 148 | return false 149 | } 150 | return true 151 | } 152 | 153 | // format generate a standard line of log 154 | func (this *Logger) format(level int, v interface{}, logid []string) string { 155 | var id string 156 | if len(logid) > 0 { 157 | id = logid[0] 158 | } else { 159 | id = DefaultLogId 160 | } 161 | 162 | prefix := "" 163 | var logTuples = []string{ 164 | time.Now().Format("2006-01-02 15:04:05"), 165 | this.apptag, 166 | this.hostname, 167 | Level[level], 168 | id, 169 | } 170 | 171 | for _, item := range logTuples { 172 | prefix += "[" + item + "] " 173 | } 174 | 175 | var ( 176 | body []byte 177 | err error 178 | ) 179 | body, err = json.Marshal(v) 180 | if err != nil { 181 | body, _ = json.Marshal( 182 | map[string]interface{}{ 183 | DefaultKey: v, 184 | }, 185 | ) 186 | } 187 | 188 | return prefix + string(body) 189 | } 190 | 191 | // StdContent used to store temporary log content 192 | type StdContent struct { 193 | data map[string]interface{} 194 | logger *Logger 195 | lock *sync.Mutex 196 | } 197 | 198 | // NewStdContent return an temporary StdContent 199 | func (this *Logger) NewStdContent() *StdContent { 200 | return &StdContent{ 201 | data: make(map[string]interface{}), 202 | logger: this, 203 | lock: &sync.Mutex{}, 204 | } 205 | } 206 | 207 | // SetVal add (key, value) pair to log 208 | func (sc *StdContent) SetVal(key string, val interface{}) { 209 | sc.lock.Lock() 210 | sc.data[key] = val 211 | sc.lock.Unlock() 212 | } 213 | 214 | func (sc *StdContent) Trace(logid ...string) error { 215 | if !sc.logger.suitLevel(LevelTrace) { 216 | return nil 217 | } 218 | formatStr := sc.logger.format(LevelTrace, sc.data, logid) 219 | return sc.logger.write(formatStr) 220 | } 221 | 222 | func (sc *StdContent) Debug(logid ...string) error { 223 | if !sc.logger.suitLevel(LevelDebug) { 224 | return nil 225 | } 226 | formatStr := sc.logger.format(LevelDebug, sc.data, logid) 227 | return sc.logger.write(formatStr) 228 | } 229 | 230 | func (sc *StdContent) Notice(logid ...string) error { 231 | if !sc.logger.suitLevel(LevelNotice) { 232 | return nil 233 | } 234 | formatStr := sc.logger.format(LevelNotice, sc.data, logid) 235 | return sc.logger.write(formatStr) 236 | } 237 | 238 | func (sc *StdContent) Warn(logid ...string) error { 239 | if !sc.logger.suitLevel(LevelWarn) { 240 | return nil 241 | } 242 | formatStr := sc.logger.format(LevelWarn, sc.data, logid) 243 | return sc.logger.write(formatStr) 244 | } 245 | 246 | func (sc *StdContent) Fatal(logid ...string) error { 247 | if !sc.logger.suitLevel(LevelFatal) { 248 | return nil 249 | } 250 | formatStr := sc.logger.format(LevelFatal, sc.data, logid) 251 | return sc.logger.write(formatStr) 252 | } 253 | -------------------------------------------------------------------------------- /log/milog_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestLog(t *testing.T) { 8 | logger := NewLogger("test.log", "testtag") 9 | for i := 0; i <= LevelFatal; i++ { 10 | logger.SetLevel(i) 11 | outs(logger) 12 | } 13 | 14 | logger.Flush() 15 | } 16 | 17 | func outs(logger *Logger) { 18 | logger.Trace("trace") 19 | logger.Debug("debug") 20 | logger.Notice("notice") 21 | logger.Warn("warn") 22 | logger.Fatal("fatal") 23 | } 24 | 25 | func TestSetLevel(t *testing.T) { 26 | logger := NewLogger("test.log", "testtag") 27 | logger.SetAppTag("testtag") 28 | _ = logger.Level() 29 | _ = logger.AppTag() 30 | } 31 | 32 | func TestLevel(t *testing.T) { 33 | logger := NewLogger("test.log", "testtag") 34 | 35 | levels := []string{"Trace", "Debug", "Notice", "Warn", "Fatal"} 36 | 37 | levelMap := map[string]int{ 38 | "Trace": LevelTrace, 39 | "Debug": LevelDebug, 40 | "Notice": LevelNotice, 41 | "Warn": LevelWarn, 42 | "Fatal": LevelFatal, 43 | } 44 | for i, level := range levels { 45 | logger.SetLevel(levelMap[level]) 46 | if logger.level != i { 47 | t.Error("Test Log SetLevel Error") 48 | } 49 | } 50 | logger.Flush() 51 | t.Log("Test Level Pass") 52 | } 53 | 54 | func TestLogContentLog(t *testing.T) { 55 | logger := NewLogger("test.log") 56 | logger.SetLevel(LevelDebug) 57 | 58 | sc := logger.NewStdContent() 59 | t.Log("Test NewStdContent With Id Pass") 60 | 61 | mkContent(sc) 62 | 63 | scNoId := logger.NewStdContent() 64 | t.Log("Test NewStdContent Without Id Pass") 65 | mkContent(scNoId) 66 | 67 | for i := 0; i <= LevelFatal; i++ { 68 | logger.SetLevel(i) 69 | stdOuts(sc, "LogIddddddd") 70 | stdOuts(scNoId) 71 | } 72 | logger.Flush() 73 | } 74 | 75 | func stdOuts(sc *StdContent, logId ...string) { 76 | sc.Trace(logId...) 77 | sc.Debug(logId...) 78 | sc.Notice(logId...) 79 | sc.Warn(logId...) 80 | sc.Fatal(logId...) 81 | } 82 | 83 | func BenchmarkLog(b *testing.B) { 84 | logger := NewLogger("test.log") 85 | logger.SetLevel(LevelDebug) 86 | for i := 0; i < b.N; i++ { 87 | sc := logger.NewStdContent() 88 | mkBenchContent(sc) 89 | sc.Trace("LogIddddddd") 90 | sc.Debug("LogIddddddd") 91 | sc.Notice("LogIddddddd") 92 | sc.Warn("LogIddddddd") 93 | sc.Fatal("LogIddddddd") 94 | } 95 | logger.Flush() 96 | } 97 | 98 | func mkBenchContent(sc *StdContent) { 99 | sc.SetVal("key0", "val0") 100 | sc.SetVal("nullval", nil) 101 | sc.SetVal("123", 123) 102 | } 103 | 104 | func mkContent(sc *StdContent) { 105 | sc.SetVal("key0", "val0") 106 | sc.SetVal("arr", []string{"val0", "val1", "val2"}) 107 | sc.SetVal("hash", map[string]string{ 108 | "v0": "val0", 109 | "v1": "val1", 110 | "v2": "val2", 111 | }) 112 | sc.SetVal("nullval", nil) 113 | var interfaceVal interface{} 114 | interfaceVal = []interface{}{"123", 123, map[string]string{"v4": "val4"}} 115 | sc.SetVal("interface", interfaceVal) 116 | sc.SetVal("123", 123) 117 | } 118 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | all: build run 2 | 3 | build: 4 | go install ./... 5 | 6 | run: 7 | cd cmd/proxy && go build && cd - 8 | GOGC=1000 ./cmd/proxy/proxy --config=etc/proxy_single.yaml --logfile=log/proxy.log --loglevel=0 9 | 10 | clean: 11 | go clean -i ./... 12 | 13 | test: 14 | go test ./... 15 | 16 | package: build 17 | tar cvf output.tar etc/ cmd/proxy/proxy run.sh 18 | -------------------------------------------------------------------------------- /mysql/const.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | const ( 4 | MinProtocolVersion byte = 10 5 | MaxPayloadLen int = 1<<24 - 1 6 | TimeFormat string = "2006-01-02 15:04:05" 7 | ServerVersion string = "5.5.31-go_proxy-0.1" 8 | ) 9 | 10 | const ( 11 | OK_HEADER byte = 0x00 12 | ERR_HEADER byte = 0xff 13 | EOF_HEADER byte = 0xfe 14 | LocalInFile_HEADER byte = 0xfb 15 | ) 16 | 17 | const ( 18 | SERVER_STATUS_IN_TRANS uint16 = 0x0001 19 | SERVER_STATUS_AUTOCOMMIT uint16 = 0x0002 20 | SERVER_MORE_RESULTS_EXISTS uint16 = 0x0008 21 | SERVER_STATUS_NO_GOOD_INDEX_USED uint16 = 0x0010 22 | SERVER_STATUS_NO_INDEX_USED uint16 = 0x0020 23 | SERVER_STATUS_CURSOR_EXISTS uint16 = 0x0040 24 | SERVER_STATUS_LAST_ROW_SEND uint16 = 0x0080 25 | SERVER_STATUS_DB_DROPPED uint16 = 0x0100 26 | SERVER_STATUS_NO_BACKSLASH_ESCAPED uint16 = 0x0200 27 | SERVER_STATUS_METADATA_CHANGED uint16 = 0x0400 28 | SERVER_QUERY_WAS_SLOW uint16 = 0x0800 29 | SERVER_PS_OUT_PARAMS uint16 = 0x1000 30 | ) 31 | 32 | const ( 33 | COM_SLEEP byte = iota 34 | COM_QUIT 35 | COM_INIT_DB 36 | COM_QUERY 37 | COM_FIELD_LIST 38 | COM_CREATE_DB 39 | COM_DROP_DB 40 | COM_REFRESH 41 | COM_SHUTDOWN 42 | COM_STATISTICS 43 | COM_PROCESS_INFO 44 | COM_CONNECT 45 | COM_PROCESS_KILL 46 | COM_DEBUG 47 | COM_PING 48 | COM_TIME 49 | COM_DELAYED_INSERT 50 | COM_CHANGE_USER 51 | COM_BINLOG_DUMP 52 | COM_TABLE_DUMP 53 | COM_CONNECT_OUT 54 | COM_REGISTER_SLAVE 55 | COM_STMT_PREPARE 56 | COM_STMT_EXECUTE // 23 57 | COM_STMT_SEND_LONG_DATA // 24 58 | COM_STMT_CLOSE 59 | COM_STMT_RESET 60 | COM_SET_OPTION 61 | COM_STMT_FETCH 62 | COM_DAEMON 63 | COM_BINLOG_DUMP_GTID 64 | COM_RESET_CONNECTION 65 | ) 66 | 67 | const ( 68 | CLIENT_LONG_PASSWORD uint32 = 1 << iota 69 | CLIENT_FOUND_ROWS 70 | CLIENT_LONG_FLAG 71 | CLIENT_CONNECT_WITH_DB 72 | CLIENT_NO_SCHEMA 73 | CLIENT_COMPRESS 74 | CLIENT_ODBC 75 | CLIENT_LOCAL_FILES 76 | CLIENT_IGNORE_SPACE 77 | CLIENT_PROTOCOL_41 78 | CLIENT_INTERACTIVE 79 | CLIENT_SSL 80 | CLIENT_IGNORE_SIGPIPE 81 | CLIENT_TRANSACTIONS 82 | CLIENT_RESERVED 83 | CLIENT_SECURE_CONNECTION 84 | CLIENT_MULTI_STATEMENTS 85 | CLIENT_MULTI_RESULTS 86 | CLIENT_PS_MULTI_RESULTS 87 | CLIENT_PLUGIN_AUTH 88 | CLIENT_CONNECT_ATTRS 89 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 90 | ) 91 | 92 | const ( 93 | MYSQL_TYPE_DECIMAL byte = iota 94 | MYSQL_TYPE_TINY 95 | MYSQL_TYPE_SHORT 96 | MYSQL_TYPE_LONG 97 | MYSQL_TYPE_FLOAT 98 | MYSQL_TYPE_DOUBLE 99 | MYSQL_TYPE_NULL 100 | MYSQL_TYPE_TIMESTAMP 101 | MYSQL_TYPE_LONGLONG 102 | MYSQL_TYPE_INT24 103 | MYSQL_TYPE_DATE 104 | MYSQL_TYPE_TIME 105 | MYSQL_TYPE_DATETIME 106 | MYSQL_TYPE_YEAR 107 | MYSQL_TYPE_NEWDATE 108 | MYSQL_TYPE_VARCHAR 109 | MYSQL_TYPE_BIT 110 | ) 111 | 112 | const ( 113 | MYSQL_TYPE_NEWDECIMAL byte = iota + 0xf6 114 | MYSQL_TYPE_ENUM 115 | MYSQL_TYPE_SET 116 | MYSQL_TYPE_TINY_BLOB 117 | MYSQL_TYPE_MEDIUM_BLOB 118 | MYSQL_TYPE_LONG_BLOB 119 | MYSQL_TYPE_BLOB 120 | MYSQL_TYPE_VAR_STRING 121 | MYSQL_TYPE_STRING 122 | MYSQL_TYPE_GEOMETRY 123 | ) 124 | 125 | const ( 126 | NOT_NULL_FLAG = 1 127 | PRI_KEY_FLAG = 2 128 | UNIQUE_KEY_FLAG = 4 129 | BLOB_FLAG = 16 130 | UNSIGNED_FLAG = 32 131 | ZEROFILL_FLAG = 64 132 | BINARY_FLAG = 128 133 | ENUM_FLAG = 256 134 | AUTO_INCREMENT_FLAG = 512 135 | TIMESTAMP_FLAG = 1024 136 | SET_FLAG = 2048 137 | NUM_FLAG = 32768 138 | PART_KEY_FLAG = 16384 139 | GROUP_FLAG = 32768 140 | UNIQUE_FLAG = 65536 141 | ) 142 | 143 | const ( 144 | AUTH_NAME = "mysql_native_password" 145 | ) 146 | 147 | const ( 148 | CURSOR_TYPE_NO_CURSOR byte = iota 149 | CURSOR_TYPE_READ_ONLY 150 | CURSOR_TYPE_FOR_UPDATE 151 | CURSOR_TYPE_SCROLLABLE 152 | ) 153 | -------------------------------------------------------------------------------- /mysql/debug.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | ) 7 | 8 | const ( 9 | InitialHandshake = "Initial Handshake Packet" 10 | ) 11 | 12 | func Dump(types string, action string, cId uint32, capability uint32, data []byte) { 13 | fmt.Println("---------------mysql packet dump----------------") 14 | fmt.Println("Action:", action) 15 | fmt.Println("ConnectionId:", cId) 16 | fmt.Println("PacketType:", types) 17 | 18 | switch types { 19 | case InitialHandshake: 20 | dumpInitialHandshake(data, capability) 21 | default: 22 | fmt.Println("Unsupport packet type") 23 | } 24 | } 25 | 26 | func dumpInitialHandshake(data []byte, capability uint32) { 27 | fmt.Println("\tProtocal Version:", data[0]) 28 | 29 | strlen := bytes.IndexByte(data[1:], 0x00) 30 | fmt.Println("\tServer:", string(data[1:1+strlen]), "00") 31 | pos := 1 + strlen + 1 32 | 33 | //cipher := data[pos : pos+8] 34 | pos += 8 35 | 36 | // capability lower 2 byte 37 | // capa := binary.LittleEndian.Uint16(data[pos : pos+2]) 38 | pos += 2 39 | 40 | if len(data) > pos { 41 | pos += 1 + 2 + 2 + 1 + 10 42 | // cipher = append(cipher, data[pos:pos+12]...) 43 | } 44 | 45 | //fmt.Println("\tCipher:", cipher) 46 | } 47 | -------------------------------------------------------------------------------- /mysql/error.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | ErrBadConn = errors.New("connection was bad") 10 | ErrMalformPacket = errors.New("Malform packet error") 11 | 12 | ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") 13 | ErrBadPkgLen = errors.New("bad packet length") 14 | ErrPktSync = errors.New("packet sync error") 15 | ErrPktSyncMul = errors.New("packet sync mul error") 16 | ErrPktTooLarge = errors.New("packet to large") 17 | ) 18 | 19 | type SqlError struct { 20 | Code uint16 21 | Message string 22 | State string 23 | } 24 | 25 | func (e *SqlError) Error() string { 26 | return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.State, e.Message) 27 | } 28 | 29 | //default mysql error, must adapt errname message format 30 | func NewDefaultError(errCode uint16, args ...interface{}) *SqlError { 31 | e := new(SqlError) 32 | e.Code = errCode 33 | 34 | if s, ok := MySQLState[errCode]; ok { 35 | e.State = s 36 | } else { 37 | e.State = DEFAULT_MYSQL_STATE 38 | } 39 | 40 | if format, ok := MySQLErrName[errCode]; ok { 41 | e.Message = fmt.Sprintf(format, args...) 42 | } else { 43 | e.Message = fmt.Sprint(args...) 44 | } 45 | 46 | return e 47 | } 48 | 49 | func NewError(errCode uint16, message string) *SqlError { 50 | e := new(SqlError) 51 | e.Code = errCode 52 | 53 | if s, ok := MySQLState[errCode]; ok { 54 | e.State = s 55 | } else { 56 | e.State = DEFAULT_MYSQL_STATE 57 | } 58 | 59 | e.Message = message 60 | 61 | return e 62 | } 63 | -------------------------------------------------------------------------------- /mysql/field.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "encoding/binary" 5 | ) 6 | 7 | type FieldData []byte 8 | 9 | type Field struct { 10 | Data FieldData 11 | Schema []byte 12 | Table []byte 13 | OrgTable []byte 14 | Name []byte 15 | OrgName []byte 16 | Charset uint16 17 | ColumnLength uint32 18 | Type uint8 19 | Flag uint16 20 | Decimal uint8 21 | 22 | DefaultValueLength uint64 23 | DefaultValue []byte 24 | } 25 | 26 | func (p FieldData) Parse() (f *Field, err error) { 27 | f = new(Field) 28 | 29 | f.Data = p 30 | 31 | var n int 32 | pos := 0 33 | //skip catelog, always def 34 | n, err = SkipLengthEnodedString(p) 35 | if err != nil { 36 | return 37 | } 38 | pos += n 39 | 40 | //schema 41 | f.Schema, _, n, err = LengthEnodedString(p[pos:]) 42 | if err != nil { 43 | return 44 | } 45 | pos += n 46 | 47 | //table 48 | f.Table, _, n, err = LengthEnodedString(p[pos:]) 49 | if err != nil { 50 | return 51 | } 52 | pos += n 53 | 54 | //org_table 55 | f.OrgTable, _, n, err = LengthEnodedString(p[pos:]) 56 | if err != nil { 57 | return 58 | } 59 | pos += n 60 | 61 | //name 62 | f.Name, _, n, err = LengthEnodedString(p[pos:]) 63 | if err != nil { 64 | return 65 | } 66 | pos += n 67 | 68 | //org_name 69 | f.OrgName, _, n, err = LengthEnodedString(p[pos:]) 70 | if err != nil { 71 | return 72 | } 73 | pos += n 74 | 75 | //skip oc 76 | pos += 1 77 | 78 | //charset 79 | f.Charset = binary.LittleEndian.Uint16(p[pos:]) 80 | pos += 2 81 | 82 | //column length 83 | f.ColumnLength = binary.LittleEndian.Uint32(p[pos:]) 84 | pos += 4 85 | 86 | //type 87 | f.Type = p[pos] 88 | pos++ 89 | 90 | //flag 91 | f.Flag = binary.LittleEndian.Uint16(p[pos:]) 92 | pos += 2 93 | 94 | //decimals 1 95 | f.Decimal = p[pos] 96 | pos++ 97 | 98 | //filter [0x00][0x00] 99 | pos += 2 100 | 101 | f.DefaultValue = nil 102 | //if more data, command was field list 103 | if len(p) > pos { 104 | //length of default value lenenc-int 105 | f.DefaultValueLength, _, n = LengthEncodedInt(p[pos:]) 106 | pos += n 107 | 108 | if pos+int(f.DefaultValueLength) > len(p) { 109 | err = ErrMalformPacket 110 | return 111 | } 112 | 113 | //default value string[$len] 114 | f.DefaultValue = p[pos:(pos + int(f.DefaultValueLength))] 115 | } 116 | 117 | return 118 | } 119 | 120 | func (f *Field) Dump() []byte { 121 | if f.Data != nil { 122 | return []byte(f.Data) 123 | } 124 | 125 | l := len(f.Schema) + len(f.Table) + len(f.OrgTable) + len(f.Name) + len(f.OrgName) + len(f.DefaultValue) + 48 126 | 127 | data := make([]byte, 0, l) 128 | 129 | data = append(data, PutLengthEncodedString([]byte("def"))...) 130 | 131 | data = append(data, PutLengthEncodedString(f.Schema)...) 132 | 133 | data = append(data, PutLengthEncodedString(f.Table)...) 134 | data = append(data, PutLengthEncodedString(f.OrgTable)...) 135 | 136 | data = append(data, PutLengthEncodedString(f.Name)...) 137 | data = append(data, PutLengthEncodedString(f.OrgName)...) 138 | 139 | data = append(data, 0x0c) 140 | 141 | data = append(data, Uint16ToBytes(f.Charset)...) 142 | data = append(data, Uint32ToBytes(f.ColumnLength)...) 143 | data = append(data, f.Type) 144 | data = append(data, Uint16ToBytes(f.Flag)...) 145 | data = append(data, f.Decimal) 146 | data = append(data, 0, 0) 147 | 148 | if f.DefaultValue != nil { 149 | data = append(data, Uint64ToBytes(f.DefaultValueLength)...) 150 | data = append(data, f.DefaultValue...) 151 | } 152 | 153 | return data 154 | } 155 | -------------------------------------------------------------------------------- /mysql/packet.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) All Rights Reserved 2 | // @file package.go 3 | // @author 王靖 (wangjild@gmail.com) 4 | // @date 14-11-27 14:40:10 5 | // @version $Revision: 1.0 $ 6 | // @brief 7 | 8 | package mysql 9 | 10 | import () 11 | 12 | // Packets documentation: 13 | // http://dev.mysql.com/doc/internals/en/mysql-packet.html 14 | const MaxPacketSize = 1<<24 - 1 15 | const PacketHeadSize = 4 16 | 17 | /* vim: set expandtab ts=4 sw=4 */ 18 | -------------------------------------------------------------------------------- /mysql/packetio.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | . "github.com/wangjild/go-mysql-proxy/log" 7 | "net" 8 | ) 9 | 10 | type PacketIO struct { 11 | reader io.Reader 12 | writer io.Writer 13 | 14 | Sequence uint8 15 | } 16 | 17 | func NewPacketIO(conn net.Conn) *PacketIO { 18 | p := new(PacketIO) 19 | 20 | p.reader = bufio.NewReader(conn) 21 | p.writer = conn 22 | 23 | p.Sequence = 0 24 | 25 | return p 26 | } 27 | 28 | func (p *PacketIO) ReadPacket() ([]byte, error) { 29 | 30 | var payload []byte 31 | for { 32 | 33 | var header [PacketHeadSize]byte 34 | if n, err := io.ReadFull(p.reader, header[:]); err != nil { 35 | AppLog.Warn("wrong packet format, head size is %d", n) 36 | return nil, ErrBadConn 37 | } 38 | 39 | length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 40 | if length < 1 { 41 | AppLog.Warn("wrong packet length, size is %d", length) 42 | return nil, ErrBadPkgLen 43 | } 44 | 45 | if uint8(header[3]) != p.Sequence { 46 | if uint8(header[3]) > p.Sequence { 47 | return nil, ErrPktSyncMul 48 | } else { 49 | return nil, ErrPktSync 50 | } 51 | } 52 | 53 | p.Sequence++ 54 | 55 | data := make([]byte, length, length) 56 | var err error 57 | if _, err = io.ReadFull(p.reader, data); err != nil { 58 | AppLog.Warn("read packet from conn error: %s", err.Error()) 59 | return nil, ErrBadConn 60 | } 61 | 62 | lastPacket := (length < MaxPacketSize) 63 | 64 | if lastPacket && payload == nil { 65 | return data, nil 66 | } 67 | 68 | payload = append(payload, data...) 69 | 70 | if lastPacket { 71 | return payload, nil 72 | } 73 | 74 | } 75 | } 76 | 77 | //data already have header 78 | func (p *PacketIO) WritePacket(data []byte) error { 79 | length := len(data) - 4 80 | 81 | for length >= MaxPayloadLen { 82 | 83 | data[0] = 0xff 84 | data[1] = 0xff 85 | data[2] = 0xff 86 | 87 | data[3] = p.Sequence 88 | 89 | if n, err := p.writer.Write(data[:4+MaxPayloadLen]); err != nil { 90 | return ErrBadConn 91 | } else if n != (4 + MaxPayloadLen) { 92 | return ErrBadConn 93 | } else { 94 | p.Sequence++ 95 | length -= MaxPayloadLen 96 | data = data[MaxPayloadLen:] 97 | } 98 | } 99 | 100 | data[0] = byte(length) 101 | data[1] = byte(length >> 8) 102 | data[2] = byte(length >> 16) 103 | data[3] = p.Sequence 104 | 105 | if n, err := p.writer.Write(data); err != nil { 106 | return ErrBadConn 107 | } else if n != len(data) { 108 | return ErrBadConn 109 | } else { 110 | p.Sequence++ 111 | return nil 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /mysql/result.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | type Result struct { 4 | AffectedRows uint64 5 | InsertId uint64 6 | 7 | Status uint16 8 | Warnings uint16 9 | 10 | *Resultset 11 | } 12 | -------------------------------------------------------------------------------- /mysql/resultset_sort.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/wangjild/go-mysql-proxy/hack" 7 | "sort" 8 | ) 9 | 10 | const ( 11 | SortAsc = "asc" 12 | SortDesc = "desc" 13 | ) 14 | 15 | type SortKey struct { 16 | //name of the field 17 | Name string 18 | 19 | Direction string 20 | 21 | //column index of the field 22 | column int 23 | } 24 | 25 | type resultsetSorter struct { 26 | *Resultset 27 | 28 | sk []SortKey 29 | } 30 | 31 | func newResultsetSorter(r *Resultset, sk []SortKey) (*resultsetSorter, error) { 32 | s := new(resultsetSorter) 33 | 34 | s.Resultset = r 35 | 36 | for i, k := range sk { 37 | if column, ok := r.FieldNames[k.Name]; ok { 38 | sk[i].column = column 39 | } else { 40 | return nil, fmt.Errorf("key %s not in resultset fields, can not sort", k.Name) 41 | } 42 | } 43 | 44 | s.sk = sk 45 | 46 | return s, nil 47 | } 48 | 49 | func (r *resultsetSorter) Len() int { 50 | return r.RowNumber() 51 | } 52 | 53 | func (r *resultsetSorter) Less(i, j int) bool { 54 | v1 := r.Values[i] 55 | v2 := r.Values[j] 56 | 57 | for _, k := range r.sk { 58 | v := cmpValue(v1[k.column], v2[k.column]) 59 | 60 | if k.Direction == SortDesc { 61 | v = -v 62 | } 63 | 64 | if v < 0 { 65 | return true 66 | } else if v > 0 { 67 | return false 68 | } 69 | 70 | //equal, cmp next key 71 | } 72 | 73 | return false 74 | } 75 | 76 | //compare value using asc 77 | func cmpValue(v1 interface{}, v2 interface{}) int { 78 | if v1 == nil && v2 == nil { 79 | return 0 80 | } else if v1 == nil { 81 | return -1 82 | } else if v2 == nil { 83 | return 1 84 | } 85 | 86 | switch v := v1.(type) { 87 | case string: 88 | s := v2.(string) 89 | return bytes.Compare(hack.Slice(v), hack.Slice(s)) 90 | case []byte: 91 | s := v2.([]byte) 92 | return bytes.Compare(v, s) 93 | case int64: 94 | s := v2.(int64) 95 | if v < s { 96 | return -1 97 | } else if v > s { 98 | return 1 99 | } else { 100 | return 0 101 | } 102 | case uint64: 103 | s := v2.(uint64) 104 | if v < s { 105 | return -1 106 | } else if v > s { 107 | return 1 108 | } else { 109 | return 0 110 | } 111 | case float64: 112 | s := v2.(float64) 113 | if v < s { 114 | return -1 115 | } else if v > s { 116 | return 1 117 | } else { 118 | return 0 119 | } 120 | default: 121 | //can not go here 122 | panic(fmt.Sprintf("invalid type %T", v)) 123 | } 124 | } 125 | 126 | func (r *resultsetSorter) Swap(i, j int) { 127 | r.Values[i], r.Values[j] = r.Values[j], r.Values[i] 128 | 129 | r.RowDatas[i], r.RowDatas[j] = r.RowDatas[j], r.RowDatas[i] 130 | } 131 | 132 | func (r *Resultset) Sort(sk []SortKey) error { 133 | s, err := newResultsetSorter(r, sk) 134 | 135 | if err != nil { 136 | return err 137 | } 138 | 139 | sort.Sort(s) 140 | 141 | return nil 142 | } 143 | -------------------------------------------------------------------------------- /mysql/resultset_sort_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "testing" 8 | ) 9 | 10 | func TestResultsetSort(t *testing.T) { 11 | r1 := new(Resultset) 12 | r2 := new(Resultset) 13 | 14 | r1.Values = [][]interface{}{ 15 | []interface{}{int64(1), "a", []byte("aa")}, 16 | []interface{}{int64(2), "a", []byte("bb")}, 17 | []interface{}{int64(3), "c", []byte("bb")}, 18 | } 19 | 20 | r1.RowDatas = []RowData{ 21 | RowData([]byte("1")), 22 | RowData([]byte("2")), 23 | RowData([]byte("3")), 24 | } 25 | 26 | s := new(resultsetSorter) 27 | 28 | s.Resultset = r1 29 | 30 | s.sk = []SortKey{ 31 | SortKey{column: 0, Direction: SortDesc}, 32 | } 33 | 34 | sort.Sort(s) 35 | 36 | r2.Values = [][]interface{}{ 37 | []interface{}{int64(3), "c", []byte("bb")}, 38 | []interface{}{int64(2), "a", []byte("bb")}, 39 | []interface{}{int64(1), "a", []byte("aa")}, 40 | } 41 | 42 | r2.RowDatas = []RowData{ 43 | RowData([]byte("3")), 44 | RowData([]byte("2")), 45 | RowData([]byte("1")), 46 | } 47 | 48 | if !reflect.DeepEqual(r1, r2) { 49 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 50 | } 51 | 52 | s.sk = []SortKey{ 53 | SortKey{column: 1, Direction: SortAsc}, 54 | SortKey{column: 2, Direction: SortDesc}, 55 | } 56 | 57 | sort.Sort(s) 58 | 59 | r2.Values = [][]interface{}{ 60 | []interface{}{int64(2), "a", []byte("bb")}, 61 | []interface{}{int64(1), "a", []byte("aa")}, 62 | []interface{}{int64(3), "c", []byte("bb")}, 63 | } 64 | 65 | r2.RowDatas = []RowData{ 66 | RowData([]byte("2")), 67 | RowData([]byte("1")), 68 | RowData([]byte("3")), 69 | } 70 | 71 | if !reflect.DeepEqual(r1, r2) { 72 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 73 | } 74 | 75 | s.sk = []SortKey{ 76 | SortKey{column: 1, Direction: SortAsc}, 77 | SortKey{column: 2, Direction: SortAsc}, 78 | } 79 | 80 | sort.Sort(s) 81 | 82 | r2.Values = [][]interface{}{ 83 | []interface{}{int64(1), "a", []byte("aa")}, 84 | []interface{}{int64(2), "a", []byte("bb")}, 85 | []interface{}{int64(3), "c", []byte("bb")}, 86 | } 87 | 88 | r2.RowDatas = []RowData{ 89 | RowData([]byte("1")), 90 | RowData([]byte("2")), 91 | RowData([]byte("3")), 92 | } 93 | 94 | if !reflect.DeepEqual(r1, r2) { 95 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /mysql/util.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/sha1" 6 | "encoding/binary" 7 | "fmt" 8 | "io" 9 | "runtime" 10 | "unicode/utf8" 11 | ) 12 | 13 | func Pstack() string { 14 | buf := make([]byte, 1024) 15 | n := runtime.Stack(buf, false) 16 | return string(buf[0:n]) 17 | } 18 | 19 | func CalcPassword(scramble, password []byte) []byte { 20 | if len(password) == 0 { 21 | return nil 22 | } 23 | 24 | // stage1Hash = SHA1(password) 25 | crypt := sha1.New() 26 | crypt.Write(password) 27 | stage1 := crypt.Sum(nil) 28 | 29 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) 30 | // inner Hash 31 | crypt.Reset() 32 | crypt.Write(stage1) 33 | hash := crypt.Sum(nil) 34 | 35 | // outer Hash 36 | crypt.Reset() 37 | crypt.Write(scramble) 38 | crypt.Write(hash) 39 | scramble = crypt.Sum(nil) 40 | 41 | // token = scrambleHash XOR stage1Hash 42 | for i := range scramble { 43 | scramble[i] ^= stage1[i] 44 | } 45 | return scramble 46 | } 47 | 48 | func RandomBuf(size int) ([]byte, error) { 49 | buf := make([]byte, size) 50 | 51 | if _, err := io.ReadFull(rand.Reader, buf); err != nil { 52 | return nil, err 53 | } 54 | 55 | for i, b := range buf { 56 | if uint8(b) == 0 { 57 | buf[i] = '0' 58 | } 59 | } 60 | return buf, nil 61 | } 62 | 63 | func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { 64 | switch b[0] { 65 | 66 | // 251: NULL 67 | case 0xfb: 68 | n = 1 69 | isNull = true 70 | return 71 | 72 | // 252: value of following 2 73 | case 0xfc: 74 | num = uint64(b[1]) | uint64(b[2])<<8 75 | n = 3 76 | return 77 | 78 | // 253: value of following 3 79 | case 0xfd: 80 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 81 | n = 4 82 | return 83 | 84 | // 254: value of following 8 85 | case 0xfe: 86 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | 87 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | 88 | uint64(b[7])<<48 | uint64(b[8])<<56 89 | n = 9 90 | return 91 | } 92 | 93 | // 0-250: value of first byte 94 | num = uint64(b[0]) 95 | n = 1 96 | return 97 | } 98 | 99 | func PutLengthEncodedInt(n uint64) []byte { 100 | switch { 101 | case n <= 250: 102 | return []byte{byte(n)} 103 | 104 | case n <= 0xffff: 105 | return []byte{0xfc, byte(n), byte(n >> 8)} 106 | 107 | case n <= 0xffffff: 108 | return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} 109 | 110 | case n <= 0xffffffffffffffff: 111 | return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), 112 | byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)} 113 | } 114 | return nil 115 | } 116 | 117 | func LengthEnodedString(b []byte) ([]byte, bool, int, error) { 118 | // Get length 119 | num, isNull, n := LengthEncodedInt(b) 120 | if num < 1 { 121 | return nil, isNull, n, nil 122 | } 123 | 124 | n += int(num) 125 | 126 | // Check data length 127 | if len(b) >= n { 128 | return b[n-int(num) : n], false, n, nil 129 | } 130 | return nil, false, n, io.EOF 131 | } 132 | 133 | func SkipLengthEnodedString(b []byte) (int, error) { 134 | // Get length 135 | num, _, n := LengthEncodedInt(b) 136 | if num < 1 { 137 | return n, nil 138 | } 139 | 140 | n += int(num) 141 | 142 | // Check data length 143 | if len(b) >= n { 144 | return n, nil 145 | } 146 | return n, io.EOF 147 | } 148 | 149 | func PutLengthEncodedString(b []byte) []byte { 150 | data := make([]byte, 0, len(b)+9) 151 | data = append(data, PutLengthEncodedInt(uint64(len(b)))...) 152 | data = append(data, b...) 153 | return data 154 | } 155 | 156 | func Uint16ToBytes(n uint16) []byte { 157 | return []byte{ 158 | byte(n), 159 | byte(n >> 8), 160 | } 161 | } 162 | 163 | func Uint32ToBytes(n uint32) []byte { 164 | return []byte{ 165 | byte(n), 166 | byte(n >> 8), 167 | byte(n >> 16), 168 | byte(n >> 24), 169 | } 170 | } 171 | 172 | func Uint64ToBytes(n uint64) []byte { 173 | return []byte{ 174 | byte(n), 175 | byte(n >> 8), 176 | byte(n >> 16), 177 | byte(n >> 24), 178 | byte(n >> 32), 179 | byte(n >> 40), 180 | byte(n >> 48), 181 | byte(n >> 56), 182 | } 183 | } 184 | 185 | func FormatBinaryDate(n int, data []byte) ([]byte, error) { 186 | switch n { 187 | case 0: 188 | return []byte("0000-00-00"), nil 189 | case 4: 190 | return []byte(fmt.Sprintf("%04d-%02d-%02d", 191 | binary.LittleEndian.Uint16(data[:2]), 192 | data[2], 193 | data[3])), nil 194 | default: 195 | return nil, fmt.Errorf("invalid date packet length %d", n) 196 | } 197 | } 198 | 199 | func FormatBinaryDateTime(n int, data []byte) ([]byte, error) { 200 | switch n { 201 | case 0: 202 | return []byte("0000-00-00 00:00:00"), nil 203 | case 4: 204 | return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00", 205 | binary.LittleEndian.Uint16(data[:2]), 206 | data[2], 207 | data[3])), nil 208 | case 7: 209 | return []byte(fmt.Sprintf( 210 | "%04d-%02d-%02d %02d:%02d:%02d", 211 | binary.LittleEndian.Uint16(data[:2]), 212 | data[2], 213 | data[3], 214 | data[4], 215 | data[5], 216 | data[6])), nil 217 | case 11: 218 | return []byte(fmt.Sprintf( 219 | "%04d-%02d-%02d %02d:%02d:%02d.%06d", 220 | binary.LittleEndian.Uint16(data[:2]), 221 | data[2], 222 | data[3], 223 | data[4], 224 | data[5], 225 | data[6], 226 | binary.LittleEndian.Uint32(data[7:11]))), nil 227 | default: 228 | return nil, fmt.Errorf("invalid datetime packet length %d", n) 229 | } 230 | } 231 | 232 | func FormatBinaryTime(n int, data []byte) ([]byte, error) { 233 | if n == 0 { 234 | return []byte("0000-00-00"), nil 235 | } 236 | 237 | var sign byte 238 | if data[0] == 1 { 239 | sign = byte('-') 240 | } 241 | 242 | switch n { 243 | case 8: 244 | return []byte(fmt.Sprintf( 245 | "%c%02d:%02d:%02d", 246 | sign, 247 | uint16(data[1])*24+uint16(data[5]), 248 | data[6], 249 | data[7], 250 | )), nil 251 | case 12: 252 | return []byte(fmt.Sprintf( 253 | "%c%02d:%02d:%02d.%06d", 254 | sign, 255 | uint16(data[1])*24+uint16(data[5]), 256 | data[6], 257 | data[7], 258 | binary.LittleEndian.Uint32(data[8:12]), 259 | )), nil 260 | default: 261 | return nil, fmt.Errorf("invalid time packet length %d", n) 262 | } 263 | } 264 | 265 | var ( 266 | DONTESCAPE = byte(255) 267 | 268 | EncodeMap [256]byte 269 | ) 270 | 271 | func Escape(sql string) string { 272 | dest := make([]byte, 0, 2*len(sql)) 273 | 274 | for i, w := 0, 0; i < len(sql); i += w { 275 | runeValue, width := utf8.DecodeRuneInString(sql[i:]) 276 | if c := EncodeMap[byte(runeValue)]; c == DONTESCAPE { 277 | dest = append(dest, sql[i:i+width]...) 278 | } else { 279 | dest = append(dest, '\\', c) 280 | } 281 | w = width 282 | } 283 | 284 | return string(dest) 285 | } 286 | 287 | var encodeRef = map[byte]byte{ 288 | '\x00': '0', 289 | '\'': '\'', 290 | '"': '"', 291 | '\b': 'b', 292 | '\n': 'n', 293 | '\r': 'r', 294 | '\t': 't', 295 | 26: 'Z', // ctl-Z 296 | '\\': '\\', 297 | } 298 | 299 | func init() { 300 | for i := range EncodeMap { 301 | EncodeMap[i] = DONTESCAPE 302 | } 303 | for i := range EncodeMap { 304 | if to, ok := encodeRef[byte(i)]; ok { 305 | EncodeMap[byte(i)] = to 306 | } 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /pool/.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | -------------------------------------------------------------------------------- /pool/slice.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | ) 7 | 8 | const maxSliceType = 16 9 | const minSliceSize = 1 << 3 // 8 len slice 10 | const maxSliceSize = 1 << maxSliceType // 64k len slice 11 | 12 | type ( 13 | PoolI interface { 14 | Borrow(size int) interface{} 15 | Return(b interface{}) 16 | } 17 | 18 | // SliceSyncPool holds bufs. 19 | syncPool struct { 20 | capV int 21 | lenV int 22 | *sync.Pool 23 | } 24 | 25 | SliceSyncPool struct { 26 | pools []*syncPool 27 | 28 | New func(l int, c int) interface{} 29 | checkType func(interface{}) bool 30 | } 31 | ) 32 | 33 | func newSyncPool(NewFunc func(l int, c int) interface{}, lv int, cv int) *syncPool { 34 | p := new(syncPool) 35 | p.capV = cv 36 | p.lenV = lv 37 | p.Pool = &sync.Pool{New: func() interface{} { return NewFunc(p.lenV, p.capV) }} 38 | return p 39 | } 40 | 41 | func NewSliceSyncPool(NewFunc func(l int, c int) interface{}, check func(interface{}) bool) *SliceSyncPool { 42 | p := new(SliceSyncPool) 43 | 44 | p.New = NewFunc 45 | p.checkType = check 46 | 47 | p.pools = make([]*syncPool, maxSliceType+1) 48 | min := floorlog2(minSliceSize) 49 | max := floorlog2(maxSliceSize) 50 | for i := min; i <= max; i++ { 51 | // return 2^i size slice 52 | p.pools[i] = newSyncPool(NewFunc, 0, 1< maxSliceSize { 64 | return p.New(size, 2*size) 65 | } else if size < minSliceSize { 66 | // small than 8 len's slice all return 8 cap interface{} 67 | ret = p.borrow(floorlog2(minSliceSize)) 68 | } else { 69 | idx := floorlog2(uint(size)) 70 | if 1< maxSliceSize || v.Cap() < minSliceSize { 95 | return // too big or too small, let it go 96 | } 97 | 98 | idx := floorlog2(uint(v.Cap())) 99 | rs := 1 << uint(idx) 100 | p.pools[idx].Put(v.Slice3(0, rs, rs).Interface()) 101 | } 102 | -------------------------------------------------------------------------------- /pool/slice1.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | const cacheSliceCap = 10240 8 | 9 | type ( 10 | // SlicePool holds bufs. 11 | SlicePool struct { 12 | pools []chan interface{} 13 | 14 | New func(l int, c int) interface{} 15 | checkType func(interface{}) bool 16 | } 17 | ) 18 | 19 | func NewSlicePool(NewFunc func(l int, c int) interface{}, check func(i interface{}) bool) *SlicePool { 20 | p := new(SlicePool) 21 | 22 | p.New = NewFunc 23 | p.checkType = check 24 | 25 | p.pools = make([]chan interface{}, maxSliceType+1) 26 | min := floorlog2(minSliceSize) 27 | max := floorlog2(maxSliceSize) 28 | for i := min; i <= max; i++ { 29 | // return 2^i size slice 30 | p.pools[i] = make(chan interface{}, cacheSliceCap) 31 | } 32 | 33 | return p 34 | } 35 | 36 | // borrow a buf from the pool. 37 | func (p *SlicePool) Borrow(size int) interface{} { 38 | 39 | var ret interface{} 40 | 41 | if size > maxSliceSize { 42 | return p.New(size, 2*size) 43 | } else if size < minSliceSize { 44 | // small than 8 len's slice all return 8 cap interface{} 45 | ret = p.borrow(floorlog2(minSliceSize)) 46 | } else { 47 | idx := floorlog2(uint(size)) 48 | if 1< maxSliceSize || v.Cap() < minSliceSize { 78 | return // too big or too small, let it go 79 | } 80 | 81 | idx := floorlog2(uint(v.Cap())) 82 | rs := 1 << uint(idx) 83 | select { 84 | case p.pools[idx] <- v.Slice3(0, rs, rs).Interface(): 85 | default: 86 | // let it go, let it go 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /pool/slice1_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | ) 7 | 8 | func TestChanPoolConsist(t *testing.T) { 9 | 10 | isPool := NewSlicePool( 11 | func(l int, c int) interface{} { return make([]int, l, c) }, 12 | checkInts, 13 | ) 14 | 15 | testPoolConsist(isPool, t) 16 | } 17 | 18 | func TestHugePool(t *testing.T) { 19 | isPool := NewSlicePool( 20 | func(l int, c int) interface{} { return make([]int, l, c) }, 21 | checkInts, 22 | ) 23 | 24 | testHugePool(isPool, t) 25 | } 26 | 27 | func TestPoolEdgeCondition(t *testing.T) { 28 | bsPool := NewSlicePool( 29 | func(l int, c int) interface{} { return make([]byte, l, c) }, 30 | checkBytes, 31 | ) 32 | 33 | testPoolEdgeCondition(bsPool, t) 34 | } 35 | 36 | func TestDifferentTypePanic(t *testing.T) { 37 | 38 | bsPool := NewSlicePool( 39 | func(l int, c int) interface{} { return make([]byte, l, c) }, 40 | checkBytes, 41 | ) 42 | 43 | testDifferentTypePanic(bsPool, t) 44 | } 45 | func TestPoolFull(t *testing.T) { 46 | 47 | bsPool := NewSlicePool( 48 | func(l int, c int) interface{} { return make([]byte, l, c) }, 49 | checkBytes, 50 | ) 51 | 52 | for i := 0; i <= cacheSliceCap+1; i++ { 53 | bsPool.Return(make([]byte, 0, 8)) 54 | } 55 | } 56 | 57 | func BenchmarkSliceBorrowReturn(t *testing.B) { 58 | 59 | bytesPool := NewSlicePool( 60 | func(l int, c int) interface{} { return make([]byte, l, c) }, 61 | checkBytes, 62 | ) 63 | 64 | for i := 0; i < t.N; i++ { 65 | size := rand.Intn(maxSliceSize) 66 | if size == 0 { 67 | continue 68 | } 69 | 70 | v := bytesPool.Borrow(size) 71 | b, ok := v.([]byte) 72 | if !ok { 73 | t.Fatal(v, "is not slice type!") 74 | } 75 | 76 | if len(b) != size || cap(b) < len(b) { 77 | t.Fatal("length:", len(b), "is less than cap:", cap(b)) 78 | } else { 79 | bytesPool.Return(b) 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /pool/slice_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "math/rand" 5 | "reflect" 6 | "testing" 7 | "unsafe" 8 | ) 9 | 10 | func getBytes(i interface{}, tb testing.TB) []byte { 11 | var b []byte 12 | var ok bool 13 | if b, ok = i.([]byte); !ok { 14 | tb.Fatal(i, "is not bytes slice type!") 15 | } 16 | return b 17 | } 18 | 19 | func getInts(i interface{}, tb testing.TB) []int { 20 | var b []int 21 | var ok bool 22 | if b, ok = i.([]int); !ok { 23 | tb.Fatal(i, "is not int slice type!") 24 | } 25 | return b 26 | } 27 | 28 | func checkBytes(i interface{}) bool { 29 | _, ok := i.([]byte) 30 | return ok 31 | } 32 | 33 | func checkInts(i interface{}) bool { 34 | _, ok := i.([]int) 35 | return ok 36 | } 37 | 38 | func TestSyncPoolConsist(t *testing.T) { 39 | 40 | isSyncPool := NewSliceSyncPool( 41 | func(l int, c int) interface{} { return make([]int, l, c) }, 42 | checkInts, 43 | ) 44 | 45 | testPoolConsist(isSyncPool, t) 46 | } 47 | 48 | func testPoolConsist(isPool PoolI, t testing.TB) { 49 | 50 | contents := [8]int{1, 2, 3, 4, 5, 6, 7, 8} 51 | b := getInts(isPool.Borrow(0), t) 52 | b = append(b, contents[:]...) 53 | isPool.Return(b) 54 | 55 | nb := getInts(isPool.Borrow(8), t) 56 | 57 | if (*reflect.SliceHeader)(unsafe.Pointer(&nb)).Data != (*reflect.SliceHeader)(unsafe.Pointer(&b)).Data { 58 | t.Fatal("not the same underly buffer!") 59 | } 60 | } 61 | 62 | func TestSyncHugePool(t *testing.T) { 63 | isSyncPool := NewSliceSyncPool( 64 | func(l int, c int) interface{} { return make([]int, l, c) }, 65 | checkInts, 66 | ) 67 | 68 | testHugePool(isSyncPool, t) 69 | } 70 | 71 | func testHugePool(isPool PoolI, tb testing.TB) { 72 | b := getInts(isPool.Borrow(maxSliceSize+1), tb) 73 | isPool.Return(b) // should not pool this really big buffer 74 | 75 | nb := getInts(isPool.Borrow(maxSliceSize), tb) 76 | if (*reflect.SliceHeader)(unsafe.Pointer(&nb)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&b)).Data { 77 | tb.Fatal("these two buffer should be different underly array!") 78 | } 79 | } 80 | 81 | func TestSyncPoolEdgeCondition(t *testing.T) { 82 | bsSyncPool := NewSliceSyncPool( 83 | func(l int, c int) interface{} { return make([]byte, l, c) }, 84 | checkBytes, 85 | ) 86 | 87 | testPoolEdgeCondition(bsSyncPool, t) 88 | } 89 | 90 | func testPoolEdgeCondition(bsSyncPool PoolI, t testing.TB) { 91 | 92 | for i := 1; i <= minSliceSize; i++ { 93 | s := bsSyncPool.Borrow(i) 94 | b := getBytes(s, t) 95 | 96 | if len(b) != i { 97 | t.Fatal("len:", len(b), "not match required size:", i) 98 | } 99 | 100 | if cap(b) != minSliceSize { 101 | t.Fatal("cap:", cap(b), "not match minSliceSize:", minSliceSize) 102 | } 103 | 104 | bsSyncPool.Return(b) 105 | } 106 | 107 | for i := minSliceSize + 1; i <= maxSliceSize; i++ { 108 | s := bsSyncPool.Borrow(i) 109 | b := getBytes(s, t) 110 | 111 | if len(b) != i { 112 | t.Fatal("len:", len(b), "not match required size:", i) 113 | } 114 | 115 | fl := floorlog2(uint(i)) 116 | if 1< 1 { 6 | size >>= 1 7 | idx++ 8 | } 9 | return idx 10 | } 11 | -------------------------------------------------------------------------------- /proxy/auth.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type userAuth struct { 8 | Map map[string]*passDB 9 | } 10 | 11 | func newUserAuth() *userAuth { 12 | s := new(userAuth) 13 | s.Map = make(map[string]*passDB) 14 | return s 15 | } 16 | 17 | type passDB struct { 18 | DB map[string]string 19 | } 20 | 21 | func newPassDB() *passDB { 22 | pass := new(passDB) 23 | pass.DB = make(map[string]string) 24 | return pass 25 | } 26 | 27 | func (s *passDB) add(user string, passwd string, db string) error { 28 | if _, ok := s.DB[passwd]; ok { 29 | return fmt.Errorf("user[%s] with same passwd has multi db, this is forbidden!") 30 | } 31 | 32 | s.DB[passwd] = db 33 | return nil 34 | } 35 | 36 | func (s *Server) getUserAuth(user string) *passDB { 37 | return s.users.Map[user] 38 | } 39 | 40 | func (s *Server) parseUserAuths() error { 41 | uas := newUserAuth() 42 | for _, schemaCfg := range s.cfg.Schemas { 43 | for _, auth := range schemaCfg.Auths { 44 | if _, ok := uas.Map[auth.User]; !ok { 45 | uas.Map[auth.User] = newPassDB() 46 | } 47 | 48 | if err := uas.Map[auth.User].add(auth.User, auth.Passwd, schemaCfg.DB); err != nil { 49 | return err 50 | } 51 | } 52 | } 53 | 54 | s.users = uas 55 | 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /proxy/conn_auth.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | . "github.com/wangjild/go-mysql-proxy/log" 6 | . "github.com/wangjild/go-mysql-proxy/mysql" 7 | ) 8 | 9 | func (c *Conn) checkAuth(auth []byte) error { 10 | AppLog.Debug("checkAuth") 11 | auths := c.server.getUserAuth(c.user) 12 | if auths == nil { 13 | AppLog.Warn("connect without db, auths is nil") 14 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes") 15 | } 16 | 17 | for passwd, db := range auths.DB { 18 | if bytes.Equal(auth, CalcPassword(c.salt, []byte(passwd))) { 19 | // gotcha!!! 20 | c.db = db 21 | return nil 22 | } 23 | } 24 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes") 25 | } 26 | 27 | func (c *Conn) checkAuthWithDB(auth []byte, db string) error { 28 | var s *Schema 29 | if s = c.server.getSchema(db); s == nil { 30 | return NewDefaultError(ER_BAD_DB_ERROR, db) 31 | } 32 | 33 | if passwd, ok := s.auths[c.user]; !ok { 34 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes") 35 | } else if !bytes.Equal(auth, CalcPassword(c.salt, []byte(passwd))) { 36 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes") 37 | } 38 | 39 | if err := c.useDB(db); err != nil { 40 | return err 41 | } 42 | 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /proxy/conn_query.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wangjild/go-mysql-proxy/client" 6 | "github.com/wangjild/go-mysql-proxy/hack" 7 | . "github.com/wangjild/go-mysql-proxy/mysql" 8 | "github.com/wangjild/go-mysql-proxy/sql" 9 | ) 10 | 11 | func (c *Conn) handleQuery(sqlstmt string) (err error) { 12 | /*defer func() { 13 | if e := recover(); e != nil { 14 | err = fmt.Errorf("execute %s error %v", sql, e) 15 | return 16 | } 17 | }()*/ 18 | 19 | var stmt sql.IStatement 20 | stmt, err = sql.Parse(sqlstmt) 21 | if err != nil { 22 | return fmt.Errorf(`parse sql "%s" error "%s"`, sqlstmt, err.Error()) 23 | } 24 | 25 | switch v := stmt.(type) { 26 | case sql.ISelect: 27 | return c.handleSelect(v, sqlstmt) 28 | case *sql.Insert: 29 | return c.handleExec(stmt, sqlstmt, false) 30 | case *sql.Update: 31 | return c.handleExec(stmt, sqlstmt, false) 32 | case *sql.Delete: 33 | return c.handleExec(stmt, sqlstmt, false) 34 | case *sql.Replace: 35 | return c.handleExec(stmt, sqlstmt, false) 36 | case *sql.Set: 37 | return c.handleSet(v, sqlstmt) 38 | case *sql.Begin: 39 | return c.handleBegin() 40 | case *sql.Commit: 41 | return c.handleCommit() 42 | case *sql.Rollback: 43 | return c.handleRollback() 44 | case sql.IShow: 45 | return c.handleShow(sqlstmt, v) 46 | case sql.IDDLStatement: 47 | return c.handleExec(stmt, sqlstmt, false) 48 | case *sql.Do: 49 | return c.handleExec(stmt, sqlstmt, false) 50 | case *sql.Call: 51 | return c.handleExec(stmt, sqlstmt, false) 52 | case *sql.Use: 53 | if err := c.useDB(hack.String(stmt.(*sql.Use).DB)); err != nil { 54 | return err 55 | } else { 56 | return c.writeOK(nil) 57 | } 58 | 59 | default: 60 | return fmt.Errorf("statement %T[%s] not support now", stmt, sqlstmt) 61 | } 62 | 63 | return nil 64 | } 65 | 66 | func (c *Conn) getConn(n *Node, isSelect bool) (co *client.SqlConn, err error) { 67 | if !c.needBeginTx() { 68 | if isSelect { 69 | co, err = n.getSelectConn() 70 | } else { 71 | co, err = n.getMasterConn() 72 | } 73 | if err != nil { 74 | return 75 | } 76 | } else { 77 | var ok bool 78 | c.Lock() 79 | co, ok = c.txConns[n] 80 | c.Unlock() 81 | 82 | if !ok { 83 | if co, err = n.getMasterConn(); err != nil { 84 | return 85 | } 86 | 87 | if err = co.SetAutocommit(c.IsAutoCommit()); err != nil { 88 | return 89 | } 90 | 91 | if err = co.Begin(); err != nil { 92 | return 93 | } 94 | 95 | c.Lock() 96 | c.txConns[n] = co 97 | c.Unlock() 98 | } 99 | } 100 | 101 | //todo, set conn charset, etc... 102 | if err = co.UseDB(c.schema.db); err != nil { 103 | return 104 | } 105 | 106 | if err = co.SetCharset(c.charset); err != nil { 107 | return 108 | } 109 | 110 | return 111 | } 112 | 113 | func (c *Conn) closeDBConn(co *client.SqlConn, rollback bool) { 114 | // since we have DDL, and when server is not in autoCommit, 115 | // we do not release the connection and will reuse it later 116 | if c.isInTransaction() || !c.isAutoCommit() { 117 | return 118 | } 119 | 120 | if rollback { 121 | co.Rollback() 122 | } 123 | 124 | co.Close() 125 | } 126 | 127 | func makeBindVars(args []interface{}) map[string]interface{} { 128 | bindVars := make(map[string]interface{}, len(args)) 129 | 130 | for i, v := range args { 131 | bindVars[fmt.Sprintf("v%d", i+1)] = v 132 | } 133 | 134 | return bindVars 135 | } 136 | 137 | func (c *Conn) handleExec(stmt sql.IStatement, sqlstmt string, isread bool) error { 138 | 139 | if err := c.checkDB(); err != nil { 140 | return err 141 | } 142 | 143 | conn, err := c.getConn(c.schema.node, isread) 144 | if err != nil { 145 | return err 146 | } else if conn == nil { 147 | return fmt.Errorf("no available connection") 148 | } 149 | 150 | var rs *Result 151 | rs, err = conn.Execute(sqlstmt) 152 | 153 | c.closeDBConn(conn, err != nil) 154 | 155 | if err == nil { 156 | err = c.writeOK(rs) 157 | } 158 | 159 | return err 160 | } 161 | 162 | func (c *Conn) mergeSelectResult(rs *Result) error { 163 | r := rs.Resultset 164 | status := c.status | rs.Status 165 | return c.writeResultset(status, r) 166 | } 167 | -------------------------------------------------------------------------------- /proxy/conn_resultset.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wangjild/go-mysql-proxy/hack" 6 | . "github.com/wangjild/go-mysql-proxy/mysql" 7 | "strconv" 8 | ) 9 | 10 | func formatValue(value interface{}) ([]byte, error) { 11 | switch v := value.(type) { 12 | case int8: 13 | return strconv.AppendInt(nil, int64(v), 10), nil 14 | case int16: 15 | return strconv.AppendInt(nil, int64(v), 10), nil 16 | case int32: 17 | return strconv.AppendInt(nil, int64(v), 10), nil 18 | case int64: 19 | return strconv.AppendInt(nil, int64(v), 10), nil 20 | case int: 21 | return strconv.AppendInt(nil, int64(v), 10), nil 22 | case uint8: 23 | return strconv.AppendUint(nil, uint64(v), 10), nil 24 | case uint16: 25 | return strconv.AppendUint(nil, uint64(v), 10), nil 26 | case uint32: 27 | return strconv.AppendUint(nil, uint64(v), 10), nil 28 | case uint64: 29 | return strconv.AppendUint(nil, uint64(v), 10), nil 30 | case uint: 31 | return strconv.AppendUint(nil, uint64(v), 10), nil 32 | case float32: 33 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 34 | case float64: 35 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 36 | case []byte: 37 | return v, nil 38 | case string: 39 | return hack.Slice(v), nil 40 | default: 41 | return nil, fmt.Errorf("invalid type %T", value) 42 | } 43 | } 44 | 45 | func formatField(field *Field, value interface{}) error { 46 | switch value.(type) { 47 | case int8, int16, int32, int64, int: 48 | field.Charset = 63 49 | field.Type = MYSQL_TYPE_LONGLONG 50 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG 51 | case uint8, uint16, uint32, uint64, uint: 52 | field.Charset = 63 53 | field.Type = MYSQL_TYPE_LONGLONG 54 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG 55 | case string, []byte: 56 | field.Charset = 33 57 | field.Type = MYSQL_TYPE_VAR_STRING 58 | default: 59 | return fmt.Errorf("unsupport type %T for resultset", value) 60 | } 61 | return nil 62 | } 63 | 64 | func (c *Conn) buildResultset(names []string, values [][]interface{}) (*Resultset, error) { 65 | r := new(Resultset) 66 | 67 | r.Fields = make([]*Field, len(names)) 68 | 69 | var b []byte 70 | var err error 71 | 72 | for i, vs := range values { 73 | if len(vs) != len(r.Fields) { 74 | return nil, fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields)) 75 | } 76 | 77 | var row []byte 78 | for j, value := range vs { 79 | if i == 0 { 80 | field := &Field{} 81 | r.Fields[j] = field 82 | field.Name = hack.Slice(names[j]) 83 | 84 | if err = formatField(field, value); err != nil { 85 | return nil, err 86 | } 87 | } 88 | b, err = formatValue(value) 89 | 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | row = append(row, PutLengthEncodedString(b)...) 95 | } 96 | 97 | r.RowDatas = append(r.RowDatas, row) 98 | } 99 | 100 | return r, nil 101 | } 102 | 103 | func (c *Conn) writeResultset(status uint16, r *Resultset) error { 104 | c.affectedRows = int64(-1) 105 | 106 | columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) 107 | 108 | data := make([]byte, 4, 1024) 109 | 110 | data = append(data, columnLen...) 111 | if err := c.writePacket(data); err != nil { 112 | return err 113 | } 114 | 115 | for _, v := range r.Fields { 116 | data = data[0:4] 117 | data = append(data, v.Dump()...) 118 | if err := c.writePacket(data); err != nil { 119 | return err 120 | } 121 | } 122 | 123 | if err := c.writeEOF(status); err != nil { 124 | return err 125 | } 126 | 127 | for _, v := range r.RowDatas { 128 | data = data[0:4] 129 | data = append(data, v...) 130 | if err := c.writePacket(data); err != nil { 131 | return err 132 | } 133 | } 134 | 135 | if err := c.writeEOF(status); err != nil { 136 | return err 137 | } 138 | 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /proxy/conn_select.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | . "github.com/wangjild/go-mysql-proxy/mysql" 7 | "github.com/wangjild/go-mysql-proxy/sql" 8 | ) 9 | 10 | func (c *Conn) buildSimpleSelectResult(value interface{}, name []byte, asName []byte) (*Resultset, error) { 11 | field := &Field{} 12 | 13 | field.Name = name 14 | 15 | if asName != nil { 16 | field.Name = asName 17 | } 18 | 19 | field.OrgName = name 20 | 21 | formatField(field, value) 22 | 23 | r := &Resultset{Fields: []*Field{field}} 24 | row, err := formatValue(value) 25 | if err != nil { 26 | return nil, err 27 | } 28 | r.RowDatas = append(r.RowDatas, PutLengthEncodedString(row)) 29 | 30 | return r, nil 31 | } 32 | 33 | func (c *Conn) handleFieldList(data []byte) error { 34 | index := bytes.IndexByte(data, 0x00) 35 | table := string(data[0:index]) 36 | wildcard := string(data[index+1:]) 37 | 38 | if c.schema == nil { 39 | return NewDefaultError(ER_NO_DB_ERROR) 40 | } 41 | 42 | co, err := c.schema.node.getMasterConn() 43 | if err != nil { 44 | return err 45 | } 46 | defer co.Close() 47 | 48 | if err = co.UseDB(c.schema.db); err != nil { 49 | return err 50 | } 51 | 52 | if fs, err := co.FieldList(table, wildcard); err != nil { 53 | return err 54 | } else { 55 | return c.writeFieldList(c.status, fs) 56 | } 57 | } 58 | 59 | func (c *Conn) writeFieldList(status uint16, fs []*Field) error { 60 | c.affectedRows = int64(-1) 61 | 62 | data := make([]byte, 4, 1024) 63 | 64 | for _, v := range fs { 65 | data = data[0:4] 66 | data = append(data, v.Dump()...) 67 | if err := c.writePacket(data); err != nil { 68 | return err 69 | } 70 | } 71 | 72 | if err := c.writeEOF(status); err != nil { 73 | return err 74 | } 75 | return nil 76 | } 77 | 78 | func (c *Conn) handleSelect(stmt sql.IStatement, sqlstmt string) error { 79 | 80 | if err := c.checkDB(); err != nil { 81 | return err 82 | } 83 | 84 | isread := false 85 | if s, ok := stmt.(sql.ISelect); ok { 86 | isread = !s.IsLocked() 87 | } else if _, sok := stmt.(sql.IShow); sok { 88 | isread = true 89 | } 90 | 91 | conn, err := c.getConn(c.schema.node, isread) 92 | 93 | if err != nil { 94 | return err 95 | } else if conn == nil { 96 | // r := c.newEmptyResultset(stmt) 97 | // return c.writeResultset(c.status, r) 98 | return fmt.Errorf("no available connection") 99 | } 100 | 101 | var res *Result 102 | res, err = conn.Execute(sqlstmt) 103 | 104 | c.closeDBConn(conn, false) 105 | 106 | if err == nil { 107 | err = c.mergeSelectResult(res) 108 | } 109 | 110 | return err 111 | } 112 | -------------------------------------------------------------------------------- /proxy/conn_set.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/wangjild/go-mysql-proxy/log" 6 | . "github.com/wangjild/go-mysql-proxy/mysql" 7 | "github.com/wangjild/go-mysql-proxy/sql" 8 | "strings" 9 | ) 10 | 11 | func (c *Conn) handleSet(stmt *sql.Set, sql string) error { 12 | if len(stmt.VarList) < 1 { 13 | return fmt.Errorf("must set one item at least") 14 | } 15 | 16 | var err error 17 | for _, v := range stmt.VarList { 18 | if strings.ToUpper(v.Name) == "AUTOCOMMIT" { 19 | AppLog.Debug("handle autocommit") 20 | err = c.handleSetAutoCommit(v.Value) 21 | } 22 | } 23 | 24 | if err != nil { 25 | return err 26 | } 27 | return c.handleOtherSet(stmt, sql) 28 | } 29 | 30 | func (c *Conn) handleSetAutoCommit(val sql.IExpr) error { 31 | 32 | var stmt *sql.Predicate 33 | var ok bool 34 | if stmt, ok = val.(*sql.Predicate); !ok { 35 | return fmt.Errorf("set autocommit is not support for complicate expressions") 36 | } 37 | 38 | switch value := stmt.Expr.(type) { 39 | case sql.NumVal: 40 | if i, err := value.ParseInt(); err != nil { 41 | return err 42 | } else if i == 1 { 43 | c.status |= SERVER_STATUS_AUTOCOMMIT 44 | AppLog.Debug("autocommit is set") 45 | } else if i == 0 { 46 | c.status &= ^SERVER_STATUS_AUTOCOMMIT 47 | AppLog.Debug("auto commit is unset") 48 | } else { 49 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of '%s'", i) 50 | } 51 | case sql.StrVal: 52 | if s := value.Trim(); s == "" { 53 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of ''") 54 | } else if us := strings.ToUpper(s); us == `ON` { 55 | c.status |= SERVER_STATUS_AUTOCOMMIT 56 | AppLog.Debug("auto commit is set") 57 | } else if us == `OFF` { 58 | c.status &= ^SERVER_STATUS_AUTOCOMMIT 59 | AppLog.Debug("auto commit is unset") 60 | } else { 61 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of '%s'", us) 62 | } 63 | default: 64 | return fmt.Errorf("set autocommit error, value type is %T", val) 65 | } 66 | 67 | return nil 68 | } 69 | 70 | func (c *Conn) handleSetNames(val sql.IValExpr) error { 71 | value, ok := val.(sql.StrVal) 72 | if !ok { 73 | return fmt.Errorf("set names charset error") 74 | } 75 | 76 | charset := strings.ToLower(string(value)) 77 | cid, ok := CharsetIds[charset] 78 | if !ok { 79 | return fmt.Errorf("invalid charset %s", charset) 80 | } 81 | 82 | c.charset = charset 83 | c.collation = cid 84 | 85 | return c.writeOK(nil) 86 | } 87 | 88 | func (c *Conn) handleOtherSet(stmt sql.IStatement, sql string) error { 89 | return c.handleExec(stmt, sql, false) 90 | } 91 | -------------------------------------------------------------------------------- /proxy/conn_show.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/wangjild/go-mysql-proxy/hack" 5 | . "github.com/wangjild/go-mysql-proxy/mysql" 6 | "github.com/wangjild/go-mysql-proxy/sql" 7 | ) 8 | 9 | func (c *Conn) handleShow(strsql string, stmt sql.IShow) error { 10 | var err error 11 | 12 | switch stmt.(type) { 13 | case *sql.ShowDatabases: 14 | err = c.handleShowDatabases() 15 | default: 16 | err = c.handleSelect(stmt, strsql) 17 | } 18 | 19 | return err 20 | 21 | } 22 | 23 | func (c *Conn) handleShowDatabases() error { 24 | dbs := make([]interface{}, 0, len(c.server.schemas)) 25 | for key := range c.server.schemas { 26 | dbs = append(dbs, key) 27 | } 28 | 29 | if r, err := c.buildSimpleShowResultset(dbs, "Database"); err != nil { 30 | return err 31 | } else { 32 | return c.writeResultset(c.status, r) 33 | } 34 | } 35 | 36 | func (c *Conn) buildSimpleShowResultset(values []interface{}, name string) (*Resultset, error) { 37 | 38 | r := new(Resultset) 39 | 40 | field := &Field{} 41 | 42 | field.Name = hack.Slice(name) 43 | field.Charset = 33 44 | field.Type = MYSQL_TYPE_VAR_STRING 45 | 46 | r.Fields = []*Field{field} 47 | 48 | var row []byte 49 | var err error 50 | 51 | for _, value := range values { 52 | row, err = formatValue(value) 53 | if err != nil { 54 | return nil, err 55 | } 56 | r.RowDatas = append(r.RowDatas, 57 | PutLengthEncodedString(row)) 58 | } 59 | 60 | return r, nil 61 | } 62 | -------------------------------------------------------------------------------- /proxy/conn_stmt.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "github.com/wangjild/go-mysql-proxy/client" 7 | . "github.com/wangjild/go-mysql-proxy/log" 8 | . "github.com/wangjild/go-mysql-proxy/mysql" 9 | "github.com/wangjild/go-mysql-proxy/sql" 10 | "strconv" 11 | ) 12 | 13 | var paramFieldData []byte 14 | var columnFieldData []byte 15 | 16 | func init() { 17 | var p = &Field{Name: []byte("?")} 18 | var c = &Field{} 19 | 20 | paramFieldData = p.Dump() 21 | columnFieldData = c.Dump() 22 | } 23 | 24 | type Stmt struct { 25 | id uint32 26 | bid uint32 27 | 28 | params int 29 | types []byte 30 | columns int 31 | 32 | args []interface{} 33 | 34 | s sql.IStatement 35 | 36 | sqlstmt string 37 | 38 | cstmt *client.Stmt 39 | } 40 | 41 | func (s *Stmt) ClearParams() { 42 | s.args = make([]interface{}, s.params) 43 | } 44 | 45 | func (s *Stmt) Close() { 46 | s.cstmt.Close(true) 47 | } 48 | 49 | func (c *Conn) handleComStmtPrepare(sqlstmt string) error { 50 | if c.schema == nil { 51 | return NewDefaultError(ER_NO_DB_ERROR) 52 | } 53 | 54 | s := new(Stmt) 55 | 56 | var err error 57 | s.s, err = sql.Parse(sqlstmt) 58 | if err != nil { 59 | return fmt.Errorf(`prepare parse sql "%s" error`, sqlstmt) 60 | } 61 | 62 | s.sqlstmt = sqlstmt 63 | 64 | var co *client.SqlConn 65 | co, err = c.schema.node.getMasterConn() 66 | // TODO tablename for select 67 | if err != nil { 68 | return fmt.Errorf("prepare error %s", err) 69 | } 70 | 71 | if err = co.UseDB(c.schema.db); err != nil { 72 | co.Close() 73 | return fmt.Errorf("parepre error %s", err) 74 | } 75 | 76 | if t, err := co.Prepare(sqlstmt); err != nil { 77 | co.Close() 78 | return fmt.Errorf("parepre error %s", err) 79 | } else { 80 | s.params = t.ParamNum() 81 | s.types = make([]byte, 0, s.params*2) 82 | s.columns = t.ColumnNum() 83 | s.bid = t.ID() 84 | s.cstmt = t 85 | } 86 | 87 | s.id = c.stmtId 88 | c.stmtId++ 89 | 90 | if err = c.writePrepare(s); err != nil { 91 | return err 92 | } 93 | 94 | s.ClearParams() 95 | 96 | c.stmts[s.id] = s 97 | 98 | return nil 99 | } 100 | 101 | func (c *Conn) writePrepare(s *Stmt) error { 102 | data := make([]byte, 4, 128) 103 | 104 | //status ok 105 | data = append(data, 0) 106 | //stmt id 107 | data = append(data, Uint32ToBytes(s.id)...) 108 | //number columns 109 | data = append(data, Uint16ToBytes(uint16(s.columns))...) 110 | //number params 111 | data = append(data, Uint16ToBytes(uint16(s.params))...) 112 | //filter [00] 113 | data = append(data, 0) 114 | //warning count 115 | data = append(data, 0, 0) 116 | 117 | if err := c.writePacket(data); err != nil { 118 | return err 119 | } 120 | 121 | if s.params > 0 { 122 | for i := 0; i < s.params; i++ { 123 | data = data[0:4] 124 | data = append(data, []byte(s.cstmt.ParamDefs[i])...) 125 | 126 | if err := c.writePacket(data); err != nil { 127 | return err 128 | } 129 | } 130 | 131 | if err := c.writeEOF(c.status); err != nil { 132 | return err 133 | } 134 | } 135 | 136 | if s.columns > 0 { 137 | for i := 0; i < s.columns; i++ { 138 | data = data[0:4] 139 | data = append(data, []byte(s.cstmt.ColDefs[i])...) 140 | 141 | if err := c.writePacket(data); err != nil { 142 | return err 143 | } 144 | } 145 | 146 | if err := c.writeEOF(c.status); err != nil { 147 | return err 148 | } 149 | 150 | } 151 | return nil 152 | } 153 | 154 | func (c *Conn) handleComStmtExecute(data []byte) error { 155 | if len(data) < 9 { 156 | AppLog.Warn("ErrMalFormPacket: length %d", len(data)) 157 | return ErrMalformPacket 158 | } 159 | 160 | pos := 0 161 | id := binary.LittleEndian.Uint32(data[0:4]) 162 | pos += 4 163 | 164 | s, ok := c.stmts[id] 165 | if !ok { 166 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 167 | strconv.FormatUint(uint64(id), 10), "stmt_execute") 168 | } 169 | 170 | flag := data[pos] 171 | pos++ 172 | 173 | //now we only support CURSOR_TYPE_NO_CURSOR flag 174 | if flag != 0 { 175 | return NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) 176 | } 177 | 178 | s.cstmt.SetAttr(flag) 179 | 180 | //skip iteration-count, always 1 181 | pos += 4 182 | 183 | st, isread := s.s.(sql.ISelect) 184 | if isread { 185 | isread = (!st.IsLocked()) 186 | } 187 | err := c.handleStmtExec(s, data[pos:], isread) 188 | 189 | s.ClearParams() 190 | 191 | return err 192 | } 193 | 194 | func (c *Conn) handleComStmtSendLongData(data []byte) error { 195 | if len(data) < 6 { 196 | AppLog.Warn("ErrMalFormPacket") 197 | return ErrMalformPacket 198 | } 199 | 200 | id := binary.LittleEndian.Uint32(data[0:4]) 201 | 202 | s, ok := c.stmts[id] 203 | if !ok { 204 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 205 | strconv.FormatUint(uint64(id), 10), "stmt_send_longdata") 206 | } 207 | 208 | paramId := binary.LittleEndian.Uint16(data[4:6]) 209 | if paramId >= uint16(s.params) { 210 | return NewDefaultError(ER_WRONG_ARGUMENTS, "stmt_send_longdata") 211 | } 212 | 213 | s.cstmt.SendLongData(paramId, data[6:]) 214 | return nil 215 | } 216 | 217 | func (c *Conn) handleComStmtReset(data []byte) error { 218 | if len(data) < 4 { 219 | AppLog.Warn("ErrMalFormPacket") 220 | return ErrMalformPacket 221 | } 222 | 223 | id := binary.LittleEndian.Uint32(data[0:4]) 224 | 225 | s, ok := c.stmts[id] 226 | if !ok { 227 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 228 | strconv.FormatUint(uint64(id), 10), "stmt_reset") 229 | } 230 | 231 | if r, err := s.cstmt.Reset(); err != nil { 232 | return err 233 | } else { 234 | s.ClearParams() 235 | return c.writeOK(r) 236 | } 237 | } 238 | 239 | func (c *Conn) handleComStmtClose(data []byte) error { 240 | if len(data) < 4 { 241 | return nil 242 | } 243 | 244 | id := binary.LittleEndian.Uint32(data[0:4]) 245 | 246 | if cstmt, ok := c.stmts[id]; ok { 247 | cstmt.Close() 248 | } 249 | 250 | delete(c.stmts, id) 251 | 252 | return nil 253 | } 254 | 255 | // 256 | func (c *Conn) handleStmtExec(prepared *Stmt, data []byte, resultSet bool) error { 257 | 258 | res, err := prepared.cstmt.Execute(data) 259 | if err != nil { 260 | return err 261 | } 262 | 263 | if resultSet { 264 | err = c.mergeSelectResult(res) 265 | } else { 266 | err = c.writeOK(res) 267 | } 268 | return err 269 | } 270 | -------------------------------------------------------------------------------- /proxy/conn_stmt_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestStmt_DropTable(t *testing.T) { 8 | server := newTestServer(t) 9 | n := server.nodes["node1"] 10 | c, err := n.getMasterConn() 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | c.UseDB("go_proxy") 15 | if _, err := c.Execute(`drop table if exists go_proxy_test_proxy_stmt`); err != nil { 16 | t.Fatal(err) 17 | } 18 | c.Close() 19 | } 20 | 21 | func TestStmt_CreateTable(t *testing.T) { 22 | str := `CREATE TABLE IF NOT EXISTS go_proxy_test_proxy_stmt ( 23 | id BIGINT(64) UNSIGNED NOT NULL, 24 | str VARCHAR(256), 25 | f DOUBLE, 26 | e enum("test1", "test2"), 27 | u tinyint unsigned, 28 | i tinyint, 29 | PRIMARY KEY (id) 30 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 31 | 32 | server := newTestServer(t) 33 | n := server.nodes["node1"] 34 | c, err := n.getMasterConn() 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | c.UseDB("go_proxy") 40 | defer c.Close() 41 | if _, err := c.Execute(str); err != nil { 42 | t.Fatal(err) 43 | } 44 | } 45 | 46 | func TestStmt_Insert(t *testing.T) { 47 | str := `insert into go_proxy_test_proxy_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)` 48 | 49 | c := newTestDBConn(t) 50 | defer c.Close() 51 | 52 | s, err := c.Prepare(str) 53 | 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | 58 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil { 59 | t.Fatal(err) 60 | } else { 61 | if pkg.AffectedRows != 1 { 62 | t.Fatal(pkg.AffectedRows) 63 | } 64 | } 65 | 66 | s.Close() 67 | } 68 | 69 | func TestStmt_Select(t *testing.T) { 70 | str := `select str, f, e from go_proxy_test_proxy_stmt where id = ?` 71 | 72 | c := newTestDBConn(t) 73 | defer c.Close() 74 | 75 | s, err := c.Prepare(str) 76 | 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | 81 | if result, err := s.Execute(1); err != nil { 82 | t.Fatal(err) 83 | } else { 84 | if len(result.Values) != 1 { 85 | t.Fatal(len(result.Values)) 86 | } 87 | 88 | if len(result.Fields) != 3 { 89 | t.Fatal(len(result.Fields)) 90 | } 91 | 92 | if str, _ := result.GetString(0, 0); str != "a" { 93 | t.Fatal("invalid str", str) 94 | } 95 | 96 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 97 | t.Fatal("invalid f", f) 98 | } 99 | 100 | if e, _ := result.GetString(0, 2); e != "test1" { 101 | t.Fatal("invalid e", e) 102 | } 103 | 104 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 105 | t.Fatal("invalid str", str) 106 | } 107 | 108 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 109 | t.Fatal("invalid f", f) 110 | } 111 | 112 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 113 | t.Fatal("invalid e", e) 114 | } 115 | 116 | } 117 | 118 | s.Close() 119 | } 120 | 121 | func TestStmt_NULL(t *testing.T) { 122 | str := `insert into go_proxy_test_proxy_stmt (id, str, f, e) values (?, ?, ?, ?)` 123 | 124 | c := newTestDBConn(t) 125 | defer c.Close() 126 | 127 | s, err := c.Prepare(str) 128 | 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | 133 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil { 134 | t.Fatal(err) 135 | } else { 136 | if pkg.AffectedRows != 1 { 137 | t.Fatal(pkg.AffectedRows) 138 | } 139 | } 140 | 141 | s.Close() 142 | 143 | str = `select * from go_proxy_test_proxy_stmt where id = ?` 144 | s, err = c.Prepare(str) 145 | 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | if r, err := s.Execute(2); err != nil { 151 | t.Fatal(err) 152 | } else { 153 | if b, err := r.IsNullByName(0, "id"); err != nil { 154 | t.Fatal(err) 155 | } else if b == true { 156 | t.Fatal(b) 157 | } 158 | 159 | if b, err := r.IsNullByName(0, "str"); err != nil { 160 | t.Fatal(err) 161 | } else if b == false { 162 | t.Fatal(b) 163 | } 164 | 165 | if b, err := r.IsNullByName(0, "f"); err != nil { 166 | t.Fatal(err) 167 | } else if b == true { 168 | t.Fatal(b) 169 | } 170 | 171 | if b, err := r.IsNullByName(0, "e"); err != nil { 172 | t.Fatal(err) 173 | } else if b == false { 174 | t.Fatal(b) 175 | } 176 | } 177 | 178 | s.Close() 179 | } 180 | 181 | func TestStmt_Unsigned(t *testing.T) { 182 | str := `insert into go_proxy_test_proxy_stmt (id, u) values (?, ?)` 183 | 184 | c := newTestDBConn(t) 185 | defer c.Close() 186 | 187 | s, err := c.Prepare(str) 188 | 189 | if err != nil { 190 | t.Fatal(err) 191 | } 192 | 193 | if pkg, err := s.Execute(3, uint8(255)); err != nil { 194 | t.Fatal(err) 195 | } else { 196 | if pkg.AffectedRows != 1 { 197 | t.Fatal(pkg.AffectedRows) 198 | } 199 | } 200 | 201 | s.Close() 202 | 203 | str = `select u from go_proxy_test_proxy_stmt where id = ?` 204 | 205 | s, err = c.Prepare(str) 206 | if err != nil { 207 | t.Fatal(err) 208 | } 209 | 210 | if r, err := s.Execute(3); err != nil { 211 | t.Fatal(err) 212 | } else { 213 | if u, err := r.GetUint(0, 0); err != nil { 214 | t.Fatal(err) 215 | } else if u != uint64(255) { 216 | t.Fatal(u) 217 | } 218 | } 219 | 220 | s.Close() 221 | } 222 | 223 | func TestStmt_Signed(t *testing.T) { 224 | str := `insert into go_proxy_test_proxy_stmt (id, i) values (?, ?)` 225 | 226 | c := newTestDBConn(t) 227 | defer c.Close() 228 | 229 | s, err := c.Prepare(str) 230 | 231 | if err != nil { 232 | t.Fatal(err) 233 | } 234 | 235 | if _, err := s.Execute(4, 127); err != nil { 236 | t.Fatal(err) 237 | } 238 | 239 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil { 240 | t.Fatal(err) 241 | } 242 | 243 | s.Close() 244 | 245 | } 246 | 247 | func TestStmt_Trans(t *testing.T) { 248 | c1 := newTestDBConn(t) 249 | defer c1.Close() 250 | 251 | if _, err := c1.Execute(`insert into go_proxy_test_proxy_stmt (id, str) values (1002, "abc")`); err != nil { 252 | t.Fatal(err) 253 | } 254 | 255 | var err error 256 | if err = c1.Begin(); err != nil { 257 | t.Fatal(err) 258 | } 259 | 260 | str := `select str from go_proxy_test_proxy_stmt where id = ?` 261 | 262 | s, err := c1.Prepare(str) 263 | if err != nil { 264 | t.Fatal(err) 265 | } 266 | 267 | if _, err := s.Execute(1002); err != nil { 268 | t.Fatal(err) 269 | } 270 | 271 | if err := c1.Commit(); err != nil { 272 | t.Fatal(err) 273 | } 274 | 275 | if r, err := s.Execute(1002); err != nil { 276 | t.Fatal(err) 277 | } else { 278 | if str, _ := r.GetString(0, 0); str != `abc` { 279 | t.Fatal(str) 280 | } 281 | } 282 | 283 | if err := s.Close(); err != nil { 284 | t.Fatal(err) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /proxy/conn_tx.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/wangjild/go-mysql-proxy/client" 5 | . "github.com/wangjild/go-mysql-proxy/mysql" 6 | ) 7 | 8 | func (c *Conn) isInTransaction() bool { 9 | return c.status&SERVER_STATUS_IN_TRANS > 0 10 | } 11 | 12 | func (c *Conn) isAutoCommit() bool { 13 | return c.status&SERVER_STATUS_AUTOCOMMIT > 0 14 | } 15 | 16 | func (c *Conn) handleBegin() error { 17 | c.status |= SERVER_STATUS_IN_TRANS 18 | return c.writeOK(nil) 19 | } 20 | 21 | func (c *Conn) handleCommit() (err error) { 22 | if err := c.commit(); err != nil { 23 | return err 24 | } else { 25 | return c.writeOK(nil) 26 | } 27 | } 28 | 29 | func (c *Conn) handleRollback() (err error) { 30 | if err := c.rollback(); err != nil { 31 | return err 32 | } 33 | 34 | return c.writeOK(nil) 35 | } 36 | 37 | func (c *Conn) commit() (err error) { 38 | c.status &= ^SERVER_STATUS_IN_TRANS 39 | 40 | for _, co := range c.txConns { 41 | if e := co.Commit(); e != nil { 42 | err = e 43 | } 44 | co.Close() 45 | } 46 | 47 | c.txConns = map[*Node]*client.SqlConn{} 48 | 49 | return 50 | } 51 | 52 | func (c *Conn) rollback() (err error) { 53 | c.status &= ^SERVER_STATUS_IN_TRANS 54 | 55 | for _, co := range c.txConns { 56 | if e := co.Rollback(); e != nil { 57 | err = e 58 | } 59 | co.Close() 60 | } 61 | 62 | c.txConns = map[*Node]*client.SqlConn{} 63 | 64 | return 65 | } 66 | 67 | //if status is in_trans, need 68 | //else if status is not autocommit, need 69 | //else no need 70 | func (c *Conn) needBeginTx() bool { 71 | return c.isInTransaction() || !c.isAutoCommit() 72 | } 73 | -------------------------------------------------------------------------------- /proxy/node.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wangjild/go-mysql-proxy/client" 6 | "github.com/wangjild/go-mysql-proxy/config" 7 | . "github.com/wangjild/go-mysql-proxy/log" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | const ( 13 | Master = "master" 14 | Slave = "slave" 15 | ) 16 | 17 | type Node struct { 18 | sync.Mutex 19 | 20 | server *Server 21 | 22 | cfg config.NodeConfig 23 | 24 | //running master db 25 | db *client.DB 26 | 27 | master *client.DB 28 | slave *client.DB 29 | 30 | downAfterNoAlive time.Duration 31 | 32 | lastMasterPing int64 33 | lastSlavePing int64 34 | } 35 | 36 | func (n *Node) run() { 37 | //to do 38 | //1 check connection alive 39 | //2 check remove mysql server alive 40 | 41 | t := time.NewTicker(3000 * time.Second) 42 | defer t.Stop() 43 | 44 | n.lastMasterPing = time.Now().Unix() 45 | n.lastSlavePing = n.lastMasterPing 46 | for { 47 | select { 48 | case <-t.C: 49 | n.checkMaster() 50 | n.checkSlave() 51 | } 52 | } 53 | } 54 | 55 | func (n *Node) String() string { 56 | return n.cfg.Name 57 | } 58 | 59 | func (n *Node) getMasterConn() (*client.SqlConn, error) { 60 | n.Lock() 61 | db := n.db 62 | n.Unlock() 63 | 64 | if db == nil { 65 | return nil, fmt.Errorf("master is down") 66 | } 67 | 68 | return db.GetConn() 69 | } 70 | 71 | func (n *Node) getSelectConn() (*client.SqlConn, error) { 72 | var db *client.DB 73 | 74 | n.Lock() 75 | if n.cfg.RWSplit && n.slave != nil { 76 | db = n.slave 77 | } else { 78 | db = n.db 79 | } 80 | n.Unlock() 81 | 82 | if db == nil { 83 | return nil, fmt.Errorf("no alive mysql server") 84 | } 85 | 86 | return db.GetConn() 87 | } 88 | 89 | func (n *Node) checkMaster() { 90 | n.Lock() 91 | db := n.db 92 | n.Unlock() 93 | 94 | if db == nil { 95 | AppLog.Notice("no master avaliable") 96 | return 97 | } 98 | 99 | if err := db.Ping(); err != nil { 100 | AppLog.Warn("%s ping master %s error %s", n, db.Addr(), err.Error()) 101 | } else { 102 | n.lastMasterPing = time.Now().Unix() 103 | return 104 | } 105 | 106 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastMasterPing > int64(n.downAfterNoAlive) { 107 | AppLog.Warn("%s down master db %s", n, n.master.Addr()) 108 | 109 | n.downMaster() 110 | } 111 | } 112 | 113 | func (n *Node) checkSlave() { 114 | if n.slave == nil { 115 | return 116 | } 117 | 118 | db := n.slave 119 | if err := db.Ping(); err != nil { 120 | AppLog.Warn("%s ping slave %s error %s", n, db.Addr(), err.Error()) 121 | } else { 122 | n.lastSlavePing = time.Now().Unix() 123 | } 124 | 125 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastSlavePing > int64(n.downAfterNoAlive) { 126 | AppLog.Warn("%s slave db %s not alive over %ds, down it", 127 | n, db.Addr(), int64(n.downAfterNoAlive/time.Second)) 128 | 129 | n.downSlave() 130 | } 131 | } 132 | 133 | func (n *Node) openDB(addr string) (*client.DB, error) { 134 | db, err := client.Open(addr, n.cfg.User, n.cfg.Password, "") 135 | if err != nil { 136 | return nil, err 137 | } 138 | 139 | db.SetMaxIdleConnNum(n.cfg.IdleConns) 140 | return db, nil 141 | } 142 | 143 | func (n *Node) checkUpDB(addr string) (*client.DB, error) { 144 | db, err := n.openDB(addr) 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | if err := db.Ping(); err != nil { 150 | db.Close() 151 | return nil, err 152 | } 153 | 154 | return db, nil 155 | } 156 | 157 | func (n *Node) upMaster(addr string) error { 158 | n.Lock() 159 | if n.master != nil { 160 | n.Unlock() 161 | return fmt.Errorf("%s master must be down first", n) 162 | } 163 | n.Unlock() 164 | 165 | db, err := n.checkUpDB(addr) 166 | if err != nil { 167 | return err 168 | } 169 | 170 | n.Lock() 171 | n.master = db 172 | n.db = db 173 | n.Unlock() 174 | 175 | return nil 176 | } 177 | 178 | func (n *Node) upSlave(addr string) error { 179 | n.Lock() 180 | if n.slave != nil { 181 | n.Unlock() 182 | return fmt.Errorf("%s, slave must be down first", n) 183 | } 184 | n.Unlock() 185 | 186 | db, err := n.checkUpDB(addr) 187 | if err != nil { 188 | return err 189 | } 190 | 191 | n.Lock() 192 | n.slave = db 193 | n.Unlock() 194 | 195 | return nil 196 | } 197 | 198 | func (n *Node) downMaster() error { 199 | n.Lock() 200 | if n.master != nil { 201 | n.master = nil 202 | } 203 | return nil 204 | } 205 | 206 | func (n *Node) downSlave() error { 207 | n.Lock() 208 | db := n.slave 209 | n.slave = nil 210 | n.Unlock() 211 | 212 | if db != nil { 213 | db.Close() 214 | } 215 | 216 | return nil 217 | } 218 | 219 | func (s *Server) UpMaster(node string, addr string) error { 220 | n := s.getNode(node) 221 | if n == nil { 222 | return fmt.Errorf("invalid node %s", node) 223 | } 224 | 225 | return n.upMaster(addr) 226 | } 227 | 228 | func (s *Server) UpSlave(node string, addr string) error { 229 | n := s.getNode(node) 230 | if n == nil { 231 | return fmt.Errorf("invalid node %s", node) 232 | } 233 | 234 | return n.upSlave(addr) 235 | } 236 | func (s *Server) DownMaster(node string) error { 237 | n := s.getNode(node) 238 | if n == nil { 239 | return fmt.Errorf("invalid node %s", node) 240 | } 241 | n.db = nil 242 | return n.downMaster() 243 | } 244 | 245 | func (s *Server) DownSlave(node string) error { 246 | n := s.getNode(node) 247 | if n == nil { 248 | return fmt.Errorf("invalid node [%s].", node) 249 | } 250 | return n.downSlave() 251 | } 252 | 253 | func (s *Server) getNode(name string) *Node { 254 | return s.nodes[name] 255 | } 256 | 257 | func (s *Server) parseNodes() error { 258 | cfg := s.cfg 259 | s.nodes = make(map[string]*Node, len(cfg.Nodes)) 260 | 261 | for _, v := range cfg.Nodes { 262 | if _, ok := s.nodes[v.Name]; ok { 263 | return fmt.Errorf("duplicate node [%s].", v.Name) 264 | } 265 | 266 | n, err := s.parseNode(v) 267 | if err != nil { 268 | return err 269 | } 270 | 271 | s.nodes[v.Name] = n 272 | } 273 | 274 | return nil 275 | } 276 | 277 | func (s *Server) parseNode(cfg config.NodeConfig) (*Node, error) { 278 | n := new(Node) 279 | n.server = s 280 | n.cfg = cfg 281 | 282 | n.downAfterNoAlive = time.Duration(cfg.DownAfterNoAlive) * time.Second 283 | 284 | if len(cfg.Master) == 0 { 285 | return nil, fmt.Errorf("must setting master MySQL node.") 286 | } 287 | 288 | var err error 289 | if n.master, err = n.openDB(cfg.Master); err != nil { 290 | return nil, err 291 | } 292 | 293 | n.db = n.master 294 | 295 | if len(cfg.Slave) > 0 { 296 | if n.slave, err = n.openDB(cfg.Slave); err != nil { 297 | AppLog.Warn(err.Error()) 298 | n.slave = nil 299 | } 300 | } 301 | 302 | go n.run() 303 | 304 | return n, nil 305 | } 306 | -------------------------------------------------------------------------------- /proxy/schema.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wangjild/go-mysql-proxy/config" 6 | ) 7 | 8 | type Schema struct { 9 | db string 10 | node *Node 11 | auths map[string]string 12 | } 13 | 14 | func (s *Server) parseSchemas() error { 15 | s.schemas = make(map[string]*Schema) 16 | 17 | for _, schemaCfg := range s.cfg.Schemas { 18 | if _, ok := s.schemas[schemaCfg.DB]; ok { 19 | return fmt.Errorf("duplicate schema [%s].", schemaCfg.DB) 20 | } 21 | 22 | n := s.getNode(schemaCfg.Node) 23 | if n == nil { 24 | return fmt.Errorf("schema [%s] node [%s] config is not exists.", schemaCfg.DB, schemaCfg.Node) 25 | } 26 | 27 | auths, err := s.getAuths(schemaCfg) 28 | if err != nil { 29 | return err 30 | } 31 | 32 | s.schemas[schemaCfg.DB] = &Schema{ 33 | db: schemaCfg.DB, 34 | node: n, 35 | auths: auths, 36 | } 37 | } 38 | 39 | return nil 40 | } 41 | 42 | func (s *Server) getAuths(schema config.SchemaConfig) (map[string]string, error) { 43 | if len(schema.Auths) == 0 { 44 | return nil, fmt.Errorf("schema [%s]'s auth is empty.", schema.DB) 45 | } 46 | 47 | auth := make(map[string]string) 48 | 49 | for _, v := range schema.Auths { 50 | if _, ok := auth[v.User]; ok { 51 | return nil, fmt.Errorf("schema [%s] has duplicate user[%s]", schema.DB, v.User) 52 | } 53 | 54 | auth[v.User] = v.Passwd 55 | } 56 | 57 | return auth, nil 58 | } 59 | 60 | func (s *Server) getSchema(db string) *Schema { 61 | return s.schemas[db] 62 | } 63 | -------------------------------------------------------------------------------- /proxy/server.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/wangjild/go-mysql-proxy/config" 5 | . "github.com/wangjild/go-mysql-proxy/log" 6 | "net" 7 | // "runtime" 8 | ) 9 | 10 | type Server struct { 11 | cfg *config.Config 12 | 13 | addr string 14 | user string 15 | password string 16 | 17 | running bool 18 | 19 | listener net.Listener 20 | 21 | nodes map[string]*Node 22 | 23 | schemas map[string]*Schema 24 | 25 | users *userAuth 26 | } 27 | 28 | func NewServer(cfg *config.Config) (*Server, error) { 29 | s := new(Server) 30 | 31 | s.cfg = cfg 32 | 33 | s.addr = cfg.Addr 34 | s.user = cfg.User 35 | s.password = cfg.Password 36 | 37 | if err := s.parseNodes(); err != nil { 38 | return nil, err 39 | } 40 | 41 | if err := s.parseSchemas(); err != nil { 42 | return nil, err 43 | } 44 | 45 | if err := s.parseUserAuths(); err != nil { 46 | return nil, err 47 | } 48 | 49 | var err error 50 | s.listener, err = net.Listen("tcp4", s.addr) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | SysLog.Notice("Go-MySQL-Proxy Listen(tcp4) at [%s]", s.addr) 56 | return s, nil 57 | } 58 | 59 | func (s *Server) Run() error { 60 | s.running = true 61 | 62 | for s.running { 63 | conn, err := s.listener.Accept() 64 | if err != nil { 65 | SysLog.Warn("accept error %s", err.Error()) 66 | continue 67 | } 68 | 69 | go s.onConn(conn) 70 | } 71 | 72 | return nil 73 | } 74 | 75 | func (s *Server) Close() { 76 | s.running = false 77 | if s.listener != nil { 78 | s.listener.Close() 79 | } 80 | } 81 | 82 | func (s *Server) onConn(c net.Conn) { 83 | conn := s.newConn(c) 84 | 85 | defer func() { 86 | /*if err := recover(); err != nil { 87 | const size = 4096 88 | buf := make([]byte, size) 89 | buf = buf[:runtime.Stack(buf, false)] 90 | AppLog.Warn("onConn panic %v: %v\n%s", c.RemoteAddr().String(), err, buf) 91 | }*/ 92 | 93 | conn.Close() 94 | }() 95 | 96 | if err := conn.Handshake(); err != nil { 97 | AppLog.Warn("handshake error %s", err.Error()) 98 | c.Close() 99 | return 100 | } 101 | 102 | conn.Run() 103 | } 104 | -------------------------------------------------------------------------------- /proxy/server_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/wangjild/go-mysql-proxy/client" 5 | "github.com/wangjild/go-mysql-proxy/config" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | var testServerOnce sync.Once 12 | var testServer *Server 13 | var testDBOnce sync.Once 14 | var testDB *client.DB 15 | 16 | var testConfigData = []byte(` 17 | addr : 127.0.0.1:4000 18 | user : root 19 | password : root 20 | 21 | nodes : 22 | - 23 | name : node1 24 | down_after_noalive : 300 25 | idle_conns : 16 26 | rw_split: false 27 | user: root 28 | password: 29 | master : 127.0.0.1:4306 30 | slave : 31 | - 32 | name : node2 33 | down_after_noalive : 300 34 | idle_conns : 16 35 | rw_split: false 36 | user: root 37 | password: 38 | master : 127.0.0.1:4307 39 | 40 | - 41 | name : node3 42 | down_after_noalive : 300 43 | idle_conns : 16 44 | rw_split: false 45 | user: root 46 | password: 47 | master : 127.0.0.1:4308 48 | 49 | schemas : 50 | - 51 | db : proxy_test 52 | nodes: [node1, node2, node3] 53 | rules: 54 | default: node1 55 | shard: 56 | - 57 | table: proxy_test_shard_hash 58 | key: id 59 | nodes: [node2, node3] 60 | type: hash 61 | 62 | - 63 | table: proxy_test_shard_range 64 | key: id 65 | nodes: [node2, node3] 66 | range: -10000- 67 | type: range 68 | `) 69 | 70 | func newTestServer(t *testing.T) *Server { 71 | f := func() { 72 | cfg, err := config.ParseConfigData(testConfigData) 73 | if err != nil { 74 | t.Fatal(err.Error()) 75 | } 76 | 77 | testServer, err = NewServer(cfg) 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | go testServer.Run() 83 | 84 | time.Sleep(1 * time.Second) 85 | } 86 | 87 | testServerOnce.Do(f) 88 | 89 | return testServer 90 | } 91 | 92 | func newTestDB(t *testing.T) *client.DB { 93 | newTestServer(t) 94 | 95 | f := func() { 96 | var err error 97 | testDB, err = client.Open("127.0.0.1:4000", "root", "", "go_proxy") 98 | 99 | if err != nil { 100 | t.Fatal(err) 101 | } 102 | 103 | testDB.SetMaxIdleConnNum(4) 104 | } 105 | 106 | testDBOnce.Do(f) 107 | return testDB 108 | } 109 | 110 | func newTestDBConn(t *testing.T) *client.SqlConn { 111 | db := newTestDB(t) 112 | 113 | c, err := db.GetConn() 114 | 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | 119 | return c 120 | } 121 | 122 | func TestServer(t *testing.T) { 123 | newTestServer(t) 124 | } 125 | -------------------------------------------------------------------------------- /proxy/signal.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) All Rights Reserved 2 | // @file signal.go 3 | // @author 王靖 (wangjild@gmail.com) 4 | // @date 14-11-26 15:36:24 5 | // @version $Revision: 1.0 $ 6 | // @brief 7 | 8 | package proxy 9 | 10 | import ( 11 | . "github.com/wangjild/go-mysql-proxy/log" 12 | "os" 13 | ) 14 | 15 | type SignalHandler func(s os.Signal, arg interface{}) error 16 | 17 | type SignalSet struct { 18 | M map[os.Signal]SignalHandler 19 | } 20 | 21 | func NewSignalSet() *SignalSet { 22 | s := new(SignalSet) 23 | s.M = make(map[os.Signal]SignalHandler) 24 | return s 25 | } 26 | 27 | func (s *SignalSet) Register(sig os.Signal, handler SignalHandler) { 28 | if _, exist := s.M[sig]; !exist { 29 | s.M[sig] = handler 30 | } 31 | } 32 | 33 | func (s *SignalSet) Handle(sig os.Signal, arg interface{}) error { 34 | if handler, exist := s.M[sig]; exist { 35 | return handler(sig, arg) 36 | } else { 37 | SysLog.Warn("no available handler for signal %v, ignore!", sig) 38 | return nil 39 | } 40 | } 41 | 42 | func init() { 43 | 44 | } 45 | 46 | /* vim: set expandtab ts=4 sw=4 */ 47 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | GOGC=1000 ./cmd/proxy/proxy --config=etc/proxy_single.yaml --logfile=log/proxy.log --loglevel=0 -------------------------------------------------------------------------------- /sql/.gitignore: -------------------------------------------------------------------------------- 1 | *.output 2 | -------------------------------------------------------------------------------- /sql/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2012, Google Inc. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can 3 | # be found in the LICENSE file. 4 | 5 | # MAKEFLAGS = -s 6 | 7 | sql.go: sql_yacc.yy 8 | ./bin/yacc -o sql_yacc.go -p MySQL sql_yacc.yy 9 | gofmt -w sql_yacc.go 10 | 11 | clean: 12 | rm -f y.output sql_yacc.go 13 | -------------------------------------------------------------------------------- /sql/ast.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type IStatement interface { 4 | IStatement() 5 | } 6 | 7 | func SetParseTree(yylex interface{}, stmt IStatement) { 8 | yylex.(*SQLLexer).ParseTree = stmt 9 | } 10 | -------------------------------------------------------------------------------- /sql/ast_alter.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type AlterTable struct { 4 | Table ISimpleTable 5 | } 6 | 7 | func (*AlterTable) IStatement() {} 8 | func (*AlterTable) IDDLStatement() {} 9 | 10 | type AlterDatabase struct { 11 | Schema []byte 12 | } 13 | 14 | func (*AlterDatabase) IStatement() {} 15 | func (*AlterDatabase) IDDLStatement() {} 16 | 17 | type AlterProcedure struct { 18 | Procedure *Spname 19 | } 20 | 21 | func (*AlterProcedure) IStatement() {} 22 | func (*AlterProcedure) IDDLStatement() {} 23 | 24 | type AlterFunction struct { 25 | Function *Spname 26 | } 27 | 28 | func (*AlterFunction) IStatement() {} 29 | func (*AlterFunction) IDDLStatement() {} 30 | 31 | /************************* 32 | * Alter View Statement 33 | *************************/ 34 | func (*AlterView) IStatement() {} 35 | func (*AlterView) IDDLStatement() {} 36 | 37 | type AlterView struct { 38 | View ISimpleTable 39 | As ISelect 40 | } 41 | 42 | type viewTail struct { 43 | View ISimpleTable 44 | As ISelect 45 | } 46 | 47 | func (av *AlterView) GetSchemas() []string { 48 | d := av.View.GetSchemas() 49 | p := av.As.GetSchemas() 50 | if d != nil && p != nil { 51 | d = append(d, p...) 52 | } 53 | 54 | return d 55 | } 56 | 57 | /************************* 58 | * Alter Event Statement 59 | *************************/ 60 | func (*AlterEvent) IStatement() {} 61 | func (*AlterEvent) IDDLStatement() {} 62 | func (*AlterEvent) HasDDLSchemas() {} 63 | func (a *AlterEvent) GetSchemas() []string { 64 | if a.Rename == nil { 65 | return a.Event.GetSchemas() 66 | } 67 | 68 | return GetSchemas(a.Event.GetSchemas(), a.Rename.GetSchemas()) 69 | } 70 | 71 | type AlterEvent struct { 72 | Event *Spname 73 | Rename *Spname 74 | } 75 | 76 | type AlterTablespace struct{} 77 | 78 | func (*AlterTablespace) IStatement() {} 79 | func (*AlterTablespace) IDDLStatement() {} 80 | 81 | type AlterLogfile struct{} 82 | 83 | func (*AlterLogfile) IStatement() {} 84 | func (*AlterLogfile) IDDLStatement() {} 85 | 86 | type AlterServer struct{} 87 | 88 | func (*AlterServer) IStatement() {} 89 | func (*AlterServer) IDDLStatement() {} 90 | -------------------------------------------------------------------------------- /sql/ast_compound.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type Signal struct{} 4 | 5 | func (*Signal) IStatement() {} 6 | 7 | type Resignal struct{} 8 | 9 | func (*Resignal) IStatement() {} 10 | 11 | type Diagnostics struct{} 12 | 13 | func (*Diagnostics) IStatement() {} 14 | -------------------------------------------------------------------------------- /sql/ast_create.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | func (*CreateTable) IStatement() {} 4 | func (*CreateTable) IDDLStatement() {} 5 | func (*CreateTable) HasDDLSchemas() {} 6 | 7 | func (c *CreateTable) GetSchemas() []string { 8 | return c.Table.GetSchemas() 9 | } 10 | 11 | type CreateTable struct { 12 | Table ISimpleTable 13 | } 14 | 15 | func (*CreateIndex) IStatement() {} 16 | func (*CreateIndex) IDDLStatement() {} 17 | 18 | type CreateIndex struct{} 19 | 20 | /**************************** 21 | * Create Database Statement 22 | ***************************/ 23 | func (*CreateDatabase) IStatement() {} 24 | func (*CreateDatabase) IDDLStatement() {} 25 | 26 | type CreateDatabase struct{} 27 | 28 | func (*CreateView) IStatement() {} 29 | func (*CreateView) IDDLStatement() {} 30 | func (*CreateView) HasDDLSchemas() {} 31 | 32 | type CreateView struct { 33 | View ISimpleTable 34 | As ISelect 35 | } 36 | 37 | func (c *CreateView) GetSchemas() []string { 38 | return GetSchemas(c.View.GetSchemas(), c.As.GetSchemas()) 39 | } 40 | 41 | func (*CreateLog) IStatement() {} 42 | func (*CreateLog) IDDLStatement() {} 43 | 44 | type CreateLog struct{} 45 | 46 | func (*CreateTablespace) IStatement() {} 47 | func (*CreateTablespace) IDDLStatement() {} 48 | 49 | type CreateTablespace struct{} 50 | 51 | func (*CreateServer) IStatement() {} 52 | func (*CreateServer) IDDLStatement() {} 53 | 54 | type CreateServer struct{} 55 | 56 | /********************** 57 | * Create Event Statement 58 | * http://dev.mysql.com/doc/refman/5.7/en/create-event.html 59 | *********************/ 60 | func (*CreateEvent) IStatement() {} 61 | func (*CreateEvent) IDDLStatement() {} 62 | func (*CreateEvent) HasDDLSchemas() {} 63 | 64 | type CreateEvent struct { 65 | Event ISimpleTable 66 | } 67 | 68 | func (c *CreateEvent) GetSchemas() []string { 69 | return c.Event.GetSchemas() 70 | } 71 | 72 | type eventTail struct { 73 | Event ISimpleTable 74 | } 75 | 76 | func (*CreateProcedure) IStatement() {} 77 | func (*CreateProcedure) IDDLStatement() {} 78 | func (*CreateProcedure) HasDDLSchemas() {} 79 | 80 | type CreateProcedure struct { 81 | Procedure ISimpleTable 82 | } 83 | 84 | func (c *CreateProcedure) GetSchemas() []string { 85 | return c.Procedure.GetSchemas() 86 | } 87 | 88 | type spTail struct { 89 | Procedure ISimpleTable 90 | } 91 | 92 | func (*CreateFunction) IStatement() {} 93 | func (*CreateFunction) IDDLStatement() {} 94 | func (*CreateFunction) HasDDLSchemas() {} 95 | 96 | type CreateFunction struct { 97 | Function ISimpleTable 98 | } 99 | type sfTail struct { 100 | Function ISimpleTable 101 | } 102 | 103 | func (c *CreateFunction) GetSchemas() []string { 104 | return c.Function.GetSchemas() 105 | } 106 | 107 | func (*CreateTrigger) IStatement() {} 108 | func (*CreateTrigger) IDDLStatement() {} 109 | func (*CreateTrigger) HasDDLSchemas() {} 110 | 111 | type CreateTrigger struct { 112 | Trigger ISimpleTable 113 | } 114 | type triggerTail struct { 115 | Trigger ISimpleTable 116 | } 117 | 118 | func (c *CreateTrigger) GetSchemas() []string { 119 | return c.Trigger.GetSchemas() 120 | } 121 | -------------------------------------------------------------------------------- /sql/ast_dal.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | func (*Set) IStatement() {} 4 | 5 | type Set struct { 6 | VarList Vars 7 | } 8 | 9 | type Vars []*Variable 10 | 11 | type Variable struct { 12 | Type VarType 13 | Life LifeType 14 | Name string 15 | Value IExpr 16 | } 17 | 18 | type VarType int 19 | type LifeType int 20 | 21 | const ( 22 | Type_Sys = 1 23 | Type_Usr = 2 24 | 25 | Life_Unknown = 0 26 | Life_Global = 1 27 | Life_Local = 2 28 | Life_Session = 3 29 | ) 30 | 31 | type IAccountMgrStmt interface { 32 | IsAccountMgrStmt() 33 | IStatement 34 | } 35 | 36 | type Partition struct{} 37 | 38 | func (*Partition) IStatement() {} 39 | 40 | /******************************* 41 | * Table Maintenance Statements 42 | ******************************/ 43 | type ITableMtStmt interface { 44 | IStatement 45 | IsTableMtStmt() 46 | GetSchemas() []string 47 | } 48 | 49 | func (*Check) IStatement() {} 50 | func (*Check) IsTableMtStmt() {} 51 | func (*CheckSum) IStatement() {} 52 | func (*CheckSum) IsTableMtStmt() {} 53 | func (*Repair) IStatement() {} 54 | func (*Repair) IsTableMtStmt() {} 55 | func (*Analyze) IStatement() {} 56 | func (*Analyze) IsTableMtStmt() {} 57 | func (*Optimize) IStatement() {} 58 | func (*Optimize) IsTableMtStmt() {} 59 | 60 | func (c *Check) GetSchemas() []string { 61 | return c.Tables.GetSchemas() 62 | } 63 | 64 | func (c *CheckSum) GetSchemas() []string { 65 | return c.Tables.GetSchemas() 66 | } 67 | 68 | func (r *Repair) GetSchemas() []string { 69 | return r.Tables.GetSchemas() 70 | } 71 | 72 | func (a *Analyze) GetSchemas() []string { 73 | return a.Tables.GetSchemas() 74 | } 75 | 76 | func (o *Optimize) GetSchemas() []string { 77 | return o.Tables.GetSchemas() 78 | } 79 | 80 | type Check struct { 81 | Tables ISimpleTables 82 | } 83 | 84 | type CheckSum struct { 85 | Tables ISimpleTables 86 | } 87 | 88 | type Repair struct { 89 | Tables ISimpleTables 90 | } 91 | 92 | type Analyze struct { 93 | Tables ISimpleTables 94 | } 95 | 96 | type Optimize struct { 97 | Tables ISimpleTables 98 | } 99 | 100 | /**************************** 101 | * Cache Index Statement 102 | ***************************/ 103 | func (*CacheIndex) IStatement() {} 104 | 105 | type CacheIndex struct { 106 | TableIndexList TableIndexes 107 | } 108 | 109 | func (c *CacheIndex) GetSchemas() []string { 110 | if c.TableIndexList == nil || len(c.TableIndexList) == 0 { 111 | return nil 112 | } 113 | return c.TableIndexList.GetSchemas() 114 | } 115 | 116 | func (*LoadIndex) IStatement() {} 117 | 118 | type LoadIndex struct { 119 | TableIndexList TableIndexes 120 | } 121 | 122 | func (l *LoadIndex) GetSchemas() []string { 123 | if l.TableIndexList == nil || len(l.TableIndexList) == 0 { 124 | return nil 125 | } 126 | return l.TableIndexList.GetSchemas() 127 | } 128 | 129 | type TableIndexes []*TableIndex 130 | 131 | func (tis TableIndexes) GetSchemas() []string { 132 | var rt []string 133 | for _, v := range tis { 134 | if v == nil { 135 | continue 136 | } 137 | 138 | if r := v.Table.GetSchemas(); r != nil && len(r) != 0 { 139 | rt = append(rt, r...) 140 | } 141 | } 142 | 143 | if len(rt) == 0 { 144 | return nil 145 | } 146 | 147 | return rt 148 | } 149 | 150 | type TableIndex struct { 151 | Table ISimpleTable 152 | } 153 | 154 | type Binlog struct{} 155 | 156 | func (*Binlog) IStatement() {} 157 | 158 | func (*Flush) IStatement() {} 159 | 160 | type Flush struct{} 161 | 162 | func (*FlushTables) IStatement() {} 163 | 164 | func (f *FlushTables) GetSchemas() []string { 165 | if f.Tables == nil { 166 | return nil 167 | } 168 | return f.Tables.GetSchemas() 169 | } 170 | 171 | type FlushTables struct { 172 | Tables ISimpleTables 173 | } 174 | 175 | type Kill struct{} 176 | 177 | func (*Kill) IStatement() {} 178 | 179 | type Reset struct{} 180 | 181 | func (*Reset) IStatement() {} 182 | 183 | /********************************************** 184 | * Plugin and User-Defined Function Statements 185 | *********************************************/ 186 | type IPluginAndUdf interface { 187 | IStatement 188 | IsPluginAndUdf() 189 | } 190 | 191 | func (*Install) IStatement() {} 192 | func (*Install) IsPluginAndUdf() {} 193 | func (*CreateUDF) IStatement() {} 194 | func (*CreateUDF) IDDLStatement() {} 195 | func (*CreateUDF) IsPluginAndUdf() {} 196 | func (*Uninstall) IStatement() {} 197 | func (*Uninstall) IsPluginAndUdf() {} 198 | 199 | type Install struct{} 200 | 201 | type Uninstall struct{} 202 | 203 | type CreateUDF struct { 204 | Function ISimpleTable 205 | } 206 | 207 | type udfTail struct { 208 | Function ISimpleTable 209 | } 210 | 211 | /********************************** 212 | * Account Management Statements 213 | *********************************/ 214 | func (*Grant) IStatement() {} 215 | func (*Grant) IsAccountMgrStmt() {} 216 | 217 | type Grant struct{} 218 | 219 | func (*SetPassword) IStatement() {} 220 | func (*SetPassword) IsAccountStmt() {} 221 | 222 | type SetPassword struct{} 223 | 224 | func (*RenameUser) IStatement() {} 225 | func (*RenameUser) IsAccountMgrStmt() {} 226 | 227 | type RenameUser struct{} 228 | 229 | func (*Revoke) IStatement() {} 230 | func (*Revoke) IsAccountMgrStmt() {} 231 | 232 | type Revoke struct{} 233 | 234 | func (*CreateUser) IStatement() {} 235 | func (*CreateUser) IDDLStatement() {} 236 | func (*CreateUser) IsAccountMgrStmt() {} 237 | 238 | type CreateUser struct{} 239 | 240 | func (*AlterUser) IStatement() {} 241 | func (*AlterUser) IDDLStatement() {} 242 | func (*AlterUser) IsAccountMgrStmt() {} 243 | 244 | type AlterUser struct{} 245 | 246 | func (*DropUser) IStatement() {} 247 | func (*DropUser) IDDLStatement() {} 248 | func (*DropUser) IsAccountMgrStmt() {} 249 | 250 | type DropUser struct{} 251 | -------------------------------------------------------------------------------- /sql/ast_ddl.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type IDDLStatement interface { 4 | IDDLStatement() 5 | IStatement() 6 | } 7 | 8 | type IDDLSchemas interface { 9 | GetSchemas() []string 10 | HasDDLSchemas() 11 | } 12 | 13 | type RenameTable struct { 14 | ToList []*TableToTable 15 | } 16 | 17 | func (*RenameTable) IStatement() {} 18 | func (*RenameTable) IDDLStatement() {} 19 | 20 | func (*TruncateTable) IStatement() {} 21 | func (*TruncateTable) IDDLStatement() {} 22 | func (*TruncateTable) HasDDLSchemas() {} 23 | func (t *TruncateTable) GetSchemas() []string { 24 | return t.Table.GetSchemas() 25 | } 26 | 27 | type TruncateTable struct { 28 | Table ISimpleTable 29 | } 30 | -------------------------------------------------------------------------------- /sql/ast_dml.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | /*********************************** 4 | * Select Clause 5 | ***********************************/ 6 | 7 | type ISelect interface { 8 | ISelect() 9 | IsLocked() bool 10 | GetSchemas() []string 11 | IStatement 12 | } 13 | 14 | func (*Select) ISelect() {} 15 | func (*ParenSelect) ISelect() {} 16 | func (*Union) ISelect() {} 17 | func (*SubQuery) ISelect() {} 18 | 19 | func (*Select) IStatement() {} 20 | func (*ParenSelect) IStatement() {} 21 | func (*Union) IStatement() {} 22 | func (*SubQuery) IStatement() {} 23 | 24 | type Union struct { 25 | Left, Right ISelect 26 | } 27 | 28 | func (u *Union) IsLocked() bool { 29 | return u.Left.IsLocked() || u.Right.IsLocked() 30 | } 31 | 32 | func (u *Union) GetSchemas() []string { 33 | if u.Left == nil { 34 | panic("union must have left select statement") 35 | } 36 | 37 | if u.Right == nil { 38 | panic("union must have right select statement") 39 | } 40 | 41 | l := u.Left.GetSchemas() 42 | r := u.Right.GetSchemas() 43 | 44 | if l == nil && r == nil { 45 | return nil 46 | } else if l == nil { 47 | return r 48 | } else if r == nil { 49 | return l 50 | } 51 | return append(l, r...) 52 | } 53 | 54 | // SubQuery --------- 55 | type SubQuery struct { 56 | SelectStatement ISelect 57 | } 58 | 59 | func (s *SubQuery) IsLocked() bool { 60 | return s.SelectStatement.IsLocked() 61 | } 62 | 63 | func (s *SubQuery) GetSchemas() []string { 64 | if s.SelectStatement == nil { 65 | panic("subquery has no content") 66 | } 67 | 68 | return s.SelectStatement.GetSchemas() 69 | } 70 | 71 | // Select ----------- 72 | type Select struct { 73 | From ITables 74 | LockType LockType 75 | } 76 | 77 | func (s *Select) IsLocked() bool { 78 | return s.LockType != LockType_NoLock 79 | } 80 | 81 | func (s *Select) GetSchemas() []string { 82 | if s.From == nil { 83 | return nil 84 | } 85 | 86 | ret := make([]string, 0, 8) 87 | for _, v := range s.From { 88 | r := v.GetSchemas() 89 | if r != nil || len(r) != 0 { 90 | ret = append(ret, r...) 91 | } 92 | } 93 | 94 | return ret 95 | } 96 | 97 | // ParenSelect ------ 98 | type ParenSelect struct { 99 | Select ISelect 100 | } 101 | 102 | func (p *ParenSelect) IsLocked() bool { 103 | return p.Select.IsLocked() 104 | } 105 | 106 | func (p *ParenSelect) GetSchemas() []string { 107 | return p.Select.GetSchemas() 108 | } 109 | 110 | type LockType int 111 | 112 | const ( 113 | LockType_NoLock = iota 114 | LockType_ForUpdate 115 | LockType_LockInShareMode 116 | ) 117 | 118 | /********************************* 119 | * Insert Clause 120 | * - http://dev.mysql.com/doc/refman/5.7/en/insert.html 121 | ********************************/ 122 | func (*Insert) IStatement() {} 123 | func (i *Insert) HasISelect() bool { 124 | if i.InsertFields == nil { 125 | return false 126 | } 127 | 128 | if _, ok := i.InsertFields.(ISelect); !ok { 129 | return false 130 | } 131 | 132 | return true 133 | } 134 | 135 | func (i *Insert) GetSchemas() []string { 136 | ret := i.Table.GetSchemas() 137 | var s []string = nil 138 | if i.HasISelect() { 139 | s = i.InsertFields.(*Select).GetSchemas() 140 | } 141 | 142 | if ret == nil || len(ret) == 0 { 143 | return s 144 | } 145 | 146 | if s == nil || len(s) == 0 { 147 | return ret 148 | } 149 | 150 | return append(ret, s...) 151 | } 152 | 153 | type Insert struct { 154 | Table ISimpleTable 155 | // can be `values(x,y,z)` list or `select` statement 156 | InsertFields interface{} 157 | } 158 | 159 | /********************************* 160 | * Update Clause 161 | * - http://dev.mysql.com/doc/refman/5.7/en/update.html 162 | ********************************/ 163 | func (*Update) IStatement() {} 164 | func (u *Update) GetSchemas() []string { 165 | if u.Tables == nil { 166 | panic("update must have table identifier") 167 | } 168 | 169 | return u.Tables.GetSchemas() 170 | } 171 | 172 | type Update struct { 173 | Tables ITables 174 | } 175 | 176 | /********************************* 177 | * Delete Clause 178 | ********************************/ 179 | func (*Delete) IStatement() {} 180 | 181 | type Delete struct { 182 | Tables ITables 183 | } 184 | 185 | func (d *Delete) GetSchemas() []string { 186 | if d.Tables == nil || len(d.Tables) == 0 { 187 | return nil 188 | } 189 | return d.Tables.GetSchemas() 190 | } 191 | 192 | /*********************************************** 193 | * Replace Clause 194 | **********************************************/ 195 | func (*Replace) IStatement() {} 196 | func (r *Replace) HasISelect() bool { 197 | if r.ReplaceFields == nil { 198 | return false 199 | } 200 | 201 | if _, ok := r.ReplaceFields.(ISelect); !ok { 202 | return false 203 | } 204 | 205 | return true 206 | } 207 | func (r *Replace) GetSchemas() []string { 208 | ret := r.Table.GetSchemas() 209 | var s []string = nil 210 | if r.HasISelect() { 211 | s = r.ReplaceFields.(*Select).GetSchemas() 212 | } 213 | 214 | if ret == nil || len(ret) == 0 { 215 | return s 216 | } 217 | 218 | if s == nil || len(s) == 0 { 219 | return ret 220 | } 221 | 222 | return append(ret, s...) 223 | } 224 | 225 | type Replace struct { 226 | Table ITable 227 | // can be `values(x,y,z)` list or `select` statement 228 | ReplaceFields interface{} 229 | } 230 | 231 | type Call struct { 232 | Spname *Spname 233 | } 234 | 235 | func (*Call) IStatement() {} 236 | 237 | type Do struct{} 238 | 239 | func (*Do) IStatement() {} 240 | 241 | type Load struct{} 242 | 243 | func (*Load) IStatement() {} 244 | 245 | type Handler struct{} 246 | 247 | func (*Handler) IStatement() {} 248 | -------------------------------------------------------------------------------- /sql/ast_drop.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | func (*DropTables) IStatement() {} 4 | func (*DropTables) IDDLStatement() {} 5 | func (*DropTables) HasDDLSchemas() {} 6 | func (d *DropTables) GetSchemas() []string { 7 | return d.Tables.GetSchemas() 8 | } 9 | 10 | type DropTables struct { 11 | Tables ISimpleTables 12 | } 13 | 14 | func (*DropIndex) IStatement() {} 15 | func (*DropIndex) IDDLStatement() {} 16 | func (*DropIndex) HasDDLSchemas() {} 17 | func (d *DropIndex) GetSchemas() []string { 18 | return d.On.GetSchemas() 19 | } 20 | 21 | type DropIndex struct { 22 | On ISimpleTable 23 | } 24 | 25 | type DropDatabase struct{} 26 | 27 | func (*DropDatabase) IStatement() {} 28 | func (*DropDatabase) IDDLStatement() {} 29 | 30 | func (*DropFunction) IStatement() {} 31 | func (*DropFunction) IDDLStatement() {} 32 | func (*DropFunction) HasDDLSchemas() {} 33 | func (d *DropFunction) GetSchemas() []string { 34 | return d.Function.GetSchemas() 35 | } 36 | 37 | type DropFunction struct { 38 | Function *Spname 39 | } 40 | 41 | func (*DropProcedure) IStatement() {} 42 | func (*DropProcedure) IDDLStatement() {} 43 | func (*DropProcedure) HasDDLSchemas() {} 44 | func (d *DropProcedure) GetSchemas() []string { 45 | return d.Procedure.GetSchemas() 46 | } 47 | 48 | type DropProcedure struct { 49 | Procedure *Spname 50 | } 51 | 52 | type DropView struct{} 53 | 54 | func (*DropView) IStatement() {} 55 | func (*DropView) IDDLStatement() {} 56 | 57 | func (*DropTrigger) IStatement() {} 58 | func (*DropTrigger) IDDLStatement() {} 59 | func (*DropTrigger) HasDDLSchemas() {} 60 | func (d *DropTrigger) GetSchemas() []string { 61 | return d.Trigger.GetSchemas() 62 | } 63 | 64 | type DropTrigger struct { 65 | Trigger *Spname 66 | } 67 | 68 | func (*DropTablespace) IStatement() {} 69 | func (*DropTablespace) IDDLStatement() {} 70 | 71 | type DropTablespace struct{} 72 | 73 | func (*DropLogfile) IStatement() {} 74 | func (*DropLogfile) IDDLStatement() {} 75 | 76 | type DropLogfile struct{} 77 | 78 | func (*DropServer) IStatement() {} 79 | func (*DropServer) IDDLStatement() {} 80 | 81 | type DropServer struct{} 82 | 83 | func (*DropEvent) IStatement() {} 84 | func (*DropEvent) IDDLStatement() {} 85 | func (*DropEvent) HasDDLSchemas() {} 86 | func (d *DropEvent) GetSchemas() []string { 87 | return d.Event.GetSchemas() 88 | } 89 | 90 | type DropEvent struct { 91 | Event *Spname 92 | } 93 | -------------------------------------------------------------------------------- /sql/ast_prepare.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type Deallocate struct{} 4 | 5 | func (*Deallocate) IStatement() {} 6 | 7 | type Prepare struct{} 8 | 9 | func (*Prepare) IStatement() {} 10 | 11 | type Execute struct{} 12 | 13 | func (*Execute) IStatement() {} 14 | -------------------------------------------------------------------------------- /sql/ast_replication.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type Change struct{} 4 | 5 | func (*Change) IStatement() {} 6 | 7 | type Purge struct{} 8 | 9 | func (*Purge) IStatement() {} 10 | 11 | type StartSlave struct{} 12 | 13 | func (*StartSlave) IStatement() {} 14 | 15 | type StopSlave struct{} 16 | 17 | func (*StopSlave) IStatement() {} 18 | -------------------------------------------------------------------------------- /sql/ast_show.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | type IShow interface { 4 | IShow() 5 | IStatement 6 | } 7 | 8 | type IShowSchemas interface { 9 | IShow 10 | GetSchemas() []string 11 | } 12 | 13 | func (*ShowLogs) IStatement() {} 14 | func (*ShowLogs) IShow() {} 15 | func (*ShowLogEvents) IStatement() {} 16 | func (*ShowLogEvents) IShow() {} 17 | func (*ShowCharset) IStatement() {} 18 | func (*ShowCharset) IShow() {} 19 | func (*ShowCollation) IStatement() {} 20 | func (*ShowCollation) IShow() {} 21 | 22 | // SHOW CREATE [event|procedure|table|trigger|view] 23 | func (*ShowCreate) IStatement() {} 24 | func (*ShowCreate) IShow() {} 25 | func (*ShowCreateDatabase) IStatement() {} 26 | func (*ShowCreateDatabase) IShow() {} 27 | 28 | func (*ShowColumns) IStatement() {} 29 | func (*ShowColumns) IShow() {} 30 | 31 | func (*ShowDatabases) IStatement() {} 32 | func (*ShowDatabases) IShow() {} 33 | 34 | func (*ShowEngines) IStatement() {} 35 | func (*ShowEngines) IShow() {} 36 | 37 | func (*ShowErrors) IStatement() {} 38 | func (*ShowErrors) IShow() {} 39 | func (*ShowWarnings) IStatement() {} 40 | func (*ShowWarnings) IShow() {} 41 | 42 | func (*ShowEvents) IStatement() {} 43 | func (*ShowEvents) IShow() {} 44 | 45 | func (*ShowFunction) IStatement() {} 46 | func (*ShowFunction) IShow() {} 47 | 48 | func (*ShowGrants) IStatement() {} 49 | func (*ShowGrants) IShow() {} 50 | 51 | func (*ShowIndex) IStatement() {} 52 | func (*ShowIndex) IShow() {} 53 | 54 | func (*ShowStatus) IStatement() {} 55 | func (*ShowStatus) IShow() {} 56 | 57 | func (*ShowOpenTables) IStatement() {} 58 | func (*ShowOpenTables) IShow() {} 59 | func (*ShowTables) IStatement() {} 60 | func (*ShowTables) IShow() {} 61 | func (*ShowTableStatus) IStatement() {} 62 | func (*ShowTableStatus) IShow() {} 63 | 64 | func (*ShowPlugins) IStatement() {} 65 | func (*ShowPlugins) IShow() {} 66 | 67 | func (*ShowPrivileges) IStatement() {} 68 | func (*ShowPrivileges) IShow() {} 69 | 70 | func (*ShowProcedure) IStatement() {} 71 | func (*ShowProcedure) IShow() {} 72 | 73 | func (*ShowProcessList) IStatement() {} 74 | func (*ShowProcessList) IShow() {} 75 | 76 | func (*ShowProfiles) IStatement() {} 77 | func (*ShowProfiles) IShow() {} 78 | 79 | func (*ShowSlaveHosts) IStatement() {} 80 | func (*ShowSlaveHosts) IShow() {} 81 | func (*ShowSlaveStatus) IStatement() {} 82 | func (*ShowSlaveStatus) IShow() {} 83 | func (*ShowMasterStatus) IStatement() {} 84 | func (*ShowMasterStatus) IShow() {} 85 | 86 | func (*ShowTriggers) IStatement() {} 87 | func (*ShowTriggers) IShow() {} 88 | 89 | func (*ShowVariables) IStatement() {} 90 | func (*ShowVariables) IShow() {} 91 | 92 | // currently we use only like for `show databases` syntax 93 | type LikeOrWhere struct { 94 | Like string 95 | } 96 | 97 | type ShowDatabases struct { 98 | LikeOrWhere *LikeOrWhere 99 | } 100 | 101 | func (s *ShowTables) GetSchemas() []string { 102 | if s.From == nil || len(s.From) == 0 { 103 | return nil 104 | } 105 | 106 | return []string{string(s.From)} 107 | } 108 | 109 | type ShowTables struct { 110 | From []byte 111 | } 112 | 113 | func (s *ShowTriggers) GetSchemas() []string { 114 | if s.From == nil || len(s.From) == 0 { 115 | return nil 116 | } 117 | 118 | return []string{string(s.From)} 119 | } 120 | 121 | type ShowTriggers struct { 122 | From []byte 123 | } 124 | 125 | func (s *ShowEvents) GetSchemas() []string { 126 | if s.From == nil || len(s.From) == 0 { 127 | return nil 128 | } 129 | 130 | return []string{string(s.From)} 131 | } 132 | 133 | type ShowEvents struct { 134 | From []byte 135 | } 136 | 137 | func (s *ShowTableStatus) GetSchemas() []string { 138 | if s.From == nil || len(s.From) == 0 { 139 | return nil 140 | } 141 | 142 | return []string{string(s.From)} 143 | } 144 | 145 | type ShowTableStatus struct { 146 | From []byte 147 | } 148 | 149 | func (s *ShowOpenTables) GetSchemas() []string { 150 | if s.From == nil || len(s.From) == 0 { 151 | return nil 152 | } 153 | 154 | return []string{string(s.From)} 155 | } 156 | 157 | type ShowOpenTables struct { 158 | From []byte 159 | } 160 | 161 | func (s *ShowColumns) GetSchemas() []string { 162 | if s.From == nil || len(s.From) == 0 { 163 | return s.Table.GetSchemas() 164 | } 165 | 166 | return []string{string(s.From)} 167 | } 168 | 169 | type ShowColumns struct { 170 | Table ISimpleTable 171 | From []byte 172 | } 173 | 174 | func (s *ShowIndex) GetSchemas() []string { 175 | if s.From == nil || len(s.From) == 0 { 176 | return s.Table.GetSchemas() 177 | } 178 | 179 | return []string{string(s.From)} 180 | } 181 | 182 | type ShowIndex struct { 183 | Table ISimpleTable 184 | From []byte 185 | } 186 | 187 | func (s *ShowProcedure) GetSchemas() []string { 188 | return s.Procedure.GetSchemas() 189 | } 190 | 191 | type ShowProcedure struct { 192 | Procedure *Spname 193 | } 194 | 195 | func (s *ShowFunction) GetSchemas() []string { 196 | return s.Function.GetSchemas() 197 | } 198 | 199 | type ShowFunction struct { 200 | Function *Spname 201 | } 202 | 203 | func (s *ShowCreate) GetSchemas() []string { 204 | return s.Table.GetSchemas() 205 | } 206 | 207 | type ShowCreate struct { 208 | Prefix []byte 209 | Table ISimpleTable 210 | } 211 | 212 | func (s *ShowCreateDatabase) GetSchemas() []string { 213 | if s.Schema == nil || len(s.Schema) == 0 { 214 | return nil 215 | } 216 | 217 | return []string{string(s.Schema)} 218 | } 219 | 220 | type ShowCreateDatabase struct { 221 | Schema []byte 222 | } 223 | 224 | type ShowGrants struct{} 225 | type ShowCollation struct{} 226 | type ShowCharset struct{} 227 | type ShowVariables struct{} 228 | type ShowProcessList struct{} 229 | type ShowStatus struct{} 230 | type ShowProfiles struct{} 231 | type ShowPrivileges struct{} 232 | type ShowWarnings struct{} 233 | type ShowErrors struct{} 234 | type ShowLogEvents struct{} 235 | type ShowSlaveHosts struct{} 236 | type ShowSlaveStatus struct{} 237 | type ShowMasterStatus struct{} 238 | type ShowLogs struct{} 239 | type ShowPlugins struct{} 240 | type ShowEngines struct{} 241 | -------------------------------------------------------------------------------- /sql/ast_table.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | /******************************************* 8 | * Table Interfaces and Structs 9 | * doc: 10 | * - table_references http://dev.mysql.com/doc/refman/5.7/en/join.html 11 | * - table_factor http://dev.mysql.com/doc/refman/5.7/en/join.html 12 | * - join_table http://dev.mysql.com/doc/refman/5.7/en/join.html 13 | ******************************************/ 14 | type ITable interface { 15 | IsTable() 16 | GetSchemas() []string 17 | } 18 | 19 | type ITables []ITable 20 | 21 | func (ts ITables) GetSchemas() []string { 22 | if ts == nil && len(ts) == 0 { 23 | return nil 24 | } 25 | 26 | var ret []string 27 | for _, v := range ts { 28 | if r := v.GetSchemas(); r != nil && len(r) != 0 { 29 | ret = append(ret, r...) 30 | } 31 | } 32 | 33 | if len(ret) == 0 { 34 | return nil 35 | } 36 | 37 | return ret 38 | } 39 | 40 | func (*JoinTable) IsTable() {} 41 | func (*ParenTable) IsTable() {} 42 | func (*AliasedTable) IsTable() {} 43 | 44 | type JoinTable struct { 45 | Left ITable 46 | Join []byte 47 | Right ITable 48 | // TODO On BoolExpr 49 | } 50 | 51 | func (j *JoinTable) GetSchemas() []string { 52 | 53 | if j.Left == nil { 54 | panic("join table must have left value") 55 | } 56 | 57 | if j.Right == nil { 58 | panic("join table must have right value") 59 | } 60 | 61 | l := j.Left.GetSchemas() 62 | r := j.Right.GetSchemas() 63 | 64 | if l == nil && r == nil { 65 | return nil 66 | } else if l == nil { 67 | return r 68 | } else if r == nil { 69 | return l 70 | } 71 | 72 | return append(l, r...) 73 | } 74 | 75 | type ParenTable struct { 76 | Table ITable 77 | } 78 | 79 | func (p *ParenTable) GetSchemas() []string { 80 | if p.Table == nil { 81 | return nil 82 | } 83 | return p.Table.GetSchemas() 84 | } 85 | 86 | type AliasedTable struct { 87 | TableOrSubQuery interface{} // here may be the table_ident or subquery 88 | As []byte 89 | // TODO IndexHints 90 | } 91 | 92 | func (a *AliasedTable) GetSchemas() []string { 93 | if t, ok := a.TableOrSubQuery.(ITable); ok { 94 | return t.GetSchemas() 95 | } else if s, can := a.TableOrSubQuery.(*SubQuery); can { 96 | return s.SelectStatement.GetSchemas() 97 | } else { 98 | panic(fmt.Sprintf("alias table has no table_factor or subquery, element type[%T]", a.TableOrSubQuery)) 99 | } 100 | } 101 | 102 | // SimpleTable contains only qualifier, name and a column field 103 | func (*SimpleTable) IsSimpleTable() {} 104 | func (*SimpleTable) IsTable() {} 105 | 106 | type ISimpleTable interface { 107 | IsSimpleTable() 108 | ITable 109 | } 110 | 111 | type SimpleTable struct { 112 | Qualifier []byte 113 | Name []byte 114 | Column []byte 115 | } 116 | 117 | func (s *SimpleTable) GetSchemas() []string { 118 | if s.Qualifier == nil || len(s.Qualifier) == 0 { 119 | return nil 120 | } 121 | return []string{string(s.Qualifier)} 122 | } 123 | 124 | type ISimpleTables []ISimpleTable 125 | 126 | func (ts ISimpleTables) GetSchemas() []string { 127 | if ts == nil && len(ts) == 0 { 128 | return nil 129 | } 130 | 131 | var ret []string 132 | for _, v := range ts { 133 | if r := v.GetSchemas(); r != nil && len(r) != 0 { 134 | ret = append(ret, r...) 135 | } 136 | } 137 | 138 | if len(ret) == 0 { 139 | return nil 140 | } 141 | 142 | return ret 143 | } 144 | 145 | func (*Spname) IsSimpleTable() {} 146 | func (*Spname) IsTable() {} 147 | 148 | func (s *Spname) GetSchemas() []string { 149 | if s.Qualifier == nil || len(s.Qualifier) == 0 { 150 | return nil 151 | } 152 | 153 | return []string{string(s.Qualifier)} 154 | } 155 | 156 | type Spname struct { 157 | Qualifier []byte 158 | Name []byte 159 | } 160 | 161 | type SchemaInfo struct { 162 | Name []byte 163 | } 164 | 165 | func GetSchemas(params ...[]string) []string { 166 | var dst []string 167 | for _, arr := range params { 168 | if arr != nil { 169 | dst = append(dst, arr...) 170 | } 171 | } 172 | 173 | if len(dst) == 0 { 174 | return nil 175 | } 176 | 177 | return dst 178 | } 179 | 180 | type TableToTable struct { 181 | From ISimpleTable 182 | To ISimpleTable 183 | } 184 | -------------------------------------------------------------------------------- /sql/ast_trans.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | func (*StartTrans) IStatement() {} 4 | func (*Lock) IStatement() {} 5 | func (*Unlock) IStatement() {} 6 | func (*Begin) IStatement() {} 7 | func (*Commit) IStatement() {} 8 | func (*Rollback) IStatement() {} 9 | func (*XA) IStatement() {} 10 | func (*SavePoint) IStatement() {} 11 | func (*Release) IStatement() {} 12 | func (*SetTrans) IStatement() {} 13 | 14 | type StartTrans struct{} 15 | 16 | func (l *Lock) GetSchemas() []string { 17 | return l.Tables.GetSchemas() 18 | } 19 | 20 | type Lock struct { 21 | Tables ISimpleTables 22 | } 23 | 24 | type Unlock struct{} 25 | 26 | type Begin struct{} 27 | 28 | type Commit struct{} 29 | 30 | type Rollback struct { 31 | Point []byte 32 | } 33 | 34 | type XA struct{} 35 | 36 | type SavePoint struct { 37 | Point []byte 38 | } 39 | 40 | type Release struct { 41 | Point []byte 42 | } 43 | 44 | type SetTrans struct{} 45 | -------------------------------------------------------------------------------- /sql/ast_util.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func (*Help) IStatement() {} 8 | func (*DescribeTable) IStatement() {} 9 | func (*DescribeStmt) IStatement() {} 10 | func (*Use) IStatement() {} 11 | 12 | type Help struct{} 13 | 14 | func (d *DescribeTable) GetSchemas() []string { 15 | return d.Table.GetSchemas() 16 | } 17 | 18 | type DescribeTable struct { 19 | Table ISimpleTable 20 | } 21 | 22 | func (d *DescribeStmt) GetSchemas() []string { 23 | switch st := d.Stmt.(type) { 24 | case *Select: 25 | return st.GetSchemas() 26 | case *Insert: 27 | return st.GetSchemas() 28 | case *Update: 29 | return st.GetSchemas() 30 | case *Replace: 31 | return st.GetSchemas() 32 | case *Delete: 33 | return st.GetSchemas() 34 | default: 35 | panic(fmt.Sprintf("statement type %T is not explainable", st)) 36 | } 37 | } 38 | 39 | type DescribeStmt struct { 40 | Stmt IStatement 41 | } 42 | 43 | type Use struct { 44 | DB []byte 45 | } 46 | -------------------------------------------------------------------------------- /sql/bin/yacc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjild/go-mysql-proxy/cda78f8c6e14d6983c577ab0d183a0ae12ce6fb5/sql/bin/yacc -------------------------------------------------------------------------------- /sql/charset/charset.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | import ( 4 | "bytes" 5 | . "github.com/wangjild/go-mysql-proxy/sql/state" 6 | ) 7 | 8 | type ( 9 | CharsetInfo struct { 10 | Number int 11 | PrimaryNumber int 12 | BinaryNumber int 13 | 14 | CSName string 15 | Name string 16 | 17 | CType []byte 18 | 19 | StateMap []uint 20 | IdentMap []uint 21 | } 22 | ) 23 | 24 | func init() { 25 | ValidCharsets = make(map[string]*CharsetInfo) 26 | ValidCharsets["utf8_general_cli"] = CSUtf8GeneralCli 27 | 28 | for _, v := range ValidCharsets { 29 | initStateMaps(v) 30 | } 31 | } 32 | 33 | var ValidCharsets map[string]*CharsetInfo 34 | 35 | func IsValidCharsets(cs []byte) bool { 36 | if _, ok := ValidCharsets[string(bytes.ToLower(cs))]; ok { 37 | return true 38 | } 39 | 40 | return false 41 | } 42 | 43 | func initStateMaps(cs *CharsetInfo) { 44 | 45 | var state_map [256]uint 46 | 47 | for i := 0; i < 256; i++ { 48 | if cs.IsAlpha(byte(i)) == true { 49 | state_map[i] = (MY_LEX_IDENT) 50 | } else if cs.IsDigit(byte(i)) { 51 | state_map[i] = MY_LEX_NUMBER_IDENT 52 | } else if cs.IsSpace(byte(i)) { 53 | state_map[i] = MY_LEX_SKIP 54 | } else { 55 | state_map[i] = MY_LEX_CHAR 56 | } 57 | } 58 | state_map[0] = MY_LEX_EOL 59 | state_map['_'] = MY_LEX_IDENT 60 | state_map['$'] = MY_LEX_IDENT 61 | state_map['\''] = MY_LEX_STRING 62 | state_map['.'] = MY_LEX_REAL_OR_POINT 63 | state_map['>'] = MY_LEX_CMP_OP 64 | state_map['='] = MY_LEX_CMP_OP 65 | state_map['!'] = MY_LEX_CMP_OP 66 | state_map['<'] = MY_LEX_LONG_CMP_OP 67 | state_map['&'] = MY_LEX_BOOL 68 | state_map['|'] = MY_LEX_BOOL 69 | state_map['#'] = MY_LEX_COMMENT 70 | state_map[';'] = MY_LEX_SEMICOLON 71 | state_map[':'] = MY_LEX_SET_VAR 72 | state_map['\\'] = MY_LEX_ESCAPE 73 | state_map['/'] = MY_LEX_LONG_COMMENT 74 | state_map['*'] = MY_LEX_END_LONG_COMMENT 75 | state_map['@'] = MY_LEX_USER_END 76 | state_map['`'] = MY_LEX_USER_VARIABLE_DELIMITER 77 | state_map['"'] = MY_LEX_STRING_OR_DELIMITER 78 | 79 | var ident_map [256]uint 80 | for i := 0; i < 256; i++ { 81 | ident_map[i] = func() uint { 82 | if state_map[i] == MY_LEX_IDENT || state_map[i] == MY_LEX_NUMBER_IDENT { 83 | return 1 84 | } 85 | return 0 86 | }() 87 | } 88 | 89 | state_map['x'] = MY_LEX_IDENT_OR_HEX 90 | state_map['X'] = MY_LEX_IDENT_OR_HEX 91 | state_map['b'] = MY_LEX_IDENT_OR_BIN 92 | state_map['B'] = MY_LEX_IDENT_OR_BIN 93 | state_map['n'] = (MY_LEX_IDENT_OR_NCHAR) 94 | state_map['N'] = (MY_LEX_IDENT_OR_NCHAR) 95 | 96 | cs.IdentMap = ident_map[:] 97 | cs.StateMap = state_map[:] 98 | } 99 | 100 | func (cs *CharsetInfo) IsAlpha(c byte) bool { 101 | if cs.CType[c+1]&(_MY_U|_MY_L) == 0 { 102 | return false 103 | } 104 | return true 105 | } 106 | 107 | func (cs *CharsetInfo) IsDigit(c byte) bool { 108 | if cs.CType[c+1]&_MY_NMR == 0 { 109 | return false 110 | } 111 | 112 | return true 113 | } 114 | 115 | func (cs *CharsetInfo) IsSpace(c byte) bool { 116 | if cs.CType[c+1]&_MY_SPC == 0 { 117 | return false 118 | } 119 | 120 | return true 121 | } 122 | 123 | func (cs *CharsetInfo) IsCntrl(c byte) bool { 124 | if cs.CType[c+1]&_MY_CTR == 0 { 125 | return false 126 | } 127 | 128 | return true 129 | } 130 | 131 | func (cs *CharsetInfo) IsXdigit(c byte) bool { 132 | if cs.CType[c+1]&_MY_X == 0 { 133 | return false 134 | } 135 | return true 136 | } 137 | 138 | func (cs *CharsetInfo) IsAlnum(c byte) bool { 139 | if cs.CType[c+1]&(_MY_U|_MY_L|_MY_NMR) == 0 { 140 | return false 141 | } 142 | 143 | return true 144 | } 145 | 146 | const ( 147 | _MY_U = 01 148 | _MY_L = 02 149 | _MY_NMR = 04 /* Numeral (digit) */ 150 | _MY_SPC = 010 /* Spacing character */ 151 | _MY_PNT = 020 /* Punctuation */ 152 | _MY_CTR = 040 /* Control character */ 153 | _MY_B = 0100 /* Blank */ 154 | _MY_X = 0200 /* heXadecimal digit */ 155 | ) 156 | -------------------------------------------------------------------------------- /sql/charset/charset_test.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestUtf8(t *testing.T) { 8 | 9 | func() { // TEST for utf8 digit 10 | for i := 0; i < 10; i++ { 11 | b := byte('0') + byte(i) 12 | if CSUtf8GeneralCli.IsDigit(b) == false { 13 | t.Fatalf("%v is not digit type", b) 14 | } 15 | } 16 | }() 17 | 18 | func() { // TEST for utf8 digit 19 | for i := 0; i < 26; i++ { 20 | b := byte('A') + byte(i) 21 | if CSUtf8GeneralCli.IsAlpha(b) == false { 22 | t.Fatalf("%v is not digit type", b) 23 | } 24 | 25 | b = byte('a') + byte(i) 26 | if CSUtf8GeneralCli.IsAlpha(b) == false { 27 | t.Fatalf("%v is not digit type", b) 28 | } 29 | } 30 | }() 31 | } 32 | -------------------------------------------------------------------------------- /sql/charset/utf8_general_cli.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | var CSUtf8GeneralCli *CharsetInfo = &CharsetInfo{ 4 | 33, 5 | 0, 6 | 0, 7 | 8 | "utf8", 9 | "utf8_general_ci", 10 | 11 | ctype_utf8, 12 | nil, 13 | nil, 14 | } 15 | 16 | var ctype_utf8 []byte = []byte{ 17 | 0, 18 | 32, 32, 32, 32, 32, 32, 32, 32, 32, 40, 40, 40, 40, 40, 32, 32, // 0 - 15 19 | 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, // 16 - 31 20 | 72, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, // 32 - 47 21 | //0, 1, 2 ...... 9 22 | 132, 132, 132, 132, 132, 132, 132, 132, 132, 132, 16, 16, 16, 16, 16, 16, 23 | 16, 129, 129, 129, 129, 129, 129, 1, 1, 1, 1, 1, 1, 1, 1, 1, 24 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 25 | 16, 130, 130, 130, 130, 130, 130, 2, 2, 2, 2, 2, 2, 2, 2, 2, 26 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 16, 16, 16, 16, 32, 27 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 28 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 29 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 30 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 31 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 32 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 33 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 34 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 35 | } 36 | -------------------------------------------------------------------------------- /sql/debug.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import "fmt" 4 | 5 | var debug bool = false 6 | 7 | func DEBUG(i interface{}) { 8 | if debug { 9 | fmt.Printf("%v", i) 10 | } 11 | } 12 | 13 | func setDebug(dbg bool) { 14 | debug = dbg 15 | } 16 | -------------------------------------------------------------------------------- /sql/lex_ident.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wangjild/go-mysql-proxy/sql/charset" 6 | . "github.com/wangjild/go-mysql-proxy/sql/state" 7 | ) 8 | 9 | func (lex *SQLLexer) getPureIdentifier() (int, []byte) { 10 | ident_map := lex.cs.IdentMap 11 | c := lex.yyPeek() 12 | rs := int(c) 13 | 14 | for ident_map[lex.yyPeek()] != 0 { 15 | rs |= int(c) 16 | c = lex.yyNext() 17 | } 18 | 19 | if rs&0x80 != 0 { 20 | rs = IDENT_QUOTED 21 | } else { 22 | rs = IDENT 23 | } 24 | 25 | if lex.yyPeek() == '.' && ident_map[int(lex.yyPeek2())] != 0 { 26 | lex.next_state = MY_LEX_IDENT_SEP 27 | } 28 | 29 | return rs, lex.buf[lex.tok_start:lex.ptr] 30 | } 31 | 32 | func (lex *SQLLexer) getIdentifier() (int, []byte) { 33 | 34 | ident_map := lex.cs.IdentMap 35 | 36 | c := lex.yyPeek() 37 | rs := int(c) 38 | 39 | for ident_map[lex.yyPeek()] != 0 { 40 | rs |= int(c) 41 | c = lex.yyNext() 42 | } 43 | 44 | if rs&0x80 != 0 { 45 | rs = IDENT_QUOTED 46 | } else { 47 | rs = IDENT 48 | } 49 | 50 | idc := lex.buf[lex.tok_start:lex.ptr] 51 | DEBUG(fmt.Sprintf("idc:[" + string(idc) + "]\n")) 52 | 53 | start := lex.ptr 54 | 55 | /* 56 | for ; lex.ignore_space && state_map[c] == MY_LEX_SKIP; c = lex.yyNext() { 57 | }*/ 58 | 59 | c = lex.yyPeek() 60 | if start == lex.ptr && lex.yyPeek() == '.' && ident_map[int(lex.yyPeek())] != 0 { 61 | lex.next_state = MY_LEX_IDENT_SEP 62 | } else if ret, ok := findKeywords(idc, c == '('); ok { 63 | lex.next_state = MY_LEX_START 64 | return ret, idc 65 | } 66 | 67 | if idc[0] == '_' && charset.IsValidCharsets(idc[1:]) { 68 | return UNDERSCORE_CHARSET, idc 69 | } 70 | 71 | return rs, idc 72 | } 73 | -------------------------------------------------------------------------------- /sql/lex_ident_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestIdentifier(t *testing.T) { 8 | testMatchReturn(t, "`test ` ", IDENT_QUOTED, false) 9 | } 10 | 11 | func TestMultiIdentifier(t *testing.T) { 12 | str := "SELECT INSERT 'string ' UPDATE DELEte `SELECT` `Update`" 13 | lex, lval := getLexer(str) 14 | 15 | lexExpect(t, lex, lval, SELECT_SYM) 16 | lexExpect(t, lex, lval, INSERT) 17 | 18 | lexExpect(t, lex, lval, TEXT_STRING) 19 | lvalExpect(t, lval, "'string '") 20 | 21 | lexExpect(t, lex, lval, UPDATE_SYM) 22 | lexExpect(t, lex, lval, DELETE_SYM) 23 | 24 | lexExpect(t, lex, lval, IDENT_QUOTED) 25 | lvalExpect(t, lval, "`SELECT`") 26 | 27 | lexExpect(t, lex, lval, IDENT_QUOTED) 28 | lvalExpect(t, lval, "`Update`") 29 | 30 | lexExpect(t, lex, lval, END_OF_INPUT) 31 | } 32 | 33 | func TestParamMarker(t *testing.T) { 34 | str := "select ?,?,? from t1;" 35 | lex, lval := getLexer(str) 36 | 37 | lexExpect(t, lex, lval, SELECT_SYM) 38 | lexExpect(t, lex, lval, PARAM_MARKER) 39 | lexExpect(t, lex, lval, ',') 40 | lexExpect(t, lex, lval, PARAM_MARKER) 41 | lexExpect(t, lex, lval, ',') 42 | lexExpect(t, lex, lval, PARAM_MARKER) 43 | } 44 | 45 | func TestMultiIdentifier1(t *testing.T) { 46 | str := "s n insert `s` `` s" 47 | lex, lval := getLexer(str) 48 | 49 | lexExpect(t, lex, lval, IDENT) 50 | lvalExpect(t, lval, `s`) 51 | 52 | lexExpect(t, lex, lval, IDENT) 53 | lvalExpect(t, lval, `n`) 54 | 55 | lexExpect(t, lex, lval, INSERT) 56 | 57 | lexExpect(t, lex, lval, IDENT_QUOTED) 58 | lvalExpect(t, lval, "`s`") 59 | 60 | lexExpect(t, lex, lval, IDENT_QUOTED) 61 | lvalExpect(t, lval, "``") 62 | 63 | lexExpect(t, lex, lval, IDENT) 64 | lvalExpect(t, lval, `s`) 65 | } 66 | 67 | func TestMultiIdentifier2(t *testing.T) { 68 | str := `table1.column_name=table2.column_name` 69 | lex, lval := getLexer(str) 70 | lexExpect(t, lex, lval, IDENT) 71 | lvalExpect(t, lval, "table1") 72 | 73 | lexExpect(t, lex, lval, '.') 74 | 75 | lexExpect(t, lex, lval, IDENT) 76 | lvalExpect(t, lval, "column_name") 77 | 78 | lexExpect(t, lex, lval, EQ) 79 | 80 | lexExpect(t, lex, lval, IDENT) 81 | lvalExpect(t, lval, "table2") 82 | 83 | lexExpect(t, lex, lval, '.') 84 | 85 | lexExpect(t, lex, lval, IDENT) 86 | lvalExpect(t, lval, "column_name") 87 | } 88 | -------------------------------------------------------------------------------- /sql/lex_keywords_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestKeywords(t *testing.T) { 8 | testMatchReturn(t, `SELECT`, SELECT_SYM, false) 9 | } 10 | 11 | func TestFunctions(t *testing.T) { 12 | testMatchReturn(t, `CURTIME()`, CURTIME, false) 13 | } 14 | 15 | func TestCharsetName(t *testing.T) { 16 | testMatchReturn(t, `_utf8_general_cli`, UNDERSCORE_CHARSET, false) 17 | } 18 | 19 | func TestIdent(t *testing.T) { 20 | testMatchReturn(t, `thisisaident`, IDENT, false) 21 | } 22 | 23 | func TestBoolOp(t *testing.T) { 24 | testMatchReturn(t, `&&`, AND_AND_SYM, false) 25 | testMatchReturn(t, `||`, OR_OR_SYM, false) 26 | testMatchReturn(t, `<`, LT, false) 27 | testMatchReturn(t, `<=`, LE, false) 28 | testMatchReturn(t, `<>`, NE, false) 29 | testMatchReturn(t, `!=`, NE, false) 30 | testMatchReturn(t, `=`, EQ, false) 31 | testMatchReturn(t, `>`, GT_SYM, false) 32 | testMatchReturn(t, `>=`, GE, false) 33 | testMatchReturn(t, `<<`, SHIFT_LEFT, false) 34 | testMatchReturn(t, `>>`, SHIFT_RIGHT, false) 35 | testMatchReturn(t, `<=>`, EQUAL_SYM, false) 36 | 37 | testMatchReturn(t, `:=`, SET_VAR, false) 38 | } 39 | 40 | func TestChar(t *testing.T) { 41 | testMatchReturn(t, `& `, '&', false) 42 | } 43 | 44 | func TestMultiKeywords(t *testing.T) { 45 | lexer, lval := getLexer(`SELECT SHOW Databases SELECT `) 46 | 47 | lexExpect(t, lexer, lval, SELECT_SYM) 48 | lexExpect(t, lexer, lval, SHOW) 49 | lexExpect(t, lexer, lval, DATABASES) 50 | 51 | lexExpect(t, lexer, lval, SELECT_SYM) 52 | } 53 | -------------------------------------------------------------------------------- /sql/lex_nchar.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import () 4 | 5 | func (lexer *SQLLexer) scanNChar(lval *MySQLSymType) (int, byte) { 6 | 7 | // found N'string' 8 | lexer.yyNext() // Skip ' 9 | 10 | // Skip any char except ' 11 | var c byte 12 | for c = lexer.yyNext(); c != 0 && c != '\''; c = lexer.yyNext() { 13 | } 14 | 15 | if c != '\'' { 16 | return ABORT_SYM, c 17 | } 18 | 19 | lval.bytes = lexer.buf[lexer.tok_start:lexer.ptr] 20 | 21 | return NCHAR_STRING, c 22 | } 23 | -------------------------------------------------------------------------------- /sql/lex_number.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | const ( 8 | LONG_LEN = 10 9 | LONGLONG_LEN = 19 10 | SIGNED_LONGLONG_LEN = 19 11 | UNSIGNED_LONGLONG_LEN = 20 12 | ) 13 | 14 | var ( 15 | LONG []byte = []byte{'2', '1', '4', '7', '4', '8', '3', '6', '4', '7'} 16 | SIGNED_LONG []byte = []byte{'-', '2', '1', '4', '7', '4', '8', '3', '6', '4', '8'} 17 | LONGLONG []byte = []byte{'9', '2', '2', '3', '3', '7', '2', '0', '3', '6', '8', '5', '4', '7', '7', '5', '8', '0', '7'} 18 | SIGNED_LONGLONG []byte = []byte{'-', '9', '2', '2', '3', '3', '7', '2', '0', '3', '6', '8', '5', '4', '7', '7', '5', '8', '0', '8'} 19 | UNSIGNED_LONGLONG []byte = []byte{'1', '8', '4', '4', '6', '7', '4', '4', '0', '7', '3', '7', '0', '9', '5', '5', '1', '6', '1', '5'} 20 | ) 21 | 22 | func (lex *SQLLexer) scanInt(lval *MySQLSymType) int { 23 | length := lex.ptr - lex.tok_start 24 | lval.bytes = lex.buf[lex.tok_start:lex.ptr] 25 | 26 | if length < LONG_LEN { 27 | return NUM 28 | } 29 | 30 | neg := false 31 | start := lex.tok_start 32 | if lex.buf[start] == '+' { 33 | start += 1 34 | length -= 1 35 | } else if lex.buf[start] == '-' { 36 | start += 1 37 | length -= 1 38 | neg = true 39 | } 40 | 41 | // ignore any '0' character 42 | for start < lex.ptr && lex.buf[start] == '0' { 43 | start += 1 44 | length -= 1 45 | } 46 | 47 | if length < LONG_LEN { 48 | return NUM 49 | } 50 | 51 | var cmp []byte 52 | var smaller int 53 | var bigger int 54 | if neg { 55 | if length == LONG_LEN { 56 | cmp = SIGNED_LONG[1:len(SIGNED_LONG)] 57 | smaller = NUM 58 | bigger = LONG_NUM 59 | } else if length < SIGNED_LONGLONG_LEN { 60 | return LONG_NUM 61 | } else if length > SIGNED_LONGLONG_LEN { 62 | return DECIMAL_NUM 63 | } else { 64 | cmp = SIGNED_LONGLONG[1:len(SIGNED_LONGLONG)] 65 | smaller = LONG_NUM 66 | bigger = DECIMAL_NUM 67 | } 68 | } else { 69 | if length == LONG_LEN { 70 | cmp = LONG 71 | smaller = NUM 72 | bigger = LONG_NUM 73 | } else if length < LONGLONG_LEN { 74 | return LONG_NUM 75 | } else if length > LONGLONG_LEN { 76 | if length > UNSIGNED_LONGLONG_LEN { 77 | return DECIMAL_NUM 78 | } 79 | cmp = UNSIGNED_LONGLONG 80 | smaller = ULONGLONG_NUM 81 | bigger = DECIMAL_NUM 82 | } else { 83 | cmp = LONGLONG 84 | smaller = LONG_NUM 85 | bigger = ULONGLONG_NUM 86 | } 87 | } 88 | 89 | idx := 0 90 | for idx < len(cmp) && cmp[idx] == lex.buf[start] { 91 | DEBUG(fmt.Sprintf("cmp:[%c] buf[%c]\n", cmp[idx], lex.buf[start])) 92 | idx += 1 93 | start += 1 94 | } 95 | 96 | if idx == len(cmp) { 97 | return smaller 98 | } 99 | 100 | if lex.buf[start] <= cmp[idx] { 101 | return smaller 102 | } 103 | return bigger 104 | } 105 | 106 | func (lex *SQLLexer) scanFloat(lval *MySQLSymType, c *byte) (int, bool) { 107 | cs := lex.cs 108 | 109 | // try match (+|-)? digit+ 110 | if lex.yyPeek() == '+' || lex.yyPeek() == '-' { 111 | lex.yySkip() // ignore this char 112 | } 113 | 114 | // at least we have 1 digit-char 115 | if cs.IsDigit(lex.yyPeek()) { 116 | for ; cs.IsDigit(lex.yyPeek()); lex.yySkip() { 117 | } 118 | 119 | lval.bytes = lex.buf[lex.tok_start:lex.ptr] 120 | return FLOAT_NUM, true 121 | } 122 | 123 | return 0, false 124 | } 125 | -------------------------------------------------------------------------------- /sql/lex_number_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestInt(t *testing.T) { 8 | testMatchReturn(t, `123456`, NUM, false) 9 | testMatchReturn(t, `0000000000000000000000000123456`, NUM, false) 10 | testMatchReturn(t, `2147483646`, NUM, false) // NUM 11 | testMatchReturn(t, `2147483647`, NUM, false) // 2^31 - 1 12 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 13 | testMatchReturn(t, `0000000000000000000002147483648`, LONG_NUM, false) // 2^31 14 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 15 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 16 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 17 | 18 | testMatchReturn(t, `9223372036854775807`, LONG_NUM, false) 19 | testMatchReturn(t, `9223372036854775808`, ULONGLONG_NUM, false) 20 | testMatchReturn(t, `18446744073709551615`, ULONGLONG_NUM, false) 21 | testMatchReturn(t, `18446744073709551616`, DECIMAL_NUM, false) 22 | } 23 | 24 | func TestNum(t *testing.T) { 25 | testMatchReturn(t, `0x1234`, HEX_NUM, false) 26 | testMatchReturn(t, `0xa4234`, HEX_NUM, false) 27 | testMatchReturn(t, `0b0110`, BIN_NUM, false) 28 | } 29 | 30 | func TestFloatNum(t *testing.T) { 31 | testMatchReturn(t, " 10e-10", FLOAT_NUM, false) 32 | testMatchReturn(t, " 10E+10", FLOAT_NUM, false) 33 | testMatchReturn(t, " 10E10", FLOAT_NUM, false) 34 | testMatchReturn(t, "1.20E10", FLOAT_NUM, false) 35 | testMatchReturn(t, "1.20E-10", FLOAT_NUM, false) 36 | } 37 | 38 | func TestDecimalNum(t *testing.T) { 39 | testMatchReturn(t, `.21`, DECIMAL_NUM, false) 40 | testMatchReturn(t, `72.21`, DECIMAL_NUM, false) 41 | } 42 | 43 | func TestHex(t *testing.T) { 44 | testMatchReturn(t, `X'4D7953514C'`, HEX_NUM, false) 45 | 46 | testMatchReturn(t, `x'D34F2X`, ABORT_SYM, false) 47 | testMatchReturn(t, `x'`, ABORT_SYM, false) 48 | 49 | } 50 | 51 | func TestBin(t *testing.T) { 52 | testMatchReturn(t, `b'0101010111000'`, BIN_NUM, false) 53 | testMatchReturn(t, `b'0S01010111000'`, ABORT_SYM, false) 54 | testMatchReturn(t, `b'12312351123`, ABORT_SYM, false) 55 | } 56 | 57 | func TestMultiNum(t *testing.T) { 58 | str := `123 'string1' 18446744073709551616 1.20E-10 .312 x'4D7953514C' ` 59 | lex, lval := getLexer(str) 60 | 61 | lexExpect(t, lex, lval, NUM) 62 | lvalExpect(t, lval, `123`) 63 | 64 | lexExpect(t, lex, lval, TEXT_STRING) 65 | lvalExpect(t, lval, `'string1'`) 66 | 67 | lexExpect(t, lex, lval, DECIMAL_NUM) 68 | lvalExpect(t, lval, `18446744073709551616`) 69 | 70 | lexExpect(t, lex, lval, FLOAT_NUM) 71 | lvalExpect(t, lval, `1.20E-10`) 72 | 73 | lexExpect(t, lex, lval, DECIMAL_NUM) 74 | lvalExpect(t, lval, `.312`) 75 | 76 | lexExpect(t, lex, lval, HEX_NUM) 77 | lvalExpect(t, lval, `x'4D7953514C'`) 78 | 79 | lexExpect(t, lex, lval, END_OF_INPUT) 80 | } 81 | 82 | func TestNumberInPlacehold(t *testing.T) { 83 | str := ` (5)` 84 | lex, lval := getLexer(str) 85 | lexExpect(t, lex, lval, '(') 86 | lexExpect(t, lex, lval, NUM) 87 | lexExpect(t, lex, lval, ')') 88 | } 89 | -------------------------------------------------------------------------------- /sql/lex_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func getLexer(str string) (lexer *SQLLexer, lval *MySQLSymType) { 8 | lval = new(MySQLSymType) 9 | lexer = NewSQLLexer(str) 10 | 11 | return 12 | } 13 | 14 | func testMatchReturn(t *testing.T, str string, match int, dbg bool) (*SQLLexer, *MySQLSymType) { 15 | setDebug(dbg) 16 | lexer, lval := getLexer(str) 17 | ret := lexer.Lex(lval) 18 | if ret != match { 19 | t.Fatalf("test failed! expect[%s] return[%s]", TokenName(match), TokenName(ret)) 20 | } 21 | 22 | return lexer, lval 23 | } 24 | 25 | func TestNULLEscape(t *testing.T) { 26 | lexer, lval := getLexer("\\N") 27 | if lexer.Lex(lval) != NULL_SYM { 28 | t.Fatal("test failed") 29 | } 30 | } 31 | 32 | func TestSingleComment(t *testing.T) { 33 | lexer, lval := getLexer(" -- Single Line Comment. \r\n") 34 | 35 | if lexer.Lex(lval) != END_OF_INPUT { 36 | t.Fatal("test failed") 37 | } 38 | } 39 | 40 | func TestSingleComment2(t *testing.T) { 41 | } 42 | -------------------------------------------------------------------------------- /sql/lex_text.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | /** 8 | * For Anltr3 Defination: 9 | 10 | SINGLE_QUOTED_TEXT 11 | @init { int escape_count = 0; }: 12 | SINGLE_QUOTE 13 | ( 14 | SINGLE_QUOTE SINGLE_QUOTE { escape_count++; } 15 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ESCAPE_OPERATOR . { escape_count++; } 16 | | {SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(SINGLE_QUOTE) 17 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(SINGLE_QUOTE | ESCAPE_OPERATOR) 18 | )* 19 | SINGLE_QUOTE 20 | { EMIT(); LTOKEN->user1 = escape_count; } 21 | ; 22 | 23 | DOUBLE_QUOTED_TEXT 24 | @init { int escape_count = 0; }: 25 | DOUBLE_QUOTE 26 | ( 27 | DOUBLE_QUOTE DOUBLE_QUOTE { escape_count++; } 28 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ESCAPE_OPERATOR . { escape_count++; } 29 | | {SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(DOUBLE_QUOTE) 30 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(DOUBLE_QUOTE | ESCAPE_OPERATOR) 31 | )* 32 | DOUBLE_QUOTE 33 | { EMIT(); LTOKEN->user1 = escape_count; } 34 | ; 35 | */ 36 | 37 | var StringFormatError error = errors.New("text string format error") 38 | 39 | func (lexer *SQLLexer) getQuotedText() ([]byte, error) { 40 | var dq bool 41 | var sep byte 42 | 43 | if sep = lexer.yyLookHead(); sep == '"' { 44 | dq = true 45 | } 46 | 47 | for lexer.ptr < uint(len(lexer.buf)) { 48 | c := lexer.yyNext() 49 | 50 | if c == '\\' && !lexer.sqlMode.MODE_NO_BACKSLASH_ESCAPES { 51 | if lexer.yyPeek() == EOF { 52 | return nil, StringFormatError 53 | } 54 | 55 | lexer.yySkip() // skip next char 56 | } else if matchQuote(c, dq) { 57 | if matchQuote(lexer.yyPeek(), dq) { 58 | // found a escape quote. Eg. '' "" 59 | lexer.yySkip() // skip for the second quote 60 | continue 61 | } 62 | // we have found the last quote 63 | return lexer.buf[lexer.tok_start:lexer.ptr], nil 64 | } 65 | } 66 | 67 | return nil, StringFormatError 68 | } 69 | 70 | func matchQuote(c byte, double_quote bool) bool { 71 | if double_quote { 72 | return c == '"' 73 | } 74 | 75 | return c == '\'' 76 | } 77 | -------------------------------------------------------------------------------- /sql/lex_text_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func testTextParse(t *testing.T, str string, mode SQLMode) { 8 | lexer, lval := getLexer(str) 9 | lexer.sqlMode = mode 10 | if r := lexer.Lex(lval); r != TEXT_STRING { 11 | t.Fatalf("parse text failed. return[%s]", TokenName(r)) 12 | } 13 | 14 | if string(lval.bytes) != str { 15 | t.Fatalf("orgin[%s] not match parsed[%s]", str, string(lval.bytes)) 16 | } 17 | } 18 | 19 | func TestSingleQuoteString(t *testing.T) { 20 | testMatchReturn(t, `'single Quoted string'`, TEXT_STRING, false) 21 | } 22 | 23 | func TestDoubleQuoteString(t *testing.T) { 24 | testMatchReturn(t, `"double quoted string"`, TEXT_STRING, false) 25 | } 26 | 27 | func TestAnsiQuotesSQLModeString(t *testing.T) { 28 | str := `'a' ' ' 'string'` 29 | lexer, lval := getLexer(str) 30 | lexer.sqlMode.MODE_ANSI_QUOTES = true 31 | 32 | if lexer.Lex(lval) != TEXT_STRING { 33 | t.Fatalf("parse ansi quotes string failed!") 34 | } 35 | 36 | } 37 | 38 | func TestSingleQuoteString3(t *testing.T) { 39 | testTextParse(t, `'afasgasdgasg'`, SQLMode{}) 40 | testTextParse(t, `'''afasgasdgasg'`, SQLMode{}) 41 | testTextParse(t, `''`, SQLMode{}) 42 | testTextParse(t, `""`, SQLMode{}) 43 | 44 | testTextParse(t, `'""hello""'`, SQLMode{}) 45 | testTextParse(t, `'hel''lo'`, SQLMode{}) 46 | testTextParse(t, `'\'hello'`, SQLMode{}) 47 | 48 | testTextParse(t, `'\''`, SQLMode{}) 49 | testTextParse(t, `'\'`, SQLMode{MODE_NO_BACKSLASH_ESCAPES: true}) 50 | } 51 | 52 | func TestStringException(t *testing.T) { 53 | str := `'\'` 54 | lexer, lval := getLexer(str) 55 | if r := lexer.Lex(lval); r != ABORT_SYM { 56 | t.Fatalf("parse text failed. return[%s]", MySQLToknames[r-ABORT_SYM]) 57 | } 58 | 59 | lexer, lval = getLexer(`"\`) 60 | if r := lexer.Lex(lval); r != ABORT_SYM { 61 | t.Fatalf("parse text failed. return[%s]", MySQLToknames[r-ABORT_SYM]) 62 | } 63 | } 64 | 65 | func TestNChar(t *testing.T) { 66 | testMatchReturn(t, `n'some text'`, NCHAR_STRING, false) 67 | testMatchReturn(t, `N'some text'`, NCHAR_STRING, false) 68 | 69 | testMatchReturn(t, `N'`, ABORT_SYM, false) 70 | } 71 | 72 | func lexExpect(t *testing.T, lexer *SQLLexer, lval *MySQLSymType, expect int) { 73 | if ret := lexer.Lex(lval); ret != expect { 74 | t.Fatalf("expect[%s] return[%s]", TokenName(expect), TokenName(ret)) 75 | } 76 | } 77 | 78 | func lvalExpect(t *testing.T, lval *MySQLSymType, expect string) { 79 | if string(lval.bytes) != expect { 80 | t.Fatalf("expect[%s] return[%s]", expect, string(lval.bytes)) 81 | } 82 | } 83 | 84 | func TestMultiString(t *testing.T) { 85 | str := `"string1" 'string2' 'string3' n'string 4' ` 86 | lex, lval := getLexer(str) 87 | 88 | lexExpect(t, lex, lval, TEXT_STRING) 89 | lvalExpect(t, lval, `"string1"`) 90 | 91 | lexExpect(t, lex, lval, TEXT_STRING) 92 | lvalExpect(t, lval, `'string2'`) 93 | 94 | lexExpect(t, lex, lval, TEXT_STRING) 95 | lvalExpect(t, lval, `'string3'`) 96 | 97 | lexExpect(t, lex, lval, NCHAR_STRING) 98 | lvalExpect(t, lval, `n'string 4'`) 99 | 100 | lexExpect(t, lex, lval, END_OF_INPUT) 101 | } 102 | -------------------------------------------------------------------------------- /sql/lex_var_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestHostName(t *testing.T) { 8 | // testMatchReturn(t, `user@hostname`, LEX_HOSTNAME, true) 9 | } 10 | 11 | func TestSystemVariables(t *testing.T) { 12 | lexer, lval := testMatchReturn(t, `@@uservar`, '@', false) 13 | ret := lexer.Lex(lval) 14 | if ret != '@' { 15 | t.Fatalf("expect[IDENT_QUOTED] unexpect %s", TokenName(ret)) 16 | } 17 | 18 | ret = lexer.Lex(lval) 19 | if ret != IDENT { 20 | t.Fatalf("expect[IDENT] unexpect %s", TokenName(ret)) 21 | } 22 | } 23 | 24 | func TestUserDefinedVariables(t *testing.T) { 25 | lexer, lval := testMatchReturn(t, "@`uservar`", '@', false) 26 | ret := lexer.Lex(lval) 27 | if ret != IDENT_QUOTED { 28 | t.Fatalf("expect[IDENT_QUOTED] unexpect %s", TokenName(ret)) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /sql/parser.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | func Parse(sql string) (IStatement, error) { 8 | lexer := NewSQLLexer(sql) 9 | if MySQLParse(lexer) != 0 { 10 | return nil, errors.New(lexer.LastError) 11 | } 12 | 13 | return lexer.ParseTree, nil 14 | } 15 | -------------------------------------------------------------------------------- /sql/parser_dal_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSet(t *testing.T) { 8 | var st IStatement 9 | 10 | st = testParse(`set global autocommit = 1`, t, false) 11 | matchType(t, st, &Set{}) 12 | 13 | st = testParse(`set global autocommit = 1, sysvar = 2`, t, false) 14 | set := st.(*Set) 15 | 16 | v := set.VarList[0] 17 | if v.Life != Life_Global { 18 | t.Fatal("missed life type") 19 | } 20 | 21 | if v.Name != "autocommit" { 22 | t.Fatal("missed varname") 23 | } 24 | 25 | v = set.VarList[1] 26 | if v.Life != Life_Global { 27 | t.Fatal("missed life type") 28 | } 29 | 30 | if v.Name != "sysvar" { 31 | t.Fatal("missed varname") 32 | } 33 | 34 | } 35 | 36 | func TestShow(t *testing.T) { 37 | var st IStatement 38 | 39 | st = testParse(`show session variables like 'autocommit'`, t, false) 40 | matchType(t, st, &ShowVariables{}) 41 | 42 | st = testParse(`show full tables in test`, t, false) 43 | matchSchemas(t, st, `test`) 44 | 45 | st = testParse(`show table status in test`, t, false) 46 | matchType(t, st, &ShowTableStatus{}) 47 | matchSchemas(t, st, `test`) 48 | 49 | st = testParse(`show global status`, t, false) 50 | matchType(t, st, &ShowStatus{}) 51 | 52 | st = testParse(`SHOW SLAVE STATUS`, t, false) 53 | matchType(t, st, &ShowSlaveStatus{}) 54 | 55 | st = testParse(`SHOW SLAVE HOSTS`, t, false) 56 | matchType(t, st, &ShowSlaveHosts{}) 57 | 58 | st = testParse(`SHOW Profiles`, t, false) 59 | matchType(t, st, &ShowProfiles{}) 60 | 61 | st = testParse(`SHOW FULL PROCESSLIST`, t, false) 62 | matchType(t, st, &ShowProcessList{}) 63 | 64 | st = testParse(`SHOW PLUGINS`, t, false) 65 | matchType(t, st, &ShowPlugins{}) 66 | 67 | st = testParse(`SHOW PRIVILEGES`, t, false) 68 | matchType(t, st, &ShowPrivileges{}) 69 | 70 | st = testParse(`SHOW OPEN TABLES IN test like 'tables_%'`, t, false) 71 | matchType(t, st, &ShowOpenTables{}) 72 | matchSchemas(t, st, `test`) 73 | 74 | st = testParse(`SHOW MASTER STATUS`, t, false) 75 | matchType(t, st, &ShowMasterStatus{}) 76 | 77 | st = testParse(`SHOW INDEX FROM mytable FROM mydb;`, t, false) 78 | matchType(t, st, &ShowIndex{}) 79 | matchSchemas(t, st, `mydb`) 80 | 81 | st = testParse(`SHOW GRANTS FOR 'root'@'localhost';`, t, false) 82 | matchType(t, st, &ShowGrants{}) 83 | 84 | st = testParse(`SHOW FUNCTION STATUS`, t, false) 85 | matchType(t, st, &ShowFunction{}) 86 | 87 | st = testParse(`SHOW FUNCTION CODE dbname.func_name`, t, false) 88 | matchSchemas(t, st, `dbname`) 89 | 90 | st = testParse(`SHOW EVENTS FROM test;`, t, false) 91 | matchSchemas(t, st, `test`) 92 | 93 | st = testParse(`SHOW ERRORS`, t, false) 94 | matchType(t, st, &ShowErrors{}) 95 | 96 | st = testParse(`SHOW COUNT(*) ERRORS`, t, false) 97 | matchType(t, st, &ShowErrors{}) 98 | 99 | st = testParse(`Show STORAGE ENGINES`, t, false) 100 | matchType(t, st, &ShowEngines{}) 101 | 102 | st = testParse(`SHOW ENGINE PERFORMANCE_SCHEMA STATUS`, t, false) 103 | matchType(t, st, &ShowEngines{}) 104 | 105 | st = testParse(`SHOW Databases like '%presale%'`, t, false) 106 | matchType(t, st, &ShowDatabases{}) 107 | 108 | st = testParse(`SHOW CREATE View test.view`, t, false) 109 | matchType(t, st, &ShowCreate{}) 110 | matchSchemas(t, st, `test`) 111 | 112 | st = testParse(`SHOW CREATE TRIGGER test.trigger`, t, false) 113 | matchType(t, st, &ShowCreate{}) 114 | matchSchemas(t, st, `test`) 115 | 116 | st = testParse(`SHOW CREATE TABLE test.table`, t, false) 117 | matchType(t, st, &ShowCreate{}) 118 | matchSchemas(t, st, `test`) 119 | 120 | st = testParse(`SHOW CREATE EVENT test.e_daily`, t, false) 121 | matchType(t, st, &ShowCreate{}) 122 | matchSchemas(t, st, `test`) 123 | 124 | st = testParse(`SHOW CREATE PROCEDURE test.simpleproc`, t, false) 125 | matchType(t, st, &ShowCreate{}) 126 | matchSchemas(t, st, `test`) 127 | 128 | st = testParse(`SHOW CREATE DATABASE test`, t, false) 129 | matchType(t, st, &ShowCreateDatabase{}) 130 | matchSchemas(t, st, `test`) 131 | 132 | st = testParse(`SHOW CHARACTER SET LIKE 'latin%';`, t, false) 133 | matchType(t, st, &ShowCharset{}) 134 | 135 | st = testParse(`SHOW COLUMNS FROM mytable FROM mydb;`, t, false) 136 | matchSchemas(t, st, `mydb`) 137 | 138 | st = testParse(`SHOW COLUMNS FROM mydb.mytable;`, t, false) 139 | matchSchemas(t, st, `mydb`) 140 | 141 | st = testParse(`SHOW COLLATION LIKE 'latin1%';`, t, false) 142 | matchType(t, st, &ShowCollation{}) 143 | 144 | st = testParse(`SHOW Binary LOGS;`, t, false) 145 | matchType(t, st, &ShowLogs{}) 146 | 147 | st = testParse(`show binlog events in 'log1' from 123 limit 2, 4`, t, false) 148 | matchType(t, st, &ShowLogEvents{}) 149 | } 150 | 151 | func TestTableMtStmt(t *testing.T) { 152 | st := testParse(`analyze table db1.tb1`, t, false) 153 | matchType(t, st, &Analyze{}) 154 | matchSchemas(t, st, `db1`) 155 | 156 | st = testParse(`CHECK TABLE test.test_table FAST QUICK;`, t, false) 157 | matchType(t, st, &Check{}) 158 | matchSchemas(t, st, `test`) 159 | 160 | st = testParse(`CHECKSUM TABLE test.test_table QUICK;`, t, false) 161 | matchType(t, st, &CheckSum{}) 162 | matchSchemas(t, st, `test`) 163 | 164 | st = testParse(`OPTIMIZE TABLE foo.bar`, t, false) 165 | matchType(t, st, &Optimize{}) 166 | matchSchemas(t, st, `foo`) 167 | 168 | st = testParse(`REPAIR NO_WRITE_TO_BINLOG TABLE foo.bar quick`, t, false) 169 | matchType(t, st, &Repair{}) 170 | matchSchemas(t, st, `foo`) 171 | 172 | } 173 | 174 | func TestPluginAndUdf(t *testing.T) { 175 | st := testParse(`CREATE AGGREGATE FUNCTION function_name RETURNS DECIMAL SONAME 'shared_library_name'`, t, false) 176 | matchType(t, st, &CreateUDF{}) 177 | 178 | st = testParse(`INSTALL PLUGIN plugin_name SONAME 'shared_library_name'`, t, false) 179 | matchType(t, st, &Install{}) 180 | if _, ok := st.(IPluginAndUdf); !ok { 181 | t.Fatalf("type[%T] is not a instance of IPluginAndUdf", st) 182 | } 183 | 184 | st = testParse(`UNINSTALL PLUGIN plugin_name`, t, false) 185 | matchType(t, st, &Uninstall{}) 186 | if _, ok := st.(IPluginAndUdf); !ok { 187 | t.Fatalf("type[%T] is not a instance of IPluginAndUdf", st) 188 | } 189 | } 190 | 191 | func TestAccountMgrStmt(t *testing.T) { 192 | st := testParse(`ALTER USER 'jeffrey'@'localhost' PASSWORD EXPIRE;`, t, false) 193 | matchType(t, st, &AlterUser{}) 194 | 195 | st = testParse(`CREATE USER 'jeffrey'@'localhost' IDENTIFIED BY 'mypass';`, t, false) 196 | matchType(t, st, &CreateUser{}) 197 | 198 | st = testParse(`DROP USER 'jeffrey'@'localhost';`, t, false) 199 | matchType(t, st, &DropUser{}) 200 | 201 | st = testParse(`GRANT SELECT ON db2.invoice TO 'jeffrey'@'localhost';`, t, false) 202 | matchType(t, st, &Grant{}) 203 | 204 | st = testParse(`RENAME USER 'jeffrey'@'localhost' TO 'jeff'@'127.0.0.1';`, t, false) 205 | matchType(t, st, &RenameUser{}) 206 | 207 | st = testParse(`REVOKE INSERT ON *.* FROM 'jeffrey'@'localhost';`, t, false) 208 | matchType(t, st, &Revoke{}) 209 | 210 | st = testParse(`SET PASSWORD FOR 'jeffrey'@'localhost' = PASSWORD('cleartext password');`, t, false) 211 | // matchType(t, st, &SetPassword{}) 212 | } 213 | 214 | func TestBinlog(t *testing.T) { 215 | st := testParse(`BINLOG 'str'`, t, false) 216 | matchType(t, st, &Binlog{}) 217 | } 218 | 219 | func TestCacheIndex(t *testing.T) { 220 | st := testParse(`CACHE INDEX d1.t1, d2.t2, d3.t3 IN hot_cache;`, t, false) 221 | matchType(t, st, &CacheIndex{}) 222 | matchSchemas(t, st, `d1`, `d2`, `d3`) 223 | 224 | st = testParse(`LOAD INDEX INTO CACHE pt PARTITION (p1, p3);`, t, false) 225 | matchType(t, st, &LoadIndex{}) 226 | matchSchemas(t, st) 227 | 228 | st = testParse(`LOAD INDEX INTO CACHE db1.t1, db2.t2 IGNORE LEAVES;`, t, false) 229 | matchSchemas(t, st, `db1`, `db2`) 230 | } 231 | 232 | func TestFlush(t *testing.T) { 233 | st := testParse(`FLUSH TABLES db1.tbl_name , db2.tbl_name WITH READ LOCK`, t, false) 234 | matchType(t, st, &FlushTables{}) 235 | matchSchemas(t, st, `db1`, `db2`) 236 | 237 | st = testParse(`flush logs`, t, false) 238 | matchType(t, st, &Flush{}) 239 | } 240 | 241 | func TestKill(t *testing.T) { 242 | st := testParse(`kill connection 1234`, t, false) 243 | matchType(t, st, &Kill{}) 244 | } 245 | 246 | func TestReset(t *testing.T) { 247 | st := testParse(`reset master, query cache, slave`, t, false) 248 | matchType(t, st, &Reset{}) 249 | } 250 | -------------------------------------------------------------------------------- /sql/parser_ddl_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestAlter(t *testing.T) { 8 | st := testParse(`alter view d1.v1 as select * from t2;`, t, false) 9 | matchSchemas(t, st, `d1`) 10 | 11 | st = testParse( 12 | `ALTER EVENT myschema.myevent 13 | ON SCHEDULE 14 | AT CURRENT_TIMESTAMP + INTERVAL 1 DAY 15 | DO 16 | TRUNCATE TABLE myschema.mytable;`, t, false) 17 | matchSchemas(t, st, `myschema`) 18 | 19 | st = testParse(`ALTER EVENT olddb.myevent RENAME TO newdb.myevent;`, t, false) 20 | matchSchemas(t, st, `olddb`, `newdb`) 21 | 22 | st = testParse(`ALTER SERVER s OPTIONS (USER 'sally');`, t, false) 23 | 24 | } 25 | 26 | func TestCreate(t *testing.T) { 27 | st := testParse(`CREATE DATABASE IF NOT EXISTS my_db default charset utf8 COLLATE utf8_general_ci;`, t, false) 28 | 29 | st = testParse(`CREATE EVENT mydb.myevent 30 | ON SCHEDULE AT CURRENT_TIMESTAMP + INTERVAL 1 HOUR 31 | DO 32 | UPDATE myschema.mytable SET mycol = mycol + 1;`, t, false) 33 | matchSchemas(t, st, `mydb`) 34 | 35 | st = testParse(`CREATE FUNCTION thisdb.hello (s CHAR(20)) RETURNS CHAR(50) DETERMINISTIC RETURN CONCAT('Hello, ',s,'!');`, t, false) 36 | matchSchemas(t, st, `thisdb`) 37 | 38 | st = testParse( 39 | `CREATE DEFINER = 'admin'@'localhost' PROCEDURE db1.account_count() 40 | SQL SECURITY INVOKER 41 | BEGIN 42 | SELECT 'Number of accounts:', COUNT(*) FROM mysql.user; 43 | END;`, t, false) 44 | matchSchemas(t, st, `db1`) 45 | 46 | st = testParse(`CREATE INDEX part_of_name ON customer (name(10));`, t, false) 47 | st = testParse(`CREATE INDEX id_index ON lookup (id) USING BTREE;`, t, false) 48 | st = testParse(`CREATE INDEX id_index ON t1 (id) COMMENT 'MERGE_THRESHOLD=40';`, t, false) 49 | 50 | st = testParse( 51 | `CREATE SERVER s FOREIGN DATA WRAPPER mysql 52 | OPTIONS (USER 'Remote', HOST '192.168.1.106', DATABASE 'test');`, t, false) 53 | } 54 | 55 | func TestCreateTable(t *testing.T) { 56 | st := testParse(`CREATE TABLE db1.t1 (col1 INT, col2 CHAR(5)) 57 | PARTITION BY HASH(col1);`, t, false) 58 | matchSchemas(t, st, `db1`) 59 | 60 | testParse(`CREATE TABLE t1 (col1 INT, col2 CHAR(5), col3 DATETIME) 61 | PARTITION BY HASH ( YEAR(col3) );`, t, false) 62 | testParse(`CREATE /*!32302 TEMPORARY */ TABLE t (a INT);`, t, false) 63 | 64 | testParse(`SELECT /*! STRAIGHT_JOIN */ col1 FROM table1,table2`, t, false) 65 | } 66 | 67 | func TestDrop(t *testing.T) { 68 | st := testParse(`DROP EVENT IF EXISTS db1.event_name`, t, false) 69 | matchSchemas(t, st, `db1`) 70 | 71 | st = testParse(`Drop Procedure If exists db1.sp_name`, t, false) 72 | matchSchemas(t, st, `db1`) 73 | 74 | st = testParse("DROP INDEX `PRIMARY` ON db1.t1;", t, false) 75 | matchSchemas(t, st, `db1`) 76 | 77 | testParse("Drop server if exists server_name", t, false) 78 | 79 | st = testParse("DROP TABLE IF EXISTS B.B, C.C, A.A;", t, false) 80 | matchSchemas(t, st, `B`, `C`, `A`) 81 | 82 | st = testParse("DROP TRIGGER schema_name.trigger_name;", t, false) 83 | matchSchemas(t, st, `schema_name`) 84 | } 85 | 86 | func TestOthers(t *testing.T) { 87 | st := testParse(`Truncate db1.table1`, t, false) 88 | matchSchemas(t, st, `db1`) 89 | 90 | testParse(`RENAME TABLE current_db.tbl_name TO other_db.tbl_name;`, t, false) 91 | } 92 | -------------------------------------------------------------------------------- /sql/parser_dml_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func fmtimport() { 10 | fmt.Println() 11 | } 12 | 13 | func matchType(t *testing.T, st IStatement, ref interface{}) { 14 | if reflect.TypeOf(st) != reflect.TypeOf(ref) { 15 | t.Fatalf("expect type[%v] not match[%v]", reflect.TypeOf(ref), reflect.TypeOf(st)) 16 | } 17 | } 18 | 19 | func matchSchemas(t *testing.T, st IStatement, tables ...string) { 20 | var ts []string 21 | 22 | switch ast := st.(type) { 23 | case *Select: 24 | ts = ast.GetSchemas() 25 | case *Union: 26 | ts = ast.GetSchemas() 27 | case *Insert: 28 | ts = ast.GetSchemas() 29 | case *Delete: 30 | ts = ast.GetSchemas() 31 | case *Update: 32 | ts = ast.GetSchemas() 33 | case *Replace: 34 | ts = ast.GetSchemas() 35 | case *AlterView: 36 | ts = ast.GetSchemas() 37 | case IDDLSchemas: 38 | ts = ast.GetSchemas() 39 | case *Lock: 40 | ts = ast.GetSchemas() 41 | case *DescribeTable: 42 | ts = ast.GetSchemas() 43 | case *DescribeStmt: 44 | ts = ast.GetSchemas() 45 | case ITableMtStmt: 46 | ts = ast.GetSchemas() 47 | case *CacheIndex: 48 | ts = ast.GetSchemas() 49 | case *LoadIndex: 50 | ts = ast.GetSchemas() 51 | case *FlushTables: 52 | ts = ast.GetSchemas() 53 | case IShowSchemas: 54 | ts = ast.GetSchemas() 55 | default: 56 | t.Fatalf("unknow statement type: %T", ast) 57 | } 58 | 59 | if len(tables) == 0 && len(ts) == 0 { 60 | return 61 | } else if len(tables) != len(ts) { 62 | t.Fatalf("expect table number[%d] not match return[%d]", len(tables), len(ts)) 63 | } 64 | 65 | for k, v := range ts { 66 | if v != tables[k] { 67 | t.Fatalf("expect table[%s] not match return[%s]", tables[k], v) 68 | } 69 | } 70 | 71 | } 72 | 73 | func TestSelect(t *testing.T) { 74 | st := testParse("SELECT * FROM table1;", t, false) 75 | matchSchemas(t, st) 76 | 77 | st = testParse("SELECT t1.* FROM (select * from db1.table1) as t1;", t, false) 78 | matchSchemas(t, st, "db1") 79 | 80 | st = testParse("SELECT sb1,sb2,sb3 \n FROM (SELECT s1 AS sb1, s2 AS sb2, s3*2 AS sb3 FROM db1.t1) AS sb \n WHERE sb1 > 1;", t, false) 81 | matchSchemas(t, st, "db1") 82 | 83 | st = testParse("SELECT AVG(SUM(column1)) FROM t1 GROUP BY column1;", t, false) 84 | matchSchemas(t, st) 85 | 86 | st = testParse("SELECT REPEAT('a',1) UNION SELECT REPEAT('b',10);", t, false) 87 | matchSchemas(t, st) 88 | 89 | st = testParse(`(SELECT a FROM db1.t1 WHERE a=10 AND B=1 ORDER BY a LIMIT 10) 90 | UNION 91 | (SELECT a FROM db2.t2 WHERE a=11 AND B=2 ORDER BY a LIMIT 10);`, t, false) 92 | matchSchemas(t, st, "db1", "db2") 93 | 94 | st = testParse(`SELECT funcs(s) 95 | FROM db1.table1 96 | LEFT OUTER JOIN db2.table2 97 | ON db1.table1.column_name=db2.table2.column_name;`, t, false) 98 | matchSchemas(t, st, "db1", "db2") 99 | 100 | st = testParse("SELECT * FROM db1.table1 LEFT JOIN db2.table2 ON table1.id=table2.id LEFT JOIN db3.table3 ON table2.id = table3.id for update", t, false) 101 | matchSchemas(t, st, "db1", "db2", "db3") 102 | 103 | if st.(*Select).LockType != LockType_ForUpdate { 104 | t.Fatalf("lock type is not For Update") 105 | } 106 | 107 | st = testParse(`select last_insert_id() as a`, t, false) 108 | st = testParse(`SELECT substr('''a''bc',0,3) FROM dual`, t, false) 109 | testParse(`SELECT /*mark for picman*/ * FROM filterd limit 1;`, t, false) 110 | 111 | testParse(`SELECT ?,?,? from t1;`, t, false) 112 | } 113 | 114 | func TestInsert(t *testing.T) { 115 | st := testParse(`INSERT INTO db1.tbl_temp2 (fld_id) 116 | SELECT tempdb.tbl_temp1.fld_order_id 117 | FROM tempdb.tbl_temp1 WHERE tbl_temp1.fld_order_id > 100;`, t, false) 118 | matchSchemas(t, st, "db1", "tempdb") 119 | } 120 | 121 | func TestUpdate(t *testing.T) { 122 | st := testParse(`UPDATE t1 SET col1 = col1 + 1, col2 = col1;`, t, false) 123 | matchSchemas(t, st) 124 | 125 | st = testParse("UPDATE `Table A`,`Table B` SET `Table A`.`text`=concat_ws('',`Table A`.`text`,`Table B`.`B-num`,\" from \",`Table B`.`date`,'/') WHERE `Table A`.`A-num` = `Table B`.`A-num`", t, false) 126 | matchSchemas(t, st) 127 | 128 | st = testParse(`UPDATE db1.items,db2.month SET items.price=month.price 129 | WHERE items.id=month.id;`, t, false) 130 | matchSchemas(t, st, "db1", "db2") 131 | } 132 | 133 | func TestDelete(t *testing.T) { 134 | st := testParse(`DELETE FROM db.somelog WHERE user = 'jcole' 135 | ORDER BY timestamp_column LIMIT 1;`, t, false) 136 | matchSchemas(t, st, "db") 137 | 138 | st = testParse(`DELETE FROM db1.t1, db2.t2 USING t1 INNER JOIN t2 INNER JOIN db3.t3 139 | WHERE t1.id=t2.id AND t2.id=t3.id;`, t, false) 140 | matchSchemas(t, st, "db1", "db2", "db3") 141 | 142 | st = testParse(`DELETE FROM a1, a2 USING db1.t1 AS a1 INNER JOIN t2 AS a2 143 | WHERE a1.id=a2.id;`, t, false) 144 | matchSchemas(t, st, "db1") 145 | } 146 | 147 | func TestReplace(t *testing.T) { 148 | st := testParse(`REPLACE INTO test2 VALUES (1, 'Old', '2014-08-20 18:47:00');`, t, false) 149 | matchSchemas(t, st) 150 | 151 | st = testParse(`REPLACE INTO dbname2.test2 VALUES (1, 'Old', '2014-08-20 18:47:00');`, t, false) 152 | matchSchemas(t, st, "dbname2") 153 | } 154 | -------------------------------------------------------------------------------- /sql/parser_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func PrintTree(statement IStatement) { 9 | if statement == nil { 10 | fmt.Println(`(nil)`) 11 | } 12 | 13 | switch st := statement.(type) { 14 | case *Union: 15 | fmt.Printf("left: %+v right: %+v\n", st.Right) 16 | case *Select: 17 | fmt.Printf("From: %+v Lock: %+v\n", st.From, st.LockType) 18 | default: 19 | fmt.Println("Yet Unknow Statement:", st) 20 | } 21 | } 22 | 23 | func testParse(sql string, t *testing.T, dbg bool) IStatement { 24 | setDebug(dbg) 25 | if st, err := Parse(sql); err != nil { 26 | setDebug(false) 27 | t.Fatalf("%v", err) 28 | return nil 29 | } else { 30 | setDebug(false) 31 | return st 32 | } 33 | } 34 | 35 | func TestExplain(t *testing.T) { 36 | testParse("EXPLAIN SELECT f1(5)", t, false) 37 | testParse("EXPLAIN SELECT * FROM t1 AS a1, (SELECT BENCHMARK(1000000, MD5(NOW())));", t, false) 38 | } 39 | 40 | func TestParse(t *testing.T) { 41 | setDebug(false) 42 | if _, err := Parse("Select version()"); err != nil { 43 | t.Fatalf("%v", err) 44 | } 45 | } 46 | 47 | func TestTokenName(t *testing.T) { 48 | if name := MySQLTokname(ABORT_SYM); name == "" { 49 | t.Fatal("get token name error") 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /sql/parser_token.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func TokenName(tok int) string { 8 | 9 | if tok == 0 { 10 | return "EOF" 11 | } 12 | 13 | if tok > 31 && tok < 126 { 14 | return fmt.Sprintf("%c", tok) 15 | } 16 | 17 | if (tok-ABORT_SYM) < 0 || (tok-ABORT_SYM) > len(MySQLToknames) { 18 | return fmt.Sprintf("Unknown Token:%d", tok) 19 | } 20 | 21 | return MySQLToknames[tok-ABORT_SYM] 22 | } 23 | -------------------------------------------------------------------------------- /sql/parser_trans_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestTransaction(t *testing.T) { 8 | st := testParse(`Start Transaction WITH CONSISTENT SNAPSHOT`, t, false) 9 | matchType(t, st, &StartTrans{}) 10 | 11 | st = testParse(`BEGIN`, t, false) 12 | matchType(t, st, &Begin{}) 13 | 14 | st = testParse(`COMMIT WORk NO RELEASE`, t, false) 15 | matchType(t, st, &Commit{}) 16 | 17 | st = testParse(`rollback`, t, false) 18 | matchType(t, st, &Rollback{}) 19 | } 20 | 21 | func TestSavePoint(t *testing.T) { 22 | st := testParse(`Savepoint identifier`, t, false) 23 | matchType(t, st, &SavePoint{}) 24 | 25 | st = testParse(`rollback to identifier`, t, false) 26 | matchType(t, st, &Rollback{}) 27 | 28 | st = testParse(`release savepoint identifier`, t, false) 29 | matchType(t, st, &Release{}) 30 | } 31 | 32 | func TestLockTables(t *testing.T) { 33 | st := testParse(`LOCK TABLES tb1 AS alias1 read, db2.tb2 low_priority write`, t, false) 34 | matchType(t, st, &Lock{}) 35 | matchSchemas(t, st, `db2`) 36 | 37 | st = testParse(`UNLOCK TABLES`, t, false) 38 | matchType(t, st, &Unlock{}) 39 | } 40 | -------------------------------------------------------------------------------- /sql/parser_util_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestDesc(t *testing.T) { 8 | st := testParse(` DESCRIBE db1.tb1;`, t, false) 9 | matchType(t, st, &DescribeTable{}) 10 | matchSchemas(t, st, `db1`) 11 | 12 | st = testParse(`explain select * from db1.table1`, t, false) 13 | matchSchemas(t, st, `db1`) 14 | } 15 | 16 | func TestHelp(t *testing.T) { 17 | st := testParse(`help 'help me'`, t, false) 18 | matchType(t, st, &Help{}) 19 | } 20 | 21 | func TestUse(t *testing.T) { 22 | st := testParse(`use mydb`, t, false) 23 | matchType(t, st, &Use{}) 24 | 25 | if string(st.(*Use).DB) != `mydb` { 26 | t.Fatalf("expect [mydb] match[%s]", string(st.(*Use).DB)) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sql/state/state.go: -------------------------------------------------------------------------------- 1 | package state 2 | 3 | import "fmt" 4 | 5 | const ( 6 | MY_LEX_START = iota 7 | MY_LEX_CHAR 8 | MY_LEX_IDENT 9 | MY_LEX_IDENT_SEP 10 | MY_LEX_IDENT_START 11 | MY_LEX_REAL 12 | MY_LEX_HEX_NUMBER 13 | MY_LEX_BIN_NUMBER 14 | MY_LEX_CMP_OP 15 | MY_LEX_LONG_CMP_OP 16 | MY_LEX_STRING 17 | MY_LEX_COMMENT 18 | MY_LEX_END 19 | MY_LEX_OPERATOR_OR_IDENT 20 | MY_LEX_NUMBER_IDENT 21 | MY_LEX_INT_OR_REAL 22 | MY_LEX_REAL_OR_POINT 23 | MY_LEX_BOOL 24 | MY_LEX_EOL 25 | MY_LEX_ESCAPE 26 | MY_LEX_LONG_COMMENT 27 | MY_LEX_END_LONG_COMMENT 28 | MY_LEX_SEMICOLON 29 | MY_LEX_SET_VAR 30 | MY_LEX_USER_END 31 | MY_LEX_HOSTNAME 32 | MY_LEX_SKIP 33 | MY_LEX_USER_VARIABLE_DELIMITER 34 | MY_LEX_SYSTEM_VAR 35 | MY_LEX_IDENT_OR_KEYWORD 36 | MY_LEX_IDENT_OR_HEX 37 | MY_LEX_IDENT_OR_BIN 38 | MY_LEX_IDENT_OR_NCHAR 39 | MY_LEX_STRING_OR_DELIMITER 40 | ) 41 | 42 | var statusMap map[uint]string = map[uint]string{ 43 | 44 | MY_LEX_START: "MY_LEX_START", 45 | MY_LEX_CHAR: "MY_LEX_CHAR", 46 | MY_LEX_IDENT: "MY_LEX_IDENT", 47 | MY_LEX_IDENT_SEP: "MY_LEX_IDENT_SEP", 48 | MY_LEX_IDENT_START: "MY_LEX_IDENT_START", 49 | MY_LEX_REAL: "MY_LEX_REAL", 50 | MY_LEX_HEX_NUMBER: "MY_LEX_HEX_NUMBER", 51 | MY_LEX_BIN_NUMBER: "MY_LEX_BIN_NUMBER", 52 | MY_LEX_CMP_OP: "MY_LEX_CMP_OP", 53 | MY_LEX_LONG_CMP_OP: "MY_LEX_LONG_CMP_OP", 54 | MY_LEX_STRING: "MY_LEX_STRING", 55 | MY_LEX_COMMENT: "MY_LEX_COMMENT", 56 | MY_LEX_END: "MY_LEX_END", 57 | MY_LEX_OPERATOR_OR_IDENT: "MY_LEX_OPERATOR_OR_IDENT", 58 | MY_LEX_NUMBER_IDENT: "MY_LEX_NUMBER_IDENT", 59 | MY_LEX_INT_OR_REAL: "MY_LEX_INT_OR_REAL", 60 | MY_LEX_REAL_OR_POINT: "MY_LEX_REAL_OR_POINT", 61 | MY_LEX_BOOL: "MY_LEX_BOOL", 62 | MY_LEX_EOL: "MY_LEX_EOL", 63 | MY_LEX_ESCAPE: "MY_LEX_ESCAPE", 64 | MY_LEX_LONG_COMMENT: "MY_LEX_LONG_COMMENT", 65 | MY_LEX_END_LONG_COMMENT: "MY_LEX_END_LONG_COMMENT", 66 | MY_LEX_SEMICOLON: "MY_LEX_SEMICOLON", 67 | MY_LEX_SET_VAR: "MY_LEX_SET_VAR", 68 | MY_LEX_USER_END: "MY_LEX_USER_END", 69 | MY_LEX_HOSTNAME: "MY_LEX_HOSTNAME", 70 | MY_LEX_SKIP: "MY_LEX_SKIP", 71 | MY_LEX_USER_VARIABLE_DELIMITER: "MY_LEX_USER_VARIABLE_DELIMITER", 72 | MY_LEX_SYSTEM_VAR: "MY_LEX_SYSTEM_VAR", 73 | MY_LEX_IDENT_OR_KEYWORD: "MY_LEX_IDENT_OR_KEYWORD", 74 | MY_LEX_IDENT_OR_HEX: "MY_LEX_IDENT_OR_HEX", 75 | MY_LEX_IDENT_OR_BIN: "MY_LEX_IDENT_OR_BIN", 76 | MY_LEX_IDENT_OR_NCHAR: "MY_LEX_IDENT_OR_NCHAR", 77 | MY_LEX_STRING_OR_DELIMITER: "MY_LEX_STRING_OR_DELIMITER", 78 | } 79 | 80 | func GetLexStatus(which uint) string { 81 | if v, ok := statusMap[which]; ok { 82 | return v 83 | } 84 | 85 | return fmt.Sprint("Unknow Status[%d]", which) 86 | } 87 | -------------------------------------------------------------------------------- /sql/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go test -coverprofile=$HOME/cover/coverage.out 4 | cd ~/cover/ && go tool cover -html=coverage.out -o coverage.html 5 | cp coverage.html /mnt/hgfs/ubuntu 6 | cd - 7 | -------------------------------------------------------------------------------- /wercker.yml: -------------------------------------------------------------------------------- 1 | # This references the default golang container from 2 | # the Docker Hub: https://registry.hub.docker.com/u/library/golang/ 3 | # If you want Google's container you would reference google/golang 4 | # Read more about containers on our dev center 5 | # http://devcenter.wercker.com/docs/containers/index.html 6 | box: golang 7 | # This is the build pipeline. Pipelines are the core of wercker 8 | # Read more about pipelines on our dev center 9 | # http://devcenter.wercker.com/docs/pipelines/index.html 10 | 11 | # You can also use services such as databases. Read more on our dev center: 12 | # http://devcenter.wercker.com/docs/services/index.html 13 | # services: 14 | # - postgres 15 | # http://devcenter.wercker.com/docs/services/postgresql.html 16 | 17 | # - mongodb 18 | # http://devcenter.wercker.com/docs/services/mongodb.html 19 | build: 20 | # The steps that will be executed on build 21 | # Steps make up the actions in your pipeline 22 | # Read more about steps on our dev center: 23 | # http://devcenter.wercker.com/docs/steps/index.html 24 | steps: 25 | # Sets the go workspace and places you package 26 | # at the right place in the workspace tree 27 | - setup-go-workspace 28 | 29 | # Gets the dependencies 30 | - script: 31 | name: hello world 32 | code: | 33 | echo "Hello world!" 34 | 35 | test: 36 | steps: 37 | # Build the project 38 | - script: 39 | name: hello test 40 | code: | 41 | echo "Hello test!" 42 | 43 | 44 | parallel-test: 45 | steps: 46 | # Build the project 47 | - script: 48 | name: hello test 49 | code: | 50 | echo "Hello test!" 51 | 52 | deploy: 53 | steps: 54 | # Build the project 55 | - script: 56 | name: hello test 57 | code: | 58 | echo "Hello test!" 59 | 60 | deploy-test: 61 | steps: 62 | # Build the project 63 | - script: 64 | name: hello test 65 | code: | 66 | echo "Hello test!" 67 | 68 | 69 | --------------------------------------------------------------------------------