├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── gateway ├── config.go ├── gateway.go ├── relay.go └── tls.go ├── go.mod ├── go.sum ├── main.go ├── mysql ├── buffer.go ├── compress.go ├── conn.go ├── conn_test.go ├── constants.go ├── packet_err.go ├── packet_handshake.go ├── packet_handshake_response.go ├── protocol_test.go └── util.go └── utility └── logger.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | tidb-gateway 18 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # TODO: a golang docker build exmaple is provided by default, remove this if it is not applicable to your service 2 | # Build the manager binary 3 | FROM golang:1.18 as builder 4 | 5 | WORKDIR /workspace 6 | 7 | # Copy the go source 8 | COPY . . 9 | 10 | # Build 11 | RUN CGO_ENABLED=0 go build 12 | 13 | FROM alpine:3.10 14 | WORKDIR / 15 | 16 | COPY --from=builder /workspace/tidb-gateway . 17 | 18 | USER 65532:65532 19 | 20 | ENTRYPOINT ["/tidb-gateway"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tidb-gateway 2 | 3 | Manage client connections to multiple TiDB instances. 4 | 5 | mysql 客户端可以通过 gateway 连接不同的后端 TiDB 集群。 6 | 7 | ## How 8 | 9 | 为了兼容尽可能多的 driver,我们修改了 UserName 字段(即 mysql 的 -U 参数)来指定后端集群。 10 | 11 | 规则是 `username = {clusterid}.{username}`。 12 | 13 | 14 | ```mermaid 15 | sequenceDiagram 16 | client->>tidb-gateway: connect 17 | tidb-gateway->>client: InitialHandshake 18 | client->>tidb-gateway: HandshakeResponse 19 | Note over tidb-gateway: extract clusterID from UserName 20 | tidb-gateway->>tidb: connect 21 | tidb->>tidb-gateway: InitialHandshake 22 | tidb-gateway->>tidb: HandshakeResponse 23 | Note over client,tidb: Continue exchanging data 24 | ``` 25 | 26 | ## Example 27 | 28 | ```bash 29 | # start tidb1 (localhost:4000) 30 | > ./tidb-server 31 | # start tidb2 (localhost:4001) 32 | > ./tidb-server -P 4001 -status 10081 -path /tmp/tidb2 33 | 34 | # start tidb-gateway (localhost:3306) 35 | > ./tidb-gateway --addr :3306 --backend tidb1=localhost:4000 --backend tidb2=localhost:4001 36 | 37 | # connect tidb1 38 | > mysql -uroot -h 127.0.0.1 -u tidb1.root -D test 39 | 40 | # connect tidb2 41 | > mysql -uroot -h 127.0.0.1 -u tidb2.root -D test 42 | ``` 43 | -------------------------------------------------------------------------------- /gateway/config.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | type BackendConfig struct { 9 | ClusterID string 10 | Address string 11 | } 12 | 13 | type BackendConfigs []BackendConfig 14 | 15 | func (b *BackendConfigs) String() string { 16 | return "backend clusters" 17 | } 18 | 19 | func (b *BackendConfigs) Set(value string) error { 20 | splits := strings.SplitN(value, "=", 2) 21 | if len(splits) != 2 { 22 | return errors.New("backend must be in the form of clusterID=address") 23 | } 24 | *b = append(*b, BackendConfig{ClusterID: splits[0], Address: splits[1]}) 25 | return nil 26 | } 27 | 28 | func (b *BackendConfigs) Find(cluster string) string { 29 | for _, c := range *b { 30 | if strings.EqualFold(c.ClusterID, cluster) { 31 | return c.Address 32 | } 33 | } 34 | return cluster 35 | } 36 | 37 | // TLSConfig is used to establish TLS connection. 38 | type TLSConfig struct { 39 | CA string 40 | Cert string 41 | Key string 42 | MinVersion string 43 | } 44 | 45 | // Config is used to configure a gateway. 46 | type Config struct { 47 | TLS TLSConfig 48 | BackendConfigs BackendConfigs 49 | EnableCompression bool 50 | BackendInsecureTransport bool 51 | } 52 | -------------------------------------------------------------------------------- /gateway/gateway.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "net" 7 | "regexp" 8 | "strings" 9 | "sync" 10 | "sync/atomic" 11 | 12 | "github.com/oh-my-tidb/tidb-gateway/mysql" 13 | "github.com/oh-my-tidb/tidb-gateway/utility" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type Gateway struct { 18 | log *zap.SugaredLogger 19 | l net.Listener 20 | conf *Config 21 | tlsConf *tls.Config 22 | quit chan struct{} 23 | wg sync.WaitGroup 24 | connectionID uint32 25 | } 26 | 27 | func New(l net.Listener, conf *Config) (*Gateway, error) { 28 | tlsConfig, err := loadTLSConfig(conf.TLS.CA, conf.TLS.Cert, conf.TLS.Key, conf.TLS.MinVersion) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | return &Gateway{ 34 | log: utility.GetLogger(), 35 | conf: conf, 36 | tlsConf: tlsConfig, 37 | l: l, 38 | quit: make(chan struct{}), 39 | }, nil 40 | } 41 | 42 | func (g *Gateway) Stop() { 43 | g.log.Info("gateway starts to stop") 44 | close(g.quit) 45 | g.l.Close() 46 | g.wg.Wait() 47 | g.log.Sync() 48 | } 49 | 50 | func (g *Gateway) StartServe() { 51 | g.wg.Add(1) 52 | go g.serve() 53 | } 54 | 55 | func (g *Gateway) serve() { 56 | defer g.wg.Done() 57 | g.log.Info("gateway starts to accept connections") 58 | for { 59 | conn, err := g.l.Accept() 60 | if err != nil { 61 | return 62 | } 63 | g.wg.Add(1) 64 | go g.handleConn(conn) 65 | } 66 | } 67 | 68 | func (g *Gateway) handleConn(rawConn net.Conn) { 69 | defer g.wg.Done() 70 | 71 | connID := atomic.AddUint32(&g.connectionID, 1) 72 | // TODO: set keepalive and nodelay options 73 | g.log.Infow("accepting new connection", "connID", connID) 74 | conn := mysql.NewConn(rawConn) 75 | defer conn.Close() 76 | 77 | if err := g.sendInitialHandshake(conn, connID); err != nil { 78 | g.log.Warnw("failed to send initial handshake", "connID", connID, "err", err) 79 | return 80 | } 81 | 82 | res, err := g.recvHandshakeResponse(conn) 83 | if err != nil { 84 | g.log.Warnw("failed to recv handshake response", "connID", connID, "err", err) 85 | return 86 | } 87 | 88 | if res.Capability&mysql.ClientSSL != 0 { 89 | tlsConn := tls.Server(rawConn, g.tlsConf) 90 | if err := tlsConn.Handshake(); err != nil { 91 | g.log.Warnw("failed to upgrade to tls connection", "err", err) 92 | return 93 | } 94 | conn.SetRawConn(tlsConn) 95 | res, err = g.recvHandshakeResponse(conn) 96 | if err != nil { 97 | g.log.Warnw("failed to recv handshake response", "err", err) 98 | return 99 | } 100 | } 101 | 102 | enableCompress := res.Capability&mysql.ClientCompress != 0 103 | 104 | backendAddr, err := g.getBackendAddr(res) 105 | if err != nil { 106 | g.log.Warnw("failed to get cluster address", "connID", connID, "err", err) 107 | g.sendErr(conn, err.Error()) 108 | return 109 | } 110 | 111 | g.log.Infow("start to connect backend", "connID", connID, "backend", backendAddr) 112 | 113 | backendConn, err := g.connectBackend(backendAddr) 114 | if err != nil { 115 | g.log.Errorw("failed to connect backend", "connID", connID, "err", err) 116 | g.sendErr(conn, err.Error()) 117 | return 118 | } 119 | defer backendConn.Close() 120 | 121 | _, err = g.recvInitialHandshake(backendConn) 122 | if err != nil { 123 | g.log.Errorw("recv initial handshake from backend failed", "connID", connID, "err", err) 124 | g.sendErr(conn, err.Error()) 125 | return 126 | } 127 | 128 | // We do not really care about the content of InitialHandshake here. 129 | // Simply redirect remote's response to backend. 130 | // Hopefully they can come to a consensus. 131 | 132 | // Always connect backend without compression. 133 | // TiDB allows it even if it has compression enabled. 134 | res.Capability &= ^mysql.ClientCompress 135 | 136 | if g.conf.BackendInsecureTransport { 137 | res.Capability &= ^mysql.ClientSecureConnection 138 | } 139 | 140 | // Change auth plugin to a invalid name that backend does not know. 141 | // Backend will send a SwitchMethod to complete auth process. 142 | res.Capability |= mysql.ClientPluginAuth 143 | res.AuthPlugin = mysql.AuthInvalidMethod 144 | 145 | if err := backendConn.SendPacket(res); err != nil { 146 | g.log.Errorw("failed to send handshake response to backend", "connID", connID, "err", err) 147 | g.sendErr(conn, err.Error()) 148 | return 149 | } 150 | 151 | if res.Capability&mysql.ClientSSL != 0 { 152 | tlsConn := tls.Client(backendConn.RawConn(), &tls.Config{InsecureSkipVerify: true}) // nolint: gosec // nolint 153 | if err = tlsConn.Handshake(); err != nil { 154 | g.log.Errorw("failed to upgrade to tls connection with backend", "err", err) 155 | g.sendErr(conn, err.Error()) 156 | return 157 | } 158 | backendConn.SetRawConn(tlsConn) 159 | if err := backendConn.SendPacket(res); err != nil { 160 | g.log.Errorw("failed to send handshake response to backend", "err", err) 161 | g.sendErr(conn, err.Error()) 162 | return 163 | } 164 | } 165 | 166 | err = g.exchangeAuth(conn, backendConn) 167 | if err != nil { 168 | g.log.Errorw("failed to exchanage auth", "err", err) 169 | return 170 | } 171 | 172 | g.log.Infow("start to relay data", "connID", connID, "backend", backendAddr) 173 | 174 | if enableCompress { 175 | conn.EnableCompression() 176 | err = RelayPackets(conn, backendConn, g.quit) 177 | } else { 178 | err = RelayRawBytes(conn, backendConn, g.quit) 179 | } 180 | g.log.Infow("connection is closed", "connID", connID) 181 | } 182 | 183 | func (g *Gateway) sendInitialHandshake(conn *mysql.Conn, connID uint32) error { 184 | hs := &mysql.Handshake{ 185 | ProtocolVersion: mysql.DefaultHandshakeVersion, 186 | ServerVersion: "5.7.25-TiDB", 187 | ConnectionID: connID, 188 | AuthPluginData: make([]byte, 20), 189 | Capability: mysql.DefaultCapability, 190 | CharacterSet: mysql.DefaultCollationID, 191 | StatusFlags: mysql.ServerStatusAutocommit, 192 | AuthPluginName: mysql.AuthNativePassword, 193 | } 194 | return conn.SendPacket(hs) 195 | } 196 | 197 | func (g *Gateway) recvInitialHandshake(conn *mysql.Conn) (*mysql.Handshake, error) { 198 | var hs mysql.Handshake 199 | if err := conn.RecvPacket(&hs); err != nil { 200 | return nil, err 201 | } 202 | return &hs, nil 203 | } 204 | 205 | func (g *Gateway) recvHandshakeResponse(conn *mysql.Conn) (*mysql.HandshakeResponse, error) { 206 | var res mysql.HandshakeResponse 207 | if err := conn.RecvPacket(&res); err != nil { 208 | return nil, err 209 | } 210 | return &res, nil 211 | } 212 | 213 | func copyPacket(dst, src *mysql.Conn) ([]byte, error) { 214 | var b bytes.Buffer 215 | err := src.ReadPacket(&b) 216 | if err != nil { 217 | return nil, err 218 | } 219 | err = dst.WritePacket(b.Bytes()) 220 | if err != nil { 221 | return b.Bytes(), err 222 | } 223 | return b.Bytes(), dst.Flush() 224 | } 225 | 226 | func (g *Gateway) exchangeAuth(clientConn, backendConn *mysql.Conn) error { 227 | for { 228 | data, err := copyPacket(clientConn, backendConn) 229 | if err != nil { 230 | return err 231 | } 232 | if len(data) > 0 && (data[0] == mysql.HeaderOK || data[0] == mysql.HeaderErr) { 233 | return nil 234 | } 235 | _, err = copyPacket(backendConn, clientConn) 236 | if err != nil { 237 | return err 238 | } 239 | } 240 | } 241 | 242 | func (g *Gateway) sendErr(conn *mysql.Conn, msg string) { 243 | err := &mysql.Err{ 244 | Header: mysql.HeaderErr, 245 | Code: mysql.ErrCodeUnknown, 246 | State: mysql.UnknownState, 247 | Message: msg, 248 | Capability: mysql.DefaultCapability, 249 | } 250 | conn.SendPacket(err) 251 | } 252 | 253 | func (g *Gateway) getBackendAddr(res *mysql.HandshakeResponse) (string, error) { 254 | var clusterID string 255 | if splits := strings.SplitN(res.UserName, ".", 2); len(splits) == 1 { 256 | clusterID, res.UserName = splits[0], "" 257 | } else { 258 | clusterID, res.UserName = splits[0], splits[1] 259 | } 260 | 261 | clusterAddr := g.conf.BackendConfigs.Find(clusterID) 262 | if ok, _ := regexp.MatchString(`:\d+$`, clusterAddr); !ok { 263 | clusterAddr = clusterAddr + ":4000" 264 | } 265 | 266 | return clusterAddr, nil 267 | } 268 | 269 | func (g *Gateway) connectBackend(addr string) (*mysql.Conn, error) { 270 | rawConn, err := net.Dial("tcp", addr) 271 | if err != nil { 272 | return nil, err 273 | } 274 | return mysql.NewConn(rawConn), nil 275 | } 276 | -------------------------------------------------------------------------------- /gateway/relay.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | 7 | "github.com/oh-my-tidb/tidb-gateway/mysql" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // RelayRawBytes relays raw bytes between remote and backend. 12 | func RelayRawBytes(remote, backend *mysql.Conn, quit <-chan struct{}) error { 13 | remote.SetResetOption(mysql.SeqResetBoth) 14 | backend.SetResetOption(mysql.SeqResetBoth) 15 | errCh := make(chan error, 2) // nolint:gomnd // nolint 16 | go func() { 17 | _, err := io.Copy(backend.RawConn(), remote.RawConn()) 18 | errCh <- errors.Wrap(err, "remote -> backend closed") 19 | }() 20 | go func() { 21 | _, err := io.Copy(remote.RawConn(), backend.RawConn()) 22 | errCh <- errors.Wrap(err, "backend -> remote closed") 23 | }() 24 | select { 25 | case err := <-errCh: 26 | return err 27 | case <-quit: 28 | return errors.New("relayer is closed") 29 | } 30 | } 31 | 32 | // RelayPacketes relays packets between remote and backend. 33 | func RelayPackets(remote, backend *mysql.Conn, quit <-chan struct{}) error { 34 | remote.SetResetOption(mysql.SeqResetBoth) 35 | backend.SetResetOption(mysql.SeqResetBoth) 36 | errCh := make(chan error, 2) // nolint:gomnd // nolint 37 | go copyInboundPackets(remote, backend, errCh) 38 | go copyOutboundPackets(remote, backend, errCh) 39 | select { 40 | case err := <-errCh: 41 | return err 42 | case <-quit: 43 | return errors.New("relayer is closed") 44 | } 45 | } 46 | 47 | func copyInboundPackets(remote, backend *mysql.Conn, errCh chan error) { 48 | var b bytes.Buffer 49 | for { 50 | b.Reset() 51 | _, err := remote.ReadPartialPacket(&b) 52 | if err != nil { 53 | errCh <- errors.Wrap(err, "read from remote failed") 54 | return 55 | } 56 | backend.SetResetOption(mysql.SeqResetOnWrite) 57 | err = backend.WritePacket(b.Bytes()) 58 | if err == nil { 59 | err = backend.Flush() 60 | } 61 | if err != nil { 62 | errCh <- errors.Wrap(err, "write to backend failed") 63 | return 64 | } 65 | } 66 | } 67 | 68 | func copyOutboundPackets(remote, backend *mysql.Conn, errCh chan error) { 69 | var totalBytes int64 70 | var b bytes.Buffer 71 | for { 72 | b.Reset() 73 | n, err := backend.ReadPartialPacket(&b) 74 | if err != nil { 75 | errCh <- errors.Wrap(err, "read from backend failed") 76 | return 77 | } 78 | totalBytes += int64(n) 79 | remote.SetResetOption(mysql.SeqResetOnRead) 80 | err = remote.WritePacket(b.Bytes()) 81 | if err != nil { 82 | errCh <- errors.Wrap(err, "write to remote failed") 83 | return 84 | } 85 | if b.Len() == 0 || 86 | b.Bytes()[0] == mysql.HeaderOK || 87 | b.Bytes()[0] == mysql.HeaderEOF || 88 | b.Bytes()[0] == mysql.HeaderErr { 89 | err = remote.Flush() 90 | // if first byte is other value, it means it is paritial 91 | // result and there will be more packets so we don't 92 | // need to flush. 93 | } 94 | if err != nil { 95 | errCh <- errors.Wrap(err, "write to remote failed") 96 | return 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /gateway/tls.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "io/ioutil" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | func loadTLSConfig(ca, cert, key, version string) (*tls.Config, error) { 12 | if ca == "" && cert == "" && key == "" { 13 | return nil, nil 14 | } 15 | 16 | var tlsConfig tls.Config 17 | if ca != "" { 18 | caCert, err := ioutil.ReadFile(ca) 19 | if err != nil { 20 | return nil, errors.Wrap(err, "failed to read ca") 21 | } 22 | caCertPool := x509.NewCertPool() 23 | caCertPool.AppendCertsFromPEM(caCert) 24 | tlsConfig.RootCAs = caCertPool 25 | } 26 | if cert != "" && key != "" { 27 | cert, err := tls.LoadX509KeyPair(cert, key) 28 | if err != nil { 29 | return nil, errors.Wrap(err, "failed to load key pair") 30 | } 31 | tlsConfig.Certificates = []tls.Certificate{cert} 32 | } 33 | tlsConfig.MinVersion = tls.VersionTLS12 34 | switch version { 35 | case "TLSv1.0": 36 | tlsConfig.MinVersion = tls.VersionTLS10 37 | case "TLSv1.1": 38 | tlsConfig.MinVersion = tls.VersionTLS11 39 | case "TLSv1.2": 40 | tlsConfig.MinVersion = tls.VersionTLS12 41 | case "TLSv1.3": 42 | tlsConfig.MinVersion = tls.VersionTLS13 43 | } 44 | return &tlsConfig, nil 45 | } 46 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/oh-my-tidb/tidb-gateway 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/pkg/errors v0.9.1 7 | github.com/stretchr/testify v1.7.1 8 | ) 9 | 10 | require ( 11 | go.uber.org/atomic v1.7.0 // indirect 12 | go.uber.org/multierr v1.6.0 // indirect 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | go.uber.org/zap v1.21.0 19 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 2 | github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 7 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 8 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 9 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 10 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 11 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 12 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 18 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 20 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 22 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 23 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 24 | go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= 25 | go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= 26 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 27 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 28 | go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= 29 | go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= 30 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 31 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 32 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 33 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 34 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 35 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 36 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 37 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 38 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 39 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 40 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 41 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 42 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 43 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 44 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 45 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 46 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 47 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 48 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 49 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 50 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 51 | golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 52 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 53 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 54 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 55 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 56 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 57 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 58 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 59 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 60 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 61 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 62 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 63 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | "github.com/oh-my-tidb/tidb-gateway/gateway" 11 | "github.com/oh-my-tidb/tidb-gateway/utility" 12 | ) 13 | 14 | var ( 15 | addr string 16 | tlsCA string 17 | tlsCert string 18 | tlsKey string 19 | tlsVersion string 20 | backendConfigs gateway.BackendConfigs 21 | enableCompression bool 22 | backendInsecureTransport bool 23 | ) 24 | 25 | func main() { 26 | flag.StringVar(&addr, "addr", ":3306", "listening address") 27 | flag.StringVar(&tlsCA, "tls-ca", "", "TLS CA file") 28 | flag.StringVar(&tlsCert, "tls-cert", "", "TLS cert file") 29 | flag.StringVar(&tlsKey, "tls-key", "", "TLS key file") 30 | flag.StringVar(&tlsVersion, "tls-version", "", "Minimal TLS version (TLSv1.0/TLSv1.1/TLSv1.2/TLSv1.3)") 31 | flag.BoolVar(&enableCompression, "compress", false, "Enable compression") 32 | flag.Var(&backendConfigs, "backend", "backend cluster configs") 33 | flag.BoolVar(&backendInsecureTransport, "backend-insecure-transport", false, "Using insecure connection to backend") 34 | flag.Parse() 35 | 36 | log := utility.GetLogger() 37 | log.Infow("initializing gateway", "addr", addr, "backend", backendConfigs) 38 | 39 | lis, err := net.Listen("tcp", addr) 40 | if err != nil { 41 | log.Errorw("failed to listen", "err", err) 42 | return 43 | } 44 | 45 | tlsConfig := gateway.TLSConfig{ 46 | CA: tlsCA, 47 | Cert: tlsCert, 48 | Key: tlsKey, 49 | MinVersion: tlsVersion, 50 | } 51 | 52 | gw, err := gateway.New(lis, &gateway.Config{ 53 | TLS: tlsConfig, 54 | BackendConfigs: backendConfigs, 55 | EnableCompression: enableCompression, 56 | BackendInsecureTransport: backendInsecureTransport, 57 | }) 58 | if err != nil { 59 | log.Errorw("failed to create gateway", "err", err) 60 | return 61 | } 62 | gw.StartServe() 63 | 64 | sigs := make(chan os.Signal, 1) 65 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 66 | sig := <-sigs 67 | log.Warnw("received signal", "signal", sig) 68 | gw.Stop() 69 | } 70 | -------------------------------------------------------------------------------- /mysql/buffer.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "io" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Buffer wraps bytes.Buffer for read/write mysql data types. 12 | type Buffer struct { 13 | b *bytes.Buffer 14 | } 15 | 16 | func newBuffer(data []byte) *Buffer { 17 | return &Buffer{bytes.NewBuffer(data)} 18 | } 19 | 20 | // WriteByte writes a single byte. 21 | func (b *Buffer) WriteByte(by byte) { 22 | b.b.WriteByte(by) 23 | } 24 | 25 | // ReadByte reads a single byte. 26 | func (b *Buffer) ReadByte() (byte, error) { 27 | by, err := b.b.ReadByte() 28 | return by, errors.WithStack(err) 29 | } 30 | 31 | // WriteBytes writes bytes. 32 | func (b *Buffer) WriteBytes(bys []byte) { 33 | b.b.Write(bys) 34 | } 35 | 36 | // ReadBytes reads n bytes. 37 | func (b *Buffer) ReadBytes(n int) ([]byte, error) { 38 | data := b.b.Next(n) 39 | if len(data) == n { 40 | return data, nil 41 | } 42 | return nil, errors.WithStack(io.EOF) 43 | } 44 | 45 | // WriteString writes a string followed by a null byte. 46 | func (b *Buffer) WriteStringNull(s string) { 47 | b.b.WriteString(s) 48 | b.WriteByte(0x00) 49 | } 50 | 51 | // ReadStringNull reads a string followed by a null byte. 52 | func (b *Buffer) ReadStringNull() (string, error) { 53 | s, err := b.b.ReadString(0x00) 54 | if err != nil { 55 | return "", errors.WithStack(err) 56 | } 57 | return s[:len(s)-1], nil 58 | } 59 | 60 | // WriteUint32 writes a uint32. 61 | func (b *Buffer) WriteUint32(n uint32) { 62 | var b4 [4]byte 63 | binary.LittleEndian.PutUint32(b4[:], n) 64 | b.WriteBytes(b4[:]) 65 | } 66 | 67 | // ReadUint32 reads a uint32. 68 | func (b *Buffer) ReadUint32() (uint32, error) { 69 | data, err := b.ReadBytes(4) 70 | if err != nil { 71 | return 0, err 72 | } 73 | return binary.LittleEndian.Uint32(data), nil 74 | } 75 | 76 | // WriteUint16 writes a uint16. 77 | func (b *Buffer) WriteUint16(n uint16) { 78 | var b2 [2]byte 79 | binary.LittleEndian.PutUint16(b2[:], n) 80 | b.WriteBytes(b2[:]) 81 | } 82 | 83 | // ReadUint16 reads a uint16. 84 | func (b *Buffer) ReadUint16() (uint16, error) { 85 | data, err := b.ReadBytes(2) 86 | if err != nil { 87 | return 0, err 88 | } 89 | return binary.LittleEndian.Uint16(data), nil 90 | } 91 | 92 | // WriteUint24 writes a uint24. 93 | func (b *Buffer) WriteUint24(n uint32) { 94 | b.WriteUint16(uint16(n & 0xFFFF)) 95 | b.WriteByte(byte(n >> 16)) 96 | } 97 | 98 | // ReadUint24 reads a uint24. 99 | func (b *Buffer) ReadUint24() (uint32, error) { 100 | u16, err := b.ReadUint16() 101 | if err != nil { 102 | return 0, err 103 | } 104 | b3, err := b.ReadByte() 105 | if err != nil { 106 | return 0, err 107 | } 108 | return uint32(u16) | uint32(b3)<<16, nil 109 | } 110 | 111 | // WriteUint64 writes a uint64. 112 | func (b *Buffer) WriteUint64(n uint64) { 113 | var b8 [8]byte 114 | binary.LittleEndian.PutUint64(b8[:], n) 115 | b.WriteBytes(b8[:]) 116 | } 117 | 118 | // ReadUint64 reads a uint64. 119 | func (b *Buffer) ReadUint64() (uint64, error) { 120 | data, err := b.ReadBytes(8) 121 | if err != nil { 122 | return 0, err 123 | } 124 | return binary.LittleEndian.Uint64(data), nil 125 | } 126 | 127 | // Len returns the number of bytes written to the buffer or left to read from 128 | // the buffer 129 | func (b *Buffer) Len() int { 130 | return b.b.Len() 131 | } 132 | 133 | // Skip skips n bytes for read. 134 | func (b *Buffer) Skip(n int) error { 135 | _, err := b.ReadBytes(n) 136 | return err 137 | } 138 | 139 | // WriteLenencInt writes a lenenc int. 140 | func (b *Buffer) WriteLenencInt(n uint64) { 141 | switch { 142 | case n < 251: 143 | b.WriteByte(byte(n)) 144 | case n >= 251 && n < (1<<16): 145 | b.WriteByte(0xFC) 146 | b.WriteUint16(uint16(n)) 147 | case n >= (1<<16) && n < (1<<24): 148 | b.WriteByte(0xFD) 149 | b.WriteUint24(uint32(n)) 150 | default: 151 | b.WriteByte(0xFE) 152 | b.WriteUint64(n) 153 | } 154 | } 155 | 156 | // ReadLenencInt reads a lenenc int. 157 | func (b *Buffer) ReadLenencInt() (uint64, error) { 158 | b1, err := b.ReadByte() 159 | if err != nil { 160 | return 0, err 161 | } 162 | switch { 163 | case b1 < 0xFC: 164 | return uint64(b1), nil 165 | case b1 == 0xFC: 166 | n, err := b.ReadUint16() 167 | return uint64(n), err 168 | case b1 == 0xFD: 169 | n, err := b.ReadUint24() 170 | return uint64(n), err 171 | case b1 == 0xFE: 172 | return b.ReadUint64() 173 | } 174 | return 0, errors.New("invalid lenenc int") 175 | } 176 | 177 | // WriteLenencString writes a lenenc string. 178 | func (b *Buffer) WriteLenencString(s string) { 179 | b.WriteLenencInt(uint64(len(s))) 180 | b.b.WriteString(s) 181 | } 182 | 183 | // ReadLenencString reads a lenenc string. 184 | func (b *Buffer) ReadLenencString() (string, error) { 185 | n, err := b.ReadLenencInt() 186 | if err != nil { 187 | return "", err 188 | } 189 | data, err := b.ReadBytes(int(n)) 190 | if err != nil { 191 | return "", err 192 | } 193 | return string(data), nil 194 | } 195 | 196 | // Bytes returns the underlying bytes. 197 | func (b *Buffer) Bytes() []byte { 198 | return b.b.Bytes() 199 | } 200 | -------------------------------------------------------------------------------- /mysql/compress.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "compress/zlib" 6 | "io" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | const ( 12 | minCompressLen = 128 13 | maxBufferLen = (1 << 23) - 1 14 | ) 15 | 16 | // Compressor wraps a Reader and a WriteFlusher for compression. 17 | type Compressor struct { 18 | r io.Reader 19 | w WriteFlusher 20 | sequence uint8 21 | seqreset uint8 22 | readBuffer bytes.Buffer // decompressed data to be read. 23 | writeBuffer bytes.Buffer // bytes to be compressed. 24 | flushBuffer bytes.Buffer // compressed data to be sent. 25 | } 26 | 27 | // NewCompressor creates a new Compressor. 28 | func NewCompressor(r io.Reader, w WriteFlusher) *Compressor { 29 | return &Compressor{ 30 | r: r, 31 | w: w, 32 | } 33 | } 34 | 35 | // Read reads data from the underlying reader. 36 | func (c *Compressor) Read(p []byte) (int, error) { 37 | // drain buffer before reading next trunk. 38 | if n := c.readBuffer.Len(); n > 0 { 39 | if n > len(p) { 40 | n = len(p) 41 | } 42 | return c.readBuffer.Read(p[:n]) 43 | } 44 | return 0, c.loadNextTrunk() 45 | } 46 | 47 | func (c *Compressor) loadNextTrunk() error { 48 | c.readBuffer.Reset() 49 | 50 | var head [7]byte 51 | n, err := io.ReadFull(c.r, head[:]) 52 | if n != 7 { 53 | return err // err is guranateed not nil. 54 | } 55 | if c.seqreset&SeqResetOnRead != 0 { 56 | c.seqreset &= ^SeqResetOnRead 57 | c.sequence = 0 58 | } 59 | 60 | payloadLen := readLen3(head[0:3]) 61 | sequence := head[3] 62 | uncompressedLen := readLen3(head[4:7]) 63 | 64 | if sequence != c.sequence { 65 | return errors.Errorf("invalid sequence %d != %d", sequence, c.sequence) 66 | } 67 | 68 | if uncompressedLen == 0 { 69 | // uncompressed payload. 70 | n, err := io.CopyN(&c.readBuffer, c.r, int64(payloadLen)) 71 | if n != int64(payloadLen) { 72 | return err // err is guranateed not nil. 73 | } 74 | } else { 75 | zr, err := zlib.NewReader(io.LimitReader(c.r, int64(payloadLen))) 76 | if err != nil { 77 | return err 78 | } 79 | n, err := io.Copy(&c.readBuffer, zr) 80 | if n != int64(uncompressedLen) { 81 | return errors.Errorf("uncompessed length mismatch %d != %d", n, uncompressedLen) 82 | } 83 | } 84 | c.sequence++ 85 | return nil 86 | } 87 | 88 | // Write writes data to the underlying writer. It works like bufio.Writer with compression. 89 | func (c *Compressor) Write(p []byte) (int, error) { 90 | for len(p) > 0 { 91 | capacity := maxBufferLen - c.writeBuffer.Len() 92 | if capacity >= len(p) { 93 | return c.writeBuffer.Write(p) 94 | } 95 | n, err := c.writeBuffer.Write(p[:capacity]) 96 | if n != capacity { 97 | return n, err 98 | } 99 | err = c.Flush() 100 | if err != nil { 101 | return n, err 102 | } 103 | p = p[capacity:] 104 | } 105 | return 0, nil 106 | } 107 | 108 | // Flush compress then flush the data to the underlying writer. 109 | func (c *Compressor) Flush() error { 110 | if c.seqreset&SeqResetOnWrite != 0 { 111 | c.seqreset &= ^SeqResetOnWrite 112 | c.sequence = 0 113 | } 114 | 115 | var head [7]byte 116 | var payload []byte 117 | 118 | if c.writeBuffer.Len() < minCompressLen { 119 | // write without compression. 120 | writeLen3(head[0:3], c.writeBuffer.Len()) 121 | head[3] = c.sequence 122 | writeLen3(head[4:7], 0) 123 | payload = c.writeBuffer.Bytes() 124 | } else { 125 | // with compression. 126 | zw := zlib.NewWriter(&c.flushBuffer) 127 | n, err := zw.Write(c.writeBuffer.Bytes()) 128 | if n != c.writeBuffer.Len() { 129 | return err // err is guranateed not nil. 130 | } 131 | err = zw.Close() 132 | if err != nil { 133 | return err 134 | } 135 | writeLen3(head[0:3], c.flushBuffer.Len()) 136 | head[3] = c.sequence 137 | writeLen3(head[4:7], c.writeBuffer.Len()) 138 | payload = c.flushBuffer.Bytes() 139 | } 140 | 141 | n, _ := c.w.Write(head[:]) 142 | if n != 7 { 143 | return errors.WithStack(ErrBadConn) 144 | } 145 | n, _ = c.w.Write(payload) 146 | if n != len(payload) { 147 | return errors.WithStack(ErrBadConn) 148 | } 149 | c.sequence++ 150 | c.writeBuffer.Reset() 151 | c.flushBuffer.Reset() 152 | return c.w.Flush() 153 | } 154 | 155 | // SetResetOption marks the sequence to be reset on next read or write. 156 | func (c *Compressor) SetResetOption(opt uint8) { 157 | c.seqreset = opt 158 | } 159 | -------------------------------------------------------------------------------- /mysql/conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 16 | // 17 | // This Source Code Form is subject to the terms of the Mozilla Public 18 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 19 | // You can obtain one at http://mozilla.org/MPL/2.0/. 20 | 21 | // The MIT License (MIT) 22 | // 23 | // Copyright (c) 2014 wandoulabs 24 | // Copyright (c) 2014 siddontang 25 | // 26 | // Permission is hereby granted, free of charge, to any person obtaining a copy of 27 | // this software and associated documentation files (the "Software"), to deal in 28 | // the Software without restriction, including without limitation the rights to 29 | // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 30 | // the Software, and to permit persons to whom the Software is furnished to do so, 31 | // subject to the following conditions: 32 | // 33 | // The above copyright notice and this permission notice shall be included in all 34 | // copies or substantial portions of the Software. 35 | 36 | package mysql 37 | 38 | import ( 39 | "bufio" 40 | "bytes" 41 | "io" 42 | "math" 43 | "net" 44 | "time" 45 | 46 | "github.com/pkg/errors" 47 | ) 48 | 49 | // Portable analogs of some common call errors. 50 | var ( 51 | ErrBadConn = errors.New("connection was bad") 52 | errNetPacketTooLarge = errors.New("net packet too large") 53 | ErrMalformPacket = errors.New("malform packet") 54 | ) 55 | 56 | const ( 57 | defaultWriterSize = 16 * 1024 58 | defaultReaderSize = 16 * 1024 59 | ) 60 | 61 | const ( 62 | // MaxPayloadLen is the max packet payload length. 63 | MaxPayloadLen = 1<<24 - 1 64 | ) 65 | 66 | // Options to determine when the sequence number should be reset. 67 | const ( 68 | SeqResetNone uint8 = 0 69 | SeqResetOnRead uint8 = 1 70 | SeqResetOnWrite uint8 = 2 71 | SeqResetBoth uint8 = 3 72 | ) 73 | 74 | // WriterFlusher represents a buffered writer. (like bufio.Writer) 75 | type WriteFlusher interface { 76 | io.Writer 77 | Flush() error 78 | } 79 | 80 | // Conn wraps net.Conn for data read/write. 81 | // MySQL Packets: https://dev.mysql.com/doc/internals/en/mysql-packet.html 82 | type Conn struct { 83 | conn net.Conn 84 | r io.Reader 85 | w WriteFlusher 86 | sequence uint8 87 | seqreset uint8 88 | readTimeout time.Duration 89 | // maxAllowedPacket is the maximum size of one packet in readPacket. 90 | maxAllowedPacket uint64 91 | compressor *Compressor 92 | } 93 | 94 | // NewConn wraps a raw net.Conn into a Conn. 95 | func NewConn(conn net.Conn) *Conn { 96 | return &Conn{ 97 | conn: conn, 98 | r: bufio.NewReaderSize(conn, defaultReaderSize), 99 | w: bufio.NewWriterSize(conn, defaultWriterSize), 100 | // TODO: config max allowed packet 101 | maxAllowedPacket: math.MaxUint64, 102 | } 103 | } 104 | 105 | // SetRawConn resets the underlying net.Conn. 106 | // Used for upgrading to TLS. 107 | func (c *Conn) SetRawConn(conn net.Conn) { 108 | c.conn = conn 109 | c.r = bufio.NewReaderSize(conn, defaultReaderSize) 110 | c.w = bufio.NewWriterSize(conn, defaultWriterSize) 111 | } 112 | 113 | // SetReadTimeout sets the read timeout for the connection. 114 | func (c *Conn) SetReadTimeout(timeout time.Duration) { 115 | c.readTimeout = timeout 116 | } 117 | 118 | // SetMaxAllowedPacket sets the maximum packet size. 119 | func (c *Conn) SetMaxAllowedPacket(maxAllowedPacket uint64) { 120 | c.maxAllowedPacket = maxAllowedPacket 121 | } 122 | 123 | // Packet is the interface for a MySQL packet. 124 | type Packet interface { 125 | Write(b *Buffer) 126 | Read(b *Buffer) error 127 | } 128 | 129 | // SendPacket sends a MySQL packet. 130 | func (c *Conn) SendPacket(pkt Packet) error { 131 | b := newBuffer(nil) 132 | pkt.Write(b) 133 | if err := c.WritePacket(b.Bytes()); err != nil { 134 | return err 135 | } 136 | return c.Flush() 137 | } 138 | 139 | // RecvPacket receives a MySQL packet. 140 | func (c *Conn) RecvPacket(pkg Packet) error { 141 | var b bytes.Buffer 142 | err := c.ReadPacket(&b) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | return pkg.Read(newBuffer(b.Bytes())) 148 | } 149 | 150 | func (c *Conn) readFull(data []byte) error { 151 | if c.readTimeout > 0 { 152 | if err := c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { 153 | return errors.WithStack(err) 154 | } 155 | } 156 | if _, err := io.ReadFull(c.r, data); err != nil { 157 | return errors.WithStack(err) 158 | } 159 | return nil 160 | } 161 | 162 | // ReadPacket reads a complete MySQL packet. 163 | func (c *Conn) ReadPacket(b *bytes.Buffer) error { 164 | for { 165 | n, err := c.ReadPartialPacket(b) 166 | if err != nil { 167 | return err 168 | } 169 | if n < MaxPayloadLen { 170 | return nil 171 | } 172 | } 173 | } 174 | 175 | // ReadpartialPacket reads a MySQL wire packet. It may be 176 | // part of a larger packet. 177 | func (c *Conn) ReadPartialPacket(b *bytes.Buffer) (n int, err error) { 178 | var head [4]byte 179 | if err = c.readFull(head[:]); err != nil { 180 | return 181 | } 182 | if c.seqreset&SeqResetOnRead != 0 { 183 | c.seqreset &= ^SeqResetOnRead 184 | c.sequence = 0 185 | } 186 | sequence := head[3] 187 | if sequence != c.sequence { 188 | return 0, errors.Errorf("invalid sequence %d != %d", sequence, c.sequence) 189 | } 190 | c.sequence++ 191 | 192 | n = readLen3(head[:3]) 193 | b.Grow(n) 194 | readLen, err := b.ReadFrom(&io.LimitedReader{R: c.r, N: int64(n)}) 195 | if int(readLen) != n { 196 | return int(readLen), err 197 | } 198 | return n, nil 199 | } 200 | 201 | // WritePacket writes data. 202 | func (c *Conn) WritePacket(data []byte) error { 203 | if c.seqreset&SeqResetOnWrite != 0 { 204 | c.seqreset &= ^SeqResetOnWrite 205 | c.sequence = 0 206 | } 207 | 208 | var head [4]byte 209 | for { 210 | plen := len(data) 211 | if plen >= MaxPayloadLen { 212 | plen = MaxPayloadLen 213 | } 214 | writeLen3(head[:3], plen) 215 | head[3] = c.sequence 216 | 217 | n, err := c.w.Write(head[:]) 218 | if err != nil || n != len(head) { 219 | return errors.WithStack(ErrBadConn) 220 | } 221 | 222 | if n, err := c.w.Write(data[:plen]); err != nil { 223 | return errors.WithStack(ErrBadConn) 224 | } else if n != plen { 225 | return errors.WithStack(ErrBadConn) 226 | } else { 227 | c.sequence++ 228 | data = data[plen:] 229 | } 230 | 231 | if len(data) == 0 { 232 | return nil 233 | } 234 | } 235 | } 236 | 237 | // Flush flushes data to the underlying connection. 238 | func (c *Conn) Flush() error { 239 | err := c.w.Flush() 240 | if err != nil { 241 | return errors.WithStack(err) 242 | } 243 | return err 244 | } 245 | 246 | // RawConn returns the underlying net.Conn. 247 | func (p *Conn) RawConn() net.Conn { 248 | return p.conn 249 | } 250 | 251 | // Close closes the connection. 252 | func (p *Conn) Close() { 253 | p.conn.Close() 254 | } 255 | 256 | // SetResetOption marks the connection to reset sequence on next read/write. 257 | func (c *Conn) SetResetOption(opt uint8) { 258 | c.seqreset = opt 259 | if c.compressor != nil { 260 | c.compressor.SetResetOption(opt) 261 | } 262 | } 263 | 264 | // EnableCompression wraps the underlying reader and writer to support compression. 265 | func (c *Conn) EnableCompression() { 266 | c.compressor = NewCompressor(c.r, c.w) 267 | c.r = c.compressor 268 | c.w = c.compressor 269 | } 270 | -------------------------------------------------------------------------------- /mysql/conn_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "math/rand" 7 | "net" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestConnMultiplePackets(t *testing.T) { 16 | client, server := makeConnPair() 17 | testConnMultiplePackets(t, client, server) 18 | client, server = makeConnPairWithCompression() 19 | testConnMultiplePackets(t, client, server) 20 | } 21 | 22 | func testConnMultiplePackets(t *testing.T, client, server *Conn) { 23 | defer client.Close() 24 | defer server.Close() 25 | 26 | for i := 0; i < 10; i++ { 27 | p := randomPayloads() 28 | var wg sync.WaitGroup 29 | goSendPayloads(t, &wg, client, p) 30 | result := recvPayloads(t, server, len(p)) 31 | require.Equal(t, p, result) 32 | wg.Wait() 33 | } 34 | } 35 | 36 | func TestConnRequestResponse(t *testing.T) { 37 | client, server := makeConnPair() 38 | testConnRequestResponse(t, client, server) 39 | client, server = makeConnPairWithCompression() 40 | testConnRequestResponse(t, client, server) 41 | } 42 | 43 | func testConnRequestResponse(t *testing.T, client, server *Conn) { 44 | defer client.Close() 45 | defer server.Close() 46 | 47 | for i := 0; i < 5; i++ { 48 | p := randomPayloads() 49 | var wg sync.WaitGroup 50 | goSendPayloads(t, &wg, client, p) 51 | result := recvPayloads(t, server, len(p)) 52 | require.Equal(t, p, result) 53 | wg.Wait() 54 | 55 | p, result = randomPayloads(), nil 56 | goSendPayloads(t, &wg, server, p) 57 | result = recvPayloads(t, client, len(p)) 58 | require.Equal(t, p, result) 59 | wg.Wait() 60 | 61 | // reset sequence number. 62 | client.SetResetOption(SeqResetOnWrite) 63 | server.SetResetOption(SeqResetOnRead) 64 | } 65 | } 66 | 67 | func randomPayloads() [][]byte { 68 | p := make([][]byte, rand.Intn(10)+1) 69 | for i := range p { 70 | p[i] = make([]byte, rand.Intn(10)*rand.Intn(1024)) 71 | rand.Read(p[i]) 72 | } 73 | return p 74 | } 75 | 76 | func goSendPayloads(t *testing.T, wg *sync.WaitGroup, conn *Conn, payload [][]byte) { 77 | wg.Add(1) 78 | go func() { 79 | defer wg.Done() 80 | for _, p := range payload { 81 | err := conn.WritePacket(p) 82 | require.NoError(t, err) 83 | } 84 | err := conn.Flush() 85 | require.NoError(t, err) 86 | }() 87 | } 88 | 89 | func recvPayloads(t *testing.T, conn *Conn, n int) [][]byte { 90 | var payload [][]byte 91 | for i := 0; i < n; i++ { 92 | var b bytes.Buffer 93 | err := conn.ReadPacket(&b) 94 | require.NoError(t, err) 95 | payload = append(payload, b.Bytes()) 96 | } 97 | return payload 98 | } 99 | 100 | type mockConn struct { 101 | *io.PipeReader 102 | *io.PipeWriter 103 | } 104 | 105 | func (conn mockConn) Close() error { 106 | conn.PipeReader.Close() 107 | conn.PipeWriter.Close() 108 | return nil 109 | } 110 | 111 | func (conn mockConn) LocalAddr() net.Addr { 112 | return nil 113 | } 114 | 115 | func (conn mockConn) RemoteAddr() net.Addr { 116 | return nil 117 | } 118 | 119 | func (conn mockConn) SetDeadline(t time.Time) error { 120 | return nil 121 | } 122 | 123 | func (conn mockConn) SetReadDeadline(t time.Time) error { 124 | return nil 125 | } 126 | 127 | func (conn mockConn) SetWriteDeadline(t time.Time) error { 128 | return nil 129 | } 130 | 131 | func makeConnPair() (*Conn, *Conn) { 132 | r1, w1 := io.Pipe() 133 | r2, w2 := io.Pipe() 134 | return NewConn(mockConn{r1, w2}), NewConn(mockConn{r2, w1}) 135 | } 136 | 137 | func makeConnPairWithCompression() (*Conn, *Conn) { 138 | conn1, conn2 := makeConnPair() 139 | conn1.EnableCompression() 140 | conn2.EnableCompression() 141 | return conn1, conn2 142 | } 143 | -------------------------------------------------------------------------------- /mysql/constants.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | // Packet constants. 4 | const ( 5 | DefaultHandshakeVersion = 10 6 | DefaultCollationID = 46 7 | DefaultCapability = ClientLongPassword | ClientLongFlag | 8 | ClientConnectWithDB | ClientProtocol41 | 9 | ClientTransactions | ClientSecureConnection | ClientFoundRows | 10 | ClientMultiStatements | ClientMultiResults | ClientLocalFiles | 11 | ClientConnectAttrs | ClientPluginAuth | ClientInteractive 12 | ) 13 | 14 | // OK packet constants. 15 | const ( 16 | HeaderOK = 0x00 17 | HeaderEOF = 0xFE 18 | HeaderErr = 0xFF 19 | ) 20 | 21 | // Server information. 22 | const ( 23 | ServerStatusInTrans uint16 = 0x0001 24 | ServerStatusAutocommit uint16 = 0x0002 25 | ServerMoreResultsExists uint16 = 0x0008 26 | ServerStatusNoGoodIndexUsed uint16 = 0x0010 27 | ServerStatusNoIndexUsed uint16 = 0x0020 28 | ServerStatusCursorExists uint16 = 0x0040 29 | ServerStatusLastRowSend uint16 = 0x0080 30 | ServerStatusDBDropped uint16 = 0x0100 31 | ServerStatusNoBackslashEscaped uint16 = 0x0200 32 | ServerStatusMetadataChanged uint16 = 0x0400 33 | ServerStatusWasSlow uint16 = 0x0800 34 | ServerPSOutParams uint16 = 0x1000 35 | ServerStatusInTransReadonly uint16 = 0x2000 36 | ServerSessionStateChanged uint16 = 0x4000 37 | ) 38 | 39 | // Client information. 40 | const ( 41 | ClientLongPassword uint32 = 1 << iota 42 | ClientFoundRows 43 | ClientLongFlag 44 | ClientConnectWithDB 45 | ClientNoSchema 46 | ClientCompress 47 | ClientODBC 48 | ClientLocalFiles 49 | ClientIgnoreSpace 50 | ClientProtocol41 51 | ClientInteractive 52 | ClientSSL 53 | ClientIgnoreSigpipe 54 | ClientTransactions 55 | ClientReserved 56 | ClientSecureConnection 57 | ClientMultiStatements 58 | ClientMultiResults 59 | ClientPSMultiResults 60 | ClientPluginAuth 61 | ClientConnectAttrs 62 | ClientPluginAuthLenencClientData 63 | ClientCanHandleExpiredPasswords 64 | ClientSessionTrack 65 | ClientDeprecateEOF 66 | ) 67 | 68 | // Auth name information. 69 | const ( 70 | AuthInvalidMethod = "invalid_dummy_method" 71 | AuthNativePassword = "mysql_native_password" // #nosec G101 72 | AuthCachingSha2Password = "caching_sha2_password" // #nosec G101 73 | AuthSocket = "auth_socket" 74 | ) 75 | 76 | // Collations maps MySQL collation ID to its name. 77 | var Collations = map[uint8]string{ 78 | 1: "big5_chinese_ci", 79 | 2: "latin2_czech_cs", 80 | 3: "dec8_swedish_ci", 81 | 4: "cp850_general_ci", 82 | 5: "latin1_german1_ci", 83 | 6: "hp8_english_ci", 84 | 7: "koi8r_general_ci", 85 | 8: "latin1_swedish_ci", 86 | 9: "latin2_general_ci", 87 | 10: "swe7_swedish_ci", 88 | 11: "ascii_general_ci", 89 | 12: "ujis_japanese_ci", 90 | 13: "sjis_japanese_ci", 91 | 14: "cp1251_bulgarian_ci", 92 | 15: "latin1_danish_ci", 93 | 16: "hebrew_general_ci", 94 | 18: "tis620_thai_ci", 95 | 19: "euckr_korean_ci", 96 | 20: "latin7_estonian_cs", 97 | 21: "latin2_hungarian_ci", 98 | 22: "koi8u_general_ci", 99 | 23: "cp1251_ukrainian_ci", 100 | 24: "gb2312_chinese_ci", 101 | 25: "greek_general_ci", 102 | 26: "cp1250_general_ci", 103 | 27: "latin2_croatian_ci", 104 | 28: "gbk_chinese_ci", 105 | 29: "cp1257_lithuanian_ci", 106 | 30: "latin5_turkish_ci", 107 | 31: "latin1_german2_ci", 108 | 32: "armscii8_general_ci", 109 | 33: "utf8_general_ci", 110 | 34: "cp1250_czech_cs", 111 | 35: "ucs2_general_ci", 112 | 36: "cp866_general_ci", 113 | 37: "keybcs2_general_ci", 114 | 38: "macce_general_ci", 115 | 39: "macroman_general_ci", 116 | 40: "cp852_general_ci", 117 | 41: "latin7_general_ci", 118 | 42: "latin7_general_cs", 119 | 43: "macce_bin", 120 | 44: "cp1250_croatian_ci", 121 | 45: "utf8mb4_general_ci", 122 | 46: "utf8mb4_bin", 123 | 47: "latin1_bin", 124 | 48: "latin1_general_ci", 125 | 49: "latin1_general_cs", 126 | 50: "cp1251_bin", 127 | 51: "cp1251_general_ci", 128 | 52: "cp1251_general_cs", 129 | 53: "macroman_bin", 130 | 54: "utf16_general_ci", 131 | 55: "utf16_bin", 132 | 56: "utf16le_general_ci", 133 | 57: "cp1256_general_ci", 134 | 58: "cp1257_bin", 135 | 59: "cp1257_general_ci", 136 | 60: "utf32_general_ci", 137 | 61: "utf32_bin", 138 | 62: "utf16le_bin", 139 | 63: "binary", 140 | 64: "armscii8_bin", 141 | 65: "ascii_bin", 142 | 66: "cp1250_bin", 143 | 67: "cp1256_bin", 144 | 68: "cp866_bin", 145 | 69: "dec8_bin", 146 | 70: "greek_bin", 147 | 71: "hebrew_bin", 148 | 72: "hp8_bin", 149 | 73: "keybcs2_bin", 150 | 74: "koi8r_bin", 151 | 75: "koi8u_bin", 152 | 77: "latin2_bin", 153 | 78: "latin5_bin", 154 | 79: "latin7_bin", 155 | 80: "cp850_bin", 156 | 81: "cp852_bin", 157 | 82: "swe7_bin", 158 | 83: "utf8_bin", 159 | 84: "big5_bin", 160 | 85: "euckr_bin", 161 | 86: "gb2312_bin", 162 | 87: "gbk_bin", 163 | 88: "sjis_bin", 164 | 89: "tis620_bin", 165 | 90: "ucs2_bin", 166 | 91: "ujis_bin", 167 | 92: "geostd8_general_ci", 168 | 93: "geostd8_bin", 169 | 94: "latin1_spanish_ci", 170 | 95: "cp932_japanese_ci", 171 | 96: "cp932_bin", 172 | 97: "eucjpms_japanese_ci", 173 | 98: "eucjpms_bin", 174 | 99: "cp1250_polish_ci", 175 | 101: "utf16_unicode_ci", 176 | 102: "utf16_icelandic_ci", 177 | 103: "utf16_latvian_ci", 178 | 104: "utf16_romanian_ci", 179 | 105: "utf16_slovenian_ci", 180 | 106: "utf16_polish_ci", 181 | 107: "utf16_estonian_ci", 182 | 108: "utf16_spanish_ci", 183 | 109: "utf16_swedish_ci", 184 | 110: "utf16_turkish_ci", 185 | 111: "utf16_czech_ci", 186 | 112: "utf16_danish_ci", 187 | 113: "utf16_lithuanian_ci", 188 | 114: "utf16_slovak_ci", 189 | 115: "utf16_spanish2_ci", 190 | 116: "utf16_roman_ci", 191 | 117: "utf16_persian_ci", 192 | 118: "utf16_esperanto_ci", 193 | 119: "utf16_hungarian_ci", 194 | 120: "utf16_sinhala_ci", 195 | 121: "utf16_german2_ci", 196 | 122: "utf16_croatian_ci", 197 | 123: "utf16_unicode_520_ci", 198 | 124: "utf16_vietnamese_ci", 199 | 128: "ucs2_unicode_ci", 200 | 129: "ucs2_icelandic_ci", 201 | 130: "ucs2_latvian_ci", 202 | 131: "ucs2_romanian_ci", 203 | 132: "ucs2_slovenian_ci", 204 | 133: "ucs2_polish_ci", 205 | 134: "ucs2_estonian_ci", 206 | 135: "ucs2_spanish_ci", 207 | 136: "ucs2_swedish_ci", 208 | 137: "ucs2_turkish_ci", 209 | 138: "ucs2_czech_ci", 210 | 139: "ucs2_danish_ci", 211 | 140: "ucs2_lithuanian_ci", 212 | 141: "ucs2_slovak_ci", 213 | 142: "ucs2_spanish2_ci", 214 | 143: "ucs2_roman_ci", 215 | 144: "ucs2_persian_ci", 216 | 145: "ucs2_esperanto_ci", 217 | 146: "ucs2_hungarian_ci", 218 | 147: "ucs2_sinhala_ci", 219 | 148: "ucs2_german2_ci", 220 | 149: "ucs2_croatian_ci", 221 | 150: "ucs2_unicode_520_ci", 222 | 151: "ucs2_vietnamese_ci", 223 | 159: "ucs2_general_mysql500_ci", 224 | 160: "utf32_unicode_ci", 225 | 161: "utf32_icelandic_ci", 226 | 162: "utf32_latvian_ci", 227 | 163: "utf32_romanian_ci", 228 | 164: "utf32_slovenian_ci", 229 | 165: "utf32_polish_ci", 230 | 166: "utf32_estonian_ci", 231 | 167: "utf32_spanish_ci", 232 | 168: "utf32_swedish_ci", 233 | 169: "utf32_turkish_ci", 234 | 170: "utf32_czech_ci", 235 | 171: "utf32_danish_ci", 236 | 172: "utf32_lithuanian_ci", 237 | 173: "utf32_slovak_ci", 238 | 174: "utf32_spanish2_ci", 239 | 175: "utf32_roman_ci", 240 | 176: "utf32_persian_ci", 241 | 177: "utf32_esperanto_ci", 242 | 178: "utf32_hungarian_ci", 243 | 179: "utf32_sinhala_ci", 244 | 180: "utf32_german2_ci", 245 | 181: "utf32_croatian_ci", 246 | 182: "utf32_unicode_520_ci", 247 | 183: "utf32_vietnamese_ci", 248 | 192: "utf8_unicode_ci", 249 | 193: "utf8_icelandic_ci", 250 | 194: "utf8_latvian_ci", 251 | 195: "utf8_romanian_ci", 252 | 196: "utf8_slovenian_ci", 253 | 197: "utf8_polish_ci", 254 | 198: "utf8_estonian_ci", 255 | 199: "utf8_spanish_ci", 256 | 200: "utf8_swedish_ci", 257 | 201: "utf8_turkish_ci", 258 | 202: "utf8_czech_ci", 259 | 203: "utf8_danish_ci", 260 | 204: "utf8_lithuanian_ci", 261 | 205: "utf8_slovak_ci", 262 | 206: "utf8_spanish2_ci", 263 | 207: "utf8_roman_ci", 264 | 208: "utf8_persian_ci", 265 | 209: "utf8_esperanto_ci", 266 | 210: "utf8_hungarian_ci", 267 | 211: "utf8_sinhala_ci", 268 | 212: "utf8_german2_ci", 269 | 213: "utf8_croatian_ci", 270 | 214: "utf8_unicode_520_ci", 271 | 215: "utf8_vietnamese_ci", 272 | 223: "utf8_general_mysql500_ci", 273 | 224: "utf8mb4_unicode_ci", 274 | 225: "utf8mb4_icelandic_ci", 275 | 226: "utf8mb4_latvian_ci", 276 | 227: "utf8mb4_romanian_ci", 277 | 228: "utf8mb4_slovenian_ci", 278 | 229: "utf8mb4_polish_ci", 279 | 230: "utf8mb4_estonian_ci", 280 | 231: "utf8mb4_spanish_ci", 281 | 232: "utf8mb4_swedish_ci", 282 | 233: "utf8mb4_turkish_ci", 283 | 234: "utf8mb4_czech_ci", 284 | 235: "utf8mb4_danish_ci", 285 | 236: "utf8mb4_lithuanian_ci", 286 | 237: "utf8mb4_slovak_ci", 287 | 238: "utf8mb4_spanish2_ci", 288 | 239: "utf8mb4_roman_ci", 289 | 240: "utf8mb4_persian_ci", 290 | 241: "utf8mb4_esperanto_ci", 291 | 242: "utf8mb4_hungarian_ci", 292 | 243: "utf8mb4_sinhala_ci", 293 | 244: "utf8mb4_german2_ci", 294 | 245: "utf8mb4_croatian_ci", 295 | 246: "utf8mb4_unicode_520_ci", 296 | 247: "utf8mb4_vietnamese_ci", 297 | 255: "utf8mb4_0900_ai_ci", 298 | } 299 | 300 | // CollationNames maps MySQL collation name to its ID 301 | var CollationNames = map[string]uint8{ 302 | "big5_chinese_ci": 1, 303 | "latin2_czech_cs": 2, 304 | "dec8_swedish_ci": 3, 305 | "cp850_general_ci": 4, 306 | "latin1_german1_ci": 5, 307 | "hp8_english_ci": 6, 308 | "koi8r_general_ci": 7, 309 | "latin1_swedish_ci": 8, 310 | "latin2_general_ci": 9, 311 | "swe7_swedish_ci": 10, 312 | "ascii_general_ci": 11, 313 | "ujis_japanese_ci": 12, 314 | "sjis_japanese_ci": 13, 315 | "cp1251_bulgarian_ci": 14, 316 | "latin1_danish_ci": 15, 317 | "hebrew_general_ci": 16, 318 | "tis620_thai_ci": 18, 319 | "euckr_korean_ci": 19, 320 | "latin7_estonian_cs": 20, 321 | "latin2_hungarian_ci": 21, 322 | "koi8u_general_ci": 22, 323 | "cp1251_ukrainian_ci": 23, 324 | "gb2312_chinese_ci": 24, 325 | "greek_general_ci": 25, 326 | "cp1250_general_ci": 26, 327 | "latin2_croatian_ci": 27, 328 | "gbk_chinese_ci": 28, 329 | "cp1257_lithuanian_ci": 29, 330 | "latin5_turkish_ci": 30, 331 | "latin1_german2_ci": 31, 332 | "armscii8_general_ci": 32, 333 | "utf8_general_ci": 33, 334 | "cp1250_czech_cs": 34, 335 | "ucs2_general_ci": 35, 336 | "cp866_general_ci": 36, 337 | "keybcs2_general_ci": 37, 338 | "macce_general_ci": 38, 339 | "macroman_general_ci": 39, 340 | "cp852_general_ci": 40, 341 | "latin7_general_ci": 41, 342 | "latin7_general_cs": 42, 343 | "macce_bin": 43, 344 | "cp1250_croatian_ci": 44, 345 | "utf8mb4_general_ci": 45, 346 | "utf8mb4_bin": 46, 347 | "latin1_bin": 47, 348 | "latin1_general_ci": 48, 349 | "latin1_general_cs": 49, 350 | "cp1251_bin": 50, 351 | "cp1251_general_ci": 51, 352 | "cp1251_general_cs": 52, 353 | "macroman_bin": 53, 354 | "utf16_general_ci": 54, 355 | "utf16_bin": 55, 356 | "utf16le_general_ci": 56, 357 | "cp1256_general_ci": 57, 358 | "cp1257_bin": 58, 359 | "cp1257_general_ci": 59, 360 | "utf32_general_ci": 60, 361 | "utf32_bin": 61, 362 | "utf16le_bin": 62, 363 | "binary": 63, 364 | "armscii8_bin": 64, 365 | "ascii_bin": 65, 366 | "cp1250_bin": 66, 367 | "cp1256_bin": 67, 368 | "cp866_bin": 68, 369 | "dec8_bin": 69, 370 | "greek_bin": 70, 371 | "hebrew_bin": 71, 372 | "hp8_bin": 72, 373 | "keybcs2_bin": 73, 374 | "koi8r_bin": 74, 375 | "koi8u_bin": 75, 376 | "latin2_bin": 77, 377 | "latin5_bin": 78, 378 | "latin7_bin": 79, 379 | "cp850_bin": 80, 380 | "cp852_bin": 81, 381 | "swe7_bin": 82, 382 | "utf8_bin": 83, 383 | "big5_bin": 84, 384 | "euckr_bin": 85, 385 | "gb2312_bin": 86, 386 | "gbk_bin": 87, 387 | "sjis_bin": 88, 388 | "tis620_bin": 89, 389 | "ucs2_bin": 90, 390 | "ujis_bin": 91, 391 | "geostd8_general_ci": 92, 392 | "geostd8_bin": 93, 393 | "latin1_spanish_ci": 94, 394 | "cp932_japanese_ci": 95, 395 | "cp932_bin": 96, 396 | "eucjpms_japanese_ci": 97, 397 | "eucjpms_bin": 98, 398 | "cp1250_polish_ci": 99, 399 | "utf16_unicode_ci": 101, 400 | "utf16_icelandic_ci": 102, 401 | "utf16_latvian_ci": 103, 402 | "utf16_romanian_ci": 104, 403 | "utf16_slovenian_ci": 105, 404 | "utf16_polish_ci": 106, 405 | "utf16_estonian_ci": 107, 406 | "utf16_spanish_ci": 108, 407 | "utf16_swedish_ci": 109, 408 | "utf16_turkish_ci": 110, 409 | "utf16_czech_ci": 111, 410 | "utf16_danish_ci": 112, 411 | "utf16_lithuanian_ci": 113, 412 | "utf16_slovak_ci": 114, 413 | "utf16_spanish2_ci": 115, 414 | "utf16_roman_ci": 116, 415 | "utf16_persian_ci": 117, 416 | "utf16_esperanto_ci": 118, 417 | "utf16_hungarian_ci": 119, 418 | "utf16_sinhala_ci": 120, 419 | "utf16_german2_ci": 121, 420 | "utf16_croatian_ci": 122, 421 | "utf16_unicode_520_ci": 123, 422 | "utf16_vietnamese_ci": 124, 423 | "ucs2_unicode_ci": 128, 424 | "ucs2_icelandic_ci": 129, 425 | "ucs2_latvian_ci": 130, 426 | "ucs2_romanian_ci": 131, 427 | "ucs2_slovenian_ci": 132, 428 | "ucs2_polish_ci": 133, 429 | "ucs2_estonian_ci": 134, 430 | "ucs2_spanish_ci": 135, 431 | "ucs2_swedish_ci": 136, 432 | "ucs2_turkish_ci": 137, 433 | "ucs2_czech_ci": 138, 434 | "ucs2_danish_ci": 139, 435 | "ucs2_lithuanian_ci": 140, 436 | "ucs2_slovak_ci": 141, 437 | "ucs2_spanish2_ci": 142, 438 | "ucs2_roman_ci": 143, 439 | "ucs2_persian_ci": 144, 440 | "ucs2_esperanto_ci": 145, 441 | "ucs2_hungarian_ci": 146, 442 | "ucs2_sinhala_ci": 147, 443 | "ucs2_german2_ci": 148, 444 | "ucs2_croatian_ci": 149, 445 | "ucs2_unicode_520_ci": 150, 446 | "ucs2_vietnamese_ci": 151, 447 | "ucs2_general_mysql500_ci": 159, 448 | "utf32_unicode_ci": 160, 449 | "utf32_icelandic_ci": 161, 450 | "utf32_latvian_ci": 162, 451 | "utf32_romanian_ci": 163, 452 | "utf32_slovenian_ci": 164, 453 | "utf32_polish_ci": 165, 454 | "utf32_estonian_ci": 166, 455 | "utf32_spanish_ci": 167, 456 | "utf32_swedish_ci": 168, 457 | "utf32_turkish_ci": 169, 458 | "utf32_czech_ci": 170, 459 | "utf32_danish_ci": 171, 460 | "utf32_lithuanian_ci": 172, 461 | "utf32_slovak_ci": 173, 462 | "utf32_spanish2_ci": 174, 463 | "utf32_roman_ci": 175, 464 | "utf32_persian_ci": 176, 465 | "utf32_esperanto_ci": 177, 466 | "utf32_hungarian_ci": 178, 467 | "utf32_sinhala_ci": 179, 468 | "utf32_german2_ci": 180, 469 | "utf32_croatian_ci": 181, 470 | "utf32_unicode_520_ci": 182, 471 | "utf32_vietnamese_ci": 183, 472 | "utf8_unicode_ci": 192, 473 | "utf8_icelandic_ci": 193, 474 | "utf8_latvian_ci": 194, 475 | "utf8_romanian_ci": 195, 476 | "utf8_slovenian_ci": 196, 477 | "utf8_polish_ci": 197, 478 | "utf8_estonian_ci": 198, 479 | "utf8_spanish_ci": 199, 480 | "utf8_swedish_ci": 200, 481 | "utf8_turkish_ci": 201, 482 | "utf8_czech_ci": 202, 483 | "utf8_danish_ci": 203, 484 | "utf8_lithuanian_ci": 204, 485 | "utf8_slovak_ci": 205, 486 | "utf8_spanish2_ci": 206, 487 | "utf8_roman_ci": 207, 488 | "utf8_persian_ci": 208, 489 | "utf8_esperanto_ci": 209, 490 | "utf8_hungarian_ci": 210, 491 | "utf8_sinhala_ci": 211, 492 | "utf8_german2_ci": 212, 493 | "utf8_croatian_ci": 213, 494 | "utf8_unicode_520_ci": 214, 495 | "utf8_vietnamese_ci": 215, 496 | "utf8_general_mysql500_ci": 223, 497 | "utf8mb4_unicode_ci": 224, 498 | "utf8mb4_icelandic_ci": 225, 499 | "utf8mb4_latvian_ci": 226, 500 | "utf8mb4_romanian_ci": 227, 501 | "utf8mb4_slovenian_ci": 228, 502 | "utf8mb4_polish_ci": 229, 503 | "utf8mb4_estonian_ci": 230, 504 | "utf8mb4_spanish_ci": 231, 505 | "utf8mb4_swedish_ci": 232, 506 | "utf8mb4_turkish_ci": 233, 507 | "utf8mb4_czech_ci": 234, 508 | "utf8mb4_danish_ci": 235, 509 | "utf8mb4_lithuanian_ci": 236, 510 | "utf8mb4_slovak_ci": 237, 511 | "utf8mb4_spanish2_ci": 238, 512 | "utf8mb4_roman_ci": 239, 513 | "utf8mb4_persian_ci": 240, 514 | "utf8mb4_esperanto_ci": 241, 515 | "utf8mb4_hungarian_ci": 242, 516 | "utf8mb4_sinhala_ci": 243, 517 | "utf8mb4_german2_ci": 244, 518 | "utf8mb4_croatian_ci": 245, 519 | "utf8mb4_unicode_520_ci": 246, 520 | "utf8mb4_vietnamese_ci": 247, 521 | "utf8mb4_0900_ai_ci": 255, 522 | } 523 | 524 | const ( 525 | ErrCodeUnknown = 1105 526 | UnknownState = "08S01" 527 | ) 528 | -------------------------------------------------------------------------------- /mysql/packet_err.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | // Err represnets a MySQL packet that contains an error. 4 | type Err struct { 5 | Header byte 6 | Code uint16 7 | State string 8 | Message string 9 | Capability uint32 10 | } 11 | 12 | // Write writes the packet to a buffer. 13 | func (e *Err) Write(b *Buffer) { 14 | b.WriteByte(e.Header) 15 | b.WriteUint16(e.Code) 16 | if e.Capability&ClientProtocol41 != 0 { 17 | b.WriteByte('#') 18 | b.WriteBytes([]byte(e.State)) 19 | } 20 | b.WriteBytes([]byte(e.Message)) 21 | } 22 | 23 | // Read reads packet from a buffer. 24 | func (e *Err) Read(b *Buffer) error { 25 | panic("implemented") 26 | } 27 | -------------------------------------------------------------------------------- /mysql/packet_handshake.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import "github.com/pkg/errors" 4 | 5 | // Handshake is the initial handshake packet sent from server to client. 6 | type Handshake struct { 7 | ProtocolVersion uint8 8 | ServerVersion string 9 | ConnectionID uint32 10 | AuthPluginData []byte 11 | Capability uint32 12 | CharacterSet uint8 13 | StatusFlags uint16 14 | AuthPluginName string 15 | } 16 | 17 | // Write writes the packet to a buffer. 18 | func (s *Handshake) Write(b *Buffer) { 19 | // 1 [0a] protocol version 20 | b.WriteByte(s.ProtocolVersion) 21 | // string[NUL] server version 22 | b.WriteStringNull(s.ServerVersion) 23 | // 4 connection id 24 | b.WriteUint32(s.ConnectionID) 25 | // string[8] auth-plugin-data-part-1 26 | b.WriteBytes(s.AuthPluginData[:8]) 27 | // 1 [00] filler 28 | b.WriteByte(0x00) 29 | // 2 capability flags (lower 2 bytes) 30 | b.WriteUint16(uint16(s.Capability & 0xFFFF)) 31 | // 1 character set 32 | b.WriteByte(s.CharacterSet) 33 | // 2 status flags 34 | b.WriteUint16(s.StatusFlags) 35 | // 2 capability flags (upper 2 bytes) 36 | b.WriteUint16(uint16(s.Capability >> 16)) 37 | // if capabilities & CLIENT_PLUGIN_AUTH { 38 | // 1 length of auth-plugin-data 39 | // } else { 40 | // 1 [00] 41 | // } 42 | if s.Capability&ClientPluginAuth != 0 { 43 | b.WriteByte(byte(len(s.AuthPluginData) + 1)) 44 | } else { 45 | b.WriteByte(0x00) 46 | } 47 | // string[10] reserved (all [00]) 48 | b.WriteBytes(make([]byte, 10)) 49 | // if capabilities & CLIENT_SECURE_CONNECTION { 50 | // string[$len] auth-plugin-data-part-2 ($len=MAX(13, length of auth-plugin-data - 8)) 51 | if s.Capability&ClientSecureConnection != 0 { 52 | l := len(s.AuthPluginData) - 8 53 | b.WriteBytes(s.AuthPluginData[8 : 8+l]) 54 | if l < 13 { 55 | b.WriteBytes(make([]byte, 13-l)) 56 | } 57 | } 58 | // if capabilities & CLIENT_PLUGIN_AUTH { 59 | // string[NUL] auth-plugin name 60 | if s.Capability&ClientPluginAuth != 0 { 61 | b.WriteStringNull(s.AuthPluginName) 62 | } 63 | } 64 | 65 | // Read reads the packet from a buffer. 66 | // Support V9 or V10. 67 | func (s *Handshake) Read(b *Buffer) error { 68 | var err error 69 | // 1 [0a] or [09] protocol version 70 | s.ProtocolVersion, err = b.ReadByte() 71 | if err != nil { 72 | return err 73 | } 74 | if s.ProtocolVersion != 10 && s.ProtocolVersion != 9 { 75 | return errors.New("only support protocol v9 or v10") 76 | } 77 | 78 | // string[NUL] server version 79 | s.ServerVersion, err = b.ReadStringNull() 80 | if err != nil { 81 | return err 82 | } 83 | 84 | // 4 connection id 85 | s.ConnectionID, err = b.ReadUint32() 86 | if err != nil { 87 | return err 88 | } 89 | 90 | if s.ProtocolVersion == 9 { 91 | // string[NUL] scramble 92 | str, err := b.ReadStringNull() 93 | if err != nil { 94 | return err 95 | } 96 | s.AuthPluginData = []byte(str) 97 | return nil 98 | } 99 | 100 | // string[8] auth-plugin-data-part-1 101 | data, err := b.ReadBytes(8) 102 | if err != nil { 103 | return err 104 | } 105 | s.AuthPluginData = append(s.AuthPluginData, data...) 106 | 107 | // 1 [00] filler 108 | if _, err := b.ReadByte(); err != nil { 109 | return err 110 | } 111 | 112 | // 2 capability flags (lower 2 bytes) 113 | capLow, err := b.ReadUint16() 114 | if err != nil { 115 | return err 116 | } 117 | s.Capability = uint32(capLow) 118 | 119 | if b.Len() == 0 { 120 | return nil 121 | } 122 | // if more data in the packet: 123 | 124 | // 1 character set 125 | s.CharacterSet, err = b.ReadByte() 126 | if err != nil { 127 | return err 128 | } 129 | 130 | // 2 status flags 131 | s.StatusFlags, err = b.ReadUint16() 132 | if err != nil { 133 | return err 134 | } 135 | 136 | // 2 capability flags (upper 2 bytes) 137 | capHigh, err := b.ReadUint16() 138 | if err != nil { 139 | return err 140 | } 141 | s.Capability |= uint32(capHigh) << 16 142 | 143 | // if capabilities & CLIENT_PLUGIN_AUTH { 144 | // 1 length of auth-plugin-data 145 | // } else { 146 | // 1 [00] 147 | // } 148 | var authDataLen byte 149 | authDataLen, err = b.ReadByte() 150 | if err != nil { 151 | return err 152 | } 153 | 154 | // string[10] reserved (all [00]) 155 | if err = b.Skip(10); err != nil { 156 | return err 157 | } 158 | 159 | // if capabilities & CLIENT_SECURE_CONNECTION { 160 | // string[$len] auth-plugin-data-part-2 ($len=MAX(13, length of auth-plugin-data - 8)) 161 | if s.Capability&ClientSecureConnection != 0 { 162 | l := int(authDataLen) - 8 - 1 163 | data, err = b.ReadBytes(l) 164 | if err != nil { 165 | return err 166 | } 167 | if l < 13 { 168 | err = b.Skip(13 - l) 169 | if err != nil { 170 | return err 171 | } 172 | } 173 | s.AuthPluginData = append(s.AuthPluginData, data...) 174 | } 175 | 176 | // if capabilities & CLIENT_PLUGIN_AUTH { 177 | // string[NUL] auth-plugin name 178 | if s.Capability&ClientPluginAuth != 0 { 179 | s.AuthPluginName, err = b.ReadStringNull() 180 | if err != nil { 181 | return err 182 | } 183 | } 184 | 185 | return nil 186 | } 187 | -------------------------------------------------------------------------------- /mysql/packet_handshake_response.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | // HandshakerResponse is the initial handshake response from the client. 4 | type HandshakeResponse struct { 5 | Capability uint32 6 | MaxPacketSize uint32 7 | CharacterSet byte 8 | UserName string 9 | DBName string 10 | Auth []byte 11 | AuthPlugin string 12 | Attrs map[string]string 13 | } 14 | 15 | // Write writes the handshake response to the buffer. 16 | func (s *HandshakeResponse) Write(b *Buffer) { 17 | // 4 capability flags 18 | b.WriteUint32(s.Capability) 19 | 20 | if s.Capability&ClientProtocol41 == 0 { 21 | // old format: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse320 22 | 23 | // 3 max-packet size 24 | b.WriteUint24(s.MaxPacketSize) 25 | // string[NUL] username 26 | b.WriteStringNull(s.UserName) 27 | // if capabilities & CLIENT_CONNECT_WITH_DB { 28 | // string[NUL] auth-response 29 | // string[NUL] database 30 | // } else { 31 | // string[EOF] auth-response 32 | // } 33 | b.WriteBytes(s.Auth) 34 | if s.Capability&ClientConnectWithDB != 0 { 35 | b.WriteByte(0x00) 36 | b.WriteBytes([]byte(s.DBName)) 37 | b.WriteByte(0x00) 38 | } 39 | return 40 | } 41 | 42 | // new format: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 43 | 44 | // 4 max-packet size 45 | b.WriteUint32(s.MaxPacketSize) 46 | // 1 character set 47 | b.WriteByte(s.CharacterSet) 48 | // string[23] reserved (all [0]) 49 | b.WriteBytes(make([]byte, 23)) 50 | // string[NUL] username 51 | b.WriteStringNull(s.UserName) 52 | // if capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA { 53 | // lenenc-int length of auth-response 54 | // string[n] auth-response 55 | // } else if capabilities & CLIENT_SECURE_CONNECTION { 56 | // 1 length of auth-response 57 | // string[n] auth-response 58 | // } else { 59 | // string[NUL] auth-response 60 | // } 61 | if s.Capability&ClientPluginAuthLenencClientData != 0 { 62 | b.WriteLenencString(string(s.Auth)) 63 | } else if s.Capability&ClientSecureConnection != 0 { 64 | b.WriteByte(byte(len(s.Auth))) 65 | b.WriteBytes(s.Auth) 66 | } else { 67 | b.WriteStringNull(string(s.Auth)) 68 | } 69 | 70 | // if capabilities & CLIENT_CONNECT_WITH_DB { 71 | // string[NUL] database 72 | // } 73 | if s.Capability&ClientConnectWithDB != 0 { 74 | b.WriteStringNull(s.DBName) 75 | } 76 | 77 | // if capabilities & CLIENT_PLUGIN_AUTH { 78 | // string[NUL] auth plugin name 79 | // } 80 | if s.Capability&ClientPluginAuth != 0 { 81 | b.WriteStringNull(s.AuthPlugin) 82 | } 83 | 84 | // if capabilities & CLIENT_CONNECT_ATTRS { 85 | // lenenc-int length of all key-values 86 | // lenenc-str key 87 | // lenenc-str value 88 | if s.Capability&ClientConnectAttrs != 0 { 89 | ab := newBuffer(nil) 90 | for k, v := range s.Attrs { 91 | ab.WriteLenencString(k) 92 | ab.WriteLenencString(v) 93 | } 94 | b.WriteLenencInt(uint64(ab.Len())) 95 | b.WriteBytes(ab.Bytes()) 96 | } 97 | } 98 | 99 | // Read reads the handshake response from the buffer. 100 | func (s *HandshakeResponse) Read(b *Buffer) error { 101 | var err error 102 | // 4 capability flags 103 | s.Capability, err = b.ReadUint32() 104 | if s.Capability&ClientProtocol41 == 0 { 105 | // old format: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse320 106 | 107 | // 3 max-packet size 108 | s.MaxPacketSize, err = b.ReadUint24() 109 | if err != nil { 110 | return err 111 | } 112 | // string[NUL] username 113 | s.UserName, err = b.ReadStringNull() 114 | if err != nil { 115 | return err 116 | } 117 | // if capabilities & CLIENT_CONNECT_WITH_DB { 118 | // string[NUL] auth-response 119 | // string[NUL] database 120 | // } else { 121 | // string[EOF] auth-response 122 | // } 123 | if s.Capability&ClientConnectWithDB != 0 { 124 | auth, err := b.ReadStringNull() 125 | if err != nil { 126 | return err 127 | } 128 | s.Auth = []byte(auth) 129 | s.DBName, err = b.ReadStringNull() 130 | if err != nil { 131 | return err 132 | } 133 | } else { 134 | s.Auth = b.Bytes() 135 | } 136 | return nil 137 | } 138 | 139 | // new format: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 140 | 141 | // 4 max-packet size 142 | s.MaxPacketSize, err = b.ReadUint32() 143 | if err != nil { 144 | return err 145 | } 146 | // 1 character set 147 | s.CharacterSet, err = b.ReadByte() 148 | if err != nil { 149 | return err 150 | } 151 | // string[23] reserved (all [0]) 152 | err = b.Skip(23) 153 | if err != nil { 154 | return err 155 | } 156 | 157 | // Handle SSL Connection Request. 158 | if s.Capability&ClientSSL != 0 && b.Len() == 0 { 159 | return nil 160 | } 161 | 162 | // string[NUL] username 163 | s.UserName, err = b.ReadStringNull() 164 | if err != nil { 165 | return err 166 | } 167 | // if capabilities & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA { 168 | // lenenc-int length of auth-response 169 | // string[n] auth-response 170 | // } else if capabilities & CLIENT_SECURE_CONNECTION { 171 | // 1 length of auth-response 172 | // string[n] auth-response 173 | // } else { 174 | // string[NUL] auth-response 175 | // } 176 | if s.Capability&ClientPluginAuthLenencClientData != 0 { 177 | l, err := b.ReadLenencInt() 178 | if err != nil { 179 | return err 180 | } 181 | s.Auth, err = b.ReadBytes(int(l)) 182 | if err != nil { 183 | return err 184 | } 185 | } else if s.Capability&ClientSecureConnection != 0 { 186 | l, err := b.ReadByte() 187 | if err != nil { 188 | return err 189 | } 190 | s.Auth, err = b.ReadBytes(int(l)) 191 | if err != nil { 192 | return err 193 | } 194 | } else { 195 | auth, err := b.ReadStringNull() 196 | if err != nil { 197 | return err 198 | } 199 | s.Auth = []byte(auth) 200 | } 201 | // if capabilities & CLIENT_CONNECT_WITH_DB { 202 | // string[NUL] database 203 | // } 204 | if s.Capability&ClientConnectWithDB != 0 { 205 | s.DBName, err = b.ReadStringNull() 206 | if err != nil { 207 | return err 208 | } 209 | } 210 | // if capabilities & CLIENT_PLUGIN_AUTH { 211 | // string[NUL] auth plugin name 212 | // } 213 | if s.Capability&ClientPluginAuth != 0 { 214 | s.AuthPlugin, err = b.ReadStringNull() 215 | if err != nil { 216 | return err 217 | } 218 | } 219 | // if capabilities & CLIENT_CONNECT_ATTRS { 220 | // lenenc-int length of all key-values 221 | // lenenc-str key 222 | // lenenc-str value 223 | if s.Capability&ClientConnectAttrs != 0 { 224 | l, err := b.ReadLenencInt() 225 | if err != nil { 226 | return err 227 | } 228 | data, err := b.ReadBytes(int(l)) 229 | if err != nil { 230 | return err 231 | } 232 | ab := newBuffer(data) 233 | for ab.Len() > 0 { 234 | k, err := ab.ReadLenencString() 235 | if err != nil { 236 | return err 237 | } 238 | v, err := ab.ReadLenencString() 239 | if err != nil { 240 | return err 241 | } 242 | if s.Attrs == nil { 243 | s.Attrs = make(map[string]string) 244 | } 245 | s.Attrs[k] = v 246 | } 247 | } 248 | 249 | return nil 250 | } 251 | -------------------------------------------------------------------------------- /mysql/protocol_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestProtocol(t *testing.T) { 11 | hs1 := Handshake{ 12 | ProtocolVersion: DefaultHandshakeVersion, 13 | ServerVersion: "5.7.25-TiDB", 14 | ConnectionID: 1, 15 | AuthPluginData: make([]byte, 20), 16 | Capability: DefaultCapability, 17 | CharacterSet: DefaultCollationID, 18 | StatusFlags: ServerStatusAutocommit, 19 | AuthPluginName: AuthNativePassword, 20 | } 21 | b := newBuffer(nil) 22 | hs1.Write(b) 23 | var hs2 Handshake 24 | b2 := newBuffer(b.Bytes()) 25 | hs2.Read(b2) 26 | 27 | assert.Equal(t, toJson(hs2), toJson(hs1)) 28 | } 29 | 30 | func toJson(x interface{}) string { 31 | jb, _ := json.Marshal(x) 32 | return string(jb) 33 | } 34 | -------------------------------------------------------------------------------- /mysql/util.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | func readLen3(b []byte) int { 4 | return int(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16) 5 | } 6 | 7 | func writeLen3(b []byte, n int) { 8 | b[0] = byte(n) 9 | b[1] = byte(n >> 8) 10 | b[2] = byte(n >> 16) 11 | } 12 | -------------------------------------------------------------------------------- /utility/logger.go: -------------------------------------------------------------------------------- 1 | package utility 2 | 3 | import "go.uber.org/zap" 4 | 5 | func GetLogger() *zap.SugaredLogger { 6 | logger, _ := zap.NewDevelopment() 7 | return logger.Sugar() 8 | } 9 | --------------------------------------------------------------------------------