├── .gitignore
├── makefile
├── mysql
├── result.go
├── error.go
├── resultset_sort_test.go
├── packetio.go
├── resultset_sort.go
├── field.go
├── const.go
├── util.go
└── resultset.go
├── bootstrap.sh
├── sqlparser
├── Makefile
├── sql_test.go
├── analyzer_test.go
├── parsed_query.go
├── tracked_buffer.go
├── analyzer.go
└── parsed_query_test.go
├── dev.env
├── hack
├── hack_test.go
└── hack.go
├── LICENSE
├── etc
├── mixer_multi.conf.yaml
├── mixer_single.conf.yaml
└── mixer.conf.yaml
├── proxy
├── schema.go
├── conn_set.go
├── conn_tx.go
├── conn_admin.go
├── server.go
├── server_test.go
├── conn_select.go
├── conn_resultset.go
├── conn_show.go
├── conn_stmt_test.go
├── node.go
├── conn.go
├── conn_test.go
├── conn_shard_test.go
├── conn_stmt.go
└── conn_query.go
├── license_vitess
├── cmd
└── mixer-proxy
│ └── main.go
├── config
├── config.go
└── config_test.go
├── router
├── router_test.go
├── numkey.go
├── config.go
├── router.go
├── shard.go
└── key.go
├── client
├── conn_test.go
├── db.go
├── stmt.go
└── stmt_test.go
├── doc
└── mysql-proxy
│ └── scripting.txt
├── sqltypes
└── sqltypes.go
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | bin
2 | pkg
3 | .DS_Store
4 | y.output
5 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | all: build
2 |
3 | build:
4 | go install ./...
5 |
6 | clean:
7 | go clean -i ./...
8 |
9 | test:
10 | go test ./...
--------------------------------------------------------------------------------
/mysql/result.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | type Result struct {
4 | Status uint16
5 |
6 | InsertId uint64
7 | AffectedRows uint64
8 |
9 | *Resultset
10 | }
11 |
--------------------------------------------------------------------------------
/bootstrap.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -f bootstrap.sh ]; then
4 | echo "bootstrap.sh must be run from its current directory" 1>&2
5 | exit 1
6 | fi
7 |
8 | source ./dev.env
9 |
10 | go get github.com/siddontang/go-log/log
11 | go get github.com/siddontang/go-yaml/yaml
--------------------------------------------------------------------------------
/sqlparser/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright 2012, Google Inc. All rights reserved.
2 | # Use of this source code is governed by a BSD-style license that can
3 | # be found in the LICENSE file.
4 |
5 | MAKEFLAGS = -s
6 |
7 | sql.go: sql.y
8 | go tool yacc -o sql.go sql.y
9 | gofmt -w sql.go
10 |
11 | clean:
12 | rm -f y.output sql.go
13 |
--------------------------------------------------------------------------------
/dev.env:
--------------------------------------------------------------------------------
1 | export VTTOP=$(pwd)
2 | export VTROOT="${VTROOT:-${VTTOP/\/src\/github.com\/siddontang\/mixer/}}"
3 | # VTTOP sanity check
4 | if [[ "$VTTOP" == "${VTTOP/\/src\/github.com\/siddontang\/mixer/}" ]]; then
5 | echo "WARNING: VTTOP($VTTOP) does not contain src/github.com/siddontang/mixer"
6 | fi
7 |
8 | export GOTOP=$VTTOP
9 |
10 | function prepend_path()
11 | {
12 | # $1 path variable
13 | # $2 path to add
14 | if [ -d "$2" ] && [[ ":$1:" != *":$2:"* ]]; then
15 | echo "$2:$1"
16 | else
17 | echo "$1"
18 | fi
19 | }
20 |
21 | export GOPATH=$(prepend_path $GOPATH $VTROOT)
22 |
--------------------------------------------------------------------------------
/hack/hack_test.go:
--------------------------------------------------------------------------------
1 | package hack
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | )
7 |
8 | func TestString(t *testing.T) {
9 | b := []byte("hello world")
10 | a := String(b)
11 |
12 | if a != "hello world" {
13 | t.Fatal(a)
14 | }
15 |
16 | b[0] = 'a'
17 |
18 | if a != "aello world" {
19 | t.Fatal(a)
20 | }
21 |
22 | b = append(b, "abc"...)
23 | if a != "aello world" {
24 | t.Fatal(a)
25 | }
26 | }
27 |
28 | func TestByte(t *testing.T) {
29 | a := "hello world"
30 |
31 | b := Slice(a)
32 |
33 | if !bytes.Equal(b, []byte("hello world")) {
34 | t.Fatal(string(b))
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/hack/hack.go:
--------------------------------------------------------------------------------
1 | package hack
2 |
3 | import (
4 | "reflect"
5 | "unsafe"
6 | )
7 |
8 | // no copy to change slice to string
9 | // use your own risk
10 | func String(b []byte) (s string) {
11 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
12 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
13 | pstring.Data = pbytes.Data
14 | pstring.Len = pbytes.Len
15 | return
16 | }
17 |
18 | // no copy to change string to slice
19 | // use your own risk
20 | func Slice(s string) (b []byte) {
21 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
22 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
23 | pbytes.Data = pstring.Data
24 | pbytes.Len = pstring.Len
25 | pbytes.Cap = pstring.Len
26 | return
27 | }
28 |
--------------------------------------------------------------------------------
/sqlparser/sql_test.go:
--------------------------------------------------------------------------------
1 | package sqlparser
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func testParse(t *testing.T, sql string) {
8 | _, err := Parse(sql)
9 | if err != nil {
10 | t.Fatal(err)
11 | }
12 |
13 | }
14 |
15 | func TestSet(t *testing.T) {
16 | sql := "set names gbk"
17 | testParse(t, sql)
18 | }
19 |
20 | func TestSimpleSelect(t *testing.T) {
21 | sql := "select last_insert_id() as a"
22 | testParse(t, sql)
23 | }
24 |
25 | func TestMixer(t *testing.T) {
26 | sql := `admin upnode("node1", "master", "127.0.0.1")`
27 | testParse(t, sql)
28 |
29 | sql = "show databases"
30 | testParse(t, sql)
31 |
32 | sql = "show tables from abc"
33 | testParse(t, sql)
34 |
35 | sql = "show tables from abc like a"
36 | testParse(t, sql)
37 |
38 | sql = "show tables from abc where a = 1"
39 | testParse(t, sql)
40 |
41 | sql = "show proxy abc"
42 | testParse(t, sql)
43 | }
44 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2014 siddontang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/etc/mixer_multi.conf.yaml:
--------------------------------------------------------------------------------
1 | addr : 127.0.0.1:4000
2 | user : root
3 | password :
4 | log_level : error
5 |
6 | nodes :
7 | -
8 | name : node1
9 | down_after_noalive : 300
10 | idle_conns : 16
11 | rw_split: false
12 | user: root
13 | password:
14 | master : 127.0.0.1:3306
15 | slave :
16 | -
17 | name : node2
18 | down_after_noalive : 300
19 | 0 : 16
20 | rw_split: false
21 | user: root
22 | password:
23 | master : 127.0.0.1:3307
24 |
25 | -
26 | name : node3
27 | down_after_noalive : 300
28 | idle_conns : 16
29 | rw_split: false
30 | user: root
31 | password:
32 | master : 127.0.0.1:3308
33 |
34 | schemas :
35 | -
36 | db : mixer
37 | nodes: [node1,node2,node3]
38 | rules:
39 | default: node1
40 | shard:
41 | -
42 | table: mixer_test_shard_hash
43 | key: id
44 | nodes: [node2, node3]
45 | type: hash
46 |
47 | -
48 | table: mixer_test_shard_range
49 | key: id
50 | type: range
51 | nodes: [node2, node3]
52 | range: -10000-
53 |
--------------------------------------------------------------------------------
/sqlparser/analyzer_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import "testing"
8 |
9 | func TestGetDBName(t *testing.T) {
10 | wantYes := []string{
11 | "insert into a.b values(1)",
12 | "update a.b set c=1",
13 | "delete from a.b where c=d",
14 | }
15 | for _, stmt := range wantYes {
16 | result, err := GetDBName(stmt)
17 | if err != nil {
18 | t.Errorf("error %v on %s", err, stmt)
19 | continue
20 | }
21 | if result != "a" {
22 | t.Errorf("want a, got %s", result)
23 | }
24 | }
25 |
26 | wantNo := []string{
27 | "insert into a values(1)",
28 | "update a set c=1",
29 | "delete from a where c=d",
30 | }
31 | for _, stmt := range wantNo {
32 | result, err := GetDBName(stmt)
33 | if err != nil {
34 | t.Errorf("error %v on %s", err, stmt)
35 | continue
36 | }
37 | if result != "" {
38 | t.Errorf("want '', got %s", result)
39 | }
40 | }
41 |
42 | wantErr := []string{
43 | "select * from a",
44 | "syntax error",
45 | }
46 | for _, stmt := range wantErr {
47 | _, err := GetDBName(stmt)
48 | if err == nil {
49 | t.Errorf("want error, got nil")
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/etc/mixer_single.conf.yaml:
--------------------------------------------------------------------------------
1 | # Mixer configuration, using yaml format
2 |
3 | # server listen addr
4 | addr : 127.0.0.1:4000
5 |
6 | # server user and password
7 | user : root
8 | password :
9 |
10 | # log level[debug|info|warn|error],default error
11 | log_level : error
12 |
13 | # node is an agenda for real remote mysql server.
14 | nodes :
15 | -
16 | name : node1
17 |
18 | # default max idle conns for mysql server
19 | idle_conns : 16
20 |
21 | # if rw_split is true, select will use slave server
22 | rw_split: true
23 |
24 | # all mysql in a node must have the same user and password
25 | user : root
26 | password :
27 |
28 | # master represents a real mysql master server
29 | master : 127.0.0.1:3306
30 |
31 | # slave represents a real mysql salve server
32 | slave :
33 |
34 | # down mysql after N seconds noalive
35 | # 0 will no down
36 | down_after_noalive : 0
37 |
38 |
39 |
40 | # schema defines which db can be used by client and this db's sql will be executed in which nodes
41 | schemas :
42 | -
43 | db : mixer
44 | nodes: [node1]
45 | rules:
46 | # any other table not set above will use default [node1]
47 | default: node1
48 | shard:
49 | # empty shard tables rule
50 | -
--------------------------------------------------------------------------------
/mysql/error.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | )
7 |
8 | var (
9 | ErrBadConn = errors.New("connection was bad")
10 | ErrMalformPacket = errors.New("Malform packet error")
11 |
12 | ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
13 | )
14 |
15 | type SqlError struct {
16 | Code uint16
17 | Message string
18 | State string
19 | }
20 |
21 | func (e *SqlError) Error() string {
22 | return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.State, e.Message)
23 | }
24 |
25 | //default mysql error, must adapt errname message format
26 | func NewDefaultError(errCode uint16, args ...interface{}) *SqlError {
27 | e := new(SqlError)
28 | e.Code = errCode
29 |
30 | if s, ok := MySQLState[errCode]; ok {
31 | e.State = s
32 | } else {
33 | e.State = DEFAULT_MYSQL_STATE
34 | }
35 |
36 | if format, ok := MySQLErrName[errCode]; ok {
37 | e.Message = fmt.Sprintf(format, args...)
38 | } else {
39 | e.Message = fmt.Sprint(args...)
40 | }
41 |
42 | return e
43 | }
44 |
45 | func NewError(errCode uint16, message string) *SqlError {
46 | e := new(SqlError)
47 | e.Code = errCode
48 |
49 | if s, ok := MySQLState[errCode]; ok {
50 | e.State = s
51 | } else {
52 | e.State = DEFAULT_MYSQL_STATE
53 | }
54 |
55 | e.Message = message
56 |
57 | return e
58 | }
59 |
--------------------------------------------------------------------------------
/proxy/schema.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/router"
6 | )
7 |
8 | type Schema struct {
9 | db string
10 |
11 | nodes map[string]*Node
12 |
13 | rule *router.Router
14 | }
15 |
16 | func (s *Server) parseSchemas() error {
17 | s.schemas = make(map[string]*Schema)
18 |
19 | for _, schemaCfg := range s.cfg.Schemas {
20 | if _, ok := s.schemas[schemaCfg.DB]; ok {
21 | return fmt.Errorf("duplicate schema [%s].", schemaCfg.DB)
22 | }
23 | if len(schemaCfg.Nodes) == 0 {
24 | return fmt.Errorf("schema [%s] must have a node.", schemaCfg.DB)
25 | }
26 |
27 | nodes := make(map[string]*Node)
28 | for _, n := range schemaCfg.Nodes {
29 | if s.getNode(n) == nil {
30 | return fmt.Errorf("schema [%s] node [%s] config is not exists.", schemaCfg.DB, n)
31 | }
32 |
33 | if _, ok := nodes[n]; ok {
34 | return fmt.Errorf("schema [%s] node [%s] duplicate.", schemaCfg.DB, n)
35 | }
36 |
37 | nodes[n] = s.getNode(n)
38 | }
39 |
40 | rule, err := router.NewRouter(&schemaCfg)
41 | if err != nil {
42 | return err
43 | }
44 |
45 | s.schemas[schemaCfg.DB] = &Schema{
46 | db: schemaCfg.DB,
47 | nodes: nodes,
48 | rule: rule,
49 | }
50 | }
51 |
52 | return nil
53 | }
54 |
55 | func (s *Server) getSchema(db string) *Schema {
56 | return s.schemas[db]
57 | }
58 |
--------------------------------------------------------------------------------
/license_vitess:
--------------------------------------------------------------------------------
1 | Copyright 2012, Google Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are
6 | met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 | * Redistributions in binary form must reproduce the above
11 | copyright notice, this list of conditions and the following disclaimer
12 | in the documentation and/or other materials provided with the
13 | distribution.
14 | * Neither the name of Google Inc. nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/proxy/conn_set.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | . "github.com/siddontang/mixer/mysql"
6 | "github.com/siddontang/mixer/sqlparser"
7 | "strings"
8 | )
9 |
10 | var nstring = sqlparser.String
11 |
12 | func (c *Conn) handleSet(stmt *sqlparser.Set) error {
13 | if len(stmt.Exprs) != 1 {
14 | return fmt.Errorf("must set one item once, not %s", nstring(stmt))
15 | }
16 |
17 | k := string(stmt.Exprs[0].Name.Name)
18 |
19 | switch strings.ToUpper(k) {
20 | case `AUTOCOMMIT`:
21 | return c.handleSetAutoCommit(stmt.Exprs[0].Expr)
22 | case `NAMES`:
23 | return c.handleSetNames(stmt.Exprs[0].Expr)
24 | default:
25 | return fmt.Errorf("set %s is not supported now", k)
26 | }
27 | }
28 |
29 | func (c *Conn) handleSetAutoCommit(val sqlparser.ValExpr) error {
30 | value, ok := val.(sqlparser.NumVal)
31 | if !ok {
32 | return fmt.Errorf("set autocommit error")
33 | }
34 | switch value[0] {
35 | case '1':
36 | c.status |= SERVER_STATUS_AUTOCOMMIT
37 | case '0':
38 | c.status &= ^SERVER_STATUS_AUTOCOMMIT
39 | default:
40 | return fmt.Errorf("invalid autocommit flag %s", value)
41 | }
42 |
43 | return c.writeOK(nil)
44 | }
45 |
46 | func (c *Conn) handleSetNames(val sqlparser.ValExpr) error {
47 | value, ok := val.(sqlparser.StrVal)
48 | if !ok {
49 | return fmt.Errorf("set names charset error")
50 | }
51 |
52 | charset := strings.ToLower(string(value))
53 | cid, ok := CharsetIds[charset]
54 | if !ok {
55 | return fmt.Errorf("invalid charset %s", charset)
56 | }
57 |
58 | c.charset = charset
59 | c.collation = cid
60 |
61 | return c.writeOK(nil)
62 | }
63 |
--------------------------------------------------------------------------------
/proxy/conn_tx.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "github.com/siddontang/mixer/client"
5 | . "github.com/siddontang/mixer/mysql"
6 | )
7 |
8 | func (c *Conn) isInTransaction() bool {
9 | return c.status&SERVER_STATUS_IN_TRANS > 0
10 | }
11 |
12 | func (c *Conn) isAutoCommit() bool {
13 | return c.status&SERVER_STATUS_AUTOCOMMIT > 0
14 | }
15 |
16 | func (c *Conn) handleBegin() error {
17 | c.status |= SERVER_STATUS_IN_TRANS
18 | return c.writeOK(nil)
19 | }
20 |
21 | func (c *Conn) handleCommit() (err error) {
22 | if err := c.commit(); err != nil {
23 | return err
24 | } else {
25 | return c.writeOK(nil)
26 | }
27 | }
28 |
29 | func (c *Conn) handleRollback() (err error) {
30 | if err := c.rollback(); err != nil {
31 | return err
32 | } else {
33 | return c.writeOK(nil)
34 | }
35 | }
36 |
37 | func (c *Conn) commit() (err error) {
38 | c.status &= ^SERVER_STATUS_IN_TRANS
39 |
40 | for _, co := range c.txConns {
41 | if e := co.Commit(); e != nil {
42 | err = e
43 | }
44 | co.Close()
45 | }
46 |
47 | c.txConns = map[*Node]*client.SqlConn{}
48 |
49 | return
50 | }
51 |
52 | func (c *Conn) rollback() (err error) {
53 | c.status &= ^SERVER_STATUS_IN_TRANS
54 |
55 | for _, co := range c.txConns {
56 | if e := co.Rollback(); e != nil {
57 | err = e
58 | }
59 | co.Close()
60 | }
61 |
62 | c.txConns = map[*Node]*client.SqlConn{}
63 |
64 | return
65 | }
66 |
67 | //if status is in_trans, need
68 | //else if status is not autocommit, need
69 | //else no need
70 | func (c *Conn) needBeginTx() bool {
71 | return c.isInTransaction() || !c.isAutoCommit()
72 | }
73 |
--------------------------------------------------------------------------------
/proxy/conn_admin.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/sqlparser"
6 | "strings"
7 | )
8 |
9 | func (c *Conn) handleAdmin(admin *sqlparser.Admin) error {
10 | name := string(admin.Name)
11 |
12 | var err error
13 | switch strings.ToLower(name) {
14 | case "upnode":
15 | err = c.adminUpNodeServer(admin.Values)
16 | case "downnode":
17 | err = c.adminDownNodeServer(admin.Values)
18 | default:
19 | return fmt.Errorf("admin %s not supported now", name)
20 | }
21 |
22 | if err != nil {
23 | return err
24 | }
25 |
26 | return c.writeOK(nil)
27 | }
28 |
29 | func (c *Conn) adminUpNodeServer(values sqlparser.ValExprs) error {
30 | if len(values) != 3 {
31 | return fmt.Errorf("upnode needs 3 args, not %d", len(values))
32 | }
33 |
34 | nodeName := nstring(values[0])
35 | sType := strings.ToLower(nstring(values[1]))
36 | addr := strings.ToLower(nstring(values[2]))
37 |
38 | switch sType {
39 | case Master:
40 | return c.server.UpMaster(nodeName, addr)
41 | case Slave:
42 | return c.server.UpSlave(nodeName, addr)
43 | default:
44 | return fmt.Errorf("invalid server type %s", sType)
45 | }
46 | }
47 |
48 | func (c *Conn) adminDownNodeServer(values sqlparser.ValExprs) error {
49 | if len(values) != 2 {
50 | return fmt.Errorf("upnode needs 2 args, not %d", len(values))
51 | }
52 |
53 | nodeName := nstring(values[0])
54 | sType := strings.ToLower(nstring(values[1]))
55 |
56 | switch sType {
57 | case Master:
58 | return c.server.DownMaster(nodeName)
59 | case Slave:
60 | return c.server.DownSlave(nodeName)
61 | default:
62 | return fmt.Errorf("invalid server type %s", sType)
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/cmd/mixer-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "github.com/siddontang/go-log/log"
6 | "github.com/siddontang/mixer/config"
7 | "github.com/siddontang/mixer/proxy"
8 | "os"
9 | "os/signal"
10 | "runtime"
11 | "strings"
12 | "syscall"
13 | )
14 |
15 | var configFile *string = flag.String("config", "/etc/mixer.conf", "mixer proxy config file")
16 | var logLevel *string = flag.String("log-level", "", "log level [debug|info|warn|error], default error")
17 |
18 | func main() {
19 | runtime.GOMAXPROCS(runtime.NumCPU())
20 |
21 | flag.Parse()
22 |
23 | if len(*configFile) == 0 {
24 | log.Error("must use a config file")
25 | return
26 | }
27 |
28 | cfg, err := config.ParseConfigFile(*configFile)
29 | if err != nil {
30 | log.Error(err.Error())
31 | return
32 | }
33 |
34 | if *logLevel != "" {
35 | setLogLevel(*logLevel)
36 | } else {
37 | setLogLevel(cfg.LogLevel)
38 | }
39 |
40 | var svr *proxy.Server
41 | svr, err = proxy.NewServer(cfg)
42 | if err != nil {
43 | log.Error(err.Error())
44 | return
45 | }
46 |
47 | sc := make(chan os.Signal, 1)
48 | signal.Notify(sc,
49 | syscall.SIGHUP,
50 | syscall.SIGINT,
51 | syscall.SIGTERM,
52 | syscall.SIGQUIT)
53 |
54 | go func() {
55 | sig := <-sc
56 | log.Info("Got signal [%d] to exit.", sig)
57 | svr.Close()
58 | }()
59 |
60 | svr.Run()
61 | }
62 |
63 | func setLogLevel(level string) {
64 | switch strings.ToLower(level) {
65 | case "debug":
66 | log.SetLevel(log.LevelDebug)
67 | case "info":
68 | log.SetLevel(log.LevelInfo)
69 | case "warn":
70 | log.SetLevel(log.LevelWarn)
71 | case "error":
72 | log.SetLevel(log.LevelError)
73 | default:
74 | log.SetLevel(log.LevelError)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/siddontang/go-yaml/yaml"
5 | "io/ioutil"
6 | )
7 |
8 | type NodeConfig struct {
9 | Name string `yaml:"name"`
10 | DownAfterNoAlive int `yaml:"down_after_noalive"`
11 | IdleConns int `yaml:"idle_conns"`
12 | RWSplit bool `yaml:"rw_split"`
13 |
14 | User string `yaml:"user"`
15 | Password string `yaml:"password"`
16 |
17 | Master string `yaml:"master"`
18 | Slave string `yaml:"slave"`
19 | }
20 |
21 | type SchemaConfig struct {
22 | DB string `yaml:"db"`
23 | Nodes []string `yaml:"nodes"`
24 | RulesConifg RulesConfig `yaml:"rules"`
25 | }
26 |
27 | type RulesConfig struct {
28 | Default string `yaml:"default"`
29 | ShardRule []ShardConfig `yaml:"shard"`
30 | }
31 |
32 | type ShardConfig struct {
33 | Table string `yaml:"table"`
34 | Key string `yaml:"key"`
35 | Nodes []string `yaml:"nodes"`
36 | Type string `yaml:"type"`
37 | Range string `yaml:"range"`
38 | }
39 |
40 | type Config struct {
41 | Addr string `yaml:"addr"`
42 | User string `yaml:"user"`
43 | Password string `yaml:"password"`
44 | LogLevel string `yaml:"log_level"`
45 |
46 | Nodes []NodeConfig `yaml:"nodes"`
47 |
48 | Schemas []SchemaConfig `yaml:"schemas"`
49 | }
50 |
51 | func ParseConfigData(data []byte) (*Config, error) {
52 | var cfg Config
53 | if err := yaml.Unmarshal([]byte(data), &cfg); err != nil {
54 | return nil, err
55 | }
56 | return &cfg, nil
57 | }
58 |
59 | func ParseConfigFile(fileName string) (*Config, error) {
60 | data, err := ioutil.ReadFile(fileName)
61 | if err != nil {
62 | return nil, err
63 | }
64 |
65 | return ParseConfigData(data)
66 | }
67 |
--------------------------------------------------------------------------------
/etc/mixer.conf.yaml:
--------------------------------------------------------------------------------
1 | # Mixer configuration, using yaml format
2 |
3 | # server listen addr
4 | addr : 127.0.0.1:4000
5 |
6 | # server user and password
7 | user : root
8 | password :
9 |
10 | # log level[debug|info|warn|error],default error
11 | log_level : error
12 |
13 | # node is an agenda for real remote mysql server.
14 | nodes :
15 | -
16 | name : node1
17 |
18 | # default max idle conns for mysql server
19 | idle_conns : 16
20 |
21 | # if rw_split is true, select will use slave server
22 | rw_split: true
23 |
24 | # all mysql in a node must have the same user and password
25 | user : root
26 | password:
27 |
28 | # master represents a real mysql master server
29 | master : 127.0.0.1:3306
30 |
31 | # slave represents a real mysql salve server
32 | slave : 127.0.0.1:4306
33 |
34 | # down mysql after N seconds noalive
35 | # 0 will no down
36 | down_after_noalive : 300
37 |
38 | -
39 | name : node2
40 | user: root
41 | password:
42 | master : 127.0.0.1:3308
43 |
44 |
45 | # schema defines which db can be used by client and this db's sql will be executed in which nodes
46 | schemas :
47 | -
48 | db : mixer
49 | nodes: [node1, node2]
50 |
51 | # rule defines how sql executed in nodes
52 | rules:
53 | # any other table not set above will use default [node1]
54 | default: node1
55 | shard:
56 | -
57 | table: test1
58 | key: id
59 | type: hash
60 | # node will be node1, node2
61 | nodes: [node1, node2]
62 |
63 | -
64 | table: test2
65 | key: name
66 | nodes: [node1, node2]
67 | type: range
68 | # range is left close and right open
69 | # node1 range (-inf, 10000)
70 | # node2 range [10000, 20000)
71 | range: -10000-
72 |
--------------------------------------------------------------------------------
/router/router_test.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "github.com/siddontang/go-yaml/yaml"
5 | "github.com/siddontang/mixer/config"
6 | "testing"
7 | )
8 |
9 | func TestParseRule(t *testing.T) {
10 | var s = `
11 | schemas :
12 | -
13 | db : mixer
14 | nodes: [node1, node2, node3]
15 | rules:
16 | default: node1
17 | shard:
18 | -
19 | table: mixer_test_shard_hash
20 | key: id
21 | nodes: [node2, node3]
22 | type: hash
23 |
24 | -
25 | table: mixer_test_shard_range
26 | key: id
27 | type: range
28 | nodes: [node2, node3]
29 | range: -10000-
30 | `
31 | var cfg config.Config
32 | if err := yaml.Unmarshal([]byte(s), &cfg); err != nil {
33 | t.Fatal(err)
34 | }
35 |
36 | rt, err := NewRouter(&cfg.Schemas[0])
37 | if err != nil {
38 | t.Fatal(err)
39 | }
40 | if rt.DefaultRule.Nodes[0] != "node1" {
41 | t.Fatal("default rule parse not correct.")
42 | }
43 |
44 | hashRule := rt.GetRule("mixer_test_shard_hash")
45 | if hashRule.Type != HashRuleType {
46 | t.Fatal(hashRule.Type)
47 | }
48 |
49 | if len(hashRule.Nodes) != 2 || hashRule.Nodes[0] != "node2" || hashRule.Nodes[1] != "node3" {
50 | t.Fatal("parse nodes not correct.")
51 | }
52 |
53 | if n := hashRule.FindNode(uint64(11)); n != "node3" {
54 | t.Fatal(n)
55 | }
56 |
57 | rangeRule := rt.GetRule("mixer_test_shard_range")
58 | if rangeRule.Type != RangeRuleType {
59 | t.Fatal(rangeRule.Type)
60 | }
61 |
62 | if n := rangeRule.FindNode(10000 - 1); n != "node2" {
63 | t.Fatal(n)
64 | }
65 |
66 | defaultRule := rt.GetRule("mixer_defaultRule_table")
67 | if defaultRule == nil {
68 | t.Fatal("must not nil")
69 | }
70 |
71 | if defaultRule.Type != DefaultRuleType {
72 | t.Fatal(defaultRule.Type)
73 | }
74 |
75 | if defaultRule.Shard == nil {
76 | t.Fatal("nil error")
77 | }
78 |
79 | if n := defaultRule.FindNode(11); n != "node1" {
80 | t.Fatal(n)
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/router/numkey.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "strconv"
7 | "strings"
8 | )
9 |
10 | const (
11 | MinNumKey = math.MinInt64
12 | MaxNumKey = math.MaxInt64
13 | )
14 |
15 | type NumKeyRange struct {
16 | Start int64
17 | End int64
18 | }
19 |
20 | func (kr NumKeyRange) MapKey() string {
21 | return fmt.Sprintf("%d-%d", kr.String(), kr.End)
22 | }
23 |
24 | func (kr NumKeyRange) Contains(i int64) bool {
25 | return kr.Start <= i && (kr.End == MaxNumKey || i < kr.End)
26 | }
27 |
28 | func (kr NumKeyRange) String() string {
29 | return fmt.Sprintf("{Start: %d, End: %d}", kr.Start, kr.End)
30 | }
31 |
32 | // ParseShardingSpec parses a string that describes a sharding
33 | // specification. a-b-c-d will be parsed as a-b, b-c, c-d. The empty
34 | // string may serve both as the start and end of the keyspace: -a-b-
35 | // will be parsed as start-a, a-b, b-end.
36 | func ParseNumShardingSpec(spec string) ([]NumKeyRange, error) {
37 | parts := strings.Split(spec, "-")
38 | if len(parts) == 1 {
39 | return nil, fmt.Errorf("malformed spec: doesn't define a range: %q", spec)
40 | }
41 | var old int64
42 | var err error
43 | if len(parts[0]) != 0 {
44 | old, err = strconv.ParseInt(parts[0], 10, 64)
45 | if err != nil {
46 | return nil, err
47 | }
48 | } else {
49 | old = MinNumKey
50 | }
51 |
52 | ranges := make([]NumKeyRange, len(parts)-1)
53 |
54 | var n int64
55 |
56 | for i, p := range parts[1:] {
57 | if p == "" && i != (len(parts)-2) {
58 | return nil, fmt.Errorf("malformed spec: MinKey/MaxKey cannot be in the middle of the spec: %q", spec)
59 | }
60 |
61 | if p != "" {
62 | n, err = strconv.ParseInt(p, 10, 64)
63 | if err != nil {
64 | return nil, err
65 | }
66 | } else {
67 | n = MaxNumKey
68 | }
69 | if n <= old {
70 | return nil, fmt.Errorf("malformed spec: shard limits should be in order: %q", spec)
71 | }
72 |
73 | ranges[i] = NumKeyRange{Start: old, End: n}
74 | old = n
75 | }
76 | return ranges, nil
77 | }
78 |
--------------------------------------------------------------------------------
/proxy/server.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "github.com/siddontang/go-log/log"
5 | "github.com/siddontang/mixer/config"
6 |
7 | "net"
8 | "runtime"
9 | "strings"
10 | )
11 |
12 | type Server struct {
13 | cfg *config.Config
14 |
15 | addr string
16 | user string
17 | password string
18 |
19 | running bool
20 |
21 | listener net.Listener
22 |
23 | nodes map[string]*Node
24 |
25 | schemas map[string]*Schema
26 | }
27 |
28 | func NewServer(cfg *config.Config) (*Server, error) {
29 | s := new(Server)
30 |
31 | s.cfg = cfg
32 |
33 | s.addr = cfg.Addr
34 | s.user = cfg.User
35 | s.password = cfg.Password
36 |
37 | if err := s.parseNodes(); err != nil {
38 | return nil, err
39 | }
40 |
41 | if err := s.parseSchemas(); err != nil {
42 | return nil, err
43 | }
44 |
45 | var err error
46 | netProto := "tcp"
47 | if strings.Contains(netProto, "/") {
48 | netProto = "unix"
49 | }
50 | s.listener, err = net.Listen(netProto, s.addr)
51 |
52 | if err != nil {
53 | return nil, err
54 | }
55 |
56 | log.Info("Server run MySql Protocol Listen(%s) at [%s]", netProto, s.addr)
57 | return s, nil
58 | }
59 |
60 | func (s *Server) Run() error {
61 | s.running = true
62 |
63 | for s.running {
64 | conn, err := s.listener.Accept()
65 | if err != nil {
66 | log.Error("accept error %s", err.Error())
67 | continue
68 | }
69 |
70 | go s.onConn(conn)
71 | }
72 |
73 | return nil
74 | }
75 |
76 | func (s *Server) Close() {
77 | s.running = false
78 | if s.listener != nil {
79 | s.listener.Close()
80 | }
81 | }
82 |
83 | func (s *Server) onConn(c net.Conn) {
84 | conn := s.newConn(c)
85 |
86 | defer func() {
87 | if err := recover(); err != nil {
88 | const size = 4096
89 | buf := make([]byte, size)
90 | buf = buf[:runtime.Stack(buf, false)]
91 | log.Error("onConn panic %v: %v\n%s", c.RemoteAddr().String(), err, buf)
92 | }
93 |
94 | conn.Close()
95 | }()
96 |
97 | if err := conn.Handshake(); err != nil {
98 | log.Error("handshake error %s", err.Error())
99 | c.Close()
100 | return
101 | }
102 |
103 | conn.Run()
104 |
105 | }
106 |
--------------------------------------------------------------------------------
/mysql/resultset_sort_test.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "sort"
7 | "testing"
8 | )
9 |
10 | func TestResultsetSort(t *testing.T) {
11 | r1 := new(Resultset)
12 | r2 := new(Resultset)
13 |
14 | r1.Values = [][]interface{}{
15 | []interface{}{int64(1), "a", []byte("aa")},
16 | []interface{}{int64(2), "a", []byte("bb")},
17 | []interface{}{int64(3), "c", []byte("bb")},
18 | }
19 |
20 | r1.RowDatas = []RowData{
21 | RowData([]byte("1")),
22 | RowData([]byte("2")),
23 | RowData([]byte("3")),
24 | }
25 |
26 | s := new(resultsetSorter)
27 |
28 | s.Resultset = r1
29 |
30 | s.sk = []SortKey{
31 | SortKey{column: 0, Direction: SortDesc},
32 | }
33 |
34 | sort.Sort(s)
35 |
36 | r2.Values = [][]interface{}{
37 | []interface{}{int64(3), "c", []byte("bb")},
38 | []interface{}{int64(2), "a", []byte("bb")},
39 | []interface{}{int64(1), "a", []byte("aa")},
40 | }
41 |
42 | r2.RowDatas = []RowData{
43 | RowData([]byte("3")),
44 | RowData([]byte("2")),
45 | RowData([]byte("1")),
46 | }
47 |
48 | if !reflect.DeepEqual(r1, r2) {
49 | t.Fatal(fmt.Sprintf("%v %v", r1, r2))
50 | }
51 |
52 | s.sk = []SortKey{
53 | SortKey{column: 1, Direction: SortAsc},
54 | SortKey{column: 2, Direction: SortDesc},
55 | }
56 |
57 | sort.Sort(s)
58 |
59 | r2.Values = [][]interface{}{
60 | []interface{}{int64(2), "a", []byte("bb")},
61 | []interface{}{int64(1), "a", []byte("aa")},
62 | []interface{}{int64(3), "c", []byte("bb")},
63 | }
64 |
65 | r2.RowDatas = []RowData{
66 | RowData([]byte("2")),
67 | RowData([]byte("1")),
68 | RowData([]byte("3")),
69 | }
70 |
71 | if !reflect.DeepEqual(r1, r2) {
72 | t.Fatal(fmt.Sprintf("%v %v", r1, r2))
73 | }
74 |
75 | s.sk = []SortKey{
76 | SortKey{column: 1, Direction: SortAsc},
77 | SortKey{column: 2, Direction: SortAsc},
78 | }
79 |
80 | sort.Sort(s)
81 |
82 | r2.Values = [][]interface{}{
83 | []interface{}{int64(1), "a", []byte("aa")},
84 | []interface{}{int64(2), "a", []byte("bb")},
85 | []interface{}{int64(3), "c", []byte("bb")},
86 | }
87 |
88 | r2.RowDatas = []RowData{
89 | RowData([]byte("1")),
90 | RowData([]byte("2")),
91 | RowData([]byte("3")),
92 | }
93 |
94 | if !reflect.DeepEqual(r1, r2) {
95 | t.Fatal(fmt.Sprintf("%v %v", r1, r2))
96 | }
97 |
98 | }
99 |
--------------------------------------------------------------------------------
/mysql/packetio.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "io"
7 | "net"
8 | )
9 |
10 | type PacketIO struct {
11 | rb *bufio.Reader
12 | wb io.Writer
13 |
14 | Sequence uint8
15 | }
16 |
17 | func NewPacketIO(conn net.Conn) *PacketIO {
18 | p := new(PacketIO)
19 |
20 | p.rb = bufio.NewReaderSize(conn, 1024)
21 | p.wb = conn
22 |
23 | p.Sequence = 0
24 |
25 | return p
26 | }
27 |
28 | func (p *PacketIO) ReadPacket() ([]byte, error) {
29 | header := []byte{0, 0, 0, 0}
30 |
31 | if _, err := io.ReadFull(p.rb, header); err != nil {
32 | return nil, ErrBadConn
33 | }
34 |
35 | length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
36 | if length < 1 {
37 | return nil, fmt.Errorf("invalid payload length %d", length)
38 | }
39 |
40 | sequence := uint8(header[3])
41 |
42 | if sequence != p.Sequence {
43 | return nil, fmt.Errorf("invalid sequence %d != %d", sequence, p.Sequence)
44 | }
45 |
46 | p.Sequence++
47 |
48 | data := make([]byte, length)
49 | if _, err := io.ReadFull(p.rb, data); err != nil {
50 | return nil, ErrBadConn
51 | } else {
52 | if length < MaxPayloadLen {
53 | return data, nil
54 | }
55 |
56 | var buf []byte
57 | buf, err = p.ReadPacket()
58 | if err != nil {
59 | return nil, ErrBadConn
60 | } else {
61 | return append(data, buf...), nil
62 | }
63 | }
64 | }
65 |
66 | //data already have header
67 | func (p *PacketIO) WritePacket(data []byte) error {
68 | length := len(data) - 4
69 |
70 | for length >= MaxPayloadLen {
71 |
72 | data[0] = 0xff
73 | data[1] = 0xff
74 | data[2] = 0xff
75 |
76 | data[3] = p.Sequence
77 |
78 | if n, err := p.wb.Write(data[:4+MaxPayloadLen]); err != nil {
79 | return ErrBadConn
80 | } else if n != (4 + MaxPayloadLen) {
81 | return ErrBadConn
82 | } else {
83 | p.Sequence++
84 | length -= MaxPayloadLen
85 | data = data[MaxPayloadLen:]
86 | }
87 | }
88 |
89 | data[0] = byte(length)
90 | data[1] = byte(length >> 8)
91 | data[2] = byte(length >> 16)
92 | data[3] = p.Sequence
93 |
94 | if n, err := p.wb.Write(data); err != nil {
95 | return ErrBadConn
96 | } else if n != len(data) {
97 | return ErrBadConn
98 | } else {
99 | p.Sequence++
100 | return nil
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/proxy/server_test.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "github.com/siddontang/mixer/client"
5 | "github.com/siddontang/mixer/config"
6 | "sync"
7 | "testing"
8 | "time"
9 | )
10 |
11 | var testServerOnce sync.Once
12 | var testServer *Server
13 | var testDBOnce sync.Once
14 | var testDB *client.DB
15 |
16 | var testConfigData = []byte(`
17 | addr : 127.0.0.1:4000
18 | user : root
19 | password :
20 |
21 | nodes :
22 | -
23 | name : node1
24 | down_after_noalive : 300
25 | idle_conns : 16
26 | rw_split: false
27 | user: root
28 | password:
29 | master : 127.0.0.1:3306
30 | slave :
31 | -
32 | name : node2
33 | down_after_noalive : 300
34 | idle_conns : 16
35 | rw_split: false
36 | user: root
37 | password:
38 | master : 127.0.0.1:3307
39 |
40 | -
41 | name : node3
42 | down_after_noalive : 300
43 | idle_conns : 16
44 | rw_split: false
45 | user: root
46 | password:
47 | master : 127.0.0.1:3308
48 |
49 | schemas :
50 | -
51 | db : mixer
52 | nodes: [node1, node2, node3]
53 | rules:
54 | default: node1
55 | shard:
56 | -
57 | table: mixer_test_shard_hash
58 | key: id
59 | nodes: [node2, node3]
60 | type: hash
61 |
62 | -
63 | table: mixer_test_shard_range
64 | key: id
65 | nodes: [node2, node3]
66 | range: -10000-
67 | type: range
68 | `)
69 |
70 | func newTestServer(t *testing.T) *Server {
71 | f := func() {
72 | cfg, err := config.ParseConfigData(testConfigData)
73 | if err != nil {
74 | t.Fatal(err.Error())
75 | }
76 |
77 | testServer, err = NewServer(cfg)
78 | if err != nil {
79 | t.Fatal(err)
80 | }
81 |
82 | go testServer.Run()
83 |
84 | time.Sleep(1 * time.Second)
85 | }
86 |
87 | testServerOnce.Do(f)
88 |
89 | return testServer
90 | }
91 |
92 | func newTestDB(t *testing.T) *client.DB {
93 | newTestServer(t)
94 |
95 | f := func() {
96 | var err error
97 | testDB, err = client.Open("127.0.0.1:4000", "root", "", "mixer")
98 |
99 | if err != nil {
100 | t.Fatal(err)
101 | }
102 |
103 | testDB.SetMaxIdleConnNum(4)
104 | }
105 |
106 | testDBOnce.Do(f)
107 | return testDB
108 | }
109 |
110 | func newTestDBConn(t *testing.T) *client.SqlConn {
111 | db := newTestDB(t)
112 |
113 | c, err := db.GetConn()
114 |
115 | if err != nil {
116 | t.Fatal(err)
117 | }
118 |
119 | return c
120 | }
121 |
122 | func TestServer(t *testing.T) {
123 | newTestServer(t)
124 | }
125 |
--------------------------------------------------------------------------------
/router/config.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/config"
6 | "regexp"
7 | "strconv"
8 | "strings"
9 | )
10 |
11 | var (
12 | DefaultRuleType = "default"
13 | HashRuleType = "hash"
14 | RangeRuleType = "range"
15 | )
16 |
17 | type RuleConfig struct {
18 | config.ShardConfig
19 | }
20 |
21 | func (c *RuleConfig) ParseRule(db string) (*Rule, error) {
22 | r := new(Rule)
23 | r.DB = db
24 | r.Table = c.Table
25 | r.Key = c.Key
26 | r.Type = c.Type
27 | r.Nodes = c.Nodes
28 |
29 | if err := c.parseShard(r); err != nil {
30 | return nil, err
31 | }
32 |
33 | return r, nil
34 | }
35 |
36 | func (c *RuleConfig) parseNodes(r *Rule) error {
37 | // Note: did not used yet, by HuangChuanTong
38 | reg, err := regexp.Compile(`(\w+)\((\d+)\-(\d+)\)`)
39 | if err != nil {
40 | return err
41 | }
42 |
43 | ns := c.Nodes // strings.Split(c.Nodes, ",")
44 |
45 | nodes := map[string]struct{}{}
46 |
47 | for _, n := range ns {
48 | n = strings.TrimSpace(n)
49 | if s := reg.FindStringSubmatch(n); s == nil {
50 | if _, ok := nodes[n]; ok {
51 | return fmt.Errorf("duplicate node %s", n)
52 | }
53 |
54 | nodes[n] = struct{}{}
55 | r.Nodes = append(r.Nodes, n)
56 | } else {
57 | var start, stop int
58 | if start, err = strconv.Atoi(s[2]); err != nil {
59 | return err
60 | }
61 |
62 | if stop, err = strconv.Atoi(s[3]); err != nil {
63 | return err
64 | }
65 |
66 | if start >= stop {
67 | return fmt.Errorf("invalid node format %s", n)
68 | }
69 |
70 | for i := start; i <= stop; i++ {
71 | n = fmt.Sprintf("%s%d", s[1], i)
72 |
73 | if _, ok := nodes[n]; ok {
74 | return fmt.Errorf("duplicate node %s", n)
75 | }
76 |
77 | nodes[n] = struct{}{}
78 | r.Nodes = append(r.Nodes, n)
79 |
80 | }
81 | }
82 | }
83 |
84 | if len(r.Nodes) == 0 {
85 | return fmt.Errorf("empty nodes info")
86 | }
87 |
88 | if r.Type == DefaultRuleType && len(r.Nodes) != 1 {
89 | return fmt.Errorf("default rule must have only one node")
90 | }
91 |
92 | return nil
93 | }
94 |
95 | func (c *RuleConfig) parseShard(r *Rule) error {
96 | if r.Type == HashRuleType {
97 | //hash shard
98 | r.Shard = &HashShard{ShardNum: len(r.Nodes)}
99 | } else if r.Type == RangeRuleType {
100 | rs, err := ParseNumShardingSpec(c.Range)
101 | if err != nil {
102 | return err
103 | }
104 |
105 | if len(rs) != len(r.Nodes) {
106 | return fmt.Errorf("range space %d not equal nodes %d", len(rs), len(r.Nodes))
107 | }
108 |
109 | r.Shard = &NumRangeShard{Shards: rs}
110 | } else {
111 | r.Shard = &DefaultShard{}
112 | }
113 |
114 | return nil
115 | }
116 |
--------------------------------------------------------------------------------
/router/router.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/config"
6 | "strings"
7 | )
8 |
9 | type Rule struct {
10 | DB string
11 | Table string
12 | Key string
13 |
14 | Type string
15 |
16 | Nodes []string
17 | Shard Shard
18 | }
19 |
20 | func (r *Rule) FindNode(key interface{}) string {
21 | i := r.Shard.FindForKey(key)
22 | return r.Nodes[i]
23 | }
24 |
25 | func (r *Rule) FindNodeIndex(key interface{}) int {
26 | return r.Shard.FindForKey(key)
27 | }
28 |
29 | func (r *Rule) String() string {
30 | return fmt.Sprintf("%s.%s?key=%v&shard=%s&nodes=%s",
31 | r.DB, r.Table, r.Key, r.Type, strings.Join(r.Nodes, ", "))
32 | }
33 |
34 | func NewDefaultRule(db string, node string) *Rule {
35 | var r *Rule = &Rule{
36 | DB: db,
37 | Type: DefaultRuleType,
38 | Nodes: []string{node},
39 | Shard: new(DefaultShard),
40 | }
41 | return r
42 | }
43 |
44 | func (r *Router) GetRule(table string) *Rule {
45 | rule := r.Rules[table]
46 | if rule == nil {
47 | return r.DefaultRule
48 | } else {
49 | return rule
50 | }
51 | }
52 |
53 | type Router struct {
54 | DB string
55 | Rules map[string]*Rule //key is
56 | DefaultRule *Rule
57 | nodes []string //just for human saw
58 | }
59 |
60 | func NewRouter(schemaConfig *config.SchemaConfig) (*Router, error) {
61 |
62 | if !includeNode(schemaConfig.Nodes, schemaConfig.RulesConifg.Default) {
63 | return nil, fmt.Errorf("default node[%s] not in the nodes list.",
64 | schemaConfig.RulesConifg.Default)
65 | }
66 |
67 | rt := new(Router)
68 | rt.DB = schemaConfig.DB
69 | rt.nodes = schemaConfig.Nodes
70 | rt.Rules = make(map[string]*Rule, len(schemaConfig.RulesConifg.ShardRule))
71 | rt.DefaultRule = NewDefaultRule(rt.DB, schemaConfig.RulesConifg.Default)
72 |
73 | for _, shard := range schemaConfig.RulesConifg.ShardRule {
74 | rc := &RuleConfig{shard}
75 | for _, node := range shard.Nodes {
76 | if !includeNode(rt.nodes, node) {
77 | return nil, fmt.Errorf("shard table[%s] node[%s] not in the schema.nodes list:[%s].",
78 | shard.Table, node, strings.Join(shard.Nodes, ","))
79 | }
80 | }
81 | rule, err := rc.ParseRule(rt.DB)
82 | if err != nil {
83 | return nil, err
84 | }
85 |
86 | if rule.Type == DefaultRuleType {
87 | return nil, fmt.Errorf("[default-rule] duplicate, must only one.")
88 | } else {
89 | if _, ok := rt.Rules[rule.Table]; ok {
90 | return nil, fmt.Errorf("table %s rule in %s duplicate", rule.Table, rule.DB)
91 | }
92 | rt.Rules[rule.Table] = rule
93 | }
94 | }
95 | return rt, nil
96 | }
97 |
98 | func includeNode(nodes []string, node string) bool {
99 | for _, n := range nodes {
100 | if n == node {
101 | return true
102 | }
103 | }
104 | return false
105 | }
106 |
--------------------------------------------------------------------------------
/sqlparser/parsed_query.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import (
8 | "bytes"
9 | "encoding/json"
10 | "fmt"
11 | "strconv"
12 |
13 | "github.com/siddontang/mixer/sqltypes"
14 | )
15 |
16 | type BindLocation struct {
17 | Offset, Length int
18 | }
19 |
20 | type ParsedQuery struct {
21 | Query string
22 | BindLocations []BindLocation
23 | }
24 |
25 | type EncoderFunc func(value interface{}) ([]byte, error)
26 |
27 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}, listVariables []sqltypes.Value) ([]byte, error) {
28 | if len(pq.BindLocations) == 0 {
29 | return []byte(pq.Query), nil
30 | }
31 | buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query)))
32 | current := 0
33 | for _, loc := range pq.BindLocations {
34 | buf.WriteString(pq.Query[current:loc.Offset])
35 | varName := pq.Query[loc.Offset+1 : loc.Offset+loc.Length]
36 | var supplied interface{}
37 | if varName[0] >= '0' && varName[0] <= '9' {
38 | index, err := strconv.Atoi(varName)
39 | if err != nil {
40 | return nil, fmt.Errorf("unexpected: %v for %s", err, varName)
41 | }
42 | if index >= len(listVariables) {
43 | return nil, fmt.Errorf("index out of range: %d", index)
44 | }
45 | supplied = listVariables[index]
46 | } else if varName[0] == '*' {
47 | supplied = listVariables
48 | } else {
49 | var ok bool
50 | supplied, ok = bindVariables[varName]
51 | if !ok {
52 | return nil, fmt.Errorf("missing bind var %s", varName)
53 | }
54 | }
55 | if err := EncodeValue(buf, supplied); err != nil {
56 | return nil, err
57 | }
58 | current = loc.Offset + loc.Length
59 | }
60 | buf.WriteString(pq.Query[current:])
61 | return buf.Bytes(), nil
62 | }
63 |
64 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) {
65 | return json.Marshal(pq.Query)
66 | }
67 |
68 | func EncodeValue(buf *bytes.Buffer, value interface{}) error {
69 | switch bindVal := value.(type) {
70 | case nil:
71 | buf.WriteString("null")
72 | case []sqltypes.Value:
73 | for i := 0; i < len(bindVal); i++ {
74 | if i != 0 {
75 | buf.WriteString(", ")
76 | }
77 | if err := EncodeValue(buf, bindVal[i]); err != nil {
78 | return err
79 | }
80 | }
81 | case [][]sqltypes.Value:
82 | for i := 0; i < len(bindVal); i++ {
83 | if i != 0 {
84 | buf.WriteString(", ")
85 | }
86 | buf.WriteByte('(')
87 | if err := EncodeValue(buf, bindVal[i]); err != nil {
88 | return err
89 | }
90 | buf.WriteByte(')')
91 | }
92 | default:
93 | v, err := sqltypes.BuildValue(bindVal)
94 | if err != nil {
95 | return err
96 | }
97 | v.EncodeSql(buf)
98 | }
99 | return nil
100 | }
101 |
--------------------------------------------------------------------------------
/mysql/resultset_sort.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "github.com/siddontang/mixer/hack"
7 | "sort"
8 | )
9 |
10 | const (
11 | SortAsc = "asc"
12 | SortDesc = "desc"
13 | )
14 |
15 | type SortKey struct {
16 | //name of the field
17 | Name string
18 |
19 | Direction string
20 |
21 | //column index of the field
22 | column int
23 | }
24 |
25 | type resultsetSorter struct {
26 | *Resultset
27 |
28 | sk []SortKey
29 | }
30 |
31 | func newResultsetSorter(r *Resultset, sk []SortKey) (*resultsetSorter, error) {
32 | s := new(resultsetSorter)
33 |
34 | s.Resultset = r
35 |
36 | for i, k := range sk {
37 | if column, ok := r.FieldNames[k.Name]; ok {
38 | sk[i].column = column
39 | } else {
40 | return nil, fmt.Errorf("key %s not in resultset fields, can not sort", k.Name)
41 | }
42 | }
43 |
44 | s.sk = sk
45 |
46 | return s, nil
47 | }
48 |
49 | func (r *resultsetSorter) Len() int {
50 | return r.RowNumber()
51 | }
52 |
53 | func (r *resultsetSorter) Less(i, j int) bool {
54 | v1 := r.Values[i]
55 | v2 := r.Values[j]
56 |
57 | for _, k := range r.sk {
58 | v := cmpValue(v1[k.column], v2[k.column])
59 |
60 | if k.Direction == SortDesc {
61 | v = -v
62 | }
63 |
64 | if v < 0 {
65 | return true
66 | } else if v > 0 {
67 | return false
68 | }
69 |
70 | //equal, cmp next key
71 | }
72 |
73 | return false
74 | }
75 |
76 | //compare value using asc
77 | func cmpValue(v1 interface{}, v2 interface{}) int {
78 | if v1 == nil && v2 == nil {
79 | return 0
80 | } else if v1 == nil {
81 | return -1
82 | } else if v2 == nil {
83 | return 1
84 | }
85 |
86 | switch v := v1.(type) {
87 | case string:
88 | s := v2.(string)
89 | return bytes.Compare(hack.Slice(v), hack.Slice(s))
90 | case []byte:
91 | s := v2.([]byte)
92 | return bytes.Compare(v, s)
93 | case int64:
94 | s := v2.(int64)
95 | if v < s {
96 | return -1
97 | } else if v > s {
98 | return 1
99 | } else {
100 | return 0
101 | }
102 | case uint64:
103 | s := v2.(uint64)
104 | if v < s {
105 | return -1
106 | } else if v > s {
107 | return 1
108 | } else {
109 | return 0
110 | }
111 | case float64:
112 | s := v2.(float64)
113 | if v < s {
114 | return -1
115 | } else if v > s {
116 | return 1
117 | } else {
118 | return 0
119 | }
120 | default:
121 | //can not go here
122 | panic(fmt.Sprintf("invalid type %T", v))
123 | }
124 | }
125 |
126 | func (r *resultsetSorter) Swap(i, j int) {
127 | r.Values[i], r.Values[j] = r.Values[j], r.Values[i]
128 |
129 | r.RowDatas[i], r.RowDatas[j] = r.RowDatas[j], r.RowDatas[i]
130 | }
131 |
132 | func (r *Resultset) Sort(sk []SortKey) error {
133 | s, err := newResultsetSorter(r, sk)
134 |
135 | if err != nil {
136 | return err
137 | }
138 |
139 | sort.Sort(s)
140 |
141 | return nil
142 | }
143 |
--------------------------------------------------------------------------------
/proxy/conn_select.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | . "github.com/siddontang/mixer/mysql"
7 | "github.com/siddontang/mixer/sqlparser"
8 | "strings"
9 | )
10 |
11 | func (c *Conn) handleSimpleSelect(sql string, stmt *sqlparser.SimpleSelect) error {
12 | if len(stmt.SelectExprs) != 1 {
13 | return fmt.Errorf("support select one informaction function, %s", sql)
14 | }
15 |
16 | expr, ok := stmt.SelectExprs[0].(*sqlparser.NonStarExpr)
17 | if !ok {
18 | return fmt.Errorf("support select informaction function, %s", sql)
19 | }
20 |
21 | var f *sqlparser.FuncExpr
22 | f, ok = expr.Expr.(*sqlparser.FuncExpr)
23 | if !ok {
24 | return fmt.Errorf("support select informaction function, %s", sql)
25 | }
26 |
27 | var r *Resultset
28 | var err error
29 |
30 | switch strings.ToLower(string(f.Name)) {
31 | case "last_insert_id":
32 | r, err = c.buildSimpleSelectResult(c.lastInsertId, f.Name, expr.As)
33 | case "row_count":
34 | r, err = c.buildSimpleSelectResult(c.affectedRows, f.Name, expr.As)
35 | case "version":
36 | r, err = c.buildSimpleSelectResult(ServerVersion, f.Name, expr.As)
37 | case "connection_id":
38 | r, err = c.buildSimpleSelectResult(c.connectionId, f.Name, expr.As)
39 | case "database":
40 | if c.schema != nil {
41 | r, err = c.buildSimpleSelectResult(c.schema.db, f.Name, expr.As)
42 | } else {
43 | r, err = c.buildSimpleSelectResult("NULL", f.Name, expr.As)
44 | }
45 | default:
46 | return fmt.Errorf("function %s not support", f.Name)
47 | }
48 |
49 | if err != nil {
50 | return err
51 | }
52 |
53 | return c.writeResultset(c.status, r)
54 | }
55 |
56 | func (c *Conn) buildSimpleSelectResult(value interface{}, name []byte, asName []byte) (*Resultset, error) {
57 | field := &Field{}
58 |
59 | field.Name = name
60 |
61 | if asName != nil {
62 | field.Name = asName
63 | }
64 |
65 | field.OrgName = name
66 |
67 | formatField(field, value)
68 |
69 | r := &Resultset{Fields: []*Field{field}}
70 | row, err := formatValue(value)
71 | if err != nil {
72 | return nil, err
73 | }
74 | r.RowDatas = append(r.RowDatas, PutLengthEncodedString(row))
75 |
76 | return r, nil
77 | }
78 |
79 | func (c *Conn) handleFieldList(data []byte) error {
80 | index := bytes.IndexByte(data, 0x00)
81 | table := string(data[0:index])
82 | wildcard := string(data[index+1:])
83 |
84 | if c.schema == nil {
85 | return NewDefaultError(ER_NO_DB_ERROR)
86 | }
87 |
88 | nodeName := c.schema.rule.GetRule(table).Nodes[0]
89 |
90 | n := c.server.getNode(nodeName)
91 |
92 | co, err := n.getMasterConn()
93 | if err != nil {
94 | return err
95 | }
96 | defer co.Close()
97 |
98 | if err = co.UseDB(c.schema.db); err != nil {
99 | return err
100 | }
101 |
102 | if fs, err := co.FieldList(table, wildcard); err != nil {
103 | return err
104 | } else {
105 | return c.writeFieldList(c.status, fs)
106 | }
107 | }
108 |
109 | func (c *Conn) writeFieldList(status uint16, fs []*Field) error {
110 | c.affectedRows = int64(-1)
111 |
112 | data := make([]byte, 4, 1024)
113 |
114 | for _, v := range fs {
115 | data = data[0:4]
116 | data = append(data, v.Dump()...)
117 | if err := c.writePacket(data); err != nil {
118 | return err
119 | }
120 | }
121 |
122 | if err := c.writeEOF(status); err != nil {
123 | return err
124 | }
125 | return nil
126 | }
127 |
--------------------------------------------------------------------------------
/sqlparser/tracked_buffer.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import (
8 | "bytes"
9 | "fmt"
10 | )
11 |
12 | // ParserError: To be deprecated.
13 | // TODO(sougou): deprecate.
14 | type ParserError struct {
15 | Message string
16 | }
17 |
18 | func NewParserError(format string, args ...interface{}) ParserError {
19 | return ParserError{fmt.Sprintf(format, args...)}
20 | }
21 |
22 | func (err ParserError) Error() string {
23 | return err.Message
24 | }
25 |
26 | func handleError(err *error) {
27 | if x := recover(); x != nil {
28 | *err = x.(error)
29 | }
30 | }
31 |
32 | // TrackedBuffer is used to rebuild a query from the ast.
33 | // bindLocations keeps track of locations in the buffer that
34 | // use bind variables for efficient future substitutions.
35 | // nodeFormatter is the formatting function the buffer will
36 | // use to format a node. By default(nil), it's FormatNode.
37 | // But you can supply a different formatting function if you
38 | // want to generate a query that's different from the default.
39 | type TrackedBuffer struct {
40 | *bytes.Buffer
41 | bindLocations []BindLocation
42 | nodeFormatter func(buf *TrackedBuffer, node SQLNode)
43 | }
44 |
45 | func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node SQLNode)) *TrackedBuffer {
46 | buf := &TrackedBuffer{
47 | Buffer: bytes.NewBuffer(make([]byte, 0, 128)),
48 | bindLocations: make([]BindLocation, 0, 4),
49 | nodeFormatter: nodeFormatter,
50 | }
51 | return buf
52 | }
53 |
54 | // Fprintf mimics fmt.Fprintf, but limited to Node(%v), Node.Value(%s) and string(%s).
55 | // It also allows a %a for a value argument, in which case it adds tracking info for
56 | // future substitutions.
57 | func (buf *TrackedBuffer) Fprintf(format string, values ...interface{}) {
58 | end := len(format)
59 | fieldnum := 0
60 | for i := 0; i < end; {
61 | lasti := i
62 | for i < end && format[i] != '%' {
63 | i++
64 | }
65 | if i > lasti {
66 | buf.WriteString(format[lasti:i])
67 | }
68 | if i >= end {
69 | break
70 | }
71 | i++ // '%'
72 | switch format[i] {
73 | case 'c':
74 | switch v := values[fieldnum].(type) {
75 | case byte:
76 | buf.WriteByte(v)
77 | case rune:
78 | buf.WriteRune(v)
79 | default:
80 | panic(fmt.Sprintf("unexpected type %T", v))
81 | }
82 | case 's':
83 | switch v := values[fieldnum].(type) {
84 | case []byte:
85 | buf.Write(v)
86 | case string:
87 | buf.WriteString(v)
88 | default:
89 | panic(fmt.Sprintf("unexpected type %T", v))
90 | }
91 | case 'v':
92 | node := values[fieldnum].(SQLNode)
93 | if buf.nodeFormatter == nil {
94 | node.Format(buf)
95 | } else {
96 | buf.nodeFormatter(buf, node)
97 | }
98 | case 'a':
99 | buf.WriteArg(values[fieldnum].(string))
100 | default:
101 | panic("unexpected")
102 | }
103 | fieldnum++
104 | i++
105 | }
106 | }
107 |
108 | // WriteArg writes a value argument into the buffer. arg should not contain
109 | // the ':' prefix. It also adds tracking info for future substitutions.
110 | func (buf *TrackedBuffer) WriteArg(arg string) {
111 | buf.bindLocations = append(buf.bindLocations, BindLocation{buf.Len(), len(arg) + 1})
112 | buf.WriteByte(':')
113 | buf.WriteString(arg)
114 | }
115 |
116 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
117 | return &ParsedQuery{buf.String(), buf.bindLocations}
118 | }
119 |
--------------------------------------------------------------------------------
/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "testing"
7 | )
8 |
9 | func TestConfig(t *testing.T) {
10 | var testConfigData = []byte(
11 | `
12 | addr : 127.0.0.1:4000
13 | user : root
14 | password :
15 | log_level : error
16 |
17 | nodes :
18 | -
19 | name : node1
20 | down_after_noalive : 300
21 | idle_conns : 16
22 | rw_split: true
23 | user: root
24 | password:
25 | master : 127.0.0.1:3306
26 | slave : 127.0.0.1:4306
27 | -
28 | name : node2
29 | user: root
30 | master : 127.0.0.1:3307
31 |
32 | -
33 | name : node3
34 | down_after_noalive : 300
35 | idle_conns : 16
36 | rw_split: false
37 | user: root
38 | password:
39 | master : 127.0.0.1:3308
40 |
41 | schemas :
42 | -
43 | db : mixer
44 | nodes: [node1, node2, node3]
45 | rules:
46 | default: node1
47 | shard:
48 | -
49 | table: mixer_test_shard_hash
50 | key: id
51 | nodes: [node1, node2, node3]
52 | type: hash
53 |
54 | -
55 | table: mixer_test_shard_range
56 | key: id
57 | type: range
58 | nodes: [node2, node3]
59 | range: -10000-
60 | `)
61 |
62 | cfg, err := ParseConfigData(testConfigData)
63 | if err != nil {
64 | t.Fatal(err)
65 | }
66 |
67 | if len(cfg.Nodes) != 3 {
68 | t.Fatal(len(cfg.Nodes))
69 | }
70 |
71 | if len(cfg.Schemas) != 1 {
72 | t.Fatal(len(cfg.Schemas))
73 | }
74 |
75 | testNode := NodeConfig{
76 | Name: "node1",
77 | DownAfterNoAlive: 300,
78 | IdleConns: 16,
79 | RWSplit: true,
80 |
81 | User: "root",
82 | Password: "",
83 |
84 | Master: "127.0.0.1:3306",
85 | Slave: "127.0.0.1:4306",
86 | }
87 |
88 | if !reflect.DeepEqual(cfg.Nodes[0], testNode) {
89 | fmt.Printf("%v\n", cfg.Nodes[0])
90 | t.Fatal("node1 must equal")
91 | }
92 |
93 | testNode_2 := NodeConfig{
94 | Name: "node2",
95 | User: "root",
96 | Master: "127.0.0.1:3307",
97 | }
98 |
99 | if !reflect.DeepEqual(cfg.Nodes[1], testNode_2) {
100 | t.Fatal("node2 must equal")
101 | }
102 |
103 | testShard_1 := ShardConfig{
104 | Table: "mixer_test_shard_hash",
105 | Key: "id",
106 | Nodes: []string{"node1", "node2", "node3"},
107 | Type: "hash",
108 | }
109 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg.ShardRule[0], testShard_1) {
110 | t.Fatal("ShardConfig0 must equal")
111 | }
112 |
113 | testShard_2 := ShardConfig{
114 | Table: "mixer_test_shard_range",
115 | Key: "id",
116 | Nodes: []string{"node2", "node3"},
117 | Type: "range",
118 | Range: "-10000-",
119 | }
120 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg.ShardRule[1], testShard_2) {
121 | t.Fatal("ShardConfig1 must equal")
122 | }
123 |
124 | if 2 != len(cfg.Schemas[0].RulesConifg.ShardRule) {
125 | t.Fatal("ShardRule must 2")
126 | }
127 |
128 | testRules := RulesConfig{
129 | Default: "node1",
130 | ShardRule: []ShardConfig{testShard_1, testShard_2},
131 | }
132 | if !reflect.DeepEqual(cfg.Schemas[0].RulesConifg, testRules) {
133 | t.Fatal("RulesConfig must equal")
134 | }
135 |
136 | testSchema := SchemaConfig{
137 | DB: "mixer",
138 | Nodes: []string{"node1", "node2", "node3"},
139 | RulesConifg: testRules,
140 | }
141 |
142 | if !reflect.DeepEqual(cfg.Schemas[0], testSchema) {
143 | t.Fatal("schema must equal")
144 | }
145 |
146 | if cfg.LogLevel != "error" || cfg.User != "root" || cfg.Password != "" || cfg.Addr != "127.0.0.1:4000" {
147 | t.Fatal("Top Config not equal.")
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/client/conn_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "fmt"
5 | . "github.com/siddontang/mixer/mysql"
6 | "testing"
7 | )
8 |
9 | func newTestConn() *Conn {
10 | c := new(Conn)
11 |
12 | if err := c.Connect("127.0.0.1:3306", "root", "", "mixer"); err != nil {
13 | panic(err)
14 | }
15 |
16 | return c
17 | }
18 |
19 | func TestConn_Connect(t *testing.T) {
20 | c := newTestConn()
21 | defer c.Close()
22 | }
23 |
24 | func TestConn_Ping(t *testing.T) {
25 | c := newTestConn()
26 | defer c.Close()
27 |
28 | if err := c.Ping(); err != nil {
29 | t.Fatal(err)
30 | }
31 | }
32 |
33 | func TestConn_DeleteTable(t *testing.T) {
34 | c := newTestConn()
35 | defer c.Close()
36 |
37 | if _, err := c.Execute("drop table if exists mixer_test_conn"); err != nil {
38 | t.Fatal(err)
39 | }
40 | }
41 |
42 | func TestConn_CreateTable(t *testing.T) {
43 | s := `CREATE TABLE IF NOT EXISTS mixer_test_conn (
44 | id BIGINT(64) UNSIGNED NOT NULL,
45 | str VARCHAR(256),
46 | f DOUBLE,
47 | e enum("test1", "test2"),
48 | u tinyint unsigned,
49 | i tinyint,
50 | PRIMARY KEY (id)
51 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
52 |
53 | c := newTestConn()
54 | defer c.Close()
55 |
56 | if _, err := c.Execute(s); err != nil {
57 | t.Fatal(err)
58 | }
59 | }
60 |
61 | func TestConn_Insert(t *testing.T) {
62 | s := `insert into mixer_test_conn (id, str, f, e) values(1, "a", 3.14, "test1")`
63 |
64 | c := newTestConn()
65 | defer c.Close()
66 |
67 | if pkg, err := c.Execute(s); err != nil {
68 | t.Fatal(err)
69 | } else {
70 | if pkg.AffectedRows != 1 {
71 | t.Fatal(pkg.AffectedRows)
72 | }
73 | }
74 | }
75 |
76 | func TestConn_Select(t *testing.T) {
77 | s := `select str, f, e from mixer_test_conn where id = 1`
78 |
79 | c := newTestConn()
80 | defer c.Close()
81 |
82 | if result, err := c.Execute(s); err != nil {
83 | t.Fatal(err)
84 | } else {
85 | if len(result.Fields) != 3 {
86 | t.Fatal(len(result.Fields))
87 | }
88 |
89 | if len(result.Values) != 1 {
90 | t.Fatal(len(result.Values))
91 | }
92 |
93 | if str, _ := result.GetString(0, 0); str != "a" {
94 | t.Fatal("invalid str", str)
95 | }
96 |
97 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) {
98 | t.Fatal("invalid f", f)
99 | }
100 |
101 | if e, _ := result.GetString(0, 2); e != "test1" {
102 | t.Fatal("invalid e", e)
103 | }
104 |
105 | if str, _ := result.GetStringByName(0, "str"); str != "a" {
106 | t.Fatal("invalid str", str)
107 | }
108 |
109 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) {
110 | t.Fatal("invalid f", f)
111 | }
112 |
113 | if e, _ := result.GetStringByName(0, "e"); e != "test1" {
114 | t.Fatal("invalid e", e)
115 | }
116 |
117 | }
118 | }
119 |
120 | func TestConn_Escape(t *testing.T) {
121 | c := newTestConn()
122 | defer c.Close()
123 |
124 | e := `""''\abc`
125 | s := fmt.Sprintf(`insert into mixer_test_conn (id, str) values(5, "%s")`,
126 | Escape(e))
127 |
128 | if _, err := c.Execute(s); err != nil {
129 | t.Fatal(err)
130 | }
131 |
132 | s = `select str from mixer_test_conn where id = ?`
133 |
134 | if r, err := c.Execute(s, 5); err != nil {
135 | t.Fatal(err)
136 | } else {
137 | str, _ := r.GetString(0, 0)
138 | if str != e {
139 | t.Fatal(str)
140 | }
141 | }
142 | }
143 |
144 | func TestConn_SetCharset(t *testing.T) {
145 | c := newTestConn()
146 | defer c.Close()
147 |
148 | if err := c.SetCharset("gb2312"); err != nil {
149 | t.Fatal(err)
150 | }
151 | }
152 |
--------------------------------------------------------------------------------
/mysql/field.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "encoding/binary"
5 | )
6 |
7 | type FieldData []byte
8 |
9 | type Field struct {
10 | Data FieldData
11 | Schema []byte
12 | Table []byte
13 | OrgTable []byte
14 | Name []byte
15 | OrgName []byte
16 | Charset uint16
17 | ColumnLength uint32
18 | Type uint8
19 | Flag uint16
20 | Decimal uint8
21 |
22 | DefaultValueLength uint64
23 | DefaultValue []byte
24 | }
25 |
26 | func (p FieldData) Parse() (f *Field, err error) {
27 | f = new(Field)
28 |
29 | f.Data = p
30 |
31 | var n int
32 | pos := 0
33 | //skip catelog, always def
34 | n, err = SkipLengthEnodedString(p)
35 | if err != nil {
36 | return
37 | }
38 | pos += n
39 |
40 | //schema
41 | f.Schema, _, n, err = LengthEnodedString(p[pos:])
42 | if err != nil {
43 | return
44 | }
45 | pos += n
46 |
47 | //table
48 | f.Table, _, n, err = LengthEnodedString(p[pos:])
49 | if err != nil {
50 | return
51 | }
52 | pos += n
53 |
54 | //org_table
55 | f.OrgTable, _, n, err = LengthEnodedString(p[pos:])
56 | if err != nil {
57 | return
58 | }
59 | pos += n
60 |
61 | //name
62 | f.Name, _, n, err = LengthEnodedString(p[pos:])
63 | if err != nil {
64 | return
65 | }
66 | pos += n
67 |
68 | //org_name
69 | f.OrgName, _, n, err = LengthEnodedString(p[pos:])
70 | if err != nil {
71 | return
72 | }
73 | pos += n
74 |
75 | //skip oc
76 | pos += 1
77 |
78 | //charset
79 | f.Charset = binary.LittleEndian.Uint16(p[pos:])
80 | pos += 2
81 |
82 | //column length
83 | f.ColumnLength = binary.LittleEndian.Uint32(p[pos:])
84 | pos += 4
85 |
86 | //type
87 | f.Type = p[pos]
88 | pos++
89 |
90 | //flag
91 | f.Flag = binary.LittleEndian.Uint16(p[pos:])
92 | pos += 2
93 |
94 | //decimals 1
95 | f.Decimal = p[pos]
96 | pos++
97 |
98 | //filter [0x00][0x00]
99 | pos += 2
100 |
101 | f.DefaultValue = nil
102 | //if more data, command was field list
103 | if len(p) > pos {
104 | //length of default value lenenc-int
105 | f.DefaultValueLength, _, n = LengthEncodedInt(p[pos:])
106 | pos += n
107 |
108 | if pos+int(f.DefaultValueLength) > len(p) {
109 | err = ErrMalformPacket
110 | return
111 | }
112 |
113 | //default value string[$len]
114 | f.DefaultValue = p[pos:(pos + int(f.DefaultValueLength))]
115 | }
116 |
117 | return
118 | }
119 |
120 | func (f *Field) Dump() []byte {
121 | if f.Data != nil {
122 | return []byte(f.Data)
123 | }
124 |
125 | l := len(f.Schema) + len(f.Table) + len(f.OrgTable) + len(f.Name) + len(f.OrgName) + len(f.DefaultValue) + 48
126 |
127 | data := make([]byte, 0, l)
128 |
129 | data = append(data, PutLengthEncodedString([]byte("def"))...)
130 |
131 | data = append(data, PutLengthEncodedString(f.Schema)...)
132 |
133 | data = append(data, PutLengthEncodedString(f.Table)...)
134 | data = append(data, PutLengthEncodedString(f.OrgTable)...)
135 |
136 | data = append(data, PutLengthEncodedString(f.Name)...)
137 | data = append(data, PutLengthEncodedString(f.OrgName)...)
138 |
139 | data = append(data, 0x0c)
140 |
141 | data = append(data, Uint16ToBytes(f.Charset)...)
142 | data = append(data, Uint32ToBytes(f.ColumnLength)...)
143 | data = append(data, f.Type)
144 | data = append(data, Uint16ToBytes(f.Flag)...)
145 | data = append(data, f.Decimal)
146 | data = append(data, 0, 0)
147 |
148 | if f.DefaultValue != nil {
149 | data = append(data, Uint64ToBytes(f.DefaultValueLength)...)
150 | data = append(data, f.DefaultValue...)
151 | }
152 |
153 | return data
154 | }
155 |
--------------------------------------------------------------------------------
/mysql/const.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | const (
4 | MinProtocolVersion byte = 10
5 | MaxPayloadLen int = 1<<24 - 1
6 | TimeFormat string = "2006-01-02 15:04:05"
7 | ServerVersion string = "5.5.31-mixer-0.1"
8 | )
9 |
10 | const (
11 | OK_HEADER byte = 0x00
12 | ERR_HEADER byte = 0xff
13 | EOF_HEADER byte = 0xfe
14 | LocalInFile_HEADER byte = 0xfb
15 | )
16 |
17 | const (
18 | SERVER_STATUS_IN_TRANS uint16 = 0x0001
19 | SERVER_STATUS_AUTOCOMMIT uint16 = 0x0002
20 | SERVER_MORE_RESULTS_EXISTS uint16 = 0x0008
21 | SERVER_STATUS_NO_GOOD_INDEX_USED uint16 = 0x0010
22 | SERVER_STATUS_NO_INDEX_USED uint16 = 0x0020
23 | SERVER_STATUS_CURSOR_EXISTS uint16 = 0x0040
24 | SERVER_STATUS_LAST_ROW_SEND uint16 = 0x0080
25 | SERVER_STATUS_DB_DROPPED uint16 = 0x0100
26 | SERVER_STATUS_NO_BACKSLASH_ESCAPED uint16 = 0x0200
27 | SERVER_STATUS_METADATA_CHANGED uint16 = 0x0400
28 | SERVER_QUERY_WAS_SLOW uint16 = 0x0800
29 | SERVER_PS_OUT_PARAMS uint16 = 0x1000
30 | )
31 |
32 | const (
33 | COM_SLEEP byte = iota
34 | COM_QUIT
35 | COM_INIT_DB
36 | COM_QUERY
37 | COM_FIELD_LIST
38 | COM_CREATE_DB
39 | COM_DROP_DB
40 | COM_REFRESH
41 | COM_SHUTDOWN
42 | COM_STATISTICS
43 | COM_PROCESS_INFO
44 | COM_CONNECT
45 | COM_PROCESS_KILL
46 | COM_DEBUG
47 | COM_PING
48 | COM_TIME
49 | COM_DELAYED_INSERT
50 | COM_CHANGE_USER
51 | COM_BINLOG_DUMP
52 | COM_TABLE_DUMP
53 | COM_CONNECT_OUT
54 | COM_REGISTER_SLAVE
55 | COM_STMT_PREPARE
56 | COM_STMT_EXECUTE
57 | COM_STMT_SEND_LONG_DATA
58 | COM_STMT_CLOSE
59 | COM_STMT_RESET
60 | COM_SET_OPTION
61 | COM_STMT_FETCH
62 | COM_DAEMON
63 | COM_BINLOG_DUMP_GTID
64 | COM_RESET_CONNECTION
65 | )
66 |
67 | const (
68 | CLIENT_LONG_PASSWORD uint32 = 1 << iota
69 | CLIENT_FOUND_ROWS
70 | CLIENT_LONG_FLAG
71 | CLIENT_CONNECT_WITH_DB
72 | CLIENT_NO_SCHEMA
73 | CLIENT_COMPRESS
74 | CLIENT_ODBC
75 | CLIENT_LOCAL_FILES
76 | CLIENT_IGNORE_SPACE
77 | CLIENT_PROTOCOL_41
78 | CLIENT_INTERACTIVE
79 | CLIENT_SSL
80 | CLIENT_IGNORE_SIGPIPE
81 | CLIENT_TRANSACTIONS
82 | CLIENT_RESERVED
83 | CLIENT_SECURE_CONNECTION
84 | CLIENT_MULTI_STATEMENTS
85 | CLIENT_MULTI_RESULTS
86 | CLIENT_PS_MULTI_RESULTS
87 | CLIENT_PLUGIN_AUTH
88 | CLIENT_CONNECT_ATTRS
89 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
90 | )
91 |
92 | const (
93 | MYSQL_TYPE_DECIMAL byte = iota
94 | MYSQL_TYPE_TINY
95 | MYSQL_TYPE_SHORT
96 | MYSQL_TYPE_LONG
97 | MYSQL_TYPE_FLOAT
98 | MYSQL_TYPE_DOUBLE
99 | MYSQL_TYPE_NULL
100 | MYSQL_TYPE_TIMESTAMP
101 | MYSQL_TYPE_LONGLONG
102 | MYSQL_TYPE_INT24
103 | MYSQL_TYPE_DATE
104 | MYSQL_TYPE_TIME
105 | MYSQL_TYPE_DATETIME
106 | MYSQL_TYPE_YEAR
107 | MYSQL_TYPE_NEWDATE
108 | MYSQL_TYPE_VARCHAR
109 | MYSQL_TYPE_BIT
110 | )
111 |
112 | const (
113 | MYSQL_TYPE_NEWDECIMAL byte = iota + 0xf6
114 | MYSQL_TYPE_ENUM
115 | MYSQL_TYPE_SET
116 | MYSQL_TYPE_TINY_BLOB
117 | MYSQL_TYPE_MEDIUM_BLOB
118 | MYSQL_TYPE_LONG_BLOB
119 | MYSQL_TYPE_BLOB
120 | MYSQL_TYPE_VAR_STRING
121 | MYSQL_TYPE_STRING
122 | MYSQL_TYPE_GEOMETRY
123 | )
124 |
125 | const (
126 | NOT_NULL_FLAG = 1
127 | PRI_KEY_FLAG = 2
128 | UNIQUE_KEY_FLAG = 4
129 | BLOB_FLAG = 16
130 | UNSIGNED_FLAG = 32
131 | ZEROFILL_FLAG = 64
132 | BINARY_FLAG = 128
133 | ENUM_FLAG = 256
134 | AUTO_INCREMENT_FLAG = 512
135 | TIMESTAMP_FLAG = 1024
136 | SET_FLAG = 2048
137 | NUM_FLAG = 32768
138 | PART_KEY_FLAG = 16384
139 | GROUP_FLAG = 32768
140 | UNIQUE_FLAG = 65536
141 | )
142 |
143 | const (
144 | AUTH_NAME = "mysql_native_password"
145 | )
146 |
--------------------------------------------------------------------------------
/proxy/conn_resultset.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/hack"
6 | . "github.com/siddontang/mixer/mysql"
7 | "strconv"
8 | )
9 |
10 | func formatValue(value interface{}) ([]byte, error) {
11 | switch v := value.(type) {
12 | case int8:
13 | return strconv.AppendInt(nil, int64(v), 10), nil
14 | case int16:
15 | return strconv.AppendInt(nil, int64(v), 10), nil
16 | case int32:
17 | return strconv.AppendInt(nil, int64(v), 10), nil
18 | case int64:
19 | return strconv.AppendInt(nil, int64(v), 10), nil
20 | case int:
21 | return strconv.AppendInt(nil, int64(v), 10), nil
22 | case uint8:
23 | return strconv.AppendUint(nil, uint64(v), 10), nil
24 | case uint16:
25 | return strconv.AppendUint(nil, uint64(v), 10), nil
26 | case uint32:
27 | return strconv.AppendUint(nil, uint64(v), 10), nil
28 | case uint64:
29 | return strconv.AppendUint(nil, uint64(v), 10), nil
30 | case uint:
31 | return strconv.AppendUint(nil, uint64(v), 10), nil
32 | case float32:
33 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil
34 | case float64:
35 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil
36 | case []byte:
37 | return v, nil
38 | case string:
39 | return hack.Slice(v), nil
40 | default:
41 | return nil, fmt.Errorf("invalid type %T", value)
42 | }
43 | }
44 |
45 | func formatField(field *Field, value interface{}) error {
46 | switch value.(type) {
47 | case int8, int16, int32, int64, int:
48 | field.Charset = 63
49 | field.Type = MYSQL_TYPE_LONGLONG
50 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG
51 | case uint8, uint16, uint32, uint64, uint:
52 | field.Charset = 63
53 | field.Type = MYSQL_TYPE_LONGLONG
54 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG
55 | case string, []byte:
56 | field.Charset = 33
57 | field.Type = MYSQL_TYPE_VAR_STRING
58 | default:
59 | return fmt.Errorf("unsupport type %T for resultset", value)
60 | }
61 | return nil
62 | }
63 |
64 | func (c *Conn) buildResultset(names []string, values [][]interface{}) (*Resultset, error) {
65 | r := new(Resultset)
66 |
67 | r.Fields = make([]*Field, len(names))
68 |
69 | var b []byte
70 | var err error
71 |
72 | for i, vs := range values {
73 | if len(vs) != len(r.Fields) {
74 | return nil, fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields))
75 | }
76 |
77 | var row []byte
78 | for j, value := range vs {
79 | if i == 0 {
80 | field := &Field{}
81 | r.Fields[j] = field
82 | field.Name = hack.Slice(names[j])
83 |
84 | if err = formatField(field, value); err != nil {
85 | return nil, err
86 | }
87 | }
88 | b, err = formatValue(value)
89 |
90 | if err != nil {
91 | return nil, err
92 | }
93 |
94 | row = append(row, PutLengthEncodedString(b)...)
95 | }
96 |
97 | r.RowDatas = append(r.RowDatas, row)
98 | }
99 |
100 | return r, nil
101 | }
102 |
103 | func (c *Conn) writeResultset(status uint16, r *Resultset) error {
104 | c.affectedRows = int64(-1)
105 |
106 | columnLen := PutLengthEncodedInt(uint64(len(r.Fields)))
107 |
108 | data := make([]byte, 4, 1024)
109 |
110 | data = append(data, columnLen...)
111 | if err := c.writePacket(data); err != nil {
112 | return err
113 | }
114 |
115 | for _, v := range r.Fields {
116 | data = data[0:4]
117 | data = append(data, v.Dump()...)
118 | if err := c.writePacket(data); err != nil {
119 | return err
120 | }
121 | }
122 |
123 | if err := c.writeEOF(status); err != nil {
124 | return err
125 | }
126 |
127 | for _, v := range r.RowDatas {
128 | data = data[0:4]
129 | data = append(data, v...)
130 | if err := c.writePacket(data); err != nil {
131 | return err
132 | }
133 | }
134 |
135 | if err := c.writeEOF(status); err != nil {
136 | return err
137 | }
138 |
139 | return nil
140 | }
141 |
--------------------------------------------------------------------------------
/sqlparser/analyzer.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | // analyzer.go contains utility analysis functions.
8 |
9 | import (
10 | "fmt"
11 |
12 | "github.com/siddontang/mixer/sqltypes"
13 | )
14 |
15 | // GetDBName parses the specified DML and returns the
16 | // db name if it was used to qualify the table name.
17 | // It returns an error if parsing fails or if the statement
18 | // is not a DML.
19 | func GetDBName(sql string) (string, error) {
20 | statement, err := Parse(sql)
21 | if err != nil {
22 | return "", err
23 | }
24 | switch stmt := statement.(type) {
25 | case *Insert:
26 | return string(stmt.Table.Qualifier), nil
27 | case *Update:
28 | return string(stmt.Table.Qualifier), nil
29 | case *Delete:
30 | return string(stmt.Table.Qualifier), nil
31 | }
32 | return "", fmt.Errorf("statement '%s' is not a dml", sql)
33 | }
34 |
35 | // GetTableName returns the table name from the SimpleTableExpr
36 | // only if it's a simple expression. Otherwise, it returns "".
37 | func GetTableName(node SimpleTableExpr) string {
38 | if n, ok := node.(*TableName); ok && n.Qualifier == nil {
39 | return string(n.Name)
40 | }
41 | // sub-select or '.' expression
42 | return ""
43 | }
44 |
45 | // GetColName returns the column name, only if
46 | // it's a simple expression. Otherwise, it returns "".
47 | func GetColName(node Expr) string {
48 | if n, ok := node.(*ColName); ok {
49 | return string(n.Name)
50 | }
51 | return ""
52 | }
53 |
54 | // IsColName returns true if the ValExpr is a *ColName.
55 | func IsColName(node ValExpr) bool {
56 | _, ok := node.(*ColName)
57 | return ok
58 | }
59 |
60 | // IsVal returns true if the ValExpr is a string, number or value arg.
61 | // NULL is not considered to be a value.
62 | func IsValue(node ValExpr) bool {
63 | switch node.(type) {
64 | case StrVal, NumVal, ValArg:
65 | return true
66 | }
67 | return false
68 | }
69 |
70 | // HasINCaluse returns true if an yof the conditions has an IN clause.
71 | func HasINClause(conditions []BoolExpr) bool {
72 | for _, node := range conditions {
73 | if c, ok := node.(*ComparisonExpr); ok && c.Operator == AST_IN {
74 | return true
75 | }
76 | }
77 | return false
78 | }
79 |
80 | // IsSimpleTuple returns true if the ValExpr is a ValTuple that
81 | // contains simple values.
82 | func IsSimpleTuple(node ValExpr) bool {
83 | list, ok := node.(ValTuple)
84 | if !ok {
85 | // It's a subquery.
86 | return false
87 | }
88 | for _, n := range list {
89 | if !IsValue(n) {
90 | return false
91 | }
92 | }
93 | return true
94 | }
95 |
96 | // AsInterface converts the ValExpr to an interface. It converts
97 | // ValTuple to []interface{}, ValArg to string, StrVal to sqltypes.String,
98 | // NumVal to sqltypes.Numeric. Otherwise, it returns an error.
99 | func AsInterface(node ValExpr) (interface{}, error) {
100 | switch node := node.(type) {
101 | case ValTuple:
102 | vals := make([]interface{}, 0, len(node))
103 | for _, val := range node {
104 | v, err := AsInterface(val)
105 | if err != nil {
106 | return nil, err
107 | }
108 | vals = append(vals, v)
109 | }
110 | return vals, nil
111 | case ValArg:
112 | return string(node), nil
113 | case StrVal:
114 | return sqltypes.MakeString(node), nil
115 | case NumVal:
116 | n, err := sqltypes.BuildNumeric(string(node))
117 | if err != nil {
118 | return nil, fmt.Errorf("type mismatch: %s", err)
119 | }
120 | return n, nil
121 | }
122 | return nil, fmt.Errorf("unexpected node %v", node)
123 | }
124 |
125 | // StringIn is a convenience function that returns
126 | // true if str matches any of the values.
127 | func StringIn(str string, values ...string) bool {
128 | for _, val := range values {
129 | if str == val {
130 | return true
131 | }
132 | }
133 | return false
134 | }
135 |
--------------------------------------------------------------------------------
/router/shard.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package router
6 |
7 | import (
8 | "fmt"
9 | "github.com/siddontang/mixer/hack"
10 | "hash/crc32"
11 | "strconv"
12 | )
13 |
14 | type KeyError string
15 |
16 | func NewKeyError(format string, args ...interface{}) KeyError {
17 | return KeyError(fmt.Sprintf(format, args...))
18 | }
19 |
20 | func (ke KeyError) Error() string {
21 | return string(ke)
22 | }
23 |
24 | func handleError(err *error) {
25 | if x := recover(); x != nil {
26 | *err = x.(KeyError)
27 | }
28 | }
29 |
30 | func EncodeValue(value interface{}) string {
31 | switch val := value.(type) {
32 | case int:
33 | return Uint64Key(val).String()
34 | case uint64:
35 | return Uint64Key(val).String()
36 | case int64:
37 | return Uint64Key(val).String()
38 | case string:
39 | return val
40 | case []byte:
41 | return hack.String(val)
42 | }
43 | panic(NewKeyError("Unexpected key variable type %T", value))
44 | }
45 |
46 | func HashValue(value interface{}) uint64 {
47 | switch val := value.(type) {
48 | case int:
49 | return uint64(val)
50 | case uint64:
51 | return uint64(val)
52 | case int64:
53 | return uint64(val)
54 | case string:
55 | return uint64(crc32.ChecksumIEEE(hack.Slice(val)))
56 | case []byte:
57 | return uint64(crc32.ChecksumIEEE(val))
58 | }
59 | panic(NewKeyError("Unexpected key variable type %T", value))
60 | }
61 |
62 | func NumValue(value interface{}) int64 {
63 | switch val := value.(type) {
64 | case int:
65 | return int64(val)
66 | case uint64:
67 | return int64(val)
68 | case int64:
69 | return int64(val)
70 | case string:
71 | if v, err := strconv.ParseInt(val, 10, 64); err != nil {
72 | panic(NewKeyError("invalid num format %s", v))
73 | } else {
74 | return v
75 | }
76 | case []byte:
77 | if v, err := strconv.ParseInt(hack.String(val), 10, 64); err != nil {
78 | panic(NewKeyError("invalid num format %s", v))
79 | } else {
80 | return v
81 | }
82 | }
83 | panic(NewKeyError("Unexpected key variable type %T", value))
84 | }
85 |
86 | type Shard interface {
87 | FindForKey(key interface{}) int
88 | }
89 |
90 | type RangeShard interface {
91 | Shard
92 | EqualStart(key interface{}, index int) bool
93 | EqualStop(key interface{}, index int) bool
94 | }
95 |
96 | type HashShard struct {
97 | ShardNum int
98 | }
99 |
100 | func (s *HashShard) FindForKey(key interface{}) int {
101 | h := HashValue(key)
102 |
103 | return int(h % uint64(s.ShardNum))
104 | }
105 |
106 | type NumRangeShard struct {
107 | Shards []NumKeyRange
108 | }
109 |
110 | func (s *NumRangeShard) FindForKey(key interface{}) int {
111 | v := NumValue(key)
112 | for i, r := range s.Shards {
113 | if r.Contains(v) {
114 | return i
115 | }
116 | }
117 | panic(NewKeyError("Unexpected key %v, not in range", key))
118 | }
119 |
120 | func (s *NumRangeShard) EqualStart(key interface{}, index int) bool {
121 | v := NumValue(key)
122 | return s.Shards[index].Start == v
123 | }
124 | func (s *NumRangeShard) EqualStop(key interface{}, index int) bool {
125 | v := NumValue(key)
126 | return s.Shards[index].End == v
127 | }
128 |
129 | type KeyRangeShard struct {
130 | Shards []KeyRange
131 | }
132 |
133 | func (s *KeyRangeShard) FindForKey(key interface{}) int {
134 | v := KeyspaceId(EncodeValue(key))
135 | for i, r := range s.Shards {
136 | if r.Contains(v) {
137 | return i
138 | }
139 | }
140 | panic(NewKeyError("Unexpected key %v, not in range", key))
141 | }
142 |
143 | func (s *KeyRangeShard) EqualStart(key interface{}, index int) bool {
144 | v := KeyspaceId(EncodeValue(key))
145 | return s.Shards[index].Start == v
146 | }
147 | func (s *KeyRangeShard) EqualStop(key interface{}, index int) bool {
148 | v := KeyspaceId(EncodeValue(key))
149 | return s.Shards[index].End == v
150 | }
151 |
152 | type DefaultShard struct {
153 | }
154 |
155 | func (s *DefaultShard) FindForKey(key interface{}) int {
156 | return 0
157 | }
158 |
--------------------------------------------------------------------------------
/client/db.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "container/list"
5 | "fmt"
6 | . "github.com/siddontang/mixer/mysql"
7 | "sync"
8 | "sync/atomic"
9 | )
10 |
11 | type DB struct {
12 | sync.Mutex
13 |
14 | addr string
15 | user string
16 | password string
17 | db string
18 | maxIdleConns int
19 |
20 | idleConns *list.List
21 |
22 | connNum int32
23 | }
24 |
25 | func Open(addr string, user string, password string, dbName string) (*DB, error) {
26 | db := new(DB)
27 |
28 | db.addr = addr
29 | db.user = user
30 | db.password = password
31 | db.db = dbName
32 |
33 | db.idleConns = list.New()
34 | db.connNum = 0
35 |
36 | return db, nil
37 | }
38 |
39 | func (db *DB) Addr() string {
40 | return db.addr
41 | }
42 |
43 | func (db *DB) String() string {
44 | return fmt.Sprintf("%s:%s@%s/%s?maxIdleConns=%v",
45 | db.user, db.password, db.addr, db.db, db.maxIdleConns)
46 | }
47 |
48 | func (db *DB) Close() error {
49 | db.Lock()
50 |
51 | for {
52 | if db.idleConns.Len() > 0 {
53 | v := db.idleConns.Back()
54 | co := v.Value.(*Conn)
55 | db.idleConns.Remove(v)
56 |
57 | co.Close()
58 |
59 | } else {
60 | break
61 | }
62 | }
63 |
64 | db.Unlock()
65 |
66 | return nil
67 | }
68 |
69 | func (db *DB) Ping() error {
70 | c, err := db.PopConn()
71 | if err != nil {
72 | return err
73 | }
74 |
75 | err = c.Ping()
76 | db.PushConn(c, err)
77 | return err
78 | }
79 |
80 | func (db *DB) SetMaxIdleConnNum(num int) {
81 | db.maxIdleConns = num
82 | }
83 |
84 | func (db *DB) GetIdleConnNum() int {
85 | return db.idleConns.Len()
86 | }
87 |
88 | func (db *DB) GetConnNum() int {
89 | return int(db.connNum)
90 | }
91 |
92 | func (db *DB) newConn() (*Conn, error) {
93 | co := new(Conn)
94 |
95 | if err := co.Connect(db.addr, db.user, db.password, db.db); err != nil {
96 | return nil, err
97 | }
98 |
99 | return co, nil
100 | }
101 |
102 | func (db *DB) tryReuse(co *Conn) error {
103 | if co.IsInTransaction() {
104 | //we can not reuse a connection in transaction status
105 | if err := co.Rollback(); err != nil {
106 | return err
107 | }
108 | }
109 |
110 | if !co.IsAutoCommit() {
111 | //we can not reuse a connection not in autocomit
112 | if _, err := co.exec("set autocommit = 1"); err != nil {
113 | return err
114 | }
115 | }
116 |
117 | //connection may be set names early
118 | //we must use default utf8
119 | if co.GetCharset() != DEFAULT_CHARSET {
120 | if err := co.SetCharset(DEFAULT_CHARSET); err != nil {
121 | return err
122 | }
123 | }
124 |
125 | return nil
126 | }
127 |
128 | func (db *DB) PopConn() (co *Conn, err error) {
129 | db.Lock()
130 | if db.idleConns.Len() > 0 {
131 | v := db.idleConns.Front()
132 | co = v.Value.(*Conn)
133 | db.idleConns.Remove(v)
134 | }
135 | db.Unlock()
136 |
137 | if co != nil {
138 | if err := co.Ping(); err == nil {
139 | if err := db.tryReuse(co); err == nil {
140 | //connection may alive
141 | return co, nil
142 | }
143 | }
144 | co.Close()
145 | }
146 |
147 | co, err = db.newConn()
148 | if err == nil {
149 | atomic.AddInt32(&db.connNum, 1)
150 | }
151 | return
152 | }
153 |
154 | func (db *DB) PushConn(co *Conn, err error) {
155 | var closeConn *Conn = nil
156 |
157 | if err != nil {
158 | closeConn = co
159 | } else {
160 | if db.maxIdleConns > 0 {
161 | db.Lock()
162 |
163 | if db.idleConns.Len() >= db.maxIdleConns {
164 | v := db.idleConns.Front()
165 | closeConn = v.Value.(*Conn)
166 | db.idleConns.Remove(v)
167 | }
168 |
169 | db.idleConns.PushBack(co)
170 |
171 | db.Unlock()
172 |
173 | } else {
174 | closeConn = co
175 | }
176 |
177 | }
178 |
179 | if closeConn != nil {
180 | atomic.AddInt32(&db.connNum, -1)
181 |
182 | closeConn.Close()
183 | }
184 | }
185 |
186 | type SqlConn struct {
187 | *Conn
188 |
189 | db *DB
190 | }
191 |
192 | func (p *SqlConn) Close() {
193 | if p.Conn != nil {
194 | p.db.PushConn(p.Conn, p.Conn.pkgErr)
195 | p.Conn = nil
196 | }
197 | }
198 |
199 | func (db *DB) GetConn() (*SqlConn, error) {
200 | c, err := db.PopConn()
201 | return &SqlConn{c, db}, err
202 | }
203 |
--------------------------------------------------------------------------------
/sqlparser/parsed_query_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package sqlparser
6 |
7 | import (
8 | "testing"
9 |
10 | "github.com/siddontang/mixer/sqltypes"
11 | )
12 |
13 | func TestParsedQuery(t *testing.T) {
14 | tcases := []struct {
15 | desc string
16 | query string
17 | bindVars map[string]interface{}
18 | listVars []sqltypes.Value
19 | output string
20 | }{
21 | {
22 | "no subs",
23 | "select * from a where id = 2",
24 | map[string]interface{}{
25 | "id": 1,
26 | },
27 | nil,
28 | "select * from a where id = 2",
29 | }, {
30 | "simple bindvar sub",
31 | "select * from a where id1 = :id1 and id2 = :id2",
32 | map[string]interface{}{
33 | "id1": 1,
34 | "id2": nil,
35 | },
36 | nil,
37 | "select * from a where id1 = 1 and id2 = null",
38 | }, {
39 | "missing bind var",
40 | "select * from a where id1 = :id1 and id2 = :id2",
41 | map[string]interface{}{
42 | "id1": 1,
43 | },
44 | nil,
45 | "missing bind var id2",
46 | }, {
47 | "unencodable bind var",
48 | "select * from a where id1 = :id",
49 | map[string]interface{}{
50 | "id": make([]int, 1),
51 | },
52 | nil,
53 | "unsupported bind variable type []int: [0]",
54 | }, {
55 | "list var sub",
56 | "select * from a where id = :0 and name = :1",
57 | nil,
58 | []sqltypes.Value{
59 | sqltypes.MakeNumeric([]byte("1")),
60 | sqltypes.MakeString([]byte("aa")),
61 | },
62 | "select * from a where id = 1 and name = 'aa'",
63 | }, {
64 | "list inside bind vars",
65 | "select * from a where id in (:vals)",
66 | map[string]interface{}{
67 | "vals": []sqltypes.Value{
68 | sqltypes.MakeNumeric([]byte("1")),
69 | sqltypes.MakeString([]byte("aa")),
70 | },
71 | },
72 | nil,
73 | "select * from a where id in (1, 'aa')",
74 | }, {
75 | "two lists inside bind vars",
76 | "select * from a where id in (:vals)",
77 | map[string]interface{}{
78 | "vals": [][]sqltypes.Value{
79 | []sqltypes.Value{
80 | sqltypes.MakeNumeric([]byte("1")),
81 | sqltypes.MakeString([]byte("aa")),
82 | },
83 | []sqltypes.Value{
84 | sqltypes.Value{},
85 | sqltypes.MakeString([]byte("bb")),
86 | },
87 | },
88 | },
89 | nil,
90 | "select * from a where id in ((1, 'aa'), (null, 'bb'))",
91 | }, {
92 | "illega list var name",
93 | "select * from a where id = :0a",
94 | nil,
95 | []sqltypes.Value{
96 | sqltypes.MakeNumeric([]byte("1")),
97 | sqltypes.MakeString([]byte("aa")),
98 | },
99 | `unexpected: strconv.ParseInt: parsing "0a": invalid syntax for 0a`,
100 | }, {
101 | "out of range list var index",
102 | "select * from a where id = :10",
103 | nil,
104 | []sqltypes.Value{
105 | sqltypes.MakeNumeric([]byte("1")),
106 | sqltypes.MakeString([]byte("aa")),
107 | },
108 | "index out of range: 10",
109 | },
110 | }
111 |
112 | for _, tcase := range tcases {
113 | tree, err := Parse(tcase.query)
114 | if err != nil {
115 | t.Errorf("parse failed for %s: %v", tcase.desc, err)
116 | continue
117 | }
118 | buf := NewTrackedBuffer(nil)
119 | buf.Fprintf("%v", tree)
120 | pq := buf.ParsedQuery()
121 | bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.listVars)
122 | var got string
123 | if err != nil {
124 | got = err.Error()
125 | } else {
126 | got = string(bytes)
127 | }
128 | if got != tcase.output {
129 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output)
130 | }
131 | }
132 | }
133 |
134 | func TestStarParam(t *testing.T) {
135 | buf := NewTrackedBuffer(nil)
136 | buf.Fprintf("select * from a where id in (%a)", "*")
137 | pq := buf.ParsedQuery()
138 | listvars := []sqltypes.Value{
139 | sqltypes.MakeNumeric([]byte("1")),
140 | sqltypes.MakeString([]byte("aa")),
141 | }
142 | bytes, err := pq.GenerateQuery(nil, listvars)
143 | if err != nil {
144 | t.Errorf("generate failed: %v", err)
145 | return
146 | }
147 | got := string(bytes)
148 | want := "select * from a where id in (1, 'aa')"
149 | if got != want {
150 | t.Errorf("got %s, want %s", got, want)
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/client/stmt.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | . "github.com/siddontang/mixer/mysql"
7 | "math"
8 | )
9 |
10 | type Stmt struct {
11 | conn *Conn
12 | id uint32
13 | query string
14 |
15 | params int
16 | columns int
17 | }
18 |
19 | func (s *Stmt) ParamNum() int {
20 | return s.params
21 | }
22 |
23 | func (s *Stmt) ColumnNum() int {
24 | return s.columns
25 | }
26 |
27 | func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
28 | if err := s.write(args...); err != nil {
29 | return nil, err
30 | }
31 |
32 | return s.conn.readResult(true)
33 | }
34 |
35 | func (s *Stmt) Close() error {
36 | if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
37 | return err
38 | }
39 |
40 | return nil
41 | }
42 |
43 | func (s *Stmt) write(args ...interface{}) error {
44 | paramsNum := s.params
45 |
46 | if len(args) != paramsNum {
47 | return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
48 | }
49 |
50 | paramTypes := make([]byte, paramsNum<<1)
51 | paramValues := make([][]byte, paramsNum)
52 |
53 | //NULL-bitmap, length: (num-params+7)
54 | nullBitmap := make([]byte, (paramsNum+7)>>3)
55 |
56 | var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1))
57 |
58 | var newParamBoundFlag byte = 0
59 |
60 | for i := range args {
61 | if args[i] == nil {
62 | nullBitmap[i/8] |= (1 << (uint(i) % 8))
63 | paramTypes[i<<1] = MYSQL_TYPE_NULL
64 | continue
65 | }
66 |
67 | newParamBoundFlag = 1
68 |
69 | switch v := args[i].(type) {
70 | case int8:
71 | paramTypes[i<<1] = MYSQL_TYPE_TINY
72 | paramValues[i] = []byte{byte(v)}
73 | case int16:
74 | paramTypes[i<<1] = MYSQL_TYPE_SHORT
75 | paramValues[i] = Uint16ToBytes(uint16(v))
76 | case int32:
77 | paramTypes[i<<1] = MYSQL_TYPE_LONG
78 | paramValues[i] = Uint32ToBytes(uint32(v))
79 | case int:
80 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
81 | paramValues[i] = Uint64ToBytes(uint64(v))
82 | case int64:
83 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
84 | paramValues[i] = Uint64ToBytes(uint64(v))
85 | case uint8:
86 | paramTypes[i<<1] = MYSQL_TYPE_TINY
87 | paramTypes[(i<<1)+1] = 0x80
88 | paramValues[i] = []byte{v}
89 | case uint16:
90 | paramTypes[i<<1] = MYSQL_TYPE_SHORT
91 | paramTypes[(i<<1)+1] = 0x80
92 | paramValues[i] = Uint16ToBytes(uint16(v))
93 | case uint32:
94 | paramTypes[i<<1] = MYSQL_TYPE_LONG
95 | paramTypes[(i<<1)+1] = 0x80
96 | paramValues[i] = Uint32ToBytes(uint32(v))
97 | case uint:
98 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
99 | paramTypes[(i<<1)+1] = 0x80
100 | paramValues[i] = Uint64ToBytes(uint64(v))
101 | case uint64:
102 | paramTypes[i<<1] = MYSQL_TYPE_LONGLONG
103 | paramTypes[(i<<1)+1] = 0x80
104 | paramValues[i] = Uint64ToBytes(uint64(v))
105 | case bool:
106 | paramTypes[i<<1] = MYSQL_TYPE_TINY
107 | if v {
108 | paramValues[i] = []byte{1}
109 | } else {
110 | paramValues[i] = []byte{0}
111 |
112 | }
113 | case float32:
114 | paramTypes[i<<1] = MYSQL_TYPE_FLOAT
115 | paramValues[i] = Uint32ToBytes(math.Float32bits(v))
116 | case float64:
117 | paramTypes[i<<1] = MYSQL_TYPE_DOUBLE
118 | paramValues[i] = Uint64ToBytes(math.Float64bits(v))
119 | case string:
120 | paramTypes[i<<1] = MYSQL_TYPE_STRING
121 | paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
122 | case []byte:
123 | paramTypes[i<<1] = MYSQL_TYPE_STRING
124 | paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...)
125 | default:
126 | return fmt.Errorf("invalid argument type %T", args[i])
127 | }
128 |
129 | length += len(paramValues[i])
130 | }
131 |
132 | data := make([]byte, 4, 4+length)
133 |
134 | data = append(data, COM_STMT_EXECUTE)
135 | data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))
136 |
137 | //flag: CURSOR_TYPE_NO_CURSOR
138 | data = append(data, 0x00)
139 |
140 | //iteration-count, always 1
141 | data = append(data, 1, 0, 0, 0)
142 |
143 | if s.params > 0 {
144 | data = append(data, nullBitmap...)
145 |
146 | //new-params-bound-flag
147 | data = append(data, newParamBoundFlag)
148 |
149 | if newParamBoundFlag == 1 {
150 | //type of each parameter, length: num-params * 2
151 | data = append(data, paramTypes...)
152 |
153 | //value of each parameter
154 | for _, v := range paramValues {
155 | data = append(data, v...)
156 | }
157 | }
158 | }
159 |
160 | s.conn.pkg.Sequence = 0
161 |
162 | return s.conn.writePacket(data)
163 | }
164 |
165 | func (c *Conn) Prepare(query string) (*Stmt, error) {
166 | if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil {
167 | return nil, err
168 | }
169 |
170 | data, err := c.readPacket()
171 | if err != nil {
172 | return nil, err
173 | }
174 |
175 | if data[0] == ERR_HEADER {
176 | return nil, c.handleErrorPacket(data)
177 | } else if data[0] != OK_HEADER {
178 | return nil, ErrMalformPacket
179 | }
180 |
181 | s := new(Stmt)
182 | s.conn = c
183 |
184 | pos := 1
185 |
186 | //for statement id
187 | s.id = binary.LittleEndian.Uint32(data[pos:])
188 | pos += 4
189 |
190 | //number columns
191 | s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
192 | pos += 2
193 |
194 | //number params
195 | s.params = int(binary.LittleEndian.Uint16(data[pos:]))
196 | pos += 2
197 |
198 | //warnings
199 | //warnings = binary.LittleEndian.Uint16(data[pos:])
200 |
201 | if s.params > 0 {
202 | if err := s.conn.readUntilEOF(); err != nil {
203 | return nil, err
204 | }
205 | }
206 |
207 | if s.columns > 0 {
208 | if err := s.conn.readUntilEOF(); err != nil {
209 | return nil, err
210 | }
211 | }
212 |
213 | return s, nil
214 | }
215 |
--------------------------------------------------------------------------------
/proxy/conn_show.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/go-log/log"
6 | "github.com/siddontang/mixer/hack"
7 | . "github.com/siddontang/mixer/mysql"
8 | "github.com/siddontang/mixer/sqlparser"
9 | "sort"
10 | "strings"
11 | "time"
12 | )
13 |
14 | func (c *Conn) handleShow(sql string, stmt *sqlparser.Show) error {
15 | var err error
16 | var r *Resultset
17 | switch strings.ToLower(stmt.Section) {
18 | case "databases":
19 | r, err = c.handleShowDatabases()
20 | case "tables":
21 | r, err = c.handleShowTables(sql, stmt)
22 | case "proxy":
23 | r, err = c.handleShowProxy(sql, stmt)
24 | default:
25 | err = fmt.Errorf("unsupport show %s now", sql)
26 | }
27 |
28 | if err != nil {
29 | return err
30 | }
31 |
32 | return c.writeResultset(c.status, r)
33 | }
34 |
35 | func (c *Conn) handleShowDatabases() (*Resultset, error) {
36 | dbs := make([]interface{}, 0, len(c.server.schemas))
37 | for key := range c.server.schemas {
38 | dbs = append(dbs, key)
39 | }
40 |
41 | return c.buildSimpleShowResultset(dbs, "Database")
42 | }
43 |
44 | func (c *Conn) handleShowTables(sql string, stmt *sqlparser.Show) (*Resultset, error) {
45 | s := c.schema
46 | if stmt.From != nil {
47 | db := nstring(stmt.From)
48 | s = c.server.getSchema(db)
49 | }
50 |
51 | if s == nil {
52 | return nil, NewDefaultError(ER_NO_DB_ERROR)
53 | }
54 |
55 | var tables []string
56 | tmap := map[string]struct{}{}
57 | for _, n := range s.nodes {
58 | co, err := n.getMasterConn()
59 | if err != nil {
60 | return nil, err
61 | }
62 |
63 | if err := co.UseDB(s.db); err != nil {
64 | co.Close()
65 | return nil, err
66 | }
67 |
68 | if r, err := co.Execute(sql); err != nil {
69 | co.Close()
70 | return nil, err
71 | } else {
72 | co.Close()
73 | for i := 0; i < r.RowNumber(); i++ {
74 | n, _ := r.GetString(i, 0)
75 | if _, ok := tmap[n]; !ok {
76 | tables = append(tables, n)
77 | }
78 | }
79 | }
80 | }
81 |
82 | sort.Strings(tables)
83 |
84 | values := make([]interface{}, len(tables))
85 | for i := range tables {
86 | values[i] = tables[i]
87 | }
88 |
89 | return c.buildSimpleShowResultset(values, fmt.Sprintf("Tables_in_%s", s.db))
90 | }
91 |
92 | func (c *Conn) handleShowProxy(sql string, stmt *sqlparser.Show) (*Resultset, error) {
93 | var err error
94 | var r *Resultset
95 | switch strings.ToLower(stmt.Key) {
96 | case "config":
97 | r, err = c.handleShowProxyConfig()
98 | case "status":
99 | r, err = c.handleShowProxyStatus(sql, stmt)
100 | default:
101 | err = fmt.Errorf("Unsupport show proxy [%v] yet, just support [config|status] now.", stmt.Key)
102 | log.Warn(err.Error())
103 | return nil, err
104 | }
105 | return r, err
106 | }
107 |
108 | func (c *Conn) handleShowProxyConfig() (*Resultset, error) {
109 | var names []string = []string{"Section", "Key", "Value"}
110 | var rows [][]string
111 | const (
112 | Column = 3
113 | )
114 |
115 | rows = append(rows, []string{"Global_Config", "Addr", c.server.cfg.Addr})
116 | rows = append(rows, []string{"Global_Config", "User", c.server.cfg.User})
117 | rows = append(rows, []string{"Global_Config", "Password", c.server.cfg.Password})
118 | rows = append(rows, []string{"Global_Config", "LogLevel", c.server.cfg.LogLevel})
119 | rows = append(rows, []string{"Global_Config", "Schemas_Count", fmt.Sprintf("%d", len(c.server.schemas))})
120 | rows = append(rows, []string{"Global_Config", "Nodes_Count", fmt.Sprintf("%d", len(c.server.nodes))})
121 |
122 | for db, schema := range c.server.schemas {
123 | rows = append(rows, []string{"Schemas", "DB", db})
124 |
125 | var nodeNames []string
126 | var nodeRows [][]string
127 | for name, node := range schema.nodes {
128 | nodeNames = append(nodeNames, name)
129 | var nodeSection = fmt.Sprintf("Schemas[%s]-Node[ %v ]", db, name)
130 |
131 | if node.master != nil {
132 | nodeRows = append(nodeRows, []string{nodeSection, "Master", node.master.String()})
133 | }
134 |
135 | if node.slave != nil {
136 | nodeRows = append(nodeRows, []string{nodeSection, "Slave", node.slave.String()})
137 | }
138 | nodeRows = append(nodeRows, []string{nodeSection, "Last_Master_Ping", fmt.Sprintf("%v", time.Unix(node.lastMasterPing, 0))})
139 |
140 | nodeRows = append(nodeRows, []string{nodeSection, "Last_Slave_Ping", fmt.Sprintf("%v", time.Unix(node.lastSlavePing, 0))})
141 |
142 | nodeRows = append(nodeRows, []string{nodeSection, "down_after_noalive", fmt.Sprintf("%v", node.downAfterNoAlive)})
143 |
144 | }
145 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]", db), "Nodes_List", strings.Join(nodeNames, ",")})
146 |
147 | var defaultRule = schema.rule.DefaultRule
148 | if defaultRule.DB == db {
149 | if defaultRule.DB == db {
150 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]_Rule_Default", db),
151 | "Default_Table", defaultRule.String()})
152 | }
153 | }
154 | for tb, r := range schema.rule.Rules {
155 | if r.DB == db {
156 | rows = append(rows, []string{fmt.Sprintf("Schemas[%s]_Rule_Table", db),
157 | fmt.Sprintf("Table[ %s ]", tb), r.String()})
158 | }
159 | }
160 |
161 | rows = append(rows, nodeRows...)
162 |
163 | }
164 |
165 | var values [][]interface{} = make([][]interface{}, len(rows))
166 | for i := range rows {
167 | values[i] = make([]interface{}, Column)
168 | for j := range rows[i] {
169 | values[i][j] = rows[i][j]
170 | }
171 | }
172 |
173 | return c.buildResultset(names, values)
174 | }
175 |
176 | func (c *Conn) handleShowProxyStatus(sql string, stmt *sqlparser.Show) (*Resultset, error) {
177 | // TODO: handle like_or_where expr
178 | return nil, nil
179 | }
180 |
181 | func (c *Conn) buildSimpleShowResultset(values []interface{}, name string) (*Resultset, error) {
182 |
183 | r := new(Resultset)
184 |
185 | field := &Field{}
186 |
187 | field.Name = hack.Slice(name)
188 | field.Charset = 33
189 | field.Type = MYSQL_TYPE_VAR_STRING
190 |
191 | r.Fields = []*Field{field}
192 |
193 | var row []byte
194 | var err error
195 |
196 | for _, value := range values {
197 | row, err = formatValue(value)
198 | if err != nil {
199 | return nil, err
200 | }
201 | r.RowDatas = append(r.RowDatas,
202 | PutLengthEncodedString(row))
203 | }
204 |
205 | return r, nil
206 | }
207 |
--------------------------------------------------------------------------------
/proxy/conn_stmt_test.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestStmt_DropTable(t *testing.T) {
8 | server := newTestServer(t)
9 | n := server.nodes["node1"]
10 | c, err := n.getMasterConn()
11 | if err != nil {
12 | t.Fatal(err)
13 | }
14 | c.UseDB("mixer")
15 | if _, err := c.Execute(`drop table if exists mixer_test_proxy_stmt`); err != nil {
16 | t.Fatal(err)
17 | }
18 | c.Close()
19 | }
20 |
21 | func TestStmt_CreateTable(t *testing.T) {
22 | str := `CREATE TABLE IF NOT EXISTS mixer_test_proxy_stmt (
23 | id BIGINT(64) UNSIGNED NOT NULL,
24 | str VARCHAR(256),
25 | f DOUBLE,
26 | e enum("test1", "test2"),
27 | u tinyint unsigned,
28 | i tinyint,
29 | PRIMARY KEY (id)
30 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
31 |
32 | server := newTestServer(t)
33 | n := server.nodes["node1"]
34 | c, err := n.getMasterConn()
35 | if err != nil {
36 | t.Fatal(err)
37 | }
38 |
39 | c.UseDB("mixer")
40 | defer c.Close()
41 | if _, err := c.Execute(str); err != nil {
42 | t.Fatal(err)
43 | }
44 | }
45 |
46 | func TestStmt_Insert(t *testing.T) {
47 | str := `insert into mixer_test_proxy_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)`
48 |
49 | c := newTestDBConn(t)
50 | defer c.Close()
51 |
52 | s, err := c.Prepare(str)
53 |
54 | if err != nil {
55 | t.Fatal(err)
56 | }
57 |
58 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil {
59 | t.Fatal(err)
60 | } else {
61 | if pkg.AffectedRows != 1 {
62 | t.Fatal(pkg.AffectedRows)
63 | }
64 | }
65 |
66 | s.Close()
67 | }
68 |
69 | func TestStmt_Select(t *testing.T) {
70 | str := `select str, f, e from mixer_test_proxy_stmt where id = ?`
71 |
72 | c := newTestDBConn(t)
73 | defer c.Close()
74 |
75 | s, err := c.Prepare(str)
76 |
77 | if err != nil {
78 | t.Fatal(err)
79 | }
80 |
81 | if result, err := s.Execute(1); err != nil {
82 | t.Fatal(err)
83 | } else {
84 | if len(result.Values) != 1 {
85 | t.Fatal(len(result.Values))
86 | }
87 |
88 | if len(result.Fields) != 3 {
89 | t.Fatal(len(result.Fields))
90 | }
91 |
92 | if str, _ := result.GetString(0, 0); str != "a" {
93 | t.Fatal("invalid str", str)
94 | }
95 |
96 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) {
97 | t.Fatal("invalid f", f)
98 | }
99 |
100 | if e, _ := result.GetString(0, 2); e != "test1" {
101 | t.Fatal("invalid e", e)
102 | }
103 |
104 | if str, _ := result.GetStringByName(0, "str"); str != "a" {
105 | t.Fatal("invalid str", str)
106 | }
107 |
108 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) {
109 | t.Fatal("invalid f", f)
110 | }
111 |
112 | if e, _ := result.GetStringByName(0, "e"); e != "test1" {
113 | t.Fatal("invalid e", e)
114 | }
115 |
116 | }
117 |
118 | s.Close()
119 | }
120 |
121 | func TestStmt_NULL(t *testing.T) {
122 | str := `insert into mixer_test_proxy_stmt (id, str, f, e) values (?, ?, ?, ?)`
123 |
124 | c := newTestDBConn(t)
125 | defer c.Close()
126 |
127 | s, err := c.Prepare(str)
128 |
129 | if err != nil {
130 | t.Fatal(err)
131 | }
132 |
133 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil {
134 | t.Fatal(err)
135 | } else {
136 | if pkg.AffectedRows != 1 {
137 | t.Fatal(pkg.AffectedRows)
138 | }
139 | }
140 |
141 | s.Close()
142 |
143 | str = `select * from mixer_test_proxy_stmt where id = ?`
144 | s, err = c.Prepare(str)
145 |
146 | if err != nil {
147 | t.Fatal(err)
148 | }
149 |
150 | if r, err := s.Execute(2); err != nil {
151 | t.Fatal(err)
152 | } else {
153 | if b, err := r.IsNullByName(0, "id"); err != nil {
154 | t.Fatal(err)
155 | } else if b == true {
156 | t.Fatal(b)
157 | }
158 |
159 | if b, err := r.IsNullByName(0, "str"); err != nil {
160 | t.Fatal(err)
161 | } else if b == false {
162 | t.Fatal(b)
163 | }
164 |
165 | if b, err := r.IsNullByName(0, "f"); err != nil {
166 | t.Fatal(err)
167 | } else if b == true {
168 | t.Fatal(b)
169 | }
170 |
171 | if b, err := r.IsNullByName(0, "e"); err != nil {
172 | t.Fatal(err)
173 | } else if b == false {
174 | t.Fatal(b)
175 | }
176 | }
177 |
178 | s.Close()
179 | }
180 |
181 | func TestStmt_Unsigned(t *testing.T) {
182 | str := `insert into mixer_test_proxy_stmt (id, u) values (?, ?)`
183 |
184 | c := newTestDBConn(t)
185 | defer c.Close()
186 |
187 | s, err := c.Prepare(str)
188 |
189 | if err != nil {
190 | t.Fatal(err)
191 | }
192 |
193 | if pkg, err := s.Execute(3, uint8(255)); err != nil {
194 | t.Fatal(err)
195 | } else {
196 | if pkg.AffectedRows != 1 {
197 | t.Fatal(pkg.AffectedRows)
198 | }
199 | }
200 |
201 | s.Close()
202 |
203 | str = `select u from mixer_test_proxy_stmt where id = ?`
204 |
205 | s, err = c.Prepare(str)
206 | if err != nil {
207 | t.Fatal(err)
208 | }
209 |
210 | if r, err := s.Execute(3); err != nil {
211 | t.Fatal(err)
212 | } else {
213 | if u, err := r.GetUint(0, 0); err != nil {
214 | t.Fatal(err)
215 | } else if u != uint64(255) {
216 | t.Fatal(u)
217 | }
218 | }
219 |
220 | s.Close()
221 | }
222 |
223 | func TestStmt_Signed(t *testing.T) {
224 | str := `insert into mixer_test_proxy_stmt (id, i) values (?, ?)`
225 |
226 | c := newTestDBConn(t)
227 | defer c.Close()
228 |
229 | s, err := c.Prepare(str)
230 |
231 | if err != nil {
232 | t.Fatal(err)
233 | }
234 |
235 | if _, err := s.Execute(4, 127); err != nil {
236 | t.Fatal(err)
237 | }
238 |
239 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil {
240 | t.Fatal(err)
241 | }
242 |
243 | s.Close()
244 |
245 | }
246 |
247 | func TestStmt_Trans(t *testing.T) {
248 | c1 := newTestDBConn(t)
249 | defer c1.Close()
250 |
251 | if _, err := c1.Execute(`insert into mixer_test_proxy_stmt (id, str) values (1002, "abc")`); err != nil {
252 | t.Fatal(err)
253 | }
254 |
255 | var err error
256 | if err = c1.Begin(); err != nil {
257 | t.Fatal(err)
258 | }
259 |
260 | str := `select str from mixer_test_proxy_stmt where id = ?`
261 |
262 | s, err := c1.Prepare(str)
263 | if err != nil {
264 | t.Fatal(err)
265 | }
266 |
267 | if _, err := s.Execute(1002); err != nil {
268 | t.Fatal(err)
269 | }
270 |
271 | if err := c1.Commit(); err != nil {
272 | t.Fatal(err)
273 | }
274 |
275 | if r, err := s.Execute(1002); err != nil {
276 | t.Fatal(err)
277 | } else {
278 | if str, _ := r.GetString(0, 0); str != `abc` {
279 | t.Fatal(str)
280 | }
281 | }
282 |
283 | if err := s.Close(); err != nil {
284 | t.Fatal(err)
285 | }
286 | }
287 |
--------------------------------------------------------------------------------
/doc/mysql-proxy/scripting.txt:
--------------------------------------------------------------------------------
1 | Hooks
2 | =====
3 |
4 | connect_server
5 | --------------
6 |
7 | read_auth
8 | ---------
9 |
10 | read_auth_result
11 | ----------------
12 |
13 | read_query
14 | ----------
15 |
16 | read_query_result
17 | -----------------
18 |
19 | disconnect_client
20 | -----------------
21 |
22 | Modules
23 | =======
24 |
25 | mysql.proto
26 | -----------
27 |
28 | The ``mysql.proto`` module provides encoders and decoders for the packets exchanged between client and server
29 |
30 |
31 | from_err_packet
32 | ...............
33 |
34 | Decodes a ERR-packet into a table.
35 |
36 | Parameters:
37 |
38 | ``packet``
39 | (string) mysql packet
40 |
41 |
42 | On success it returns a table containing:
43 |
44 | ``errmsg``
45 | (string)
46 |
47 | ``sqlstate``
48 | (string)
49 |
50 | ``errcode``
51 | (int)
52 |
53 | Otherwise it raises an error.
54 |
55 | to_err_packet
56 | .............
57 |
58 | Encode a table containing a ERR packet into a MySQL packet.
59 |
60 | Parameters:
61 |
62 | ``err``
63 | (table)
64 |
65 | ``errmsg``
66 | (string)
67 |
68 | ``sqlstate``
69 | (string)
70 |
71 | ``errcode``
72 | (int)
73 |
74 | into a MySQL packet.
75 |
76 | Returns a string.
77 |
78 | from_ok_packet
79 | ..............
80 |
81 | Decodes a OK-packet
82 |
83 | ``packet``
84 | (string) mysql packet
85 |
86 |
87 | On success it returns a table containing:
88 |
89 | ``server_status``
90 | (int) bit-mask of the connection status
91 |
92 | ``insert_id``
93 | (int) last used insert id
94 |
95 | ``warnings``
96 | (int) number of warnings for the last executed statement
97 |
98 | ``affected_rows``
99 | (int) rows affected by the last statement
100 |
101 | Otherwise it raises an error.
102 |
103 |
104 | to_ok_packet
105 | ............
106 |
107 | Encode a OK packet
108 |
109 | from_eof_packet
110 | ...............
111 |
112 | Decodes a EOF-packet
113 |
114 | Parameters:
115 |
116 | ``packet``
117 | (string) mysql packet
118 |
119 |
120 | On success it returns a table containing:
121 |
122 | ``server_status``
123 | (int) bit-mask of the connection status
124 |
125 | ``warnings``
126 | (int)
127 |
128 | Otherwise it raises an error.
129 |
130 |
131 | to_eof_packet
132 | .............
133 |
134 | from_challenge_packet
135 | .....................
136 |
137 | Decodes a auth-challenge-packet
138 |
139 | Parameters:
140 |
141 | ``packet``
142 | (string) mysql packet
143 |
144 | On success it returns a table containing:
145 |
146 | ``protocol_version``
147 | (int) version of the mysql protocol, usually 10
148 |
149 | ``server_version``
150 | (int) version of the server as integer: 50506 is MySQL 5.5.6
151 |
152 | ``thread_id``
153 | (int) connection id
154 |
155 | ``capabilities``
156 | (int) bit-mask of the server capabilities
157 |
158 | ``charset``
159 | (int) server default character-set
160 |
161 | ``server_status``
162 | (int) bit-mask of the connection-status
163 |
164 | ``challenge``
165 | (string) password challenge
166 |
167 |
168 | to_challenge_packet
169 | ...................
170 |
171 | Encode a auth-response-packet
172 |
173 | from_response_packet
174 | ....................
175 |
176 | Decodes a auth-response-packet
177 |
178 | Parameters:
179 |
180 | ``packet``
181 | (string) mysql packet
182 |
183 |
184 | to_response_packet
185 | ..................
186 |
187 | from_masterinfo_string
188 | ......................
189 |
190 | Decodes the content of the ``master.info`` file.
191 |
192 |
193 | to_masterinfo_string
194 | ....................
195 |
196 | from_stmt_prepare_packet
197 | ........................
198 |
199 | Decodes a COM_STMT_PREPARE-packet
200 |
201 | Parameters:
202 |
203 | ``packet``
204 | (string) mysql packet
205 |
206 |
207 | On success it returns a table containing:
208 |
209 | ``stmt_text``
210 | (string)
211 | text of the prepared statement
212 |
213 | Otherwise it raises an error.
214 |
215 | from_stmt_prepare_ok_packet
216 | ...........................
217 |
218 | Decodes a COM_STMT_PACKET OK-packet
219 |
220 | Parameters:
221 |
222 | ``packet``
223 | (string) mysql packet
224 |
225 |
226 | On success it returns a table containing:
227 |
228 | ``stmt_id``
229 | (int) statement-id
230 |
231 | ``num_columns``
232 | (int) number of columns in the resultset
233 |
234 | ``num_params``
235 | (int) number of parameters
236 |
237 | ``warnings``
238 | (int) warnings generated by the prepare statement
239 |
240 | Otherwise it raises an error.
241 |
242 |
243 | from_stmt_execute_packet
244 | ........................
245 |
246 | Decodes a COM_STMT_EXECUTE-packet
247 |
248 | Parameters:
249 |
250 | ``packet``
251 | (string) mysql packet
252 |
253 | ``num_params``
254 | (int) number of parameters of the corresponding prepared statement
255 |
256 | On success it returns a table containing:
257 |
258 | ``stmt_id``
259 | (int) statemend-id
260 |
261 | ``flags``
262 | (int) flags describing the kind of cursor used
263 |
264 | ``iteration_count``
265 | (int) iteration count: always 1
266 |
267 | ``new_params_bound``
268 | (bool)
269 |
270 | ``params``
271 | (nil, table)
272 | number-index array of parameters if ``new_params_bound`` is ``true``
273 |
274 | Each param is a table of:
275 |
276 | ``type``
277 | (int)
278 | MYSQL_TYPE_INT, MYSQL_TYPE_STRING ... and so on
279 |
280 | ``value``
281 | (nil, number, string)
282 | if the value is a NULL, it ``nil``
283 | if it is a number (_INT, _DOUBLE, ...) it is a ``number``
284 | otherwise it is a ``string``
285 |
286 | If decoding fails it raises an error.
287 |
288 | To get the ``num_params`` for this function, you have to track the track the number of parameters as returned
289 | by the `from_stmt_prepare_ok_packet`_. Use `stmt_id_from_stmt_execute_packet`_ to get the ``statement-id`` from
290 | the COM_STMT_EXECUTE packet and lookup your tracked information.
291 |
292 | stmt_id_from_stmt_execute_packet
293 | ................................
294 |
295 | Decodes statement-id from a COM_STMT_EXECUTE-packet
296 |
297 | Parameters:
298 |
299 | ``packet``
300 | (string) mysql packet
301 |
302 |
303 | On success it returns the ``statement-id`` as ``int``.
304 |
305 | Otherwise it raises an error.
306 |
307 | from_stmt_close_packet
308 | ......................
309 |
310 | Decodes a COM_STMT_CLOSE-packet
311 |
312 | Parameters:
313 |
314 | ``packet``
315 | (string) mysql packet
316 |
317 |
318 | On success it returns a table containing:
319 |
320 | ``stmt_id``
321 | (int)
322 | statement-id that shall be closed
323 |
324 | Otherwise it raises an error.
325 |
326 |
327 |
--------------------------------------------------------------------------------
/router/key.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package router
6 |
7 | import (
8 | "bytes"
9 | "encoding/binary"
10 | "encoding/hex"
11 | "fmt"
12 | "strings"
13 | )
14 |
15 | //
16 | // KeyspaceId definitions
17 | //
18 |
19 | // MinKey is smaller than all KeyspaceId (the value really is).
20 | var MinKey = KeyspaceId("")
21 |
22 | // MaxKey is bigger than all KeyspaceId (by convention).
23 | var MaxKey = KeyspaceId("")
24 |
25 | // KeyspaceId is the type we base sharding on.
26 | type KeyspaceId string
27 |
28 | // Hex prints a KeyspaceId in capital hex.
29 | func (kid KeyspaceId) Hex() HexKeyspaceId {
30 | return HexKeyspaceId(strings.ToUpper(hex.EncodeToString([]byte(kid))))
31 | }
32 |
33 | //
34 | // Uint64Key definitions
35 | //
36 |
37 | // Uint64Key is a uint64 that can be converted into a KeyspaceId.
38 | type Uint64Key uint64
39 |
40 | func (i Uint64Key) String() string {
41 | buf := new(bytes.Buffer)
42 | binary.Write(buf, binary.BigEndian, uint64(i))
43 | return buf.String()
44 | }
45 |
46 | // KeyspaceId returns the KeyspaceId associated with a Uint64Key.
47 | func (i Uint64Key) KeyspaceId() KeyspaceId {
48 | return KeyspaceId(i.String())
49 | }
50 |
51 | // HexKeyspaceId is the hex represention of a KeyspaceId.
52 | type HexKeyspaceId string
53 |
54 | // Unhex converts a HexKeyspaceId into a KeyspaceId (hex decoding).
55 | func (hkid HexKeyspaceId) Unhex() (KeyspaceId, error) {
56 | b, err := hex.DecodeString(string(hkid))
57 | if err != nil {
58 | return KeyspaceId(""), err
59 | }
60 | return KeyspaceId(string(b)), nil
61 | }
62 |
63 | //
64 | // KeyspaceIdType definitions
65 | //
66 |
67 | // KeyspaceIdType represents the type of the KeyspaceId.
68 | // Usually we don't care, but some parts of the code will need that info.
69 | type KeyspaceIdType string
70 |
71 | const (
72 | // unset - no type for this KeyspaceId
73 | KIT_UNSET = KeyspaceIdType("")
74 |
75 | // uint64 - a uint64 value is used
76 | // this is represented as 'unsigned bigint' in mysql
77 | KIT_UINT64 = KeyspaceIdType("uint64")
78 |
79 | // bytes - a string of bytes is used
80 | // this is represented as 'varbinary' in mysql
81 | KIT_BYTES = KeyspaceIdType("bytes")
82 | )
83 |
84 | var AllKeyspaceIdTypes = []KeyspaceIdType{
85 | KIT_UNSET,
86 | KIT_UINT64,
87 | KIT_BYTES,
88 | }
89 |
90 | // IsKeyspaceIdTypeInList returns true if the given type is in the list.
91 | // Use it with AllKeyspaceIdTypes for instance.
92 | func IsKeyspaceIdTypeInList(typ KeyspaceIdType, types []KeyspaceIdType) bool {
93 | for _, t := range types {
94 | if typ == t {
95 | return true
96 | }
97 | }
98 | return false
99 | }
100 |
101 | //
102 | // KeyRange definitions
103 | //
104 |
105 | // KeyRange is an interval of KeyspaceId values. It contains Start,
106 | // but excludes End. In other words, it is: [Start, End[
107 | type KeyRange struct {
108 | Start KeyspaceId
109 | End KeyspaceId
110 | }
111 |
112 | func (kr KeyRange) MapKey() string {
113 | return string(kr.Start) + "-" + string(kr.End)
114 | }
115 |
116 | func (kr KeyRange) Contains(i KeyspaceId) bool {
117 | return kr.Start <= i && (kr.End == MaxKey || i < kr.End)
118 | }
119 |
120 | func (kr KeyRange) String() string {
121 | return fmt.Sprintf("{Start: %v, End: %v}", string(kr.Start.Hex()), string(kr.End.Hex()))
122 | }
123 |
124 | // Parse a start and end hex values and build a KeyRange
125 | func ParseKeyRangeParts(start, end string) (KeyRange, error) {
126 | s, err := HexKeyspaceId(start).Unhex()
127 | if err != nil {
128 | return KeyRange{}, err
129 | }
130 | e, err := HexKeyspaceId(end).Unhex()
131 | if err != nil {
132 | return KeyRange{}, err
133 | }
134 | return KeyRange{Start: s, End: e}, nil
135 | }
136 |
137 | // Returns true if the KeyRange does not cover the entire space.
138 | func (kr KeyRange) IsPartial() bool {
139 | return !(kr.Start == MinKey && kr.End == MaxKey)
140 | }
141 |
142 | // KeyRangesIntersect returns true if some Keyspace values exist in both ranges.
143 | //
144 | // See: http://stackoverflow.com/questions/4879315/what-is-a-tidy-algorithm-to-find-overlapping-intervals
145 | // two segments defined as (a,b) and (c,d) (with a c) && (a < d)
147 | // overlap = min(b, d) - max(c, a)
148 | func KeyRangesIntersect(first, second KeyRange) bool {
149 | return (first.End == MaxKey || second.Start < first.End) &&
150 | (second.End == MaxKey || first.Start < second.End)
151 | }
152 |
153 | // KeyRangesOverlap returns the overlap between two KeyRanges.
154 | // They need to overlap, otherwise an error is returned.
155 | func KeyRangesOverlap(first, second KeyRange) (KeyRange, error) {
156 | if !KeyRangesIntersect(first, second) {
157 | return KeyRange{}, fmt.Errorf("KeyRanges %v and %v don't overlap", first, second)
158 | }
159 | // compute max(c,a) and min(b,d)
160 | // start with (a,b)
161 | result := first
162 | // if c > a, then use c
163 | if second.Start > first.Start {
164 | result.Start = second.Start
165 | }
166 | // if b is maxed out, or
167 | // (d is not maxed out and d < b)
168 | // ^ valid test as neither b nor d are max
169 | // then use d
170 | if first.End == MaxKey || (second.End != MaxKey && second.End < first.End) {
171 | result.End = second.End
172 | }
173 | return result, nil
174 | }
175 |
176 | // ParseShardingSpec parses a string that describes a sharding
177 | // specification. a-b-c-d will be parsed as a-b, b-c, c-d. The empty
178 | // string may serve both as the start and end of the keyspace: -a-b-
179 | // will be parsed as start-a, a-b, b-end.
180 | func ParseShardingSpec(spec string) ([]KeyRange, error) {
181 | parts := strings.Split(spec, "-")
182 | if len(parts) == 1 {
183 | return nil, fmt.Errorf("malformed spec: doesn't define a range: %q", spec)
184 | }
185 | old := parts[0]
186 | ranges := make([]KeyRange, len(parts)-1)
187 |
188 | for i, p := range parts[1:] {
189 | if p == "" && i != (len(parts)-2) {
190 | return nil, fmt.Errorf("malformed spec: MinKey/MaxKey cannot be in the middle of the spec: %q", spec)
191 | }
192 | if p != "" && p <= old {
193 | return nil, fmt.Errorf("malformed spec: shard limits should be in order: %q", spec)
194 | }
195 | s, err := HexKeyspaceId(old).Unhex()
196 | if err != nil {
197 | return nil, err
198 | }
199 | e, err := HexKeyspaceId(p).Unhex()
200 | if err != nil {
201 | return nil, err
202 | }
203 | ranges[i] = KeyRange{Start: s, End: e}
204 | old = p
205 | }
206 | return ranges, nil
207 | }
208 |
--------------------------------------------------------------------------------
/proxy/node.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/go-log/log"
6 | "github.com/siddontang/mixer/client"
7 | "github.com/siddontang/mixer/config"
8 | "sync"
9 | "time"
10 | )
11 |
12 | const (
13 | Master = "master"
14 | Slave = "slave"
15 | )
16 |
17 | type Node struct {
18 | sync.Mutex
19 |
20 | server *Server
21 |
22 | cfg config.NodeConfig
23 |
24 | //running master db
25 | db *client.DB
26 |
27 | master *client.DB
28 | slave *client.DB
29 |
30 | downAfterNoAlive time.Duration
31 |
32 | lastMasterPing int64
33 | lastSlavePing int64
34 | }
35 |
36 | func (n *Node) run() {
37 | //to do
38 | //1 check connection alive
39 | //2 check remove mysql server alive
40 |
41 | t := time.NewTicker(3000 * time.Second)
42 | defer t.Stop()
43 |
44 | n.lastMasterPing = time.Now().Unix()
45 | n.lastSlavePing = n.lastMasterPing
46 | for {
47 | select {
48 | case <-t.C:
49 | n.checkMaster()
50 | n.checkSlave()
51 | }
52 | }
53 | }
54 |
55 | func (n *Node) String() string {
56 | return n.cfg.Name
57 | }
58 |
59 | func (n *Node) getMasterConn() (*client.SqlConn, error) {
60 | n.Lock()
61 | db := n.db
62 | n.Unlock()
63 |
64 | if db == nil {
65 | return nil, fmt.Errorf("master is down")
66 | }
67 |
68 | return db.GetConn()
69 | }
70 |
71 | func (n *Node) getSelectConn() (*client.SqlConn, error) {
72 | var db *client.DB
73 |
74 | n.Lock()
75 | if n.cfg.RWSplit && n.slave != nil {
76 | db = n.slave
77 | } else {
78 | db = n.db
79 | }
80 | n.Unlock()
81 |
82 | if db == nil {
83 | return nil, fmt.Errorf("no alive mysql server")
84 | }
85 |
86 | return db.GetConn()
87 | }
88 |
89 | func (n *Node) checkMaster() {
90 | n.Lock()
91 | db := n.db
92 | n.Unlock()
93 |
94 | if db == nil {
95 | log.Info("no master avaliable")
96 | return
97 | }
98 |
99 | if err := db.Ping(); err != nil {
100 | log.Error("%s ping master %s error %s", n, db.Addr(), err.Error())
101 | } else {
102 | n.lastMasterPing = time.Now().Unix()
103 | return
104 | }
105 |
106 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastMasterPing > int64(n.downAfterNoAlive) {
107 | log.Error("%s down master db %s", n, n.master.Addr())
108 |
109 | n.downMaster()
110 | }
111 | }
112 |
113 | func (n *Node) checkSlave() {
114 | if n.slave == nil {
115 | return
116 | }
117 |
118 | db := n.slave
119 | if err := db.Ping(); err != nil {
120 | log.Error("%s ping slave %s error %s", n, db.Addr(), err.Error())
121 | } else {
122 | n.lastSlavePing = time.Now().Unix()
123 | }
124 |
125 | if int64(n.downAfterNoAlive) > 0 && time.Now().Unix()-n.lastSlavePing > int64(n.downAfterNoAlive) {
126 | log.Error("%s slave db %s not alive over %ds, down it",
127 | n, db.Addr(), int64(n.downAfterNoAlive/time.Second))
128 |
129 | n.downSlave()
130 | }
131 | }
132 |
133 | func (n *Node) openDB(addr string) (*client.DB, error) {
134 | db, err := client.Open(addr, n.cfg.User, n.cfg.Password, "")
135 | if err != nil {
136 | return nil, err
137 | }
138 |
139 | db.SetMaxIdleConnNum(n.cfg.IdleConns)
140 | return db, nil
141 | }
142 |
143 | func (n *Node) checkUpDB(addr string) (*client.DB, error) {
144 | db, err := n.openDB(addr)
145 | if err != nil {
146 | return nil, err
147 | }
148 |
149 | if err := db.Ping(); err != nil {
150 | db.Close()
151 | return nil, err
152 | }
153 |
154 | return db, nil
155 | }
156 |
157 | func (n *Node) upMaster(addr string) error {
158 | n.Lock()
159 | if n.master != nil {
160 | n.Unlock()
161 | return fmt.Errorf("%s master must be down first", n)
162 | }
163 | n.Unlock()
164 |
165 | db, err := n.checkUpDB(addr)
166 | if err != nil {
167 | return err
168 | }
169 |
170 | n.Lock()
171 | n.master = db
172 | n.db = db
173 | n.Unlock()
174 |
175 | return nil
176 | }
177 |
178 | func (n *Node) upSlave(addr string) error {
179 | n.Lock()
180 | if n.slave != nil {
181 | n.Unlock()
182 | return fmt.Errorf("%s, slave must be down first", n)
183 | }
184 | n.Unlock()
185 |
186 | db, err := n.checkUpDB(addr)
187 | if err != nil {
188 | return err
189 | }
190 |
191 | n.Lock()
192 | n.slave = db
193 | n.Unlock()
194 |
195 | return nil
196 | }
197 |
198 | func (n *Node) downMaster() error {
199 | n.Lock()
200 | if n.master != nil {
201 | n.master = nil
202 | }
203 | return nil
204 | }
205 |
206 | func (n *Node) downSlave() error {
207 | n.Lock()
208 | db := n.slave
209 | n.slave = nil
210 | n.Unlock()
211 |
212 | if db != nil {
213 | db.Close()
214 | }
215 |
216 | return nil
217 | }
218 |
219 | func (s *Server) UpMaster(node string, addr string) error {
220 | n := s.getNode(node)
221 | if n == nil {
222 | return fmt.Errorf("invalid node %s", node)
223 | }
224 |
225 | return n.upMaster(addr)
226 | }
227 |
228 | func (s *Server) UpSlave(node string, addr string) error {
229 | n := s.getNode(node)
230 | if n == nil {
231 | return fmt.Errorf("invalid node %s", node)
232 | }
233 |
234 | return n.upSlave(addr)
235 | }
236 | func (s *Server) DownMaster(node string) error {
237 | n := s.getNode(node)
238 | if n == nil {
239 | return fmt.Errorf("invalid node %s", node)
240 | }
241 | n.db = nil
242 | return n.downMaster()
243 | }
244 |
245 | func (s *Server) DownSlave(node string) error {
246 | n := s.getNode(node)
247 | if n == nil {
248 | return fmt.Errorf("invalid node [%s].", node)
249 | }
250 | return n.downSlave()
251 | }
252 |
253 | func (s *Server) getNode(name string) *Node {
254 | return s.nodes[name]
255 | }
256 |
257 | func (s *Server) parseNodes() error {
258 | cfg := s.cfg
259 | s.nodes = make(map[string]*Node, len(cfg.Nodes))
260 |
261 | for _, v := range cfg.Nodes {
262 | if _, ok := s.nodes[v.Name]; ok {
263 | return fmt.Errorf("duplicate node [%s].", v.Name)
264 | }
265 |
266 | n, err := s.parseNode(v)
267 | if err != nil {
268 | return err
269 | }
270 |
271 | s.nodes[v.Name] = n
272 | }
273 |
274 | return nil
275 | }
276 |
277 | func (s *Server) parseNode(cfg config.NodeConfig) (*Node, error) {
278 | n := new(Node)
279 | n.server = s
280 | n.cfg = cfg
281 |
282 | n.downAfterNoAlive = time.Duration(cfg.DownAfterNoAlive) * time.Second
283 |
284 | if len(cfg.Master) == 0 {
285 | return nil, fmt.Errorf("must setting master MySQL node.")
286 | }
287 |
288 | var err error
289 | if n.master, err = n.openDB(cfg.Master); err != nil {
290 | return nil, err
291 | }
292 |
293 | n.db = n.master
294 |
295 | if len(cfg.Slave) > 0 {
296 | if n.slave, err = n.openDB(cfg.Slave); err != nil {
297 | log.Error(err.Error())
298 | n.slave = nil
299 | }
300 | }
301 |
302 | go n.run()
303 |
304 | return n, nil
305 | }
306 |
--------------------------------------------------------------------------------
/client/stmt_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestStmt_DropTable(t *testing.T) {
8 | str := `drop table if exists mixer_test_stmt`
9 |
10 | c := newTestConn()
11 |
12 | s, err := c.Prepare(str)
13 | if err != nil {
14 | t.Fatal(err)
15 | }
16 |
17 | if _, err := s.Execute(); err != nil {
18 | t.Fatal(err)
19 | }
20 |
21 | s.Close()
22 | }
23 |
24 | func TestStmt_CreateTable(t *testing.T) {
25 | str := `CREATE TABLE IF NOT EXISTS mixer_test_stmt (
26 | id BIGINT(64) UNSIGNED NOT NULL,
27 | str VARCHAR(256),
28 | f DOUBLE,
29 | e enum("test1", "test2"),
30 | u tinyint unsigned,
31 | i tinyint,
32 | PRIMARY KEY (id)
33 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
34 |
35 | c := newTestConn()
36 | defer c.Close()
37 |
38 | s, err := c.Prepare(str)
39 |
40 | if err != nil {
41 | t.Fatal(err)
42 | }
43 |
44 | if _, err = s.Execute(); err != nil {
45 | t.Fatal(err)
46 | }
47 |
48 | s.Close()
49 | }
50 |
51 | func TestStmt_Delete(t *testing.T) {
52 | str := `delete from mixer_test_stmt`
53 |
54 | c := newTestConn()
55 | defer c.Close()
56 |
57 | s, err := c.Prepare(str)
58 |
59 | if err != nil {
60 | t.Fatal(err)
61 | }
62 |
63 | if _, err := s.Execute(); err != nil {
64 | t.Fatal(err)
65 | }
66 |
67 | s.Close()
68 | }
69 |
70 | func TestStmt_Insert(t *testing.T) {
71 | str := `insert into mixer_test_stmt (id, str, f, e, u, i) values (?, ?, ?, ?, ?, ?)`
72 |
73 | c := newTestConn()
74 | defer c.Close()
75 |
76 | s, err := c.Prepare(str)
77 |
78 | if err != nil {
79 | t.Fatal(err)
80 | }
81 |
82 | if pkg, err := s.Execute(1, "a", 3.14, "test1", 255, -127); err != nil {
83 | t.Fatal(err)
84 | } else {
85 | if pkg.AffectedRows != 1 {
86 | t.Fatal(pkg.AffectedRows)
87 | }
88 | }
89 |
90 | s.Close()
91 | }
92 |
93 | func TestStmt_Select(t *testing.T) {
94 | str := `select str, f, e from mixer_test_stmt where id = ?`
95 |
96 | c := newTestConn()
97 | defer c.Close()
98 |
99 | s, err := c.Prepare(str)
100 | if err != nil {
101 | t.Fatal(err)
102 | }
103 |
104 | if result, err := s.Execute(1); err != nil {
105 | t.Fatal(err)
106 | } else {
107 | if len(result.Values) != 1 {
108 | t.Fatal(len(result.Values))
109 | }
110 |
111 | if len(result.Fields) != 3 {
112 | t.Fatal(len(result.Fields))
113 | }
114 |
115 | if str, _ := result.GetString(0, 0); str != "a" {
116 | t.Fatal("invalid str", str)
117 | }
118 |
119 | if f, _ := result.GetFloat(0, 1); f != float64(3.14) {
120 | t.Fatal("invalid f", f)
121 | }
122 |
123 | if e, _ := result.GetString(0, 2); e != "test1" {
124 | t.Fatal("invalid e", e)
125 | }
126 |
127 | if str, _ := result.GetStringByName(0, "str"); str != "a" {
128 | t.Fatal("invalid str", str)
129 | }
130 |
131 | if f, _ := result.GetFloatByName(0, "f"); f != float64(3.14) {
132 | t.Fatal("invalid f", f)
133 | }
134 |
135 | if e, _ := result.GetStringByName(0, "e"); e != "test1" {
136 | t.Fatal("invalid e", e)
137 | }
138 |
139 | }
140 |
141 | s.Close()
142 | }
143 |
144 | func TestStmt_NULL(t *testing.T) {
145 | str := `insert into mixer_test_stmt (id, str, f, e) values (?, ?, ?, ?)`
146 |
147 | c := newTestConn()
148 | defer c.Close()
149 |
150 | s, err := c.Prepare(str)
151 |
152 | if err != nil {
153 | t.Fatal(err)
154 | }
155 |
156 | if pkg, err := s.Execute(2, nil, 3.14, nil); err != nil {
157 | t.Fatal(err)
158 | } else {
159 | if pkg.AffectedRows != 1 {
160 | t.Fatal(pkg.AffectedRows)
161 | }
162 | }
163 |
164 | s.Close()
165 |
166 | str = `select * from mixer_test_stmt where id = ?`
167 | s, err = c.Prepare(str)
168 |
169 | if err != nil {
170 | t.Fatal(err)
171 | }
172 |
173 | if r, err := s.Execute(2); err != nil {
174 | t.Fatal(err)
175 | } else {
176 | if b, err := r.IsNullByName(0, "id"); err != nil {
177 | t.Fatal(err)
178 | } else if b == true {
179 | t.Fatal(b)
180 | }
181 |
182 | if b, err := r.IsNullByName(0, "str"); err != nil {
183 | t.Fatal(err)
184 | } else if b == false {
185 | t.Fatal(b)
186 | }
187 |
188 | if b, err := r.IsNullByName(0, "f"); err != nil {
189 | t.Fatal(err)
190 | } else if b == true {
191 | t.Fatal(b)
192 | }
193 |
194 | if b, err := r.IsNullByName(0, "e"); err != nil {
195 | t.Fatal(err)
196 | } else if b == false {
197 | t.Fatal(b)
198 | }
199 | }
200 |
201 | s.Close()
202 | }
203 |
204 | func TestStmt_Unsigned(t *testing.T) {
205 | str := `insert into mixer_test_stmt (id, u) values (?, ?)`
206 |
207 | c := newTestConn()
208 | defer c.Close()
209 |
210 | s, err := c.Prepare(str)
211 |
212 | if err != nil {
213 | t.Fatal(err)
214 | }
215 |
216 | if pkg, err := s.Execute(3, uint8(255)); err != nil {
217 | t.Fatal(err)
218 | } else {
219 | if pkg.AffectedRows != 1 {
220 | t.Fatal(pkg.AffectedRows)
221 | }
222 | }
223 |
224 | s.Close()
225 |
226 | str = `select u from mixer_test_stmt where id = ?`
227 |
228 | s, err = c.Prepare(str)
229 | if err != nil {
230 | t.Fatal(err)
231 | }
232 |
233 | if r, err := s.Execute(3); err != nil {
234 | t.Fatal(err)
235 | } else {
236 | if u, err := r.GetUint(0, 0); err != nil {
237 | t.Fatal(err)
238 | } else if u != uint64(255) {
239 | t.Fatal(u)
240 | }
241 | }
242 |
243 | s.Close()
244 | }
245 |
246 | func TestStmt_Signed(t *testing.T) {
247 | str := `insert into mixer_test_stmt (id, i) values (?, ?)`
248 |
249 | c := newTestConn()
250 | defer c.Close()
251 |
252 | s, err := c.Prepare(str)
253 |
254 | if err != nil {
255 | t.Fatal(err)
256 | }
257 |
258 | if _, err := s.Execute(4, 127); err != nil {
259 | t.Fatal(err)
260 | }
261 |
262 | if _, err := s.Execute(uint64(18446744073709551516), int8(-128)); err != nil {
263 | t.Fatal(err)
264 | }
265 |
266 | s.Close()
267 |
268 | }
269 |
270 | func TestStmt_Trans(t *testing.T) {
271 | c := newTestConn()
272 | defer c.Close()
273 |
274 | if _, err := c.Execute(`insert into mixer_test_stmt (id, str) values (1002, "abc")`); err != nil {
275 | t.Fatal(err)
276 | }
277 |
278 | if err := c.Begin(); err != nil {
279 | t.Fatal(err)
280 | }
281 |
282 | str := `select str from mixer_test_stmt where id = ?`
283 |
284 | s, err := c.Prepare(str)
285 | if err != nil {
286 | t.Fatal(err)
287 | }
288 |
289 | if _, err := s.Execute(1002); err != nil {
290 | t.Fatal(err)
291 | }
292 |
293 | if err := c.Commit(); err != nil {
294 | t.Fatal(err)
295 | }
296 |
297 | if r, err := s.Execute(1002); err != nil {
298 | t.Fatal(err)
299 | } else {
300 | if str, _ := r.GetString(0, 0); str != `abc` {
301 | t.Fatal(str)
302 | }
303 | }
304 |
305 | if err := s.Close(); err != nil {
306 | t.Fatal(err)
307 | }
308 | }
309 |
--------------------------------------------------------------------------------
/mysql/util.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "crypto/sha1"
5 | "encoding/binary"
6 | "fmt"
7 | "io"
8 | "math/rand"
9 | "runtime"
10 | "time"
11 | "unicode/utf8"
12 | )
13 |
14 | func Pstack() string {
15 | buf := make([]byte, 1024)
16 | n := runtime.Stack(buf, false)
17 | return string(buf[0:n])
18 | }
19 |
20 | func CalcPassword(scramble, password []byte) []byte {
21 | if len(password) == 0 {
22 | return nil
23 | }
24 |
25 | // stage1Hash = SHA1(password)
26 | crypt := sha1.New()
27 | crypt.Write(password)
28 | stage1 := crypt.Sum(nil)
29 |
30 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
31 | // inner Hash
32 | crypt.Reset()
33 | crypt.Write(stage1)
34 | hash := crypt.Sum(nil)
35 |
36 | // outer Hash
37 | crypt.Reset()
38 | crypt.Write(scramble)
39 | crypt.Write(hash)
40 | scramble = crypt.Sum(nil)
41 |
42 | // token = scrambleHash XOR stage1Hash
43 | for i := range scramble {
44 | scramble[i] ^= stage1[i]
45 | }
46 | return scramble
47 | }
48 |
49 | func RandomBuf(size int) []byte {
50 | buf := make([]byte, size)
51 | rand.Seed(time.Now().UTC().UnixNano())
52 | for i := 0; i < size; i++ {
53 | buf[i] = byte(rand.Intn(127))
54 | if buf[i] == 0 || buf[i] == byte('$') {
55 | buf[i]++
56 | }
57 | }
58 | return buf
59 | }
60 |
61 | func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
62 | switch b[0] {
63 |
64 | // 251: NULL
65 | case 0xfb:
66 | n = 1
67 | isNull = true
68 | return
69 |
70 | // 252: value of following 2
71 | case 0xfc:
72 | num = uint64(b[1]) | uint64(b[2])<<8
73 | n = 3
74 | return
75 |
76 | // 253: value of following 3
77 | case 0xfd:
78 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
79 | n = 4
80 | return
81 |
82 | // 254: value of following 8
83 | case 0xfe:
84 | num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
85 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
86 | uint64(b[7])<<48 | uint64(b[8])<<56
87 | n = 9
88 | return
89 | }
90 |
91 | // 0-250: value of first byte
92 | num = uint64(b[0])
93 | n = 1
94 | return
95 | }
96 |
97 | func PutLengthEncodedInt(n uint64) []byte {
98 | switch {
99 | case n <= 250:
100 | return []byte{byte(n)}
101 |
102 | case n <= 0xffff:
103 | return []byte{0xfc, byte(n), byte(n >> 8)}
104 |
105 | case n <= 0xffffff:
106 | return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
107 |
108 | case n <= 0xffffffffffffffff:
109 | return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24),
110 | byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}
111 | }
112 | return nil
113 | }
114 |
115 | func LengthEnodedString(b []byte) ([]byte, bool, int, error) {
116 | // Get length
117 | num, isNull, n := LengthEncodedInt(b)
118 | if num < 1 {
119 | return nil, isNull, n, nil
120 | }
121 |
122 | n += int(num)
123 |
124 | // Check data length
125 | if len(b) >= n {
126 | return b[n-int(num) : n], false, n, nil
127 | }
128 | return nil, false, n, io.EOF
129 | }
130 |
131 | func SkipLengthEnodedString(b []byte) (int, error) {
132 | // Get length
133 | num, _, n := LengthEncodedInt(b)
134 | if num < 1 {
135 | return n, nil
136 | }
137 |
138 | n += int(num)
139 |
140 | // Check data length
141 | if len(b) >= n {
142 | return n, nil
143 | }
144 | return n, io.EOF
145 | }
146 |
147 | func PutLengthEncodedString(b []byte) []byte {
148 | data := make([]byte, 0, len(b)+9)
149 | data = append(data, PutLengthEncodedInt(uint64(len(b)))...)
150 | data = append(data, b...)
151 | return data
152 | }
153 |
154 | func Uint16ToBytes(n uint16) []byte {
155 | return []byte{
156 | byte(n),
157 | byte(n >> 8),
158 | }
159 | }
160 |
161 | func Uint32ToBytes(n uint32) []byte {
162 | return []byte{
163 | byte(n),
164 | byte(n >> 8),
165 | byte(n >> 16),
166 | byte(n >> 24),
167 | }
168 | }
169 |
170 | func Uint64ToBytes(n uint64) []byte {
171 | return []byte{
172 | byte(n),
173 | byte(n >> 8),
174 | byte(n >> 16),
175 | byte(n >> 24),
176 | byte(n >> 32),
177 | byte(n >> 40),
178 | byte(n >> 48),
179 | byte(n >> 56),
180 | }
181 | }
182 |
183 | func FormatBinaryDate(n int, data []byte) ([]byte, error) {
184 | switch n {
185 | case 0:
186 | return []byte("0000-00-00"), nil
187 | case 4:
188 | return []byte(fmt.Sprintf("%04d-%02d-%02d",
189 | binary.LittleEndian.Uint16(data[:2]),
190 | data[2],
191 | data[3])), nil
192 | default:
193 | return nil, fmt.Errorf("invalid date packet length %d", n)
194 | }
195 | }
196 |
197 | func FormatBinaryDateTime(n int, data []byte) ([]byte, error) {
198 | switch n {
199 | case 0:
200 | return []byte("0000-00-00 00:00:00"), nil
201 | case 4:
202 | return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
203 | binary.LittleEndian.Uint16(data[:2]),
204 | data[2],
205 | data[3])), nil
206 | case 7:
207 | return []byte(fmt.Sprintf(
208 | "%04d-%02d-%02d %02d:%02d:%02d",
209 | binary.LittleEndian.Uint16(data[:2]),
210 | data[2],
211 | data[3],
212 | data[4],
213 | data[5],
214 | data[6])), nil
215 | case 11:
216 | return []byte(fmt.Sprintf(
217 | "%04d-%02d-%02d %02d:%02d:%02d.%06d",
218 | binary.LittleEndian.Uint16(data[:2]),
219 | data[2],
220 | data[3],
221 | data[4],
222 | data[5],
223 | data[6],
224 | binary.LittleEndian.Uint32(data[7:11]))), nil
225 | default:
226 | return nil, fmt.Errorf("invalid datetime packet length %d", n)
227 | }
228 | }
229 |
230 | func FormatBinaryTime(n int, data []byte) ([]byte, error) {
231 | if n == 0 {
232 | return []byte("0000-00-00"), nil
233 | }
234 |
235 | var sign byte
236 | if data[0] == 1 {
237 | sign = byte('-')
238 | }
239 |
240 | switch n {
241 | case 8:
242 | return []byte(fmt.Sprintf(
243 | "%c%02d:%02d:%02d",
244 | sign,
245 | uint16(data[1])*24+uint16(data[5]),
246 | data[6],
247 | data[7],
248 | )), nil
249 | case 12:
250 | return []byte(fmt.Sprintf(
251 | "%c%02d:%02d:%02d.%06d",
252 | sign,
253 | uint16(data[1])*24+uint16(data[5]),
254 | data[6],
255 | data[7],
256 | binary.LittleEndian.Uint32(data[8:12]),
257 | )), nil
258 | default:
259 | return nil, fmt.Errorf("invalid time packet length %d", n)
260 | }
261 | }
262 |
263 | var (
264 | DONTESCAPE = byte(255)
265 |
266 | EncodeMap [256]byte
267 | )
268 |
269 | func Escape(sql string) string {
270 | dest := make([]byte, 0, 2*len(sql))
271 |
272 | for i, w := 0, 0; i < len(sql); i += w {
273 | runeValue, width := utf8.DecodeRuneInString(sql[i:])
274 | if c := EncodeMap[byte(runeValue)]; c == DONTESCAPE {
275 | dest = append(dest, sql[i:i+width]...)
276 | } else {
277 | dest = append(dest, '\\', c)
278 | }
279 | w = width
280 | }
281 |
282 | return string(dest)
283 | }
284 |
285 | var encodeRef = map[byte]byte{
286 | '\x00': '0',
287 | '\'': '\'',
288 | '"': '"',
289 | '\b': 'b',
290 | '\n': 'n',
291 | '\r': 'r',
292 | '\t': 't',
293 | 26: 'Z', // ctl-Z
294 | '\\': '\\',
295 | }
296 |
297 | func init() {
298 | for i := range EncodeMap {
299 | EncodeMap[i] = DONTESCAPE
300 | }
301 | for i := range EncodeMap {
302 | if to, ok := encodeRef[byte(i)]; ok {
303 | EncodeMap[byte(i)] = to
304 | }
305 | }
306 | }
307 |
--------------------------------------------------------------------------------
/proxy/conn.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "fmt"
7 | "github.com/siddontang/go-log/log"
8 | "github.com/siddontang/mixer/client"
9 | "github.com/siddontang/mixer/hack"
10 | . "github.com/siddontang/mixer/mysql"
11 | "net"
12 | "runtime"
13 | "sync"
14 | "sync/atomic"
15 | )
16 |
17 | var DEFAULT_CAPABILITY uint32 = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG |
18 | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 |
19 | CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION
20 |
21 | //client <-> proxy
22 | type Conn struct {
23 | sync.Mutex
24 |
25 | pkg *PacketIO
26 |
27 | c net.Conn
28 |
29 | server *Server
30 |
31 | capability uint32
32 |
33 | connectionId uint32
34 |
35 | status uint16
36 | collation CollationId
37 | charset string
38 |
39 | user string
40 | db string
41 |
42 | salt []byte
43 |
44 | schema *Schema
45 |
46 | txConns map[*Node]*client.SqlConn
47 |
48 | closed bool
49 |
50 | lastInsertId int64
51 | affectedRows int64
52 |
53 | stmtId uint32
54 |
55 | stmts map[uint32]*Stmt
56 | }
57 |
58 | var baseConnId uint32 = 10000
59 |
60 | func (s *Server) newConn(co net.Conn) *Conn {
61 | c := new(Conn)
62 |
63 | c.c = co
64 |
65 | c.pkg = NewPacketIO(co)
66 |
67 | c.server = s
68 |
69 | c.c = co
70 | c.pkg.Sequence = 0
71 |
72 | c.connectionId = atomic.AddUint32(&baseConnId, 1)
73 |
74 | c.status = SERVER_STATUS_AUTOCOMMIT
75 |
76 | c.salt = RandomBuf(20)
77 |
78 | c.txConns = make(map[*Node]*client.SqlConn)
79 |
80 | c.closed = false
81 |
82 | c.collation = DEFAULT_COLLATION_ID
83 | c.charset = DEFAULT_CHARSET
84 |
85 | c.stmtId = 0
86 | c.stmts = make(map[uint32]*Stmt)
87 |
88 | return c
89 | }
90 |
91 | func (c *Conn) Handshake() error {
92 | if err := c.writeInitialHandshake(); err != nil {
93 | log.Error("send initial handshake error %s", err.Error())
94 | return err
95 | }
96 |
97 | if err := c.readHandshakeResponse(); err != nil {
98 | log.Error("recv handshake response error %s", err.Error())
99 |
100 | c.writeError(err)
101 |
102 | return err
103 | }
104 |
105 | if err := c.writeOK(nil); err != nil {
106 | log.Error("write ok fail %s", err.Error())
107 | return err
108 | }
109 |
110 | c.pkg.Sequence = 0
111 |
112 | return nil
113 | }
114 |
115 | func (c *Conn) Close() error {
116 | if c.closed {
117 | return nil
118 | }
119 |
120 | c.c.Close()
121 |
122 | c.rollback()
123 |
124 | c.closed = true
125 |
126 | return nil
127 | }
128 |
129 | func (c *Conn) writeInitialHandshake() error {
130 | data := make([]byte, 4, 128)
131 |
132 | //min version 10
133 | data = append(data, 10)
134 |
135 | //server version[00]
136 | data = append(data, ServerVersion...)
137 | data = append(data, 0)
138 |
139 | //connection id
140 | data = append(data, byte(c.connectionId), byte(c.connectionId>>8), byte(c.connectionId>>16), byte(c.connectionId>>24))
141 |
142 | //auth-plugin-data-part-1
143 | data = append(data, c.salt[0:8]...)
144 |
145 | //filter [00]
146 | data = append(data, 0)
147 |
148 | //capability flag lower 2 bytes, using default capability here
149 | data = append(data, byte(DEFAULT_CAPABILITY), byte(DEFAULT_CAPABILITY>>8))
150 |
151 | //charset, utf-8 default
152 | data = append(data, uint8(DEFAULT_COLLATION_ID))
153 |
154 | //status
155 | data = append(data, byte(c.status), byte(c.status>>8))
156 |
157 | //below 13 byte may not be used
158 | //capability flag upper 2 bytes, using default capability here
159 | data = append(data, byte(DEFAULT_CAPABILITY>>16), byte(DEFAULT_CAPABILITY>>24))
160 |
161 | //filter [0x15], for wireshark dump, value is 0x15
162 | data = append(data, 0x15)
163 |
164 | //reserved 10 [00]
165 | data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
166 |
167 | //auth-plugin-data-part-2
168 | data = append(data, c.salt[8:]...)
169 |
170 | //filter [00]
171 | data = append(data, 0)
172 |
173 | return c.writePacket(data)
174 | }
175 |
176 | func (c *Conn) readPacket() ([]byte, error) {
177 | return c.pkg.ReadPacket()
178 | }
179 |
180 | func (c *Conn) writePacket(data []byte) error {
181 | return c.pkg.WritePacket(data)
182 | }
183 |
184 | func (c *Conn) readHandshakeResponse() error {
185 | data, err := c.readPacket()
186 |
187 | if err != nil {
188 | return err
189 | }
190 |
191 | pos := 0
192 |
193 | //capability
194 | c.capability = binary.LittleEndian.Uint32(data[:4])
195 | pos += 4
196 |
197 | //skip max packet size
198 | pos += 4
199 |
200 | //charset, skip, if you want to use another charset, use set names
201 | //c.collation = CollationId(data[pos])
202 | pos++
203 |
204 | //skip reserved 23[00]
205 | pos += 23
206 |
207 | //user name
208 | c.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
209 | pos += len(c.user) + 1
210 |
211 | //auth length and auth
212 | authLen := int(data[pos])
213 | pos++
214 | auth := data[pos : pos+authLen]
215 |
216 | checkAuth := CalcPassword(c.salt, []byte(c.server.cfg.Password))
217 |
218 | if !bytes.Equal(auth, checkAuth) {
219 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.c.RemoteAddr().String(), c.user, "Yes")
220 | }
221 |
222 | pos += authLen
223 |
224 | if c.capability&CLIENT_CONNECT_WITH_DB > 0 {
225 | if len(data[pos:]) == 0 {
226 | return nil
227 | }
228 |
229 | db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
230 | pos += len(c.db) + 1
231 |
232 | if err := c.useDB(db); err != nil {
233 | return err
234 | }
235 | }
236 |
237 | return nil
238 | }
239 |
240 | func (c *Conn) Run() {
241 | defer func() {
242 | r := recover()
243 | if err, ok := r.(error); ok {
244 | const size = 4096
245 | buf := make([]byte, size)
246 | buf = buf[:runtime.Stack(buf, false)]
247 |
248 | log.Error("%v, %s", err, buf)
249 | }
250 |
251 | c.Close()
252 | }()
253 |
254 | for {
255 | data, err := c.readPacket()
256 |
257 | if err != nil {
258 | return
259 | }
260 |
261 | if err := c.dispatch(data); err != nil {
262 | log.Error("dispatch error %s", err.Error())
263 | if err != ErrBadConn {
264 | c.writeError(err)
265 | }
266 | }
267 |
268 | if c.closed {
269 | return
270 | }
271 |
272 | c.pkg.Sequence = 0
273 | }
274 | }
275 |
276 | func (c *Conn) dispatch(data []byte) error {
277 | cmd := data[0]
278 | data = data[1:]
279 |
280 | switch cmd {
281 | case COM_QUIT:
282 | c.Close()
283 | return nil
284 | case COM_QUERY:
285 | return c.handleQuery(hack.String(data))
286 | case COM_PING:
287 | return c.writeOK(nil)
288 | case COM_INIT_DB:
289 | if err := c.useDB(hack.String(data)); err != nil {
290 | return err
291 | } else {
292 | return c.writeOK(nil)
293 | }
294 | case COM_FIELD_LIST:
295 | return c.handleFieldList(data)
296 | case COM_STMT_PREPARE:
297 | return c.handleStmtPrepare(hack.String(data))
298 | case COM_STMT_EXECUTE:
299 | return c.handleStmtExecute(data)
300 | case COM_STMT_CLOSE:
301 | return c.handleStmtClose(data)
302 | case COM_STMT_SEND_LONG_DATA:
303 | return c.handleStmtSendLongData(data)
304 | case COM_STMT_RESET:
305 | return c.handleStmtReset(data)
306 | default:
307 | msg := fmt.Sprintf("command %d not supported now", cmd)
308 | return NewError(ER_UNKNOWN_ERROR, msg)
309 | }
310 |
311 | return nil
312 | }
313 |
314 | func (c *Conn) useDB(db string) error {
315 | if s := c.server.getSchema(db); s == nil {
316 | return NewDefaultError(ER_BAD_DB_ERROR, db)
317 | } else {
318 | c.schema = s
319 | c.db = db
320 | }
321 | return nil
322 | }
323 |
324 | func (c *Conn) writeOK(r *Result) error {
325 | if r == nil {
326 | r = &Result{Status: c.status}
327 | }
328 | data := make([]byte, 4, 32)
329 |
330 | data = append(data, OK_HEADER)
331 |
332 | data = append(data, PutLengthEncodedInt(r.AffectedRows)...)
333 | data = append(data, PutLengthEncodedInt(r.InsertId)...)
334 |
335 | if c.capability&CLIENT_PROTOCOL_41 > 0 {
336 | data = append(data, byte(r.Status), byte(r.Status>>8))
337 | data = append(data, 0, 0)
338 | }
339 |
340 | return c.writePacket(data)
341 | }
342 |
343 | func (c *Conn) writeError(e error) error {
344 | var m *SqlError
345 | var ok bool
346 | if m, ok = e.(*SqlError); !ok {
347 | m = NewError(ER_UNKNOWN_ERROR, e.Error())
348 | }
349 |
350 | data := make([]byte, 4, 16+len(m.Message))
351 |
352 | data = append(data, ERR_HEADER)
353 | data = append(data, byte(m.Code), byte(m.Code>>8))
354 |
355 | if c.capability&CLIENT_PROTOCOL_41 > 0 {
356 | data = append(data, '#')
357 | data = append(data, m.State...)
358 | }
359 |
360 | data = append(data, m.Message...)
361 |
362 | return c.writePacket(data)
363 | }
364 |
365 | func (c *Conn) writeEOF(status uint16) error {
366 | data := make([]byte, 4, 9)
367 |
368 | data = append(data, EOF_HEADER)
369 | if c.capability&CLIENT_PROTOCOL_41 > 0 {
370 | data = append(data, 0, 0)
371 | data = append(data, byte(status), byte(status>>8))
372 | }
373 |
374 | return c.writePacket(data)
375 | }
376 |
--------------------------------------------------------------------------------
/proxy/conn_test.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | . "github.com/siddontang/mixer/mysql"
5 | "testing"
6 | )
7 |
8 | func TestConn_Handshake(t *testing.T) {
9 | c := newTestDBConn(t)
10 |
11 | if err := c.Ping(); err != nil {
12 | t.Fatal(err)
13 | }
14 |
15 | c.Close()
16 | }
17 |
18 | func TestConn_DeleteTable(t *testing.T) {
19 | server := newTestServer(t)
20 | n := server.nodes["node1"]
21 | c, err := n.getMasterConn()
22 | if err != nil {
23 | t.Fatal(err)
24 | }
25 | c.UseDB("mixer")
26 | if _, err := c.Execute(`drop table if exists mixer_test_proxy_conn`); err != nil {
27 | t.Fatal(err)
28 | }
29 | c.Close()
30 | }
31 |
32 | func TestConn_CreateTable(t *testing.T) {
33 | s := `CREATE TABLE IF NOT EXISTS mixer_test_proxy_conn (
34 | id BIGINT(64) UNSIGNED NOT NULL,
35 | str VARCHAR(256),
36 | f DOUBLE,
37 | e enum("test1", "test2"),
38 | u tinyint unsigned,
39 | i tinyint,
40 | ni tinyint,
41 | PRIMARY KEY (id)
42 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
43 |
44 | server := newTestServer(t)
45 | n := server.nodes["node1"]
46 | c, err := n.getMasterConn()
47 | if err != nil {
48 | t.Fatal(err)
49 | }
50 |
51 | c.UseDB("mixer")
52 | defer c.Close()
53 | if _, err := c.Execute(s); err != nil {
54 | t.Fatal(err)
55 | }
56 | }
57 |
58 | func TestConn_Insert(t *testing.T) {
59 | s := `insert into mixer_test_proxy_conn (id, str, f, e, u, i) values(1, "abc", 3.14, "test1", 255, -127)`
60 |
61 | c := newTestDBConn(t)
62 | defer c.Close()
63 |
64 | if r, err := c.Execute(s); err != nil {
65 | t.Fatal(err)
66 | } else {
67 | if r.AffectedRows != 1 {
68 | t.Fatal(r.AffectedRows)
69 | }
70 | }
71 | }
72 |
73 | func TestConn_Select(t *testing.T) {
74 | s := `select str, f, e, u, i, ni from mixer_test_proxy_conn where id = 1`
75 |
76 | c := newTestDBConn(t)
77 | defer c.Close()
78 |
79 | if r, err := c.Execute(s); err != nil {
80 | t.Fatal(err)
81 | } else {
82 | if r.RowNumber() != 1 {
83 | t.Fatal(r.RowNumber())
84 | }
85 |
86 | if r.ColumnNumber() != 6 {
87 | t.Fatal(r.ColumnNumber())
88 | }
89 |
90 | if v, _ := r.GetString(0, 0); v != `abc` {
91 | t.Fatal(v)
92 | }
93 |
94 | if v, _ := r.GetFloat(0, 1); v != 3.14 {
95 | t.Fatal(v)
96 | }
97 |
98 | if v, _ := r.GetString(0, 2); v != `test1` {
99 | t.Fatal(v)
100 | }
101 |
102 | if v, _ := r.GetUint(0, 3); v != 255 {
103 | t.Fatal(v)
104 | }
105 |
106 | if v, _ := r.GetInt(0, 4); v != -127 {
107 | t.Fatal(v)
108 | }
109 |
110 | if v, _ := r.IsNull(0, 5); !v {
111 | t.Fatal("ni not null")
112 | }
113 | }
114 | }
115 |
116 | func TestConn_Update(t *testing.T) {
117 | s := `update mixer_test_proxy_conn set str = "123" where id = 1`
118 |
119 | c := newTestDBConn(t)
120 | defer c.Close()
121 |
122 | if _, err := c.Execute(s); err != nil {
123 | t.Fatal(err)
124 | }
125 |
126 | if r, err := c.Execute(`select str from mixer_test_proxy_conn where id = 1`); err != nil {
127 | t.Fatal(err)
128 | } else {
129 | if v, _ := r.GetString(0, 0); v != `123` {
130 | t.Fatal(v)
131 | }
132 | }
133 | }
134 |
135 | func TestConn_Replace(t *testing.T) {
136 | s := `replace into mixer_test_proxy_conn (id, str, f) values(1, "abc", 3.14159)`
137 |
138 | c := newTestDBConn(t)
139 | defer c.Close()
140 |
141 | if r, err := c.Execute(s); err != nil {
142 | t.Fatal(err)
143 | } else {
144 | if r.AffectedRows != 2 {
145 | t.Fatal(r.AffectedRows)
146 | }
147 | }
148 |
149 | s = `replace into mixer_test_proxy_conn (id, str) values(2, "abcb")`
150 |
151 | if r, err := c.Execute(s); err != nil {
152 | t.Fatal(err)
153 | } else {
154 | if r.AffectedRows != 1 {
155 | t.Fatal(r.AffectedRows)
156 | }
157 | }
158 |
159 | s = `select str, f from mixer_test_proxy_conn`
160 |
161 | if r, err := c.Execute(s); err != nil {
162 | t.Fatal(err)
163 | } else {
164 | if v, _ := r.GetString(0, 0); v != `abc` {
165 | t.Fatal(v)
166 | }
167 |
168 | if v, _ := r.GetString(1, 0); v != `abcb` {
169 | t.Fatal(v)
170 | }
171 |
172 | if v, _ := r.GetFloat(0, 1); v != 3.14159 {
173 | t.Fatal(v)
174 | }
175 |
176 | if v, _ := r.IsNull(1, 1); !v {
177 | t.Fatal(v)
178 | }
179 | }
180 | }
181 |
182 | func TestConn_Delete(t *testing.T) {
183 | s := `delete from mixer_test_proxy_conn where id = 100000`
184 |
185 | c := newTestDBConn(t)
186 | defer c.Close()
187 |
188 | if r, err := c.Execute(s); err != nil {
189 | t.Fatal(err)
190 | } else {
191 | if r.AffectedRows != 0 {
192 | t.Fatal(r.AffectedRows)
193 | }
194 | }
195 | }
196 |
197 | func TestConn_SetAutoCommit(t *testing.T) {
198 | c := newTestDBConn(t)
199 | defer c.Close()
200 |
201 | if r, err := c.Execute("set autocommit = 1"); err != nil {
202 | t.Fatal(err)
203 | } else {
204 | if !(r.Status&SERVER_STATUS_AUTOCOMMIT > 0) {
205 | t.Fatal(r.Status)
206 | }
207 | }
208 |
209 | if r, err := c.Execute("set autocommit = 0"); err != nil {
210 | t.Fatal(err)
211 | } else {
212 | if !(r.Status&SERVER_STATUS_AUTOCOMMIT == 0) {
213 | t.Fatal(r.Status)
214 | }
215 | }
216 | }
217 |
218 | func TestConn_Trans(t *testing.T) {
219 | c1 := newTestDBConn(t)
220 | defer c1.Close()
221 |
222 | c2 := newTestDBConn(t)
223 | defer c2.Close()
224 |
225 | var err error
226 |
227 | if err = c1.Begin(); err != nil {
228 | t.Fatal(err)
229 | }
230 |
231 | if err = c2.Begin(); err != nil {
232 | t.Fatal(err)
233 | }
234 |
235 | if _, err := c1.Execute(`insert into mixer_test_proxy_conn (id, str) values (111, "abc")`); err != nil {
236 | t.Fatal(err)
237 | }
238 |
239 | if r, err := c2.Execute(`select str from mixer_test_proxy_conn where id = 111`); err != nil {
240 | t.Fatal(err)
241 | } else {
242 | if r.RowNumber() != 0 {
243 | t.Fatal(r.RowNumber())
244 | }
245 | }
246 |
247 | if err := c1.Commit(); err != nil {
248 | t.Fatal(err)
249 | }
250 |
251 | if err := c2.Commit(); err != nil {
252 | t.Fatal(err)
253 | }
254 |
255 | if r, err := c1.Execute(`select str from mixer_test_proxy_conn where id = 111`); err != nil {
256 | t.Fatal(err)
257 | } else {
258 | if r.RowNumber() != 1 {
259 | t.Fatal(r.RowNumber())
260 | }
261 |
262 | if v, _ := r.GetString(0, 0); v != `abc` {
263 | t.Fatal(v)
264 | }
265 | }
266 | }
267 |
268 | func TestConn_SetNames(t *testing.T) {
269 | c := newTestDBConn(t)
270 | defer c.Close()
271 |
272 | if err := c.SetCharset("gb2312"); err != nil {
273 | t.Fatal(err)
274 | }
275 | }
276 |
277 | func TestConn_LastInsertId(t *testing.T) {
278 | s := `CREATE TABLE IF NOT EXISTS mixer_test_conn_id (
279 | id BIGINT(64) UNSIGNED AUTO_INCREMENT NOT NULL,
280 | str VARCHAR(256),
281 | PRIMARY KEY (id)
282 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
283 |
284 | server := newTestServer(t)
285 | n := server.nodes["node1"]
286 |
287 | c1, err := n.getMasterConn()
288 | if err != nil {
289 | t.Fatal(err)
290 | }
291 |
292 | if _, err := c1.Execute(s); err != nil {
293 | t.Fatal(err)
294 | }
295 |
296 | c1.Close()
297 |
298 | c := newTestDBConn(t)
299 | defer c.Close()
300 |
301 | r, err := c.Execute(`insert into mixer_test_conn_id (str) values ("abc")`)
302 | if err != nil {
303 | t.Fatal(err)
304 | }
305 |
306 | lastId := r.InsertId
307 | if r, err := c.Execute(`select last_insert_id()`); err != nil {
308 | t.Fatal(err)
309 | } else {
310 | if r.ColumnNumber() != 1 {
311 | t.Fatal(r.ColumnNumber())
312 | }
313 |
314 | if v, _ := r.GetUint(0, 0); v != lastId {
315 | t.Fatal(v)
316 | }
317 | }
318 |
319 | if r, err := c.Execute(`select last_insert_id() as a`); err != nil {
320 | t.Fatal(err)
321 | } else {
322 | if string(r.Fields[0].Name) != "a" {
323 | t.Fatal(string(r.Fields[0].Name))
324 | }
325 |
326 | if v, _ := r.GetUint(0, 0); v != lastId {
327 | t.Fatal(v)
328 | }
329 | }
330 |
331 | c1, _ = n.getMasterConn()
332 |
333 | if _, err := c1.Execute(`drop table if exists mixer_test_conn_id`); err != nil {
334 | t.Fatal(err)
335 | }
336 |
337 | c1.Close()
338 | }
339 |
340 | func TestConn_RowCount(t *testing.T) {
341 | c := newTestDBConn(t)
342 | defer c.Close()
343 |
344 | r, err := c.Execute(`insert into mixer_test_proxy_conn (id, str) values (1002, "abc")`)
345 | if err != nil {
346 | t.Fatal(err)
347 | }
348 |
349 | row := r.AffectedRows
350 |
351 | if r, err := c.Execute("select row_count()"); err != nil {
352 | t.Fatal(err)
353 | } else {
354 | if v, _ := r.GetUint(0, 0); v != row {
355 | t.Fatal(v)
356 | }
357 | }
358 |
359 | if r, err := c.Execute("select row_count() as b"); err != nil {
360 | t.Fatal(err)
361 | } else {
362 | if v, _ := r.GetInt(0, 0); v != -1 {
363 | t.Fatal(v)
364 | }
365 | }
366 | }
367 |
368 | func TestConn_SelectVersion(t *testing.T) {
369 | c := newTestDBConn(t)
370 | defer c.Close()
371 |
372 | if r, err := c.Execute("select version()"); err != nil {
373 | t.Fatal(err)
374 | } else {
375 | if v, _ := r.GetString(0, 0); v != ServerVersion {
376 | t.Fatal(v)
377 | }
378 | }
379 | }
380 |
--------------------------------------------------------------------------------
/sqltypes/sqltypes.go:
--------------------------------------------------------------------------------
1 | // Copyright 2012, Google Inc. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // Package sqltypes implements interfaces and types that represent SQL values.
6 | package sqltypes
7 |
8 | import (
9 | "encoding/base64"
10 | "encoding/gob"
11 | "encoding/json"
12 | "fmt"
13 | "strconv"
14 | "time"
15 |
16 | "github.com/siddontang/mixer/hack"
17 | )
18 |
19 | var (
20 | NULL = Value{}
21 | DONTESCAPE = byte(255)
22 | nullstr = []byte("null")
23 | )
24 |
25 | // BinWriter interface is used for encoding values.
26 | // Types like bytes.Buffer conform to this interface.
27 | // We expect the writer objects to be in-memory buffers.
28 | // So, we don't expect the write operations to fail.
29 | type BinWriter interface {
30 | Write([]byte) (int, error)
31 | WriteByte(byte) error
32 | }
33 |
34 | // Value can store any SQL value. NULL is stored as nil.
35 | type Value struct {
36 | Inner InnerValue
37 | }
38 |
39 | // Numeric represents non-fractional SQL number.
40 | type Numeric []byte
41 |
42 | // Fractional represents fractional types like float and decimal
43 | // It's functionally equivalent to Numeric other than how it's constructed
44 | type Fractional []byte
45 |
46 | // String represents any SQL type that needs to be represented using quotes.
47 | type String []byte
48 |
49 | // MakeNumeric makes a Numeric from a []byte without validation.
50 | func MakeNumeric(b []byte) Value {
51 | return Value{Numeric(b)}
52 | }
53 |
54 | // MakeFractional makes a Fractional value from a []byte without validation.
55 | func MakeFractional(b []byte) Value {
56 | return Value{Fractional(b)}
57 | }
58 |
59 | // MakeString makes a String value from a []byte.
60 | func MakeString(b []byte) Value {
61 | return Value{String(b)}
62 | }
63 |
64 | // Raw returns the raw bytes. All types are currently implemented as []byte.
65 | func (v Value) Raw() []byte {
66 | if v.Inner == nil {
67 | return nil
68 | }
69 | return v.Inner.raw()
70 | }
71 |
72 | // String returns the raw value as a string
73 | func (v Value) String() string {
74 | if v.Inner == nil {
75 | return ""
76 | }
77 | return hack.String(v.Inner.raw())
78 | }
79 |
80 | // ParseInt64 will parse a Numeric value into an int64
81 | func (v Value) ParseInt64() (val int64, err error) {
82 | if v.Inner == nil {
83 | return 0, fmt.Errorf("value is null")
84 | }
85 | n, ok := v.Inner.(Numeric)
86 | if !ok {
87 | return 0, fmt.Errorf("value is not Numeric")
88 | }
89 | return strconv.ParseInt(string(n.raw()), 10, 64)
90 | }
91 |
92 | // ParseUint64 will parse a Numeric value into a uint64
93 | func (v Value) ParseUint64() (val uint64, err error) {
94 | if v.Inner == nil {
95 | return 0, fmt.Errorf("value is null")
96 | }
97 | n, ok := v.Inner.(Numeric)
98 | if !ok {
99 | return 0, fmt.Errorf("value is not Numeric")
100 | }
101 | return strconv.ParseUint(string(n.raw()), 10, 64)
102 | }
103 |
104 | // EncodeSql encodes the value into an SQL statement. Can be binary.
105 | func (v Value) EncodeSql(b BinWriter) {
106 | if v.Inner == nil {
107 | if _, err := b.Write(nullstr); err != nil {
108 | panic(err)
109 | }
110 | } else {
111 | v.Inner.encodeSql(b)
112 | }
113 | }
114 |
115 | // EncodeAscii encodes the value using 7-bit clean ascii bytes.
116 | func (v Value) EncodeAscii(b BinWriter) {
117 | if v.Inner == nil {
118 | if _, err := b.Write(nullstr); err != nil {
119 | panic(err)
120 | }
121 | } else {
122 | v.Inner.encodeAscii(b)
123 | }
124 | }
125 |
126 | func (v Value) IsNull() bool {
127 | return v.Inner == nil
128 | }
129 |
130 | func (v Value) IsNumeric() (ok bool) {
131 | if v.Inner != nil {
132 | _, ok = v.Inner.(Numeric)
133 | }
134 | return ok
135 | }
136 |
137 | func (v Value) IsFractional() (ok bool) {
138 | if v.Inner != nil {
139 | _, ok = v.Inner.(Fractional)
140 | }
141 | return ok
142 | }
143 |
144 | func (v Value) IsString() (ok bool) {
145 | if v.Inner != nil {
146 | _, ok = v.Inner.(String)
147 | }
148 | return ok
149 | }
150 |
151 | // MarshalJSON should only be used for testing.
152 | // It's not a complete implementation.
153 | func (v Value) MarshalJSON() ([]byte, error) {
154 | return json.Marshal(v.Inner)
155 | }
156 |
157 | // UnmarshalJSON should only be used for testing.
158 | // It's not a complete implementation.
159 | func (v *Value) UnmarshalJSON(b []byte) error {
160 | if len(b) == 0 {
161 | return fmt.Errorf("error unmarshaling empty bytes")
162 | }
163 | var val interface{}
164 | var err error
165 | switch b[0] {
166 | case '-':
167 | var ival int64
168 | err = json.Unmarshal(b, &ival)
169 | val = ival
170 | case '"':
171 | var bval []byte
172 | err = json.Unmarshal(b, &bval)
173 | val = bval
174 | case 'n': // null
175 | err = json.Unmarshal(b, &val)
176 | default:
177 | var uval uint64
178 | err = json.Unmarshal(b, &uval)
179 | val = uval
180 | }
181 | if err != nil {
182 | return err
183 | }
184 | *v, err = BuildValue(val)
185 | return err
186 | }
187 |
188 | // InnerValue defines methods that need to be supported by all non-null value types.
189 | type InnerValue interface {
190 | raw() []byte
191 | encodeSql(BinWriter)
192 | encodeAscii(BinWriter)
193 | }
194 |
195 | func BuildValue(goval interface{}) (v Value, err error) {
196 | switch bindVal := goval.(type) {
197 | case nil:
198 | // no op
199 | case int:
200 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
201 | case int32:
202 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
203 | case int64:
204 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
205 | case uint:
206 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
207 | case uint32:
208 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
209 | case uint64:
210 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
211 | case float64:
212 | v = Value{Fractional(strconv.AppendFloat(nil, bindVal, 'f', -1, 64))}
213 | case string:
214 | v = Value{String([]byte(bindVal))}
215 | case []byte:
216 | v = Value{String(bindVal)}
217 | case time.Time:
218 | v = Value{String([]byte(bindVal.Format("'2006-01-02 15:04:05'")))}
219 | case Numeric, Fractional, String:
220 | v = Value{bindVal.(InnerValue)}
221 | case Value:
222 | v = bindVal
223 | default:
224 | return Value{}, fmt.Errorf("unsupported bind variable type %T: %v", goval, goval)
225 | }
226 | return v, nil
227 | }
228 |
229 | // BuildNumeric builds a Numeric type that represents any whole number.
230 | // It normalizes the representation to ensure 1:1 mapping between the
231 | // number and its representation.
232 | func BuildNumeric(val string) (n Value, err error) {
233 | if val[0] == '-' || val[0] == '+' {
234 | signed, err := strconv.ParseInt(val, 0, 64)
235 | if err != nil {
236 | return Value{}, err
237 | }
238 | n = Value{Numeric(strconv.AppendInt(nil, signed, 10))}
239 | } else {
240 | unsigned, err := strconv.ParseUint(val, 0, 64)
241 | if err != nil {
242 | return Value{}, err
243 | }
244 | n = Value{Numeric(strconv.AppendUint(nil, unsigned, 10))}
245 | }
246 | return n, nil
247 | }
248 |
249 | func (n Numeric) raw() []byte {
250 | return []byte(n)
251 | }
252 |
253 | func (n Numeric) encodeSql(b BinWriter) {
254 | if _, err := b.Write(n.raw()); err != nil {
255 | panic(err)
256 | }
257 | }
258 |
259 | func (n Numeric) encodeAscii(b BinWriter) {
260 | if _, err := b.Write(n.raw()); err != nil {
261 | panic(err)
262 | }
263 | }
264 |
265 | func (n Numeric) MarshalJSON() ([]byte, error) {
266 | return n.raw(), nil
267 | }
268 |
269 | func (f Fractional) raw() []byte {
270 | return []byte(f)
271 | }
272 |
273 | func (f Fractional) encodeSql(b BinWriter) {
274 | if _, err := b.Write(f.raw()); err != nil {
275 | panic(err)
276 | }
277 | }
278 |
279 | func (f Fractional) encodeAscii(b BinWriter) {
280 | if _, err := b.Write(f.raw()); err != nil {
281 | panic(err)
282 | }
283 | }
284 |
285 | func (s String) raw() []byte {
286 | return []byte(s)
287 | }
288 |
289 | func (s String) encodeSql(b BinWriter) {
290 | writebyte(b, '\'')
291 | for _, ch := range s.raw() {
292 | if encodedChar := SqlEncodeMap[ch]; encodedChar == DONTESCAPE {
293 | writebyte(b, ch)
294 | } else {
295 | writebyte(b, '\\')
296 | writebyte(b, encodedChar)
297 | }
298 | }
299 | writebyte(b, '\'')
300 | }
301 |
302 | func (s String) encodeAscii(b BinWriter) {
303 | writebyte(b, '\'')
304 | encoder := base64.NewEncoder(base64.StdEncoding, b)
305 | encoder.Write(s.raw())
306 | encoder.Close()
307 | writebyte(b, '\'')
308 | }
309 |
310 | func writebyte(b BinWriter, c byte) {
311 | if err := b.WriteByte(c); err != nil {
312 | panic(err)
313 | }
314 | }
315 |
316 | // SqlEncodeMap specifies how to escape binary data with '\'.
317 | // Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html
318 | var SqlEncodeMap [256]byte
319 |
320 | // SqlDecodeMap is the reverse of SqlEncodeMap
321 | var SqlDecodeMap [256]byte
322 |
323 | var encodeRef = map[byte]byte{
324 | '\x00': '0',
325 | '\'': '\'',
326 | '"': '"',
327 | '\b': 'b',
328 | '\n': 'n',
329 | '\r': 'r',
330 | '\t': 't',
331 | 26: 'Z', // ctl-Z
332 | '\\': '\\',
333 | }
334 |
335 | func init() {
336 | for i := range SqlEncodeMap {
337 | SqlEncodeMap[i] = DONTESCAPE
338 | SqlDecodeMap[i] = DONTESCAPE
339 | }
340 | for i := range SqlEncodeMap {
341 | if to, ok := encodeRef[byte(i)]; ok {
342 | SqlEncodeMap[byte(i)] = to
343 | SqlDecodeMap[to] = byte(i)
344 | }
345 | }
346 | gob.Register(Numeric(nil))
347 | gob.Register(Fractional(nil))
348 | gob.Register(String(nil))
349 | }
350 |
--------------------------------------------------------------------------------
/mysql/resultset.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | "github.com/siddontang/mixer/hack"
7 | "math"
8 | "strconv"
9 | )
10 |
11 | type RowData []byte
12 |
13 | func (p RowData) Parse(f []*Field, binary bool) ([]interface{}, error) {
14 | if binary {
15 | return p.ParseBinary(f)
16 | } else {
17 | return p.ParseText(f)
18 | }
19 | }
20 |
21 | func (p RowData) ParseText(f []*Field) ([]interface{}, error) {
22 | data := make([]interface{}, len(f))
23 |
24 | var err error
25 | var v []byte
26 | var isNull, isUnsigned bool
27 | var pos int = 0
28 | var n int = 0
29 |
30 | for i := range f {
31 | v, isNull, n, err = LengthEnodedString(p[pos:])
32 | if err != nil {
33 | return nil, err
34 | }
35 |
36 | pos += n
37 |
38 | if isNull {
39 | data[i] = nil
40 | } else {
41 | isUnsigned = (f[i].Flag&UNSIGNED_FLAG > 0)
42 |
43 | switch f[i].Type {
44 | case MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_INT24,
45 | MYSQL_TYPE_LONGLONG, MYSQL_TYPE_YEAR:
46 | if isUnsigned {
47 | data[i], err = strconv.ParseUint(string(v), 10, 64)
48 | } else {
49 | data[i], err = strconv.ParseInt(string(v), 10, 64)
50 | }
51 | case MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE:
52 | data[i], err = strconv.ParseFloat(string(v), 64)
53 | default:
54 | data[i] = v
55 | }
56 |
57 | if err != nil {
58 | return nil, err
59 | }
60 | }
61 | }
62 |
63 | return data, nil
64 | }
65 |
66 | func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) {
67 | data := make([]interface{}, len(f))
68 |
69 | if p[0] != OK_HEADER {
70 | return nil, ErrMalformPacket
71 | }
72 |
73 | pos := 1 + ((len(f) + 7 + 2) >> 3)
74 |
75 | nullBitmap := p[1:pos]
76 |
77 | var isUnsigned bool
78 | var isNull bool
79 | var n int
80 | var err error
81 | var v []byte
82 | for i := range data {
83 | if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 {
84 | data[i] = nil
85 | continue
86 | }
87 |
88 | isUnsigned = f[i].Flag&UNSIGNED_FLAG > 0
89 |
90 | switch f[i].Type {
91 | case MYSQL_TYPE_NULL:
92 | data[i] = nil
93 | continue
94 |
95 | case MYSQL_TYPE_TINY:
96 | if isUnsigned {
97 | data[i] = uint64(p[pos])
98 | } else {
99 | data[i] = int64(p[pos])
100 | }
101 | pos++
102 | continue
103 |
104 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
105 | if isUnsigned {
106 | data[i] = uint64(binary.LittleEndian.Uint16(p[pos : pos+2]))
107 | } else {
108 | data[i] = int64((binary.LittleEndian.Uint16(p[pos : pos+2])))
109 | }
110 | pos += 2
111 | continue
112 |
113 | case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
114 | if isUnsigned {
115 | data[i] = uint64(binary.LittleEndian.Uint32(p[pos : pos+4]))
116 | } else {
117 | data[i] = int64(binary.LittleEndian.Uint32(p[pos : pos+4]))
118 | }
119 | pos += 4
120 | continue
121 |
122 | case MYSQL_TYPE_LONGLONG:
123 | if isUnsigned {
124 | data[i] = binary.LittleEndian.Uint64(p[pos : pos+8])
125 | } else {
126 | data[i] = int64(binary.LittleEndian.Uint64(p[pos : pos+8]))
127 | }
128 | pos += 8
129 | continue
130 |
131 | case MYSQL_TYPE_FLOAT:
132 | data[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(p[pos : pos+4])))
133 | pos += 4
134 | continue
135 |
136 | case MYSQL_TYPE_DOUBLE:
137 | data[i] = math.Float64frombits(binary.LittleEndian.Uint64(p[pos : pos+8]))
138 | pos += 8
139 | continue
140 |
141 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR,
142 | MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
143 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
144 | MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY:
145 | v, isNull, n, err = LengthEnodedString(p[pos:])
146 | pos += n
147 | if err != nil {
148 | return nil, err
149 | }
150 |
151 | if !isNull {
152 | data[i] = v
153 | continue
154 | } else {
155 | data[i] = nil
156 | continue
157 | }
158 | case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
159 | var num uint64
160 | num, isNull, n = LengthEncodedInt(p[pos:])
161 |
162 | pos += n
163 |
164 | if isNull {
165 | data[i] = nil
166 | continue
167 | }
168 |
169 | data[i], err = FormatBinaryDate(int(num), p[pos:])
170 | pos += int(num)
171 |
172 | if err != nil {
173 | return nil, err
174 | }
175 |
176 | case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME:
177 | var num uint64
178 | num, isNull, n = LengthEncodedInt(p[pos:])
179 |
180 | pos += n
181 |
182 | if isNull {
183 | data[i] = nil
184 | continue
185 | }
186 |
187 | data[i], err = FormatBinaryDateTime(int(num), p[pos:])
188 | pos += int(num)
189 |
190 | if err != nil {
191 | return nil, err
192 | }
193 |
194 | case MYSQL_TYPE_TIME:
195 | var num uint64
196 | num, isNull, n = LengthEncodedInt(p[pos:])
197 |
198 | pos += n
199 |
200 | if isNull {
201 | data[i] = nil
202 | continue
203 | }
204 |
205 | data[i], err = FormatBinaryTime(int(num), p[pos:])
206 | pos += int(num)
207 |
208 | if err != nil {
209 | return nil, err
210 | }
211 |
212 | default:
213 | return nil, fmt.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name)
214 | }
215 | }
216 |
217 | return data, nil
218 | }
219 |
220 | type Resultset struct {
221 | Fields []*Field
222 | FieldNames map[string]int
223 | Values [][]interface{}
224 |
225 | RowDatas []RowData
226 | }
227 |
228 | func (r *Resultset) RowNumber() int {
229 | return len(r.Values)
230 | }
231 |
232 | func (r *Resultset) ColumnNumber() int {
233 | return len(r.Fields)
234 | }
235 |
236 | func (r *Resultset) GetValue(row, column int) (interface{}, error) {
237 | if row >= len(r.Values) || row < 0 {
238 | return nil, fmt.Errorf("invalid row index %d", row)
239 | }
240 |
241 | if column >= len(r.Fields) || column < 0 {
242 | return nil, fmt.Errorf("invalid column index %d", column)
243 | }
244 |
245 | return r.Values[row][column], nil
246 | }
247 |
248 | func (r *Resultset) NameIndex(name string) (int, error) {
249 | if column, ok := r.FieldNames[name]; ok {
250 | return column, nil
251 | } else {
252 | return 0, fmt.Errorf("invalid field name %s", name)
253 | }
254 | }
255 |
256 | func (r *Resultset) GetValueByName(row int, name string) (interface{}, error) {
257 | if column, err := r.NameIndex(name); err != nil {
258 | return nil, err
259 | } else {
260 | return r.GetValue(row, column)
261 | }
262 | }
263 |
264 | func (r *Resultset) IsNull(row, column int) (bool, error) {
265 | d, err := r.GetValue(row, column)
266 | if err != nil {
267 | return false, err
268 | }
269 |
270 | return d == nil, nil
271 | }
272 |
273 | func (r *Resultset) IsNullByName(row int, name string) (bool, error) {
274 | if column, err := r.NameIndex(name); err != nil {
275 | return false, err
276 | } else {
277 | return r.IsNull(row, column)
278 | }
279 | }
280 |
281 | func (r *Resultset) GetUint(row, column int) (uint64, error) {
282 | d, err := r.GetValue(row, column)
283 | if err != nil {
284 | return 0, err
285 | }
286 |
287 | switch v := d.(type) {
288 | case uint64:
289 | return v, nil
290 | case int64:
291 | return uint64(v), nil
292 | case float64:
293 | return uint64(v), nil
294 | case string:
295 | return strconv.ParseUint(v, 10, 64)
296 | case []byte:
297 | return strconv.ParseUint(string(v), 10, 64)
298 | case nil:
299 | return 0, nil
300 | default:
301 | return 0, fmt.Errorf("data type is %T", v)
302 | }
303 | }
304 |
305 | func (r *Resultset) GetUintByName(row int, name string) (uint64, error) {
306 | if column, err := r.NameIndex(name); err != nil {
307 | return 0, err
308 | } else {
309 | return r.GetUint(row, column)
310 | }
311 | }
312 |
313 | func (r *Resultset) GetInt(row, column int) (int64, error) {
314 | v, err := r.GetUint(row, column)
315 | if err != nil {
316 | return 0, err
317 | }
318 |
319 | return int64(v), nil
320 | }
321 |
322 | func (r *Resultset) GetIntByName(row int, name string) (int64, error) {
323 | v, err := r.GetUintByName(row, name)
324 | if err != nil {
325 | return 0, err
326 | }
327 |
328 | return int64(v), nil
329 | }
330 |
331 | func (r *Resultset) GetFloat(row, column int) (float64, error) {
332 | d, err := r.GetValue(row, column)
333 | if err != nil {
334 | return 0, err
335 | }
336 |
337 | switch v := d.(type) {
338 | case float64:
339 | return v, nil
340 | case uint64:
341 | return float64(v), nil
342 | case int64:
343 | return float64(v), nil
344 | case string:
345 | return strconv.ParseFloat(v, 64)
346 | case []byte:
347 | return strconv.ParseFloat(string(v), 64)
348 | case nil:
349 | return 0, nil
350 | default:
351 | return 0, fmt.Errorf("data type is %T", v)
352 | }
353 | }
354 |
355 | func (r *Resultset) GetFloatByName(row int, name string) (float64, error) {
356 | if column, err := r.NameIndex(name); err != nil {
357 | return 0, err
358 | } else {
359 | return r.GetFloat(row, column)
360 | }
361 | }
362 |
363 | func (r *Resultset) GetString(row, column int) (string, error) {
364 | d, err := r.GetValue(row, column)
365 | if err != nil {
366 | return "", err
367 | }
368 |
369 | switch v := d.(type) {
370 | case string:
371 | return v, nil
372 | case []byte:
373 | return hack.String(v), nil
374 | case int64:
375 | return strconv.FormatInt(v, 10), nil
376 | case uint64:
377 | return strconv.FormatUint(v, 10), nil
378 | case float64:
379 | return strconv.FormatFloat(v, 'f', -1, 64), nil
380 | case nil:
381 | return "", nil
382 | default:
383 | return "", fmt.Errorf("data type is %T", v)
384 | }
385 | }
386 |
387 | func (r *Resultset) GetStringByName(row int, name string) (string, error) {
388 | if column, err := r.NameIndex(name); err != nil {
389 | return "", err
390 | } else {
391 | return r.GetString(row, column)
392 | }
393 | }
394 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mixer
2 |
3 | Mixer is a MySQL proxy powered by Go which aims to supply a simple solution for MySQL sharding.
4 |
5 | ## Features
6 |
7 | - Supports basic SQL statements (select, insert, update, replace, delete)
8 | - Supports transactions
9 | - Splits reads and writes (not fully tested)
10 | - MySQL HA
11 | - Basic SQL Routing
12 | - Supports prepared statement: `COM_STMT_PREPARE`, `COM_STMT_EXECUTE`, etc.
13 |
14 | ## TODO
15 |
16 | - Some admin commands
17 | - Some show command support, i.e. ```show databases```, etc.
18 | - Some select system variable, i.e. ```select @@version```, etc.
19 | - Enhance routing rules
20 | - Monitor
21 | - SQL validation check
22 | - Statistics
23 | - Many other things ...
24 |
25 | ## Install
26 |
27 | cd $WORKSPACE
28 | git clone git@github.com:siddontang/mixer.git src/github.com/siddontang/mixer
29 |
30 | cd src/github.com/siddontang/mixer
31 |
32 | ./bootstrap.sh
33 |
34 | . ./dev.env
35 |
36 | make
37 | make test
38 |
39 | ## Keywords
40 |
41 | ### proxy
42 |
43 | A proxy is the bridge connecting clients and the real MySQL servers.
44 |
45 | It acts as a MySQL server too, clients can communicate with it using the MySQL procotol.
46 |
47 | ### node
48 |
49 | Mixer uses nodes to represent the real remote MySQL servers. A node can have two MySQL servers:
50 |
51 | + master: main MySQL server, all write operations, read operations (if ```rw_split``` and slave are not set) will be executed here.
52 | All transactions will be executed here too.
53 | + slave: if ```rw_split``` is set, any select operations will be executed here. (can not set)
54 |
55 | Notice:
56 |
57 | + You can use ```admin upnode``` or ```admin downnode``` commands to bring a specified MySQL server up or down.
58 | + If the master was down, you must use an admin command to bring it up manually.
59 | + You must set up MySQL replication for yourself, mixer does not do it.
60 |
61 | ### schema
62 |
63 | Schema likes MySQL database, if a client executes ```use db``` command, ```db``` must exist in the schema.
64 |
65 | A schema contains one or more nodes. If a client use the specified schema, any command will be only routed to the node which belongs to the schema to be executed.
66 |
67 | ### rule
68 |
69 | You must set some rules for a schema to let the mixer decide how to route SQL statements to different nodes to be executed.
70 |
71 | Mixer uses ```table + key``` to route. Duplicate rule for a table are not allowed.
72 |
73 | When SQL needs to be routed, mixer does the following steps:
74 |
75 | + Parse SQL and find the table operated on
76 | + If there are no rule for the table, mixer use the default rule
77 | + If a rule exists, mixer tries to route it with the specified key
78 |
79 | Rules have three types: default, hash and range.
80 |
81 | A schema must have a default rule with only one node assigned.
82 |
83 | For hash and range routing you can see the example below.
84 |
85 | ## admin commands
86 |
87 | Mixer suplies `admin` statement to administrate. The `admin` format is `admin func(arg, ...)` like `select func(arg,...)`. Later we may add admin password for safe use.
88 |
89 | Support admin functions now:
90 |
91 | - admin upnode(node, serverype, addr);
92 | - admin downnode(node, servertype);
93 | - show proxy config;
94 |
95 | ## Base Example
96 |
97 | ```
98 | #start mixer
99 | mixer-proxy -config=/etc/mixer.conf
100 |
101 | #another shell
102 | mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer
103 |
104 | Welcome to the MySQL monitor. Commands end with ; or \g.
105 | Your MySQL connection id is 158
106 | Server version: 5.6.19 Homebrew
107 |
108 | mysql> use mixer;
109 | Database changed
110 |
111 | mysql> delete from mixer_test_conn;
112 | Query OK, 3 rows affected (0.04 sec)
113 |
114 | mysql> insert into mixer_test_conn (id, str) values (1, "a");
115 | Query OK, 1 row affected (0.00 sec)
116 |
117 | mysql> insert into mixer_test_conn (id, str) values (2, "b");
118 | Query OK, 1 row affected (0.00 sec)
119 |
120 | mysql> select id, str from mixer_test_conn;
121 | +----+------+
122 | | id | str |
123 | +----+------+
124 | | 1 | a |
125 | | 2 | b |
126 | +----+------+
127 | ```
128 |
129 | ## Hash Sharding Example
130 |
131 | ```
132 | schemas :
133 | -
134 | db : mixer
135 | nodes: [node1, node2, node3]
136 | rules:
137 | default: node1
138 | shard:
139 | -
140 | table: mixer_test_shard_hash
141 | key: id
142 | nodes: [node2, node3]
143 | type: hash
144 |
145 | hash algorithm: value % len(nodes)
146 |
147 | table: mixer_test_shard_hash
148 |
149 | Node: node2, node3
150 | node2 mysql: 127.0.0.1:3307
151 | node3 mysql: 127.0.0.1:3308
152 |
153 | mixer-proxy: 127.0.0.1:4000
154 |
155 | proxy> mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer
156 | node2> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer
157 | node3> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer
158 |
159 | proxy> insert into mixer_test_shard_hash (id, str) values (0, "a");
160 | node2> select str from mixer_test_shard_hash where id = 0;
161 | +------+
162 | | str |
163 | +------+
164 | | a |
165 | +------+
166 |
167 | proxy> insert into mixer_test_shard_hash (id, str) values (1, "b");
168 | node3> select str from mixer_test_shard_hash where id = 1;
169 | +------+
170 | | str |
171 | +------+
172 | | b |
173 | +------+
174 |
175 | proxy> select str from mixer_test_shard_hash where id in (0, 1);
176 | +------+
177 | | str |
178 | +------+
179 | | a |
180 | | b |
181 | +------+
182 |
183 | proxy> select str from mixer_test_shard_hash where id = 0 or id = 1;
184 | +------+
185 | | str |
186 | +------+
187 | | a |
188 | | b |
189 | +------+
190 |
191 | proxy> select str from mixer_test_shard_hash where id = 0 and id = 1;
192 | Empty set
193 | ```
194 |
195 |
196 | ## Range Sharding Example
197 |
198 | ```
199 | schemas :
200 | -
201 | db : mixer
202 | nodes: [node1, node2, node3]
203 | rules:
204 | default: node1
205 | shard:
206 | -
207 | table: mixer_test_shard_range
208 | key: id
209 | nodes: [node2, node3]
210 | range: -10000-
211 | type: range
212 |
213 | range algorithm: node key start <= value < node key stop
214 |
215 | table: mixer_test_shard_range
216 |
217 | Node: node2, node3
218 | node2 range: (-inf, 10000)
219 | node3 range: [10000, +inf)
220 | node2 mysql: 127.0.0.1:3307
221 | node3 mysql: 127.0.0.1:3308
222 |
223 | mixer-proxy: 127.0.0.1:4000
224 |
225 | proxy> mysql -uroot -h127.0.0.1 -P4000 -p -Dmixer
226 | node2> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer
227 | node3> mysql -uroot -h127.0.0.1 -P3307 -p -Dmixer
228 |
229 | proxy> insert into mixer_test_shard_range (id, str) values (0, "a");
230 | node2> select str from mixer_test_shard_range where id = 0;
231 | +------+
232 | | str |
233 | +------+
234 | | a |
235 | +------+
236 |
237 | proxy> insert into mixer_test_shard_range (id, str) values (10000, "b");
238 | node3> select str from mixer_test_shard_range where id = 10000;
239 | +------+
240 | | str |
241 | +------+
242 | | b |
243 | +------+
244 |
245 | proxy> select str from mixer_test_shard_range where id in (0, 10000);
246 | +------+
247 | | str |
248 | +------+
249 | | a |
250 | | b |
251 | +------+
252 |
253 | proxy> select str from mixer_test_shard_range where id = 0 or id = 10000;
254 | +------+
255 | | str |
256 | +------+
257 | | a |
258 | | b |
259 | +------+
260 |
261 | proxy> select str from mixer_test_shard_range where id = 0 and id = 10000;
262 | Empty set
263 |
264 | proxy> select str from mixer_test_shard_range where id > 100;
265 | +------+
266 | | str |
267 | +------+
268 | | b |
269 | +------+
270 |
271 | proxy> select str from mixer_test_shard_range where id < 100;
272 | +------+
273 | | str |
274 | +------+
275 | | a |
276 | +------+
277 |
278 | proxy> select str from mixer_test_shard_range where id >=0 and id < 100000;
279 | +------+
280 | | str |
281 | +------+
282 | | a |
283 | | b |
284 | +------+
285 | ```
286 |
287 | ## Limitations
288 |
289 | ### Select
290 |
291 | + Join not supported, later only cross sharding not supported.
292 | + Subselects not supported, later only cross sharding not supported.
293 | + Cross sharding "group by" will not work ok only except the "group by" key is the routing key
294 | + Cross sharding "order by" only takes effect when the "order by" key exists as a select expression field
295 |
296 | ```select id from t1 order by id``` is ok.
297 |
298 | ```select str from t1 order by id``` is not ok, mixer does not known how to sort because it can not find proper data to compare with `id`
299 |
300 | + Limit should be used with "order by", otherwise you may receive incorrect results
301 |
302 | ### Insert
303 |
304 | + "insert into select" not supported, later only cross sharding not supported.
305 | + Multi insert values to different nodes not supported
306 | + "insert on duplicate key update" can not set the routing key
307 |
308 | ### Replace
309 |
310 | + Multi replace values to different nodes not supported
311 |
312 | ### Update
313 |
314 | + Update can not set the routing key
315 |
316 | ### Set
317 |
318 | + Set autocommit support
319 | + Set name charset support
320 |
321 | ### Range Rule
322 |
323 | + Only int64 number range supported
324 |
325 | ## Caveat
326 |
327 | + Mixer uses 2PC to handle write operations for multi nodes. You take the risk that data becomes corrupted if some nodes commit ok but others error. In that case, you must try to recover your data by yourself.
328 | + You must design your routing rule and write SQL carefully. (e.g. if your where condition contains no routing key, mixer will route the SQL to all nodes, maybe).
329 |
330 | ## Why not [vitess](https://github.com/youtube/vitess)?
331 |
332 | Vitess is very awesome, and I use some of its code like sqlparser. Why not use vitess directly? Maybe below:
333 |
334 | + Vitess is too huge for me, I need a simple proxy
335 | + Vitess uses an RPC protocol based on BSON, I want to use the MySQL protocol
336 | + Most likely, something has gone wrong in my head
337 |
338 | ## Status
339 |
340 | Mixer now is still in development and should not be used in production.
341 |
342 | ## Feedback
343 |
344 | Email: ,
345 |
--------------------------------------------------------------------------------
/proxy/conn_shard_test.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/mysql"
6 | "reflect"
7 | "testing"
8 | )
9 |
10 | func testShard_Insert(t *testing.T, table string, node string, id int, str string) {
11 | conn := newTestDBConn(t)
12 |
13 | s := fmt.Sprintf(`insert into %s (id, str) values (%d, "%s")`, table, id, str)
14 | if r, err := conn.Execute(s); err != nil {
15 | t.Fatal(s, err)
16 | } else if r.AffectedRows != 1 {
17 | t.Fatal(r.AffectedRows)
18 | }
19 | s = fmt.Sprintf(`select str from %s where id = %d`, table, id)
20 |
21 | n := newTestServer(t).nodes[node]
22 | c, err := n.getMasterConn()
23 | if err != nil {
24 | t.Fatal(s, err)
25 | } else {
26 | if r, err := c.Execute(s); err != nil {
27 | t.Fatal(s, err)
28 | } else if v, _ := r.GetString(0, 0); v != str {
29 | t.Fatal(s, v)
30 | }
31 | }
32 |
33 | if r, err := conn.Execute(s); err != nil {
34 | t.Fatal(s, err)
35 | } else if v, _ := r.GetString(0, 0); v != str {
36 | t.Fatal(s, v)
37 | }
38 | }
39 |
40 | func testShard_Select(t *testing.T, table string, where string, strs ...string) {
41 | sql := fmt.Sprintf("select str from %s where %s", table, where)
42 | conn := newTestDBConn(t)
43 |
44 | r, err := conn.Execute(sql)
45 | if err != nil {
46 | t.Fatal(sql, err)
47 | } else if r.RowNumber() != len(strs) {
48 | t.Fatal(sql, r.RowNumber(), len(strs))
49 | }
50 |
51 | m := map[string]struct{}{}
52 | for _, s := range strs {
53 | m[s] = struct{}{}
54 | }
55 |
56 | for i := 0; i < r.RowNumber(); i++ {
57 | if v, err := r.GetString(i, 0); err != nil {
58 | t.Fatal(sql, err)
59 | } else if _, ok := m[v]; !ok {
60 | t.Fatal(sql, v, "no in check strs")
61 | } else {
62 | delete(m, v)
63 | }
64 | }
65 |
66 | if len(m) != 0 {
67 | t.Fatal(sql, "invalid select")
68 | }
69 | }
70 |
71 | func testShard_StmtInsert(t *testing.T, table string, node string, id int, str string) {
72 | conn := newTestDBConn(t)
73 |
74 | s := fmt.Sprintf(`insert into %s (id, str) values (?, ?)`, table)
75 | if r, err := conn.Execute(s, id, str); err != nil {
76 | t.Fatal(s, err)
77 | } else if r.AffectedRows != 1 {
78 | t.Fatal(r.AffectedRows)
79 | }
80 | s = fmt.Sprintf(`select str from %s where id = ?`, table)
81 |
82 | n := newTestServer(t).nodes[node]
83 | c, err := n.getMasterConn()
84 | if err != nil {
85 | t.Fatal(s, err)
86 | } else {
87 | if r, err := c.Execute(s, id); err != nil {
88 | t.Fatal(s, err)
89 | } else if v, _ := r.GetString(0, 0); v != str {
90 | t.Fatal(s, v)
91 | }
92 | }
93 |
94 | if r, err := conn.Execute(s, id); err != nil {
95 | t.Fatal(s, err)
96 | } else if v, _ := r.GetString(0, 0); v != str {
97 | t.Fatal(s, v)
98 | }
99 | }
100 |
101 | func testShard_StmtSelect(t *testing.T, table string, where string, args []interface{}, strs ...string) {
102 | sql := fmt.Sprintf("select str from %s where %s", table, where)
103 | conn := newTestDBConn(t)
104 |
105 | r, err := conn.Execute(sql, args...)
106 | if err != nil {
107 | t.Fatal(sql, err)
108 | } else if r.RowNumber() != len(strs) {
109 | t.Fatal(sql, r.RowNumber(), len(strs))
110 | }
111 |
112 | m := map[string]struct{}{}
113 | for _, s := range strs {
114 | m[s] = struct{}{}
115 | }
116 |
117 | for i := 0; i < r.RowNumber(); i++ {
118 | if v, err := r.GetString(i, 0); err != nil {
119 | t.Fatal(sql, err)
120 | } else if _, ok := m[v]; !ok {
121 | t.Fatal(sql, v, "no in check strs")
122 | } else {
123 | delete(m, v)
124 | }
125 | }
126 |
127 | if len(m) != 0 {
128 | t.Fatal(sql, "invalid select")
129 | }
130 | }
131 |
132 | func TestShard_DeleteHashTable(t *testing.T) {
133 | s := `drop table if exists mixer_test_shard_hash`
134 |
135 | server := newTestServer(t)
136 |
137 | for _, n := range server.nodes {
138 | if n.String() != "node2" && n.String() != "node3" {
139 | continue
140 | }
141 | c, err := n.getMasterConn()
142 | if err != nil {
143 | t.Fatal(err)
144 | }
145 |
146 | c.UseDB("mixer")
147 | defer c.Close()
148 | if _, err := c.Execute(s); err != nil {
149 | t.Fatal(err)
150 | }
151 |
152 | }
153 | }
154 |
155 | func TestShard_CreateHashTable(t *testing.T) {
156 | s := `CREATE TABLE IF NOT EXISTS mixer_test_shard_hash (
157 | id BIGINT(64) UNSIGNED NOT NULL,
158 | str VARCHAR(256),
159 | PRIMARY KEY (id)
160 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
161 |
162 | server := newTestServer(t)
163 |
164 | for _, n := range server.nodes {
165 | if n.String() != "node2" && n.String() != "node3" {
166 | continue
167 | }
168 | c, err := n.getMasterConn()
169 | if err != nil {
170 | t.Fatal(err)
171 | }
172 |
173 | c.UseDB("mixer")
174 | defer c.Close()
175 | if _, err := c.Execute(s); err != nil {
176 | t.Fatal(err)
177 | }
178 | }
179 | }
180 |
181 | func TestShard_Hash(t *testing.T) {
182 | table := "mixer_test_shard_hash"
183 | testShard_Insert(t, table, "node2", 0, "a")
184 | testShard_Insert(t, table, "node3", 1, "b")
185 | testShard_Insert(t, table, "node2", 2, "c")
186 | testShard_Insert(t, table, "node3", 3, "d")
187 |
188 | testShard_Select(t, table, "id = 2", "c")
189 | testShard_Select(t, table, "id = 2 or id = 3", "c", "d")
190 | testShard_Select(t, table, "id = 2 and id = 3")
191 | testShard_Select(t, table, "id in (0, 1, 3)", "a", "b", "d")
192 |
193 | testShard_StmtInsert(t, table, "node2", 10, "a")
194 | testShard_StmtInsert(t, table, "node3", 11, "b")
195 | testShard_StmtInsert(t, table, "node2", 12, "c")
196 | testShard_StmtInsert(t, table, "node3", 13, "d")
197 |
198 | testShard_StmtSelect(t, table, "id = ?", []interface{}{12}, "c")
199 | testShard_StmtSelect(t, table, "id = ? or id = ?", []interface{}{12, 13}, "c", "d")
200 | testShard_StmtSelect(t, table, "id = ? and id = ?", []interface{}{12, 13})
201 | testShard_StmtSelect(t, table, "id in (?, ?, ?)", []interface{}{10, 11, 13}, "a", "b", "d")
202 | }
203 |
204 | func testExecute(t *testing.T, sql string) *mysql.Result {
205 | conn := newTestDBConn(t)
206 |
207 | r, err := conn.Execute(sql)
208 | if err != nil {
209 | t.Fatal(err)
210 | }
211 |
212 | return r
213 | }
214 |
215 | func testShared_SelectOrderBy(t *testing.T, table string, where string, v [][]interface{}) {
216 | sql := fmt.Sprintf("select id, str from %s where %s", table, where)
217 |
218 | r := testExecute(t, sql)
219 |
220 | if !reflect.DeepEqual(r.Values, v) {
221 | t.Fatal(fmt.Sprintf("%v != %v", r.Values, v))
222 | }
223 | }
224 |
225 | func TestShard_HashOrderByLimit(t *testing.T) {
226 | table := "mixer_test_shard_hash"
227 |
228 | testShard_Insert(t, table, "node2", 4, "a")
229 | testShard_Insert(t, table, "node3", 5, "a")
230 | testShard_Insert(t, table, "node2", 6, "b")
231 | testShard_Insert(t, table, "node3", 7, "b")
232 |
233 | var v [][]interface{}
234 | v = [][]interface{}{
235 | []interface{}{uint64(7), []byte("b")},
236 | []interface{}{uint64(6), []byte("b")},
237 | []interface{}{uint64(5), []byte("a")},
238 | []interface{}{uint64(4), []byte("a")},
239 | }
240 |
241 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by id desc", v)
242 |
243 | v = [][]interface{}{
244 | []interface{}{uint64(6), []byte("b")},
245 | []interface{}{uint64(7), []byte("b")},
246 | []interface{}{uint64(4), []byte("a")},
247 | []interface{}{uint64(5), []byte("a")},
248 | }
249 |
250 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc", v)
251 |
252 | v = [][]interface{}{
253 | []interface{}{uint64(6), []byte("b")},
254 | []interface{}{uint64(7), []byte("b")},
255 | }
256 |
257 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc limit 0, 2", v)
258 |
259 | v = [][]interface{}{
260 | []interface{}{uint64(5), []byte("a")},
261 | }
262 |
263 | testShared_SelectOrderBy(t, table, "id in (4,5,6,7) order by str desc, id asc limit 1, 2", v)
264 |
265 | }
266 |
267 | func TestShard_DeleteRangeTable(t *testing.T) {
268 | s := `drop table if exists mixer_test_shard_range`
269 |
270 | server := newTestServer(t)
271 |
272 | for _, n := range server.nodes {
273 | if n.String() != "node2" && n.String() != "node3" {
274 | continue
275 | }
276 | c, err := n.getMasterConn()
277 | if err != nil {
278 | t.Fatal(err)
279 | }
280 |
281 | c.UseDB("mixer")
282 | defer c.Close()
283 | if _, err := c.Execute(s); err != nil {
284 | t.Fatal(err)
285 | }
286 |
287 | }
288 | }
289 |
290 | func TestShard_CreateRangeTable(t *testing.T) {
291 | s := `CREATE TABLE IF NOT EXISTS mixer_test_shard_range (
292 | id BIGINT(64) UNSIGNED NOT NULL,
293 | str VARCHAR(256),
294 | PRIMARY KEY (id)
295 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`
296 |
297 | server := newTestServer(t)
298 |
299 | for _, n := range server.nodes {
300 | if n.String() != "node2" && n.String() != "node3" {
301 | continue
302 | }
303 | c, err := n.getMasterConn()
304 | if err != nil {
305 | t.Fatal(err)
306 | }
307 |
308 | c.UseDB("mixer")
309 | defer c.Close()
310 | if _, err := c.Execute(s); err != nil {
311 | t.Fatal(err)
312 | }
313 |
314 | }
315 | }
316 |
317 | func TestShard_Range(t *testing.T) {
318 | table := "mixer_test_shard_range"
319 | testShard_Insert(t, table, "node2", 0, "a")
320 | testShard_Insert(t, table, "node3", 10000, "b")
321 | testShard_StmtInsert(t, table, "node2", 2, "c")
322 | testShard_StmtInsert(t, table, "node3", 10001, "d")
323 |
324 | testShard_Select(t, table, "id = 2", "c")
325 | testShard_Select(t, table, "id = 2 or id = 10001", "c", "d")
326 | testShard_Select(t, table, "id = 2 and id = 10001")
327 | testShard_Select(t, table, "id in (0, 10000, 10001)", "a", "b", "d")
328 | testShard_Select(t, table, "id < 1 or id >= 10000", "a", "b", "d")
329 | testShard_Select(t, table, "id > 1 and id <= 10000", "b", "c")
330 | testShard_Select(t, table, "id < 1 and id >= 10000")
331 |
332 | testShard_StmtSelect(t, table, "id = ?", []interface{}{2}, "c")
333 | testShard_StmtSelect(t, table, "id = ? or id = ?", []interface{}{2, 10001}, "c", "d")
334 | testShard_StmtSelect(t, table, "id = ? and id = ?", []interface{}{2, 10001})
335 | testShard_StmtSelect(t, table, "id in (?, ?, ?)", []interface{}{0, 10000, 10001}, "a", "b", "d")
336 | testShard_StmtSelect(t, table, "id < ? or id >= ?", []interface{}{1, 10000}, "a", "b", "d")
337 | testShard_StmtSelect(t, table, "id > ? and id <= ?", []interface{}{1, 10000}, "b", "c")
338 | testShard_StmtSelect(t, table, "id < ? and id >= ?", []interface{}{1, 10000})
339 | }
340 |
--------------------------------------------------------------------------------
/proxy/conn_stmt.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | . "github.com/siddontang/mixer/mysql"
7 | "github.com/siddontang/mixer/sqlparser"
8 | "math"
9 | "strconv"
10 | "strings"
11 | )
12 |
13 | var paramFieldData []byte
14 | var columnFieldData []byte
15 |
16 | func init() {
17 | var p = &Field{Name: []byte("?")}
18 | var c = &Field{}
19 |
20 | paramFieldData = p.Dump()
21 | columnFieldData = c.Dump()
22 | }
23 |
24 | type Stmt struct {
25 | id uint32
26 |
27 | params int
28 | columns int
29 |
30 | args []interface{}
31 |
32 | s sqlparser.Statement
33 |
34 | sql string
35 | }
36 |
37 | func (s *Stmt) ResetParams() {
38 | s.args = make([]interface{}, s.params)
39 | }
40 |
41 | func (c *Conn) handleStmtPrepare(sql string) error {
42 | if c.schema == nil {
43 | return NewDefaultError(ER_NO_DB_ERROR)
44 | }
45 |
46 | s := new(Stmt)
47 |
48 | sql = strings.TrimRight(sql, ";")
49 |
50 | var err error
51 | s.s, err = sqlparser.Parse(sql)
52 | if err != nil {
53 | return fmt.Errorf(`parse sql "%s" error`, sql)
54 | }
55 |
56 | s.sql = sql
57 |
58 | var tableName string
59 | switch s := s.s.(type) {
60 | case *sqlparser.Select:
61 | tableName = nstring(s.From)
62 | case *sqlparser.Insert:
63 | tableName = nstring(s.Table)
64 | case *sqlparser.Update:
65 | tableName = nstring(s.Table)
66 | case *sqlparser.Delete:
67 | tableName = nstring(s.Table)
68 | case *sqlparser.Replace:
69 | tableName = nstring(s.Table)
70 | default:
71 | return fmt.Errorf(`unsupport prepare sql "%s"`, sql)
72 | }
73 |
74 | r := c.schema.rule.GetRule(tableName)
75 |
76 | n := c.server.getNode(r.Nodes[0])
77 |
78 | if co, err := n.getMasterConn(); err != nil {
79 | return fmt.Errorf("prepare error %s", err)
80 | } else {
81 | defer co.Close()
82 |
83 | if err = co.UseDB(c.schema.db); err != nil {
84 | return fmt.Errorf("parepre error %s", err)
85 | }
86 |
87 | if t, err := co.Prepare(sql); err != nil {
88 | return fmt.Errorf("parepre error %s", err)
89 | } else {
90 |
91 | s.params = t.ParamNum()
92 | s.columns = t.ColumnNum()
93 | }
94 | }
95 |
96 | s.id = c.stmtId
97 | c.stmtId++
98 |
99 | if err = c.writePrepare(s); err != nil {
100 | return err
101 | }
102 |
103 | s.ResetParams()
104 |
105 | c.stmts[s.id] = s
106 |
107 | return nil
108 | }
109 |
110 | func (c *Conn) writePrepare(s *Stmt) error {
111 | data := make([]byte, 4, 128)
112 |
113 | //status ok
114 | data = append(data, 0)
115 | //stmt id
116 | data = append(data, Uint32ToBytes(s.id)...)
117 | //number columns
118 | data = append(data, Uint16ToBytes(uint16(s.columns))...)
119 | //number params
120 | data = append(data, Uint16ToBytes(uint16(s.params))...)
121 | //filter [00]
122 | data = append(data, 0)
123 | //warning count
124 | data = append(data, 0, 0)
125 |
126 | if err := c.writePacket(data); err != nil {
127 | return err
128 | }
129 |
130 | if s.params > 0 {
131 | for i := 0; i < s.params; i++ {
132 | data = data[0:4]
133 | data = append(data, []byte(paramFieldData)...)
134 |
135 | if err := c.writePacket(data); err != nil {
136 | return err
137 | }
138 | }
139 |
140 | if err := c.writeEOF(c.status); err != nil {
141 | return err
142 | }
143 | }
144 |
145 | if s.columns > 0 {
146 | for i := 0; i < s.columns; i++ {
147 | data = data[0:4]
148 | data = append(data, []byte(columnFieldData)...)
149 |
150 | if err := c.writePacket(data); err != nil {
151 | return err
152 | }
153 | }
154 |
155 | if err := c.writeEOF(c.status); err != nil {
156 | return err
157 | }
158 |
159 | }
160 | return nil
161 | }
162 |
163 | func (c *Conn) handleStmtExecute(data []byte) error {
164 | if len(data) < 9 {
165 | return ErrMalformPacket
166 | }
167 |
168 | pos := 0
169 | id := binary.LittleEndian.Uint32(data[0:4])
170 | pos += 4
171 |
172 | s, ok := c.stmts[id]
173 | if !ok {
174 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
175 | strconv.FormatUint(uint64(id), 10), "stmt_execute")
176 | }
177 |
178 | flag := data[pos]
179 | pos++
180 | //now we only support CURSOR_TYPE_NO_CURSOR flag
181 | if flag != 0 {
182 | return NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag))
183 | }
184 |
185 | //skip iteration-count, always 1
186 | pos += 4
187 |
188 | var nullBitmaps []byte
189 | var paramTypes []byte
190 | var paramValues []byte
191 |
192 | paramNum := s.params
193 |
194 | if paramNum > 0 {
195 | nullBitmapLen := (s.params + 7) >> 3
196 | if len(data) < (pos + nullBitmapLen + 1) {
197 | return ErrMalformPacket
198 | }
199 | nullBitmaps = data[pos : pos+nullBitmapLen]
200 | pos += nullBitmapLen
201 |
202 | //new param bound flag
203 | if data[pos] == 1 {
204 | pos++
205 | if len(data) < (pos + (paramNum << 1)) {
206 | return ErrMalformPacket
207 | }
208 |
209 | paramTypes = data[pos : pos+(paramNum<<1)]
210 | pos += (paramNum << 1)
211 |
212 | paramValues = data[pos:]
213 | }
214 |
215 | if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil {
216 | return err
217 | }
218 | }
219 |
220 | var err error
221 |
222 | switch stmt := s.s.(type) {
223 | case *sqlparser.Select:
224 | err = c.handleSelect(stmt, s.sql, s.args)
225 | case *sqlparser.Insert:
226 | err = c.handleExec(s.s, s.sql, s.args)
227 | case *sqlparser.Update:
228 | err = c.handleExec(s.s, s.sql, s.args)
229 | case *sqlparser.Delete:
230 | err = c.handleExec(s.s, s.sql, s.args)
231 | case *sqlparser.Replace:
232 | err = c.handleExec(s.s, s.sql, s.args)
233 | default:
234 | err = fmt.Errorf("command %T not supported now", stmt)
235 | }
236 |
237 | s.ResetParams()
238 |
239 | return err
240 | }
241 |
242 | func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error {
243 | args := s.args
244 |
245 | pos := 0
246 |
247 | var v []byte
248 | var n int = 0
249 | var isNull bool
250 | var err error
251 |
252 | for i := 0; i < s.params; i++ {
253 | if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
254 | args[i] = nil
255 | continue
256 | }
257 |
258 | tp := paramTypes[i<<1]
259 | isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
260 |
261 | switch tp {
262 | case MYSQL_TYPE_NULL:
263 | args[i] = nil
264 | continue
265 |
266 | case MYSQL_TYPE_TINY:
267 | if len(paramValues) < (pos + 1) {
268 | return ErrMalformPacket
269 | }
270 |
271 | if isUnsigned {
272 | args[i] = uint8(paramValues[pos])
273 | } else {
274 | args[i] = int8(paramValues[pos])
275 | }
276 |
277 | pos++
278 | continue
279 |
280 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
281 | if len(paramValues) < (pos + 2) {
282 | return ErrMalformPacket
283 | }
284 |
285 | if isUnsigned {
286 | args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
287 | } else {
288 | args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2])))
289 | }
290 | pos += 2
291 | continue
292 |
293 | case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG:
294 | if len(paramValues) < (pos + 4) {
295 | return ErrMalformPacket
296 | }
297 |
298 | if isUnsigned {
299 | args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
300 | } else {
301 | args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
302 | }
303 | pos += 4
304 | continue
305 |
306 | case MYSQL_TYPE_LONGLONG:
307 | if len(paramValues) < (pos + 8) {
308 | return ErrMalformPacket
309 | }
310 |
311 | if isUnsigned {
312 | args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8])
313 | } else {
314 | args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
315 | }
316 | pos += 8
317 | continue
318 |
319 | case MYSQL_TYPE_FLOAT:
320 | if len(paramValues) < (pos + 4) {
321 | return ErrMalformPacket
322 | }
323 |
324 | args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
325 | pos += 4
326 | continue
327 |
328 | case MYSQL_TYPE_DOUBLE:
329 | if len(paramValues) < (pos + 8) {
330 | return ErrMalformPacket
331 | }
332 |
333 | args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
334 | pos += 8
335 | continue
336 |
337 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR,
338 | MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB,
339 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB,
340 | MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY,
341 | MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE,
342 | MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME:
343 | if len(paramValues) < (pos + 1) {
344 | return ErrMalformPacket
345 | }
346 |
347 | v, isNull, n, err = LengthEnodedString(paramValues[pos:])
348 | pos += n
349 | if err != nil {
350 | return err
351 | }
352 |
353 | if !isNull {
354 | args[i] = v
355 | continue
356 | } else {
357 | args[i] = nil
358 | continue
359 | }
360 | default:
361 | return fmt.Errorf("Stmt Unknown FieldType %d", tp)
362 | }
363 | }
364 | return nil
365 | }
366 |
367 | func (c *Conn) handleStmtSendLongData(data []byte) error {
368 | if len(data) < 6 {
369 | return ErrMalformPacket
370 | }
371 |
372 | id := binary.LittleEndian.Uint32(data[0:4])
373 |
374 | s, ok := c.stmts[id]
375 | if !ok {
376 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
377 | strconv.FormatUint(uint64(id), 10), "stmt_send_longdata")
378 | }
379 |
380 | paramId := binary.LittleEndian.Uint16(data[4:6])
381 | if paramId >= uint16(s.params) {
382 | return NewDefaultError(ER_WRONG_ARGUMENTS, "stmt_send_longdata")
383 | }
384 |
385 | if s.args[paramId] == nil {
386 | s.args[paramId] = data[6:]
387 | } else {
388 | if b, ok := s.args[paramId].([]byte); ok {
389 | b = append(b, data[6:]...)
390 | s.args[paramId] = b
391 | } else {
392 | return fmt.Errorf("invalid param long data type %T", s.args[paramId])
393 | }
394 | }
395 |
396 | return nil
397 | }
398 |
399 | func (c *Conn) handleStmtReset(data []byte) error {
400 | if len(data) < 4 {
401 | return ErrMalformPacket
402 | }
403 |
404 | id := binary.LittleEndian.Uint32(data[0:4])
405 |
406 | s, ok := c.stmts[id]
407 | if !ok {
408 | return NewDefaultError(ER_UNKNOWN_STMT_HANDLER,
409 | strconv.FormatUint(uint64(id), 10), "stmt_reset")
410 | }
411 |
412 | s.ResetParams()
413 |
414 | return c.writeOK(nil)
415 | }
416 |
417 | func (c *Conn) handleStmtClose(data []byte) error {
418 | if len(data) < 4 {
419 | return nil
420 | }
421 |
422 | id := binary.LittleEndian.Uint32(data[0:4])
423 |
424 | delete(c.stmts, id)
425 |
426 | return nil
427 | }
428 |
--------------------------------------------------------------------------------
/proxy/conn_query.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "fmt"
5 | "github.com/siddontang/mixer/client"
6 | "github.com/siddontang/mixer/hack"
7 | . "github.com/siddontang/mixer/mysql"
8 | "github.com/siddontang/mixer/sqlparser"
9 | "strconv"
10 | "strings"
11 | "sync"
12 | )
13 |
14 | func (c *Conn) handleQuery(sql string) (err error) {
15 | defer func() {
16 | if e := recover(); e != nil {
17 | err = fmt.Errorf("execute %s error %v", sql, e)
18 | return
19 | }
20 | }()
21 |
22 | sql = strings.TrimRight(sql, ";")
23 |
24 | var stmt sqlparser.Statement
25 | stmt, err = sqlparser.Parse(sql)
26 | if err != nil {
27 | return fmt.Errorf(`parse sql "%s" error`, sql)
28 | }
29 |
30 | switch v := stmt.(type) {
31 | case *sqlparser.Select:
32 | return c.handleSelect(v, sql, nil)
33 | case *sqlparser.Insert:
34 | return c.handleExec(stmt, sql, nil)
35 | case *sqlparser.Update:
36 | return c.handleExec(stmt, sql, nil)
37 | case *sqlparser.Delete:
38 | return c.handleExec(stmt, sql, nil)
39 | case *sqlparser.Replace:
40 | return c.handleExec(stmt, sql, nil)
41 | case *sqlparser.Set:
42 | return c.handleSet(v)
43 | case *sqlparser.Begin:
44 | return c.handleBegin()
45 | case *sqlparser.Commit:
46 | return c.handleCommit()
47 | case *sqlparser.Rollback:
48 | return c.handleRollback()
49 | case *sqlparser.SimpleSelect:
50 | return c.handleSimpleSelect(sql, v)
51 | case *sqlparser.Show:
52 | return c.handleShow(sql, v)
53 | case *sqlparser.Admin:
54 | return c.handleAdmin(v)
55 | default:
56 | return fmt.Errorf("statement %T not support now", stmt)
57 | }
58 |
59 | return nil
60 | }
61 |
62 | func (c *Conn) getShardList(stmt sqlparser.Statement, bindVars map[string]interface{}) ([]*Node, error) {
63 | if c.schema == nil {
64 | return nil, NewDefaultError(ER_NO_DB_ERROR)
65 | }
66 |
67 | ns, err := sqlparser.GetStmtShardList(stmt, c.schema.rule, bindVars)
68 | if err != nil {
69 | return nil, err
70 | }
71 |
72 | if len(ns) == 0 {
73 | return nil, nil
74 | }
75 |
76 | n := make([]*Node, 0, len(ns))
77 | for _, name := range ns {
78 | n = append(n, c.server.getNode(name))
79 | }
80 | return n, nil
81 | }
82 |
83 | func (c *Conn) getConn(n *Node, isSelect bool) (co *client.SqlConn, err error) {
84 | if !c.needBeginTx() {
85 | if isSelect {
86 | co, err = n.getSelectConn()
87 | } else {
88 | co, err = n.getMasterConn()
89 | }
90 | if err != nil {
91 | return
92 | }
93 | } else {
94 | var ok bool
95 | c.Lock()
96 | co, ok = c.txConns[n]
97 | c.Unlock()
98 |
99 | if !ok {
100 | if co, err = n.getMasterConn(); err != nil {
101 | return
102 | }
103 |
104 | if err = co.Begin(); err != nil {
105 | return
106 | }
107 |
108 | c.Lock()
109 | c.txConns[n] = co
110 | c.Unlock()
111 | }
112 | }
113 |
114 | //todo, set conn charset, etc...
115 | if err = co.UseDB(c.schema.db); err != nil {
116 | return
117 | }
118 |
119 | if err = co.SetCharset(c.charset); err != nil {
120 | return
121 | }
122 |
123 | return
124 | }
125 |
126 | func (c *Conn) getShardConns(isSelect bool,stmt sqlparser.Statement, bindVars map[string]interface{}) ([]*client.SqlConn, error) {
127 | nodes, err := c.getShardList(stmt, bindVars)
128 | if err != nil {
129 | return nil, err
130 | } else if nodes == nil {
131 | return nil, nil
132 | }
133 |
134 | conns := make([]*client.SqlConn, 0, len(nodes))
135 |
136 | var co *client.SqlConn
137 | for _, n := range nodes {
138 | co, err = c.getConn(n, isSelect)
139 | if err != nil {
140 | break
141 | }
142 |
143 | conns = append(conns, co)
144 | }
145 |
146 | return conns, err
147 | }
148 |
149 | func (c *Conn) executeInShard(conns []*client.SqlConn, sql string, args []interface{}) ([]*Result, error) {
150 | var wg sync.WaitGroup
151 | wg.Add(len(conns))
152 |
153 | rs := make([]interface{}, len(conns))
154 |
155 | f := func(rs []interface{}, i int, co *client.SqlConn) {
156 | r, err := co.Execute(sql, args...)
157 | if err != nil {
158 | rs[i] = err
159 | } else {
160 | rs[i] = r
161 | }
162 |
163 | wg.Done()
164 | }
165 |
166 | for i, co := range conns {
167 | go f(rs, i, co)
168 | }
169 |
170 | wg.Wait()
171 |
172 | var err error
173 | r := make([]*Result, len(conns))
174 | for i, v := range rs {
175 | if e, ok := v.(error); ok {
176 | err = e
177 | break
178 | }
179 | r[i] = rs[i].(*Result)
180 | }
181 |
182 | return r, err
183 | }
184 |
185 | func (c *Conn) closeShardConns(conns []*client.SqlConn, rollback bool) {
186 | if c.isInTransaction() {
187 | return
188 | }
189 |
190 | for _, co := range conns {
191 | if rollback {
192 | co.Rollback()
193 | }
194 |
195 | co.Close()
196 | }
197 | }
198 |
199 | func (c *Conn) newEmptyResultset(stmt *sqlparser.Select) *Resultset {
200 | r := new(Resultset)
201 | r.Fields = make([]*Field, len(stmt.SelectExprs))
202 |
203 | for i, expr := range stmt.SelectExprs {
204 | r.Fields[i] = &Field{}
205 | switch e := expr.(type) {
206 | case *sqlparser.StarExpr:
207 | r.Fields[i].Name = []byte("*")
208 | case *sqlparser.NonStarExpr:
209 | if e.As != nil {
210 | r.Fields[i].Name = e.As
211 | r.Fields[i].OrgName = hack.Slice(nstring(e.Expr))
212 | } else {
213 | r.Fields[i].Name = hack.Slice(nstring(e.Expr))
214 | }
215 | default:
216 | r.Fields[i].Name = hack.Slice(nstring(e))
217 | }
218 | }
219 |
220 | r.Values = make([][]interface{}, 0)
221 | r.RowDatas = make([]RowData, 0)
222 |
223 | return r
224 | }
225 |
226 | func makeBindVars(args []interface{}) map[string]interface{} {
227 | bindVars := make(map[string]interface{}, len(args))
228 |
229 | for i, v := range args {
230 | bindVars[fmt.Sprintf("v%d", i+1)] = v
231 | }
232 |
233 | return bindVars
234 | }
235 |
236 | func (c *Conn) handleSelect(stmt *sqlparser.Select, sql string, args []interface{}) error {
237 | bindVars := makeBindVars(args)
238 |
239 | conns, err := c.getShardConns(true,stmt, bindVars)
240 | if err != nil {
241 | return err
242 | } else if conns == nil {
243 | r := c.newEmptyResultset(stmt)
244 | return c.writeResultset(c.status, r)
245 | }
246 |
247 | var rs []*Result
248 |
249 | rs, err = c.executeInShard(conns, sql, args)
250 |
251 | c.closeShardConns(conns, false)
252 |
253 | if err == nil {
254 | err = c.mergeSelectResult(rs, stmt)
255 | }
256 |
257 | return err
258 | }
259 |
260 | func (c *Conn) beginShardConns(conns []*client.SqlConn) error {
261 | if c.isInTransaction() {
262 | return nil
263 | }
264 |
265 | for _, co := range conns {
266 | if err := co.Begin(); err != nil {
267 | return err
268 | }
269 | }
270 |
271 | return nil
272 | }
273 |
274 | func (c *Conn) commitShardConns(conns []*client.SqlConn) error {
275 | if c.isInTransaction() {
276 | return nil
277 | }
278 |
279 | for _, co := range conns {
280 | if err := co.Commit(); err != nil {
281 | return err
282 | }
283 | }
284 |
285 | return nil
286 | }
287 |
288 | func (c *Conn) handleExec(stmt sqlparser.Statement, sql string, args []interface{}) error {
289 | bindVars := makeBindVars(args)
290 |
291 | conns, err := c.getShardConns(false,stmt, bindVars)
292 | if err != nil {
293 | return err
294 | } else if conns == nil {
295 | return c.writeOK(nil)
296 | }
297 |
298 | var rs []*Result
299 |
300 | if len(conns) == 1 {
301 | rs, err = c.executeInShard(conns, sql, args)
302 | } else {
303 | //for multi nodes, 2PC simple, begin, exec, commit
304 | //if commit error, data maybe corrupt
305 | for {
306 | if err = c.beginShardConns(conns); err != nil {
307 | break
308 | }
309 |
310 | if rs, err = c.executeInShard(conns, sql, args); err != nil {
311 | break
312 | }
313 |
314 | err = c.commitShardConns(conns)
315 | break
316 | }
317 | }
318 |
319 | c.closeShardConns(conns, err != nil)
320 |
321 | if err == nil {
322 | err = c.mergeExecResult(rs)
323 | }
324 |
325 | return err
326 | }
327 |
328 | func (c *Conn) mergeExecResult(rs []*Result) error {
329 | r := new(Result)
330 |
331 | for _, v := range rs {
332 | r.Status |= v.Status
333 | r.AffectedRows += v.AffectedRows
334 | if r.InsertId == 0 {
335 | r.InsertId = v.InsertId
336 | } else if r.InsertId > v.InsertId {
337 | //last insert id is first gen id for multi row inserted
338 | //see http://dev.mysql.com/doc/refman/5.6/en/information-functions.html#function_last-insert-id
339 | r.InsertId = v.InsertId
340 | }
341 | }
342 |
343 | if r.InsertId > 0 {
344 | c.lastInsertId = int64(r.InsertId)
345 | }
346 |
347 | c.affectedRows = int64(r.AffectedRows)
348 |
349 | return c.writeOK(r)
350 | }
351 |
352 | func (c *Conn) mergeSelectResult(rs []*Result, stmt *sqlparser.Select) error {
353 | r := rs[0].Resultset
354 |
355 | status := c.status | rs[0].Status
356 |
357 | for i := 1; i < len(rs); i++ {
358 | status |= rs[i].Status
359 |
360 | //check fields equal
361 |
362 | for j := range rs[i].Values {
363 | r.Values = append(r.Values, rs[i].Values[j])
364 | r.RowDatas = append(r.RowDatas, rs[i].RowDatas[j])
365 | }
366 | }
367 |
368 | //to do order by, group by, limit offset
369 | c.sortSelectResult(r, stmt)
370 | //to do, add log here, sort may error because order by key not exist in resultset fields
371 |
372 | if err := c.limitSelectResult(r, stmt); err != nil {
373 | return err
374 | }
375 |
376 | return c.writeResultset(status, r)
377 | }
378 |
379 | func (c *Conn) sortSelectResult(r *Resultset, stmt *sqlparser.Select) error {
380 | if stmt.OrderBy == nil {
381 | return nil
382 | }
383 |
384 | sk := make([]SortKey, len(stmt.OrderBy))
385 |
386 | for i, o := range stmt.OrderBy {
387 | sk[i].Name = nstring(o.Expr)
388 | sk[i].Direction = o.Direction
389 | }
390 |
391 | return r.Sort(sk)
392 | }
393 |
394 | func (c *Conn) limitSelectResult(r *Resultset, stmt *sqlparser.Select) error {
395 | if stmt.Limit == nil {
396 | return nil
397 | }
398 |
399 | var offset, count int64
400 | var err error
401 | if stmt.Limit.Offset == nil {
402 | offset = 0
403 | } else {
404 | if o, ok := stmt.Limit.Offset.(sqlparser.NumVal); !ok {
405 | return fmt.Errorf("invalid select limit %s", nstring(stmt.Limit))
406 | } else {
407 | if offset, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil {
408 | return err
409 | }
410 | }
411 | }
412 |
413 | if o, ok := stmt.Limit.Rowcount.(sqlparser.NumVal); !ok {
414 | return fmt.Errorf("invalid limit %s", nstring(stmt.Limit))
415 | } else {
416 | if count, err = strconv.ParseInt(hack.String([]byte(o)), 10, 64); err != nil {
417 | return err
418 | } else if count < 0 {
419 | return fmt.Errorf("invalid limit %s", nstring(stmt.Limit))
420 | }
421 | }
422 |
423 | if offset+count > int64(len(r.Values)) {
424 | count = int64(len(r.Values)) - offset
425 | }
426 |
427 | r.Values = r.Values[offset : offset+count]
428 | r.RowDatas = r.RowDatas[offset : offset+count]
429 |
430 | return nil
431 | }
432 |
--------------------------------------------------------------------------------