├── .gitignore ├── .golangci.yml ├── .idea ├── .gitignore ├── codeStyles │ └── codeStyleConfig.xml ├── copyright │ ├── Geert_JM.xml │ └── profiles_settings.xml ├── dataSources.xml ├── inspectionProfiles │ └── Project_Default.xml ├── modules.xml ├── pxmysql.iml ├── sqldialects.xml ├── vcs.xml └── watcherTasks.xml ├── CHANGELOG.yaml ├── LICENSE.md ├── README.md ├── _badges ├── badges.json ├── go-version.svg └── license.svg ├── _support └── pxmysql-compose │ ├── .gitignore │ ├── conf.d │ ├── .gitignore │ ├── 01_basics.cnf │ ├── 02_tls.cnf │ ├── ca-key.pem │ ├── ca.pem │ ├── ca.srl │ ├── client-cert.pem │ ├── client-key.pem │ ├── server-cert.pem │ └── server-key.pem │ ├── conf.srl │ ├── docker-compose.yml │ ├── generate_tls.sh │ ├── server_x509_ext.conf │ └── shared │ ├── build.sh │ └── goapps │ └── unix_socket │ ├── go.mod │ ├── go.sum │ └── main.go ├── cmd ├── .gitignore ├── gencollations │ └── main.go ├── genprotobuf │ └── main.go └── make │ └── main.go ├── connection.go ├── connection_test.go ├── connector.go ├── datasource.go ├── datasource_test.go ├── decimal ├── decimal.go └── decimal_test.go ├── driver.go ├── driver_test.go ├── errors.go ├── go.mod ├── go.sum ├── interfaces └── message.go ├── internal ├── mysqlx │ ├── info.md │ ├── mysqlx │ │ └── mysqlx.pb.go │ ├── mysqlxconnection │ │ └── mysqlx_connection.pb.go │ ├── mysqlxcrud │ │ └── mysqlx_crud.pb.go │ ├── mysqlxcursor │ │ └── mysqlx_cursor.pb.go │ ├── mysqlxdatatypes │ │ └── mysqlx_datatypes.pb.go │ ├── mysqlxexpect │ │ └── mysqlx_expect.pb.go │ ├── mysqlxexpr │ │ └── mysqlx_expr.pb.go │ ├── mysqlxnotice │ │ └── mysqlx_notice.pb.go │ ├── mysqlxprepare │ │ └── mysqlx_prepare.pb.go │ ├── mysqlxresultset │ │ └── mysqlx_resultset.pb.go │ ├── mysqlxsession │ │ └── mysqlx_session.pb.go │ └── mysqlxsql │ │ └── mysqlx_sql.pb.go └── xxt │ ├── builder.go │ ├── context.go │ ├── credentials.go │ ├── docker.go │ ├── errors.go │ ├── memory.go │ └── server.go ├── main_test.go ├── mysqlerrors ├── error.go ├── error_client.go ├── error_test.go └── test_errors │ └── mysqlerrors_test.go ├── null ├── bytes.go ├── bytes_test.go ├── decimal.go ├── decimal_test.go ├── duration.go ├── duration_test.go ├── float32.go ├── float32_test.go ├── float64.go ├── float64_test.go ├── int64.go ├── int64_test.go ├── main.go ├── main_test.go ├── string.go ├── string_test.go ├── strings.go ├── strings_test.go ├── time.go ├── time_test.go ├── uint64.go └── uint64_test.go ├── register ├── mysql │ ├── register.go │ └── register_test.go ├── register.go └── register_test.go ├── result.go ├── rows.go ├── rows_test.go ├── statement.go ├── statement_test.go ├── transaction.go └── xmysql ├── _testdata ├── .gitignore ├── base.sql ├── data_types_datetime.sql ├── data_types_numeric.sql ├── data_types_string.sql ├── inserting.sql ├── prepared_stmt.sql └── schema_collections.sql ├── authentication.go ├── capabilities.go ├── collations.go ├── collations_data.go ├── collations_test.go ├── collection.go ├── collection ├── create.go └── get.go ├── collection_test.go ├── connection_config.go ├── context.go ├── context_test.go ├── crud_add.go ├── defaults.go ├── errors.go ├── examples_test.go ├── internal ├── network │ ├── messages.go │ ├── proto.go │ ├── read.go │ ├── tls.go │ ├── trace.go │ └── write.go └── statements │ ├── quote.go │ ├── quote_test.go │ ├── statement_test.go │ └── statements.go ├── main_test.go ├── notice.go ├── prepared.go ├── prepared_test.go ├── result.go ├── result_test.go ├── schema.go ├── schema_test.go ├── session.go ├── session_test.go └── xproto ├── command.go ├── expr.go ├── expr_test.go ├── fields.go ├── scalar.go └── scalar_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | ?dev_* 2 | *.proto 3 | cmd/scratch 4 | _testdata/*.pem 5 | 6 | .DS_Store -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | skip-dirs: 3 | - cmd/scratch 4 | - internal/mysqlx 5 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/copyright/Geert_JM.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/copyright/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 10 | 11 | -------------------------------------------------------------------------------- /.idea/dataSources.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | mysql.8 6 | true 7 | com.mysql.cj.jdbc.Driver 8 | jdbc:mysql://localhost:53306 9 | $ProjectFileDir$ 10 | 11 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pxmysql.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /.idea/sqldialects.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /.idea/watcherTasks.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 16 | 28 | 29 | -------------------------------------------------------------------------------- /CHANGELOG.yaml: -------------------------------------------------------------------------------- 1 | pxmysql: 2 | - meta: 3 | projectURL: https://github.com/golistic/pxmysql 4 | description: | 5 | Go MySQL driver using X Protocol communicating with the MySQL server using 6 | Protocol Buffers. 7 | 8 | All notable changes to this project will be documented in this file. 9 | We follow the conventionalcommits.org specification. 10 | 11 | Change entries with prefix `(!)` warn for a "breaking change". 12 | - versions: 13 | - version: v0.9 14 | date: 2023-01-24 15 | description: Initial development release (not production ready). 16 | patches: 17 | - version: v0.9.8 18 | date: 2023-08-27 19 | refactor: 20 | driver: 21 | - (!) We move the registration of the `sql`-driver "pxmysql" to the subpackage 22 | `github.com/golistic/pxmysql/register` (driver name "mysql" to `../register/mysql`). 23 | Refactoring should not break things, but this does. Users must change the (anonymous) 24 | import using the new sub-package. 25 | - Use `github.com/golistic/xgo/xsql` for managing the data source name. 26 | build: 27 | - Dependencies have been tidied and updated where needed. 28 | - version: v0.9.7 29 | date: 2023-08-15 30 | fixed: 31 | driver: 32 | - Properly deallocate prepared statements when using the connection methods 33 | `ExecContext` and `QueryContext` preventing the server to reach maximum 34 | prepared statements. 35 | refactor: 36 | general: 37 | - Cleanup dependencies and use `golistic/xgo` instead of the now deprecated 38 | subpackages within `golistic` or `github.com/geertjanvdk/xkit`. 39 | build: 40 | general: 41 | - Go version has been upped to 1.21 to make it clear that we eventually might 42 | use some features from that version. 43 | - version: v0.9.6 44 | date: 2023-08-09 45 | fixed: 46 | driver: 47 | - `pxmysql.QueryContext()` will now correctly return empty Rows-object when 48 | result has no rows, instead of returning `sql.ErrNoRows`. 49 | - (!) Go `sql` driver is now named `pxmysql` so it aligns with the package name; 50 | we do not keep backward compatibility. 51 | - We support the driver name "mysql" as some projects need to use this name. When 52 | this is needed, load anonymous sub-package `github.com/golistic/pxmysql/mysql`. 53 | added: 54 | driver: 55 | - We support the driver name "mysql" as some projects need to use this name. When 56 | this is needed, load anonymous sub-package `github.com/golistic/pxmysql/mysql`. 57 | build: 58 | - Upgrade ProtoBuf MySQL code to MySQL 8.0.34 (but no changes). 59 | - version: v0.9.5 60 | date: 2023-05-28 61 | fixed: 62 | - handling too large packets 63 | - wrap driver.ErrBadConn 64 | - version: v0.9.4 65 | date: 2023-05-10 66 | fixed: 67 | - Recover from server timing out connections. 68 | - version: v0.9.3 69 | date: 2023-05-10 70 | changed: 71 | - Updated protocol buffer generated code and collations to MySQL 8.0.32. 72 | added: 73 | - Added golistic/gomake targets for linting, reporting, and badges. 74 | - Added badges, generated/stored within repository, to README.md. 75 | fixed: 76 | - Fixed error returned when Unix socket is not available. 77 | - Fixed cmd/gencollations to use TLS and set password as valid nullable. 78 | - Fixed linting issues reported by linters run by golangci-lint. 79 | - Replaced deprecated package golang.org/x/crypto/ssh/terminal. 80 | - Fixed handling DATETIME zero values for time parts. 81 | - version: v0.9.2 82 | date: 2023-02-14 83 | changed: 84 | - Fixed naming of `pxmysql.ParseDSN` (before it was `ParseDNS`). 85 | fixed: 86 | - Fixed including query part when getting string representation of DataSource. 87 | - Fixed slash detection when not using schema name together with query part. 88 | - version: v0.9.1 89 | date: 2023-02-03 90 | fixed: 91 | - Fixed parsing query string of DSN so `useTLS` works as expected. 92 | - Fixed using connection address without TCP port. 93 | - Fix wrapping errors. 94 | - Finish testing Unix socket support. 95 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022, 2023, Geert JM Vanderkelen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /_badges/badges.json: -------------------------------------------------------------------------------- 1 | { 2 | "style": "flat-square", 3 | "urlShieldsIO": "https://img.shields.io/static/v1", 4 | "destFolder": "_badges", 5 | "badges": [ 6 | { 7 | "name": "go-version", 8 | "label": "GO", 9 | "messageFunc": "go.mod.version", 10 | "color": "#00ADD8" 11 | }, 12 | { 13 | "name": "license", 14 | "label": "License", 15 | "message": "MIT", 16 | "color": "#97ca00" 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /_badges/go-version.svg: -------------------------------------------------------------------------------- 1 | GO: 1.21GO1.21 -------------------------------------------------------------------------------- /_badges/license.svg: -------------------------------------------------------------------------------- 1 | License: MITLicenseMIT -------------------------------------------------------------------------------- /_support/pxmysql-compose/.gitignore: -------------------------------------------------------------------------------- 1 | data*/ -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/.gitignore: -------------------------------------------------------------------------------- 1 | *req.pem -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/01_basics.cnf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | skip_name_resolve # DNS not important for testing 3 | mysqlx_socket = mysqlx.sock -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/02_tls.cnf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | ssl_ca=/etc/mysql/conf.d/ca.pem 3 | ssl_cert=/etc/mysql/conf.d/server-cert.pem 4 | ssl_key=/etc/mysql/conf.d/server-key.pem 5 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/ca-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEAucVTs35msvYgC9iu6elW3aXzknxW3CmQasgjD7VkufT44OgG 3 | /ce10Hjp3QrDSqXqXCpVBxi8bxWHzCnSwtSh6K7/HjoC8akeVzv9HIqjUUCEpQWT 4 | 9eayDG7hPrq1wJ8Hc6krh1h//dQlDNQeUG+uuhvknA83H01yyJKQcWBN2mMLZuBN 5 | apNEAcdsWjq5rrRxj4B0Igja1g7E4XQnYfy8fY+Uc0a8G1wt01+3ftSrpa0+jJug 6 | 0c6S8H14TgDjGtGqDK2wmy4QV4rcKXqkO9E1E1Gh8LmGkexCw49pgeujzh8TqvdU 7 | 5o2jxh3YSPbZWkNfilqE2SrYAiF2X99LNRlYvwIDAQABAoIBADusg2KZK+w427py 8 | dF13MwwoDsHzZwN55oYmm/yjzCNf6cJ1RimnSWQaMyVqG6mS+mF4x69r5rvYMrMG 9 | jElBfHD+Jb1T7TYrmS90ea39atDi5LkNvaWz4WXVCE3aNCAX9ZDVusHTT+n9h5lD 10 | WimEdqAZ7amjyZUoj8KWMgf5Y4jO07iOOD3FH5g3VW3D8C/bl4hQuddlFveKp708 11 | UeGFE9EobDf1YU/vF+J+tVZC5xaniPiSicQsyebuTrTmT4Z5niVZu2C9dynzujWn 12 | 3zBjIX8iwmQdjznLCOOWJKnKr5GP+eyW9NHvpC+DFWxtp3q0gn6G8Q65ICkM+UWw 13 | sqawrIECgYEA5wDLju75zLrxhRpn5gAC1z23gPa+B0APy93fMgbVYFKFI/jk0LSu 14 | 4fCPDsQZ2OfmBzQO4Q8i8mYLvypywP1lPiaquDYSe1Ykp7pfCXEp5ZTc1OL6Zudv 15 | Hb6v7QvJ7VYXrtU3/Wy9fVDtGNQfI9J6JvXIknQbSvn9COlvSGnYNRECgYEAzd+D 16 | 4U9qm4wESYpTXM4ea9HB61YcGhu9gpLIF2H56aOSRzm382kzjI9QlpMJEiaeBXyy 17 | RZa3Uw7w5pRkGhRcwFpEJBfzvvEFQqtM9yiZw+n/2S3ZvRPeDe5dH8BqRO/KKRiq 18 | 2LSQ/OEu78/cq/GU0dAKiLmrmzVp6HHK6usfcM8CgYBWO8zBieKEk9DvYEEi8iQd 19 | V7O2F+YubLK45xWX5kcnUwbSu+onIxwZyiSNXZVMjJ0pWTyotW7VUFTYQy9dbfqq 20 | beLTK5RQqIK8fm1V6AG864pYinbxjTnEv9eKxRjXWYkzwfLJzxsZuekYmK8bP0pM 21 | WvpJ+b/qiFH2TrY1MRX+EQKBgGAFWzZwWxHXmXxPZxhHDstNFzxTemH3BEnteiPl 22 | z7FYWHaeBh0iuSdbBMRmKfnsRxHaGi/43uJ/en6hQZskWiphL50CCu7I7aIt0YUJ 23 | y8Yj0vARwZe9t3kZ7xdLIIWsrcbDOZQ/i8xWnxS9B3ivAbFmbjNdHhwTKqV+xZ0S 24 | MyTjAoGACgtKzGhSdZYgJX4tapuRV+f0rcvN6dvKMleWtShIfA+g0hJwzOmMe1hw 25 | JFAbK433pXQYvnpCYbajPm9wcjygvDotjvjdw5d1nIEBJP6VuLux10OwKxno4+1c 26 | HgGhCcrueb9mfwUQ3FvooJ1bsSkkU+8qV25GwQjk11vL9qyDA8I= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/ca.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIICyjCCAbICCQC/DWKJCY0kvjANBgkqhkiG9w0BAQsFADAnMQswCQYDVQQGEwJE 3 | RTEYMBYGA1UEAwwPcHhteXNxbC10ZXN0LUNBMB4XDTIyMTIxMzA2MzE1NFoXDTMy 4 | MTAyMTA2MzE1NFowJzELMAkGA1UEBhMCREUxGDAWBgNVBAMMD3B4bXlzcWwtdGVz 5 | dC1DQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALnFU7N+ZrL2IAvY 6 | runpVt2l85J8VtwpkGrIIw+1ZLn0+ODoBv3HtdB46d0Kw0ql6lwqVQcYvG8Vh8wp 7 | 0sLUoeiu/x46AvGpHlc7/RyKo1FAhKUFk/Xmsgxu4T66tcCfB3OpK4dYf/3UJQzU 8 | HlBvrrob5JwPNx9NcsiSkHFgTdpjC2bgTWqTRAHHbFo6ua60cY+AdCII2tYOxOF0 9 | J2H8vH2PlHNGvBtcLdNft37Uq6WtPoyboNHOkvB9eE4A4xrRqgytsJsuEFeK3Cl6 10 | pDvRNRNRofC5hpHsQsOPaYHro84fE6r3VOaNo8Yd2Ej22VpDX4pahNkq2AIhdl/f 11 | SzUZWL8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEARq0IJN06ajIEkHclKa8eIjOH 12 | HuYH1c38cJsXQN1BXE9Sv2EalrTRrwXSsAwHvKhRClbJeodlLXbC+C4gX+fywX/m 13 | Oc5SvJNGdUZyT1pqF0wPjIfYx9aO9sXsh/yE8ePwhxAeRU2i68QW6wB+AB6AnTIl 14 | op1v8gtzXHaLsUdiDD++98LEHEbYiUElnQBKwG2mzt223gCIIMlfUIHF6Hd2APKg 15 | ChAkUS2buFx93+b1ik4l63HAFqXue1X0Q8W4eMubbU0Ti2AkkLtOfwlTLrSEOOtf 16 | rPMMVMfwpIeq8tkrVNLm0zrKUSXP2cu1BmffZUoaewz/5yv5eybK00h4iqldwQ== 17 | -----END CERTIFICATE----- 18 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/ca.srl: -------------------------------------------------------------------------------- 1 | 5DE6BF0D6CC6DAF943F4306AE1EA712AD007F2DF 2 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/client-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIICzjCCAbYCCQC8fGqNl/yoxjANBgkqhkiG9w0BAQsFADAnMQswCQYDVQQGEwJE 3 | RTEYMBYGA1UEAwwPcHhteXNxbC10ZXN0LUNBMB4XDTIyMTIxMzA2MzE1NFoXDTMy 4 | MTAyMTA2MzE1NFowKzELMAkGA1UEBhMCREUxHDAaBgNVBAMME3B4bXlzcWwtdGVz 5 | dC1jbGllbnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/hOzDqY4I 6 | 85p8VK1rtqkSgqNAg3tSixRuucwZRjW3wEuNbO1ZxW0hzVqXGxNp/hLzlFaa/YOe 7 | 3oQTDdztdUOysDVVUnnytx/8uUtMgcsYrMw8RWPZbazCwwLPW0v3KHigqyC+3A2y 8 | voRYMfUeS4MR2AWnGzJSxS5wM0Pyd2VX31slapFZdG/bDxFKolQDa9nkEG5WWoj2 9 | pLBPajDdXkYequUQvmRURQkv0lkIy6TX16Y3dmwepgbX2lIROQy8TcwY+iisge+O 10 | w4IEW2XHCZLr7deIzcdU3Q50ttFJCYsThn2hs0zHAzk6gccSMdPWcChpQfU0f0hV 11 | whe9bjTs3W/zAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABVInbT5lKQ40C9frZHs 12 | IEhKDWtkXuy6pQAdTZu+ZasMR9RJx5DtJVbijs7DRvXu8Zbiiv7GpuVUEGxGqZBs 13 | CiA122aIr4KX36SfC4/apA127WbcYI1a6jBkYU+B8shXRcCeD23MaNlMZ05U8odd 14 | YFG/FHBaiJ3+atrqCkuqXl9fXCRWBWlClhxuzS+16RU6uQoqboFwZLobeN8zzVX7 15 | hw7qDkS/6yWGtHxDB7yBPil+cj3IKqQ5pEnrpufpKBN2yiB6Rrh6dNVZdfaSUutV 16 | rrei0lpWQBAfhe/3On/pXfuUtadWZbKR0EGItghVYoHwyh0Eq1zvPJpopt016qVm 17 | sPM= 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/client-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpQIBAAKCAQEAv4Tsw6mOCPOafFSta7apEoKjQIN7UosUbrnMGUY1t8BLjWzt 3 | WcVtIc1alxsTaf4S85RWmv2Dnt6EEw3c7XVDsrA1VVJ58rcf/LlLTIHLGKzMPEVj 4 | 2W2swsMCz1tL9yh4oKsgvtwNsr6EWDH1HkuDEdgFpxsyUsUucDND8ndlV99bJWqR 5 | WXRv2w8RSqJUA2vZ5BBuVlqI9qSwT2ow3V5GHqrlEL5kVEUJL9JZCMuk19emN3Zs 6 | HqYG19pSETkMvE3MGPoorIHvjsOCBFtlxwmS6+3XiM3HVN0OdLbRSQmLE4Z9obNM 7 | xwM5OoHHEjHT1nAoaUH1NH9IVcIXvW407N1v8wIDAQABAoIBAQCOX/LjQhkk7nPa 8 | GdkSSihGaneSbiwvoNT/u3/PCjLE918zM9b+9ZW7mz3NN4OnOAo+qff4IJ7IbAMj 9 | ZxrmLFa3b+c2FqoxlZFh/x3LMnIZVdw+shcYfEACSZa9L9G5W4zRZGZjfJNyXc9l 10 | AT6H1vsJON566+ztO0jagEHy7m+YcliHTzwzod/9melbto7AwedlbJNut/j3aFVz 11 | 6wN5NkaQu7C/NiNQ26N9QDnNSVoYlFOB9opCNVmLcHK/hdSNjNcOB3UpKePQXBYo 12 | GKkN/XtuZBSVhw9d6tEub+C5oBOOQ3EJuX5WwST/aj1cgB6HGlSahmie0r94ZWxX 13 | OC4lOxYBAoGBAPGa/V/1bvJVhq65sM543nuC4tMZUgOb5mQp+oumtPplnG3RPNP7 14 | np9+xoAaUz55KYkf74gC4Ux+10A/qXwrzJ9YIX8+ZpLNJZwWfQUmheZOXowfP0HL 15 | 9oMBRXmy1ivRUnToqR6fkSr+yHVJ3AJ/oy0MDVNrjs0YWYIu+tNvt+YjAoGBAMru 16 | Aug7l8HKgQi/Wg1ylpLi4Xl3UcqiHSW+soctcjrJm1z10pih92+Hqcg21wbFAR0P 17 | ESAslcD8OlCJYblSXi0zXE1eI111iCxcF/X9CyY0o+BsBIiJA+QDtIUwIfvk2Ka1 18 | nx0rkVpQ1H3ilDm4BpgMm0SFdCXDj+e2cgHpjSPxAoGBAPD9YeJHU4UQ3iiGO9+X 19 | HIQiR9G8ndvPs30RikGl5TsmA2ReosfnYY9BywmYOJRGErIeUrRd+xBsLJR/a7TZ 20 | k18Vb0QWoAWp7uvEWqu6gzD31sL5oAUnRxnhOMVtJsfKIO9P6vEKxKgYPycOpw8u 21 | 9TpHnTsqO+RDd3StG6+u7cX1AoGBAKAWdfqpEIZT78lr02nqbPkBvShqxf6aN25Q 22 | a1ySsJvJ8iO61eGNXLsChiEpiiaQAdnfyf3czmMJWCOyzYI6hYsZCocKbdHL55o/ 23 | KLPpZQNF4cYo0Ma5eHVHqwCrQRQLrBKQEy8a8LcULx4EQjTqhWEsCM1cjo1AIuWE 24 | G5qAmdSxAoGAe43Q/b3WFTLorGORzZ0ro9VILYp1iqAm2A2k4Fms1OiZP1MXksXU 25 | +kLG0ovtpGbxnzqd/ybyZWWWTr6q06aTXAhAzttgM1f2twk66oXbIpbMO1/RG/h5 26 | 13d4xXJFsDDASwH41cmcj92/sxUbRg4comNKmMUkJPGQt5EssVJwCJ4= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/server-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC6DCCAdCgAwIBAgIJALx8ao2X/KjFMA0GCSqGSIb3DQEBCwUAMCcxCzAJBgNV 3 | BAYTAkRFMRgwFgYDVQQDDA9weG15c3FsLXRlc3QtQ0EwHhcNMjIxMjEzMDYzMTU0 4 | WhcNMzIxMDIxMDYzMTU0WjArMQswCQYDVQQGEwJERTEcMBoGA1UEAwwTcHhteXNx 5 | bC10ZXN0LXNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAK0d 6 | K4okwsm61i5KX200bAeSzBujCLKp+TBr3PxhziSEL0i4ib904ipAwqYKdLnMiuzk 7 | QFGXNo0UriSwBeiSZDCTWPUGyxic8iI+pP+83OvxP20DWNaxizP3QxwdLDnsXqZm 8 | KQriYYBIsAul6bc5MnzVRRUVkVdH8xc/LvVfsqYkEEnpsbM1hQY2vsOhRvs0hnJa 9 | 3Q6HaFFLpmtwmmrnostIhAw79zyDec7D0naXy9+zhAPV9AwI4yieBxIs+6zdO+dR 10 | BmTNN4yrY8sVsRrqHVVs/KMiAOM7xzUVSlL7g6PepNaTFno6jZYzDVZu00I/l9KL 11 | jsZyYAWzCX0RntaGqnUCAwEAAaMTMBEwDwYDVR0RBAgwBocEfwAAATANBgkqhkiG 12 | 9w0BAQsFAAOCAQEApccqby/++1A7EFv5+jTnRVgnB3fQ1POzPbeBADye35mJNmcU 13 | Ks7qCj22gGRNO/INW71j2GBtyaTzT0FtZytKIMWv3kC6kgtVYo3r6cfOQihD+T3o 14 | G1PNKLicGgmfaSyHhtTA6WK3CLAQu/yS4KARIWi3X88H7zzBfKHZ23+rNGzILtGv 15 | 3RiyP4hPD+y1a4nObBjcSM5F09qRZzQpcpwldrZGSZrbrqL9uqeE4V7403sFuqNt 16 | IycSNEchF7vAm/0oqbYsuRQkjZyk77OiDxhWrtWSDkUdx6v+CAaCme3DD9SK/0l/ 17 | vJhrCW2HUmXnrYEWVkzX6mverbTDAQ5v4I/sAg== 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.d/server-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpgIBAAKCAQEArR0riiTCybrWLkpfbTRsB5LMG6MIsqn5MGvc/GHOJIQvSLiJ 3 | v3TiKkDCpgp0ucyK7ORAUZc2jRSuJLAF6JJkMJNY9QbLGJzyIj6k/7zc6/E/bQNY 4 | 1rGLM/dDHB0sOexepmYpCuJhgEiwC6XptzkyfNVFFRWRV0fzFz8u9V+ypiQQSemx 5 | szWFBja+w6FG+zSGclrdDodoUUuma3Caaueiy0iEDDv3PIN5zsPSdpfL37OEA9X0 6 | DAjjKJ4HEiz7rN0751EGZM03jKtjyxWxGuodVWz8oyIA4zvHNRVKUvuDo96k1pMW 7 | ejqNljMNVm7TQj+X0ouOxnJgBbMJfRGe1oaqdQIDAQABAoIBAQCQN13fTvKrZigq 8 | FjFbY7GfuY6qc266kNmUmjdWVhCK4UgW+A1hX3lOo/bEpq9JXfpakWh30FZUv+a3 9 | j6DMeLBYu1f/gLJPheg92RxSJL+TG76wDXrEGNKT7yiMUk1Wz/CmBTOp6qA5Y9St 10 | T4Hd7xt9XZqYjwguwzTjp/Jx3lCRENlzAPBgIeXmEvAFWdT9XNaI4PcDBR9uGCp7 11 | HOg9GyGrsPdExKDQtukrJzNKu/63oFrQKbibks3cVWwtKEq0Xyz/tt5I4yACG8AB 12 | DFCrmA+e04C8OzliGupOI4GRrQhYxNOdNSxqJU04x9q1JyqHfZ0Nf9pevdYKA9Z7 13 | rwK5Vp79AoGBANNBbhdV0hjqjMNO99Nds9Rscg30dfojbxUxulcl/N4OAE5UpkmC 14 | vFwNu0ktfMz8OBlgkaW9lSRS6KsCtnh/cLNKB14VErM58J1qx0/9PUJSPbNipuVn 15 | hk0ZV3TGgxV1mYqZhNerKaUPJ7ZNPCe9h0lh6OEV/UlQiq7QErAMwwpjAoGBANHH 16 | pcXwPSMjtf6pXHaDF0es6DH8dQB2ZpsaiJwm6AbMJvwFfJEaYApnEzTEtoPahD5i 17 | 24iPW9IVBERKJaTAZe4QYwd0gaQYFPmMEPJZtCyD9cwtOUgbIVMSsCraYhLXg9ou 18 | 6QRpaXdERiLzmffO+QUkfPYofggONuuX3YW9s+NHAoGBAJyk0JAnB7GIAbY0kNi+ 19 | i0CA5RVp5i0DJzQM+oHyXgz9TsbGR8MMWMTdPbkmLHsGrkZK79R4veUAQRvE2C6D 20 | OLsIsmvVrlcNKFhhO8cZHNpXhv7DsMM7vz7eApZJOBuqZp559SHB/hAxK54mqOtC 21 | wtTr77UvC+/X8+1pxeGapOjHAoGBANAeW64mCuFjulituReyMlRfi/SbW5Bb5quW 22 | BVW1m5eyzjJVVyG1ovZvEDTXu6LQFUa3WMkAQL4JL7R4QyRR5E3sX/KzeTJM2fJB 23 | LUbiC8fmGuK3Mw8AK215KuE4yveabCr3QyGnWoSCbXqbZnLdGVwquPaVcYOYZpAQ 24 | mCro6yBdAoGBAMYcSrtG8T4/UsN0j7vyjRm0+VKDci4zjvzpKR5O8597tnvY8R63 25 | b60rkqogT0x/yTD7qbgBgJ41Z4o3W7WilVt10E/eA4sStdhepQ6NPXQnPKIwMMof 26 | ZkVAH0mGvUuYf1lFn0DKidBe823nS/ezdhckjrZc8iIAPbz+frSZnHCq 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/conf.srl: -------------------------------------------------------------------------------- 1 | BC7C6A8D97FCA8C6 2 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/docker-compose.yml: -------------------------------------------------------------------------------- 1 | networks: 2 | pxmysql.test.net: 3 | driver: bridge 4 | 5 | volumes: 6 | shared: 7 | 8 | services: 9 | mysql: 10 | container_name: pxmysql.test.db 11 | image: mysql:8.0 12 | environment: 13 | MYSQL_ROOT_PASSWORD: rootpwd 14 | command: --default-authentication-plugin=mysql_native_password 15 | networks: 16 | - pxmysql.test.net 17 | volumes: 18 | - ./data:/var/lib/mysql 19 | - ./conf.d:/etc/mysql/conf.d 20 | - ./shared:/shared:ro 21 | ports: 22 | - 127.0.0.1:53306:3306 23 | - 127.0.0.1:53360:33060 24 | 25 | go: 26 | container_name: pxmysql.test.go 27 | image: golang:1.21-bullseye 28 | tty: true 29 | stdin_open: true 30 | volumes: 31 | - ./shared:/shared 32 | - ../../../pxmysql:/go/src/github.com/golistic/pxmysql:ro 33 | 34 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/generate_tls.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # 3 | # Copyright (c) 2022, Geert JM Vanderkelen 4 | # 5 | 6 | set -e 7 | 8 | # 9 | # This script generates the MySQL Server Certificate Authority (CA). 10 | # It also generates server and client certificates. 11 | # 12 | # The server certificate has X.509 subjectAltName extension with value 13 | # of 'IP:127.0.0.1'. 14 | # 15 | # 16 | 17 | OUT_DIR=conf.d 18 | CA_KEY="${OUT_DIR}/ca-key.pem" 19 | CA_CERT="${OUT_DIR}/ca.pem" 20 | SERVER_KEY="${OUT_DIR}/server-key.pem" 21 | SERVER_REQ="${OUT_DIR}/server-req.pem" 22 | SERVER_CERT="${OUT_DIR}/server-cert.pem" 23 | CLIENT_KEY="${OUT_DIR}/client-key.pem" 24 | CLIENT_REQ="${OUT_DIR}/client-req.pem" 25 | CLIENT_CERT="${OUT_DIR}/client-cert.pem" 26 | DAYS=3600 27 | 28 | OPENSSL=$(command -v openssl) 29 | if [ "${OPENSSL}" = "" ]; then 30 | echo "Error: openssl command not available in path" 31 | exit 1 32 | fi 33 | 34 | v=$($OPENSSL version) 35 | case "${v}" in 36 | "OpenSSL 1.1"*) ;; 37 | "LibreSSL 3.3.6"*) ;; 38 | *) 39 | echo "Error: expecting OpenSSL v1.1 or greater, or LibreSSL v3.3 or greater (got ${v})" 40 | exit 1 41 | esac 42 | 43 | # Create CA certificate 44 | ${OPENSSL} genrsa -out ${CA_KEY} 2048 45 | ${OPENSSL} req -new \ 46 | -subj "/C=DE/CN=pxmysql-test-CA" \ 47 | -x509 -days ${DAYS} -nodes -key ${CA_KEY} -out ${CA_CERT} 48 | 49 | # Create server certificate (removing passphrase) 50 | ${OPENSSL} genrsa -out ${SERVER_KEY} 2048 51 | ${OPENSSL} req -new -nodes \ 52 | -subj "/C=DE/CN=pxmysql-test-server" \ 53 | -key ${SERVER_KEY} -out ${SERVER_REQ} 54 | ${OPENSSL} x509 -req -days ${DAYS} -in ${SERVER_REQ} \ 55 | -CA ${CA_CERT} -CAkey ${CA_KEY} -CAcreateserial \ 56 | -extfile server_x509_ext.conf \ 57 | -out ${SERVER_CERT} 58 | 59 | # Create client certificate (removing passphrase) 60 | ${OPENSSL} genrsa -out ${CLIENT_KEY} 2048 61 | ${OPENSSL} req -new -nodes \ 62 | -subj "/C=DE/CN=pxmysql-test-client" \ 63 | -key ${CLIENT_KEY} -out ${CLIENT_REQ} 64 | ${OPENSSL} x509 -req -days ${DAYS} -in ${CLIENT_REQ} \ 65 | -CA ${CA_CERT} -CAkey ${CA_KEY} -CAcreateserial \ 66 | -out ${CLIENT_CERT} 67 | 68 | rm ${OUT_DIR}/*-req.pem 69 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/server_x509_ext.conf: -------------------------------------------------------------------------------- 1 | subjectAltName=IP:127.0.0.1 -------------------------------------------------------------------------------- /_support/pxmysql-compose/shared/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | out="/shared/builds/$1" 4 | main="/shared/goapps/$1" 5 | 6 | export GOPRIVATE="github.com/golistic/pxmysql" 7 | 8 | cd "${main}" || exit 1 9 | go mod tidy 10 | go build -o "${out}" . 11 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/shared/goapps/unix_socket/go.mod: -------------------------------------------------------------------------------- 1 | module example.com/unix_socket 2 | 3 | go 1.21 4 | 5 | replace github.com/golistic/pxmysql => /go/src/github.com/golistic/pxmysql 6 | 7 | require github.com/golistic/pxmysql v1.0.0 8 | 9 | require ( 10 | github.com/golistic/xgo v1.0.0 // indirect 11 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect 12 | google.golang.org/protobuf v1.31.0 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/shared/goapps/unix_socket/go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 2 | github.com/golistic/xgo v1.0.0 h1:NtwCDhGBJR0vOvn0l8S8z9wzmN3hYRM8Wwwd1+2BFO0= 3 | github.com/golistic/xgo v1.0.0/go.mod h1:em3spZJ1b8mrGv1P5Fx2Ewclaf0rqIfSwrzYIVnL6o4= 4 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 5 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 6 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 7 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= 8 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= 9 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 10 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 11 | google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= 12 | google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 13 | -------------------------------------------------------------------------------- /_support/pxmysql-compose/shared/goapps/unix_socket/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | /* 4 | This application is executed within the MySQL container by the pxmysql 5 | Go tests. 6 | 7 | The MySQL sock-file within the Docker container cannot be accessed through 8 | a Docker volume. Copying in an application which uses pxmysql is therefor 9 | the only way to automate testing of UNIX socket file support. 10 | */ 11 | 12 | package main 13 | 14 | import ( 15 | "context" 16 | "database/sql" 17 | "fmt" 18 | "log" 19 | 20 | _ "github.com/golistic/pxmysql/register" 21 | ) 22 | 23 | func main() { 24 | // credentials are the once used when running pxmysql tests 25 | db, err := sql.Open("pxmysql", "root:rootpwd@unix(/var/lib/mysql/mysqlx.sock)") 26 | if err != nil { 27 | log.Fatalln("open:", err) 28 | } 29 | 30 | if err := db.Ping(); err != nil { 31 | log.Fatalln("ping:", err) 32 | } 33 | 34 | var version string 35 | if err := db.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&version); err != nil { 36 | log.Fatalln("query row:", err) 37 | } 38 | 39 | fmt.Println(version) 40 | } 41 | -------------------------------------------------------------------------------- /cmd/.gitignore: -------------------------------------------------------------------------------- 1 | scratch/ -------------------------------------------------------------------------------- /cmd/gencollations/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen 2 | 3 | package main 4 | 5 | import ( 6 | "context" 7 | "flag" 8 | "fmt" 9 | "go/format" 10 | "os" 11 | "time" 12 | 13 | "github.com/golistic/xgo/xstrings" 14 | "golang.org/x/term" 15 | 16 | "github.com/golistic/pxmysql/xmysql" 17 | ) 18 | 19 | type appFlags struct { 20 | addr string 21 | username string 22 | } 23 | 24 | func main() { 25 | flags := initFlags() 26 | 27 | fmt.Printf("Password for %s (enter for empty): ", flags.username) 28 | p, err := term.ReadPassword(0) 29 | if err != nil { 30 | fmt.Printf("Error: failed reading password (%s)\n", err) 31 | os.Exit(1) 32 | } 33 | fmt.Println() 34 | 35 | config := &xmysql.ConnectConfig{ 36 | Address: flags.addr, 37 | Username: flags.username, 38 | Password: xstrings.Pointer(string(p)), 39 | UseTLS: true, 40 | } 41 | 42 | ses, err := xmysql.GetSession(context.Background(), config) 43 | if err != nil { 44 | fmt.Printf("Error: %s\n", err) 45 | os.Exit(1) 46 | } 47 | 48 | q := "SELECT VERSION()" 49 | res, err := ses.ExecuteStatement(context.Background(), q) 50 | if err != nil { 51 | fmt.Printf("Error: %s\n", err) 52 | os.Exit(1) 53 | } 54 | version := res.Rows[0].Values[0].(string) 55 | 56 | charset := "utf8mb4" 57 | q = "SELECT id, collation_name AS name " + 58 | "FROM information_schema.collations WHERE character_set_name = ? ORDER BY id" 59 | res, err = ses.ExecuteStatement(context.Background(), q, charset) 60 | if err != nil { 61 | fmt.Printf("Error: %s\n", err) 62 | os.Exit(1) 63 | } 64 | 65 | content := fmt.Sprintf("// Code generated by gencollations. DO NOT EDIT.\n"+ 66 | "// MySQL v%s; generated at %s\n\n"+ 67 | "package xmysql\n\nvar Collations = map[string]Collation{\n", version, time.Now().UTC()) 68 | 69 | entry := "\"%s\": {ID: %d, Name: \"%s\", CharSet: \"%s\",},\n" 70 | 71 | entryID := "%d: \"%s\",\n" 72 | 73 | var ids string 74 | for _, r := range res.Rows { 75 | v := r.Values 76 | content += fmt.Sprintf(entry, v[1].(string), v[0].(uint64), v[1].(string), charset) 77 | ids += fmt.Sprintf(entryID, v[0].(uint64), v[1].(string)) 78 | } 79 | 80 | content += "}\n" 81 | 82 | content += "\nvar collationIDs = map[uint64]string{\n" + ids + "\n}" 83 | 84 | data, err := format.Source([]byte(content)) 85 | if err != nil { 86 | fmt.Printf("Error: failed formatting Go code (%s)\n", err) 87 | os.Exit(1) 88 | } 89 | 90 | fn := "xmysql/collations_data.go" 91 | fp, err := os.OpenFile(fn, os.O_WRONLY|os.O_TRUNC, 0666) 92 | if err != nil { 93 | fmt.Printf("Error: failed opening %s for writing (%s)\n", fn, err) 94 | os.Exit(1) 95 | } 96 | defer func() { _ = fp.Close() }() 97 | 98 | if _, err := fp.Write(data); err != nil { 99 | if err != nil { 100 | fmt.Printf("Error: failed writing to %s (%s)\n", fn, err) 101 | os.Exit(1) 102 | } 103 | } 104 | } 105 | 106 | func initFlags() *appFlags { 107 | f := &appFlags{} 108 | flag.StringVar(&f.addr, "address", "localhost:33060", 109 | "address (host with port) of the MySQL Server's X Plugin") 110 | flag.StringVar(&f.username, "user", "root", 111 | "user for the connection to the MySQL Server") 112 | 113 | flag.Parse() 114 | return f 115 | } 116 | -------------------------------------------------------------------------------- /cmd/genprotobuf/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package main 4 | 5 | import ( 6 | "bytes" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "os" 11 | "os/exec" 12 | "path" 13 | "strings" 14 | "time" 15 | 16 | "github.com/golistic/xgo/xos" 17 | ) 18 | 19 | const ( 20 | mysqlVersion = "8.0.34" 21 | baseURL = "https://raw.githubusercontent.com/mysql/mysql-server/mysql-" + mysqlVersion + "/plugin/x/protocol/protobuf/" 22 | ) 23 | 24 | const pkgxmysql = "github.com/golistic/pxmysql" 25 | 26 | const nrOfFilesAtLeast = 5 27 | 28 | var protoPath = path.Join("internal", "mysqlx") 29 | 30 | func main() { 31 | if err := generate(); err != nil { 32 | exitWithErr(err) 33 | } 34 | } 35 | 36 | func exitWithErr(err error) { 37 | fmt.Println("Error:", err) 38 | os.Exit(1) 39 | } 40 | 41 | func checkExecLocation() (string, error) { 42 | d, err := os.Getwd() 43 | if err != nil { 44 | return "", fmt.Errorf("failed getting working directory (%w)", err) 45 | } 46 | 47 | needles := []string{".git", protoPath} 48 | 49 | for _, n := range needles { 50 | if !xos.IsDir(path.Join(d, n)) { 51 | return "", fmt.Errorf("must execute within root of xmysql repository") 52 | } 53 | } 54 | 55 | return d, nil 56 | } 57 | 58 | func fetchFile(name string) ([]byte, error) { 59 | u := baseURL + name 60 | resp, err := http.Get(u) 61 | if err != nil { 62 | return nil, fmt.Errorf("failed opening URL downloading %s (%w)", name, err) 63 | } 64 | if resp.StatusCode != http.StatusOK { 65 | return nil, fmt.Errorf("failed opening URL downloading %s (HTTP status %d)", name, resp.StatusCode) 66 | } 67 | defer func() { _ = resp.Body.Close() }() 68 | 69 | body, err := io.ReadAll(resp.Body) 70 | if err != nil { 71 | return nil, fmt.Errorf("failed reading body downloading file %s (%w)", name, err) 72 | } 73 | 74 | return body, nil 75 | } 76 | 77 | func protoFiles() ([]string, error) { 78 | fileData, err := fetchFile("source_files.cmake") 79 | if err != nil { 80 | return nil, err 81 | } 82 | 83 | files := make([]string, 0, nrOfFilesAtLeast) 84 | for _, l := range bytes.Split(fileData, []byte("\n")) { 85 | l = bytes.TrimSpace(l) 86 | if len(l) == 0 || l[0] == '#' || 87 | !(bytes.HasPrefix(l, []byte("mysqlx")) && bytes.HasSuffix(l, []byte(".proto"))) { 88 | continue 89 | } 90 | files = append(files, string(l)) 91 | } 92 | 93 | return files, nil 94 | } 95 | 96 | func downloadFile(dir, filename string) error { 97 | fileData, err := fetchFile(filename) 98 | if err != nil { 99 | return err 100 | } 101 | fp, err := os.OpenFile(path.Join(dir, filename), os.O_CREATE|os.O_WRONLY, 0666) 102 | if err != nil { 103 | return fmt.Errorf("failed opening file %s (%w)", filename, err) 104 | } 105 | if _, err := fp.Write(fileData); err != nil { 106 | _ = fp.Close() 107 | return fmt.Errorf("failed writing to file %s (%w)", filename, err) 108 | } 109 | _ = fp.Close() 110 | 111 | return nil 112 | } 113 | 114 | func generate() error { 115 | wd, err := checkExecLocation() 116 | if err != nil { 117 | return err 118 | } 119 | 120 | protoc, err := exec.LookPath("protoc") 121 | if err != nil { 122 | return fmt.Errorf("protoc executable not available") 123 | } 124 | 125 | files, err := protoFiles() 126 | if err != nil { 127 | return err 128 | } 129 | 130 | args := []string{protoc, "--proto_path=" + protoPath, 131 | "--go_out=.", 132 | "--go_opt=paths=import", 133 | "--go_opt=module=github.com/golistic/pxmysql", 134 | } 135 | 136 | for _, f := range files { 137 | if err := downloadFile(protoPath, f); err != nil { 138 | return err 139 | } 140 | 141 | m := strings.Replace(f, ".proto", "", 1) 142 | m = strings.Replace(m, "_", "", -1) 143 | args = append(args, fmt.Sprintf("--go_opt=M%s=%s/%s/%s", f, pkgxmysql, protoPath, m)) 144 | } 145 | 146 | args = append(args, files...) 147 | 148 | output := bytes.NewBuffer(nil) 149 | 150 | cmd := exec.Cmd{ 151 | Dir: wd, 152 | Path: protoc, 153 | Args: args, 154 | Stdout: output, 155 | Stderr: output, 156 | } 157 | 158 | err = cmd.Run() 159 | switch err.(type) { 160 | case *exec.ExitError: 161 | return fmt.Errorf("execution of protoc failed: %s", output.String()) 162 | case error: 163 | return fmt.Errorf("could not run protoc") 164 | } 165 | 166 | for _, f := range files { 167 | _ = os.Remove(path.Join(protoPath, f)) 168 | } 169 | 170 | infoFile := path.Join(protoPath, "info.md") 171 | fp, err := os.OpenFile(infoFile, os.O_CREATE|os.O_WRONLY, 0666) 172 | if err != nil { 173 | return fmt.Errorf("failed opening file %s (%w)", infoFile, err) 174 | } 175 | defer func() { _ = fp.Close() }() 176 | 177 | _, err = fp.WriteString(fmt.Sprintf("Generated from MySQL Server %s at %s.\n", mysqlVersion, 178 | time.Now().UTC().Format(time.RFC3339))) 179 | if err != nil { 180 | return fmt.Errorf("failed writing to file %s (%w)", infoFile, err) 181 | } 182 | 183 | return nil 184 | } 185 | -------------------------------------------------------------------------------- /cmd/make/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/golistic/gomake" 7 | ) 8 | 9 | func main() { 10 | gomake.RegisterTargets( 11 | &gomake.TargetBadges, 12 | &gomake.TargetGoLint, 13 | ) 14 | gomake.Make() 15 | } 16 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "fmt" 9 | 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | type connection struct { 14 | cfg *xmysql.ConnectConfig 15 | session *xmysql.Session 16 | } 17 | 18 | var ( 19 | _ driver.Conn = (*connection)(nil) 20 | _ driver.ConnBeginTx = (*connection)(nil) 21 | _ driver.Pinger = (*connection)(nil) 22 | _ driver.ExecerContext = (*connection)(nil) 23 | _ driver.QueryerContext = (*connection)(nil) 24 | ) 25 | 26 | func (c *connection) Prepare(query string) (driver.Stmt, error) { 27 | 28 | prep, err := c.session.PrepareStatement(context.Background(), query) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | s := &statement{ 34 | prepared: prep, 35 | } 36 | 37 | return s, nil 38 | } 39 | 40 | func (c *connection) Close() error { 41 | if c.session != nil { 42 | return c.session.Close() 43 | } 44 | return nil 45 | } 46 | 47 | func (c *connection) Begin() (driver.Tx, error) { 48 | return c.BeginTx(context.Background(), driver.TxOptions{}) 49 | } 50 | 51 | func (c *connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 52 | q := "START TRANSACTION" 53 | if opts.ReadOnly { 54 | q += q + " READ ONLY" 55 | } 56 | 57 | if _, err := c.session.ExecuteStatement(ctx, q); err != nil { 58 | return nil, err 59 | } 60 | 61 | return &Transaction{session: c.session}, nil 62 | } 63 | 64 | func (c *connection) Ping(ctx context.Context) error { 65 | if c.session == nil { 66 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn) 67 | } 68 | 69 | if _, err := c.session.SessionID(ctx); err != nil { 70 | return fmt.Errorf("ping failed (%w)", err) 71 | } 72 | 73 | return nil 74 | } 75 | 76 | func (c *connection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 77 | prep, err := c.session.PrepareStatement(context.Background(), query) 78 | if err != nil { 79 | return nil, handleError(err) 80 | } 81 | 82 | defer func() { _ = prep.Deallocate(ctx) }() 83 | 84 | stmt := &statement{ 85 | prepared: prep, 86 | } 87 | 88 | return stmt.ExecContext(ctx, args) 89 | } 90 | 91 | func (c *connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 92 | prep, err := c.session.PrepareStatement(context.Background(), query) 93 | if err != nil { 94 | return nil, handleError(err) 95 | } 96 | 97 | defer func() { _ = prep.Deallocate(ctx) }() 98 | 99 | stmt := &statement{ 100 | prepared: prep, 101 | } 102 | 103 | return stmt.QueryContext(ctx, args) 104 | } 105 | -------------------------------------------------------------------------------- /connection_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql_test 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "errors" 9 | "fmt" 10 | "testing" 11 | "time" 12 | 13 | "github.com/golistic/xgo/xt" 14 | 15 | "github.com/golistic/pxmysql/mysqlerrors" 16 | ) 17 | 18 | func TestConnection_Begin(t *testing.T) { 19 | dsn := getTCPDSN() 20 | db, err := sql.Open("pxmysql", dsn) 21 | xt.OK(t, err) 22 | defer func() { _ = db.Close() }() 23 | 24 | tbl := "t29dkckiidk" 25 | 26 | _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl)) 27 | xt.OK(t, err) 28 | _, err = db.Exec(fmt.Sprintf("CREATE TABLE `%s` (id INT PRIMARY KEY, c1 INT)", tbl)) 29 | xt.OK(t, err) 30 | 31 | t.Run("start transaction and commit", func(t *testing.T) { 32 | tx, err := db.Begin() 33 | xt.OK(t, err) 34 | 35 | id := 1 36 | exp := 123 37 | 38 | stmtInsert := fmt.Sprintf("INSERT INTO `%s` (id, c1) VALUES (?, ?)", tbl) 39 | result, err := tx.Exec(stmtInsert, id, exp) 40 | xt.OK(t, err) 41 | affected, err := result.RowsAffected() 42 | xt.OK(t, err) 43 | xt.Eq(t, 1, affected) 44 | 45 | xt.OK(t, tx.Commit()) 46 | 47 | q := fmt.Sprintf("SELECT c1 FROM `%s` WHERE id = ?", tbl) 48 | var have int 49 | xt.OK(t, db.QueryRowContext(context.Background(), q, id).Scan(&have)) 50 | xt.Eq(t, exp, have) 51 | }) 52 | 53 | t.Run("start transaction and rollback", func(t *testing.T) { 54 | tx, err := db.Begin() 55 | xt.OK(t, err) 56 | 57 | id := 2 58 | value := 987 59 | 60 | stmtInsert := fmt.Sprintf("INSERT INTO `%s` (id, c1) VALUES (?, ?)", tbl) 61 | result, err := tx.Exec(stmtInsert, id, value) 62 | xt.OK(t, err) 63 | affected, err := result.RowsAffected() 64 | xt.OK(t, err) 65 | xt.Eq(t, 1, affected) 66 | 67 | xt.OK(t, tx.Rollback()) 68 | 69 | q := fmt.Sprintf("SELECT c1 FROM `%s` WHERE id = ?", tbl) 70 | rows, err := db.QueryContext(context.Background(), q, id) 71 | xt.OK(t, err) 72 | xt.Assert(t, !rows.Next()) 73 | }) 74 | } 75 | 76 | func TestConnection_ExecContext(t *testing.T) { 77 | t.Run("respect timeout", func(t *testing.T) { 78 | dsn := getTCPDSN() 79 | db, err := sql.Open("pxmysql", dsn) 80 | xt.OK(t, err) 81 | defer func() { _ = db.Close() }() 82 | 83 | ctx := context.Background() 84 | ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) 85 | defer cancel() 86 | 87 | _, err = db.ExecContext(ctx, "SELECT SLEEP(5)") 88 | xt.KO(t, err) 89 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error()) 90 | }) 91 | 92 | t.Run("prepared statement should close using Query", func(t *testing.T) { 93 | dsn := getTCPDSN() 94 | db, err := sql.Open("pxmysql", dsn) 95 | xt.OK(t, err) 96 | defer func() { _ = db.Close() }() 97 | 98 | needle := "SDFIciwkdixks" 99 | stmt := "/* " + needle + " */ SELECT COUNT(*) FROM performance_schema.prepared_statements_instances WHERE SQL_TEXT LIKE ?" 100 | needleParam := "%" + needle + "%" 101 | 102 | for i := 0; i < 2; i++ { 103 | var got int 104 | xt.OK(t, db.QueryRowContext(context.Background(), stmt, needleParam).Scan(&got)) 105 | xt.Eq(t, 1, got) // 1 because the query is seeing itself 106 | } 107 | }) 108 | 109 | t.Run("prepared statement should close using Exec", func(t *testing.T) { 110 | dsn := getTCPDSN() 111 | db, err := sql.Open("pxmysql", dsn) 112 | xt.OK(t, err) 113 | defer func() { _ = db.Close() }() 114 | 115 | needle := "owicIOwidols" 116 | stmt := "/* " + needle + " */ DELETE FROM mysql.user WHERE user = 'nobody'" 117 | needleParam := "%" + needle + "%" 118 | 119 | selectStmt := "SELECT COUNT(*) FROM performance_schema.prepared_statements_instances WHERE SQL_TEXT LIKE ?" 120 | 121 | for i := 0; i < 2; i++ { 122 | var got int 123 | _, err := db.ExecContext(context.Background(), stmt) 124 | xt.OK(t, err) 125 | xt.OK(t, db.QueryRowContext(context.Background(), selectStmt, needleParam).Scan(&got)) 126 | xt.Eq(t, 0, got) 127 | } 128 | }) 129 | } 130 | 131 | func TestConnection_QueryContext(t *testing.T) { 132 | dsn := getTCPDSN() 133 | db, err := sql.Open("pxmysql", dsn) 134 | xt.OK(t, err) 135 | defer func() { _ = db.Close() }() 136 | 137 | t.Run("respect timeout", func(t *testing.T) { 138 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 139 | defer cancel() 140 | 141 | _, err := db.QueryContext(ctx, "SELECT SLEEP(5)") 142 | xt.KO(t, err) 143 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error()) 144 | }) 145 | } 146 | 147 | func TestConnector_Connect(t *testing.T) { 148 | t.Run("server closing stale connection and reconnect", func(t *testing.T) { 149 | dsn := getTCPDSN() 150 | db, err := sql.Open("pxmysql", dsn) 151 | xt.OK(t, err) 152 | 153 | _, err = db.Exec("SET @@SESSION.mysqlx_wait_timeout = 2") 154 | xt.OK(t, err) 155 | 156 | var cnxID int 157 | xt.OK(t, db.QueryRow("SELECT CONNECTION_ID()").Scan(&cnxID)) 158 | 159 | var n string 160 | var v string 161 | xt.OK(t, db.QueryRow("SHOW SESSION VARIABLES LIKE 'mysqlx_wait_timeout'").Scan(&n, &v)) 162 | 163 | time.Sleep(3 * time.Second) // server should close connection 164 | 165 | var cnxIDAfter int 166 | xt.OK(t, db.QueryRow("SELECT CONNECTION_ID()").Scan(&cnxIDAfter)) 167 | 168 | xt.Assert(t, cnxID != cnxIDAfter) 169 | }) 170 | } 171 | -------------------------------------------------------------------------------- /connector.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | 9 | "github.com/golistic/pxmysql/xmysql" 10 | ) 11 | 12 | type connector struct { 13 | dataSource DataSource 14 | } 15 | 16 | var _ driver.Connector = &connector{} 17 | 18 | func (c connector) Connect(ctx context.Context) (driver.Conn, error) { 19 | 20 | // dataSource at this point is valid 21 | config := &xmysql.ConnectConfig{ 22 | UseTLS: c.dataSource.UseTLS, 23 | AuthMethod: xmysql.AuthMethodAuto, 24 | Username: c.dataSource.User, 25 | Schema: c.dataSource.Schema, 26 | } 27 | config.SetPassword(c.dataSource.Password) 28 | 29 | switch c.dataSource.Protocol { 30 | case "unix": 31 | config.UnixSockAddr = c.dataSource.Address 32 | case "tcp": 33 | config.Address = c.dataSource.Address 34 | } 35 | 36 | ses, err := xmysql.GetSession(ctx, config) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return &connection{ 42 | cfg: config, 43 | session: ses, 44 | }, nil 45 | } 46 | 47 | func (c connector) Driver() driver.Driver { 48 | return &Driver{} 49 | } 50 | -------------------------------------------------------------------------------- /datasource.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "fmt" 7 | 8 | "github.com/golistic/xgo/xconv" 9 | "github.com/golistic/xgo/xsql" 10 | ) 11 | 12 | // DataSource defines the configuration of the connection. It embeds xsql.DataSource 13 | // and extends it with attributes defined in the Options. 14 | type DataSource struct { 15 | xsql.DataSource 16 | 17 | UseTLS bool 18 | } 19 | 20 | // NewDataSource instantiates a DataSource using the Data Source Name (DSN). 21 | func NewDataSource(name string) (DataSource, error) { 22 | xds, err := xsql.ParseDSN(name) 23 | if err != nil { 24 | return DataSource{}, err 25 | } 26 | 27 | ds := DataSource{ 28 | DataSource: *xds, 29 | UseTLS: false, 30 | } 31 | 32 | if err := ds.handleOptions(); err != nil { 33 | return DataSource{}, err 34 | } 35 | 36 | if err := ds.CheckValidity(); err != nil { 37 | return DataSource{}, fmt.Errorf("configuration not valid (%w)", err) 38 | } 39 | 40 | return ds, nil 41 | } 42 | 43 | // CheckValidity returns whether the DataSource has enough configuration to establish 44 | // a connection. Needed are the address, protocol, and username. 45 | func (ds *DataSource) CheckValidity() error { 46 | switch { 47 | case ds.Address == "": 48 | return fmt.Errorf("address missing") 49 | case ds.User == "": 50 | return fmt.Errorf("user missing") 51 | case ds.Protocol == "": 52 | return fmt.Errorf("protocol missing") 53 | default: 54 | return nil 55 | } 56 | } 57 | 58 | func (ds *DataSource) handleOptions() error { 59 | var err error 60 | useTLS := ds.Options.Get("useTLS") 61 | if useTLS != "" { 62 | ds.UseTLS, err = xconv.ParseBool(useTLS) 63 | if err != nil { 64 | return fmt.Errorf("invalid value for useTLS option (was %s)", useTLS) 65 | } 66 | } 67 | 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /datasource_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "errors" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xsql" 10 | "github.com/golistic/xgo/xt" 11 | ) 12 | 13 | func TestDataSource_IsZero(t *testing.T) { 14 | t.Run("zero when username missing", func(t *testing.T) { 15 | _, err := NewDataSource(":pwd@tcp(127.0.0.1)") 16 | xt.KO(t, err) 17 | xt.Eq(t, "user missing", errors.Unwrap(err).Error()) 18 | }) 19 | 20 | t.Run("zero when address missing", func(t *testing.T) { 21 | ds := DataSource{ 22 | DataSource: xsql.DataSource{ 23 | Driver: "pxmysql", 24 | User: "user", 25 | Password: "", 26 | Protocol: "tcp", 27 | Address: "", 28 | Schema: "", 29 | Options: nil, 30 | }, 31 | } 32 | err := ds.CheckValidity() 33 | xt.KO(t, err) 34 | xt.Eq(t, "address missing", err.Error()) 35 | }) 36 | 37 | t.Run("protocol missing", func(t *testing.T) { 38 | ds := DataSource{ 39 | DataSource: xsql.DataSource{ 40 | Driver: "pxmysql", 41 | User: "user", 42 | Password: "", 43 | Protocol: "", 44 | Address: "127.0.0.1", 45 | Schema: "", 46 | Options: nil, 47 | }, 48 | } 49 | err := ds.CheckValidity() 50 | xt.KO(t, err) 51 | xt.Eq(t, "protocol missing", err.Error()) 52 | }) 53 | } 54 | 55 | func TestNewDataSource(t *testing.T) { 56 | t.Run("invalid useTLS option value", func(t *testing.T) { 57 | _, err := NewDataSource("user:pwd@tcp(127.0.0.1)/?useTLS=nope") 58 | xt.KO(t, err) 59 | xt.Eq(t, "invalid value for useTLS option (was nope)", err.Error()) 60 | }) 61 | } 62 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | ) 9 | 10 | type Driver struct{} 11 | 12 | var ( 13 | _ driver.Driver = &Driver{} 14 | _ driver.DriverContext = &Driver{} 15 | ) 16 | 17 | // Open returns a new connection to the MySQL database using MySQL X Protocol. 18 | func (d *Driver) Open(name string) (driver.Conn, error) { 19 | c, err := d.OpenConnector(name) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | return c.Connect(context.Background()) 25 | } 26 | 27 | // OpenConnector returns a connector which will be used by sql.DB to open a connection 28 | // to the MySQL database using MySQL X Protocol. 29 | // This will be used instead of the Open-method (which actually uses this method). 30 | func (d *Driver) OpenConnector(name string) (driver.Connector, error) { 31 | ds, err := NewDataSource(name) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return &connector{ 37 | dataSource: ds, 38 | }, nil 39 | } 40 | -------------------------------------------------------------------------------- /driver_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql_test 4 | 5 | import ( 6 | "database/sql" 7 | "errors" 8 | "fmt" 9 | "os" 10 | "testing" 11 | 12 | "github.com/golistic/xgo/xstrings" 13 | "github.com/golistic/xgo/xt" 14 | 15 | "github.com/golistic/pxmysql" 16 | "github.com/golistic/pxmysql/internal/xxt" 17 | "github.com/golistic/pxmysql/mysqlerrors" 18 | "github.com/golistic/pxmysql/register" 19 | ) 20 | 21 | func TestSQLDriver_Open(t *testing.T) { 22 | pwd := "aPassword" 23 | users := []string{"userfkEivks", "userFcae283"} 24 | 25 | for _, u := range users { 26 | _ = testContext.Server.DropUser(u) 27 | xt.OK(t, testContext.Server.CreateUser(u, pwd, testSchema, xxt.AuthPluginNative)) 28 | } 29 | 30 | defer func() { 31 | for _, u := range users { 32 | _ = testContext.Server.DropUser(u) 33 | } 34 | }() 35 | 36 | t.Run("valid data source names", func(t *testing.T) { 37 | var cases = map[string]string{ 38 | "no query": fmt.Sprintf("%s:%s@tcp(%s)/%s", 39 | users[0], pwd, testContext.XPluginAddr, testSchema), 40 | "no schema": fmt.Sprintf("%s:%s@tcp(%s)/?useTLS=true", 41 | users[1], pwd, testContext.XPluginAddr), 42 | } 43 | 44 | for cn, dsn := range cases { 45 | t.Run(cn, func(t *testing.T) { 46 | drv := &pxmysql.Driver{} 47 | _, err := drv.Open(dsn) 48 | xt.OK(t, err) 49 | }) 50 | } 51 | }) 52 | 53 | t.Run("retrieve LastInsertID after insert", func(t *testing.T) { 54 | tbl := "test_AFiek23eeF" 55 | q := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl) 56 | _, err := testContext.Server.ExecSQLStmt(q) 57 | xt.OK(t, err) 58 | 59 | _, err = testContext.Server.ExecSQLStmt("CREATE TABLE " + tbl + 60 | " (id INT AUTO_INCREMENT PRIMARY KEY, t1 INT)") 61 | xt.OK(t, err) 62 | 63 | dsn := getTCPDSN("", "") 64 | db, err := sql.Open("pxmysql", dsn) 65 | 66 | t.Run("using Prepared Statement", func(t *testing.T) { 67 | xt.OK(t, err) 68 | q := fmt.Sprintf("INSERT INTO `%s` (`t1`) VALUES (?)", tbl) 69 | stmt, err := db.Prepare(q) 70 | xt.OK(t, err) 71 | res, err := stmt.Exec("45") 72 | xt.OK(t, err) 73 | lastID, err := res.LastInsertId() 74 | xt.OK(t, err) 75 | xt.Eq(t, 1, lastID) 76 | }) 77 | 78 | t.Run("executing INSERT directly", func(t *testing.T) { 79 | xt.OK(t, err) 80 | q := fmt.Sprintf("INSERT INTO `%s` (`t1`) VALUES (?)", tbl) 81 | res, err := db.Exec(q, 46) 82 | xt.OK(t, err) 83 | lastID, err := res.LastInsertId() 84 | xt.OK(t, err) 85 | xt.Eq(t, 2, lastID) 86 | }) 87 | }) 88 | 89 | t.Run("using Unix socket", func(t *testing.T) { 90 | // runs app within Container; will not add to coverage 91 | app := "unix_socket" 92 | _, err := testContext.Builder.App(app) 93 | xt.OK(t, err) 94 | 95 | out, err := testContext.Server.ExecApp("/shared/builds/" + app) 96 | xt.OK(t, err) 97 | xt.Eq(t, testContext.Server.Version, string(out)) 98 | }) 99 | 100 | t.Run("unsupported protocol", func(t *testing.T) { 101 | drv := &pxmysql.Driver{} 102 | _, err := drv.Open("scott:tiger@UDP(localhost)/") 103 | xt.KO(t, err) 104 | xt.Eq(t, "unsupported protocol 'UDP'", errors.Unwrap(err).Error()) 105 | }) 106 | 107 | t.Run("not enough configured with missing username", func(t *testing.T) { 108 | drv := &pxmysql.Driver{} 109 | _, err := drv.Open(":tiger@tcp(localhost)/") 110 | xt.KO(t, err) 111 | xt.Eq(t, "configuration not valid (user missing)", err.Error()) 112 | xt.Eq(t, "user missing", errors.Unwrap(err).Error()) 113 | }) 114 | } 115 | 116 | func TestConnection_Ping(t *testing.T) { 117 | t.Run("using TCP", func(t *testing.T) { 118 | db, err := sql.Open(register.DriverName, getTCPDSN()) 119 | defer func() { _ = db.Close() }() 120 | xt.OK(t, err) 121 | xt.OK(t, db.Ping()) 122 | xt.Eq(t, "tcp", cnxType(t, db)) 123 | }) 124 | 125 | t.Run("using non-existing Unix socket", func(t *testing.T) { 126 | drv := &pxmysql.Driver{} 127 | os.TempDir() 128 | _, err := drv.Open("username:pwd@unix(_testdata/mysqlx.sock)/myschema") 129 | xt.KO(t, err) 130 | xt.Eq(t, mysqlerrors.ClientBadUnixSocket, err.(*mysqlerrors.Error).Code) 131 | xt.KO(t, errors.Unwrap(err)) 132 | xt.Eq(t, "no such file or directory", errors.Unwrap(err).Error()) 133 | }) 134 | } 135 | 136 | func TestDriver_Open(t *testing.T) { 137 | t.Run("pxmysql is registered", func(t *testing.T) { 138 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql to be registered") 139 | }) 140 | 141 | t.Run("mysql is not registered", func(t *testing.T) { 142 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql to be NOT registered") 143 | }) 144 | } 145 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "errors" 7 | "os" 8 | 9 | "github.com/golistic/pxmysql/mysqlerrors" 10 | ) 11 | 12 | func handleError(err error) error { 13 | if errors.Is(err, os.ErrDeadlineExceeded) { 14 | return mysqlerrors.ErrContextDeadlineExceeded 15 | } 16 | 17 | return err 18 | } 19 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/golistic/pxmysql 2 | 3 | go 1.21 4 | 5 | toolchain go1.21.0 6 | 7 | require ( 8 | github.com/golistic/gomake v0.9.4 9 | github.com/golistic/xgo v1.0.0 10 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 11 | golang.org/x/term v0.11.0 12 | google.golang.org/protobuf v1.31.0 13 | ) 14 | 15 | require ( 16 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9 // indirect 17 | golang.org/x/mod v0.12.0 // indirect 18 | golang.org/x/sys v0.11.0 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 2 | github.com/golistic/gomake v0.9.4 h1:6cOBeAsbSnacsskMvyg49DjCRCE9WnyhJv4Viy9rF9M= 3 | github.com/golistic/gomake v0.9.4/go.mod h1:IiYQuN6aK4Jb3QLA/rfOpeOqG7s8DNfBzVXZt8Pf6PA= 4 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9 h1:NBSzSvgVJjhI37ytLLcHLkRA7UTR/P1L/0dm6EY4GoE= 5 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9/go.mod h1:9Ad43QCHXG87MxOV2rB668vLcEV/lC4qJZhs+VdEXJQ= 6 | github.com/golistic/xgo v1.0.0 h1:NtwCDhGBJR0vOvn0l8S8z9wzmN3hYRM8Wwwd1+2BFO0= 7 | github.com/golistic/xgo v1.0.0/go.mod h1:em3spZJ1b8mrGv1P5Fx2Ewclaf0rqIfSwrzYIVnL6o4= 8 | github.com/golistic/xt v1.0.1 h1:prcwpL757GEu+dj6x2v6vxHkulkFdyQ3S1PqO+6ohPM= 9 | github.com/golistic/xt v1.0.1/go.mod h1:j1ZuWefyOD4HegoapgSjbXanA7X9YOKUeEB5gsEPgYg= 10 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 11 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 12 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 13 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= 14 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= 15 | golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= 16 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 17 | golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= 18 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 19 | golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= 20 | golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= 21 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 22 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 23 | google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= 24 | google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 25 | -------------------------------------------------------------------------------- /interfaces/message.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package interfaces 4 | 5 | import ( 6 | "google.golang.org/protobuf/proto" 7 | 8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx" 9 | ) 10 | 11 | type ServerMessager interface { 12 | Unmarshall(message proto.Message) error 13 | ServerMessageType() mysqlx.ServerMessages_Type 14 | } 15 | -------------------------------------------------------------------------------- /internal/mysqlx/info.md: -------------------------------------------------------------------------------- 1 | Generated from MySQL Server 8.0.34 at 2023-08-09T15:38:05Z. 2 | -------------------------------------------------------------------------------- /internal/xxt/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | type GoBuilder struct { 11 | tctx *TestContext 12 | Container *Container 13 | } 14 | 15 | func NewGoBuilder(tctx *TestContext, container *Container) (*GoBuilder, error) { 16 | gb := &GoBuilder{ 17 | Container: container, 18 | tctx: tctx, 19 | // Schema is stored at the end 20 | } 21 | 22 | _, err := gb.goVersion() 23 | if err != nil { 24 | return nil, fmt.Errorf("failed getting Go version (%w)", err) 25 | } 26 | 27 | return gb, nil 28 | } 29 | 30 | // App takes the application name which is located in the container's 31 | // shared volume located at "/shared". 32 | func (gb GoBuilder) App(name string) ([]byte, error) { 33 | args := []string{ 34 | "exec", "-i", gb.Container.Name, 35 | "sh", "/shared/build.sh", name, 36 | } 37 | 38 | return gb.Container.run(args...) 39 | } 40 | 41 | func (gb GoBuilder) goVersion() (string, error) { 42 | args := []string{ 43 | "exec", "-i", gb.Container.Name, 44 | "go", "version", 45 | } 46 | 47 | buf, err := gb.Container.run(args...) 48 | if err != nil { 49 | return "", err 50 | } 51 | 52 | version := string(buf) 53 | version = strings.Replace(version, "go version go", "", 1) 54 | 55 | return version, nil 56 | } 57 | -------------------------------------------------------------------------------- /internal/xxt/context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "regexp" 9 | ) 10 | 11 | var reVersion = regexp.MustCompile(`(\d+)\.(\d+).(\d+)`) 12 | 13 | const minMySQLVersion = 8000028 14 | const minMySQLVersionStr = "8.0.28" 15 | 16 | // container names are defined in _support/pxmysql-compose/docker-compose.yml 17 | const ( 18 | defaultDockerContainer = "pxmysql.test.db" 19 | defaultDockerContainerGo = "pxmysql.test.go" 20 | ) 21 | 22 | const ( 23 | defaultDockerExec = "docker" 24 | defaultDockerXPluginAddr = "127.0.0.1:53360" 25 | defaultDockerMySQLAddr = "127.0.0.1:53306" 26 | defaultDockerMySQLRootPwd = "rootpwd" 27 | ) 28 | 29 | const ( 30 | AuthPluginNative = "mysql_native_password" 31 | AuthPluginCachedSha2 = "caching_sha2_password" 32 | ) 33 | 34 | type TestContext struct { 35 | MySQLRootPwd string 36 | XPluginAddr string 37 | MySQLAddr string 38 | Server *MySQLServer 39 | Builder *GoBuilder 40 | } 41 | 42 | func New(schema string) (*TestContext, error) { 43 | dbContainerName := defaultDockerContainer 44 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_CONTAINER"); have { 45 | dbContainerName = v 46 | } 47 | 48 | goContainerName := defaultDockerContainerGo 49 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_CONTAINER_GO"); have { 50 | goContainerName = v 51 | } 52 | 53 | var err error 54 | tctx := &TestContext{ 55 | MySQLRootPwd: defaultDockerMySQLRootPwd, 56 | XPluginAddr: defaultDockerXPluginAddr, 57 | MySQLAddr: defaultDockerMySQLAddr, 58 | } 59 | 60 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_XPLUGIN_ADDR"); have { 61 | tctx.XPluginAddr = v 62 | } 63 | 64 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_MYSQL_ADDR"); have { 65 | tctx.MySQLAddr = v 66 | } 67 | 68 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_MYSQL_PWD"); have { 69 | tctx.MySQLRootPwd = v 70 | } 71 | 72 | dockerExec := defaultDockerExec 73 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_EXEC"); have { 74 | dockerExec = v 75 | } 76 | 77 | dbContainer, err := NewContainer(dbContainerName, dockerExec) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | if err := dbContainer.CheckRunning(); err != nil { 83 | return nil, fmt.Errorf("make sure the Docker is available (set XMYSQL_TEST_DOCKER_EXEC?)"+ 84 | " and container %s is running (%s)", dbContainerName, err) 85 | } 86 | 87 | goContainer, err := NewContainer(goContainerName, dockerExec) 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | if err := goContainer.CheckRunning(); err != nil { 93 | return nil, fmt.Errorf("make sure the Docker is available (set XMYSQL_TEST_DOCKER_EXEC?)"+ 94 | " and container %s is running (%s)", goContainerName, err) 95 | } 96 | 97 | tctx.Server, err = NewMySQLServer(tctx, dbContainer, schema) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | tctx.Builder, err = NewGoBuilder(tctx, goContainer) 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | return tctx, nil 108 | } 109 | -------------------------------------------------------------------------------- /internal/xxt/credentials.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | const ( 6 | UserNative = "user_native" 7 | UserNativePwd = "pwd_user_native" 8 | ) 9 | 10 | const ( 11 | UserCachedSHA256 = "user_sha256" 12 | UserCachedSHA256Pwd = "pwd_user_sha256" 13 | ) 14 | -------------------------------------------------------------------------------- /internal/xxt/docker.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "bytes" 7 | "fmt" 8 | "io" 9 | "os/exec" 10 | "strings" 11 | ) 12 | 13 | type Container struct { 14 | dockerExec string 15 | Name string 16 | } 17 | 18 | func NewContainer(name string, dockerExec string) (*Container, error) { 19 | return &Container{ 20 | dockerExec: dockerExec, 21 | Name: name, 22 | }, nil 23 | } 24 | 25 | // CopyFileFromContainer copies a file from a Docker container. 26 | func (c Container) CopyFileFromContainer(srcPath, dstPath string) error { 27 | args := []string{ 28 | "cp", c.Name + ":" + srcPath, dstPath, 29 | } 30 | 31 | if _, err := c.run(args...); err != nil { 32 | return err 33 | } 34 | 35 | return nil 36 | } 37 | 38 | // getDockerCmd searches for the Docker executable in the directories named by 39 | // the PATH environment variable. 40 | func (c Container) getDockerCmd(output io.Writer, args ...string) (*exec.Cmd, error) { 41 | if c.dockerExec == "" || c.dockerExec[0] != '/' { 42 | dockerExec, err := exec.LookPath("docker") 43 | if err != nil { 44 | return nil, err 45 | } 46 | c.dockerExec = dockerExec 47 | } 48 | 49 | if output == nil { 50 | output = io.Discard 51 | } 52 | 53 | return &exec.Cmd{ 54 | Path: c.dockerExec, 55 | Args: append([]string{c.dockerExec}, args...), 56 | Stdout: output, 57 | Stderr: output, 58 | }, nil 59 | } 60 | 61 | // run executes the docker command using provided arguments. 62 | func (c Container) run(args ...string) ([]byte, error) { 63 | output := bytes.NewBuffer(nil) 64 | 65 | cmd, err := c.getDockerCmd(output, args...) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | err = cmd.Run() 71 | switch err.(type) { 72 | case *exec.ExitError: 73 | if err := getContainerExecError(output); err != nil { 74 | return nil, err 75 | } 76 | 77 | case error: 78 | return nil, err 79 | } 80 | 81 | buf, err := io.ReadAll(output) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | var res []byte 87 | for _, l := range strings.Split(string(buf), "\n") { 88 | if strings.Contains(l, "[Warning]") { 89 | continue 90 | } 91 | 92 | res = append(res, []byte(l)...) 93 | } 94 | 95 | return res, nil 96 | } 97 | 98 | // CheckRunning checks whether the container is running. 99 | func (c Container) CheckRunning() error { 100 | args := []string{ 101 | "inspect", "-f", "'{{.State.Running}}'", c.Name, 102 | } 103 | 104 | _, err := c.run(args...) 105 | return err 106 | } 107 | 108 | func getContainerExecError(r io.Reader) error { 109 | buf, err := io.ReadAll(r) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | for _, l := range strings.Split(string(buf), "\n") { 115 | if strings.Contains(l, "[Warning]") { 116 | continue 117 | } 118 | 119 | if strings.HasPrefix(l, "Error ") || 120 | strings.HasPrefix(l, "error:") || 121 | strings.Contains(l, "ERROR ") || 122 | strings.Contains(l, "Error: ") || 123 | strings.Contains(l, "[ERROR]") { 124 | return fmt.Errorf(l) 125 | } 126 | 127 | if strings.HasPrefix(l, "OCI runtime exec failed") { 128 | return fmt.Errorf(strings.Replace(l, "OCI runtime exec failed: ", "", -1)) 129 | } 130 | } 131 | 132 | msg := bytes.Replace(buf, []byte("\n"), []byte("; "), -1) 133 | 134 | return fmt.Errorf(string(msg)) 135 | } 136 | -------------------------------------------------------------------------------- /internal/xxt/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/golistic/xgo/xt" 11 | 12 | "github.com/golistic/pxmysql/mysqlerrors" 13 | ) 14 | 15 | func NewTestErr(err error, format string, a ...any) error { 16 | if err != nil { 17 | format += " (" + err.Error() + ")" 18 | } 19 | return fmt.Errorf(format, a...) 20 | } 21 | 22 | func AssertMySQLError(t *testing.T, err error, code int) { 23 | t.Helper() 24 | 25 | xt.KO(t, err) 26 | var errMySQL *mysqlerrors.Error 27 | xt.Assert(t, errors.As(err, &errMySQL)) 28 | xt.Eq(t, code, errMySQL.Code) 29 | } 30 | -------------------------------------------------------------------------------- /internal/xxt/memory.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "fmt" 7 | "runtime" 8 | ) 9 | 10 | type MemoryUse struct { 11 | start *runtime.MemStats 12 | end *runtime.MemStats 13 | } 14 | 15 | func NewMemoryUse() *MemoryUse { 16 | m := &MemoryUse{} 17 | m.start = &runtime.MemStats{} 18 | m.end = &runtime.MemStats{} 19 | 20 | runtime.GC() 21 | runtime.ReadMemStats(m.start) 22 | return m 23 | } 24 | 25 | func (m *MemoryUse) Stop() { 26 | runtime.ReadMemStats(m.end) 27 | } 28 | 29 | func (m *MemoryUse) DiffAlloc() uint64 { 30 | return m.end.Alloc - m.start.Alloc 31 | } 32 | 33 | func (m MemoryUse) String() string { 34 | return fmt.Sprintf("DiffTotalAlloc = % 10d\tNumGC = % 5d\n", 35 | m.end.Alloc-m.start.Alloc, m.end.NumGC) 36 | } 37 | -------------------------------------------------------------------------------- /internal/xxt/server.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xxt 4 | 5 | import ( 6 | "bytes" 7 | "fmt" 8 | "io" 9 | "os" 10 | "path" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | 15 | "github.com/golistic/xgo/xstrings" 16 | ) 17 | 18 | type MySQLServer struct { 19 | tctx *TestContext 20 | Container *Container 21 | Schema string 22 | Version string 23 | } 24 | 25 | func NewMySQLServer(tctx *TestContext, container *Container, schema string) (*MySQLServer, error) { 26 | server := &MySQLServer{ 27 | Container: container, 28 | tctx: tctx, 29 | // Schema is stored at the end 30 | } 31 | 32 | if _, err := server.ExecSQLStmt("DROP SCHEMA IF EXISTS " + schema); err != nil { 33 | return nil, err 34 | } 35 | if _, err := server.ExecSQLStmt("CREATE SCHEMA " + schema); err != nil { 36 | return nil, err 37 | } 38 | 39 | errMsg := "failed getting MySQL version running in container %s (%s)" 40 | if output, err := server.ExecSQLStmt("SELECT VERSION()"); err != nil { 41 | return nil, NewTestErr(err, errMsg, container.Name, err) 42 | } else { 43 | parts := reVersion.FindAllStringSubmatch(string(output), -1) 44 | if parts == nil { 45 | return nil, NewTestErr(err, errMsg, container.Name, "reVersion") 46 | } 47 | 48 | // simplistic way of checking the MySQL version 49 | vMaj, err := strconv.ParseInt(parts[0][1], 10, 64) 50 | if err != nil { 51 | return nil, NewTestErr(nil, errMsg, container.Name) 52 | } 53 | 54 | vMin, err := strconv.ParseInt(parts[0][2], 10, 64) 55 | if err != nil { 56 | return nil, NewTestErr(nil, errMsg, container.Name) 57 | 58 | } 59 | 60 | patch, err := strconv.ParseInt(parts[0][3], 10, 64) 61 | if err != nil { 62 | return nil, NewTestErr(nil, errMsg, container.Name) 63 | } 64 | 65 | v := vMaj*1000000 + vMin*1000 + patch 66 | if v < minMySQLVersion { 67 | return nil, NewTestErr(fmt.Errorf("MySQL version must be %s or greater", minMySQLVersionStr), 68 | errMsg, container.Name) 69 | } 70 | 71 | server.Version = fmt.Sprintf("%d.%d.%d", vMaj, vMin, patch) 72 | } 73 | 74 | server.Schema = schema 75 | 76 | return server, nil 77 | } 78 | 79 | // ExecSQLStmt executes the SQL stmt using the mysql CLI within the container. 80 | // This is not SQL-injection safe and is only used for testing. 81 | func (my MySQLServer) ExecSQLStmt(stmt string) ([]byte, error) { 82 | args := []string{ 83 | "exec", "-i", my.Container.Name, 84 | "mysql", "-uroot", "-p" + my.tctx.MySQLRootPwd, "-NB", "-e", stmt, 85 | } 86 | 87 | if my.Schema != "" { 88 | args = append(args, []string{"-D", my.Schema}...) 89 | } 90 | 91 | return my.Container.run(args...) 92 | } 93 | 94 | // LoadSQLScript executes the statements from files provided as scripts 95 | // using the mysql CLI within the container. 96 | func (my MySQLServer) LoadSQLScript(scripts ...string) error { 97 | args := []string{ 98 | "exec", "-i", my.Container.Name, 99 | "mysql", "-uroot", "-p" + my.tctx.MySQLRootPwd, 100 | } 101 | 102 | stderr := bytes.NewBuffer(nil) 103 | 104 | cmd, err := my.Container.getDockerCmd(stderr, args...) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | stdin, err := cmd.StdinPipe() 110 | if err != nil { 111 | return err 112 | } 113 | 114 | wg := sync.WaitGroup{} 115 | 116 | var capturedErr error 117 | go func() { 118 | defer func() { 119 | _ = stdin.Close() 120 | wg.Done() 121 | }() 122 | 123 | for _, s := range scripts { 124 | if !strings.HasSuffix(s, ".sql") { 125 | s += ".sql" 126 | } 127 | p := path.Join("_testdata", s) 128 | sql, err := os.ReadFile(p) 129 | if err != nil { 130 | capturedErr = fmt.Errorf("failed reading SQL script %s (%s)", p, err) 131 | break 132 | } 133 | 134 | if _, err := io.WriteString(stdin, string(sql)); err != nil { 135 | capturedErr = fmt.Errorf("failed writing SQL script to STDIN %s (%s)", p, err) 136 | break 137 | } 138 | } 139 | }() 140 | 141 | wg.Add(1) 142 | if err := cmd.Run(); err != nil { 143 | if err := getContainerExecError(stderr); err != nil { 144 | capturedErr = err 145 | } 146 | } 147 | 148 | wg.Wait() 149 | return capturedErr 150 | } 151 | 152 | func (my MySQLServer) FlushPrivileges() error { 153 | args := []string{ 154 | "exec", "-i", my.Container.Name, 155 | "mysqladmin", "-uroot", "-p" + my.tctx.MySQLRootPwd, "flush-privileges", 156 | } 157 | 158 | _, err := my.Container.run(args...) 159 | return err 160 | } 161 | 162 | func (my MySQLServer) Variable(scope, variable string) (string, error) { 163 | if !(scope == "global" || scope == "session") { 164 | panic("scope must be one of 'session' or 'global'") 165 | } 166 | 167 | output, err := my.ExecSQLStmt(fmt.Sprintf("SELECT @@%s.%s", scope, variable)) 168 | if err != nil { 169 | return "", err 170 | } 171 | 172 | return string(output), nil 173 | } 174 | 175 | func (my MySQLServer) CreateUser(username, password, schema, authPlugin string) error { 176 | // this function is not SQL-injection-safe; only used for testing 177 | 178 | authPlugins := []string{AuthPluginNative, AuthPluginCachedSha2} 179 | 180 | if !xstrings.SliceHas(authPlugins, authPlugin) { 181 | panic("unsupported authMethod") 182 | } 183 | 184 | createUser := fmt.Sprintf("CREATE USER '%s'@'%%' IDENTIFIED WITH %s BY '%s'", 185 | username, authPlugin, password) 186 | grant := fmt.Sprintf("GRANT ALL ON %s.* TO '%s'@'%%'", schema, username) 187 | 188 | if _, err := my.ExecSQLStmt(createUser); err != nil { 189 | return err 190 | } 191 | if _, err := my.ExecSQLStmt(grant); err != nil { 192 | return err 193 | } 194 | 195 | return nil 196 | } 197 | 198 | func (my MySQLServer) DropUser(username string) error { 199 | // this function is not SQL-injection-safe; only used for testing 200 | 201 | dropUser := fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username) 202 | 203 | if _, err := my.ExecSQLStmt(dropUser); err != nil { 204 | return err 205 | } 206 | return nil 207 | } 208 | 209 | // ExecApp runs the application within the container found at path and returns its output. 210 | func (my MySQLServer) ExecApp(path string) ([]byte, error) { 211 | args := []string{ 212 | "exec", "-i", my.Container.Name, 213 | path, 214 | } 215 | 216 | return my.Container.run(args...) 217 | } 218 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql_test 4 | 5 | import ( 6 | "database/sql" 7 | "fmt" 8 | "os" 9 | "testing" 10 | 11 | "github.com/golistic/xgo/xt" 12 | 13 | _ "github.com/golistic/pxmysql/register" // registers pxmysql 14 | 15 | "github.com/golistic/pxmysql/internal/xxt" 16 | ) 17 | 18 | var ( 19 | testExitCode int 20 | testErr error 21 | testDockerContainer string 22 | testSchema = "pxmysqldriver_tests" 23 | testContext *xxt.TestContext 24 | ) 25 | 26 | func testTearDown() { 27 | if testErr != nil { 28 | testExitCode = 1 29 | fmt.Println(testErr) 30 | } 31 | } 32 | 33 | func TestMain(m *testing.M) { 34 | defer func() { os.Exit(testExitCode) }() 35 | defer testTearDown() 36 | 37 | var err error 38 | if testContext, testErr = xxt.New(testSchema); err != nil { 39 | return 40 | } 41 | 42 | if err := testContext.Server.Container.CopyFileFromContainer( 43 | "/etc/mysql/conf.d/ca.pem", "_testdata/mysql_ca.pem"); err != nil { 44 | testErr = fmt.Errorf("failed copying MySQL CA certificate from container %s (%s)", 45 | testDockerContainer, err) 46 | return 47 | } 48 | 49 | testExitCode = m.Run() 50 | } 51 | 52 | func getCredentials(credentials ...string) (string, string) { 53 | username := "root" 54 | password := testContext.MySQLRootPwd 55 | if len(credentials) > 0 && credentials[0] != "" { 56 | username = credentials[0] 57 | } 58 | if len(credentials) > 1 && credentials[1] != "" { 59 | password = credentials[1] 60 | } 61 | 62 | return username, password 63 | } 64 | 65 | func getTCPDSN(credentials ...string) string { 66 | username, password := getCredentials(credentials...) 67 | return fmt.Sprintf("%s:%s@tcp(%s)/%s?useTLS=yes", username, password, testContext.XPluginAddr, 68 | testSchema) 69 | } 70 | 71 | func cnxType(t *testing.T, db *sql.DB) string { 72 | t.Helper() 73 | 74 | q := "SELECT IF(HOST='localhost', 'unix', 'tcp') As CnxType " + 75 | "FROM performance_schema.processlist WHERE ID = CONNECTION_ID()" 76 | 77 | var ct string 78 | err := db.QueryRow(q).Scan(&ct) 79 | xt.OK(t, err) 80 | 81 | return ct 82 | } 83 | -------------------------------------------------------------------------------- /mysqlerrors/error.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package mysqlerrors 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "strings" 9 | "unicode" 10 | 11 | "github.com/golistic/pxmysql/interfaces" 12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx" 13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice" 14 | ) 15 | 16 | var ( 17 | ErrContextDeadlineExceeded = errors.New("context deadline exceeded") 18 | ) 19 | 20 | // Error holds information of MySQL returned error and is implementing 21 | // the Go error interface. 22 | type Error struct { 23 | Inner error 24 | Message string 25 | Code int 26 | SQLState string 27 | Severity int 28 | Parameters []any 29 | ExtraMessage string 30 | } 31 | 32 | // New instantiates an Error object with code being the MySQL 33 | // client or server error code. When the message contains placeholders (is 34 | // a format specifier), params are used to interpolate them. 35 | // When one of the params is an error, it is used as a wrapped error. 36 | // Panics when params contains more than one value which is error-type. 37 | func New(code int, params ...any) *Error { 38 | e, have := mysqlClientErrors[code] 39 | if !have { 40 | panic(fmt.Sprintf("error code %d not registered in mysqlClientErrors", code)) 41 | } 42 | 43 | for _, param := range params { 44 | if p, ok := param.(error); ok { 45 | if e.Inner != nil { 46 | panic("only one parameter can be of error type") 47 | } 48 | e.Inner = p 49 | } 50 | } 51 | 52 | e.Parameters = params 53 | return &e 54 | } 55 | 56 | func (e *Error) Unwrap() error { 57 | return e.Inner 58 | } 59 | 60 | // Error is the string representation of the error. Messages look the same, but 61 | // they are differently formatted than the MySQL Client message to conform Go 62 | // best practices. 63 | func (e *Error) Error() string { 64 | msg := e.Message 65 | 66 | if len(e.Parameters) > 0 { 67 | if e.Inner != nil { 68 | msg = fmt.Errorf(e.Message, e.Parameters...).Error() 69 | } else { 70 | msg = fmt.Sprintf(e.Message, e.Parameters...) 71 | } 72 | } 73 | 74 | if e.ExtraMessage != "" { 75 | msg += " (" + e.ExtraMessage + ")" 76 | } 77 | 78 | if len(msg) > 2 { 79 | r := []rune(msg) 80 | if !strings.HasPrefix(msg, "MySQL") { 81 | r[0] = unicode.ToLower(r[0]) 82 | } 83 | msg = string(r) 84 | } 85 | 86 | return fmt.Sprintf("%s [%d:%s]", msg, e.Code, e.SQLState) 87 | } 88 | 89 | // NewFromServerMessage takes msg and transforms it into an Error. 90 | func NewFromServerMessage(msg interfaces.ServerMessager) error { 91 | myErr := &mysqlx.Error{} 92 | 93 | if err := msg.Unmarshall(myErr); err != nil { 94 | return err 95 | } 96 | 97 | return &Error{ 98 | Message: myErr.GetMsg(), 99 | Code: int(myErr.GetCode()), 100 | SQLState: myErr.GetSqlState(), 101 | Severity: int(myErr.GetSeverity()), 102 | } 103 | } 104 | 105 | // MySQLWarning holds information of MySQL returned warning and is implementing 106 | // the Go error interface. 107 | type MySQLWarning struct { 108 | Message string 109 | Code int 110 | Level string 111 | } 112 | 113 | // Error is the string representation of the warning, mimicking how MySQL would 114 | // show them. 115 | func (w *MySQLWarning) Error() string { 116 | return fmt.Sprintf("%s %d: %s", w.Level, w.Code, w.Message) 117 | } 118 | 119 | // NewFromWarning takes msg and transforms it into a MySQLWarning. 120 | func NewFromWarning(msg *mysqlxnotice.Warning) error { 121 | return &MySQLWarning{ 122 | Level: msg.Level.String(), 123 | Message: msg.GetMsg(), 124 | Code: int(msg.GetCode()), 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /mysqlerrors/error_client.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package mysqlerrors 4 | 5 | // MySQL Client errors as found in the MySQL manual under 6 | // https://dev.mysql.com/doc/mysql-errors/8.0/en/client-error-reference.html. 7 | // Names have been altered so that they make more sense. For example, 8 | // CR_CONNECTION_ERROR became ClientBadUnixSocket. 9 | const ( 10 | ClientUnknown = 2000 11 | ClientBadUnixSocket = 2002 12 | ClientBadTCPSocket = 2005 13 | ClientWrongProtocol = 2007 14 | ClientNetPacketTooLarge = 2020 15 | ) 16 | 17 | var mysqlClientErrors = map[int]Error{ 18 | ClientUnknown: { // 2000 19 | Message: "unknown error", 20 | Code: ClientUnknown, 21 | SQLState: "HY000", 22 | }, 23 | ClientBadUnixSocket: { // 2002 24 | Message: "cannot connect to local MySQL server through socket '%s' (%w)", 25 | Code: ClientBadUnixSocket, 26 | SQLState: "HY000", 27 | }, 28 | ClientBadTCPSocket: { // 2005 29 | Message: "unknown MySQL server host '%s' (%w)", 30 | Code: ClientBadTCPSocket, 31 | SQLState: "HY000", 32 | }, 33 | ClientWrongProtocol: { // 2007 34 | Message: "wrong protocol", 35 | Code: ClientBadTCPSocket, 36 | SQLState: "HY000", 37 | }, 38 | ClientNetPacketTooLarge: { // 2020 39 | Message: "got packet bigger than 'mysqlx_max_allowed_packet' bytes", 40 | Code: ClientNetPacketTooLarge, 41 | SQLState: "HY000", 42 | }, 43 | } 44 | -------------------------------------------------------------------------------- /mysqlerrors/error_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package mysqlerrors 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestMySLQError_Error(t *testing.T) { 12 | t.Run("formatted MySQL errors", func(t *testing.T) { 13 | myErr := &Error{ 14 | Message: "Table 'test.no_such_table' doesn't exist", // shamelessly copy/pasted from MySQL docs 15 | Code: 1146, 16 | SQLState: "42S02", 17 | Severity: 1, 18 | } 19 | 20 | xt.Eq(t, "table 'test.no_such_table' doesn't exist [1146:42S02]", myErr.Error()) 21 | }) 22 | } 23 | 24 | func TestMySQLWarning_Error(t *testing.T) { 25 | t.Run("formatted as MySQL would do", func(t *testing.T) { 26 | myErr := &MySQLWarning{ 27 | Message: "Data truncated for column 'b' at row 1", // shamelessly copy/pasted from MySQL docs 28 | Code: 1265, 29 | Level: "Warning", 30 | } 31 | 32 | xt.Eq(t, "Warning 1265: Data truncated for column 'b' at row 1", myErr.Error()) 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /mysqlerrors/test_errors/mysqlerrors_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package test_errors 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "testing" 9 | "time" 10 | 11 | "github.com/golistic/xgo/xt" 12 | 13 | "github.com/golistic/pxmysql/xmysql" 14 | ) 15 | 16 | func TestMySQLErrors(t *testing.T) { 17 | t.Run("wrapped errors", func(t *testing.T) { 18 | config := &xmysql.ConnectConfig{ 19 | Address: "127.0.0.40", 20 | } 21 | 22 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 23 | defer cancel() 24 | 25 | _, err := xmysql.GetSession(ctx, config) 26 | xt.Eq(t, "unknown MySQL server host '127.0.0.40:33060' (i/o timeout) [2005:HY000]", err.Error()) 27 | xt.Eq(t, "i/o timeout", errors.Unwrap(err).Error()) 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /null/bytes.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "bytes" 7 | "database/sql/driver" 8 | "fmt" 9 | ) 10 | 11 | // Bytes represents a []byte (any MySQL BINARY types) that may be NULL. 12 | // This is not available in Go's sql package, and does not implement the Scanner interface. 13 | type Bytes struct { 14 | Bytes []byte 15 | Valid bool 16 | } 17 | 18 | var _ driver.Valuer = &Bytes{} 19 | var _ Nullable = &Bytes{} 20 | 21 | // Compare returns whether value compares with the nullable Bytes. 22 | // It returns: 23 | // - true when Valid and stored Bytes is equal to value 24 | // - true when not Valid and value is nil 25 | // - false in any other case 26 | func (n Bytes) Compare(value any) bool { 27 | if !n.Valid && value != nil { 28 | return false 29 | } 30 | 31 | if value == nil { 32 | return !n.Valid 33 | } 34 | 35 | switch v := value.(type) { 36 | case []byte: 37 | return bytes.Equal(n.Bytes, v) 38 | case *[]byte: 39 | return bytes.Equal(n.Bytes, *v) 40 | default: 41 | panic(fmt.Sprintf("value must be []byte or *[]byte; not %T", value)) 42 | } 43 | } 44 | 45 | // Value returns the value of n and implements the driver.Valuer 46 | // as well as Nullable interface. 47 | func (n Bytes) Value() (driver.Value, error) { 48 | if !n.Valid { 49 | return nil, nil 50 | } 51 | return n.Bytes, nil 52 | } 53 | -------------------------------------------------------------------------------- /null/bytes_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xt" 10 | ) 11 | 12 | func TestBytes_Compare(t *testing.T) { 13 | var cases = []struct { 14 | n Nullable 15 | value any 16 | exp bool 17 | }{ 18 | { 19 | n: Bytes{Bytes: []byte("Sakila"), Valid: true}, 20 | value: bytesPtr([]byte("Sakila")), 21 | exp: true, 22 | }, 23 | { 24 | n: Bytes{Bytes: []byte("Go gopher"), Valid: true}, 25 | value: bytesPtr([]byte("Sakila")), // supposed to be not 'Go gopher' 26 | exp: false, 27 | }, 28 | { 29 | n: Bytes{Bytes: []byte("Sakila"), Valid: true}, 30 | value: []byte("Sakila"), 31 | exp: true, 32 | }, 33 | { 34 | n: Bytes{Bytes: []byte("Go gopher"), Valid: true}, 35 | value: []byte("Sakila"), // supposed to be not 'Go gopher' 36 | exp: false, 37 | }, 38 | { 39 | n: Bytes{Bytes: nil, Valid: false}, 40 | value: nil, 41 | exp: true, 42 | }, 43 | { 44 | n: Bytes{Bytes: nil, Valid: false}, 45 | value: []byte{}, 46 | exp: false, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run("", func(t *testing.T) { 52 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 53 | }) 54 | } 55 | 56 | t.Run("panics if value type is not supported", func(t *testing.T) { 57 | xt.Panics(t, func() { 58 | _ = Bytes{Bytes: nil, Valid: true}.Compare("str") 59 | }) 60 | }) 61 | } 62 | 63 | func TestBytes_Value(t *testing.T) { 64 | data := []byte("I am Data") 65 | 66 | t.Run("valid", func(t *testing.T) { 67 | nb := Bytes{Bytes: data, Valid: true} 68 | v, _ := nb.Value() 69 | d, ok := v.([]byte) 70 | xt.Assert(t, ok, fmt.Sprintf("expected []byte; got %T", v)) 71 | xt.Eq(t, data, d) 72 | }) 73 | 74 | t.Run("not valid", func(t *testing.T) { 75 | nb := Bytes{Bytes: data, Valid: false} 76 | v, _ := nb.Value() 77 | xt.Eq(t, nil, v, "expected nil") 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /null/decimal.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | 9 | "github.com/golistic/pxmysql/decimal" 10 | ) 11 | 12 | // Decimal represents a decimal.Decimal (MySQL DECIMAL type) that may be NULL. 13 | // This is similar to types provided by Go's sql package, and does not implement the Scanner interface. 14 | type Decimal struct { 15 | Decimal decimal.Decimal 16 | Valid bool 17 | } 18 | 19 | var _ driver.Valuer = &Decimal{} 20 | var _ Nullable = &Decimal{} 21 | 22 | // Compare returns whether value compares with the nullable Decimal. 23 | // It returns: 24 | // - true when Valid and stored Decimal is equal to value 25 | // - true when not Valid and value is nil 26 | // - false in any other case 27 | func (nd Decimal) Compare(value any) bool { 28 | if !nd.Valid && value != nil { 29 | return false 30 | } 31 | 32 | if value == nil { 33 | return !nd.Valid 34 | } 35 | 36 | switch v := value.(type) { 37 | case decimal.Decimal: 38 | return nd.Decimal.Equal(v) 39 | case *decimal.Decimal: 40 | return nd.Decimal.Equal(*v) 41 | default: 42 | panic(fmt.Sprintf("value is of unsupported type %T", value)) 43 | } 44 | } 45 | 46 | // Value returns the value of n and implements the driver.Valuer 47 | // as well as Nullable interface. 48 | func (nd Decimal) Value() (driver.Value, error) { 49 | if !nd.Valid { 50 | return nil, nil 51 | } 52 | return nd.Decimal, nil 53 | } 54 | -------------------------------------------------------------------------------- /null/decimal_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xt" 10 | 11 | "github.com/golistic/pxmysql/decimal" 12 | ) 13 | 14 | func TestDecimal_Compare(t *testing.T) { 15 | dec1 := decimal.MustNew("8.56") 16 | dec2 := decimal.MustNew("19.469") 17 | 18 | t.Run("Decimal", func(t *testing.T) { 19 | var cases = []struct { 20 | n Nullable 21 | value any 22 | exp bool 23 | }{ 24 | { 25 | n: Decimal{Decimal: *dec1, Valid: true}, 26 | value: *dec1, 27 | exp: true, 28 | }, 29 | { 30 | n: Decimal{Decimal: *dec2, Valid: true}, 31 | value: *dec1, 32 | exp: false, 33 | }, 34 | { 35 | n: Decimal{Decimal: *dec1, Valid: true}, 36 | value: dec1, 37 | exp: true, 38 | }, 39 | { 40 | n: Decimal{Decimal: *dec2, Valid: true}, 41 | value: dec1, 42 | exp: false, 43 | }, 44 | { 45 | n: Decimal{Decimal: *dec1, Valid: false}, 46 | value: nil, 47 | exp: true, 48 | }, 49 | { 50 | n: Decimal{Decimal: *dec1, Valid: false}, 51 | value: dec1, 52 | exp: false, 53 | }, 54 | } 55 | 56 | for _, c := range cases { 57 | t.Run("", func(t *testing.T) { 58 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 59 | }) 60 | } 61 | }) 62 | 63 | t.Run("panics if value type is not supported", func(t *testing.T) { 64 | xt.Panics(t, func() { 65 | _ = Decimal{Decimal: *decimal.Zero, Valid: true}.Compare("str") 66 | }) 67 | }) 68 | } 69 | 70 | func TestDecimal_Value(t *testing.T) { 71 | pi := decimal.MustNew("3.14") 72 | 73 | t.Run("valid", func(t *testing.T) { 74 | nd := Decimal{Decimal: *pi, Valid: true} 75 | v, _ := nd.Value() 76 | d, ok := v.(decimal.Decimal) 77 | xt.Assert(t, ok, fmt.Sprintf("expected decimal.Decimal; got %T", v)) 78 | xt.Eq(t, *pi, d) 79 | }) 80 | 81 | t.Run("not valid", func(t *testing.T) { 82 | nd := Decimal{Decimal: *pi, Valid: false} 83 | v, _ := nd.Value() 84 | xt.Eq(t, nil, v, "expected nil") 85 | }) 86 | } 87 | -------------------------------------------------------------------------------- /null/duration.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | "time" 9 | ) 10 | 11 | // Duration represents a time.Duration (MySQL TIME type) that may be NULL. 12 | // This is not available in Go's sql package, and does not implement the Scanner interface. 13 | // Note that the sql.NullTime is for timestamps (which includes dates). 14 | type Duration struct { 15 | Duration time.Duration 16 | Valid bool 17 | } 18 | 19 | var _ driver.Valuer = &Duration{} 20 | var _ Nullable = &Duration{} 21 | 22 | // Compare returns whether value compares with the nullable Duration. 23 | // It returns: 24 | // - true when Valid and stored Duration is equal to value 25 | // - true when not Valid and value is nil 26 | // - false in any other case 27 | func (nd Duration) Compare(value any) bool { 28 | if !nd.Valid && value != nil { 29 | return false 30 | } 31 | 32 | if value == nil { 33 | return !nd.Valid 34 | } 35 | 36 | switch v := value.(type) { 37 | case time.Duration: 38 | return nd.Duration == v 39 | case *time.Duration: 40 | return nd.Duration == *v 41 | default: 42 | panic(fmt.Sprintf("value must be time.Duration or *time.Duration; not %T", value)) 43 | } 44 | } 45 | 46 | // Value returns the value of n and implements the driver.Valuer 47 | // as well as Nullable interface. 48 | func (nd Duration) Value() (driver.Value, error) { 49 | if !nd.Valid { 50 | return nil, nil 51 | } 52 | return nd.Duration, nil 53 | } 54 | -------------------------------------------------------------------------------- /null/duration_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | "time" 8 | 9 | "github.com/golistic/xgo/xt" 10 | ) 11 | 12 | func TestDuration_Compare(t *testing.T) { 13 | dur2d3h, _ := time.ParseDuration("2d3h") 14 | dur5h6m, _ := time.ParseDuration("5h6m") 15 | 16 | var cases = []struct { 17 | n Nullable 18 | value any 19 | exp bool 20 | }{ 21 | { 22 | n: Duration{Duration: dur2d3h, Valid: true}, 23 | value: &dur2d3h, 24 | exp: true, 25 | }, 26 | { 27 | n: Duration{Duration: dur2d3h, Valid: true}, 28 | value: &dur5h6m, // supposed to be different 29 | exp: false, 30 | }, 31 | { 32 | n: Duration{Duration: dur2d3h, Valid: true}, 33 | value: dur2d3h, 34 | exp: true, 35 | }, 36 | { 37 | n: Duration{Duration: dur2d3h, Valid: true}, 38 | value: dur5h6m, // supposed to be different 39 | exp: false, 40 | }, 41 | { 42 | n: Duration{Valid: false}, 43 | value: nil, 44 | exp: true, 45 | }, 46 | 47 | { 48 | n: Duration{Valid: false}, 49 | value: 0, 50 | exp: false, 51 | }, 52 | } 53 | 54 | for _, c := range cases { 55 | t.Run("", func(t *testing.T) { 56 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 57 | }) 58 | } 59 | 60 | t.Run("panics if value type is not supported", func(t *testing.T) { 61 | xt.Panics(t, func() { 62 | _ = Duration{Duration: 0, Valid: true}.Compare("str") 63 | }) 64 | }) 65 | } 66 | 67 | func TestDuration_Value(t *testing.T) { 68 | t.Run("valid", func(t *testing.T) { 69 | dur2d3h, _ := time.ParseDuration("2d3h") 70 | nd := Duration{Duration: dur2d3h, Valid: true} 71 | v, _ := nd.Value() 72 | d, ok := v.(time.Duration) 73 | xt.Assert(t, ok, "expected time.Duration") 74 | xt.Eq(t, dur2d3h, d) 75 | }) 76 | 77 | t.Run("not valid", func(t *testing.T) { 78 | nd := Duration{Duration: 0, Valid: false} 79 | v, _ := nd.Value() 80 | xt.Eq(t, nil, v, "expected nil") 81 | }) 82 | } 83 | -------------------------------------------------------------------------------- /null/float32.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | ) 9 | 10 | // Float32 represents a float32 (any MySQL FLOAT type) that may be NULL. 11 | // This is similar to sql.NullFloat64, and does not implement the Scanner interface. 12 | type Float32 struct { 13 | Float32 float32 14 | Valid bool 15 | } 16 | 17 | var _ driver.Valuer = &Float32{} 18 | var _ Nullable = &Float32{} 19 | 20 | // Compare returns whether value compares with the nullable Float32. 21 | // It returns: 22 | // - true when Valid and stored Float32 is equal to value 23 | // - true when not Valid and value is nil 24 | // - false in any other case 25 | func (nf Float32) Compare(value any) bool { 26 | if !nf.Valid && value != nil { 27 | return false 28 | } 29 | 30 | if value == nil { 31 | return !nf.Valid 32 | } 33 | 34 | switch v := value.(type) { 35 | case float32: 36 | return v == nf.Float32 37 | case *float32: 38 | return *v == nf.Float32 39 | default: 40 | panic(fmt.Sprintf("value is of unsupported type %T", value)) 41 | } 42 | } 43 | 44 | // Value returns the value of n and implements the driver.Valuer 45 | // as well as Nullable interface. 46 | func (nf Float32) Value() (driver.Value, error) { 47 | if !nf.Valid { 48 | return nil, nil 49 | } 50 | return nf.Float32, nil 51 | } 52 | -------------------------------------------------------------------------------- /null/float32_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestFloat32_Compare(t *testing.T) { 12 | t.Run("Float32", func(t *testing.T) { 13 | var cases = []struct { 14 | n Nullable 15 | value any 16 | exp bool 17 | }{ 18 | { 19 | n: Float32{Float32: 8.56, Valid: true}, 20 | value: float32(8.56), 21 | exp: true, 22 | }, 23 | { 24 | n: Float32{Float32: 19.469, Valid: true}, 25 | value: float32(8.56), // supposed to be not 19.469 26 | exp: false, 27 | }, 28 | { 29 | n: Float32{Float32: 8.56, Valid: true}, 30 | value: float32Ptr(8.56), 31 | exp: true, 32 | }, 33 | { 34 | n: Float32{Float32: 19.469, Valid: true}, 35 | value: float32Ptr(8.56), // supposed to be not 19.469 36 | exp: false, 37 | }, 38 | { 39 | n: Float32{Float32: 898, Valid: false}, 40 | value: nil, 41 | exp: true, 42 | }, 43 | { 44 | n: Float32{Float32: 9, Valid: false}, 45 | value: float32(9), 46 | exp: false, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run("", func(t *testing.T) { 52 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 53 | }) 54 | } 55 | }) 56 | 57 | t.Run("panics if value type is not supported", func(t *testing.T) { 58 | xt.Panics(t, func() { 59 | _ = Float32{Float32: 0, Valid: true}.Compare("str") 60 | }) 61 | }) 62 | } 63 | 64 | func TestFloat32_Value(t *testing.T) { 65 | pi := float32(3.14) 66 | 67 | t.Run("valid", func(t *testing.T) { 68 | nt := Float32{Float32: pi, Valid: true} 69 | v, _ := nt.Value() 70 | d, ok := v.(float32) 71 | xt.Assert(t, ok, "expected float32") 72 | xt.Eq(t, pi, d) 73 | }) 74 | 75 | t.Run("not valid", func(t *testing.T) { 76 | nd := Float32{Float32: pi, Valid: false} 77 | v, _ := nd.Value() 78 | xt.Eq(t, nil, v, "expected nil") 79 | }) 80 | } 81 | -------------------------------------------------------------------------------- /null/float64.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | ) 9 | 10 | // Float64 represents a float64 (any MySQL float/double type) that may be NULL. 11 | // This is similar to sql.NullFloat64, and does not implement the Scanner interface. 12 | type Float64 struct { 13 | Float64 float64 14 | Valid bool 15 | } 16 | 17 | var _ driver.Valuer = &Float64{} 18 | var _ Nullable = &Float64{} 19 | 20 | // Compare returns whether value compares with the nullable Float64. 21 | // It returns: 22 | // - true when Valid and stored Float64 is equal to value 23 | // - true when not Valid and value is nil 24 | // - false in any other case 25 | func (nf Float64) Compare(value any) bool { 26 | if !nf.Valid && value != nil { 27 | return false 28 | } 29 | 30 | if value == nil { 31 | return !nf.Valid 32 | } 33 | 34 | switch v := value.(type) { 35 | case float64: 36 | return v == nf.Float64 37 | case *float64: 38 | return *v == nf.Float64 39 | default: 40 | panic(fmt.Sprintf("value is of unsupported type %T", value)) 41 | } 42 | } 43 | 44 | // Value returns the value of n and implements the driver.Valuer 45 | // as well as Nullable interface. 46 | func (nf Float64) Value() (driver.Value, error) { 47 | if !nf.Valid { 48 | return nil, nil 49 | } 50 | return nf.Float64, nil 51 | } 52 | -------------------------------------------------------------------------------- /null/float64_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestFloat64_Compare(t *testing.T) { 12 | t.Run("Float64", func(t *testing.T) { 13 | var cases = []struct { 14 | n Nullable 15 | value any 16 | exp bool 17 | }{ 18 | { 19 | n: Float64{Float64: 8.56, Valid: true}, 20 | value: float64Ptr(8.56), 21 | exp: true, 22 | }, 23 | { 24 | n: Float64{Float64: 19.469, Valid: true}, 25 | value: float64Ptr(8.56), // supposed to be not 19.469 26 | exp: false, 27 | }, 28 | { 29 | n: Float64{Float64: 8.56, Valid: true}, 30 | value: 8.56, 31 | exp: true, 32 | }, 33 | { 34 | n: Float64{Float64: 19.469, Valid: true}, 35 | value: 8.56, // supposed to be not 19.469 36 | exp: false, 37 | }, 38 | { 39 | n: Float64{Float64: 898, Valid: false}, 40 | value: nil, 41 | exp: true, 42 | }, 43 | { 44 | n: Float64{Float64: 898, Valid: false}, 45 | value: 898, 46 | exp: false, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run("", func(t *testing.T) { 52 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 53 | }) 54 | } 55 | }) 56 | 57 | t.Run("panics if value type is not supported", func(t *testing.T) { 58 | xt.Panics(t, func() { 59 | _ = Float64{Float64: 0, Valid: true}.Compare("str") 60 | }) 61 | }) 62 | } 63 | 64 | func TestFloat64_Value(t *testing.T) { 65 | pi := 3.14 66 | 67 | t.Run("valid", func(t *testing.T) { 68 | nt := Float64{Float64: pi, Valid: true} 69 | v, _ := nt.Value() 70 | d, ok := v.(float64) 71 | xt.Assert(t, ok, "expected float64") 72 | xt.Eq(t, pi, d) 73 | }) 74 | 75 | t.Run("not valid", func(t *testing.T) { 76 | nd := Float64{Float64: pi, Valid: false} 77 | v, _ := nd.Value() 78 | xt.Eq(t, nil, v, "expected nil") 79 | }) 80 | } 81 | -------------------------------------------------------------------------------- /null/int64.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | 9 | "github.com/golistic/xgo/xconv" 10 | ) 11 | 12 | // Int64 represents an int64 (any MySQL signed integral type) that may be NULL. 13 | // This is similar to sql.NullInt64, and does not implement the Scanner interface. 14 | type Int64 struct { 15 | Int64 int64 16 | Valid bool 17 | } 18 | 19 | var _ driver.Valuer = &Int64{} 20 | var _ Nullable = &Int64{} 21 | 22 | // Compare returns whether value compares with the nullable Duration. 23 | // It returns: 24 | // - true when Valid and stored Duration is equal to value 25 | // - true when not Valid and value is nil 26 | // - false in any other case 27 | func (ni Int64) Compare(value any) bool { 28 | if !ni.Valid && value != nil { 29 | return false 30 | } 31 | 32 | if value == nil { 33 | return !ni.Valid 34 | } 35 | 36 | switch v := value.(type) { 37 | case int64, int, int8, int16, int32: 38 | return xconv.SignedAsInt64(v) == ni.Int64 39 | case *int64, *int, *int8, *int16, *int32: 40 | return *xconv.SignedAsInt64Ptr(v) == ni.Int64 41 | default: 42 | panic(fmt.Sprintf("value is of unsupported type %T", value)) 43 | } 44 | } 45 | 46 | // Value returns the value of n and implements the driver.Valuer 47 | // as well as Nullable interface. 48 | func (ni Int64) Value() (driver.Value, error) { 49 | if !ni.Valid { 50 | return nil, nil 51 | } 52 | return ni.Int64, nil 53 | } 54 | -------------------------------------------------------------------------------- /null/int64_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestInt64_Compare(t *testing.T) { 12 | t.Run("Int64", func(t *testing.T) { 13 | var cases = []struct { 14 | n Nullable 15 | value any 16 | exp bool 17 | }{ 18 | { 19 | n: Int64{Int64: -8, Valid: true}, 20 | value: int64Ptr(-8), 21 | exp: true, 22 | }, 23 | { 24 | n: Int64{Int64: -19, Valid: true}, 25 | value: int64Ptr(-8), // supposed to be not 19 26 | exp: false, 27 | }, 28 | { 29 | n: Int64{Int64: -8, Valid: true}, 30 | value: int64(-8), 31 | exp: true, 32 | }, 33 | { 34 | n: Int64{Int64: -19, Valid: true}, 35 | value: int64(-8), // supposed to be not 19 36 | exp: false, 37 | }, 38 | { 39 | n: Int64{Int64: -898, Valid: false}, 40 | value: nil, 41 | exp: true, 42 | }, 43 | { 44 | n: Int64{Int64: -898, Valid: false}, 45 | value: 123, 46 | exp: false, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run("", func(t *testing.T) { 52 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 53 | }) 54 | } 55 | }) 56 | 57 | t.Run("panics if value type is not supported", func(t *testing.T) { 58 | xt.Panics(t, func() { 59 | _ = Int64{Int64: 0, Valid: true}.Compare("str") 60 | }) 61 | }) 62 | } 63 | 64 | func TestInt64_Value(t *testing.T) { 65 | t.Run("valid", func(t *testing.T) { 66 | ni := Int64{Int64: 9, Valid: true} 67 | v, _ := ni.Value() 68 | d, ok := v.(int64) 69 | xt.Assert(t, ok, "expected int64") 70 | xt.Eq(t, 9, d) 71 | }) 72 | 73 | t.Run("not valid", func(t *testing.T) { 74 | ni := Int64{Int64: 0, Valid: false} 75 | v, _ := ni.Value() 76 | xt.Eq(t, nil, v, "expected nil") 77 | }) 78 | } 79 | -------------------------------------------------------------------------------- /null/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import "database/sql/driver" 6 | 7 | type Nullable interface { 8 | Compare(any) bool 9 | Value() (driver.Value, error) 10 | } 11 | 12 | // Compare checks value against the value stored with Nullable. 13 | // It returns true when: 14 | // - nullable is valid and value of nullable is equal to given value, 15 | // - or when nullable is not valid (SQL NULL) and value is nil 16 | // 17 | // Panics when value cannot be used with given Nullable n. 18 | func Compare(n Nullable, value any) bool { 19 | return n.Compare(value) 20 | } 21 | -------------------------------------------------------------------------------- /null/main_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | func uint64Ptr(v uint64) *uint64 { 6 | return &v 7 | } 8 | 9 | func bytesPtr(v []byte) *[]byte { 10 | return &v 11 | } 12 | 13 | func stringPtr(v string) *string { 14 | return &v 15 | } 16 | 17 | func float64Ptr(v float64) *float64 { 18 | return &v 19 | } 20 | 21 | func float32Ptr(v float32) *float32 { 22 | return &v 23 | } 24 | 25 | func int64Ptr(v int64) *int64 { 26 | return &v 27 | } 28 | -------------------------------------------------------------------------------- /null/string.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | ) 9 | 10 | // String represents as string (any MySQL CHAR-kind of data type) that may be NULL. 11 | // This is similar to sql.NullString, and does not implement the Scanner interface. 12 | type String struct { 13 | String string 14 | Valid bool 15 | } 16 | 17 | var _ driver.Valuer = &String{} 18 | var _ Nullable = &String{} 19 | 20 | // Compare returns whether value compares with the nullable String. 21 | // It returns: 22 | // - true when Valid and stored String is equal to value 23 | // - true when not Valid and value is nil 24 | // - false in any other case 25 | func (ns String) Compare(value any) bool { 26 | if !ns.Valid && value != nil { 27 | return false 28 | } 29 | 30 | if value == nil { 31 | return !ns.Valid 32 | } 33 | 34 | switch v := value.(type) { 35 | case string: 36 | return ns.String == v 37 | case *string: 38 | return ns.String == *v 39 | default: 40 | panic(fmt.Sprintf("value must be string or *string; not %T", value)) 41 | } 42 | } 43 | 44 | // Value returns the value of n and implements the driver.Valuer 45 | // as well as Nullable interface. 46 | func (ns String) Value() (driver.Value, error) { 47 | if !ns.Valid { 48 | return nil, nil 49 | } 50 | return ns.String, nil 51 | } 52 | -------------------------------------------------------------------------------- /null/string_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestString_Compare(t *testing.T) { 12 | var cases = []struct { 13 | n Nullable 14 | value any 15 | exp bool 16 | }{ 17 | { 18 | n: String{String: "Sakila", Valid: true}, 19 | value: "Sakila", 20 | exp: true, 21 | }, 22 | { 23 | n: String{String: "Sakila", Valid: true}, 24 | value: "Go gopher", // supposed to not include 'Go gopher' 25 | exp: false, 26 | }, 27 | { 28 | n: String{String: "Sakila", Valid: true}, 29 | value: stringPtr("Sakila"), 30 | exp: true, 31 | }, 32 | { 33 | n: String{String: "Sakila", Valid: true}, 34 | value: stringPtr("Go gopher"), // supposed to not include 'Go gopher' 35 | exp: false, 36 | }, 37 | { 38 | n: String{Valid: false}, 39 | value: nil, 40 | exp: true, 41 | }, 42 | { 43 | n: String{Valid: false}, 44 | value: "", 45 | exp: false, 46 | }, 47 | } 48 | 49 | for _, c := range cases { 50 | t.Run("", func(t *testing.T) { 51 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 52 | }) 53 | } 54 | 55 | t.Run("panics if value type is not supported", func(t *testing.T) { 56 | xt.Panics(t, func() { 57 | _ = String{Valid: true}.Compare(123) 58 | }) 59 | }) 60 | } 61 | 62 | func TestString_Value(t *testing.T) { 63 | str := "String!" 64 | 65 | t.Run("valid", func(t *testing.T) { 66 | ns := String{String: str, Valid: true} 67 | v, _ := ns.Value() 68 | d, ok := v.(string) 69 | xt.Assert(t, ok, "expected string") 70 | xt.Eq(t, str, d) 71 | }) 72 | 73 | t.Run("not valid", func(t *testing.T) { 74 | ns := String{String: str, Valid: false} 75 | v, _ := ns.Value() 76 | xt.Eq(t, nil, v, "expected nil") 77 | }) 78 | } 79 | -------------------------------------------------------------------------------- /null/strings.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | "sort" 9 | ) 10 | 11 | // Strings represents a []string (slice of strings), for example used for MySQL ENUM 12 | // type, that may be NULL. 13 | // This is not available in the Go's sql package, and does not implement the Scanner interface. 14 | type Strings struct { 15 | Strings []string 16 | Valid bool 17 | } 18 | 19 | var _ driver.Valuer = &Strings{} 20 | var _ Nullable = &Strings{} 21 | 22 | // Compare returns whether value compares with the nullable Strings. 23 | // It returns: 24 | // - true when Valid and stored Strings is equal to value 25 | // - true when not Valid and value is nil 26 | // - false in any other case 27 | func (ns Strings) Compare(value any) bool { 28 | if !ns.Valid && value != nil { 29 | return false 30 | } 31 | 32 | if value == nil { 33 | return !ns.Valid 34 | } 35 | 36 | equal := func(a, b []string) bool { 37 | if len(a) != len(b) { 38 | return false 39 | } 40 | 41 | ac := make([]string, len(a)) 42 | copy(ac, a) 43 | sort.Strings(ac) 44 | 45 | bc := make([]string, len(b)) 46 | copy(bc, b) 47 | sort.Strings(bc) 48 | 49 | for i, v := range ac { 50 | if v != bc[i] { 51 | return false 52 | } 53 | } 54 | return true 55 | } 56 | 57 | switch v := value.(type) { 58 | case []string: 59 | return equal(ns.Strings, v) 60 | case *[]string: 61 | return equal(ns.Strings, *v) 62 | default: 63 | panic(fmt.Sprintf("value must be []strings or []*strings; not %T", value)) 64 | } 65 | } 66 | 67 | // Value returns the value of n and implements the driver.Valuer 68 | // as well as Nullable interface. 69 | func (ns Strings) Value() (driver.Value, error) { 70 | if !ns.Valid { 71 | return nil, nil 72 | } 73 | return ns.Strings, nil 74 | } 75 | -------------------------------------------------------------------------------- /null/strings_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xt" 10 | ) 11 | 12 | func TestStrings_Compare(t *testing.T) { 13 | mascots1 := []string{"Sakila", "Go gopher"} 14 | mascots2 := []string{"Sakila", "Duke"} 15 | mascots3 := []string{"Sakila"} 16 | 17 | var cases = []struct { 18 | n Nullable 19 | value any 20 | exp bool 21 | }{ 22 | { 23 | n: Strings{Strings: mascots1, Valid: true}, 24 | value: mascots1, 25 | exp: true, 26 | }, 27 | { 28 | n: Strings{Strings: mascots1, Valid: true}, 29 | value: mascots2, 30 | exp: false, 31 | }, 32 | { 33 | n: Strings{Strings: mascots1, Valid: true}, 34 | value: &mascots1, 35 | exp: true, 36 | }, 37 | { 38 | n: Strings{Strings: mascots1, Valid: true}, 39 | value: &mascots3, 40 | exp: false, 41 | }, 42 | { 43 | n: Strings{Strings: nil, Valid: false}, 44 | value: nil, 45 | exp: true, 46 | }, 47 | { 48 | n: Strings{Strings: []string{}, Valid: false}, 49 | value: nil, 50 | exp: true, 51 | }, 52 | { 53 | n: Strings{Strings: []string{}, Valid: false}, 54 | value: []string{""}, 55 | exp: false, 56 | }, 57 | } 58 | 59 | for _, c := range cases { 60 | t.Run("", func(t *testing.T) { 61 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 62 | }) 63 | } 64 | 65 | t.Run("panics if value type is not supported", func(t *testing.T) { 66 | xt.Panics(t, func() { 67 | _ = Strings{Strings: nil, Valid: true}.Compare("str") 68 | }) 69 | }) 70 | } 71 | 72 | func TestStrings_Value(t *testing.T) { 73 | data := []string{"Sakila", "Go gopher"} 74 | 75 | t.Run("valid", func(t *testing.T) { 76 | ns := Strings{Strings: data, Valid: true} 77 | v, _ := ns.Value() 78 | d, ok := v.([]string) 79 | xt.Assert(t, ok, fmt.Sprintf("expected []string; got %T", v)) 80 | xt.Eq(t, data, d) 81 | }) 82 | 83 | t.Run("not valid", func(t *testing.T) { 84 | ns := Strings{Strings: data, Valid: false} 85 | v, _ := ns.Value() 86 | xt.Eq(t, nil, v, "expected nil") 87 | }) 88 | } 89 | -------------------------------------------------------------------------------- /null/time.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | "time" 9 | ) 10 | 11 | // Time represents as string (MySQL TIMESTAMP, DATETIME, and DATE types) that may be NULL. 12 | // This is similar to sql.NullTime, and does not implement the Scanner interface. 13 | type Time struct { 14 | Time time.Time 15 | Valid bool 16 | } 17 | 18 | var _ driver.Valuer = &Time{} 19 | var _ Nullable = &Time{} 20 | 21 | // Compare returns whether value compares with the nullable Time. 22 | // It returns: 23 | // - true when Valid and stored Time is equal to value 24 | // - true when not Valid and value is nil 25 | // - false in any other case 26 | func (nd Time) Compare(value any) bool { 27 | if !nd.Valid && value != nil { 28 | return false 29 | } 30 | 31 | if value == nil { 32 | return !nd.Valid 33 | } 34 | 35 | switch v := value.(type) { 36 | case time.Time: 37 | return nd.Time.Equal(v) 38 | case *time.Time: 39 | return nd.Time.Equal(*v) 40 | default: 41 | panic(fmt.Sprintf("value must be time.Time or *time.Time; not %T", value)) 42 | } 43 | } 44 | 45 | // Value returns the value of n and implements the driver.Valuer 46 | // as well as Nullable interface. 47 | func (nd Time) Value() (driver.Value, error) { 48 | if !nd.Valid { 49 | return nil, nil 50 | } 51 | return nd.Time, nil 52 | } 53 | -------------------------------------------------------------------------------- /null/time_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | "time" 8 | 9 | "github.com/golistic/xgo/xt" 10 | "github.com/golistic/xgo/xtime" 11 | ) 12 | 13 | func TestTime_Compare(t *testing.T) { 14 | now := time.Now() 15 | yesterday := xtime.Yesterday() 16 | 17 | t.Run("Uint64", func(t *testing.T) { 18 | var cases = []struct { 19 | n Nullable 20 | value *time.Time 21 | exp bool 22 | }{ 23 | { 24 | n: Time{Time: now, Valid: true}, 25 | value: &now, 26 | exp: true, 27 | }, 28 | { 29 | n: Time{Time: now, Valid: true}, 30 | value: &yesterday, 31 | exp: false, 32 | }, 33 | { 34 | n: Time{Time: now, Valid: false}, 35 | value: nil, 36 | exp: false, 37 | }, 38 | } 39 | 40 | for _, c := range cases { 41 | t.Run("", func(t *testing.T) { 42 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 43 | }) 44 | } 45 | }) 46 | 47 | t.Run("non-pointer", func(t *testing.T) { 48 | xt.Assert(t, !Time{Time: now, Valid: false}.Compare(now)) 49 | xt.Assert(t, Time{Time: now, Valid: true}.Compare(now)) 50 | }) 51 | 52 | t.Run("panics if value type is not supported", func(t *testing.T) { 53 | xt.Panics(t, func() { 54 | _ = Time{Time: now, Valid: true}.Compare("str") 55 | }) 56 | }) 57 | 58 | t.Run("value is explicitly nil", func(t *testing.T) { 59 | xt.Assert(t, Time{Time: now, Valid: false}.Compare(nil)) 60 | }) 61 | } 62 | 63 | func TestTime_Value(t *testing.T) { 64 | now := time.Now() 65 | 66 | t.Run("valid", func(t *testing.T) { 67 | nt := Time{Time: now, Valid: true} 68 | v, _ := nt.Value() 69 | d, ok := v.(time.Time) 70 | xt.Assert(t, ok, "expected time.Time") 71 | xt.Eq(t, now, d) 72 | }) 73 | 74 | t.Run("not valid", func(t *testing.T) { 75 | nd := Time{Time: now, Valid: false} 76 | v, _ := nd.Value() 77 | xt.Eq(t, nil, v, "expected nil") 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /null/uint64.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | 9 | "github.com/golistic/xgo/xconv" 10 | ) 11 | 12 | // Uint64 represents an uint64 (any MySQL unsigned integer type) that may be NULL. 13 | // This is not available in Go's sql package, and does not implement the Scanner interface. 14 | type Uint64 struct { 15 | Uint64 uint64 16 | Valid bool 17 | } 18 | 19 | var _ driver.Valuer = &Uint64{} 20 | var _ Nullable = &Uint64{} 21 | 22 | // Compare returns whether value compares with the nullable Uint64. 23 | // It returns: 24 | // - true when Valid and stored Uint64 is equal to value 25 | // - true when not Valid and value is nil 26 | // - false in any other case 27 | func (ni Uint64) Compare(value any) bool { 28 | if !ni.Valid && value != nil { 29 | return false 30 | } 31 | 32 | if value == nil { 33 | return !ni.Valid 34 | } 35 | 36 | switch v := value.(type) { 37 | case uint64, uint, uint8, uint16, uint32: 38 | return xconv.UnsignedAsUint64(v) == ni.Uint64 39 | case *uint64, *uint, *uint8, *uint16, *uint32: 40 | return *xconv.UnsignedAsUint64Ptr(v) == ni.Uint64 41 | default: 42 | panic(fmt.Sprintf("value is of unsupported type %T", value)) 43 | } 44 | } 45 | 46 | // Value returns the value of n and implements the driver.Valuer 47 | // as well as Nullable interface. 48 | func (ni Uint64) Value() (driver.Value, error) { 49 | if !ni.Valid { 50 | return nil, nil 51 | } 52 | return ni.Uint64, nil 53 | } 54 | -------------------------------------------------------------------------------- /null/uint64_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package null 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | func TestUint64_Compare(t *testing.T) { 12 | t.Run("Uint64", func(t *testing.T) { 13 | var cases = []struct { 14 | n Nullable 15 | value any 16 | exp bool 17 | }{ 18 | { 19 | n: Uint64{Uint64: 8, Valid: true}, 20 | value: uint64Ptr(8), 21 | exp: true, 22 | }, 23 | { 24 | n: Uint64{Uint64: 19, Valid: true}, 25 | value: uint64Ptr(8), 26 | exp: false, 27 | }, 28 | { 29 | n: Uint64{Uint64: 8, Valid: true}, 30 | value: uint(8), 31 | exp: true, 32 | }, 33 | { 34 | n: Uint64{Uint64: 19, Valid: true}, 35 | value: uint(8), 36 | exp: false, 37 | }, 38 | { 39 | n: Uint64{Uint64: 898, Valid: false}, 40 | value: nil, 41 | exp: true, 42 | }, 43 | { 44 | n: Uint64{Uint64: 898, Valid: false}, 45 | value: uint(898), 46 | exp: false, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run("", func(t *testing.T) { 52 | xt.Eq(t, c.exp, c.n.Compare(c.value)) 53 | }) 54 | } 55 | }) 56 | 57 | t.Run("panics if value type is not supported", func(t *testing.T) { 58 | xt.Panics(t, func() { 59 | _ = Uint64{Uint64: 0, Valid: true}.Compare("str") 60 | }) 61 | }) 62 | } 63 | 64 | func TestUint64_Value(t *testing.T) { 65 | t.Run("valid", func(t *testing.T) { 66 | ni := Uint64{Uint64: 9, Valid: true} 67 | v, _ := ni.Value() 68 | d, ok := v.(uint64) 69 | xt.Assert(t, ok, "expected uint64") 70 | xt.Eq(t, 9, d) 71 | }) 72 | 73 | t.Run("not valid", func(t *testing.T) { 74 | ni := Uint64{Uint64: 0, Valid: false} 75 | v, _ := ni.Value() 76 | xt.Eq(t, nil, v, "expected nil") 77 | }) 78 | } 79 | -------------------------------------------------------------------------------- /register/mysql/register.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package mysql 4 | 5 | import ( 6 | "database/sql" 7 | 8 | "github.com/golistic/pxmysql" 9 | ) 10 | 11 | const DriverName = "mysql" 12 | 13 | func init() { 14 | sql.Register(DriverName, &pxmysql.Driver{}) 15 | } 16 | -------------------------------------------------------------------------------- /register/mysql/register_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package mysql 4 | 5 | import ( 6 | "database/sql" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xstrings" 10 | "github.com/golistic/xgo/xt" 11 | ) 12 | 13 | func TestDriver_Open(t *testing.T) { 14 | t.Run("pxmysql is not registered", func(t *testing.T) { 15 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql not to be registered") 16 | }) 17 | 18 | t.Run("mysql is registered", func(t *testing.T) { 19 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql to be registered") 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /register/register.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package register 4 | 5 | import ( 6 | "database/sql" 7 | 8 | "github.com/golistic/pxmysql" 9 | ) 10 | 11 | const DriverName = "pxmysql" 12 | 13 | func init() { 14 | sql.Register(DriverName, &pxmysql.Driver{}) 15 | } 16 | -------------------------------------------------------------------------------- /register/register_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package register 4 | 5 | import ( 6 | "database/sql" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xstrings" 10 | "github.com/golistic/xgo/xt" 11 | ) 12 | 13 | func TestDriver_Open(t *testing.T) { 14 | t.Run("pxmysql is registered", func(t *testing.T) { 15 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql to be registered") 16 | }) 17 | 18 | t.Run("mysql is not registered", func(t *testing.T) { 19 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql not to be registered") 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "database/sql/driver" 7 | "fmt" 8 | "math" 9 | 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | type result struct { 14 | xpresult *xmysql.Result 15 | } 16 | 17 | var _ driver.Result = &result{} 18 | 19 | // LastInsertId returns the database's auto-generated ID after, 20 | // for example, an INSERT into a table with primary key. 21 | // If this information is not available, zero (0) is returned. 22 | func (r result) LastInsertId() (int64, error) { 23 | if r.xpresult != nil { 24 | lid := r.xpresult.LastInsertID() 25 | if lid > math.MaxInt64 { 26 | return 0, fmt.Errorf("LastInsertID overflowed max 64-bit unsigned integer") 27 | } 28 | return int64(lid), nil 29 | } 30 | return 0, nil 31 | } 32 | 33 | // RowsAffected returns the number of rows affected by the query. 34 | // If this information is not available, zero (0) is returned. 35 | func (r result) RowsAffected() (int64, error) { 36 | if r.xpresult != nil { 37 | affected := r.xpresult.RowsAffected() 38 | if affected > math.MaxInt64 { 39 | return 0, fmt.Errorf("RowsAffected overflowed max 64-bit unsigned integer") 40 | } 41 | return int64(affected), nil 42 | } 43 | return 0, nil 44 | } 45 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "database/sql/driver" 7 | "io" 8 | 9 | "github.com/golistic/pxmysql/null" 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | type rows struct { 14 | xpresult *xmysql.Result 15 | 16 | currRowIndex int 17 | } 18 | 19 | var _ driver.Rows = &rows{} 20 | 21 | // Columns returns the names of the columns. 22 | func (r *rows) Columns() []string { 23 | if r.xpresult == nil { 24 | return nil 25 | } 26 | 27 | cols := make([]string, len(r.xpresult.Columns)) 28 | 29 | for i, c := range r.xpresult.Columns { 30 | cols[i] = string(c.Name) 31 | } 32 | 33 | return cols 34 | } 35 | 36 | func (r *rows) Close() error { 37 | return nil 38 | } 39 | 40 | func (r *rows) Next(dest []driver.Value) error { 41 | if r.xpresult == nil || r.currRowIndex >= len(r.xpresult.Rows) { 42 | return io.EOF 43 | } 44 | 45 | for i, value := range r.xpresult.Rows[r.currRowIndex].Values { 46 | if n, ok := value.(null.Nullable); ok { 47 | var err error 48 | dest[i], err = n.Value() 49 | if err != nil { 50 | return err 51 | } 52 | } else { 53 | dest[i] = value 54 | } 55 | } 56 | 57 | r.currRowIndex++ 58 | return nil 59 | } 60 | -------------------------------------------------------------------------------- /rows_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package pxmysql_test 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/golistic/xgo/xt" 13 | ) 14 | 15 | func TestRows_Next(t *testing.T) { 16 | db, err := sql.Open("pxmysql", getTCPDSN("", "")) 17 | xt.OK(t, err) 18 | defer func() { _ = db.Close() }() 19 | 20 | ctx := context.Background() 21 | 22 | t.Run("time.Time", func(t *testing.T) { 23 | tbl := "test_data_types_null_time" 24 | _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE `%s` (id int, ts DATETIME NULL)", tbl)) 25 | xt.OK(t, err) 26 | 27 | _, err = db.ExecContext(ctx, fmt.Sprintf( 28 | "INSERT INTO `%s` (id, ts) VALUE (1, NOW()),(2, NULL)", tbl)) 29 | xt.OK(t, err) 30 | 31 | stmt := fmt.Sprintf("SELECT ts FROM `%s` WHERE id = ?", tbl) 32 | 33 | var ts time.Time 34 | xt.OK(t, db.QueryRowContext(ctx, stmt, 1).Scan(&ts)) 35 | 36 | var tsNull sql.NullTime 37 | xt.OK(t, db.QueryRowContext(ctx, stmt, 2).Scan(&tsNull)) 38 | }) 39 | } 40 | -------------------------------------------------------------------------------- /statement.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "time" 9 | 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | var closeTimeout = time.Second 14 | 15 | type statement struct { 16 | prepared *xmysql.Prepared 17 | result *xmysql.Result 18 | } 19 | 20 | var ( 21 | _ driver.Stmt = &statement{} 22 | _ driver.StmtQueryContext = &statement{} 23 | _ driver.StmtExecContext = &statement{} 24 | ) 25 | 26 | func (s *statement) Close() error { 27 | ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) 28 | defer cancel() 29 | 30 | if err := s.prepared.Deallocate(ctx); err != nil { 31 | return err 32 | } else { 33 | s.prepared = nil 34 | s.result = nil 35 | } 36 | return nil 37 | } 38 | 39 | // NumInput returns the number of placeholders. 40 | func (s *statement) NumInput() int { 41 | if s.prepared != nil { 42 | return s.prepared.NumPlaceholders() 43 | } 44 | 45 | return 0 46 | } 47 | 48 | // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. 49 | // Deprecated: use ExecContext instead. 50 | func (s *statement) Exec(args []driver.Value) (driver.Result, error) { 51 | named := make([]driver.NamedValue, len(args)) 52 | for i, arg := range args { 53 | named[i].Name = "" 54 | named[i].Ordinal = i + 1 55 | named[i].Value = arg 56 | } 57 | return s.ExecContext(context.Background(), named) 58 | } 59 | 60 | func (s *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 61 | execArgs := make([]any, len(args)) 62 | 63 | for i, a := range args { 64 | execArgs[i] = a.Value 65 | } 66 | 67 | execResult, err := s.prepared.Execute(ctx, execArgs...) 68 | if err != nil { 69 | return nil, handleError(err) 70 | } 71 | 72 | res := &result{ 73 | xpresult: execResult, 74 | } 75 | 76 | return res, nil 77 | } 78 | 79 | // Query executes a query that may return rows, such as a SELECT. 80 | // Deprecated: use QueryContext instead. 81 | func (s *statement) Query(args []driver.Value) (driver.Rows, error) { 82 | named := make([]driver.NamedValue, len(args)) 83 | for i, arg := range args { 84 | named[i].Name = "" 85 | named[i].Ordinal = i + 1 86 | named[i].Value = arg 87 | } 88 | return s.QueryContext(context.Background(), named) 89 | } 90 | 91 | // QueryContext executes a query that may return rows, such as a SELECT. 92 | func (s *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 93 | execArgs := make([]any, len(args)) 94 | 95 | for i, a := range args { 96 | execArgs[i] = a 97 | } 98 | 99 | execResult, err := s.prepared.Execute(ctx, execArgs...) 100 | if err != nil { 101 | return nil, handleError(err) 102 | } 103 | 104 | if len(execResult.Rows) == 0 { 105 | return &rows{}, nil 106 | } 107 | 108 | r := &rows{ 109 | xpresult: execResult, 110 | } 111 | 112 | return r, nil 113 | } 114 | -------------------------------------------------------------------------------- /statement_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql_test 4 | 5 | import ( 6 | "context" 7 | "crypto/md5" 8 | "database/sql" 9 | "encoding/hex" 10 | "errors" 11 | "fmt" 12 | "sort" 13 | "strings" 14 | "testing" 15 | "time" 16 | 17 | "github.com/golistic/xgo/xt" 18 | 19 | "github.com/golistic/pxmysql/mysqlerrors" 20 | ) 21 | 22 | func TestStatement_Close(t *testing.T) { 23 | dsn := getTCPDSN("", "") 24 | db, err := sql.Open("pxmysql", dsn) 25 | xt.OK(t, err) 26 | 27 | stmt := "SELECT ?" 28 | prep, err := db.Prepare(stmt) 29 | xt.OK(t, err) 30 | 31 | _, err = prep.Exec(3) 32 | xt.OK(t, err) 33 | 34 | xt.OK(t, prep.Close()) 35 | 36 | _, err = prep.Exec(3) 37 | xt.KO(t, err) 38 | xt.Eq(t, "sql: statement is closed", err.Error()) 39 | } 40 | 41 | func testOpenQueryRowsClose() ([]string, error) { 42 | dsn := getTCPDSN("", "") 43 | db, err := sql.Open("pxmysql", dsn) 44 | if err != nil { 45 | return nil, err 46 | } 47 | defer func() { _ = db.Close() }() 48 | 49 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA` 50 | prep, err := db.Prepare(stmt) 51 | if err != nil { 52 | return nil, err 53 | } 54 | defer func() { _ = prep.Close() }() 55 | 56 | rows, err := prep.QueryContext(context.Background(), "mysql") 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | var tablesNames []string 62 | for rows.Next() { 63 | var name sql.NullString 64 | if err := rows.Scan(&name); err != nil { 65 | return nil, err 66 | } 67 | if !name.Valid { 68 | return nil, fmt.Errorf("found entry with null as table name") 69 | } 70 | tablesNames = append(tablesNames, name.String) 71 | } 72 | 73 | return tablesNames, nil 74 | } 75 | 76 | func TestStatement_ExecContext(t *testing.T) { 77 | t.Run("respect timeout", func(t *testing.T) { 78 | dsn := getTCPDSN("", "") 79 | db, err := sql.Open("pxmysql", dsn) 80 | xt.OK(t, err) 81 | defer func() { _ = db.Close() }() 82 | 83 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 84 | defer cancel() 85 | 86 | _, err = db.ExecContext(ctx, "SELECT SLEEP(5)") 87 | xt.KO(t, err) 88 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error()) 89 | }) 90 | } 91 | 92 | func BenchmarkStatement_QueryContext(b *testing.B) { 93 | b.Run("fetch tables from mysql database", func(b *testing.B) { 94 | if _, err := testOpenQueryRowsClose(); err != nil { 95 | b.Error(err) 96 | } 97 | }) 98 | } 99 | 100 | func TestStatement_QueryContext(t *testing.T) { 101 | t.Run("has rows in result", func(t *testing.T) { 102 | tableNames, err := testOpenQueryRowsClose() 103 | xt.OK(t, err) 104 | 105 | sort.Strings(tableNames) 106 | sum := md5.Sum([]byte(strings.Join(tableNames, " "))) 107 | xt.Eq(t, "859173a1b7b8ef446282e772dcd3039b", hex.EncodeToString(sum[:])) 108 | }) 109 | 110 | t.Run("respect timeout", func(t *testing.T) { 111 | dsn := getTCPDSN("", "") 112 | db, err := sql.Open("pxmysql", dsn) 113 | xt.OK(t, err) 114 | defer func() { _ = db.Close() }() 115 | 116 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 117 | defer cancel() 118 | 119 | _, err = db.QueryContext(ctx, "SELECT SLEEP(5)") 120 | xt.KO(t, err) 121 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error()) 122 | }) 123 | 124 | t.Run("does not return sql.ErrNoRows", func(t *testing.T) { 125 | db, err := sql.Open("pxmysql", getTCPDSN("", "")) 126 | xt.OK(t, err) 127 | defer func() { _ = db.Close() }() 128 | 129 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA` 130 | rows, err := db.QueryContext(context.Background(), stmt, "_this_does_not_exists_") 131 | xt.OK(t, err) 132 | xt.Assert(t, !rows.Next(), "expected no rows") 133 | }) 134 | 135 | t.Run("QueryRowContext does return sql.ErrNoRows", func(t *testing.T) { 136 | db, err := sql.Open("pxmysql", getTCPDSN("", "")) 137 | xt.OK(t, err) 138 | defer func() { _ = db.Close() }() 139 | 140 | var name string 141 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA` 142 | err = db.QueryRowContext(context.Background(), stmt, "_this_does_not_exists_").Scan(&name) 143 | xt.KO(t, err) 144 | xt.Assert(t, errors.Is(err, sql.ErrNoRows)) 145 | }) 146 | } 147 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package pxmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "fmt" 9 | 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | type Transaction struct { 14 | session *xmysql.Session 15 | } 16 | 17 | var _ driver.Tx = &Transaction{} 18 | 19 | func (tx *Transaction) Commit() error { 20 | if tx.session == nil { 21 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn) 22 | } 23 | 24 | if _, err := tx.session.ExecuteStatement(context.Background(), "COMMIT"); err != nil { 25 | return err 26 | } 27 | 28 | return nil 29 | } 30 | 31 | func (tx *Transaction) Rollback() error { 32 | if tx.session == nil { 33 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn) 34 | } 35 | 36 | if _, err := tx.session.ExecuteStatement(context.Background(), "ROLLBACK"); err != nil { 37 | return err 38 | } 39 | 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /xmysql/_testdata/.gitignore: -------------------------------------------------------------------------------- 1 | mysql_ca.pem -------------------------------------------------------------------------------- /xmysql/_testdata/base.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA IF NOT EXISTS pxmysql_tests; 2 | CREATE SCHEMA IF NOT EXISTS pxmysql_tests_a; 3 | 4 | -- Following users are used for testing the authentication with 5 | -- MySQL Authentication Plugins. 6 | CREATE USER IF NOT EXISTS 'user_native'@'%' IDENTIFIED WITH mysql_native_password 7 | BY 'pwd_user_native'; 8 | GRANT ALL ON pxmysql_tests.* TO 'user_native'@'%'; 9 | GRANT ALL ON pxmysql_tests_a.* TO 'user_native'@'%'; 10 | 11 | CREATE USER IF NOT EXISTS 'user_sha256'@'%' IDENTIFIED WITH caching_sha2_password 12 | BY 'pwd_user_sha256'; 13 | GRANT ALL ON pxmysql_tests.* TO 'user_sha256'@'%'; 14 | GRANT ALL ON pxmysql_tests_a.* TO 'user_sha256'@'%'; 15 | 16 | CREATE USER IF NOT EXISTS 'pxmysqltest'@'%' IDENTIFIED WITH mysql_native_password BY ''; 17 | GRANT ALL ON pxmysql_tests.* TO 'pxmysqltest'@'%'; 18 | GRANT ALL ON pxmysql_tests_a.* TO 'pxmysqltest'@'%'; 19 | 20 | -- Clean up objects that might have been created by tests 21 | DROP SCHEMA IF EXISTS `pxmysql_2839cks829dka`; 22 | -------------------------------------------------------------------------------- /xmysql/_testdata/data_types_datetime.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS `data_types_datetime`; 4 | 5 | CREATE TABLE data_types_datetime 6 | ( 7 | id TINYINT AUTO_INCREMENT, 8 | dt_date DATE NOT NULL, 9 | dt_time TIME(6) NOT NULL, 10 | dt_datetime DATETIME(6) NOT NULL, 11 | dt_timestamp TIMESTAMP NOT NULL, 12 | dt_year YEAR NOT NULL, 13 | PRIMARY KEY (id) 14 | ); 15 | 16 | SET @@time_zone = '+00:00'; 17 | INSERT INTO data_types_datetime 18 | VALUES (1, '2005-03-01', '08:00:01.123456', '2005-03-01 07:00:01', FROM_UNIXTIME(1109660401), 19 | 2005), 20 | (2, '9999-12-31', '838:59:59.0', '9999-12-31 23:59:59.999999', FROM_UNIXTIME(2147483647), 21 | 1901), 22 | (3, '1000-01-01', '-838:59:59.0', '1000-01-01 00:00:00', FROM_UNIXTIME(1), 23 | 1901); 24 | -------------------------------------------------------------------------------- /xmysql/_testdata/data_types_numeric.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS `data_types_numeric`; 4 | 5 | CREATE TABLE data_types_numeric 6 | ( 7 | id TINYINT AUTO_INCREMENT, 8 | numeric_bit BIT(6) NULL, 9 | numeric_bool BOOL NULL, 10 | numeric_tinyint TINYINT NULL, 11 | numeric_tinyint_unsigned TINYINT UNSIGNED NULL, 12 | numeric_smallint SMALLINT NOT NULL, 13 | numeric_smallint_unsigned SMALLINT UNSIGNED NOT NULL, 14 | numeric_mediumint MEDIUMINT NOT NULL, 15 | numeric_mediumint_unsigned MEDIUMINT UNSIGNED NOT NULL, 16 | numeric_int INT NOT NULL, 17 | numeric_int_unsigned INT UNSIGNED NOT NULL, 18 | numeric_bigint BIGINT NOT NULL, 19 | numeric_bigint_unsigned BIGINT UNSIGNED NOT NULL, 20 | numeric_decimal DECIMAL(65, 30) NOT NULL, 21 | numeric_decimal2 DECIMAL(65, 1) NOT NULL, 22 | numeric_decimal3 DECIMAL(18, 9) NOT NULL, 23 | PRIMARY KEY (id) 24 | ); 25 | 26 | INSERT INTO data_types_numeric 27 | VALUES (1, b'100110', false, 127, 0, 28 | 32767, 0, 8388607, 0, 2147483647, 0, 29 | 9223372036854775807, 0, 30 | 3.14, 31 | 9999999999999999999999999999999999999999999999999999999999991234.9, 32 | 123456789.000001), 33 | (2, b'000110', true, -128, 255, 34 | -32768, 65535, -8388608, 16777215, -2147483648, 4294967295, 35 | -9223372036854775808, 18446744073709551615, 36 | -3.14, 37 | -9999999999999999999999999999999999999999999999999999999999991234.5, 38 | -123456789.000001) 39 | ; 40 | -------------------------------------------------------------------------------- /xmysql/_testdata/data_types_string.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS `data_types_string`; 4 | 5 | CREATE TABLE `data_types_string` 6 | ( 7 | id TINYINT AUTO_INCREMENT, 8 | s_char CHAR(255) NOT NULL, 9 | s_varchar VARCHAR(400) NOT NULL, 10 | s_binary BINARY(20) NOT NULL, 11 | s_varbinary VARBINARY(20) NOT NULL, 12 | s_longtext LONGTEXT NOT NULL, 13 | s_tinyblob TINYBLOB NOT NULL, 14 | s_enum ENUM ('Go', 'Python', 'JavaScript') NOT NULL, 15 | s_set SET ('Go', 'Python', 'JavaScript') NOT NULL, 16 | PRIMARY KEY (id) 17 | ); 18 | 19 | INSERT INTO data_types_string 20 | VALUES (1, 21 | CONCAT('CHAR', REPEAT('a', 251)), 22 | CONCAT('VARCHAR', REPEAT('b', 393)), 23 | X'0708090a0b0c0d0e0f10', 24 | X'08090a0b0c0d0e0f10', 25 | CONCAT('LONGTEXT', REPEAT('l', @@mysqlx_max_allowed_packet - 10)), 26 | 'I am a tiny blob', 27 | 'Go', 28 | 'Python,Go'); 29 | -------------------------------------------------------------------------------- /xmysql/_testdata/inserting.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS `inserts01`; 4 | 5 | CREATE TABLE inserts01 6 | ( 7 | id TINYINT NOT NULL AUTO_INCREMENT, 8 | c1 VARCHAR(20), 9 | PRIMARY KEY (id) 10 | ); 11 | -------------------------------------------------------------------------------- /xmysql/_testdata/prepared_stmt.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS `numeric_not_null`, `numeric_null`; 4 | DROP TABLE IF EXISTS `temporal_not_null`, `temporal_null`; 5 | DROP TABLE IF EXISTS `strings_not_null`, `strings_null`; 6 | 7 | CREATE TABLE numeric_not_null 8 | ( 9 | id TINYINT AUTO_INCREMENT, 10 | bit_ BIT(6) NOT NULL, 11 | bool_ BOOL NOT NULL, 12 | tinyint_ TINYINT NOT NULL, 13 | tinyint_unsigned TINYINT UNSIGNED NOT NULL, 14 | smallint_ SMALLINT NOT NULL, 15 | smallint_unsigned SMALLINT UNSIGNED NOT NULL, 16 | mediumint_ MEDIUMINT NOT NULL, 17 | mediumint_unsigned MEDIUMINT UNSIGNED NOT NULL, 18 | int_ INT NOT NULL, 19 | int_unsigned INT UNSIGNED NOT NULL, 20 | bigint_ BIGINT NOT NULL, 21 | bigint_unsigned BIGINT UNSIGNED NOT NULL, 22 | decimal_ DECIMAL(65, 30) NOT NULL, 23 | float_ FLOAT NOT NULL, 24 | float_unsigned FLOAT UNSIGNED NOT NULL, 25 | double_ DOUBLE NOT NULL, 26 | double_unsigned DOUBLE UNSIGNED NOT NULL, 27 | PRIMARY KEY (id) 28 | ); 29 | 30 | CREATE TABLE numeric_null 31 | ( 32 | id TINYINT AUTO_INCREMENT, 33 | bit_ BIT(6) NULL, 34 | bool_ BOOL NULL, 35 | tinyint_ TINYINT NULL, 36 | tinyint_unsigned TINYINT UNSIGNED NULL, 37 | smallint_ SMALLINT NULL, 38 | smallint_unsigned SMALLINT UNSIGNED NULL, 39 | mediumint_ MEDIUMINT NULL, 40 | mediumint_unsigned MEDIUMINT UNSIGNED NULL, 41 | int_ INT NULL, 42 | int_unsigned INT UNSIGNED NULL, 43 | bigint_ BIGINT NULL, 44 | bigint_unsigned BIGINT UNSIGNED NULL, 45 | decimal_ DECIMAL(65, 30) NULL, 46 | float_ FLOAT NULL, 47 | float_unsigned FLOAT UNSIGNED NULL, 48 | double_ DOUBLE NULL, 49 | double_unsigned DOUBLE UNSIGNED NULL, 50 | PRIMARY KEY (id) 51 | ); 52 | 53 | CREATE TABLE temporal_not_null 54 | ( 55 | id TINYINT AUTO_INCREMENT, 56 | datetime_ DATETIME(6) NOT NULL, 57 | date_ DATE NOT NULL, 58 | timestamp_ TIMESTAMP(6) NOT NULL, 59 | year_ YEAR NOT NULL, 60 | PRIMARY KEY (id) 61 | ); 62 | 63 | CREATE TABLE temporal_null 64 | ( 65 | id TINYINT AUTO_INCREMENT, 66 | datetime_ DATETIME(6) NULL, 67 | date_ DATE NULL, 68 | timestamp_ TIMESTAMP(6) NULL, 69 | year_ YEAR NULL, 70 | PRIMARY KEY (id) 71 | ); 72 | 73 | CREATE TABLE strings_not_null 74 | ( 75 | id TINYINT AUTO_INCREMENT, 76 | char_ CHAR(255) NOT NULL, 77 | binary_ BINARY(255) NOT NULL, 78 | varchar_ VARCHAR(600) NOT NULL, 79 | varbinary_ VARBINARY(410) NOT NULL, 80 | tinyblob_ TINYBLOB NOT NULL, 81 | tinytext_ TINYTEXT NOT NULL, 82 | blob_ BLOB NOT NULL, 83 | text_ TEXT NOT NULL, 84 | mediumblob_ MEDIUMBLOB NOT NULL, 85 | mediumtext_ MEDIUMTEXT NOT NULL, 86 | longblob_ LONGBLOB NOT NULL, 87 | longtext_ LONGTEXT NOT NULL, 88 | enum_ ENUM ('Earth', 'Moon', 'Mars', 'Europa') NOT NULL, 89 | set_ SET ('Earth', 'Moon', 'Mars') NOT NULL, 90 | PRIMARY KEY (id) 91 | ); 92 | 93 | CREATE TABLE strings_null 94 | ( 95 | id TINYINT AUTO_INCREMENT, 96 | char_ CHAR(255) NULL, 97 | binary_ BINARY(255) NULL, 98 | varchar_ VARCHAR(600) NULL, 99 | varbinary_ VARBINARY(410) NULL, 100 | tinyblob_ TINYBLOB NULL, 101 | tinytext_ TINYTEXT NULL, 102 | blob_ BLOB NULL, 103 | text_ TEXT NULL, 104 | mediumblob_ MEDIUMBLOB NULL, 105 | mediumtext_ MEDIUMTEXT NULL, 106 | longblob_ LONGBLOB NULL, 107 | longtext_ LONGTEXT NULL, 108 | enum_ ENUM ('Earth', 'Moon', 'Mars'), 109 | set_ SET ('Earth', 'Moon', 'Mars'), 110 | PRIMARY KEY (id) 111 | ); -------------------------------------------------------------------------------- /xmysql/_testdata/schema_collections.sql: -------------------------------------------------------------------------------- 1 | USE pxmysql_tests; 2 | 3 | DROP TABLE IF EXISTS not_collection_28380dew22; 4 | DROP TABLE IF EXISTS collection_wic28skwixkd; 5 | DROP TABLE IF EXISTS collection_weux73293jsnsj; 6 | 7 | CREATE TABLE not_collection_28380dew22 8 | ( 9 | id INT 10 | ); 11 | 12 | CREATE TABLE collection_wic28skwixkd 13 | ( 14 | doc json null, 15 | _id varbinary(32) as (json_unquote(json_extract(`doc`, _utf8mb4'$._id'))) stored 16 | primary key, 17 | _json_schema json as (_utf8mb4'{"type":"object"}'), 18 | constraint $val_strict_wic28skwixkd 19 | check (json_schema_valid(`_json_schema`, `doc`)) 20 | ); 21 | 22 | CREATE TABLE collection_weux73293jsnsj 23 | ( 24 | doc json null, 25 | _id varbinary(32) as (json_unquote(json_extract(`doc`, _utf8mb4'$._id'))) stored 26 | primary key, 27 | _json_schema json as (_utf8mb4'{"type":"object"}'), 28 | constraint $val_strict_weux73293jsnsj 29 | check (json_schema_valid(`_json_schema`, `doc`)) 30 | ); -------------------------------------------------------------------------------- /xmysql/authentication.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "crypto/sha1" 7 | "crypto/sha256" 8 | "fmt" 9 | ) 10 | 11 | type authn struct { 12 | username string 13 | password string 14 | schema string 15 | challenge []byte 16 | } 17 | 18 | // authSHA256Data prepares authentication data to be sent with the AuthenticateContinue 19 | // message using SHA256. Username and scrambled password are returned as hex. 20 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html. 21 | func authSHA256Data(an authn) ([]byte, error) { 22 | if len(an.challenge) != authChallengeLen { 23 | return nil, fmt.Errorf("authentication challenge must be 20 bytes (was %d)", len(an.challenge)) 24 | } 25 | 26 | var scramble string 27 | if an.password != "" { 28 | // hex(sha256(password) XOR sha256(challenge + sha256(sha256(password)))) 29 | h1 := sha256.Sum256([]byte(an.password)) 30 | hh1 := sha256.Sum256(h1[:]) 31 | 32 | hr := sha256.New() 33 | hr.Write(hh1[:]) 34 | hr.Write(an.challenge) 35 | h2 := hr.Sum(nil) 36 | 37 | for i := range h2 { 38 | h1[i] ^= h2[i] 39 | } 40 | scramble = fmt.Sprintf("%x", h1) 41 | } 42 | 43 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, scramble)), nil 44 | } 45 | 46 | // authMYSQL41Data prepares authentication data to be sent with the AuthenticateContinue 47 | // message using SHA1 (also known as mysql_native_password). Username and scrambled password 48 | // are returned as hex. 49 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html. 50 | func authMySQL41Data(an authn) ([]byte, error) { 51 | if len(an.challenge) != authChallengeLen { 52 | return nil, fmt.Errorf("authentication challenge must be 20 bytes (was %d)", len(an.challenge)) 53 | } 54 | 55 | var scramble string 56 | if an.password != "" { 57 | // hex(sha1(password) XOR sha1(challenge + sha1(sha1(password)))) 58 | h1 := sha1.Sum([]byte(an.password)) 59 | hh1 := sha1.Sum(h1[:]) 60 | 61 | hr := sha1.New() 62 | hr.Write(an.challenge) 63 | hr.Write(hh1[:]) 64 | h2 := hr.Sum(nil) 65 | 66 | for i := range h1 { 67 | h1[i] ^= h2[i] 68 | } 69 | 70 | scramble = fmt.Sprintf("*%x", h1) 71 | } 72 | 73 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, scramble)), nil 74 | } 75 | 76 | // authMySQLPlain prepares authentication data to be sent in plain text. This is only 77 | // supported when connection is encrypted (TLS) 78 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html. 79 | func authMySQLPlain(an authn) ([]byte, error) { 80 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, an.password)), nil 81 | } 82 | -------------------------------------------------------------------------------- /xmysql/capabilities.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "fmt" 7 | 8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxconnection" 9 | "github.com/golistic/pxmysql/xmysql/internal/network" 10 | ) 11 | 12 | // ServerCapabilities holds the capabilities returned by the server. 13 | type ServerCapabilities struct { 14 | TLS bool 15 | AuthMechanisms []string 16 | } 17 | 18 | // NewServerCapabilitiesFromMessage instantiates a new ServerCapabilities object 19 | // using a message returned by MySQL Server's X Plugin. 20 | func NewServerCapabilitiesFromMessage(msg *network.ServerMessage) (*ServerCapabilities, error) { 21 | capabilities := &mysqlxconnection.Capabilities{} 22 | if err := msg.Unmarshall(capabilities); err != nil { 23 | return nil, fmt.Errorf("message was not mysqlxconnection.Capabilities") 24 | } 25 | 26 | sc := &ServerCapabilities{} 27 | 28 | for _, c := range capabilities.Capabilities { 29 | switch c.GetName() { 30 | case "tls": 31 | sc.TLS = c.Value.Scalar.GetVBool() 32 | case "authentication.mechanisms": 33 | sc.AuthMechanisms = []string{} 34 | for _, m := range c.Value.Array.Value { 35 | sc.AuthMechanisms = append(sc.AuthMechanisms, string(m.Scalar.GetVString().GetValue())) 36 | } 37 | } 38 | } 39 | 40 | return sc, nil 41 | } 42 | -------------------------------------------------------------------------------- /xmysql/collations.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | type CollationID int 6 | 7 | type Collation struct { 8 | ID int `json:"id"` 9 | Name string `json:"name"` 10 | CharSet string `json:"charSet"` 11 | } 12 | 13 | // IsSupportedCollation returns whether c is a valid/supported collation. Note that 14 | // MySQL Protocol X only supports the utf8mb4 character set, and consequently, only collations of utf8mb4. 15 | // Argument c can be the internal MYSQL ID or MySQL name. 16 | func IsSupportedCollation[T string | uint64 | int](c T) bool { 17 | var have bool 18 | switch v := any(c).(type) { 19 | case string: 20 | _, have = Collations[v] 21 | case uint64: 22 | _, have = collationIDs[v] 23 | case int: 24 | _, have = collationIDs[uint64(v)] 25 | } 26 | return have 27 | } 28 | -------------------------------------------------------------------------------- /xmysql/collations_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | 10 | "github.com/golistic/pxmysql/xmysql" 11 | ) 12 | 13 | func TestIsSupportedCollation(t *testing.T) { 14 | t.Run("using MySQL ID", func(t *testing.T) { 15 | xt.Assert(t, xmysql.IsSupportedCollation(241)) 16 | xt.Assert(t, xmysql.IsSupportedCollation(uint64(241))) 17 | }) 18 | 19 | t.Run("using MySQL name", func(t *testing.T) { 20 | xt.Assert(t, xmysql.IsSupportedCollation("utf8mb4_esperanto_ci")) 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /xmysql/collection.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "slices" 9 | ) 10 | 11 | type Collection struct { 12 | schema *Schema 13 | session *Session 14 | name string 15 | } 16 | 17 | var ( 18 | _ adder = (*Collection)(nil) 19 | ) 20 | 21 | // newCollection instantiates a new Collection object with schema. 22 | func newCollection(schema *Schema, name string) (*Collection, error) { 23 | if schema == nil || schema.session == nil { 24 | return nil, fmt.Errorf("session closed") 25 | } 26 | 27 | if name == "" { 28 | return nil, fmt.Errorf("invalid name") 29 | } 30 | 31 | return &Collection{ 32 | schema: schema, 33 | session: schema.session, 34 | name: name, 35 | }, nil 36 | } 37 | 38 | func (c *Collection) String() string { 39 | return fmt.Sprintf("", c.name, c.schema) 40 | } 41 | 42 | // Name returns the collection. 43 | func (c *Collection) Name() string { 44 | return c.name 45 | } 46 | 47 | func (c *Collection) CheckExistence(ctx context.Context) error { 48 | names, err := c.schema.objectNames(ctx, ObjectCollection) 49 | if err != nil { 50 | return err 51 | } 52 | 53 | if _, ok := slices.BinarySearch(names, c.name); !ok { 54 | return ErrNotAvailable 55 | } 56 | 57 | return nil 58 | } 59 | 60 | func (c *Collection) Add(object ...any) *Add { 61 | 62 | return NewAdd(c).Add(object...) 63 | } 64 | 65 | func (c *Collection) Remove(ctx context.Context) *Add { 66 | 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /xmysql/collection/create.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package collection 4 | 5 | type CreateOptions struct { 6 | ReuseExisting bool 7 | } 8 | 9 | type CreateOption func(opts *CreateOptions) 10 | 11 | func NewCreateOptions(opts ...CreateOption) *CreateOptions { 12 | options := &CreateOptions{} 13 | 14 | for _, opt := range opts { 15 | opt(options) 16 | } 17 | 18 | return options 19 | } 20 | 21 | func CreateReuseExisting() CreateOption { 22 | return func(opts *CreateOptions) { 23 | opts.ReuseExisting = true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /xmysql/collection/get.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package collection 4 | 5 | type GetOptions struct { 6 | ValidateExistence bool 7 | } 8 | 9 | type GetOption func(opts *GetOptions) 10 | 11 | func NewGetOptions(opts ...GetOption) *GetOptions { 12 | options := &GetOptions{} 13 | 14 | for _, opt := range opts { 15 | opt(options) 16 | } 17 | 18 | return options 19 | } 20 | 21 | func GetValidateExistence() GetOption { 22 | return func(opts *GetOptions) { 23 | opts.ValidateExistence = true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /xmysql/collection_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "sort" 10 | "testing" 11 | 12 | "github.com/golistic/xgo/xt" 13 | 14 | "github.com/golistic/pxmysql/internal/xxt" 15 | "github.com/golistic/pxmysql/null" 16 | "github.com/golistic/pxmysql/xmysql" 17 | ) 18 | 19 | type Person struct { 20 | Name string `json:"name"` 21 | Age int `json:"age"` 22 | } 23 | 24 | func crudTestCollection(t *testing.T, name string) (*xmysql.Schema, *xmysql.Collection) { 25 | config := &xmysql.ConnectConfig{ 26 | Address: testContext.XPluginAddr, 27 | Username: xxt.UserNative, 28 | Schema: "pxmysql_tests", 29 | } 30 | config.SetPassword(xxt.UserNativePwd) 31 | 32 | ses, err := xmysql.GetSession(context.Background(), config) 33 | xt.OK(t, err) 34 | 35 | schema, err := ses.GetSchema(context.Background()) 36 | xt.OK(t, err) 37 | 38 | c, err := schema.CreateCollection(context.Background(), name) 39 | xt.OK(t, err) 40 | 41 | return schema, c 42 | } 43 | 44 | func TestCollection_Add(t *testing.T) { 45 | schema, coll := crudTestCollection(t, "person_2987dk8dj0s") 46 | 47 | t.Run("can only add struct", func(t *testing.T) { 48 | err := coll.Add(1).GetError() 49 | xt.KO(t, err) 50 | xt.Eq(t, "unsupported object kind int", err.Error()) 51 | }) 52 | 53 | t.Run("object as pointer value", func(t *testing.T) { 54 | xt.OK(t, coll.Add(&Person{Name: "Alice"}).GetError()) 55 | }) 56 | 57 | t.Run("object as value", func(t *testing.T) { 58 | xt.OK(t, coll.Add(Person{Name: "Alice,c"}).GetError()) 59 | }) 60 | 61 | t.Run("execute stores data", func(t *testing.T) { 62 | xt.OK(t, coll. 63 | Add(&Person{Name: "Laurie", Age: 19}). 64 | Add(&Person{Name: "Nadya", Age: 54}, &Person{Name: "Lucas", Age: 32}). 65 | Execute(context.Background())) 66 | exp := []string{"Laurie", "Nadya", "Lucas"} 67 | sort.Strings(exp) 68 | 69 | ses := schema.GetSession() 70 | res, err := ses.ExecuteStatement(context.Background(), "SELECT doc FROM person_2987dk8dj0s") 71 | xt.OK(t, err) 72 | 73 | var got []string 74 | for _, row := range res.Rows { 75 | doc, ok := row.Values[0].(null.Bytes) 76 | xt.Assert(t, ok, "null.Bytes") 77 | p := Person{} 78 | xt.OK(t, json.Unmarshal(doc.Bytes, &p)) 79 | got = append(got, p.Name) 80 | } 81 | sort.Strings(got) 82 | 83 | xt.Eq(t, exp, got) 84 | }) 85 | 86 | t.Run("execute to return error stored by adding", func(t *testing.T) { 87 | adder := coll.Add(&Person{Name: "Laurie", Age: 19}).Add("something not OK") 88 | err := adder.Execute(context.Background()) 89 | xt.KO(t, err) 90 | xt.Eq(t, "unsupported object kind string", errors.Unwrap(err).Error()) 91 | }) 92 | } 93 | -------------------------------------------------------------------------------- /xmysql/connection_config.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "github.com/golistic/xgo/xstrings" 7 | ) 8 | 9 | // ConnectConfig manages the configuration of a connection to a MySQL server. 10 | type ConnectConfig struct { 11 | Address string 12 | UnixSockAddr string 13 | Username string 14 | Password *string 15 | Schema string 16 | UseTLS bool 17 | AuthMethod AuthMethodType 18 | TLSServerCACertPath string `envVar:"PXMYSQL_CA_CERT"` 19 | TimeZoneName string 20 | } 21 | 22 | // DefaultConnectConfig is the default configuration used if none is provided 23 | // when a Connection is instantiated. 24 | var DefaultConnectConfig = &ConnectConfig{ 25 | Address: "127.0.0.1:33060", // note that the port number is of X Plugin 26 | Username: "root", 27 | Password: xstrings.Pointer(""), 28 | Schema: "", 29 | UseTLS: false, 30 | AuthMethod: AuthMethodAuto, 31 | } 32 | 33 | // Clone duplicates other, but leaves the password nil. The caller must 34 | // save the password. 35 | func (cfg *ConnectConfig) Clone() *ConnectConfig { 36 | return &ConnectConfig{ 37 | Address: cfg.Address, 38 | UnixSockAddr: cfg.UnixSockAddr, 39 | Username: cfg.Username, 40 | Password: nil, 41 | Schema: cfg.Schema, 42 | UseTLS: cfg.UseTLS, 43 | AuthMethod: cfg.AuthMethod, 44 | TLSServerCACertPath: cfg.TLSServerCACertPath, 45 | TimeZoneName: cfg.TimeZoneName, 46 | } 47 | } 48 | 49 | // SetPassword sets the password within cfg. If no password is provided, 50 | // the Password-field of cfg will be nil. 51 | // Panics when p has more than 1 element. 52 | func (cfg *ConnectConfig) SetPassword(p ...string) *ConnectConfig { 53 | switch len(p) { 54 | case 1: 55 | cfg.Password = xstrings.Pointer(p[0]) 56 | case 0: 57 | cfg.Password = nil 58 | default: 59 | panic("accepting only 1 optional string") 60 | } 61 | 62 | return cfg 63 | } 64 | -------------------------------------------------------------------------------- /xmysql/context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "time" 8 | ) 9 | 10 | type CtxKey struct{} 11 | 12 | var CtxTimeLocation = &CtxKey{} 13 | 14 | var DefaultTimeLocation = time.UTC 15 | 16 | // SetContextTimeLocation sets the time location used when decoding MySQL DATETIME and 17 | // TIMESTAMP to Go `time.Time` objects. If l is nil, it is unset, and default will 18 | // be used. 19 | func SetContextTimeLocation(ctx context.Context, l *time.Location) context.Context { 20 | return context.WithValue(ctx, CtxTimeLocation, l) 21 | } 22 | 23 | // ContextTimeLocation retrieves the time location set in context used when decoding 24 | // MySQL DATETIME and TIMESTAMP to Go `time.Time`. If none is defined in context, 25 | // or a none `*time.Location` was found, the default will be returned. 26 | func ContextTimeLocation(ctx context.Context) *time.Location { 27 | if v := ctx.Value(CtxTimeLocation); v != nil { 28 | if l, ok := v.(*time.Location); ok { 29 | return l 30 | } 31 | } 32 | 33 | return DefaultTimeLocation 34 | } 35 | -------------------------------------------------------------------------------- /xmysql/context_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xt" 10 | ) 11 | 12 | func TestContextTimeLocation(t *testing.T) { 13 | t.Run("no time location in context", func(t *testing.T) { 14 | xt.Eq(t, DefaultTimeLocation.String(), ContextTimeLocation(context.Background()).String()) 15 | }) 16 | } 17 | -------------------------------------------------------------------------------- /xmysql/crud_add.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "reflect" 9 | 10 | "github.com/golistic/xgo/xstrings" 11 | 12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxcrud" 13 | mysqlxexpr "github.com/golistic/pxmysql/internal/mysqlx/mysqlxexpr" 14 | "github.com/golistic/pxmysql/xmysql/xproto" 15 | ) 16 | 17 | type cruder interface { 18 | Execute(ctx context.Context) error 19 | GetError() error 20 | } 21 | 22 | type adder interface { 23 | Add(object ...any) *Add 24 | } 25 | 26 | type Add struct { 27 | collection *Collection 28 | values []any 29 | err error 30 | } 31 | 32 | var ( 33 | _ cruder = (*Add)(nil) 34 | _ adder = (*Add)(nil) 35 | ) 36 | 37 | func NewAdd(c *Collection) *Add { 38 | 39 | return &Add{collection: c} 40 | } 41 | 42 | // Add adds object to the queue. 43 | func (a *Add) Add(objects ...any) *Add { 44 | 45 | for _, object := range objects { 46 | rt := reflect.TypeOf(object) 47 | if reflect.ValueOf(object).Kind() == reflect.Pointer { 48 | rt = rt.Elem() 49 | } 50 | if rt.Kind() != reflect.Struct { 51 | a.err = fmt.Errorf("unsupported object kind %s", rt.Kind()) 52 | } 53 | 54 | a.values = append(a.values, object) 55 | } 56 | 57 | return a 58 | } 59 | 60 | func (a *Add) Execute(ctx context.Context) error { 61 | 62 | errBaseMsg := "adding to collection %s (%w)" 63 | 64 | if a.err != nil { 65 | return fmt.Errorf(errBaseMsg, a.collection.name, a.err) 66 | } 67 | 68 | rows := make([]*mysqlxcrud.Insert_TypedRow, len(a.values)) 69 | 70 | for i, v := range a.values { 71 | rows[i] = &mysqlxcrud.Insert_TypedRow{ 72 | Field: []*mysqlxexpr.Expr{ 73 | { 74 | Type: mysqlxexpr.Expr_OBJECT.Enum(), 75 | Object: xproto.StructExpr(v), 76 | }, 77 | }, 78 | } 79 | } 80 | 81 | msg := &mysqlxcrud.Insert{ 82 | Collection: &mysqlxcrud.Collection{ 83 | Name: xstrings.Pointer(a.collection.Name()), 84 | Schema: xstrings.Pointer(a.collection.schema.Name()), 85 | }, 86 | DataModel: mysqlxcrud.DataModel_DOCUMENT.Enum(), 87 | Projection: nil, 88 | Row: rows, 89 | } 90 | 91 | ses := a.collection.schema.GetSession() 92 | if err := ses.Write(ctx, msg); err != nil { 93 | return fmt.Errorf(errBaseMsg, a.collection.name, err) 94 | } 95 | 96 | _, err := ses.handleResult(ctx, func(r *Result) bool { 97 | return r.stmtOK 98 | }) 99 | if err != nil { 100 | return fmt.Errorf(errBaseMsg, a.collection.name, err) 101 | } 102 | 103 | return nil 104 | } 105 | 106 | func (a *Add) GetError() error { 107 | 108 | return a.err 109 | } 110 | -------------------------------------------------------------------------------- /xmysql/defaults.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | const authChallengeLen = 20 6 | 7 | const ( 8 | AuthMethodPlain AuthMethodType = "PLAIN" 9 | AuthMethodAuto AuthMethodType = "AUTO" 10 | AuthMethodSHA256Memory AuthMethodType = "SHA256_MEMORY" 11 | AuthMethodMySQL41 AuthMethodType = "MYSQL41" 12 | ) 13 | 14 | const DefaultPort = "33060" 15 | const DefaultHost = "127.0.0.1" 16 | 17 | type AuthMethodType string 18 | 19 | type AuthMethodTypes []AuthMethodType 20 | 21 | func (a AuthMethodTypes) Has(m AuthMethodType) bool { 22 | for _, v := range a { 23 | if v == m { 24 | return true 25 | } 26 | } 27 | return false 28 | } 29 | 30 | var defaultAuthMethods = []AuthMethodType{AuthMethodMySQL41, AuthMethodSHA256Memory} 31 | 32 | var supportedAuthMethods = AuthMethodTypes{AuthMethodSHA256Memory, AuthMethodMySQL41, AuthMethodPlain, AuthMethodAuto} 33 | 34 | func DefaultAuthMethods() []AuthMethodType { 35 | return defaultAuthMethods 36 | } 37 | 38 | func SupportedAuthMethods() AuthMethodTypes { 39 | return supportedAuthMethods 40 | } 41 | -------------------------------------------------------------------------------- /xmysql/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import "fmt" 6 | 7 | var ErrNotAvailable = fmt.Errorf("not available") 8 | -------------------------------------------------------------------------------- /xmysql/examples_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/golistic/pxmysql/null" 11 | "github.com/golistic/pxmysql/xmysql" 12 | ) 13 | 14 | func ExampleGetSession_auto_notls() { 15 | config := &xmysql.ConnectConfig{ 16 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml 17 | Username: "user_native", 18 | } 19 | config.SetPassword("pwd_user_native") 20 | 21 | session, err := xmysql.GetSession(context.Background(), config) 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | fmt.Println("TLS:", session.UsesTLS()) 26 | // Output: TLS: false 27 | } 28 | 29 | func ExampleGetSession_plain_withtls() { 30 | config := &xmysql.ConnectConfig{ 31 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml 32 | AuthMethod: xmysql.AuthMethodPlain, 33 | UseTLS: true, 34 | Username: "user_native", 35 | } 36 | config.SetPassword("pwd_user_native") 37 | 38 | session, err := xmysql.GetSession(context.Background(), config) 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | 43 | fmt.Println("TLS:", session.UsesTLS()) 44 | fmt.Println("Auth Method:", config.AuthMethod) 45 | // Output: 46 | // TLS: true 47 | // Auth Method: PLAIN 48 | } 49 | 50 | func ExampleSession_ExecuteStatement() { 51 | config := &xmysql.ConnectConfig{ 52 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml 53 | AuthMethod: xmysql.AuthMethodPlain, 54 | UseTLS: true, 55 | Username: "user_native", 56 | } 57 | config.SetPassword("pwd_user_native") 58 | 59 | session, err := xmysql.GetSession(context.Background(), config) 60 | if err != nil { 61 | log.Fatal(err) 62 | } 63 | 64 | q := "SELECT ?, STR_TO_DATE('2005-03-01 07:00:01', '%Y-%m-%d %H:%i:%s')" 65 | res, err := session.ExecuteStatement(context.Background(), q, "started") 66 | if err != nil { 67 | log.Fatal(err) 68 | } 69 | 70 | for _, row := range res.Rows { 71 | fmt.Printf("%s at %s\n", row.Values[0].(string), row.Values[1].(null.Time).Time) 72 | } 73 | 74 | // Output: 75 | // started at 2005-03-01 07:00:01 +0000 UTC 76 | } 77 | -------------------------------------------------------------------------------- /xmysql/internal/network/messages.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "database/sql/driver" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "os" 12 | 13 | "google.golang.org/protobuf/proto" 14 | 15 | "github.com/golistic/pxmysql/interfaces" 16 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx" 17 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxconnection" 18 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxcrud" 19 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxprepare" 20 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsession" 21 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql" 22 | "github.com/golistic/pxmysql/mysqlerrors" 23 | ) 24 | 25 | type ServerMessage struct { 26 | msgType int 27 | payload []byte 28 | } 29 | 30 | var _ interfaces.ServerMessager = &ServerMessage{} 31 | 32 | var maxServerMessageType int32 33 | 34 | func init() { 35 | for n := range mysqlx.ServerMessages_Type_name { 36 | if n > maxServerMessageType { 37 | maxServerMessageType = n 38 | } 39 | } 40 | } 41 | 42 | func (m *ServerMessage) Unmarshall(into proto.Message) error { 43 | if err := UnmarshalPartial(m.payload, into); err != nil { 44 | return fmt.Errorf("failed unmarshalling server message type %s (%w)", 45 | mysqlx.ServerMessages_Type(m.msgType).String(), err) 46 | } 47 | return nil 48 | } 49 | 50 | func (m *ServerMessage) ServerMessageType() mysqlx.ServerMessages_Type { 51 | return mysqlx.ServerMessages_Type(m.msgType) 52 | } 53 | 54 | func readMessage(r io.Reader) (*ServerMessage, error) { 55 | var header [5]byte 56 | if n, err := io.ReadFull(r, header[:]); err != nil { 57 | if errors.Is(err, os.ErrDeadlineExceeded) { 58 | err = os.ErrDeadlineExceeded 59 | } 60 | if errors.Is(err, io.EOF) && n < 5 { 61 | return nil, fmt.Errorf("broken pipe when reading (%w)", driver.ErrBadConn) 62 | } 63 | return nil, fmt.Errorf("failed reading message header (%w)", err) 64 | } 65 | 66 | if header[4] == 0x0a || int32(header[4]) > maxServerMessageType { 67 | return nil, mysqlerrors.New(2007) 68 | } 69 | 70 | msg := &ServerMessage{ 71 | msgType: int(header[4]), 72 | } 73 | 74 | msg.payload = make([]byte, binary.LittleEndian.Uint32(header[0:4])-1) 75 | if _, err := io.ReadFull(r, msg.payload); err != nil { 76 | if errors.Is(err, os.ErrDeadlineExceeded) { 77 | err = os.ErrDeadlineExceeded 78 | } 79 | return nil, fmt.Errorf("failed reading message payload (%w)", err) 80 | } 81 | 82 | return msg, nil 83 | } 84 | 85 | func clientMessageType(msg proto.Message) (mysqlx.ClientMessages_Type, error) { 86 | // cases ordered as ClientMessage_Type constants 87 | switch msg.(type) { 88 | case *mysqlxconnection.CapabilitiesGet: 89 | return mysqlx.ClientMessages_CON_CAPABILITIES_GET, nil 90 | case *mysqlxconnection.CapabilitiesSet: 91 | return mysqlx.ClientMessages_CON_CAPABILITIES_SET, nil 92 | case *mysqlxconnection.Close: 93 | return mysqlx.ClientMessages_CON_CLOSE, nil 94 | 95 | case *mysqlxprepare.Execute: 96 | return mysqlx.ClientMessages_PREPARE_EXECUTE, nil 97 | case *mysqlxprepare.Prepare: 98 | return mysqlx.ClientMessages_PREPARE_PREPARE, nil 99 | case *mysqlxprepare.Deallocate: 100 | return mysqlx.ClientMessages_PREPARE_DEALLOCATE, nil 101 | 102 | case *mysqlxsession.AuthenticateStart: 103 | return mysqlx.ClientMessages_SESS_AUTHENTICATE_START, nil 104 | case *mysqlxsession.AuthenticateContinue: 105 | return mysqlx.ClientMessages_SESS_AUTHENTICATE_CONTINUE, nil 106 | case *mysqlxsession.Reset: 107 | return mysqlx.ClientMessages_SESS_RESET, nil 108 | case *mysqlxsession.Close: 109 | return mysqlx.ClientMessages_SESS_CLOSE, nil 110 | 111 | case *mysqlxcrud.Insert: 112 | return mysqlx.ClientMessages_CRUD_INSERT, nil 113 | 114 | case *mysqlxsql.StmtExecute: 115 | return mysqlx.ClientMessages_SQL_STMT_EXECUTE, nil 116 | default: 117 | return 0, fmt.Errorf("unsupported message '%T'", msg) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /xmysql/internal/network/proto.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "google.golang.org/protobuf/encoding/protowire" 7 | "google.golang.org/protobuf/proto" 8 | ) 9 | 10 | const NamespaceMySQLx = "mysqlx" 11 | 12 | // UnmarshalPartial parses the wire-format message in b and places the result in m. 13 | // The provided message must be mutable (e.g., a non-nil pointer to a message). 14 | // This is the same function as proto.Unmarshall except that AllowPartial option set to true. 15 | func UnmarshalPartial(b []byte, m proto.Message) error { 16 | return proto.UnmarshalOptions{ 17 | RecursionLimit: protowire.DefaultRecursionLimit, 18 | AllowPartial: true, 19 | }.Unmarshal(b, m) 20 | } 21 | -------------------------------------------------------------------------------- /xmysql/internal/network/read.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "time" 12 | 13 | "github.com/golistic/pxmysql/mysqlerrors" 14 | ) 15 | 16 | // Read reads a message from the network connection conn. 17 | // If no deadline is set in ctx, a default will be used. 18 | func Read(ctx context.Context, conn net.Conn) (*ServerMessage, error) { 19 | 20 | deadline, ok := ctx.Deadline() 21 | if !ok { 22 | deadline = time.Now().Add(10 * time.Second) 23 | } 24 | 25 | if err := conn.SetReadDeadline(deadline); err != nil { 26 | return nil, fmt.Errorf("setting read deadline (%w)", err) 27 | } 28 | 29 | msg, err := readMessage(conn) 30 | if err != nil { 31 | if err == io.EOF { 32 | return nil, io.EOF 33 | } 34 | 35 | var myErr *mysqlerrors.Error 36 | if errors.As(err, &myErr) { 37 | return nil, myErr 38 | } 39 | 40 | return nil, err 41 | } 42 | 43 | Trace("r", msg) 44 | 45 | return msg, nil 46 | } 47 | -------------------------------------------------------------------------------- /xmysql/internal/network/tls.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "crypto/x509" 7 | "fmt" 8 | "os" 9 | "sync" 10 | ) 11 | 12 | var ServerCAPool *x509.CertPool 13 | var muServerCAPool sync.RWMutex 14 | 15 | func init() { 16 | ServerCAPool = x509.NewCertPool() 17 | } 18 | 19 | func addServerCACert(certs []byte) error { 20 | 21 | muServerCAPool.Lock() 22 | defer muServerCAPool.Unlock() 23 | 24 | if ok := ServerCAPool.AppendCertsFromPEM(certs); !ok { 25 | return fmt.Errorf("appending CA certificate to pool") 26 | } 27 | return nil 28 | } 29 | 30 | func AddServerCACertFromFile(filename string) error { 31 | 32 | certs, err := os.ReadFile(filename) 33 | if err != nil { 34 | return fmt.Errorf("reading server CA certificate (%w)", err) 35 | } 36 | 37 | return addServerCACert(certs) 38 | } 39 | -------------------------------------------------------------------------------- /xmysql/internal/network/trace.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "os" 9 | 10 | "google.golang.org/protobuf/proto" 11 | 12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice" 13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql" 14 | ) 15 | 16 | var ( 17 | traceReadWrites bool 18 | TraceValues bool 19 | ) 20 | 21 | func init() { 22 | _, traceReadWrites = os.LookupEnv("PXMYSQL_TRACE") 23 | _, TraceValues = os.LookupEnv("PXMYSQL_TRACE_VALUES") 24 | if !traceReadWrites { 25 | TraceValues = false 26 | } 27 | } 28 | 29 | // Trace is used for debugging and is enabled by setting the PYMYSQL_TRACE 30 | // environment variable. 31 | func Trace(action string, msg any, a ...any) { 32 | if !traceReadWrites || msg == nil { 33 | return 34 | } 35 | 36 | var indicator string 37 | 38 | switch action { 39 | case "w", "write": 40 | indicator = "\n> write:" 41 | case "r", "read": 42 | indicator = "< read" 43 | case "un", "unhandled": 44 | indicator = "< unhandled " 45 | case "error": 46 | indicator = "< ERROR " 47 | case "state": 48 | indicator = "\t< STATE " 49 | default: 50 | indicator = "< unknown" 51 | } 52 | 53 | prefix := "\t" 54 | 55 | var s string 56 | var topic string 57 | switch v := msg.(type) { 58 | case *ServerMessage: 59 | topic = v.ServerMessageType().String() 60 | case *mysqlxnotice.SessionStateChanged: 61 | topic = v.GetParam().String() 62 | doc, err := json.MarshalIndent(v.Value, prefix, " ") 63 | if err != nil { 64 | panic(err) 65 | } 66 | if doc[1] != '}' { 67 | s = fmt.Sprintf(" %s\n", string(doc)) 68 | } 69 | case *mysqlxsql.StmtExecute: 70 | s = " SQL Statement: " + string(v.Stmt) + "\n" 71 | case proto.Message: 72 | topic = string(v.ProtoReflect().Descriptor().Name()) 73 | doc, err := json.MarshalIndent(v, prefix, " ") 74 | if err != nil { 75 | panic(err) 76 | } 77 | if doc[1] != '}' { 78 | s = fmt.Sprintf(prefix+"%s\n", string(doc)) 79 | } 80 | case string: 81 | topic = v 82 | default: 83 | topic = fmt.Sprintf("unhandled %T", msg) 84 | } 85 | 86 | _, err := fmt.Fprintf(os.Stderr, indicator+" "+topic+"\n"+s) 87 | if err != nil { 88 | panic(err) 89 | } 90 | 91 | if len(a) > 0 { 92 | _, err := fmt.Fprintf(os.Stderr, prefix+fmt.Sprint(a...)+"\n") 93 | if err != nil { 94 | panic(err) 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /xmysql/internal/network/write.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package network 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "encoding/binary" 9 | "errors" 10 | "fmt" 11 | "net" 12 | "syscall" 13 | 14 | "google.golang.org/protobuf/proto" 15 | 16 | "github.com/golistic/pxmysql/mysqlerrors" 17 | ) 18 | 19 | // Write writes protobuf msg using conn to the server. 20 | func Write(ctx context.Context, conn net.Conn, msg proto.Message, maxAllowedPacket int) error { 21 | 22 | if conn == nil { 23 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn) 24 | } 25 | 26 | msgType, err := clientMessageType(msg) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | deadline, _ := ctx.Deadline() 32 | if err := conn.SetWriteDeadline(deadline); err != nil { 33 | return fmt.Errorf("failed setting write deadline (%w)", err) 34 | } 35 | 36 | b, err := proto.Marshal(msg) 37 | if err != nil { 38 | return fmt.Errorf("failed marshalling protobuf message (%w)", err) 39 | } 40 | 41 | if maxAllowedPacket > 0 && len(b) > maxAllowedPacket { 42 | return mysqlerrors.New(mysqlerrors.ClientNetPacketTooLarge) 43 | } 44 | 45 | var header [5]byte 46 | binary.LittleEndian.PutUint32(header[:], uint32(len(b))+1) // +1 is final \x00 47 | 48 | header[4] = byte(msgType) 49 | 50 | buf := &net.Buffers{header[:], b} 51 | _, err = buf.WriteTo(conn) 52 | switch { 53 | case errors.Is(err, syscall.EPIPE): 54 | return fmt.Errorf("broken pipe when writing (%w)", driver.ErrBadConn) 55 | case err != nil: 56 | return fmt.Errorf("failed sending message (%w)", err) 57 | } 58 | 59 | Trace("w", msg) 60 | 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /xmysql/internal/statements/quote.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package statements 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | // QuoteValue quotes p so that it can be safely used to substituted placeholders 11 | // within a SQL query. 12 | func QuoteValue(p any) (string, error) { 13 | 14 | switch v := p.(type) { 15 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 16 | return fmt.Sprintf("%d", v), nil 17 | case []byte: 18 | return fmt.Sprintf("_binary'%x'", v), nil 19 | case float32, float64: 20 | return fmt.Sprintf("%f", v), nil 21 | case string: 22 | return "'" + strings.ReplaceAll(v, "'", `\'`) + "'", nil 23 | default: 24 | return "", fmt.Errorf("cannot quote parameter with value type %T", p) 25 | } 26 | } 27 | 28 | func QuoteIdentifier(p string) (string, error) { 29 | return "`" + strings.Replace(p, "`", "``", -1) + "`", nil 30 | } 31 | -------------------------------------------------------------------------------- /xmysql/internal/statements/quote_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package statements_test 4 | 5 | import ( 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xt" 10 | 11 | "github.com/golistic/pxmysql/xmysql/internal/statements" 12 | ) 13 | 14 | func TestQuoteValue(t *testing.T) { 15 | t.Run("strings", func(t *testing.T) { 16 | var cases = []struct { 17 | got string 18 | exp string 19 | }{ 20 | { 21 | got: "Gopher", 22 | exp: "'Gopher'", 23 | }, 24 | { 25 | got: "'Gopher'", 26 | exp: `'\'Gopher\''`, 27 | }, 28 | { 29 | got: "'poop'; DROP TABLE gophers", 30 | exp: `'\'poop\'; DROP TABLE gophers'`, 31 | }, 32 | { 33 | got: "🐰", 34 | exp: `'🐰'`, 35 | }, 36 | } 37 | 38 | for i, c := range cases { 39 | t.Run(strconv.Itoa(i+1), func(t *testing.T) { 40 | got, err := statements.QuoteValue(c.got) 41 | xt.OK(t, err) 42 | xt.Eq(t, c.exp, got) 43 | }) 44 | } 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /xmysql/internal/statements/statement_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package statements_test 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | 10 | "github.com/golistic/pxmysql/xmysql/internal/statements" 11 | ) 12 | 13 | func TestPlaceholderIndexes(t *testing.T) { 14 | var cases = []struct { 15 | stmt string 16 | exp []int 17 | }{ 18 | { 19 | stmt: `SELECT ?`, 20 | exp: []int{7}, 21 | }, 22 | { 23 | stmt: `SELECT ?, '?', "?"`, 24 | exp: []int{7}, 25 | }, 26 | { 27 | stmt: `SELECT ?, '?', "?", ?`, 28 | exp: []int{7, 20}, 29 | }, 30 | 31 | { 32 | stmt: `SELECT ?, '?', "?", ?, "'?'", ?`, 33 | exp: []int{7, 20, 30}, 34 | }, 35 | } 36 | 37 | for _, c := range cases { 38 | t.Run("", func(t *testing.T) { 39 | xt.Eq(t, c.exp, statements.PlaceholderIndexes(statements.Placeholder, c.stmt)) 40 | }) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /xmysql/internal/statements/statements.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package statements 4 | 5 | import ( 6 | "bytes" 7 | "fmt" 8 | 9 | "github.com/golistic/xgo/xmath" 10 | ) 11 | 12 | const Placeholder = '?' 13 | 14 | // SubstitutePlaceholders replaces the placeholders within stmt with respective element of args. 15 | func SubstitutePlaceholders(stmt string, args ...any) (string, error) { 16 | 17 | placeholders := PlaceholderIndexes(Placeholder, stmt) 18 | if len(placeholders) != len(args) { 19 | return "", fmt.Errorf("need %d placeholder(s); found %d)", len(args), len(placeholders)) 20 | } 21 | 22 | var nextArg int 23 | var buf []byte 24 | 25 | var index int 26 | for _, ph := range placeholders { 27 | buf = append(buf, stmt[index:ph]...) 28 | 29 | arg := args[nextArg] 30 | nextArg++ 31 | index = ph + 1 32 | 33 | if arg == nil { 34 | buf = append(buf, "NULL"...) 35 | continue 36 | } 37 | 38 | quoted, err := QuoteValue(arg) 39 | if err != nil { 40 | return "", err 41 | } 42 | buf = append(buf, quoted...) 43 | } 44 | 45 | // rest of stmt 46 | buf = append(buf, stmt[index:]...) 47 | 48 | if len(args) > nextArg { 49 | return "", fmt.Errorf("%d argument(s) not substituted", xmath.AbsInt(len(args)-nextArg)) 50 | } else if len(args) < nextArg { 51 | return "", fmt.Errorf("%d placeholder(s) not substituted", xmath.AbsInt(len(args)-nextArg)) 52 | } 53 | 54 | return string(buf), nil 55 | } 56 | 57 | // PlaceholderIndexes returns the indices of all placeholders within query. 58 | func PlaceholderIndexes(placeholder rune, query string) []int { 59 | 60 | var indexes []int 61 | 62 | var quoted bool 63 | var quote rune 64 | for i, r := range bytes.Runes([]byte(query)) { 65 | // we skip quoted so that we support queries which have placeholder in string literals 66 | if r == '"' || r == '\'' { 67 | if quoted && quote == r { 68 | quoted = false 69 | quote = 0 70 | continue 71 | } else if !quoted { 72 | quoted = true 73 | quote = r 74 | continue 75 | } 76 | } 77 | 78 | if quoted { 79 | continue 80 | } 81 | 82 | if r == placeholder { 83 | indexes = append(indexes, i) 84 | } 85 | } 86 | 87 | return indexes 88 | } 89 | -------------------------------------------------------------------------------- /xmysql/main_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "runtime/debug" 9 | "strconv" 10 | "testing" 11 | 12 | "github.com/golistic/pxmysql/internal/xxt" 13 | ) 14 | 15 | var ( 16 | testExitCode int 17 | testErr error 18 | testSchema = "pxmysql_tests" 19 | testContext *xxt.TestContext 20 | ) 21 | 22 | var ( 23 | testMySQLMaxAllowedPacket = -1 // MySQL's mysqlx_max_allowed_packet 24 | ) 25 | 26 | func testTearDown() { 27 | if testErr != nil { 28 | testExitCode = 1 29 | fmt.Println(testErr) 30 | } 31 | } 32 | 33 | func TestMain(m *testing.M) { 34 | defer func() { os.Exit(testExitCode) }() 35 | defer testTearDown() 36 | defer func() { 37 | if r := recover(); r != nil { 38 | fmt.Println(string(debug.Stack())) 39 | os.Exit(1) 40 | } 41 | }() 42 | 43 | var err error 44 | if testContext, testErr = xxt.New(testSchema); err != nil { 45 | return 46 | } 47 | 48 | if err := testContext.Server.LoadSQLScript("base"); err != nil { 49 | testErr = fmt.Errorf("failed testing MySQL running in container %s (%s)", 50 | testContext.Server.Container.Name, err) 51 | return 52 | } 53 | 54 | if err := testContext.Server.Container.CopyFileFromContainer( 55 | "/etc/mysql/conf.d/ca.pem", "_testdata/mysql_ca.pem"); err != nil { 56 | testErr = fmt.Errorf("failed copying MySQL CA certificate from container %s (%s)", 57 | testContext.Server.Container.Name, err) 58 | return 59 | } 60 | 61 | if v, err := testContext.Server.Variable("global", "mysqlx_max_allowed_packet"); err != nil { 62 | testErr = fmt.Errorf("failed getting variable mysqlx_max_allowed_packet (%s)", err) 63 | return 64 | } else { 65 | n, err := strconv.ParseInt(v, 10, 32) 66 | if err != nil { 67 | testErr = fmt.Errorf("failed converting variable mysqlx_max_allowed_packet (%s)", err) 68 | return 69 | } 70 | testMySQLMaxAllowedPacket = int(n) 71 | } 72 | 73 | testExitCode = m.Run() 74 | } 75 | -------------------------------------------------------------------------------- /xmysql/notice.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "fmt" 7 | 8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice" 9 | "github.com/golistic/pxmysql/xmysql/internal/network" 10 | ) 11 | 12 | type StateChanges struct { 13 | ClientID uint64 14 | GeneratedInsertID uint64 15 | RowsAffected uint64 16 | CurrentSchema string 17 | ProducedMessage string 18 | } 19 | 20 | type notices struct { 21 | warnings []*mysqlxnotice.Warning 22 | sessionVariableChanges []*mysqlxnotice.SessionVariableChanged 23 | sessionStateChanges []*mysqlxnotice.SessionStateChanged 24 | groupReplicationStateChanges []*mysqlxnotice.GroupReplicationStateChanged 25 | serverHello *mysqlxnotice.ServerHello 26 | unhandled []mysqlxnotice.Frame_Type 27 | stateChanges StateChanges 28 | } 29 | 30 | func (n *notices) add(msg *network.ServerMessage) error { 31 | frame := &mysqlxnotice.Frame{} 32 | if err := msg.Unmarshall(frame); err != nil { 33 | return fmt.Errorf("failed unmarshalling notice message (%w)", err) 34 | } 35 | 36 | switch mysqlxnotice.Frame_Type(frame.GetType()) { 37 | case mysqlxnotice.Frame_WARNING: 38 | m := &mysqlxnotice.Warning{} 39 | if err := msg.Unmarshall(m); err != nil { 40 | return err 41 | } 42 | n.warnings = append(n.warnings, m) 43 | case mysqlxnotice.Frame_SESSION_VARIABLE_CHANGED: 44 | m := &mysqlxnotice.SessionVariableChanged{} 45 | if err := msg.Unmarshall(m); err != nil { 46 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err) 47 | } 48 | n.sessionVariableChanges = append(n.sessionVariableChanges, m) 49 | case mysqlxnotice.Frame_SESSION_STATE_CHANGED: 50 | m := &mysqlxnotice.SessionStateChanged{} 51 | if err := network.UnmarshalPartial(frame.Payload, m); err != nil { 52 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err) 53 | } 54 | network.Trace("state", m) 55 | 56 | switch m.GetParam() { 57 | case mysqlxnotice.SessionStateChanged_GENERATED_INSERT_ID: 58 | if len(m.Value) > 0 { 59 | n.stateChanges.GeneratedInsertID = m.Value[0].GetVUnsignedInt() 60 | } 61 | case mysqlxnotice.SessionStateChanged_ROWS_AFFECTED: 62 | if len(m.Value) > 0 { 63 | n.stateChanges.RowsAffected = m.Value[0].GetVUnsignedInt() 64 | } 65 | case mysqlxnotice.SessionStateChanged_CURRENT_SCHEMA: 66 | if len(m.Value) > 0 { 67 | n.stateChanges.CurrentSchema = string(m.Value[0].VString.Value) 68 | } 69 | case mysqlxnotice.SessionStateChanged_PRODUCED_MESSAGE: 70 | if len(m.Value) > 0 { 71 | n.stateChanges.ProducedMessage = string(m.Value[0].VString.Value) 72 | } 73 | case mysqlxnotice.SessionStateChanged_CLIENT_ID_ASSIGNED: 74 | if len(m.Value) > 0 { 75 | n.stateChanges.ClientID = m.Value[0].GetVUnsignedInt() 76 | } 77 | } 78 | 79 | n.sessionStateChanges = append(n.sessionStateChanges, m) 80 | case mysqlxnotice.Frame_GROUP_REPLICATION_STATE_CHANGED: 81 | m := &mysqlxnotice.GroupReplicationStateChanged{} 82 | if err := msg.Unmarshall(m); err != nil { 83 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err) 84 | } 85 | n.groupReplicationStateChanges = append(n.groupReplicationStateChanges, m) 86 | case mysqlxnotice.Frame_SERVER_HELLO: 87 | m := &mysqlxnotice.ServerHello{} 88 | if err := msg.Unmarshall(m); err != nil { 89 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err) 90 | } 91 | n.serverHello = m 92 | default: 93 | n.unhandled = append(n.unhandled, mysqlxnotice.Frame_Type(frame.GetType())) 94 | } 95 | 96 | return nil 97 | } 98 | -------------------------------------------------------------------------------- /xmysql/prepared.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "fmt" 9 | "strings" 10 | "time" 11 | 12 | "github.com/golistic/pxmysql/decimal" 13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes" 14 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxprepare" 15 | "github.com/golistic/pxmysql/xmysql/xproto" 16 | ) 17 | 18 | type Prepared struct { 19 | session *Session 20 | result *Result 21 | numPlaceholders int 22 | } 23 | 24 | // Execute the prepared statements replacing placeholders with args. 25 | func (p *Prepared) Execute(ctx context.Context, args ...any) (*Result, error) { 26 | if p.session == nil || p.result == nil || p.result.stmtID == 0 { 27 | return nil, fmt.Errorf("not initialized") 28 | } 29 | 30 | pArgs := make([]*mysqlxdatatypes.Any, len(args)) 31 | 32 | for i, arg := range args { 33 | var err error 34 | 35 | var a any 36 | switch v := arg.(type) { 37 | case driver.NamedValue: 38 | a = v.Value 39 | default: 40 | a = arg 41 | } 42 | 43 | // ridiculous type-switch; preventing using reflection 44 | switch v := a.(type) { 45 | case nil: 46 | pArgs[i] = xproto.Nil() 47 | case bool: 48 | pArgs[i] = xproto.Bool(v) 49 | case *bool: 50 | pArgs[i] = xproto.Bool(*v) 51 | case int: 52 | pArgs[i] = xproto.SignedInt(v) 53 | case int8: 54 | pArgs[i] = xproto.SignedInt(v) 55 | case int16: 56 | pArgs[i] = xproto.SignedInt(v) 57 | case int32: 58 | pArgs[i] = xproto.SignedInt(v) 59 | case int64: 60 | pArgs[i] = xproto.SignedInt(v) 61 | case uint: 62 | pArgs[i] = xproto.UnsignedInt(v) 63 | case uint8: 64 | pArgs[i] = xproto.UnsignedInt(v) 65 | case uint16: 66 | pArgs[i] = xproto.UnsignedInt(v) 67 | case uint32: 68 | pArgs[i] = xproto.UnsignedInt(v) 69 | case uint64: 70 | pArgs[i] = xproto.UnsignedInt(v) 71 | case *int: 72 | pArgs[i] = xproto.SignedInt(*v) 73 | case *int8: 74 | pArgs[i] = xproto.SignedInt(*v) 75 | case *int16: 76 | pArgs[i] = xproto.SignedInt(*v) 77 | case *int32: 78 | pArgs[i] = xproto.SignedInt(*v) 79 | case *int64: 80 | pArgs[i] = xproto.SignedInt(*v) 81 | case *uint: 82 | pArgs[i] = xproto.UnsignedInt(*v) 83 | case *uint8: 84 | pArgs[i] = xproto.UnsignedInt(*v) 85 | case *uint16: 86 | pArgs[i] = xproto.UnsignedInt(*v) 87 | case *uint32: 88 | pArgs[i] = xproto.UnsignedInt(*v) 89 | case *uint64: 90 | pArgs[i] = xproto.UnsignedInt(*v) 91 | case string: 92 | pArgs[i] = xproto.String(v) 93 | case *string: 94 | pArgs[i] = xproto.String(v) 95 | case []byte: 96 | pArgs[i] = xproto.Bytes(v) 97 | case float32: 98 | pArgs[i] = xproto.Float32(v) 99 | case *float32: 100 | pArgs[i] = xproto.Float32(*v) 101 | case float64: 102 | pArgs[i] = xproto.Float64(v) 103 | case *float64: 104 | pArgs[i] = xproto.Float64(*v) 105 | case decimal.Decimal: 106 | pArgs[i] = xproto.Decimal(v) 107 | case *decimal.Decimal: 108 | pArgs[i] = xproto.Decimal(*v) 109 | case time.Time: 110 | if pArgs[i], err = xproto.Time(v, p.session.TimeLocation().String()); err != nil { 111 | return nil, err 112 | } 113 | case *time.Time: 114 | if pArgs[i], err = xproto.Time(*v, p.session.TimeLocation().String()); err != nil { 115 | return nil, err 116 | } 117 | case []string: 118 | pArgs[i] = xproto.String(strings.Join(v, ",")) 119 | default: 120 | return nil, fmt.Errorf("argument type '%T' not supported", a) 121 | } 122 | } 123 | 124 | if err := p.session.Write(ctx, &mysqlxprepare.Execute{ 125 | StmtId: &p.result.stmtID, 126 | Args: pArgs, 127 | }); err != nil { 128 | return nil, err 129 | } 130 | 131 | res, err := p.session.handleResult(ctx, func(r *Result) bool { 132 | return r.stmtOK 133 | }) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | return res, nil 139 | } 140 | 141 | // Deallocate makes this prepared statement not usable any longer. 142 | func (p *Prepared) Deallocate(ctx context.Context) error { 143 | return p.session.DeallocatePrepareStatement(ctx, p.result.stmtID) 144 | } 145 | 146 | // StatementID returns the statement ID. 147 | func (p *Prepared) StatementID() uint32 { 148 | return p.result.stmtID 149 | } 150 | 151 | // NumPlaceholders returns the number of placeholder parameters. 152 | func (p *Prepared) NumPlaceholders() int { 153 | return p.numPlaceholders 154 | } 155 | -------------------------------------------------------------------------------- /xmysql/result_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "testing" 9 | "time" 10 | 11 | "github.com/golistic/xgo/xt" 12 | 13 | "github.com/golistic/pxmysql/internal/xxt" 14 | "github.com/golistic/pxmysql/xmysql" 15 | ) 16 | 17 | func TestResult_FetchRow(t *testing.T) { 18 | config := &xmysql.ConnectConfig{ 19 | Address: testContext.XPluginAddr, 20 | Username: xxt.UserNative, 21 | } 22 | config.SetPassword(xxt.UserNativePwd) 23 | 24 | tbl := "bulk_fidiEfiS223" 25 | 26 | ses, err := xmysql.GetSession(context.Background(), config) 27 | xt.OK(t, err) 28 | xt.OK(t, ses.SetActiveSchema(context.Background(), testSchema)) 29 | 30 | createTable := fmt.Sprintf( 31 | "CREATE TABLE `%s` (id INT AUTO_INCREMENT PRIMARY KEY, c1 VARCHAR(30) NOT NULL)", tbl) 32 | 33 | _, err = ses.ExecuteStatement(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl)) 34 | xt.OK(t, err) 35 | 36 | _, err = ses.ExecuteStatement(context.Background(), createTable) 37 | xt.OK(t, err) 38 | 39 | nrRows := 100 40 | for i := 0; i < nrRows; i++ { 41 | _, err = ses.ExecuteStatement(context.Background(), 42 | fmt.Sprintf("INSERT INTO `%s` (c1) VALUES (?)", tbl), fmt.Sprintf("data%d", i+1)) 43 | xt.OK(t, err) 44 | } 45 | 46 | t.Run("fetch", func(t *testing.T) { 47 | ses, err := xmysql.GetSession(context.Background(), config) 48 | xt.OK(t, err) 49 | xt.OK(t, ses.SetActiveSchema(context.Background(), testSchema)) 50 | 51 | mUse := xxt.NewMemoryUse() 52 | res, err := ses.ExecuteStatement(context.Background(), 53 | fmt.Sprintf("SELECT * FROM `%s` ORDER BY id", tbl)) 54 | xt.OK(t, err) 55 | xt.Eq(t, nrRows, len(res.Rows)) 56 | 57 | rowCtx, cancel := context.WithTimeout(context.Background(), time.Second) 58 | defer cancel() 59 | 60 | for i := 1; res.Row != nil; i++ { 61 | id := res.Row.Values[0].(int64) 62 | xt.Eq(t, i, id) 63 | xt.Eq(t, fmt.Sprintf("data%d", i), res.Row.Values[1].(string)) 64 | 65 | err = res.FetchRow(rowCtx) 66 | xt.OK(t, err) 67 | } 68 | mUse.Stop() 69 | 70 | // keep allocations in check (if nrRows changes, this will obviously go up) 71 | xt.Assert(t, mUse.DiffAlloc() < 35000) 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /xmysql/schema.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "sort" 9 | 10 | "github.com/golistic/pxmysql/null" 11 | "github.com/golistic/pxmysql/xmysql/collection" 12 | "github.com/golistic/pxmysql/xmysql/xproto" 13 | ) 14 | 15 | type ObjectKind string 16 | 17 | const ( 18 | ObjectCollection ObjectKind = "COLLECTION" 19 | ) 20 | 21 | // Schema defines the representation of a database schema. It provides 22 | // functionality to access the schema's contents. 23 | type Schema struct { 24 | session *Session 25 | name string 26 | } 27 | 28 | // newSchema instantiates a new Schema object using session. If name is the 29 | // empty string, the current schema of session will be used. 30 | func newSchema(session *Session, name string) (*Schema, error) { 31 | if session == nil { 32 | return nil, fmt.Errorf("session closed") 33 | } 34 | 35 | return &Schema{ 36 | session: session, 37 | name: name, 38 | }, nil 39 | } 40 | 41 | func (s *Schema) String() string { 42 | return fmt.Sprintf("", s.name, s.session) 43 | } 44 | 45 | // Name returns the schema or database name. 46 | func (s *Schema) Name() string { 47 | return s.name 48 | } 49 | 50 | // GetSession returns the underlying session of s. 51 | func (s *Schema) GetSession() *Session { 52 | return s.session 53 | } 54 | 55 | // GetCollection retrieve the collection using its name. 56 | // To keep compatible with behavior seen it MySQL connectors, when the collection 57 | // does not exist, by default, no error is returned. To return ErrNotAvailable instead, 58 | // use the functional option collection.GetValidateExistence. 59 | func (s *Schema) GetCollection(ctx context.Context, name string, options ...collection.GetOption) (*Collection, error) { 60 | 61 | c, err := newCollection(s, name) 62 | if err != nil { 63 | return nil, fmt.Errorf("getting collection (%w)", err) 64 | } 65 | 66 | opts := collection.NewGetOptions(options...) 67 | if opts.ValidateExistence { 68 | if err := c.CheckExistence(ctx); err != nil { 69 | return nil, err 70 | } 71 | } 72 | 73 | return c, nil 74 | } 75 | 76 | // GetCollections retrieve all available collections (does not include views or tables). 77 | func (s *Schema) GetCollections(ctx context.Context) ([]*Collection, error) { 78 | 79 | names, err := s.objectNames(ctx, ObjectCollection) 80 | if err != nil { 81 | return nil, fmt.Errorf("getting collections (%w)", err) 82 | } 83 | 84 | if len(names) == 0 { 85 | return nil, nil 86 | } 87 | 88 | collections := make([]*Collection, len(names)) 89 | for i, name := range names { 90 | collections[i], err = newCollection(s, name) 91 | if err != nil { 92 | return nil, fmt.Errorf("getting collections (%w)", err) 93 | } 94 | } 95 | 96 | return collections, nil 97 | } 98 | 99 | // CreateCollection creates a new collection. If the functional option 100 | // collection.CreateReuseExisting is used, no error is reported when collection 101 | // already exists. 102 | func (s *Schema) CreateCollection(ctx context.Context, name string, 103 | options ...collection.CreateOption) (*Collection, error) { 104 | 105 | c, err := newCollection(s, name) 106 | if err != nil { 107 | return nil, fmt.Errorf("creating collection (%w)", err) 108 | } 109 | 110 | opts := collection.NewCreateOptions(options...) 111 | 112 | args := xproto.CommandArgs( 113 | xproto.ObjectField("schema", s.name), 114 | xproto.ObjectField("name", name), 115 | xproto.ObjectField("options", xproto.ObjectFields{ 116 | xproto.ObjectField("reuse_existing", opts.ReuseExisting), 117 | }), 118 | ) 119 | 120 | _, err = s.session.ExecCommand(ctx, "create_collection", args) 121 | if err != nil { 122 | return nil, fmt.Errorf("creating collection (%w)", err) 123 | } 124 | 125 | return c, nil 126 | } 127 | 128 | // DropCollection drops the collection. 129 | func (s *Schema) DropCollection(ctx context.Context, name string) error { 130 | 131 | args := xproto.CommandArgs( 132 | xproto.ObjectField("schema", s.name), 133 | xproto.ObjectField("name", name), 134 | ) 135 | 136 | _, err := s.session.ExecCommand(ctx, "drop_collection", args) 137 | if err != nil { 138 | return fmt.Errorf("dropping collection (%w)", err) 139 | } 140 | 141 | return nil 142 | } 143 | 144 | func (s *Schema) objectNames(ctx context.Context, kind ObjectKind) ([]string, error) { 145 | 146 | args := xproto.CommandArgs( 147 | xproto.ObjectField("schema", s.name), 148 | ) 149 | 150 | if err := s.session.Write(ctx, xproto.Command("list_objects", args)); err != nil { 151 | return nil, err 152 | } 153 | 154 | res, err := s.session.handleResult(ctx, func(r *Result) bool { 155 | return r.stmtOK 156 | }) 157 | if err != nil { 158 | return nil, err 159 | } 160 | 161 | if len(res.Rows) == 0 { 162 | return nil, nil 163 | } 164 | 165 | var names []string 166 | for _, row := range res.Rows { 167 | objType, ok := row.Values[1].(string) 168 | if !ok || objType != string(kind) { 169 | continue 170 | } 171 | 172 | name, ok := row.Values[0].(null.String) 173 | if !ok || !name.Valid { 174 | continue 175 | } 176 | 177 | names = append(names, name.String) 178 | } 179 | 180 | if len(names) == 0 { 181 | return nil, nil 182 | } 183 | 184 | sort.Strings(names) 185 | 186 | return names, nil 187 | } 188 | -------------------------------------------------------------------------------- /xmysql/schema_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xmysql_test 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "sort" 10 | "testing" 11 | 12 | "github.com/golistic/xgo/xstrings" 13 | "github.com/golistic/xgo/xt" 14 | 15 | "github.com/golistic/pxmysql/internal/xxt" 16 | "github.com/golistic/pxmysql/xmysql" 17 | "github.com/golistic/pxmysql/xmysql/collection" 18 | ) 19 | 20 | func TestSchema_GetSession(t *testing.T) { 21 | 22 | t.Run("session of schema returned", func(t *testing.T) { 23 | 24 | config := &xmysql.ConnectConfig{ 25 | Address: testContext.XPluginAddr, 26 | Username: xxt.UserNative, 27 | } 28 | config.SetPassword(xxt.UserNativePwd) 29 | 30 | ctx := context.Background() 31 | 32 | exp, err := xmysql.GetSession(ctx, config) 33 | xt.OK(t, err) 34 | 35 | for i := 0; i < 10; i++ { 36 | schema, err := exp.GetSchema(ctx) 37 | xt.OK(t, err) 38 | 39 | got := schema.GetSession() 40 | xt.Assert(t, got != nil, "expected not nil") 41 | xt.Eq(t, exp, got) 42 | } 43 | }) 44 | 45 | t.Run("no session returns nil", func(t *testing.T) { 46 | 47 | xt.Eq(t, nil, (&xmysql.Schema{}).GetSession()) 48 | }) 49 | } 50 | 51 | func TestSchema_GetCollections(t *testing.T) { 52 | 53 | config := &xmysql.ConnectConfig{ 54 | Address: testContext.XPluginAddr, 55 | Username: xxt.UserNative, 56 | } 57 | config.SetPassword(xxt.UserNativePwd) 58 | 59 | t.Run("all collections", func(t *testing.T) { 60 | 61 | xt.OK(t, testContext.Server.LoadSQLScript("schema_collections")) 62 | 63 | ses, err := xmysql.GetSession(context.Background(), config) 64 | xt.OK(t, err) 65 | 66 | exp := []string{"collection_wic28skwixkd", "collection_weux73293jsnsj"} 67 | sort.Strings(exp) 68 | 69 | schema, err := ses.GetSchemaWithName(context.Background(), "pxmysql_tests") 70 | xt.OK(t, err) 71 | 72 | collections, err := schema.GetCollections(context.Background()) 73 | xt.OK(t, err) 74 | 75 | xt.Assert(t, len(collections) >= len(exp), fmt.Sprintf("expected at least %d", len(exp))) 76 | 77 | var got []string 78 | for _, s := range collections { 79 | got = append(got, s.Name()) 80 | } 81 | sort.Strings(got) 82 | 83 | xt.Assert(t, func(exp, got []string) bool { 84 | if len(exp) > len(got) { 85 | return false 86 | } 87 | for _, l := range exp { 88 | if !xstrings.SliceHas(got, l) { 89 | return false 90 | } 91 | } 92 | return true 93 | }(exp, got)) 94 | }) 95 | } 96 | 97 | func TestSchema_CreateCollection(t *testing.T) { 98 | 99 | config := &xmysql.ConnectConfig{ 100 | Address: testContext.XPluginAddr, 101 | Username: xxt.UserNative, 102 | Schema: "pxmysql_tests", 103 | } 104 | config.SetPassword(xxt.UserNativePwd) 105 | 106 | ses, err := xmysql.GetSession(context.Background(), config) 107 | xt.OK(t, err) 108 | 109 | schema, err := ses.GetSchema(context.Background()) 110 | xt.OK(t, err) 111 | 112 | t.Run("name is required", func(t *testing.T) { 113 | 114 | ctx := context.Background() 115 | _, err := schema.CreateCollection(ctx, "") 116 | xt.KO(t, err) 117 | xt.Eq(t, "creating collection (invalid name)", err.Error()) 118 | }) 119 | 120 | t.Run("create and drop", func(t *testing.T) { 121 | 122 | ctx := context.Background() 123 | name := "ciwejkuwmidi2938x" 124 | c, err := schema.CreateCollection(ctx, name) 125 | xt.OK(t, err) 126 | xt.Eq(t, name, c.Name()) 127 | 128 | t.Run("check existence", func(t *testing.T) { 129 | c, err := schema.GetCollection(ctx, name, collection.GetValidateExistence()) 130 | xt.OK(t, err) 131 | xt.Eq(t, name, c.Name()) 132 | 133 | t.Run("drop", func(t *testing.T) { 134 | err := schema.DropCollection(ctx, name) 135 | xt.OK(t, err) 136 | 137 | err = c.CheckExistence(ctx) 138 | xt.KO(t, err) 139 | xt.Assert(t, errors.Is(err, xmysql.ErrNotAvailable)) 140 | }) 141 | }) 142 | }) 143 | 144 | t.Run("reuse existing", func(t *testing.T) { 145 | 146 | ctx := context.Background() 147 | 148 | name := "eovwo28373" 149 | _, err := schema.CreateCollection(ctx, name) 150 | xt.OK(t, err) 151 | 152 | _, err = schema.CreateCollection(ctx, name) 153 | xt.KO(t, err) 154 | xt.Eq(t, "table 'eovwo28373' already exists [1050:42S01]", errors.Unwrap(err).Error()) 155 | 156 | _, err = schema.CreateCollection(ctx, name, collection.CreateReuseExisting()) 157 | xt.OK(t, err) 158 | }) 159 | } 160 | -------------------------------------------------------------------------------- /xmysql/xproto/command.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto 4 | 5 | import ( 6 | "github.com/golistic/xgo/xstrings" 7 | 8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes" 9 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql" 10 | "github.com/golistic/pxmysql/xmysql/internal/network" 11 | ) 12 | 13 | func Command(command string, args *mysqlxdatatypes.Any) *mysqlxsql.StmtExecute { 14 | return &mysqlxsql.StmtExecute{ 15 | Namespace: xstrings.Pointer(network.NamespaceMySQLx), 16 | Stmt: []byte(command), 17 | Args: []*mysqlxdatatypes.Any{args}, 18 | } 19 | } 20 | 21 | func CommandArgs(fields ...*mysqlxdatatypes.Object_ObjectField) *mysqlxdatatypes.Any { 22 | return &mysqlxdatatypes.Any{ 23 | Type: mysqlxdatatypes.Any_OBJECT.Enum(), 24 | Obj: &mysqlxdatatypes.Object{ 25 | Fld: fields, 26 | }, 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /xmysql/xproto/expr.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto 4 | 5 | import ( 6 | "reflect" 7 | "strings" 8 | 9 | "github.com/golistic/xgo/xstrings" 10 | 11 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxexpr" 12 | ) 13 | 14 | func Expr[T reflect.Value | any](value T) *mysqlxexpr.Expr { 15 | 16 | var rv reflect.Value 17 | 18 | switch v := any(value).(type) { 19 | case nil: 20 | return nil 21 | case reflect.Value: 22 | rv = v 23 | default: 24 | rv = reflect.ValueOf(v) 25 | } 26 | 27 | switch reflect.Indirect(rv).Kind() { 28 | case reflect.Slice: 29 | return &mysqlxexpr.Expr{ 30 | Type: mysqlxexpr.Expr_ARRAY.Enum(), 31 | Array: sliceExpr(rv), 32 | } 33 | case reflect.Struct: 34 | return &mysqlxexpr.Expr{ 35 | Type: mysqlxexpr.Expr_OBJECT.Enum(), 36 | Object: StructExpr(rv.Interface()), 37 | } 38 | default: 39 | return &mysqlxexpr.Expr{ 40 | Type: mysqlxexpr.Expr_LITERAL.Enum(), 41 | Literal: Scalar(rv), 42 | } 43 | } 44 | } 45 | 46 | func StructExpr(object any) *mysqlxexpr.Object { 47 | 48 | rv := reflect.Indirect(reflect.ValueOf(object)) 49 | 50 | rt := reflect.TypeOf(object) 51 | if rt.Kind() == reflect.Pointer { 52 | rt = reflect.TypeOf(object).Elem() 53 | } 54 | 55 | obj := &mysqlxexpr.Object{} 56 | 57 | for i := 0; i < rv.NumField(); i++ { 58 | rvf := rv.Field(i) 59 | if !rvf.CanInterface() { 60 | continue 61 | } 62 | rtf := rt.Field(i) 63 | 64 | name := rtf.Name 65 | tag := rtf.Tag.Get("json") 66 | if tag != "" { 67 | if tag == "-" || (strings.HasSuffix(tag, ",omitempty") && rvf.IsZero()) { 68 | continue 69 | } 70 | name = strings.Replace(tag, ",omitempty", "", -1) 71 | } 72 | 73 | obj.Fld = append(obj.Fld, &mysqlxexpr.Object_ObjectField{ 74 | Key: xstrings.Pointer(name), 75 | Value: Expr(rvf), 76 | }) 77 | } 78 | 79 | return obj 80 | } 81 | 82 | func sliceExpr(value reflect.Value) *mysqlxexpr.Array { 83 | array := &mysqlxexpr.Array{ 84 | Value: make([]*mysqlxexpr.Expr, value.Len()), 85 | } 86 | 87 | for i := 0; i < value.Len(); i++ { 88 | array.Value[i] = Expr(value.Index(i)) 89 | } 90 | 91 | return array 92 | } 93 | -------------------------------------------------------------------------------- /xmysql/xproto/expr_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/golistic/xgo/xt" 9 | ) 10 | 11 | type Department struct { 12 | Name string `json:"name"` 13 | } 14 | 15 | type Person struct { 16 | Name string `json:"name"` 17 | Age int `json:"age"` 18 | Department Department `json:"department,omitempty"` 19 | } 20 | 21 | func TestExpr(t *testing.T) { 22 | t.Run("object with nested object", func(t *testing.T) { 23 | expr := Expr(&Person{Name: "Alice", Age: 36, Department: Department{Name: "Engineering"}}) 24 | got := expr.Object 25 | 26 | xt.Eq(t, "name", *got.Fld[0].Key) 27 | xt.Eq(t, "Alice", string(got.Fld[0].Value.Literal.VString.Value)) 28 | xt.Eq(t, "age", *got.Fld[1].Key) 29 | xt.Eq(t, int64(36), *got.Fld[1].Value.Literal.VSignedInt) 30 | xt.Eq(t, "department", *got.Fld[2].Key) 31 | xt.Eq(t, "Engineering", string(got.Fld[2].Value.Object.Fld[0].Value.Literal.VString.Value)) 32 | }) 33 | 34 | t.Run("array of objects", func(t *testing.T) { 35 | exp := []*Person{ 36 | {Name: "Alice", Age: 36}, 37 | {Name: "Bob", Age: 34}, 38 | } 39 | 40 | expr := Expr(exp) 41 | array := expr.Array 42 | xt.Eq(t, 2, len(array.Value)) 43 | 44 | for i, v := range array.Value { 45 | got := v.Object.Fld[0] 46 | xt.Eq(t, "name", *got.Key) 47 | xt.Eq(t, exp[i].Name, got.Value.Literal.VString.Value) 48 | } 49 | }) 50 | } 51 | -------------------------------------------------------------------------------- /xmysql/xproto/fields.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto 4 | 5 | import ( 6 | "github.com/golistic/xgo/xstrings" 7 | 8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes" 9 | ) 10 | 11 | type ObjectFields = []*mysqlxdatatypes.Object_ObjectField 12 | 13 | func ObjectField[T ~string | ~bool | ~[]*mysqlxdatatypes.Object_ObjectField](key string, value T) *mysqlxdatatypes.Object_ObjectField { 14 | f := &mysqlxdatatypes.Object_ObjectField{ 15 | Key: xstrings.Pointer(key), 16 | } 17 | 18 | switch v := any(value).(type) { 19 | case string: 20 | f.Value = String(v) 21 | case bool: 22 | f.Value = Bool(v) 23 | case []*mysqlxdatatypes.Object_ObjectField: 24 | f.Value = &mysqlxdatatypes.Any{ 25 | Type: mysqlxdatatypes.Any_OBJECT.Enum(), 26 | Obj: &mysqlxdatatypes.Object{ 27 | Fld: v, 28 | }, 29 | } 30 | default: 31 | panic("unsupported value type") 32 | } 33 | 34 | return f 35 | } 36 | -------------------------------------------------------------------------------- /xmysql/xproto/scalar.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto 4 | 5 | import ( 6 | "fmt" 7 | "reflect" 8 | "time" 9 | 10 | "golang.org/x/exp/constraints" 11 | "google.golang.org/protobuf/proto" 12 | 13 | "github.com/golistic/pxmysql/decimal" 14 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes" 15 | ) 16 | 17 | func Scalar[T reflect.Value | any](value T) *mysqlxdatatypes.Scalar { 18 | 19 | var rv reflect.Value 20 | 21 | switch v := any(value).(type) { 22 | case nil: 23 | return NilScalar() 24 | case reflect.Value: 25 | rv = v 26 | default: 27 | rv = reflect.ValueOf(v) 28 | } 29 | 30 | switch { 31 | case rv.Kind() == reflect.Slice: 32 | switch { 33 | case rv.Type().Elem().Kind() == reflect.Uint8 && rv.Len() == 0: // empty []byte 34 | return NilScalar() 35 | case rv.Type().Elem().Kind() == reflect.Uint8: // []byte 36 | return BytesScalar(rv.Bytes()) 37 | default: 38 | panic("unsupported scalar slice") 39 | } 40 | case rv.Kind() == reflect.Pointer && rv.IsNil(): 41 | return NilScalar() 42 | } 43 | 44 | rv = reflect.Indirect(rv) 45 | 46 | switch rv.Kind() { 47 | case reflect.Bool: 48 | return BoolScalar(rv.Bool()) 49 | case reflect.String: 50 | return StringScalar(rv.String()) 51 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 52 | return UnsignedIntScalar(rv.Uint()) 53 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 54 | return SignedIntScalar(rv.Int()) 55 | case reflect.Float32: 56 | return Float32Scalar(float32(rv.Float())) 57 | case reflect.Float64: 58 | return Float64Scalar(rv.Float()) 59 | default: 60 | panic(fmt.Sprintf("unsupported scalar value; was %s", rv.Kind())) 61 | } 62 | } 63 | 64 | func Bool(value bool) *mysqlxdatatypes.Any { 65 | return &mysqlxdatatypes.Any{ 66 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 67 | Scalar: BoolScalar(value), 68 | } 69 | } 70 | 71 | func BoolScalar(value bool) *mysqlxdatatypes.Scalar { 72 | return &mysqlxdatatypes.Scalar{ 73 | Type: mysqlxdatatypes.Scalar_V_BOOL.Enum(), 74 | VBool: proto.Bool(value), 75 | } 76 | } 77 | 78 | func Nil() *mysqlxdatatypes.Any { 79 | return &mysqlxdatatypes.Any{ 80 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 81 | Scalar: NilScalar(), 82 | } 83 | } 84 | 85 | func NilScalar() *mysqlxdatatypes.Scalar { 86 | return &mysqlxdatatypes.Scalar{ 87 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(), 88 | } 89 | } 90 | 91 | func SignedInt[T constraints.Signed](value T) *mysqlxdatatypes.Any { 92 | return &mysqlxdatatypes.Any{ 93 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 94 | Scalar: SignedIntScalar(value), 95 | } 96 | } 97 | 98 | func SignedIntScalar[T constraints.Signed](value T) *mysqlxdatatypes.Scalar { 99 | return &mysqlxdatatypes.Scalar{ 100 | Type: mysqlxdatatypes.Scalar_V_SINT.Enum(), 101 | VSignedInt: proto.Int64(int64(value)), 102 | } 103 | } 104 | 105 | func UnsignedInt[T constraints.Unsigned](value T) *mysqlxdatatypes.Any { 106 | return &mysqlxdatatypes.Any{ 107 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 108 | Scalar: UnsignedIntScalar(value), 109 | } 110 | } 111 | 112 | func UnsignedIntScalar[T constraints.Unsigned](value T) *mysqlxdatatypes.Scalar { 113 | return &mysqlxdatatypes.Scalar{ 114 | Type: mysqlxdatatypes.Scalar_V_UINT.Enum(), 115 | VUnsignedInt: proto.Uint64(uint64(value)), 116 | } 117 | } 118 | 119 | func String(value any) *mysqlxdatatypes.Any { 120 | var v string 121 | if value != nil { 122 | switch sv := value.(type) { 123 | case string: 124 | v = sv 125 | case *string: 126 | if sv != nil { 127 | v = *sv 128 | } else { 129 | return Nil() 130 | } 131 | default: 132 | panic(fmt.Sprintf("String accepts string or *string; not %T", value)) 133 | } 134 | } else { 135 | return Nil() 136 | } 137 | 138 | return &mysqlxdatatypes.Any{ 139 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 140 | Scalar: StringScalar(v), 141 | } 142 | } 143 | 144 | func StringScalar[T ~string](value T) *mysqlxdatatypes.Scalar { 145 | return &mysqlxdatatypes.Scalar{ 146 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(), 147 | VString: &mysqlxdatatypes.Scalar_String{ 148 | Value: []byte(value), 149 | }, 150 | } 151 | } 152 | 153 | func Bytes[T ~[]byte](value T) *mysqlxdatatypes.Any { 154 | v := []byte(value) 155 | return &mysqlxdatatypes.Any{ 156 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 157 | Scalar: BytesScalar(v), 158 | } 159 | } 160 | 161 | func BytesScalar[T ~[]byte](value T) *mysqlxdatatypes.Scalar { 162 | return &mysqlxdatatypes.Scalar{ 163 | Type: mysqlxdatatypes.Scalar_V_OCTETS.Enum(), 164 | VOctets: &mysqlxdatatypes.Scalar_Octets{ 165 | Value: value, 166 | }, 167 | } 168 | } 169 | 170 | func Float32[T ~float32](value T) *mysqlxdatatypes.Any { 171 | v := float32(value) 172 | return &mysqlxdatatypes.Any{ 173 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 174 | Scalar: Float32Scalar(v), 175 | } 176 | } 177 | 178 | func Float32Scalar[T ~float32](value T) *mysqlxdatatypes.Scalar { 179 | return &mysqlxdatatypes.Scalar{ 180 | Type: mysqlxdatatypes.Scalar_V_FLOAT.Enum(), 181 | VFloat: proto.Float32(float32(value)), 182 | } 183 | } 184 | 185 | func Float64[T ~float64](value T) *mysqlxdatatypes.Any { 186 | v := float64(value) 187 | return &mysqlxdatatypes.Any{ 188 | Type: mysqlxdatatypes.Any_SCALAR.Enum(), 189 | Scalar: Float64Scalar(v), 190 | } 191 | } 192 | 193 | func Float64Scalar[T ~float64](value T) *mysqlxdatatypes.Scalar { 194 | return &mysqlxdatatypes.Scalar{ 195 | Type: mysqlxdatatypes.Scalar_V_DOUBLE.Enum(), 196 | VDouble: proto.Float64(float64(value)), 197 | } 198 | } 199 | 200 | func Decimal(value decimal.Decimal) *mysqlxdatatypes.Any { 201 | // MySQL X Protocol does not support sending the encoded BCD. 202 | return String(value.String()) 203 | } 204 | 205 | func Time(value time.Time, timeZoneName string) (*mysqlxdatatypes.Any, error) { 206 | tz, err := time.LoadLocation(timeZoneName) 207 | if err != nil { 208 | return nil, err 209 | } 210 | return String(value.In(tz).Format("2006-01-02 15:04:05.999999")), nil 211 | } 212 | -------------------------------------------------------------------------------- /xmysql/xproto/scalar_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Geert JM Vanderkelen 2 | 3 | package xproto_test 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/golistic/xgo/xstrings" 10 | "github.com/golistic/xgo/xt" 11 | 12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes" 13 | "github.com/golistic/pxmysql/xmysql/xproto" 14 | ) 15 | 16 | func TestScalar(t *testing.T) { 17 | t.Run("basic types", func(t *testing.T) { 18 | var nilString *string 19 | 20 | var cases = []struct { 21 | have any 22 | exp *mysqlxdatatypes.Scalar 23 | }{ 24 | { 25 | have: "gopher", 26 | exp: &mysqlxdatatypes.Scalar{ 27 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(), 28 | VString: &mysqlxdatatypes.Scalar_String{ 29 | Value: []byte("gopher"), 30 | }, 31 | }, 32 | }, 33 | { 34 | have: "", 35 | exp: &mysqlxdatatypes.Scalar{ 36 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(), 37 | VString: &mysqlxdatatypes.Scalar_String{ 38 | Value: []byte(""), 39 | }, 40 | }, 41 | }, 42 | { 43 | have: xstrings.Pointer("gopher"), 44 | exp: &mysqlxdatatypes.Scalar{ 45 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(), 46 | VString: &mysqlxdatatypes.Scalar_String{ 47 | Value: []byte("gopher"), 48 | }, 49 | }, 50 | }, 51 | { 52 | have: nilString, 53 | exp: &mysqlxdatatypes.Scalar{ 54 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(), 55 | }, 56 | }, 57 | { 58 | have: nil, 59 | exp: &mysqlxdatatypes.Scalar{ 60 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(), 61 | }, 62 | }, 63 | { 64 | have: []byte("gopher"), 65 | exp: &mysqlxdatatypes.Scalar{ 66 | Type: mysqlxdatatypes.Scalar_V_OCTETS.Enum(), 67 | VOctets: &mysqlxdatatypes.Scalar_Octets{ 68 | Value: []byte("gopher"), 69 | }, 70 | }, 71 | }, 72 | { 73 | have: []byte{}, 74 | exp: &mysqlxdatatypes.Scalar{ 75 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(), 76 | }, 77 | }, 78 | } 79 | 80 | for _, c := range cases { 81 | t.Run(fmt.Sprintf("%T", c.have), func(t *testing.T) { 82 | got := xproto.Scalar(c.have) 83 | xt.Eq(t, c.exp, got) 84 | }) 85 | } 86 | }) 87 | 88 | } 89 | --------------------------------------------------------------------------------