├── .gitignore ├── README.md ├── go.mod ├── go.sum ├── main.go ├── protocol ├── flags.go ├── math.go └── protocol.go └── proxy ├── connection.go ├── proxy.go └── proxy_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-mysql-proxy 2 | This repository is a result of writing serie of articles `Writing MySQL Proxy` that I'm posting on Medium website. 3 | 4 | 1 - [Writing MySQL Proxy in GO for self-learning: Part 1 — TCP Proxy](https://medium.com/@alexanderravikovich/quarantine-journey-writing-mysql-proxy-in-go-for-self-learning-part-1-tcp-proxy-39810479b7e9?source=friends_link&sk=9b498aca1d0b239228ab294ba09414bb) 5 | 2 - [Writing MySQL Proxy in GO for self-learning: Part 2 — decoding handshake packet](https://medium.com/@alexanderravikovich/writing-mysql-proxy-in-go-for-learning-purposes-part-2-decoding-connection-phase-server-response-7091d87e877e?source=friends_link&sk=c2efb5dfe76e5e061b0679c48e224f2b) 6 | 7 | The main goal is to learn the MySQL Protocol by implementing it. 8 | 9 | The plan: 10 | - [x] Implement TCP Proxy as a starting point 11 | - [ ] Implement state machine 12 | - [ ] Implement query/query data buffering 13 | - [ ] Implement plugins 14 | 15 | Packets decode/encode todo: 16 | - [x] Handshake Packet 17 | - [ ] Authorization Packet 18 | 19 | 20 | go version go1.12.9 21 | 22 | To try it, just clone, and run: 23 | 24 | ``` 25 | go run . 26 | ``` 27 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module go-mysql-proxy 2 | 3 | go 1.13 4 | 5 | require github.com/go-sql-driver/mysql v1.5.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= 2 | github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 3 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | proxy2 "go-mysql-proxy/proxy" 6 | "log" 7 | "os" 8 | "os/signal" 9 | ) 10 | 11 | func main() { 12 | ctx, cancel := context.WithCancel(context.Background()) 13 | proxy := proxy2.NewProxy("127.0.0.1", ":3306", ctx) 14 | proxy.EnableDecoding() 15 | 16 | c := make(chan os.Signal, 1) 17 | signal.Notify(c, os.Interrupt) 18 | go func(){ 19 | for sig := range c { 20 | log.Printf("Signal received %v, stopping and exiting...", sig) 21 | cancel() 22 | } 23 | }() 24 | 25 | err := proxy.Start("3336") 26 | if err != nil { 27 | log.Fatal(err) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /protocol/flags.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // https://dev.mysql.com/doc/internals/en/capability-flags.html 9 | 10 | type CapabilityFlag uint32 11 | 12 | /** 13 | Each flag is just an number, that can be represented just by having a single bit ON. 14 | It allows as to use fast bitwise operations. Each flag is just a number with applied << operator, 15 | that is equivalent of multiply by 2 16 | 17 | 1 = 00000001 18 | 2 = 00000010 19 | 4 = 00000100 20 | ... 21 | 22 | To check if the flag is set, we use & operator 23 | 24 | 00000111 & 00000001 = 1 => true 25 | 00000111 & 01000000 = 0 => false 26 | */ 27 | 28 | func (r CapabilityFlag) Has(flag CapabilityFlag) bool { 29 | return r & flag != 0 30 | } 31 | 32 | func (r CapabilityFlag) String() string { 33 | var names []string 34 | 35 | for i := uint64(1); i <= uint64(1) << 31; i = i << 1 { 36 | name, ok := flags[CapabilityFlag(i)]; if ok { 37 | names = append(names, fmt.Sprintf("0x%08x - %032b - %s", i, i, name)) 38 | } 39 | } 40 | 41 | return strings.Join(names, "\n") 42 | } 43 | 44 | const ( 45 | clientLongPassword CapabilityFlag = 1 << iota 46 | clientFoundRows 47 | clientLongFlag 48 | clientConnectWithDB 49 | clientNoSchema 50 | clientCompress 51 | clientODBC 52 | clientLocalFiles 53 | clientIgnoreSpace 54 | clientProtocol41 55 | clientInteractive 56 | clientSSL 57 | clientIgnoreSIGPIPE 58 | clientTransactions 59 | clientReserved 60 | clientSecureConn 61 | clientMultiStatements 62 | clientMultiResults 63 | clientPSMultiResults 64 | clientPluginAuth 65 | clientConnectAttrs 66 | clientPluginAuthLenEncClientData 67 | clientCanHandleExpiredPasswords 68 | clientSessionTrack 69 | clientDeprecateEOF 70 | ) 71 | 72 | var flags = map[CapabilityFlag]string{ 73 | clientLongPassword: "clientLongPassword", 74 | clientFoundRows: "clientFoundRows", 75 | clientLongFlag: "clientLongFlag", 76 | clientConnectWithDB: "clientConnectWithDB", 77 | clientNoSchema: "clientNoSchema", 78 | clientCompress: "clientCompress", 79 | clientODBC: "clientODBC", 80 | clientLocalFiles: "clientLocalFiles", 81 | clientIgnoreSpace: "clientIgnoreSpace", 82 | clientProtocol41: "clientProtocol41", 83 | clientInteractive: "clientInteractive", 84 | clientSSL: "clientSSL", 85 | clientIgnoreSIGPIPE: "clientIgnoreSIGPIPE", 86 | clientTransactions: "clientTransactions", 87 | clientReserved: "clientReserved", 88 | clientSecureConn: "clientSecureConn", 89 | clientMultiStatements: "clientMultiStatements", 90 | clientMultiResults: "clientMultiResults", 91 | clientPSMultiResults: "clientPSMultiResults", 92 | clientPluginAuth: "clientPluginAuth", 93 | clientConnectAttrs: "clientConnectAttrs", 94 | clientPluginAuthLenEncClientData: "clientPluginAuthLenEncClientData", 95 | clientCanHandleExpiredPasswords: "clientCanHandleExpiredPasswords", 96 | clientSessionTrack: "clientSessionTrack", 97 | clientDeprecateEOF: "clientDeprecateEOF", 98 | } -------------------------------------------------------------------------------- /protocol/math.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | func Max(x, y int) int { 4 | if x > y { 5 | return x 6 | } 7 | return y 8 | } -------------------------------------------------------------------------------- /protocol/protocol.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "net" 8 | ) 9 | 10 | /* 11 | PacketHeader represents packet header 12 | */ 13 | type PacketHeader struct { 14 | Length uint32 15 | SequenceId uint8 16 | } 17 | 18 | /* 19 | InitialHandshakePacket represents initial handshake packet sent by MySQL Server 20 | */ 21 | type InitialHandshakePacket struct { 22 | ProtocolVersion uint8 23 | ServerVersion []byte 24 | ConnectionId uint32 25 | AuthPluginData []byte 26 | Filler byte 27 | CapabilitiesFlags CapabilityFlag 28 | CharacterSet uint8 29 | StatusFlags uint16 30 | AuthPluginDataLen uint8 31 | AuthPluginName []byte 32 | header *PacketHeader 33 | } 34 | 35 | // Decode decodes the first packet received from the MySQl Server 36 | // It's a handshake packet 37 | func (r *InitialHandshakePacket) Decode(conn net.Conn) error { 38 | data := make([]byte, 1024) 39 | _, err := conn.Read(data) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | header := &PacketHeader{} 45 | ln := []byte{data[0], data[1], data[2], 0x00} 46 | header.Length = binary.LittleEndian.Uint32(ln) 47 | // a single byte integer is the same in BigEndian and LittleEndian 48 | header.SequenceId = data[3] 49 | 50 | r.header = header 51 | /** 52 | Assign payload only data to new var just for convenience 53 | */ 54 | payload := data[4:header.Length + 4] 55 | position := 0 56 | /** 57 | As defined in the documentation, this value is alway 10 (0x00 in hex) 58 | 1 [0a] protocol version 59 | */ 60 | r.ProtocolVersion = payload[0] 61 | if r.ProtocolVersion != 0x0a { 62 | return errors.New("non supported protocol for the proxy. Only version 10 is supported") 63 | } 64 | 65 | position += 1 66 | 67 | /** 68 | Extract server version, by finding the terminal character (0x00) index, 69 | and extracting the data in between 70 | string[NUL] server version 71 | */ 72 | index := bytes.IndexByte(payload, byte(0x00)) 73 | r.ServerVersion = payload[position: index] 74 | position = index + 1 75 | 76 | connectionId := payload[position : position + 4] 77 | id := binary.LittleEndian.Uint32(connectionId) 78 | r.ConnectionId = id 79 | position += 4 80 | 81 | /* 82 | The auth-plugin-data is the concatenation of strings auth-plugin-data-part-1 and auth-plugin-data-part-2. 83 | */ 84 | 85 | r.AuthPluginData = make([]byte, 8) 86 | copy(r.AuthPluginData, payload[position: position + 8]) 87 | 88 | position += 8 89 | 90 | r.Filler = payload[position] 91 | if r.Filler != 0x00 { 92 | return errors.New("failed to decode filler value") 93 | } 94 | 95 | position += 1 96 | 97 | capabilitiesFlags1 := payload[position: position + 2] 98 | position += 2 99 | 100 | r.CharacterSet = payload[position] 101 | position += 1 102 | 103 | r.StatusFlags = binary.LittleEndian.Uint16(payload[position: position + 2]) 104 | position += 2 105 | 106 | capabilityFlags2 := payload[position: position + 2] 107 | position += 2 108 | 109 | /** 110 | Reconstruct 32 bit integer from two 16 bit integers. 111 | Take low 2 bytes and high 2 bytes, ans sum it. 112 | */ 113 | capLow := binary.LittleEndian.Uint16(capabilitiesFlags1) 114 | capHi := binary.LittleEndian.Uint16(capabilityFlags2) 115 | cap := uint32(capLow) | uint32(capHi) << 16 116 | 117 | r.CapabilitiesFlags = CapabilityFlag(cap) 118 | 119 | if r.CapabilitiesFlags&clientPluginAuth != 0 { 120 | r.AuthPluginDataLen = payload[position] 121 | if r.AuthPluginDataLen == 0 { 122 | return errors.New("wrong auth plugin data len") 123 | } 124 | } 125 | 126 | /* 127 | Skip reserved bytes 128 | 129 | string[10] reserved (all [00]) 130 | */ 131 | 132 | position += 1 + 10 133 | 134 | /** 135 | This flag tell us that the client should hash the password using algorithm described here: 136 | https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41 137 | */ 138 | if r.CapabilitiesFlags&clientSecureConn != 0 { 139 | /* 140 | The auth-plugin-data is the concatenation of strings auth-plugin-data-part-1 and auth-plugin-data-part-2. 141 | */ 142 | end := position + Max(13, int(r.AuthPluginDataLen) - 8) 143 | r.AuthPluginData = append(r.AuthPluginData, payload[position:end]...) 144 | position = end 145 | } 146 | 147 | index = bytes.IndexByte(payload[position:], byte(0x00)) 148 | 149 | /* 150 | Due to Bug#59453 the auth-plugin-name is missing the terminating NUL-char in versions prior to 5.5.10 and 5.6.2. 151 | We know the length of the payload, so if there is no NUL-char, just read all the data until the end 152 | */ 153 | if index != -1 { 154 | r.AuthPluginName = payload[position:position+index] 155 | } else { 156 | r.AuthPluginName = payload[position:] 157 | } 158 | 159 | return nil 160 | } 161 | 162 | // Encode encodes the InitialHandshakePacket to bytes 163 | func (r InitialHandshakePacket) Encode() ([]byte, error) { 164 | buf := make([]byte, 0) 165 | buf = append(buf, r.ProtocolVersion) 166 | buf = append(buf, r.ServerVersion...) 167 | buf = append(buf, byte(0x00)) 168 | 169 | connectionId := make([]byte, 4) 170 | binary.LittleEndian.PutUint32(connectionId, r.ConnectionId) 171 | buf = append(buf, connectionId...) 172 | 173 | //auth1 := make([]byte, 8) 174 | auth1 := r.AuthPluginData[0:8] 175 | buf = append(buf, auth1...) 176 | buf = append(buf, 0x00) 177 | 178 | cap := make([]byte, 4) 179 | binary.LittleEndian.PutUint32(cap, uint32(r.CapabilitiesFlags)) 180 | 181 | cap1 := cap[0:2] 182 | cap2 := cap[2:] 183 | 184 | buf = append(buf, cap1...) 185 | buf = append(buf, r.CharacterSet) 186 | 187 | statusFlag := make([]byte, 2) 188 | binary.LittleEndian.PutUint16(statusFlag, r.StatusFlags) 189 | buf = append(buf, statusFlag...) 190 | buf = append(buf, cap2...) 191 | buf = append(buf, r.AuthPluginDataLen) 192 | 193 | reserved := make([]byte, 10) 194 | buf = append(buf, reserved...) 195 | buf = append(buf, r.AuthPluginData[8:]...) 196 | buf = append(buf, r.AuthPluginName...) 197 | buf = append(buf, 0x00) 198 | 199 | h := PacketHeader{ 200 | Length: uint32(len(buf)), 201 | SequenceId: r.header.SequenceId, 202 | } 203 | 204 | newBuf := make([]byte, 0, h.Length + 4) 205 | 206 | ln := make([]byte, 4) 207 | binary.LittleEndian.PutUint32(ln, h.Length) 208 | 209 | newBuf = append(newBuf, ln[:3]...) 210 | newBuf = append(newBuf, h.SequenceId) 211 | newBuf = append(newBuf, buf...) 212 | 213 | return newBuf, nil 214 | } 215 | 216 | func (r InitialHandshakePacket) String() string { 217 | return r.CapabilitiesFlags.String() 218 | } 219 | -------------------------------------------------------------------------------- /proxy/connection.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "go-mysql-proxy/protocol" 6 | "io" 7 | "log" 8 | "net" 9 | ) 10 | 11 | func NewConnection(host string, port string, conn net.Conn, id uint64, enableDecoding bool) *Connection { 12 | return &Connection{ 13 | host: host, 14 | port: port, 15 | conn: conn, 16 | id: id, 17 | enableDecoding: enableDecoding, 18 | } 19 | } 20 | 21 | type Connection struct { 22 | id uint64 23 | conn net.Conn 24 | host string 25 | port string 26 | enableDecoding bool 27 | } 28 | 29 | func (r *Connection) Handle() error { 30 | address := fmt.Sprintf("%s%s", r.host, r.port) 31 | mysql, err := net.Dial("tcp", address) 32 | if err != nil { 33 | log.Printf("Failed to connection to MySQL: [%d] %s", r.id, err.Error()) 34 | return err 35 | } 36 | 37 | if !r.enableDecoding { 38 | // client to server 39 | go func() { 40 | copied, err := io.Copy(mysql, r.conn) 41 | if err != nil { 42 | log.Printf("Conection error: [%d] %s", r.id, err.Error()) 43 | } 44 | 45 | log.Printf("Connection closed. Bytes copied: [%d] %d", r.id, copied) 46 | }() 47 | 48 | copied, err := io.Copy(r.conn, mysql) 49 | if err != nil { 50 | log.Printf("Connection error: [%d] %s", r.id, err.Error()) 51 | } 52 | 53 | log.Printf("Connection closed. Bytes copied: [%d] %d", r.id, copied) 54 | 55 | return nil 56 | } 57 | 58 | handshakePacket := &protocol.InitialHandshakePacket{} 59 | err = handshakePacket.Decode(mysql) 60 | if err != nil{ 61 | log.Printf("Failed ot decode handshake initial packet: %s", err.Error()) 62 | return err 63 | } 64 | 65 | fmt.Printf("InitialHandshakePacket:\n%s\n", handshakePacket) 66 | 67 | res, _ := handshakePacket.Encode() 68 | 69 | written, err := r.conn.Write(res) 70 | if err != nil{ 71 | log.Printf("Failed to write %d: %s", written, err.Error()) 72 | return err 73 | } 74 | 75 | go func() { 76 | copied, err := io.Copy(mysql, r.conn) 77 | if err != nil { 78 | log.Printf("Conection error: [%d] %s", r.id, err.Error()) 79 | } 80 | 81 | log.Printf("Connection closed. Bytes copied: [%d] %d", r.id, copied) 82 | }() 83 | 84 | copied, err := io.Copy(r.conn, mysql) 85 | if err != nil { 86 | log.Printf("Connection error: [%d] %s", r.id, err.Error()) 87 | } 88 | 89 | log.Printf("Connection closed. Bytes copied: [%d] %d", r.id, copied) 90 | 91 | return nil 92 | } -------------------------------------------------------------------------------- /proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net" 8 | ) 9 | 10 | func NewProxy(host, port string, ctx context.Context) *Proxy { 11 | return &Proxy{ 12 | host: host, 13 | port: port, 14 | ctx: ctx, 15 | } 16 | } 17 | 18 | type Proxy struct { 19 | host string 20 | port string 21 | connectionId uint64 22 | enableDecoding bool 23 | ctx context.Context 24 | shutDownAsked bool 25 | } 26 | 27 | func (r *Proxy) Start(port string) error { 28 | log.Printf("Start listening on: %s", port) 29 | ln, err := net.Listen("tcp", fmt.Sprintf(":%s", port)) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | go func() { 35 | log.Printf("Waiting for shut down signal ^C") 36 | <-r.ctx.Done() 37 | r.shutDownAsked = true 38 | log.Printf("Shut down signal received, closing connections...") 39 | ln.Close() 40 | }() 41 | 42 | for { 43 | conn, err := ln.Accept() 44 | r.connectionId += 1 45 | if err != nil { 46 | log.Printf("Failed to accept new connection: [%d] %s", r.connectionId, err.Error()) 47 | if r.shutDownAsked { 48 | log.Printf("Shutdown asked [%d]", r.connectionId,) 49 | break 50 | } 51 | continue 52 | } 53 | 54 | log.Printf("Connection accepted: [%d] %s", r.connectionId, conn.RemoteAddr()) 55 | go r.handle(conn, r.connectionId, r.enableDecoding) 56 | } 57 | 58 | return nil 59 | } 60 | 61 | func (r *Proxy) handle(conn net.Conn, connectionId uint64, enableDecoding bool) { 62 | connection := NewConnection(r.host, r.port, conn, connectionId, enableDecoding) 63 | err := connection.Handle() 64 | if err != nil { 65 | log.Printf("Error handling proxy connection: %s", err.Error()) 66 | } 67 | } 68 | 69 | func (r *Proxy) EnableDecoding() { 70 | r.enableDecoding = true 71 | } -------------------------------------------------------------------------------- /proxy/proxy_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | _ "github.com/go-sql-driver/mysql" 7 | "log" 8 | "sync" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func Test_Proxy(t *testing.T) { 14 | var wg sync.WaitGroup 15 | wg.Add(1) 16 | ctx, cancel := context.WithCancel(context.Background()) 17 | 18 | go func() { 19 | proxy := NewProxy("127.0.0.1", ":3306", ctx) 20 | proxy.EnableDecoding() 21 | err := proxy.Start("3336") 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | 26 | wg.Done() 27 | }() 28 | 29 | time.Sleep(2 * time.Second) 30 | 31 | db, err := sql.Open("mysql", "dbuser:dbpassword@tcp(localhost:3336)/proxydb") 32 | if err != nil{ 33 | log.Fatal(err) 34 | } 35 | 36 | err = db.Ping() 37 | if err != nil { 38 | log.Fatalf("Failed ot connect to db: %s", err.Error()) 39 | } 40 | 41 | type User struct { 42 | Id int64 43 | Name string 44 | } 45 | 46 | sql := "SELECT id, name FROM users" 47 | rows, err := db.Query(sql) 48 | if err != nil { 49 | log.Fatalf("Failed to query db: %s (%s)", sql, err.Error()) 50 | } 51 | 52 | if rows.Next() { 53 | user := &User{} 54 | err := rows.Scan(&user.Id, &user.Name) 55 | if err != nil { 56 | log.Fatalf("Failed to scan row: %s", err.Error()) 57 | } 58 | 59 | log.Printf("User fetched, id: %d, name: %s", user.Id, user.Name) 60 | } 61 | 62 | if rows.Err(); err != nil { 63 | log.Fatalf("Failed fetch all data: %s", err.Error()) 64 | } 65 | 66 | cancel() 67 | wg.Wait() 68 | } 69 | --------------------------------------------------------------------------------