├── .gitignore ├── makefile ├── mysql ├── result.go ├── error.go ├── resultset_sort_test.go ├── packetio.go ├── resultset_sort.go ├── field.go ├── const.go ├── util.go └── resultset.go ├── bootstrap.sh ├── sqlparser ├── Makefile ├── sql_test.go ├── analyzer_test.go ├── parsed_query.go ├── tracked_buffer.go ├── analyzer.go └── parsed_query_test.go ├── dev.env ├── hack ├── hack_test.go └── hack.go ├── LICENSE ├── etc ├── mixer_multi.conf.yaml ├── mixer_single.conf.yaml └── mixer.conf.yaml ├── proxy ├── schema.go ├── conn_set.go ├── conn_tx.go ├── conn_admin.go ├── server.go ├── server_test.go ├── conn_select.go ├── conn_resultset.go ├── conn_show.go ├── conn_stmt_test.go ├── node.go ├── conn.go ├── conn_test.go ├── conn_shard_test.go ├── conn_stmt.go └── conn_query.go ├── license_vitess ├── cmd └── mixer-proxy │ └── main.go ├── config ├── config.go └── config_test.go ├── router ├── router_test.go ├── numkey.go ├── config.go ├── router.go ├── shard.go └── key.go ├── client ├── conn_test.go ├── db.go ├── stmt.go └── stmt_test.go ├── doc └── mysql-proxy │ └── scripting.txt ├── sqltypes └── sqltypes.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | pkg 3 | .DS_Store 4 | y.output 5 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | all: build 2 | 3 | build: 4 | go install ./... 5 | 6 | clean: 7 | go clean -i ./... 8 | 9 | test: 10 | go test ./... -------------------------------------------------------------------------------- /mysql/result.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | type Result struct { 4 | Status uint16 5 | 6 | InsertId uint64 7 | AffectedRows uint64 8 | 9 | *Resultset 10 | } 11 | -------------------------------------------------------------------------------- /bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -f bootstrap.sh ]; then 4 | echo "bootstrap.sh must be run from its current directory" 1>&2 5 | exit 1 6 | fi 7 | 8 | source ./dev.env 9 | 10 | go get github.com/siddontang/go-log/log 11 | go get github.com/siddontang/go-yaml/yaml -------------------------------------------------------------------------------- /sqlparser/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2012, Google Inc. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can 3 | # be found in the LICENSE file. 4 | 5 | MAKEFLAGS = -s 6 | 7 | sql.go: sql.y 8 | go tool yacc -o sql.go sql.y 9 | gofmt -w sql.go 10 | 11 | clean: 12 | rm -f y.output sql.go 13 | -------------------------------------------------------------------------------- /dev.env: -------------------------------------------------------------------------------- 1 | export VTTOP=$(pwd) 2 | export VTROOT="${VTROOT:-${VTTOP/\/src\/github.com\/siddontang\/mixer/}}" 3 | # VTTOP sanity check 4 | if [[ "$VTTOP" == "${VTTOP/\/src\/github.com\/siddontang\/mixer/}" ]]; then 5 | echo "WARNING: VTTOP($VTTOP) does not contain src/github.com/siddontang/mixer" 6 | fi 7 | 8 | export GOTOP=$VTTOP 9 | 10 | function prepend_path() 11 | { 12 | # $1 path variable 13 | # $2 path to add 14 | if [ -d "$2" ] && [[ ":$1:" != *":$2:"* ]]; then 15 | echo "$2:$1" 16 | else 17 | echo "$1" 18 | fi 19 | } 20 | 21 | export GOPATH=$(prepend_path $GOPATH $VTROOT) 22 | -------------------------------------------------------------------------------- /hack/hack_test.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestString(t *testing.T) { 9 | b := []byte("hello world") 10 | a := String(b) 11 | 12 | if a != "hello world" { 13 | t.Fatal(a) 14 | } 15 | 16 | b[0] = 'a' 17 | 18 | if a != "aello world" { 19 | t.Fatal(a) 20 | } 21 | 22 | b = append(b, "abc"...) 23 | if a != "aello world" { 24 | t.Fatal(a) 25 | } 26 | } 27 | 28 | func TestByte(t *testing.T) { 29 | a := "hello world" 30 | 31 | b := Slice(a) 32 | 33 | if !bytes.Equal(b, []byte("hello world")) { 34 | t.Fatal(string(b)) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /hack/hack.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "reflect" 5 | "unsafe" 6 | ) 7 | 8 | // no copy to change slice to string 9 | // use your own risk 10 | func String(b []byte) (s string) { 11 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 12 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 13 | pstring.Data = pbytes.Data 14 | pstring.Len = pbytes.Len 15 | return 16 | } 17 | 18 | // no copy to change string to slice 19 | // use your own risk 20 | func Slice(s string) (b []byte) { 21 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 22 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 23 | pbytes.Data = pstring.Data 24 | pbytes.Len = pstring.Len 25 | pbytes.Cap = pstring.Len 26 | return 27 | } 28 | -------------------------------------------------------------------------------- /sqlparser/sql_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func testParse(t *testing.T, sql string) { 8 | _, err := Parse(sql) 9 | if err != nil { 10 | t.Fatal(err) 11 | } 12 | 13 | } 14 | 15 | func TestSet(t *testing.T) { 16 | sql := "set names gbk" 17 | testParse(t, sql) 18 | } 19 | 20 | func TestSimpleSelect(t *testing.T) { 21 | sql := "select last_insert_id() as a" 22 | testParse(t, sql) 23 | } 24 | 25 | func TestMixer(t *testing.T) { 26 | sql := `admin upnode("node1", "master", "127.0.0.1")` 27 | testParse(t, sql) 28 | 29 | sql = "show databases" 30 | testParse(t, sql) 31 | 32 | sql = "show tables from abc" 33 | testParse(t, sql) 34 | 35 | sql = "show tables from abc like a" 36 | testParse(t, sql) 37 | 38 | sql = "show tables from abc where a = 1" 39 | testParse(t, sql) 40 | 41 | sql = "show proxy abc" 42 | testParse(t, sql) 43 | } 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 siddontang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /etc/mixer_multi.conf.yaml: -------------------------------------------------------------------------------- 1 | addr : 127.0.0.1:4000 2 | user : root 3 | password : 4 | log_level : error 5 | 6 | nodes : 7 | - 8 | name : node1 9 | down_after_noalive : 300 10 | idle_conns : 16 11 | rw_split: false 12 | user: root 13 | password: 14 | master : 127.0.0.1:3306 15 | slave : 16 | - 17 | name : node2 18 | down_after_noalive : 300 19 | 0 : 16 20 | rw_split: false 21 | user: root 22 | password: 23 | master : 127.0.0.1:3307 24 | 25 | - 26 | name : node3 27 | down_after_noalive : 300 28 | idle_conns : 16 29 | rw_split: false 30 | user: root 31 | password: 32 | master : 127.0.0.1:3308 33 | 34 | schemas : 35 | - 36 | db : mixer 37 | nodes: [node1,node2,node3] 38 | rules: 39 | default: node1 40 | shard: 41 | - 42 | table: mixer_test_shard_hash 43 | key: id 44 | nodes: [node2, node3] 45 | type: hash 46 | 47 | - 48 | table: mixer_test_shard_range 49 | key: id 50 | type: range 51 | nodes: [node2, node3] 52 | range: -10000- 53 | -------------------------------------------------------------------------------- /sqlparser/analyzer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import "testing" 8 | 9 | func TestGetDBName(t *testing.T) { 10 | wantYes := []string{ 11 | "insert into a.b values(1)", 12 | "update a.b set c=1", 13 | "delete from a.b where c=d", 14 | } 15 | for _, stmt := range wantYes { 16 | result, err := GetDBName(stmt) 17 | if err != nil { 18 | t.Errorf("error %v on %s", err, stmt) 19 | continue 20 | } 21 | if result != "a" { 22 | t.Errorf("want a, got %s", result) 23 | } 24 | } 25 | 26 | wantNo := []string{ 27 | "insert into a values(1)", 28 | "update a set c=1", 29 | "delete from a where c=d", 30 | } 31 | for _, stmt := range wantNo { 32 | result, err := GetDBName(stmt) 33 | if err != nil { 34 | t.Errorf("error %v on %s", err, stmt) 35 | continue 36 | } 37 | if result != "" { 38 | t.Errorf("want '', got %s", result) 39 | } 40 | } 41 | 42 | wantErr := []string{ 43 | "select * from a", 44 | "syntax error", 45 | } 46 | for _, stmt := range wantErr { 47 | _, err := GetDBName(stmt) 48 | if err == nil { 49 | t.Errorf("want error, got nil") 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /etc/mixer_single.conf.yaml: -------------------------------------------------------------------------------- 1 | # Mixer configuration, using yaml format 2 | 3 | # server listen addr 4 | addr : 127.0.0.1:4000 5 | 6 | # server user and password 7 | user : root 8 | password : 9 | 10 | # log level[debug|info|warn|error],default error 11 | log_level : error 12 | 13 | # node is an agenda for real remote mysql server. 14 | nodes : 15 | - 16 | name : node1 17 | 18 | # default max idle conns for mysql server 19 | idle_conns : 16 20 | 21 | # if rw_split is true, select will use slave server 22 | rw_split: true 23 | 24 | # all mysql in a node must have the same user and password 25 | user : root 26 | password : 27 | 28 | # master represents a real mysql master server 29 | master : 127.0.0.1:3306 30 | 31 | # slave represents a real mysql salve server 32 | slave : 33 | 34 | # down mysql after N seconds noalive 35 | # 0 will no down 36 | down_after_noalive : 0 37 | 38 | 39 | 40 | # schema defines which db can be used by client and this db's sql will be executed in which nodes 41 | schemas : 42 | - 43 | db : mixer 44 | nodes: [node1] 45 | rules: 46 | # any other table not set above will use default [node1] 47 | default: node1 48 | shard: 49 | # empty shard tables rule 50 | - -------------------------------------------------------------------------------- /mysql/error.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | ErrBadConn = errors.New("connection was bad") 10 | ErrMalformPacket = errors.New("Malform packet error") 11 | 12 | ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") 13 | ) 14 | 15 | type SqlError struct { 16 | Code uint16 17 | Message string 18 | State string 19 | } 20 | 21 | func (e *SqlError) Error() string { 22 | return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.State, e.Message) 23 | } 24 | 25 | //default mysql error, must adapt errname message format 26 | func NewDefaultError(errCode uint16, args ...interface{}) *SqlError { 27 | e := new(SqlError) 28 | e.Code = errCode 29 | 30 | if s, ok := MySQLState[errCode]; ok { 31 | e.State = s 32 | } else { 33 | e.State = DEFAULT_MYSQL_STATE 34 | } 35 | 36 | if format, ok := MySQLErrName[errCode]; ok { 37 | e.Message = fmt.Sprintf(format, args...) 38 | } else { 39 | e.Message = fmt.Sprint(args...) 40 | } 41 | 42 | return e 43 | } 44 | 45 | func NewError(errCode uint16, message string) *SqlError { 46 | e := new(SqlError) 47 | e.Code = errCode 48 | 49 | if s, ok := MySQLState[errCode]; ok { 50 | e.State = s 51 | } else { 52 | e.State = DEFAULT_MYSQL_STATE 53 | } 54 | 55 | e.Message = message 56 | 57 | return e 58 | } 59 | -------------------------------------------------------------------------------- /proxy/schema.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/router" 6 | ) 7 | 8 | type Schema struct { 9 | db string 10 | 11 | nodes map[string]*Node 12 | 13 | rule *router.Router 14 | } 15 | 16 | func (s *Server) parseSchemas() error { 17 | s.schemas = make(map[string]*Schema) 18 | 19 | for _, schemaCfg := range s.cfg.Schemas { 20 | if _, ok := s.schemas[schemaCfg.DB]; ok { 21 | return fmt.Errorf("duplicate schema [%s].", schemaCfg.DB) 22 | } 23 | if len(schemaCfg.Nodes) == 0 { 24 | return fmt.Errorf("schema [%s] must have a node.", schemaCfg.DB) 25 | } 26 | 27 | nodes := make(map[string]*Node) 28 | for _, n := range schemaCfg.Nodes { 29 | if s.getNode(n) == nil { 30 | return fmt.Errorf("schema [%s] node [%s] config is not exists.", schemaCfg.DB, n) 31 | } 32 | 33 | if _, ok := nodes[n]; ok { 34 | return fmt.Errorf("schema [%s] node [%s] duplicate.", schemaCfg.DB, n) 35 | } 36 | 37 | nodes[n] = s.getNode(n) 38 | } 39 | 40 | rule, err := router.NewRouter(&schemaCfg) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | s.schemas[schemaCfg.DB] = &Schema{ 46 | db: schemaCfg.DB, 47 | nodes: nodes, 48 | rule: rule, 49 | } 50 | } 51 | 52 | return nil 53 | } 54 | 55 | func (s *Server) getSchema(db string) *Schema { 56 | return s.schemas[db] 57 | } 58 | -------------------------------------------------------------------------------- /license_vitess: -------------------------------------------------------------------------------- 1 | Copyright 2012, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /proxy/conn_set.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/siddontang/mixer/mysql" 6 | "github.com/siddontang/mixer/sqlparser" 7 | "strings" 8 | ) 9 | 10 | var nstring = sqlparser.String 11 | 12 | func (c *Conn) handleSet(stmt *sqlparser.Set) error { 13 | if len(stmt.Exprs) != 1 { 14 | return fmt.Errorf("must set one item once, not %s", nstring(stmt)) 15 | } 16 | 17 | k := string(stmt.Exprs[0].Name.Name) 18 | 19 | switch strings.ToUpper(k) { 20 | case `AUTOCOMMIT`: 21 | return c.handleSetAutoCommit(stmt.Exprs[0].Expr) 22 | case `NAMES`: 23 | return c.handleSetNames(stmt.Exprs[0].Expr) 24 | default: 25 | return fmt.Errorf("set %s is not supported now", k) 26 | } 27 | } 28 | 29 | func (c *Conn) handleSetAutoCommit(val sqlparser.ValExpr) error { 30 | value, ok := val.(sqlparser.NumVal) 31 | if !ok { 32 | return fmt.Errorf("set autocommit error") 33 | } 34 | switch value[0] { 35 | case '1': 36 | c.status |= SERVER_STATUS_AUTOCOMMIT 37 | case '0': 38 | c.status &= ^SERVER_STATUS_AUTOCOMMIT 39 | default: 40 | return fmt.Errorf("invalid autocommit flag %s", value) 41 | } 42 | 43 | return c.writeOK(nil) 44 | } 45 | 46 | func (c *Conn) handleSetNames(val sqlparser.ValExpr) error { 47 | value, ok := val.(sqlparser.StrVal) 48 | if !ok { 49 | return fmt.Errorf("set names charset error") 50 | } 51 | 52 | charset := strings.ToLower(string(value)) 53 | cid, ok := CharsetIds[charset] 54 | if !ok { 55 | return fmt.Errorf("invalid charset %s", charset) 56 | } 57 | 58 | c.charset = charset 59 | c.collation = cid 60 | 61 | return c.writeOK(nil) 62 | } 63 | -------------------------------------------------------------------------------- /proxy/conn_tx.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/siddontang/mixer/client" 5 | . "github.com/siddontang/mixer/mysql" 6 | ) 7 | 8 | func (c *Conn) isInTransaction() bool { 9 | return c.status&SERVER_STATUS_IN_TRANS > 0 10 | } 11 | 12 | func (c *Conn) isAutoCommit() bool { 13 | return c.status&SERVER_STATUS_AUTOCOMMIT > 0 14 | } 15 | 16 | func (c *Conn) handleBegin() error { 17 | c.status |= SERVER_STATUS_IN_TRANS 18 | return c.writeOK(nil) 19 | } 20 | 21 | func (c *Conn) handleCommit() (err error) { 22 | if err := c.commit(); err != nil { 23 | return err 24 | } else { 25 | return c.writeOK(nil) 26 | } 27 | } 28 | 29 | func (c *Conn) handleRollback() (err error) { 30 | if err := c.rollback(); err != nil { 31 | return err 32 | } else { 33 | return c.writeOK(nil) 34 | } 35 | } 36 | 37 | func (c *Conn) commit() (err error) { 38 | c.status &= ^SERVER_STATUS_IN_TRANS 39 | 40 | for _, co := range c.txConns { 41 | if e := co.Commit(); e != nil { 42 | err = e 43 | } 44 | co.Close() 45 | } 46 | 47 | c.txConns = map[*Node]*client.SqlConn{} 48 | 49 | return 50 | } 51 | 52 | func (c *Conn) rollback() (err error) { 53 | c.status &= ^SERVER_STATUS_IN_TRANS 54 | 55 | for _, co := range c.txConns { 56 | if e := co.Rollback(); e != nil { 57 | err = e 58 | } 59 | co.Close() 60 | } 61 | 62 | c.txConns = map[*Node]*client.SqlConn{} 63 | 64 | return 65 | } 66 | 67 | //if status is in_trans, need 68 | //else if status is not autocommit, need 69 | //else no need 70 | func (c *Conn) needBeginTx() bool { 71 | return c.isInTransaction() || !c.isAutoCommit() 72 | } 73 | -------------------------------------------------------------------------------- /proxy/conn_admin.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/sqlparser" 6 | "strings" 7 | ) 8 | 9 | func (c *Conn) handleAdmin(admin *sqlparser.Admin) error { 10 | name := string(admin.Name) 11 | 12 | var err error 13 | switch strings.ToLower(name) { 14 | case "upnode": 15 | err = c.adminUpNodeServer(admin.Values) 16 | case "downnode": 17 | err = c.adminDownNodeServer(admin.Values) 18 | default: 19 | return fmt.Errorf("admin %s not supported now", name) 20 | } 21 | 22 | if err != nil { 23 | return err 24 | } 25 | 26 | return c.writeOK(nil) 27 | } 28 | 29 | func (c *Conn) adminUpNodeServer(values sqlparser.ValExprs) error { 30 | if len(values) != 3 { 31 | return fmt.Errorf("upnode needs 3 args, not %d", len(values)) 32 | } 33 | 34 | nodeName := nstring(values[0]) 35 | sType := strings.ToLower(nstring(values[1])) 36 | addr := strings.ToLower(nstring(values[2])) 37 | 38 | switch sType { 39 | case Master: 40 | return c.server.UpMaster(nodeName, addr) 41 | case Slave: 42 | return c.server.UpSlave(nodeName, addr) 43 | default: 44 | return fmt.Errorf("invalid server type %s", sType) 45 | } 46 | } 47 | 48 | func (c *Conn) adminDownNodeServer(values sqlparser.ValExprs) error { 49 | if len(values) != 2 { 50 | return fmt.Errorf("upnode needs 2 args, not %d", len(values)) 51 | } 52 | 53 | nodeName := nstring(values[0]) 54 | sType := strings.ToLower(nstring(values[1])) 55 | 56 | switch sType { 57 | case Master: 58 | return c.server.DownMaster(nodeName) 59 | case Slave: 60 | return c.server.DownSlave(nodeName) 61 | default: 62 | return fmt.Errorf("invalid server type %s", sType) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /cmd/mixer-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "github.com/siddontang/go-log/log" 6 | "github.com/siddontang/mixer/config" 7 | "github.com/siddontang/mixer/proxy" 8 | "os" 9 | "os/signal" 10 | "runtime" 11 | "strings" 12 | "syscall" 13 | ) 14 | 15 | var configFile *string = flag.String("config", "/etc/mixer.conf", "mixer proxy config file") 16 | var logLevel *string = flag.String("log-level", "", "log level [debug|info|warn|error], default error") 17 | 18 | func main() { 19 | runtime.GOMAXPROCS(runtime.NumCPU()) 20 | 21 | flag.Parse() 22 | 23 | if len(*configFile) == 0 { 24 | log.Error("must use a config file") 25 | return 26 | } 27 | 28 | cfg, err := config.ParseConfigFile(*configFile) 29 | if err != nil { 30 | log.Error(err.Error()) 31 | return 32 | } 33 | 34 | if *logLevel != "" { 35 | setLogLevel(*logLevel) 36 | } else { 37 | setLogLevel(cfg.LogLevel) 38 | } 39 | 40 | var svr *proxy.Server 41 | svr, err = proxy.NewServer(cfg) 42 | if err != nil { 43 | log.Error(err.Error()) 44 | return 45 | } 46 | 47 | sc := make(chan os.Signal, 1) 48 | signal.Notify(sc, 49 | syscall.SIGHUP, 50 | syscall.SIGINT, 51 | syscall.SIGTERM, 52 | syscall.SIGQUIT) 53 | 54 | go func() { 55 | sig := <-sc 56 | log.Info("Got signal [%d] to exit.", sig) 57 | svr.Close() 58 | }() 59 | 60 | svr.Run() 61 | } 62 | 63 | func setLogLevel(level string) { 64 | switch strings.ToLower(level) { 65 | case "debug": 66 | log.SetLevel(log.LevelDebug) 67 | case "info": 68 | log.SetLevel(log.LevelInfo) 69 | case "warn": 70 | log.SetLevel(log.LevelWarn) 71 | case "error": 72 | log.SetLevel(log.LevelError) 73 | default: 74 | log.SetLevel(log.LevelError) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/siddontang/go-yaml/yaml" 5 | "io/ioutil" 6 | ) 7 | 8 | type NodeConfig struct { 9 | Name string `yaml:"name"` 10 | DownAfterNoAlive int `yaml:"down_after_noalive"` 11 | IdleConns int `yaml:"idle_conns"` 12 | RWSplit bool `yaml:"rw_split"` 13 | 14 | User string `yaml:"user"` 15 | Password string `yaml:"password"` 16 | 17 | Master string `yaml:"master"` 18 | Slave string `yaml:"slave"` 19 | } 20 | 21 | type SchemaConfig struct { 22 | DB string `yaml:"db"` 23 | Nodes []string `yaml:"nodes"` 24 | RulesConifg RulesConfig `yaml:"rules"` 25 | } 26 | 27 | type RulesConfig struct { 28 | Default string `yaml:"default"` 29 | ShardRule []ShardConfig `yaml:"shard"` 30 | } 31 | 32 | type ShardConfig struct { 33 | Table string `yaml:"table"` 34 | Key string `yaml:"key"` 35 | Nodes []string `yaml:"nodes"` 36 | Type string `yaml:"type"` 37 | Range string `yaml:"range"` 38 | } 39 | 40 | type Config struct { 41 | Addr string `yaml:"addr"` 42 | User string `yaml:"user"` 43 | Password string `yaml:"password"` 44 | LogLevel string `yaml:"log_level"` 45 | 46 | Nodes []NodeConfig `yaml:"nodes"` 47 | 48 | Schemas []SchemaConfig `yaml:"schemas"` 49 | } 50 | 51 | func ParseConfigData(data []byte) (*Config, error) { 52 | var cfg Config 53 | if err := yaml.Unmarshal([]byte(data), &cfg); err != nil { 54 | return nil, err 55 | } 56 | return &cfg, nil 57 | } 58 | 59 | func ParseConfigFile(fileName string) (*Config, error) { 60 | data, err := ioutil.ReadFile(fileName) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return ParseConfigData(data) 66 | } 67 | -------------------------------------------------------------------------------- /etc/mixer.conf.yaml: -------------------------------------------------------------------------------- 1 | # Mixer configuration, using yaml format 2 | 3 | # server listen addr 4 | addr : 127.0.0.1:4000 5 | 6 | # server user and password 7 | user : root 8 | password : 9 | 10 | # log level[debug|info|warn|error],default error 11 | log_level : error 12 | 13 | # node is an agenda for real remote mysql server. 14 | nodes : 15 | - 16 | name : node1 17 | 18 | # default max idle conns for mysql server 19 | idle_conns : 16 20 | 21 | # if rw_split is true, select will use slave server 22 | rw_split: true 23 | 24 | # all mysql in a node must have the same user and password 25 | user : root 26 | password: 27 | 28 | # master represents a real mysql master server 29 | master : 127.0.0.1:3306 30 | 31 | # slave represents a real mysql salve server 32 | slave : 127.0.0.1:4306 33 | 34 | # down mysql after N seconds noalive 35 | # 0 will no down 36 | down_after_noalive : 300 37 | 38 | - 39 | name : node2 40 | user: root 41 | password: 42 | master : 127.0.0.1:3308 43 | 44 | 45 | # schema defines which db can be used by client and this db's sql will be executed in which nodes 46 | schemas : 47 | - 48 | db : mixer 49 | nodes: [node1, node2] 50 | 51 | # rule defines how sql executed in nodes 52 | rules: 53 | # any other table not set above will use default [node1] 54 | default: node1 55 | shard: 56 | - 57 | table: test1 58 | key: id 59 | type: hash 60 | # node will be node1, node2 61 | nodes: [node1, node2] 62 | 63 | - 64 | table: test2 65 | key: name 66 | nodes: [node1, node2] 67 | type: range 68 | # range is left close and right open 69 | # node1 range (-inf, 10000) 70 | # node2 range [10000, 20000) 71 | range: -10000- 72 | -------------------------------------------------------------------------------- /router/router_test.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "github.com/siddontang/go-yaml/yaml" 5 | "github.com/siddontang/mixer/config" 6 | "testing" 7 | ) 8 | 9 | func TestParseRule(t *testing.T) { 10 | var s = ` 11 | schemas : 12 | - 13 | db : mixer 14 | nodes: [node1, node2, node3] 15 | rules: 16 | default: node1 17 | shard: 18 | - 19 | table: mixer_test_shard_hash 20 | key: id 21 | nodes: [node2, node3] 22 | type: hash 23 | 24 | - 25 | table: mixer_test_shard_range 26 | key: id 27 | type: range 28 | nodes: [node2, node3] 29 | range: -10000- 30 | ` 31 | var cfg config.Config 32 | if err := yaml.Unmarshal([]byte(s), &cfg); err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | rt, err := NewRouter(&cfg.Schemas[0]) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | if rt.DefaultRule.Nodes[0] != "node1" { 41 | t.Fatal("default rule parse not correct.") 42 | } 43 | 44 | hashRule := rt.GetRule("mixer_test_shard_hash") 45 | if hashRule.Type != HashRuleType { 46 | t.Fatal(hashRule.Type) 47 | } 48 | 49 | if len(hashRule.Nodes) != 2 || hashRule.Nodes[0] != "node2" || hashRule.Nodes[1] != "node3" { 50 | t.Fatal("parse nodes not correct.") 51 | } 52 | 53 | if n := hashRule.FindNode(uint64(11)); n != "node3" { 54 | t.Fatal(n) 55 | } 56 | 57 | rangeRule := rt.GetRule("mixer_test_shard_range") 58 | if rangeRule.Type != RangeRuleType { 59 | t.Fatal(rangeRule.Type) 60 | } 61 | 62 | if n := rangeRule.FindNode(10000 - 1); n != "node2" { 63 | t.Fatal(n) 64 | } 65 | 66 | defaultRule := rt.GetRule("mixer_defaultRule_table") 67 | if defaultRule == nil { 68 | t.Fatal("must not nil") 69 | } 70 | 71 | if defaultRule.Type != DefaultRuleType { 72 | t.Fatal(defaultRule.Type) 73 | } 74 | 75 | if defaultRule.Shard == nil { 76 | t.Fatal("nil error") 77 | } 78 | 79 | if n := defaultRule.FindNode(11); n != "node1" { 80 | t.Fatal(n) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /router/numkey.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | MinNumKey = math.MinInt64 12 | MaxNumKey = math.MaxInt64 13 | ) 14 | 15 | type NumKeyRange struct { 16 | Start int64 17 | End int64 18 | } 19 | 20 | func (kr NumKeyRange) MapKey() string { 21 | return fmt.Sprintf("%d-%d", kr.String(), kr.End) 22 | } 23 | 24 | func (kr NumKeyRange) Contains(i int64) bool { 25 | return kr.Start <= i && (kr.End == MaxNumKey || i < kr.End) 26 | } 27 | 28 | func (kr NumKeyRange) String() string { 29 | return fmt.Sprintf("{Start: %d, End: %d}", kr.Start, kr.End) 30 | } 31 | 32 | // ParseShardingSpec parses a string that describes a sharding 33 | // specification. a-b-c-d will be parsed as a-b, b-c, c-d. The empty 34 | // string may serve both as the start and end of the keyspace: -a-b- 35 | // will be parsed as start-a, a-b, b-end. 36 | func ParseNumShardingSpec(spec string) ([]NumKeyRange, error) { 37 | parts := strings.Split(spec, "-") 38 | if len(parts) == 1 { 39 | return nil, fmt.Errorf("malformed spec: doesn't define a range: %q", spec) 40 | } 41 | var old int64 42 | var err error 43 | if len(parts[0]) != 0 { 44 | old, err = strconv.ParseInt(parts[0], 10, 64) 45 | if err != nil { 46 | return nil, err 47 | } 48 | } else { 49 | old = MinNumKey 50 | } 51 | 52 | ranges := make([]NumKeyRange, len(parts)-1) 53 | 54 | var n int64 55 | 56 | for i, p := range parts[1:] { 57 | if p == "" && i != (len(parts)-2) { 58 | return nil, fmt.Errorf("malformed spec: MinKey/MaxKey cannot be in the middle of the spec: %q", spec) 59 | } 60 | 61 | if p != "" { 62 | n, err = strconv.ParseInt(p, 10, 64) 63 | if err != nil { 64 | return nil, err 65 | } 66 | } else { 67 | n = MaxNumKey 68 | } 69 | if n <= old { 70 | return nil, fmt.Errorf("malformed spec: shard limits should be in order: %q", spec) 71 | } 72 | 73 | ranges[i] = NumKeyRange{Start: old, End: n} 74 | old = n 75 | } 76 | return ranges, nil 77 | } 78 | -------------------------------------------------------------------------------- /proxy/server.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/siddontang/go-log/log" 5 | "github.com/siddontang/mixer/config" 6 | 7 | "net" 8 | "runtime" 9 | "strings" 10 | ) 11 | 12 | type Server struct { 13 | cfg *config.Config 14 | 15 | addr string 16 | user string 17 | password string 18 | 19 | running bool 20 | 21 | listener net.Listener 22 | 23 | nodes map[string]*Node 24 | 25 | schemas map[string]*Schema 26 | } 27 | 28 | func NewServer(cfg *config.Config) (*Server, error) { 29 | s := new(Server) 30 | 31 | s.cfg = cfg 32 | 33 | s.addr = cfg.Addr 34 | s.user = cfg.User 35 | s.password = cfg.Password 36 | 37 | if err := s.parseNodes(); err != nil { 38 | return nil, err 39 | } 40 | 41 | if err := s.parseSchemas(); err != nil { 42 | return nil, err 43 | } 44 | 45 | var err error 46 | netProto := "tcp" 47 | if strings.Contains(netProto, "/") { 48 | netProto = "unix" 49 | } 50 | s.listener, err = net.Listen(netProto, s.addr) 51 | 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | log.Info("Server run MySql Protocol Listen(%s) at [%s]", netProto, s.addr) 57 | return s, nil 58 | } 59 | 60 | func (s *Server) Run() error { 61 | s.running = true 62 | 63 | for s.running { 64 | conn, err := s.listener.Accept() 65 | if err != nil { 66 | log.Error("accept error %s", err.Error()) 67 | continue 68 | } 69 | 70 | go s.onConn(conn) 71 | } 72 | 73 | return nil 74 | } 75 | 76 | func (s *Server) Close() { 77 | s.running = false 78 | if s.listener != nil { 79 | s.listener.Close() 80 | } 81 | } 82 | 83 | func (s *Server) onConn(c net.Conn) { 84 | conn := s.newConn(c) 85 | 86 | defer func() { 87 | if err := recover(); err != nil { 88 | const size = 4096 89 | buf := make([]byte, size) 90 | buf = buf[:runtime.Stack(buf, false)] 91 | log.Error("onConn panic %v: %v\n%s", c.RemoteAddr().String(), err, buf) 92 | } 93 | 94 | conn.Close() 95 | }() 96 | 97 | if err := conn.Handshake(); err != nil { 98 | log.Error("handshake error %s", err.Error()) 99 | c.Close() 100 | return 101 | } 102 | 103 | conn.Run() 104 | 105 | } 106 | -------------------------------------------------------------------------------- /mysql/resultset_sort_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "testing" 8 | ) 9 | 10 | func TestResultsetSort(t *testing.T) { 11 | r1 := new(Resultset) 12 | r2 := new(Resultset) 13 | 14 | r1.Values = [][]interface{}{ 15 | []interface{}{int64(1), "a", []byte("aa")}, 16 | []interface{}{int64(2), "a", []byte("bb")}, 17 | []interface{}{int64(3), "c", []byte("bb")}, 18 | } 19 | 20 | r1.RowDatas = []RowData{ 21 | RowData([]byte("1")), 22 | RowData([]byte("2")), 23 | RowData([]byte("3")), 24 | } 25 | 26 | s := new(resultsetSorter) 27 | 28 | s.Resultset = r1 29 | 30 | s.sk = []SortKey{ 31 | SortKey{column: 0, Direction: SortDesc}, 32 | } 33 | 34 | sort.Sort(s) 35 | 36 | r2.Values = [][]interface{}{ 37 | []interface{}{int64(3), "c", []byte("bb")}, 38 | []interface{}{int64(2), "a", []byte("bb")}, 39 | []interface{}{int64(1), "a", []byte("aa")}, 40 | } 41 | 42 | r2.RowDatas = []RowData{ 43 | RowData([]byte("3")), 44 | RowData([]byte("2")), 45 | RowData([]byte("1")), 46 | } 47 | 48 | if !reflect.DeepEqual(r1, r2) { 49 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 50 | } 51 | 52 | s.sk = []SortKey{ 53 | SortKey{column: 1, Direction: SortAsc}, 54 | SortKey{column: 2, Direction: SortDesc}, 55 | } 56 | 57 | sort.Sort(s) 58 | 59 | r2.Values = [][]interface{}{ 60 | []interface{}{int64(2), "a", []byte("bb")}, 61 | []interface{}{int64(1), "a", []byte("aa")}, 62 | []interface{}{int64(3), "c", []byte("bb")}, 63 | } 64 | 65 | r2.RowDatas = []RowData{ 66 | RowData([]byte("2")), 67 | RowData([]byte("1")), 68 | RowData([]byte("3")), 69 | } 70 | 71 | if !reflect.DeepEqual(r1, r2) { 72 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 73 | } 74 | 75 | s.sk = []SortKey{ 76 | SortKey{column: 1, Direction: SortAsc}, 77 | SortKey{column: 2, Direction: SortAsc}, 78 | } 79 | 80 | sort.Sort(s) 81 | 82 | r2.Values = [][]interface{}{ 83 | []interface{}{int64(1), "a", []byte("aa")}, 84 | []interface{}{int64(2), "a", []byte("bb")}, 85 | []interface{}{int64(3), "c", []byte("bb")}, 86 | } 87 | 88 | r2.RowDatas = []RowData{ 89 | RowData([]byte("1")), 90 | RowData([]byte("2")), 91 | RowData([]byte("3")), 92 | } 93 | 94 | if !reflect.DeepEqual(r1, r2) { 95 | t.Fatal(fmt.Sprintf("%v %v", r1, r2)) 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /mysql/packetio.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net" 8 | ) 9 | 10 | type PacketIO struct { 11 | rb *bufio.Reader 12 | wb io.Writer 13 | 14 | Sequence uint8 15 | } 16 | 17 | func NewPacketIO(conn net.Conn) *PacketIO { 18 | p := new(PacketIO) 19 | 20 | p.rb = bufio.NewReaderSize(conn, 1024) 21 | p.wb = conn 22 | 23 | p.Sequence = 0 24 | 25 | return p 26 | } 27 | 28 | func (p *PacketIO) ReadPacket() ([]byte, error) { 29 | header := []byte{0, 0, 0, 0} 30 | 31 | if _, err := io.ReadFull(p.rb, header); err != nil { 32 | return nil, ErrBadConn 33 | } 34 | 35 | length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 36 | if length < 1 { 37 | return nil, fmt.Errorf("invalid payload length %d", length) 38 | } 39 | 40 | sequence := uint8(header[3]) 41 | 42 | if sequence != p.Sequence { 43 | return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence) 44 | } 45 | 46 | p.Sequence++ 47 | 48 | data := make([]byte, length) 49 | if _, err := io.ReadFull(p.rb, data); err != nil { 50 | return nil, ErrBadConn 51 | } else { 52 | if length < MaxPayloadLen { 53 | return data, nil 54 | } 55 | 56 | var buf []byte 57 | buf, err = p.ReadPacket() 58 | if err != nil { 59 | return nil, ErrBadConn 60 | } else { 61 | return append(data, buf...), nil 62 | } 63 | } 64 | } 65 | 66 | //data already have header 67 | func (p *PacketIO) WritePacket(data []byte) error { 68 | length := len(data) - 4 69 | 70 | for length >= MaxPayloadLen { 71 | 72 | data[0] = 0xff 73 | data[1] = 0xff 74 | data[2] = 0xff 75 | 76 | data[3] = p.Sequence 77 | 78 | if n, err := p.wb.Write(data[:4+MaxPayloadLen]); err != nil { 79 | return ErrBadConn 80 | } else if n != (4 + MaxPayloadLen) { 81 | return ErrBadConn 82 | } else { 83 | p.Sequence++ 84 | length -= MaxPayloadLen 85 | data = data[MaxPayloadLen:] 86 | } 87 | } 88 | 89 | data[0] = byte(length) 90 | data[1] = byte(length >> 8) 91 | data[2] = byte(length >> 16) 92 | data[3] = p.Sequence 93 | 94 | if n, err := p.wb.Write(data); err != nil { 95 | return ErrBadConn 96 | } else if n != len(data) { 97 | return ErrBadConn 98 | } else { 99 | p.Sequence++ 100 | return nil 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /proxy/server_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/siddontang/mixer/client" 5 | "github.com/siddontang/mixer/config" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | var testServerOnce sync.Once 12 | var testServer *Server 13 | var testDBOnce sync.Once 14 | var testDB *client.DB 15 | 16 | var testConfigData = []byte(` 17 | addr : 127.0.0.1:4000 18 | user : root 19 | password : 20 | 21 | nodes : 22 | - 23 | name : node1 24 | down_after_noalive : 300 25 | idle_conns : 16 26 | rw_split: false 27 | user: root 28 | password: 29 | master : 127.0.0.1:3306 30 | slave : 31 | - 32 | name : node2 33 | down_after_noalive : 300 34 | idle_conns : 16 35 | rw_split: false 36 | user: root 37 | password: 38 | master : 127.0.0.1:3307 39 | 40 | - 41 | name : node3 42 | down_after_noalive : 300 43 | idle_conns : 16 44 | rw_split: false 45 | user: root 46 | password: 47 | master : 127.0.0.1:3308 48 | 49 | schemas : 50 | - 51 | db : mixer 52 | nodes: [node1, node2, node3] 53 | rules: 54 | default: node1 55 | shard: 56 | - 57 | table: mixer_test_shard_hash 58 | key: id 59 | nodes: [node2, node3] 60 | type: hash 61 | 62 | - 63 | table: mixer_test_shard_range 64 | key: id 65 | nodes: [node2, node3] 66 | range: -10000- 67 | type: range 68 | `) 69 | 70 | func newTestServer(t *testing.T) *Server { 71 | f := func() { 72 | cfg, err := config.ParseConfigData(testConfigData) 73 | if err != nil { 74 | t.Fatal(err.Error()) 75 | } 76 | 77 | testServer, err = NewServer(cfg) 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | go testServer.Run() 83 | 84 | time.Sleep(1 * time.Second) 85 | } 86 | 87 | testServerOnce.Do(f) 88 | 89 | return testServer 90 | } 91 | 92 | func newTestDB(t *testing.T) *client.DB { 93 | newTestServer(t) 94 | 95 | f := func() { 96 | var err error 97 | testDB, err = client.Open("127.0.0.1:4000", "root", "", "mixer") 98 | 99 | if err != nil { 100 | t.Fatal(err) 101 | } 102 | 103 | testDB.SetMaxIdleConnNum(4) 104 | } 105 | 106 | testDBOnce.Do(f) 107 | return testDB 108 | } 109 | 110 | func newTestDBConn(t *testing.T) *client.SqlConn { 111 | db := newTestDB(t) 112 | 113 | c, err := db.GetConn() 114 | 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | 119 | return c 120 | } 121 | 122 | func TestServer(t *testing.T) { 123 | newTestServer(t) 124 | } 125 | -------------------------------------------------------------------------------- /router/config.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/config" 6 | "regexp" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | var ( 12 | DefaultRuleType = "default" 13 | HashRuleType = "hash" 14 | RangeRuleType = "range" 15 | ) 16 | 17 | type RuleConfig struct { 18 | config.ShardConfig 19 | } 20 | 21 | func (c *RuleConfig) ParseRule(db string) (*Rule, error) { 22 | r := new(Rule) 23 | r.DB = db 24 | r.Table = c.Table 25 | r.Key = c.Key 26 | r.Type = c.Type 27 | r.Nodes = c.Nodes 28 | 29 | if err := c.parseShard(r); err != nil { 30 | return nil, err 31 | } 32 | 33 | return r, nil 34 | } 35 | 36 | func (c *RuleConfig) parseNodes(r *Rule) error { 37 | // Note: did not used yet, by HuangChuanTong 38 | reg, err := regexp.Compile(`(\w+)\((\d+)\-(\d+)\)`) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | ns := c.Nodes // strings.Split(c.Nodes, ",") 44 | 45 | nodes := map[string]struct{}{} 46 | 47 | for _, n := range ns { 48 | n = strings.TrimSpace(n) 49 | if s := reg.FindStringSubmatch(n); s == nil { 50 | if _, ok := nodes[n]; ok { 51 | return fmt.Errorf("duplicate node %s", n) 52 | } 53 | 54 | nodes[n] = struct{}{} 55 | r.Nodes = append(r.Nodes, n) 56 | } else { 57 | var start, stop int 58 | if start, err = strconv.Atoi(s[2]); err != nil { 59 | return err 60 | } 61 | 62 | if stop, err = strconv.Atoi(s[3]); err != nil { 63 | return err 64 | } 65 | 66 | if start >= stop { 67 | return fmt.Errorf("invalid node format %s", n) 68 | } 69 | 70 | for i := start; i <= stop; i++ { 71 | n = fmt.Sprintf("%s%d", s[1], i) 72 | 73 | if _, ok := nodes[n]; ok { 74 | return fmt.Errorf("duplicate node %s", n) 75 | } 76 | 77 | nodes[n] = struct{}{} 78 | r.Nodes = append(r.Nodes, n) 79 | 80 | } 81 | } 82 | } 83 | 84 | if len(r.Nodes) == 0 { 85 | return fmt.Errorf("empty nodes info") 86 | } 87 | 88 | if r.Type == DefaultRuleType && len(r.Nodes) != 1 { 89 | return fmt.Errorf("default rule must have only one node") 90 | } 91 | 92 | return nil 93 | } 94 | 95 | func (c *RuleConfig) parseShard(r *Rule) error { 96 | if r.Type == HashRuleType { 97 | //hash shard 98 | r.Shard = &HashShard{ShardNum: len(r.Nodes)} 99 | } else if r.Type == RangeRuleType { 100 | rs, err := ParseNumShardingSpec(c.Range) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | if len(rs) != len(r.Nodes) { 106 | return fmt.Errorf("range space %d not equal nodes %d", len(rs), len(r.Nodes)) 107 | } 108 | 109 | r.Shard = &NumRangeShard{Shards: rs} 110 | } else { 111 | r.Shard = &DefaultShard{} 112 | } 113 | 114 | return nil 115 | } 116 | -------------------------------------------------------------------------------- /router/router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/config" 6 | "strings" 7 | ) 8 | 9 | type Rule struct { 10 | DB string 11 | Table string 12 | Key string 13 | 14 | Type string 15 | 16 | Nodes []string 17 | Shard Shard 18 | } 19 | 20 | func (r *Rule) FindNode(key interface{}) string { 21 | i := r.Shard.FindForKey(key) 22 | return r.Nodes[i] 23 | } 24 | 25 | func (r *Rule) FindNodeIndex(key interface{}) int { 26 | return r.Shard.FindForKey(key) 27 | } 28 | 29 | func (r *Rule) String() string { 30 | return fmt.Sprintf("%s.%s?key=%v&shard=%s&nodes=%s", 31 | r.DB, r.Table, r.Key, r.Type, strings.Join(r.Nodes, ", ")) 32 | } 33 | 34 | func NewDefaultRule(db string, node string) *Rule { 35 | var r *Rule = &Rule{ 36 | DB: db, 37 | Type: DefaultRuleType, 38 | Nodes: []string{node}, 39 | Shard: new(DefaultShard), 40 | } 41 | return r 42 | } 43 | 44 | func (r *Router) GetRule(table string) *Rule { 45 | rule := r.Rules[table] 46 | if rule == nil { 47 | return r.DefaultRule 48 | } else { 49 | return rule 50 | } 51 | } 52 | 53 | type Router struct { 54 | DB string 55 | Rules map[string]*Rule //key is 56 | DefaultRule *Rule 57 | nodes []string //just for human saw 58 | } 59 | 60 | func NewRouter(schemaConfig *config.SchemaConfig) (*Router, error) { 61 | 62 | if !includeNode(schemaConfig.Nodes, schemaConfig.RulesConifg.Default) { 63 | return nil, fmt.Errorf("default node[%s] not in the nodes list.", 64 | schemaConfig.RulesConifg.Default) 65 | } 66 | 67 | rt := new(Router) 68 | rt.DB = schemaConfig.DB 69 | rt.nodes = schemaConfig.Nodes 70 | rt.Rules = make(map[string]*Rule, len(schemaConfig.RulesConifg.ShardRule)) 71 | rt.DefaultRule = NewDefaultRule(rt.DB, schemaConfig.RulesConifg.Default) 72 | 73 | for _, shard := range schemaConfig.RulesConifg.ShardRule { 74 | rc := &RuleConfig{shard} 75 | for _, node := range shard.Nodes { 76 | if !includeNode(rt.nodes, node) { 77 | return nil, fmt.Errorf("shard table[%s] node[%s] not in the schema.nodes list:[%s].", 78 | shard.Table, node, strings.Join(shard.Nodes, ",")) 79 | } 80 | } 81 | rule, err := rc.ParseRule(rt.DB) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | if rule.Type == DefaultRuleType { 87 | return nil, fmt.Errorf("[default-rule] duplicate, must only one.") 88 | } else { 89 | if _, ok := rt.Rules[rule.Table]; ok { 90 | return nil, fmt.Errorf("table %s rule in %s duplicate", rule.Table, rule.DB) 91 | } 92 | rt.Rules[rule.Table] = rule 93 | } 94 | } 95 | return rt, nil 96 | } 97 | 98 | func includeNode(nodes []string, node string) bool { 99 | for _, n := range nodes { 100 | if n == node { 101 | return true 102 | } 103 | } 104 | return false 105 | } 106 | -------------------------------------------------------------------------------- /sqlparser/parsed_query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "fmt" 11 | "strconv" 12 | 13 | "github.com/siddontang/mixer/sqltypes" 14 | ) 15 | 16 | type BindLocation struct { 17 | Offset, Length int 18 | } 19 | 20 | type ParsedQuery struct { 21 | Query string 22 | BindLocations []BindLocation 23 | } 24 | 25 | type EncoderFunc func(value interface{}) ([]byte, error) 26 | 27 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}, listVariables []sqltypes.Value) ([]byte, error) { 28 | if len(pq.BindLocations) == 0 { 29 | return []byte(pq.Query), nil 30 | } 31 | buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query))) 32 | current := 0 33 | for _, loc := range pq.BindLocations { 34 | buf.WriteString(pq.Query[current:loc.Offset]) 35 | varName := pq.Query[loc.Offset+1 : loc.Offset+loc.Length] 36 | var supplied interface{} 37 | if varName[0] >= '0' && varName[0] <= '9' { 38 | index, err := strconv.Atoi(varName) 39 | if err != nil { 40 | return nil, fmt.Errorf("unexpected: %v for %s", err, varName) 41 | } 42 | if index >= len(listVariables) { 43 | return nil, fmt.Errorf("index out of range: %d", index) 44 | } 45 | supplied = listVariables[index] 46 | } else if varName[0] == '*' { 47 | supplied = listVariables 48 | } else { 49 | var ok bool 50 | supplied, ok = bindVariables[varName] 51 | if !ok { 52 | return nil, fmt.Errorf("missing bind var %s", varName) 53 | } 54 | } 55 | if err := EncodeValue(buf, supplied); err != nil { 56 | return nil, err 57 | } 58 | current = loc.Offset + loc.Length 59 | } 60 | buf.WriteString(pq.Query[current:]) 61 | return buf.Bytes(), nil 62 | } 63 | 64 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) { 65 | return json.Marshal(pq.Query) 66 | } 67 | 68 | func EncodeValue(buf *bytes.Buffer, value interface{}) error { 69 | switch bindVal := value.(type) { 70 | case nil: 71 | buf.WriteString("null") 72 | case []sqltypes.Value: 73 | for i := 0; i < len(bindVal); i++ { 74 | if i != 0 { 75 | buf.WriteString(", ") 76 | } 77 | if err := EncodeValue(buf, bindVal[i]); err != nil { 78 | return err 79 | } 80 | } 81 | case [][]sqltypes.Value: 82 | for i := 0; i < len(bindVal); i++ { 83 | if i != 0 { 84 | buf.WriteString(", ") 85 | } 86 | buf.WriteByte('(') 87 | if err := EncodeValue(buf, bindVal[i]); err != nil { 88 | return err 89 | } 90 | buf.WriteByte(')') 91 | } 92 | default: 93 | v, err := sqltypes.BuildValue(bindVal) 94 | if err != nil { 95 | return err 96 | } 97 | v.EncodeSql(buf) 98 | } 99 | return nil 100 | } 101 | -------------------------------------------------------------------------------- /mysql/resultset_sort.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/siddontang/mixer/hack" 7 | "sort" 8 | ) 9 | 10 | const ( 11 | SortAsc = "asc" 12 | SortDesc = "desc" 13 | ) 14 | 15 | type SortKey struct { 16 | //name of the field 17 | Name string 18 | 19 | Direction string 20 | 21 | //column index of the field 22 | column int 23 | } 24 | 25 | type resultsetSorter struct { 26 | *Resultset 27 | 28 | sk []SortKey 29 | } 30 | 31 | func newResultsetSorter(r *Resultset, sk []SortKey) (*resultsetSorter, error) { 32 | s := new(resultsetSorter) 33 | 34 | s.Resultset = r 35 | 36 | for i, k := range sk { 37 | if column, ok := r.FieldNames[k.Name]; ok { 38 | sk[i].column = column 39 | } else { 40 | return nil, fmt.Errorf("key %s not in resultset fields, can not sort", k.Name) 41 | } 42 | } 43 | 44 | s.sk = sk 45 | 46 | return s, nil 47 | } 48 | 49 | func (r *resultsetSorter) Len() int { 50 | return r.RowNumber() 51 | } 52 | 53 | func (r *resultsetSorter) Less(i, j int) bool { 54 | v1 := r.Values[i] 55 | v2 := r.Values[j] 56 | 57 | for _, k := range r.sk { 58 | v := cmpValue(v1[k.column], v2[k.column]) 59 | 60 | if k.Direction == SortDesc { 61 | v = -v 62 | } 63 | 64 | if v < 0 { 65 | return true 66 | } else if v > 0 { 67 | return false 68 | } 69 | 70 | //equal, cmp next key 71 | } 72 | 73 | return false 74 | } 75 | 76 | //compare value using asc 77 | func cmpValue(v1 interface{}, v2 interface{}) int { 78 | if v1 == nil && v2 == nil { 79 | return 0 80 | } else if v1 == nil { 81 | return -1 82 | } else if v2 == nil { 83 | return 1 84 | } 85 | 86 | switch v := v1.(type) { 87 | case string: 88 | s := v2.(string) 89 | return bytes.Compare(hack.Slice(v), hack.Slice(s)) 90 | case []byte: 91 | s := v2.([]byte) 92 | return bytes.Compare(v, s) 93 | case int64: 94 | s := v2.(int64) 95 | if v < s { 96 | return -1 97 | } else if v > s { 98 | return 1 99 | } else { 100 | return 0 101 | } 102 | case uint64: 103 | s := v2.(uint64) 104 | if v < s { 105 | return -1 106 | } else if v > s { 107 | return 1 108 | } else { 109 | return 0 110 | } 111 | case float64: 112 | s := v2.(float64) 113 | if v < s { 114 | return -1 115 | } else if v > s { 116 | return 1 117 | } else { 118 | return 0 119 | } 120 | default: 121 | //can not go here 122 | panic(fmt.Sprintf("invalid type %T", v)) 123 | } 124 | } 125 | 126 | func (r *resultsetSorter) Swap(i, j int) { 127 | r.Values[i], r.Values[j] = r.Values[j], r.Values[i] 128 | 129 | r.RowDatas[i], r.RowDatas[j] = r.RowDatas[j], r.RowDatas[i] 130 | } 131 | 132 | func (r *Resultset) Sort(sk []SortKey) error { 133 | s, err := newResultsetSorter(r, sk) 134 | 135 | if err != nil { 136 | return err 137 | } 138 | 139 | sort.Sort(s) 140 | 141 | return nil 142 | } 143 | -------------------------------------------------------------------------------- /proxy/conn_select.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | . "github.com/siddontang/mixer/mysql" 7 | "github.com/siddontang/mixer/sqlparser" 8 | "strings" 9 | ) 10 | 11 | func (c *Conn) handleSimpleSelect(sql string, stmt *sqlparser.SimpleSelect) error { 12 | if len(stmt.SelectExprs) != 1 { 13 | return fmt.Errorf("support select one informaction function, %s", sql) 14 | } 15 | 16 | expr, ok := stmt.SelectExprs[0].(*sqlparser.NonStarExpr) 17 | if !ok { 18 | return fmt.Errorf("support select informaction function, %s", sql) 19 | } 20 | 21 | var f *sqlparser.FuncExpr 22 | f, ok = expr.Expr.(*sqlparser.FuncExpr) 23 | if !ok { 24 | return fmt.Errorf("support select informaction function, %s", sql) 25 | } 26 | 27 | var r *Resultset 28 | var err error 29 | 30 | switch strings.ToLower(string(f.Name)) { 31 | case "last_insert_id": 32 | r, err = c.buildSimpleSelectResult(c.lastInsertId, f.Name, expr.As) 33 | case "row_count": 34 | r, err = c.buildSimpleSelectResult(c.affectedRows, f.Name, expr.As) 35 | case "version": 36 | r, err = c.buildSimpleSelectResult(ServerVersion, f.Name, expr.As) 37 | case "connection_id": 38 | r, err = c.buildSimpleSelectResult(c.connectionId, f.Name, expr.As) 39 | case "database": 40 | if c.schema != nil { 41 | r, err = c.buildSimpleSelectResult(c.schema.db, f.Name, expr.As) 42 | } else { 43 | r, err = c.buildSimpleSelectResult("NULL", f.Name, expr.As) 44 | } 45 | default: 46 | return fmt.Errorf("function %s not support", f.Name) 47 | } 48 | 49 | if err != nil { 50 | return err 51 | } 52 | 53 | return c.writeResultset(c.status, r) 54 | } 55 | 56 | func (c *Conn) buildSimpleSelectResult(value interface{}, name []byte, asName []byte) (*Resultset, error) { 57 | field := &Field{} 58 | 59 | field.Name = name 60 | 61 | if asName != nil { 62 | field.Name = asName 63 | } 64 | 65 | field.OrgName = name 66 | 67 | formatField(field, value) 68 | 69 | r := &Resultset{Fields: []*Field{field}} 70 | row, err := formatValue(value) 71 | if err != nil { 72 | return nil, err 73 | } 74 | r.RowDatas = append(r.RowDatas, PutLengthEncodedString(row)) 75 | 76 | return r, nil 77 | } 78 | 79 | func (c *Conn) handleFieldList(data []byte) error { 80 | index := bytes.IndexByte(data, 0x00) 81 | table := string(data[0:index]) 82 | wildcard := string(data[index+1:]) 83 | 84 | if c.schema == nil { 85 | return NewDefaultError(ER_NO_DB_ERROR) 86 | } 87 | 88 | nodeName := c.schema.rule.GetRule(table).Nodes[0] 89 | 90 | n := c.server.getNode(nodeName) 91 | 92 | co, err := n.getMasterConn() 93 | if err != nil { 94 | return err 95 | } 96 | defer co.Close() 97 | 98 | if err = co.UseDB(c.schema.db); err != nil { 99 | return err 100 | } 101 | 102 | if fs, err := co.FieldList(table, wildcard); err != nil { 103 | return err 104 | } else { 105 | return c.writeFieldList(c.status, fs) 106 | } 107 | } 108 | 109 | func (c *Conn) writeFieldList(status uint16, fs []*Field) error { 110 | c.affectedRows = int64(-1) 111 | 112 | data := make([]byte, 4, 1024) 113 | 114 | for _, v := range fs { 115 | data = data[0:4] 116 | data = append(data, v.Dump()...) 117 | if err := c.writePacket(data); err != nil { 118 | return err 119 | } 120 | } 121 | 122 | if err := c.writeEOF(status); err != nil { 123 | return err 124 | } 125 | return nil 126 | } 127 | -------------------------------------------------------------------------------- /sqlparser/tracked_buffer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | ) 11 | 12 | // ParserError: To be deprecated. 13 | // TODO(sougou): deprecate. 14 | type ParserError struct { 15 | Message string 16 | } 17 | 18 | func NewParserError(format string, args ...interface{}) ParserError { 19 | return ParserError{fmt.Sprintf(format, args...)} 20 | } 21 | 22 | func (err ParserError) Error() string { 23 | return err.Message 24 | } 25 | 26 | func handleError(err *error) { 27 | if x := recover(); x != nil { 28 | *err = x.(error) 29 | } 30 | } 31 | 32 | // TrackedBuffer is used to rebuild a query from the ast. 33 | // bindLocations keeps track of locations in the buffer that 34 | // use bind variables for efficient future substitutions. 35 | // nodeFormatter is the formatting function the buffer will 36 | // use to format a node. By default(nil), it's FormatNode. 37 | // But you can supply a different formatting function if you 38 | // want to generate a query that's different from the default. 39 | type TrackedBuffer struct { 40 | *bytes.Buffer 41 | bindLocations []BindLocation 42 | nodeFormatter func(buf *TrackedBuffer, node SQLNode) 43 | } 44 | 45 | func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node SQLNode)) *TrackedBuffer { 46 | buf := &TrackedBuffer{ 47 | Buffer: bytes.NewBuffer(make([]byte, 0, 128)), 48 | bindLocations: make([]BindLocation, 0, 4), 49 | nodeFormatter: nodeFormatter, 50 | } 51 | return buf 52 | } 53 | 54 | // Fprintf mimics fmt.Fprintf, but limited to Node(%v), Node.Value(%s) and string(%s). 55 | // It also allows a %a for a value argument, in which case it adds tracking info for 56 | // future substitutions. 57 | func (buf *TrackedBuffer) Fprintf(format string, values ...interface{}) { 58 | end := len(format) 59 | fieldnum := 0 60 | for i := 0; i < end; { 61 | lasti := i 62 | for i < end && format[i] != '%' { 63 | i++ 64 | } 65 | if i > lasti { 66 | buf.WriteString(format[lasti:i]) 67 | } 68 | if i >= end { 69 | break 70 | } 71 | i++ // '%' 72 | switch format[i] { 73 | case 'c': 74 | switch v := values[fieldnum].(type) { 75 | case byte: 76 | buf.WriteByte(v) 77 | case rune: 78 | buf.WriteRune(v) 79 | default: 80 | panic(fmt.Sprintf("unexpected type %T", v)) 81 | } 82 | case 's': 83 | switch v := values[fieldnum].(type) { 84 | case []byte: 85 | buf.Write(v) 86 | case string: 87 | buf.WriteString(v) 88 | default: 89 | panic(fmt.Sprintf("unexpected type %T", v)) 90 | } 91 | case 'v': 92 | node := values[fieldnum].(SQLNode) 93 | if buf.nodeFormatter == nil { 94 | node.Format(buf) 95 | } else { 96 | buf.nodeFormatter(buf, node) 97 | } 98 | case 'a': 99 | buf.WriteArg(values[fieldnum].(string)) 100 | default: 101 | panic("unexpected") 102 | } 103 | fieldnum++ 104 | i++ 105 | } 106 | } 107 | 108 | // WriteArg writes a value argument into the buffer. arg should not contain 109 | // the ':' prefix. It also adds tracking info for future substitutions. 110 | func (buf *TrackedBuffer) WriteArg(arg string) { 111 | buf.bindLocations = append(buf.bindLocations, BindLocation{buf.Len(), len(arg) + 1}) 112 | buf.WriteByte(':') 113 | buf.WriteString(arg) 114 | } 115 | 116 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 117 | return &ParsedQuery{buf.String(), buf.bindLocations} 118 | } 119 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestConfig(t *testing.T) { 10 | var testConfigData = []byte( 11 | ` 12 | addr : 127.0.0.1:4000 13 | user : root 14 | password : 15 | log_level : error 16 | 17 | nodes : 18 | - 19 | name : node1 20 | down_after_noalive : 300 21 | idle_conns : 16 22 | rw_split: true 23 | user: root 24 | password: 25 | master : 127.0.0.1:3306 26 | slave : 127.0.0.1:4306 27 | - 28 | name : node2 29 | user: root 30 | master : 127.0.0.1:3307 31 | 32 | - 33 | name : node3 34 | down_after_noalive : 300 35 | idle_conns : 16 36 | rw_split: false 37 | user: root 38 | password: 39 | master : 127.0.0.1:3308 40 | 41 | schemas : 42 | - 43 | db : mixer 44 | nodes: [node1, node2, node3] 45 | rules: 46 | default: node1 47 | shard: 48 | - 49 | table: mixer_test_shard_hash 50 | key: id 51 | nodes: [node1, node2, node3] 52 | type: hash 53 | 54 | - 55 | table: mixer_test_shard_range 56 | key: id 57 | type: range 58 | nodes: [node2, node3] 59 | range: -10000- 60 | `) 61 | 62 | cfg, err := ParseConfigData(testConfigData) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | 67 | if len(cfg.Nodes) != 3 { 68 | t.Fatal(len(cfg.Nodes)) 69 | } 70 | 71 | if len(cfg.Schemas) != 1 { 72 | t.Fatal(len(cfg.Schemas)) 73 | } 74 | 75 | testNode := NodeConfig{ 76 | Name: "node1", 77 | DownAfterNoAlive: 300, 78 | IdleConns: 16, 79 | RWSplit: true, 80 | 81 | User: "root", 82 | Password: "", 83 | 84 | Master: "127.0.0.1:3306", 85 | Slave: "127.0.0.1:4306", 86 | } 87 | 88 | if !reflect.DeepEqual(cfg.Nodes[0], testNode) { 89 | fmt.Printf("%v\n", cfg.Nodes[0]) 90 | t.Fatal("node1 must equal") 91 | } 92 | 93 | testNode_2 := NodeConfig{ 94 | Name: "node2", 95 | User: "root", 96 | Master: "127.0.0.1:3307", 97 | } 98 | 99 | if !reflect.DeepEqual(cfg.Nodes[1], testNode_2) { 100 | t.Fatal("node2 must equal") 101 | } 102 | 103 | testShard_1 := ShardConfig{ 104 | Table: "mixer_test_shard_hash", 105 | Key: "id", 106 | Nodes: []string{"node1", "node2", "node3"}, 107 | Type: "hash", 108 | } 109 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg.ShardRule[0], testShard_1) { 110 | t.Fatal("ShardConfig0 must equal") 111 | } 112 | 113 | testShard_2 := ShardConfig{ 114 | Table: "mixer_test_shard_range", 115 | Key: "id", 116 | Nodes: []string{"node2", "node3"}, 117 | Type: "range", 118 | Range: "-10000-", 119 | } 120 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg.ShardRule[1], testShard_2) { 121 | t.Fatal("ShardConfig1 must equal") 122 | } 123 | 124 | if 2 != len(cfg.Schemas[0].RulesConifg.ShardRule) { 125 | t.Fatal("ShardRule must 2") 126 | } 127 | 128 | testRules := RulesConfig{ 129 | Default: "node1", 130 | ShardRule: []ShardConfig{testShard_1, testShard_2}, 131 | } 132 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg, testRules) { 133 | t.Fatal("RulesConfig must equal") 134 | } 135 | 136 | testSchema := SchemaConfig{ 137 | DB: "mixer", 138 | Nodes: []string{"node1", "node2", "node3"}, 139 | RulesConifg: testRules, 140 | } 141 | 142 | if !reflect.DeepEqual(cfg.Schemas[0], testSchema) { 143 | t.Fatal("schema must equal") 144 | } 145 | 146 | if cfg.LogLevel != "error" || cfg.User != "root" || cfg.Password != "" || cfg.Addr != "127.0.0.1:4000" { 147 | t.Fatal("Top Config not equal.") 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /client/conn_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/siddontang/mixer/mysql" 6 | "testing" 7 | ) 8 | 9 | func newTestConn() *Conn { 10 | c := new(Conn) 11 | 12 | if err := c.Connect("127.0.0.1:3306", "root", "", "mixer"); err != nil { 13 | panic(err) 14 | } 15 | 16 | return c 17 | } 18 | 19 | func TestConn_Connect(t *testing.T) { 20 | c := newTestConn() 21 | defer c.Close() 22 | } 23 | 24 | func TestConn_Ping(t *testing.T) { 25 | c := newTestConn() 26 | defer c.Close() 27 | 28 | if err := c.Ping(); err != nil { 29 | t.Fatal(err) 30 | } 31 | } 32 | 33 | func TestConn_DeleteTable(t *testing.T) { 34 | c := newTestConn() 35 | defer c.Close() 36 | 37 | if _, err := c.Execute("drop table if exists mixer_test_conn"); err != nil { 38 | t.Fatal(err) 39 | } 40 | } 41 | 42 | func TestConn_CreateTable(t *testing.T) { 43 | s := `CREATE TABLE IF NOT EXISTS mixer_test_conn ( 44 | id BIGINT(64) UNSIGNED NOT NULL, 45 | str VARCHAR(256), 46 | f DOUBLE, 47 | e enum("test1", "test2"), 48 | u tinyint unsigned, 49 | i tinyint, 50 | PRIMARY KEY (id) 51 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 52 | 53 | c := newTestConn() 54 | defer c.Close() 55 | 56 | if _, err := c.Execute(s); err != nil { 57 | t.Fatal(err) 58 | } 59 | } 60 | 61 | func TestConn_Insert(t *testing.T) { 62 | s := `insert into mixer_test_conn (id, str, f, e) values(1, "a", 3.14, "test1")` 63 | 64 | c := newTestConn() 65 | defer c.Close() 66 | 67 | if pkg, err := c.Execute(s); err != nil { 68 | t.Fatal(err) 69 | } else { 70 | if pkg.AffectedRows != 1 { 71 | t.Fatal(pkg.AffectedRows) 72 | } 73 | } 74 | } 75 | 76 | func TestConn_Select(t *testing.T) { 77 | s := `select str, f, e from mixer_test_conn where id = 1` 78 | 79 | c := newTestConn() 80 | defer c.Close() 81 | 82 | if result, err := c.Execute(s); err != nil { 83 | t.Fatal(err) 84 | } else { 85 | if len(result.Fields) != 3 { 86 | t.Fatal(len(result.Fields)) 87 | } 88 | 89 | if len(result.Values) != 1 { 90 | t.Fatal(len(result.Values)) 91 | } 92 | 93 | if str, _ := result.GetString(0, 0); str != "a" { 94 | t.Fatal("invalid str", str) 95 | } 96 | 97 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 98 | t.Fatal("invalid f", f) 99 | } 100 | 101 | if e, _ := result.GetString(0, 2); e != "test1" { 102 | t.Fatal("invalid e", e) 103 | } 104 | 105 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 106 | t.Fatal("invalid str", str) 107 | } 108 | 109 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 110 | t.Fatal("invalid f", f) 111 | } 112 | 113 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 114 | t.Fatal("invalid e", e) 115 | } 116 | 117 | } 118 | } 119 | 120 | func TestConn_Escape(t *testing.T) { 121 | c := newTestConn() 122 | defer c.Close() 123 | 124 | e := `""''\abc` 125 | s := fmt.Sprintf(`insert into mixer_test_conn (id, str) values(5, "%s")`, 126 | Escape(e)) 127 | 128 | if _, err := c.Execute(s); err != nil { 129 | t.Fatal(err) 130 | } 131 | 132 | s = `select str from mixer_test_conn where id = ?` 133 | 134 | if r, err := c.Execute(s, 5); err != nil { 135 | t.Fatal(err) 136 | } else { 137 | str, _ := r.GetString(0, 0) 138 | if str != e { 139 | t.Fatal(str) 140 | } 141 | } 142 | } 143 | 144 | func TestConn_SetCharset(t *testing.T) { 145 | c := newTestConn() 146 | defer c.Close() 147 | 148 | if err := c.SetCharset("gb2312"); err != nil { 149 | t.Fatal(err) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /mysql/field.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "encoding/binary" 5 | ) 6 | 7 | type FieldData []byte 8 | 9 | type Field struct { 10 | Data FieldData 11 | Schema []byte 12 | Table []byte 13 | OrgTable []byte 14 | Name []byte 15 | OrgName []byte 16 | Charset uint16 17 | ColumnLength uint32 18 | Type uint8 19 | Flag uint16 20 | Decimal uint8 21 | 22 | DefaultValueLength uint64 23 | DefaultValue []byte 24 | } 25 | 26 | func (p FieldData) Parse() (f *Field, err error) { 27 | f = new(Field) 28 | 29 | f.Data = p 30 | 31 | var n int 32 | pos := 0 33 | //skip catelog, always def 34 | n, err = SkipLengthEnodedString(p) 35 | if err != nil { 36 | return 37 | } 38 | pos += n 39 | 40 | //schema 41 | f.Schema, _, n, err = LengthEnodedString(p[pos:]) 42 | if err != nil { 43 | return 44 | } 45 | pos += n 46 | 47 | //table 48 | f.Table, _, n, err = LengthEnodedString(p[pos:]) 49 | if err != nil { 50 | return 51 | } 52 | pos += n 53 | 54 | //org_table 55 | f.OrgTable, _, n, err = LengthEnodedString(p[pos:]) 56 | if err != nil { 57 | return 58 | } 59 | pos += n 60 | 61 | //name 62 | f.Name, _, n, err = LengthEnodedString(p[pos:]) 63 | if err != nil { 64 | return 65 | } 66 | pos += n 67 | 68 | //org_name 69 | f.OrgName, _, n, err = LengthEnodedString(p[pos:]) 70 | if err != nil { 71 | return 72 | } 73 | pos += n 74 | 75 | //skip oc 76 | pos += 1 77 | 78 | //charset 79 | f.Charset = binary.LittleEndian.Uint16(p[pos:]) 80 | pos += 2 81 | 82 | //column length 83 | f.ColumnLength = binary.LittleEndian.Uint32(p[pos:]) 84 | pos += 4 85 | 86 | //type 87 | f.Type = p[pos] 88 | pos++ 89 | 90 | //flag 91 | f.Flag = binary.LittleEndian.Uint16(p[pos:]) 92 | pos += 2 93 | 94 | //decimals 1 95 | f.Decimal = p[pos] 96 | pos++ 97 | 98 | //filter [0x00][0x00] 99 | pos += 2 100 | 101 | f.DefaultValue = nil 102 | //if more data, command was field list 103 | if len(p) > pos { 104 | //length of default value lenenc-int 105 | f.DefaultValueLength, _, n = LengthEncodedInt(p[pos:]) 106 | pos += n 107 | 108 | if pos+int(f.DefaultValueLength) > len(p) { 109 | err = ErrMalformPacket 110 | return 111 | } 112 | 113 | //default value string[$len] 114 | f.DefaultValue = p[pos:(pos + int(f.DefaultValueLength))] 115 | } 116 | 117 | return 118 | } 119 | 120 | func (f *Field) Dump() []byte { 121 | if f.Data != nil { 122 | return []byte(f.Data) 123 | } 124 | 125 | l := len(f.Schema) + len(f.Table) + len(f.OrgTable) + len(f.Name) + len(f.OrgName) + len(f.DefaultValue) + 48 126 | 127 | data := make([]byte, 0, l) 128 | 129 | data = append(data, PutLengthEncodedString([]byte("def"))...) 130 | 131 | data = append(data, PutLengthEncodedString(f.Schema)...) 132 | 133 | data = append(data, PutLengthEncodedString(f.Table)...) 134 | data = append(data, PutLengthEncodedString(f.OrgTable)...) 135 | 136 | data = append(data, PutLengthEncodedString(f.Name)...) 137 | data = append(data, PutLengthEncodedString(f.OrgName)...) 138 | 139 | data = append(data, 0x0c) 140 | 141 | data = append(data, Uint16ToBytes(f.Charset)...) 142 | data = append(data, Uint32ToBytes(f.ColumnLength)...) 143 | data = append(data, f.Type) 144 | data = append(data, Uint16ToBytes(f.Flag)...) 145 | data = append(data, f.Decimal) 146 | data = append(data, 0, 0) 147 | 148 | if f.DefaultValue != nil { 149 | data = append(data, Uint64ToBytes(f.DefaultValueLength)...) 150 | data = append(data, f.DefaultValue...) 151 | } 152 | 153 | return data 154 | } 155 | -------------------------------------------------------------------------------- /mysql/const.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | const ( 4 | MinProtocolVersion byte = 10 5 | MaxPayloadLen int = 1<<24 - 1 6 | TimeFormat string = "2006-01-02 15:04:05" 7 | ServerVersion string = "5.5.31-mixer-0.1" 8 | ) 9 | 10 | const ( 11 | OK_HEADER byte = 0x00 12 | ERR_HEADER byte = 0xff 13 | EOF_HEADER byte = 0xfe 14 | LocalInFile_HEADER byte = 0xfb 15 | ) 16 | 17 | const ( 18 | SERVER_STATUS_IN_TRANS uint16 = 0x0001 19 | SERVER_STATUS_AUTOCOMMIT uint16 = 0x0002 20 | SERVER_MORE_RESULTS_EXISTS uint16 = 0x0008 21 | SERVER_STATUS_NO_GOOD_INDEX_USED uint16 = 0x0010 22 | SERVER_STATUS_NO_INDEX_USED uint16 = 0x0020 23 | SERVER_STATUS_CURSOR_EXISTS uint16 = 0x0040 24 | SERVER_STATUS_LAST_ROW_SEND uint16 = 0x0080 25 | SERVER_STATUS_DB_DROPPED uint16 = 0x0100 26 | SERVER_STATUS_NO_BACKSLASH_ESCAPED uint16 = 0x0200 27 | SERVER_STATUS_METADATA_CHANGED uint16 = 0x0400 28 | SERVER_QUERY_WAS_SLOW uint16 = 0x0800 29 | SERVER_PS_OUT_PARAMS uint16 = 0x1000 30 | ) 31 | 32 | const ( 33 | COM_SLEEP byte = iota 34 | COM_QUIT 35 | COM_INIT_DB 36 | COM_QUERY 37 | COM_FIELD_LIST 38 | COM_CREATE_DB 39 | COM_DROP_DB 40 | COM_REFRESH 41 | COM_SHUTDOWN 42 | COM_STATISTICS 43 | COM_PROCESS_INFO 44 | COM_CONNECT 45 | COM_PROCESS_KILL 46 | COM_DEBUG 47 | COM_PING 48 | COM_TIME 49 | COM_DELAYED_INSERT 50 | COM_CHANGE_USER 51 | COM_BINLOG_DUMP 52 | COM_TABLE_DUMP 53 | COM_CONNECT_OUT 54 | COM_REGISTER_SLAVE 55 | COM_STMT_PREPARE 56 | COM_STMT_EXECUTE 57 | COM_STMT_SEND_LONG_DATA 58 | COM_STMT_CLOSE 59 | COM_STMT_RESET 60 | COM_SET_OPTION 61 | COM_STMT_FETCH 62 | COM_DAEMON 63 | COM_BINLOG_DUMP_GTID 64 | COM_RESET_CONNECTION 65 | ) 66 | 67 | const ( 68 | CLIENT_LONG_PASSWORD uint32 = 1 << iota 69 | CLIENT_FOUND_ROWS 70 | CLIENT_LONG_FLAG 71 | CLIENT_CONNECT_WITH_DB 72 | CLIENT_NO_SCHEMA 73 | CLIENT_COMPRESS 74 | CLIENT_ODBC 75 | CLIENT_LOCAL_FILES 76 | CLIENT_IGNORE_SPACE 77 | CLIENT_PROTOCOL_41 78 | CLIENT_INTERACTIVE 79 | CLIENT_SSL 80 | CLIENT_IGNORE_SIGPIPE 81 | CLIENT_TRANSACTIONS 82 | CLIENT_RESERVED 83 | CLIENT_SECURE_CONNECTION 84 | CLIENT_MULTI_STATEMENTS 85 | CLIENT_MULTI_RESULTS 86 | CLIENT_PS_MULTI_RESULTS 87 | CLIENT_PLUGIN_AUTH 88 | CLIENT_CONNECT_ATTRS 89 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 90 | ) 91 | 92 | const ( 93 | MYSQL_TYPE_DECIMAL byte = iota 94 | MYSQL_TYPE_TINY 95 | MYSQL_TYPE_SHORT 96 | MYSQL_TYPE_LONG 97 | MYSQL_TYPE_FLOAT 98 | MYSQL_TYPE_DOUBLE 99 | MYSQL_TYPE_NULL 100 | MYSQL_TYPE_TIMESTAMP 101 | MYSQL_TYPE_LONGLONG 102 | MYSQL_TYPE_INT24 103 | MYSQL_TYPE_DATE 104 | MYSQL_TYPE_TIME 105 | MYSQL_TYPE_DATETIME 106 | MYSQL_TYPE_YEAR 107 | MYSQL_TYPE_NEWDATE 108 | MYSQL_TYPE_VARCHAR 109 | MYSQL_TYPE_BIT 110 | ) 111 | 112 | const ( 113 | MYSQL_TYPE_NEWDECIMAL byte = iota + 0xf6 114 | MYSQL_TYPE_ENUM 115 | MYSQL_TYPE_SET 116 | MYSQL_TYPE_TINY_BLOB 117 | MYSQL_TYPE_MEDIUM_BLOB 118 | MYSQL_TYPE_LONG_BLOB 119 | MYSQL_TYPE_BLOB 120 | MYSQL_TYPE_VAR_STRING 121 | MYSQL_TYPE_STRING 122 | MYSQL_TYPE_GEOMETRY 123 | ) 124 | 125 | const ( 126 | NOT_NULL_FLAG = 1 127 | PRI_KEY_FLAG = 2 128 | UNIQUE_KEY_FLAG = 4 129 | BLOB_FLAG = 16 130 | UNSIGNED_FLAG = 32 131 | ZEROFILL_FLAG = 64 132 | BINARY_FLAG = 128 133 | ENUM_FLAG = 256 134 | AUTO_INCREMENT_FLAG = 512 135 | TIMESTAMP_FLAG = 1024 136 | SET_FLAG = 2048 137 | NUM_FLAG = 32768 138 | PART_KEY_FLAG = 16384 139 | GROUP_FLAG = 32768 140 | UNIQUE_FLAG = 65536 141 | ) 142 | 143 | const ( 144 | AUTH_NAME = "mysql_native_password" 145 | ) 146 | -------------------------------------------------------------------------------- /proxy/conn_resultset.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/hack" 6 | . "github.com/siddontang/mixer/mysql" 7 | "strconv" 8 | ) 9 | 10 | func formatValue(value interface{}) ([]byte, error) { 11 | switch v := value.(type) { 12 | case int8: 13 | return strconv.AppendInt(nil, int64(v), 10), nil 14 | case int16: 15 | return strconv.AppendInt(nil, int64(v), 10), nil 16 | case int32: 17 | return strconv.AppendInt(nil, int64(v), 10), nil 18 | case int64: 19 | return strconv.AppendInt(nil, int64(v), 10), nil 20 | case int: 21 | return strconv.AppendInt(nil, int64(v), 10), nil 22 | case uint8: 23 | return strconv.AppendUint(nil, uint64(v), 10), nil 24 | case uint16: 25 | return strconv.AppendUint(nil, uint64(v), 10), nil 26 | case uint32: 27 | return strconv.AppendUint(nil, uint64(v), 10), nil 28 | case uint64: 29 | return strconv.AppendUint(nil, uint64(v), 10), nil 30 | case uint: 31 | return strconv.AppendUint(nil, uint64(v), 10), nil 32 | case float32: 33 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 34 | case float64: 35 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 36 | case []byte: 37 | return v, nil 38 | case string: 39 | return hack.Slice(v), nil 40 | default: 41 | return nil, fmt.Errorf("invalid type %T", value) 42 | } 43 | } 44 | 45 | func formatField(field *Field, value interface{}) error { 46 | switch value.(type) { 47 | case int8, int16, int32, int64, int: 48 | field.Charset = 63 49 | field.Type = MYSQL_TYPE_LONGLONG 50 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG 51 | case uint8, uint16, uint32, uint64, uint: 52 | field.Charset = 63 53 | field.Type = MYSQL_TYPE_LONGLONG 54 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG 55 | case string, []byte: 56 | field.Charset = 33 57 | field.Type = MYSQL_TYPE_VAR_STRING 58 | default: 59 | return fmt.Errorf("unsupport type %T for resultset", value) 60 | } 61 | return nil 62 | } 63 | 64 | func (c *Conn) buildResultset(names []string, values [][]interface{}) (*Resultset, error) { 65 | r := new(Resultset) 66 | 67 | r.Fields = make([]*Field, len(names)) 68 | 69 | var b []byte 70 | var err error 71 | 72 | for i, vs := range values { 73 | if len(vs) != len(r.Fields) { 74 | return nil, fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields)) 75 | } 76 | 77 | var row []byte 78 | for j, value := range vs { 79 | if i == 0 { 80 | field := &Field{} 81 | r.Fields[j] = field 82 | field.Name = hack.Slice(names[j]) 83 | 84 | if err = formatField(field, value); err != nil { 85 | return nil, err 86 | } 87 | } 88 | b, err = formatValue(value) 89 | 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | row = append(row, PutLengthEncodedString(b)...) 95 | } 96 | 97 | r.RowDatas = append(r.RowDatas, row) 98 | } 99 | 100 | return r, nil 101 | } 102 | 103 | func (c *Conn) writeResultset(status uint16, r *Resultset) error { 104 | c.affectedRows = int64(-1) 105 | 106 | columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) 107 | 108 | data := make([]byte, 4, 1024) 109 | 110 | data = append(data, columnLen...) 111 | if err := c.writePacket(data); err != nil { 112 | return err 113 | } 114 | 115 | for _, v := range r.Fields { 116 | data = data[0:4] 117 | data = append(data, v.Dump()...) 118 | if err := c.writePacket(data); err != nil { 119 | return err 120 | } 121 | } 122 | 123 | if err := c.writeEOF(status); err != nil { 124 | return err 125 | } 126 | 127 | for _, v := range r.RowDatas { 128 | data = data[0:4] 129 | data = append(data, v...) 130 | if err := c.writePacket(data); err != nil { 131 | return err 132 | } 133 | } 134 | 135 | if err := c.writeEOF(status); err != nil { 136 | return err 137 | } 138 | 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /sqlparser/analyzer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | // analyzer.go contains utility analysis functions. 8 | 9 | import ( 10 | "fmt" 11 | 12 | "github.com/siddontang/mixer/sqltypes" 13 | ) 14 | 15 | // GetDBName parses the specified DML and returns the 16 | // db name if it was used to qualify the table name. 17 | // It returns an error if parsing fails or if the statement 18 | // is not a DML. 19 | func GetDBName(sql string) (string, error) { 20 | statement, err := Parse(sql) 21 | if err != nil { 22 | return "", err 23 | } 24 | switch stmt := statement.(type) { 25 | case *Insert: 26 | return string(stmt.Table.Qualifier), nil 27 | case *Update: 28 | return string(stmt.Table.Qualifier), nil 29 | case *Delete: 30 | return string(stmt.Table.Qualifier), nil 31 | } 32 | return "", fmt.Errorf("statement '%s' is not a dml", sql) 33 | } 34 | 35 | // GetTableName returns the table name from the SimpleTableExpr 36 | // only if it's a simple expression. Otherwise, it returns "". 37 | func GetTableName(node SimpleTableExpr) string { 38 | if n, ok := node.(*TableName); ok && n.Qualifier == nil { 39 | return string(n.Name) 40 | } 41 | // sub-select or '.' expression 42 | return "" 43 | } 44 | 45 | // GetColName returns the column name, only if 46 | // it's a simple expression. Otherwise, it returns "". 47 | func GetColName(node Expr) string { 48 | if n, ok := node.(*ColName); ok { 49 | return string(n.Name) 50 | } 51 | return "" 52 | } 53 | 54 | // IsColName returns true if the ValExpr is a *ColName. 55 | func IsColName(node ValExpr) bool { 56 | _, ok := node.(*ColName) 57 | return ok 58 | } 59 | 60 | // IsVal returns true if the ValExpr is a string, number or value arg. 61 | // NULL is not considered to be a value. 62 | func IsValue(node ValExpr) bool { 63 | switch node.(type) { 64 | case StrVal, NumVal, ValArg: 65 | return true 66 | } 67 | return false 68 | } 69 | 70 | // HasINCaluse returns true if an yof the conditions has an IN clause. 71 | func HasINClause(conditions []BoolExpr) bool { 72 | for _, node := range conditions { 73 | if c, ok := node.(*ComparisonExpr); ok && c.Operator == AST_IN { 74 | return true 75 | } 76 | } 77 | return false 78 | } 79 | 80 | // IsSimpleTuple returns true if the ValExpr is a ValTuple that 81 | // contains simple values. 82 | func IsSimpleTuple(node ValExpr) bool { 83 | list, ok := node.(ValTuple) 84 | if !ok { 85 | // It's a subquery. 86 | return false 87 | } 88 | for _, n := range list { 89 | if !IsValue(n) { 90 | return false 91 | } 92 | } 93 | return true 94 | } 95 | 96 | // AsInterface converts the ValExpr to an interface. It converts 97 | // ValTuple to []interface{}, ValArg to string, StrVal to sqltypes.String, 98 | // NumVal to sqltypes.Numeric. Otherwise, it returns an error. 99 | func AsInterface(node ValExpr) (interface{}, error) { 100 | switch node := node.(type) { 101 | case ValTuple: 102 | vals := make([]interface{}, 0, len(node)) 103 | for _, val := range node { 104 | v, err := AsInterface(val) 105 | if err != nil { 106 | return nil, err 107 | } 108 | vals = append(vals, v) 109 | } 110 | return vals, nil 111 | case ValArg: 112 | return string(node), nil 113 | case StrVal: 114 | return sqltypes.MakeString(node), nil 115 | case NumVal: 116 | n, err := sqltypes.BuildNumeric(string(node)) 117 | if err != nil { 118 | return nil, fmt.Errorf("type mismatch: %s", err) 119 | } 120 | return n, nil 121 | } 122 | return nil, fmt.Errorf("unexpected node %v", node) 123 | } 124 | 125 | // StringIn is a convenience function that returns 126 | // true if str matches any of the values. 127 | func StringIn(str string, values ...string) bool { 128 | for _, val := range values { 129 | if str == val { 130 | return true 131 | } 132 | } 133 | return false 134 | } 135 | -------------------------------------------------------------------------------- /router/shard.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package router 6 | 7 | import ( 8 | "fmt" 9 | "github.com/siddontang/mixer/hack" 10 | "hash/crc32" 11 | "strconv" 12 | ) 13 | 14 | type KeyError string 15 | 16 | func NewKeyError(format string, args ...interface{}) KeyError { 17 | return KeyError(fmt.Sprintf(format, args...)) 18 | } 19 | 20 | func (ke KeyError) Error() string { 21 | return string(ke) 22 | } 23 | 24 | func handleError(err *error) { 25 | if x := recover(); x != nil { 26 | *err = x.(KeyError) 27 | } 28 | } 29 | 30 | func EncodeValue(value interface{}) string { 31 | switch val := value.(type) { 32 | case int: 33 | return Uint64Key(val).String() 34 | case uint64: 35 | return Uint64Key(val).String() 36 | case int64: 37 | return Uint64Key(val).String() 38 | case string: 39 | return val 40 | case []byte: 41 | return hack.String(val) 42 | } 43 | panic(NewKeyError("Unexpected key variable type %T", value)) 44 | } 45 | 46 | func HashValue(value interface{}) uint64 { 47 | switch val := value.(type) { 48 | case int: 49 | return uint64(val) 50 | case uint64: 51 | return uint64(val) 52 | case int64: 53 | return uint64(val) 54 | case string: 55 | return uint64(crc32.ChecksumIEEE(hack.Slice(val))) 56 | case []byte: 57 | return uint64(crc32.ChecksumIEEE(val)) 58 | } 59 | panic(NewKeyError("Unexpected key variable type %T", value)) 60 | } 61 | 62 | func NumValue(value interface{}) int64 { 63 | switch val := value.(type) { 64 | case int: 65 | return int64(val) 66 | case uint64: 67 | return int64(val) 68 | case int64: 69 | return int64(val) 70 | case string: 71 | if v, err := strconv.ParseInt(val, 10, 64); err != nil { 72 | panic(NewKeyError("invalid num format %s", v)) 73 | } else { 74 | return v 75 | } 76 | case []byte: 77 | if v, err := strconv.ParseInt(hack.String(val), 10, 64); err != nil { 78 | panic(NewKeyError("invalid num format %s", v)) 79 | } else { 80 | return v 81 | } 82 | } 83 | panic(NewKeyError("Unexpected key variable type %T", value)) 84 | } 85 | 86 | type Shard interface { 87 | FindForKey(key interface{}) int 88 | } 89 | 90 | type RangeShard interface { 91 | Shard 92 | EqualStart(key interface{}, index int) bool 93 | EqualStop(key interface{}, index int) bool 94 | } 95 | 96 | type HashShard struct { 97 | ShardNum int 98 | } 99 | 100 | func (s *HashShard) FindForKey(key interface{}) int { 101 | h := HashValue(key) 102 | 103 | return int(h % uint64(s.ShardNum)) 104 | } 105 | 106 | type NumRangeShard struct { 107 | Shards []NumKeyRange 108 | } 109 | 110 | func (s *NumRangeShard) FindForKey(key interface{}) int { 111 | v := NumValue(key) 112 | for i, r := range s.Shards { 113 | if r.Contains(v) { 114 | return i 115 | } 116 | } 117 | panic(NewKeyError("Unexpected key %v, not in range", key)) 118 | } 119 | 120 | func (s *NumRangeShard) EqualStart(key interface{}, index int) bool { 121 | v := NumValue(key) 122 | return s.Shards[index].Start == v 123 | } 124 | func (s *NumRangeShard) EqualStop(key interface{}, index int) bool { 125 | v := NumValue(key) 126 | return s.Shards[index].End == v 127 | } 128 | 129 | type KeyRangeShard struct { 130 | Shards []KeyRange 131 | } 132 | 133 | func (s *KeyRangeShard) FindForKey(key interface{}) int { 134 | v := KeyspaceId(EncodeValue(key)) 135 | for i, r := range s.Shards { 136 | if r.Contains(v) { 137 | return i 138 | } 139 | } 140 | panic(NewKeyError("Unexpected key %v, not in range", key)) 141 | } 142 | 143 | func (s *KeyRangeShard) EqualStart(key interface{}, index int) bool { 144 | v := KeyspaceId(EncodeValue(key)) 145 | return s.Shards[index].Start == v 146 | } 147 | func (s *KeyRangeShard) EqualStop(key interface{}, index int) bool { 148 | v := KeyspaceId(EncodeValue(key)) 149 | return s.Shards[index].End == v 150 | } 151 | 152 | type DefaultShard struct { 153 | } 154 | 155 | func (s *DefaultShard) FindForKey(key interface{}) int { 156 | return 0 157 | } 158 | -------------------------------------------------------------------------------- /client/db.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "container/list" 5 | "fmt" 6 | . "github.com/siddontang/mixer/mysql" 7 | "sync" 8 | "sync/atomic" 9 | ) 10 | 11 | type DB struct { 12 | sync.Mutex 13 | 14 | addr string 15 | user string 16 | password string 17 | db string 18 | maxIdleConns int 19 | 20 | idleConns *list.List 21 | 22 | connNum int32 23 | } 24 | 25 | func Open(addr string, user string, password string, dbName string) (*DB, error) { 26 | db := new(DB) 27 | 28 | db.addr = addr 29 | db.user = user 30 | db.password = password 31 | db.db = dbName 32 | 33 | db.idleConns = list.New() 34 | db.connNum = 0 35 | 36 | return db, nil 37 | } 38 | 39 | func (db *DB) Addr() string { 40 | return db.addr 41 | } 42 | 43 | func (db *DB) String() string { 44 | return fmt.Sprintf("%s:%s@%s/%s?maxIdleConns=%v", 45 | db.user, db.password, db.addr, db.db, db.maxIdleConns) 46 | } 47 | 48 | func (db *DB) Close() error { 49 | db.Lock() 50 | 51 | for { 52 | if db.idleConns.Len() > 0 { 53 | v := db.idleConns.Back() 54 | co := v.Value.(*Conn) 55 | db.idleConns.Remove(v) 56 | 57 | co.Close() 58 | 59 | } else { 60 | break 61 | } 62 | } 63 | 64 | db.Unlock() 65 | 66 | return nil 67 | } 68 | 69 | func (db *DB) Ping() error { 70 | c, err := db.PopConn() 71 | if err != nil { 72 | return err 73 | } 74 | 75 | err = c.Ping() 76 | db.PushConn(c, err) 77 | return err 78 | } 79 | 80 | func (db *DB) SetMaxIdleConnNum(num int) { 81 | db.maxIdleConns = num 82 | } 83 | 84 | func (db *DB) GetIdleConnNum() int { 85 | return db.idleConns.Len() 86 | } 87 | 88 | func (db *DB) GetConnNum() int { 89 | return int(db.connNum) 90 | } 91 | 92 | func (db *DB) newConn() (*Conn, error) { 93 | co := new(Conn) 94 | 95 | if err := co.Connect(db.addr, db.user, db.password, db.db); err != nil { 96 | return nil, err 97 | } 98 | 99 | return co, nil 100 | } 101 | 102 | func (db *DB) tryReuse(co *Conn) error { 103 | if co.IsInTransaction() { 104 | //we can not reuse a connection in transaction status 105 | if err := co.Rollback(); err != nil { 106 | return err 107 | } 108 | } 109 | 110 | if !co.IsAutoCommit() { 111 | //we can not reuse a connection not in autocomit 112 | if _, err := co.exec("set autocommit = 1"); err != nil { 113 | return err 114 | } 115 | } 116 | 117 | //connection may be set names early 118 | //we must use default utf8 119 | if co.GetCharset() != DEFAULT_CHARSET { 120 | if err := co.SetCharset(DEFAULT_CHARSET); err != nil { 121 | return err 122 | } 123 | } 124 | 125 | return nil 126 | } 127 | 128 | func (db *DB) PopConn() (co *Conn, err error) { 129 | db.Lock() 130 | if db.idleConns.Len() > 0 { 131 | v := db.idleConns.Front() 132 | co = v.Value.(*Conn) 133 | db.idleConns.Remove(v) 134 | } 135 | db.Unlock() 136 | 137 | if co != nil { 138 | if err := co.Ping(); err == nil { 139 | if err := db.tryReuse(co); err == nil { 140 | //connection may alive 141 | return co, nil 142 | } 143 | } 144 | co.Close() 145 | } 146 | 147 | co, err = db.newConn() 148 | if err == nil { 149 | atomic.AddInt32(&db.connNum, 1) 150 | } 151 | return 152 | } 153 | 154 | func (db *DB) PushConn(co *Conn, err error) { 155 | var closeConn *Conn = nil 156 | 157 | if err != nil { 158 | closeConn = co 159 | } else { 160 | if db.maxIdleConns > 0 { 161 | db.Lock() 162 | 163 | if db.idleConns.Len() >= db.maxIdleConns { 164 | v := db.idleConns.Front() 165 | closeConn = v.Value.(*Conn) 166 | db.idleConns.Remove(v) 167 | } 168 | 169 | db.idleConns.PushBack(co) 170 | 171 | db.Unlock() 172 | 173 | } else { 174 | closeConn = co 175 | } 176 | 177 | } 178 | 179 | if closeConn != nil { 180 | atomic.AddInt32(&db.connNum, -1) 181 | 182 | closeConn.Close() 183 | } 184 | } 185 | 186 | type SqlConn struct { 187 | *Conn 188 | 189 | db *DB 190 | } 191 | 192 | func (p *SqlConn) Close() { 193 | if p.Conn != nil { 194 | p.db.PushConn(p.Conn, p.Conn.pkgErr) 195 | p.Conn = nil 196 | } 197 | } 198 | 199 | func (db *DB) GetConn() (*SqlConn, error) { 200 | c, err := db.PopConn() 201 | return &SqlConn{c, db}, err 202 | } 203 | -------------------------------------------------------------------------------- /sqlparser/parsed_query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/siddontang/mixer/sqltypes" 11 | ) 12 | 13 | func TestParsedQuery(t *testing.T) { 14 | tcases := []struct { 15 | desc string 16 | query string 17 | bindVars map[string]interface{} 18 | listVars []sqltypes.Value 19 | output string 20 | }{ 21 | { 22 | "no subs", 23 | "select * from a where id = 2", 24 | map[string]interface{}{ 25 | "id": 1, 26 | }, 27 | nil, 28 | "select * from a where id = 2", 29 | }, { 30 | "simple bindvar sub", 31 | "select * from a where id1 = :id1 and id2 = :id2", 32 | map[string]interface{}{ 33 | "id1": 1, 34 | "id2": nil, 35 | }, 36 | nil, 37 | "select * from a where id1 = 1 and id2 = null", 38 | }, { 39 | "missing bind var", 40 | "select * from a where id1 = :id1 and id2 = :id2", 41 | map[string]interface{}{ 42 | "id1": 1, 43 | }, 44 | nil, 45 | "missing bind var id2", 46 | }, { 47 | "unencodable bind var", 48 | "select * from a where id1 = :id", 49 | map[string]interface{}{ 50 | "id": make([]int, 1), 51 | }, 52 | nil, 53 | "unsupported bind variable type []int: [0]", 54 | }, { 55 | "list var sub", 56 | "select * from a where id = :0 and name = :1", 57 | nil, 58 | []sqltypes.Value{ 59 | sqltypes.MakeNumeric([]byte("1")), 60 | sqltypes.MakeString([]byte("aa")), 61 | }, 62 | "select * from a where id = 1 and name = 'aa'", 63 | }, { 64 | "list inside bind vars", 65 | "select * from a where id in (:vals)", 66 | map[string]interface{}{ 67 | "vals": []sqltypes.Value{ 68 | sqltypes.MakeNumeric([]byte("1")), 69 | sqltypes.MakeString([]byte("aa")), 70 | }, 71 | }, 72 | nil, 73 | "select * from a where id in (1, 'aa')", 74 | }, { 75 | "two lists inside bind vars", 76 | "select * from a where id in (:vals)", 77 | map[string]interface{}{ 78 | "vals": [][]sqltypes.Value{ 79 | []sqltypes.Value{ 80 | sqltypes.MakeNumeric([]byte("1")), 81 | sqltypes.MakeString([]byte("aa")), 82 | }, 83 | []sqltypes.Value{ 84 | sqltypes.Value{}, 85 | sqltypes.MakeString([]byte("bb")), 86 | }, 87 | }, 88 | }, 89 | nil, 90 | "select * from a where id in ((1, 'aa'), (null, 'bb'))", 91 | }, { 92 | "illega list var name", 93 | "select * from a where id = :0a", 94 | nil, 95 | []sqltypes.Value{ 96 | sqltypes.MakeNumeric([]byte("1")), 97 | sqltypes.MakeString([]byte("aa")), 98 | }, 99 | `unexpected: strconv.ParseInt: parsing "0a": invalid syntax for 0a`, 100 | }, { 101 | "out of range list var index", 102 | "select * from a where id = :10", 103 | nil, 104 | []sqltypes.Value{ 105 | sqltypes.MakeNumeric([]byte("1")), 106 | sqltypes.MakeString([]byte("aa")), 107 | }, 108 | "index out of range: 10", 109 | }, 110 | } 111 | 112 | for _, tcase := range tcases { 113 | tree, err := Parse(tcase.query) 114 | if err != nil { 115 | t.Errorf("parse failed for %s: %v", tcase.desc, err) 116 | continue 117 | } 118 | buf := NewTrackedBuffer(nil) 119 | buf.Fprintf("%v", tree) 120 | pq := buf.ParsedQuery() 121 | bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.listVars) 122 | var got string 123 | if err != nil { 124 | got = err.Error() 125 | } else { 126 | got = string(bytes) 127 | } 128 | if got != tcase.output { 129 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output) 130 | } 131 | } 132 | } 133 | 134 | func TestStarParam(t *testing.T) { 135 | buf := NewTrackedBuffer(nil) 136 | buf.Fprintf("select * from a where id in (%a)", "*") 137 | pq := buf.ParsedQuery() 138 | listvars := []sqltypes.Value{ 139 | sqltypes.MakeNumeric([]byte("1")), 140 | sqltypes.MakeString([]byte("aa")), 141 | } 142 | bytes, err := pq.GenerateQuery(nil, listvars) 143 | if err != nil { 144 | t.Errorf("generate failed: %v", err) 145 | return 146 | } 147 | got := string(bytes) 148 | want := "select * from a where id in (1, 'aa')" 149 | if got != want { 150 | t.Errorf("got %s, want %s", got, want) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /client/stmt.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | . "github.com/siddontang/mixer/mysql" 7 | "math" 8 | ) 9 | 10 | type Stmt struct { 11 | conn *Conn 12 | id uint32 13 | query string 14 | 15 | params int 16 | columns int 17 | } 18 | 19 | func (s *Stmt) ParamNum() int { 20 | return s.params 21 | } 22 | 23 | func (s *Stmt) ColumnNum() int { 24 | return s.columns 25 | } 26 | 27 | func (s *Stmt) Execute(args ...interface{}) (*Result, error) { 28 | if err := s.write(args...); err != nil { 29 | return nil, err 30 | } 31 | 32 | return s.conn.readResult(true) 33 | } 34 | 35 | func (s *Stmt) Close() error { 36 | if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil { 37 | return err 38 | } 39 | 40 | return nil 41 | } 42 | 43 | func (s *Stmt) write(args ...interface{}) error { 44 | paramsNum := s.params 45 | 46 | if len(args) != paramsNum { 47 | return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) 48 | } 49 | 50 | paramTypes := make([]byte, paramsNum<<1) 51 | paramValues := make([][]byte, paramsNum) 52 | 53 | //NULL-bitmap, length: (num-params+7) 54 | nullBitmap := make([]byte, (paramsNum+7)>>3) 55 | 56 | var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)) 57 | 58 | var newParamBoundFlag byte = 0 59 | 60 | for i := range args { 61 | if args[i] == nil { 62 | nullBitmap[i/8] |= (1 << (uint(i) % 8)) 63 | paramTypes[i<<1] = MYSQL_TYPE_NULL 64 | continue 65 | } 66 | 67 | newParamBoundFlag = 1 68 | 69 | switch v := args[i].(type) { 70 | case int8: 71 | paramTypes[i<<1] = MYSQL_TYPE_TINY 72 | paramValues[i] = []byte{byte(v)} 73 | case int16: 74 | paramTypes[i<<1] = MYSQL_TYPE_SHORT 75 | paramValues[i] = Uint16ToBytes(uint16(v)) 76 | case int32: 77 | paramTypes[i<<1] = MYSQL_TYPE_LONG 78 | paramValues[i] = Uint32ToBytes(uint32(v)) 79 | case int: 80 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG 81 | paramValues[i] = Uint64ToBytes(uint64(v)) 82 | case int64: 83 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG 84 | paramValues[i] = Uint64ToBytes(uint64(v)) 85 | case uint8: 86 | paramTypes[i<<1] = MYSQL_TYPE_TINY 87 | paramTypes[(i<<1)+1] = 0x80 88 | paramValues[i] = []byte{v} 89 | case uint16: 90 | paramTypes[i<<1] = MYSQL_TYPE_SHORT 91 | paramTypes[(i<<1)+1] = 0x80 92 | paramValues[i] = Uint16ToBytes(uint16(v)) 93 | case uint32: 94 | paramTypes[i<<1] = MYSQL_TYPE_LONG 95 | paramTypes[(i<<1)+1] = 0x80 96 | paramValues[i] = Uint32ToBytes(uint32(v)) 97 | case uint: 98 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG 99 | paramTypes[(i<<1)+1] = 0x80 100 | paramValues[i] = Uint64ToBytes(uint64(v)) 101 | case uint64: 102 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG 103 | paramTypes[(i<<1)+1] = 0x80 104 | paramValues[i] = Uint64ToBytes(uint64(v)) 105 | case bool: 106 | paramTypes[i<<1] = MYSQL_TYPE_TINY 107 | if v { 108 | paramValues[i] = []byte{1} 109 | } else { 110 | paramValues[i] = []byte{0} 111 | 112 | } 113 | case float32: 114 | paramTypes[i<<1] = MYSQL_TYPE_FLOAT 115 | paramValues[i] = Uint32ToBytes(math.Float32bits(v)) 116 | case float64: 117 | paramTypes[i<<1] = MYSQL_TYPE_DOUBLE 118 | paramValues[i] = Uint64ToBytes(math.Float64bits(v)) 119 | case string: 120 | paramTypes[i<<1] = MYSQL_TYPE_STRING 121 | paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) 122 | case []byte: 123 | paramTypes[i<<1] = MYSQL_TYPE_STRING 124 | paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) 125 | default: 126 | return fmt.Errorf("invalid argument type %T", args[i]) 127 | } 128 | 129 | length += len(paramValues[i]) 130 | } 131 | 132 | data := make([]byte, 4, 4+length) 133 | 134 | data = append(data, COM_STMT_EXECUTE) 135 | data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24)) 136 | 137 | //flag: CURSOR_TYPE_NO_CURSOR 138 | data = append(data, 0x00) 139 | 140 | //iteration-count, always 1 141 | data = append(data, 1, 0, 0, 0) 142 | 143 | if s.params > 0 { 144 | data = append(data, nullBitmap...) 145 | 146 | //new-params-bound-flag 147 | data = append(data, newParamBoundFlag) 148 | 149 | if newParamBoundFlag == 1 { 150 | //type of each parameter, length: num-params * 2 151 | data = append(data, paramTypes...) 152 | 153 | //value of each parameter 154 | for _, v := range paramValues { 155 | data = append(data, v...) 156 | } 157 | } 158 | } 159 | 160 | s.conn.pkg.Sequence = 0 161 | 162 | return s.conn.writePacket(data) 163 | } 164 | 165 | func (c *Conn) Prepare(query string) (*Stmt, error) { 166 | if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil { 167 | return nil, err 168 | } 169 | 170 | data, err := c.readPacket() 171 | if err != nil { 172 | return nil, err 173 | } 174 | 175 | if data[0] == ERR_HEADER { 176 | return nil, c.handleErrorPacket(data) 177 | } else if data[0] != OK_HEADER { 178 | return nil, ErrMalformPacket 179 | } 180 | 181 | s := new(Stmt) 182 | s.conn = c 183 | 184 | pos := 1 185 | 186 | //for statement id 187 | s.id = binary.LittleEndian.Uint32(data[pos:]) 188 | pos += 4 189 | 190 | //number columns 191 | s.columns = int(binary.LittleEndian.Uint16(data[pos:])) 192 | pos += 2 193 | 194 | //number params 195 | s.params = int(binary.LittleEndian.Uint16(data[pos:])) 196 | pos += 2 197 | 198 | //warnings 199 | //warnings = binary.LittleEndian.Uint16(data[pos:]) 200 | 201 | if s.params > 0 { 202 | if err := s.conn.readUntilEOF(); err != nil { 203 | return nil, err 204 | } 205 | } 206 | 207 | if s.columns > 0 { 208 | if err := s.conn.readUntilEOF(); err != nil { 209 | return nil, err 210 | } 211 | } 212 | 213 | return s, nil 214 | } 215 | -------------------------------------------------------------------------------- /proxy/conn_show.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/go-log/log" 6 | "github.com/siddontang/mixer/hack" 7 | . "github.com/siddontang/mixer/mysql" 8 | "github.com/siddontang/mixer/sqlparser" 9 | "sort" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | func (c *Conn) handleShow(sql string, stmt *sqlparser.Show) error { 15 | var err error 16 | var r *Resultset 17 | switch strings.ToLower(stmt.Section) { 18 | case "databases": 19 | r, err = c.handleShowDatabases() 20 | case "tables": 21 | r, err = c.handleShowTables(sql, stmt) 22 | case "proxy": 23 | r, err = c.handleShowProxy(sql, stmt) 24 | default: 25 | err = fmt.Errorf("unsupport show %s now", sql) 26 | } 27 | 28 | if err != nil { 29 | return err 30 | } 31 | 32 | return c.writeResultset(c.status, r) 33 | } 34 | 35 | func (c *Conn) handleShowDatabases() (*Resultset, error) { 36 | dbs := make([]interface{}, 0, len(c.server.schemas)) 37 | for key := range c.server.schemas { 38 | dbs = append(dbs, key) 39 | } 40 | 41 | return c.buildSimpleShowResultset(dbs, "Database") 42 | } 43 | 44 | func (c *Conn) handleShowTables(sql string, stmt *sqlparser.Show) (*Resultset, error) { 45 | s := c.schema 46 | if stmt.From != nil { 47 | db := nstring(stmt.From) 48 | s = c.server.getSchema(db) 49 | } 50 | 51 | if s == nil { 52 | return nil, NewDefaultError(ER_NO_DB_ERROR) 53 | } 54 | 55 | var tables []string 56 | tmap := map[string]struct{}{} 57 | for _, n := range s.nodes { 58 | co, err := n.getMasterConn() 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | if err := co.UseDB(s.db); err != nil { 64 | co.Close() 65 | return nil, err 66 | } 67 | 68 | if r, err := co.Execute(sql); err != nil { 69 | co.Close() 70 | return nil, err 71 | } else { 72 | co.Close() 73 | for i := 0; i < r.RowNumber(); i++ { 74 | n, _ := r.GetString(i, 0) 75 | if _, ok := tmap[n]; !ok { 76 | tables = append(tables, n) 77 | } 78 | } 79 | } 80 | } 81 | 82 | sort.Strings(tables) 83 | 84 | values := make([]interface{}, len(tables)) 85 | for i := range tables { 86 | values[i] = tables[i] 87 | } 88 | 89 | return c.buildSimpleShowResultset(values, fmt.Sprintf("Tables_in_%s", s.db)) 90 | } 91 | 92 | func (c *Conn) handleShowProxy(sql string, stmt *sqlparser.Show) (*Resultset, error) { 93 | var err error 94 | var r *Resultset 95 | switch strings.ToLower(stmt.Key) { 96 | case "config": 97 | r, err = c.handleShowProxyConfig() 98 | case "status": 99 | r, err = c.handleShowProxyStatus(sql, stmt) 100 | default: 101 | err = fmt.Errorf("Unsupport show proxy [%v] yet, just support [config|status] now.", stmt.Key) 102 | log.Warn(err.Error()) 103 | return nil, err 104 | } 105 | return r, err 106 | } 107 | 108 | func (c *Conn) handleShowProxyConfig() (*Resultset, error) { 109 | var names []string = []string{"Section", "Key", "Value"} 110 | var rows [][]string 111 | const ( 112 | Column = 3 113 | ) 114 | 115 | rows = append(rows, []string{"Global_Config", "Addr", c.server.cfg.Addr}) 116 | rows = append(rows, []string{"Global_Config", "User", c.server.cfg.User}) 117 | rows = append(rows, []string{"Global_Config", "Password", c.server.cfg.Password}) 118 | rows = append(rows, []string{"Global_Config", "LogLevel", c.server.cfg.LogLevel}) 119 | rows = append(rows, []string{"Global_Config", "Schemas_Count", fmt.Sprintf("%d", len(c.server.schemas))}) 120 | rows = append(rows, []string{"Global_Config", "Nodes_Count", fmt.Sprintf("%d", len(c.server.nodes))}) 121 | 122 | for db, schema := range c.server.schemas { 123 | rows = append(rows, []string{"Schemas", "DB", db}) 124 | 125 | var nodeNames []string 126 | var nodeRows [][]string 127 | for name, node := range schema.nodes { 128 | nodeNames = append(nodeNames, name) 129 | var nodeSection = fmt.Sprintf("Schemas[%s]-Node[ %v ]", db, name) 130 | 131 | if node.master != nil { 132 | nodeRows = append(nodeRows, []string{nodeSection, "Master", node.master.String()}) 133 | } 134 | 135 | if node.slave != nil { 136 | nodeRows = append(nodeRows, []string{nodeSection, "Slave", node.slave.String()}) 137 | } 138 | nodeRows = append(nodeRows, []string{nodeSection, "Last_Master_Ping", fmt.Sprintf("%v", time.Unix(node.lastMasterPing, 0))}) 139 | 140 | nodeRows = append(nodeRows, []string{nodeSection, "Last_Slave_Ping", fmt.Sprintf("%v", time.Unix(node.lastSlavePing, 0))}) 141 | 142 | nodeRows = append(nodeRows, []string{nodeSection, "down_after_noalive", fmt.Sprintf("%v", node.downAfterNoAlive)}) 143 | 144 | } 145 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]", db), "Nodes_List", strings.Join(nodeNames, ",")}) 146 | 147 | var defaultRule = schema.rule.DefaultRule 148 | if defaultRule.DB == db { 149 | if defaultRule.DB == db { 150 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]_Rule_Default", db), 151 | "Default_Table", defaultRule.String()}) 152 | } 153 | } 154 | for tb, r := range schema.rule.Rules { 155 | if r.DB == db { 156 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]_Rule_Table", db), 157 | fmt.Sprintf("Table[ %s ]", tb), r.String()}) 158 | } 159 | } 160 | 161 | rows = append(rows, nodeRows...) 162 | 163 | } 164 | 165 | var values [][]interface{} = make([][]interface{}, len(rows)) 166 | for i := range rows { 167 | values[i] = make([]interface{}, Column) 168 | for j := range rows[i] { 169 | values[i][j] = rows[i][j] 170 | } 171 | } 172 | 173 | return c.buildResultset(names, values) 174 | } 175 | 176 | func (c *Conn) handleShowProxyStatus(sql string, stmt *sqlparser.Show) (*Resultset, error) { 177 | // TODO: handle like_or_where expr 178 | return nil, nil 179 | } 180 | 181 | func (c *Conn) buildSimpleShowResultset(values []interface{}, name string) (*Resultset, error) { 182 | 183 | r := new(Resultset) 184 | 185 | field := &Field{} 186 | 187 | field.Name = hack.Slice(name) 188 | field.Charset = 33 189 | field.Type = MYSQL_TYPE_VAR_STRING 190 | 191 | r.Fields = []*Field{field} 192 | 193 | var row []byte 194 | var err error 195 | 196 | for _, value := range values { 197 | row, err = formatValue(value) 198 | if err != nil { 199 | return nil, err 200 | } 201 | r.RowDatas = append(r.RowDatas, 202 | PutLengthEncodedString(row)) 203 | } 204 | 205 | return r, nil 206 | } 207 | -------------------------------------------------------------------------------- /proxy/conn_stmt_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestStmt_DropTable(t *testing.T) { 8 | server := newTestServer(t) 9 | n := server.nodes["node1"] 10 | c, err := n.getMasterConn() 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | c.UseDB("mixer") 15 | if _, err := c.Execute(`drop table if exists mixer_test_proxy_stmt`); err != nil { 16 | t.Fatal(err) 17 | } 18 | c.Close() 19 | } 20 | 21 | func TestStmt_CreateTable(t *testing.T) { 22 | str := `CREATE TABLE IF NOT EXISTS mixer_test_proxy_stmt ( 23 | id BIGINT(64) UNSIGNED NOT NULL, 24 | str VARCHAR(256), 25 | f DOUBLE, 26 | e enum("test1", "test2"), 27 | u tinyint unsigned, 28 | i tinyint, 29 | PRIMARY KEY (id) 30 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 31 | 32 | server := newTestServer(t) 33 | n := server.nodes["node1"] 34 | c, err := n.getMasterConn() 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | c.UseDB("mixer") 40 | defer c.Close() 41 | if _, err := c.Execute(str); err != nil { 42 | t.Fatal(err) 43 | } 44 | } 45 | 46 | func TestStmt_Insert(t *testing.T) { 47 | str := `insert into mixer_test_proxy_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)` 48 | 49 | c := newTestDBConn(t) 50 | defer c.Close() 51 | 52 | s, err := c.Prepare(str) 53 | 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | 58 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil { 59 | t.Fatal(err) 60 | } else { 61 | if pkg.AffectedRows != 1 { 62 | t.Fatal(pkg.AffectedRows) 63 | } 64 | } 65 | 66 | s.Close() 67 | } 68 | 69 | func TestStmt_Select(t *testing.T) { 70 | str := `select str, f, e from mixer_test_proxy_stmt where id = ?` 71 | 72 | c := newTestDBConn(t) 73 | defer c.Close() 74 | 75 | s, err := c.Prepare(str) 76 | 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | 81 | if result, err := s.Execute(1); err != nil { 82 | t.Fatal(err) 83 | } else { 84 | if len(result.Values) != 1 { 85 | t.Fatal(len(result.Values)) 86 | } 87 | 88 | if len(result.Fields) != 3 { 89 | t.Fatal(len(result.Fields)) 90 | } 91 | 92 | if str, _ := result.GetString(0, 0); str != "a" { 93 | t.Fatal("invalid str", str) 94 | } 95 | 96 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 97 | t.Fatal("invalid f", f) 98 | } 99 | 100 | if e, _ := result.GetString(0, 2); e != "test1" { 101 | t.Fatal("invalid e", e) 102 | } 103 | 104 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 105 | t.Fatal("invalid str", str) 106 | } 107 | 108 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 109 | t.Fatal("invalid f", f) 110 | } 111 | 112 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 113 | t.Fatal("invalid e", e) 114 | } 115 | 116 | } 117 | 118 | s.Close() 119 | } 120 | 121 | func TestStmt_NULL(t *testing.T) { 122 | str := `insert into mixer_test_proxy_stmt (id, str, f, e) values (?, ?, ?, ?)` 123 | 124 | c := newTestDBConn(t) 125 | defer c.Close() 126 | 127 | s, err := c.Prepare(str) 128 | 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | 133 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil { 134 | t.Fatal(err) 135 | } else { 136 | if pkg.AffectedRows != 1 { 137 | t.Fatal(pkg.AffectedRows) 138 | } 139 | } 140 | 141 | s.Close() 142 | 143 | str = `select * from mixer_test_proxy_stmt where id = ?` 144 | s, err = c.Prepare(str) 145 | 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | if r, err := s.Execute(2); err != nil { 151 | t.Fatal(err) 152 | } else { 153 | if b, err := r.IsNullByName(0, "id"); err != nil { 154 | t.Fatal(err) 155 | } else if b == true { 156 | t.Fatal(b) 157 | } 158 | 159 | if b, err := r.IsNullByName(0, "str"); err != nil { 160 | t.Fatal(err) 161 | } else if b == false { 162 | t.Fatal(b) 163 | } 164 | 165 | if b, err := r.IsNullByName(0, "f"); err != nil { 166 | t.Fatal(err) 167 | } else if b == true { 168 | t.Fatal(b) 169 | } 170 | 171 | if b, err := r.IsNullByName(0, "e"); err != nil { 172 | t.Fatal(err) 173 | } else if b == false { 174 | t.Fatal(b) 175 | } 176 | } 177 | 178 | s.Close() 179 | } 180 | 181 | func TestStmt_Unsigned(t *testing.T) { 182 | str := `insert into mixer_test_proxy_stmt (id, u) values (?, ?)` 183 | 184 | c := newTestDBConn(t) 185 | defer c.Close() 186 | 187 | s, err := c.Prepare(str) 188 | 189 | if err != nil { 190 | t.Fatal(err) 191 | } 192 | 193 | if pkg, err := s.Execute(3, uint8(255)); err != nil { 194 | t.Fatal(err) 195 | } else { 196 | if pkg.AffectedRows != 1 { 197 | t.Fatal(pkg.AffectedRows) 198 | } 199 | } 200 | 201 | s.Close() 202 | 203 | str = `select u from mixer_test_proxy_stmt where id = ?` 204 | 205 | s, err = c.Prepare(str) 206 | if err != nil { 207 | t.Fatal(err) 208 | } 209 | 210 | if r, err := s.Execute(3); err != nil { 211 | t.Fatal(err) 212 | } else { 213 | if u, err := r.GetUint(0, 0); err != nil { 214 | t.Fatal(err) 215 | } else if u != uint64(255) { 216 | t.Fatal(u) 217 | } 218 | } 219 | 220 | s.Close() 221 | } 222 | 223 | func TestStmt_Signed(t *testing.T) { 224 | str := `insert into mixer_test_proxy_stmt (id, i) values (?, ?)` 225 | 226 | c := newTestDBConn(t) 227 | defer c.Close() 228 | 229 | s, err := c.Prepare(str) 230 | 231 | if err != nil { 232 | t.Fatal(err) 233 | } 234 | 235 | if _, err := s.Execute(4, 127); err != nil { 236 | t.Fatal(err) 237 | } 238 | 239 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil { 240 | t.Fatal(err) 241 | } 242 | 243 | s.Close() 244 | 245 | } 246 | 247 | func TestStmt_Trans(t *testing.T) { 248 | c1 := newTestDBConn(t) 249 | defer c1.Close() 250 | 251 | if _, err := c1.Execute(`insert into mixer_test_proxy_stmt (id, str) values (1002, "abc")`); err != nil { 252 | t.Fatal(err) 253 | } 254 | 255 | var err error 256 | if err = c1.Begin(); err != nil { 257 | t.Fatal(err) 258 | } 259 | 260 | str := `select str from mixer_test_proxy_stmt where id = ?` 261 | 262 | s, err := c1.Prepare(str) 263 | if err != nil { 264 | t.Fatal(err) 265 | } 266 | 267 | if _, err := s.Execute(1002); err != nil { 268 | t.Fatal(err) 269 | } 270 | 271 | if err := c1.Commit(); err != nil { 272 | t.Fatal(err) 273 | } 274 | 275 | if r, err := s.Execute(1002); err != nil { 276 | t.Fatal(err) 277 | } else { 278 | if str, _ := r.GetString(0, 0); str != `abc` { 279 | t.Fatal(str) 280 | } 281 | } 282 | 283 | if err := s.Close(); err != nil { 284 | t.Fatal(err) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /doc/mysql-proxy/scripting.txt: -------------------------------------------------------------------------------- 1 | Hooks 2 | ===== 3 | 4 | connect_server 5 | -------------- 6 | 7 | read_auth 8 | --------- 9 | 10 | read_auth_result 11 | ---------------- 12 | 13 | read_query 14 | ---------- 15 | 16 | read_query_result 17 | ----------------- 18 | 19 | disconnect_client 20 | ----------------- 21 | 22 | Modules 23 | ======= 24 | 25 | mysql.proto 26 | ----------- 27 | 28 | The ``mysql.proto`` module provides encoders and decoders for the packets exchanged between client and server 29 | 30 | 31 | from_err_packet 32 | ............... 33 | 34 | Decodes a ERR-packet into a table. 35 | 36 | Parameters: 37 | 38 | ``packet`` 39 | (string) mysql packet 40 | 41 | 42 | On success it returns a table containing: 43 | 44 | ``errmsg`` 45 | (string) 46 | 47 | ``sqlstate`` 48 | (string) 49 | 50 | ``errcode`` 51 | (int) 52 | 53 | Otherwise it raises an error. 54 | 55 | to_err_packet 56 | ............. 57 | 58 | Encode a table containing a ERR packet into a MySQL packet. 59 | 60 | Parameters: 61 | 62 | ``err`` 63 | (table) 64 | 65 | ``errmsg`` 66 | (string) 67 | 68 | ``sqlstate`` 69 | (string) 70 | 71 | ``errcode`` 72 | (int) 73 | 74 | into a MySQL packet. 75 | 76 | Returns a string. 77 | 78 | from_ok_packet 79 | .............. 80 | 81 | Decodes a OK-packet 82 | 83 | ``packet`` 84 | (string) mysql packet 85 | 86 | 87 | On success it returns a table containing: 88 | 89 | ``server_status`` 90 | (int) bit-mask of the connection status 91 | 92 | ``insert_id`` 93 | (int) last used insert id 94 | 95 | ``warnings`` 96 | (int) number of warnings for the last executed statement 97 | 98 | ``affected_rows`` 99 | (int) rows affected by the last statement 100 | 101 | Otherwise it raises an error. 102 | 103 | 104 | to_ok_packet 105 | ............ 106 | 107 | Encode a OK packet 108 | 109 | from_eof_packet 110 | ............... 111 | 112 | Decodes a EOF-packet 113 | 114 | Parameters: 115 | 116 | ``packet`` 117 | (string) mysql packet 118 | 119 | 120 | On success it returns a table containing: 121 | 122 | ``server_status`` 123 | (int) bit-mask of the connection status 124 | 125 | ``warnings`` 126 | (int) 127 | 128 | Otherwise it raises an error. 129 | 130 | 131 | to_eof_packet 132 | ............. 133 | 134 | from_challenge_packet 135 | ..................... 136 | 137 | Decodes a auth-challenge-packet 138 | 139 | Parameters: 140 | 141 | ``packet`` 142 | (string) mysql packet 143 | 144 | On success it returns a table containing: 145 | 146 | ``protocol_version`` 147 | (int) version of the mysql protocol, usually 10 148 | 149 | ``server_version`` 150 | (int) version of the server as integer: 50506 is MySQL 5.5.6 151 | 152 | ``thread_id`` 153 | (int) connection id 154 | 155 | ``capabilities`` 156 | (int) bit-mask of the server capabilities 157 | 158 | ``charset`` 159 | (int) server default character-set 160 | 161 | ``server_status`` 162 | (int) bit-mask of the connection-status 163 | 164 | ``challenge`` 165 | (string) password challenge 166 | 167 | 168 | to_challenge_packet 169 | ................... 170 | 171 | Encode a auth-response-packet 172 | 173 | from_response_packet 174 | .................... 175 | 176 | Decodes a auth-response-packet 177 | 178 | Parameters: 179 | 180 | ``packet`` 181 | (string) mysql packet 182 | 183 | 184 | to_response_packet 185 | .................. 186 | 187 | from_masterinfo_string 188 | ...................... 189 | 190 | Decodes the content of the ``master.info`` file. 191 | 192 | 193 | to_masterinfo_string 194 | .................... 195 | 196 | from_stmt_prepare_packet 197 | ........................ 198 | 199 | Decodes a COM_STMT_PREPARE-packet 200 | 201 | Parameters: 202 | 203 | ``packet`` 204 | (string) mysql packet 205 | 206 | 207 | On success it returns a table containing: 208 | 209 | ``stmt_text`` 210 | (string) 211 | text of the prepared statement 212 | 213 | Otherwise it raises an error. 214 | 215 | from_stmt_prepare_ok_packet 216 | ........................... 217 | 218 | Decodes a COM_STMT_PACKET OK-packet 219 | 220 | Parameters: 221 | 222 | ``packet`` 223 | (string) mysql packet 224 | 225 | 226 | On success it returns a table containing: 227 | 228 | ``stmt_id`` 229 | (int) statement-id 230 | 231 | ``num_columns`` 232 | (int) number of columns in the resultset 233 | 234 | ``num_params`` 235 | (int) number of parameters 236 | 237 | ``warnings`` 238 | (int) warnings generated by the prepare statement 239 | 240 | Otherwise it raises an error. 241 | 242 | 243 | from_stmt_execute_packet 244 | ........................ 245 | 246 | Decodes a COM_STMT_EXECUTE-packet 247 | 248 | Parameters: 249 | 250 | ``packet`` 251 | (string) mysql packet 252 | 253 | ``num_params`` 254 | (int) number of parameters of the corresponding prepared statement 255 | 256 | On success it returns a table containing: 257 | 258 | ``stmt_id`` 259 | (int) statemend-id 260 | 261 | ``flags`` 262 | (int) flags describing the kind of cursor used 263 | 264 | ``iteration_count`` 265 | (int) iteration count: always 1 266 | 267 | ``new_params_bound`` 268 | (bool) 269 | 270 | ``params`` 271 | (nil, table) 272 | number-index array of parameters if ``new_params_bound`` is ``true`` 273 | 274 | Each param is a table of: 275 | 276 | ``type`` 277 | (int) 278 | MYSQL_TYPE_INT, MYSQL_TYPE_STRING ... and so on 279 | 280 | ``value`` 281 | (nil, number, string) 282 | if the value is a NULL, it ``nil`` 283 | if it is a number (_INT, _DOUBLE, ...) it is a ``number`` 284 | otherwise it is a ``string`` 285 | 286 | If decoding fails it raises an error. 287 | 288 | To get the ``num_params`` for this function, you have to track the track the number of parameters as returned 289 | by the `from_stmt_prepare_ok_packet`_. Use `stmt_id_from_stmt_execute_packet`_ to get the ``statement-id`` from 290 | the COM_STMT_EXECUTE packet and lookup your tracked information. 291 | 292 | stmt_id_from_stmt_execute_packet 293 | ................................ 294 | 295 | Decodes statement-id from a COM_STMT_EXECUTE-packet 296 | 297 | Parameters: 298 | 299 | ``packet`` 300 | (string) mysql packet 301 | 302 | 303 | On success it returns the ``statement-id`` as ``int``. 304 | 305 | Otherwise it raises an error. 306 | 307 | from_stmt_close_packet 308 | ...................... 309 | 310 | Decodes a COM_STMT_CLOSE-packet 311 | 312 | Parameters: 313 | 314 | ``packet`` 315 | (string) mysql packet 316 | 317 | 318 | On success it returns a table containing: 319 | 320 | ``stmt_id`` 321 | (int) 322 | statement-id that shall be closed 323 | 324 | Otherwise it raises an error. 325 | 326 | 327 | -------------------------------------------------------------------------------- /router/key.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package router 6 | 7 | import ( 8 | "bytes" 9 | "encoding/binary" 10 | "encoding/hex" 11 | "fmt" 12 | "strings" 13 | ) 14 | 15 | // 16 | // KeyspaceId definitions 17 | // 18 | 19 | // MinKey is smaller than all KeyspaceId (the value really is). 20 | var MinKey = KeyspaceId("") 21 | 22 | // MaxKey is bigger than all KeyspaceId (by convention). 23 | var MaxKey = KeyspaceId("") 24 | 25 | // KeyspaceId is the type we base sharding on. 26 | type KeyspaceId string 27 | 28 | // Hex prints a KeyspaceId in capital hex. 29 | func (kid KeyspaceId) Hex() HexKeyspaceId { 30 | return HexKeyspaceId(strings.ToUpper(hex.EncodeToString([]byte(kid)))) 31 | } 32 | 33 | // 34 | // Uint64Key definitions 35 | // 36 | 37 | // Uint64Key is a uint64 that can be converted into a KeyspaceId. 38 | type Uint64Key uint64 39 | 40 | func (i Uint64Key) String() string { 41 | buf := new(bytes.Buffer) 42 | binary.Write(buf, binary.BigEndian, uint64(i)) 43 | return buf.String() 44 | } 45 | 46 | // KeyspaceId returns the KeyspaceId associated with a Uint64Key. 47 | func (i Uint64Key) KeyspaceId() KeyspaceId { 48 | return KeyspaceId(i.String()) 49 | } 50 | 51 | // HexKeyspaceId is the hex represention of a KeyspaceId. 52 | type HexKeyspaceId string 53 | 54 | // Unhex converts a HexKeyspaceId into a KeyspaceId (hex decoding). 55 | func (hkid HexKeyspaceId) Unhex() (KeyspaceId, error) { 56 | b, err := hex.DecodeString(string(hkid)) 57 | if err != nil { 58 | return KeyspaceId(""), err 59 | } 60 | return KeyspaceId(string(b)), nil 61 | } 62 | 63 | // 64 | // KeyspaceIdType definitions 65 | // 66 | 67 | // KeyspaceIdType represents the type of the KeyspaceId. 68 | // Usually we don't care, but some parts of the code will need that info. 69 | type KeyspaceIdType string 70 | 71 | const ( 72 | // unset - no type for this KeyspaceId 73 | KIT_UNSET = KeyspaceIdType("") 74 | 75 | // uint64 - a uint64 value is used 76 | // this is represented as 'unsigned bigint' in mysql 77 | KIT_UINT64 = KeyspaceIdType("uint64") 78 | 79 | // bytes - a string of bytes is used 80 | // this is represented as 'varbinary' in mysql 81 | KIT_BYTES = KeyspaceIdType("bytes") 82 | ) 83 | 84 | var AllKeyspaceIdTypes = []KeyspaceIdType{ 85 | KIT_UNSET, 86 | KIT_UINT64, 87 | KIT_BYTES, 88 | } 89 | 90 | // IsKeyspaceIdTypeInList returns true if the given type is in the list. 91 | // Use it with AllKeyspaceIdTypes for instance. 92 | func IsKeyspaceIdTypeInList(typ KeyspaceIdType, types []KeyspaceIdType) bool { 93 | for _, t := range types { 94 | if typ == t { 95 | return true 96 | } 97 | } 98 | return false 99 | } 100 | 101 | // 102 | // KeyRange definitions 103 | // 104 | 105 | // KeyRange is an interval of KeyspaceId values. It contains Start, 106 | // but excludes End. In other words, it is: [Start, End[ 107 | type KeyRange struct { 108 | Start KeyspaceId 109 | End KeyspaceId 110 | } 111 | 112 | func (kr KeyRange) MapKey() string { 113 | return string(kr.Start) + "-" + string(kr.End) 114 | } 115 | 116 | func (kr KeyRange) Contains(i KeyspaceId) bool { 117 | return kr.Start <= i && (kr.End == MaxKey || i < kr.End) 118 | } 119 | 120 | func (kr KeyRange) String() string { 121 | return fmt.Sprintf("{Start: %v, End: %v}", string(kr.Start.Hex()), string(kr.End.Hex())) 122 | } 123 | 124 | // Parse a start and end hex values and build a KeyRange 125 | func ParseKeyRangeParts(start, end string) (KeyRange, error) { 126 | s, err := HexKeyspaceId(start).Unhex() 127 | if err != nil { 128 | return KeyRange{}, err 129 | } 130 | e, err := HexKeyspaceId(end).Unhex() 131 | if err != nil { 132 | return KeyRange{}, err 133 | } 134 | return KeyRange{Start: s, End: e}, nil 135 | } 136 | 137 | // Returns true if the KeyRange does not cover the entire space. 138 | func (kr KeyRange) IsPartial() bool { 139 | return !(kr.Start == MinKey && kr.End == MaxKey) 140 | } 141 | 142 | // KeyRangesIntersect returns true if some Keyspace values exist in both ranges. 143 | // 144 | // See: http://stackoverflow.com/questions/4879315/what-is-a-tidy-algorithm-to-find-overlapping-intervals 145 | // two segments defined as (a,b) and (c,d) (with a c) && (a < d) 147 | // overlap = min(b, d) - max(c, a) 148 | func KeyRangesIntersect(first, second KeyRange) bool { 149 | return (first.End == MaxKey || second.Start < first.End) && 150 | (second.End == MaxKey || first.Start < second.End) 151 | } 152 | 153 | // KeyRangesOverlap returns the overlap between two KeyRanges. 154 | // They need to overlap, otherwise an error is returned. 155 | func KeyRangesOverlap(first, second KeyRange) (KeyRange, error) { 156 | if !KeyRangesIntersect(first, second) { 157 | return KeyRange{}, fmt.Errorf("KeyRanges %v and %v don't overlap", first, second) 158 | } 159 | // compute max(c,a) and min(b,d) 160 | // start with (a,b) 161 | result := first 162 | // if c > a, then use c 163 | if second.Start > first.Start { 164 | result.Start = second.Start 165 | } 166 | // if b is maxed out, or 167 | // (d is not maxed out and d < b) 168 | // ^ valid test as neither b nor d are max 169 | // then use d 170 | if first.End == MaxKey || (second.End != MaxKey && second.End < first.End) { 171 | result.End = second.End 172 | } 173 | return result, nil 174 | } 175 | 176 | // ParseShardingSpec parses a string that describes a sharding 177 | // specification. a-b-c-d will be parsed as a-b, b-c, c-d. The empty 178 | // string may serve both as the start and end of the keyspace: -a-b- 179 | // will be parsed as start-a, a-b, b-end. 180 | func ParseShardingSpec(spec string) ([]KeyRange, error) { 181 | parts := strings.Split(spec, "-") 182 | if len(parts) == 1 { 183 | return nil, fmt.Errorf("malformed spec: doesn't define a range: %q", spec) 184 | } 185 | old := parts[0] 186 | ranges := make([]KeyRange, len(parts)-1) 187 | 188 | for i, p := range parts[1:] { 189 | if p == "" && i != (len(parts)-2) { 190 | return nil, fmt.Errorf("malformed spec: MinKey/MaxKey cannot be in the middle of the spec: %q", spec) 191 | } 192 | if p != "" && p <= old { 193 | return nil, fmt.Errorf("malformed spec: shard limits should be in order: %q", spec) 194 | } 195 | s, err := HexKeyspaceId(old).Unhex() 196 | if err != nil { 197 | return nil, err 198 | } 199 | e, err := HexKeyspaceId(p).Unhex() 200 | if err != nil { 201 | return nil, err 202 | } 203 | ranges[i] = KeyRange{Start: s, End: e} 204 | old = p 205 | } 206 | return ranges, nil 207 | } 208 | -------------------------------------------------------------------------------- /proxy/node.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/go-log/log" 6 | "github.com/siddontang/mixer/client" 7 | "github.com/siddontang/mixer/config" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | const ( 13 | Master = "master" 14 | Slave = "slave" 15 | ) 16 | 17 | type Node struct { 18 | sync.Mutex 19 | 20 | server *Server 21 | 22 | cfg config.NodeConfig 23 | 24 | //running master db 25 | db *client.DB 26 | 27 | master *client.DB 28 | slave *client.DB 29 | 30 | downAfterNoAlive time.Duration 31 | 32 | lastMasterPing int64 33 | lastSlavePing int64 34 | } 35 | 36 | func (n *Node) run() { 37 | //to do 38 | //1 check connection alive 39 | //2 check remove mysql server alive 40 | 41 | t := time.NewTicker(3000 * time.Second) 42 | defer t.Stop() 43 | 44 | n.lastMasterPing = time.Now().Unix() 45 | n.lastSlavePing = n.lastMasterPing 46 | for { 47 | select { 48 | case <-t.C: 49 | n.checkMaster() 50 | n.checkSlave() 51 | } 52 | } 53 | } 54 | 55 | func (n *Node) String() string { 56 | return n.cfg.Name 57 | } 58 | 59 | func (n *Node) getMasterConn() (*client.SqlConn, error) { 60 | n.Lock() 61 | db := n.db 62 | n.Unlock() 63 | 64 | if db == nil { 65 | return nil, fmt.Errorf("master is down") 66 | } 67 | 68 | return db.GetConn() 69 | } 70 | 71 | func (n *Node) getSelectConn() (*client.SqlConn, error) { 72 | var db *client.DB 73 | 74 | n.Lock() 75 | if n.cfg.RWSplit && n.slave != nil { 76 | db = n.slave 77 | } else { 78 | db = n.db 79 | } 80 | n.Unlock() 81 | 82 | if db == nil { 83 | return nil, fmt.Errorf("no alive mysql server") 84 | } 85 | 86 | return db.GetConn() 87 | } 88 | 89 | func (n *Node) checkMaster() { 90 | n.Lock() 91 | db := n.db 92 | n.Unlock() 93 | 94 | if db == nil { 95 | log.Info("no master avaliable") 96 | return 97 | } 98 | 99 | if err := db.Ping(); err != nil { 100 | log.Error("%s ping master %s error %s", n, db.Addr(), err.Error()) 101 | } else { 102 | n.lastMasterPing = time.Now().Unix() 103 | return 104 | } 105 | 106 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastMasterPing > int64(n.downAfterNoAlive) { 107 | log.Error("%s down master db %s", n, n.master.Addr()) 108 | 109 | n.downMaster() 110 | } 111 | } 112 | 113 | func (n *Node) checkSlave() { 114 | if n.slave == nil { 115 | return 116 | } 117 | 118 | db := n.slave 119 | if err := db.Ping(); err != nil { 120 | log.Error("%s ping slave %s error %s", n, db.Addr(), err.Error()) 121 | } else { 122 | n.lastSlavePing = time.Now().Unix() 123 | } 124 | 125 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastSlavePing > int64(n.downAfterNoAlive) { 126 | log.Error("%s slave db %s not alive over %ds, down it", 127 | n, db.Addr(), int64(n.downAfterNoAlive/time.Second)) 128 | 129 | n.downSlave() 130 | } 131 | } 132 | 133 | func (n *Node) openDB(addr string) (*client.DB, error) { 134 | db, err := client.Open(addr, n.cfg.User, n.cfg.Password, "") 135 | if err != nil { 136 | return nil, err 137 | } 138 | 139 | db.SetMaxIdleConnNum(n.cfg.IdleConns) 140 | return db, nil 141 | } 142 | 143 | func (n *Node) checkUpDB(addr string) (*client.DB, error) { 144 | db, err := n.openDB(addr) 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | if err := db.Ping(); err != nil { 150 | db.Close() 151 | return nil, err 152 | } 153 | 154 | return db, nil 155 | } 156 | 157 | func (n *Node) upMaster(addr string) error { 158 | n.Lock() 159 | if n.master != nil { 160 | n.Unlock() 161 | return fmt.Errorf("%s master must be down first", n) 162 | } 163 | n.Unlock() 164 | 165 | db, err := n.checkUpDB(addr) 166 | if err != nil { 167 | return err 168 | } 169 | 170 | n.Lock() 171 | n.master = db 172 | n.db = db 173 | n.Unlock() 174 | 175 | return nil 176 | } 177 | 178 | func (n *Node) upSlave(addr string) error { 179 | n.Lock() 180 | if n.slave != nil { 181 | n.Unlock() 182 | return fmt.Errorf("%s, slave must be down first", n) 183 | } 184 | n.Unlock() 185 | 186 | db, err := n.checkUpDB(addr) 187 | if err != nil { 188 | return err 189 | } 190 | 191 | n.Lock() 192 | n.slave = db 193 | n.Unlock() 194 | 195 | return nil 196 | } 197 | 198 | func (n *Node) downMaster() error { 199 | n.Lock() 200 | if n.master != nil { 201 | n.master = nil 202 | } 203 | return nil 204 | } 205 | 206 | func (n *Node) downSlave() error { 207 | n.Lock() 208 | db := n.slave 209 | n.slave = nil 210 | n.Unlock() 211 | 212 | if db != nil { 213 | db.Close() 214 | } 215 | 216 | return nil 217 | } 218 | 219 | func (s *Server) UpMaster(node string, addr string) error { 220 | n := s.getNode(node) 221 | if n == nil { 222 | return fmt.Errorf("invalid node %s", node) 223 | } 224 | 225 | return n.upMaster(addr) 226 | } 227 | 228 | func (s *Server) UpSlave(node string, addr string) error { 229 | n := s.getNode(node) 230 | if n == nil { 231 | return fmt.Errorf("invalid node %s", node) 232 | } 233 | 234 | return n.upSlave(addr) 235 | } 236 | func (s *Server) DownMaster(node string) error { 237 | n := s.getNode(node) 238 | if n == nil { 239 | return fmt.Errorf("invalid node %s", node) 240 | } 241 | n.db = nil 242 | return n.downMaster() 243 | } 244 | 245 | func (s *Server) DownSlave(node string) error { 246 | n := s.getNode(node) 247 | if n == nil { 248 | return fmt.Errorf("invalid node [%s].", node) 249 | } 250 | return n.downSlave() 251 | } 252 | 253 | func (s *Server) getNode(name string) *Node { 254 | return s.nodes[name] 255 | } 256 | 257 | func (s *Server) parseNodes() error { 258 | cfg := s.cfg 259 | s.nodes = make(map[string]*Node, len(cfg.Nodes)) 260 | 261 | for _, v := range cfg.Nodes { 262 | if _, ok := s.nodes[v.Name]; ok { 263 | return fmt.Errorf("duplicate node [%s].", v.Name) 264 | } 265 | 266 | n, err := s.parseNode(v) 267 | if err != nil { 268 | return err 269 | } 270 | 271 | s.nodes[v.Name] = n 272 | } 273 | 274 | return nil 275 | } 276 | 277 | func (s *Server) parseNode(cfg config.NodeConfig) (*Node, error) { 278 | n := new(Node) 279 | n.server = s 280 | n.cfg = cfg 281 | 282 | n.downAfterNoAlive = time.Duration(cfg.DownAfterNoAlive) * time.Second 283 | 284 | if len(cfg.Master) == 0 { 285 | return nil, fmt.Errorf("must setting master MySQL node.") 286 | } 287 | 288 | var err error 289 | if n.master, err = n.openDB(cfg.Master); err != nil { 290 | return nil, err 291 | } 292 | 293 | n.db = n.master 294 | 295 | if len(cfg.Slave) > 0 { 296 | if n.slave, err = n.openDB(cfg.Slave); err != nil { 297 | log.Error(err.Error()) 298 | n.slave = nil 299 | } 300 | } 301 | 302 | go n.run() 303 | 304 | return n, nil 305 | } 306 | -------------------------------------------------------------------------------- /client/stmt_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestStmt_DropTable(t *testing.T) { 8 | str := `drop table if exists mixer_test_stmt` 9 | 10 | c := newTestConn() 11 | 12 | s, err := c.Prepare(str) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | 17 | if _, err := s.Execute(); err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | s.Close() 22 | } 23 | 24 | func TestStmt_CreateTable(t *testing.T) { 25 | str := `CREATE TABLE IF NOT EXISTS mixer_test_stmt ( 26 | id BIGINT(64) UNSIGNED NOT NULL, 27 | str VARCHAR(256), 28 | f DOUBLE, 29 | e enum("test1", "test2"), 30 | u tinyint unsigned, 31 | i tinyint, 32 | PRIMARY KEY (id) 33 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 34 | 35 | c := newTestConn() 36 | defer c.Close() 37 | 38 | s, err := c.Prepare(str) 39 | 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | if _, err = s.Execute(); err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | s.Close() 49 | } 50 | 51 | func TestStmt_Delete(t *testing.T) { 52 | str := `delete from mixer_test_stmt` 53 | 54 | c := newTestConn() 55 | defer c.Close() 56 | 57 | s, err := c.Prepare(str) 58 | 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | if _, err := s.Execute(); err != nil { 64 | t.Fatal(err) 65 | } 66 | 67 | s.Close() 68 | } 69 | 70 | func TestStmt_Insert(t *testing.T) { 71 | str := `insert into mixer_test_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)` 72 | 73 | c := newTestConn() 74 | defer c.Close() 75 | 76 | s, err := c.Prepare(str) 77 | 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil { 83 | t.Fatal(err) 84 | } else { 85 | if pkg.AffectedRows != 1 { 86 | t.Fatal(pkg.AffectedRows) 87 | } 88 | } 89 | 90 | s.Close() 91 | } 92 | 93 | func TestStmt_Select(t *testing.T) { 94 | str := `select str, f, e from mixer_test_stmt where id = ?` 95 | 96 | c := newTestConn() 97 | defer c.Close() 98 | 99 | s, err := c.Prepare(str) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | if result, err := s.Execute(1); err != nil { 105 | t.Fatal(err) 106 | } else { 107 | if len(result.Values) != 1 { 108 | t.Fatal(len(result.Values)) 109 | } 110 | 111 | if len(result.Fields) != 3 { 112 | t.Fatal(len(result.Fields)) 113 | } 114 | 115 | if str, _ := result.GetString(0, 0); str != "a" { 116 | t.Fatal("invalid str", str) 117 | } 118 | 119 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) { 120 | t.Fatal("invalid f", f) 121 | } 122 | 123 | if e, _ := result.GetString(0, 2); e != "test1" { 124 | t.Fatal("invalid e", e) 125 | } 126 | 127 | if str, _ := result.GetStringByName(0, "str"); str != "a" { 128 | t.Fatal("invalid str", str) 129 | } 130 | 131 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) { 132 | t.Fatal("invalid f", f) 133 | } 134 | 135 | if e, _ := result.GetStringByName(0, "e"); e != "test1" { 136 | t.Fatal("invalid e", e) 137 | } 138 | 139 | } 140 | 141 | s.Close() 142 | } 143 | 144 | func TestStmt_NULL(t *testing.T) { 145 | str := `insert into mixer_test_stmt (id, str, f, e) values (?, ?, ?, ?)` 146 | 147 | c := newTestConn() 148 | defer c.Close() 149 | 150 | s, err := c.Prepare(str) 151 | 152 | if err != nil { 153 | t.Fatal(err) 154 | } 155 | 156 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil { 157 | t.Fatal(err) 158 | } else { 159 | if pkg.AffectedRows != 1 { 160 | t.Fatal(pkg.AffectedRows) 161 | } 162 | } 163 | 164 | s.Close() 165 | 166 | str = `select * from mixer_test_stmt where id = ?` 167 | s, err = c.Prepare(str) 168 | 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | if r, err := s.Execute(2); err != nil { 174 | t.Fatal(err) 175 | } else { 176 | if b, err := r.IsNullByName(0, "id"); err != nil { 177 | t.Fatal(err) 178 | } else if b == true { 179 | t.Fatal(b) 180 | } 181 | 182 | if b, err := r.IsNullByName(0, "str"); err != nil { 183 | t.Fatal(err) 184 | } else if b == false { 185 | t.Fatal(b) 186 | } 187 | 188 | if b, err := r.IsNullByName(0, "f"); err != nil { 189 | t.Fatal(err) 190 | } else if b == true { 191 | t.Fatal(b) 192 | } 193 | 194 | if b, err := r.IsNullByName(0, "e"); err != nil { 195 | t.Fatal(err) 196 | } else if b == false { 197 | t.Fatal(b) 198 | } 199 | } 200 | 201 | s.Close() 202 | } 203 | 204 | func TestStmt_Unsigned(t *testing.T) { 205 | str := `insert into mixer_test_stmt (id, u) values (?, ?)` 206 | 207 | c := newTestConn() 208 | defer c.Close() 209 | 210 | s, err := c.Prepare(str) 211 | 212 | if err != nil { 213 | t.Fatal(err) 214 | } 215 | 216 | if pkg, err := s.Execute(3, uint8(255)); err != nil { 217 | t.Fatal(err) 218 | } else { 219 | if pkg.AffectedRows != 1 { 220 | t.Fatal(pkg.AffectedRows) 221 | } 222 | } 223 | 224 | s.Close() 225 | 226 | str = `select u from mixer_test_stmt where id = ?` 227 | 228 | s, err = c.Prepare(str) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | 233 | if r, err := s.Execute(3); err != nil { 234 | t.Fatal(err) 235 | } else { 236 | if u, err := r.GetUint(0, 0); err != nil { 237 | t.Fatal(err) 238 | } else if u != uint64(255) { 239 | t.Fatal(u) 240 | } 241 | } 242 | 243 | s.Close() 244 | } 245 | 246 | func TestStmt_Signed(t *testing.T) { 247 | str := `insert into mixer_test_stmt (id, i) values (?, ?)` 248 | 249 | c := newTestConn() 250 | defer c.Close() 251 | 252 | s, err := c.Prepare(str) 253 | 254 | if err != nil { 255 | t.Fatal(err) 256 | } 257 | 258 | if _, err := s.Execute(4, 127); err != nil { 259 | t.Fatal(err) 260 | } 261 | 262 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil { 263 | t.Fatal(err) 264 | } 265 | 266 | s.Close() 267 | 268 | } 269 | 270 | func TestStmt_Trans(t *testing.T) { 271 | c := newTestConn() 272 | defer c.Close() 273 | 274 | if _, err := c.Execute(`insert into mixer_test_stmt (id, str) values (1002, "abc")`); err != nil { 275 | t.Fatal(err) 276 | } 277 | 278 | if err := c.Begin(); err != nil { 279 | t.Fatal(err) 280 | } 281 | 282 | str := `select str from mixer_test_stmt where id = ?` 283 | 284 | s, err := c.Prepare(str) 285 | if err != nil { 286 | t.Fatal(err) 287 | } 288 | 289 | if _, err := s.Execute(1002); err != nil { 290 | t.Fatal(err) 291 | } 292 | 293 | if err := c.Commit(); err != nil { 294 | t.Fatal(err) 295 | } 296 | 297 | if r, err := s.Execute(1002); err != nil { 298 | t.Fatal(err) 299 | } else { 300 | if str, _ := r.GetString(0, 0); str != `abc` { 301 | t.Fatal(str) 302 | } 303 | } 304 | 305 | if err := s.Close(); err != nil { 306 | t.Fatal(err) 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /mysql/util.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "math/rand" 9 | "runtime" 10 | "time" 11 | "unicode/utf8" 12 | ) 13 | 14 | func Pstack() string { 15 | buf := make([]byte, 1024) 16 | n := runtime.Stack(buf, false) 17 | return string(buf[0:n]) 18 | } 19 | 20 | func CalcPassword(scramble, password []byte) []byte { 21 | if len(password) == 0 { 22 | return nil 23 | } 24 | 25 | // stage1Hash = SHA1(password) 26 | crypt := sha1.New() 27 | crypt.Write(password) 28 | stage1 := crypt.Sum(nil) 29 | 30 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) 31 | // inner Hash 32 | crypt.Reset() 33 | crypt.Write(stage1) 34 | hash := crypt.Sum(nil) 35 | 36 | // outer Hash 37 | crypt.Reset() 38 | crypt.Write(scramble) 39 | crypt.Write(hash) 40 | scramble = crypt.Sum(nil) 41 | 42 | // token = scrambleHash XOR stage1Hash 43 | for i := range scramble { 44 | scramble[i] ^= stage1[i] 45 | } 46 | return scramble 47 | } 48 | 49 | func RandomBuf(size int) []byte { 50 | buf := make([]byte, size) 51 | rand.Seed(time.Now().UTC().UnixNano()) 52 | for i := 0; i < size; i++ { 53 | buf[i] = byte(rand.Intn(127)) 54 | if buf[i] == 0 || buf[i] == byte('$') { 55 | buf[i]++ 56 | } 57 | } 58 | return buf 59 | } 60 | 61 | func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { 62 | switch b[0] { 63 | 64 | // 251: NULL 65 | case 0xfb: 66 | n = 1 67 | isNull = true 68 | return 69 | 70 | // 252: value of following 2 71 | case 0xfc: 72 | num = uint64(b[1]) | uint64(b[2])<<8 73 | n = 3 74 | return 75 | 76 | // 253: value of following 3 77 | case 0xfd: 78 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 79 | n = 4 80 | return 81 | 82 | // 254: value of following 8 83 | case 0xfe: 84 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | 85 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | 86 | uint64(b[7])<<48 | uint64(b[8])<<56 87 | n = 9 88 | return 89 | } 90 | 91 | // 0-250: value of first byte 92 | num = uint64(b[0]) 93 | n = 1 94 | return 95 | } 96 | 97 | func PutLengthEncodedInt(n uint64) []byte { 98 | switch { 99 | case n <= 250: 100 | return []byte{byte(n)} 101 | 102 | case n <= 0xffff: 103 | return []byte{0xfc, byte(n), byte(n >> 8)} 104 | 105 | case n <= 0xffffff: 106 | return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} 107 | 108 | case n <= 0xffffffffffffffff: 109 | return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), 110 | byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)} 111 | } 112 | return nil 113 | } 114 | 115 | func LengthEnodedString(b []byte) ([]byte, bool, int, error) { 116 | // Get length 117 | num, isNull, n := LengthEncodedInt(b) 118 | if num < 1 { 119 | return nil, isNull, n, nil 120 | } 121 | 122 | n += int(num) 123 | 124 | // Check data length 125 | if len(b) >= n { 126 | return b[n-int(num) : n], false, n, nil 127 | } 128 | return nil, false, n, io.EOF 129 | } 130 | 131 | func SkipLengthEnodedString(b []byte) (int, error) { 132 | // Get length 133 | num, _, n := LengthEncodedInt(b) 134 | if num < 1 { 135 | return n, nil 136 | } 137 | 138 | n += int(num) 139 | 140 | // Check data length 141 | if len(b) >= n { 142 | return n, nil 143 | } 144 | return n, io.EOF 145 | } 146 | 147 | func PutLengthEncodedString(b []byte) []byte { 148 | data := make([]byte, 0, len(b)+9) 149 | data = append(data, PutLengthEncodedInt(uint64(len(b)))...) 150 | data = append(data, b...) 151 | return data 152 | } 153 | 154 | func Uint16ToBytes(n uint16) []byte { 155 | return []byte{ 156 | byte(n), 157 | byte(n >> 8), 158 | } 159 | } 160 | 161 | func Uint32ToBytes(n uint32) []byte { 162 | return []byte{ 163 | byte(n), 164 | byte(n >> 8), 165 | byte(n >> 16), 166 | byte(n >> 24), 167 | } 168 | } 169 | 170 | func Uint64ToBytes(n uint64) []byte { 171 | return []byte{ 172 | byte(n), 173 | byte(n >> 8), 174 | byte(n >> 16), 175 | byte(n >> 24), 176 | byte(n >> 32), 177 | byte(n >> 40), 178 | byte(n >> 48), 179 | byte(n >> 56), 180 | } 181 | } 182 | 183 | func FormatBinaryDate(n int, data []byte) ([]byte, error) { 184 | switch n { 185 | case 0: 186 | return []byte("0000-00-00"), nil 187 | case 4: 188 | return []byte(fmt.Sprintf("%04d-%02d-%02d", 189 | binary.LittleEndian.Uint16(data[:2]), 190 | data[2], 191 | data[3])), nil 192 | default: 193 | return nil, fmt.Errorf("invalid date packet length %d", n) 194 | } 195 | } 196 | 197 | func FormatBinaryDateTime(n int, data []byte) ([]byte, error) { 198 | switch n { 199 | case 0: 200 | return []byte("0000-00-00 00:00:00"), nil 201 | case 4: 202 | return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00", 203 | binary.LittleEndian.Uint16(data[:2]), 204 | data[2], 205 | data[3])), nil 206 | case 7: 207 | return []byte(fmt.Sprintf( 208 | "%04d-%02d-%02d %02d:%02d:%02d", 209 | binary.LittleEndian.Uint16(data[:2]), 210 | data[2], 211 | data[3], 212 | data[4], 213 | data[5], 214 | data[6])), nil 215 | case 11: 216 | return []byte(fmt.Sprintf( 217 | "%04d-%02d-%02d %02d:%02d:%02d.%06d", 218 | binary.LittleEndian.Uint16(data[:2]), 219 | data[2], 220 | data[3], 221 | data[4], 222 | data[5], 223 | data[6], 224 | binary.LittleEndian.Uint32(data[7:11]))), nil 225 | default: 226 | return nil, fmt.Errorf("invalid datetime packet length %d", n) 227 | } 228 | } 229 | 230 | func FormatBinaryTime(n int, data []byte) ([]byte, error) { 231 | if n == 0 { 232 | return []byte("0000-00-00"), nil 233 | } 234 | 235 | var sign byte 236 | if data[0] == 1 { 237 | sign = byte('-') 238 | } 239 | 240 | switch n { 241 | case 8: 242 | return []byte(fmt.Sprintf( 243 | "%c%02d:%02d:%02d", 244 | sign, 245 | uint16(data[1])*24+uint16(data[5]), 246 | data[6], 247 | data[7], 248 | )), nil 249 | case 12: 250 | return []byte(fmt.Sprintf( 251 | "%c%02d:%02d:%02d.%06d", 252 | sign, 253 | uint16(data[1])*24+uint16(data[5]), 254 | data[6], 255 | data[7], 256 | binary.LittleEndian.Uint32(data[8:12]), 257 | )), nil 258 | default: 259 | return nil, fmt.Errorf("invalid time packet length %d", n) 260 | } 261 | } 262 | 263 | var ( 264 | DONTESCAPE = byte(255) 265 | 266 | EncodeMap [256]byte 267 | ) 268 | 269 | func Escape(sql string) string { 270 | dest := make([]byte, 0, 2*len(sql)) 271 | 272 | for i, w := 0, 0; i < len(sql); i += w { 273 | runeValue, width := utf8.DecodeRuneInString(sql[i:]) 274 | if c := EncodeMap[byte(runeValue)]; c == DONTESCAPE { 275 | dest = append(dest, sql[i:i+width]...) 276 | } else { 277 | dest = append(dest, '\\', c) 278 | } 279 | w = width 280 | } 281 | 282 | return string(dest) 283 | } 284 | 285 | var encodeRef = map[byte]byte{ 286 | '\x00': '0', 287 | '\'': '\'', 288 | '"': '"', 289 | '\b': 'b', 290 | '\n': 'n', 291 | '\r': 'r', 292 | '\t': 't', 293 | 26: 'Z', // ctl-Z 294 | '\\': '\\', 295 | } 296 | 297 | func init() { 298 | for i := range EncodeMap { 299 | EncodeMap[i] = DONTESCAPE 300 | } 301 | for i := range EncodeMap { 302 | if to, ok := encodeRef[byte(i)]; ok { 303 | EncodeMap[byte(i)] = to 304 | } 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /proxy/conn.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "github.com/siddontang/go-log/log" 8 | "github.com/siddontang/mixer/client" 9 | "github.com/siddontang/mixer/hack" 10 | . "github.com/siddontang/mixer/mysql" 11 | "net" 12 | "runtime" 13 | "sync" 14 | "sync/atomic" 15 | ) 16 | 17 | var DEFAULT_CAPABILITY uint32 = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | 18 | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | 19 | CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION 20 | 21 | //client <-> proxy 22 | type Conn struct { 23 | sync.Mutex 24 | 25 | pkg *PacketIO 26 | 27 | c net.Conn 28 | 29 | server *Server 30 | 31 | capability uint32 32 | 33 | connectionId uint32 34 | 35 | status uint16 36 | collation CollationId 37 | charset string 38 | 39 | user string 40 | db string 41 | 42 | salt []byte 43 | 44 | schema *Schema 45 | 46 | txConns map[*Node]*client.SqlConn 47 | 48 | closed bool 49 | 50 | lastInsertId int64 51 | affectedRows int64 52 | 53 | stmtId uint32 54 | 55 | stmts map[uint32]*Stmt 56 | } 57 | 58 | var baseConnId uint32 = 10000 59 | 60 | func (s *Server) newConn(co net.Conn) *Conn { 61 | c := new(Conn) 62 | 63 | c.c = co 64 | 65 | c.pkg = NewPacketIO(co) 66 | 67 | c.server = s 68 | 69 | c.c = co 70 | c.pkg.Sequence = 0 71 | 72 | c.connectionId = atomic.AddUint32(&baseConnId, 1) 73 | 74 | c.status = SERVER_STATUS_AUTOCOMMIT 75 | 76 | c.salt = RandomBuf(20) 77 | 78 | c.txConns = make(map[*Node]*client.SqlConn) 79 | 80 | c.closed = false 81 | 82 | c.collation = DEFAULT_COLLATION_ID 83 | c.charset = DEFAULT_CHARSET 84 | 85 | c.stmtId = 0 86 | c.stmts = make(map[uint32]*Stmt) 87 | 88 | return c 89 | } 90 | 91 | func (c *Conn) Handshake() error { 92 | if err := c.writeInitialHandshake(); err != nil { 93 | log.Error("send initial handshake error %s", err.Error()) 94 | return err 95 | } 96 | 97 | if err := c.readHandshakeResponse(); err != nil { 98 | log.Error("recv handshake response error %s", err.Error()) 99 | 100 | c.writeError(err) 101 | 102 | return err 103 | } 104 | 105 | if err := c.writeOK(nil); err != nil { 106 | log.Error("write ok fail %s", err.Error()) 107 | return err 108 | } 109 | 110 | c.pkg.Sequence = 0 111 | 112 | return nil 113 | } 114 | 115 | func (c *Conn) Close() error { 116 | if c.closed { 117 | return nil 118 | } 119 | 120 | c.c.Close() 121 | 122 | c.rollback() 123 | 124 | c.closed = true 125 | 126 | return nil 127 | } 128 | 129 | func (c *Conn) writeInitialHandshake() error { 130 | data := make([]byte, 4, 128) 131 | 132 | //min version 10 133 | data = append(data, 10) 134 | 135 | //server version[00] 136 | data = append(data, ServerVersion...) 137 | data = append(data, 0) 138 | 139 | //connection id 140 | data = append(data, byte(c.connectionId), byte(c.connectionId>>8), byte(c.connectionId>>16), byte(c.connectionId>>24)) 141 | 142 | //auth-plugin-data-part-1 143 | data = append(data, c.salt[0:8]...) 144 | 145 | //filter [00] 146 | data = append(data, 0) 147 | 148 | //capability flag lower 2 bytes, using default capability here 149 | data = append(data, byte(DEFAULT_CAPABILITY), byte(DEFAULT_CAPABILITY>>8)) 150 | 151 | //charset, utf-8 default 152 | data = append(data, uint8(DEFAULT_COLLATION_ID)) 153 | 154 | //status 155 | data = append(data, byte(c.status), byte(c.status>>8)) 156 | 157 | //below 13 byte may not be used 158 | //capability flag upper 2 bytes, using default capability here 159 | data = append(data, byte(DEFAULT_CAPABILITY>>16), byte(DEFAULT_CAPABILITY>>24)) 160 | 161 | //filter [0x15], for wireshark dump, value is 0x15 162 | data = append(data, 0x15) 163 | 164 | //reserved 10 [00] 165 | data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) 166 | 167 | //auth-plugin-data-part-2 168 | data = append(data, c.salt[8:]...) 169 | 170 | //filter [00] 171 | data = append(data, 0) 172 | 173 | return c.writePacket(data) 174 | } 175 | 176 | func (c *Conn) readPacket() ([]byte, error) { 177 | return c.pkg.ReadPacket() 178 | } 179 | 180 | func (c *Conn) writePacket(data []byte) error { 181 | return c.pkg.WritePacket(data) 182 | } 183 | 184 | func (c *Conn) readHandshakeResponse() error { 185 | data, err := c.readPacket() 186 | 187 | if err != nil { 188 | return err 189 | } 190 | 191 | pos := 0 192 | 193 | //capability 194 | c.capability = binary.LittleEndian.Uint32(data[:4]) 195 | pos += 4 196 | 197 | //skip max packet size 198 | pos += 4 199 | 200 | //charset, skip, if you want to use another charset, use set names 201 | //c.collation = CollationId(data[pos]) 202 | pos++ 203 | 204 | //skip reserved 23[00] 205 | pos += 23 206 | 207 | //user name 208 | c.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) 209 | pos += len(c.user) + 1 210 | 211 | //auth length and auth 212 | authLen := int(data[pos]) 213 | pos++ 214 | auth := data[pos : pos+authLen] 215 | 216 | checkAuth := CalcPassword(c.salt, []byte(c.server.cfg.Password)) 217 | 218 | if !bytes.Equal(auth, checkAuth) { 219 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes") 220 | } 221 | 222 | pos += authLen 223 | 224 | if c.capability&CLIENT_CONNECT_WITH_DB > 0 { 225 | if len(data[pos:]) == 0 { 226 | return nil 227 | } 228 | 229 | db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) 230 | pos += len(c.db) + 1 231 | 232 | if err := c.useDB(db); err != nil { 233 | return err 234 | } 235 | } 236 | 237 | return nil 238 | } 239 | 240 | func (c *Conn) Run() { 241 | defer func() { 242 | r := recover() 243 | if err, ok := r.(error); ok { 244 | const size = 4096 245 | buf := make([]byte, size) 246 | buf = buf[:runtime.Stack(buf, false)] 247 | 248 | log.Error("%v, %s", err, buf) 249 | } 250 | 251 | c.Close() 252 | }() 253 | 254 | for { 255 | data, err := c.readPacket() 256 | 257 | if err != nil { 258 | return 259 | } 260 | 261 | if err := c.dispatch(data); err != nil { 262 | log.Error("dispatch error %s", err.Error()) 263 | if err != ErrBadConn { 264 | c.writeError(err) 265 | } 266 | } 267 | 268 | if c.closed { 269 | return 270 | } 271 | 272 | c.pkg.Sequence = 0 273 | } 274 | } 275 | 276 | func (c *Conn) dispatch(data []byte) error { 277 | cmd := data[0] 278 | data = data[1:] 279 | 280 | switch cmd { 281 | case COM_QUIT: 282 | c.Close() 283 | return nil 284 | case COM_QUERY: 285 | return c.handleQuery(hack.String(data)) 286 | case COM_PING: 287 | return c.writeOK(nil) 288 | case COM_INIT_DB: 289 | if err := c.useDB(hack.String(data)); err != nil { 290 | return err 291 | } else { 292 | return c.writeOK(nil) 293 | } 294 | case COM_FIELD_LIST: 295 | return c.handleFieldList(data) 296 | case COM_STMT_PREPARE: 297 | return c.handleStmtPrepare(hack.String(data)) 298 | case COM_STMT_EXECUTE: 299 | return c.handleStmtExecute(data) 300 | case COM_STMT_CLOSE: 301 | return c.handleStmtClose(data) 302 | case COM_STMT_SEND_LONG_DATA: 303 | return c.handleStmtSendLongData(data) 304 | case COM_STMT_RESET: 305 | return c.handleStmtReset(data) 306 | default: 307 | msg := fmt.Sprintf("command %d not supported now", cmd) 308 | return NewError(ER_UNKNOWN_ERROR, msg) 309 | } 310 | 311 | return nil 312 | } 313 | 314 | func (c *Conn) useDB(db string) error { 315 | if s := c.server.getSchema(db); s == nil { 316 | return NewDefaultError(ER_BAD_DB_ERROR, db) 317 | } else { 318 | c.schema = s 319 | c.db = db 320 | } 321 | return nil 322 | } 323 | 324 | func (c *Conn) writeOK(r *Result) error { 325 | if r == nil { 326 | r = &Result{Status: c.status} 327 | } 328 | data := make([]byte, 4, 32) 329 | 330 | data = append(data, OK_HEADER) 331 | 332 | data = append(data, PutLengthEncodedInt(r.AffectedRows)...) 333 | data = append(data, PutLengthEncodedInt(r.InsertId)...) 334 | 335 | if c.capability&CLIENT_PROTOCOL_41 > 0 { 336 | data = append(data, byte(r.Status), byte(r.Status>>8)) 337 | data = append(data, 0, 0) 338 | } 339 | 340 | return c.writePacket(data) 341 | } 342 | 343 | func (c *Conn) writeError(e error) error { 344 | var m *SqlError 345 | var ok bool 346 | if m, ok = e.(*SqlError); !ok { 347 | m = NewError(ER_UNKNOWN_ERROR, e.Error()) 348 | } 349 | 350 | data := make([]byte, 4, 16+len(m.Message)) 351 | 352 | data = append(data, ERR_HEADER) 353 | data = append(data, byte(m.Code), byte(m.Code>>8)) 354 | 355 | if c.capability&CLIENT_PROTOCOL_41 > 0 { 356 | data = append(data, '#') 357 | data = append(data, m.State...) 358 | } 359 | 360 | data = append(data, m.Message...) 361 | 362 | return c.writePacket(data) 363 | } 364 | 365 | func (c *Conn) writeEOF(status uint16) error { 366 | data := make([]byte, 4, 9) 367 | 368 | data = append(data, EOF_HEADER) 369 | if c.capability&CLIENT_PROTOCOL_41 > 0 { 370 | data = append(data, 0, 0) 371 | data = append(data, byte(status), byte(status>>8)) 372 | } 373 | 374 | return c.writePacket(data) 375 | } 376 | -------------------------------------------------------------------------------- /proxy/conn_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | . "github.com/siddontang/mixer/mysql" 5 | "testing" 6 | ) 7 | 8 | func TestConn_Handshake(t *testing.T) { 9 | c := newTestDBConn(t) 10 | 11 | if err := c.Ping(); err != nil { 12 | t.Fatal(err) 13 | } 14 | 15 | c.Close() 16 | } 17 | 18 | func TestConn_DeleteTable(t *testing.T) { 19 | server := newTestServer(t) 20 | n := server.nodes["node1"] 21 | c, err := n.getMasterConn() 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | c.UseDB("mixer") 26 | if _, err := c.Execute(`drop table if exists mixer_test_proxy_conn`); err != nil { 27 | t.Fatal(err) 28 | } 29 | c.Close() 30 | } 31 | 32 | func TestConn_CreateTable(t *testing.T) { 33 | s := `CREATE TABLE IF NOT EXISTS mixer_test_proxy_conn ( 34 | id BIGINT(64) UNSIGNED NOT NULL, 35 | str VARCHAR(256), 36 | f DOUBLE, 37 | e enum("test1", "test2"), 38 | u tinyint unsigned, 39 | i tinyint, 40 | ni tinyint, 41 | PRIMARY KEY (id) 42 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 43 | 44 | server := newTestServer(t) 45 | n := server.nodes["node1"] 46 | c, err := n.getMasterConn() 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | 51 | c.UseDB("mixer") 52 | defer c.Close() 53 | if _, err := c.Execute(s); err != nil { 54 | t.Fatal(err) 55 | } 56 | } 57 | 58 | func TestConn_Insert(t *testing.T) { 59 | s := `insert into mixer_test_proxy_conn (id, str, f, e, u, i) values(1, "abc", 3.14, "test1", 255, -127)` 60 | 61 | c := newTestDBConn(t) 62 | defer c.Close() 63 | 64 | if r, err := c.Execute(s); err != nil { 65 | t.Fatal(err) 66 | } else { 67 | if r.AffectedRows != 1 { 68 | t.Fatal(r.AffectedRows) 69 | } 70 | } 71 | } 72 | 73 | func TestConn_Select(t *testing.T) { 74 | s := `select str, f, e, u, i, ni from mixer_test_proxy_conn where id = 1` 75 | 76 | c := newTestDBConn(t) 77 | defer c.Close() 78 | 79 | if r, err := c.Execute(s); err != nil { 80 | t.Fatal(err) 81 | } else { 82 | if r.RowNumber() != 1 { 83 | t.Fatal(r.RowNumber()) 84 | } 85 | 86 | if r.ColumnNumber() != 6 { 87 | t.Fatal(r.ColumnNumber()) 88 | } 89 | 90 | if v, _ := r.GetString(0, 0); v != `abc` { 91 | t.Fatal(v) 92 | } 93 | 94 | if v, _ := r.GetFloat(0, 1); v != 3.14 { 95 | t.Fatal(v) 96 | } 97 | 98 | if v, _ := r.GetString(0, 2); v != `test1` { 99 | t.Fatal(v) 100 | } 101 | 102 | if v, _ := r.GetUint(0, 3); v != 255 { 103 | t.Fatal(v) 104 | } 105 | 106 | if v, _ := r.GetInt(0, 4); v != -127 { 107 | t.Fatal(v) 108 | } 109 | 110 | if v, _ := r.IsNull(0, 5); !v { 111 | t.Fatal("ni not null") 112 | } 113 | } 114 | } 115 | 116 | func TestConn_Update(t *testing.T) { 117 | s := `update mixer_test_proxy_conn set str = "123" where id = 1` 118 | 119 | c := newTestDBConn(t) 120 | defer c.Close() 121 | 122 | if _, err := c.Execute(s); err != nil { 123 | t.Fatal(err) 124 | } 125 | 126 | if r, err := c.Execute(`select str from mixer_test_proxy_conn where id = 1`); err != nil { 127 | t.Fatal(err) 128 | } else { 129 | if v, _ := r.GetString(0, 0); v != `123` { 130 | t.Fatal(v) 131 | } 132 | } 133 | } 134 | 135 | func TestConn_Replace(t *testing.T) { 136 | s := `replace into mixer_test_proxy_conn (id, str, f) values(1, "abc", 3.14159)` 137 | 138 | c := newTestDBConn(t) 139 | defer c.Close() 140 | 141 | if r, err := c.Execute(s); err != nil { 142 | t.Fatal(err) 143 | } else { 144 | if r.AffectedRows != 2 { 145 | t.Fatal(r.AffectedRows) 146 | } 147 | } 148 | 149 | s = `replace into mixer_test_proxy_conn (id, str) values(2, "abcb")` 150 | 151 | if r, err := c.Execute(s); err != nil { 152 | t.Fatal(err) 153 | } else { 154 | if r.AffectedRows != 1 { 155 | t.Fatal(r.AffectedRows) 156 | } 157 | } 158 | 159 | s = `select str, f from mixer_test_proxy_conn` 160 | 161 | if r, err := c.Execute(s); err != nil { 162 | t.Fatal(err) 163 | } else { 164 | if v, _ := r.GetString(0, 0); v != `abc` { 165 | t.Fatal(v) 166 | } 167 | 168 | if v, _ := r.GetString(1, 0); v != `abcb` { 169 | t.Fatal(v) 170 | } 171 | 172 | if v, _ := r.GetFloat(0, 1); v != 3.14159 { 173 | t.Fatal(v) 174 | } 175 | 176 | if v, _ := r.IsNull(1, 1); !v { 177 | t.Fatal(v) 178 | } 179 | } 180 | } 181 | 182 | func TestConn_Delete(t *testing.T) { 183 | s := `delete from mixer_test_proxy_conn where id = 100000` 184 | 185 | c := newTestDBConn(t) 186 | defer c.Close() 187 | 188 | if r, err := c.Execute(s); err != nil { 189 | t.Fatal(err) 190 | } else { 191 | if r.AffectedRows != 0 { 192 | t.Fatal(r.AffectedRows) 193 | } 194 | } 195 | } 196 | 197 | func TestConn_SetAutoCommit(t *testing.T) { 198 | c := newTestDBConn(t) 199 | defer c.Close() 200 | 201 | if r, err := c.Execute("set autocommit = 1"); err != nil { 202 | t.Fatal(err) 203 | } else { 204 | if !(r.Status&SERVER_STATUS_AUTOCOMMIT > 0) { 205 | t.Fatal(r.Status) 206 | } 207 | } 208 | 209 | if r, err := c.Execute("set autocommit = 0"); err != nil { 210 | t.Fatal(err) 211 | } else { 212 | if !(r.Status&SERVER_STATUS_AUTOCOMMIT == 0) { 213 | t.Fatal(r.Status) 214 | } 215 | } 216 | } 217 | 218 | func TestConn_Trans(t *testing.T) { 219 | c1 := newTestDBConn(t) 220 | defer c1.Close() 221 | 222 | c2 := newTestDBConn(t) 223 | defer c2.Close() 224 | 225 | var err error 226 | 227 | if err = c1.Begin(); err != nil { 228 | t.Fatal(err) 229 | } 230 | 231 | if err = c2.Begin(); err != nil { 232 | t.Fatal(err) 233 | } 234 | 235 | if _, err := c1.Execute(`insert into mixer_test_proxy_conn (id, str) values (111, "abc")`); err != nil { 236 | t.Fatal(err) 237 | } 238 | 239 | if r, err := c2.Execute(`select str from mixer_test_proxy_conn where id = 111`); err != nil { 240 | t.Fatal(err) 241 | } else { 242 | if r.RowNumber() != 0 { 243 | t.Fatal(r.RowNumber()) 244 | } 245 | } 246 | 247 | if err := c1.Commit(); err != nil { 248 | t.Fatal(err) 249 | } 250 | 251 | if err := c2.Commit(); err != nil { 252 | t.Fatal(err) 253 | } 254 | 255 | if r, err := c1.Execute(`select str from mixer_test_proxy_conn where id = 111`); err != nil { 256 | t.Fatal(err) 257 | } else { 258 | if r.RowNumber() != 1 { 259 | t.Fatal(r.RowNumber()) 260 | } 261 | 262 | if v, _ := r.GetString(0, 0); v != `abc` { 263 | t.Fatal(v) 264 | } 265 | } 266 | } 267 | 268 | func TestConn_SetNames(t *testing.T) { 269 | c := newTestDBConn(t) 270 | defer c.Close() 271 | 272 | if err := c.SetCharset("gb2312"); err != nil { 273 | t.Fatal(err) 274 | } 275 | } 276 | 277 | func TestConn_LastInsertId(t *testing.T) { 278 | s := `CREATE TABLE IF NOT EXISTS mixer_test_conn_id ( 279 | id BIGINT(64) UNSIGNED AUTO_INCREMENT NOT NULL, 280 | str VARCHAR(256), 281 | PRIMARY KEY (id) 282 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 283 | 284 | server := newTestServer(t) 285 | n := server.nodes["node1"] 286 | 287 | c1, err := n.getMasterConn() 288 | if err != nil { 289 | t.Fatal(err) 290 | } 291 | 292 | if _, err := c1.Execute(s); err != nil { 293 | t.Fatal(err) 294 | } 295 | 296 | c1.Close() 297 | 298 | c := newTestDBConn(t) 299 | defer c.Close() 300 | 301 | r, err := c.Execute(`insert into mixer_test_conn_id (str) values ("abc")`) 302 | if err != nil { 303 | t.Fatal(err) 304 | } 305 | 306 | lastId := r.InsertId 307 | if r, err := c.Execute(`select last_insert_id()`); err != nil { 308 | t.Fatal(err) 309 | } else { 310 | if r.ColumnNumber() != 1 { 311 | t.Fatal(r.ColumnNumber()) 312 | } 313 | 314 | if v, _ := r.GetUint(0, 0); v != lastId { 315 | t.Fatal(v) 316 | } 317 | } 318 | 319 | if r, err := c.Execute(`select last_insert_id() as a`); err != nil { 320 | t.Fatal(err) 321 | } else { 322 | if string(r.Fields[0].Name) != "a" { 323 | t.Fatal(string(r.Fields[0].Name)) 324 | } 325 | 326 | if v, _ := r.GetUint(0, 0); v != lastId { 327 | t.Fatal(v) 328 | } 329 | } 330 | 331 | c1, _ = n.getMasterConn() 332 | 333 | if _, err := c1.Execute(`drop table if exists mixer_test_conn_id`); err != nil { 334 | t.Fatal(err) 335 | } 336 | 337 | c1.Close() 338 | } 339 | 340 | func TestConn_RowCount(t *testing.T) { 341 | c := newTestDBConn(t) 342 | defer c.Close() 343 | 344 | r, err := c.Execute(`insert into mixer_test_proxy_conn (id, str) values (1002, "abc")`) 345 | if err != nil { 346 | t.Fatal(err) 347 | } 348 | 349 | row := r.AffectedRows 350 | 351 | if r, err := c.Execute("select row_count()"); err != nil { 352 | t.Fatal(err) 353 | } else { 354 | if v, _ := r.GetUint(0, 0); v != row { 355 | t.Fatal(v) 356 | } 357 | } 358 | 359 | if r, err := c.Execute("select row_count() as b"); err != nil { 360 | t.Fatal(err) 361 | } else { 362 | if v, _ := r.GetInt(0, 0); v != -1 { 363 | t.Fatal(v) 364 | } 365 | } 366 | } 367 | 368 | func TestConn_SelectVersion(t *testing.T) { 369 | c := newTestDBConn(t) 370 | defer c.Close() 371 | 372 | if r, err := c.Execute("select version()"); err != nil { 373 | t.Fatal(err) 374 | } else { 375 | if v, _ := r.GetString(0, 0); v != ServerVersion { 376 | t.Fatal(v) 377 | } 378 | } 379 | } 380 | -------------------------------------------------------------------------------- /sqltypes/sqltypes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package sqltypes implements interfaces and types that represent SQL values. 6 | package sqltypes 7 | 8 | import ( 9 | "encoding/base64" 10 | "encoding/gob" 11 | "encoding/json" 12 | "fmt" 13 | "strconv" 14 | "time" 15 | 16 | "github.com/siddontang/mixer/hack" 17 | ) 18 | 19 | var ( 20 | NULL = Value{} 21 | DONTESCAPE = byte(255) 22 | nullstr = []byte("null") 23 | ) 24 | 25 | // BinWriter interface is used for encoding values. 26 | // Types like bytes.Buffer conform to this interface. 27 | // We expect the writer objects to be in-memory buffers. 28 | // So, we don't expect the write operations to fail. 29 | type BinWriter interface { 30 | Write([]byte) (int, error) 31 | WriteByte(byte) error 32 | } 33 | 34 | // Value can store any SQL value. NULL is stored as nil. 35 | type Value struct { 36 | Inner InnerValue 37 | } 38 | 39 | // Numeric represents non-fractional SQL number. 40 | type Numeric []byte 41 | 42 | // Fractional represents fractional types like float and decimal 43 | // It's functionally equivalent to Numeric other than how it's constructed 44 | type Fractional []byte 45 | 46 | // String represents any SQL type that needs to be represented using quotes. 47 | type String []byte 48 | 49 | // MakeNumeric makes a Numeric from a []byte without validation. 50 | func MakeNumeric(b []byte) Value { 51 | return Value{Numeric(b)} 52 | } 53 | 54 | // MakeFractional makes a Fractional value from a []byte without validation. 55 | func MakeFractional(b []byte) Value { 56 | return Value{Fractional(b)} 57 | } 58 | 59 | // MakeString makes a String value from a []byte. 60 | func MakeString(b []byte) Value { 61 | return Value{String(b)} 62 | } 63 | 64 | // Raw returns the raw bytes. All types are currently implemented as []byte. 65 | func (v Value) Raw() []byte { 66 | if v.Inner == nil { 67 | return nil 68 | } 69 | return v.Inner.raw() 70 | } 71 | 72 | // String returns the raw value as a string 73 | func (v Value) String() string { 74 | if v.Inner == nil { 75 | return "" 76 | } 77 | return hack.String(v.Inner.raw()) 78 | } 79 | 80 | // ParseInt64 will parse a Numeric value into an int64 81 | func (v Value) ParseInt64() (val int64, err error) { 82 | if v.Inner == nil { 83 | return 0, fmt.Errorf("value is null") 84 | } 85 | n, ok := v.Inner.(Numeric) 86 | if !ok { 87 | return 0, fmt.Errorf("value is not Numeric") 88 | } 89 | return strconv.ParseInt(string(n.raw()), 10, 64) 90 | } 91 | 92 | // ParseUint64 will parse a Numeric value into a uint64 93 | func (v Value) ParseUint64() (val uint64, err error) { 94 | if v.Inner == nil { 95 | return 0, fmt.Errorf("value is null") 96 | } 97 | n, ok := v.Inner.(Numeric) 98 | if !ok { 99 | return 0, fmt.Errorf("value is not Numeric") 100 | } 101 | return strconv.ParseUint(string(n.raw()), 10, 64) 102 | } 103 | 104 | // EncodeSql encodes the value into an SQL statement. Can be binary. 105 | func (v Value) EncodeSql(b BinWriter) { 106 | if v.Inner == nil { 107 | if _, err := b.Write(nullstr); err != nil { 108 | panic(err) 109 | } 110 | } else { 111 | v.Inner.encodeSql(b) 112 | } 113 | } 114 | 115 | // EncodeAscii encodes the value using 7-bit clean ascii bytes. 116 | func (v Value) EncodeAscii(b BinWriter) { 117 | if v.Inner == nil { 118 | if _, err := b.Write(nullstr); err != nil { 119 | panic(err) 120 | } 121 | } else { 122 | v.Inner.encodeAscii(b) 123 | } 124 | } 125 | 126 | func (v Value) IsNull() bool { 127 | return v.Inner == nil 128 | } 129 | 130 | func (v Value) IsNumeric() (ok bool) { 131 | if v.Inner != nil { 132 | _, ok = v.Inner.(Numeric) 133 | } 134 | return ok 135 | } 136 | 137 | func (v Value) IsFractional() (ok bool) { 138 | if v.Inner != nil { 139 | _, ok = v.Inner.(Fractional) 140 | } 141 | return ok 142 | } 143 | 144 | func (v Value) IsString() (ok bool) { 145 | if v.Inner != nil { 146 | _, ok = v.Inner.(String) 147 | } 148 | return ok 149 | } 150 | 151 | // MarshalJSON should only be used for testing. 152 | // It's not a complete implementation. 153 | func (v Value) MarshalJSON() ([]byte, error) { 154 | return json.Marshal(v.Inner) 155 | } 156 | 157 | // UnmarshalJSON should only be used for testing. 158 | // It's not a complete implementation. 159 | func (v *Value) UnmarshalJSON(b []byte) error { 160 | if len(b) == 0 { 161 | return fmt.Errorf("error unmarshaling empty bytes") 162 | } 163 | var val interface{} 164 | var err error 165 | switch b[0] { 166 | case '-': 167 | var ival int64 168 | err = json.Unmarshal(b, &ival) 169 | val = ival 170 | case '"': 171 | var bval []byte 172 | err = json.Unmarshal(b, &bval) 173 | val = bval 174 | case 'n': // null 175 | err = json.Unmarshal(b, &val) 176 | default: 177 | var uval uint64 178 | err = json.Unmarshal(b, &uval) 179 | val = uval 180 | } 181 | if err != nil { 182 | return err 183 | } 184 | *v, err = BuildValue(val) 185 | return err 186 | } 187 | 188 | // InnerValue defines methods that need to be supported by all non-null value types. 189 | type InnerValue interface { 190 | raw() []byte 191 | encodeSql(BinWriter) 192 | encodeAscii(BinWriter) 193 | } 194 | 195 | func BuildValue(goval interface{}) (v Value, err error) { 196 | switch bindVal := goval.(type) { 197 | case nil: 198 | // no op 199 | case int: 200 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 201 | case int32: 202 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 203 | case int64: 204 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 205 | case uint: 206 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 207 | case uint32: 208 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 209 | case uint64: 210 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 211 | case float64: 212 | v = Value{Fractional(strconv.AppendFloat(nil, bindVal, 'f', -1, 64))} 213 | case string: 214 | v = Value{String([]byte(bindVal))} 215 | case []byte: 216 | v = Value{String(bindVal)} 217 | case time.Time: 218 | v = Value{String([]byte(bindVal.Format("'2006-01-02 15:04:05'")))} 219 | case Numeric, Fractional, String: 220 | v = Value{bindVal.(InnerValue)} 221 | case Value: 222 | v = bindVal 223 | default: 224 | return Value{}, fmt.Errorf("unsupported bind variable type %T: %v", goval, goval) 225 | } 226 | return v, nil 227 | } 228 | 229 | // BuildNumeric builds a Numeric type that represents any whole number. 230 | // It normalizes the representation to ensure 1:1 mapping between the 231 | // number and its representation. 232 | func BuildNumeric(val string) (n Value, err error) { 233 | if val[0] == '-' || val[0] == '+' { 234 | signed, err := strconv.ParseInt(val, 0, 64) 235 | if err != nil { 236 | return Value{}, err 237 | } 238 | n = Value{Numeric(strconv.AppendInt(nil, signed, 10))} 239 | } else { 240 | unsigned, err := strconv.ParseUint(val, 0, 64) 241 | if err != nil { 242 | return Value{}, err 243 | } 244 | n = Value{Numeric(strconv.AppendUint(nil, unsigned, 10))} 245 | } 246 | return n, nil 247 | } 248 | 249 | func (n Numeric) raw() []byte { 250 | return []byte(n) 251 | } 252 | 253 | func (n Numeric) encodeSql(b BinWriter) { 254 | if _, err := b.Write(n.raw()); err != nil { 255 | panic(err) 256 | } 257 | } 258 | 259 | func (n Numeric) encodeAscii(b BinWriter) { 260 | if _, err := b.Write(n.raw()); err != nil { 261 | panic(err) 262 | } 263 | } 264 | 265 | func (n Numeric) MarshalJSON() ([]byte, error) { 266 | return n.raw(), nil 267 | } 268 | 269 | func (f Fractional) raw() []byte { 270 | return []byte(f) 271 | } 272 | 273 | func (f Fractional) encodeSql(b BinWriter) { 274 | if _, err := b.Write(f.raw()); err != nil { 275 | panic(err) 276 | } 277 | } 278 | 279 | func (f Fractional) encodeAscii(b BinWriter) { 280 | if _, err := b.Write(f.raw()); err != nil { 281 | panic(err) 282 | } 283 | } 284 | 285 | func (s String) raw() []byte { 286 | return []byte(s) 287 | } 288 | 289 | func (s String) encodeSql(b BinWriter) { 290 | writebyte(b, '\'') 291 | for _, ch := range s.raw() { 292 | if encodedChar := SqlEncodeMap[ch]; encodedChar == DONTESCAPE { 293 | writebyte(b, ch) 294 | } else { 295 | writebyte(b, '\\') 296 | writebyte(b, encodedChar) 297 | } 298 | } 299 | writebyte(b, '\'') 300 | } 301 | 302 | func (s String) encodeAscii(b BinWriter) { 303 | writebyte(b, '\'') 304 | encoder := base64.NewEncoder(base64.StdEncoding, b) 305 | encoder.Write(s.raw()) 306 | encoder.Close() 307 | writebyte(b, '\'') 308 | } 309 | 310 | func writebyte(b BinWriter, c byte) { 311 | if err := b.WriteByte(c); err != nil { 312 | panic(err) 313 | } 314 | } 315 | 316 | // SqlEncodeMap specifies how to escape binary data with '\'. 317 | // Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html 318 | var SqlEncodeMap [256]byte 319 | 320 | // SqlDecodeMap is the reverse of SqlEncodeMap 321 | var SqlDecodeMap [256]byte 322 | 323 | var encodeRef = map[byte]byte{ 324 | '\x00': '0', 325 | '\'': '\'', 326 | '"': '"', 327 | '\b': 'b', 328 | '\n': 'n', 329 | '\r': 'r', 330 | '\t': 't', 331 | 26: 'Z', // ctl-Z 332 | '\\': '\\', 333 | } 334 | 335 | func init() { 336 | for i := range SqlEncodeMap { 337 | SqlEncodeMap[i] = DONTESCAPE 338 | SqlDecodeMap[i] = DONTESCAPE 339 | } 340 | for i := range SqlEncodeMap { 341 | if to, ok := encodeRef[byte(i)]; ok { 342 | SqlEncodeMap[byte(i)] = to 343 | SqlDecodeMap[to] = byte(i) 344 | } 345 | } 346 | gob.Register(Numeric(nil)) 347 | gob.Register(Fractional(nil)) 348 | gob.Register(String(nil)) 349 | } 350 | -------------------------------------------------------------------------------- /mysql/resultset.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "github.com/siddontang/mixer/hack" 7 | "math" 8 | "strconv" 9 | ) 10 | 11 | type RowData []byte 12 | 13 | func (p RowData) Parse(f []*Field, binary bool) ([]interface{}, error) { 14 | if binary { 15 | return p.ParseBinary(f) 16 | } else { 17 | return p.ParseText(f) 18 | } 19 | } 20 | 21 | func (p RowData) ParseText(f []*Field) ([]interface{}, error) { 22 | data := make([]interface{}, len(f)) 23 | 24 | var err error 25 | var v []byte 26 | var isNull, isUnsigned bool 27 | var pos int = 0 28 | var n int = 0 29 | 30 | for i := range f { 31 | v, isNull, n, err = LengthEnodedString(p[pos:]) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | pos += n 37 | 38 | if isNull { 39 | data[i] = nil 40 | } else { 41 | isUnsigned = (f[i].Flag&UNSIGNED_FLAG > 0) 42 | 43 | switch f[i].Type { 44 | case MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_INT24, 45 | MYSQL_TYPE_LONGLONG, MYSQL_TYPE_YEAR: 46 | if isUnsigned { 47 | data[i], err = strconv.ParseUint(string(v), 10, 64) 48 | } else { 49 | data[i], err = strconv.ParseInt(string(v), 10, 64) 50 | } 51 | case MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE: 52 | data[i], err = strconv.ParseFloat(string(v), 64) 53 | default: 54 | data[i] = v 55 | } 56 | 57 | if err != nil { 58 | return nil, err 59 | } 60 | } 61 | } 62 | 63 | return data, nil 64 | } 65 | 66 | func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { 67 | data := make([]interface{}, len(f)) 68 | 69 | if p[0] != OK_HEADER { 70 | return nil, ErrMalformPacket 71 | } 72 | 73 | pos := 1 + ((len(f) + 7 + 2) >> 3) 74 | 75 | nullBitmap := p[1:pos] 76 | 77 | var isUnsigned bool 78 | var isNull bool 79 | var n int 80 | var err error 81 | var v []byte 82 | for i := range data { 83 | if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 { 84 | data[i] = nil 85 | continue 86 | } 87 | 88 | isUnsigned = f[i].Flag&UNSIGNED_FLAG > 0 89 | 90 | switch f[i].Type { 91 | case MYSQL_TYPE_NULL: 92 | data[i] = nil 93 | continue 94 | 95 | case MYSQL_TYPE_TINY: 96 | if isUnsigned { 97 | data[i] = uint64(p[pos]) 98 | } else { 99 | data[i] = int64(p[pos]) 100 | } 101 | pos++ 102 | continue 103 | 104 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: 105 | if isUnsigned { 106 | data[i] = uint64(binary.LittleEndian.Uint16(p[pos : pos+2])) 107 | } else { 108 | data[i] = int64((binary.LittleEndian.Uint16(p[pos : pos+2]))) 109 | } 110 | pos += 2 111 | continue 112 | 113 | case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: 114 | if isUnsigned { 115 | data[i] = uint64(binary.LittleEndian.Uint32(p[pos : pos+4])) 116 | } else { 117 | data[i] = int64(binary.LittleEndian.Uint32(p[pos : pos+4])) 118 | } 119 | pos += 4 120 | continue 121 | 122 | case MYSQL_TYPE_LONGLONG: 123 | if isUnsigned { 124 | data[i] = binary.LittleEndian.Uint64(p[pos : pos+8]) 125 | } else { 126 | data[i] = int64(binary.LittleEndian.Uint64(p[pos : pos+8])) 127 | } 128 | pos += 8 129 | continue 130 | 131 | case MYSQL_TYPE_FLOAT: 132 | data[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(p[pos : pos+4]))) 133 | pos += 4 134 | continue 135 | 136 | case MYSQL_TYPE_DOUBLE: 137 | data[i] = math.Float64frombits(binary.LittleEndian.Uint64(p[pos : pos+8])) 138 | pos += 8 139 | continue 140 | 141 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, 142 | MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, 143 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, 144 | MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY: 145 | v, isNull, n, err = LengthEnodedString(p[pos:]) 146 | pos += n 147 | if err != nil { 148 | return nil, err 149 | } 150 | 151 | if !isNull { 152 | data[i] = v 153 | continue 154 | } else { 155 | data[i] = nil 156 | continue 157 | } 158 | case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE: 159 | var num uint64 160 | num, isNull, n = LengthEncodedInt(p[pos:]) 161 | 162 | pos += n 163 | 164 | if isNull { 165 | data[i] = nil 166 | continue 167 | } 168 | 169 | data[i], err = FormatBinaryDate(int(num), p[pos:]) 170 | pos += int(num) 171 | 172 | if err != nil { 173 | return nil, err 174 | } 175 | 176 | case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME: 177 | var num uint64 178 | num, isNull, n = LengthEncodedInt(p[pos:]) 179 | 180 | pos += n 181 | 182 | if isNull { 183 | data[i] = nil 184 | continue 185 | } 186 | 187 | data[i], err = FormatBinaryDateTime(int(num), p[pos:]) 188 | pos += int(num) 189 | 190 | if err != nil { 191 | return nil, err 192 | } 193 | 194 | case MYSQL_TYPE_TIME: 195 | var num uint64 196 | num, isNull, n = LengthEncodedInt(p[pos:]) 197 | 198 | pos += n 199 | 200 | if isNull { 201 | data[i] = nil 202 | continue 203 | } 204 | 205 | data[i], err = FormatBinaryTime(int(num), p[pos:]) 206 | pos += int(num) 207 | 208 | if err != nil { 209 | return nil, err 210 | } 211 | 212 | default: 213 | return nil, fmt.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name) 214 | } 215 | } 216 | 217 | return data, nil 218 | } 219 | 220 | type Resultset struct { 221 | Fields []*Field 222 | FieldNames map[string]int 223 | Values [][]interface{} 224 | 225 | RowDatas []RowData 226 | } 227 | 228 | func (r *Resultset) RowNumber() int { 229 | return len(r.Values) 230 | } 231 | 232 | func (r *Resultset) ColumnNumber() int { 233 | return len(r.Fields) 234 | } 235 | 236 | func (r *Resultset) GetValue(row, column int) (interface{}, error) { 237 | if row >= len(r.Values) || row < 0 { 238 | return nil, fmt.Errorf("invalid row index %d", row) 239 | } 240 | 241 | if column >= len(r.Fields) || column < 0 { 242 | return nil, fmt.Errorf("invalid column index %d", column) 243 | } 244 | 245 | return r.Values[row][column], nil 246 | } 247 | 248 | func (r *Resultset) NameIndex(name string) (int, error) { 249 | if column, ok := r.FieldNames[name]; ok { 250 | return column, nil 251 | } else { 252 | return 0, fmt.Errorf("invalid field name %s", name) 253 | } 254 | } 255 | 256 | func (r *Resultset) GetValueByName(row int, name string) (interface{}, error) { 257 | if column, err := r.NameIndex(name); err != nil { 258 | return nil, err 259 | } else { 260 | return r.GetValue(row, column) 261 | } 262 | } 263 | 264 | func (r *Resultset) IsNull(row, column int) (bool, error) { 265 | d, err := r.GetValue(row, column) 266 | if err != nil { 267 | return false, err 268 | } 269 | 270 | return d == nil, nil 271 | } 272 | 273 | func (r *Resultset) IsNullByName(row int, name string) (bool, error) { 274 | if column, err := r.NameIndex(name); err != nil { 275 | return false, err 276 | } else { 277 | return r.IsNull(row, column) 278 | } 279 | } 280 | 281 | func (r *Resultset) GetUint(row, column int) (uint64, error) { 282 | d, err := r.GetValue(row, column) 283 | if err != nil { 284 | return 0, err 285 | } 286 | 287 | switch v := d.(type) { 288 | case uint64: 289 | return v, nil 290 | case int64: 291 | return uint64(v), nil 292 | case float64: 293 | return uint64(v), nil 294 | case string: 295 | return strconv.ParseUint(v, 10, 64) 296 | case []byte: 297 | return strconv.ParseUint(string(v), 10, 64) 298 | case nil: 299 | return 0, nil 300 | default: 301 | return 0, fmt.Errorf("data type is %T", v) 302 | } 303 | } 304 | 305 | func (r *Resultset) GetUintByName(row int, name string) (uint64, error) { 306 | if column, err := r.NameIndex(name); err != nil { 307 | return 0, err 308 | } else { 309 | return r.GetUint(row, column) 310 | } 311 | } 312 | 313 | func (r *Resultset) GetInt(row, column int) (int64, error) { 314 | v, err := r.GetUint(row, column) 315 | if err != nil { 316 | return 0, err 317 | } 318 | 319 | return int64(v), nil 320 | } 321 | 322 | func (r *Resultset) GetIntByName(row int, name string) (int64, error) { 323 | v, err := r.GetUintByName(row, name) 324 | if err != nil { 325 | return 0, err 326 | } 327 | 328 | return int64(v), nil 329 | } 330 | 331 | func (r *Resultset) GetFloat(row, column int) (float64, error) { 332 | d, err := r.GetValue(row, column) 333 | if err != nil { 334 | return 0, err 335 | } 336 | 337 | switch v := d.(type) { 338 | case float64: 339 | return v, nil 340 | case uint64: 341 | return float64(v), nil 342 | case int64: 343 | return float64(v), nil 344 | case string: 345 | return strconv.ParseFloat(v, 64) 346 | case []byte: 347 | return strconv.ParseFloat(string(v), 64) 348 | case nil: 349 | return 0, nil 350 | default: 351 | return 0, fmt.Errorf("data type is %T", v) 352 | } 353 | } 354 | 355 | func (r *Resultset) GetFloatByName(row int, name string) (float64, error) { 356 | if column, err := r.NameIndex(name); err != nil { 357 | return 0, err 358 | } else { 359 | return r.GetFloat(row, column) 360 | } 361 | } 362 | 363 | func (r *Resultset) GetString(row, column int) (string, error) { 364 | d, err := r.GetValue(row, column) 365 | if err != nil { 366 | return "", err 367 | } 368 | 369 | switch v := d.(type) { 370 | case string: 371 | return v, nil 372 | case []byte: 373 | return hack.String(v), nil 374 | case int64: 375 | return strconv.FormatInt(v, 10), nil 376 | case uint64: 377 | return strconv.FormatUint(v, 10), nil 378 | case float64: 379 | return strconv.FormatFloat(v, 'f', -1, 64), nil 380 | case nil: 381 | return "", nil 382 | default: 383 | return "", fmt.Errorf("data type is %T", v) 384 | } 385 | } 386 | 387 | func (r *Resultset) GetStringByName(row int, name string) (string, error) { 388 | if column, err := r.NameIndex(name); err != nil { 389 | return "", err 390 | } else { 391 | return r.GetString(row, column) 392 | } 393 | } 394 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixer 2 | 3 | Mixer is a MySQL proxy powered by Go which aims to supply a simple solution for MySQL sharding. 4 | 5 | ## Features 6 | 7 | - Supports basic SQL statements (select, insert, update, replace, delete) 8 | - Supports transactions 9 | - Splits reads and writes (not fully tested) 10 | - MySQL HA 11 | - Basic SQL Routing 12 | - Supports prepared statement: `COM_STMT_PREPARE`, `COM_STMT_EXECUTE`, etc. 13 | 14 | ## TODO 15 | 16 | - Some admin commands 17 | - Some show command support, i.e. ```show databases```, etc. 18 | - Some select system variable, i.e. ```select @@version```, etc. 19 | - Enhance routing rules 20 | - Monitor 21 | - SQL validation check 22 | - Statistics 23 | - Many other things ... 24 | 25 | ## Install 26 | 27 | cd $WORKSPACE 28 | git clone git@github.com:siddontang/mixer.git src/github.com/siddontang/mixer 29 | 30 | cd src/github.com/siddontang/mixer 31 | 32 | ./bootstrap.sh 33 | 34 | . ./dev.env 35 | 36 | make 37 | make test 38 | 39 | ## Keywords 40 | 41 | ### proxy 42 | 43 | A proxy is the bridge connecting clients and the real MySQL servers. 44 | 45 | It acts as a MySQL server too, clients can communicate with it using the MySQL procotol. 46 | 47 | ### node 48 | 49 | Mixer uses nodes to represent the real remote MySQL servers. A node can have two MySQL servers: 50 | 51 | + master: main MySQL server, all write operations, read operations (if ```rw_split``` and slave are not set) will be executed here. 52 | All transactions will be executed here too. 53 | + slave: if ```rw_split``` is set, any select operations will be executed here. (can not set) 54 | 55 | Notice: 56 | 57 | + You can use ```admin upnode``` or ```admin downnode``` commands to bring a specified MySQL server up or down. 58 | + If the master was down, you must use an admin command to bring it up manually. 59 | + You must set up MySQL replication for yourself, mixer does not do it. 60 | 61 | ### schema 62 | 63 | Schema likes MySQL database, if a client executes ```use db``` command, ```db``` must exist in the schema. 64 | 65 | A schema contains one or more nodes. If a client use the specified schema, any command will be only routed to the node which belongs to the schema to be executed. 66 | 67 | ### rule 68 | 69 | You must set some rules for a schema to let the mixer decide how to route SQL statements to different nodes to be executed. 70 | 71 | Mixer uses ```table + key``` to route. Duplicate rule for a table are not allowed. 72 | 73 | When SQL needs to be routed, mixer does the following steps: 74 | 75 | + Parse SQL and find the table operated on 76 | + If there are no rule for the table, mixer use the default rule 77 | + If a rule exists, mixer tries to route it with the specified key 78 | 79 | Rules have three types: default, hash and range. 80 | 81 | A schema must have a default rule with only one node assigned. 82 | 83 | For hash and range routing you can see the example below. 84 | 85 | ## admin commands 86 | 87 | Mixer suplies `admin` statement to administrate. The `admin` format is `admin func(arg, ...)` like `select func(arg,...)`. Later we may add admin password for safe use. 88 | 89 | Support admin functions now: 90 | 91 | - admin upnode(node, serverype, addr); 92 | - admin downnode(node, servertype); 93 | - show proxy config; 94 | 95 | ## Base Example 96 | 97 | ``` 98 | #start mixer 99 | mixer-proxy -config=/etc/mixer.conf 100 | 101 | #another shell 102 | mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer 103 | 104 | Welcome to the MySQL monitor. Commands end with ; or \g. 105 | Your MySQL connection id is 158 106 | Server version: 5.6.19 Homebrew 107 | 108 | mysql> use mixer; 109 | Database changed 110 | 111 | mysql> delete from mixer_test_conn; 112 | Query OK, 3 rows affected (0.04 sec) 113 | 114 | mysql> insert into mixer_test_conn (id, str) values (1, "a"); 115 | Query OK, 1 row affected (0.00 sec) 116 | 117 | mysql> insert into mixer_test_conn (id, str) values (2, "b"); 118 | Query OK, 1 row affected (0.00 sec) 119 | 120 | mysql> select id, str from mixer_test_conn; 121 | +----+------+ 122 | | id | str | 123 | +----+------+ 124 | | 1 | a | 125 | | 2 | b | 126 | +----+------+ 127 | ``` 128 | 129 | ## Hash Sharding Example 130 | 131 | ``` 132 | schemas : 133 | - 134 | db : mixer 135 | nodes: [node1, node2, node3] 136 | rules: 137 | default: node1 138 | shard: 139 | - 140 | table: mixer_test_shard_hash 141 | key: id 142 | nodes: [node2, node3] 143 | type: hash 144 | 145 | hash algorithm: value % len(nodes) 146 | 147 | table: mixer_test_shard_hash 148 | 149 | Node: node2, node3 150 | node2 mysql: 127.0.0.1:3307 151 | node3 mysql: 127.0.0.1:3308 152 | 153 | mixer-proxy: 127.0.0.1:4000 154 | 155 | proxy> mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer 156 | node2> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer 157 | node3> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer 158 | 159 | proxy> insert into mixer_test_shard_hash (id, str) values (0, "a"); 160 | node2> select str from mixer_test_shard_hash where id = 0; 161 | +------+ 162 | | str | 163 | +------+ 164 | | a | 165 | +------+ 166 | 167 | proxy> insert into mixer_test_shard_hash (id, str) values (1, "b"); 168 | node3> select str from mixer_test_shard_hash where id = 1; 169 | +------+ 170 | | str | 171 | +------+ 172 | | b | 173 | +------+ 174 | 175 | proxy> select str from mixer_test_shard_hash where id in (0, 1); 176 | +------+ 177 | | str | 178 | +------+ 179 | | a | 180 | | b | 181 | +------+ 182 | 183 | proxy> select str from mixer_test_shard_hash where id = 0 or id = 1; 184 | +------+ 185 | | str | 186 | +------+ 187 | | a | 188 | | b | 189 | +------+ 190 | 191 | proxy> select str from mixer_test_shard_hash where id = 0 and id = 1; 192 | Empty set 193 | ``` 194 | 195 | 196 | ## Range Sharding Example 197 | 198 | ``` 199 | schemas : 200 | - 201 | db : mixer 202 | nodes: [node1, node2, node3] 203 | rules: 204 | default: node1 205 | shard: 206 | - 207 | table: mixer_test_shard_range 208 | key: id 209 | nodes: [node2, node3] 210 | range: -10000- 211 | type: range 212 | 213 | range algorithm: node key start <= value < node key stop 214 | 215 | table: mixer_test_shard_range 216 | 217 | Node: node2, node3 218 | node2 range: (-inf, 10000) 219 | node3 range: [10000, +inf) 220 | node2 mysql: 127.0.0.1:3307 221 | node3 mysql: 127.0.0.1:3308 222 | 223 | mixer-proxy: 127.0.0.1:4000 224 | 225 | proxy> mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer 226 | node2> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer 227 | node3> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer 228 | 229 | proxy> insert into mixer_test_shard_range (id, str) values (0, "a"); 230 | node2> select str from mixer_test_shard_range where id = 0; 231 | +------+ 232 | | str | 233 | +------+ 234 | | a | 235 | +------+ 236 | 237 | proxy> insert into mixer_test_shard_range (id, str) values (10000, "b"); 238 | node3> select str from mixer_test_shard_range where id = 10000; 239 | +------+ 240 | | str | 241 | +------+ 242 | | b | 243 | +------+ 244 | 245 | proxy> select str from mixer_test_shard_range where id in (0, 10000); 246 | +------+ 247 | | str | 248 | +------+ 249 | | a | 250 | | b | 251 | +------+ 252 | 253 | proxy> select str from mixer_test_shard_range where id = 0 or id = 10000; 254 | +------+ 255 | | str | 256 | +------+ 257 | | a | 258 | | b | 259 | +------+ 260 | 261 | proxy> select str from mixer_test_shard_range where id = 0 and id = 10000; 262 | Empty set 263 | 264 | proxy> select str from mixer_test_shard_range where id > 100; 265 | +------+ 266 | | str | 267 | +------+ 268 | | b | 269 | +------+ 270 | 271 | proxy> select str from mixer_test_shard_range where id < 100; 272 | +------+ 273 | | str | 274 | +------+ 275 | | a | 276 | +------+ 277 | 278 | proxy> select str from mixer_test_shard_range where id >=0 and id < 100000; 279 | +------+ 280 | | str | 281 | +------+ 282 | | a | 283 | | b | 284 | +------+ 285 | ``` 286 | 287 | ## Limitations 288 | 289 | ### Select 290 | 291 | + Join not supported, later only cross sharding not supported. 292 | + Subselects not supported, later only cross sharding not supported. 293 | + Cross sharding "group by" will not work ok only except the "group by" key is the routing key 294 | + Cross sharding "order by" only takes effect when the "order by" key exists as a select expression field 295 | 296 | ```select id from t1 order by id``` is ok. 297 | 298 | ```select str from t1 order by id``` is not ok, mixer does not known how to sort because it can not find proper data to compare with `id` 299 | 300 | + Limit should be used with "order by", otherwise you may receive incorrect results 301 | 302 | ### Insert 303 | 304 | + "insert into select" not supported, later only cross sharding not supported. 305 | + Multi insert values to different nodes not supported 306 | + "insert on duplicate key update" can not set the routing key 307 | 308 | ### Replace 309 | 310 | + Multi replace values to different nodes not supported 311 | 312 | ### Update 313 | 314 | + Update can not set the routing key 315 | 316 | ### Set 317 | 318 | + Set autocommit support 319 | + Set name charset support 320 | 321 | ### Range Rule 322 | 323 | + Only int64 number range supported 324 | 325 | ## Caveat 326 | 327 | + Mixer uses 2PC to handle write operations for multi nodes. You take the risk that data becomes corrupted if some nodes commit ok but others error. In that case, you must try to recover your data by yourself. 328 | + You must design your routing rule and write SQL carefully. (e.g. if your where condition contains no routing key, mixer will route the SQL to all nodes, maybe). 329 | 330 | ## Why not [vitess](https://github.com/youtube/vitess)? 331 | 332 | Vitess is very awesome, and I use some of its code like sqlparser. Why not use vitess directly? Maybe below: 333 | 334 | + Vitess is too huge for me, I need a simple proxy 335 | + Vitess uses an RPC protocol based on BSON, I want to use the MySQL protocol 336 | + Most likely, something has gone wrong in my head 337 | 338 | ## Status 339 | 340 | Mixer now is still in development and should not be used in production. 341 | 342 | ## Feedback 343 | 344 | Email: , 345 | -------------------------------------------------------------------------------- /proxy/conn_shard_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/mysql" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func testShard_Insert(t *testing.T, table string, node string, id int, str string) { 11 | conn := newTestDBConn(t) 12 | 13 | s := fmt.Sprintf(`insert into %s (id, str) values (%d, "%s")`, table, id, str) 14 | if r, err := conn.Execute(s); err != nil { 15 | t.Fatal(s, err) 16 | } else if r.AffectedRows != 1 { 17 | t.Fatal(r.AffectedRows) 18 | } 19 | s = fmt.Sprintf(`select str from %s where id = %d`, table, id) 20 | 21 | n := newTestServer(t).nodes[node] 22 | c, err := n.getMasterConn() 23 | if err != nil { 24 | t.Fatal(s, err) 25 | } else { 26 | if r, err := c.Execute(s); err != nil { 27 | t.Fatal(s, err) 28 | } else if v, _ := r.GetString(0, 0); v != str { 29 | t.Fatal(s, v) 30 | } 31 | } 32 | 33 | if r, err := conn.Execute(s); err != nil { 34 | t.Fatal(s, err) 35 | } else if v, _ := r.GetString(0, 0); v != str { 36 | t.Fatal(s, v) 37 | } 38 | } 39 | 40 | func testShard_Select(t *testing.T, table string, where string, strs ...string) { 41 | sql := fmt.Sprintf("select str from %s where %s", table, where) 42 | conn := newTestDBConn(t) 43 | 44 | r, err := conn.Execute(sql) 45 | if err != nil { 46 | t.Fatal(sql, err) 47 | } else if r.RowNumber() != len(strs) { 48 | t.Fatal(sql, r.RowNumber(), len(strs)) 49 | } 50 | 51 | m := map[string]struct{}{} 52 | for _, s := range strs { 53 | m[s] = struct{}{} 54 | } 55 | 56 | for i := 0; i < r.RowNumber(); i++ { 57 | if v, err := r.GetString(i, 0); err != nil { 58 | t.Fatal(sql, err) 59 | } else if _, ok := m[v]; !ok { 60 | t.Fatal(sql, v, "no in check strs") 61 | } else { 62 | delete(m, v) 63 | } 64 | } 65 | 66 | if len(m) != 0 { 67 | t.Fatal(sql, "invalid select") 68 | } 69 | } 70 | 71 | func testShard_StmtInsert(t *testing.T, table string, node string, id int, str string) { 72 | conn := newTestDBConn(t) 73 | 74 | s := fmt.Sprintf(`insert into %s (id, str) values (?, ?)`, table) 75 | if r, err := conn.Execute(s, id, str); err != nil { 76 | t.Fatal(s, err) 77 | } else if r.AffectedRows != 1 { 78 | t.Fatal(r.AffectedRows) 79 | } 80 | s = fmt.Sprintf(`select str from %s where id = ?`, table) 81 | 82 | n := newTestServer(t).nodes[node] 83 | c, err := n.getMasterConn() 84 | if err != nil { 85 | t.Fatal(s, err) 86 | } else { 87 | if r, err := c.Execute(s, id); err != nil { 88 | t.Fatal(s, err) 89 | } else if v, _ := r.GetString(0, 0); v != str { 90 | t.Fatal(s, v) 91 | } 92 | } 93 | 94 | if r, err := conn.Execute(s, id); err != nil { 95 | t.Fatal(s, err) 96 | } else if v, _ := r.GetString(0, 0); v != str { 97 | t.Fatal(s, v) 98 | } 99 | } 100 | 101 | func testShard_StmtSelect(t *testing.T, table string, where string, args []interface{}, strs ...string) { 102 | sql := fmt.Sprintf("select str from %s where %s", table, where) 103 | conn := newTestDBConn(t) 104 | 105 | r, err := conn.Execute(sql, args...) 106 | if err != nil { 107 | t.Fatal(sql, err) 108 | } else if r.RowNumber() != len(strs) { 109 | t.Fatal(sql, r.RowNumber(), len(strs)) 110 | } 111 | 112 | m := map[string]struct{}{} 113 | for _, s := range strs { 114 | m[s] = struct{}{} 115 | } 116 | 117 | for i := 0; i < r.RowNumber(); i++ { 118 | if v, err := r.GetString(i, 0); err != nil { 119 | t.Fatal(sql, err) 120 | } else if _, ok := m[v]; !ok { 121 | t.Fatal(sql, v, "no in check strs") 122 | } else { 123 | delete(m, v) 124 | } 125 | } 126 | 127 | if len(m) != 0 { 128 | t.Fatal(sql, "invalid select") 129 | } 130 | } 131 | 132 | func TestShard_DeleteHashTable(t *testing.T) { 133 | s := `drop table if exists mixer_test_shard_hash` 134 | 135 | server := newTestServer(t) 136 | 137 | for _, n := range server.nodes { 138 | if n.String() != "node2" && n.String() != "node3" { 139 | continue 140 | } 141 | c, err := n.getMasterConn() 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | 146 | c.UseDB("mixer") 147 | defer c.Close() 148 | if _, err := c.Execute(s); err != nil { 149 | t.Fatal(err) 150 | } 151 | 152 | } 153 | } 154 | 155 | func TestShard_CreateHashTable(t *testing.T) { 156 | s := `CREATE TABLE IF NOT EXISTS mixer_test_shard_hash ( 157 | id BIGINT(64) UNSIGNED NOT NULL, 158 | str VARCHAR(256), 159 | PRIMARY KEY (id) 160 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 161 | 162 | server := newTestServer(t) 163 | 164 | for _, n := range server.nodes { 165 | if n.String() != "node2" && n.String() != "node3" { 166 | continue 167 | } 168 | c, err := n.getMasterConn() 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | c.UseDB("mixer") 174 | defer c.Close() 175 | if _, err := c.Execute(s); err != nil { 176 | t.Fatal(err) 177 | } 178 | } 179 | } 180 | 181 | func TestShard_Hash(t *testing.T) { 182 | table := "mixer_test_shard_hash" 183 | testShard_Insert(t, table, "node2", 0, "a") 184 | testShard_Insert(t, table, "node3", 1, "b") 185 | testShard_Insert(t, table, "node2", 2, "c") 186 | testShard_Insert(t, table, "node3", 3, "d") 187 | 188 | testShard_Select(t, table, "id = 2", "c") 189 | testShard_Select(t, table, "id = 2 or id = 3", "c", "d") 190 | testShard_Select(t, table, "id = 2 and id = 3") 191 | testShard_Select(t, table, "id in (0, 1, 3)", "a", "b", "d") 192 | 193 | testShard_StmtInsert(t, table, "node2", 10, "a") 194 | testShard_StmtInsert(t, table, "node3", 11, "b") 195 | testShard_StmtInsert(t, table, "node2", 12, "c") 196 | testShard_StmtInsert(t, table, "node3", 13, "d") 197 | 198 | testShard_StmtSelect(t, table, "id = ?", []interface{}{12}, "c") 199 | testShard_StmtSelect(t, table, "id = ? or id = ?", []interface{}{12, 13}, "c", "d") 200 | testShard_StmtSelect(t, table, "id = ? and id = ?", []interface{}{12, 13}) 201 | testShard_StmtSelect(t, table, "id in (?, ?, ?)", []interface{}{10, 11, 13}, "a", "b", "d") 202 | } 203 | 204 | func testExecute(t *testing.T, sql string) *mysql.Result { 205 | conn := newTestDBConn(t) 206 | 207 | r, err := conn.Execute(sql) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | 212 | return r 213 | } 214 | 215 | func testShared_SelectOrderBy(t *testing.T, table string, where string, v [][]interface{}) { 216 | sql := fmt.Sprintf("select id, str from %s where %s", table, where) 217 | 218 | r := testExecute(t, sql) 219 | 220 | if !reflect.DeepEqual(r.Values, v) { 221 | t.Fatal(fmt.Sprintf("%v != %v", r.Values, v)) 222 | } 223 | } 224 | 225 | func TestShard_HashOrderByLimit(t *testing.T) { 226 | table := "mixer_test_shard_hash" 227 | 228 | testShard_Insert(t, table, "node2", 4, "a") 229 | testShard_Insert(t, table, "node3", 5, "a") 230 | testShard_Insert(t, table, "node2", 6, "b") 231 | testShard_Insert(t, table, "node3", 7, "b") 232 | 233 | var v [][]interface{} 234 | v = [][]interface{}{ 235 | []interface{}{uint64(7), []byte("b")}, 236 | []interface{}{uint64(6), []byte("b")}, 237 | []interface{}{uint64(5), []byte("a")}, 238 | []interface{}{uint64(4), []byte("a")}, 239 | } 240 | 241 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by id desc", v) 242 | 243 | v = [][]interface{}{ 244 | []interface{}{uint64(6), []byte("b")}, 245 | []interface{}{uint64(7), []byte("b")}, 246 | []interface{}{uint64(4), []byte("a")}, 247 | []interface{}{uint64(5), []byte("a")}, 248 | } 249 | 250 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc", v) 251 | 252 | v = [][]interface{}{ 253 | []interface{}{uint64(6), []byte("b")}, 254 | []interface{}{uint64(7), []byte("b")}, 255 | } 256 | 257 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc limit 0, 2", v) 258 | 259 | v = [][]interface{}{ 260 | []interface{}{uint64(5), []byte("a")}, 261 | } 262 | 263 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc limit 1, 2", v) 264 | 265 | } 266 | 267 | func TestShard_DeleteRangeTable(t *testing.T) { 268 | s := `drop table if exists mixer_test_shard_range` 269 | 270 | server := newTestServer(t) 271 | 272 | for _, n := range server.nodes { 273 | if n.String() != "node2" && n.String() != "node3" { 274 | continue 275 | } 276 | c, err := n.getMasterConn() 277 | if err != nil { 278 | t.Fatal(err) 279 | } 280 | 281 | c.UseDB("mixer") 282 | defer c.Close() 283 | if _, err := c.Execute(s); err != nil { 284 | t.Fatal(err) 285 | } 286 | 287 | } 288 | } 289 | 290 | func TestShard_CreateRangeTable(t *testing.T) { 291 | s := `CREATE TABLE IF NOT EXISTS mixer_test_shard_range ( 292 | id BIGINT(64) UNSIGNED NOT NULL, 293 | str VARCHAR(256), 294 | PRIMARY KEY (id) 295 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8` 296 | 297 | server := newTestServer(t) 298 | 299 | for _, n := range server.nodes { 300 | if n.String() != "node2" && n.String() != "node3" { 301 | continue 302 | } 303 | c, err := n.getMasterConn() 304 | if err != nil { 305 | t.Fatal(err) 306 | } 307 | 308 | c.UseDB("mixer") 309 | defer c.Close() 310 | if _, err := c.Execute(s); err != nil { 311 | t.Fatal(err) 312 | } 313 | 314 | } 315 | } 316 | 317 | func TestShard_Range(t *testing.T) { 318 | table := "mixer_test_shard_range" 319 | testShard_Insert(t, table, "node2", 0, "a") 320 | testShard_Insert(t, table, "node3", 10000, "b") 321 | testShard_StmtInsert(t, table, "node2", 2, "c") 322 | testShard_StmtInsert(t, table, "node3", 10001, "d") 323 | 324 | testShard_Select(t, table, "id = 2", "c") 325 | testShard_Select(t, table, "id = 2 or id = 10001", "c", "d") 326 | testShard_Select(t, table, "id = 2 and id = 10001") 327 | testShard_Select(t, table, "id in (0, 10000, 10001)", "a", "b", "d") 328 | testShard_Select(t, table, "id < 1 or id >= 10000", "a", "b", "d") 329 | testShard_Select(t, table, "id > 1 and id <= 10000", "b", "c") 330 | testShard_Select(t, table, "id < 1 and id >= 10000") 331 | 332 | testShard_StmtSelect(t, table, "id = ?", []interface{}{2}, "c") 333 | testShard_StmtSelect(t, table, "id = ? or id = ?", []interface{}{2, 10001}, "c", "d") 334 | testShard_StmtSelect(t, table, "id = ? and id = ?", []interface{}{2, 10001}) 335 | testShard_StmtSelect(t, table, "id in (?, ?, ?)", []interface{}{0, 10000, 10001}, "a", "b", "d") 336 | testShard_StmtSelect(t, table, "id < ? or id >= ?", []interface{}{1, 10000}, "a", "b", "d") 337 | testShard_StmtSelect(t, table, "id > ? and id <= ?", []interface{}{1, 10000}, "b", "c") 338 | testShard_StmtSelect(t, table, "id < ? and id >= ?", []interface{}{1, 10000}) 339 | } 340 | -------------------------------------------------------------------------------- /proxy/conn_stmt.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | . "github.com/siddontang/mixer/mysql" 7 | "github.com/siddontang/mixer/sqlparser" 8 | "math" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | var paramFieldData []byte 14 | var columnFieldData []byte 15 | 16 | func init() { 17 | var p = &Field{Name: []byte("?")} 18 | var c = &Field{} 19 | 20 | paramFieldData = p.Dump() 21 | columnFieldData = c.Dump() 22 | } 23 | 24 | type Stmt struct { 25 | id uint32 26 | 27 | params int 28 | columns int 29 | 30 | args []interface{} 31 | 32 | s sqlparser.Statement 33 | 34 | sql string 35 | } 36 | 37 | func (s *Stmt) ResetParams() { 38 | s.args = make([]interface{}, s.params) 39 | } 40 | 41 | func (c *Conn) handleStmtPrepare(sql string) error { 42 | if c.schema == nil { 43 | return NewDefaultError(ER_NO_DB_ERROR) 44 | } 45 | 46 | s := new(Stmt) 47 | 48 | sql = strings.TrimRight(sql, ";") 49 | 50 | var err error 51 | s.s, err = sqlparser.Parse(sql) 52 | if err != nil { 53 | return fmt.Errorf(`parse sql "%s" error`, sql) 54 | } 55 | 56 | s.sql = sql 57 | 58 | var tableName string 59 | switch s := s.s.(type) { 60 | case *sqlparser.Select: 61 | tableName = nstring(s.From) 62 | case *sqlparser.Insert: 63 | tableName = nstring(s.Table) 64 | case *sqlparser.Update: 65 | tableName = nstring(s.Table) 66 | case *sqlparser.Delete: 67 | tableName = nstring(s.Table) 68 | case *sqlparser.Replace: 69 | tableName = nstring(s.Table) 70 | default: 71 | return fmt.Errorf(`unsupport prepare sql "%s"`, sql) 72 | } 73 | 74 | r := c.schema.rule.GetRule(tableName) 75 | 76 | n := c.server.getNode(r.Nodes[0]) 77 | 78 | if co, err := n.getMasterConn(); err != nil { 79 | return fmt.Errorf("prepare error %s", err) 80 | } else { 81 | defer co.Close() 82 | 83 | if err = co.UseDB(c.schema.db); err != nil { 84 | return fmt.Errorf("parepre error %s", err) 85 | } 86 | 87 | if t, err := co.Prepare(sql); err != nil { 88 | return fmt.Errorf("parepre error %s", err) 89 | } else { 90 | 91 | s.params = t.ParamNum() 92 | s.columns = t.ColumnNum() 93 | } 94 | } 95 | 96 | s.id = c.stmtId 97 | c.stmtId++ 98 | 99 | if err = c.writePrepare(s); err != nil { 100 | return err 101 | } 102 | 103 | s.ResetParams() 104 | 105 | c.stmts[s.id] = s 106 | 107 | return nil 108 | } 109 | 110 | func (c *Conn) writePrepare(s *Stmt) error { 111 | data := make([]byte, 4, 128) 112 | 113 | //status ok 114 | data = append(data, 0) 115 | //stmt id 116 | data = append(data, Uint32ToBytes(s.id)...) 117 | //number columns 118 | data = append(data, Uint16ToBytes(uint16(s.columns))...) 119 | //number params 120 | data = append(data, Uint16ToBytes(uint16(s.params))...) 121 | //filter [00] 122 | data = append(data, 0) 123 | //warning count 124 | data = append(data, 0, 0) 125 | 126 | if err := c.writePacket(data); err != nil { 127 | return err 128 | } 129 | 130 | if s.params > 0 { 131 | for i := 0; i < s.params; i++ { 132 | data = data[0:4] 133 | data = append(data, []byte(paramFieldData)...) 134 | 135 | if err := c.writePacket(data); err != nil { 136 | return err 137 | } 138 | } 139 | 140 | if err := c.writeEOF(c.status); err != nil { 141 | return err 142 | } 143 | } 144 | 145 | if s.columns > 0 { 146 | for i := 0; i < s.columns; i++ { 147 | data = data[0:4] 148 | data = append(data, []byte(columnFieldData)...) 149 | 150 | if err := c.writePacket(data); err != nil { 151 | return err 152 | } 153 | } 154 | 155 | if err := c.writeEOF(c.status); err != nil { 156 | return err 157 | } 158 | 159 | } 160 | return nil 161 | } 162 | 163 | func (c *Conn) handleStmtExecute(data []byte) error { 164 | if len(data) < 9 { 165 | return ErrMalformPacket 166 | } 167 | 168 | pos := 0 169 | id := binary.LittleEndian.Uint32(data[0:4]) 170 | pos += 4 171 | 172 | s, ok := c.stmts[id] 173 | if !ok { 174 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 175 | strconv.FormatUint(uint64(id), 10), "stmt_execute") 176 | } 177 | 178 | flag := data[pos] 179 | pos++ 180 | //now we only support CURSOR_TYPE_NO_CURSOR flag 181 | if flag != 0 { 182 | return NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) 183 | } 184 | 185 | //skip iteration-count, always 1 186 | pos += 4 187 | 188 | var nullBitmaps []byte 189 | var paramTypes []byte 190 | var paramValues []byte 191 | 192 | paramNum := s.params 193 | 194 | if paramNum > 0 { 195 | nullBitmapLen := (s.params + 7) >> 3 196 | if len(data) < (pos + nullBitmapLen + 1) { 197 | return ErrMalformPacket 198 | } 199 | nullBitmaps = data[pos : pos+nullBitmapLen] 200 | pos += nullBitmapLen 201 | 202 | //new param bound flag 203 | if data[pos] == 1 { 204 | pos++ 205 | if len(data) < (pos + (paramNum << 1)) { 206 | return ErrMalformPacket 207 | } 208 | 209 | paramTypes = data[pos : pos+(paramNum<<1)] 210 | pos += (paramNum << 1) 211 | 212 | paramValues = data[pos:] 213 | } 214 | 215 | if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil { 216 | return err 217 | } 218 | } 219 | 220 | var err error 221 | 222 | switch stmt := s.s.(type) { 223 | case *sqlparser.Select: 224 | err = c.handleSelect(stmt, s.sql, s.args) 225 | case *sqlparser.Insert: 226 | err = c.handleExec(s.s, s.sql, s.args) 227 | case *sqlparser.Update: 228 | err = c.handleExec(s.s, s.sql, s.args) 229 | case *sqlparser.Delete: 230 | err = c.handleExec(s.s, s.sql, s.args) 231 | case *sqlparser.Replace: 232 | err = c.handleExec(s.s, s.sql, s.args) 233 | default: 234 | err = fmt.Errorf("command %T not supported now", stmt) 235 | } 236 | 237 | s.ResetParams() 238 | 239 | return err 240 | } 241 | 242 | func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error { 243 | args := s.args 244 | 245 | pos := 0 246 | 247 | var v []byte 248 | var n int = 0 249 | var isNull bool 250 | var err error 251 | 252 | for i := 0; i < s.params; i++ { 253 | if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { 254 | args[i] = nil 255 | continue 256 | } 257 | 258 | tp := paramTypes[i<<1] 259 | isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 260 | 261 | switch tp { 262 | case MYSQL_TYPE_NULL: 263 | args[i] = nil 264 | continue 265 | 266 | case MYSQL_TYPE_TINY: 267 | if len(paramValues) < (pos + 1) { 268 | return ErrMalformPacket 269 | } 270 | 271 | if isUnsigned { 272 | args[i] = uint8(paramValues[pos]) 273 | } else { 274 | args[i] = int8(paramValues[pos]) 275 | } 276 | 277 | pos++ 278 | continue 279 | 280 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: 281 | if len(paramValues) < (pos + 2) { 282 | return ErrMalformPacket 283 | } 284 | 285 | if isUnsigned { 286 | args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) 287 | } else { 288 | args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) 289 | } 290 | pos += 2 291 | continue 292 | 293 | case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: 294 | if len(paramValues) < (pos + 4) { 295 | return ErrMalformPacket 296 | } 297 | 298 | if isUnsigned { 299 | args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) 300 | } else { 301 | args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) 302 | } 303 | pos += 4 304 | continue 305 | 306 | case MYSQL_TYPE_LONGLONG: 307 | if len(paramValues) < (pos + 8) { 308 | return ErrMalformPacket 309 | } 310 | 311 | if isUnsigned { 312 | args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8]) 313 | } else { 314 | args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) 315 | } 316 | pos += 8 317 | continue 318 | 319 | case MYSQL_TYPE_FLOAT: 320 | if len(paramValues) < (pos + 4) { 321 | return ErrMalformPacket 322 | } 323 | 324 | args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))) 325 | pos += 4 326 | continue 327 | 328 | case MYSQL_TYPE_DOUBLE: 329 | if len(paramValues) < (pos + 8) { 330 | return ErrMalformPacket 331 | } 332 | 333 | args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) 334 | pos += 8 335 | continue 336 | 337 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, 338 | MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, 339 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, 340 | MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY, 341 | MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, 342 | MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME: 343 | if len(paramValues) < (pos + 1) { 344 | return ErrMalformPacket 345 | } 346 | 347 | v, isNull, n, err = LengthEnodedString(paramValues[pos:]) 348 | pos += n 349 | if err != nil { 350 | return err 351 | } 352 | 353 | if !isNull { 354 | args[i] = v 355 | continue 356 | } else { 357 | args[i] = nil 358 | continue 359 | } 360 | default: 361 | return fmt.Errorf("Stmt Unknown FieldType %d", tp) 362 | } 363 | } 364 | return nil 365 | } 366 | 367 | func (c *Conn) handleStmtSendLongData(data []byte) error { 368 | if len(data) < 6 { 369 | return ErrMalformPacket 370 | } 371 | 372 | id := binary.LittleEndian.Uint32(data[0:4]) 373 | 374 | s, ok := c.stmts[id] 375 | if !ok { 376 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 377 | strconv.FormatUint(uint64(id), 10), "stmt_send_longdata") 378 | } 379 | 380 | paramId := binary.LittleEndian.Uint16(data[4:6]) 381 | if paramId >= uint16(s.params) { 382 | return NewDefaultError(ER_WRONG_ARGUMENTS, "stmt_send_longdata") 383 | } 384 | 385 | if s.args[paramId] == nil { 386 | s.args[paramId] = data[6:] 387 | } else { 388 | if b, ok := s.args[paramId].([]byte); ok { 389 | b = append(b, data[6:]...) 390 | s.args[paramId] = b 391 | } else { 392 | return fmt.Errorf("invalid param long data type %T", s.args[paramId]) 393 | } 394 | } 395 | 396 | return nil 397 | } 398 | 399 | func (c *Conn) handleStmtReset(data []byte) error { 400 | if len(data) < 4 { 401 | return ErrMalformPacket 402 | } 403 | 404 | id := binary.LittleEndian.Uint32(data[0:4]) 405 | 406 | s, ok := c.stmts[id] 407 | if !ok { 408 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 409 | strconv.FormatUint(uint64(id), 10), "stmt_reset") 410 | } 411 | 412 | s.ResetParams() 413 | 414 | return c.writeOK(nil) 415 | } 416 | 417 | func (c *Conn) handleStmtClose(data []byte) error { 418 | if len(data) < 4 { 419 | return nil 420 | } 421 | 422 | id := binary.LittleEndian.Uint32(data[0:4]) 423 | 424 | delete(c.stmts, id) 425 | 426 | return nil 427 | } 428 | -------------------------------------------------------------------------------- /proxy/conn_query.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/siddontang/mixer/client" 6 | "github.com/siddontang/mixer/hack" 7 | . "github.com/siddontang/mixer/mysql" 8 | "github.com/siddontang/mixer/sqlparser" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | ) 13 | 14 | func (c *Conn) handleQuery(sql string) (err error) { 15 | defer func() { 16 | if e := recover(); e != nil { 17 | err = fmt.Errorf("execute %s error %v", sql, e) 18 | return 19 | } 20 | }() 21 | 22 | sql = strings.TrimRight(sql, ";") 23 | 24 | var stmt sqlparser.Statement 25 | stmt, err = sqlparser.Parse(sql) 26 | if err != nil { 27 | return fmt.Errorf(`parse sql "%s" error`, sql) 28 | } 29 | 30 | switch v := stmt.(type) { 31 | case *sqlparser.Select: 32 | return c.handleSelect(v, sql, nil) 33 | case *sqlparser.Insert: 34 | return c.handleExec(stmt, sql, nil) 35 | case *sqlparser.Update: 36 | return c.handleExec(stmt, sql, nil) 37 | case *sqlparser.Delete: 38 | return c.handleExec(stmt, sql, nil) 39 | case *sqlparser.Replace: 40 | return c.handleExec(stmt, sql, nil) 41 | case *sqlparser.Set: 42 | return c.handleSet(v) 43 | case *sqlparser.Begin: 44 | return c.handleBegin() 45 | case *sqlparser.Commit: 46 | return c.handleCommit() 47 | case *sqlparser.Rollback: 48 | return c.handleRollback() 49 | case *sqlparser.SimpleSelect: 50 | return c.handleSimpleSelect(sql, v) 51 | case *sqlparser.Show: 52 | return c.handleShow(sql, v) 53 | case *sqlparser.Admin: 54 | return c.handleAdmin(v) 55 | default: 56 | return fmt.Errorf("statement %T not support now", stmt) 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func (c *Conn) getShardList(stmt sqlparser.Statement, bindVars map[string]interface{}) ([]*Node, error) { 63 | if c.schema == nil { 64 | return nil, NewDefaultError(ER_NO_DB_ERROR) 65 | } 66 | 67 | ns, err := sqlparser.GetStmtShardList(stmt, c.schema.rule, bindVars) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if len(ns) == 0 { 73 | return nil, nil 74 | } 75 | 76 | n := make([]*Node, 0, len(ns)) 77 | for _, name := range ns { 78 | n = append(n, c.server.getNode(name)) 79 | } 80 | return n, nil 81 | } 82 | 83 | func (c *Conn) getConn(n *Node, isSelect bool) (co *client.SqlConn, err error) { 84 | if !c.needBeginTx() { 85 | if isSelect { 86 | co, err = n.getSelectConn() 87 | } else { 88 | co, err = n.getMasterConn() 89 | } 90 | if err != nil { 91 | return 92 | } 93 | } else { 94 | var ok bool 95 | c.Lock() 96 | co, ok = c.txConns[n] 97 | c.Unlock() 98 | 99 | if !ok { 100 | if co, err = n.getMasterConn(); err != nil { 101 | return 102 | } 103 | 104 | if err = co.Begin(); err != nil { 105 | return 106 | } 107 | 108 | c.Lock() 109 | c.txConns[n] = co 110 | c.Unlock() 111 | } 112 | } 113 | 114 | //todo, set conn charset, etc... 115 | if err = co.UseDB(c.schema.db); err != nil { 116 | return 117 | } 118 | 119 | if err = co.SetCharset(c.charset); err != nil { 120 | return 121 | } 122 | 123 | return 124 | } 125 | 126 | func (c *Conn) getShardConns(isSelect bool,stmt sqlparser.Statement, bindVars map[string]interface{}) ([]*client.SqlConn, error) { 127 | nodes, err := c.getShardList(stmt, bindVars) 128 | if err != nil { 129 | return nil, err 130 | } else if nodes == nil { 131 | return nil, nil 132 | } 133 | 134 | conns := make([]*client.SqlConn, 0, len(nodes)) 135 | 136 | var co *client.SqlConn 137 | for _, n := range nodes { 138 | co, err = c.getConn(n, isSelect) 139 | if err != nil { 140 | break 141 | } 142 | 143 | conns = append(conns, co) 144 | } 145 | 146 | return conns, err 147 | } 148 | 149 | func (c *Conn) executeInShard(conns []*client.SqlConn, sql string, args []interface{}) ([]*Result, error) { 150 | var wg sync.WaitGroup 151 | wg.Add(len(conns)) 152 | 153 | rs := make([]interface{}, len(conns)) 154 | 155 | f := func(rs []interface{}, i int, co *client.SqlConn) { 156 | r, err := co.Execute(sql, args...) 157 | if err != nil { 158 | rs[i] = err 159 | } else { 160 | rs[i] = r 161 | } 162 | 163 | wg.Done() 164 | } 165 | 166 | for i, co := range conns { 167 | go f(rs, i, co) 168 | } 169 | 170 | wg.Wait() 171 | 172 | var err error 173 | r := make([]*Result, len(conns)) 174 | for i, v := range rs { 175 | if e, ok := v.(error); ok { 176 | err = e 177 | break 178 | } 179 | r[i] = rs[i].(*Result) 180 | } 181 | 182 | return r, err 183 | } 184 | 185 | func (c *Conn) closeShardConns(conns []*client.SqlConn, rollback bool) { 186 | if c.isInTransaction() { 187 | return 188 | } 189 | 190 | for _, co := range conns { 191 | if rollback { 192 | co.Rollback() 193 | } 194 | 195 | co.Close() 196 | } 197 | } 198 | 199 | func (c *Conn) newEmptyResultset(stmt *sqlparser.Select) *Resultset { 200 | r := new(Resultset) 201 | r.Fields = make([]*Field, len(stmt.SelectExprs)) 202 | 203 | for i, expr := range stmt.SelectExprs { 204 | r.Fields[i] = &Field{} 205 | switch e := expr.(type) { 206 | case *sqlparser.StarExpr: 207 | r.Fields[i].Name = []byte("*") 208 | case *sqlparser.NonStarExpr: 209 | if e.As != nil { 210 | r.Fields[i].Name = e.As 211 | r.Fields[i].OrgName = hack.Slice(nstring(e.Expr)) 212 | } else { 213 | r.Fields[i].Name = hack.Slice(nstring(e.Expr)) 214 | } 215 | default: 216 | r.Fields[i].Name = hack.Slice(nstring(e)) 217 | } 218 | } 219 | 220 | r.Values = make([][]interface{}, 0) 221 | r.RowDatas = make([]RowData, 0) 222 | 223 | return r 224 | } 225 | 226 | func makeBindVars(args []interface{}) map[string]interface{} { 227 | bindVars := make(map[string]interface{}, len(args)) 228 | 229 | for i, v := range args { 230 | bindVars[fmt.Sprintf("v%d", i+1)] = v 231 | } 232 | 233 | return bindVars 234 | } 235 | 236 | func (c *Conn) handleSelect(stmt *sqlparser.Select, sql string, args []interface{}) error { 237 | bindVars := makeBindVars(args) 238 | 239 | conns, err := c.getShardConns(true,stmt, bindVars) 240 | if err != nil { 241 | return err 242 | } else if conns == nil { 243 | r := c.newEmptyResultset(stmt) 244 | return c.writeResultset(c.status, r) 245 | } 246 | 247 | var rs []*Result 248 | 249 | rs, err = c.executeInShard(conns, sql, args) 250 | 251 | c.closeShardConns(conns, false) 252 | 253 | if err == nil { 254 | err = c.mergeSelectResult(rs, stmt) 255 | } 256 | 257 | return err 258 | } 259 | 260 | func (c *Conn) beginShardConns(conns []*client.SqlConn) error { 261 | if c.isInTransaction() { 262 | return nil 263 | } 264 | 265 | for _, co := range conns { 266 | if err := co.Begin(); err != nil { 267 | return err 268 | } 269 | } 270 | 271 | return nil 272 | } 273 | 274 | func (c *Conn) commitShardConns(conns []*client.SqlConn) error { 275 | if c.isInTransaction() { 276 | return nil 277 | } 278 | 279 | for _, co := range conns { 280 | if err := co.Commit(); err != nil { 281 | return err 282 | } 283 | } 284 | 285 | return nil 286 | } 287 | 288 | func (c *Conn) handleExec(stmt sqlparser.Statement, sql string, args []interface{}) error { 289 | bindVars := makeBindVars(args) 290 | 291 | conns, err := c.getShardConns(false,stmt, bindVars) 292 | if err != nil { 293 | return err 294 | } else if conns == nil { 295 | return c.writeOK(nil) 296 | } 297 | 298 | var rs []*Result 299 | 300 | if len(conns) == 1 { 301 | rs, err = c.executeInShard(conns, sql, args) 302 | } else { 303 | //for multi nodes, 2PC simple, begin, exec, commit 304 | //if commit error, data maybe corrupt 305 | for { 306 | if err = c.beginShardConns(conns); err != nil { 307 | break 308 | } 309 | 310 | if rs, err = c.executeInShard(conns, sql, args); err != nil { 311 | break 312 | } 313 | 314 | err = c.commitShardConns(conns) 315 | break 316 | } 317 | } 318 | 319 | c.closeShardConns(conns, err != nil) 320 | 321 | if err == nil { 322 | err = c.mergeExecResult(rs) 323 | } 324 | 325 | return err 326 | } 327 | 328 | func (c *Conn) mergeExecResult(rs []*Result) error { 329 | r := new(Result) 330 | 331 | for _, v := range rs { 332 | r.Status |= v.Status 333 | r.AffectedRows += v.AffectedRows 334 | if r.InsertId == 0 { 335 | r.InsertId = v.InsertId 336 | } else if r.InsertId > v.InsertId { 337 | //last insert id is first gen id for multi row inserted 338 | //see http://dev.mysql.com/doc/refman/5.6/en/information-functions.html#function_last-insert-id 339 | r.InsertId = v.InsertId 340 | } 341 | } 342 | 343 | if r.InsertId > 0 { 344 | c.lastInsertId = int64(r.InsertId) 345 | } 346 | 347 | c.affectedRows = int64(r.AffectedRows) 348 | 349 | return c.writeOK(r) 350 | } 351 | 352 | func (c *Conn) mergeSelectResult(rs []*Result, stmt *sqlparser.Select) error { 353 | r := rs[0].Resultset 354 | 355 | status := c.status | rs[0].Status 356 | 357 | for i := 1; i < len(rs); i++ { 358 | status |= rs[i].Status 359 | 360 | //check fields equal 361 | 362 | for j := range rs[i].Values { 363 | r.Values = append(r.Values, rs[i].Values[j]) 364 | r.RowDatas = append(r.RowDatas, rs[i].RowDatas[j]) 365 | } 366 | } 367 | 368 | //to do order by, group by, limit offset 369 | c.sortSelectResult(r, stmt) 370 | //to do, add log here, sort may error because order by key not exist in resultset fields 371 | 372 | if err := c.limitSelectResult(r, stmt); err != nil { 373 | return err 374 | } 375 | 376 | return c.writeResultset(status, r) 377 | } 378 | 379 | func (c *Conn) sortSelectResult(r *Resultset, stmt *sqlparser.Select) error { 380 | if stmt.OrderBy == nil { 381 | return nil 382 | } 383 | 384 | sk := make([]SortKey, len(stmt.OrderBy)) 385 | 386 | for i, o := range stmt.OrderBy { 387 | sk[i].Name = nstring(o.Expr) 388 | sk[i].Direction = o.Direction 389 | } 390 | 391 | return r.Sort(sk) 392 | } 393 | 394 | func (c *Conn) limitSelectResult(r *Resultset, stmt *sqlparser.Select) error { 395 | if stmt.Limit == nil { 396 | return nil 397 | } 398 | 399 | var offset, count int64 400 | var err error 401 | if stmt.Limit.Offset == nil { 402 | offset = 0 403 | } else { 404 | if o, ok := stmt.Limit.Offset.(sqlparser.NumVal); !ok { 405 | return fmt.Errorf("invalid select limit %s", nstring(stmt.Limit)) 406 | } else { 407 | if offset, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil { 408 | return err 409 | } 410 | } 411 | } 412 | 413 | if o, ok := stmt.Limit.Rowcount.(sqlparser.NumVal); !ok { 414 | return fmt.Errorf("invalid limit %s", nstring(stmt.Limit)) 415 | } else { 416 | if count, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil { 417 | return err 418 | } else if count < 0 { 419 | return fmt.Errorf("invalid limit %s", nstring(stmt.Limit)) 420 | } 421 | } 422 | 423 | if offset+count > int64(len(r.Values)) { 424 | count = int64(len(r.Values)) - offset 425 | } 426 | 427 | r.Values = r.Values[offset : offset+count] 428 | r.RowDatas = r.RowDatas[offset : offset+count] 429 | 430 | return nil 431 | } 432 | --------------------------------------------------------------------------------