├── .gitignore ├── images └── demo.png ├── go.mod ├── protocol ├── errors.go ├── command.go ├── packet_test.go ├── packet.go ├── conn_test.go ├── response.go └── conn.go ├── internal └── proxy │ ├── error.go │ ├── pool_test.go │ ├── proxy.go │ └── pool.go ├── LICENSE ├── README.md ├── go.sum └── cmd └── umyproxy └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | .DS_Store 3 | .idea 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /images/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyuangg/umyproxy/HEAD/images/demo.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lyuangg/umyproxy 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/davecgh/go-spew v1.1.1 // indirect 7 | github.com/pmezard/go-difflib v1.0.0 // indirect 8 | github.com/stretchr/testify v1.8.1 // indirect 9 | gopkg.in/yaml.v3 v3.0.1 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /protocol/errors.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrConnClosed = errors.New("connection is closed") 7 | ErrNoAuth = errors.New("client no auth") 8 | ErrAuth = errors.New("client auth error") 9 | ErrClientQuit = errors.New("client quit cmd") 10 | ) 11 | -------------------------------------------------------------------------------- /internal/proxy/error.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrConnExpired = errors.New("connection expired") 7 | ErrConnClosed = errors.New("connection Closed") 8 | ErrPoolClosed = errors.New("connection closed") 9 | ErrPoolFull = errors.New("pool full") 10 | ErrWaitConnTimeout = errors.New("wait mysql connection timeout") 11 | ) 12 | -------------------------------------------------------------------------------- /protocol/command.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase.html 4 | const ( 5 | // Text Protocol 6 | COM_QUERY = 0x03 7 | 8 | // Utility Commands 9 | COM_QUIT = 0x01 10 | COM_INIT_DB = 0x02 11 | COM_FIELD_LIST = 0x04 12 | COM_REFRESH = 0x07 13 | COM_STATISTICS = 0x08 14 | COM_PROCESS_INFO = 0x0A 15 | COM_PROCESS_KILL = 0x0C 16 | COM_DEBUG = 0x0D 17 | COM_PING = 0x1E 18 | COM_CHANGE_USER = 0x11 19 | COM_RESET_CONNECTION = 0x1F 20 | COM_SET_OPTION = 0x1A 21 | 22 | // Prepared Statements 23 | COM_STMT_PREPARE = 0x16 24 | COM_STMT_EXECUTE = 0x17 25 | COM_STMT_FETCH = 0x19 26 | COM_STMT_CLOSE = 0x19 27 | COM_STMT_RESET = 0x1A 28 | COM_STMT_SEND_LONG_DATA = 0x18 29 | 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 lyuangg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # umyproxy 2 | ![GitHub](https://img.shields.io/github/license/lyuangg/umyproxy?style=flat-square) 3 | ![GitHub go.mod Go version (subdirectory of monorepo)](https://img.shields.io/github/go-mod/go-version/lyuangg/umyproxy?style=flat-square&logo=go) 4 | ![GitHub Repo stars](https://img.shields.io/github/stars/lyuangg/umyproxy?style=flat-square) 5 | 6 | `umyproxy` 是一个 mysql 的本地代理服务程序。 7 | 可以为基于 php-fpm 的 PHP 程序提供 mysql 连接池的功能, 解决高并发中短连接产生大量 `TIME_WAIT` 的问题。 8 | 9 | `umyproxy` 使用 `Unix domain socket` 与客户端进行通信, 第一次与 mysql 服务端建立连接, 代理 client 端进行通信,认证通过后复用连接并放入连接池。 10 | client 端第二次连接时采用假的认证方式认证。 11 | 12 | ## 特点 13 | 14 | - 支持 mysql 连接池。 15 | - 纯 go 语言开发, 不依赖第三方库。 16 | - 使用 Unix domain socket 通信。 17 | - 使用简单,不需要配置 mysql 账号密码。 18 | 19 | ## 使用 20 | 21 | > laravel 框架为例 22 | 23 | 1. 启动服务 24 | 25 | ``` 26 | ./umyproxy -host 127.0.0.1 -port 3306 -socket /tmp/umyproxy.socket 27 | ``` 28 | 29 | 2. 配置数据库 30 | 31 | 修改数据库配置: ./config/database.php 32 | 33 | mysql 配置增加 34 | 35 | ``` 36 | 'unix_socket' => '/tmp/umyproxy.socket', 37 | ``` 38 | 39 | ## 查看帮助 40 | 41 | ``` 42 | ./umyproxy -h 43 | ``` 44 | 45 | ## 编译 46 | 47 | ``` 48 | ./build 49 | ``` 50 | 51 | ## demo 52 | 53 | ![demo](./images/demo.png) 54 | 55 | 56 | ## laravel 使用 57 | 58 | [https://github.com/lyuangg/laravel-umyproxy](https://github.com/lyuangg/laravel-umyproxy) 59 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 8 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 11 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 12 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 15 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 16 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 17 | -------------------------------------------------------------------------------- /protocol/packet_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestHeader(t *testing.T) { 10 | mockSize := []int{ 11 | 100, 12 | 1000, 13 | MAX_PAYLOAD_LEN, 14 | MAX_PAYLOAD_LEN + 10, 15 | } 16 | 17 | for _, size := range mockSize { 18 | data := make([]byte, size) 19 | p := Packet{Payload: data, SeqId: 0} 20 | h := p.Header() 21 | assert.Len(t, h, 4, "header length error") 22 | if size >= MAX_PAYLOAD_LEN { 23 | assert.Equal(t, int(h[2]), int(0xff), "header size error") 24 | } else { 25 | assert.Equal(t, int(h[0]), int(byte(size)), "header size error") 26 | } 27 | } 28 | } 29 | 30 | func TestSplit(t *testing.T) { 31 | testCase := []struct { 32 | name string 33 | size int 34 | sliceLen int 35 | } { 36 | { 37 | "100", 38 | 100, 39 | 1, 40 | }, 41 | { 42 | "1000", 43 | 1000, 44 | 1, 45 | }, 46 | { 47 | "max_payload1", 48 | MAX_PAYLOAD_LEN, 49 | 2, 50 | }, 51 | { 52 | "max_payload2", 53 | MAX_PAYLOAD_LEN + 100, 54 | 2, 55 | }, 56 | } 57 | 58 | for _, tc := range testCase { 59 | t.Run(tc.name, func(t *testing.T) { 60 | data := make([]byte, tc.size) 61 | p := Packet{Payload: data, SeqId: 0} 62 | ps := p.Split() 63 | assert.Len(t, ps, tc.sliceLen, "split length error") 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /protocol/packet.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | const ( 4 | MAX_PAYLOAD_LEN int = 1<<24 - 1 5 | 6 | OK_PACKET byte = 0x00 7 | ERR_PACKET byte = 0xff 8 | EOF_PACKET byte = 0xfe 9 | QUIT_PACKET byte= 0x01 10 | ) 11 | 12 | 13 | type ( 14 | Packet struct { 15 | // 不包含头 16 | Payload []byte 17 | SeqId uint8 18 | } 19 | ) 20 | 21 | // 分包 22 | func (p Packet) Split() []Packet { 23 | data := p.Payload 24 | pll := len(data) 25 | packets := make([]Packet, 0) 26 | seqId := p.SeqId 27 | 28 | 29 | for pll >= MAX_PAYLOAD_LEN { 30 | pk := Packet{ 31 | Payload: data[:MAX_PAYLOAD_LEN], 32 | SeqId: seqId, 33 | } 34 | packets = append(packets, pk) 35 | data = data[MAX_PAYLOAD_LEN:] 36 | pll = len(data) 37 | seqId ++ 38 | } 39 | 40 | if pll > 0 { 41 | pk := Packet{ 42 | Payload: data, 43 | SeqId: seqId, 44 | } 45 | packets = append(packets, pk) 46 | } else { 47 | pk := Packet{} 48 | packets = append(packets, pk) 49 | } 50 | 51 | return packets 52 | } 53 | 54 | // 包头 55 | func (p Packet) Header() []byte { 56 | header := make([]byte, 4) 57 | length := len(p.Payload) 58 | if length >= MAX_PAYLOAD_LEN { 59 | header[0] = 0xff 60 | header[1] = 0xff 61 | header[2] = 0xff 62 | } else { 63 | header[0] = byte(length) 64 | header[1] = byte(length >> 8) 65 | header[2] = byte(length >> 16) 66 | } 67 | header[3] = p.SeqId 68 | 69 | return header 70 | } 71 | 72 | func IsQuitPacket(p Packet) bool { 73 | if len(p.Payload) > 0 && p.Payload[0] == QUIT_PACKET { 74 | return true 75 | } 76 | return false 77 | } 78 | 79 | func IsEofPacket(p Packet) bool { 80 | if len(p.Payload) > 0 && p.Payload[0] == EOF_PACKET { 81 | return true 82 | } 83 | return false 84 | } 85 | 86 | func IsOkPacket(p Packet) bool { 87 | if len(p.Payload) > 0 && p.Payload[0] == OK_PACKET { 88 | return true 89 | } 90 | return false 91 | } 92 | 93 | func IsErrPacket(p Packet) bool { 94 | if len(p.Payload) > 0 && p.Payload[0] == ERR_PACKET { 95 | return true 96 | } 97 | return false 98 | } 99 | -------------------------------------------------------------------------------- /cmd/umyproxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "os/signal" 10 | "runtime" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/lyuangg/umyproxy/internal/proxy" 15 | ) 16 | 17 | var ( 18 | appname string = "umyproxy" 19 | version string = "0.0.1" 20 | showversion bool 21 | host string 22 | port int 23 | socketfile string 24 | poolsize int 25 | maxlife int 26 | waittimeout int 27 | debug bool 28 | ) 29 | 30 | const ( 31 | logstr = ` 32 | __ ____ ___ ____ 33 | / / / / |/ /_ __/ __ \_________ _ ____ __ 34 | / / / / /|_/ / / / / /_/ / ___/ __ \| |/_/ / / / 35 | / /_/ / / / / /_/ / ____/ / / /_/ /> 8 && p.Payload[5] == 0 && p.Payload[6] == 0 && p.Payload[7] == 0 && p.Payload[8] == 0 { 119 | return nil 120 | } 121 | eofCount := 0 122 | 123 | // columns 为0 124 | if len(p.Payload) > 8 && p.Payload[5] == 0 && p.Payload[6] == 0 { 125 | eofCount = 1 126 | } 127 | // parameters 为0 128 | if len(p.Payload) > 8 && p.Payload[7] == 0 && p.Payload[8] == 0 { 129 | eofCount = 1 130 | } 131 | for { 132 | p2, err2 := TransportPacket(r.server, client) 133 | if err2 != nil { 134 | return err2 135 | } 136 | if IsEofPacket(p2) { 137 | if eofCount == 1 { 138 | return nil 139 | } else { 140 | eofCount++ 141 | } 142 | } 143 | } 144 | } 145 | return nil 146 | } 147 | 148 | // 预处理语句响应 149 | if r.cmd == COM_STMT_EXECUTE { 150 | p, err := TransportPacket(r.server, client) 151 | if err != nil { 152 | return err 153 | } 154 | if IsErrPacket(p) || IsOkPacket(p) { 155 | return nil 156 | } 157 | eofCount := 0 158 | for { 159 | p2, err2 := TransportPacket(r.server, client) 160 | if err2 != nil { 161 | return err2 162 | } 163 | if IsEofPacket(p2) { 164 | if eofCount == 1 { 165 | return nil 166 | } else { 167 | eofCount++ 168 | } 169 | } 170 | } 171 | } 172 | 173 | // 响应失败 174 | _, err := TransportPacket(r.server, client) 175 | return err 176 | } 177 | -------------------------------------------------------------------------------- /internal/proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net" 7 | "os" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/lyuangg/umyproxy/protocol" 12 | ) 13 | 14 | type ( 15 | Proxy struct { 16 | server net.Listener 17 | pool *Pool 18 | socketFile string 19 | debug bool 20 | inShutdown uint32 21 | } 22 | ) 23 | 24 | func NewProxy(p *Pool, socketfile string) *Proxy { 25 | return &Proxy{ 26 | pool: p, 27 | socketFile: socketfile, 28 | } 29 | } 30 | 31 | func (p *Proxy) Run() { 32 | p.deleteSocketFile() 33 | 34 | serv, err := net.Listen("unix", p.socketFile) 35 | if err != nil { 36 | log.Fatalln("Listen err:", err) 37 | } 38 | 39 | 40 | // Set socket file permissions 41 | perm := os.FileMode(0777) 42 | err = os.Chmod(p.socketFile, perm) 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | p.server = serv 48 | p.startPrint() 49 | 50 | for { 51 | conn, err := p.server.Accept() 52 | if p.shuttingDown() { 53 | log.Println("shutting down...") 54 | return 55 | } 56 | if err != nil { 57 | log.Fatalln("conn err:", err) 58 | return 59 | } 60 | p.debugPrintf("accept conn") 61 | 62 | go p.HandleConn(conn) 63 | } 64 | } 65 | 66 | func (p *Proxy) SetDebug() { 67 | p.debug = true 68 | p.debugPrintf("debug mode") 69 | } 70 | 71 | func (p *Proxy) debugPrintf(format string, v ...interface{}) { 72 | if p.debug { 73 | format = "[DEBUG]" + format + "\n" 74 | log.Printf(format, v...) 75 | } 76 | } 77 | 78 | func (p *Proxy) startPrint() { 79 | log.Println("start server: ", p.socketFile) 80 | log.Println("host:", p.pool.option.Host) 81 | log.Println("port:", p.pool.option.Port) 82 | log.Println("pool_size:", p.pool.option.PoolMaxSize) 83 | log.Println("conn_maxlifetime:", p.pool.option.MaxLifetime) 84 | log.Println("wait_timeout:", p.pool.option.WaitTimeout) 85 | } 86 | 87 | func (p *Proxy) HandleConn(conn net.Conn) { 88 | client := protocol.NewConn(conn) 89 | defer client.Close() 90 | 91 | mysqlServ, err := p.Get() 92 | if err != nil { 93 | log.Printf("get mysql conn err: %+v \n", err) 94 | return 95 | } 96 | p.debugPrintf("get mysql conn") 97 | defer p.Put(mysqlServ) 98 | 99 | // 认证 100 | if err := mysqlServ.Auth(client); err != nil { 101 | log.Printf("mysql auth err: %+v \n", err) 102 | return 103 | } 104 | p.debugPrintf("client auth success") 105 | 106 | // 发送命令 107 | for { 108 | cmd, err := client.ReadPacket() 109 | 110 | if err != nil { 111 | log.Println("read cmd err: ", err) 112 | return 113 | } 114 | 115 | p.debugPrintf("read cmd: %+v", cmd) 116 | 117 | if protocol.IsQuitPacket(cmd) { 118 | p.debugPrintf("client quit") 119 | return 120 | } 121 | 122 | err = mysqlServ.WritePacket(cmd) 123 | if err != nil { 124 | log.Printf("write cmd to server err: %+v \n", err) 125 | return 126 | } 127 | 128 | // response 129 | resp := protocol.NewResponse(mysqlServ, cmd.Payload[0]) 130 | err = resp.ResponsePacket(client) 131 | p.debugPrintf("transport response") 132 | if err != nil { 133 | log.Println("transport response err:", err) 134 | return 135 | } 136 | p.debugPrintf("end transport response") 137 | } 138 | 139 | } 140 | 141 | func (p *Proxy) Get() (protocol.Connector, error) { 142 | return p.pool.Get() 143 | } 144 | 145 | func (p *Proxy) Put(conn protocol.Connector) error { 146 | p.debugPrintf("put conn") 147 | return p.pool.Put(conn) 148 | } 149 | 150 | func (p *Proxy) Shutdown(ctx context.Context) error { 151 | atomic.StoreUint32(&p.inShutdown, 1) 152 | 153 | p.pool.Close() 154 | 155 | // 检查请求 156 | t := time.NewTimer(time.Millisecond * 100) 157 | defer t.Stop() 158 | for { 159 | if p.pool.OpenSize() <= 0 { 160 | return p.server.Close() 161 | } 162 | select { 163 | case <-ctx.Done(): 164 | p.server.Close() 165 | return ctx.Err() 166 | case <-t.C: 167 | t.Reset(time.Millisecond * 100) 168 | } 169 | } 170 | } 171 | 172 | func (p *Proxy) shuttingDown() bool { 173 | if atomic.LoadUint32(&p.inShutdown) == 1 { 174 | return true 175 | } 176 | return false 177 | } 178 | 179 | func (p *Proxy) deleteSocketFile() error { 180 | _, err := os.Stat(p.socketFile) 181 | if err == nil || os.IsExist(err) { 182 | return os.Remove(p.socketFile) 183 | } 184 | return err 185 | } 186 | -------------------------------------------------------------------------------- /internal/proxy/pool.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/lyuangg/umyproxy/protocol" 7 | "net" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | type ( 13 | PoolOption struct { 14 | Host string 15 | Port int 16 | MaxLifetime time.Duration 17 | PoolMaxSize int 18 | WaitTimeout time.Duration 19 | } 20 | 21 | Pool struct { 22 | option PoolOption 23 | mu sync.Mutex 24 | freeConn []protocol.Connector 25 | openSize int 26 | connRequests map[uint64]chan protocol.Connector 27 | nextRequest uint64 28 | closed bool 29 | createConn ConnCreater 30 | } 31 | 32 | ConnCreater func(string) (protocol.Connector, error) 33 | ) 34 | 35 | func NewPool(option PoolOption) *Pool { 36 | freeConn := make([]protocol.Connector, 0) 37 | connRequest := make(map[uint64]chan protocol.Connector, 0) 38 | var createConn ConnCreater 39 | createConn = NewConnect 40 | return &Pool{option: option, freeConn: freeConn, connRequests: connRequest, createConn: createConn} 41 | } 42 | 43 | func (p *Pool) SetCreater(creater ConnCreater) { 44 | p.createConn = creater 45 | } 46 | 47 | func (p *Pool) Get() (protocol.Connector, error) { 48 | p.mu.Lock() 49 | if p.closed { 50 | p.mu.Unlock() 51 | return nil, errors.New("pool closed") 52 | } 53 | 54 | // 空闲连接 55 | freeNum := len(p.freeConn) 56 | if freeNum > 0 { 57 | for i, conn := range p.freeConn { 58 | 59 | // 判断 conn 过期 60 | if !conn.Expired(p.option.MaxLifetime) && !conn.Closed() { 61 | 62 | // 删除 63 | copy(p.freeConn, p.freeConn[i+1:]) 64 | p.freeConn = p.freeConn[:freeNum-i-1] 65 | 66 | conn.RefreshUseTime() 67 | 68 | p.mu.Unlock() 69 | return conn, nil 70 | } 71 | 72 | conn.Close() 73 | p.openSize-- 74 | } 75 | 76 | // clean all 77 | p.freeConn = nil 78 | } 79 | 80 | // 创建新连接 81 | if p.openSize < p.option.PoolMaxSize { 82 | conn, err := p.createConn(fmt.Sprintf("%s:%d", p.option.Host, p.option.Port)) 83 | if err != nil { 84 | p.mu.Unlock() 85 | return nil, fmt.Errorf("new connect err: %w", err) 86 | } 87 | p.openSize++ 88 | p.mu.Unlock() 89 | return conn, nil 90 | } 91 | 92 | // 等待队列 93 | req := make(chan protocol.Connector, 1) 94 | reqKey := p.nextRequest + 1 95 | p.nextRequest = reqKey 96 | p.connRequests[reqKey] = req 97 | p.mu.Unlock() 98 | select { 99 | case <-time.After(p.option.WaitTimeout): 100 | p.mu.Lock() 101 | delete(p.connRequests, reqKey) 102 | p.mu.Unlock() 103 | 104 | // put 105 | select { 106 | default: 107 | case conn, ok := <-req: 108 | if ok && !conn.Closed() && !conn.Expired(p.option.MaxLifetime) { 109 | p.Put(conn) 110 | } 111 | } 112 | 113 | return nil, ErrWaitConnTimeout 114 | case conn, ok := <-req: 115 | if !ok { 116 | return nil, ErrWaitConnTimeout 117 | } 118 | return conn, nil 119 | } 120 | } 121 | 122 | func (p *Pool) Put(conn protocol.Connector) error { 123 | p.mu.Lock() 124 | if p.closed { 125 | conn.Close() 126 | p.openSize-- 127 | p.mu.Unlock() 128 | return ErrPoolClosed 129 | } 130 | 131 | if conn.Expired(p.option.MaxLifetime) || conn.Closed() { 132 | conn.Close() 133 | p.openSize-- 134 | p.mu.Unlock() 135 | return ErrConnExpired 136 | } 137 | conn.RefreshUseTime() 138 | 139 | // 请求队列 140 | if len(p.connRequests) > 0 { 141 | for reqKey, ch := range p.connRequests { 142 | ch <- conn 143 | delete(p.connRequests, reqKey) 144 | close(ch) 145 | p.mu.Unlock() 146 | return nil 147 | } 148 | } 149 | 150 | // 放入freeConn 151 | freeNum := len(p.freeConn) 152 | if freeNum >= p.option.PoolMaxSize { 153 | // 删掉一个 154 | copy(p.freeConn, p.freeConn[1:]) 155 | p.freeConn = p.freeConn[:len(p.freeConn)-1] 156 | } 157 | p.freeConn = append(p.freeConn, conn) 158 | p.mu.Unlock() 159 | 160 | return nil 161 | } 162 | 163 | func (p *Pool) Close() { 164 | p.mu.Lock() 165 | p.closed = true 166 | for _, conn := range p.freeConn { 167 | p.openSize-- 168 | conn.Close() 169 | } 170 | p.freeConn = nil 171 | for _, reqCh := range p.connRequests { 172 | close(reqCh) 173 | } 174 | p.mu.Unlock() 175 | } 176 | 177 | func (p *Pool) OpenSize() int { 178 | p.mu.Lock() 179 | defer p.mu.Unlock() 180 | 181 | return p.openSize 182 | } 183 | 184 | func NewConnect(address string) (protocol.Connector, error) { 185 | conn, err := net.DialTimeout("tcp", address, time.Second*2) 186 | if err != nil { 187 | return nil, fmt.Errorf("new tcp connect err: %w", err) 188 | } 189 | tcpconn := conn.(*net.TCPConn) 190 | tcpconn.SetKeepAlive(true) 191 | mysqlConn := protocol.NewConn(tcpconn) 192 | return mysqlConn, nil 193 | } 194 | -------------------------------------------------------------------------------- /protocol/conn.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "time" 8 | ) 9 | 10 | type ( 11 | Connector interface { 12 | ReadPacket() (Packet, error) 13 | WritePacket(Packet) error 14 | Auth(Connector) error 15 | Closed() bool 16 | Expired(time.Duration) bool 17 | RefreshUseTime() 18 | Close() error 19 | } 20 | 21 | Conn struct { 22 | c net.Conn 23 | initHandPacket Packet 24 | authSuccessPacket Packet 25 | authSuccess bool 26 | usedTime time.Time 27 | closed bool 28 | } 29 | 30 | ) 31 | 32 | func NewConn(c net.Conn) Connector { 33 | return &Conn{c:c, usedTime: time.Now()} 34 | } 35 | 36 | func (c *Conn) ReadPacket() (Packet, error) { 37 | 38 | p := Packet{} 39 | if c.Closed() { 40 | return p, ErrConnClosed 41 | } 42 | 43 | // read header 44 | header := make([]byte, 4) 45 | if _, err := io.ReadFull(c.c, header); err != nil { 46 | c.Close() 47 | return p, fmt.Errorf("read packet header err: %w", err) 48 | } 49 | 50 | p.SeqId = uint8(header[3]) 51 | 52 | dataLen := int(uint32(header[0]) | uint32(header[1]) << 8 | uint32(header[2]) << 16) 53 | if dataLen < 1 { 54 | return p, nil 55 | } 56 | 57 | // read body 58 | data := make([]byte, dataLen) 59 | if _, err := io.ReadFull(c.c, data); err != nil { 60 | c.Close() 61 | return p, fmt.Errorf("read packet payload err: %w", err) 62 | } 63 | p.Payload = data 64 | 65 | if dataLen < MAX_PAYLOAD_LEN { 66 | return p, nil 67 | } 68 | 69 | // append split packet 70 | p2, err := c.ReadPacket() 71 | if err != nil { 72 | c.Close() 73 | return p, fmt.Errorf("read split packet err: %w", err) 74 | } 75 | 76 | p.Payload = append(p.Payload, p2.Payload...) 77 | return p, nil 78 | } 79 | 80 | func (c *Conn) WritePacket(p Packet) error { 81 | if c.Closed() { 82 | return ErrConnClosed 83 | } 84 | 85 | ps := p.Split() 86 | for _, p2 := range ps { 87 | writeData := append(p2.Header(), p2.Payload...) 88 | if n, err := c.c.Write(writeData); err != nil { 89 | return fmt.Errorf("write packet err: %w", err) 90 | } else if n != len(writeData) { 91 | return fmt.Errorf("write packet length err: write(%d) data(%d)", n, len(writeData)) 92 | } 93 | } 94 | return nil 95 | } 96 | 97 | func (c *Conn) Auth(client Connector) error { 98 | if c.authSuccess { 99 | return c.fakeAuth(client) 100 | } 101 | return c.firstAuth(client) 102 | } 103 | 104 | func (c *Conn) firstAuth(client Connector) error { 105 | var err error 106 | initPacket, err := c.ReadPacket() 107 | if err != nil { 108 | return fmt.Errorf("read init packet err: %w", err) 109 | } 110 | c.initHandPacket = initPacket 111 | 112 | // send init packet 113 | err = client.WritePacket(c.initHandPacket) 114 | if err != nil { 115 | return fmt.Errorf("send init err: %w", err) 116 | } 117 | 118 | // read auth packet 119 | authPacket, err := client.ReadPacket() 120 | if err != nil { 121 | return fmt.Errorf("read auth packet err: %w", err) 122 | } 123 | 124 | // send auth to server 125 | err = c.WritePacket(authPacket) 126 | if err != nil { 127 | return fmt.Errorf("send auth packet err: %w", err) 128 | } 129 | 130 | // read auth result 131 | authResult, err := c.ReadPacket() 132 | if err != nil { 133 | return fmt.Errorf("read auth result err: %w", err) 134 | } 135 | 136 | // send auth result 137 | err = client.WritePacket(authResult) 138 | if err != nil { 139 | return fmt.Errorf("send result err: %w", err) 140 | } 141 | 142 | if IsErrPacket(authResult) { 143 | return ErrAuth 144 | } 145 | 146 | c.authSuccessPacket = authResult 147 | c.authSuccess = true 148 | 149 | return nil 150 | } 151 | 152 | func (c *Conn) fakeAuth(client Connector) error { 153 | if c.authSuccess == false { 154 | return ErrNoAuth 155 | } 156 | 157 | var err error 158 | 159 | // send init packet 160 | err = client.WritePacket(c.initHandPacket) 161 | if err != nil { 162 | return fmt.Errorf("send init err: %w", err) 163 | } 164 | 165 | // read auth packet 166 | _, err = client.ReadPacket() 167 | if err != nil { 168 | return fmt.Errorf("read auth packet err: %w", err) 169 | } 170 | 171 | // send auth result 172 | err = client.WritePacket(c.authSuccessPacket) 173 | if err != nil { 174 | return fmt.Errorf("send result err: %w", err) 175 | } 176 | 177 | return nil 178 | } 179 | 180 | func (c *Conn) Closed() bool { 181 | return c.closed 182 | } 183 | 184 | func (c *Conn) Expired(t time.Duration) bool { 185 | if time.Now().Sub(c.usedTime) < t { 186 | return false 187 | } 188 | return true 189 | } 190 | 191 | func (c *Conn) RefreshUseTime() { 192 | c.usedTime = time.Now() 193 | } 194 | 195 | func (c *Conn) Close() error { 196 | if c.closed { 197 | return ErrConnClosed 198 | } 199 | c.closed = true 200 | return c.c.Close() 201 | } 202 | --------------------------------------------------------------------------------