├── .gitignore ├── .gitmodules ├── .travis.yml ├── Godeps ├── Godeps.json └── Readme ├── LICENSE ├── README.md ├── build.sh ├── cmd ├── dbatman │ ├── .gitignore │ └── main.go └── version │ └── version.go ├── config ├── config.go ├── config_test.go ├── proxy.yml └── test.yml ├── database ├── cluster │ └── driver.go ├── mysql │ ├── benchmark_test.go │ ├── buffer.go │ ├── charset.go │ ├── collations.go │ ├── connection.go │ ├── const.go │ ├── const_public.go │ ├── convert.go │ ├── driver.go │ ├── driver_test.go │ ├── dsn.go │ ├── dsn_test.go │ ├── errcode.go │ ├── errname.go │ ├── errors.go │ ├── errors_test.go │ ├── infile.go │ ├── interface.go │ ├── packets.go │ ├── pool.go │ ├── result.go │ ├── rows.go │ ├── server_conn.go │ ├── server_conn_test.go │ ├── state.go │ ├── statement.go │ ├── transaction.go │ ├── utils.go │ └── utils_test.go └── sql │ └── driver │ ├── driver.go │ ├── types.go │ └── types_test.go ├── docs ├── config_example.txt ├── internal │ ├── COM_QUERY.png │ ├── CS Protocol.xmind │ ├── Command.jpg │ ├── CommandLifeCycle.jpg │ ├── ConfigReloadFlow.png │ ├── ConfigStructure.png │ ├── ConnectionFlow.png │ ├── ConnectionLifecycle.jpg │ ├── Design.md │ ├── ListenRoutine.jpg │ ├── MySQL请求时序图.jpg │ ├── RWSplit.jpg │ ├── RW_Splite.xmind │ ├── design.mdj │ └── proxy_user_case.jpg ├── mysql-proxy │ ├── protocol.txt │ └── scripting.txt └── protocol.txt ├── hack ├── hack.go └── hack_test.go ├── parser ├── .gitignore ├── Makefile ├── ast.go ├── ast_alter.go ├── ast_compound.go ├── ast_create.go ├── ast_dal.go ├── ast_ddl.go ├── ast_dml.go ├── ast_drop.go ├── ast_expr.go ├── ast_prepare.go ├── ast_replication.go ├── ast_show.go ├── ast_table.go ├── ast_trans.go ├── ast_util.go ├── bin │ └── yacc ├── charset │ ├── charset.go │ ├── charset_test.go │ └── utf8_general_cli.go ├── debug.go ├── lex.go ├── lex_ident.go ├── lex_ident_test.go ├── lex_keywords.go ├── lex_keywords_test.go ├── lex_nchar.go ├── lex_number.go ├── lex_number_test.go ├── lex_test.go ├── lex_text.go ├── lex_text_test.go ├── lex_var_test.go ├── parser.go ├── parser_dal_test.go ├── parser_ddl_test.go ├── parser_dml_test.go ├── parser_test.go ├── parser_trans_test.go ├── parser_util_test.go ├── sql_yacc.go ├── sql_yacc.prf ├── sql_yacc.yy ├── state │ └── state.go └── test.sh ├── pool ├── .gitignore ├── slice.go ├── slice1.go ├── slice1_test.go ├── slice_test.go └── utils.go ├── proxy ├── auth.go ├── com_query.go ├── com_query_test.go ├── conn.go ├── conn_resultset.go ├── conn_select.go ├── conn_set.go ├── conn_show.go ├── conn_show_test.go ├── conn_stmt.go ├── conn_stmt_test.go ├── conn_test.go ├── conn_tx.go ├── conn_tx_test.go ├── debug.go ├── dispatch.go ├── dispatch_test.go ├── rows.go ├── server.go ├── server_test.go ├── session.go └── signal.go ├── run.sh ├── systest.sh ├── unitest.sh └── wercker.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | *.swp 27 | 28 | *.out 29 | 30 | # ignore IntelliJ IDEA config 31 | .idea 32 | 33 | *.log 34 | coverage.txt 35 | 36 | git.sh 37 | output 38 | GoDeps/_workspace/src 39 | 40 | mysql-workbench 41 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mysql-server"] 2 | path = mysql-server 3 | url = https://github.com/mysql/mysql-server.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | sudo: required 3 | language: go 4 | go: 5 | - 1.5 6 | - 1.6 7 | 8 | addons: 9 | apt: 10 | packages: 11 | - mysql-server-5.5 12 | - mysql-client-core-5.5 13 | - mysql-client-5.5 14 | - mysql-testsuite 15 | 16 | before_install: 17 | - git submodule update --init --recursive 18 | - go get github.com/onsi/gomega 19 | - go get github.com/onsi/ginkgo 20 | - go get golang.org/x/tools/cmd/cover 21 | 22 | before_script: 23 | - mysql -u root -e 'create database gotest;' 24 | - mysql -u root -e 'create database dbatman_test;' 25 | 26 | script: 27 | - ./unitest.sh 28 | #- go build ./... 29 | #- ./cmd/dbatman/dbatman -config config/test.yml & 30 | #- mysql_client_test -uproxy_test_user -ptest -h127.0.0.1 -P4306 -Dclient_test_db 31 | 32 | after_success: 33 | - bash <(curl -s https://codecov.io/bash) 34 | -------------------------------------------------------------------------------- /Godeps/Godeps.json: -------------------------------------------------------------------------------- 1 | { 2 | "ImportPath": "github.com/bytedance/dbatman", 3 | "GoVersion": "go1.5", 4 | "Packages": [ 5 | "./..." 6 | ], 7 | "Deps": [ 8 | { 9 | "ImportPath": "github.com/go-sql-driver/mysql", 10 | "Comment": "v1.2-194-g7ebe0a5", 11 | "Rev": "7ebe0a500653eeb1859664bed5e48dec1e164e73" 12 | }, 13 | { 14 | "ImportPath": "github.com/juju/errors", 15 | "Rev": "b2c7a7da5b2995941048f60146e67702a292e468" 16 | }, 17 | { 18 | "ImportPath": "github.com/ngaut/log", 19 | "Rev": "37d3e0f43b4fe05429e1adb75e835bf31fc1bba6" 20 | }, 21 | { 22 | "ImportPath": "gopkg.in/yaml.v2", 23 | "Rev": "f7716cbe52baa25d2e9b0d0da546fcf909fc16b4" 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /Godeps/Readme: -------------------------------------------------------------------------------- 1 | This directory tree is generated automatically by godep. 2 | 3 | Please do not edit. 4 | 5 | See https://github.com/tools/godep for more information. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | // Copyright 2016 ByteDance, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Dbatman - A MySQL Proxy 2 | ============== 3 | 4 | [![Build Status](https://travis-ci.org/bytedance/dbatman.svg?branch=master)](https://travis-ci.org/bytedance/dbatman) 5 | [![codecov](https://codecov.io/gh/bytedance/dbatman/branch/master/graph/badge.svg)](https://codecov.io/gh/bytedance/dbatman) 6 | 7 | 8 | ## Work In Progress 9 | 10 | This project is under developing and not ready for production yet. 11 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | which godep || go get github.com/tools/godep 6 | 7 | godep restore 8 | 9 | cd cmd/dbatman && go build && cd - 10 | 11 | mkdir -p output 12 | 13 | cp cmd/dbatman/dbatman ./output 14 | cp config/proxy.yml ./output 15 | cp config/test.yml ./output 16 | -------------------------------------------------------------------------------- /cmd/dbatman/.gitignore: -------------------------------------------------------------------------------- 1 | main 2 | dbatman 3 | -------------------------------------------------------------------------------- /cmd/dbatman/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | _ "net/http/pprof" 7 | "os" 8 | "os/exec" 9 | "os/signal" 10 | "path/filepath" 11 | "runtime" 12 | "strings" 13 | "syscall" 14 | 15 | "github.com/bytedance/dbatman/config" 16 | "github.com/bytedance/dbatman/database/cluster" 17 | "github.com/bytedance/dbatman/database/mysql" 18 | "github.com/bytedance/dbatman/proxy" 19 | "github.com/ngaut/log" 20 | ) 21 | 22 | var ( 23 | configFile *string = flag.String("config", getCurrentDir()+"/proxy.yml", "go mysql proxy config file") 24 | logLevel *int = flag.Int("loglevel", 0, "0-debug| 1-notice|2-warn|3-fatal") 25 | logFile *string = flag.String("logfile", getCurrentDir()+"/proxy.log", "go mysql proxy logfile") 26 | gcLevel *string = flag.String("gclevel", "500", "go gc level") 27 | ) 28 | 29 | func substr(s string, pos, length int) string { 30 | runes := []rune(s) 31 | l := pos + length 32 | if l > len(runes) { 33 | l = len(runes) 34 | } 35 | return string(runes[pos:l]) 36 | } 37 | func getCurrentDir() string { 38 | file, _ := exec.LookPath(os.Args[0]) 39 | path, _ := filepath.Abs(file) 40 | path1 := substr(path, 0, strings.LastIndex(path, "/")) 41 | return path1 42 | } 43 | 44 | func main() { 45 | 46 | runtime.GOMAXPROCS(runtime.NumCPU()) 47 | runtime.SetBlockProfileRate(1) 48 | os.Setenv("GOGC", "100") 49 | log.SetOutputByName(*logFile) 50 | flag.Parse() //parse tue input argument 51 | println(*logFile) 52 | println(*configFile) 53 | 54 | if len(*configFile) == 0 { 55 | log.Fatal("must use a config file") 56 | os.Exit(1) 57 | } 58 | 59 | cfg, err := config.LoadConfig(*configFile) 60 | if err != nil { 61 | log.Fatal(err.Error()) 62 | os.Exit(1) 63 | } 64 | 65 | if err = cluster.Init(cfg); err != nil { 66 | log.Fatal(err.Error()) 67 | os.Exit(1) 68 | } 69 | 70 | mysql.SetLogger(log.Logger()) 71 | 72 | go func() { 73 | err := cluster.DisasterControl() 74 | if err != nil { 75 | log.Warn(err) 76 | } 77 | }() 78 | go func() { 79 | // log.info("start checking config file") 80 | cfg.CheckConfigUpdate(cluster.NotifyChan) 81 | }() 82 | 83 | sc := make(chan os.Signal, 1) 84 | Restart := make(chan os.Signal, 1) 85 | signal.Notify(Restart, syscall.SIGUSR1) 86 | signal.Notify(sc, syscall.SIGQUIT, 87 | syscall.SIGHUP, 88 | syscall.SIGINT, 89 | syscall.SIGTERM) 90 | 91 | var svr *proxy.Server 92 | svr, err = proxy.NewServer(cfg) 93 | if err != nil { 94 | log.Fatal(err.Error()) 95 | os.Exit(1) 96 | } 97 | 98 | //port for go pprof Debug 99 | go func() { 100 | http.ListenAndServe(":11888", nil) 101 | }() 102 | 103 | go func() { 104 | select { 105 | case sig := <-sc: 106 | log.Infof("Got signal [%d] to exit.", sig) 107 | svr.Close() 108 | case sig := <-Restart: 109 | log.Infof("Got signal [%d] to Restart.", sig) 110 | svr.Restart() 111 | } 112 | }() 113 | 114 | svr.Serve() 115 | os.Exit(0) 116 | 117 | } 118 | -------------------------------------------------------------------------------- /cmd/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | var Version string = "5.6.24-72.2-log" 4 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestConfig(t *testing.T) { 10 | conf, err := LoadConfig("./proxy.yml") 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | cfg := conf.GetConfig() 15 | 16 | globalConfig := GlobalConfig{ 17 | Port: 3306, 18 | ManagePort: 3307, 19 | MaxConnections: 10, 20 | LogLevel: 1, 21 | LogFilename: "./log/dbatman.log", 22 | LogMaxSize: 1024, 23 | ClientTimeout: 1800, 24 | ServerTimeout: 1800, 25 | WriteTimeInterval: 10, 26 | ConfAutoload: 1, 27 | AuthIPActive: false, 28 | ReqRate: 1000, 29 | ReqBurst: 2000, 30 | AuthIPs: []string{"10.4.64.1", "10.4.64.2"}, 31 | } 32 | 33 | masterNode := NodeConfig{ 34 | Host: "10.4.4.4", 35 | Port: 3307, 36 | Username: "pgc", 37 | Password: "pgc", 38 | DBName: "pgc", 39 | Charset: "utf8mb4", 40 | MaxConnections: 100, 41 | MaxConnectionPoolSize: 10, 42 | ConnectTimeout: 10, 43 | TimeReconnectInterval: 10, 44 | Weight: 1, 45 | } 46 | 47 | userNode := UserConfig{ 48 | Username: "proxy_pgc_user", 49 | Password: "pgc", 50 | MaxConnections: 1000, 51 | MinConnections: 100, 52 | DBName: "pgc", 53 | Charset: "utf8mb4", 54 | ClusterName: "pgc_cluster", 55 | AuthIPs: []string{"10.1.1.1", "10.1.1.2"}, 56 | BlackListIPs: []string{"10.1.1.3", "10.1.1.4"}, 57 | } 58 | 59 | if !reflect.DeepEqual(cfg.Global, &globalConfig) { 60 | fmt.Printf("%v\n", globalConfig) 61 | t.Fatal("global must equal") 62 | } 63 | 64 | master, _ := cfg.GetMasterNodefromClusterByName("pgc_cluster") 65 | if !reflect.DeepEqual(master, &masterNode) { 66 | fmt.Printf("%v\n", masterNode) 67 | t.Fatal("master must equal") 68 | } 69 | 70 | u, _ := cfg.GetUserByName("proxy_pgc_user") 71 | if !reflect.DeepEqual(u, &userNode) { 72 | fmt.Printf("%v\n", userNode) 73 | t.Fatal("user must equal") 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /config/proxy.yml: -------------------------------------------------------------------------------- 1 | global: 2 | port: 3306 3 | manage_port: 3307 4 | max_connections: 10 5 | log_filename: ./log/dbatman.log 6 | log_level: 1 7 | log_maxsize: 1024 8 | log_query_min_time: 0 9 | client_timeout: 1800 10 | server_timeout: 1800 11 | write_time_interval: 10 12 | conf_autoload: 1 13 | authip_active: false 14 | auth_ips: 15 | - 10.4.64.1 16 | - 10.4.64.2 17 | 18 | clusters: 19 | pgc_cluster: 20 | master: 21 | host: 10.4.4.4 22 | port: 3307 23 | username: pgc 24 | password: pgc 25 | dbname: pgc 26 | charset: utf8mb4 27 | max_connections: 100 28 | max_connection_pool_size: 10 29 | connect_timeout: 10 30 | time_reconnect_interval: 10 31 | weight: 1 32 | slaves: 33 | - host: 10.4.4.2 34 | port: 3306 35 | username: pgc 36 | password: pgc 37 | dbname: pgc 38 | charset: utf8mb4 39 | max_connections: 100 40 | max_connection_pool_size: 10 41 | connect_timeout: 10 42 | time_reconnect_interval: 10 43 | weight: 1 44 | - host: 10.4.4.3 45 | port: 3306 46 | username: pgc 47 | password: pgc 48 | dbname: pgc 49 | charset: utf8mb4 50 | max_connections: 100 51 | max_connection_pool_size: 10 52 | connect_timeout: 10 53 | time_reconnect_interval: 10 54 | weight: 1 55 | 56 | users: 57 | proxy_pgc_user: 58 | username: proxy_pgc_user 59 | password: pgc 60 | max_connections: 1000 61 | min_connections: 100 62 | dbname: pgc 63 | charset: utf8mb4 64 | cluster_name: pgc_cluster 65 | auth_ips: 66 | - 10.1.1.1 67 | - 10.1.1.2 68 | black_list_ips: 69 | - 10.1.1.3 70 | - 10.1.1.4 71 | 72 | 73 | -------------------------------------------------------------------------------- /config/test.yml: -------------------------------------------------------------------------------- 1 | global: 2 | port: 4306 3 | manage_port: 4307 4 | max_connections: 10 5 | log_filename: ./log/dbatman.log 6 | log_level: 1 7 | log_maxsize: 1024 8 | log_query_min_time: 0 9 | client_timeout: 1800 10 | server_timeout: 1800 11 | write_time_interval: 10 12 | conf_autoload: 1 13 | authip_active: false 14 | auth_ips: 15 | - 10.4.64.1 16 | - 10.4.64.2 17 | clusters: 18 | test_cluster: 19 | master: 20 | host: 127.0.0.1 21 | port: 3306 22 | username: root 23 | password: 24 | dbname: client_test_db 25 | charset: utf8mb4 26 | weight: 1 27 | max_connections: 100 28 | max_connection_pool_size: 10 29 | connect_timeout: 10 30 | time_reconnect_interval: 10 31 | slaves: 32 | - host: 127.0.0.1 33 | port: 3306 34 | username: root 35 | password: 36 | dbname: client_test_db 37 | charset: utf8mb4 38 | weight: 1 39 | max_connections: 100 40 | max_connection_pool_size: 10 41 | connect_timeout: 10 42 | time_reconnect_interval: 10 43 | users: 44 | proxy_test_user: 45 | username: proxy_test_user 46 | password: test 47 | dbname: client_test_db 48 | charset: utf8mb4 49 | max_connections: 1000 50 | min_connections: 100 51 | cluster_name: test_cluster 52 | auth_ips: 53 | - 10.1.1.1 54 | - 10.1.1.2 55 | black_list_ips: 56 | - 10.1.1.3 57 | - 10.1.1.4 58 | -------------------------------------------------------------------------------- /database/mysql/benchmark_test.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "bytes" 13 | //"github.com/bytedance/dbatman/database/sql" 14 | "github.com/bytedance/dbatman/database/sql/driver" 15 | "math" 16 | "strings" 17 | "sync" 18 | "sync/atomic" 19 | "testing" 20 | "time" 21 | ) 22 | 23 | type TB testing.B 24 | 25 | func (tb *TB) check(err error) { 26 | if err != nil { 27 | tb.Fatal(err) 28 | } 29 | } 30 | 31 | func (tb *TB) checkDB(db *DB, err error) *DB { 32 | tb.check(err) 33 | return db 34 | } 35 | 36 | func (tb *TB) checkRows(rows Rows, err error) Rows { 37 | tb.check(err) 38 | return rows 39 | } 40 | 41 | func (tb *TB) checkStmt(stmt *Stmt, err error) *Stmt { 42 | tb.check(err) 43 | return stmt 44 | } 45 | 46 | func initDB(b *testing.B, queries ...string) *DB { 47 | tb := (*TB)(b) 48 | db := tb.checkDB(Open("mysql", dsn)) 49 | for _, query := range queries { 50 | if _, err := db.Exec(query); err != nil { 51 | if w, ok := err.(MySQLWarnings); ok { 52 | b.Logf("warning on %q: %v", query, w) 53 | } else { 54 | b.Fatalf("error on %q: %v", query, err) 55 | } 56 | } 57 | } 58 | return db 59 | } 60 | 61 | const concurrencyLevel = 10 62 | 63 | func BenchmarkQuery(b *testing.B) { 64 | tb := (*TB)(b) 65 | b.StopTimer() 66 | b.ReportAllocs() 67 | db := initDB(b, 68 | "DROP TABLE IF EXISTS foo", 69 | "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", 70 | `INSERT INTO foo VALUES (1, "one")`, 71 | `INSERT INTO foo VALUES (2, "two")`, 72 | ) 73 | db.SetMaxIdleConns(concurrencyLevel) 74 | defer db.Close() 75 | 76 | stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) 77 | defer stmt.Close() 78 | 79 | remain := int64(b.N) 80 | var wg sync.WaitGroup 81 | wg.Add(concurrencyLevel) 82 | defer wg.Wait() 83 | b.StartTimer() 84 | 85 | for i := 0; i < concurrencyLevel; i++ { 86 | go func() { 87 | for { 88 | if atomic.AddInt64(&remain, -1) < 0 { 89 | wg.Done() 90 | return 91 | } 92 | 93 | var got string 94 | tb.check(stmt.QueryRow(1).Scan(&got)) 95 | if got != "one" { 96 | b.Errorf("query = %q; want one", got) 97 | wg.Done() 98 | return 99 | } 100 | } 101 | }() 102 | } 103 | } 104 | 105 | func BenchmarkExec(b *testing.B) { 106 | tb := (*TB)(b) 107 | b.StopTimer() 108 | b.ReportAllocs() 109 | db := tb.checkDB(Open("mysql", dsn)) 110 | db.SetMaxIdleConns(concurrencyLevel) 111 | defer db.Close() 112 | 113 | stmt := tb.checkStmt(db.Prepare("DO 1")) 114 | defer stmt.Close() 115 | 116 | remain := int64(b.N) 117 | var wg sync.WaitGroup 118 | wg.Add(concurrencyLevel) 119 | defer wg.Wait() 120 | b.StartTimer() 121 | 122 | for i := 0; i < concurrencyLevel; i++ { 123 | go func() { 124 | for { 125 | if atomic.AddInt64(&remain, -1) < 0 { 126 | wg.Done() 127 | return 128 | } 129 | 130 | if _, err := stmt.Exec(); err != nil { 131 | b.Fatal(err.Error()) 132 | } 133 | } 134 | }() 135 | } 136 | } 137 | 138 | // data, but no db writes 139 | var roundtripSample []byte 140 | 141 | func initRoundtripBenchmarks() ([]byte, int, int) { 142 | if roundtripSample == nil { 143 | roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) 144 | } 145 | return roundtripSample, 16, len(roundtripSample) 146 | } 147 | 148 | func BenchmarkRoundtripTxt(b *testing.B) { 149 | b.StopTimer() 150 | sample, min, max := initRoundtripBenchmarks() 151 | sampleString := string(sample) 152 | b.ReportAllocs() 153 | tb := (*TB)(b) 154 | db := tb.checkDB(Open("mysql", dsn)) 155 | defer db.Close() 156 | b.StartTimer() 157 | var result string 158 | for i := 0; i < b.N; i++ { 159 | length := min + i 160 | if length > max { 161 | length = max 162 | } 163 | test := sampleString[0:length] 164 | rows := tb.checkRows(db.Query(`SELECT "` + test + `"`)) 165 | if !rows.Next() { 166 | rows.Close() 167 | b.Fatalf("crashed") 168 | } 169 | err := rows.Scan(&result) 170 | if err != nil { 171 | rows.Close() 172 | b.Fatalf("crashed") 173 | } 174 | if result != test { 175 | rows.Close() 176 | b.Errorf("mismatch") 177 | } 178 | rows.Close() 179 | } 180 | } 181 | 182 | func BenchmarkRoundtripBin(b *testing.B) { 183 | b.StopTimer() 184 | sample, min, max := initRoundtripBenchmarks() 185 | b.ReportAllocs() 186 | tb := (*TB)(b) 187 | db := tb.checkDB(Open("mysql", dsn)) 188 | defer db.Close() 189 | stmt := tb.checkStmt(db.Prepare("SELECT ?")) 190 | defer stmt.Close() 191 | b.StartTimer() 192 | var result RawBytes 193 | for i := 0; i < b.N; i++ { 194 | length := min + i 195 | if length > max { 196 | length = max 197 | } 198 | test := sample[0:length] 199 | rows := tb.checkRows(stmt.Query(test)) 200 | if !rows.Next() { 201 | rows.Close() 202 | b.Fatalf("crashed") 203 | } 204 | err := rows.Scan(&result) 205 | if err != nil { 206 | rows.Close() 207 | b.Fatalf("crashed") 208 | } 209 | if !bytes.Equal(result, test) { 210 | rows.Close() 211 | b.Errorf("mismatch") 212 | } 213 | rows.Close() 214 | } 215 | } 216 | 217 | func BenchmarkInterpolation(b *testing.B) { 218 | mc := &MySQLConn{ 219 | cfg: &Config{ 220 | InterpolateParams: true, 221 | Loc: time.UTC, 222 | }, 223 | maxPacketAllowed: maxPacketSize, 224 | maxWriteSize: maxPacketSize - 1, 225 | buf: newBuffer(nil), 226 | } 227 | 228 | args := []driver.Value{ 229 | int64(42424242), 230 | float64(math.Pi), 231 | false, 232 | time.Unix(1423411542, 807015000), 233 | []byte("bytes containing special chars ' \" \a \x00"), 234 | "string containing special chars ' \" \a \x00", 235 | } 236 | q := "SELECT ?, ?, ?, ?, ?, ?" 237 | 238 | b.ReportAllocs() 239 | b.ResetTimer() 240 | for i := 0; i < b.N; i++ { 241 | _, err := mc.interpolateParams(q, args) 242 | if err != nil { 243 | b.Fatal(err) 244 | } 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /database/mysql/buffer.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "io" 13 | "net" 14 | "time" 15 | 16 | "github.com/ngaut/log" 17 | ) 18 | 19 | const defaultBufSize = 4096 * 3 20 | const checkBrokenReadTimeoutStr = "100us" 21 | 22 | // A buffer which is used for both reading and writing. 23 | // This is possible since communication on each connection is synchronous. 24 | // In other words, we can't write and read simultaneously on the same connection. 25 | // The buffer is similar to bufio.Reader / Writer but zero-copy-ish 26 | // Also highly optimized for this particular use case. 27 | type buffer struct { 28 | buf []byte 29 | nc net.Conn 30 | idx int 31 | length int 32 | timeout time.Duration 33 | } 34 | 35 | func newBuffer(nc net.Conn) buffer { 36 | var b [defaultBufSize]byte 37 | return buffer{ 38 | buf: b[:], 39 | nc: nc, 40 | } 41 | } 42 | 43 | // read io.EOF or other exception when connection closed 44 | func (b *buffer) isBroken() bool { 45 | if b.nc == nil { 46 | log.Warn("the conn become's nile return") 47 | return true 48 | } 49 | timeout, _ := time.ParseDuration(checkBrokenReadTimeoutStr) 50 | if err := b.nc.SetReadDeadline(time.Now().Add(timeout)); err != nil { 51 | return true 52 | } 53 | buf := make([]byte, 1) 54 | _, err := b.nc.Read(buf) 55 | 56 | //restore read dead line 57 | var zeroTime time.Time 58 | b.nc.SetReadDeadline(zeroTime) 59 | 60 | //only timeout represents the connection is alive 61 | if oe, ok := err.(*net.OpError); ok { 62 | return !oe.Timeout() 63 | } 64 | return true 65 | } 66 | 67 | // fill reads into the buffer until at least _need_ bytes are in it 68 | func (b *buffer) fill(need int) error { 69 | n := b.length 70 | 71 | // move existing data to the beginning 72 | if n > 0 && b.idx > 0 { 73 | copy(b.buf[0:n], b.buf[b.idx:]) 74 | } 75 | 76 | // grow buffer if necessary 77 | // TODO: let the buffer shrink again at some point 78 | // Maybe keep the org buf slice and swap back? 79 | if need > len(b.buf) { 80 | // Round up to the next multiple of the default size 81 | newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) 82 | copy(newBuf, b.buf) 83 | b.buf = newBuf 84 | } 85 | 86 | b.idx = 0 87 | 88 | for { 89 | if b.timeout > 0 { 90 | if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { 91 | return err 92 | } 93 | } 94 | 95 | nn, err := b.nc.Read(b.buf[n:]) 96 | n += nn 97 | // log.Warnf("current read num: %d ,need,", n, need) 98 | 99 | switch err { 100 | case nil: 101 | if n < need { 102 | continue 103 | } 104 | b.length = n 105 | return nil 106 | 107 | case io.EOF: 108 | if n >= need { 109 | b.length = n 110 | return nil 111 | } 112 | 113 | log.Debugf("need: %d, readed: %d", need, n) 114 | return io.ErrUnexpectedEOF 115 | 116 | default: 117 | return err 118 | } 119 | } 120 | } 121 | 122 | // returns next N bytes from buffer. 123 | // The returned slice is only guaranteed to be valid until the next read 124 | func (b *buffer) readNext(need int) ([]byte, error) { 125 | if b.length < need { 126 | // refill 127 | if err := b.fill(need); err != nil { 128 | return nil, err 129 | } 130 | } 131 | 132 | offset := b.idx 133 | b.idx += need 134 | b.length -= need 135 | return b.buf[offset:b.idx], nil 136 | } 137 | 138 | // returns a buffer with the requested size. 139 | // If possible, a slice from the existing buffer is returned. 140 | // Otherwise a bigger buffer is made. 141 | // Only one buffer (total) can be used at a time. 142 | func (b *buffer) takeBuffer(length int) []byte { 143 | if b.length > 0 { 144 | return nil 145 | } 146 | 147 | // test (cheap) general case first 148 | if length <= defaultBufSize || length <= cap(b.buf) { 149 | return b.buf[:length] 150 | } 151 | 152 | if length < maxPacketSize { 153 | b.buf = make([]byte, length) 154 | return b.buf 155 | } 156 | return make([]byte, length) 157 | } 158 | 159 | // shortcut which can be used if the requested buffer is guaranteed to be 160 | // smaller than defaultBufSize 161 | // Only one buffer (total) can be used at a time. 162 | func (b *buffer) takeSmallBuffer(length int) []byte { 163 | if b.length == 0 { 164 | return b.buf[:length] 165 | } 166 | return nil 167 | } 168 | 169 | // takeCompleteBuffer returns the complete existing buffer. 170 | // This can be used if the necessary buffer size is unknown. 171 | // Only one buffer (total) can be used at a time. 172 | func (b *buffer) takeCompleteBuffer() []byte { 173 | if b.length == 0 { 174 | return b.buf 175 | } 176 | return nil 177 | } 178 | -------------------------------------------------------------------------------- /database/mysql/const.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | const ( 12 | minProtocolVersion byte = 10 13 | maxPacketSize = 1<<24 - 1 14 | timeFormat = "2006-01-02 15:04:05.999999" 15 | PacketHeaderLen int = 4 16 | ) 17 | 18 | // MySQL constants documentation: 19 | // http://dev.mysql.com/doc/internals/en/client-server-protocol.html 20 | 21 | const ( 22 | iOK byte = 0x00 23 | iLocalInFile byte = 0xfb 24 | iEOF byte = 0xfe 25 | iERR byte = 0xff 26 | ) 27 | 28 | // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags 29 | type clientFlag uint32 30 | 31 | const ( 32 | clientLongPassword clientFlag = 1 << iota 33 | clientFoundRows 34 | clientLongFlag 35 | clientConnectWithDB 36 | clientNoSchema 37 | clientCompress 38 | clientODBC 39 | clientLocalFiles 40 | clientIgnoreSpace 41 | clientProtocol41 42 | clientInteractive 43 | clientSSL 44 | clientIgnoreSIGPIPE 45 | clientTransactions 46 | clientReserved 47 | clientSecureConn 48 | clientMultiStatements 49 | clientMultiResults 50 | clientPSMultiResults 51 | clientPluginAuth 52 | clientConnectAttrs 53 | clientPluginAuthLenEncClientData 54 | clientCanHandleExpiredPasswords 55 | clientSessionTrack 56 | clientDeprecateEOF 57 | ) 58 | 59 | const ( 60 | comQuit byte = iota + 1 61 | comInitDB 62 | comQuery 63 | comFieldList 64 | comCreateDB 65 | comDropDB 66 | comRefresh 67 | comShutdown 68 | comStatistics 69 | comProcessInfo 70 | comConnect 71 | comProcessKill 72 | comDebug 73 | comPing 74 | comTime 75 | comDelayedInsert 76 | comChangeUser 77 | comBinlogDump 78 | comTableDump 79 | comConnectOut 80 | comRegisterSlave 81 | comStmtPrepare 82 | comStmtExecute 83 | comStmtSendLongData 84 | comStmtClose 85 | comStmtReset 86 | comSetOption 87 | comStmtFetch 88 | ) 89 | 90 | // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType 91 | const ( 92 | fieldTypeDecimal byte = iota 93 | fieldTypeTiny 94 | fieldTypeShort 95 | fieldTypeLong 96 | fieldTypeFloat 97 | fieldTypeDouble 98 | fieldTypeNULL 99 | fieldTypeTimestamp 100 | fieldTypeLongLong 101 | fieldTypeInt24 102 | fieldTypeDate 103 | fieldTypeTime 104 | fieldTypeDateTime 105 | fieldTypeYear 106 | fieldTypeNewDate 107 | fieldTypeVarChar 108 | fieldTypeBit 109 | ) 110 | const ( 111 | fieldTypeJSON byte = iota + 0xf5 112 | fieldTypeNewDecimal 113 | fieldTypeEnum 114 | fieldTypeSet 115 | fieldTypeTinyBLOB 116 | fieldTypeMediumBLOB 117 | fieldTypeLongBLOB 118 | fieldTypeBLOB 119 | fieldTypeVarString 120 | fieldTypeString 121 | fieldTypeGeometry 122 | ) 123 | 124 | type fieldFlag uint16 125 | 126 | const ( 127 | flagNotNULL fieldFlag = 1 << iota 128 | flagPriKey 129 | flagUniqueKey 130 | flagMultipleKey 131 | flagBLOB 132 | flagUnsigned 133 | flagZeroFill 134 | flagBinary 135 | flagEnum 136 | flagAutoIncrement 137 | flagTimestamp 138 | flagSet 139 | flagUnknown1 140 | flagUnknown2 141 | flagUnknown3 142 | flagUnknown4 143 | ) 144 | 145 | // http://dev.mysql.com/doc/internals/en/status-flags.html 146 | type statusFlag uint16 147 | 148 | const ( 149 | statusInTrans statusFlag = 1 << iota 150 | statusInAutocommit 151 | statusReserved // Not in documentation 152 | statusMoreResultsExists 153 | statusNoGoodIndexUsed 154 | statusNoIndexUsed 155 | statusCursorExists 156 | statusLastRowSent 157 | statusDbDropped 158 | statusNoBackslashEscapes 159 | statusMetadataChanged 160 | statusQueryWasSlow 161 | statusPsOutParams 162 | statusInTransReadonly 163 | statusSessionStateChanged 164 | ) 165 | -------------------------------------------------------------------------------- /database/mysql/driver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 2 | // 3 | // This Source Code Form is subject to the terms of the Mozilla Public 4 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 | // You can obtain one at http://mozilla.org/MPL/2.0/. 6 | 7 | // Package mysql provides a MySQL driver for Go's database/sql package 8 | // 9 | // The driver should be used via the database/sql package: 10 | // 11 | // import "database/sql" 12 | // import _ "github.com/go-sql-driver/mysql" 13 | // 14 | // db, err := sql.Open("mysql", "user:password@/dbname") 15 | // 16 | // See https://github.com/go-sql-driver/mysql#usage for details 17 | package mysql 18 | 19 | import ( 20 | "net" 21 | 22 | "github.com/bytedance/dbatman/database/sql/driver" 23 | ) 24 | 25 | // MySQLDriver is exported to make the driver directly accessible. 26 | // In general the driver is used via the database/sql package. 27 | type MySQLDriver struct{} 28 | 29 | // DialFunc is a function which can be used to establish the network connection. 30 | // Custom dial functions must be registered with RegisterDial 31 | type DialFunc func(addr string) (net.Conn, error) 32 | 33 | var dials map[string]DialFunc 34 | 35 | // RegisterDial registers a custom dial function. It can then be used by the 36 | // network address mynet(addr), where mynet is the registered new network. 37 | // addr is passed as a parameter to the dial function. 38 | func RegisterDial(net string, dial DialFunc) { 39 | if dials == nil { 40 | dials = make(map[string]DialFunc) 41 | } 42 | dials[net] = dial 43 | } 44 | 45 | // Open new Connection. 46 | // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how 47 | // the DSN string is formated 48 | func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { 49 | var err error 50 | 51 | // New MySQLConn 52 | mc := &MySQLConn{ 53 | maxPacketAllowed: maxPacketSize, 54 | maxWriteSize: maxPacketSize - 1, 55 | } 56 | mc.cfg, err = ParseDSN(dsn) 57 | if err != nil { 58 | return nil, err 59 | } 60 | mc.parseTime = mc.cfg.ParseTime 61 | mc.strict = mc.cfg.Strict 62 | 63 | // Connect to Server 64 | if dial, ok := dials[mc.cfg.Net]; ok { 65 | mc.netConn, err = dial(mc.cfg.Addr) 66 | } else { 67 | nd := net.Dialer{Timeout: mc.cfg.Timeout} 68 | mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) 69 | } 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | // Enable TCP Keepalives on TCP connections 75 | if tc, ok := mc.netConn.(*net.TCPConn); ok { 76 | if err := tc.SetKeepAlive(true); err != nil { 77 | // Don't send COM_QUIT before handshake. 78 | mc.netConn.Close() 79 | mc.netConn = nil 80 | return nil, err 81 | } 82 | } 83 | //disable the no delay opt 84 | if tc, ok := mc.netConn.(*net.TCPConn); ok { 85 | if err := tc.SetNoDelay(false); err != nil { 86 | // Don't send COM_QUIT before handshake. 87 | mc.netConn.Close() 88 | mc.netConn = nil 89 | return nil, err 90 | } 91 | } 92 | 93 | mc.buf = newBuffer(mc.netConn) 94 | 95 | // Set I/O timeouts 96 | mc.buf.timeout = mc.cfg.ReadTimeout 97 | mc.writeTimeout = mc.cfg.WriteTimeout 98 | 99 | // Reading Handshake Initialization Packet 100 | cipher, threadId, err := mc.readInitPacket() 101 | if err != nil { 102 | mc.cleanup() 103 | return nil, err 104 | } 105 | mc.threadId = threadId 106 | 107 | // Send Client Authentication Packet 108 | if err = mc.writeAuthPacket(cipher); err != nil { 109 | mc.cleanup() 110 | return nil, err 111 | } 112 | 113 | // Handle response to auth packet, switch methods if possible 114 | if err = handleAuthResult(mc, cipher); err != nil { 115 | // Authentication failed and MySQL has already closed the connection 116 | // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). 117 | // Do not send COM_QUIT, just cleanup and return the error. 118 | mc.cleanup() 119 | return nil, err 120 | } 121 | 122 | // Get max allowed packet size 123 | maxap, err := mc.getSystemVar("max_allowed_packet") 124 | if err != nil { 125 | mc.Close() 126 | return nil, err 127 | } 128 | mc.maxPacketAllowed = stringToInt(maxap) - 1 129 | if mc.maxPacketAllowed < maxPacketSize { 130 | mc.maxWriteSize = mc.maxPacketAllowed 131 | } 132 | 133 | // Handle DSN Params 134 | err = mc.handleParams() 135 | if err != nil { 136 | mc.Close() 137 | return nil, err 138 | } 139 | 140 | return mc, nil 141 | } 142 | 143 | func handleAuthResult(mc *MySQLConn, cipher []byte) error { 144 | // Read Result Packet 145 | err := mc.readResultOK() 146 | if err == nil { 147 | return nil // auth successful 148 | } 149 | 150 | if mc.cfg == nil { 151 | return err // auth failed and retry not possible 152 | } 153 | 154 | // Retry auth if configured to do so. 155 | if mc.cfg.AllowOldPasswords && err == ErrOldPassword { 156 | // Retry with old authentication method. Note: there are edge cases 157 | // where this should work but doesn't; this is currently "wontfix": 158 | // https://github.com/go-sql-driver/mysql/issues/184 159 | if err = mc.writeOldAuthPacket(cipher); err != nil { 160 | return err 161 | } 162 | err = mc.readResultOK() 163 | } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { 164 | // Retry with clear text password for 165 | // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html 166 | // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html 167 | if err = mc.writeClearAuthPacket(); err != nil { 168 | return err 169 | } 170 | err = mc.readResultOK() 171 | } 172 | return err 173 | } 174 | 175 | func init() { 176 | Register("dbatman", &MySQLDriver{}) 177 | } 178 | -------------------------------------------------------------------------------- /database/mysql/errors.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "errors" 13 | "fmt" 14 | "github.com/bytedance/dbatman/database/sql/driver" 15 | "io" 16 | "log" 17 | "os" 18 | ) 19 | 20 | // Various errors the driver might return. Can change between driver versions. 21 | var ( 22 | ErrInvalidConn = errors.New("invalid connection") 23 | ErrMalformPkt = errors.New("malformed packet") 24 | ErrNoTLS = errors.New("TLS requested but server does not support TLS") 25 | ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") 26 | ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") 27 | ErrUnknownPlugin = errors.New("this authentication plugin is not supported") 28 | ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") 29 | ErrPktSync = errors.New("commands out of sync. You can't run this command now") 30 | ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") 31 | ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") 32 | ErrBusyBuffer = errors.New("busy buffer") 33 | ) 34 | 35 | var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) 36 | 37 | // Logger is used to log critical error messages. 38 | type Logger interface { 39 | Print(v ...interface{}) 40 | } 41 | 42 | // SetLogger is used to set the logger for critical errors. 43 | // The initial logger is os.Stderr. 44 | func SetLogger(logger Logger) error { 45 | if logger == nil { 46 | return errors.New("logger is nil") 47 | } 48 | errLog = logger 49 | return nil 50 | } 51 | 52 | // MySQLError is an error type which represents a single MySQL error 53 | type MySQLError struct { 54 | Number uint16 55 | Message string 56 | State string 57 | } 58 | 59 | func (me *MySQLError) Error() string { 60 | return fmt.Sprintf("Error %d (%s): %s", me.Number, me.State, me.Message) 61 | } 62 | 63 | //default mysql error, must adapt errname message format 64 | func NewDefaultError(number uint16, args ...interface{}) *MySQLError { 65 | e := new(MySQLError) 66 | e.Number = number 67 | 68 | if s, ok := MySQLState[number]; ok { 69 | e.State = s 70 | } else { 71 | e.State = DEFAULT_MYSQL_STATE 72 | } 73 | 74 | if format, ok := MySQLErrName[number]; ok { 75 | e.Message = fmt.Sprintf(format, args...) 76 | } else { 77 | e.Message = fmt.Sprint(args...) 78 | } 79 | 80 | return e 81 | } 82 | 83 | // MySQLWarnings is an error type which represents a group of one or more MySQL 84 | // warnings 85 | type MySQLWarnings []*MySQLWarning 86 | 87 | func (mws MySQLWarnings) Error() string { 88 | var msg string 89 | for i, warning := range mws { 90 | if i > 0 { 91 | msg += "\r\n" 92 | } 93 | msg += fmt.Sprintf( 94 | "%s %s: %s", 95 | warning.Level, 96 | warning.Code, 97 | warning.Message, 98 | ) 99 | } 100 | return msg 101 | } 102 | 103 | func (mws MySQLWarnings) Errors() []error { 104 | errs := make([]error, 0, len(mws)) 105 | 106 | for _, warning := range mws { 107 | errs = append(errs, warning) 108 | } 109 | 110 | return errs 111 | } 112 | 113 | // MySQLWarning is an error type which represents a single MySQL warning. 114 | // Warnings are returned in groups only. See MySQLWarnings 115 | type MySQLWarning struct { 116 | Level string 117 | Code string 118 | Message string 119 | } 120 | 121 | func (warning *MySQLWarning) Error() string { 122 | return fmt.Sprintf( 123 | "%s %s: %s", 124 | warning.Level, 125 | warning.Code, 126 | warning.Message, 127 | ) 128 | } 129 | 130 | func (mc *MySQLConn) getWarnings() (err error) { 131 | rows, err := mc.Query("SHOW WARNINGS", nil) 132 | if err != nil { 133 | return 134 | } 135 | 136 | var warnings = MySQLWarnings{} 137 | var values = make([]driver.Value, 3) 138 | 139 | for { 140 | err = rows.Next(values) 141 | switch err { 142 | case nil: 143 | warning := &MySQLWarning{} 144 | 145 | if raw, ok := values[0].([]byte); ok { 146 | warning.Level = string(raw) 147 | } else { 148 | warning.Level = fmt.Sprintf("%s", values[0]) 149 | } 150 | if raw, ok := values[1].([]byte); ok { 151 | warning.Code = string(raw) 152 | } else { 153 | warning.Code = fmt.Sprintf("%s", values[1]) 154 | } 155 | if raw, ok := values[2].([]byte); ok { 156 | warning.Message = string(raw) 157 | } else { 158 | warning.Message = fmt.Sprintf("%s", values[0]) 159 | } 160 | 161 | warnings = append(warnings, warning) 162 | 163 | case io.EOF: 164 | return warnings 165 | 166 | default: 167 | rows.Close() 168 | return 169 | } 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /database/mysql/errors_test.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "bytes" 13 | "log" 14 | "testing" 15 | ) 16 | 17 | func TestErrorsSetLogger(t *testing.T) { 18 | previous := errLog 19 | defer func() { 20 | errLog = previous 21 | }() 22 | 23 | // set up logger 24 | const expected = "prefix: test\n" 25 | buffer := bytes.NewBuffer(make([]byte, 0, 64)) 26 | logger := log.New(buffer, "prefix: ", 0) 27 | 28 | // print 29 | SetLogger(logger) 30 | errLog.Print("test") 31 | 32 | // check result 33 | if actual := buffer.String(); actual != expected { 34 | t.Errorf("expected %q, got %q", expected, actual) 35 | } 36 | } 37 | 38 | func TestErrorsStrictIgnoreNotes(t *testing.T) { 39 | runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { 40 | dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") 41 | }) 42 | } 43 | -------------------------------------------------------------------------------- /database/mysql/infile.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "fmt" 13 | "io" 14 | "os" 15 | "strings" 16 | "sync" 17 | ) 18 | 19 | var ( 20 | fileRegister map[string]bool 21 | fileRegisterLock sync.RWMutex 22 | readerRegister map[string]func() io.Reader 23 | readerRegisterLock sync.RWMutex 24 | ) 25 | 26 | // RegisterLocalFile adds the given file to the file whitelist, 27 | // so that it can be used by "LOAD DATA LOCAL INFILE ". 28 | // Alternatively you can allow the use of all local files with 29 | // the DSN parameter 'allowAllFiles=true' 30 | // 31 | // filePath := "/home/gopher/data.csv" 32 | // mysql.RegisterLocalFile(filePath) 33 | // err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") 34 | // if err != nil { 35 | // ... 36 | // 37 | func RegisterLocalFile(filePath string) { 38 | fileRegisterLock.Lock() 39 | // lazy map init 40 | if fileRegister == nil { 41 | fileRegister = make(map[string]bool) 42 | } 43 | 44 | fileRegister[strings.Trim(filePath, `"`)] = true 45 | fileRegisterLock.Unlock() 46 | } 47 | 48 | // DeregisterLocalFile removes the given filepath from the whitelist. 49 | func DeregisterLocalFile(filePath string) { 50 | fileRegisterLock.Lock() 51 | delete(fileRegister, strings.Trim(filePath, `"`)) 52 | fileRegisterLock.Unlock() 53 | } 54 | 55 | // RegisterReaderHandler registers a handler function which is used 56 | // to receive a io.Reader. 57 | // The Reader can be used by "LOAD DATA LOCAL INFILE Reader::". 58 | // If the handler returns a io.ReadCloser Close() is called when the 59 | // request is finished. 60 | // 61 | // mysql.RegisterReaderHandler("data", func() io.Reader { 62 | // var csvReader io.Reader // Some Reader that returns CSV data 63 | // ... // Open Reader here 64 | // return csvReader 65 | // }) 66 | // err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") 67 | // if err != nil { 68 | // ... 69 | // 70 | func RegisterReaderHandler(name string, handler func() io.Reader) { 71 | readerRegisterLock.Lock() 72 | // lazy map init 73 | if readerRegister == nil { 74 | readerRegister = make(map[string]func() io.Reader) 75 | } 76 | 77 | readerRegister[name] = handler 78 | readerRegisterLock.Unlock() 79 | } 80 | 81 | // DeregisterReaderHandler removes the ReaderHandler function with 82 | // the given name from the registry. 83 | func DeregisterReaderHandler(name string) { 84 | readerRegisterLock.Lock() 85 | delete(readerRegister, name) 86 | readerRegisterLock.Unlock() 87 | } 88 | 89 | func deferredClose(err *error, closer io.Closer) { 90 | closeErr := closer.Close() 91 | if *err == nil { 92 | *err = closeErr 93 | } 94 | } 95 | 96 | func (mc *MySQLConn) handleInFileRequest(name string) (err error) { 97 | var rdr io.Reader 98 | var data []byte 99 | packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP 100 | if mc.maxWriteSize < packetSize { 101 | packetSize = mc.maxWriteSize 102 | } 103 | 104 | if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader 105 | // The server might return an an absolute path. See issue #355. 106 | name = name[idx+8:] 107 | 108 | readerRegisterLock.RLock() 109 | handler, inMap := readerRegister[name] 110 | readerRegisterLock.RUnlock() 111 | 112 | if inMap { 113 | rdr = handler() 114 | if rdr != nil { 115 | if cl, ok := rdr.(io.Closer); ok { 116 | defer deferredClose(&err, cl) 117 | } 118 | } else { 119 | err = fmt.Errorf("Reader '%s' is ", name) 120 | } 121 | } else { 122 | err = fmt.Errorf("Reader '%s' is not registered", name) 123 | } 124 | } else { // File 125 | name = strings.Trim(name, `"`) 126 | fileRegisterLock.RLock() 127 | fr := fileRegister[name] 128 | fileRegisterLock.RUnlock() 129 | if mc.cfg.AllowAllFiles || fr { 130 | var file *os.File 131 | var fi os.FileInfo 132 | 133 | if file, err = os.Open(name); err == nil { 134 | defer deferredClose(&err, file) 135 | 136 | // get file size 137 | if fi, err = file.Stat(); err == nil { 138 | rdr = file 139 | if fileSize := int(fi.Size()); fileSize < packetSize { 140 | packetSize = fileSize 141 | } 142 | } 143 | } 144 | } else { 145 | err = fmt.Errorf("local file '%s' is not registered", name) 146 | } 147 | } 148 | 149 | // send content packets 150 | if err == nil { 151 | data := make([]byte, 4+packetSize) 152 | var n int 153 | for err == nil { 154 | n, err = rdr.Read(data[4:]) 155 | if n > 0 { 156 | if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { 157 | return ioErr 158 | } 159 | } 160 | } 161 | if err == io.EOF { 162 | err = nil 163 | } 164 | } 165 | 166 | // send empty packet (termination) 167 | if data == nil { 168 | data = make([]byte, 4) 169 | } 170 | if ioErr := mc.writePacket(data[:4]); ioErr != nil { 171 | return ioErr 172 | } 173 | 174 | // read OK packet 175 | if err == nil { 176 | return mc.readResultOK() 177 | } 178 | 179 | mc.readPacket() 180 | return err 181 | } 182 | -------------------------------------------------------------------------------- /database/mysql/interface.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | type Executor interface { 4 | Exec(query string, args ...interface{}) (Result, error) 5 | Query(query string, args ...interface{}) (Rows, error) 6 | Prepare(query string) (*Stmt, error) 7 | } 8 | -------------------------------------------------------------------------------- /database/mysql/result.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | type MySQLResult struct { 12 | status statusFlag 13 | warnings []error 14 | affectedRows int64 15 | insertId int64 16 | statusInfo string 17 | } 18 | 19 | func (r *MySQLResult) Status() (int64, error) { 20 | return int64(r.status), nil 21 | } 22 | 23 | func (r *MySQLResult) Warnings() []error { 24 | return r.warnings 25 | } 26 | 27 | func (r *MySQLResult) LastInsertId() (int64, error) { 28 | return r.insertId, nil 29 | } 30 | 31 | func (r *MySQLResult) RowsAffected() (int64, error) { 32 | return r.affectedRows, nil 33 | } 34 | 35 | func (r *MySQLResult) Info() (string, error) { 36 | return r.statusInfo, nil 37 | } 38 | -------------------------------------------------------------------------------- /database/mysql/server_conn_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | // "github.com/bytedance/dbatman/database/sql" 5 | "testing" 6 | ) 7 | 8 | func TestWriteCommandFieldList(t *testing.T) { 9 | 10 | runTests(t, dsn, func(dbt *DBTest) { 11 | dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") 12 | 13 | var rows Rows 14 | var err error 15 | if rows, err = dbt.db.FieldList("test", ""); err != nil { 16 | t.Fatal(err) 17 | } 18 | 19 | cols, err := rows.ColumnPackets() 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | if len(cols) != 2 { 25 | t.Fatalf("expect 2 rows, got %d", len(cols)) 26 | } 27 | 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /database/mysql/statement.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | import ( 12 | "fmt" 13 | "github.com/bytedance/dbatman/database/sql/driver" 14 | "reflect" 15 | "strconv" 16 | ) 17 | 18 | type mysqlStmt struct { 19 | mc *MySQLConn 20 | id uint32 21 | paramCount uint16 22 | columnCount uint16 23 | params []MySQLField // cached from the prepare 24 | prepareColumns []MySQLField // cached from the prepare 25 | columns []MySQLField // cached from the first query 26 | } 27 | 28 | func (stmt *mysqlStmt) Close() error { 29 | if stmt.mc == nil || stmt.mc.netConn == nil { 30 | errLog.Print(ErrInvalidConn) 31 | return driver.ErrBadConn 32 | } 33 | 34 | err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) 35 | stmt.mc = nil 36 | return err 37 | } 38 | 39 | func (stmt *mysqlStmt) NumInput() int { 40 | return int(stmt.paramCount) 41 | } 42 | 43 | func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { 44 | return converter{} 45 | } 46 | 47 | func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { 48 | if stmt.mc.netConn == nil { 49 | errLog.Print(ErrInvalidConn) 50 | return nil, driver.ErrBadConn 51 | } 52 | 53 | err := stmt.exec(args) 54 | mc := stmt.mc 55 | if err == nil { 56 | return &MySQLResult{ 57 | affectedRows: int64(mc.affectedRows), 58 | insertId: int64(mc.insertId), 59 | status: mc.status, 60 | warnings: nil, 61 | statusInfo: mc.popStatusInfo(), 62 | }, nil 63 | } else if errs, ok := err.(MySQLWarnings); ok { 64 | return &MySQLResult{ 65 | affectedRows: int64(mc.affectedRows), 66 | insertId: int64(mc.insertId), 67 | status: mc.status, 68 | warnings: errs.Errors(), 69 | statusInfo: mc.popStatusInfo(), 70 | }, err 71 | } 72 | 73 | return nil, err 74 | } 75 | 76 | func (stmt *mysqlStmt) exec(args []driver.Value) error { 77 | // Send command 78 | err := stmt.writeExecutePacket(args) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | mc := stmt.mc 84 | 85 | // Read Result 86 | resLen, err := mc.readResultSetHeaderPacket() 87 | if err == nil && resLen > 0 { 88 | // Columns 89 | err = mc.readUntilEOF() 90 | if err != nil { 91 | return err 92 | } 93 | 94 | // Rows 95 | err = mc.readUntilEOF() 96 | } 97 | 98 | return err 99 | } 100 | 101 | func (stmt *mysqlStmt) reset() error { 102 | 103 | // Send command 104 | err := stmt.mc.writeCommandPacketUint32(comStmtReset, stmt.id) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | mc := stmt.mc 110 | 111 | // Read Result 112 | resLen, err := mc.readResultSetHeaderPacket() 113 | if err == nil && resLen > 0 { 114 | // Columns 115 | err = mc.readUntilEOF() 116 | if err != nil { 117 | return err 118 | } 119 | 120 | // Rows 121 | err = mc.readUntilEOF() 122 | } 123 | 124 | return err 125 | } 126 | 127 | func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { 128 | if stmt.mc.netConn == nil { 129 | errLog.Print(ErrInvalidConn) 130 | return nil, driver.ErrBadConn 131 | } 132 | // Send command 133 | err := stmt.writeExecutePacket(args) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | mc := stmt.mc 139 | 140 | // Read Result 141 | resLen, err := mc.readResultSetHeaderPacket() 142 | if err != nil { 143 | return nil, err 144 | } 145 | 146 | rows := new(BinaryRows) 147 | 148 | if resLen > 0 { 149 | rows.mc = mc 150 | // Columns 151 | // If not cached, read them and cache them 152 | if stmt.columns == nil { 153 | rows.columns, err = mc.readColumns(resLen) 154 | stmt.columns = rows.columns 155 | } else { 156 | rows.columns = stmt.columns 157 | err = mc.readUntilEOF() 158 | } 159 | } 160 | 161 | return rows, err 162 | } 163 | 164 | func (s *mysqlStmt) Columns() []driver.RawPacket { 165 | var ret []driver.RawPacket 166 | for _, col := range s.prepareColumns { 167 | ret = append(ret, col.Dump()) 168 | } 169 | 170 | return ret 171 | } 172 | 173 | func (s *mysqlStmt) Params() []driver.RawPacket { 174 | var ret []driver.RawPacket 175 | for _, col := range s.params { 176 | ret = append(ret, col.Dump()) 177 | } 178 | 179 | return ret 180 | } 181 | 182 | func (stmt *mysqlStmt) SendLongData(paramId int, data []byte) error { 183 | if stmt.mc.netConn == nil { 184 | errLog.Print(ErrInvalidConn) 185 | return driver.ErrBadConn 186 | } 187 | 188 | return stmt.writeCommandLongData(paramId, data) 189 | } 190 | 191 | func (stmt *mysqlStmt) Reset() (driver.Result, error) { 192 | if stmt.mc.netConn == nil { 193 | errLog.Print(ErrInvalidConn) 194 | return nil, driver.ErrBadConn 195 | } 196 | 197 | err := stmt.reset() 198 | mc := stmt.mc 199 | if err == nil { 200 | return &MySQLResult{ 201 | affectedRows: int64(mc.affectedRows), 202 | insertId: int64(mc.insertId), 203 | status: mc.status, 204 | warnings: nil, 205 | statusInfo: mc.popStatusInfo(), 206 | }, nil 207 | } else if errs, ok := err.(MySQLWarnings); ok { 208 | return &MySQLResult{ 209 | affectedRows: int64(mc.affectedRows), 210 | insertId: int64(mc.insertId), 211 | status: mc.status, 212 | warnings: errs.Errors(), 213 | statusInfo: mc.popStatusInfo(), 214 | }, err 215 | } 216 | 217 | return nil, err 218 | } 219 | 220 | func (s *mysqlStmt) StatementID() uint32 { 221 | return s.id 222 | } 223 | 224 | type converter struct{} 225 | 226 | func (c converter) ConvertValue(v interface{}) (driver.Value, error) { 227 | if driver.IsValue(v) { 228 | return v, nil 229 | } 230 | 231 | rv := reflect.ValueOf(v) 232 | switch rv.Kind() { 233 | case reflect.Ptr: 234 | // indirect pointers 235 | if rv.IsNil() { 236 | return nil, nil 237 | } 238 | return c.ConvertValue(rv.Elem().Interface()) 239 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 240 | return rv.Int(), nil 241 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 242 | return int64(rv.Uint()), nil 243 | case reflect.Uint64: 244 | u64 := rv.Uint() 245 | if u64 >= 1<<63 { 246 | return strconv.FormatUint(u64, 10), nil 247 | } 248 | return int64(u64), nil 249 | case reflect.Float32, reflect.Float64: 250 | return rv.Float(), nil 251 | } 252 | return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) 253 | } 254 | -------------------------------------------------------------------------------- /database/mysql/transaction.go: -------------------------------------------------------------------------------- 1 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 | // 3 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 | // 5 | // This Source Code Form is subject to the terms of the Mozilla Public 6 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 | // You can obtain one at http://mozilla.org/MPL/2.0/. 8 | 9 | package mysql 10 | 11 | type mysqlTx struct { 12 | mc *MySQLConn 13 | } 14 | 15 | func (tx *mysqlTx) Commit() (err error) { 16 | if tx.mc == nil || tx.mc.netConn == nil { 17 | // fmt.Println("error in tx.mc", tx.mc) 18 | return ErrInvalidConn 19 | } 20 | err = tx.mc.exec("COMMIT") 21 | 22 | //TODO when to release the mc 23 | //tx.mc = nil 24 | return 25 | } 26 | 27 | func (tx *mysqlTx) Rollback() (err error) { 28 | if tx.mc == nil || tx.mc.netConn == nil { 29 | return ErrInvalidConn 30 | } 31 | err = tx.mc.exec("ROLLBACK") 32 | //tx.mc = nil 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /database/sql/driver/types_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2011 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package driver 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type valueConverterTest struct { 14 | c ValueConverter 15 | in interface{} 16 | out interface{} 17 | err string 18 | } 19 | 20 | var now = time.Now() 21 | var answer int64 = 42 22 | 23 | var valueConverterTests = []valueConverterTest{ 24 | {Bool, "true", true, ""}, 25 | {Bool, "True", true, ""}, 26 | {Bool, []byte("t"), true, ""}, 27 | {Bool, true, true, ""}, 28 | {Bool, "1", true, ""}, 29 | {Bool, 1, true, ""}, 30 | {Bool, int64(1), true, ""}, 31 | {Bool, uint16(1), true, ""}, 32 | {Bool, "false", false, ""}, 33 | {Bool, false, false, ""}, 34 | {Bool, "0", false, ""}, 35 | {Bool, 0, false, ""}, 36 | {Bool, int64(0), false, ""}, 37 | {Bool, uint16(0), false, ""}, 38 | {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, 39 | {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, 40 | {DefaultParameterConverter, now, now, ""}, 41 | {DefaultParameterConverter, (*int64)(nil), nil, ""}, 42 | {DefaultParameterConverter, &answer, answer, ""}, 43 | {DefaultParameterConverter, &now, now, ""}, 44 | } 45 | 46 | func TestValueConverters(t *testing.T) { 47 | for i, tt := range valueConverterTests { 48 | out, err := tt.c.ConvertValue(tt.in) 49 | goterr := "" 50 | if err != nil { 51 | goterr = err.Error() 52 | } 53 | if goterr != tt.err { 54 | t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", 55 | i, tt.c, tt.in, tt.in, goterr, tt.err) 56 | } 57 | if tt.err != "" { 58 | continue 59 | } 60 | if !reflect.DeepEqual(out, tt.out) { 61 | t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", 62 | i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /docs/config_example.txt: -------------------------------------------------------------------------------- 1 | global: 2 | port: 3306 3 | manage_port: 3307 4 | max_connections: 10 5 | log_filename: /var/log/tiger/dbatman.log 6 | log_level: 1 7 | log_maxsize: 1024 8 | log_query_min_time: 0 9 | client_timeout: 1800 10 | server_timeout: 1800 11 | write_time_interval: 10 12 | conf_autoload: 1 13 | auth_ips: 14 | - 10.4.64.1 15 | - 10.4.64.2 16 | 17 | clusters: 18 | pgc_cluster: 19 | master: 20 | host: 10.4.4.1 21 | port: 3306 22 | username: pgc 23 | password: pgc 24 | dbname: pgc 25 | max_connections: 100 26 | max_connection_pool_size: 10 27 | connect_timeout: 10 28 | time_reconnect_interval:10 29 | weight: 1 30 | slaves: 31 | slave1: 32 | host: 10.4.4.2 33 | port: 3306 34 | username: pgc 35 | password: pgc 36 | dbname: pgc 37 | max_connections: 100 38 | max_connection_pool_size: 10 39 | connect_timeout: 10 40 | time_reconnect_interval:10 41 | weight: 1 42 | slave2: 43 | host: 10.4.4.3 44 | port: 3306 45 | username: pgc 46 | password: pgc 47 | dbname: pgc 48 | max_connections: 100 49 | max_connection_pool_size: 10 50 | connect_timeout: 10 51 | time_reconnect_interval:10 52 | weight: 1 53 | 54 | users: 55 | proxy_pgc_user: 56 | username: proxy_pgc_user 57 | password: pgc 58 | max_connections: 1000 59 | min_connections: 100 60 | default_db: pgc 61 | default_charset: utf8mb4 62 | cluster_name: pgc_cluster 63 | auth_ips: 64 | - 10.1.1.1 65 | - 10.1.1.2 66 | black_list_ips: 67 | - 10.1.1.3 68 | - 10.1.1.4 69 | 70 | 71 | -------------------------------------------------------------------------------- /docs/internal/COM_QUERY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/COM_QUERY.png -------------------------------------------------------------------------------- /docs/internal/CS Protocol.xmind: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/CS Protocol.xmind -------------------------------------------------------------------------------- /docs/internal/Command.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/Command.jpg -------------------------------------------------------------------------------- /docs/internal/CommandLifeCycle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/CommandLifeCycle.jpg -------------------------------------------------------------------------------- /docs/internal/ConfigReloadFlow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/ConfigReloadFlow.png -------------------------------------------------------------------------------- /docs/internal/ConfigStructure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/ConfigStructure.png -------------------------------------------------------------------------------- /docs/internal/ConnectionFlow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/ConnectionFlow.png -------------------------------------------------------------------------------- /docs/internal/ConnectionLifecycle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/ConnectionLifecycle.jpg -------------------------------------------------------------------------------- /docs/internal/ListenRoutine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/ListenRoutine.jpg -------------------------------------------------------------------------------- /docs/internal/MySQL请求时序图.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/MySQL请求时序图.jpg -------------------------------------------------------------------------------- /docs/internal/RWSplit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/RWSplit.jpg -------------------------------------------------------------------------------- /docs/internal/RW_Splite.xmind: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/RW_Splite.xmind -------------------------------------------------------------------------------- /docs/internal/proxy_user_case.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/docs/internal/proxy_user_case.jpg -------------------------------------------------------------------------------- /docs/mysql-proxy/scripting.txt: -------------------------------------------------------------------------------- 1 | Hooks 2 | ===== 3 | 4 | connect_server 5 | -------------- 6 | 7 | read_auth 8 | --------- 9 | 10 | read_auth_result 11 | ---------------- 12 | 13 | read_query 14 | ---------- 15 | 16 | read_query_result 17 | ----------------- 18 | 19 | disconnect_client 20 | ----------------- 21 | 22 | Modules 23 | ======= 24 | 25 | mysql.proto 26 | ----------- 27 | 28 | The ``mysql.proto`` module provides encoders and decoders for the packets exchanged between client and server 29 | 30 | 31 | from_err_packet 32 | ............... 33 | 34 | Decodes a ERR-packet into a table. 35 | 36 | Parameters: 37 | 38 | ``packet`` 39 | (string) mysql packet 40 | 41 | 42 | On success it returns a table containing: 43 | 44 | ``errmsg`` 45 | (string) 46 | 47 | ``sqlstate`` 48 | (string) 49 | 50 | ``errcode`` 51 | (int) 52 | 53 | Otherwise it raises an error. 54 | 55 | to_err_packet 56 | ............. 57 | 58 | Encode a table containing a ERR packet into a MySQL packet. 59 | 60 | Parameters: 61 | 62 | ``err`` 63 | (table) 64 | 65 | ``errmsg`` 66 | (string) 67 | 68 | ``sqlstate`` 69 | (string) 70 | 71 | ``errcode`` 72 | (int) 73 | 74 | into a MySQL packet. 75 | 76 | Returns a string. 77 | 78 | from_ok_packet 79 | .............. 80 | 81 | Decodes a OK-packet 82 | 83 | ``packet`` 84 | (string) mysql packet 85 | 86 | 87 | On success it returns a table containing: 88 | 89 | ``server_status`` 90 | (int) bit-mask of the connection status 91 | 92 | ``insert_id`` 93 | (int) last used insert id 94 | 95 | ``warnings`` 96 | (int) number of warnings for the last executed statement 97 | 98 | ``affected_rows`` 99 | (int) rows affected by the last statement 100 | 101 | Otherwise it raises an error. 102 | 103 | 104 | to_ok_packet 105 | ............ 106 | 107 | Encode a OK packet 108 | 109 | from_eof_packet 110 | ............... 111 | 112 | Decodes a EOF-packet 113 | 114 | Parameters: 115 | 116 | ``packet`` 117 | (string) mysql packet 118 | 119 | 120 | On success it returns a table containing: 121 | 122 | ``server_status`` 123 | (int) bit-mask of the connection status 124 | 125 | ``warnings`` 126 | (int) 127 | 128 | Otherwise it raises an error. 129 | 130 | 131 | to_eof_packet 132 | ............. 133 | 134 | from_challenge_packet 135 | ..................... 136 | 137 | Decodes a auth-challenge-packet 138 | 139 | Parameters: 140 | 141 | ``packet`` 142 | (string) mysql packet 143 | 144 | On success it returns a table containing: 145 | 146 | ``protocol_version`` 147 | (int) version of the mysql protocol, usually 10 148 | 149 | ``server_version`` 150 | (int) version of the server as integer: 50506 is MySQL 5.5.6 151 | 152 | ``thread_id`` 153 | (int) connection id 154 | 155 | ``capabilities`` 156 | (int) bit-mask of the server capabilities 157 | 158 | ``charset`` 159 | (int) server default character-set 160 | 161 | ``server_status`` 162 | (int) bit-mask of the connection-status 163 | 164 | ``challenge`` 165 | (string) password challenge 166 | 167 | 168 | to_challenge_packet 169 | ................... 170 | 171 | Encode a auth-response-packet 172 | 173 | from_response_packet 174 | .................... 175 | 176 | Decodes a auth-response-packet 177 | 178 | Parameters: 179 | 180 | ``packet`` 181 | (string) mysql packet 182 | 183 | 184 | to_response_packet 185 | .................. 186 | 187 | from_masterinfo_string 188 | ...................... 189 | 190 | Decodes the content of the ``master.info`` file. 191 | 192 | 193 | to_masterinfo_string 194 | .................... 195 | 196 | from_stmt_prepare_packet 197 | ........................ 198 | 199 | Decodes a COM_STMT_PREPARE-packet 200 | 201 | Parameters: 202 | 203 | ``packet`` 204 | (string) mysql packet 205 | 206 | 207 | On success it returns a table containing: 208 | 209 | ``stmt_text`` 210 | (string) 211 | text of the prepared statement 212 | 213 | Otherwise it raises an error. 214 | 215 | from_stmt_prepare_ok_packet 216 | ........................... 217 | 218 | Decodes a COM_STMT_PACKET OK-packet 219 | 220 | Parameters: 221 | 222 | ``packet`` 223 | (string) mysql packet 224 | 225 | 226 | On success it returns a table containing: 227 | 228 | ``stmt_id`` 229 | (int) statement-id 230 | 231 | ``num_columns`` 232 | (int) number of columns in the resultset 233 | 234 | ``num_params`` 235 | (int) number of parameters 236 | 237 | ``warnings`` 238 | (int) warnings generated by the prepare statement 239 | 240 | Otherwise it raises an error. 241 | 242 | 243 | from_stmt_execute_packet 244 | ........................ 245 | 246 | Decodes a COM_STMT_EXECUTE-packet 247 | 248 | Parameters: 249 | 250 | ``packet`` 251 | (string) mysql packet 252 | 253 | ``num_params`` 254 | (int) number of parameters of the corresponding prepared statement 255 | 256 | On success it returns a table containing: 257 | 258 | ``stmt_id`` 259 | (int) statemend-id 260 | 261 | ``flags`` 262 | (int) flags describing the kind of cursor used 263 | 264 | ``iteration_count`` 265 | (int) iteration count: always 1 266 | 267 | ``new_params_bound`` 268 | (bool) 269 | 270 | ``params`` 271 | (nil, table) 272 | number-index array of parameters if ``new_params_bound`` is ``true`` 273 | 274 | Each param is a table of: 275 | 276 | ``type`` 277 | (int) 278 | MYSQL_TYPE_INT, MYSQL_TYPE_STRING ... and so on 279 | 280 | ``value`` 281 | (nil, number, string) 282 | if the value is a NULL, it ``nil`` 283 | if it is a number (_INT, _DOUBLE, ...) it is a ``number`` 284 | otherwise it is a ``string`` 285 | 286 | If decoding fails it raises an error. 287 | 288 | To get the ``num_params`` for this function, you have to track the track the number of parameters as returned 289 | by the `from_stmt_prepare_ok_packet`_. Use `stmt_id_from_stmt_execute_packet`_ to get the ``statement-id`` from 290 | the COM_STMT_EXECUTE packet and lookup your tracked information. 291 | 292 | stmt_id_from_stmt_execute_packet 293 | ................................ 294 | 295 | Decodes statement-id from a COM_STMT_EXECUTE-packet 296 | 297 | Parameters: 298 | 299 | ``packet`` 300 | (string) mysql packet 301 | 302 | 303 | On success it returns the ``statement-id`` as ``int``. 304 | 305 | Otherwise it raises an error. 306 | 307 | from_stmt_close_packet 308 | ...................... 309 | 310 | Decodes a COM_STMT_CLOSE-packet 311 | 312 | Parameters: 313 | 314 | ``packet`` 315 | (string) mysql packet 316 | 317 | 318 | On success it returns a table containing: 319 | 320 | ``stmt_id`` 321 | (int) 322 | statement-id that shall be closed 323 | 324 | Otherwise it raises an error. 325 | 326 | 327 | -------------------------------------------------------------------------------- /hack/hack.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "reflect" 5 | "unsafe" 6 | ) 7 | 8 | // String provides no copy to change slice to string 9 | // use your own risk 10 | func String(b []byte) (s string) { 11 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 12 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 13 | pstring.Data = pbytes.Data 14 | pstring.Len = pbytes.Len 15 | return 16 | } 17 | 18 | // Slice provides no copy to change string to slice 19 | // use your own risk 20 | func Slice(s string) (b []byte) { 21 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 22 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 23 | pbytes.Data = pstring.Data 24 | pbytes.Len = pstring.Len 25 | pbytes.Cap = pstring.Len 26 | return 27 | } 28 | -------------------------------------------------------------------------------- /hack/hack_test.go: -------------------------------------------------------------------------------- 1 | package hack 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestString(t *testing.T) { 9 | b := []byte("hello world") 10 | a := String(b) 11 | 12 | if a != "hello world" { 13 | t.Fatal(a) 14 | } 15 | 16 | b[0] = 'a' 17 | 18 | if a != "aello world" { 19 | t.Fatal(a) 20 | } 21 | 22 | b = append(b, "abc"...) 23 | if a != "aello world" { 24 | t.Fatal(a) 25 | } 26 | } 27 | 28 | func TestByte(t *testing.T) { 29 | a := "hello world" 30 | 31 | b := Slice(a) 32 | 33 | if !bytes.Equal(b, []byte("hello world")) { 34 | t.Fatal(string(b)) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /parser/.gitignore: -------------------------------------------------------------------------------- 1 | *.output 2 | -------------------------------------------------------------------------------- /parser/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2012, Google Inc. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can 3 | # be found in the LICENSE file. 4 | 5 | # MAKEFLAGS = -s 6 | 7 | sql.go: sql_yacc.yy 8 | goyacc -o sql_yacc.go -p MySQL sql_yacc.yy 9 | gofmt -w sql_yacc.go 10 | 11 | clean: 12 | rm -f y.output sql_yacc.go 13 | -------------------------------------------------------------------------------- /parser/ast.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type IStatement interface { 4 | IStatement() 5 | } 6 | 7 | func SetParseTree(yylex interface{}, stmt IStatement) { 8 | yylex.(*SQLLexer).ParseTree = stmt 9 | } 10 | -------------------------------------------------------------------------------- /parser/ast_alter.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type AlterTable struct { 4 | Table ISimpleTable 5 | } 6 | 7 | func (*AlterTable) IStatement() {} 8 | func (*AlterTable) IDDLStatement() {} 9 | 10 | type AlterDatabase struct { 11 | Schema []byte 12 | } 13 | 14 | func (*AlterDatabase) IStatement() {} 15 | func (*AlterDatabase) IDDLStatement() {} 16 | 17 | type AlterProcedure struct { 18 | Procedure *Spname 19 | } 20 | 21 | func (*AlterProcedure) IStatement() {} 22 | func (*AlterProcedure) IDDLStatement() {} 23 | 24 | type AlterFunction struct { 25 | Function *Spname 26 | } 27 | 28 | func (*AlterFunction) IStatement() {} 29 | func (*AlterFunction) IDDLStatement() {} 30 | 31 | /************************* 32 | * Alter View Statement 33 | *************************/ 34 | func (*AlterView) IStatement() {} 35 | func (*AlterView) IDDLStatement() {} 36 | 37 | type AlterView struct { 38 | View ISimpleTable 39 | As ISelect 40 | } 41 | 42 | func (av *AlterView) GetSchemas() []string { 43 | d := av.View.GetSchemas() 44 | p := av.As.GetSchemas() 45 | if d != nil && p != nil { 46 | d = append(d, p...) 47 | } 48 | 49 | return d 50 | } 51 | 52 | type viewTail struct { 53 | View ISimpleTable 54 | As ISelect 55 | } 56 | 57 | func (av *viewTail) GetSchemas() []string { 58 | d := av.View.GetSchemas() 59 | p := av.As.GetSchemas() 60 | if d != nil && p != nil { 61 | d = append(d, p...) 62 | } 63 | 64 | return d 65 | } 66 | 67 | /************************* 68 | * Alter Event Statement 69 | *************************/ 70 | func (*AlterEvent) IStatement() {} 71 | func (*AlterEvent) IDDLStatement() {} 72 | func (*AlterEvent) HasDDLSchemas() {} 73 | func (a *AlterEvent) GetSchemas() []string { 74 | if a.Rename == nil { 75 | return a.Event.GetSchemas() 76 | } 77 | 78 | return GetSchemas(a.Event.GetSchemas(), a.Rename.GetSchemas()) 79 | } 80 | 81 | type AlterEvent struct { 82 | Event *Spname 83 | Rename *Spname 84 | } 85 | 86 | type AlterTablespace struct{} 87 | 88 | func (*AlterTablespace) IStatement() {} 89 | func (*AlterTablespace) IDDLStatement() {} 90 | 91 | type AlterLogfile struct{} 92 | 93 | func (*AlterLogfile) IStatement() {} 94 | func (*AlterLogfile) IDDLStatement() {} 95 | 96 | type AlterServer struct{} 97 | 98 | func (*AlterServer) IStatement() {} 99 | func (*AlterServer) IDDLStatement() {} 100 | -------------------------------------------------------------------------------- /parser/ast_compound.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type Signal struct{} 4 | 5 | func (*Signal) IStatement() {} 6 | 7 | type Resignal struct{} 8 | 9 | func (*Resignal) IStatement() {} 10 | 11 | type Diagnostics struct{} 12 | 13 | func (*Diagnostics) IStatement() {} 14 | -------------------------------------------------------------------------------- /parser/ast_create.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | func (*CreateTable) IStatement() {} 4 | func (*CreateTable) IDDLStatement() {} 5 | func (*CreateTable) HasDDLSchemas() {} 6 | 7 | func (c *CreateTable) GetSchemas() []string { 8 | return c.Table.GetSchemas() 9 | } 10 | 11 | type CreateTable struct { 12 | Table ISimpleTable 13 | } 14 | 15 | func (*CreateIndex) IStatement() {} 16 | func (*CreateIndex) IDDLStatement() {} 17 | 18 | type CreateIndex struct{} 19 | 20 | /**************************** 21 | * Create Database Statement 22 | ***************************/ 23 | func (*CreateDatabase) IStatement() {} 24 | func (*CreateDatabase) IDDLStatement() {} 25 | 26 | type CreateDatabase struct{} 27 | 28 | func (*CreateView) IStatement() {} 29 | func (*CreateView) IDDLStatement() {} 30 | func (*CreateView) HasDDLSchemas() {} 31 | 32 | type CreateView struct { 33 | View ISimpleTable 34 | As ISelect 35 | } 36 | 37 | func (c *CreateView) GetSchemas() []string { 38 | return GetSchemas(c.View.GetSchemas(), c.As.GetSchemas()) 39 | } 40 | 41 | func (*CreateLog) IStatement() {} 42 | func (*CreateLog) IDDLStatement() {} 43 | 44 | type CreateLog struct{} 45 | 46 | func (*CreateTablespace) IStatement() {} 47 | func (*CreateTablespace) IDDLStatement() {} 48 | 49 | type CreateTablespace struct{} 50 | 51 | func (*CreateServer) IStatement() {} 52 | func (*CreateServer) IDDLStatement() {} 53 | 54 | type CreateServer struct{} 55 | 56 | /********************** 57 | * Create Event Statement 58 | * http://dev.mysql.com/doc/refman/5.7/en/create-event.html 59 | *********************/ 60 | func (*CreateEvent) IStatement() {} 61 | func (*CreateEvent) IDDLStatement() {} 62 | func (*CreateEvent) HasDDLSchemas() {} 63 | 64 | type CreateEvent struct { 65 | Event ISimpleTable 66 | } 67 | 68 | func (c *CreateEvent) GetSchemas() []string { 69 | return c.Event.GetSchemas() 70 | } 71 | 72 | type eventTail struct { 73 | Event ISimpleTable 74 | } 75 | 76 | func (*CreateProcedure) IStatement() {} 77 | func (*CreateProcedure) IDDLStatement() {} 78 | func (*CreateProcedure) HasDDLSchemas() {} 79 | 80 | type CreateProcedure struct { 81 | Procedure ISimpleTable 82 | } 83 | 84 | func (c *CreateProcedure) GetSchemas() []string { 85 | return c.Procedure.GetSchemas() 86 | } 87 | 88 | type spTail struct { 89 | Procedure ISimpleTable 90 | } 91 | 92 | func (*CreateFunction) IStatement() {} 93 | func (*CreateFunction) IDDLStatement() {} 94 | func (*CreateFunction) HasDDLSchemas() {} 95 | 96 | type CreateFunction struct { 97 | Function ISimpleTable 98 | } 99 | type sfTail struct { 100 | Function ISimpleTable 101 | } 102 | 103 | func (c *CreateFunction) GetSchemas() []string { 104 | return c.Function.GetSchemas() 105 | } 106 | 107 | func (*CreateTrigger) IStatement() {} 108 | func (*CreateTrigger) IDDLStatement() {} 109 | func (*CreateTrigger) HasDDLSchemas() {} 110 | 111 | type CreateTrigger struct { 112 | Trigger ISimpleTable 113 | } 114 | type triggerTail struct { 115 | Trigger ISimpleTable 116 | } 117 | 118 | func (c *CreateTrigger) GetSchemas() []string { 119 | return c.Trigger.GetSchemas() 120 | } 121 | -------------------------------------------------------------------------------- /parser/ast_dal.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | func (*Set) IStatement() {} 4 | 5 | type Set struct { 6 | VarList Vars 7 | } 8 | 9 | type Vars []*Variable 10 | 11 | type Variable struct { 12 | Type VarType 13 | Life LifeType 14 | Name string 15 | Value IExpr 16 | } 17 | 18 | type VarType int 19 | type LifeType int 20 | 21 | const ( 22 | Type_Sys = 1 23 | Type_Usr = 2 24 | 25 | Life_Unknown = 0 26 | Life_Global = 1 27 | Life_Local = 2 28 | Life_Session = 3 29 | ) 30 | 31 | type IAccountMgrStmt interface { 32 | IsAccountMgrStmt() 33 | IStatement 34 | } 35 | 36 | type Partition struct{} 37 | 38 | func (*Partition) IStatement() {} 39 | 40 | /******************************* 41 | * Table Maintenance Statements 42 | ******************************/ 43 | type ITableMtStmt interface { 44 | IStatement 45 | IsTableMtStmt() 46 | GetSchemas() []string 47 | } 48 | 49 | func (*Check) IStatement() {} 50 | func (*Check) IsTableMtStmt() {} 51 | func (*CheckSum) IStatement() {} 52 | func (*CheckSum) IsTableMtStmt() {} 53 | func (*Repair) IStatement() {} 54 | func (*Repair) IsTableMtStmt() {} 55 | func (*Analyze) IStatement() {} 56 | func (*Analyze) IsTableMtStmt() {} 57 | func (*Optimize) IStatement() {} 58 | func (*Optimize) IsTableMtStmt() {} 59 | 60 | func (c *Check) GetSchemas() []string { 61 | return c.Tables.GetSchemas() 62 | } 63 | 64 | func (c *CheckSum) GetSchemas() []string { 65 | return c.Tables.GetSchemas() 66 | } 67 | 68 | func (r *Repair) GetSchemas() []string { 69 | return r.Tables.GetSchemas() 70 | } 71 | 72 | func (a *Analyze) GetSchemas() []string { 73 | return a.Tables.GetSchemas() 74 | } 75 | 76 | func (o *Optimize) GetSchemas() []string { 77 | return o.Tables.GetSchemas() 78 | } 79 | 80 | type Check struct { 81 | Tables ISimpleTables 82 | } 83 | 84 | type CheckSum struct { 85 | Tables ISimpleTables 86 | } 87 | 88 | type Repair struct { 89 | Tables ISimpleTables 90 | } 91 | 92 | type Analyze struct { 93 | Tables ISimpleTables 94 | } 95 | 96 | type Optimize struct { 97 | Tables ISimpleTables 98 | } 99 | 100 | /**************************** 101 | * Cache Index Statement 102 | ***************************/ 103 | func (*CacheIndex) IStatement() {} 104 | 105 | type CacheIndex struct { 106 | TableIndexList TableIndexes 107 | } 108 | 109 | func (c *CacheIndex) GetSchemas() []string { 110 | if c.TableIndexList == nil || len(c.TableIndexList) == 0 { 111 | return nil 112 | } 113 | return c.TableIndexList.GetSchemas() 114 | } 115 | 116 | func (*LoadIndex) IStatement() {} 117 | 118 | type LoadIndex struct { 119 | TableIndexList TableIndexes 120 | } 121 | 122 | func (l *LoadIndex) GetSchemas() []string { 123 | if l.TableIndexList == nil || len(l.TableIndexList) == 0 { 124 | return nil 125 | } 126 | return l.TableIndexList.GetSchemas() 127 | } 128 | 129 | type TableIndexes []*TableIndex 130 | 131 | func (tis TableIndexes) GetSchemas() []string { 132 | var rt []string 133 | for _, v := range tis { 134 | if v == nil { 135 | continue 136 | } 137 | 138 | if r := v.Table.GetSchemas(); r != nil && len(r) != 0 { 139 | rt = append(rt, r...) 140 | } 141 | } 142 | 143 | if len(rt) == 0 { 144 | return nil 145 | } 146 | 147 | return rt 148 | } 149 | 150 | type TableIndex struct { 151 | Table ISimpleTable 152 | } 153 | 154 | type Binlog struct{} 155 | 156 | func (*Binlog) IStatement() {} 157 | 158 | func (*Flush) IStatement() {} 159 | 160 | type Flush struct{} 161 | 162 | func (*FlushTables) IStatement() {} 163 | 164 | func (f *FlushTables) GetSchemas() []string { 165 | if f.Tables == nil { 166 | return nil 167 | } 168 | return f.Tables.GetSchemas() 169 | } 170 | 171 | type FlushTables struct { 172 | Tables ISimpleTables 173 | } 174 | 175 | type Kill struct{} 176 | 177 | func (*Kill) IStatement() {} 178 | 179 | type Reset struct{} 180 | 181 | func (*Reset) IStatement() {} 182 | 183 | /********************************************** 184 | * Plugin and User-Defined Function Statements 185 | *********************************************/ 186 | type IPluginAndUdf interface { 187 | IStatement 188 | IsPluginAndUdf() 189 | } 190 | 191 | func (*Install) IStatement() {} 192 | func (*Install) IsPluginAndUdf() {} 193 | func (*CreateUDF) IStatement() {} 194 | func (*CreateUDF) IDDLStatement() {} 195 | func (*CreateUDF) IsPluginAndUdf() {} 196 | func (*Uninstall) IStatement() {} 197 | func (*Uninstall) IsPluginAndUdf() {} 198 | 199 | type Install struct{} 200 | 201 | type Uninstall struct{} 202 | 203 | type CreateUDF struct { 204 | Function ISimpleTable 205 | } 206 | 207 | type udfTail struct { 208 | Function ISimpleTable 209 | } 210 | 211 | /********************************** 212 | * Account Management Statements 213 | *********************************/ 214 | func (*Grant) IStatement() {} 215 | func (*Grant) IsAccountMgrStmt() {} 216 | 217 | type Grant struct{} 218 | 219 | func (*SetPassword) IStatement() {} 220 | func (*SetPassword) IsAccountStmt() {} 221 | 222 | type SetPassword struct{} 223 | 224 | func (*RenameUser) IStatement() {} 225 | func (*RenameUser) IsAccountMgrStmt() {} 226 | 227 | type RenameUser struct{} 228 | 229 | func (*Revoke) IStatement() {} 230 | func (*Revoke) IsAccountMgrStmt() {} 231 | 232 | type Revoke struct{} 233 | 234 | func (*CreateUser) IStatement() {} 235 | func (*CreateUser) IDDLStatement() {} 236 | func (*CreateUser) IsAccountMgrStmt() {} 237 | 238 | type CreateUser struct{} 239 | 240 | func (*AlterUser) IStatement() {} 241 | func (*AlterUser) IDDLStatement() {} 242 | func (*AlterUser) IsAccountMgrStmt() {} 243 | 244 | type AlterUser struct{} 245 | 246 | func (*DropUser) IStatement() {} 247 | func (*DropUser) IDDLStatement() {} 248 | func (*DropUser) IsAccountMgrStmt() {} 249 | 250 | type DropUser struct{} 251 | -------------------------------------------------------------------------------- /parser/ast_ddl.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type IDDLStatement interface { 4 | IDDLStatement() 5 | IStatement() 6 | } 7 | 8 | type IDDLSchemas interface { 9 | GetSchemas() []string 10 | HasDDLSchemas() 11 | } 12 | 13 | type RenameTable struct { 14 | ToList []*TableToTable 15 | } 16 | 17 | func (*RenameTable) IStatement() {} 18 | func (*RenameTable) IDDLStatement() {} 19 | 20 | func (*TruncateTable) IStatement() {} 21 | func (*TruncateTable) IDDLStatement() {} 22 | func (*TruncateTable) HasDDLSchemas() {} 23 | func (t *TruncateTable) GetSchemas() []string { 24 | return t.Table.GetSchemas() 25 | } 26 | 27 | type TruncateTable struct { 28 | Table ISimpleTable 29 | } 30 | -------------------------------------------------------------------------------- /parser/ast_dml.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | /*********************************** 4 | * Select Clause 5 | ***********************************/ 6 | 7 | type ISelect interface { 8 | ISelect() 9 | IsLocked() bool 10 | GetSchemas() []string 11 | IStatement 12 | } 13 | 14 | func (*Select) ISelect() {} 15 | func (*ParenSelect) ISelect() {} 16 | func (*Union) ISelect() {} 17 | func (*SubQuery) ISelect() {} 18 | 19 | func (*Select) IStatement() {} 20 | func (*ParenSelect) IStatement() {} 21 | func (*Union) IStatement() {} 22 | func (*SubQuery) IStatement() {} 23 | 24 | type Union struct { 25 | Left, Right ISelect 26 | } 27 | 28 | func (u *Union) IsLocked() bool { 29 | return u.Left.IsLocked() || u.Right.IsLocked() 30 | } 31 | 32 | func (u *Union) GetSchemas() []string { 33 | if u.Left == nil { 34 | panic("union must have left select statement") 35 | } 36 | 37 | if u.Right == nil { 38 | panic("union must have right select statement") 39 | } 40 | 41 | l := u.Left.GetSchemas() 42 | r := u.Right.GetSchemas() 43 | 44 | if l == nil && r == nil { 45 | return nil 46 | } else if l == nil { 47 | return r 48 | } else if r == nil { 49 | return l 50 | } 51 | return append(l, r...) 52 | } 53 | 54 | // SubQuery --------- 55 | type SubQuery struct { 56 | SelectStatement ISelect 57 | } 58 | 59 | func (s *SubQuery) IsLocked() bool { 60 | return s.SelectStatement.IsLocked() 61 | } 62 | 63 | func (s *SubQuery) GetSchemas() []string { 64 | if s.SelectStatement == nil { 65 | panic("subquery has no content") 66 | } 67 | 68 | return s.SelectStatement.GetSchemas() 69 | } 70 | 71 | // Select ----------- 72 | type Select struct { 73 | From ITables 74 | LockType LockType 75 | } 76 | 77 | func (s *Select) IsLocked() bool { 78 | return s.LockType != LockType_NoLock 79 | } 80 | 81 | func (s *Select) GetSchemas() []string { 82 | if s.From == nil { 83 | return nil 84 | } 85 | 86 | ret := make([]string, 0, 8) 87 | for _, v := range s.From { 88 | r := v.GetSchemas() 89 | if r != nil || len(r) != 0 { 90 | ret = append(ret, r...) 91 | } 92 | } 93 | 94 | return ret 95 | } 96 | 97 | // ParenSelect ------ 98 | type ParenSelect struct { 99 | Select ISelect 100 | } 101 | 102 | func (p *ParenSelect) IsLocked() bool { 103 | return p.Select.IsLocked() 104 | } 105 | 106 | func (p *ParenSelect) GetSchemas() []string { 107 | return p.Select.GetSchemas() 108 | } 109 | 110 | type LockType int 111 | 112 | const ( 113 | LockType_NoLock = iota 114 | LockType_ForUpdate 115 | LockType_LockInShareMode 116 | ) 117 | 118 | /********************************* 119 | * Insert Clause 120 | * - http://dev.mysql.com/doc/refman/5.7/en/insert.html 121 | ********************************/ 122 | func (*Insert) IStatement() {} 123 | func (i *Insert) HasISelect() bool { 124 | if i.InsertFields == nil { 125 | return false 126 | } 127 | 128 | if _, ok := i.InsertFields.(ISelect); !ok { 129 | return false 130 | } 131 | 132 | return true 133 | } 134 | 135 | func (i *Insert) GetSchemas() []string { 136 | ret := i.Table.GetSchemas() 137 | var s []string = nil 138 | if i.HasISelect() { 139 | s = i.InsertFields.(*Select).GetSchemas() 140 | } 141 | 142 | if ret == nil || len(ret) == 0 { 143 | return s 144 | } 145 | 146 | if s == nil || len(s) == 0 { 147 | return ret 148 | } 149 | 150 | return append(ret, s...) 151 | } 152 | 153 | type Insert struct { 154 | Table ISimpleTable 155 | // can be `values(x,y,z)` list or `select` statement 156 | InsertFields interface{} 157 | } 158 | 159 | /********************************* 160 | * Update Clause 161 | * - http://dev.mysql.com/doc/refman/5.7/en/update.html 162 | ********************************/ 163 | func (*Update) IStatement() {} 164 | func (u *Update) GetSchemas() []string { 165 | if u.Tables == nil { 166 | panic("update must have table identifier") 167 | } 168 | 169 | return u.Tables.GetSchemas() 170 | } 171 | 172 | type Update struct { 173 | Tables ITables 174 | } 175 | 176 | /********************************* 177 | * Delete Clause 178 | ********************************/ 179 | func (*Delete) IStatement() {} 180 | 181 | type Delete struct { 182 | Tables ITables 183 | } 184 | 185 | func (d *Delete) GetSchemas() []string { 186 | if d.Tables == nil || len(d.Tables) == 0 { 187 | return nil 188 | } 189 | return d.Tables.GetSchemas() 190 | } 191 | 192 | /*********************************************** 193 | * Replace Clause 194 | **********************************************/ 195 | func (*Replace) IStatement() {} 196 | func (r *Replace) HasISelect() bool { 197 | if r.ReplaceFields == nil { 198 | return false 199 | } 200 | 201 | if _, ok := r.ReplaceFields.(ISelect); !ok { 202 | return false 203 | } 204 | 205 | return true 206 | } 207 | func (r *Replace) GetSchemas() []string { 208 | ret := r.Table.GetSchemas() 209 | var s []string = nil 210 | if r.HasISelect() { 211 | s = r.ReplaceFields.(*Select).GetSchemas() 212 | } 213 | 214 | if ret == nil || len(ret) == 0 { 215 | return s 216 | } 217 | 218 | if s == nil || len(s) == 0 { 219 | return ret 220 | } 221 | 222 | return append(ret, s...) 223 | } 224 | 225 | type Replace struct { 226 | Table ITable 227 | // can be `values(x,y,z)` list or `select` statement 228 | ReplaceFields interface{} 229 | } 230 | 231 | type Call struct { 232 | Spname *Spname 233 | } 234 | 235 | func (*Call) IStatement() {} 236 | 237 | type Do struct{} 238 | 239 | func (*Do) IStatement() {} 240 | 241 | type Load struct{} 242 | 243 | func (*Load) IStatement() {} 244 | 245 | type Handler struct{} 246 | 247 | func (*Handler) IStatement() {} 248 | -------------------------------------------------------------------------------- /parser/ast_drop.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | func (*DropTables) IStatement() {} 4 | func (*DropTables) IDDLStatement() {} 5 | func (*DropTables) HasDDLSchemas() {} 6 | func (d *DropTables) GetSchemas() []string { 7 | return d.Tables.GetSchemas() 8 | } 9 | 10 | type DropTables struct { 11 | Tables ISimpleTables 12 | } 13 | 14 | func (*DropIndex) IStatement() {} 15 | func (*DropIndex) IDDLStatement() {} 16 | func (*DropIndex) HasDDLSchemas() {} 17 | func (d *DropIndex) GetSchemas() []string { 18 | return d.On.GetSchemas() 19 | } 20 | 21 | type DropIndex struct { 22 | On ISimpleTable 23 | } 24 | 25 | type DropDatabase struct{} 26 | 27 | func (*DropDatabase) IStatement() {} 28 | func (*DropDatabase) IDDLStatement() {} 29 | 30 | func (*DropFunction) IStatement() {} 31 | func (*DropFunction) IDDLStatement() {} 32 | func (*DropFunction) HasDDLSchemas() {} 33 | func (d *DropFunction) GetSchemas() []string { 34 | return d.Function.GetSchemas() 35 | } 36 | 37 | type DropFunction struct { 38 | Function *Spname 39 | } 40 | 41 | func (*DropProcedure) IStatement() {} 42 | func (*DropProcedure) IDDLStatement() {} 43 | func (*DropProcedure) HasDDLSchemas() {} 44 | func (d *DropProcedure) GetSchemas() []string { 45 | return d.Procedure.GetSchemas() 46 | } 47 | 48 | type DropProcedure struct { 49 | Procedure *Spname 50 | } 51 | 52 | type DropView struct{} 53 | 54 | func (*DropView) IStatement() {} 55 | func (*DropView) IDDLStatement() {} 56 | 57 | func (*DropTrigger) IStatement() {} 58 | func (*DropTrigger) IDDLStatement() {} 59 | func (*DropTrigger) HasDDLSchemas() {} 60 | func (d *DropTrigger) GetSchemas() []string { 61 | return d.Trigger.GetSchemas() 62 | } 63 | 64 | type DropTrigger struct { 65 | Trigger *Spname 66 | } 67 | 68 | func (*DropTablespace) IStatement() {} 69 | func (*DropTablespace) IDDLStatement() {} 70 | 71 | type DropTablespace struct{} 72 | 73 | func (*DropLogfile) IStatement() {} 74 | func (*DropLogfile) IDDLStatement() {} 75 | 76 | type DropLogfile struct{} 77 | 78 | func (*DropServer) IStatement() {} 79 | func (*DropServer) IDDLStatement() {} 80 | 81 | type DropServer struct{} 82 | 83 | func (*DropEvent) IStatement() {} 84 | func (*DropEvent) IDDLStatement() {} 85 | func (*DropEvent) HasDDLSchemas() {} 86 | func (d *DropEvent) GetSchemas() []string { 87 | return d.Event.GetSchemas() 88 | } 89 | 90 | type DropEvent struct { 91 | Event *Spname 92 | } 93 | -------------------------------------------------------------------------------- /parser/ast_prepare.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type Deallocate struct{} 4 | 5 | func (*Deallocate) IStatement() {} 6 | 7 | type Prepare struct{} 8 | 9 | func (*Prepare) IStatement() {} 10 | 11 | type Execute struct{} 12 | 13 | func (*Execute) IStatement() {} 14 | -------------------------------------------------------------------------------- /parser/ast_replication.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type Change struct{} 4 | 5 | func (*Change) IStatement() {} 6 | 7 | type Purge struct{} 8 | 9 | func (*Purge) IStatement() {} 10 | 11 | type StartSlave struct{} 12 | 13 | func (*StartSlave) IStatement() {} 14 | 15 | type StopSlave struct{} 16 | 17 | func (*StopSlave) IStatement() {} 18 | -------------------------------------------------------------------------------- /parser/ast_show.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | type IShow interface { 4 | IShow() 5 | IStatement 6 | } 7 | 8 | type IShowSchemas interface { 9 | IShow 10 | GetSchemas() []string 11 | } 12 | 13 | func (*ShowLogs) IStatement() {} 14 | func (*ShowLogs) IShow() {} 15 | func (*ShowLogEvents) IStatement() {} 16 | func (*ShowLogEvents) IShow() {} 17 | func (*ShowCharset) IStatement() {} 18 | func (*ShowCharset) IShow() {} 19 | func (*ShowCollation) IStatement() {} 20 | func (*ShowCollation) IShow() {} 21 | 22 | // SHOW CREATE [event|procedure|table|trigger|view] 23 | func (*ShowCreate) IStatement() {} 24 | func (*ShowCreate) IShow() {} 25 | func (*ShowCreateDatabase) IStatement() {} 26 | func (*ShowCreateDatabase) IShow() {} 27 | 28 | func (*ShowColumns) IStatement() {} 29 | func (*ShowColumns) IShow() {} 30 | 31 | func (*ShowDatabases) IStatement() {} 32 | func (*ShowDatabases) IShow() {} 33 | 34 | func (*ShowEngines) IStatement() {} 35 | func (*ShowEngines) IShow() {} 36 | 37 | func (*ShowErrors) IStatement() {} 38 | func (*ShowErrors) IShow() {} 39 | func (*ShowWarnings) IStatement() {} 40 | func (*ShowWarnings) IShow() {} 41 | 42 | func (*ShowEvents) IStatement() {} 43 | func (*ShowEvents) IShow() {} 44 | 45 | func (*ShowFunction) IStatement() {} 46 | func (*ShowFunction) IShow() {} 47 | 48 | func (*ShowGrants) IStatement() {} 49 | func (*ShowGrants) IShow() {} 50 | 51 | func (*ShowIndex) IStatement() {} 52 | func (*ShowIndex) IShow() {} 53 | 54 | func (*ShowStatus) IStatement() {} 55 | func (*ShowStatus) IShow() {} 56 | 57 | func (*ShowOpenTables) IStatement() {} 58 | func (*ShowOpenTables) IShow() {} 59 | func (*ShowTables) IStatement() {} 60 | func (*ShowTables) IShow() {} 61 | func (*ShowTableStatus) IStatement() {} 62 | func (*ShowTableStatus) IShow() {} 63 | 64 | func (*ShowPlugins) IStatement() {} 65 | func (*ShowPlugins) IShow() {} 66 | 67 | func (*ShowPrivileges) IStatement() {} 68 | func (*ShowPrivileges) IShow() {} 69 | 70 | func (*ShowProcedure) IStatement() {} 71 | func (*ShowProcedure) IShow() {} 72 | 73 | func (*ShowProcessList) IStatement() {} 74 | func (*ShowProcessList) IShow() {} 75 | 76 | func (*ShowProfiles) IStatement() {} 77 | func (*ShowProfiles) IShow() {} 78 | 79 | func (*ShowSlaveHosts) IStatement() {} 80 | func (*ShowSlaveHosts) IShow() {} 81 | func (*ShowSlaveStatus) IStatement() {} 82 | func (*ShowSlaveStatus) IShow() {} 83 | func (*ShowMasterStatus) IStatement() {} 84 | func (*ShowMasterStatus) IShow() {} 85 | 86 | func (*ShowTriggers) IStatement() {} 87 | func (*ShowTriggers) IShow() {} 88 | 89 | func (*ShowVariables) IStatement() {} 90 | func (*ShowVariables) IShow() {} 91 | 92 | // currently we use only like for `show databases` syntax 93 | type LikeOrWhere struct { 94 | Like string 95 | } 96 | 97 | type ShowDatabases struct { 98 | LikeOrWhere *LikeOrWhere 99 | } 100 | 101 | func (s *ShowTables) GetSchemas() []string { 102 | if s.From == nil || len(s.From) == 0 { 103 | return nil 104 | } 105 | 106 | return []string{string(s.From)} 107 | } 108 | 109 | type ShowTables struct { 110 | From []byte 111 | } 112 | 113 | func (s *ShowTriggers) GetSchemas() []string { 114 | if s.From == nil || len(s.From) == 0 { 115 | return nil 116 | } 117 | 118 | return []string{string(s.From)} 119 | } 120 | 121 | type ShowTriggers struct { 122 | From []byte 123 | } 124 | 125 | func (s *ShowEvents) GetSchemas() []string { 126 | if s.From == nil || len(s.From) == 0 { 127 | return nil 128 | } 129 | 130 | return []string{string(s.From)} 131 | } 132 | 133 | type ShowEvents struct { 134 | From []byte 135 | } 136 | 137 | func (s *ShowTableStatus) GetSchemas() []string { 138 | if s.From == nil || len(s.From) == 0 { 139 | return nil 140 | } 141 | 142 | return []string{string(s.From)} 143 | } 144 | 145 | type ShowTableStatus struct { 146 | From []byte 147 | } 148 | 149 | func (s *ShowOpenTables) GetSchemas() []string { 150 | if s.From == nil || len(s.From) == 0 { 151 | return nil 152 | } 153 | 154 | return []string{string(s.From)} 155 | } 156 | 157 | type ShowOpenTables struct { 158 | From []byte 159 | } 160 | 161 | func (s *ShowColumns) GetSchemas() []string { 162 | if s.From == nil || len(s.From) == 0 { 163 | return s.Table.GetSchemas() 164 | } 165 | 166 | return []string{string(s.From)} 167 | } 168 | 169 | type ShowColumns struct { 170 | Table ISimpleTable 171 | From []byte 172 | } 173 | 174 | func (s *ShowIndex) GetSchemas() []string { 175 | if s.From == nil || len(s.From) == 0 { 176 | return s.Table.GetSchemas() 177 | } 178 | 179 | return []string{string(s.From)} 180 | } 181 | 182 | type ShowIndex struct { 183 | Table ISimpleTable 184 | From []byte 185 | } 186 | 187 | func (s *ShowProcedure) GetSchemas() []string { 188 | return s.Procedure.GetSchemas() 189 | } 190 | 191 | type ShowProcedure struct { 192 | Procedure *Spname 193 | } 194 | 195 | func (s *ShowFunction) GetSchemas() []string { 196 | return s.Function.GetSchemas() 197 | } 198 | 199 | type ShowFunction struct { 200 | Function *Spname 201 | } 202 | 203 | func (s *ShowCreate) GetSchemas() []string { 204 | return s.Table.GetSchemas() 205 | } 206 | 207 | type ShowCreate struct { 208 | Prefix []byte 209 | Table ISimpleTable 210 | } 211 | 212 | func (s *ShowCreateDatabase) GetSchemas() []string { 213 | if s.Schema == nil || len(s.Schema) == 0 { 214 | return nil 215 | } 216 | 217 | return []string{string(s.Schema)} 218 | } 219 | 220 | type ShowCreateDatabase struct { 221 | Schema []byte 222 | } 223 | 224 | type ShowGrants struct{} 225 | type ShowCollation struct{} 226 | type ShowCharset struct{} 227 | type ShowVariables struct{} 228 | type ShowProcessList struct{} 229 | type ShowStatus struct{} 230 | type ShowProfiles struct{} 231 | type ShowPrivileges struct{} 232 | type ShowWarnings struct{} 233 | type ShowErrors struct{} 234 | type ShowLogEvents struct{} 235 | type ShowSlaveHosts struct{} 236 | type ShowSlaveStatus struct{} 237 | type ShowMasterStatus struct{} 238 | type ShowLogs struct{} 239 | type ShowPlugins struct{} 240 | type ShowEngines struct{} 241 | -------------------------------------------------------------------------------- /parser/ast_table.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | /******************************************* 8 | * Table Interfaces and Structs 9 | * doc: 10 | * - table_references http://dev.mysql.com/doc/refman/5.7/en/join.html 11 | * - table_factor http://dev.mysql.com/doc/refman/5.7/en/join.html 12 | * - join_table http://dev.mysql.com/doc/refman/5.7/en/join.html 13 | ******************************************/ 14 | type ITable interface { 15 | IsTable() 16 | GetSchemas() []string 17 | } 18 | 19 | type ITables []ITable 20 | 21 | func (ts ITables) GetSchemas() []string { 22 | if ts == nil && len(ts) == 0 { 23 | return nil 24 | } 25 | 26 | var ret []string 27 | for _, v := range ts { 28 | if r := v.GetSchemas(); r != nil && len(r) != 0 { 29 | ret = append(ret, r...) 30 | } 31 | } 32 | 33 | if len(ret) == 0 { 34 | return nil 35 | } 36 | 37 | return ret 38 | } 39 | 40 | func (*JoinTable) IsTable() {} 41 | func (*ParenTable) IsTable() {} 42 | func (*AliasedTable) IsTable() {} 43 | 44 | type JoinTable struct { 45 | Left ITable 46 | Join []byte 47 | Right ITable 48 | // TODO On BoolExpr 49 | } 50 | 51 | func (j *JoinTable) GetSchemas() []string { 52 | 53 | if j.Left == nil { 54 | panic("join table must have left value") 55 | } 56 | 57 | if j.Right == nil { 58 | panic("join table must have right value") 59 | } 60 | 61 | l := j.Left.GetSchemas() 62 | r := j.Right.GetSchemas() 63 | 64 | if l == nil && r == nil { 65 | return nil 66 | } else if l == nil { 67 | return r 68 | } else if r == nil { 69 | return l 70 | } 71 | 72 | return append(l, r...) 73 | } 74 | 75 | type ParenTable struct { 76 | Table ITable 77 | } 78 | 79 | func (p *ParenTable) GetSchemas() []string { 80 | if p.Table == nil { 81 | return nil 82 | } 83 | return p.Table.GetSchemas() 84 | } 85 | 86 | type AliasedTable struct { 87 | TableOrSubQuery interface{} // here may be the table_ident or subquery 88 | As []byte 89 | // TODO IndexHints 90 | } 91 | 92 | func (a *AliasedTable) GetSchemas() []string { 93 | if t, ok := a.TableOrSubQuery.(ITable); ok { 94 | return t.GetSchemas() 95 | } else if s, can := a.TableOrSubQuery.(*SubQuery); can { 96 | return s.SelectStatement.GetSchemas() 97 | } else { 98 | panic(fmt.Sprintf("alias table has no table_factor or subquery, element type[%T]", a.TableOrSubQuery)) 99 | } 100 | } 101 | 102 | // SimpleTable contains only qualifier, name and a column field 103 | func (*SimpleTable) IsSimpleTable() {} 104 | func (*SimpleTable) IsTable() {} 105 | 106 | type ISimpleTable interface { 107 | IsSimpleTable() 108 | ITable 109 | } 110 | 111 | type SimpleTable struct { 112 | Qualifier []byte 113 | Name []byte 114 | Column []byte 115 | } 116 | 117 | func (s *SimpleTable) GetSchemas() []string { 118 | if s.Qualifier == nil || len(s.Qualifier) == 0 { 119 | return nil 120 | } 121 | return []string{string(s.Qualifier)} 122 | } 123 | 124 | type ISimpleTables []ISimpleTable 125 | 126 | func (ts ISimpleTables) GetSchemas() []string { 127 | if ts == nil && len(ts) == 0 { 128 | return nil 129 | } 130 | 131 | var ret []string 132 | for _, v := range ts { 133 | if r := v.GetSchemas(); r != nil && len(r) != 0 { 134 | ret = append(ret, r...) 135 | } 136 | } 137 | 138 | if len(ret) == 0 { 139 | return nil 140 | } 141 | 142 | return ret 143 | } 144 | 145 | func (*Spname) IsSimpleTable() {} 146 | func (*Spname) IsTable() {} 147 | 148 | func (s *Spname) GetSchemas() []string { 149 | if s.Qualifier == nil || len(s.Qualifier) == 0 { 150 | return nil 151 | } 152 | 153 | return []string{string(s.Qualifier)} 154 | } 155 | 156 | type Spname struct { 157 | Qualifier []byte 158 | Name []byte 159 | } 160 | 161 | type SchemaInfo struct { 162 | Name []byte 163 | } 164 | 165 | func GetSchemas(params ...[]string) []string { 166 | var dst []string 167 | for _, arr := range params { 168 | if arr != nil { 169 | dst = append(dst, arr...) 170 | } 171 | } 172 | 173 | if len(dst) == 0 { 174 | return nil 175 | } 176 | 177 | return dst 178 | } 179 | 180 | type TableToTable struct { 181 | From ISimpleTable 182 | To ISimpleTable 183 | } 184 | -------------------------------------------------------------------------------- /parser/ast_trans.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | func (*StartTrans) IStatement() {} 4 | func (*Lock) IStatement() {} 5 | func (*Unlock) IStatement() {} 6 | func (*Begin) IStatement() {} 7 | func (*Commit) IStatement() {} 8 | func (*Rollback) IStatement() {} 9 | func (*XA) IStatement() {} 10 | func (*SavePoint) IStatement() {} 11 | func (*Release) IStatement() {} 12 | func (*SetTrans) IStatement() {} 13 | 14 | type StartTrans struct{} 15 | 16 | func (l *Lock) GetSchemas() []string { 17 | return l.Tables.GetSchemas() 18 | } 19 | 20 | type Lock struct { 21 | Tables ISimpleTables 22 | } 23 | 24 | type Unlock struct{} 25 | 26 | type Begin struct{} 27 | 28 | type Commit struct{} 29 | 30 | type Rollback struct { 31 | Point []byte 32 | } 33 | 34 | type XA struct{} 35 | 36 | type SavePoint struct { 37 | Point []byte 38 | } 39 | 40 | type Release struct { 41 | Point []byte 42 | } 43 | 44 | type SetTrans struct{} 45 | -------------------------------------------------------------------------------- /parser/ast_util.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func (*Help) IStatement() {} 8 | func (*DescribeTable) IStatement() {} 9 | func (*DescribeStmt) IStatement() {} 10 | func (*Use) IStatement() {} 11 | 12 | type Help struct{} 13 | 14 | func (d *DescribeTable) GetSchemas() []string { 15 | return d.Table.GetSchemas() 16 | } 17 | 18 | type DescribeTable struct { 19 | Table ISimpleTable 20 | } 21 | 22 | func (d *DescribeStmt) GetSchemas() []string { 23 | switch st := d.Stmt.(type) { 24 | case *Select: 25 | return st.GetSchemas() 26 | case *Insert: 27 | return st.GetSchemas() 28 | case *Update: 29 | return st.GetSchemas() 30 | case *Replace: 31 | return st.GetSchemas() 32 | case *Delete: 33 | return st.GetSchemas() 34 | default: 35 | panic(fmt.Sprintf("statement type %T is not explainable", st)) 36 | } 37 | } 38 | 39 | type DescribeStmt struct { 40 | Stmt IStatement 41 | } 42 | 43 | type Use struct { 44 | DB []byte 45 | } 46 | -------------------------------------------------------------------------------- /parser/bin/yacc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/dbatman/fbfcd0a293af49e20030c606b3e94782db957219/parser/bin/yacc -------------------------------------------------------------------------------- /parser/charset/charset.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | import ( 4 | "bytes" 5 | . "github.com/bytedance/dbatman/parser/state" 6 | ) 7 | 8 | type ( 9 | CharsetInfo struct { 10 | Number int 11 | PrimaryNumber int 12 | BinaryNumber int 13 | 14 | CSName string 15 | Name string 16 | 17 | CType []byte 18 | 19 | StateMap []uint 20 | IdentMap []uint 21 | } 22 | ) 23 | 24 | func init() { 25 | ValidCharsets = make(map[string]*CharsetInfo) 26 | ValidCharsets["utf8_general_cli"] = CSUtf8GeneralCli 27 | 28 | for _, v := range ValidCharsets { 29 | initStateMaps(v) 30 | } 31 | } 32 | 33 | var ValidCharsets map[string]*CharsetInfo 34 | 35 | func IsValidCharsets(cs []byte) bool { 36 | if _, ok := ValidCharsets[string(bytes.ToLower(cs))]; ok { 37 | return true 38 | } 39 | 40 | return false 41 | } 42 | 43 | func initStateMaps(cs *CharsetInfo) { 44 | 45 | var state_map [256]uint 46 | 47 | for i := 0; i < 256; i++ { 48 | if cs.IsAlpha(byte(i)) == true { 49 | state_map[i] = (MY_LEX_IDENT) 50 | } else if cs.IsDigit(byte(i)) { 51 | state_map[i] = MY_LEX_NUMBER_IDENT 52 | } else if cs.IsSpace(byte(i)) { 53 | state_map[i] = MY_LEX_SKIP 54 | } else { 55 | state_map[i] = MY_LEX_CHAR 56 | } 57 | } 58 | state_map[0] = MY_LEX_EOL 59 | state_map['_'] = MY_LEX_IDENT 60 | state_map['$'] = MY_LEX_IDENT 61 | state_map['\''] = MY_LEX_STRING 62 | state_map['.'] = MY_LEX_REAL_OR_POINT 63 | state_map['>'] = MY_LEX_CMP_OP 64 | state_map['='] = MY_LEX_CMP_OP 65 | state_map['!'] = MY_LEX_CMP_OP 66 | state_map['<'] = MY_LEX_LONG_CMP_OP 67 | state_map['&'] = MY_LEX_BOOL 68 | state_map['|'] = MY_LEX_BOOL 69 | state_map['#'] = MY_LEX_COMMENT 70 | state_map[';'] = MY_LEX_SEMICOLON 71 | state_map[':'] = MY_LEX_SET_VAR 72 | state_map['\\'] = MY_LEX_ESCAPE 73 | state_map['/'] = MY_LEX_LONG_COMMENT 74 | state_map['*'] = MY_LEX_END_LONG_COMMENT 75 | state_map['@'] = MY_LEX_USER_END 76 | state_map['`'] = MY_LEX_USER_VARIABLE_DELIMITER 77 | state_map['"'] = MY_LEX_STRING_OR_DELIMITER 78 | 79 | var ident_map [256]uint 80 | for i := 0; i < 256; i++ { 81 | ident_map[i] = func() uint { 82 | if state_map[i] == MY_LEX_IDENT || state_map[i] == MY_LEX_NUMBER_IDENT { 83 | return 1 84 | } 85 | return 0 86 | }() 87 | } 88 | 89 | state_map['x'] = MY_LEX_IDENT_OR_HEX 90 | state_map['X'] = MY_LEX_IDENT_OR_HEX 91 | state_map['b'] = MY_LEX_IDENT_OR_BIN 92 | state_map['B'] = MY_LEX_IDENT_OR_BIN 93 | state_map['n'] = (MY_LEX_IDENT_OR_NCHAR) 94 | state_map['N'] = (MY_LEX_IDENT_OR_NCHAR) 95 | 96 | cs.IdentMap = ident_map[:] 97 | cs.StateMap = state_map[:] 98 | } 99 | 100 | func (cs *CharsetInfo) IsAlpha(c byte) bool { 101 | if cs.CType[c+1]&(_MY_U|_MY_L) == 0 { 102 | return false 103 | } 104 | return true 105 | } 106 | 107 | func (cs *CharsetInfo) IsDigit(c byte) bool { 108 | if cs.CType[c+1]&_MY_NMR == 0 { 109 | return false 110 | } 111 | 112 | return true 113 | } 114 | 115 | func (cs *CharsetInfo) IsSpace(c byte) bool { 116 | if cs.CType[c+1]&_MY_SPC == 0 { 117 | return false 118 | } 119 | 120 | return true 121 | } 122 | 123 | func (cs *CharsetInfo) IsCntrl(c byte) bool { 124 | if cs.CType[c+1]&_MY_CTR == 0 { 125 | return false 126 | } 127 | 128 | return true 129 | } 130 | 131 | func (cs *CharsetInfo) IsXdigit(c byte) bool { 132 | if cs.CType[c+1]&_MY_X == 0 { 133 | return false 134 | } 135 | return true 136 | } 137 | 138 | func (cs *CharsetInfo) IsAlnum(c byte) bool { 139 | if cs.CType[c+1]&(_MY_U|_MY_L|_MY_NMR) == 0 { 140 | return false 141 | } 142 | 143 | return true 144 | } 145 | 146 | const ( 147 | _MY_U = 01 148 | _MY_L = 02 149 | _MY_NMR = 04 /* Numeral (digit) */ 150 | _MY_SPC = 010 /* Spacing character */ 151 | _MY_PNT = 020 /* Punctuation */ 152 | _MY_CTR = 040 /* Control character */ 153 | _MY_B = 0100 /* Blank */ 154 | _MY_X = 0200 /* heXadecimal digit */ 155 | ) 156 | -------------------------------------------------------------------------------- /parser/charset/charset_test.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestUtf8(t *testing.T) { 8 | 9 | func() { // TEST for utf8 digit 10 | for i := 0; i < 10; i++ { 11 | b := byte('0') + byte(i) 12 | if CSUtf8GeneralCli.IsDigit(b) == false { 13 | t.Fatalf("%v is not digit type", b) 14 | } 15 | } 16 | }() 17 | 18 | func() { // TEST for utf8 digit 19 | for i := 0; i < 26; i++ { 20 | b := byte('A') + byte(i) 21 | if CSUtf8GeneralCli.IsAlpha(b) == false { 22 | t.Fatalf("%v is not digit type", b) 23 | } 24 | 25 | b = byte('a') + byte(i) 26 | if CSUtf8GeneralCli.IsAlpha(b) == false { 27 | t.Fatalf("%v is not digit type", b) 28 | } 29 | } 30 | }() 31 | } 32 | -------------------------------------------------------------------------------- /parser/charset/utf8_general_cli.go: -------------------------------------------------------------------------------- 1 | package charset 2 | 3 | var CSUtf8GeneralCli *CharsetInfo = &CharsetInfo{ 4 | 33, 5 | 0, 6 | 0, 7 | 8 | "utf8", 9 | "utf8_general_ci", 10 | 11 | ctype_utf8, 12 | nil, 13 | nil, 14 | } 15 | 16 | var ctype_utf8 []byte = []byte{ 17 | 0, 18 | 32, 32, 32, 32, 32, 32, 32, 32, 32, 40, 40, 40, 40, 40, 32, 32, // 0 - 15 19 | 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, // 16 - 31 20 | 72, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, // 32 - 47 21 | //0, 1, 2 ...... 9 22 | 132, 132, 132, 132, 132, 132, 132, 132, 132, 132, 16, 16, 16, 16, 16, 16, 23 | 16, 129, 129, 129, 129, 129, 129, 1, 1, 1, 1, 1, 1, 1, 1, 1, 24 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 25 | 16, 130, 130, 130, 130, 130, 130, 2, 2, 2, 2, 2, 2, 2, 2, 2, 26 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 16, 16, 16, 16, 32, 27 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 28 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 29 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 30 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 31 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 32 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 33 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 34 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 35 | } 36 | -------------------------------------------------------------------------------- /parser/debug.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import "fmt" 4 | 5 | var debug bool = false 6 | 7 | func DEBUG(i interface{}) { 8 | if debug { 9 | fmt.Printf("%v", i) 10 | } 11 | } 12 | 13 | func setDebug(dbg bool) { 14 | debug = dbg 15 | } 16 | -------------------------------------------------------------------------------- /parser/lex_ident.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/bytedance/dbatman/parser/charset" 7 | . "github.com/bytedance/dbatman/parser/state" 8 | ) 9 | 10 | func (lex *SQLLexer) getPureIdentifier() (int, []byte) { 11 | ident_map := lex.cs.IdentMap 12 | c := lex.yyPeek() 13 | rs := int(c) 14 | 15 | for ident_map[lex.yyPeek()] != 0 { 16 | rs |= int(c) 17 | c = lex.yyNext() 18 | } 19 | 20 | if rs&0x80 != 0 { 21 | rs = IDENT_QUOTED 22 | } else { 23 | rs = IDENT 24 | } 25 | 26 | if lex.yyPeek() == '.' && ident_map[int(lex.yyPeek2())] != 0 { 27 | lex.next_state = MY_LEX_IDENT_SEP 28 | } 29 | 30 | return rs, lex.buf[lex.tok_start:lex.ptr] 31 | } 32 | 33 | func (lex *SQLLexer) getIdentifier() (int, []byte) { 34 | 35 | ident_map := lex.cs.IdentMap 36 | 37 | c := lex.yyPeek() 38 | rs := int(c) 39 | 40 | for ident_map[lex.yyPeek()] != 0 { 41 | rs |= int(c) 42 | c = lex.yyNext() 43 | } 44 | 45 | if rs&0x80 != 0 { 46 | rs = IDENT_QUOTED 47 | } else { 48 | rs = IDENT 49 | } 50 | 51 | idc := lex.buf[lex.tok_start:lex.ptr] 52 | if debug { 53 | DEBUG(fmt.Sprintf("idc:[" + string(idc) + "]\n")) 54 | } 55 | 56 | start := lex.ptr 57 | 58 | /* 59 | for ; lex.ignore_space && state_map[c] == MY_LEX_SKIP; c = lex.yyNext() { 60 | }*/ 61 | 62 | c = lex.yyPeek() 63 | if start == lex.ptr && lex.yyPeek() == '.' && ident_map[int(lex.yyPeek())] != 0 { 64 | lex.next_state = MY_LEX_IDENT_SEP 65 | } else if ret, ok := findKeywords(idc, c == '('); ok { 66 | lex.next_state = MY_LEX_START 67 | return ret, idc 68 | } 69 | 70 | if idc[0] == '_' && charset.IsValidCharsets(idc[1:]) { 71 | return UNDERSCORE_CHARSET, idc 72 | } 73 | 74 | return rs, idc 75 | } 76 | -------------------------------------------------------------------------------- /parser/lex_ident_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestIdentifier(t *testing.T) { 8 | testMatchReturn(t, "`test ` ", IDENT_QUOTED, false) 9 | } 10 | 11 | func TestMultiIdentifier(t *testing.T) { 12 | str := "SELECT INSERT 'string ' UPDATE DELEte `SELECT` `Update`" 13 | lex, lval := getLexer(str) 14 | 15 | lexExpect(t, lex, lval, SELECT_SYM) 16 | lexExpect(t, lex, lval, INSERT) 17 | 18 | lexExpect(t, lex, lval, TEXT_STRING) 19 | lvalExpect(t, lval, "'string '") 20 | 21 | lexExpect(t, lex, lval, UPDATE_SYM) 22 | lexExpect(t, lex, lval, DELETE_SYM) 23 | 24 | lexExpect(t, lex, lval, IDENT_QUOTED) 25 | lvalExpect(t, lval, "`SELECT`") 26 | 27 | lexExpect(t, lex, lval, IDENT_QUOTED) 28 | lvalExpect(t, lval, "`Update`") 29 | 30 | lexExpect(t, lex, lval, END_OF_INPUT) 31 | } 32 | 33 | func TestParamMarker(t *testing.T) { 34 | str := "select ?,?,? from t1;" 35 | lex, lval := getLexer(str) 36 | 37 | lexExpect(t, lex, lval, SELECT_SYM) 38 | lexExpect(t, lex, lval, PARAM_MARKER) 39 | lexExpect(t, lex, lval, ',') 40 | lexExpect(t, lex, lval, PARAM_MARKER) 41 | lexExpect(t, lex, lval, ',') 42 | lexExpect(t, lex, lval, PARAM_MARKER) 43 | } 44 | 45 | func TestMultiIdentifier1(t *testing.T) { 46 | str := "s n insert `s` `` s" 47 | lex, lval := getLexer(str) 48 | 49 | lexExpect(t, lex, lval, IDENT) 50 | lvalExpect(t, lval, `s`) 51 | 52 | lexExpect(t, lex, lval, IDENT) 53 | lvalExpect(t, lval, `n`) 54 | 55 | lexExpect(t, lex, lval, INSERT) 56 | 57 | lexExpect(t, lex, lval, IDENT_QUOTED) 58 | lvalExpect(t, lval, "`s`") 59 | 60 | lexExpect(t, lex, lval, IDENT_QUOTED) 61 | lvalExpect(t, lval, "``") 62 | 63 | lexExpect(t, lex, lval, IDENT) 64 | lvalExpect(t, lval, `s`) 65 | } 66 | 67 | func TestMultiIdentifier2(t *testing.T) { 68 | str := `table1.column_name=table2.column_name` 69 | lex, lval := getLexer(str) 70 | lexExpect(t, lex, lval, IDENT) 71 | lvalExpect(t, lval, "table1") 72 | 73 | lexExpect(t, lex, lval, '.') 74 | 75 | lexExpect(t, lex, lval, IDENT) 76 | lvalExpect(t, lval, "column_name") 77 | 78 | lexExpect(t, lex, lval, EQ) 79 | 80 | lexExpect(t, lex, lval, IDENT) 81 | lvalExpect(t, lval, "table2") 82 | 83 | lexExpect(t, lex, lval, '.') 84 | 85 | lexExpect(t, lex, lval, IDENT) 86 | lvalExpect(t, lval, "column_name") 87 | } 88 | -------------------------------------------------------------------------------- /parser/lex_keywords_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestKeywords(t *testing.T) { 8 | testMatchReturn(t, `SELECT`, SELECT_SYM, false) 9 | } 10 | 11 | func TestFunctions(t *testing.T) { 12 | testMatchReturn(t, `CURTIME()`, CURTIME, false) 13 | } 14 | 15 | func TestCharsetName(t *testing.T) { 16 | testMatchReturn(t, `_utf8_general_cli`, UNDERSCORE_CHARSET, false) 17 | } 18 | 19 | func TestIdent(t *testing.T) { 20 | testMatchReturn(t, `thisisaident`, IDENT, false) 21 | } 22 | 23 | func TestBoolOp(t *testing.T) { 24 | testMatchReturn(t, `&&`, AND_AND_SYM, false) 25 | testMatchReturn(t, `||`, OR_OR_SYM, false) 26 | testMatchReturn(t, `<`, LT, false) 27 | testMatchReturn(t, `<=`, LE, false) 28 | testMatchReturn(t, `<>`, NE, false) 29 | testMatchReturn(t, `!=`, NE, false) 30 | testMatchReturn(t, `=`, EQ, false) 31 | testMatchReturn(t, `>`, GT_SYM, false) 32 | testMatchReturn(t, `>=`, GE, false) 33 | testMatchReturn(t, `<<`, SHIFT_LEFT, false) 34 | testMatchReturn(t, `>>`, SHIFT_RIGHT, false) 35 | testMatchReturn(t, `<=>`, EQUAL_SYM, false) 36 | 37 | testMatchReturn(t, `:=`, SET_VAR, false) 38 | } 39 | 40 | func TestChar(t *testing.T) { 41 | testMatchReturn(t, `& `, '&', false) 42 | } 43 | 44 | func TestMultiKeywords(t *testing.T) { 45 | lexer, lval := getLexer(`SELECT SHOW Databases SELECT `) 46 | 47 | lexExpect(t, lexer, lval, SELECT_SYM) 48 | lexExpect(t, lexer, lval, SHOW) 49 | lexExpect(t, lexer, lval, DATABASES) 50 | 51 | lexExpect(t, lexer, lval, SELECT_SYM) 52 | } 53 | -------------------------------------------------------------------------------- /parser/lex_nchar.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import () 4 | 5 | func (lexer *SQLLexer) scanNChar(lval *MySQLSymType) (int, byte) { 6 | 7 | // found N'string' 8 | lexer.yyNext() // Skip ' 9 | 10 | // Skip any char except ' 11 | var c byte 12 | for c = lexer.yyNext(); c != 0 && c != '\''; c = lexer.yyNext() { 13 | } 14 | 15 | if c != '\'' { 16 | return ABORT_SYM, c 17 | } 18 | 19 | lval.bytes = lexer.buf[lexer.tok_start:lexer.ptr] 20 | 21 | return NCHAR_STRING, c 22 | } 23 | -------------------------------------------------------------------------------- /parser/lex_number.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import "fmt" 4 | 5 | const ( 6 | LONG_LEN = 10 7 | LONGLONG_LEN = 19 8 | SIGNED_LONGLONG_LEN = 19 9 | UNSIGNED_LONGLONG_LEN = 20 10 | ) 11 | 12 | var ( 13 | LONG []byte = []byte{'2', '1', '4', '7', '4', '8', '3', '6', '4', '7'} 14 | SIGNED_LONG []byte = []byte{'-', '2', '1', '4', '7', '4', '8', '3', '6', '4', '8'} 15 | LONGLONG []byte = []byte{'9', '2', '2', '3', '3', '7', '2', '0', '3', '6', '8', '5', '4', '7', '7', '5', '8', '0', '7'} 16 | SIGNED_LONGLONG []byte = []byte{'-', '9', '2', '2', '3', '3', '7', '2', '0', '3', '6', '8', '5', '4', '7', '7', '5', '8', '0', '8'} 17 | UNSIGNED_LONGLONG []byte = []byte{'1', '8', '4', '4', '6', '7', '4', '4', '0', '7', '3', '7', '0', '9', '5', '5', '1', '6', '1', '5'} 18 | ) 19 | 20 | func (lex *SQLLexer) scanInt(lval *MySQLSymType) int { 21 | length := lex.ptr - lex.tok_start 22 | lval.bytes = lex.buf[lex.tok_start:lex.ptr] 23 | 24 | if length < LONG_LEN { 25 | return NUM 26 | } 27 | 28 | neg := false 29 | start := lex.tok_start 30 | if lex.buf[start] == '+' { 31 | start += 1 32 | length -= 1 33 | } else if lex.buf[start] == '-' { 34 | start += 1 35 | length -= 1 36 | neg = true 37 | } 38 | 39 | // ignore any '0' character 40 | for start < lex.ptr && lex.buf[start] == '0' { 41 | start += 1 42 | length -= 1 43 | } 44 | 45 | if length < LONG_LEN { 46 | return NUM 47 | } 48 | 49 | var cmp []byte 50 | var smaller int 51 | var bigger int 52 | if neg { 53 | if length == LONG_LEN { 54 | cmp = SIGNED_LONG[1:len(SIGNED_LONG)] 55 | smaller = NUM 56 | bigger = LONG_NUM 57 | } else if length < SIGNED_LONGLONG_LEN { 58 | return LONG_NUM 59 | } else if length > SIGNED_LONGLONG_LEN { 60 | return DECIMAL_NUM 61 | } else { 62 | cmp = SIGNED_LONGLONG[1:len(SIGNED_LONGLONG)] 63 | smaller = LONG_NUM 64 | bigger = DECIMAL_NUM 65 | } 66 | } else { 67 | if length == LONG_LEN { 68 | cmp = LONG 69 | smaller = NUM 70 | bigger = LONG_NUM 71 | } else if length < LONGLONG_LEN { 72 | return LONG_NUM 73 | } else if length > LONGLONG_LEN { 74 | if length > UNSIGNED_LONGLONG_LEN { 75 | return DECIMAL_NUM 76 | } 77 | cmp = UNSIGNED_LONGLONG 78 | smaller = ULONGLONG_NUM 79 | bigger = DECIMAL_NUM 80 | } else { 81 | cmp = LONGLONG 82 | smaller = LONG_NUM 83 | bigger = ULONGLONG_NUM 84 | } 85 | } 86 | 87 | idx := 0 88 | for idx < len(cmp) && cmp[idx] == lex.buf[start] { 89 | if debug { 90 | DEBUG(fmt.Sprintf("cmp:[%c] buf[%c]\n", cmp[idx], lex.buf[start])) 91 | } 92 | idx += 1 93 | start += 1 94 | } 95 | 96 | if idx == len(cmp) { 97 | return smaller 98 | } 99 | 100 | if lex.buf[start] <= cmp[idx] { 101 | return smaller 102 | } 103 | return bigger 104 | } 105 | 106 | func (lex *SQLLexer) scanFloat(lval *MySQLSymType, c *byte) (int, bool) { 107 | cs := lex.cs 108 | 109 | // try match (+|-)? digit+ 110 | if lex.yyPeek() == '+' || lex.yyPeek() == '-' { 111 | lex.yySkip() // ignore this char 112 | } 113 | 114 | // at least we have 1 digit-char 115 | if cs.IsDigit(lex.yyPeek()) { 116 | for ; cs.IsDigit(lex.yyPeek()); lex.yySkip() { 117 | } 118 | 119 | lval.bytes = lex.buf[lex.tok_start:lex.ptr] 120 | return FLOAT_NUM, true 121 | } 122 | 123 | return 0, false 124 | } 125 | -------------------------------------------------------------------------------- /parser/lex_number_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestInt(t *testing.T) { 8 | testMatchReturn(t, `123456`, NUM, false) 9 | testMatchReturn(t, `0000000000000000000000000123456`, NUM, false) 10 | testMatchReturn(t, `2147483646`, NUM, false) // NUM 11 | testMatchReturn(t, `2147483647`, NUM, false) // 2^31 - 1 12 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 13 | testMatchReturn(t, `0000000000000000000002147483648`, LONG_NUM, false) // 2^31 14 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 15 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 16 | testMatchReturn(t, `2147483648`, LONG_NUM, false) // 2^31 17 | 18 | testMatchReturn(t, `9223372036854775807`, LONG_NUM, false) 19 | testMatchReturn(t, `9223372036854775808`, ULONGLONG_NUM, false) 20 | testMatchReturn(t, `18446744073709551615`, ULONGLONG_NUM, false) 21 | testMatchReturn(t, `18446744073709551616`, DECIMAL_NUM, false) 22 | } 23 | 24 | func TestNum(t *testing.T) { 25 | testMatchReturn(t, `0x1234`, HEX_NUM, false) 26 | testMatchReturn(t, `0xa4234`, HEX_NUM, false) 27 | testMatchReturn(t, `0b0110`, BIN_NUM, false) 28 | } 29 | 30 | func TestFloatNum(t *testing.T) { 31 | testMatchReturn(t, " 10e-10", FLOAT_NUM, false) 32 | testMatchReturn(t, " 10E+10", FLOAT_NUM, false) 33 | testMatchReturn(t, " 10E10", FLOAT_NUM, false) 34 | testMatchReturn(t, "1.20E10", FLOAT_NUM, false) 35 | testMatchReturn(t, "1.20E-10", FLOAT_NUM, false) 36 | } 37 | 38 | func TestDecimalNum(t *testing.T) { 39 | testMatchReturn(t, `.21`, DECIMAL_NUM, false) 40 | testMatchReturn(t, `72.21`, DECIMAL_NUM, false) 41 | } 42 | 43 | func TestHex(t *testing.T) { 44 | testMatchReturn(t, `X'4D7953514C'`, HEX_NUM, false) 45 | 46 | testMatchReturn(t, `x'D34F2X`, ABORT_SYM, false) 47 | testMatchReturn(t, `x'`, ABORT_SYM, false) 48 | 49 | } 50 | 51 | func TestBin(t *testing.T) { 52 | testMatchReturn(t, `b'0101010111000'`, BIN_NUM, false) 53 | testMatchReturn(t, `b'0S01010111000'`, ABORT_SYM, false) 54 | testMatchReturn(t, `b'12312351123`, ABORT_SYM, false) 55 | } 56 | 57 | func TestMultiNum(t *testing.T) { 58 | str := `123 'string1' 18446744073709551616 1.20E-10 .312 x'4D7953514C' ` 59 | lex, lval := getLexer(str) 60 | 61 | lexExpect(t, lex, lval, NUM) 62 | lvalExpect(t, lval, `123`) 63 | 64 | lexExpect(t, lex, lval, TEXT_STRING) 65 | lvalExpect(t, lval, `'string1'`) 66 | 67 | lexExpect(t, lex, lval, DECIMAL_NUM) 68 | lvalExpect(t, lval, `18446744073709551616`) 69 | 70 | lexExpect(t, lex, lval, FLOAT_NUM) 71 | lvalExpect(t, lval, `1.20E-10`) 72 | 73 | lexExpect(t, lex, lval, DECIMAL_NUM) 74 | lvalExpect(t, lval, `.312`) 75 | 76 | lexExpect(t, lex, lval, HEX_NUM) 77 | lvalExpect(t, lval, `x'4D7953514C'`) 78 | 79 | lexExpect(t, lex, lval, END_OF_INPUT) 80 | } 81 | 82 | func TestNumberInPlacehold(t *testing.T) { 83 | str := ` (5)` 84 | lex, lval := getLexer(str) 85 | lexExpect(t, lex, lval, '(') 86 | lexExpect(t, lex, lval, NUM) 87 | lexExpect(t, lex, lval, ')') 88 | } 89 | -------------------------------------------------------------------------------- /parser/lex_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func getLexer(str string) (lexer *SQLLexer, lval *MySQLSymType) { 8 | lval = new(MySQLSymType) 9 | lexer = NewSQLLexer(str) 10 | 11 | return 12 | } 13 | 14 | func testMatchReturn(t *testing.T, str string, match int, dbg bool) (*SQLLexer, *MySQLSymType) { 15 | setDebug(dbg) 16 | lexer, lval := getLexer(str) 17 | ret := lexer.Lex(lval) 18 | if ret != match { 19 | t.Fatalf("test failed! expect[%s] return[%s]", MySQLSymName(match), MySQLSymName(ret)) 20 | } 21 | 22 | return lexer, lval 23 | } 24 | 25 | func TestNULLEscape(t *testing.T) { 26 | lexer, lval := getLexer("\\N") 27 | if lexer.Lex(lval) != NULL_SYM { 28 | t.Fatal("test failed") 29 | } 30 | } 31 | 32 | func TestSingleComment(t *testing.T) { 33 | lexer, lval := getLexer(" -- Single Line Comment. \r\n") 34 | 35 | if lexer.Lex(lval) != END_OF_INPUT { 36 | t.Fatal("test failed") 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /parser/lex_text.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | /** 8 | * For Anltr3 Defination: 9 | 10 | SINGLE_QUOTED_TEXT 11 | @init { int escape_count = 0; }: 12 | SINGLE_QUOTE 13 | ( 14 | SINGLE_QUOTE SINGLE_QUOTE { escape_count++; } 15 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ESCAPE_OPERATOR . { escape_count++; } 16 | | {SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(SINGLE_QUOTE) 17 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(SINGLE_QUOTE | ESCAPE_OPERATOR) 18 | )* 19 | SINGLE_QUOTE 20 | { EMIT(); LTOKEN->user1 = escape_count; } 21 | ; 22 | 23 | DOUBLE_QUOTED_TEXT 24 | @init { int escape_count = 0; }: 25 | DOUBLE_QUOTE 26 | ( 27 | DOUBLE_QUOTE DOUBLE_QUOTE { escape_count++; } 28 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ESCAPE_OPERATOR . { escape_count++; } 29 | | {SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(DOUBLE_QUOTE) 30 | | {!SQL_MODE_ACTIVE(SQL_MODE_NO_BACKSLASH_ESCAPES)}? => ~(DOUBLE_QUOTE | ESCAPE_OPERATOR) 31 | )* 32 | DOUBLE_QUOTE 33 | { EMIT(); LTOKEN->user1 = escape_count; } 34 | ; 35 | */ 36 | 37 | var StringFormatError error = errors.New("text string format error") 38 | 39 | func (lexer *SQLLexer) getQuotedText() ([]byte, error) { 40 | var dq bool 41 | var sep byte 42 | 43 | if sep = lexer.yyLookHead(); sep == '"' { 44 | dq = true 45 | } 46 | 47 | for lexer.ptr < uint(len(lexer.buf)) { 48 | c := lexer.yyNext() 49 | 50 | if c == '\\' && !lexer.sqlMode.MODE_NO_BACKSLASH_ESCAPES { 51 | if lexer.yyPeek() == EOF { 52 | return nil, StringFormatError 53 | } 54 | 55 | lexer.yySkip() // skip next char 56 | } else if matchQuote(c, dq) { 57 | if matchQuote(lexer.yyPeek(), dq) { 58 | // found a escape quote. Eg. '' "" 59 | lexer.yySkip() // skip for the second quote 60 | continue 61 | } 62 | // we have found the last quote 63 | return lexer.buf[lexer.tok_start:lexer.ptr], nil 64 | } 65 | } 66 | 67 | return nil, StringFormatError 68 | } 69 | 70 | func matchQuote(c byte, double_quote bool) bool { 71 | if double_quote { 72 | return c == '"' 73 | } 74 | 75 | return c == '\'' 76 | } 77 | -------------------------------------------------------------------------------- /parser/lex_text_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func testTextParse(t *testing.T, str string, mode SQLMode) { 8 | lexer, lval := getLexer(str) 9 | lexer.sqlMode = mode 10 | if r := lexer.Lex(lval); r != TEXT_STRING { 11 | t.Fatalf("parse text failed. return[%s]", MySQLSymName(r)) 12 | } 13 | 14 | if string(lval.bytes) != str { 15 | t.Fatalf("orgin[%s] not match parsed[%s]", str, string(lval.bytes)) 16 | } 17 | } 18 | 19 | func TestSingleQuoteString(t *testing.T) { 20 | testMatchReturn(t, `'single Quoted string'`, TEXT_STRING, false) 21 | } 22 | 23 | func TestDoubleQuoteString(t *testing.T) { 24 | testMatchReturn(t, `"double quoted string"`, TEXT_STRING, false) 25 | } 26 | 27 | func TestAnsiQuotesSQLModeString(t *testing.T) { 28 | str := `'a' ' ' 'string'` 29 | lexer, lval := getLexer(str) 30 | lexer.sqlMode.MODE_ANSI_QUOTES = true 31 | 32 | if lexer.Lex(lval) != TEXT_STRING { 33 | t.Fatalf("parse ansi quotes string failed!") 34 | } 35 | 36 | } 37 | 38 | func TestSingleQuoteString3(t *testing.T) { 39 | testTextParse(t, `'afasgasdgasg'`, SQLMode{}) 40 | testTextParse(t, `'''afasgasdgasg'`, SQLMode{}) 41 | testTextParse(t, `''`, SQLMode{}) 42 | testTextParse(t, `""`, SQLMode{}) 43 | 44 | testTextParse(t, `'""hello""'`, SQLMode{}) 45 | testTextParse(t, `'hel''lo'`, SQLMode{}) 46 | testTextParse(t, `'\'hello'`, SQLMode{}) 47 | 48 | testTextParse(t, `'\''`, SQLMode{}) 49 | testTextParse(t, `'\'`, SQLMode{MODE_NO_BACKSLASH_ESCAPES: true}) 50 | } 51 | 52 | func TestStringException(t *testing.T) { 53 | str := `'\'` 54 | lexer, lval := getLexer(str) 55 | if r := lexer.Lex(lval); r != ABORT_SYM { 56 | t.Fatalf("parse text failed. return[%s]", MySQLSymNames[r-ABORT_SYM]) 57 | } 58 | 59 | lexer, lval = getLexer(`"\`) 60 | if r := lexer.Lex(lval); r != ABORT_SYM { 61 | t.Fatalf("parse text failed. return[%s]", MySQLSymNames[r-ABORT_SYM]) 62 | } 63 | } 64 | 65 | func TestNChar(t *testing.T) { 66 | testMatchReturn(t, `n'some text'`, NCHAR_STRING, false) 67 | testMatchReturn(t, `N'some text'`, NCHAR_STRING, false) 68 | 69 | testMatchReturn(t, `N'`, ABORT_SYM, false) 70 | } 71 | 72 | func lexExpect(t *testing.T, lexer *SQLLexer, lval *MySQLSymType, expect int) { 73 | if ret := lexer.Lex(lval); ret != expect { 74 | t.Fatalf("expect[%s] return[%s]", MySQLSymName(expect), MySQLSymName(ret)) 75 | } 76 | } 77 | 78 | func lvalExpect(t *testing.T, lval *MySQLSymType, expect string) { 79 | if string(lval.bytes) != expect { 80 | t.Fatalf("expect[%s] return[%s]", expect, string(lval.bytes)) 81 | } 82 | } 83 | 84 | func TestMultiString(t *testing.T) { 85 | str := `"string1" 'string2' 'string3' n'string 4' ` 86 | lex, lval := getLexer(str) 87 | 88 | lexExpect(t, lex, lval, TEXT_STRING) 89 | lvalExpect(t, lval, `"string1"`) 90 | 91 | lexExpect(t, lex, lval, TEXT_STRING) 92 | lvalExpect(t, lval, `'string2'`) 93 | 94 | lexExpect(t, lex, lval, TEXT_STRING) 95 | lvalExpect(t, lval, `'string3'`) 96 | 97 | lexExpect(t, lex, lval, NCHAR_STRING) 98 | lvalExpect(t, lval, `n'string 4'`) 99 | 100 | lexExpect(t, lex, lval, END_OF_INPUT) 101 | } 102 | -------------------------------------------------------------------------------- /parser/lex_var_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestHostName(t *testing.T) { 8 | setDebug(true) 9 | // testMatchReturn(t, `user@hostname`, LEX_HOSTNAME, true) 10 | } 11 | 12 | func TestSystemVariables(t *testing.T) { 13 | lexer, lval := testMatchReturn(t, `@@uservar`, '@', false) 14 | ret := lexer.Lex(lval) 15 | if ret != '@' { 16 | t.Fatalf("expect[IDENT_QUOTED] unexpect %s", MySQLSymName(ret)) 17 | } 18 | 19 | ret = lexer.Lex(lval) 20 | if ret != IDENT { 21 | t.Fatalf("expect[IDENT] unexpect %s", MySQLSymName(ret)) 22 | } 23 | } 24 | 25 | func TestUserDefinedVariables(t *testing.T) { 26 | lexer, lval := testMatchReturn(t, "@`uservar`", '@', false) 27 | ret := lexer.Lex(lval) 28 | if ret != IDENT_QUOTED { 29 | t.Fatalf("expect[IDENT_QUOTED] unexpect %s", MySQLSymName(ret)) 30 | } 31 | } 32 | 33 | func TestSetVarIdent(t *testing.T) { 34 | 35 | lexer, lval := testMatchReturn(t, "set @var=1", SET, false) 36 | 37 | lexExpect(t, lexer, lval, '@') 38 | 39 | lexExpect(t, lexer, lval, IDENT) 40 | lvalExpect(t, lval, "var") 41 | } 42 | -------------------------------------------------------------------------------- /parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | func Parse(sql string) (IStatement, error) { 8 | //TODO MEM used 70%in total 9 | lexer := NewSQLLexer(sql) 10 | if MySQLParse(lexer) != 0 { 11 | return nil, errors.New(lexer.LastError) 12 | } 13 | 14 | return lexer.ParseTree, nil 15 | } 16 | -------------------------------------------------------------------------------- /parser/parser_ddl_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestAlter(t *testing.T) { 8 | st := testParse(`alter view d1.v1 as select * from t2;`, t, false) 9 | matchSchemas(t, st, `d1`) 10 | 11 | st = testParse( 12 | `ALTER EVENT myschema.myevent 13 | ON SCHEDULE 14 | AT CURRENT_TIMESTAMP + INTERVAL 1 DAY 15 | DO 16 | TRUNCATE TABLE myschema.mytable;`, t, false) 17 | matchSchemas(t, st, `myschema`) 18 | 19 | st = testParse(`ALTER EVENT olddb.myevent RENAME TO newdb.myevent;`, t, false) 20 | matchSchemas(t, st, `olddb`, `newdb`) 21 | 22 | st = testParse(`ALTER SERVER s OPTIONS (USER 'sally');`, t, false) 23 | 24 | } 25 | 26 | func TestCreate(t *testing.T) { 27 | st := testParse(`CREATE DATABASE IF NOT EXISTS my_db default charset utf8 COLLATE utf8_general_ci;`, t, false) 28 | 29 | st = testParse(`CREATE EVENT mydb.myevent 30 | ON SCHEDULE AT CURRENT_TIMESTAMP + INTERVAL 1 HOUR 31 | DO 32 | UPDATE myschema.mytable SET mycol = mycol + 1;`, t, false) 33 | matchSchemas(t, st, `mydb`) 34 | 35 | st = testParse(`CREATE FUNCTION thisdb.hello (s CHAR(20)) RETURNS CHAR(50) DETERMINISTIC RETURN CONCAT('Hello, ',s,'!');`, t, false) 36 | matchSchemas(t, st, `thisdb`) 37 | 38 | st = testParse( 39 | `CREATE DEFINER = 'admin'@'localhost' PROCEDURE db1.account_count() 40 | SQL SECURITY INVOKER 41 | BEGIN 42 | SELECT 'Number of accounts:', COUNT(*) FROM mysql.user; 43 | END;`, t, false) 44 | matchSchemas(t, st, `db1`) 45 | 46 | st = testParse(`CREATE INDEX part_of_name ON customer (name(10));`, t, false) 47 | st = testParse(`CREATE INDEX id_index ON lookup (id) USING BTREE;`, t, false) 48 | st = testParse(`CREATE INDEX id_index ON t1 (id) COMMENT 'MERGE_THRESHOLD=40';`, t, false) 49 | 50 | st = testParse( 51 | `CREATE SERVER s FOREIGN DATA WRAPPER mysql 52 | OPTIONS (USER 'Remote', HOST '192.168.1.106', DATABASE 'test');`, t, false) 53 | 54 | st = testParse( 55 | `create view v1 as select s2,sum(s1) - count(s2) as vx 56 | from t1.t1 group by s2 having sum(s1) - count(s2) < (select f1() from t1.t2);`, t, false) 57 | matchSchemas(t, st, `t1`) 58 | } 59 | 60 | func TestCreateTable(t *testing.T) { 61 | st := testParse(`CREATE TABLE db1.t1 (col1 INT, col2 CHAR(5)) 62 | PARTITION BY HASH(col1);`, t, false) 63 | matchSchemas(t, st, `db1`) 64 | 65 | testParse(`CREATE TABLE t1 (col1 INT, col2 CHAR(5), col3 DATETIME) 66 | PARTITION BY HASH ( YEAR(col3) );`, t, false) 67 | testParse(`CREATE /*!32302 TEMPORARY */ TABLE t (a INT);`, t, false) 68 | 69 | testParse(`SELECT /*! STRAIGHT_JOIN */ col1 FROM table1,table2`, t, false) 70 | } 71 | 72 | func TestDrop(t *testing.T) { 73 | st := testParse(`DROP EVENT IF EXISTS db1.event_name`, t, false) 74 | matchSchemas(t, st, `db1`) 75 | 76 | st = testParse(`Drop Procedure If exists db1.sp_name`, t, false) 77 | matchSchemas(t, st, `db1`) 78 | 79 | st = testParse("DROP INDEX `PRIMARY` ON db1.t1;", t, false) 80 | matchSchemas(t, st, `db1`) 81 | 82 | testParse("Drop server if exists server_name", t, false) 83 | 84 | st = testParse("DROP TABLE IF EXISTS B.B, C.C, A.A;", t, false) 85 | matchSchemas(t, st, `B`, `C`, `A`) 86 | 87 | st = testParse("DROP TRIGGER schema_name.trigger_name;", t, false) 88 | matchSchemas(t, st, `schema_name`) 89 | } 90 | 91 | func TestOthers(t *testing.T) { 92 | st := testParse(`Truncate db1.table1`, t, false) 93 | matchSchemas(t, st, `db1`) 94 | 95 | testParse(`RENAME TABLE current_db.tbl_name TO other_db.tbl_name;`, t, false) 96 | } 97 | -------------------------------------------------------------------------------- /parser/parser_dml_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func fmtimport() { 10 | fmt.Println() 11 | } 12 | 13 | func matchType(t *testing.T, st IStatement, ref interface{}) { 14 | if reflect.TypeOf(st) != reflect.TypeOf(ref) { 15 | t.Fatalf("expect type[%v] not match[%v]", reflect.TypeOf(ref), reflect.TypeOf(st)) 16 | } 17 | } 18 | 19 | func matchSchemas(t *testing.T, st IStatement, tables ...string) { 20 | var ts []string 21 | 22 | switch ast := st.(type) { 23 | case *Select: 24 | ts = ast.GetSchemas() 25 | case *Union: 26 | ts = ast.GetSchemas() 27 | case *Insert: 28 | ts = ast.GetSchemas() 29 | case *Delete: 30 | ts = ast.GetSchemas() 31 | case *Update: 32 | ts = ast.GetSchemas() 33 | case *Replace: 34 | ts = ast.GetSchemas() 35 | case *AlterView: 36 | ts = ast.GetSchemas() 37 | case IDDLSchemas: 38 | ts = ast.GetSchemas() 39 | case *Lock: 40 | ts = ast.GetSchemas() 41 | case *DescribeTable: 42 | ts = ast.GetSchemas() 43 | case *DescribeStmt: 44 | ts = ast.GetSchemas() 45 | case ITableMtStmt: 46 | ts = ast.GetSchemas() 47 | case *CacheIndex: 48 | ts = ast.GetSchemas() 49 | case *LoadIndex: 50 | ts = ast.GetSchemas() 51 | case *FlushTables: 52 | ts = ast.GetSchemas() 53 | case IShowSchemas: 54 | ts = ast.GetSchemas() 55 | default: 56 | t.Fatalf("unknow statement type: %T", ast) 57 | } 58 | 59 | if len(tables) == 0 && len(ts) == 0 { 60 | return 61 | } else if len(tables) != len(ts) { 62 | t.Fatalf("expect table number[%d] not match return[%d]", len(tables), len(ts)) 63 | } 64 | 65 | for k, v := range ts { 66 | if v != tables[k] { 67 | t.Fatalf("expect table[%s] not match return[%s]", tables[k], v) 68 | } 69 | } 70 | 71 | } 72 | 73 | func TestSelect(t *testing.T) { 74 | st := testParse("SELECT * FROM table1;", t, false) 75 | matchSchemas(t, st) 76 | 77 | st = testParse("SELECT t1.* FROM (select * from db1.table1) as t1;", t, false) 78 | matchSchemas(t, st, "db1") 79 | 80 | st = testParse("SELECT sb1,sb2,sb3 \n FROM (SELECT s1 AS sb1, s2 AS sb2, s3*2 AS sb3 FROM db1.t1) AS sb \n WHERE sb1 > 1;", t, false) 81 | matchSchemas(t, st, "db1") 82 | 83 | st = testParse("SELECT AVG(SUM(column1)) FROM t1 GROUP BY column1;", t, false) 84 | matchSchemas(t, st) 85 | 86 | st = testParse("SELECT REPEAT('a',1) UNION SELECT REPEAT('b',10);", t, false) 87 | matchSchemas(t, st) 88 | 89 | st = testParse(`(SELECT a FROM db1.t1 WHERE a=10 AND B=1 ORDER BY a LIMIT 10) 90 | UNION 91 | (SELECT a FROM db2.t2 WHERE a=11 AND B=2 ORDER BY a LIMIT 10);`, t, false) 92 | matchSchemas(t, st, "db1", "db2") 93 | 94 | st = testParse(`SELECT funcs(s) 95 | FROM db1.table1 96 | LEFT OUTER JOIN db2.table2 97 | ON db1.table1.column_name=db2.table2.column_name;`, t, false) 98 | matchSchemas(t, st, "db1", "db2") 99 | 100 | st = testParse("SELECT * FROM db1.table1 LEFT JOIN db2.table2 ON table1.id=table2.id LEFT JOIN db3.table3 ON table2.id = table3.id for update", t, false) 101 | matchSchemas(t, st, "db1", "db2", "db3") 102 | 103 | if st.(*Select).LockType != LockType_ForUpdate { 104 | t.Fatalf("lock type is not For Update") 105 | } 106 | 107 | st = testParse(`select last_insert_id() as a`, t, false) 108 | st = testParse(`SELECT substr('''a''bc',0,3) FROM dual`, t, false) 109 | testParse(`SELECT /*mark for picman*/ * FROM filterd limit 1;`, t, false) 110 | 111 | testParse(`SELECT ?,?,? from t1;`, t, false) 112 | } 113 | 114 | func TestInsert(t *testing.T) { 115 | st := testParse(`INSERT INTO db1.tbl_temp2 (fld_id) 116 | SELECT tempdb.tbl_temp1.fld_order_id 117 | FROM tempdb.tbl_temp1 WHERE tbl_temp1.fld_order_id > 100;`, t, false) 118 | matchSchemas(t, st, "db1", "tempdb") 119 | } 120 | 121 | func TestUpdate(t *testing.T) { 122 | st := testParse(`UPDATE t1 SET col1 = col1 + 1, col2 = col1;`, t, false) 123 | matchSchemas(t, st) 124 | 125 | st = testParse("UPDATE `Table A`,`Table B` SET `Table A`.`text`=concat_ws('',`Table A`.`text`,`Table B`.`B-num`,\" from \",`Table B`.`date`,'/') WHERE `Table A`.`A-num` = `Table B`.`A-num`", t, false) 126 | matchSchemas(t, st) 127 | 128 | st = testParse(`UPDATE db1.items,db2.month SET items.price=month.price 129 | WHERE items.id=month.id;`, t, false) 130 | matchSchemas(t, st, "db1", "db2") 131 | } 132 | 133 | func TestDelete(t *testing.T) { 134 | st := testParse(`DELETE FROM db.somelog WHERE user = 'jcole' 135 | ORDER BY timestamp_column LIMIT 1;`, t, false) 136 | matchSchemas(t, st, "db") 137 | 138 | st = testParse(`DELETE FROM db1.t1, db2.t2 USING t1 INNER JOIN t2 INNER JOIN db3.t3 139 | WHERE t1.id=t2.id AND t2.id=t3.id;`, t, false) 140 | matchSchemas(t, st, "db1", "db2", "db3") 141 | 142 | st = testParse(`DELETE FROM a1, a2 USING db1.t1 AS a1 INNER JOIN t2 AS a2 143 | WHERE a1.id=a2.id;`, t, false) 144 | matchSchemas(t, st, "db1") 145 | } 146 | 147 | func TestReplace(t *testing.T) { 148 | st := testParse(`REPLACE INTO test2 VALUES (1, 'Old', '2014-08-20 18:47:00');`, t, false) 149 | matchSchemas(t, st) 150 | 151 | st = testParse(`REPLACE INTO dbname2.test2 VALUES (1, 'Old', '2014-08-20 18:47:00');`, t, false) 152 | matchSchemas(t, st, "dbname2") 153 | } 154 | -------------------------------------------------------------------------------- /parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func PrintTree(statement IStatement) { 9 | if statement == nil { 10 | fmt.Println(`(nil)`) 11 | } 12 | 13 | switch st := statement.(type) { 14 | case *Union: 15 | fmt.Printf("left: %+v right: %+v\n", st.Right) 16 | case *Select: 17 | fmt.Printf("From: %+v Lock: %+v\n", st.From, st.LockType) 18 | default: 19 | fmt.Println("Yet Unknow Statement:", st) 20 | } 21 | } 22 | 23 | func testParse(sql string, t *testing.T, dbg bool) IStatement { 24 | setDebug(dbg) 25 | if st, err := Parse(sql); err != nil { 26 | setDebug(false) 27 | t.Fatalf("%v", err) 28 | return nil 29 | } else { 30 | setDebug(false) 31 | return st 32 | } 33 | } 34 | 35 | func TestExplain(t *testing.T) { 36 | testParse("EXPLAIN SELECT f1(5)", t, false) 37 | testParse("EXPLAIN SELECT * FROM t1 AS a1, (SELECT BENCHMARK(1000000, MD5(NOW())));", t, false) 38 | } 39 | 40 | func TestParse(t *testing.T) { 41 | setDebug(false) 42 | if _, err := Parse("Select version()"); err != nil { 43 | t.Fatalf("%v", err) 44 | } 45 | } 46 | 47 | func TestTokenName(t *testing.T) { 48 | if name := MySQLSymName(ABORT_SYM); name == "" { 49 | t.Fatal("get token name error") 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /parser/parser_trans_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestTransaction(t *testing.T) { 8 | st := testParse(`Start Transaction WITH CONSISTENT SNAPSHOT`, t, false) 9 | matchType(t, st, &StartTrans{}) 10 | 11 | st = testParse(`BEGIN`, t, false) 12 | matchType(t, st, &Begin{}) 13 | 14 | st = testParse(`COMMIT WORk NO RELEASE`, t, false) 15 | matchType(t, st, &Commit{}) 16 | 17 | st = testParse(`rollback`, t, false) 18 | matchType(t, st, &Rollback{}) 19 | } 20 | 21 | func TestSavePoint(t *testing.T) { 22 | st := testParse(`Savepoint identifier`, t, false) 23 | matchType(t, st, &SavePoint{}) 24 | 25 | st = testParse(`rollback to identifier`, t, false) 26 | matchType(t, st, &Rollback{}) 27 | 28 | st = testParse(`release savepoint identifier`, t, false) 29 | matchType(t, st, &Release{}) 30 | } 31 | 32 | func TestLockTables(t *testing.T) { 33 | st := testParse(`LOCK TABLES tb1 AS alias1 read, db2.tb2 low_priority write`, t, false) 34 | matchType(t, st, &Lock{}) 35 | matchSchemas(t, st, `db2`) 36 | 37 | st = testParse(`UNLOCK TABLES`, t, false) 38 | matchType(t, st, &Unlock{}) 39 | } 40 | -------------------------------------------------------------------------------- /parser/parser_util_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestDesc(t *testing.T) { 8 | st := testParse(` DESCRIBE db1.tb1;`, t, false) 9 | matchType(t, st, &DescribeTable{}) 10 | matchSchemas(t, st, `db1`) 11 | 12 | st = testParse(`explain select * from db1.table1`, t, false) 13 | matchSchemas(t, st, `db1`) 14 | } 15 | 16 | func TestHelp(t *testing.T) { 17 | st := testParse(`help 'help me'`, t, false) 18 | matchType(t, st, &Help{}) 19 | } 20 | 21 | func TestUse(t *testing.T) { 22 | st := testParse(`use mydb`, t, false) 23 | matchType(t, st, &Use{}) 24 | 25 | if string(st.(*Use).DB) != `mydb` { 26 | t.Fatalf("expect [mydb] match[%s]", string(st.(*Use).DB)) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /parser/state/state.go: -------------------------------------------------------------------------------- 1 | package state 2 | 3 | import "fmt" 4 | 5 | const ( 6 | MY_LEX_START = iota 7 | MY_LEX_CHAR 8 | MY_LEX_IDENT 9 | MY_LEX_IDENT_SEP 10 | MY_LEX_IDENT_START 11 | MY_LEX_REAL 12 | MY_LEX_HEX_NUMBER 13 | MY_LEX_BIN_NUMBER 14 | MY_LEX_CMP_OP 15 | MY_LEX_LONG_CMP_OP 16 | MY_LEX_STRING 17 | MY_LEX_COMMENT 18 | MY_LEX_END 19 | MY_LEX_OPERATOR_OR_IDENT 20 | MY_LEX_NUMBER_IDENT 21 | MY_LEX_INT_OR_REAL 22 | MY_LEX_REAL_OR_POINT 23 | MY_LEX_BOOL 24 | MY_LEX_EOL 25 | MY_LEX_ESCAPE 26 | MY_LEX_LONG_COMMENT 27 | MY_LEX_END_LONG_COMMENT 28 | MY_LEX_SEMICOLON 29 | MY_LEX_SET_VAR 30 | MY_LEX_USER_END 31 | MY_LEX_HOSTNAME 32 | MY_LEX_SKIP 33 | MY_LEX_USER_VARIABLE_DELIMITER 34 | MY_LEX_SYSTEM_VAR 35 | MY_LEX_IDENT_OR_KEYWORD 36 | MY_LEX_IDENT_OR_HEX 37 | MY_LEX_IDENT_OR_BIN 38 | MY_LEX_IDENT_OR_NCHAR 39 | MY_LEX_STRING_OR_DELIMITER 40 | ) 41 | 42 | var statusMap map[uint]string = map[uint]string{ 43 | 44 | MY_LEX_START: "MY_LEX_START", 45 | MY_LEX_CHAR: "MY_LEX_CHAR", 46 | MY_LEX_IDENT: "MY_LEX_IDENT", 47 | MY_LEX_IDENT_SEP: "MY_LEX_IDENT_SEP", 48 | MY_LEX_IDENT_START: "MY_LEX_IDENT_START", 49 | MY_LEX_REAL: "MY_LEX_REAL", 50 | MY_LEX_HEX_NUMBER: "MY_LEX_HEX_NUMBER", 51 | MY_LEX_BIN_NUMBER: "MY_LEX_BIN_NUMBER", 52 | MY_LEX_CMP_OP: "MY_LEX_CMP_OP", 53 | MY_LEX_LONG_CMP_OP: "MY_LEX_LONG_CMP_OP", 54 | MY_LEX_STRING: "MY_LEX_STRING", 55 | MY_LEX_COMMENT: "MY_LEX_COMMENT", 56 | MY_LEX_END: "MY_LEX_END", 57 | MY_LEX_OPERATOR_OR_IDENT: "MY_LEX_OPERATOR_OR_IDENT", 58 | MY_LEX_NUMBER_IDENT: "MY_LEX_NUMBER_IDENT", 59 | MY_LEX_INT_OR_REAL: "MY_LEX_INT_OR_REAL", 60 | MY_LEX_REAL_OR_POINT: "MY_LEX_REAL_OR_POINT", 61 | MY_LEX_BOOL: "MY_LEX_BOOL", 62 | MY_LEX_EOL: "MY_LEX_EOL", 63 | MY_LEX_ESCAPE: "MY_LEX_ESCAPE", 64 | MY_LEX_LONG_COMMENT: "MY_LEX_LONG_COMMENT", 65 | MY_LEX_END_LONG_COMMENT: "MY_LEX_END_LONG_COMMENT", 66 | MY_LEX_SEMICOLON: "MY_LEX_SEMICOLON", 67 | MY_LEX_SET_VAR: "MY_LEX_SET_VAR", 68 | MY_LEX_USER_END: "MY_LEX_USER_END", 69 | MY_LEX_HOSTNAME: "MY_LEX_HOSTNAME", 70 | MY_LEX_SKIP: "MY_LEX_SKIP", 71 | MY_LEX_USER_VARIABLE_DELIMITER: "MY_LEX_USER_VARIABLE_DELIMITER", 72 | MY_LEX_SYSTEM_VAR: "MY_LEX_SYSTEM_VAR", 73 | MY_LEX_IDENT_OR_KEYWORD: "MY_LEX_IDENT_OR_KEYWORD", 74 | MY_LEX_IDENT_OR_HEX: "MY_LEX_IDENT_OR_HEX", 75 | MY_LEX_IDENT_OR_BIN: "MY_LEX_IDENT_OR_BIN", 76 | MY_LEX_IDENT_OR_NCHAR: "MY_LEX_IDENT_OR_NCHAR", 77 | MY_LEX_STRING_OR_DELIMITER: "MY_LEX_STRING_OR_DELIMITER", 78 | } 79 | 80 | func GetLexStatus(which uint) string { 81 | if v, ok := statusMap[which]; ok { 82 | return v 83 | } 84 | 85 | return fmt.Sprint("Unknow Status[%d]", which) 86 | } 87 | -------------------------------------------------------------------------------- /parser/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go test -coverprofile=$HOME/cover/coverage.out 4 | cd ~/cover/ && go tool cover -html=coverage.out -o coverage.html 5 | cp coverage.html /mnt/hgfs/ubuntu 6 | cd - 7 | -------------------------------------------------------------------------------- /pool/.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | -------------------------------------------------------------------------------- /pool/slice.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | ) 7 | 8 | const maxSliceType = 16 9 | const minSliceSize = 1 << 3 // 8 len slice 10 | const maxSliceSize = 1 << maxSliceType // 64k len slice 11 | 12 | type ( 13 | PoolI interface { 14 | Borrow(size int) interface{} 15 | Return(b interface{}) 16 | } 17 | 18 | // SliceSyncPool holds bufs. 19 | syncPool struct { 20 | capV int 21 | lenV int 22 | *sync.Pool 23 | } 24 | 25 | SliceSyncPool struct { 26 | pools []*syncPool 27 | 28 | New func(l int, c int) interface{} 29 | checkType func(interface{}) bool 30 | } 31 | ) 32 | 33 | func newSyncPool(NewFunc func(l int, c int) interface{}, lv int, cv int) *syncPool { 34 | p := new(syncPool) 35 | p.capV = cv 36 | p.lenV = lv 37 | p.Pool = &sync.Pool{New: func() interface{} { return NewFunc(p.lenV, p.capV) }} 38 | return p 39 | } 40 | 41 | func NewSliceSyncPool(NewFunc func(l int, c int) interface{}, check func(interface{}) bool) *SliceSyncPool { 42 | p := new(SliceSyncPool) 43 | 44 | p.New = NewFunc 45 | p.checkType = check 46 | 47 | p.pools = make([]*syncPool, maxSliceType+1) 48 | min := floorlog2(minSliceSize) 49 | max := floorlog2(maxSliceSize) 50 | for i := min; i <= max; i++ { 51 | // return 2^i size slice 52 | p.pools[i] = newSyncPool(NewFunc, 0, 1< maxSliceSize { 64 | return p.New(size, 2*size) 65 | } else if size < minSliceSize { 66 | // small than 8 len's slice all return 8 cap interface{} 67 | ret = p.borrow(floorlog2(minSliceSize)) 68 | } else { 69 | idx := floorlog2(uint(size)) 70 | if 1< maxSliceSize || v.Cap() < minSliceSize { 95 | return // too big or too small, let it go 96 | } 97 | 98 | idx := floorlog2(uint(v.Cap())) 99 | rs := 1 << uint(idx) 100 | p.pools[idx].Put(v.Slice3(0, rs, rs).Interface()) 101 | } 102 | -------------------------------------------------------------------------------- /pool/slice1.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | const cacheSliceCap = 10240 8 | 9 | type ( 10 | // SlicePool holds bufs. 11 | SlicePool struct { 12 | pools []chan interface{} 13 | 14 | New func(l int, c int) interface{} 15 | checkType func(interface{}) bool 16 | } 17 | ) 18 | 19 | func NewSlicePool(NewFunc func(l int, c int) interface{}, check func(i interface{}) bool) *SlicePool { 20 | p := new(SlicePool) 21 | 22 | p.New = NewFunc 23 | p.checkType = check 24 | 25 | p.pools = make([]chan interface{}, maxSliceType+1) 26 | min := floorlog2(minSliceSize) 27 | max := floorlog2(maxSliceSize) 28 | for i := min; i <= max; i++ { 29 | // return 2^i size slice 30 | p.pools[i] = make(chan interface{}, cacheSliceCap) 31 | } 32 | 33 | return p 34 | } 35 | 36 | // borrow a buf from the pool. 37 | func (p *SlicePool) Borrow(size int) interface{} { 38 | 39 | var ret interface{} 40 | 41 | if size > maxSliceSize { 42 | return p.New(size, 2*size) 43 | } else if size < minSliceSize { 44 | // small than 8 len's slice all return 8 cap interface{} 45 | ret = p.borrow(floorlog2(minSliceSize)) 46 | } else { 47 | idx := floorlog2(uint(size)) 48 | if 1< maxSliceSize || v.Cap() < minSliceSize { 78 | return // too big or too small, let it go 79 | } 80 | 81 | idx := floorlog2(uint(v.Cap())) 82 | rs := 1 << uint(idx) 83 | select { 84 | case p.pools[idx] <- v.Slice3(0, rs, rs).Interface(): 85 | default: 86 | // let it go, let it go 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /pool/slice1_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | ) 7 | 8 | func TestChanPoolConsist(t *testing.T) { 9 | 10 | isPool := NewSlicePool( 11 | func(l int, c int) interface{} { return make([]int, l, c) }, 12 | checkInts, 13 | ) 14 | 15 | testPoolConsist(isPool, t) 16 | } 17 | 18 | func TestHugePool(t *testing.T) { 19 | isPool := NewSlicePool( 20 | func(l int, c int) interface{} { return make([]int, l, c) }, 21 | checkInts, 22 | ) 23 | 24 | testHugePool(isPool, t) 25 | } 26 | 27 | func TestPoolEdgeCondition(t *testing.T) { 28 | bsPool := NewSlicePool( 29 | func(l int, c int) interface{} { return make([]byte, l, c) }, 30 | checkBytes, 31 | ) 32 | 33 | testPoolEdgeCondition(bsPool, t) 34 | } 35 | 36 | func TestDifferentTypePanic(t *testing.T) { 37 | 38 | bsPool := NewSlicePool( 39 | func(l int, c int) interface{} { return make([]byte, l, c) }, 40 | checkBytes, 41 | ) 42 | 43 | testDifferentTypePanic(bsPool, t) 44 | } 45 | func TestPoolFull(t *testing.T) { 46 | 47 | bsPool := NewSlicePool( 48 | func(l int, c int) interface{} { return make([]byte, l, c) }, 49 | checkBytes, 50 | ) 51 | 52 | for i := 0; i <= cacheSliceCap+1; i++ { 53 | bsPool.Return(make([]byte, 0, 8)) 54 | } 55 | } 56 | 57 | func BenchmarkSliceBorrowReturn(t *testing.B) { 58 | 59 | bytesPool := NewSlicePool( 60 | func(l int, c int) interface{} { return make([]byte, l, c) }, 61 | checkBytes, 62 | ) 63 | 64 | for i := 0; i < t.N; i++ { 65 | size := rand.Intn(maxSliceSize) 66 | if size == 0 { 67 | continue 68 | } 69 | 70 | v := bytesPool.Borrow(size) 71 | b, ok := v.([]byte) 72 | if !ok { 73 | t.Fatal(v, "is not slice type!") 74 | } 75 | 76 | if len(b) != size || cap(b) < len(b) { 77 | t.Fatal("length:", len(b), "is less than cap:", cap(b)) 78 | } else { 79 | bytesPool.Return(b) 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /pool/slice_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "math/rand" 5 | "reflect" 6 | "testing" 7 | "unsafe" 8 | ) 9 | 10 | func getBytes(i interface{}, tb testing.TB) []byte { 11 | var b []byte 12 | var ok bool 13 | if b, ok = i.([]byte); !ok { 14 | tb.Fatal(i, "is not bytes slice type!") 15 | } 16 | return b 17 | } 18 | 19 | func getInts(i interface{}, tb testing.TB) []int { 20 | var b []int 21 | var ok bool 22 | if b, ok = i.([]int); !ok { 23 | tb.Fatal(i, "is not int slice type!") 24 | } 25 | return b 26 | } 27 | 28 | func checkBytes(i interface{}) bool { 29 | _, ok := i.([]byte) 30 | return ok 31 | } 32 | 33 | func checkInts(i interface{}) bool { 34 | _, ok := i.([]int) 35 | return ok 36 | } 37 | 38 | func TestSyncPoolConsist(t *testing.T) { 39 | 40 | isSyncPool := NewSliceSyncPool( 41 | func(l int, c int) interface{} { return make([]int, l, c) }, 42 | checkInts, 43 | ) 44 | 45 | testPoolConsist(isSyncPool, t) 46 | } 47 | 48 | func testPoolConsist(isPool PoolI, t testing.TB) { 49 | 50 | contents := [8]int{1, 2, 3, 4, 5, 6, 7, 8} 51 | b := getInts(isPool.Borrow(0), t) 52 | b = append(b, contents[:]...) 53 | isPool.Return(b) 54 | 55 | nb := getInts(isPool.Borrow(8), t) 56 | 57 | if (*reflect.SliceHeader)(unsafe.Pointer(&nb)).Data != (*reflect.SliceHeader)(unsafe.Pointer(&b)).Data { 58 | t.Fatal("not the same underly buffer!") 59 | } 60 | } 61 | 62 | func TestSyncHugePool(t *testing.T) { 63 | isSyncPool := NewSliceSyncPool( 64 | func(l int, c int) interface{} { return make([]int, l, c) }, 65 | checkInts, 66 | ) 67 | 68 | testHugePool(isSyncPool, t) 69 | } 70 | 71 | func testHugePool(isPool PoolI, tb testing.TB) { 72 | b := getInts(isPool.Borrow(maxSliceSize+1), tb) 73 | isPool.Return(b) // should not pool this really big buffer 74 | 75 | nb := getInts(isPool.Borrow(maxSliceSize), tb) 76 | if (*reflect.SliceHeader)(unsafe.Pointer(&nb)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&b)).Data { 77 | tb.Fatal("these two buffer should be different underly array!") 78 | } 79 | } 80 | 81 | func TestSyncPoolEdgeCondition(t *testing.T) { 82 | bsSyncPool := NewSliceSyncPool( 83 | func(l int, c int) interface{} { return make([]byte, l, c) }, 84 | checkBytes, 85 | ) 86 | 87 | testPoolEdgeCondition(bsSyncPool, t) 88 | } 89 | 90 | func testPoolEdgeCondition(bsSyncPool PoolI, t testing.TB) { 91 | 92 | for i := 1; i <= minSliceSize; i++ { 93 | s := bsSyncPool.Borrow(i) 94 | b := getBytes(s, t) 95 | 96 | if len(b) != i { 97 | t.Fatal("len:", len(b), "not match required size:", i) 98 | } 99 | 100 | if cap(b) != minSliceSize { 101 | t.Fatal("cap:", cap(b), "not match minSliceSize:", minSliceSize) 102 | } 103 | 104 | bsSyncPool.Return(b) 105 | } 106 | 107 | for i := minSliceSize + 1; i <= maxSliceSize; i++ { 108 | s := bsSyncPool.Borrow(i) 109 | b := getBytes(s, t) 110 | 111 | if len(b) != i { 112 | t.Fatal("len:", len(b), "not match required size:", i) 113 | } 114 | 115 | fl := floorlog2(uint(i)) 116 | if 1< 1 { 6 | size >>= 1 7 | idx++ 8 | } 9 | return idx 10 | } 11 | -------------------------------------------------------------------------------- /proxy/auth.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/sha1" 7 | "io" 8 | 9 | . "github.com/bytedance/dbatman/database/mysql" 10 | "github.com/ngaut/log" 11 | ) 12 | 13 | func RandomBuf(size int) ([]byte, error) { 14 | 15 | buf := make([]byte, size) 16 | if debug { 17 | for i, _ := range buf { 18 | buf[i] = 0x01 19 | } 20 | 21 | return buf, nil 22 | } 23 | 24 | if _, err := io.ReadFull(rand.Reader, buf); err != nil { 25 | return nil, err 26 | } 27 | 28 | for i, b := range buf { 29 | if uint8(b) == 0 { 30 | buf[i] = '0' 31 | } 32 | } 33 | return buf, nil 34 | } 35 | 36 | func CalcPassword(scramble, password []byte) []byte { 37 | if len(password) == 0 { 38 | return nil 39 | } 40 | 41 | // stage1Hash = SHA1(password) 42 | crypt := sha1.New() 43 | crypt.Write(password) 44 | stage1 := crypt.Sum(nil) 45 | 46 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) 47 | // inner Hash 48 | crypt.Reset() 49 | crypt.Write(stage1) 50 | hash := crypt.Sum(nil) 51 | 52 | // outer Hash 53 | crypt.Reset() 54 | crypt.Write(scramble) 55 | crypt.Write(hash) 56 | scramble = crypt.Sum(nil) 57 | 58 | // token = scrambleHash XOR stage1Hash 59 | for i := range scramble { 60 | scramble[i] ^= stage1[i] 61 | } 62 | return scramble 63 | } 64 | 65 | func (session *Session) CheckAuth(username string, passwd []byte, db string) error { 66 | 67 | var err error 68 | //check the global authip 69 | gc, err := session.config.GetGlobalConfig() 70 | cliAddr := session.cliAddr 71 | if gc.AuthIPActive == true { 72 | if len(gc.AuthIPs) > 0 { 73 | //TODO white and black ip logic 74 | globalAuthIp := &gc.AuthIPs 75 | authIpFlag := false 76 | for _, ip := range *globalAuthIp { 77 | if ip == cliAddr { 78 | authIpFlag = true 79 | break 80 | } 81 | } 82 | 83 | if authIpFlag != true { 84 | // log.Info("This user's Ip is not in the list of User's auth_Ip") 85 | return NewDefaultError(ER_NO, "IP Is not in the auth_ip list of the global config") 86 | } 87 | 88 | } 89 | } 90 | //global auth pass 91 | // There is no user named with parameter username 92 | if session.user, err = session.config.GetUserByName(username); err != nil { 93 | if session.user == nil { 94 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, username, session.fc.RemoteAddr().String(), "Yes") 95 | } 96 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, session.user.Username, session.fc.RemoteAddr().String(), "Yes") 97 | } 98 | 99 | if db != "" && session.user.DBName != db { 100 | log.Debugf("request db: %s, user's db: %s", db, session.user.DBName) 101 | return NewDefaultError(ER_BAD_DB_ERROR, db) 102 | } 103 | //TODO add the IP auth module to check the global auth_ip 104 | //check user config auth_IP with current Session Ip 105 | if gc.AuthIPActive == true { 106 | authIpFlag := false 107 | if len(session.user.AuthIPs) > 0 { 108 | userIPs := &session.user.AuthIPs 109 | 110 | // log.Debug("client IP : ", session.cliAddr) 111 | // log.Debug("User's Auth IP is: ", userIPs) 112 | for _, ip := range *userIPs { 113 | if cliAddr == ip { 114 | authIpFlag = true 115 | break 116 | } 117 | } 118 | } 119 | if authIpFlag != true { 120 | log.Debug("This user's Ip is not in the list of User's auth_Ip") 121 | 122 | return NewDefaultError(ER_NO, "IP Is not in the auth_ip list of the user") 123 | } 124 | 125 | } 126 | if !bytes.Equal(passwd, CalcPassword(session.salt, []byte(session.user.Password))) { 127 | return NewDefaultError(ER_ACCESS_DENIED_ERROR, session.user.Username, session.fc.RemoteAddr().String(), "Yes") 128 | } 129 | if err := session.useDB(session.user.DBName); err != nil { 130 | return err 131 | } 132 | 133 | return nil 134 | } 135 | -------------------------------------------------------------------------------- /proxy/com_query_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestProxy_Query(t *testing.T) { 8 | 9 | db := newSqlDB(testProxyDSN) 10 | defer db.Close() 11 | 12 | if rs, err := db.Exec(` 13 | CREATE TABLE IF NOT EXISTS go_proxy_test_proxy_conn ( 14 | id BIGINT(64) UNSIGNED NOT NULL, 15 | str VARCHAR(256), 16 | f DOUBLE, 17 | e enum("test1", "test2"), 18 | u tinyint unsigned, 19 | i tinyint, 20 | ni tinyint, 21 | PRIMARY KEY (id) 22 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`); err != nil { 23 | t.Fatal("create table failed: ", err) 24 | } else if rows, err := rs.RowsAffected(); err != nil { 25 | t.Fatal("create table failed: ", err) 26 | } else if rows != 0 { 27 | t.Fatal("ddl should have no affected rows") 28 | } 29 | 30 | if rs, err := db.Exec(` 31 | insert into go_proxy_test_proxy_conn (id, str, f, e, u, i) values( 32 | 1, 33 | "abc", 34 | 3.14, 35 | "test1", 36 | 255, 37 | -127)`); err != nil { 38 | t.Fatal("insert failed: ", err) 39 | } else if rows, err := rs.RowsAffected(); err != nil { 40 | t.Fatal("insert failed: ", err) 41 | } else if rows != 1 { 42 | t.Fatalf("expect insert 1 rows, got %d", rows) 43 | } 44 | 45 | if rs, err := db.Exec(` 46 | update go_proxy_test_proxy_conn 47 | set str="abcde", f=3.1415926, e="test2", u=128, i=126 48 | where id=1`); err != nil { 49 | t.Fatal("update failed: ", err) 50 | } else if rows, err := rs.RowsAffected(); err != nil { 51 | t.Fatal("update failed: ", err) 52 | } else if rows != 1 { 53 | t.Fatalf("expect update 1 rows, got %d", rows) 54 | } 55 | 56 | if rs, err := db.Exec(` 57 | insert into go_proxy_test_proxy_conn (id, str, f, e, u, i) values( 58 | 2, 59 | "abc", 60 | 3.14, 61 | "test1", 62 | 255, 63 | -127)`); err != nil { 64 | t.Fatal("insert failed: ", err) 65 | } else if rows, err := rs.RowsAffected(); err != nil { 66 | t.Fatal("insert failed: ", err) 67 | } else if rows != 1 { 68 | t.Fatalf("expect insert 1 rows, got %d", rows) 69 | } 70 | 71 | if rs, err := db.Exec(`delete from go_proxy_test_proxy_conn where id = 1 or id = 2`); err != nil { 72 | t.Fatal("delete failed: ", err) 73 | } else if rows, err := rs.RowsAffected(); err != nil { 74 | t.Fatal("delete failed: ", err) 75 | } else if rows != 2 { 76 | t.Fatalf("expect delete 2 rows, got %d", rows) 77 | } 78 | } 79 | 80 | func TestProxy_QueryFailed(t *testing.T) { 81 | 82 | db := newSqlDB(testProxyDSN) 83 | defer db.Close() 84 | 85 | if _, err := db.Exec(` 86 | update go_proxy_test_proxy_conn 87 | set str="abcde", f=3.1415926, e="test2", u=128, i=255 88 | when id=1`); err == nil { 89 | t.Fatal("syntax error sql expect error, but go ok") 90 | } 91 | } 92 | 93 | func TestProxy_QueryWithInfo(t *testing.T) { 94 | 95 | db := newSqlDB(testProxyDSN) 96 | defer db.Close() 97 | 98 | if _, err := db.Exec(` 99 | CREATE TABLE test (a INT, b INT, c INT, UNIQUE (A), UNIQUE(B))`); err != nil { 100 | t.Fatalf("create table failed: %s", err) 101 | } 102 | 103 | res, err := db.Exec("INSERT test VALUES (1,2,10), (3,4,20)") 104 | if err != nil { 105 | t.Fatalf("insert table failed: %s", err) 106 | } 107 | 108 | count, err := res.RowsAffected() 109 | if err != nil { 110 | t.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 111 | } 112 | if count != 2 { 113 | t.Fatalf("expected 2 affected row, got %d", count) 114 | } 115 | 116 | // Create Data With Duplicate 117 | res, err = db.Exec("INSERT test VALUES (5,6,30), (7,4,40), (8,9,60) ON DUPLICATE KEY UPDATE c=c+100;") 118 | if err != nil { 119 | t.Fatalf("insert table failed: %s", err) 120 | } 121 | 122 | count, err = res.RowsAffected() 123 | if err != nil { 124 | t.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 125 | } 126 | if count != 4 { 127 | t.Fatalf("expected 4 affected row, got %d", count) 128 | } 129 | 130 | info, _ := res.Info() 131 | 132 | if len(info) == 0 { 133 | t.Fatal("expected duplicate message, got empty string") 134 | } 135 | 136 | } 137 | 138 | func TestProxy_Use(t *testing.T) { 139 | 140 | db := newSqlDB(testProxyDSN) 141 | defer db.Close() 142 | 143 | if _, err := db.Exec("use dbatman_test"); err != nil { 144 | t.Fatalf("use dbatman_test failed: %s", err.Error()) 145 | } 146 | 147 | if _, err := db.Exec("use mysql"); err == nil { 148 | t.Fatalf("use mysql for this user expect deny, got pass") 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /proxy/conn.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/bytedance/dbatman/database/mysql" 7 | "github.com/bytedance/dbatman/database/sql/driver" 8 | ) 9 | 10 | // Wrap the connection 11 | type SqlConn struct { 12 | master *mysql.DB 13 | slave *mysql.DB 14 | stmts map[uint32]*mysql.Stmt 15 | tx *mysql.Tx 16 | 17 | session *Session 18 | } 19 | 20 | func (bc *SqlConn) begin(s *Session) error { 21 | if bc.tx != nil { 22 | return errors.New("duplicate begin") 23 | } 24 | 25 | var err error 26 | var s_i driver.SessionI = s 27 | bc.tx, err = bc.master.Begin(s_i) 28 | if err != nil { 29 | return err 30 | } 31 | 32 | return nil 33 | } 34 | 35 | func (bc *SqlConn) commit(inAutoCommit bool) error { 36 | if bc.tx == nil { 37 | return errors.New("unexpect commit") 38 | } 39 | 40 | defer func() { 41 | if inAutoCommit { 42 | bc.tx = nil 43 | } 44 | }() 45 | 46 | if err := bc.tx.Commit(inAutoCommit); err != nil { 47 | // fmt.Println("commit err :", err) 48 | return err 49 | } 50 | 51 | return nil 52 | } 53 | 54 | func (bc *SqlConn) rollback(inAutoCommit bool) error { 55 | if bc.tx == nil { 56 | return errors.New("unexpect rollback") 57 | } 58 | 59 | defer func() { 60 | if inAutoCommit { 61 | bc.tx = nil 62 | } 63 | }() 64 | 65 | if err := bc.tx.Rollback(inAutoCommit); err != nil { 66 | return err 67 | } 68 | 69 | return nil 70 | } 71 | 72 | func (session *Session) Executor(isread bool) mysql.Executor { 73 | 74 | // TODO set autocommit 75 | if session.isInTransaction() { 76 | return session.bc.tx 77 | } 78 | 79 | if isread { 80 | return session.bc.slave 81 | } 82 | 83 | return session.bc.master 84 | } 85 | -------------------------------------------------------------------------------- /proxy/conn_resultset.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/bytedance/dbatman/hack" 6 | "strconv" 7 | ) 8 | 9 | func formatValue(value interface{}) ([]byte, error) { 10 | switch v := value.(type) { 11 | case int8: 12 | return strconv.AppendInt(nil, int64(v), 10), nil 13 | case int16: 14 | return strconv.AppendInt(nil, int64(v), 10), nil 15 | case int32: 16 | return strconv.AppendInt(nil, int64(v), 10), nil 17 | case int64: 18 | return strconv.AppendInt(nil, int64(v), 10), nil 19 | case int: 20 | return strconv.AppendInt(nil, int64(v), 10), nil 21 | case uint8: 22 | return strconv.AppendUint(nil, uint64(v), 10), nil 23 | case uint16: 24 | return strconv.AppendUint(nil, uint64(v), 10), nil 25 | case uint32: 26 | return strconv.AppendUint(nil, uint64(v), 10), nil 27 | case uint64: 28 | return strconv.AppendUint(nil, uint64(v), 10), nil 29 | case uint: 30 | return strconv.AppendUint(nil, uint64(v), 10), nil 31 | case float32: 32 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 33 | case float64: 34 | return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil 35 | case []byte: 36 | return v, nil 37 | case string: 38 | return hack.Slice(v), nil 39 | default: 40 | return nil, fmt.Errorf("invalid type %T", value) 41 | } 42 | } 43 | 44 | /* 45 | func formatField(field *Field, value interface{}) error { 46 | switch value.(type) { 47 | case int8, int16, int32, int64, int: 48 | field.Charset = 63 49 | field.Type = MYSQL_TYPE_LONGLONG 50 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG 51 | case uint8, uint16, uint32, uint64, uint: 52 | field.Charset = 63 53 | field.Type = MYSQL_TYPE_LONGLONG 54 | field.Flag = BINARY_FLAG | NOT_NULL_FLAG | UNSIGNED_FLAG 55 | case string, []byte: 56 | field.Charset = 33 57 | field.Type = MYSQL_TYPE_VAR_STRING 58 | default: 59 | return fmt.Errorf("unsupport type %T for resultset", value) 60 | } 61 | return nil 62 | } 63 | 64 | func (c *Session) buildResultset(names []string, values [][]interface{}) (*ResultSet, error) { 65 | r := new(Resultset) 66 | 67 | r.Fields = make([]*Field, len(names)) 68 | 69 | var b []byte 70 | var err error 71 | 72 | for i, vs := range values { 73 | if len(vs) != len(r.Fields) { 74 | return nil, fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields)) 75 | } 76 | 77 | var row []byte 78 | for j, value := range vs { 79 | if i == 0 { 80 | field := &Field{} 81 | r.Fields[j] = field 82 | field.Name = hack.Slice(names[j]) 83 | 84 | if err = formatField(field, value); err != nil { 85 | return nil, err 86 | } 87 | } 88 | b, err = formatValue(value) 89 | 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | row = append(row, PutLengthEncodedString(b)...) 95 | } 96 | 97 | r.RowDatas = append(r.RowDatas, row) 98 | } 99 | 100 | return r, nil 101 | } 102 | */ 103 | 104 | /* 105 | func (c *Session) writeResultset(status uint16, r *ResultSet) error { 106 | c.affectedRows = int64(-1) 107 | 108 | columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) 109 | 110 | data := make([]byte, 4, 1024) 111 | 112 | data = append(data, columnLen...) 113 | if err := c.writePacket(data); err != nil { 114 | return err 115 | } 116 | 117 | for _, v := range r.Fields { 118 | data = data[0:4] 119 | data = append(data, v.Dump()...) 120 | if err := c.writePacket(data); err != nil { 121 | return err 122 | } 123 | } 124 | 125 | if err := c.writeEOF(status); err != nil { 126 | return err 127 | } 128 | 129 | for _, v := range r.RowDatas { 130 | data = data[0:4] 131 | data = append(data, v...) 132 | if err := c.writePacket(data); err != nil { 133 | return err 134 | } 135 | } 136 | 137 | if err := c.writeEOF(status); err != nil { 138 | return err 139 | } 140 | 141 | return nil 142 | } 143 | */ 144 | -------------------------------------------------------------------------------- /proxy/conn_select.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/bytedance/dbatman/parser" 5 | "github.com/ngaut/log" 6 | ) 7 | 8 | func (session *Session) handleQuery(stmt parser.IStatement, sqlstmt string) error { 9 | 10 | if err := session.checkDB(stmt); err != nil { 11 | log.Debugf("check db error: %s", err.Error()) 12 | return err 13 | } 14 | 15 | isread := false 16 | 17 | if s, ok := stmt.(parser.ISelect); ok { 18 | isread = !s.IsLocked() 19 | } else if _, sok := stmt.(parser.IShow); sok { 20 | isread = true 21 | } 22 | 23 | rs, err := session.Executor(isread).Query(sqlstmt) 24 | // TODO here should handler error 25 | if err != nil { 26 | return session.handleMySQLError(err) 27 | } 28 | 29 | defer rs.Close() 30 | return session.writeRows(rs) 31 | } 32 | -------------------------------------------------------------------------------- /proxy/conn_set.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | 6 | "strings" 7 | 8 | . "github.com/bytedance/dbatman/database/mysql" 9 | "github.com/bytedance/dbatman/parser" 10 | "github.com/ngaut/log" 11 | ) 12 | 13 | func (c *Session) handleSet(stmt *parser.Set, sql string) error { 14 | if len(stmt.VarList) < 1 { 15 | return fmt.Errorf("must set one item at least") 16 | } 17 | 18 | var err error 19 | for _, v := range stmt.VarList { 20 | if strings.ToUpper(v.Name) == "AUTOCOMMIT" { 21 | log.Debug("handle autocommit") 22 | err = c.handleSetAutoCommit(v.Value) //?? 23 | } 24 | } 25 | 26 | if err != nil { 27 | return err 28 | } 29 | 30 | defer func() { 31 | //only execute when the autocommit 0->1 //clear 32 | if c.autoCommit == 1 { 33 | log.Debug("clear autocommit tx") 34 | c.clearAutoCommitTx() 35 | } 36 | 37 | }() 38 | return c.handleOtherSet(stmt, sql) 39 | } 40 | 41 | func (c *Session) clearAutoCommitTx() { 42 | // clear the AUTOCOMMIT status 43 | if _, err := c.bc.tx.Exec("set autocommit = 1"); err != nil { 44 | log.Warnf("session id :%d,clear autocommit errr", c.sessionId, err) 45 | //don;t need to put conn back 46 | return 47 | } 48 | c.fc.XORStatus(uint16(StatusInAutocommit)) 49 | //clear the backend conn's Tx status; 50 | //put conn back to free conn 51 | if err := c.bc.rollback(c.isAutoCommit()); err != nil { 52 | log.Warnf("session %d clear autocommit err:%s: ", c.sessionId, err.Error()) 53 | } 54 | c.fc.AndStatus(^uint16(StatusInTrans)) 55 | c.autoCommit = 0 56 | } 57 | 58 | func (c *Session) handleSetAutoCommit(val parser.IExpr) error { 59 | 60 | var stmt *parser.Predicate 61 | var ok bool 62 | if stmt, ok = val.(*parser.Predicate); !ok { 63 | return fmt.Errorf("set autocommit is not support for complicate expressions") 64 | } 65 | 66 | switch value := stmt.Expr.(type) { 67 | case parser.NumVal: 68 | if i, err := value.ParseInt(); err != nil { 69 | return err 70 | } else if i == 1 { 71 | // 72 | if c.isAutoCommit() { 73 | return nil 74 | } 75 | 76 | //inply the tx cleanUp step after last query c.handleOtherSet(stmt, sql) 77 | c.autoCommit = 1 //indicate 0 -> 1 78 | //TODO when previous handle error need 79 | 80 | log.Debug("autocommit is set") 81 | } else if i == 0 { 82 | // indicate a transection 83 | //current is autocommit = true do nothing 84 | if !c.isAutoCommit() { 85 | return nil 86 | } 87 | c.fc.AndStatus(^uint16(StatusInAutocommit)) 88 | ////atuocommit 1->0 start a transection 89 | err := c.bc.begin(c) 90 | if err != nil { 91 | log.Debug(err) 92 | } 93 | c.fc.XORStatus(uint16(StatusInTrans)) 94 | c.autoCommit = 2 // indicate 1 -> zero 95 | // log.Debug("start a transection") 96 | // log.Debug("auto commit is unset") 97 | } else { 98 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of '%s'", i) 99 | } 100 | case parser.StrVal: 101 | if s := value.Trim(); s == "" { 102 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of ''") 103 | } else if us := strings.ToUpper(s); us == `ON` { 104 | c.fc.XORStatus(uint16(StatusInAutocommit)) 105 | log.Debug("auto commit is set") 106 | // return c.handleBegin() 107 | } else if us == `OFF` { 108 | c.fc.AndStatus(^uint16(StatusInAutocommit)) 109 | log.Debug("auto commit is unset") 110 | } else { 111 | return fmt.Errorf("Variable 'autocommit' can't be set to the value of '%s'", us) 112 | } 113 | default: 114 | return fmt.Errorf("set autocommit error, value type is %T", val) 115 | } 116 | 117 | return nil 118 | } 119 | 120 | func (c *Session) handleSetNames(val parser.IValExpr) error { 121 | value, ok := val.(parser.StrVal) 122 | if !ok { 123 | return fmt.Errorf("set names charset error") 124 | } 125 | 126 | charset := strings.ToLower(string(value)) 127 | cid, ok := CharsetIds[charset] 128 | if !ok { 129 | return fmt.Errorf("invalid charset %s", charset) 130 | } 131 | 132 | c.fc.SetCollation(cid) 133 | 134 | return c.fc.WriteOK(nil) 135 | } 136 | 137 | func (c *Session) handleOtherSet(stmt parser.IStatement, sql string) error { 138 | return c.handleExec(stmt, sql, false) 139 | } 140 | -------------------------------------------------------------------------------- /proxy/conn_show.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/bytedance/dbatman/database/mysql" 7 | "github.com/bytedance/dbatman/hack" 8 | "github.com/bytedance/dbatman/parser" 9 | "github.com/ngaut/log" 10 | ) 11 | 12 | func (session *Session) handleShow(sqlstmt string, stmt parser.IShow) error { 13 | var err error 14 | 15 | switch stmt.(type) { 16 | case *parser.ShowDatabases: 17 | err = session.handleShowDatabases() 18 | default: 19 | err = session.handleQuery(stmt, sqlstmt) 20 | } 21 | 22 | if err != nil { 23 | return session.handleMySQLError(err) 24 | } 25 | 26 | return nil 27 | } 28 | 29 | func (session *Session) handleFieldList(data []byte) error { 30 | index := bytes.IndexByte(data, 0x00) 31 | table := string(data[0:index]) 32 | wildcard := string(data[index+1:]) 33 | 34 | rs, err := session.bc.master.FieldList(table, wildcard) 35 | // TODO here should handler error 36 | if err != nil { 37 | return session.handleMySQLError(err) 38 | } 39 | 40 | defer rs.Close() 41 | 42 | return session.writeFieldList(rs) 43 | } 44 | 45 | func (session *Session) writeFieldList(rs mysql.Rows) error { 46 | 47 | cols, err := rs.ColumnPackets() 48 | 49 | if err != nil { 50 | return session.handleMySQLError(err) 51 | } 52 | 53 | // Write Columns Packet 54 | for _, col := range cols { 55 | if err := session.fc.WritePacket(col); err != nil { 56 | log.Debugf("write columns packet error %v", err) 57 | return err 58 | } 59 | } 60 | 61 | // TODO Write a ok packet 62 | if err = session.fc.WriteEOF(); err != nil { 63 | return err 64 | } 65 | 66 | return nil 67 | } 68 | 69 | func (session *Session) handleShowDatabases() error { 70 | dbs := make([]interface{}, 0, 1) 71 | dbs = append(dbs, session.user.DBName) 72 | 73 | if r, err := session.buildSimpleShowResultset(dbs, "Database"); err != nil { 74 | return err 75 | } else { 76 | return session.writeRows(r) 77 | } 78 | } 79 | 80 | func (session *Session) buildSimpleShowResultset(values []interface{}, name string) (mysql.Rows, error) { 81 | 82 | r := new(SimpleRows) 83 | 84 | r.Cols = []*mysql.MySQLField{ 85 | &mysql.MySQLField{ 86 | Name: hack.Slice(name), 87 | Charset: uint16(session.fc.Collation()), 88 | FieldType: mysql.FieldTypeVarString, 89 | }, 90 | } 91 | 92 | var row []byte 93 | var err error 94 | 95 | for _, value := range values { 96 | row, err = formatValue(value) 97 | if err != nil { 98 | return nil, err 99 | } 100 | 101 | r.Rows = append(r.Rows, mysql.AppendLengthEncodedString(make([]byte, 0, len(row)+9), row)) 102 | } 103 | 104 | return r, nil 105 | } 106 | -------------------------------------------------------------------------------- /proxy/conn_show_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestProxy_Show(t *testing.T) { 8 | 9 | db := newSqlDB(testProxyDSN) 10 | defer db.Close() 11 | 12 | if q, err := db.Query("show databases"); err != nil { 13 | t.Fatalf("show databases failed: %s", err.Error()) 14 | } else { 15 | q.Next() 16 | var database string 17 | if err := q.Scan(&database); err != nil { 18 | t.Fatalf("show databases got error %s", err) 19 | } else if database != "dbatman_test" { 20 | t.Fatalf("expect %s, got %s", "dbatman", database) 21 | } 22 | } 23 | 24 | if _, err := db.Query("show tables"); err != nil { 25 | t.Fatalf("show tables failed: %s", err.Error()) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /proxy/conn_tx.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import . "github.com/bytedance/dbatman/database/mysql" 4 | 5 | func (c *Session) isInTransaction() bool { 6 | return c.fc.Status()&uint16(StatusInTrans) > 0 7 | } 8 | 9 | func (c *Session) isAutoCommit() bool { 10 | return c.fc.Status()&uint16(StatusInAutocommit) > 0 11 | } 12 | 13 | func (c *Session) handleBegin() error { 14 | 15 | // We already in transaction 16 | if c.isInTransaction() { 17 | return c.fc.WriteOK(nil) 18 | } 19 | 20 | c.fc.XORStatus(uint16(StatusInTrans)) 21 | if err := c.bc.begin(c); err != nil { 22 | return c.handleMySQLError(err) 23 | } 24 | 25 | return c.fc.WriteOK(nil) 26 | } 27 | 28 | func (c *Session) handleCommit() (err error) { 29 | 30 | if !c.isInTransaction() { 31 | return c.fc.WriteOK(nil) 32 | } 33 | 34 | defer func() { 35 | if c.isInTransaction() { 36 | if c.isAutoCommit() { 37 | c.fc.AndStatus(uint16(^StatusInTrans)) 38 | // fmt.Println("close the proxy tx") 39 | } 40 | } 41 | }() 42 | 43 | // fmt.Println("commit") 44 | // fmt.Println("this is a autocommit tx:", !c.isAutoCommit()) 45 | if err := c.bc.commit(c.isAutoCommit()); err != nil { 46 | return c.handleMySQLError(err) 47 | } else { 48 | return c.fc.WriteOK(nil) 49 | } 50 | } 51 | 52 | func (c *Session) handleRollback() (err error) { 53 | if !c.isInTransaction() { 54 | return c.fc.WriteOK(nil) 55 | } 56 | 57 | defer func() { 58 | if c.isInTransaction() { 59 | if c.isAutoCommit() { 60 | c.fc.AndStatus(uint16(^StatusInTrans)) 61 | // fmt.Println("close the proxy tx") 62 | } 63 | } 64 | }() 65 | // fmt.Println("rollback") 66 | // fmt.Println("this is a autocommit tx:", !c.isAutoCommit()) 67 | if err := c.bc.rollback(c.isAutoCommit()); err != nil { 68 | return c.handleMySQLError(err) 69 | } 70 | 71 | return c.fc.WriteOK(nil) 72 | } 73 | -------------------------------------------------------------------------------- /proxy/conn_tx_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | var inAutoCommit bool = true 8 | 9 | func TestProxy_Tx(t *testing.T) { 10 | db := newSqlDB(testProxyDSN) 11 | if _, err := db.Exec(` 12 | CREATE TABLE IF NOT EXISTS dbatman_test_tx ( 13 | id BIGINT(64) UNSIGNED NOT NULL, 14 | str VARCHAR(256), 15 | PRIMARY KEY (id) 16 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8`); err != nil { 17 | t.Fatal("create tx table failed: ", err) 18 | } 19 | 20 | tx, err := db.Begin() 21 | if err != nil { 22 | t.Fatalf("start transaction failed: %s", err) 23 | } 24 | 25 | if rs, err := tx.Exec(`insert into dbatman_test_tx values( 26 | 1, 27 | "abc")`); err != nil { 28 | tx.Rollback(inAutoCommit) 29 | t.Fatalf("insert in transaction failed: %s", err) 30 | } else if rn, err := rs.RowsAffected(); err != nil { 31 | tx.Rollback(inAutoCommit) 32 | t.Fatalf("insert failed: %s", err) 33 | } else if rn != 1 { 34 | tx.Rollback(inAutoCommit) 35 | t.Fatalf("expect 1 rows, got %d", rn) 36 | } 37 | if _, err := tx.Exec(`savepoint a1`); err != nil { 38 | t.Fatalf("save point faied", err) 39 | } 40 | if rs, err := tx.Query("select * from dbatman_test_tx"); err != nil { 41 | t.Fatalf("select in trans failed: %s", err) 42 | } else { 43 | var row int 44 | for rs.Next() { 45 | row += 1 46 | } 47 | 48 | if row != 1 { 49 | t.Fatalf("expect 1 rows after transaction, got %d", row) 50 | } 51 | } 52 | 53 | if rs, err := tx.Exec(`insert into dbatman_test_tx values( 54 | 2,'def')`); err != nil { 55 | tx.Rollback(inAutoCommit) 56 | t.Fatalf("insert in transaction failed: %s", err) 57 | } else if rn, err := rs.RowsAffected(); err != nil { 58 | tx.Rollback(inAutoCommit) 59 | t.Fatalf("insert failed: %s", err) 60 | } else if rn != 1 { 61 | tx.Rollback(inAutoCommit) 62 | t.Fatalf("expect 1 rows, got %d", rn) 63 | } 64 | if rs, err := tx.Query("select * from dbatman_test_tx"); err != nil { 65 | t.Fatalf("select in trans failed: %s", err) 66 | } else { 67 | var row int 68 | for rs.Next() { 69 | row += 1 70 | } 71 | 72 | if row != 2 { 73 | t.Fatalf("expect 2 rows after transaction, got %d", row) 74 | } 75 | } 76 | 77 | if _, err := tx.Exec(`rollback to a1`); err != nil { 78 | t.Fatalf("rollback to faild", err) 79 | } 80 | 81 | if rs, err := tx.Query("select * from dbatman_test_tx"); err != nil { 82 | t.Fatalf("select in trans failed: %s", err) 83 | } else { 84 | var row int 85 | for rs.Next() { 86 | row += 1 87 | } 88 | 89 | if row != 1 { 90 | t.Fatalf("expect 0 rows after transaction, got %d", row) 91 | } 92 | } 93 | // add savepoint 94 | if err := tx.Rollback(inAutoCommit); err != nil { 95 | t.Fatalf("rollback in trans failed: %s", err) 96 | } 97 | tx, err = db.Begin() 98 | 99 | if err := tx.Rollback(inAutoCommit); err != nil { 100 | t.Fatalf("rollback in trans failed: %s", err) 101 | } 102 | 103 | if rs, err := db.Query("select * from dbatman_test_tx"); err != nil { 104 | t.Fatalf("select after trans failed: %s", err) 105 | } else { 106 | var row int 107 | for rs.Next() { 108 | row += 1 109 | } 110 | 111 | if row > 0 { 112 | t.Fatalf("expect none rows after transaction, got %d", row) 113 | } 114 | } 115 | 116 | tx, err = db.Begin() 117 | if err != nil { 118 | t.Fatalf("start transaction failed: %s", err) 119 | } 120 | 121 | if rs, err := tx.Exec(`insert into dbatman_test_tx values( 122 | 1, 123 | "abc")`); err != nil { 124 | tx.Rollback(inAutoCommit) 125 | t.Fatalf("insert in transaction failed: %s", err) 126 | } else if rn, err := rs.RowsAffected(); err != nil { 127 | tx.Rollback(inAutoCommit) 128 | t.Fatalf("insert failed: %s", err) 129 | } else if rn != 1 { 130 | tx.Rollback(inAutoCommit) 131 | t.Fatalf("expect 1 rows, got %d", rn) 132 | } 133 | 134 | if rs, err := tx.Query("select * from dbatman_test_tx"); err != nil { 135 | t.Fatalf("select in trans failed: %s", err) 136 | } else { 137 | var row int 138 | for rs.Next() { 139 | row += 1 140 | } 141 | 142 | if row != 1 { 143 | t.Fatalf("expect 1 rows after transaction, got %d", row) 144 | } 145 | } 146 | 147 | if err := tx.Commit(inAutoCommit); err != nil { 148 | t.Fatalf("commit in trans failed: %s", err) 149 | } 150 | 151 | if rs, err := db.Query("select * from dbatman_test_tx"); err != nil { 152 | t.Fatalf("select after trans failed: %s", err) 153 | } else { 154 | var row int 155 | for rs.Next() { 156 | row += 1 157 | } 158 | 159 | if row != 1 { 160 | t.Fatalf("expect 1 rows after transaction, got %d", row) 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /proxy/debug.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "io/ioutil" 5 | ) 6 | 7 | var debug bool = true 8 | 9 | func tmpFile(content []byte) (string, error) { 10 | 11 | tmpfile, err := ioutil.TempFile("", "tmp") 12 | if err != nil { 13 | return "", err 14 | } 15 | 16 | if _, err := tmpfile.Write(content); err != nil { 17 | return "", err 18 | } 19 | if err := tmpfile.Close(); err != nil { 20 | return "", err 21 | } 22 | 23 | return tmpfile.Name(), nil 24 | } 25 | -------------------------------------------------------------------------------- /proxy/dispatch.go: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package proxy 13 | 14 | import ( 15 | "fmt" 16 | "io" 17 | 18 | "strings" 19 | 20 | "github.com/bytedance/dbatman/database/cluster" 21 | "github.com/bytedance/dbatman/database/mysql" 22 | "github.com/bytedance/dbatman/database/sql/driver" 23 | "github.com/bytedance/dbatman/hack" 24 | "github.com/ngaut/log" 25 | ) 26 | 27 | func (session *Session) dispatch(data []byte) (err error) { 28 | cmd := data[0] 29 | data = data[1:] 30 | 31 | defer func() { 32 | flush_error := session.fc.Flush() 33 | if err == nil { 34 | err = flush_error 35 | } 36 | }() 37 | 38 | switch cmd { 39 | case mysql.ComQuery: 40 | err = session.comQuery(hack.String(data)) 41 | case mysql.ComPing: 42 | err = session.fc.WriteOK(nil) 43 | case mysql.ComInitDB: 44 | if err := session.useDB(hack.String(data)); err != nil { 45 | err = session.handleMySQLError(err) 46 | } else { 47 | err = session.fc.WriteOK(nil) 48 | } 49 | case mysql.ComFieldList: 50 | err = session.handleFieldList(data) 51 | case mysql.ComStmtPrepare: 52 | err = session.handleComStmtPrepare(hack.String(data)) 53 | case mysql.ComStmtExecute: 54 | err = session.handleComStmtExecute(data) 55 | case mysql.ComStmtClose: 56 | err = session.handleComStmtClose(data) 57 | case mysql.ComStmtSendLongData: 58 | err = session.handleComStmtSendLongData(data) 59 | case mysql.ComStmtReset: 60 | err = session.handleComStmtReset(data) 61 | default: 62 | msg := fmt.Sprintf("command %d not supported now", cmd) 63 | log.Warnf(msg) 64 | err = mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR, msg) 65 | } 66 | 67 | return 68 | } 69 | 70 | func proceDbName(db string) string { 71 | ret := db 72 | // filter the `` of the `db` 73 | if strings.Contains(db, "`") { 74 | log.Debug("db name error :,", db) 75 | a := strings.Split(db, "`") 76 | ret = a[1] 77 | } 78 | return ret 79 | 80 | } 81 | func (session *Session) useDB(dbName string) error { 82 | // log.Info("use db: ", dbName) 83 | // log.Info("transfer db", proceDbName(db)) 84 | db := proceDbName(dbName) 85 | if session.cluster != nil { 86 | if session.cluster.DBName != db { 87 | // log.Debug("er1,:", session.cluster.DBName) 88 | return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) 89 | } 90 | 91 | return nil 92 | } 93 | 94 | if _, err := session.config.GetClusterByDBName(db); err != nil { 95 | // log.Debug("er2,:", err) 96 | return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) 97 | } else if session.cluster, err = cluster.New(session.user.ClusterName); err != nil { 98 | // log.Debug("er3,:", err) 99 | return err 100 | } 101 | 102 | if session.bc == nil { 103 | master, err := session.cluster.Master() 104 | if err != nil { 105 | // log.Debug("er3,:", err) 106 | return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) 107 | } 108 | slave, err := session.cluster.Slave() 109 | if err != nil { 110 | slave = master 111 | } 112 | session.bc = &SqlConn{ 113 | master: master, 114 | slave: slave, 115 | stmts: make(map[uint32]*mysql.Stmt), 116 | tx: nil, 117 | session: session, 118 | } 119 | } 120 | 121 | return nil 122 | } 123 | 124 | func (session *Session) IsAutoCommit() bool { 125 | return session.fc.Status()&uint16(mysql.StatusInAutocommit) > 0 126 | } 127 | 128 | func (session *Session) writeRows(rs mysql.Rows) error { 129 | var cols []driver.RawPacket 130 | var err error 131 | cols, err = rs.ColumnPackets() 132 | 133 | if err != nil { 134 | return session.handleMySQLError(err) 135 | } 136 | 137 | // Send a packet contains column length 138 | data := make([]byte, 4, 32) 139 | data = mysql.AppendLengthEncodedInteger(data, uint64(len(cols))) 140 | if err = session.fc.WritePacket(data); err != nil { 141 | return err 142 | } 143 | 144 | // Write Columns Packet 145 | for _, col := range cols { 146 | if err := session.fc.WritePacket(col); err != nil { 147 | log.Debugf("write columns packet error %v", err) 148 | return err 149 | } 150 | } 151 | 152 | // TODO Write a ok packet 153 | if err = session.fc.WriteEOF(); err != nil { 154 | return err 155 | } 156 | 157 | for { 158 | packet, err := rs.NextRowPacket() 159 | // var p []byte = packet 160 | // defer mysql.SysBytePool.Return([]byte(packet)) 161 | 162 | // Handle Error 163 | 164 | //warnging if in cli_deprecate_mode will get a ok_packet 165 | if err != nil { 166 | if err == io.EOF { 167 | return session.fc.WriteEOF() 168 | } else { 169 | return session.handleMySQLError(err) 170 | } 171 | } 172 | 173 | if err := session.fc.WritePacket(packet); err != nil { 174 | return err 175 | } 176 | } 177 | 178 | return nil 179 | } 180 | 181 | func (session *Session) handleMySQLError(e error) error { 182 | 183 | switch inst := e.(type) { 184 | case *mysql.MySQLError: 185 | session.fc.WriteError(inst) 186 | return nil 187 | default: 188 | log.Warnf("default error: %T %s", e, e) 189 | return e 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /proxy/dispatch_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/bytedance/dbatman/database/mysql" 5 | "testing" 6 | ) 7 | 8 | func TestProxy_ComPing(t *testing.T) { 9 | conn := newRawProxyConn(t) 10 | 11 | if err := conn.WriteCommandPacket(mysql.ComPing); err != nil { 12 | t.Fatal(err) 13 | } 14 | 15 | // Test ComQuit 16 | defer conn.Close() 17 | 18 | _, err := conn.ReadPacket() 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | } 23 | 24 | func TestProxy_ComInitDB(t *testing.T) { 25 | conn := newRawProxyConn(t) 26 | 27 | if err := conn.WriteCommandPacketStr(mysql.ComInitDB, "dbatman_test"); err != nil { 28 | t.Fatal(err) 29 | } 30 | 31 | // should receive an ok result 32 | if err := conn.ReadResultOK(); err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | if err := conn.WriteCommandPacketStr(mysql.ComInitDB, "db_no_exist"); err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | if err := conn.ReadResultOK(); err == nil { 41 | t.Fatal("expect an error result packet") 42 | } else if e, ok := err.(*mysql.MySQLError); !ok { 43 | t.Fatal(err) 44 | } else if e.Number != 1049 { 45 | t.Fatal("expect an Unknow DB error") 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /proxy/rows.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "errors" 5 | . "github.com/bytedance/dbatman/database/mysql" 6 | "github.com/bytedance/dbatman/database/sql/driver" 7 | "io" 8 | ) 9 | 10 | // SimpleRows implements sql.Row 11 | type SimpleRows struct { 12 | Cols []*MySQLField 13 | Rows []driver.RawPacket 14 | rowsi int 15 | } 16 | 17 | // Next() bool 18 | // NextRowPacket() (driver.RawPacket, error) 19 | // ColumnPackets() ([]driver.RawPacket, error) 20 | // Scan(dest ...interface{}) error 21 | // Close() error 22 | // Err() error 23 | 24 | func (rs *SimpleRows) Next() bool { 25 | return rs.rowsi < len(rs.Rows) 26 | } 27 | 28 | func (rs *SimpleRows) NextRowPacket() (driver.RawPacket, error) { 29 | if rs.rowsi >= len(rs.Rows) { 30 | return nil, io.EOF 31 | } 32 | 33 | ret := make([]byte, PacketHeaderLen, len(rs.Rows[rs.rowsi])+PacketHeaderLen) 34 | ret = append(ret, rs.Rows[rs.rowsi]...) 35 | rs.rowsi += 1 36 | return ret, nil 37 | } 38 | 39 | func (rs *SimpleRows) ColumnPackets() ([]driver.RawPacket, error) { 40 | pkgs := make([]driver.RawPacket, len(rs.Cols)) 41 | 42 | for i, column := range rs.Cols { 43 | pkgs[i] = driver.RawPacket(column.Dump()) 44 | } 45 | 46 | return pkgs, nil 47 | } 48 | 49 | func (rs *SimpleRows) Columns() ([]string, error) { 50 | return nil, errors.New("SimpleRows does not support Columns operations") 51 | } 52 | 53 | func (rs *SimpleRows) Scan(dest ...interface{}) error { 54 | return errors.New("SimpleRows does not support Scan operations") 55 | } 56 | 57 | func (rs *SimpleRows) Close() error { 58 | return nil 59 | } 60 | 61 | func (rs *SimpleRows) Err() error { 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /proxy/server.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | "runtime" 8 | "time" 9 | 10 | "sync" 11 | "syscall" 12 | 13 | "github.com/bytedance/dbatman/config" 14 | _ "github.com/bytedance/dbatman/database/mysql" 15 | "github.com/ngaut/log" 16 | ) 17 | 18 | var startNum = 0 19 | var closeNum = 0 20 | 21 | const defaultSessionIDChannelSize = 4096 22 | 23 | var sessionChan = make(chan int64, defaultSessionIDChannelSize) 24 | 25 | type LimitReqNode struct { 26 | excess int64 27 | last int64 28 | query string 29 | count int64 30 | lastSecond int64 //Last second to refresh the excess? 31 | 32 | start int64 //qps start time by millsecond 33 | lastcount int64 //last count rep num means qps 34 | currentcount int64 //repnum in current 1s dperiod 35 | } 36 | 37 | type Ip struct { 38 | ip string 39 | mu sync.Mutex 40 | printfinger map[string]*LimitReqNode 41 | } 42 | type User struct { 43 | user string 44 | iplist map[string]*Ip 45 | } 46 | type Server struct { 47 | cfg *config.Conf 48 | 49 | // nodes map[string]*Node 50 | 51 | // schemas map[string]*Schema 52 | 53 | // users *userAuth 54 | mu *sync.Mutex 55 | // users map[string]*User 56 | //qps base on fingerprint 57 | fingerprints map[string]*LimitReqNode 58 | //qps base on server 59 | qpsOnServer *LimitReqNode 60 | listener net.Listener 61 | running bool 62 | restart bool 63 | wg sync.WaitGroup 64 | } 65 | 66 | func NewServer(cfg *config.Conf) (*Server, error) { 67 | s := new(Server) 68 | 69 | s.cfg = cfg 70 | 71 | var err error 72 | 73 | s.fingerprints = make(map[string]*LimitReqNode) 74 | // s.users = make(map[string]*User) 75 | // s.qpsOnServer = &LimitReqNode{} 76 | s.mu = &sync.Mutex{} 77 | s.restart = false 78 | port := s.cfg.GetConfig().Global.Port 79 | 80 | // get listenfd from file when restart 81 | if os.Getenv("_GRACEFUL_RESTART") == "true" { 82 | log.Info("graceful restart with previous listenfd") 83 | 84 | //get the linstenfd 85 | file := os.NewFile(3, "") 86 | s.listener, err = net.FileListener(file) 87 | if err != nil { 88 | log.Warn("get linstener err ") 89 | } 90 | 91 | } else { 92 | s.listener, err = net.Listen("tcp4", fmt.Sprintf(":%d", port)) 93 | } 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | log.Infof("Dbatman Listen(tcp4) at [%d]", port) 99 | return s, nil 100 | } 101 | 102 | func (s *Server) Serve() error { 103 | log.Debug("this is ddbatman v4") 104 | s.running = true 105 | var sessionId int64 = 0 106 | for s.running { 107 | select { 108 | case sessionChan <- sessionId: 109 | //do nothing 110 | default: 111 | //warnning! 112 | log.Warnf("TASK_CHANNEL is full!") 113 | } 114 | 115 | conn, err := s.Accept() 116 | if err != nil { 117 | log.Warning("accept error %s", err.Error()) 118 | continue 119 | } 120 | //allocate a sessionId for a session 121 | go s.onConn(conn) 122 | sessionId += 1 123 | } 124 | if s.restart == true { 125 | log.Debug("Begin to restart graceful") 126 | listenerFile, err := s.listener.(*net.TCPListener).File() 127 | if err != nil { 128 | log.Fatal("Fail to get socket file descriptor:", err) 129 | } 130 | listenerFd := listenerFile.Fd() 131 | 132 | os.Setenv("_GRACEFUL_RESTART", "true") 133 | execSpec := &syscall.ProcAttr{ 134 | Env: os.Environ(), 135 | Files: []uintptr{os.Stdin.Fd(), os.Stdout.Fd(), os.Stderr.Fd(), listenerFd}, 136 | } 137 | fork, err := syscall.ForkExec(os.Args[0], os.Args, execSpec) 138 | if err != nil { 139 | return fmt.Errorf("failed to forkexec: %v", err) 140 | } 141 | 142 | log.Infof("start new process success, pid %d.", fork) 143 | } 144 | timeout := time.NewTimer(time.Minute) 145 | wait := make(chan struct{}) 146 | go func() { 147 | s.wg.Wait() 148 | wait <- struct{}{} 149 | }() 150 | 151 | select { 152 | case <-timeout.C: 153 | log.Error("server : Waittimeout error when close the service") 154 | return nil 155 | case <-wait: 156 | log.Info("server : all goroutine has been done") 157 | return nil 158 | } 159 | return nil 160 | } 161 | func (s *Server) Accept() (net.Conn, error) { 162 | 163 | conn, err := s.listener.Accept() 164 | if err != nil { 165 | return nil, err 166 | } 167 | // tc.SetKeepAlive(true) 168 | // tc.SetKeepAlivePeriod(3 * time.Minute) 169 | 170 | s.wg.Add(1) 171 | startNum += 1 172 | // log.Info("wait group add 1 total is :", startNum) 173 | 174 | return conn, nil 175 | } 176 | 177 | // TODO check this function if it need routine-safe 178 | func (s *Server) Close() { 179 | s.running = false 180 | if s.listener != nil { 181 | s.listener.Close() 182 | s.listener = nil 183 | } 184 | } 185 | func (s *Server) Restart() { 186 | s.running = false 187 | s.restart = true 188 | if s.listener != nil { 189 | //s.listener.Close() 190 | //s.listener = nil 191 | } 192 | } 193 | 194 | func (s *Server) onConn(c net.Conn) { 195 | session := s.newSession(c) 196 | 197 | defer func() { 198 | if !debug { 199 | if err := recover(); err != nil { 200 | const size = 4096 201 | buf := make([]byte, size) 202 | buf = buf[:runtime.Stack(buf, false)] 203 | log.Fatalf("onConn panic %v: %v\n%s", c.RemoteAddr().String(), err, buf) 204 | } 205 | } 206 | 207 | session.Close() 208 | }() 209 | // Handshake error, here we do not need to close the conn 210 | if err := session.Handshake(); err != nil { 211 | log.Warnf("session %d handshake error: %s", session.sessionId, err) 212 | return 213 | } 214 | 215 | if err := session.Run(); err != nil { 216 | // TODO 217 | 218 | // session.WriteError(NewDefaultError(err)) 219 | session.Close() 220 | if err == errSessionQuit { 221 | 222 | log.Warnf("session %d: %s", session.sessionId, err.Error()) 223 | // return 224 | } 225 | closeNum += 1 226 | s.wg.Done() 227 | log.Info("wait group add 1 total is :", closeNum) 228 | log.Warnf("session %d:session run error: %s", session.sessionId, err.Error()) 229 | return 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /proxy/server_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "github.com/bytedance/dbatman/config" 6 | "github.com/bytedance/dbatman/database/cluster" 7 | "github.com/bytedance/dbatman/database/mysql" 8 | "github.com/ngaut/log" 9 | "os" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | var testServerOnce sync.Once 16 | var testServer *Server 17 | var testServerError error 18 | 19 | var testClusterOnce sync.Once 20 | var testCluster *cluster.Cluster 21 | var testClusterError error 22 | 23 | var proxyConfig *config.ProxyConfig 24 | 25 | var testConfigData = []byte(` 26 | global: 27 | port: 4307 28 | manage_port: 4308 29 | max_connections: 10 30 | log_filename: ./log/dbatman.log 31 | log_level: 31 32 | log_maxsize: 1024 33 | log_query_min_time: 0 34 | client_timeout: 1800 35 | server_timeout: 1800 36 | write_time_interval: 10 37 | conf_autoload: 1 38 | auth_ips: 39 | 40 | clusters: 41 | dbatman_test_cluster: 42 | master: 43 | host: 127.0.0.1 44 | port: 3306 45 | username: root 46 | password: 47 | dbname: dbatman_test 48 | charset: utf8mb4 49 | max_connections: 100 50 | max_connection_pool_size: 10 51 | connect_timeout: 10 52 | time_reconnect_interval: 10 53 | weight: 1 54 | slaves: 55 | - host: 127.0.0.1 56 | port: 3306 57 | username: root 58 | password: 59 | dbname: dbatman_test 60 | charset: utf8mb4 61 | max_connections: 100 62 | max_connection_pool_size: 10 63 | connect_timeout: 10 64 | time_reconnect_interval: 10 65 | weight: 1 66 | 67 | users: 68 | proxy_mysql_user: 69 | username: proxy_mysql_user 70 | password: proxy_mysql_passwd 71 | max_connections: 1000 72 | min_connections: 100 73 | dbname: dbatman_test 74 | charset: utf8mb4 75 | cluster_name: dbatman_test_cluster 76 | auth_ips: 77 | - 127.0.0.1 78 | black_list_ips: 79 | - 10.1.1.3 80 | - 10.1.1.4 81 | `) 82 | 83 | var testDBDSN = "root:@tcp(127.0.0.1:3306)/mysql" 84 | var testProxyDSN = "proxy_mysql_user:proxy_mysql_passwd@tcp(127.0.0.1:4307)/dbatman_test" 85 | 86 | func newTestServer() (*Server, error) { 87 | f := func() { 88 | 89 | path, err := tmpFile(testConfigData) 90 | if err != nil { 91 | testServer, testServerError = nil, err 92 | return 93 | } 94 | 95 | defer os.Remove(path) // clean up tmp file 96 | 97 | cfg, err := config.LoadConfig(path) 98 | if err != nil { 99 | testServer, testServerError = nil, err 100 | return 101 | } 102 | 103 | if err := cluster.Init(cfg); err != nil { 104 | testServer, testServerError = nil, err 105 | return 106 | } 107 | 108 | log.SetLevel(log.LogLevel(cfg.GetConfig().Global.LogLevel)) 109 | mysql.SetLogger(log.Logger()) 110 | 111 | testServer, err = NewServer(cfg) 112 | if err != nil { 113 | testServer, testServerError = nil, err 114 | return 115 | } 116 | 117 | go testServer.Serve() 118 | 119 | time.Sleep(1 * time.Second) 120 | } 121 | 122 | testServerOnce.Do(f) 123 | 124 | return testServer, testServerError 125 | } 126 | 127 | func newTestCluster(cluster_name string) (*cluster.Cluster, error) { 128 | if _, err := newTestServer(); err != nil { 129 | testCluster, testClusterError = nil, err 130 | } 131 | 132 | f := func() { 133 | testCluster, testClusterError = cluster.New(cluster_name) 134 | } 135 | 136 | testClusterOnce.Do(f) 137 | return testCluster, testClusterError 138 | } 139 | 140 | func newTestDB(t *testing.T) *mysql.DB { 141 | cls, err := newTestCluster("dbatman_test_cluster") 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | 146 | db, err := cls.Master() 147 | 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | 152 | if err := db.Ping(); err != nil { 153 | t.Fatal(err) 154 | } 155 | return db 156 | } 157 | 158 | // return a direct connection to proxy server, this is a 159 | func newRawProxyConn(t *testing.T) *mysql.MySQLConn { 160 | newTestServer() 161 | 162 | d := mysql.MySQLDriver{} 163 | 164 | if conn, err := d.Open(testProxyDSN); err != nil { 165 | t.Fatal(err) 166 | } else if c, ok := conn.(*mysql.MySQLConn); !ok { 167 | t.Fatal("connection is not MySQLConn type") 168 | } else { 169 | return c 170 | } 171 | 172 | return nil 173 | } 174 | 175 | // return a direct connection to proxy server, this is a 176 | func newSqlDB(dsn string) *mysql.DB { 177 | 178 | db, err := mysql.Open("dbatman", dsn) 179 | if err != nil { 180 | fmt.Fprintf(os.Stderr, "%s is unavailable", dsn) 181 | os.Exit(2) 182 | } 183 | 184 | if err := db.Ping(); err != nil { 185 | fmt.Fprintf(os.Stderr, "%s is unreacheable", dsn) 186 | os.Exit(2) 187 | } 188 | 189 | return db 190 | } 191 | 192 | func TestMain(m *testing.M) { 193 | // Init dbatman_test database 194 | 195 | db := newSqlDB(testDBDSN) 196 | 197 | // Create DataBase dbatman_test 198 | if _, err := db.Exec("DROP DATABASE IF EXISTS `dbatman_test`"); err != nil { 199 | fmt.Fprintln(os.Stderr, "create database `dbatman_test` failed: ", err.Error()) 200 | os.Exit(2) 201 | } 202 | 203 | // Create DataBase dbatman_test 204 | if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `dbatman_test`"); err != nil { 205 | fmt.Fprintln(os.Stderr, "create database `dbatman_test` failed: ", err.Error()) 206 | os.Exit(2) 207 | } 208 | 209 | if _, err := newTestServer(); err != nil { 210 | fmt.Fprintln(os.Stderr, "setup proxy server failed: ", err.Error()) 211 | os.Exit(2) 212 | } 213 | 214 | if _, err := newTestCluster("dbatman_test_cluster"); err != nil { 215 | fmt.Fprintln(os.Stderr, "setup proxy -> cluster failed: ", err.Error()) 216 | os.Exit(2) 217 | } 218 | 219 | exit := m.Run() 220 | 221 | // Clear Up Database 222 | 223 | if _, err := db.Exec("DROP DATABASE IF EXISTS `dbatman_test`"); err != nil { 224 | fmt.Fprintln(os.Stderr, "drop database `dbatman_test` failed: ", err.Error()) 225 | os.Exit(2) 226 | } 227 | 228 | os.Exit(exit) 229 | } 230 | -------------------------------------------------------------------------------- /proxy/session.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 ByteDance, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package proxy 15 | 16 | import ( 17 | "errors" 18 | "fmt" 19 | "net" 20 | "strings" 21 | 22 | "github.com/bytedance/dbatman/cmd/version" 23 | "github.com/bytedance/dbatman/config" 24 | "github.com/bytedance/dbatman/database/cluster" 25 | . "github.com/bytedance/dbatman/database/mysql" 26 | "github.com/bytedance/dbatman/database/sql/driver" 27 | "github.com/bytedance/dbatman/hack" 28 | "github.com/ngaut/log" 29 | ) 30 | 31 | type Session struct { 32 | server *Server 33 | config *config.ProxyConfig 34 | user *config.UserConfig 35 | 36 | salt []byte 37 | 38 | cluster *cluster.Cluster 39 | bc *SqlConn 40 | fc *MySQLServerConn 41 | 42 | cliAddr string //client ip for auth 43 | autoCommit uint 44 | sessionId int64 45 | 46 | //session status 47 | txIsolationStmt string 48 | txIsolationInDef bool //is the tx isolation level in dafault? 49 | 50 | closed bool 51 | 52 | // lastcmd uint8 53 | } 54 | 55 | var errSessionQuit error = errors.New("session closed by client") 56 | 57 | func (s *Server) newSession(conn net.Conn) *Session { 58 | session := new(Session) 59 | id := <-sessionChan 60 | session.server = s 61 | session.config = s.cfg.GetConfig() 62 | session.salt, _ = RandomBuf(20) 63 | session.autoCommit = 0 64 | session.cliAddr = strings.Split(conn.RemoteAddr().String(), ":")[0] 65 | session.sessionId = id 66 | session.txIsolationInDef = true 67 | session.fc = NewMySQLServerConn(session, conn) 68 | //session.lastcmd = ComQuit 69 | log.Info("start new session", session.sessionId) 70 | return session 71 | } 72 | 73 | func (session *Session) Handshake() error { 74 | 75 | if err := session.fc.Handshake(); err != nil { 76 | erro := fmt.Errorf("session %d : handshake error: %s", session.sessionId, err.Error()) 77 | return erro 78 | } 79 | 80 | return nil 81 | } 82 | 83 | func (session *Session) Run() error { 84 | 85 | for { 86 | 87 | data, err := session.fc.ReadPacket() 88 | 89 | if err != nil { 90 | // log.Warn(err) 91 | // Usually client close the conn 92 | return err 93 | } 94 | 95 | if data[0] == ComQuit { 96 | return errSessionQuit 97 | } 98 | 99 | if err := session.dispatch(data); err != nil { 100 | if err == driver.ErrBadConn { 101 | // TODO handle error 102 | } 103 | 104 | log.Warnf("sessionId %d:dispatch error: %s", session.sessionId, err.Error()) 105 | return err 106 | } 107 | 108 | session.fc.ResetSequence() 109 | 110 | if session.closed { 111 | // TODO return MySQL Go Away ? 112 | return errors.New("session closed!") 113 | } 114 | } 115 | 116 | return nil 117 | } 118 | 119 | func (session *Session) Close() error { 120 | if session.closed { 121 | return nil 122 | } 123 | 124 | //current connection is in AC tx mode reset before store in poll 125 | if !session.isAutoCommit() { 126 | //Debug 127 | if !session.isInTransaction() { 128 | err := errors.New("transaction must be in true in the autocommit = 0 mode") 129 | return err 130 | } 131 | //rollback uncommit data 132 | 133 | //set the autocommit mdoe as true 134 | session.clearAutoCommitTx() 135 | for _, s := range session.bc.stmts { 136 | s.Close() 137 | } 138 | 139 | } 140 | if session.isInTransaction() { 141 | // session.handleCommit() 142 | log.Debugf("session : %d reset the tx status", session.sessionId) 143 | if session.txIsolationInDef == false { 144 | session.bc.tx.Exec("set session transaction isolation level read uncommitted;") //reset to default level 145 | } 146 | if err := session.bc.rollback(session.isAutoCommit()); err != nil { 147 | log.Info(err.Error) 148 | } 149 | } 150 | session.fc.Close() 151 | 152 | // session.bc.tx.Exec("set autocommit =0 ") 153 | // TODO transaction 154 | // session.rollback() 155 | 156 | // TODO stmts 157 | // for _, s := range session.stmts { 158 | // s.Close() 159 | // } 160 | 161 | // session.stmts = nil 162 | 163 | session.closed = true 164 | 165 | return nil 166 | } 167 | 168 | func (session *Session) ServerName() []byte { 169 | return hack.Slice(version.Version) 170 | } 171 | func (session *Session) GetIsoLevel() (string, bool) { 172 | if session.txIsolationInDef { 173 | sql := "set session transaction isolation level read committed" 174 | return sql, true 175 | } else { 176 | return session.txIsolationStmt, false 177 | } 178 | } 179 | 180 | func (session *Session) Salt() []byte { 181 | return session.salt 182 | } 183 | -------------------------------------------------------------------------------- /proxy/signal.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 ByteDance, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package proxy 15 | 16 | import ( 17 | "github.com/ngaut/log" 18 | "os" 19 | ) 20 | 21 | type SignalHandler func(s os.Signal, arg interface{}) error 22 | 23 | type SignalSet struct { 24 | M map[os.Signal]SignalHandler 25 | } 26 | 27 | func NewSignalSet() *SignalSet { 28 | s := new(SignalSet) 29 | s.M = make(map[os.Signal]SignalHandler) 30 | return s 31 | } 32 | 33 | func (s *SignalSet) Register(sig os.Signal, handler SignalHandler) { 34 | if _, exist := s.M[sig]; !exist { 35 | s.M[sig] = handler 36 | } 37 | } 38 | 39 | func (s *SignalSet) Handle(sig os.Signal, arg interface{}) error { 40 | if handler, exist := s.M[sig]; exist { 41 | return handler(sig, arg) 42 | } else { 43 | log.Warnf("no available handler for signal %v, ignore!", sig) 44 | return nil 45 | } 46 | } 47 | 48 | /* vim: set expandtab ts=4 sw=4 */ 49 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./build.sh 4 | 5 | ./cmd/dbatman/dbatman -config config/test.yml 6 | -------------------------------------------------------------------------------- /systest.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TEST_CASES=' 4 | disable_query_logs 5 | test_view_sp_list_fields 6 | client_query 7 | test_prepare_insert_update 8 | test_fetch_seek 9 | test_fetch_nobuffs 10 | test_open_direct 11 | test_fetch_null 12 | test_ps_null_param 13 | test_fetch_date 14 | test_fetch_str 15 | test_fetch_long 16 | test_fetch_short 17 | test_fetch_tiny 18 | test_fetch_bigint 19 | test_fetch_float 20 | test_fetch_double 21 | test_bind_result_ext 22 | test_bind_result_ext1 23 | test_select_direct 24 | test_select_prepare 25 | test_select 26 | test_select_version 27 | test_ps_conj_select 28 | test_select_show_table 29 | test_func_fields 30 | test_long_data 31 | test_insert 32 | test_set_variable 33 | test_select_show 34 | test_prepare_noparam 35 | test_bind_result 36 | test_prepare_simple 37 | test_prepare 38 | test_null 39 | test_debug_example 40 | test_update 41 | test_simple_update 42 | test_simple_delete 43 | test_double_compare 44 | client_store_result 45 | client_use_result 46 | test_tran_bdb 47 | test_tran_innodb 48 | test_prepare_ext 49 | test_prepare_syntax 50 | test_field_names 51 | test_field_flags 52 | test_long_data_str 53 | test_long_data_str1 54 | test_long_data_bin 55 | test_warnings 56 | test_errors 57 | test_prepare_resultset 58 | test_stmt_close 59 | test_prepare_field_result 60 | test_multi_stmt 61 | test_multi_statements 62 | test_prepare_multi_statements 63 | test_store_result 64 | test_store_result1 65 | test_store_result2 66 | test_subselect 67 | test_date 68 | test_date_frac 69 | test_temporal_param 70 | test_date_date 71 | test_date_time 72 | test_date_ts 73 | test_date_dt 74 | test_prepare_alter 75 | test_manual_sample 76 | test_pure_coverage 77 | test_buffers 78 | test_ushort_bug 79 | test_sshort_bug 80 | test_stiny_bug 81 | test_field_misc 82 | test_set_option 83 | test_prepare_grant 84 | test_frm_bug 85 | test_explain_bug 86 | test_decimal_bug 87 | test_nstmts 88 | test_logs; 89 | test_cuted_rows 90 | test_fetch_offset 91 | test_fetch_column 92 | test_mem_overun 93 | test_list_fields 94 | test_free_result 95 | test_free_store_result 96 | test_sqlmode 97 | test_ts 98 | test_bug1115 99 | test_bug1180 100 | test_bug1500 101 | test_bug1644 102 | test_bug1946 103 | test_bug2248 104 | test_parse_error_and_bad_length 105 | test_bug2247 106 | test_subqueries 107 | test_bad_union 108 | test_distinct 109 | test_subqueries_ref 110 | test_union 111 | test_bug3117 112 | test_join 113 | test_selecttmp 114 | test_create_drop 115 | test_rename 116 | test_do_set 117 | test_multi 118 | test_insert_select 119 | test_bind_nagative 120 | test_derived 121 | test_xjoin 122 | test_bug3035 123 | test_union2 124 | test_bug1664 125 | test_union_param 126 | test_order_param 127 | test_ps_i18n 128 | test_bug3796 129 | test_bug4026 130 | test_bug4079 131 | test_bug4236 132 | test_bug4030 133 | test_bug5126 134 | test_bug4231 135 | test_bug5399 136 | test_bug5194 137 | test_bug5315 138 | test_bug6049 139 | test_bug6058 140 | test_bug6059 141 | test_bug6046 142 | test_bug6081 143 | test_bug6096 144 | test_datetime_ranges 145 | test_bug4172 146 | test_conversion 147 | test_rewind 148 | test_bug6761 149 | test_view 150 | test_view_where 151 | test_view_2where 152 | test_view_star 153 | test_view_insert 154 | test_left_join_view 155 | test_view_insert_fields 156 | test_basic_cursors 157 | test_cursors_with_union 158 | test_cursors_with_procedure 159 | test_truncation 160 | test_truncation_option 161 | test_client_character_set 162 | test_bug8330 163 | test_bug7990 164 | test_bug8378 165 | test_bug8722 166 | test_bug8880 167 | test_bug9159 168 | test_bug9520 169 | test_bug9478 170 | test_bug9643 171 | test_bug10729 172 | test_bug11111 173 | test_bug9992 174 | test_bug10736 175 | test_bug10794 176 | test_bug11172 177 | test_bug11656 178 | test_bug10214 179 | test_bug21246 180 | test_bug9735 181 | test_bug11183 182 | test_bug11037 183 | test_bug10760 184 | test_bug12001 185 | test_bug11718 186 | test_bug12925 187 | test_bug11909 188 | test_bug11901 189 | test_bug11904 190 | test_bug12243 191 | test_bug14210 192 | test_bug13488 193 | test_bug13524 194 | test_bug14845 195 | test_opt_reconnect 196 | test_bug15510 197 | test_bug12744 198 | test_bug16143 199 | test_bug16144 200 | test_bug15613 201 | test_bug20152 202 | test_bug14169 203 | test_bug17667 204 | test_bug15752 205 | test_mysql_insert_id 206 | test_bug19671 207 | test_bug21206 208 | test_bug21726 209 | test_bug15518 210 | test_bug23383 211 | test_bug32265 212 | test_bug21635 213 | test_status 214 | test_bug24179 215 | test_ps_query_cache 216 | test_bug28075 217 | test_bug27876 218 | test_bug28505 219 | test_bug28934 220 | test_bug27592 221 | test_bug29687 222 | test_bug29692 223 | test_bug29306 224 | test_change_user 225 | test_bug30472 226 | test_bug20023 227 | test_bug45010 228 | test_bug53371 229 | test_bug31418 230 | test_bug31669 231 | test_bug28386 232 | test_wl4166_1 233 | test_wl4166_2 234 | test_wl4166_3 235 | test_wl4166_4 236 | test_bug36004 237 | test_wl4284_1 238 | test_wl4435 239 | test_wl4435_2 240 | test_wl4435_3 241 | test_bug38486 242 | test_bug33831 243 | test_bug40365 244 | test_bug43560 245 | test_bug36326 246 | test_bug41078 247 | test_bug44495 248 | test_bug49972 249 | test_bug42373 250 | test_bug54041 251 | test_bug47485 252 | test_bug58036 253 | test_bug57058 254 | test_bug56976 255 | test_bug11766854 256 | test_bug54790 257 | test_bug12337762 258 | test_bug11754979 259 | test_bug13001491 260 | test_wl5968 261 | test_wl5924 262 | test_wl6587 263 | test_wl5928 264 | test_wl6797 265 | test_wl6791 266 | test_wl5768 267 | test_bug17309863 268 | test_bug17512527 269 | test_bug20810928 270 | test_wl8016 271 | test_bug20645725 272 | test_bug20444737 273 | test_bug21104470 274 | test_bug21293012 275 | test_bug21199582 276 | test_bug20821550 277 | ' 278 | 279 | for i in ${TEST_CASES}; do 280 | echo ${i} 281 | done 282 | -------------------------------------------------------------------------------- /unitest.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | echo "" > coverage.txt 5 | 6 | for d in $(find ./* -maxdepth 3 -type d); do 7 | if ls $d/*.go &> /dev/null; then 8 | IMPORT_LIST=`go list -f "{{.Imports}}" $d | sed -e "s/\[//g" | sed -e "s/\]//g"` 9 | BYTEDANCE_PKG=`go list $d` 10 | for i in ${IMPORT_LIST[@]}; do 11 | if [[ ${i} == github\.com\/bytedance\/dbatman* ]]; then 12 | BYTEDANCE_PKG=${BYTEDANCE_PKG},${i} 13 | fi 14 | done 15 | 16 | go test -coverprofile=profile.out -covermode=atomic $d -coverpkg=${BYTEDANCE_PKG} 17 | 18 | if [ -f profile.out ]; then 19 | cat profile.out >> coverage.txt 20 | rm profile.out 21 | fi 22 | fi 23 | done 24 | -------------------------------------------------------------------------------- /wercker.yml: -------------------------------------------------------------------------------- 1 | # This references the default golang container from 2 | # the Docker Hub: https://registry.hub.docker.com/u/library/golang/ 3 | # If you want Google's container you would reference google/golang 4 | # Read more about containers on our dev center 5 | # http://devcenter.wercker.com/docs/containers/index.html 6 | box: golang 7 | # This is the build pipeline. Pipelines are the core of wercker 8 | # Read more about pipelines on our dev center 9 | # http://devcenter.wercker.com/docs/pipelines/index.html 10 | 11 | # You can also use services such as databases. Read more on our dev center: 12 | # http://devcenter.wercker.com/docs/services/index.html 13 | # services: 14 | # - postgres 15 | # http://devcenter.wercker.com/docs/services/postgresql.html 16 | 17 | # - mongo 18 | # http://devcenter.wercker.com/docs/services/mongodb.html 19 | build: 20 | # The steps that will be executed on build 21 | # Steps make up the actions in your pipeline 22 | # Read more about steps on our dev center: 23 | # http://devcenter.wercker.com/docs/steps/index.html 24 | steps: 25 | # Sets the go workspace and places you package 26 | # at the right place in the workspace tree 27 | - setup-go-workspace 28 | 29 | # Gets the dependencies 30 | - script: 31 | name: build 32 | code: | 33 | sh build.sh 34 | 35 | test: 36 | steps: 37 | -script: 38 | name: test 39 | code: | 40 | go test ./... 41 | --------------------------------------------------------------------------------