├── .gitignore
├── .golangci.yml
├── .idea
├── .gitignore
├── codeStyles
│ └── codeStyleConfig.xml
├── copyright
│ ├── Geert_JM.xml
│ └── profiles_settings.xml
├── dataSources.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── modules.xml
├── pxmysql.iml
├── sqldialects.xml
├── vcs.xml
└── watcherTasks.xml
├── CHANGELOG.yaml
├── LICENSE.md
├── README.md
├── _badges
├── badges.json
├── go-version.svg
└── license.svg
├── _support
└── pxmysql-compose
│ ├── .gitignore
│ ├── conf.d
│ ├── .gitignore
│ ├── 01_basics.cnf
│ ├── 02_tls.cnf
│ ├── ca-key.pem
│ ├── ca.pem
│ ├── ca.srl
│ ├── client-cert.pem
│ ├── client-key.pem
│ ├── server-cert.pem
│ └── server-key.pem
│ ├── conf.srl
│ ├── docker-compose.yml
│ ├── generate_tls.sh
│ ├── server_x509_ext.conf
│ └── shared
│ ├── build.sh
│ └── goapps
│ └── unix_socket
│ ├── go.mod
│ ├── go.sum
│ └── main.go
├── cmd
├── .gitignore
├── gencollations
│ └── main.go
├── genprotobuf
│ └── main.go
└── make
│ └── main.go
├── connection.go
├── connection_test.go
├── connector.go
├── datasource.go
├── datasource_test.go
├── decimal
├── decimal.go
└── decimal_test.go
├── driver.go
├── driver_test.go
├── errors.go
├── go.mod
├── go.sum
├── interfaces
└── message.go
├── internal
├── mysqlx
│ ├── info.md
│ ├── mysqlx
│ │ └── mysqlx.pb.go
│ ├── mysqlxconnection
│ │ └── mysqlx_connection.pb.go
│ ├── mysqlxcrud
│ │ └── mysqlx_crud.pb.go
│ ├── mysqlxcursor
│ │ └── mysqlx_cursor.pb.go
│ ├── mysqlxdatatypes
│ │ └── mysqlx_datatypes.pb.go
│ ├── mysqlxexpect
│ │ └── mysqlx_expect.pb.go
│ ├── mysqlxexpr
│ │ └── mysqlx_expr.pb.go
│ ├── mysqlxnotice
│ │ └── mysqlx_notice.pb.go
│ ├── mysqlxprepare
│ │ └── mysqlx_prepare.pb.go
│ ├── mysqlxresultset
│ │ └── mysqlx_resultset.pb.go
│ ├── mysqlxsession
│ │ └── mysqlx_session.pb.go
│ └── mysqlxsql
│ │ └── mysqlx_sql.pb.go
└── xxt
│ ├── builder.go
│ ├── context.go
│ ├── credentials.go
│ ├── docker.go
│ ├── errors.go
│ ├── memory.go
│ └── server.go
├── main_test.go
├── mysqlerrors
├── error.go
├── error_client.go
├── error_test.go
└── test_errors
│ └── mysqlerrors_test.go
├── null
├── bytes.go
├── bytes_test.go
├── decimal.go
├── decimal_test.go
├── duration.go
├── duration_test.go
├── float32.go
├── float32_test.go
├── float64.go
├── float64_test.go
├── int64.go
├── int64_test.go
├── main.go
├── main_test.go
├── string.go
├── string_test.go
├── strings.go
├── strings_test.go
├── time.go
├── time_test.go
├── uint64.go
└── uint64_test.go
├── register
├── mysql
│ ├── register.go
│ └── register_test.go
├── register.go
└── register_test.go
├── result.go
├── rows.go
├── rows_test.go
├── statement.go
├── statement_test.go
├── transaction.go
└── xmysql
├── _testdata
├── .gitignore
├── base.sql
├── data_types_datetime.sql
├── data_types_numeric.sql
├── data_types_string.sql
├── inserting.sql
├── prepared_stmt.sql
└── schema_collections.sql
├── authentication.go
├── capabilities.go
├── collations.go
├── collations_data.go
├── collations_test.go
├── collection.go
├── collection
├── create.go
└── get.go
├── collection_test.go
├── connection_config.go
├── context.go
├── context_test.go
├── crud_add.go
├── defaults.go
├── errors.go
├── examples_test.go
├── internal
├── network
│ ├── messages.go
│ ├── proto.go
│ ├── read.go
│ ├── tls.go
│ ├── trace.go
│ └── write.go
└── statements
│ ├── quote.go
│ ├── quote_test.go
│ ├── statement_test.go
│ └── statements.go
├── main_test.go
├── notice.go
├── prepared.go
├── prepared_test.go
├── result.go
├── result_test.go
├── schema.go
├── schema_test.go
├── session.go
├── session_test.go
└── xproto
├── command.go
├── expr.go
├── expr_test.go
├── fields.go
├── scalar.go
└── scalar_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | ?dev_*
2 | *.proto
3 | cmd/scratch
4 | _testdata/*.pem
5 |
6 | .DS_Store
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | skip-dirs:
3 | - cmd/scratch
4 | - internal/mysqlx
5 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/copyright/Geert_JM.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/copyright/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/dataSources.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | mysql.8
6 | true
7 | com.mysql.cj.jdbc.Driver
8 | jdbc:mysql://localhost:53306
9 | $ProjectFileDir$
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pxmysql.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/sqldialects.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/watcherTasks.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/CHANGELOG.yaml:
--------------------------------------------------------------------------------
1 | pxmysql:
2 | - meta:
3 | projectURL: https://github.com/golistic/pxmysql
4 | description: |
5 | Go MySQL driver using X Protocol communicating with the MySQL server using
6 | Protocol Buffers.
7 |
8 | All notable changes to this project will be documented in this file.
9 | We follow the conventionalcommits.org specification.
10 |
11 | Change entries with prefix `(!)` warn for a "breaking change".
12 | - versions:
13 | - version: v0.9
14 | date: 2023-01-24
15 | description: Initial development release (not production ready).
16 | patches:
17 | - version: v0.9.8
18 | date: 2023-08-27
19 | refactor:
20 | driver:
21 | - (!) We move the registration of the `sql`-driver "pxmysql" to the subpackage
22 | `github.com/golistic/pxmysql/register` (driver name "mysql" to `../register/mysql`).
23 | Refactoring should not break things, but this does. Users must change the (anonymous)
24 | import using the new sub-package.
25 | - Use `github.com/golistic/xgo/xsql` for managing the data source name.
26 | build:
27 | - Dependencies have been tidied and updated where needed.
28 | - version: v0.9.7
29 | date: 2023-08-15
30 | fixed:
31 | driver:
32 | - Properly deallocate prepared statements when using the connection methods
33 | `ExecContext` and `QueryContext` preventing the server to reach maximum
34 | prepared statements.
35 | refactor:
36 | general:
37 | - Cleanup dependencies and use `golistic/xgo` instead of the now deprecated
38 | subpackages within `golistic` or `github.com/geertjanvdk/xkit`.
39 | build:
40 | general:
41 | - Go version has been upped to 1.21 to make it clear that we eventually might
42 | use some features from that version.
43 | - version: v0.9.6
44 | date: 2023-08-09
45 | fixed:
46 | driver:
47 | - `pxmysql.QueryContext()` will now correctly return empty Rows-object when
48 | result has no rows, instead of returning `sql.ErrNoRows`.
49 | - (!) Go `sql` driver is now named `pxmysql` so it aligns with the package name;
50 | we do not keep backward compatibility.
51 | - We support the driver name "mysql" as some projects need to use this name. When
52 | this is needed, load anonymous sub-package `github.com/golistic/pxmysql/mysql`.
53 | added:
54 | driver:
55 | - We support the driver name "mysql" as some projects need to use this name. When
56 | this is needed, load anonymous sub-package `github.com/golistic/pxmysql/mysql`.
57 | build:
58 | - Upgrade ProtoBuf MySQL code to MySQL 8.0.34 (but no changes).
59 | - version: v0.9.5
60 | date: 2023-05-28
61 | fixed:
62 | - handling too large packets
63 | - wrap driver.ErrBadConn
64 | - version: v0.9.4
65 | date: 2023-05-10
66 | fixed:
67 | - Recover from server timing out connections.
68 | - version: v0.9.3
69 | date: 2023-05-10
70 | changed:
71 | - Updated protocol buffer generated code and collations to MySQL 8.0.32.
72 | added:
73 | - Added golistic/gomake targets for linting, reporting, and badges.
74 | - Added badges, generated/stored within repository, to README.md.
75 | fixed:
76 | - Fixed error returned when Unix socket is not available.
77 | - Fixed cmd/gencollations to use TLS and set password as valid nullable.
78 | - Fixed linting issues reported by linters run by golangci-lint.
79 | - Replaced deprecated package golang.org/x/crypto/ssh/terminal.
80 | - Fixed handling DATETIME zero values for time parts.
81 | - version: v0.9.2
82 | date: 2023-02-14
83 | changed:
84 | - Fixed naming of `pxmysql.ParseDSN` (before it was `ParseDNS`).
85 | fixed:
86 | - Fixed including query part when getting string representation of DataSource.
87 | - Fixed slash detection when not using schema name together with query part.
88 | - version: v0.9.1
89 | date: 2023-02-03
90 | fixed:
91 | - Fixed parsing query string of DSN so `useTLS` works as expected.
92 | - Fixed using connection address without TCP port.
93 | - Fix wrapping errors.
94 | - Finish testing Unix socket support.
95 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022, 2023, Geert JM Vanderkelen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/_badges/badges.json:
--------------------------------------------------------------------------------
1 | {
2 | "style": "flat-square",
3 | "urlShieldsIO": "https://img.shields.io/static/v1",
4 | "destFolder": "_badges",
5 | "badges": [
6 | {
7 | "name": "go-version",
8 | "label": "GO",
9 | "messageFunc": "go.mod.version",
10 | "color": "#00ADD8"
11 | },
12 | {
13 | "name": "license",
14 | "label": "License",
15 | "message": "MIT",
16 | "color": "#97ca00"
17 | }
18 | ]
19 | }
--------------------------------------------------------------------------------
/_badges/go-version.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/_badges/license.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/.gitignore:
--------------------------------------------------------------------------------
1 | data*/
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/.gitignore:
--------------------------------------------------------------------------------
1 | *req.pem
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/01_basics.cnf:
--------------------------------------------------------------------------------
1 | [mysqld]
2 | skip_name_resolve # DNS not important for testing
3 | mysqlx_socket = mysqlx.sock
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/02_tls.cnf:
--------------------------------------------------------------------------------
1 | [mysqld]
2 | ssl_ca=/etc/mysql/conf.d/ca.pem
3 | ssl_cert=/etc/mysql/conf.d/server-cert.pem
4 | ssl_key=/etc/mysql/conf.d/server-key.pem
5 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/ca-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEogIBAAKCAQEAucVTs35msvYgC9iu6elW3aXzknxW3CmQasgjD7VkufT44OgG
3 | /ce10Hjp3QrDSqXqXCpVBxi8bxWHzCnSwtSh6K7/HjoC8akeVzv9HIqjUUCEpQWT
4 | 9eayDG7hPrq1wJ8Hc6krh1h//dQlDNQeUG+uuhvknA83H01yyJKQcWBN2mMLZuBN
5 | apNEAcdsWjq5rrRxj4B0Igja1g7E4XQnYfy8fY+Uc0a8G1wt01+3ftSrpa0+jJug
6 | 0c6S8H14TgDjGtGqDK2wmy4QV4rcKXqkO9E1E1Gh8LmGkexCw49pgeujzh8TqvdU
7 | 5o2jxh3YSPbZWkNfilqE2SrYAiF2X99LNRlYvwIDAQABAoIBADusg2KZK+w427py
8 | dF13MwwoDsHzZwN55oYmm/yjzCNf6cJ1RimnSWQaMyVqG6mS+mF4x69r5rvYMrMG
9 | jElBfHD+Jb1T7TYrmS90ea39atDi5LkNvaWz4WXVCE3aNCAX9ZDVusHTT+n9h5lD
10 | WimEdqAZ7amjyZUoj8KWMgf5Y4jO07iOOD3FH5g3VW3D8C/bl4hQuddlFveKp708
11 | UeGFE9EobDf1YU/vF+J+tVZC5xaniPiSicQsyebuTrTmT4Z5niVZu2C9dynzujWn
12 | 3zBjIX8iwmQdjznLCOOWJKnKr5GP+eyW9NHvpC+DFWxtp3q0gn6G8Q65ICkM+UWw
13 | sqawrIECgYEA5wDLju75zLrxhRpn5gAC1z23gPa+B0APy93fMgbVYFKFI/jk0LSu
14 | 4fCPDsQZ2OfmBzQO4Q8i8mYLvypywP1lPiaquDYSe1Ykp7pfCXEp5ZTc1OL6Zudv
15 | Hb6v7QvJ7VYXrtU3/Wy9fVDtGNQfI9J6JvXIknQbSvn9COlvSGnYNRECgYEAzd+D
16 | 4U9qm4wESYpTXM4ea9HB61YcGhu9gpLIF2H56aOSRzm382kzjI9QlpMJEiaeBXyy
17 | RZa3Uw7w5pRkGhRcwFpEJBfzvvEFQqtM9yiZw+n/2S3ZvRPeDe5dH8BqRO/KKRiq
18 | 2LSQ/OEu78/cq/GU0dAKiLmrmzVp6HHK6usfcM8CgYBWO8zBieKEk9DvYEEi8iQd
19 | V7O2F+YubLK45xWX5kcnUwbSu+onIxwZyiSNXZVMjJ0pWTyotW7VUFTYQy9dbfqq
20 | beLTK5RQqIK8fm1V6AG864pYinbxjTnEv9eKxRjXWYkzwfLJzxsZuekYmK8bP0pM
21 | WvpJ+b/qiFH2TrY1MRX+EQKBgGAFWzZwWxHXmXxPZxhHDstNFzxTemH3BEnteiPl
22 | z7FYWHaeBh0iuSdbBMRmKfnsRxHaGi/43uJ/en6hQZskWiphL50CCu7I7aIt0YUJ
23 | y8Yj0vARwZe9t3kZ7xdLIIWsrcbDOZQ/i8xWnxS9B3ivAbFmbjNdHhwTKqV+xZ0S
24 | MyTjAoGACgtKzGhSdZYgJX4tapuRV+f0rcvN6dvKMleWtShIfA+g0hJwzOmMe1hw
25 | JFAbK433pXQYvnpCYbajPm9wcjygvDotjvjdw5d1nIEBJP6VuLux10OwKxno4+1c
26 | HgGhCcrueb9mfwUQ3FvooJ1bsSkkU+8qV25GwQjk11vL9qyDA8I=
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/ca.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIICyjCCAbICCQC/DWKJCY0kvjANBgkqhkiG9w0BAQsFADAnMQswCQYDVQQGEwJE
3 | RTEYMBYGA1UEAwwPcHhteXNxbC10ZXN0LUNBMB4XDTIyMTIxMzA2MzE1NFoXDTMy
4 | MTAyMTA2MzE1NFowJzELMAkGA1UEBhMCREUxGDAWBgNVBAMMD3B4bXlzcWwtdGVz
5 | dC1DQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALnFU7N+ZrL2IAvY
6 | runpVt2l85J8VtwpkGrIIw+1ZLn0+ODoBv3HtdB46d0Kw0ql6lwqVQcYvG8Vh8wp
7 | 0sLUoeiu/x46AvGpHlc7/RyKo1FAhKUFk/Xmsgxu4T66tcCfB3OpK4dYf/3UJQzU
8 | HlBvrrob5JwPNx9NcsiSkHFgTdpjC2bgTWqTRAHHbFo6ua60cY+AdCII2tYOxOF0
9 | J2H8vH2PlHNGvBtcLdNft37Uq6WtPoyboNHOkvB9eE4A4xrRqgytsJsuEFeK3Cl6
10 | pDvRNRNRofC5hpHsQsOPaYHro84fE6r3VOaNo8Yd2Ej22VpDX4pahNkq2AIhdl/f
11 | SzUZWL8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEARq0IJN06ajIEkHclKa8eIjOH
12 | HuYH1c38cJsXQN1BXE9Sv2EalrTRrwXSsAwHvKhRClbJeodlLXbC+C4gX+fywX/m
13 | Oc5SvJNGdUZyT1pqF0wPjIfYx9aO9sXsh/yE8ePwhxAeRU2i68QW6wB+AB6AnTIl
14 | op1v8gtzXHaLsUdiDD++98LEHEbYiUElnQBKwG2mzt223gCIIMlfUIHF6Hd2APKg
15 | ChAkUS2buFx93+b1ik4l63HAFqXue1X0Q8W4eMubbU0Ti2AkkLtOfwlTLrSEOOtf
16 | rPMMVMfwpIeq8tkrVNLm0zrKUSXP2cu1BmffZUoaewz/5yv5eybK00h4iqldwQ==
17 | -----END CERTIFICATE-----
18 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/ca.srl:
--------------------------------------------------------------------------------
1 | 5DE6BF0D6CC6DAF943F4306AE1EA712AD007F2DF
2 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/client-cert.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIICzjCCAbYCCQC8fGqNl/yoxjANBgkqhkiG9w0BAQsFADAnMQswCQYDVQQGEwJE
3 | RTEYMBYGA1UEAwwPcHhteXNxbC10ZXN0LUNBMB4XDTIyMTIxMzA2MzE1NFoXDTMy
4 | MTAyMTA2MzE1NFowKzELMAkGA1UEBhMCREUxHDAaBgNVBAMME3B4bXlzcWwtdGVz
5 | dC1jbGllbnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/hOzDqY4I
6 | 85p8VK1rtqkSgqNAg3tSixRuucwZRjW3wEuNbO1ZxW0hzVqXGxNp/hLzlFaa/YOe
7 | 3oQTDdztdUOysDVVUnnytx/8uUtMgcsYrMw8RWPZbazCwwLPW0v3KHigqyC+3A2y
8 | voRYMfUeS4MR2AWnGzJSxS5wM0Pyd2VX31slapFZdG/bDxFKolQDa9nkEG5WWoj2
9 | pLBPajDdXkYequUQvmRURQkv0lkIy6TX16Y3dmwepgbX2lIROQy8TcwY+iisge+O
10 | w4IEW2XHCZLr7deIzcdU3Q50ttFJCYsThn2hs0zHAzk6gccSMdPWcChpQfU0f0hV
11 | whe9bjTs3W/zAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABVInbT5lKQ40C9frZHs
12 | IEhKDWtkXuy6pQAdTZu+ZasMR9RJx5DtJVbijs7DRvXu8Zbiiv7GpuVUEGxGqZBs
13 | CiA122aIr4KX36SfC4/apA127WbcYI1a6jBkYU+B8shXRcCeD23MaNlMZ05U8odd
14 | YFG/FHBaiJ3+atrqCkuqXl9fXCRWBWlClhxuzS+16RU6uQoqboFwZLobeN8zzVX7
15 | hw7qDkS/6yWGtHxDB7yBPil+cj3IKqQ5pEnrpufpKBN2yiB6Rrh6dNVZdfaSUutV
16 | rrei0lpWQBAfhe/3On/pXfuUtadWZbKR0EGItghVYoHwyh0Eq1zvPJpopt016qVm
17 | sPM=
18 | -----END CERTIFICATE-----
19 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/client-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEpQIBAAKCAQEAv4Tsw6mOCPOafFSta7apEoKjQIN7UosUbrnMGUY1t8BLjWzt
3 | WcVtIc1alxsTaf4S85RWmv2Dnt6EEw3c7XVDsrA1VVJ58rcf/LlLTIHLGKzMPEVj
4 | 2W2swsMCz1tL9yh4oKsgvtwNsr6EWDH1HkuDEdgFpxsyUsUucDND8ndlV99bJWqR
5 | WXRv2w8RSqJUA2vZ5BBuVlqI9qSwT2ow3V5GHqrlEL5kVEUJL9JZCMuk19emN3Zs
6 | HqYG19pSETkMvE3MGPoorIHvjsOCBFtlxwmS6+3XiM3HVN0OdLbRSQmLE4Z9obNM
7 | xwM5OoHHEjHT1nAoaUH1NH9IVcIXvW407N1v8wIDAQABAoIBAQCOX/LjQhkk7nPa
8 | GdkSSihGaneSbiwvoNT/u3/PCjLE918zM9b+9ZW7mz3NN4OnOAo+qff4IJ7IbAMj
9 | ZxrmLFa3b+c2FqoxlZFh/x3LMnIZVdw+shcYfEACSZa9L9G5W4zRZGZjfJNyXc9l
10 | AT6H1vsJON566+ztO0jagEHy7m+YcliHTzwzod/9melbto7AwedlbJNut/j3aFVz
11 | 6wN5NkaQu7C/NiNQ26N9QDnNSVoYlFOB9opCNVmLcHK/hdSNjNcOB3UpKePQXBYo
12 | GKkN/XtuZBSVhw9d6tEub+C5oBOOQ3EJuX5WwST/aj1cgB6HGlSahmie0r94ZWxX
13 | OC4lOxYBAoGBAPGa/V/1bvJVhq65sM543nuC4tMZUgOb5mQp+oumtPplnG3RPNP7
14 | np9+xoAaUz55KYkf74gC4Ux+10A/qXwrzJ9YIX8+ZpLNJZwWfQUmheZOXowfP0HL
15 | 9oMBRXmy1ivRUnToqR6fkSr+yHVJ3AJ/oy0MDVNrjs0YWYIu+tNvt+YjAoGBAMru
16 | Aug7l8HKgQi/Wg1ylpLi4Xl3UcqiHSW+soctcjrJm1z10pih92+Hqcg21wbFAR0P
17 | ESAslcD8OlCJYblSXi0zXE1eI111iCxcF/X9CyY0o+BsBIiJA+QDtIUwIfvk2Ka1
18 | nx0rkVpQ1H3ilDm4BpgMm0SFdCXDj+e2cgHpjSPxAoGBAPD9YeJHU4UQ3iiGO9+X
19 | HIQiR9G8ndvPs30RikGl5TsmA2ReosfnYY9BywmYOJRGErIeUrRd+xBsLJR/a7TZ
20 | k18Vb0QWoAWp7uvEWqu6gzD31sL5oAUnRxnhOMVtJsfKIO9P6vEKxKgYPycOpw8u
21 | 9TpHnTsqO+RDd3StG6+u7cX1AoGBAKAWdfqpEIZT78lr02nqbPkBvShqxf6aN25Q
22 | a1ySsJvJ8iO61eGNXLsChiEpiiaQAdnfyf3czmMJWCOyzYI6hYsZCocKbdHL55o/
23 | KLPpZQNF4cYo0Ma5eHVHqwCrQRQLrBKQEy8a8LcULx4EQjTqhWEsCM1cjo1AIuWE
24 | G5qAmdSxAoGAe43Q/b3WFTLorGORzZ0ro9VILYp1iqAm2A2k4Fms1OiZP1MXksXU
25 | +kLG0ovtpGbxnzqd/ybyZWWWTr6q06aTXAhAzttgM1f2twk66oXbIpbMO1/RG/h5
26 | 13d4xXJFsDDASwH41cmcj92/sxUbRg4comNKmMUkJPGQt5EssVJwCJ4=
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/server-cert.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIC6DCCAdCgAwIBAgIJALx8ao2X/KjFMA0GCSqGSIb3DQEBCwUAMCcxCzAJBgNV
3 | BAYTAkRFMRgwFgYDVQQDDA9weG15c3FsLXRlc3QtQ0EwHhcNMjIxMjEzMDYzMTU0
4 | WhcNMzIxMDIxMDYzMTU0WjArMQswCQYDVQQGEwJERTEcMBoGA1UEAwwTcHhteXNx
5 | bC10ZXN0LXNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAK0d
6 | K4okwsm61i5KX200bAeSzBujCLKp+TBr3PxhziSEL0i4ib904ipAwqYKdLnMiuzk
7 | QFGXNo0UriSwBeiSZDCTWPUGyxic8iI+pP+83OvxP20DWNaxizP3QxwdLDnsXqZm
8 | KQriYYBIsAul6bc5MnzVRRUVkVdH8xc/LvVfsqYkEEnpsbM1hQY2vsOhRvs0hnJa
9 | 3Q6HaFFLpmtwmmrnostIhAw79zyDec7D0naXy9+zhAPV9AwI4yieBxIs+6zdO+dR
10 | BmTNN4yrY8sVsRrqHVVs/KMiAOM7xzUVSlL7g6PepNaTFno6jZYzDVZu00I/l9KL
11 | jsZyYAWzCX0RntaGqnUCAwEAAaMTMBEwDwYDVR0RBAgwBocEfwAAATANBgkqhkiG
12 | 9w0BAQsFAAOCAQEApccqby/++1A7EFv5+jTnRVgnB3fQ1POzPbeBADye35mJNmcU
13 | Ks7qCj22gGRNO/INW71j2GBtyaTzT0FtZytKIMWv3kC6kgtVYo3r6cfOQihD+T3o
14 | G1PNKLicGgmfaSyHhtTA6WK3CLAQu/yS4KARIWi3X88H7zzBfKHZ23+rNGzILtGv
15 | 3RiyP4hPD+y1a4nObBjcSM5F09qRZzQpcpwldrZGSZrbrqL9uqeE4V7403sFuqNt
16 | IycSNEchF7vAm/0oqbYsuRQkjZyk77OiDxhWrtWSDkUdx6v+CAaCme3DD9SK/0l/
17 | vJhrCW2HUmXnrYEWVkzX6mverbTDAQ5v4I/sAg==
18 | -----END CERTIFICATE-----
19 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.d/server-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEpgIBAAKCAQEArR0riiTCybrWLkpfbTRsB5LMG6MIsqn5MGvc/GHOJIQvSLiJ
3 | v3TiKkDCpgp0ucyK7ORAUZc2jRSuJLAF6JJkMJNY9QbLGJzyIj6k/7zc6/E/bQNY
4 | 1rGLM/dDHB0sOexepmYpCuJhgEiwC6XptzkyfNVFFRWRV0fzFz8u9V+ypiQQSemx
5 | szWFBja+w6FG+zSGclrdDodoUUuma3Caaueiy0iEDDv3PIN5zsPSdpfL37OEA9X0
6 | DAjjKJ4HEiz7rN0751EGZM03jKtjyxWxGuodVWz8oyIA4zvHNRVKUvuDo96k1pMW
7 | ejqNljMNVm7TQj+X0ouOxnJgBbMJfRGe1oaqdQIDAQABAoIBAQCQN13fTvKrZigq
8 | FjFbY7GfuY6qc266kNmUmjdWVhCK4UgW+A1hX3lOo/bEpq9JXfpakWh30FZUv+a3
9 | j6DMeLBYu1f/gLJPheg92RxSJL+TG76wDXrEGNKT7yiMUk1Wz/CmBTOp6qA5Y9St
10 | T4Hd7xt9XZqYjwguwzTjp/Jx3lCRENlzAPBgIeXmEvAFWdT9XNaI4PcDBR9uGCp7
11 | HOg9GyGrsPdExKDQtukrJzNKu/63oFrQKbibks3cVWwtKEq0Xyz/tt5I4yACG8AB
12 | DFCrmA+e04C8OzliGupOI4GRrQhYxNOdNSxqJU04x9q1JyqHfZ0Nf9pevdYKA9Z7
13 | rwK5Vp79AoGBANNBbhdV0hjqjMNO99Nds9Rscg30dfojbxUxulcl/N4OAE5UpkmC
14 | vFwNu0ktfMz8OBlgkaW9lSRS6KsCtnh/cLNKB14VErM58J1qx0/9PUJSPbNipuVn
15 | hk0ZV3TGgxV1mYqZhNerKaUPJ7ZNPCe9h0lh6OEV/UlQiq7QErAMwwpjAoGBANHH
16 | pcXwPSMjtf6pXHaDF0es6DH8dQB2ZpsaiJwm6AbMJvwFfJEaYApnEzTEtoPahD5i
17 | 24iPW9IVBERKJaTAZe4QYwd0gaQYFPmMEPJZtCyD9cwtOUgbIVMSsCraYhLXg9ou
18 | 6QRpaXdERiLzmffO+QUkfPYofggONuuX3YW9s+NHAoGBAJyk0JAnB7GIAbY0kNi+
19 | i0CA5RVp5i0DJzQM+oHyXgz9TsbGR8MMWMTdPbkmLHsGrkZK79R4veUAQRvE2C6D
20 | OLsIsmvVrlcNKFhhO8cZHNpXhv7DsMM7vz7eApZJOBuqZp559SHB/hAxK54mqOtC
21 | wtTr77UvC+/X8+1pxeGapOjHAoGBANAeW64mCuFjulituReyMlRfi/SbW5Bb5quW
22 | BVW1m5eyzjJVVyG1ovZvEDTXu6LQFUa3WMkAQL4JL7R4QyRR5E3sX/KzeTJM2fJB
23 | LUbiC8fmGuK3Mw8AK215KuE4yveabCr3QyGnWoSCbXqbZnLdGVwquPaVcYOYZpAQ
24 | mCro6yBdAoGBAMYcSrtG8T4/UsN0j7vyjRm0+VKDci4zjvzpKR5O8597tnvY8R63
25 | b60rkqogT0x/yTD7qbgBgJ41Z4o3W7WilVt10E/eA4sStdhepQ6NPXQnPKIwMMof
26 | ZkVAH0mGvUuYf1lFn0DKidBe823nS/ezdhckjrZc8iIAPbz+frSZnHCq
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/conf.srl:
--------------------------------------------------------------------------------
1 | BC7C6A8D97FCA8C6
2 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/docker-compose.yml:
--------------------------------------------------------------------------------
1 | networks:
2 | pxmysql.test.net:
3 | driver: bridge
4 |
5 | volumes:
6 | shared:
7 |
8 | services:
9 | mysql:
10 | container_name: pxmysql.test.db
11 | image: mysql:8.0
12 | environment:
13 | MYSQL_ROOT_PASSWORD: rootpwd
14 | command: --default-authentication-plugin=mysql_native_password
15 | networks:
16 | - pxmysql.test.net
17 | volumes:
18 | - ./data:/var/lib/mysql
19 | - ./conf.d:/etc/mysql/conf.d
20 | - ./shared:/shared:ro
21 | ports:
22 | - 127.0.0.1:53306:3306
23 | - 127.0.0.1:53360:33060
24 |
25 | go:
26 | container_name: pxmysql.test.go
27 | image: golang:1.21-bullseye
28 | tty: true
29 | stdin_open: true
30 | volumes:
31 | - ./shared:/shared
32 | - ../../../pxmysql:/go/src/github.com/golistic/pxmysql:ro
33 |
34 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/generate_tls.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 | #
3 | # Copyright (c) 2022, Geert JM Vanderkelen
4 | #
5 |
6 | set -e
7 |
8 | #
9 | # This script generates the MySQL Server Certificate Authority (CA).
10 | # It also generates server and client certificates.
11 | #
12 | # The server certificate has X.509 subjectAltName extension with value
13 | # of 'IP:127.0.0.1'.
14 | #
15 | #
16 |
17 | OUT_DIR=conf.d
18 | CA_KEY="${OUT_DIR}/ca-key.pem"
19 | CA_CERT="${OUT_DIR}/ca.pem"
20 | SERVER_KEY="${OUT_DIR}/server-key.pem"
21 | SERVER_REQ="${OUT_DIR}/server-req.pem"
22 | SERVER_CERT="${OUT_DIR}/server-cert.pem"
23 | CLIENT_KEY="${OUT_DIR}/client-key.pem"
24 | CLIENT_REQ="${OUT_DIR}/client-req.pem"
25 | CLIENT_CERT="${OUT_DIR}/client-cert.pem"
26 | DAYS=3600
27 |
28 | OPENSSL=$(command -v openssl)
29 | if [ "${OPENSSL}" = "" ]; then
30 | echo "Error: openssl command not available in path"
31 | exit 1
32 | fi
33 |
34 | v=$($OPENSSL version)
35 | case "${v}" in
36 | "OpenSSL 1.1"*) ;;
37 | "LibreSSL 3.3.6"*) ;;
38 | *)
39 | echo "Error: expecting OpenSSL v1.1 or greater, or LibreSSL v3.3 or greater (got ${v})"
40 | exit 1
41 | esac
42 |
43 | # Create CA certificate
44 | ${OPENSSL} genrsa -out ${CA_KEY} 2048
45 | ${OPENSSL} req -new \
46 | -subj "/C=DE/CN=pxmysql-test-CA" \
47 | -x509 -days ${DAYS} -nodes -key ${CA_KEY} -out ${CA_CERT}
48 |
49 | # Create server certificate (removing passphrase)
50 | ${OPENSSL} genrsa -out ${SERVER_KEY} 2048
51 | ${OPENSSL} req -new -nodes \
52 | -subj "/C=DE/CN=pxmysql-test-server" \
53 | -key ${SERVER_KEY} -out ${SERVER_REQ}
54 | ${OPENSSL} x509 -req -days ${DAYS} -in ${SERVER_REQ} \
55 | -CA ${CA_CERT} -CAkey ${CA_KEY} -CAcreateserial \
56 | -extfile server_x509_ext.conf \
57 | -out ${SERVER_CERT}
58 |
59 | # Create client certificate (removing passphrase)
60 | ${OPENSSL} genrsa -out ${CLIENT_KEY} 2048
61 | ${OPENSSL} req -new -nodes \
62 | -subj "/C=DE/CN=pxmysql-test-client" \
63 | -key ${CLIENT_KEY} -out ${CLIENT_REQ}
64 | ${OPENSSL} x509 -req -days ${DAYS} -in ${CLIENT_REQ} \
65 | -CA ${CA_CERT} -CAkey ${CA_KEY} -CAcreateserial \
66 | -out ${CLIENT_CERT}
67 |
68 | rm ${OUT_DIR}/*-req.pem
69 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/server_x509_ext.conf:
--------------------------------------------------------------------------------
1 | subjectAltName=IP:127.0.0.1
--------------------------------------------------------------------------------
/_support/pxmysql-compose/shared/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | out="/shared/builds/$1"
4 | main="/shared/goapps/$1"
5 |
6 | export GOPRIVATE="github.com/golistic/pxmysql"
7 |
8 | cd "${main}" || exit 1
9 | go mod tidy
10 | go build -o "${out}" .
11 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/shared/goapps/unix_socket/go.mod:
--------------------------------------------------------------------------------
1 | module example.com/unix_socket
2 |
3 | go 1.21
4 |
5 | replace github.com/golistic/pxmysql => /go/src/github.com/golistic/pxmysql
6 |
7 | require github.com/golistic/pxmysql v1.0.0
8 |
9 | require (
10 | github.com/golistic/xgo v1.0.0 // indirect
11 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect
12 | google.golang.org/protobuf v1.31.0 // indirect
13 | )
14 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/shared/goapps/unix_socket/go.sum:
--------------------------------------------------------------------------------
1 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
2 | github.com/golistic/xgo v1.0.0 h1:NtwCDhGBJR0vOvn0l8S8z9wzmN3hYRM8Wwwd1+2BFO0=
3 | github.com/golistic/xgo v1.0.0/go.mod h1:em3spZJ1b8mrGv1P5Fx2Ewclaf0rqIfSwrzYIVnL6o4=
4 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
5 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
6 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
7 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
8 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
9 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
10 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
11 | google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
12 | google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
13 |
--------------------------------------------------------------------------------
/_support/pxmysql-compose/shared/goapps/unix_socket/main.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | /*
4 | This application is executed within the MySQL container by the pxmysql
5 | Go tests.
6 |
7 | The MySQL sock-file within the Docker container cannot be accessed through
8 | a Docker volume. Copying in an application which uses pxmysql is therefor
9 | the only way to automate testing of UNIX socket file support.
10 | */
11 |
12 | package main
13 |
14 | import (
15 | "context"
16 | "database/sql"
17 | "fmt"
18 | "log"
19 |
20 | _ "github.com/golistic/pxmysql/register"
21 | )
22 |
23 | func main() {
24 | // credentials are the once used when running pxmysql tests
25 | db, err := sql.Open("pxmysql", "root:rootpwd@unix(/var/lib/mysql/mysqlx.sock)")
26 | if err != nil {
27 | log.Fatalln("open:", err)
28 | }
29 |
30 | if err := db.Ping(); err != nil {
31 | log.Fatalln("ping:", err)
32 | }
33 |
34 | var version string
35 | if err := db.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&version); err != nil {
36 | log.Fatalln("query row:", err)
37 | }
38 |
39 | fmt.Println(version)
40 | }
41 |
--------------------------------------------------------------------------------
/cmd/.gitignore:
--------------------------------------------------------------------------------
1 | scratch/
--------------------------------------------------------------------------------
/cmd/gencollations/main.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen
2 |
3 | package main
4 |
5 | import (
6 | "context"
7 | "flag"
8 | "fmt"
9 | "go/format"
10 | "os"
11 | "time"
12 |
13 | "github.com/golistic/xgo/xstrings"
14 | "golang.org/x/term"
15 |
16 | "github.com/golistic/pxmysql/xmysql"
17 | )
18 |
19 | type appFlags struct {
20 | addr string
21 | username string
22 | }
23 |
24 | func main() {
25 | flags := initFlags()
26 |
27 | fmt.Printf("Password for %s (enter for empty): ", flags.username)
28 | p, err := term.ReadPassword(0)
29 | if err != nil {
30 | fmt.Printf("Error: failed reading password (%s)\n", err)
31 | os.Exit(1)
32 | }
33 | fmt.Println()
34 |
35 | config := &xmysql.ConnectConfig{
36 | Address: flags.addr,
37 | Username: flags.username,
38 | Password: xstrings.Pointer(string(p)),
39 | UseTLS: true,
40 | }
41 |
42 | ses, err := xmysql.GetSession(context.Background(), config)
43 | if err != nil {
44 | fmt.Printf("Error: %s\n", err)
45 | os.Exit(1)
46 | }
47 |
48 | q := "SELECT VERSION()"
49 | res, err := ses.ExecuteStatement(context.Background(), q)
50 | if err != nil {
51 | fmt.Printf("Error: %s\n", err)
52 | os.Exit(1)
53 | }
54 | version := res.Rows[0].Values[0].(string)
55 |
56 | charset := "utf8mb4"
57 | q = "SELECT id, collation_name AS name " +
58 | "FROM information_schema.collations WHERE character_set_name = ? ORDER BY id"
59 | res, err = ses.ExecuteStatement(context.Background(), q, charset)
60 | if err != nil {
61 | fmt.Printf("Error: %s\n", err)
62 | os.Exit(1)
63 | }
64 |
65 | content := fmt.Sprintf("// Code generated by gencollations. DO NOT EDIT.\n"+
66 | "// MySQL v%s; generated at %s\n\n"+
67 | "package xmysql\n\nvar Collations = map[string]Collation{\n", version, time.Now().UTC())
68 |
69 | entry := "\"%s\": {ID: %d, Name: \"%s\", CharSet: \"%s\",},\n"
70 |
71 | entryID := "%d: \"%s\",\n"
72 |
73 | var ids string
74 | for _, r := range res.Rows {
75 | v := r.Values
76 | content += fmt.Sprintf(entry, v[1].(string), v[0].(uint64), v[1].(string), charset)
77 | ids += fmt.Sprintf(entryID, v[0].(uint64), v[1].(string))
78 | }
79 |
80 | content += "}\n"
81 |
82 | content += "\nvar collationIDs = map[uint64]string{\n" + ids + "\n}"
83 |
84 | data, err := format.Source([]byte(content))
85 | if err != nil {
86 | fmt.Printf("Error: failed formatting Go code (%s)\n", err)
87 | os.Exit(1)
88 | }
89 |
90 | fn := "xmysql/collations_data.go"
91 | fp, err := os.OpenFile(fn, os.O_WRONLY|os.O_TRUNC, 0666)
92 | if err != nil {
93 | fmt.Printf("Error: failed opening %s for writing (%s)\n", fn, err)
94 | os.Exit(1)
95 | }
96 | defer func() { _ = fp.Close() }()
97 |
98 | if _, err := fp.Write(data); err != nil {
99 | if err != nil {
100 | fmt.Printf("Error: failed writing to %s (%s)\n", fn, err)
101 | os.Exit(1)
102 | }
103 | }
104 | }
105 |
106 | func initFlags() *appFlags {
107 | f := &appFlags{}
108 | flag.StringVar(&f.addr, "address", "localhost:33060",
109 | "address (host with port) of the MySQL Server's X Plugin")
110 | flag.StringVar(&f.username, "user", "root",
111 | "user for the connection to the MySQL Server")
112 |
113 | flag.Parse()
114 | return f
115 | }
116 |
--------------------------------------------------------------------------------
/cmd/genprotobuf/main.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "os"
11 | "os/exec"
12 | "path"
13 | "strings"
14 | "time"
15 |
16 | "github.com/golistic/xgo/xos"
17 | )
18 |
19 | const (
20 | mysqlVersion = "8.0.34"
21 | baseURL = "https://raw.githubusercontent.com/mysql/mysql-server/mysql-" + mysqlVersion + "/plugin/x/protocol/protobuf/"
22 | )
23 |
24 | const pkgxmysql = "github.com/golistic/pxmysql"
25 |
26 | const nrOfFilesAtLeast = 5
27 |
28 | var protoPath = path.Join("internal", "mysqlx")
29 |
30 | func main() {
31 | if err := generate(); err != nil {
32 | exitWithErr(err)
33 | }
34 | }
35 |
36 | func exitWithErr(err error) {
37 | fmt.Println("Error:", err)
38 | os.Exit(1)
39 | }
40 |
41 | func checkExecLocation() (string, error) {
42 | d, err := os.Getwd()
43 | if err != nil {
44 | return "", fmt.Errorf("failed getting working directory (%w)", err)
45 | }
46 |
47 | needles := []string{".git", protoPath}
48 |
49 | for _, n := range needles {
50 | if !xos.IsDir(path.Join(d, n)) {
51 | return "", fmt.Errorf("must execute within root of xmysql repository")
52 | }
53 | }
54 |
55 | return d, nil
56 | }
57 |
58 | func fetchFile(name string) ([]byte, error) {
59 | u := baseURL + name
60 | resp, err := http.Get(u)
61 | if err != nil {
62 | return nil, fmt.Errorf("failed opening URL downloading %s (%w)", name, err)
63 | }
64 | if resp.StatusCode != http.StatusOK {
65 | return nil, fmt.Errorf("failed opening URL downloading %s (HTTP status %d)", name, resp.StatusCode)
66 | }
67 | defer func() { _ = resp.Body.Close() }()
68 |
69 | body, err := io.ReadAll(resp.Body)
70 | if err != nil {
71 | return nil, fmt.Errorf("failed reading body downloading file %s (%w)", name, err)
72 | }
73 |
74 | return body, nil
75 | }
76 |
77 | func protoFiles() ([]string, error) {
78 | fileData, err := fetchFile("source_files.cmake")
79 | if err != nil {
80 | return nil, err
81 | }
82 |
83 | files := make([]string, 0, nrOfFilesAtLeast)
84 | for _, l := range bytes.Split(fileData, []byte("\n")) {
85 | l = bytes.TrimSpace(l)
86 | if len(l) == 0 || l[0] == '#' ||
87 | !(bytes.HasPrefix(l, []byte("mysqlx")) && bytes.HasSuffix(l, []byte(".proto"))) {
88 | continue
89 | }
90 | files = append(files, string(l))
91 | }
92 |
93 | return files, nil
94 | }
95 |
96 | func downloadFile(dir, filename string) error {
97 | fileData, err := fetchFile(filename)
98 | if err != nil {
99 | return err
100 | }
101 | fp, err := os.OpenFile(path.Join(dir, filename), os.O_CREATE|os.O_WRONLY, 0666)
102 | if err != nil {
103 | return fmt.Errorf("failed opening file %s (%w)", filename, err)
104 | }
105 | if _, err := fp.Write(fileData); err != nil {
106 | _ = fp.Close()
107 | return fmt.Errorf("failed writing to file %s (%w)", filename, err)
108 | }
109 | _ = fp.Close()
110 |
111 | return nil
112 | }
113 |
114 | func generate() error {
115 | wd, err := checkExecLocation()
116 | if err != nil {
117 | return err
118 | }
119 |
120 | protoc, err := exec.LookPath("protoc")
121 | if err != nil {
122 | return fmt.Errorf("protoc executable not available")
123 | }
124 |
125 | files, err := protoFiles()
126 | if err != nil {
127 | return err
128 | }
129 |
130 | args := []string{protoc, "--proto_path=" + protoPath,
131 | "--go_out=.",
132 | "--go_opt=paths=import",
133 | "--go_opt=module=github.com/golistic/pxmysql",
134 | }
135 |
136 | for _, f := range files {
137 | if err := downloadFile(protoPath, f); err != nil {
138 | return err
139 | }
140 |
141 | m := strings.Replace(f, ".proto", "", 1)
142 | m = strings.Replace(m, "_", "", -1)
143 | args = append(args, fmt.Sprintf("--go_opt=M%s=%s/%s/%s", f, pkgxmysql, protoPath, m))
144 | }
145 |
146 | args = append(args, files...)
147 |
148 | output := bytes.NewBuffer(nil)
149 |
150 | cmd := exec.Cmd{
151 | Dir: wd,
152 | Path: protoc,
153 | Args: args,
154 | Stdout: output,
155 | Stderr: output,
156 | }
157 |
158 | err = cmd.Run()
159 | switch err.(type) {
160 | case *exec.ExitError:
161 | return fmt.Errorf("execution of protoc failed: %s", output.String())
162 | case error:
163 | return fmt.Errorf("could not run protoc")
164 | }
165 |
166 | for _, f := range files {
167 | _ = os.Remove(path.Join(protoPath, f))
168 | }
169 |
170 | infoFile := path.Join(protoPath, "info.md")
171 | fp, err := os.OpenFile(infoFile, os.O_CREATE|os.O_WRONLY, 0666)
172 | if err != nil {
173 | return fmt.Errorf("failed opening file %s (%w)", infoFile, err)
174 | }
175 | defer func() { _ = fp.Close() }()
176 |
177 | _, err = fp.WriteString(fmt.Sprintf("Generated from MySQL Server %s at %s.\n", mysqlVersion,
178 | time.Now().UTC().Format(time.RFC3339)))
179 | if err != nil {
180 | return fmt.Errorf("failed writing to file %s (%w)", infoFile, err)
181 | }
182 |
183 | return nil
184 | }
185 |
--------------------------------------------------------------------------------
/cmd/make/main.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package main
4 |
5 | import (
6 | "github.com/golistic/gomake"
7 | )
8 |
9 | func main() {
10 | gomake.RegisterTargets(
11 | &gomake.TargetBadges,
12 | &gomake.TargetGoLint,
13 | )
14 | gomake.Make()
15 | }
16 |
--------------------------------------------------------------------------------
/connection.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | "fmt"
9 |
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | type connection struct {
14 | cfg *xmysql.ConnectConfig
15 | session *xmysql.Session
16 | }
17 |
18 | var (
19 | _ driver.Conn = (*connection)(nil)
20 | _ driver.ConnBeginTx = (*connection)(nil)
21 | _ driver.Pinger = (*connection)(nil)
22 | _ driver.ExecerContext = (*connection)(nil)
23 | _ driver.QueryerContext = (*connection)(nil)
24 | )
25 |
26 | func (c *connection) Prepare(query string) (driver.Stmt, error) {
27 |
28 | prep, err := c.session.PrepareStatement(context.Background(), query)
29 | if err != nil {
30 | return nil, err
31 | }
32 |
33 | s := &statement{
34 | prepared: prep,
35 | }
36 |
37 | return s, nil
38 | }
39 |
40 | func (c *connection) Close() error {
41 | if c.session != nil {
42 | return c.session.Close()
43 | }
44 | return nil
45 | }
46 |
47 | func (c *connection) Begin() (driver.Tx, error) {
48 | return c.BeginTx(context.Background(), driver.TxOptions{})
49 | }
50 |
51 | func (c *connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
52 | q := "START TRANSACTION"
53 | if opts.ReadOnly {
54 | q += q + " READ ONLY"
55 | }
56 |
57 | if _, err := c.session.ExecuteStatement(ctx, q); err != nil {
58 | return nil, err
59 | }
60 |
61 | return &Transaction{session: c.session}, nil
62 | }
63 |
64 | func (c *connection) Ping(ctx context.Context) error {
65 | if c.session == nil {
66 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn)
67 | }
68 |
69 | if _, err := c.session.SessionID(ctx); err != nil {
70 | return fmt.Errorf("ping failed (%w)", err)
71 | }
72 |
73 | return nil
74 | }
75 |
76 | func (c *connection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
77 | prep, err := c.session.PrepareStatement(context.Background(), query)
78 | if err != nil {
79 | return nil, handleError(err)
80 | }
81 |
82 | defer func() { _ = prep.Deallocate(ctx) }()
83 |
84 | stmt := &statement{
85 | prepared: prep,
86 | }
87 |
88 | return stmt.ExecContext(ctx, args)
89 | }
90 |
91 | func (c *connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
92 | prep, err := c.session.PrepareStatement(context.Background(), query)
93 | if err != nil {
94 | return nil, handleError(err)
95 | }
96 |
97 | defer func() { _ = prep.Deallocate(ctx) }()
98 |
99 | stmt := &statement{
100 | prepared: prep,
101 | }
102 |
103 | return stmt.QueryContext(ctx, args)
104 | }
105 |
--------------------------------------------------------------------------------
/connection_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql_test
4 |
5 | import (
6 | "context"
7 | "database/sql"
8 | "errors"
9 | "fmt"
10 | "testing"
11 | "time"
12 |
13 | "github.com/golistic/xgo/xt"
14 |
15 | "github.com/golistic/pxmysql/mysqlerrors"
16 | )
17 |
18 | func TestConnection_Begin(t *testing.T) {
19 | dsn := getTCPDSN()
20 | db, err := sql.Open("pxmysql", dsn)
21 | xt.OK(t, err)
22 | defer func() { _ = db.Close() }()
23 |
24 | tbl := "t29dkckiidk"
25 |
26 | _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl))
27 | xt.OK(t, err)
28 | _, err = db.Exec(fmt.Sprintf("CREATE TABLE `%s` (id INT PRIMARY KEY, c1 INT)", tbl))
29 | xt.OK(t, err)
30 |
31 | t.Run("start transaction and commit", func(t *testing.T) {
32 | tx, err := db.Begin()
33 | xt.OK(t, err)
34 |
35 | id := 1
36 | exp := 123
37 |
38 | stmtInsert := fmt.Sprintf("INSERT INTO `%s` (id, c1) VALUES (?, ?)", tbl)
39 | result, err := tx.Exec(stmtInsert, id, exp)
40 | xt.OK(t, err)
41 | affected, err := result.RowsAffected()
42 | xt.OK(t, err)
43 | xt.Eq(t, 1, affected)
44 |
45 | xt.OK(t, tx.Commit())
46 |
47 | q := fmt.Sprintf("SELECT c1 FROM `%s` WHERE id = ?", tbl)
48 | var have int
49 | xt.OK(t, db.QueryRowContext(context.Background(), q, id).Scan(&have))
50 | xt.Eq(t, exp, have)
51 | })
52 |
53 | t.Run("start transaction and rollback", func(t *testing.T) {
54 | tx, err := db.Begin()
55 | xt.OK(t, err)
56 |
57 | id := 2
58 | value := 987
59 |
60 | stmtInsert := fmt.Sprintf("INSERT INTO `%s` (id, c1) VALUES (?, ?)", tbl)
61 | result, err := tx.Exec(stmtInsert, id, value)
62 | xt.OK(t, err)
63 | affected, err := result.RowsAffected()
64 | xt.OK(t, err)
65 | xt.Eq(t, 1, affected)
66 |
67 | xt.OK(t, tx.Rollback())
68 |
69 | q := fmt.Sprintf("SELECT c1 FROM `%s` WHERE id = ?", tbl)
70 | rows, err := db.QueryContext(context.Background(), q, id)
71 | xt.OK(t, err)
72 | xt.Assert(t, !rows.Next())
73 | })
74 | }
75 |
76 | func TestConnection_ExecContext(t *testing.T) {
77 | t.Run("respect timeout", func(t *testing.T) {
78 | dsn := getTCPDSN()
79 | db, err := sql.Open("pxmysql", dsn)
80 | xt.OK(t, err)
81 | defer func() { _ = db.Close() }()
82 |
83 | ctx := context.Background()
84 | ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
85 | defer cancel()
86 |
87 | _, err = db.ExecContext(ctx, "SELECT SLEEP(5)")
88 | xt.KO(t, err)
89 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error())
90 | })
91 |
92 | t.Run("prepared statement should close using Query", func(t *testing.T) {
93 | dsn := getTCPDSN()
94 | db, err := sql.Open("pxmysql", dsn)
95 | xt.OK(t, err)
96 | defer func() { _ = db.Close() }()
97 |
98 | needle := "SDFIciwkdixks"
99 | stmt := "/* " + needle + " */ SELECT COUNT(*) FROM performance_schema.prepared_statements_instances WHERE SQL_TEXT LIKE ?"
100 | needleParam := "%" + needle + "%"
101 |
102 | for i := 0; i < 2; i++ {
103 | var got int
104 | xt.OK(t, db.QueryRowContext(context.Background(), stmt, needleParam).Scan(&got))
105 | xt.Eq(t, 1, got) // 1 because the query is seeing itself
106 | }
107 | })
108 |
109 | t.Run("prepared statement should close using Exec", func(t *testing.T) {
110 | dsn := getTCPDSN()
111 | db, err := sql.Open("pxmysql", dsn)
112 | xt.OK(t, err)
113 | defer func() { _ = db.Close() }()
114 |
115 | needle := "owicIOwidols"
116 | stmt := "/* " + needle + " */ DELETE FROM mysql.user WHERE user = 'nobody'"
117 | needleParam := "%" + needle + "%"
118 |
119 | selectStmt := "SELECT COUNT(*) FROM performance_schema.prepared_statements_instances WHERE SQL_TEXT LIKE ?"
120 |
121 | for i := 0; i < 2; i++ {
122 | var got int
123 | _, err := db.ExecContext(context.Background(), stmt)
124 | xt.OK(t, err)
125 | xt.OK(t, db.QueryRowContext(context.Background(), selectStmt, needleParam).Scan(&got))
126 | xt.Eq(t, 0, got)
127 | }
128 | })
129 | }
130 |
131 | func TestConnection_QueryContext(t *testing.T) {
132 | dsn := getTCPDSN()
133 | db, err := sql.Open("pxmysql", dsn)
134 | xt.OK(t, err)
135 | defer func() { _ = db.Close() }()
136 |
137 | t.Run("respect timeout", func(t *testing.T) {
138 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
139 | defer cancel()
140 |
141 | _, err := db.QueryContext(ctx, "SELECT SLEEP(5)")
142 | xt.KO(t, err)
143 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error())
144 | })
145 | }
146 |
147 | func TestConnector_Connect(t *testing.T) {
148 | t.Run("server closing stale connection and reconnect", func(t *testing.T) {
149 | dsn := getTCPDSN()
150 | db, err := sql.Open("pxmysql", dsn)
151 | xt.OK(t, err)
152 |
153 | _, err = db.Exec("SET @@SESSION.mysqlx_wait_timeout = 2")
154 | xt.OK(t, err)
155 |
156 | var cnxID int
157 | xt.OK(t, db.QueryRow("SELECT CONNECTION_ID()").Scan(&cnxID))
158 |
159 | var n string
160 | var v string
161 | xt.OK(t, db.QueryRow("SHOW SESSION VARIABLES LIKE 'mysqlx_wait_timeout'").Scan(&n, &v))
162 |
163 | time.Sleep(3 * time.Second) // server should close connection
164 |
165 | var cnxIDAfter int
166 | xt.OK(t, db.QueryRow("SELECT CONNECTION_ID()").Scan(&cnxIDAfter))
167 |
168 | xt.Assert(t, cnxID != cnxIDAfter)
169 | })
170 | }
171 |
--------------------------------------------------------------------------------
/connector.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 |
9 | "github.com/golistic/pxmysql/xmysql"
10 | )
11 |
12 | type connector struct {
13 | dataSource DataSource
14 | }
15 |
16 | var _ driver.Connector = &connector{}
17 |
18 | func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
19 |
20 | // dataSource at this point is valid
21 | config := &xmysql.ConnectConfig{
22 | UseTLS: c.dataSource.UseTLS,
23 | AuthMethod: xmysql.AuthMethodAuto,
24 | Username: c.dataSource.User,
25 | Schema: c.dataSource.Schema,
26 | }
27 | config.SetPassword(c.dataSource.Password)
28 |
29 | switch c.dataSource.Protocol {
30 | case "unix":
31 | config.UnixSockAddr = c.dataSource.Address
32 | case "tcp":
33 | config.Address = c.dataSource.Address
34 | }
35 |
36 | ses, err := xmysql.GetSession(ctx, config)
37 | if err != nil {
38 | return nil, err
39 | }
40 |
41 | return &connection{
42 | cfg: config,
43 | session: ses,
44 | }, nil
45 | }
46 |
47 | func (c connector) Driver() driver.Driver {
48 | return &Driver{}
49 | }
50 |
--------------------------------------------------------------------------------
/datasource.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "fmt"
7 |
8 | "github.com/golistic/xgo/xconv"
9 | "github.com/golistic/xgo/xsql"
10 | )
11 |
12 | // DataSource defines the configuration of the connection. It embeds xsql.DataSource
13 | // and extends it with attributes defined in the Options.
14 | type DataSource struct {
15 | xsql.DataSource
16 |
17 | UseTLS bool
18 | }
19 |
20 | // NewDataSource instantiates a DataSource using the Data Source Name (DSN).
21 | func NewDataSource(name string) (DataSource, error) {
22 | xds, err := xsql.ParseDSN(name)
23 | if err != nil {
24 | return DataSource{}, err
25 | }
26 |
27 | ds := DataSource{
28 | DataSource: *xds,
29 | UseTLS: false,
30 | }
31 |
32 | if err := ds.handleOptions(); err != nil {
33 | return DataSource{}, err
34 | }
35 |
36 | if err := ds.CheckValidity(); err != nil {
37 | return DataSource{}, fmt.Errorf("configuration not valid (%w)", err)
38 | }
39 |
40 | return ds, nil
41 | }
42 |
43 | // CheckValidity returns whether the DataSource has enough configuration to establish
44 | // a connection. Needed are the address, protocol, and username.
45 | func (ds *DataSource) CheckValidity() error {
46 | switch {
47 | case ds.Address == "":
48 | return fmt.Errorf("address missing")
49 | case ds.User == "":
50 | return fmt.Errorf("user missing")
51 | case ds.Protocol == "":
52 | return fmt.Errorf("protocol missing")
53 | default:
54 | return nil
55 | }
56 | }
57 |
58 | func (ds *DataSource) handleOptions() error {
59 | var err error
60 | useTLS := ds.Options.Get("useTLS")
61 | if useTLS != "" {
62 | ds.UseTLS, err = xconv.ParseBool(useTLS)
63 | if err != nil {
64 | return fmt.Errorf("invalid value for useTLS option (was %s)", useTLS)
65 | }
66 | }
67 |
68 | return nil
69 | }
70 |
--------------------------------------------------------------------------------
/datasource_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "errors"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xsql"
10 | "github.com/golistic/xgo/xt"
11 | )
12 |
13 | func TestDataSource_IsZero(t *testing.T) {
14 | t.Run("zero when username missing", func(t *testing.T) {
15 | _, err := NewDataSource(":pwd@tcp(127.0.0.1)")
16 | xt.KO(t, err)
17 | xt.Eq(t, "user missing", errors.Unwrap(err).Error())
18 | })
19 |
20 | t.Run("zero when address missing", func(t *testing.T) {
21 | ds := DataSource{
22 | DataSource: xsql.DataSource{
23 | Driver: "pxmysql",
24 | User: "user",
25 | Password: "",
26 | Protocol: "tcp",
27 | Address: "",
28 | Schema: "",
29 | Options: nil,
30 | },
31 | }
32 | err := ds.CheckValidity()
33 | xt.KO(t, err)
34 | xt.Eq(t, "address missing", err.Error())
35 | })
36 |
37 | t.Run("protocol missing", func(t *testing.T) {
38 | ds := DataSource{
39 | DataSource: xsql.DataSource{
40 | Driver: "pxmysql",
41 | User: "user",
42 | Password: "",
43 | Protocol: "",
44 | Address: "127.0.0.1",
45 | Schema: "",
46 | Options: nil,
47 | },
48 | }
49 | err := ds.CheckValidity()
50 | xt.KO(t, err)
51 | xt.Eq(t, "protocol missing", err.Error())
52 | })
53 | }
54 |
55 | func TestNewDataSource(t *testing.T) {
56 | t.Run("invalid useTLS option value", func(t *testing.T) {
57 | _, err := NewDataSource("user:pwd@tcp(127.0.0.1)/?useTLS=nope")
58 | xt.KO(t, err)
59 | xt.Eq(t, "invalid value for useTLS option (was nope)", err.Error())
60 | })
61 | }
62 |
--------------------------------------------------------------------------------
/driver.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | )
9 |
10 | type Driver struct{}
11 |
12 | var (
13 | _ driver.Driver = &Driver{}
14 | _ driver.DriverContext = &Driver{}
15 | )
16 |
17 | // Open returns a new connection to the MySQL database using MySQL X Protocol.
18 | func (d *Driver) Open(name string) (driver.Conn, error) {
19 | c, err := d.OpenConnector(name)
20 | if err != nil {
21 | return nil, err
22 | }
23 |
24 | return c.Connect(context.Background())
25 | }
26 |
27 | // OpenConnector returns a connector which will be used by sql.DB to open a connection
28 | // to the MySQL database using MySQL X Protocol.
29 | // This will be used instead of the Open-method (which actually uses this method).
30 | func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
31 | ds, err := NewDataSource(name)
32 | if err != nil {
33 | return nil, err
34 | }
35 |
36 | return &connector{
37 | dataSource: ds,
38 | }, nil
39 | }
40 |
--------------------------------------------------------------------------------
/driver_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql_test
4 |
5 | import (
6 | "database/sql"
7 | "errors"
8 | "fmt"
9 | "os"
10 | "testing"
11 |
12 | "github.com/golistic/xgo/xstrings"
13 | "github.com/golistic/xgo/xt"
14 |
15 | "github.com/golistic/pxmysql"
16 | "github.com/golistic/pxmysql/internal/xxt"
17 | "github.com/golistic/pxmysql/mysqlerrors"
18 | "github.com/golistic/pxmysql/register"
19 | )
20 |
21 | func TestSQLDriver_Open(t *testing.T) {
22 | pwd := "aPassword"
23 | users := []string{"userfkEivks", "userFcae283"}
24 |
25 | for _, u := range users {
26 | _ = testContext.Server.DropUser(u)
27 | xt.OK(t, testContext.Server.CreateUser(u, pwd, testSchema, xxt.AuthPluginNative))
28 | }
29 |
30 | defer func() {
31 | for _, u := range users {
32 | _ = testContext.Server.DropUser(u)
33 | }
34 | }()
35 |
36 | t.Run("valid data source names", func(t *testing.T) {
37 | var cases = map[string]string{
38 | "no query": fmt.Sprintf("%s:%s@tcp(%s)/%s",
39 | users[0], pwd, testContext.XPluginAddr, testSchema),
40 | "no schema": fmt.Sprintf("%s:%s@tcp(%s)/?useTLS=true",
41 | users[1], pwd, testContext.XPluginAddr),
42 | }
43 |
44 | for cn, dsn := range cases {
45 | t.Run(cn, func(t *testing.T) {
46 | drv := &pxmysql.Driver{}
47 | _, err := drv.Open(dsn)
48 | xt.OK(t, err)
49 | })
50 | }
51 | })
52 |
53 | t.Run("retrieve LastInsertID after insert", func(t *testing.T) {
54 | tbl := "test_AFiek23eeF"
55 | q := fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl)
56 | _, err := testContext.Server.ExecSQLStmt(q)
57 | xt.OK(t, err)
58 |
59 | _, err = testContext.Server.ExecSQLStmt("CREATE TABLE " + tbl +
60 | " (id INT AUTO_INCREMENT PRIMARY KEY, t1 INT)")
61 | xt.OK(t, err)
62 |
63 | dsn := getTCPDSN("", "")
64 | db, err := sql.Open("pxmysql", dsn)
65 |
66 | t.Run("using Prepared Statement", func(t *testing.T) {
67 | xt.OK(t, err)
68 | q := fmt.Sprintf("INSERT INTO `%s` (`t1`) VALUES (?)", tbl)
69 | stmt, err := db.Prepare(q)
70 | xt.OK(t, err)
71 | res, err := stmt.Exec("45")
72 | xt.OK(t, err)
73 | lastID, err := res.LastInsertId()
74 | xt.OK(t, err)
75 | xt.Eq(t, 1, lastID)
76 | })
77 |
78 | t.Run("executing INSERT directly", func(t *testing.T) {
79 | xt.OK(t, err)
80 | q := fmt.Sprintf("INSERT INTO `%s` (`t1`) VALUES (?)", tbl)
81 | res, err := db.Exec(q, 46)
82 | xt.OK(t, err)
83 | lastID, err := res.LastInsertId()
84 | xt.OK(t, err)
85 | xt.Eq(t, 2, lastID)
86 | })
87 | })
88 |
89 | t.Run("using Unix socket", func(t *testing.T) {
90 | // runs app within Container; will not add to coverage
91 | app := "unix_socket"
92 | _, err := testContext.Builder.App(app)
93 | xt.OK(t, err)
94 |
95 | out, err := testContext.Server.ExecApp("/shared/builds/" + app)
96 | xt.OK(t, err)
97 | xt.Eq(t, testContext.Server.Version, string(out))
98 | })
99 |
100 | t.Run("unsupported protocol", func(t *testing.T) {
101 | drv := &pxmysql.Driver{}
102 | _, err := drv.Open("scott:tiger@UDP(localhost)/")
103 | xt.KO(t, err)
104 | xt.Eq(t, "unsupported protocol 'UDP'", errors.Unwrap(err).Error())
105 | })
106 |
107 | t.Run("not enough configured with missing username", func(t *testing.T) {
108 | drv := &pxmysql.Driver{}
109 | _, err := drv.Open(":tiger@tcp(localhost)/")
110 | xt.KO(t, err)
111 | xt.Eq(t, "configuration not valid (user missing)", err.Error())
112 | xt.Eq(t, "user missing", errors.Unwrap(err).Error())
113 | })
114 | }
115 |
116 | func TestConnection_Ping(t *testing.T) {
117 | t.Run("using TCP", func(t *testing.T) {
118 | db, err := sql.Open(register.DriverName, getTCPDSN())
119 | defer func() { _ = db.Close() }()
120 | xt.OK(t, err)
121 | xt.OK(t, db.Ping())
122 | xt.Eq(t, "tcp", cnxType(t, db))
123 | })
124 |
125 | t.Run("using non-existing Unix socket", func(t *testing.T) {
126 | drv := &pxmysql.Driver{}
127 | os.TempDir()
128 | _, err := drv.Open("username:pwd@unix(_testdata/mysqlx.sock)/myschema")
129 | xt.KO(t, err)
130 | xt.Eq(t, mysqlerrors.ClientBadUnixSocket, err.(*mysqlerrors.Error).Code)
131 | xt.KO(t, errors.Unwrap(err))
132 | xt.Eq(t, "no such file or directory", errors.Unwrap(err).Error())
133 | })
134 | }
135 |
136 | func TestDriver_Open(t *testing.T) {
137 | t.Run("pxmysql is registered", func(t *testing.T) {
138 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql to be registered")
139 | })
140 |
141 | t.Run("mysql is not registered", func(t *testing.T) {
142 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql to be NOT registered")
143 | })
144 | }
145 |
--------------------------------------------------------------------------------
/errors.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "errors"
7 | "os"
8 |
9 | "github.com/golistic/pxmysql/mysqlerrors"
10 | )
11 |
12 | func handleError(err error) error {
13 | if errors.Is(err, os.ErrDeadlineExceeded) {
14 | return mysqlerrors.ErrContextDeadlineExceeded
15 | }
16 |
17 | return err
18 | }
19 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/golistic/pxmysql
2 |
3 | go 1.21
4 |
5 | toolchain go1.21.0
6 |
7 | require (
8 | github.com/golistic/gomake v0.9.4
9 | github.com/golistic/xgo v1.0.0
10 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
11 | golang.org/x/term v0.11.0
12 | google.golang.org/protobuf v1.31.0
13 | )
14 |
15 | require (
16 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9 // indirect
17 | golang.org/x/mod v0.12.0 // indirect
18 | golang.org/x/sys v0.11.0 // indirect
19 | )
20 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
2 | github.com/golistic/gomake v0.9.4 h1:6cOBeAsbSnacsskMvyg49DjCRCE9WnyhJv4Viy9rF9M=
3 | github.com/golistic/gomake v0.9.4/go.mod h1:IiYQuN6aK4Jb3QLA/rfOpeOqG7s8DNfBzVXZt8Pf6PA=
4 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9 h1:NBSzSvgVJjhI37ytLLcHLkRA7UTR/P1L/0dm6EY4GoE=
5 | github.com/golistic/shieldbadger v0.0.0-20230223210348-5649a4ba6aa9/go.mod h1:9Ad43QCHXG87MxOV2rB668vLcEV/lC4qJZhs+VdEXJQ=
6 | github.com/golistic/xgo v1.0.0 h1:NtwCDhGBJR0vOvn0l8S8z9wzmN3hYRM8Wwwd1+2BFO0=
7 | github.com/golistic/xgo v1.0.0/go.mod h1:em3spZJ1b8mrGv1P5Fx2Ewclaf0rqIfSwrzYIVnL6o4=
8 | github.com/golistic/xt v1.0.1 h1:prcwpL757GEu+dj6x2v6vxHkulkFdyQ3S1PqO+6ohPM=
9 | github.com/golistic/xt v1.0.1/go.mod h1:j1ZuWefyOD4HegoapgSjbXanA7X9YOKUeEB5gsEPgYg=
10 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
11 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
12 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
13 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
14 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
15 | golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
16 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
17 | golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
18 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
19 | golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0=
20 | golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
21 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
22 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
23 | google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
24 | google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
25 |
--------------------------------------------------------------------------------
/interfaces/message.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package interfaces
4 |
5 | import (
6 | "google.golang.org/protobuf/proto"
7 |
8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx"
9 | )
10 |
11 | type ServerMessager interface {
12 | Unmarshall(message proto.Message) error
13 | ServerMessageType() mysqlx.ServerMessages_Type
14 | }
15 |
--------------------------------------------------------------------------------
/internal/mysqlx/info.md:
--------------------------------------------------------------------------------
1 | Generated from MySQL Server 8.0.34 at 2023-08-09T15:38:05Z.
2 |
--------------------------------------------------------------------------------
/internal/xxt/builder.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "fmt"
7 | "strings"
8 | )
9 |
10 | type GoBuilder struct {
11 | tctx *TestContext
12 | Container *Container
13 | }
14 |
15 | func NewGoBuilder(tctx *TestContext, container *Container) (*GoBuilder, error) {
16 | gb := &GoBuilder{
17 | Container: container,
18 | tctx: tctx,
19 | // Schema is stored at the end
20 | }
21 |
22 | _, err := gb.goVersion()
23 | if err != nil {
24 | return nil, fmt.Errorf("failed getting Go version (%w)", err)
25 | }
26 |
27 | return gb, nil
28 | }
29 |
30 | // App takes the application name which is located in the container's
31 | // shared volume located at "/shared".
32 | func (gb GoBuilder) App(name string) ([]byte, error) {
33 | args := []string{
34 | "exec", "-i", gb.Container.Name,
35 | "sh", "/shared/build.sh", name,
36 | }
37 |
38 | return gb.Container.run(args...)
39 | }
40 |
41 | func (gb GoBuilder) goVersion() (string, error) {
42 | args := []string{
43 | "exec", "-i", gb.Container.Name,
44 | "go", "version",
45 | }
46 |
47 | buf, err := gb.Container.run(args...)
48 | if err != nil {
49 | return "", err
50 | }
51 |
52 | version := string(buf)
53 | version = strings.Replace(version, "go version go", "", 1)
54 |
55 | return version, nil
56 | }
57 |
--------------------------------------------------------------------------------
/internal/xxt/context.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "fmt"
7 | "os"
8 | "regexp"
9 | )
10 |
11 | var reVersion = regexp.MustCompile(`(\d+)\.(\d+).(\d+)`)
12 |
13 | const minMySQLVersion = 8000028
14 | const minMySQLVersionStr = "8.0.28"
15 |
16 | // container names are defined in _support/pxmysql-compose/docker-compose.yml
17 | const (
18 | defaultDockerContainer = "pxmysql.test.db"
19 | defaultDockerContainerGo = "pxmysql.test.go"
20 | )
21 |
22 | const (
23 | defaultDockerExec = "docker"
24 | defaultDockerXPluginAddr = "127.0.0.1:53360"
25 | defaultDockerMySQLAddr = "127.0.0.1:53306"
26 | defaultDockerMySQLRootPwd = "rootpwd"
27 | )
28 |
29 | const (
30 | AuthPluginNative = "mysql_native_password"
31 | AuthPluginCachedSha2 = "caching_sha2_password"
32 | )
33 |
34 | type TestContext struct {
35 | MySQLRootPwd string
36 | XPluginAddr string
37 | MySQLAddr string
38 | Server *MySQLServer
39 | Builder *GoBuilder
40 | }
41 |
42 | func New(schema string) (*TestContext, error) {
43 | dbContainerName := defaultDockerContainer
44 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_CONTAINER"); have {
45 | dbContainerName = v
46 | }
47 |
48 | goContainerName := defaultDockerContainerGo
49 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_CONTAINER_GO"); have {
50 | goContainerName = v
51 | }
52 |
53 | var err error
54 | tctx := &TestContext{
55 | MySQLRootPwd: defaultDockerMySQLRootPwd,
56 | XPluginAddr: defaultDockerXPluginAddr,
57 | MySQLAddr: defaultDockerMySQLAddr,
58 | }
59 |
60 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_XPLUGIN_ADDR"); have {
61 | tctx.XPluginAddr = v
62 | }
63 |
64 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_MYSQL_ADDR"); have {
65 | tctx.MySQLAddr = v
66 | }
67 |
68 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_MYSQL_PWD"); have {
69 | tctx.MySQLRootPwd = v
70 | }
71 |
72 | dockerExec := defaultDockerExec
73 | if v, have := os.LookupEnv("PXMYSQL_TEST_DOCKER_EXEC"); have {
74 | dockerExec = v
75 | }
76 |
77 | dbContainer, err := NewContainer(dbContainerName, dockerExec)
78 | if err != nil {
79 | return nil, err
80 | }
81 |
82 | if err := dbContainer.CheckRunning(); err != nil {
83 | return nil, fmt.Errorf("make sure the Docker is available (set XMYSQL_TEST_DOCKER_EXEC?)"+
84 | " and container %s is running (%s)", dbContainerName, err)
85 | }
86 |
87 | goContainer, err := NewContainer(goContainerName, dockerExec)
88 | if err != nil {
89 | return nil, err
90 | }
91 |
92 | if err := goContainer.CheckRunning(); err != nil {
93 | return nil, fmt.Errorf("make sure the Docker is available (set XMYSQL_TEST_DOCKER_EXEC?)"+
94 | " and container %s is running (%s)", goContainerName, err)
95 | }
96 |
97 | tctx.Server, err = NewMySQLServer(tctx, dbContainer, schema)
98 | if err != nil {
99 | return nil, err
100 | }
101 |
102 | tctx.Builder, err = NewGoBuilder(tctx, goContainer)
103 | if err != nil {
104 | return nil, err
105 | }
106 |
107 | return tctx, nil
108 | }
109 |
--------------------------------------------------------------------------------
/internal/xxt/credentials.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | const (
6 | UserNative = "user_native"
7 | UserNativePwd = "pwd_user_native"
8 | )
9 |
10 | const (
11 | UserCachedSHA256 = "user_sha256"
12 | UserCachedSHA256Pwd = "pwd_user_sha256"
13 | )
14 |
--------------------------------------------------------------------------------
/internal/xxt/docker.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "bytes"
7 | "fmt"
8 | "io"
9 | "os/exec"
10 | "strings"
11 | )
12 |
13 | type Container struct {
14 | dockerExec string
15 | Name string
16 | }
17 |
18 | func NewContainer(name string, dockerExec string) (*Container, error) {
19 | return &Container{
20 | dockerExec: dockerExec,
21 | Name: name,
22 | }, nil
23 | }
24 |
25 | // CopyFileFromContainer copies a file from a Docker container.
26 | func (c Container) CopyFileFromContainer(srcPath, dstPath string) error {
27 | args := []string{
28 | "cp", c.Name + ":" + srcPath, dstPath,
29 | }
30 |
31 | if _, err := c.run(args...); err != nil {
32 | return err
33 | }
34 |
35 | return nil
36 | }
37 |
38 | // getDockerCmd searches for the Docker executable in the directories named by
39 | // the PATH environment variable.
40 | func (c Container) getDockerCmd(output io.Writer, args ...string) (*exec.Cmd, error) {
41 | if c.dockerExec == "" || c.dockerExec[0] != '/' {
42 | dockerExec, err := exec.LookPath("docker")
43 | if err != nil {
44 | return nil, err
45 | }
46 | c.dockerExec = dockerExec
47 | }
48 |
49 | if output == nil {
50 | output = io.Discard
51 | }
52 |
53 | return &exec.Cmd{
54 | Path: c.dockerExec,
55 | Args: append([]string{c.dockerExec}, args...),
56 | Stdout: output,
57 | Stderr: output,
58 | }, nil
59 | }
60 |
61 | // run executes the docker command using provided arguments.
62 | func (c Container) run(args ...string) ([]byte, error) {
63 | output := bytes.NewBuffer(nil)
64 |
65 | cmd, err := c.getDockerCmd(output, args...)
66 | if err != nil {
67 | return nil, err
68 | }
69 |
70 | err = cmd.Run()
71 | switch err.(type) {
72 | case *exec.ExitError:
73 | if err := getContainerExecError(output); err != nil {
74 | return nil, err
75 | }
76 |
77 | case error:
78 | return nil, err
79 | }
80 |
81 | buf, err := io.ReadAll(output)
82 | if err != nil {
83 | return nil, err
84 | }
85 |
86 | var res []byte
87 | for _, l := range strings.Split(string(buf), "\n") {
88 | if strings.Contains(l, "[Warning]") {
89 | continue
90 | }
91 |
92 | res = append(res, []byte(l)...)
93 | }
94 |
95 | return res, nil
96 | }
97 |
98 | // CheckRunning checks whether the container is running.
99 | func (c Container) CheckRunning() error {
100 | args := []string{
101 | "inspect", "-f", "'{{.State.Running}}'", c.Name,
102 | }
103 |
104 | _, err := c.run(args...)
105 | return err
106 | }
107 |
108 | func getContainerExecError(r io.Reader) error {
109 | buf, err := io.ReadAll(r)
110 | if err != nil {
111 | return err
112 | }
113 |
114 | for _, l := range strings.Split(string(buf), "\n") {
115 | if strings.Contains(l, "[Warning]") {
116 | continue
117 | }
118 |
119 | if strings.HasPrefix(l, "Error ") ||
120 | strings.HasPrefix(l, "error:") ||
121 | strings.Contains(l, "ERROR ") ||
122 | strings.Contains(l, "Error: ") ||
123 | strings.Contains(l, "[ERROR]") {
124 | return fmt.Errorf(l)
125 | }
126 |
127 | if strings.HasPrefix(l, "OCI runtime exec failed") {
128 | return fmt.Errorf(strings.Replace(l, "OCI runtime exec failed: ", "", -1))
129 | }
130 | }
131 |
132 | msg := bytes.Replace(buf, []byte("\n"), []byte("; "), -1)
133 |
134 | return fmt.Errorf(string(msg))
135 | }
136 |
--------------------------------------------------------------------------------
/internal/xxt/errors.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "errors"
7 | "fmt"
8 | "testing"
9 |
10 | "github.com/golistic/xgo/xt"
11 |
12 | "github.com/golistic/pxmysql/mysqlerrors"
13 | )
14 |
15 | func NewTestErr(err error, format string, a ...any) error {
16 | if err != nil {
17 | format += " (" + err.Error() + ")"
18 | }
19 | return fmt.Errorf(format, a...)
20 | }
21 |
22 | func AssertMySQLError(t *testing.T, err error, code int) {
23 | t.Helper()
24 |
25 | xt.KO(t, err)
26 | var errMySQL *mysqlerrors.Error
27 | xt.Assert(t, errors.As(err, &errMySQL))
28 | xt.Eq(t, code, errMySQL.Code)
29 | }
30 |
--------------------------------------------------------------------------------
/internal/xxt/memory.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "fmt"
7 | "runtime"
8 | )
9 |
10 | type MemoryUse struct {
11 | start *runtime.MemStats
12 | end *runtime.MemStats
13 | }
14 |
15 | func NewMemoryUse() *MemoryUse {
16 | m := &MemoryUse{}
17 | m.start = &runtime.MemStats{}
18 | m.end = &runtime.MemStats{}
19 |
20 | runtime.GC()
21 | runtime.ReadMemStats(m.start)
22 | return m
23 | }
24 |
25 | func (m *MemoryUse) Stop() {
26 | runtime.ReadMemStats(m.end)
27 | }
28 |
29 | func (m *MemoryUse) DiffAlloc() uint64 {
30 | return m.end.Alloc - m.start.Alloc
31 | }
32 |
33 | func (m MemoryUse) String() string {
34 | return fmt.Sprintf("DiffTotalAlloc = % 10d\tNumGC = % 5d\n",
35 | m.end.Alloc-m.start.Alloc, m.end.NumGC)
36 | }
37 |
--------------------------------------------------------------------------------
/internal/xxt/server.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xxt
4 |
5 | import (
6 | "bytes"
7 | "fmt"
8 | "io"
9 | "os"
10 | "path"
11 | "strconv"
12 | "strings"
13 | "sync"
14 |
15 | "github.com/golistic/xgo/xstrings"
16 | )
17 |
18 | type MySQLServer struct {
19 | tctx *TestContext
20 | Container *Container
21 | Schema string
22 | Version string
23 | }
24 |
25 | func NewMySQLServer(tctx *TestContext, container *Container, schema string) (*MySQLServer, error) {
26 | server := &MySQLServer{
27 | Container: container,
28 | tctx: tctx,
29 | // Schema is stored at the end
30 | }
31 |
32 | if _, err := server.ExecSQLStmt("DROP SCHEMA IF EXISTS " + schema); err != nil {
33 | return nil, err
34 | }
35 | if _, err := server.ExecSQLStmt("CREATE SCHEMA " + schema); err != nil {
36 | return nil, err
37 | }
38 |
39 | errMsg := "failed getting MySQL version running in container %s (%s)"
40 | if output, err := server.ExecSQLStmt("SELECT VERSION()"); err != nil {
41 | return nil, NewTestErr(err, errMsg, container.Name, err)
42 | } else {
43 | parts := reVersion.FindAllStringSubmatch(string(output), -1)
44 | if parts == nil {
45 | return nil, NewTestErr(err, errMsg, container.Name, "reVersion")
46 | }
47 |
48 | // simplistic way of checking the MySQL version
49 | vMaj, err := strconv.ParseInt(parts[0][1], 10, 64)
50 | if err != nil {
51 | return nil, NewTestErr(nil, errMsg, container.Name)
52 | }
53 |
54 | vMin, err := strconv.ParseInt(parts[0][2], 10, 64)
55 | if err != nil {
56 | return nil, NewTestErr(nil, errMsg, container.Name)
57 |
58 | }
59 |
60 | patch, err := strconv.ParseInt(parts[0][3], 10, 64)
61 | if err != nil {
62 | return nil, NewTestErr(nil, errMsg, container.Name)
63 | }
64 |
65 | v := vMaj*1000000 + vMin*1000 + patch
66 | if v < minMySQLVersion {
67 | return nil, NewTestErr(fmt.Errorf("MySQL version must be %s or greater", minMySQLVersionStr),
68 | errMsg, container.Name)
69 | }
70 |
71 | server.Version = fmt.Sprintf("%d.%d.%d", vMaj, vMin, patch)
72 | }
73 |
74 | server.Schema = schema
75 |
76 | return server, nil
77 | }
78 |
79 | // ExecSQLStmt executes the SQL stmt using the mysql CLI within the container.
80 | // This is not SQL-injection safe and is only used for testing.
81 | func (my MySQLServer) ExecSQLStmt(stmt string) ([]byte, error) {
82 | args := []string{
83 | "exec", "-i", my.Container.Name,
84 | "mysql", "-uroot", "-p" + my.tctx.MySQLRootPwd, "-NB", "-e", stmt,
85 | }
86 |
87 | if my.Schema != "" {
88 | args = append(args, []string{"-D", my.Schema}...)
89 | }
90 |
91 | return my.Container.run(args...)
92 | }
93 |
94 | // LoadSQLScript executes the statements from files provided as scripts
95 | // using the mysql CLI within the container.
96 | func (my MySQLServer) LoadSQLScript(scripts ...string) error {
97 | args := []string{
98 | "exec", "-i", my.Container.Name,
99 | "mysql", "-uroot", "-p" + my.tctx.MySQLRootPwd,
100 | }
101 |
102 | stderr := bytes.NewBuffer(nil)
103 |
104 | cmd, err := my.Container.getDockerCmd(stderr, args...)
105 | if err != nil {
106 | return err
107 | }
108 |
109 | stdin, err := cmd.StdinPipe()
110 | if err != nil {
111 | return err
112 | }
113 |
114 | wg := sync.WaitGroup{}
115 |
116 | var capturedErr error
117 | go func() {
118 | defer func() {
119 | _ = stdin.Close()
120 | wg.Done()
121 | }()
122 |
123 | for _, s := range scripts {
124 | if !strings.HasSuffix(s, ".sql") {
125 | s += ".sql"
126 | }
127 | p := path.Join("_testdata", s)
128 | sql, err := os.ReadFile(p)
129 | if err != nil {
130 | capturedErr = fmt.Errorf("failed reading SQL script %s (%s)", p, err)
131 | break
132 | }
133 |
134 | if _, err := io.WriteString(stdin, string(sql)); err != nil {
135 | capturedErr = fmt.Errorf("failed writing SQL script to STDIN %s (%s)", p, err)
136 | break
137 | }
138 | }
139 | }()
140 |
141 | wg.Add(1)
142 | if err := cmd.Run(); err != nil {
143 | if err := getContainerExecError(stderr); err != nil {
144 | capturedErr = err
145 | }
146 | }
147 |
148 | wg.Wait()
149 | return capturedErr
150 | }
151 |
152 | func (my MySQLServer) FlushPrivileges() error {
153 | args := []string{
154 | "exec", "-i", my.Container.Name,
155 | "mysqladmin", "-uroot", "-p" + my.tctx.MySQLRootPwd, "flush-privileges",
156 | }
157 |
158 | _, err := my.Container.run(args...)
159 | return err
160 | }
161 |
162 | func (my MySQLServer) Variable(scope, variable string) (string, error) {
163 | if !(scope == "global" || scope == "session") {
164 | panic("scope must be one of 'session' or 'global'")
165 | }
166 |
167 | output, err := my.ExecSQLStmt(fmt.Sprintf("SELECT @@%s.%s", scope, variable))
168 | if err != nil {
169 | return "", err
170 | }
171 |
172 | return string(output), nil
173 | }
174 |
175 | func (my MySQLServer) CreateUser(username, password, schema, authPlugin string) error {
176 | // this function is not SQL-injection-safe; only used for testing
177 |
178 | authPlugins := []string{AuthPluginNative, AuthPluginCachedSha2}
179 |
180 | if !xstrings.SliceHas(authPlugins, authPlugin) {
181 | panic("unsupported authMethod")
182 | }
183 |
184 | createUser := fmt.Sprintf("CREATE USER '%s'@'%%' IDENTIFIED WITH %s BY '%s'",
185 | username, authPlugin, password)
186 | grant := fmt.Sprintf("GRANT ALL ON %s.* TO '%s'@'%%'", schema, username)
187 |
188 | if _, err := my.ExecSQLStmt(createUser); err != nil {
189 | return err
190 | }
191 | if _, err := my.ExecSQLStmt(grant); err != nil {
192 | return err
193 | }
194 |
195 | return nil
196 | }
197 |
198 | func (my MySQLServer) DropUser(username string) error {
199 | // this function is not SQL-injection-safe; only used for testing
200 |
201 | dropUser := fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username)
202 |
203 | if _, err := my.ExecSQLStmt(dropUser); err != nil {
204 | return err
205 | }
206 | return nil
207 | }
208 |
209 | // ExecApp runs the application within the container found at path and returns its output.
210 | func (my MySQLServer) ExecApp(path string) ([]byte, error) {
211 | args := []string{
212 | "exec", "-i", my.Container.Name,
213 | path,
214 | }
215 |
216 | return my.Container.run(args...)
217 | }
218 |
--------------------------------------------------------------------------------
/main_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql_test
4 |
5 | import (
6 | "database/sql"
7 | "fmt"
8 | "os"
9 | "testing"
10 |
11 | "github.com/golistic/xgo/xt"
12 |
13 | _ "github.com/golistic/pxmysql/register" // registers pxmysql
14 |
15 | "github.com/golistic/pxmysql/internal/xxt"
16 | )
17 |
18 | var (
19 | testExitCode int
20 | testErr error
21 | testDockerContainer string
22 | testSchema = "pxmysqldriver_tests"
23 | testContext *xxt.TestContext
24 | )
25 |
26 | func testTearDown() {
27 | if testErr != nil {
28 | testExitCode = 1
29 | fmt.Println(testErr)
30 | }
31 | }
32 |
33 | func TestMain(m *testing.M) {
34 | defer func() { os.Exit(testExitCode) }()
35 | defer testTearDown()
36 |
37 | var err error
38 | if testContext, testErr = xxt.New(testSchema); err != nil {
39 | return
40 | }
41 |
42 | if err := testContext.Server.Container.CopyFileFromContainer(
43 | "/etc/mysql/conf.d/ca.pem", "_testdata/mysql_ca.pem"); err != nil {
44 | testErr = fmt.Errorf("failed copying MySQL CA certificate from container %s (%s)",
45 | testDockerContainer, err)
46 | return
47 | }
48 |
49 | testExitCode = m.Run()
50 | }
51 |
52 | func getCredentials(credentials ...string) (string, string) {
53 | username := "root"
54 | password := testContext.MySQLRootPwd
55 | if len(credentials) > 0 && credentials[0] != "" {
56 | username = credentials[0]
57 | }
58 | if len(credentials) > 1 && credentials[1] != "" {
59 | password = credentials[1]
60 | }
61 |
62 | return username, password
63 | }
64 |
65 | func getTCPDSN(credentials ...string) string {
66 | username, password := getCredentials(credentials...)
67 | return fmt.Sprintf("%s:%s@tcp(%s)/%s?useTLS=yes", username, password, testContext.XPluginAddr,
68 | testSchema)
69 | }
70 |
71 | func cnxType(t *testing.T, db *sql.DB) string {
72 | t.Helper()
73 |
74 | q := "SELECT IF(HOST='localhost', 'unix', 'tcp') As CnxType " +
75 | "FROM performance_schema.processlist WHERE ID = CONNECTION_ID()"
76 |
77 | var ct string
78 | err := db.QueryRow(q).Scan(&ct)
79 | xt.OK(t, err)
80 |
81 | return ct
82 | }
83 |
--------------------------------------------------------------------------------
/mysqlerrors/error.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package mysqlerrors
4 |
5 | import (
6 | "errors"
7 | "fmt"
8 | "strings"
9 | "unicode"
10 |
11 | "github.com/golistic/pxmysql/interfaces"
12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx"
13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice"
14 | )
15 |
16 | var (
17 | ErrContextDeadlineExceeded = errors.New("context deadline exceeded")
18 | )
19 |
20 | // Error holds information of MySQL returned error and is implementing
21 | // the Go error interface.
22 | type Error struct {
23 | Inner error
24 | Message string
25 | Code int
26 | SQLState string
27 | Severity int
28 | Parameters []any
29 | ExtraMessage string
30 | }
31 |
32 | // New instantiates an Error object with code being the MySQL
33 | // client or server error code. When the message contains placeholders (is
34 | // a format specifier), params are used to interpolate them.
35 | // When one of the params is an error, it is used as a wrapped error.
36 | // Panics when params contains more than one value which is error-type.
37 | func New(code int, params ...any) *Error {
38 | e, have := mysqlClientErrors[code]
39 | if !have {
40 | panic(fmt.Sprintf("error code %d not registered in mysqlClientErrors", code))
41 | }
42 |
43 | for _, param := range params {
44 | if p, ok := param.(error); ok {
45 | if e.Inner != nil {
46 | panic("only one parameter can be of error type")
47 | }
48 | e.Inner = p
49 | }
50 | }
51 |
52 | e.Parameters = params
53 | return &e
54 | }
55 |
56 | func (e *Error) Unwrap() error {
57 | return e.Inner
58 | }
59 |
60 | // Error is the string representation of the error. Messages look the same, but
61 | // they are differently formatted than the MySQL Client message to conform Go
62 | // best practices.
63 | func (e *Error) Error() string {
64 | msg := e.Message
65 |
66 | if len(e.Parameters) > 0 {
67 | if e.Inner != nil {
68 | msg = fmt.Errorf(e.Message, e.Parameters...).Error()
69 | } else {
70 | msg = fmt.Sprintf(e.Message, e.Parameters...)
71 | }
72 | }
73 |
74 | if e.ExtraMessage != "" {
75 | msg += " (" + e.ExtraMessage + ")"
76 | }
77 |
78 | if len(msg) > 2 {
79 | r := []rune(msg)
80 | if !strings.HasPrefix(msg, "MySQL") {
81 | r[0] = unicode.ToLower(r[0])
82 | }
83 | msg = string(r)
84 | }
85 |
86 | return fmt.Sprintf("%s [%d:%s]", msg, e.Code, e.SQLState)
87 | }
88 |
89 | // NewFromServerMessage takes msg and transforms it into an Error.
90 | func NewFromServerMessage(msg interfaces.ServerMessager) error {
91 | myErr := &mysqlx.Error{}
92 |
93 | if err := msg.Unmarshall(myErr); err != nil {
94 | return err
95 | }
96 |
97 | return &Error{
98 | Message: myErr.GetMsg(),
99 | Code: int(myErr.GetCode()),
100 | SQLState: myErr.GetSqlState(),
101 | Severity: int(myErr.GetSeverity()),
102 | }
103 | }
104 |
105 | // MySQLWarning holds information of MySQL returned warning and is implementing
106 | // the Go error interface.
107 | type MySQLWarning struct {
108 | Message string
109 | Code int
110 | Level string
111 | }
112 |
113 | // Error is the string representation of the warning, mimicking how MySQL would
114 | // show them.
115 | func (w *MySQLWarning) Error() string {
116 | return fmt.Sprintf("%s %d: %s", w.Level, w.Code, w.Message)
117 | }
118 |
119 | // NewFromWarning takes msg and transforms it into a MySQLWarning.
120 | func NewFromWarning(msg *mysqlxnotice.Warning) error {
121 | return &MySQLWarning{
122 | Level: msg.Level.String(),
123 | Message: msg.GetMsg(),
124 | Code: int(msg.GetCode()),
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/mysqlerrors/error_client.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package mysqlerrors
4 |
5 | // MySQL Client errors as found in the MySQL manual under
6 | // https://dev.mysql.com/doc/mysql-errors/8.0/en/client-error-reference.html.
7 | // Names have been altered so that they make more sense. For example,
8 | // CR_CONNECTION_ERROR became ClientBadUnixSocket.
9 | const (
10 | ClientUnknown = 2000
11 | ClientBadUnixSocket = 2002
12 | ClientBadTCPSocket = 2005
13 | ClientWrongProtocol = 2007
14 | ClientNetPacketTooLarge = 2020
15 | )
16 |
17 | var mysqlClientErrors = map[int]Error{
18 | ClientUnknown: { // 2000
19 | Message: "unknown error",
20 | Code: ClientUnknown,
21 | SQLState: "HY000",
22 | },
23 | ClientBadUnixSocket: { // 2002
24 | Message: "cannot connect to local MySQL server through socket '%s' (%w)",
25 | Code: ClientBadUnixSocket,
26 | SQLState: "HY000",
27 | },
28 | ClientBadTCPSocket: { // 2005
29 | Message: "unknown MySQL server host '%s' (%w)",
30 | Code: ClientBadTCPSocket,
31 | SQLState: "HY000",
32 | },
33 | ClientWrongProtocol: { // 2007
34 | Message: "wrong protocol",
35 | Code: ClientBadTCPSocket,
36 | SQLState: "HY000",
37 | },
38 | ClientNetPacketTooLarge: { // 2020
39 | Message: "got packet bigger than 'mysqlx_max_allowed_packet' bytes",
40 | Code: ClientNetPacketTooLarge,
41 | SQLState: "HY000",
42 | },
43 | }
44 |
--------------------------------------------------------------------------------
/mysqlerrors/error_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package mysqlerrors
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestMySLQError_Error(t *testing.T) {
12 | t.Run("formatted MySQL errors", func(t *testing.T) {
13 | myErr := &Error{
14 | Message: "Table 'test.no_such_table' doesn't exist", // shamelessly copy/pasted from MySQL docs
15 | Code: 1146,
16 | SQLState: "42S02",
17 | Severity: 1,
18 | }
19 |
20 | xt.Eq(t, "table 'test.no_such_table' doesn't exist [1146:42S02]", myErr.Error())
21 | })
22 | }
23 |
24 | func TestMySQLWarning_Error(t *testing.T) {
25 | t.Run("formatted as MySQL would do", func(t *testing.T) {
26 | myErr := &MySQLWarning{
27 | Message: "Data truncated for column 'b' at row 1", // shamelessly copy/pasted from MySQL docs
28 | Code: 1265,
29 | Level: "Warning",
30 | }
31 |
32 | xt.Eq(t, "Warning 1265: Data truncated for column 'b' at row 1", myErr.Error())
33 | })
34 | }
35 |
--------------------------------------------------------------------------------
/mysqlerrors/test_errors/mysqlerrors_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package test_errors
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "testing"
9 | "time"
10 |
11 | "github.com/golistic/xgo/xt"
12 |
13 | "github.com/golistic/pxmysql/xmysql"
14 | )
15 |
16 | func TestMySQLErrors(t *testing.T) {
17 | t.Run("wrapped errors", func(t *testing.T) {
18 | config := &xmysql.ConnectConfig{
19 | Address: "127.0.0.40",
20 | }
21 |
22 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
23 | defer cancel()
24 |
25 | _, err := xmysql.GetSession(ctx, config)
26 | xt.Eq(t, "unknown MySQL server host '127.0.0.40:33060' (i/o timeout) [2005:HY000]", err.Error())
27 | xt.Eq(t, "i/o timeout", errors.Unwrap(err).Error())
28 | })
29 | }
30 |
--------------------------------------------------------------------------------
/null/bytes.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "bytes"
7 | "database/sql/driver"
8 | "fmt"
9 | )
10 |
11 | // Bytes represents a []byte (any MySQL BINARY types) that may be NULL.
12 | // This is not available in Go's sql package, and does not implement the Scanner interface.
13 | type Bytes struct {
14 | Bytes []byte
15 | Valid bool
16 | }
17 |
18 | var _ driver.Valuer = &Bytes{}
19 | var _ Nullable = &Bytes{}
20 |
21 | // Compare returns whether value compares with the nullable Bytes.
22 | // It returns:
23 | // - true when Valid and stored Bytes is equal to value
24 | // - true when not Valid and value is nil
25 | // - false in any other case
26 | func (n Bytes) Compare(value any) bool {
27 | if !n.Valid && value != nil {
28 | return false
29 | }
30 |
31 | if value == nil {
32 | return !n.Valid
33 | }
34 |
35 | switch v := value.(type) {
36 | case []byte:
37 | return bytes.Equal(n.Bytes, v)
38 | case *[]byte:
39 | return bytes.Equal(n.Bytes, *v)
40 | default:
41 | panic(fmt.Sprintf("value must be []byte or *[]byte; not %T", value))
42 | }
43 | }
44 |
45 | // Value returns the value of n and implements the driver.Valuer
46 | // as well as Nullable interface.
47 | func (n Bytes) Value() (driver.Value, error) {
48 | if !n.Valid {
49 | return nil, nil
50 | }
51 | return n.Bytes, nil
52 | }
53 |
--------------------------------------------------------------------------------
/null/bytes_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xt"
10 | )
11 |
12 | func TestBytes_Compare(t *testing.T) {
13 | var cases = []struct {
14 | n Nullable
15 | value any
16 | exp bool
17 | }{
18 | {
19 | n: Bytes{Bytes: []byte("Sakila"), Valid: true},
20 | value: bytesPtr([]byte("Sakila")),
21 | exp: true,
22 | },
23 | {
24 | n: Bytes{Bytes: []byte("Go gopher"), Valid: true},
25 | value: bytesPtr([]byte("Sakila")), // supposed to be not 'Go gopher'
26 | exp: false,
27 | },
28 | {
29 | n: Bytes{Bytes: []byte("Sakila"), Valid: true},
30 | value: []byte("Sakila"),
31 | exp: true,
32 | },
33 | {
34 | n: Bytes{Bytes: []byte("Go gopher"), Valid: true},
35 | value: []byte("Sakila"), // supposed to be not 'Go gopher'
36 | exp: false,
37 | },
38 | {
39 | n: Bytes{Bytes: nil, Valid: false},
40 | value: nil,
41 | exp: true,
42 | },
43 | {
44 | n: Bytes{Bytes: nil, Valid: false},
45 | value: []byte{},
46 | exp: false,
47 | },
48 | }
49 |
50 | for _, c := range cases {
51 | t.Run("", func(t *testing.T) {
52 | xt.Eq(t, c.exp, c.n.Compare(c.value))
53 | })
54 | }
55 |
56 | t.Run("panics if value type is not supported", func(t *testing.T) {
57 | xt.Panics(t, func() {
58 | _ = Bytes{Bytes: nil, Valid: true}.Compare("str")
59 | })
60 | })
61 | }
62 |
63 | func TestBytes_Value(t *testing.T) {
64 | data := []byte("I am Data")
65 |
66 | t.Run("valid", func(t *testing.T) {
67 | nb := Bytes{Bytes: data, Valid: true}
68 | v, _ := nb.Value()
69 | d, ok := v.([]byte)
70 | xt.Assert(t, ok, fmt.Sprintf("expected []byte; got %T", v))
71 | xt.Eq(t, data, d)
72 | })
73 |
74 | t.Run("not valid", func(t *testing.T) {
75 | nb := Bytes{Bytes: data, Valid: false}
76 | v, _ := nb.Value()
77 | xt.Eq(t, nil, v, "expected nil")
78 | })
79 | }
80 |
--------------------------------------------------------------------------------
/null/decimal.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 |
9 | "github.com/golistic/pxmysql/decimal"
10 | )
11 |
12 | // Decimal represents a decimal.Decimal (MySQL DECIMAL type) that may be NULL.
13 | // This is similar to types provided by Go's sql package, and does not implement the Scanner interface.
14 | type Decimal struct {
15 | Decimal decimal.Decimal
16 | Valid bool
17 | }
18 |
19 | var _ driver.Valuer = &Decimal{}
20 | var _ Nullable = &Decimal{}
21 |
22 | // Compare returns whether value compares with the nullable Decimal.
23 | // It returns:
24 | // - true when Valid and stored Decimal is equal to value
25 | // - true when not Valid and value is nil
26 | // - false in any other case
27 | func (nd Decimal) Compare(value any) bool {
28 | if !nd.Valid && value != nil {
29 | return false
30 | }
31 |
32 | if value == nil {
33 | return !nd.Valid
34 | }
35 |
36 | switch v := value.(type) {
37 | case decimal.Decimal:
38 | return nd.Decimal.Equal(v)
39 | case *decimal.Decimal:
40 | return nd.Decimal.Equal(*v)
41 | default:
42 | panic(fmt.Sprintf("value is of unsupported type %T", value))
43 | }
44 | }
45 |
46 | // Value returns the value of n and implements the driver.Valuer
47 | // as well as Nullable interface.
48 | func (nd Decimal) Value() (driver.Value, error) {
49 | if !nd.Valid {
50 | return nil, nil
51 | }
52 | return nd.Decimal, nil
53 | }
54 |
--------------------------------------------------------------------------------
/null/decimal_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xt"
10 |
11 | "github.com/golistic/pxmysql/decimal"
12 | )
13 |
14 | func TestDecimal_Compare(t *testing.T) {
15 | dec1 := decimal.MustNew("8.56")
16 | dec2 := decimal.MustNew("19.469")
17 |
18 | t.Run("Decimal", func(t *testing.T) {
19 | var cases = []struct {
20 | n Nullable
21 | value any
22 | exp bool
23 | }{
24 | {
25 | n: Decimal{Decimal: *dec1, Valid: true},
26 | value: *dec1,
27 | exp: true,
28 | },
29 | {
30 | n: Decimal{Decimal: *dec2, Valid: true},
31 | value: *dec1,
32 | exp: false,
33 | },
34 | {
35 | n: Decimal{Decimal: *dec1, Valid: true},
36 | value: dec1,
37 | exp: true,
38 | },
39 | {
40 | n: Decimal{Decimal: *dec2, Valid: true},
41 | value: dec1,
42 | exp: false,
43 | },
44 | {
45 | n: Decimal{Decimal: *dec1, Valid: false},
46 | value: nil,
47 | exp: true,
48 | },
49 | {
50 | n: Decimal{Decimal: *dec1, Valid: false},
51 | value: dec1,
52 | exp: false,
53 | },
54 | }
55 |
56 | for _, c := range cases {
57 | t.Run("", func(t *testing.T) {
58 | xt.Eq(t, c.exp, c.n.Compare(c.value))
59 | })
60 | }
61 | })
62 |
63 | t.Run("panics if value type is not supported", func(t *testing.T) {
64 | xt.Panics(t, func() {
65 | _ = Decimal{Decimal: *decimal.Zero, Valid: true}.Compare("str")
66 | })
67 | })
68 | }
69 |
70 | func TestDecimal_Value(t *testing.T) {
71 | pi := decimal.MustNew("3.14")
72 |
73 | t.Run("valid", func(t *testing.T) {
74 | nd := Decimal{Decimal: *pi, Valid: true}
75 | v, _ := nd.Value()
76 | d, ok := v.(decimal.Decimal)
77 | xt.Assert(t, ok, fmt.Sprintf("expected decimal.Decimal; got %T", v))
78 | xt.Eq(t, *pi, d)
79 | })
80 |
81 | t.Run("not valid", func(t *testing.T) {
82 | nd := Decimal{Decimal: *pi, Valid: false}
83 | v, _ := nd.Value()
84 | xt.Eq(t, nil, v, "expected nil")
85 | })
86 | }
87 |
--------------------------------------------------------------------------------
/null/duration.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | "time"
9 | )
10 |
11 | // Duration represents a time.Duration (MySQL TIME type) that may be NULL.
12 | // This is not available in Go's sql package, and does not implement the Scanner interface.
13 | // Note that the sql.NullTime is for timestamps (which includes dates).
14 | type Duration struct {
15 | Duration time.Duration
16 | Valid bool
17 | }
18 |
19 | var _ driver.Valuer = &Duration{}
20 | var _ Nullable = &Duration{}
21 |
22 | // Compare returns whether value compares with the nullable Duration.
23 | // It returns:
24 | // - true when Valid and stored Duration is equal to value
25 | // - true when not Valid and value is nil
26 | // - false in any other case
27 | func (nd Duration) Compare(value any) bool {
28 | if !nd.Valid && value != nil {
29 | return false
30 | }
31 |
32 | if value == nil {
33 | return !nd.Valid
34 | }
35 |
36 | switch v := value.(type) {
37 | case time.Duration:
38 | return nd.Duration == v
39 | case *time.Duration:
40 | return nd.Duration == *v
41 | default:
42 | panic(fmt.Sprintf("value must be time.Duration or *time.Duration; not %T", value))
43 | }
44 | }
45 |
46 | // Value returns the value of n and implements the driver.Valuer
47 | // as well as Nullable interface.
48 | func (nd Duration) Value() (driver.Value, error) {
49 | if !nd.Valid {
50 | return nil, nil
51 | }
52 | return nd.Duration, nil
53 | }
54 |
--------------------------------------------------------------------------------
/null/duration_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 | "time"
8 |
9 | "github.com/golistic/xgo/xt"
10 | )
11 |
12 | func TestDuration_Compare(t *testing.T) {
13 | dur2d3h, _ := time.ParseDuration("2d3h")
14 | dur5h6m, _ := time.ParseDuration("5h6m")
15 |
16 | var cases = []struct {
17 | n Nullable
18 | value any
19 | exp bool
20 | }{
21 | {
22 | n: Duration{Duration: dur2d3h, Valid: true},
23 | value: &dur2d3h,
24 | exp: true,
25 | },
26 | {
27 | n: Duration{Duration: dur2d3h, Valid: true},
28 | value: &dur5h6m, // supposed to be different
29 | exp: false,
30 | },
31 | {
32 | n: Duration{Duration: dur2d3h, Valid: true},
33 | value: dur2d3h,
34 | exp: true,
35 | },
36 | {
37 | n: Duration{Duration: dur2d3h, Valid: true},
38 | value: dur5h6m, // supposed to be different
39 | exp: false,
40 | },
41 | {
42 | n: Duration{Valid: false},
43 | value: nil,
44 | exp: true,
45 | },
46 |
47 | {
48 | n: Duration{Valid: false},
49 | value: 0,
50 | exp: false,
51 | },
52 | }
53 |
54 | for _, c := range cases {
55 | t.Run("", func(t *testing.T) {
56 | xt.Eq(t, c.exp, c.n.Compare(c.value))
57 | })
58 | }
59 |
60 | t.Run("panics if value type is not supported", func(t *testing.T) {
61 | xt.Panics(t, func() {
62 | _ = Duration{Duration: 0, Valid: true}.Compare("str")
63 | })
64 | })
65 | }
66 |
67 | func TestDuration_Value(t *testing.T) {
68 | t.Run("valid", func(t *testing.T) {
69 | dur2d3h, _ := time.ParseDuration("2d3h")
70 | nd := Duration{Duration: dur2d3h, Valid: true}
71 | v, _ := nd.Value()
72 | d, ok := v.(time.Duration)
73 | xt.Assert(t, ok, "expected time.Duration")
74 | xt.Eq(t, dur2d3h, d)
75 | })
76 |
77 | t.Run("not valid", func(t *testing.T) {
78 | nd := Duration{Duration: 0, Valid: false}
79 | v, _ := nd.Value()
80 | xt.Eq(t, nil, v, "expected nil")
81 | })
82 | }
83 |
--------------------------------------------------------------------------------
/null/float32.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | )
9 |
10 | // Float32 represents a float32 (any MySQL FLOAT type) that may be NULL.
11 | // This is similar to sql.NullFloat64, and does not implement the Scanner interface.
12 | type Float32 struct {
13 | Float32 float32
14 | Valid bool
15 | }
16 |
17 | var _ driver.Valuer = &Float32{}
18 | var _ Nullable = &Float32{}
19 |
20 | // Compare returns whether value compares with the nullable Float32.
21 | // It returns:
22 | // - true when Valid and stored Float32 is equal to value
23 | // - true when not Valid and value is nil
24 | // - false in any other case
25 | func (nf Float32) Compare(value any) bool {
26 | if !nf.Valid && value != nil {
27 | return false
28 | }
29 |
30 | if value == nil {
31 | return !nf.Valid
32 | }
33 |
34 | switch v := value.(type) {
35 | case float32:
36 | return v == nf.Float32
37 | case *float32:
38 | return *v == nf.Float32
39 | default:
40 | panic(fmt.Sprintf("value is of unsupported type %T", value))
41 | }
42 | }
43 |
44 | // Value returns the value of n and implements the driver.Valuer
45 | // as well as Nullable interface.
46 | func (nf Float32) Value() (driver.Value, error) {
47 | if !nf.Valid {
48 | return nil, nil
49 | }
50 | return nf.Float32, nil
51 | }
52 |
--------------------------------------------------------------------------------
/null/float32_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestFloat32_Compare(t *testing.T) {
12 | t.Run("Float32", func(t *testing.T) {
13 | var cases = []struct {
14 | n Nullable
15 | value any
16 | exp bool
17 | }{
18 | {
19 | n: Float32{Float32: 8.56, Valid: true},
20 | value: float32(8.56),
21 | exp: true,
22 | },
23 | {
24 | n: Float32{Float32: 19.469, Valid: true},
25 | value: float32(8.56), // supposed to be not 19.469
26 | exp: false,
27 | },
28 | {
29 | n: Float32{Float32: 8.56, Valid: true},
30 | value: float32Ptr(8.56),
31 | exp: true,
32 | },
33 | {
34 | n: Float32{Float32: 19.469, Valid: true},
35 | value: float32Ptr(8.56), // supposed to be not 19.469
36 | exp: false,
37 | },
38 | {
39 | n: Float32{Float32: 898, Valid: false},
40 | value: nil,
41 | exp: true,
42 | },
43 | {
44 | n: Float32{Float32: 9, Valid: false},
45 | value: float32(9),
46 | exp: false,
47 | },
48 | }
49 |
50 | for _, c := range cases {
51 | t.Run("", func(t *testing.T) {
52 | xt.Eq(t, c.exp, c.n.Compare(c.value))
53 | })
54 | }
55 | })
56 |
57 | t.Run("panics if value type is not supported", func(t *testing.T) {
58 | xt.Panics(t, func() {
59 | _ = Float32{Float32: 0, Valid: true}.Compare("str")
60 | })
61 | })
62 | }
63 |
64 | func TestFloat32_Value(t *testing.T) {
65 | pi := float32(3.14)
66 |
67 | t.Run("valid", func(t *testing.T) {
68 | nt := Float32{Float32: pi, Valid: true}
69 | v, _ := nt.Value()
70 | d, ok := v.(float32)
71 | xt.Assert(t, ok, "expected float32")
72 | xt.Eq(t, pi, d)
73 | })
74 |
75 | t.Run("not valid", func(t *testing.T) {
76 | nd := Float32{Float32: pi, Valid: false}
77 | v, _ := nd.Value()
78 | xt.Eq(t, nil, v, "expected nil")
79 | })
80 | }
81 |
--------------------------------------------------------------------------------
/null/float64.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | )
9 |
10 | // Float64 represents a float64 (any MySQL float/double type) that may be NULL.
11 | // This is similar to sql.NullFloat64, and does not implement the Scanner interface.
12 | type Float64 struct {
13 | Float64 float64
14 | Valid bool
15 | }
16 |
17 | var _ driver.Valuer = &Float64{}
18 | var _ Nullable = &Float64{}
19 |
20 | // Compare returns whether value compares with the nullable Float64.
21 | // It returns:
22 | // - true when Valid and stored Float64 is equal to value
23 | // - true when not Valid and value is nil
24 | // - false in any other case
25 | func (nf Float64) Compare(value any) bool {
26 | if !nf.Valid && value != nil {
27 | return false
28 | }
29 |
30 | if value == nil {
31 | return !nf.Valid
32 | }
33 |
34 | switch v := value.(type) {
35 | case float64:
36 | return v == nf.Float64
37 | case *float64:
38 | return *v == nf.Float64
39 | default:
40 | panic(fmt.Sprintf("value is of unsupported type %T", value))
41 | }
42 | }
43 |
44 | // Value returns the value of n and implements the driver.Valuer
45 | // as well as Nullable interface.
46 | func (nf Float64) Value() (driver.Value, error) {
47 | if !nf.Valid {
48 | return nil, nil
49 | }
50 | return nf.Float64, nil
51 | }
52 |
--------------------------------------------------------------------------------
/null/float64_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestFloat64_Compare(t *testing.T) {
12 | t.Run("Float64", func(t *testing.T) {
13 | var cases = []struct {
14 | n Nullable
15 | value any
16 | exp bool
17 | }{
18 | {
19 | n: Float64{Float64: 8.56, Valid: true},
20 | value: float64Ptr(8.56),
21 | exp: true,
22 | },
23 | {
24 | n: Float64{Float64: 19.469, Valid: true},
25 | value: float64Ptr(8.56), // supposed to be not 19.469
26 | exp: false,
27 | },
28 | {
29 | n: Float64{Float64: 8.56, Valid: true},
30 | value: 8.56,
31 | exp: true,
32 | },
33 | {
34 | n: Float64{Float64: 19.469, Valid: true},
35 | value: 8.56, // supposed to be not 19.469
36 | exp: false,
37 | },
38 | {
39 | n: Float64{Float64: 898, Valid: false},
40 | value: nil,
41 | exp: true,
42 | },
43 | {
44 | n: Float64{Float64: 898, Valid: false},
45 | value: 898,
46 | exp: false,
47 | },
48 | }
49 |
50 | for _, c := range cases {
51 | t.Run("", func(t *testing.T) {
52 | xt.Eq(t, c.exp, c.n.Compare(c.value))
53 | })
54 | }
55 | })
56 |
57 | t.Run("panics if value type is not supported", func(t *testing.T) {
58 | xt.Panics(t, func() {
59 | _ = Float64{Float64: 0, Valid: true}.Compare("str")
60 | })
61 | })
62 | }
63 |
64 | func TestFloat64_Value(t *testing.T) {
65 | pi := 3.14
66 |
67 | t.Run("valid", func(t *testing.T) {
68 | nt := Float64{Float64: pi, Valid: true}
69 | v, _ := nt.Value()
70 | d, ok := v.(float64)
71 | xt.Assert(t, ok, "expected float64")
72 | xt.Eq(t, pi, d)
73 | })
74 |
75 | t.Run("not valid", func(t *testing.T) {
76 | nd := Float64{Float64: pi, Valid: false}
77 | v, _ := nd.Value()
78 | xt.Eq(t, nil, v, "expected nil")
79 | })
80 | }
81 |
--------------------------------------------------------------------------------
/null/int64.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 |
9 | "github.com/golistic/xgo/xconv"
10 | )
11 |
12 | // Int64 represents an int64 (any MySQL signed integral type) that may be NULL.
13 | // This is similar to sql.NullInt64, and does not implement the Scanner interface.
14 | type Int64 struct {
15 | Int64 int64
16 | Valid bool
17 | }
18 |
19 | var _ driver.Valuer = &Int64{}
20 | var _ Nullable = &Int64{}
21 |
22 | // Compare returns whether value compares with the nullable Duration.
23 | // It returns:
24 | // - true when Valid and stored Duration is equal to value
25 | // - true when not Valid and value is nil
26 | // - false in any other case
27 | func (ni Int64) Compare(value any) bool {
28 | if !ni.Valid && value != nil {
29 | return false
30 | }
31 |
32 | if value == nil {
33 | return !ni.Valid
34 | }
35 |
36 | switch v := value.(type) {
37 | case int64, int, int8, int16, int32:
38 | return xconv.SignedAsInt64(v) == ni.Int64
39 | case *int64, *int, *int8, *int16, *int32:
40 | return *xconv.SignedAsInt64Ptr(v) == ni.Int64
41 | default:
42 | panic(fmt.Sprintf("value is of unsupported type %T", value))
43 | }
44 | }
45 |
46 | // Value returns the value of n and implements the driver.Valuer
47 | // as well as Nullable interface.
48 | func (ni Int64) Value() (driver.Value, error) {
49 | if !ni.Valid {
50 | return nil, nil
51 | }
52 | return ni.Int64, nil
53 | }
54 |
--------------------------------------------------------------------------------
/null/int64_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestInt64_Compare(t *testing.T) {
12 | t.Run("Int64", func(t *testing.T) {
13 | var cases = []struct {
14 | n Nullable
15 | value any
16 | exp bool
17 | }{
18 | {
19 | n: Int64{Int64: -8, Valid: true},
20 | value: int64Ptr(-8),
21 | exp: true,
22 | },
23 | {
24 | n: Int64{Int64: -19, Valid: true},
25 | value: int64Ptr(-8), // supposed to be not 19
26 | exp: false,
27 | },
28 | {
29 | n: Int64{Int64: -8, Valid: true},
30 | value: int64(-8),
31 | exp: true,
32 | },
33 | {
34 | n: Int64{Int64: -19, Valid: true},
35 | value: int64(-8), // supposed to be not 19
36 | exp: false,
37 | },
38 | {
39 | n: Int64{Int64: -898, Valid: false},
40 | value: nil,
41 | exp: true,
42 | },
43 | {
44 | n: Int64{Int64: -898, Valid: false},
45 | value: 123,
46 | exp: false,
47 | },
48 | }
49 |
50 | for _, c := range cases {
51 | t.Run("", func(t *testing.T) {
52 | xt.Eq(t, c.exp, c.n.Compare(c.value))
53 | })
54 | }
55 | })
56 |
57 | t.Run("panics if value type is not supported", func(t *testing.T) {
58 | xt.Panics(t, func() {
59 | _ = Int64{Int64: 0, Valid: true}.Compare("str")
60 | })
61 | })
62 | }
63 |
64 | func TestInt64_Value(t *testing.T) {
65 | t.Run("valid", func(t *testing.T) {
66 | ni := Int64{Int64: 9, Valid: true}
67 | v, _ := ni.Value()
68 | d, ok := v.(int64)
69 | xt.Assert(t, ok, "expected int64")
70 | xt.Eq(t, 9, d)
71 | })
72 |
73 | t.Run("not valid", func(t *testing.T) {
74 | ni := Int64{Int64: 0, Valid: false}
75 | v, _ := ni.Value()
76 | xt.Eq(t, nil, v, "expected nil")
77 | })
78 | }
79 |
--------------------------------------------------------------------------------
/null/main.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import "database/sql/driver"
6 |
7 | type Nullable interface {
8 | Compare(any) bool
9 | Value() (driver.Value, error)
10 | }
11 |
12 | // Compare checks value against the value stored with Nullable.
13 | // It returns true when:
14 | // - nullable is valid and value of nullable is equal to given value,
15 | // - or when nullable is not valid (SQL NULL) and value is nil
16 | //
17 | // Panics when value cannot be used with given Nullable n.
18 | func Compare(n Nullable, value any) bool {
19 | return n.Compare(value)
20 | }
21 |
--------------------------------------------------------------------------------
/null/main_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | func uint64Ptr(v uint64) *uint64 {
6 | return &v
7 | }
8 |
9 | func bytesPtr(v []byte) *[]byte {
10 | return &v
11 | }
12 |
13 | func stringPtr(v string) *string {
14 | return &v
15 | }
16 |
17 | func float64Ptr(v float64) *float64 {
18 | return &v
19 | }
20 |
21 | func float32Ptr(v float32) *float32 {
22 | return &v
23 | }
24 |
25 | func int64Ptr(v int64) *int64 {
26 | return &v
27 | }
28 |
--------------------------------------------------------------------------------
/null/string.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | )
9 |
10 | // String represents as string (any MySQL CHAR-kind of data type) that may be NULL.
11 | // This is similar to sql.NullString, and does not implement the Scanner interface.
12 | type String struct {
13 | String string
14 | Valid bool
15 | }
16 |
17 | var _ driver.Valuer = &String{}
18 | var _ Nullable = &String{}
19 |
20 | // Compare returns whether value compares with the nullable String.
21 | // It returns:
22 | // - true when Valid and stored String is equal to value
23 | // - true when not Valid and value is nil
24 | // - false in any other case
25 | func (ns String) Compare(value any) bool {
26 | if !ns.Valid && value != nil {
27 | return false
28 | }
29 |
30 | if value == nil {
31 | return !ns.Valid
32 | }
33 |
34 | switch v := value.(type) {
35 | case string:
36 | return ns.String == v
37 | case *string:
38 | return ns.String == *v
39 | default:
40 | panic(fmt.Sprintf("value must be string or *string; not %T", value))
41 | }
42 | }
43 |
44 | // Value returns the value of n and implements the driver.Valuer
45 | // as well as Nullable interface.
46 | func (ns String) Value() (driver.Value, error) {
47 | if !ns.Valid {
48 | return nil, nil
49 | }
50 | return ns.String, nil
51 | }
52 |
--------------------------------------------------------------------------------
/null/string_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestString_Compare(t *testing.T) {
12 | var cases = []struct {
13 | n Nullable
14 | value any
15 | exp bool
16 | }{
17 | {
18 | n: String{String: "Sakila", Valid: true},
19 | value: "Sakila",
20 | exp: true,
21 | },
22 | {
23 | n: String{String: "Sakila", Valid: true},
24 | value: "Go gopher", // supposed to not include 'Go gopher'
25 | exp: false,
26 | },
27 | {
28 | n: String{String: "Sakila", Valid: true},
29 | value: stringPtr("Sakila"),
30 | exp: true,
31 | },
32 | {
33 | n: String{String: "Sakila", Valid: true},
34 | value: stringPtr("Go gopher"), // supposed to not include 'Go gopher'
35 | exp: false,
36 | },
37 | {
38 | n: String{Valid: false},
39 | value: nil,
40 | exp: true,
41 | },
42 | {
43 | n: String{Valid: false},
44 | value: "",
45 | exp: false,
46 | },
47 | }
48 |
49 | for _, c := range cases {
50 | t.Run("", func(t *testing.T) {
51 | xt.Eq(t, c.exp, c.n.Compare(c.value))
52 | })
53 | }
54 |
55 | t.Run("panics if value type is not supported", func(t *testing.T) {
56 | xt.Panics(t, func() {
57 | _ = String{Valid: true}.Compare(123)
58 | })
59 | })
60 | }
61 |
62 | func TestString_Value(t *testing.T) {
63 | str := "String!"
64 |
65 | t.Run("valid", func(t *testing.T) {
66 | ns := String{String: str, Valid: true}
67 | v, _ := ns.Value()
68 | d, ok := v.(string)
69 | xt.Assert(t, ok, "expected string")
70 | xt.Eq(t, str, d)
71 | })
72 |
73 | t.Run("not valid", func(t *testing.T) {
74 | ns := String{String: str, Valid: false}
75 | v, _ := ns.Value()
76 | xt.Eq(t, nil, v, "expected nil")
77 | })
78 | }
79 |
--------------------------------------------------------------------------------
/null/strings.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | "sort"
9 | )
10 |
11 | // Strings represents a []string (slice of strings), for example used for MySQL ENUM
12 | // type, that may be NULL.
13 | // This is not available in the Go's sql package, and does not implement the Scanner interface.
14 | type Strings struct {
15 | Strings []string
16 | Valid bool
17 | }
18 |
19 | var _ driver.Valuer = &Strings{}
20 | var _ Nullable = &Strings{}
21 |
22 | // Compare returns whether value compares with the nullable Strings.
23 | // It returns:
24 | // - true when Valid and stored Strings is equal to value
25 | // - true when not Valid and value is nil
26 | // - false in any other case
27 | func (ns Strings) Compare(value any) bool {
28 | if !ns.Valid && value != nil {
29 | return false
30 | }
31 |
32 | if value == nil {
33 | return !ns.Valid
34 | }
35 |
36 | equal := func(a, b []string) bool {
37 | if len(a) != len(b) {
38 | return false
39 | }
40 |
41 | ac := make([]string, len(a))
42 | copy(ac, a)
43 | sort.Strings(ac)
44 |
45 | bc := make([]string, len(b))
46 | copy(bc, b)
47 | sort.Strings(bc)
48 |
49 | for i, v := range ac {
50 | if v != bc[i] {
51 | return false
52 | }
53 | }
54 | return true
55 | }
56 |
57 | switch v := value.(type) {
58 | case []string:
59 | return equal(ns.Strings, v)
60 | case *[]string:
61 | return equal(ns.Strings, *v)
62 | default:
63 | panic(fmt.Sprintf("value must be []strings or []*strings; not %T", value))
64 | }
65 | }
66 |
67 | // Value returns the value of n and implements the driver.Valuer
68 | // as well as Nullable interface.
69 | func (ns Strings) Value() (driver.Value, error) {
70 | if !ns.Valid {
71 | return nil, nil
72 | }
73 | return ns.Strings, nil
74 | }
75 |
--------------------------------------------------------------------------------
/null/strings_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xt"
10 | )
11 |
12 | func TestStrings_Compare(t *testing.T) {
13 | mascots1 := []string{"Sakila", "Go gopher"}
14 | mascots2 := []string{"Sakila", "Duke"}
15 | mascots3 := []string{"Sakila"}
16 |
17 | var cases = []struct {
18 | n Nullable
19 | value any
20 | exp bool
21 | }{
22 | {
23 | n: Strings{Strings: mascots1, Valid: true},
24 | value: mascots1,
25 | exp: true,
26 | },
27 | {
28 | n: Strings{Strings: mascots1, Valid: true},
29 | value: mascots2,
30 | exp: false,
31 | },
32 | {
33 | n: Strings{Strings: mascots1, Valid: true},
34 | value: &mascots1,
35 | exp: true,
36 | },
37 | {
38 | n: Strings{Strings: mascots1, Valid: true},
39 | value: &mascots3,
40 | exp: false,
41 | },
42 | {
43 | n: Strings{Strings: nil, Valid: false},
44 | value: nil,
45 | exp: true,
46 | },
47 | {
48 | n: Strings{Strings: []string{}, Valid: false},
49 | value: nil,
50 | exp: true,
51 | },
52 | {
53 | n: Strings{Strings: []string{}, Valid: false},
54 | value: []string{""},
55 | exp: false,
56 | },
57 | }
58 |
59 | for _, c := range cases {
60 | t.Run("", func(t *testing.T) {
61 | xt.Eq(t, c.exp, c.n.Compare(c.value))
62 | })
63 | }
64 |
65 | t.Run("panics if value type is not supported", func(t *testing.T) {
66 | xt.Panics(t, func() {
67 | _ = Strings{Strings: nil, Valid: true}.Compare("str")
68 | })
69 | })
70 | }
71 |
72 | func TestStrings_Value(t *testing.T) {
73 | data := []string{"Sakila", "Go gopher"}
74 |
75 | t.Run("valid", func(t *testing.T) {
76 | ns := Strings{Strings: data, Valid: true}
77 | v, _ := ns.Value()
78 | d, ok := v.([]string)
79 | xt.Assert(t, ok, fmt.Sprintf("expected []string; got %T", v))
80 | xt.Eq(t, data, d)
81 | })
82 |
83 | t.Run("not valid", func(t *testing.T) {
84 | ns := Strings{Strings: data, Valid: false}
85 | v, _ := ns.Value()
86 | xt.Eq(t, nil, v, "expected nil")
87 | })
88 | }
89 |
--------------------------------------------------------------------------------
/null/time.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | "time"
9 | )
10 |
11 | // Time represents as string (MySQL TIMESTAMP, DATETIME, and DATE types) that may be NULL.
12 | // This is similar to sql.NullTime, and does not implement the Scanner interface.
13 | type Time struct {
14 | Time time.Time
15 | Valid bool
16 | }
17 |
18 | var _ driver.Valuer = &Time{}
19 | var _ Nullable = &Time{}
20 |
21 | // Compare returns whether value compares with the nullable Time.
22 | // It returns:
23 | // - true when Valid and stored Time is equal to value
24 | // - true when not Valid and value is nil
25 | // - false in any other case
26 | func (nd Time) Compare(value any) bool {
27 | if !nd.Valid && value != nil {
28 | return false
29 | }
30 |
31 | if value == nil {
32 | return !nd.Valid
33 | }
34 |
35 | switch v := value.(type) {
36 | case time.Time:
37 | return nd.Time.Equal(v)
38 | case *time.Time:
39 | return nd.Time.Equal(*v)
40 | default:
41 | panic(fmt.Sprintf("value must be time.Time or *time.Time; not %T", value))
42 | }
43 | }
44 |
45 | // Value returns the value of n and implements the driver.Valuer
46 | // as well as Nullable interface.
47 | func (nd Time) Value() (driver.Value, error) {
48 | if !nd.Valid {
49 | return nil, nil
50 | }
51 | return nd.Time, nil
52 | }
53 |
--------------------------------------------------------------------------------
/null/time_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 | "time"
8 |
9 | "github.com/golistic/xgo/xt"
10 | "github.com/golistic/xgo/xtime"
11 | )
12 |
13 | func TestTime_Compare(t *testing.T) {
14 | now := time.Now()
15 | yesterday := xtime.Yesterday()
16 |
17 | t.Run("Uint64", func(t *testing.T) {
18 | var cases = []struct {
19 | n Nullable
20 | value *time.Time
21 | exp bool
22 | }{
23 | {
24 | n: Time{Time: now, Valid: true},
25 | value: &now,
26 | exp: true,
27 | },
28 | {
29 | n: Time{Time: now, Valid: true},
30 | value: &yesterday,
31 | exp: false,
32 | },
33 | {
34 | n: Time{Time: now, Valid: false},
35 | value: nil,
36 | exp: false,
37 | },
38 | }
39 |
40 | for _, c := range cases {
41 | t.Run("", func(t *testing.T) {
42 | xt.Eq(t, c.exp, c.n.Compare(c.value))
43 | })
44 | }
45 | })
46 |
47 | t.Run("non-pointer", func(t *testing.T) {
48 | xt.Assert(t, !Time{Time: now, Valid: false}.Compare(now))
49 | xt.Assert(t, Time{Time: now, Valid: true}.Compare(now))
50 | })
51 |
52 | t.Run("panics if value type is not supported", func(t *testing.T) {
53 | xt.Panics(t, func() {
54 | _ = Time{Time: now, Valid: true}.Compare("str")
55 | })
56 | })
57 |
58 | t.Run("value is explicitly nil", func(t *testing.T) {
59 | xt.Assert(t, Time{Time: now, Valid: false}.Compare(nil))
60 | })
61 | }
62 |
63 | func TestTime_Value(t *testing.T) {
64 | now := time.Now()
65 |
66 | t.Run("valid", func(t *testing.T) {
67 | nt := Time{Time: now, Valid: true}
68 | v, _ := nt.Value()
69 | d, ok := v.(time.Time)
70 | xt.Assert(t, ok, "expected time.Time")
71 | xt.Eq(t, now, d)
72 | })
73 |
74 | t.Run("not valid", func(t *testing.T) {
75 | nd := Time{Time: now, Valid: false}
76 | v, _ := nd.Value()
77 | xt.Eq(t, nil, v, "expected nil")
78 | })
79 | }
80 |
--------------------------------------------------------------------------------
/null/uint64.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 |
9 | "github.com/golistic/xgo/xconv"
10 | )
11 |
12 | // Uint64 represents an uint64 (any MySQL unsigned integer type) that may be NULL.
13 | // This is not available in Go's sql package, and does not implement the Scanner interface.
14 | type Uint64 struct {
15 | Uint64 uint64
16 | Valid bool
17 | }
18 |
19 | var _ driver.Valuer = &Uint64{}
20 | var _ Nullable = &Uint64{}
21 |
22 | // Compare returns whether value compares with the nullable Uint64.
23 | // It returns:
24 | // - true when Valid and stored Uint64 is equal to value
25 | // - true when not Valid and value is nil
26 | // - false in any other case
27 | func (ni Uint64) Compare(value any) bool {
28 | if !ni.Valid && value != nil {
29 | return false
30 | }
31 |
32 | if value == nil {
33 | return !ni.Valid
34 | }
35 |
36 | switch v := value.(type) {
37 | case uint64, uint, uint8, uint16, uint32:
38 | return xconv.UnsignedAsUint64(v) == ni.Uint64
39 | case *uint64, *uint, *uint8, *uint16, *uint32:
40 | return *xconv.UnsignedAsUint64Ptr(v) == ni.Uint64
41 | default:
42 | panic(fmt.Sprintf("value is of unsupported type %T", value))
43 | }
44 | }
45 |
46 | // Value returns the value of n and implements the driver.Valuer
47 | // as well as Nullable interface.
48 | func (ni Uint64) Value() (driver.Value, error) {
49 | if !ni.Valid {
50 | return nil, nil
51 | }
52 | return ni.Uint64, nil
53 | }
54 |
--------------------------------------------------------------------------------
/null/uint64_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package null
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | func TestUint64_Compare(t *testing.T) {
12 | t.Run("Uint64", func(t *testing.T) {
13 | var cases = []struct {
14 | n Nullable
15 | value any
16 | exp bool
17 | }{
18 | {
19 | n: Uint64{Uint64: 8, Valid: true},
20 | value: uint64Ptr(8),
21 | exp: true,
22 | },
23 | {
24 | n: Uint64{Uint64: 19, Valid: true},
25 | value: uint64Ptr(8),
26 | exp: false,
27 | },
28 | {
29 | n: Uint64{Uint64: 8, Valid: true},
30 | value: uint(8),
31 | exp: true,
32 | },
33 | {
34 | n: Uint64{Uint64: 19, Valid: true},
35 | value: uint(8),
36 | exp: false,
37 | },
38 | {
39 | n: Uint64{Uint64: 898, Valid: false},
40 | value: nil,
41 | exp: true,
42 | },
43 | {
44 | n: Uint64{Uint64: 898, Valid: false},
45 | value: uint(898),
46 | exp: false,
47 | },
48 | }
49 |
50 | for _, c := range cases {
51 | t.Run("", func(t *testing.T) {
52 | xt.Eq(t, c.exp, c.n.Compare(c.value))
53 | })
54 | }
55 | })
56 |
57 | t.Run("panics if value type is not supported", func(t *testing.T) {
58 | xt.Panics(t, func() {
59 | _ = Uint64{Uint64: 0, Valid: true}.Compare("str")
60 | })
61 | })
62 | }
63 |
64 | func TestUint64_Value(t *testing.T) {
65 | t.Run("valid", func(t *testing.T) {
66 | ni := Uint64{Uint64: 9, Valid: true}
67 | v, _ := ni.Value()
68 | d, ok := v.(uint64)
69 | xt.Assert(t, ok, "expected uint64")
70 | xt.Eq(t, 9, d)
71 | })
72 |
73 | t.Run("not valid", func(t *testing.T) {
74 | ni := Uint64{Uint64: 0, Valid: false}
75 | v, _ := ni.Value()
76 | xt.Eq(t, nil, v, "expected nil")
77 | })
78 | }
79 |
--------------------------------------------------------------------------------
/register/mysql/register.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package mysql
4 |
5 | import (
6 | "database/sql"
7 |
8 | "github.com/golistic/pxmysql"
9 | )
10 |
11 | const DriverName = "mysql"
12 |
13 | func init() {
14 | sql.Register(DriverName, &pxmysql.Driver{})
15 | }
16 |
--------------------------------------------------------------------------------
/register/mysql/register_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package mysql
4 |
5 | import (
6 | "database/sql"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xstrings"
10 | "github.com/golistic/xgo/xt"
11 | )
12 |
13 | func TestDriver_Open(t *testing.T) {
14 | t.Run("pxmysql is not registered", func(t *testing.T) {
15 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql not to be registered")
16 | })
17 |
18 | t.Run("mysql is registered", func(t *testing.T) {
19 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql to be registered")
20 | })
21 | }
22 |
--------------------------------------------------------------------------------
/register/register.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package register
4 |
5 | import (
6 | "database/sql"
7 |
8 | "github.com/golistic/pxmysql"
9 | )
10 |
11 | const DriverName = "pxmysql"
12 |
13 | func init() {
14 | sql.Register(DriverName, &pxmysql.Driver{})
15 | }
16 |
--------------------------------------------------------------------------------
/register/register_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package register
4 |
5 | import (
6 | "database/sql"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xstrings"
10 | "github.com/golistic/xgo/xt"
11 | )
12 |
13 | func TestDriver_Open(t *testing.T) {
14 | t.Run("pxmysql is registered", func(t *testing.T) {
15 | xt.Assert(t, xstrings.SliceHas(sql.Drivers(), "pxmysql"), "expected driver pxmysql to be registered")
16 | })
17 |
18 | t.Run("mysql is not registered", func(t *testing.T) {
19 | xt.Assert(t, !xstrings.SliceHas(sql.Drivers(), "mysql"), "expected driver mysql not to be registered")
20 | })
21 | }
22 |
--------------------------------------------------------------------------------
/result.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "database/sql/driver"
7 | "fmt"
8 | "math"
9 |
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | type result struct {
14 | xpresult *xmysql.Result
15 | }
16 |
17 | var _ driver.Result = &result{}
18 |
19 | // LastInsertId returns the database's auto-generated ID after,
20 | // for example, an INSERT into a table with primary key.
21 | // If this information is not available, zero (0) is returned.
22 | func (r result) LastInsertId() (int64, error) {
23 | if r.xpresult != nil {
24 | lid := r.xpresult.LastInsertID()
25 | if lid > math.MaxInt64 {
26 | return 0, fmt.Errorf("LastInsertID overflowed max 64-bit unsigned integer")
27 | }
28 | return int64(lid), nil
29 | }
30 | return 0, nil
31 | }
32 |
33 | // RowsAffected returns the number of rows affected by the query.
34 | // If this information is not available, zero (0) is returned.
35 | func (r result) RowsAffected() (int64, error) {
36 | if r.xpresult != nil {
37 | affected := r.xpresult.RowsAffected()
38 | if affected > math.MaxInt64 {
39 | return 0, fmt.Errorf("RowsAffected overflowed max 64-bit unsigned integer")
40 | }
41 | return int64(affected), nil
42 | }
43 | return 0, nil
44 | }
45 |
--------------------------------------------------------------------------------
/rows.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "database/sql/driver"
7 | "io"
8 |
9 | "github.com/golistic/pxmysql/null"
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | type rows struct {
14 | xpresult *xmysql.Result
15 |
16 | currRowIndex int
17 | }
18 |
19 | var _ driver.Rows = &rows{}
20 |
21 | // Columns returns the names of the columns.
22 | func (r *rows) Columns() []string {
23 | if r.xpresult == nil {
24 | return nil
25 | }
26 |
27 | cols := make([]string, len(r.xpresult.Columns))
28 |
29 | for i, c := range r.xpresult.Columns {
30 | cols[i] = string(c.Name)
31 | }
32 |
33 | return cols
34 | }
35 |
36 | func (r *rows) Close() error {
37 | return nil
38 | }
39 |
40 | func (r *rows) Next(dest []driver.Value) error {
41 | if r.xpresult == nil || r.currRowIndex >= len(r.xpresult.Rows) {
42 | return io.EOF
43 | }
44 |
45 | for i, value := range r.xpresult.Rows[r.currRowIndex].Values {
46 | if n, ok := value.(null.Nullable); ok {
47 | var err error
48 | dest[i], err = n.Value()
49 | if err != nil {
50 | return err
51 | }
52 | } else {
53 | dest[i] = value
54 | }
55 | }
56 |
57 | r.currRowIndex++
58 | return nil
59 | }
60 |
--------------------------------------------------------------------------------
/rows_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package pxmysql_test
4 |
5 | import (
6 | "context"
7 | "database/sql"
8 | "fmt"
9 | "testing"
10 | "time"
11 |
12 | "github.com/golistic/xgo/xt"
13 | )
14 |
15 | func TestRows_Next(t *testing.T) {
16 | db, err := sql.Open("pxmysql", getTCPDSN("", ""))
17 | xt.OK(t, err)
18 | defer func() { _ = db.Close() }()
19 |
20 | ctx := context.Background()
21 |
22 | t.Run("time.Time", func(t *testing.T) {
23 | tbl := "test_data_types_null_time"
24 | _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE `%s` (id int, ts DATETIME NULL)", tbl))
25 | xt.OK(t, err)
26 |
27 | _, err = db.ExecContext(ctx, fmt.Sprintf(
28 | "INSERT INTO `%s` (id, ts) VALUE (1, NOW()),(2, NULL)", tbl))
29 | xt.OK(t, err)
30 |
31 | stmt := fmt.Sprintf("SELECT ts FROM `%s` WHERE id = ?", tbl)
32 |
33 | var ts time.Time
34 | xt.OK(t, db.QueryRowContext(ctx, stmt, 1).Scan(&ts))
35 |
36 | var tsNull sql.NullTime
37 | xt.OK(t, db.QueryRowContext(ctx, stmt, 2).Scan(&tsNull))
38 | })
39 | }
40 |
--------------------------------------------------------------------------------
/statement.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | "time"
9 |
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | var closeTimeout = time.Second
14 |
15 | type statement struct {
16 | prepared *xmysql.Prepared
17 | result *xmysql.Result
18 | }
19 |
20 | var (
21 | _ driver.Stmt = &statement{}
22 | _ driver.StmtQueryContext = &statement{}
23 | _ driver.StmtExecContext = &statement{}
24 | )
25 |
26 | func (s *statement) Close() error {
27 | ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
28 | defer cancel()
29 |
30 | if err := s.prepared.Deallocate(ctx); err != nil {
31 | return err
32 | } else {
33 | s.prepared = nil
34 | s.result = nil
35 | }
36 | return nil
37 | }
38 |
39 | // NumInput returns the number of placeholders.
40 | func (s *statement) NumInput() int {
41 | if s.prepared != nil {
42 | return s.prepared.NumPlaceholders()
43 | }
44 |
45 | return 0
46 | }
47 |
48 | // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
49 | // Deprecated: use ExecContext instead.
50 | func (s *statement) Exec(args []driver.Value) (driver.Result, error) {
51 | named := make([]driver.NamedValue, len(args))
52 | for i, arg := range args {
53 | named[i].Name = ""
54 | named[i].Ordinal = i + 1
55 | named[i].Value = arg
56 | }
57 | return s.ExecContext(context.Background(), named)
58 | }
59 |
60 | func (s *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
61 | execArgs := make([]any, len(args))
62 |
63 | for i, a := range args {
64 | execArgs[i] = a.Value
65 | }
66 |
67 | execResult, err := s.prepared.Execute(ctx, execArgs...)
68 | if err != nil {
69 | return nil, handleError(err)
70 | }
71 |
72 | res := &result{
73 | xpresult: execResult,
74 | }
75 |
76 | return res, nil
77 | }
78 |
79 | // Query executes a query that may return rows, such as a SELECT.
80 | // Deprecated: use QueryContext instead.
81 | func (s *statement) Query(args []driver.Value) (driver.Rows, error) {
82 | named := make([]driver.NamedValue, len(args))
83 | for i, arg := range args {
84 | named[i].Name = ""
85 | named[i].Ordinal = i + 1
86 | named[i].Value = arg
87 | }
88 | return s.QueryContext(context.Background(), named)
89 | }
90 |
91 | // QueryContext executes a query that may return rows, such as a SELECT.
92 | func (s *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
93 | execArgs := make([]any, len(args))
94 |
95 | for i, a := range args {
96 | execArgs[i] = a
97 | }
98 |
99 | execResult, err := s.prepared.Execute(ctx, execArgs...)
100 | if err != nil {
101 | return nil, handleError(err)
102 | }
103 |
104 | if len(execResult.Rows) == 0 {
105 | return &rows{}, nil
106 | }
107 |
108 | r := &rows{
109 | xpresult: execResult,
110 | }
111 |
112 | return r, nil
113 | }
114 |
--------------------------------------------------------------------------------
/statement_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql_test
4 |
5 | import (
6 | "context"
7 | "crypto/md5"
8 | "database/sql"
9 | "encoding/hex"
10 | "errors"
11 | "fmt"
12 | "sort"
13 | "strings"
14 | "testing"
15 | "time"
16 |
17 | "github.com/golistic/xgo/xt"
18 |
19 | "github.com/golistic/pxmysql/mysqlerrors"
20 | )
21 |
22 | func TestStatement_Close(t *testing.T) {
23 | dsn := getTCPDSN("", "")
24 | db, err := sql.Open("pxmysql", dsn)
25 | xt.OK(t, err)
26 |
27 | stmt := "SELECT ?"
28 | prep, err := db.Prepare(stmt)
29 | xt.OK(t, err)
30 |
31 | _, err = prep.Exec(3)
32 | xt.OK(t, err)
33 |
34 | xt.OK(t, prep.Close())
35 |
36 | _, err = prep.Exec(3)
37 | xt.KO(t, err)
38 | xt.Eq(t, "sql: statement is closed", err.Error())
39 | }
40 |
41 | func testOpenQueryRowsClose() ([]string, error) {
42 | dsn := getTCPDSN("", "")
43 | db, err := sql.Open("pxmysql", dsn)
44 | if err != nil {
45 | return nil, err
46 | }
47 | defer func() { _ = db.Close() }()
48 |
49 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA`
50 | prep, err := db.Prepare(stmt)
51 | if err != nil {
52 | return nil, err
53 | }
54 | defer func() { _ = prep.Close() }()
55 |
56 | rows, err := prep.QueryContext(context.Background(), "mysql")
57 | if err != nil {
58 | return nil, err
59 | }
60 |
61 | var tablesNames []string
62 | for rows.Next() {
63 | var name sql.NullString
64 | if err := rows.Scan(&name); err != nil {
65 | return nil, err
66 | }
67 | if !name.Valid {
68 | return nil, fmt.Errorf("found entry with null as table name")
69 | }
70 | tablesNames = append(tablesNames, name.String)
71 | }
72 |
73 | return tablesNames, nil
74 | }
75 |
76 | func TestStatement_ExecContext(t *testing.T) {
77 | t.Run("respect timeout", func(t *testing.T) {
78 | dsn := getTCPDSN("", "")
79 | db, err := sql.Open("pxmysql", dsn)
80 | xt.OK(t, err)
81 | defer func() { _ = db.Close() }()
82 |
83 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
84 | defer cancel()
85 |
86 | _, err = db.ExecContext(ctx, "SELECT SLEEP(5)")
87 | xt.KO(t, err)
88 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error())
89 | })
90 | }
91 |
92 | func BenchmarkStatement_QueryContext(b *testing.B) {
93 | b.Run("fetch tables from mysql database", func(b *testing.B) {
94 | if _, err := testOpenQueryRowsClose(); err != nil {
95 | b.Error(err)
96 | }
97 | })
98 | }
99 |
100 | func TestStatement_QueryContext(t *testing.T) {
101 | t.Run("has rows in result", func(t *testing.T) {
102 | tableNames, err := testOpenQueryRowsClose()
103 | xt.OK(t, err)
104 |
105 | sort.Strings(tableNames)
106 | sum := md5.Sum([]byte(strings.Join(tableNames, " ")))
107 | xt.Eq(t, "859173a1b7b8ef446282e772dcd3039b", hex.EncodeToString(sum[:]))
108 | })
109 |
110 | t.Run("respect timeout", func(t *testing.T) {
111 | dsn := getTCPDSN("", "")
112 | db, err := sql.Open("pxmysql", dsn)
113 | xt.OK(t, err)
114 | defer func() { _ = db.Close() }()
115 |
116 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
117 | defer cancel()
118 |
119 | _, err = db.QueryContext(ctx, "SELECT SLEEP(5)")
120 | xt.KO(t, err)
121 | xt.Assert(t, errors.Is(err, mysqlerrors.ErrContextDeadlineExceeded), err.Error())
122 | })
123 |
124 | t.Run("does not return sql.ErrNoRows", func(t *testing.T) {
125 | db, err := sql.Open("pxmysql", getTCPDSN("", ""))
126 | xt.OK(t, err)
127 | defer func() { _ = db.Close() }()
128 |
129 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA`
130 | rows, err := db.QueryContext(context.Background(), stmt, "_this_does_not_exists_")
131 | xt.OK(t, err)
132 | xt.Assert(t, !rows.Next(), "expected no rows")
133 | })
134 |
135 | t.Run("QueryRowContext does return sql.ErrNoRows", func(t *testing.T) {
136 | db, err := sql.Open("pxmysql", getTCPDSN("", ""))
137 | xt.OK(t, err)
138 | defer func() { _ = db.Close() }()
139 |
140 | var name string
141 | stmt := `SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? ORDER BY TABLE_SCHEMA`
142 | err = db.QueryRowContext(context.Background(), stmt, "_this_does_not_exists_").Scan(&name)
143 | xt.KO(t, err)
144 | xt.Assert(t, errors.Is(err, sql.ErrNoRows))
145 | })
146 | }
147 |
--------------------------------------------------------------------------------
/transaction.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package pxmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | "fmt"
9 |
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | type Transaction struct {
14 | session *xmysql.Session
15 | }
16 |
17 | var _ driver.Tx = &Transaction{}
18 |
19 | func (tx *Transaction) Commit() error {
20 | if tx.session == nil {
21 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn)
22 | }
23 |
24 | if _, err := tx.session.ExecuteStatement(context.Background(), "COMMIT"); err != nil {
25 | return err
26 | }
27 |
28 | return nil
29 | }
30 |
31 | func (tx *Transaction) Rollback() error {
32 | if tx.session == nil {
33 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn)
34 | }
35 |
36 | if _, err := tx.session.ExecuteStatement(context.Background(), "ROLLBACK"); err != nil {
37 | return err
38 | }
39 |
40 | return nil
41 | }
42 |
--------------------------------------------------------------------------------
/xmysql/_testdata/.gitignore:
--------------------------------------------------------------------------------
1 | mysql_ca.pem
--------------------------------------------------------------------------------
/xmysql/_testdata/base.sql:
--------------------------------------------------------------------------------
1 | CREATE SCHEMA IF NOT EXISTS pxmysql_tests;
2 | CREATE SCHEMA IF NOT EXISTS pxmysql_tests_a;
3 |
4 | -- Following users are used for testing the authentication with
5 | -- MySQL Authentication Plugins.
6 | CREATE USER IF NOT EXISTS 'user_native'@'%' IDENTIFIED WITH mysql_native_password
7 | BY 'pwd_user_native';
8 | GRANT ALL ON pxmysql_tests.* TO 'user_native'@'%';
9 | GRANT ALL ON pxmysql_tests_a.* TO 'user_native'@'%';
10 |
11 | CREATE USER IF NOT EXISTS 'user_sha256'@'%' IDENTIFIED WITH caching_sha2_password
12 | BY 'pwd_user_sha256';
13 | GRANT ALL ON pxmysql_tests.* TO 'user_sha256'@'%';
14 | GRANT ALL ON pxmysql_tests_a.* TO 'user_sha256'@'%';
15 |
16 | CREATE USER IF NOT EXISTS 'pxmysqltest'@'%' IDENTIFIED WITH mysql_native_password BY '';
17 | GRANT ALL ON pxmysql_tests.* TO 'pxmysqltest'@'%';
18 | GRANT ALL ON pxmysql_tests_a.* TO 'pxmysqltest'@'%';
19 |
20 | -- Clean up objects that might have been created by tests
21 | DROP SCHEMA IF EXISTS `pxmysql_2839cks829dka`;
22 |
--------------------------------------------------------------------------------
/xmysql/_testdata/data_types_datetime.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS `data_types_datetime`;
4 |
5 | CREATE TABLE data_types_datetime
6 | (
7 | id TINYINT AUTO_INCREMENT,
8 | dt_date DATE NOT NULL,
9 | dt_time TIME(6) NOT NULL,
10 | dt_datetime DATETIME(6) NOT NULL,
11 | dt_timestamp TIMESTAMP NOT NULL,
12 | dt_year YEAR NOT NULL,
13 | PRIMARY KEY (id)
14 | );
15 |
16 | SET @@time_zone = '+00:00';
17 | INSERT INTO data_types_datetime
18 | VALUES (1, '2005-03-01', '08:00:01.123456', '2005-03-01 07:00:01', FROM_UNIXTIME(1109660401),
19 | 2005),
20 | (2, '9999-12-31', '838:59:59.0', '9999-12-31 23:59:59.999999', FROM_UNIXTIME(2147483647),
21 | 1901),
22 | (3, '1000-01-01', '-838:59:59.0', '1000-01-01 00:00:00', FROM_UNIXTIME(1),
23 | 1901);
24 |
--------------------------------------------------------------------------------
/xmysql/_testdata/data_types_numeric.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS `data_types_numeric`;
4 |
5 | CREATE TABLE data_types_numeric
6 | (
7 | id TINYINT AUTO_INCREMENT,
8 | numeric_bit BIT(6) NULL,
9 | numeric_bool BOOL NULL,
10 | numeric_tinyint TINYINT NULL,
11 | numeric_tinyint_unsigned TINYINT UNSIGNED NULL,
12 | numeric_smallint SMALLINT NOT NULL,
13 | numeric_smallint_unsigned SMALLINT UNSIGNED NOT NULL,
14 | numeric_mediumint MEDIUMINT NOT NULL,
15 | numeric_mediumint_unsigned MEDIUMINT UNSIGNED NOT NULL,
16 | numeric_int INT NOT NULL,
17 | numeric_int_unsigned INT UNSIGNED NOT NULL,
18 | numeric_bigint BIGINT NOT NULL,
19 | numeric_bigint_unsigned BIGINT UNSIGNED NOT NULL,
20 | numeric_decimal DECIMAL(65, 30) NOT NULL,
21 | numeric_decimal2 DECIMAL(65, 1) NOT NULL,
22 | numeric_decimal3 DECIMAL(18, 9) NOT NULL,
23 | PRIMARY KEY (id)
24 | );
25 |
26 | INSERT INTO data_types_numeric
27 | VALUES (1, b'100110', false, 127, 0,
28 | 32767, 0, 8388607, 0, 2147483647, 0,
29 | 9223372036854775807, 0,
30 | 3.14,
31 | 9999999999999999999999999999999999999999999999999999999999991234.9,
32 | 123456789.000001),
33 | (2, b'000110', true, -128, 255,
34 | -32768, 65535, -8388608, 16777215, -2147483648, 4294967295,
35 | -9223372036854775808, 18446744073709551615,
36 | -3.14,
37 | -9999999999999999999999999999999999999999999999999999999999991234.5,
38 | -123456789.000001)
39 | ;
40 |
--------------------------------------------------------------------------------
/xmysql/_testdata/data_types_string.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS `data_types_string`;
4 |
5 | CREATE TABLE `data_types_string`
6 | (
7 | id TINYINT AUTO_INCREMENT,
8 | s_char CHAR(255) NOT NULL,
9 | s_varchar VARCHAR(400) NOT NULL,
10 | s_binary BINARY(20) NOT NULL,
11 | s_varbinary VARBINARY(20) NOT NULL,
12 | s_longtext LONGTEXT NOT NULL,
13 | s_tinyblob TINYBLOB NOT NULL,
14 | s_enum ENUM ('Go', 'Python', 'JavaScript') NOT NULL,
15 | s_set SET ('Go', 'Python', 'JavaScript') NOT NULL,
16 | PRIMARY KEY (id)
17 | );
18 |
19 | INSERT INTO data_types_string
20 | VALUES (1,
21 | CONCAT('CHAR', REPEAT('a', 251)),
22 | CONCAT('VARCHAR', REPEAT('b', 393)),
23 | X'0708090a0b0c0d0e0f10',
24 | X'08090a0b0c0d0e0f10',
25 | CONCAT('LONGTEXT', REPEAT('l', @@mysqlx_max_allowed_packet - 10)),
26 | 'I am a tiny blob',
27 | 'Go',
28 | 'Python,Go');
29 |
--------------------------------------------------------------------------------
/xmysql/_testdata/inserting.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS `inserts01`;
4 |
5 | CREATE TABLE inserts01
6 | (
7 | id TINYINT NOT NULL AUTO_INCREMENT,
8 | c1 VARCHAR(20),
9 | PRIMARY KEY (id)
10 | );
11 |
--------------------------------------------------------------------------------
/xmysql/_testdata/prepared_stmt.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS `numeric_not_null`, `numeric_null`;
4 | DROP TABLE IF EXISTS `temporal_not_null`, `temporal_null`;
5 | DROP TABLE IF EXISTS `strings_not_null`, `strings_null`;
6 |
7 | CREATE TABLE numeric_not_null
8 | (
9 | id TINYINT AUTO_INCREMENT,
10 | bit_ BIT(6) NOT NULL,
11 | bool_ BOOL NOT NULL,
12 | tinyint_ TINYINT NOT NULL,
13 | tinyint_unsigned TINYINT UNSIGNED NOT NULL,
14 | smallint_ SMALLINT NOT NULL,
15 | smallint_unsigned SMALLINT UNSIGNED NOT NULL,
16 | mediumint_ MEDIUMINT NOT NULL,
17 | mediumint_unsigned MEDIUMINT UNSIGNED NOT NULL,
18 | int_ INT NOT NULL,
19 | int_unsigned INT UNSIGNED NOT NULL,
20 | bigint_ BIGINT NOT NULL,
21 | bigint_unsigned BIGINT UNSIGNED NOT NULL,
22 | decimal_ DECIMAL(65, 30) NOT NULL,
23 | float_ FLOAT NOT NULL,
24 | float_unsigned FLOAT UNSIGNED NOT NULL,
25 | double_ DOUBLE NOT NULL,
26 | double_unsigned DOUBLE UNSIGNED NOT NULL,
27 | PRIMARY KEY (id)
28 | );
29 |
30 | CREATE TABLE numeric_null
31 | (
32 | id TINYINT AUTO_INCREMENT,
33 | bit_ BIT(6) NULL,
34 | bool_ BOOL NULL,
35 | tinyint_ TINYINT NULL,
36 | tinyint_unsigned TINYINT UNSIGNED NULL,
37 | smallint_ SMALLINT NULL,
38 | smallint_unsigned SMALLINT UNSIGNED NULL,
39 | mediumint_ MEDIUMINT NULL,
40 | mediumint_unsigned MEDIUMINT UNSIGNED NULL,
41 | int_ INT NULL,
42 | int_unsigned INT UNSIGNED NULL,
43 | bigint_ BIGINT NULL,
44 | bigint_unsigned BIGINT UNSIGNED NULL,
45 | decimal_ DECIMAL(65, 30) NULL,
46 | float_ FLOAT NULL,
47 | float_unsigned FLOAT UNSIGNED NULL,
48 | double_ DOUBLE NULL,
49 | double_unsigned DOUBLE UNSIGNED NULL,
50 | PRIMARY KEY (id)
51 | );
52 |
53 | CREATE TABLE temporal_not_null
54 | (
55 | id TINYINT AUTO_INCREMENT,
56 | datetime_ DATETIME(6) NOT NULL,
57 | date_ DATE NOT NULL,
58 | timestamp_ TIMESTAMP(6) NOT NULL,
59 | year_ YEAR NOT NULL,
60 | PRIMARY KEY (id)
61 | );
62 |
63 | CREATE TABLE temporal_null
64 | (
65 | id TINYINT AUTO_INCREMENT,
66 | datetime_ DATETIME(6) NULL,
67 | date_ DATE NULL,
68 | timestamp_ TIMESTAMP(6) NULL,
69 | year_ YEAR NULL,
70 | PRIMARY KEY (id)
71 | );
72 |
73 | CREATE TABLE strings_not_null
74 | (
75 | id TINYINT AUTO_INCREMENT,
76 | char_ CHAR(255) NOT NULL,
77 | binary_ BINARY(255) NOT NULL,
78 | varchar_ VARCHAR(600) NOT NULL,
79 | varbinary_ VARBINARY(410) NOT NULL,
80 | tinyblob_ TINYBLOB NOT NULL,
81 | tinytext_ TINYTEXT NOT NULL,
82 | blob_ BLOB NOT NULL,
83 | text_ TEXT NOT NULL,
84 | mediumblob_ MEDIUMBLOB NOT NULL,
85 | mediumtext_ MEDIUMTEXT NOT NULL,
86 | longblob_ LONGBLOB NOT NULL,
87 | longtext_ LONGTEXT NOT NULL,
88 | enum_ ENUM ('Earth', 'Moon', 'Mars', 'Europa') NOT NULL,
89 | set_ SET ('Earth', 'Moon', 'Mars') NOT NULL,
90 | PRIMARY KEY (id)
91 | );
92 |
93 | CREATE TABLE strings_null
94 | (
95 | id TINYINT AUTO_INCREMENT,
96 | char_ CHAR(255) NULL,
97 | binary_ BINARY(255) NULL,
98 | varchar_ VARCHAR(600) NULL,
99 | varbinary_ VARBINARY(410) NULL,
100 | tinyblob_ TINYBLOB NULL,
101 | tinytext_ TINYTEXT NULL,
102 | blob_ BLOB NULL,
103 | text_ TEXT NULL,
104 | mediumblob_ MEDIUMBLOB NULL,
105 | mediumtext_ MEDIUMTEXT NULL,
106 | longblob_ LONGBLOB NULL,
107 | longtext_ LONGTEXT NULL,
108 | enum_ ENUM ('Earth', 'Moon', 'Mars'),
109 | set_ SET ('Earth', 'Moon', 'Mars'),
110 | PRIMARY KEY (id)
111 | );
--------------------------------------------------------------------------------
/xmysql/_testdata/schema_collections.sql:
--------------------------------------------------------------------------------
1 | USE pxmysql_tests;
2 |
3 | DROP TABLE IF EXISTS not_collection_28380dew22;
4 | DROP TABLE IF EXISTS collection_wic28skwixkd;
5 | DROP TABLE IF EXISTS collection_weux73293jsnsj;
6 |
7 | CREATE TABLE not_collection_28380dew22
8 | (
9 | id INT
10 | );
11 |
12 | CREATE TABLE collection_wic28skwixkd
13 | (
14 | doc json null,
15 | _id varbinary(32) as (json_unquote(json_extract(`doc`, _utf8mb4'$._id'))) stored
16 | primary key,
17 | _json_schema json as (_utf8mb4'{"type":"object"}'),
18 | constraint $val_strict_wic28skwixkd
19 | check (json_schema_valid(`_json_schema`, `doc`))
20 | );
21 |
22 | CREATE TABLE collection_weux73293jsnsj
23 | (
24 | doc json null,
25 | _id varbinary(32) as (json_unquote(json_extract(`doc`, _utf8mb4'$._id'))) stored
26 | primary key,
27 | _json_schema json as (_utf8mb4'{"type":"object"}'),
28 | constraint $val_strict_weux73293jsnsj
29 | check (json_schema_valid(`_json_schema`, `doc`))
30 | );
--------------------------------------------------------------------------------
/xmysql/authentication.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "crypto/sha1"
7 | "crypto/sha256"
8 | "fmt"
9 | )
10 |
11 | type authn struct {
12 | username string
13 | password string
14 | schema string
15 | challenge []byte
16 | }
17 |
18 | // authSHA256Data prepares authentication data to be sent with the AuthenticateContinue
19 | // message using SHA256. Username and scrambled password are returned as hex.
20 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html.
21 | func authSHA256Data(an authn) ([]byte, error) {
22 | if len(an.challenge) != authChallengeLen {
23 | return nil, fmt.Errorf("authentication challenge must be 20 bytes (was %d)", len(an.challenge))
24 | }
25 |
26 | var scramble string
27 | if an.password != "" {
28 | // hex(sha256(password) XOR sha256(challenge + sha256(sha256(password))))
29 | h1 := sha256.Sum256([]byte(an.password))
30 | hh1 := sha256.Sum256(h1[:])
31 |
32 | hr := sha256.New()
33 | hr.Write(hh1[:])
34 | hr.Write(an.challenge)
35 | h2 := hr.Sum(nil)
36 |
37 | for i := range h2 {
38 | h1[i] ^= h2[i]
39 | }
40 | scramble = fmt.Sprintf("%x", h1)
41 | }
42 |
43 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, scramble)), nil
44 | }
45 |
46 | // authMYSQL41Data prepares authentication data to be sent with the AuthenticateContinue
47 | // message using SHA1 (also known as mysql_native_password). Username and scrambled password
48 | // are returned as hex.
49 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html.
50 | func authMySQL41Data(an authn) ([]byte, error) {
51 | if len(an.challenge) != authChallengeLen {
52 | return nil, fmt.Errorf("authentication challenge must be 20 bytes (was %d)", len(an.challenge))
53 | }
54 |
55 | var scramble string
56 | if an.password != "" {
57 | // hex(sha1(password) XOR sha1(challenge + sha1(sha1(password))))
58 | h1 := sha1.Sum([]byte(an.password))
59 | hh1 := sha1.Sum(h1[:])
60 |
61 | hr := sha1.New()
62 | hr.Write(an.challenge)
63 | hr.Write(hh1[:])
64 | h2 := hr.Sum(nil)
65 |
66 | for i := range h1 {
67 | h1[i] ^= h2[i]
68 | }
69 |
70 | scramble = fmt.Sprintf("*%x", h1)
71 | }
72 |
73 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, scramble)), nil
74 | }
75 |
76 | // authMySQLPlain prepares authentication data to be sent in plain text. This is only
77 | // supported when connection is encrypted (TLS)
78 | // See: https://dev.mysql.com/doc/internals/en/x-protocol-authentication-authentication.html.
79 | func authMySQLPlain(an authn) ([]byte, error) {
80 | return []byte(fmt.Sprintf("%s\x00%s\x00%s", an.schema, an.username, an.password)), nil
81 | }
82 |
--------------------------------------------------------------------------------
/xmysql/capabilities.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "fmt"
7 |
8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxconnection"
9 | "github.com/golistic/pxmysql/xmysql/internal/network"
10 | )
11 |
12 | // ServerCapabilities holds the capabilities returned by the server.
13 | type ServerCapabilities struct {
14 | TLS bool
15 | AuthMechanisms []string
16 | }
17 |
18 | // NewServerCapabilitiesFromMessage instantiates a new ServerCapabilities object
19 | // using a message returned by MySQL Server's X Plugin.
20 | func NewServerCapabilitiesFromMessage(msg *network.ServerMessage) (*ServerCapabilities, error) {
21 | capabilities := &mysqlxconnection.Capabilities{}
22 | if err := msg.Unmarshall(capabilities); err != nil {
23 | return nil, fmt.Errorf("message was not mysqlxconnection.Capabilities")
24 | }
25 |
26 | sc := &ServerCapabilities{}
27 |
28 | for _, c := range capabilities.Capabilities {
29 | switch c.GetName() {
30 | case "tls":
31 | sc.TLS = c.Value.Scalar.GetVBool()
32 | case "authentication.mechanisms":
33 | sc.AuthMechanisms = []string{}
34 | for _, m := range c.Value.Array.Value {
35 | sc.AuthMechanisms = append(sc.AuthMechanisms, string(m.Scalar.GetVString().GetValue()))
36 | }
37 | }
38 | }
39 |
40 | return sc, nil
41 | }
42 |
--------------------------------------------------------------------------------
/xmysql/collations.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | type CollationID int
6 |
7 | type Collation struct {
8 | ID int `json:"id"`
9 | Name string `json:"name"`
10 | CharSet string `json:"charSet"`
11 | }
12 |
13 | // IsSupportedCollation returns whether c is a valid/supported collation. Note that
14 | // MySQL Protocol X only supports the utf8mb4 character set, and consequently, only collations of utf8mb4.
15 | // Argument c can be the internal MYSQL ID or MySQL name.
16 | func IsSupportedCollation[T string | uint64 | int](c T) bool {
17 | var have bool
18 | switch v := any(c).(type) {
19 | case string:
20 | _, have = Collations[v]
21 | case uint64:
22 | _, have = collationIDs[v]
23 | case int:
24 | _, have = collationIDs[uint64(v)]
25 | }
26 | return have
27 | }
28 |
--------------------------------------------------------------------------------
/xmysql/collations_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 |
10 | "github.com/golistic/pxmysql/xmysql"
11 | )
12 |
13 | func TestIsSupportedCollation(t *testing.T) {
14 | t.Run("using MySQL ID", func(t *testing.T) {
15 | xt.Assert(t, xmysql.IsSupportedCollation(241))
16 | xt.Assert(t, xmysql.IsSupportedCollation(uint64(241)))
17 | })
18 |
19 | t.Run("using MySQL name", func(t *testing.T) {
20 | xt.Assert(t, xmysql.IsSupportedCollation("utf8mb4_esperanto_ci"))
21 | })
22 | }
23 |
--------------------------------------------------------------------------------
/xmysql/collection.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "slices"
9 | )
10 |
11 | type Collection struct {
12 | schema *Schema
13 | session *Session
14 | name string
15 | }
16 |
17 | var (
18 | _ adder = (*Collection)(nil)
19 | )
20 |
21 | // newCollection instantiates a new Collection object with schema.
22 | func newCollection(schema *Schema, name string) (*Collection, error) {
23 | if schema == nil || schema.session == nil {
24 | return nil, fmt.Errorf("session closed")
25 | }
26 |
27 | if name == "" {
28 | return nil, fmt.Errorf("invalid name")
29 | }
30 |
31 | return &Collection{
32 | schema: schema,
33 | session: schema.session,
34 | name: name,
35 | }, nil
36 | }
37 |
38 | func (c *Collection) String() string {
39 | return fmt.Sprintf("", c.name, c.schema)
40 | }
41 |
42 | // Name returns the collection.
43 | func (c *Collection) Name() string {
44 | return c.name
45 | }
46 |
47 | func (c *Collection) CheckExistence(ctx context.Context) error {
48 | names, err := c.schema.objectNames(ctx, ObjectCollection)
49 | if err != nil {
50 | return err
51 | }
52 |
53 | if _, ok := slices.BinarySearch(names, c.name); !ok {
54 | return ErrNotAvailable
55 | }
56 |
57 | return nil
58 | }
59 |
60 | func (c *Collection) Add(object ...any) *Add {
61 |
62 | return NewAdd(c).Add(object...)
63 | }
64 |
65 | func (c *Collection) Remove(ctx context.Context) *Add {
66 |
67 | return nil
68 | }
69 |
--------------------------------------------------------------------------------
/xmysql/collection/create.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package collection
4 |
5 | type CreateOptions struct {
6 | ReuseExisting bool
7 | }
8 |
9 | type CreateOption func(opts *CreateOptions)
10 |
11 | func NewCreateOptions(opts ...CreateOption) *CreateOptions {
12 | options := &CreateOptions{}
13 |
14 | for _, opt := range opts {
15 | opt(options)
16 | }
17 |
18 | return options
19 | }
20 |
21 | func CreateReuseExisting() CreateOption {
22 | return func(opts *CreateOptions) {
23 | opts.ReuseExisting = true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/xmysql/collection/get.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package collection
4 |
5 | type GetOptions struct {
6 | ValidateExistence bool
7 | }
8 |
9 | type GetOption func(opts *GetOptions)
10 |
11 | func NewGetOptions(opts ...GetOption) *GetOptions {
12 | options := &GetOptions{}
13 |
14 | for _, opt := range opts {
15 | opt(options)
16 | }
17 |
18 | return options
19 | }
20 |
21 | func GetValidateExistence() GetOption {
22 | return func(opts *GetOptions) {
23 | opts.ValidateExistence = true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/xmysql/collection_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "context"
7 | "encoding/json"
8 | "errors"
9 | "sort"
10 | "testing"
11 |
12 | "github.com/golistic/xgo/xt"
13 |
14 | "github.com/golistic/pxmysql/internal/xxt"
15 | "github.com/golistic/pxmysql/null"
16 | "github.com/golistic/pxmysql/xmysql"
17 | )
18 |
19 | type Person struct {
20 | Name string `json:"name"`
21 | Age int `json:"age"`
22 | }
23 |
24 | func crudTestCollection(t *testing.T, name string) (*xmysql.Schema, *xmysql.Collection) {
25 | config := &xmysql.ConnectConfig{
26 | Address: testContext.XPluginAddr,
27 | Username: xxt.UserNative,
28 | Schema: "pxmysql_tests",
29 | }
30 | config.SetPassword(xxt.UserNativePwd)
31 |
32 | ses, err := xmysql.GetSession(context.Background(), config)
33 | xt.OK(t, err)
34 |
35 | schema, err := ses.GetSchema(context.Background())
36 | xt.OK(t, err)
37 |
38 | c, err := schema.CreateCollection(context.Background(), name)
39 | xt.OK(t, err)
40 |
41 | return schema, c
42 | }
43 |
44 | func TestCollection_Add(t *testing.T) {
45 | schema, coll := crudTestCollection(t, "person_2987dk8dj0s")
46 |
47 | t.Run("can only add struct", func(t *testing.T) {
48 | err := coll.Add(1).GetError()
49 | xt.KO(t, err)
50 | xt.Eq(t, "unsupported object kind int", err.Error())
51 | })
52 |
53 | t.Run("object as pointer value", func(t *testing.T) {
54 | xt.OK(t, coll.Add(&Person{Name: "Alice"}).GetError())
55 | })
56 |
57 | t.Run("object as value", func(t *testing.T) {
58 | xt.OK(t, coll.Add(Person{Name: "Alice,c"}).GetError())
59 | })
60 |
61 | t.Run("execute stores data", func(t *testing.T) {
62 | xt.OK(t, coll.
63 | Add(&Person{Name: "Laurie", Age: 19}).
64 | Add(&Person{Name: "Nadya", Age: 54}, &Person{Name: "Lucas", Age: 32}).
65 | Execute(context.Background()))
66 | exp := []string{"Laurie", "Nadya", "Lucas"}
67 | sort.Strings(exp)
68 |
69 | ses := schema.GetSession()
70 | res, err := ses.ExecuteStatement(context.Background(), "SELECT doc FROM person_2987dk8dj0s")
71 | xt.OK(t, err)
72 |
73 | var got []string
74 | for _, row := range res.Rows {
75 | doc, ok := row.Values[0].(null.Bytes)
76 | xt.Assert(t, ok, "null.Bytes")
77 | p := Person{}
78 | xt.OK(t, json.Unmarshal(doc.Bytes, &p))
79 | got = append(got, p.Name)
80 | }
81 | sort.Strings(got)
82 |
83 | xt.Eq(t, exp, got)
84 | })
85 |
86 | t.Run("execute to return error stored by adding", func(t *testing.T) {
87 | adder := coll.Add(&Person{Name: "Laurie", Age: 19}).Add("something not OK")
88 | err := adder.Execute(context.Background())
89 | xt.KO(t, err)
90 | xt.Eq(t, "unsupported object kind string", errors.Unwrap(err).Error())
91 | })
92 | }
93 |
--------------------------------------------------------------------------------
/xmysql/connection_config.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "github.com/golistic/xgo/xstrings"
7 | )
8 |
9 | // ConnectConfig manages the configuration of a connection to a MySQL server.
10 | type ConnectConfig struct {
11 | Address string
12 | UnixSockAddr string
13 | Username string
14 | Password *string
15 | Schema string
16 | UseTLS bool
17 | AuthMethod AuthMethodType
18 | TLSServerCACertPath string `envVar:"PXMYSQL_CA_CERT"`
19 | TimeZoneName string
20 | }
21 |
22 | // DefaultConnectConfig is the default configuration used if none is provided
23 | // when a Connection is instantiated.
24 | var DefaultConnectConfig = &ConnectConfig{
25 | Address: "127.0.0.1:33060", // note that the port number is of X Plugin
26 | Username: "root",
27 | Password: xstrings.Pointer(""),
28 | Schema: "",
29 | UseTLS: false,
30 | AuthMethod: AuthMethodAuto,
31 | }
32 |
33 | // Clone duplicates other, but leaves the password nil. The caller must
34 | // save the password.
35 | func (cfg *ConnectConfig) Clone() *ConnectConfig {
36 | return &ConnectConfig{
37 | Address: cfg.Address,
38 | UnixSockAddr: cfg.UnixSockAddr,
39 | Username: cfg.Username,
40 | Password: nil,
41 | Schema: cfg.Schema,
42 | UseTLS: cfg.UseTLS,
43 | AuthMethod: cfg.AuthMethod,
44 | TLSServerCACertPath: cfg.TLSServerCACertPath,
45 | TimeZoneName: cfg.TimeZoneName,
46 | }
47 | }
48 |
49 | // SetPassword sets the password within cfg. If no password is provided,
50 | // the Password-field of cfg will be nil.
51 | // Panics when p has more than 1 element.
52 | func (cfg *ConnectConfig) SetPassword(p ...string) *ConnectConfig {
53 | switch len(p) {
54 | case 1:
55 | cfg.Password = xstrings.Pointer(p[0])
56 | case 0:
57 | cfg.Password = nil
58 | default:
59 | panic("accepting only 1 optional string")
60 | }
61 |
62 | return cfg
63 | }
64 |
--------------------------------------------------------------------------------
/xmysql/context.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "time"
8 | )
9 |
10 | type CtxKey struct{}
11 |
12 | var CtxTimeLocation = &CtxKey{}
13 |
14 | var DefaultTimeLocation = time.UTC
15 |
16 | // SetContextTimeLocation sets the time location used when decoding MySQL DATETIME and
17 | // TIMESTAMP to Go `time.Time` objects. If l is nil, it is unset, and default will
18 | // be used.
19 | func SetContextTimeLocation(ctx context.Context, l *time.Location) context.Context {
20 | return context.WithValue(ctx, CtxTimeLocation, l)
21 | }
22 |
23 | // ContextTimeLocation retrieves the time location set in context used when decoding
24 | // MySQL DATETIME and TIMESTAMP to Go `time.Time`. If none is defined in context,
25 | // or a none `*time.Location` was found, the default will be returned.
26 | func ContextTimeLocation(ctx context.Context) *time.Location {
27 | if v := ctx.Value(CtxTimeLocation); v != nil {
28 | if l, ok := v.(*time.Location); ok {
29 | return l
30 | }
31 | }
32 |
33 | return DefaultTimeLocation
34 | }
35 |
--------------------------------------------------------------------------------
/xmysql/context_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xt"
10 | )
11 |
12 | func TestContextTimeLocation(t *testing.T) {
13 | t.Run("no time location in context", func(t *testing.T) {
14 | xt.Eq(t, DefaultTimeLocation.String(), ContextTimeLocation(context.Background()).String())
15 | })
16 | }
17 |
--------------------------------------------------------------------------------
/xmysql/crud_add.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "reflect"
9 |
10 | "github.com/golistic/xgo/xstrings"
11 |
12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxcrud"
13 | mysqlxexpr "github.com/golistic/pxmysql/internal/mysqlx/mysqlxexpr"
14 | "github.com/golistic/pxmysql/xmysql/xproto"
15 | )
16 |
17 | type cruder interface {
18 | Execute(ctx context.Context) error
19 | GetError() error
20 | }
21 |
22 | type adder interface {
23 | Add(object ...any) *Add
24 | }
25 |
26 | type Add struct {
27 | collection *Collection
28 | values []any
29 | err error
30 | }
31 |
32 | var (
33 | _ cruder = (*Add)(nil)
34 | _ adder = (*Add)(nil)
35 | )
36 |
37 | func NewAdd(c *Collection) *Add {
38 |
39 | return &Add{collection: c}
40 | }
41 |
42 | // Add adds object to the queue.
43 | func (a *Add) Add(objects ...any) *Add {
44 |
45 | for _, object := range objects {
46 | rt := reflect.TypeOf(object)
47 | if reflect.ValueOf(object).Kind() == reflect.Pointer {
48 | rt = rt.Elem()
49 | }
50 | if rt.Kind() != reflect.Struct {
51 | a.err = fmt.Errorf("unsupported object kind %s", rt.Kind())
52 | }
53 |
54 | a.values = append(a.values, object)
55 | }
56 |
57 | return a
58 | }
59 |
60 | func (a *Add) Execute(ctx context.Context) error {
61 |
62 | errBaseMsg := "adding to collection %s (%w)"
63 |
64 | if a.err != nil {
65 | return fmt.Errorf(errBaseMsg, a.collection.name, a.err)
66 | }
67 |
68 | rows := make([]*mysqlxcrud.Insert_TypedRow, len(a.values))
69 |
70 | for i, v := range a.values {
71 | rows[i] = &mysqlxcrud.Insert_TypedRow{
72 | Field: []*mysqlxexpr.Expr{
73 | {
74 | Type: mysqlxexpr.Expr_OBJECT.Enum(),
75 | Object: xproto.StructExpr(v),
76 | },
77 | },
78 | }
79 | }
80 |
81 | msg := &mysqlxcrud.Insert{
82 | Collection: &mysqlxcrud.Collection{
83 | Name: xstrings.Pointer(a.collection.Name()),
84 | Schema: xstrings.Pointer(a.collection.schema.Name()),
85 | },
86 | DataModel: mysqlxcrud.DataModel_DOCUMENT.Enum(),
87 | Projection: nil,
88 | Row: rows,
89 | }
90 |
91 | ses := a.collection.schema.GetSession()
92 | if err := ses.Write(ctx, msg); err != nil {
93 | return fmt.Errorf(errBaseMsg, a.collection.name, err)
94 | }
95 |
96 | _, err := ses.handleResult(ctx, func(r *Result) bool {
97 | return r.stmtOK
98 | })
99 | if err != nil {
100 | return fmt.Errorf(errBaseMsg, a.collection.name, err)
101 | }
102 |
103 | return nil
104 | }
105 |
106 | func (a *Add) GetError() error {
107 |
108 | return a.err
109 | }
110 |
--------------------------------------------------------------------------------
/xmysql/defaults.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | const authChallengeLen = 20
6 |
7 | const (
8 | AuthMethodPlain AuthMethodType = "PLAIN"
9 | AuthMethodAuto AuthMethodType = "AUTO"
10 | AuthMethodSHA256Memory AuthMethodType = "SHA256_MEMORY"
11 | AuthMethodMySQL41 AuthMethodType = "MYSQL41"
12 | )
13 |
14 | const DefaultPort = "33060"
15 | const DefaultHost = "127.0.0.1"
16 |
17 | type AuthMethodType string
18 |
19 | type AuthMethodTypes []AuthMethodType
20 |
21 | func (a AuthMethodTypes) Has(m AuthMethodType) bool {
22 | for _, v := range a {
23 | if v == m {
24 | return true
25 | }
26 | }
27 | return false
28 | }
29 |
30 | var defaultAuthMethods = []AuthMethodType{AuthMethodMySQL41, AuthMethodSHA256Memory}
31 |
32 | var supportedAuthMethods = AuthMethodTypes{AuthMethodSHA256Memory, AuthMethodMySQL41, AuthMethodPlain, AuthMethodAuto}
33 |
34 | func DefaultAuthMethods() []AuthMethodType {
35 | return defaultAuthMethods
36 | }
37 |
38 | func SupportedAuthMethods() AuthMethodTypes {
39 | return supportedAuthMethods
40 | }
41 |
--------------------------------------------------------------------------------
/xmysql/errors.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import "fmt"
6 |
7 | var ErrNotAvailable = fmt.Errorf("not available")
8 |
--------------------------------------------------------------------------------
/xmysql/examples_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "log"
9 |
10 | "github.com/golistic/pxmysql/null"
11 | "github.com/golistic/pxmysql/xmysql"
12 | )
13 |
14 | func ExampleGetSession_auto_notls() {
15 | config := &xmysql.ConnectConfig{
16 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml
17 | Username: "user_native",
18 | }
19 | config.SetPassword("pwd_user_native")
20 |
21 | session, err := xmysql.GetSession(context.Background(), config)
22 | if err != nil {
23 | log.Fatal(err)
24 | }
25 | fmt.Println("TLS:", session.UsesTLS())
26 | // Output: TLS: false
27 | }
28 |
29 | func ExampleGetSession_plain_withtls() {
30 | config := &xmysql.ConnectConfig{
31 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml
32 | AuthMethod: xmysql.AuthMethodPlain,
33 | UseTLS: true,
34 | Username: "user_native",
35 | }
36 | config.SetPassword("pwd_user_native")
37 |
38 | session, err := xmysql.GetSession(context.Background(), config)
39 | if err != nil {
40 | log.Fatal(err)
41 | }
42 |
43 | fmt.Println("TLS:", session.UsesTLS())
44 | fmt.Println("Auth Method:", config.AuthMethod)
45 | // Output:
46 | // TLS: true
47 | // Auth Method: PLAIN
48 | }
49 |
50 | func ExampleSession_ExecuteStatement() {
51 | config := &xmysql.ConnectConfig{
52 | Address: "127.0.0.1:53360", // see _support/pxmysql-compose/docker-compose.yml
53 | AuthMethod: xmysql.AuthMethodPlain,
54 | UseTLS: true,
55 | Username: "user_native",
56 | }
57 | config.SetPassword("pwd_user_native")
58 |
59 | session, err := xmysql.GetSession(context.Background(), config)
60 | if err != nil {
61 | log.Fatal(err)
62 | }
63 |
64 | q := "SELECT ?, STR_TO_DATE('2005-03-01 07:00:01', '%Y-%m-%d %H:%i:%s')"
65 | res, err := session.ExecuteStatement(context.Background(), q, "started")
66 | if err != nil {
67 | log.Fatal(err)
68 | }
69 |
70 | for _, row := range res.Rows {
71 | fmt.Printf("%s at %s\n", row.Values[0].(string), row.Values[1].(null.Time).Time)
72 | }
73 |
74 | // Output:
75 | // started at 2005-03-01 07:00:01 +0000 UTC
76 | }
77 |
--------------------------------------------------------------------------------
/xmysql/internal/network/messages.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "database/sql/driver"
7 | "encoding/binary"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "os"
12 |
13 | "google.golang.org/protobuf/proto"
14 |
15 | "github.com/golistic/pxmysql/interfaces"
16 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlx"
17 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxconnection"
18 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxcrud"
19 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxprepare"
20 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsession"
21 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql"
22 | "github.com/golistic/pxmysql/mysqlerrors"
23 | )
24 |
25 | type ServerMessage struct {
26 | msgType int
27 | payload []byte
28 | }
29 |
30 | var _ interfaces.ServerMessager = &ServerMessage{}
31 |
32 | var maxServerMessageType int32
33 |
34 | func init() {
35 | for n := range mysqlx.ServerMessages_Type_name {
36 | if n > maxServerMessageType {
37 | maxServerMessageType = n
38 | }
39 | }
40 | }
41 |
42 | func (m *ServerMessage) Unmarshall(into proto.Message) error {
43 | if err := UnmarshalPartial(m.payload, into); err != nil {
44 | return fmt.Errorf("failed unmarshalling server message type %s (%w)",
45 | mysqlx.ServerMessages_Type(m.msgType).String(), err)
46 | }
47 | return nil
48 | }
49 |
50 | func (m *ServerMessage) ServerMessageType() mysqlx.ServerMessages_Type {
51 | return mysqlx.ServerMessages_Type(m.msgType)
52 | }
53 |
54 | func readMessage(r io.Reader) (*ServerMessage, error) {
55 | var header [5]byte
56 | if n, err := io.ReadFull(r, header[:]); err != nil {
57 | if errors.Is(err, os.ErrDeadlineExceeded) {
58 | err = os.ErrDeadlineExceeded
59 | }
60 | if errors.Is(err, io.EOF) && n < 5 {
61 | return nil, fmt.Errorf("broken pipe when reading (%w)", driver.ErrBadConn)
62 | }
63 | return nil, fmt.Errorf("failed reading message header (%w)", err)
64 | }
65 |
66 | if header[4] == 0x0a || int32(header[4]) > maxServerMessageType {
67 | return nil, mysqlerrors.New(2007)
68 | }
69 |
70 | msg := &ServerMessage{
71 | msgType: int(header[4]),
72 | }
73 |
74 | msg.payload = make([]byte, binary.LittleEndian.Uint32(header[0:4])-1)
75 | if _, err := io.ReadFull(r, msg.payload); err != nil {
76 | if errors.Is(err, os.ErrDeadlineExceeded) {
77 | err = os.ErrDeadlineExceeded
78 | }
79 | return nil, fmt.Errorf("failed reading message payload (%w)", err)
80 | }
81 |
82 | return msg, nil
83 | }
84 |
85 | func clientMessageType(msg proto.Message) (mysqlx.ClientMessages_Type, error) {
86 | // cases ordered as ClientMessage_Type constants
87 | switch msg.(type) {
88 | case *mysqlxconnection.CapabilitiesGet:
89 | return mysqlx.ClientMessages_CON_CAPABILITIES_GET, nil
90 | case *mysqlxconnection.CapabilitiesSet:
91 | return mysqlx.ClientMessages_CON_CAPABILITIES_SET, nil
92 | case *mysqlxconnection.Close:
93 | return mysqlx.ClientMessages_CON_CLOSE, nil
94 |
95 | case *mysqlxprepare.Execute:
96 | return mysqlx.ClientMessages_PREPARE_EXECUTE, nil
97 | case *mysqlxprepare.Prepare:
98 | return mysqlx.ClientMessages_PREPARE_PREPARE, nil
99 | case *mysqlxprepare.Deallocate:
100 | return mysqlx.ClientMessages_PREPARE_DEALLOCATE, nil
101 |
102 | case *mysqlxsession.AuthenticateStart:
103 | return mysqlx.ClientMessages_SESS_AUTHENTICATE_START, nil
104 | case *mysqlxsession.AuthenticateContinue:
105 | return mysqlx.ClientMessages_SESS_AUTHENTICATE_CONTINUE, nil
106 | case *mysqlxsession.Reset:
107 | return mysqlx.ClientMessages_SESS_RESET, nil
108 | case *mysqlxsession.Close:
109 | return mysqlx.ClientMessages_SESS_CLOSE, nil
110 |
111 | case *mysqlxcrud.Insert:
112 | return mysqlx.ClientMessages_CRUD_INSERT, nil
113 |
114 | case *mysqlxsql.StmtExecute:
115 | return mysqlx.ClientMessages_SQL_STMT_EXECUTE, nil
116 | default:
117 | return 0, fmt.Errorf("unsupported message '%T'", msg)
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/xmysql/internal/network/proto.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "google.golang.org/protobuf/encoding/protowire"
7 | "google.golang.org/protobuf/proto"
8 | )
9 |
10 | const NamespaceMySQLx = "mysqlx"
11 |
12 | // UnmarshalPartial parses the wire-format message in b and places the result in m.
13 | // The provided message must be mutable (e.g., a non-nil pointer to a message).
14 | // This is the same function as proto.Unmarshall except that AllowPartial option set to true.
15 | func UnmarshalPartial(b []byte, m proto.Message) error {
16 | return proto.UnmarshalOptions{
17 | RecursionLimit: protowire.DefaultRecursionLimit,
18 | AllowPartial: true,
19 | }.Unmarshal(b, m)
20 | }
21 |
--------------------------------------------------------------------------------
/xmysql/internal/network/read.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net"
11 | "time"
12 |
13 | "github.com/golistic/pxmysql/mysqlerrors"
14 | )
15 |
16 | // Read reads a message from the network connection conn.
17 | // If no deadline is set in ctx, a default will be used.
18 | func Read(ctx context.Context, conn net.Conn) (*ServerMessage, error) {
19 |
20 | deadline, ok := ctx.Deadline()
21 | if !ok {
22 | deadline = time.Now().Add(10 * time.Second)
23 | }
24 |
25 | if err := conn.SetReadDeadline(deadline); err != nil {
26 | return nil, fmt.Errorf("setting read deadline (%w)", err)
27 | }
28 |
29 | msg, err := readMessage(conn)
30 | if err != nil {
31 | if err == io.EOF {
32 | return nil, io.EOF
33 | }
34 |
35 | var myErr *mysqlerrors.Error
36 | if errors.As(err, &myErr) {
37 | return nil, myErr
38 | }
39 |
40 | return nil, err
41 | }
42 |
43 | Trace("r", msg)
44 |
45 | return msg, nil
46 | }
47 |
--------------------------------------------------------------------------------
/xmysql/internal/network/tls.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "crypto/x509"
7 | "fmt"
8 | "os"
9 | "sync"
10 | )
11 |
12 | var ServerCAPool *x509.CertPool
13 | var muServerCAPool sync.RWMutex
14 |
15 | func init() {
16 | ServerCAPool = x509.NewCertPool()
17 | }
18 |
19 | func addServerCACert(certs []byte) error {
20 |
21 | muServerCAPool.Lock()
22 | defer muServerCAPool.Unlock()
23 |
24 | if ok := ServerCAPool.AppendCertsFromPEM(certs); !ok {
25 | return fmt.Errorf("appending CA certificate to pool")
26 | }
27 | return nil
28 | }
29 |
30 | func AddServerCACertFromFile(filename string) error {
31 |
32 | certs, err := os.ReadFile(filename)
33 | if err != nil {
34 | return fmt.Errorf("reading server CA certificate (%w)", err)
35 | }
36 |
37 | return addServerCACert(certs)
38 | }
39 |
--------------------------------------------------------------------------------
/xmysql/internal/network/trace.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "encoding/json"
7 | "fmt"
8 | "os"
9 |
10 | "google.golang.org/protobuf/proto"
11 |
12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice"
13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql"
14 | )
15 |
16 | var (
17 | traceReadWrites bool
18 | TraceValues bool
19 | )
20 |
21 | func init() {
22 | _, traceReadWrites = os.LookupEnv("PXMYSQL_TRACE")
23 | _, TraceValues = os.LookupEnv("PXMYSQL_TRACE_VALUES")
24 | if !traceReadWrites {
25 | TraceValues = false
26 | }
27 | }
28 |
29 | // Trace is used for debugging and is enabled by setting the PYMYSQL_TRACE
30 | // environment variable.
31 | func Trace(action string, msg any, a ...any) {
32 | if !traceReadWrites || msg == nil {
33 | return
34 | }
35 |
36 | var indicator string
37 |
38 | switch action {
39 | case "w", "write":
40 | indicator = "\n> write:"
41 | case "r", "read":
42 | indicator = "< read"
43 | case "un", "unhandled":
44 | indicator = "< unhandled "
45 | case "error":
46 | indicator = "< ERROR "
47 | case "state":
48 | indicator = "\t< STATE "
49 | default:
50 | indicator = "< unknown"
51 | }
52 |
53 | prefix := "\t"
54 |
55 | var s string
56 | var topic string
57 | switch v := msg.(type) {
58 | case *ServerMessage:
59 | topic = v.ServerMessageType().String()
60 | case *mysqlxnotice.SessionStateChanged:
61 | topic = v.GetParam().String()
62 | doc, err := json.MarshalIndent(v.Value, prefix, " ")
63 | if err != nil {
64 | panic(err)
65 | }
66 | if doc[1] != '}' {
67 | s = fmt.Sprintf(" %s\n", string(doc))
68 | }
69 | case *mysqlxsql.StmtExecute:
70 | s = " SQL Statement: " + string(v.Stmt) + "\n"
71 | case proto.Message:
72 | topic = string(v.ProtoReflect().Descriptor().Name())
73 | doc, err := json.MarshalIndent(v, prefix, " ")
74 | if err != nil {
75 | panic(err)
76 | }
77 | if doc[1] != '}' {
78 | s = fmt.Sprintf(prefix+"%s\n", string(doc))
79 | }
80 | case string:
81 | topic = v
82 | default:
83 | topic = fmt.Sprintf("unhandled %T", msg)
84 | }
85 |
86 | _, err := fmt.Fprintf(os.Stderr, indicator+" "+topic+"\n"+s)
87 | if err != nil {
88 | panic(err)
89 | }
90 |
91 | if len(a) > 0 {
92 | _, err := fmt.Fprintf(os.Stderr, prefix+fmt.Sprint(a...)+"\n")
93 | if err != nil {
94 | panic(err)
95 | }
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/xmysql/internal/network/write.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package network
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | "encoding/binary"
9 | "errors"
10 | "fmt"
11 | "net"
12 | "syscall"
13 |
14 | "google.golang.org/protobuf/proto"
15 |
16 | "github.com/golistic/pxmysql/mysqlerrors"
17 | )
18 |
19 | // Write writes protobuf msg using conn to the server.
20 | func Write(ctx context.Context, conn net.Conn, msg proto.Message, maxAllowedPacket int) error {
21 |
22 | if conn == nil {
23 | return fmt.Errorf("not connected (%w)", driver.ErrBadConn)
24 | }
25 |
26 | msgType, err := clientMessageType(msg)
27 | if err != nil {
28 | return err
29 | }
30 |
31 | deadline, _ := ctx.Deadline()
32 | if err := conn.SetWriteDeadline(deadline); err != nil {
33 | return fmt.Errorf("failed setting write deadline (%w)", err)
34 | }
35 |
36 | b, err := proto.Marshal(msg)
37 | if err != nil {
38 | return fmt.Errorf("failed marshalling protobuf message (%w)", err)
39 | }
40 |
41 | if maxAllowedPacket > 0 && len(b) > maxAllowedPacket {
42 | return mysqlerrors.New(mysqlerrors.ClientNetPacketTooLarge)
43 | }
44 |
45 | var header [5]byte
46 | binary.LittleEndian.PutUint32(header[:], uint32(len(b))+1) // +1 is final \x00
47 |
48 | header[4] = byte(msgType)
49 |
50 | buf := &net.Buffers{header[:], b}
51 | _, err = buf.WriteTo(conn)
52 | switch {
53 | case errors.Is(err, syscall.EPIPE):
54 | return fmt.Errorf("broken pipe when writing (%w)", driver.ErrBadConn)
55 | case err != nil:
56 | return fmt.Errorf("failed sending message (%w)", err)
57 | }
58 |
59 | Trace("w", msg)
60 |
61 | return nil
62 | }
63 |
--------------------------------------------------------------------------------
/xmysql/internal/statements/quote.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package statements
4 |
5 | import (
6 | "fmt"
7 | "strings"
8 | )
9 |
10 | // QuoteValue quotes p so that it can be safely used to substituted placeholders
11 | // within a SQL query.
12 | func QuoteValue(p any) (string, error) {
13 |
14 | switch v := p.(type) {
15 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
16 | return fmt.Sprintf("%d", v), nil
17 | case []byte:
18 | return fmt.Sprintf("_binary'%x'", v), nil
19 | case float32, float64:
20 | return fmt.Sprintf("%f", v), nil
21 | case string:
22 | return "'" + strings.ReplaceAll(v, "'", `\'`) + "'", nil
23 | default:
24 | return "", fmt.Errorf("cannot quote parameter with value type %T", p)
25 | }
26 | }
27 |
28 | func QuoteIdentifier(p string) (string, error) {
29 | return "`" + strings.Replace(p, "`", "``", -1) + "`", nil
30 | }
31 |
--------------------------------------------------------------------------------
/xmysql/internal/statements/quote_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package statements_test
4 |
5 | import (
6 | "strconv"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xt"
10 |
11 | "github.com/golistic/pxmysql/xmysql/internal/statements"
12 | )
13 |
14 | func TestQuoteValue(t *testing.T) {
15 | t.Run("strings", func(t *testing.T) {
16 | var cases = []struct {
17 | got string
18 | exp string
19 | }{
20 | {
21 | got: "Gopher",
22 | exp: "'Gopher'",
23 | },
24 | {
25 | got: "'Gopher'",
26 | exp: `'\'Gopher\''`,
27 | },
28 | {
29 | got: "'poop'; DROP TABLE gophers",
30 | exp: `'\'poop\'; DROP TABLE gophers'`,
31 | },
32 | {
33 | got: "🐰",
34 | exp: `'🐰'`,
35 | },
36 | }
37 |
38 | for i, c := range cases {
39 | t.Run(strconv.Itoa(i+1), func(t *testing.T) {
40 | got, err := statements.QuoteValue(c.got)
41 | xt.OK(t, err)
42 | xt.Eq(t, c.exp, got)
43 | })
44 | }
45 | })
46 | }
47 |
--------------------------------------------------------------------------------
/xmysql/internal/statements/statement_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package statements_test
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 |
10 | "github.com/golistic/pxmysql/xmysql/internal/statements"
11 | )
12 |
13 | func TestPlaceholderIndexes(t *testing.T) {
14 | var cases = []struct {
15 | stmt string
16 | exp []int
17 | }{
18 | {
19 | stmt: `SELECT ?`,
20 | exp: []int{7},
21 | },
22 | {
23 | stmt: `SELECT ?, '?', "?"`,
24 | exp: []int{7},
25 | },
26 | {
27 | stmt: `SELECT ?, '?', "?", ?`,
28 | exp: []int{7, 20},
29 | },
30 |
31 | {
32 | stmt: `SELECT ?, '?', "?", ?, "'?'", ?`,
33 | exp: []int{7, 20, 30},
34 | },
35 | }
36 |
37 | for _, c := range cases {
38 | t.Run("", func(t *testing.T) {
39 | xt.Eq(t, c.exp, statements.PlaceholderIndexes(statements.Placeholder, c.stmt))
40 | })
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/xmysql/internal/statements/statements.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package statements
4 |
5 | import (
6 | "bytes"
7 | "fmt"
8 |
9 | "github.com/golistic/xgo/xmath"
10 | )
11 |
12 | const Placeholder = '?'
13 |
14 | // SubstitutePlaceholders replaces the placeholders within stmt with respective element of args.
15 | func SubstitutePlaceholders(stmt string, args ...any) (string, error) {
16 |
17 | placeholders := PlaceholderIndexes(Placeholder, stmt)
18 | if len(placeholders) != len(args) {
19 | return "", fmt.Errorf("need %d placeholder(s); found %d)", len(args), len(placeholders))
20 | }
21 |
22 | var nextArg int
23 | var buf []byte
24 |
25 | var index int
26 | for _, ph := range placeholders {
27 | buf = append(buf, stmt[index:ph]...)
28 |
29 | arg := args[nextArg]
30 | nextArg++
31 | index = ph + 1
32 |
33 | if arg == nil {
34 | buf = append(buf, "NULL"...)
35 | continue
36 | }
37 |
38 | quoted, err := QuoteValue(arg)
39 | if err != nil {
40 | return "", err
41 | }
42 | buf = append(buf, quoted...)
43 | }
44 |
45 | // rest of stmt
46 | buf = append(buf, stmt[index:]...)
47 |
48 | if len(args) > nextArg {
49 | return "", fmt.Errorf("%d argument(s) not substituted", xmath.AbsInt(len(args)-nextArg))
50 | } else if len(args) < nextArg {
51 | return "", fmt.Errorf("%d placeholder(s) not substituted", xmath.AbsInt(len(args)-nextArg))
52 | }
53 |
54 | return string(buf), nil
55 | }
56 |
57 | // PlaceholderIndexes returns the indices of all placeholders within query.
58 | func PlaceholderIndexes(placeholder rune, query string) []int {
59 |
60 | var indexes []int
61 |
62 | var quoted bool
63 | var quote rune
64 | for i, r := range bytes.Runes([]byte(query)) {
65 | // we skip quoted so that we support queries which have placeholder in string literals
66 | if r == '"' || r == '\'' {
67 | if quoted && quote == r {
68 | quoted = false
69 | quote = 0
70 | continue
71 | } else if !quoted {
72 | quoted = true
73 | quote = r
74 | continue
75 | }
76 | }
77 |
78 | if quoted {
79 | continue
80 | }
81 |
82 | if r == placeholder {
83 | indexes = append(indexes, i)
84 | }
85 | }
86 |
87 | return indexes
88 | }
89 |
--------------------------------------------------------------------------------
/xmysql/main_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "fmt"
7 | "os"
8 | "runtime/debug"
9 | "strconv"
10 | "testing"
11 |
12 | "github.com/golistic/pxmysql/internal/xxt"
13 | )
14 |
15 | var (
16 | testExitCode int
17 | testErr error
18 | testSchema = "pxmysql_tests"
19 | testContext *xxt.TestContext
20 | )
21 |
22 | var (
23 | testMySQLMaxAllowedPacket = -1 // MySQL's mysqlx_max_allowed_packet
24 | )
25 |
26 | func testTearDown() {
27 | if testErr != nil {
28 | testExitCode = 1
29 | fmt.Println(testErr)
30 | }
31 | }
32 |
33 | func TestMain(m *testing.M) {
34 | defer func() { os.Exit(testExitCode) }()
35 | defer testTearDown()
36 | defer func() {
37 | if r := recover(); r != nil {
38 | fmt.Println(string(debug.Stack()))
39 | os.Exit(1)
40 | }
41 | }()
42 |
43 | var err error
44 | if testContext, testErr = xxt.New(testSchema); err != nil {
45 | return
46 | }
47 |
48 | if err := testContext.Server.LoadSQLScript("base"); err != nil {
49 | testErr = fmt.Errorf("failed testing MySQL running in container %s (%s)",
50 | testContext.Server.Container.Name, err)
51 | return
52 | }
53 |
54 | if err := testContext.Server.Container.CopyFileFromContainer(
55 | "/etc/mysql/conf.d/ca.pem", "_testdata/mysql_ca.pem"); err != nil {
56 | testErr = fmt.Errorf("failed copying MySQL CA certificate from container %s (%s)",
57 | testContext.Server.Container.Name, err)
58 | return
59 | }
60 |
61 | if v, err := testContext.Server.Variable("global", "mysqlx_max_allowed_packet"); err != nil {
62 | testErr = fmt.Errorf("failed getting variable mysqlx_max_allowed_packet (%s)", err)
63 | return
64 | } else {
65 | n, err := strconv.ParseInt(v, 10, 32)
66 | if err != nil {
67 | testErr = fmt.Errorf("failed converting variable mysqlx_max_allowed_packet (%s)", err)
68 | return
69 | }
70 | testMySQLMaxAllowedPacket = int(n)
71 | }
72 |
73 | testExitCode = m.Run()
74 | }
75 |
--------------------------------------------------------------------------------
/xmysql/notice.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "fmt"
7 |
8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxnotice"
9 | "github.com/golistic/pxmysql/xmysql/internal/network"
10 | )
11 |
12 | type StateChanges struct {
13 | ClientID uint64
14 | GeneratedInsertID uint64
15 | RowsAffected uint64
16 | CurrentSchema string
17 | ProducedMessage string
18 | }
19 |
20 | type notices struct {
21 | warnings []*mysqlxnotice.Warning
22 | sessionVariableChanges []*mysqlxnotice.SessionVariableChanged
23 | sessionStateChanges []*mysqlxnotice.SessionStateChanged
24 | groupReplicationStateChanges []*mysqlxnotice.GroupReplicationStateChanged
25 | serverHello *mysqlxnotice.ServerHello
26 | unhandled []mysqlxnotice.Frame_Type
27 | stateChanges StateChanges
28 | }
29 |
30 | func (n *notices) add(msg *network.ServerMessage) error {
31 | frame := &mysqlxnotice.Frame{}
32 | if err := msg.Unmarshall(frame); err != nil {
33 | return fmt.Errorf("failed unmarshalling notice message (%w)", err)
34 | }
35 |
36 | switch mysqlxnotice.Frame_Type(frame.GetType()) {
37 | case mysqlxnotice.Frame_WARNING:
38 | m := &mysqlxnotice.Warning{}
39 | if err := msg.Unmarshall(m); err != nil {
40 | return err
41 | }
42 | n.warnings = append(n.warnings, m)
43 | case mysqlxnotice.Frame_SESSION_VARIABLE_CHANGED:
44 | m := &mysqlxnotice.SessionVariableChanged{}
45 | if err := msg.Unmarshall(m); err != nil {
46 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err)
47 | }
48 | n.sessionVariableChanges = append(n.sessionVariableChanges, m)
49 | case mysqlxnotice.Frame_SESSION_STATE_CHANGED:
50 | m := &mysqlxnotice.SessionStateChanged{}
51 | if err := network.UnmarshalPartial(frame.Payload, m); err != nil {
52 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err)
53 | }
54 | network.Trace("state", m)
55 |
56 | switch m.GetParam() {
57 | case mysqlxnotice.SessionStateChanged_GENERATED_INSERT_ID:
58 | if len(m.Value) > 0 {
59 | n.stateChanges.GeneratedInsertID = m.Value[0].GetVUnsignedInt()
60 | }
61 | case mysqlxnotice.SessionStateChanged_ROWS_AFFECTED:
62 | if len(m.Value) > 0 {
63 | n.stateChanges.RowsAffected = m.Value[0].GetVUnsignedInt()
64 | }
65 | case mysqlxnotice.SessionStateChanged_CURRENT_SCHEMA:
66 | if len(m.Value) > 0 {
67 | n.stateChanges.CurrentSchema = string(m.Value[0].VString.Value)
68 | }
69 | case mysqlxnotice.SessionStateChanged_PRODUCED_MESSAGE:
70 | if len(m.Value) > 0 {
71 | n.stateChanges.ProducedMessage = string(m.Value[0].VString.Value)
72 | }
73 | case mysqlxnotice.SessionStateChanged_CLIENT_ID_ASSIGNED:
74 | if len(m.Value) > 0 {
75 | n.stateChanges.ClientID = m.Value[0].GetVUnsignedInt()
76 | }
77 | }
78 |
79 | n.sessionStateChanges = append(n.sessionStateChanges, m)
80 | case mysqlxnotice.Frame_GROUP_REPLICATION_STATE_CHANGED:
81 | m := &mysqlxnotice.GroupReplicationStateChanged{}
82 | if err := msg.Unmarshall(m); err != nil {
83 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err)
84 | }
85 | n.groupReplicationStateChanges = append(n.groupReplicationStateChanges, m)
86 | case mysqlxnotice.Frame_SERVER_HELLO:
87 | m := &mysqlxnotice.ServerHello{}
88 | if err := msg.Unmarshall(m); err != nil {
89 | return fmt.Errorf("failed unmarshalling '%s' (%w)", m.String(), err)
90 | }
91 | n.serverHello = m
92 | default:
93 | n.unhandled = append(n.unhandled, mysqlxnotice.Frame_Type(frame.GetType()))
94 | }
95 |
96 | return nil
97 | }
98 |
--------------------------------------------------------------------------------
/xmysql/prepared.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "database/sql/driver"
8 | "fmt"
9 | "strings"
10 | "time"
11 |
12 | "github.com/golistic/pxmysql/decimal"
13 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes"
14 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxprepare"
15 | "github.com/golistic/pxmysql/xmysql/xproto"
16 | )
17 |
18 | type Prepared struct {
19 | session *Session
20 | result *Result
21 | numPlaceholders int
22 | }
23 |
24 | // Execute the prepared statements replacing placeholders with args.
25 | func (p *Prepared) Execute(ctx context.Context, args ...any) (*Result, error) {
26 | if p.session == nil || p.result == nil || p.result.stmtID == 0 {
27 | return nil, fmt.Errorf("not initialized")
28 | }
29 |
30 | pArgs := make([]*mysqlxdatatypes.Any, len(args))
31 |
32 | for i, arg := range args {
33 | var err error
34 |
35 | var a any
36 | switch v := arg.(type) {
37 | case driver.NamedValue:
38 | a = v.Value
39 | default:
40 | a = arg
41 | }
42 |
43 | // ridiculous type-switch; preventing using reflection
44 | switch v := a.(type) {
45 | case nil:
46 | pArgs[i] = xproto.Nil()
47 | case bool:
48 | pArgs[i] = xproto.Bool(v)
49 | case *bool:
50 | pArgs[i] = xproto.Bool(*v)
51 | case int:
52 | pArgs[i] = xproto.SignedInt(v)
53 | case int8:
54 | pArgs[i] = xproto.SignedInt(v)
55 | case int16:
56 | pArgs[i] = xproto.SignedInt(v)
57 | case int32:
58 | pArgs[i] = xproto.SignedInt(v)
59 | case int64:
60 | pArgs[i] = xproto.SignedInt(v)
61 | case uint:
62 | pArgs[i] = xproto.UnsignedInt(v)
63 | case uint8:
64 | pArgs[i] = xproto.UnsignedInt(v)
65 | case uint16:
66 | pArgs[i] = xproto.UnsignedInt(v)
67 | case uint32:
68 | pArgs[i] = xproto.UnsignedInt(v)
69 | case uint64:
70 | pArgs[i] = xproto.UnsignedInt(v)
71 | case *int:
72 | pArgs[i] = xproto.SignedInt(*v)
73 | case *int8:
74 | pArgs[i] = xproto.SignedInt(*v)
75 | case *int16:
76 | pArgs[i] = xproto.SignedInt(*v)
77 | case *int32:
78 | pArgs[i] = xproto.SignedInt(*v)
79 | case *int64:
80 | pArgs[i] = xproto.SignedInt(*v)
81 | case *uint:
82 | pArgs[i] = xproto.UnsignedInt(*v)
83 | case *uint8:
84 | pArgs[i] = xproto.UnsignedInt(*v)
85 | case *uint16:
86 | pArgs[i] = xproto.UnsignedInt(*v)
87 | case *uint32:
88 | pArgs[i] = xproto.UnsignedInt(*v)
89 | case *uint64:
90 | pArgs[i] = xproto.UnsignedInt(*v)
91 | case string:
92 | pArgs[i] = xproto.String(v)
93 | case *string:
94 | pArgs[i] = xproto.String(v)
95 | case []byte:
96 | pArgs[i] = xproto.Bytes(v)
97 | case float32:
98 | pArgs[i] = xproto.Float32(v)
99 | case *float32:
100 | pArgs[i] = xproto.Float32(*v)
101 | case float64:
102 | pArgs[i] = xproto.Float64(v)
103 | case *float64:
104 | pArgs[i] = xproto.Float64(*v)
105 | case decimal.Decimal:
106 | pArgs[i] = xproto.Decimal(v)
107 | case *decimal.Decimal:
108 | pArgs[i] = xproto.Decimal(*v)
109 | case time.Time:
110 | if pArgs[i], err = xproto.Time(v, p.session.TimeLocation().String()); err != nil {
111 | return nil, err
112 | }
113 | case *time.Time:
114 | if pArgs[i], err = xproto.Time(*v, p.session.TimeLocation().String()); err != nil {
115 | return nil, err
116 | }
117 | case []string:
118 | pArgs[i] = xproto.String(strings.Join(v, ","))
119 | default:
120 | return nil, fmt.Errorf("argument type '%T' not supported", a)
121 | }
122 | }
123 |
124 | if err := p.session.Write(ctx, &mysqlxprepare.Execute{
125 | StmtId: &p.result.stmtID,
126 | Args: pArgs,
127 | }); err != nil {
128 | return nil, err
129 | }
130 |
131 | res, err := p.session.handleResult(ctx, func(r *Result) bool {
132 | return r.stmtOK
133 | })
134 | if err != nil {
135 | return nil, err
136 | }
137 |
138 | return res, nil
139 | }
140 |
141 | // Deallocate makes this prepared statement not usable any longer.
142 | func (p *Prepared) Deallocate(ctx context.Context) error {
143 | return p.session.DeallocatePrepareStatement(ctx, p.result.stmtID)
144 | }
145 |
146 | // StatementID returns the statement ID.
147 | func (p *Prepared) StatementID() uint32 {
148 | return p.result.stmtID
149 | }
150 |
151 | // NumPlaceholders returns the number of placeholder parameters.
152 | func (p *Prepared) NumPlaceholders() int {
153 | return p.numPlaceholders
154 | }
155 |
--------------------------------------------------------------------------------
/xmysql/result_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, 2023, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "testing"
9 | "time"
10 |
11 | "github.com/golistic/xgo/xt"
12 |
13 | "github.com/golistic/pxmysql/internal/xxt"
14 | "github.com/golistic/pxmysql/xmysql"
15 | )
16 |
17 | func TestResult_FetchRow(t *testing.T) {
18 | config := &xmysql.ConnectConfig{
19 | Address: testContext.XPluginAddr,
20 | Username: xxt.UserNative,
21 | }
22 | config.SetPassword(xxt.UserNativePwd)
23 |
24 | tbl := "bulk_fidiEfiS223"
25 |
26 | ses, err := xmysql.GetSession(context.Background(), config)
27 | xt.OK(t, err)
28 | xt.OK(t, ses.SetActiveSchema(context.Background(), testSchema))
29 |
30 | createTable := fmt.Sprintf(
31 | "CREATE TABLE `%s` (id INT AUTO_INCREMENT PRIMARY KEY, c1 VARCHAR(30) NOT NULL)", tbl)
32 |
33 | _, err = ses.ExecuteStatement(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tbl))
34 | xt.OK(t, err)
35 |
36 | _, err = ses.ExecuteStatement(context.Background(), createTable)
37 | xt.OK(t, err)
38 |
39 | nrRows := 100
40 | for i := 0; i < nrRows; i++ {
41 | _, err = ses.ExecuteStatement(context.Background(),
42 | fmt.Sprintf("INSERT INTO `%s` (c1) VALUES (?)", tbl), fmt.Sprintf("data%d", i+1))
43 | xt.OK(t, err)
44 | }
45 |
46 | t.Run("fetch", func(t *testing.T) {
47 | ses, err := xmysql.GetSession(context.Background(), config)
48 | xt.OK(t, err)
49 | xt.OK(t, ses.SetActiveSchema(context.Background(), testSchema))
50 |
51 | mUse := xxt.NewMemoryUse()
52 | res, err := ses.ExecuteStatement(context.Background(),
53 | fmt.Sprintf("SELECT * FROM `%s` ORDER BY id", tbl))
54 | xt.OK(t, err)
55 | xt.Eq(t, nrRows, len(res.Rows))
56 |
57 | rowCtx, cancel := context.WithTimeout(context.Background(), time.Second)
58 | defer cancel()
59 |
60 | for i := 1; res.Row != nil; i++ {
61 | id := res.Row.Values[0].(int64)
62 | xt.Eq(t, i, id)
63 | xt.Eq(t, fmt.Sprintf("data%d", i), res.Row.Values[1].(string))
64 |
65 | err = res.FetchRow(rowCtx)
66 | xt.OK(t, err)
67 | }
68 | mUse.Stop()
69 |
70 | // keep allocations in check (if nrRows changes, this will obviously go up)
71 | xt.Assert(t, mUse.DiffAlloc() < 35000)
72 | })
73 | }
74 |
--------------------------------------------------------------------------------
/xmysql/schema.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "sort"
9 |
10 | "github.com/golistic/pxmysql/null"
11 | "github.com/golistic/pxmysql/xmysql/collection"
12 | "github.com/golistic/pxmysql/xmysql/xproto"
13 | )
14 |
15 | type ObjectKind string
16 |
17 | const (
18 | ObjectCollection ObjectKind = "COLLECTION"
19 | )
20 |
21 | // Schema defines the representation of a database schema. It provides
22 | // functionality to access the schema's contents.
23 | type Schema struct {
24 | session *Session
25 | name string
26 | }
27 |
28 | // newSchema instantiates a new Schema object using session. If name is the
29 | // empty string, the current schema of session will be used.
30 | func newSchema(session *Session, name string) (*Schema, error) {
31 | if session == nil {
32 | return nil, fmt.Errorf("session closed")
33 | }
34 |
35 | return &Schema{
36 | session: session,
37 | name: name,
38 | }, nil
39 | }
40 |
41 | func (s *Schema) String() string {
42 | return fmt.Sprintf("", s.name, s.session)
43 | }
44 |
45 | // Name returns the schema or database name.
46 | func (s *Schema) Name() string {
47 | return s.name
48 | }
49 |
50 | // GetSession returns the underlying session of s.
51 | func (s *Schema) GetSession() *Session {
52 | return s.session
53 | }
54 |
55 | // GetCollection retrieve the collection using its name.
56 | // To keep compatible with behavior seen it MySQL connectors, when the collection
57 | // does not exist, by default, no error is returned. To return ErrNotAvailable instead,
58 | // use the functional option collection.GetValidateExistence.
59 | func (s *Schema) GetCollection(ctx context.Context, name string, options ...collection.GetOption) (*Collection, error) {
60 |
61 | c, err := newCollection(s, name)
62 | if err != nil {
63 | return nil, fmt.Errorf("getting collection (%w)", err)
64 | }
65 |
66 | opts := collection.NewGetOptions(options...)
67 | if opts.ValidateExistence {
68 | if err := c.CheckExistence(ctx); err != nil {
69 | return nil, err
70 | }
71 | }
72 |
73 | return c, nil
74 | }
75 |
76 | // GetCollections retrieve all available collections (does not include views or tables).
77 | func (s *Schema) GetCollections(ctx context.Context) ([]*Collection, error) {
78 |
79 | names, err := s.objectNames(ctx, ObjectCollection)
80 | if err != nil {
81 | return nil, fmt.Errorf("getting collections (%w)", err)
82 | }
83 |
84 | if len(names) == 0 {
85 | return nil, nil
86 | }
87 |
88 | collections := make([]*Collection, len(names))
89 | for i, name := range names {
90 | collections[i], err = newCollection(s, name)
91 | if err != nil {
92 | return nil, fmt.Errorf("getting collections (%w)", err)
93 | }
94 | }
95 |
96 | return collections, nil
97 | }
98 |
99 | // CreateCollection creates a new collection. If the functional option
100 | // collection.CreateReuseExisting is used, no error is reported when collection
101 | // already exists.
102 | func (s *Schema) CreateCollection(ctx context.Context, name string,
103 | options ...collection.CreateOption) (*Collection, error) {
104 |
105 | c, err := newCollection(s, name)
106 | if err != nil {
107 | return nil, fmt.Errorf("creating collection (%w)", err)
108 | }
109 |
110 | opts := collection.NewCreateOptions(options...)
111 |
112 | args := xproto.CommandArgs(
113 | xproto.ObjectField("schema", s.name),
114 | xproto.ObjectField("name", name),
115 | xproto.ObjectField("options", xproto.ObjectFields{
116 | xproto.ObjectField("reuse_existing", opts.ReuseExisting),
117 | }),
118 | )
119 |
120 | _, err = s.session.ExecCommand(ctx, "create_collection", args)
121 | if err != nil {
122 | return nil, fmt.Errorf("creating collection (%w)", err)
123 | }
124 |
125 | return c, nil
126 | }
127 |
128 | // DropCollection drops the collection.
129 | func (s *Schema) DropCollection(ctx context.Context, name string) error {
130 |
131 | args := xproto.CommandArgs(
132 | xproto.ObjectField("schema", s.name),
133 | xproto.ObjectField("name", name),
134 | )
135 |
136 | _, err := s.session.ExecCommand(ctx, "drop_collection", args)
137 | if err != nil {
138 | return fmt.Errorf("dropping collection (%w)", err)
139 | }
140 |
141 | return nil
142 | }
143 |
144 | func (s *Schema) objectNames(ctx context.Context, kind ObjectKind) ([]string, error) {
145 |
146 | args := xproto.CommandArgs(
147 | xproto.ObjectField("schema", s.name),
148 | )
149 |
150 | if err := s.session.Write(ctx, xproto.Command("list_objects", args)); err != nil {
151 | return nil, err
152 | }
153 |
154 | res, err := s.session.handleResult(ctx, func(r *Result) bool {
155 | return r.stmtOK
156 | })
157 | if err != nil {
158 | return nil, err
159 | }
160 |
161 | if len(res.Rows) == 0 {
162 | return nil, nil
163 | }
164 |
165 | var names []string
166 | for _, row := range res.Rows {
167 | objType, ok := row.Values[1].(string)
168 | if !ok || objType != string(kind) {
169 | continue
170 | }
171 |
172 | name, ok := row.Values[0].(null.String)
173 | if !ok || !name.Valid {
174 | continue
175 | }
176 |
177 | names = append(names, name.String)
178 | }
179 |
180 | if len(names) == 0 {
181 | return nil, nil
182 | }
183 |
184 | sort.Strings(names)
185 |
186 | return names, nil
187 | }
188 |
--------------------------------------------------------------------------------
/xmysql/schema_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xmysql_test
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "fmt"
9 | "sort"
10 | "testing"
11 |
12 | "github.com/golistic/xgo/xstrings"
13 | "github.com/golistic/xgo/xt"
14 |
15 | "github.com/golistic/pxmysql/internal/xxt"
16 | "github.com/golistic/pxmysql/xmysql"
17 | "github.com/golistic/pxmysql/xmysql/collection"
18 | )
19 |
20 | func TestSchema_GetSession(t *testing.T) {
21 |
22 | t.Run("session of schema returned", func(t *testing.T) {
23 |
24 | config := &xmysql.ConnectConfig{
25 | Address: testContext.XPluginAddr,
26 | Username: xxt.UserNative,
27 | }
28 | config.SetPassword(xxt.UserNativePwd)
29 |
30 | ctx := context.Background()
31 |
32 | exp, err := xmysql.GetSession(ctx, config)
33 | xt.OK(t, err)
34 |
35 | for i := 0; i < 10; i++ {
36 | schema, err := exp.GetSchema(ctx)
37 | xt.OK(t, err)
38 |
39 | got := schema.GetSession()
40 | xt.Assert(t, got != nil, "expected not nil")
41 | xt.Eq(t, exp, got)
42 | }
43 | })
44 |
45 | t.Run("no session returns nil", func(t *testing.T) {
46 |
47 | xt.Eq(t, nil, (&xmysql.Schema{}).GetSession())
48 | })
49 | }
50 |
51 | func TestSchema_GetCollections(t *testing.T) {
52 |
53 | config := &xmysql.ConnectConfig{
54 | Address: testContext.XPluginAddr,
55 | Username: xxt.UserNative,
56 | }
57 | config.SetPassword(xxt.UserNativePwd)
58 |
59 | t.Run("all collections", func(t *testing.T) {
60 |
61 | xt.OK(t, testContext.Server.LoadSQLScript("schema_collections"))
62 |
63 | ses, err := xmysql.GetSession(context.Background(), config)
64 | xt.OK(t, err)
65 |
66 | exp := []string{"collection_wic28skwixkd", "collection_weux73293jsnsj"}
67 | sort.Strings(exp)
68 |
69 | schema, err := ses.GetSchemaWithName(context.Background(), "pxmysql_tests")
70 | xt.OK(t, err)
71 |
72 | collections, err := schema.GetCollections(context.Background())
73 | xt.OK(t, err)
74 |
75 | xt.Assert(t, len(collections) >= len(exp), fmt.Sprintf("expected at least %d", len(exp)))
76 |
77 | var got []string
78 | for _, s := range collections {
79 | got = append(got, s.Name())
80 | }
81 | sort.Strings(got)
82 |
83 | xt.Assert(t, func(exp, got []string) bool {
84 | if len(exp) > len(got) {
85 | return false
86 | }
87 | for _, l := range exp {
88 | if !xstrings.SliceHas(got, l) {
89 | return false
90 | }
91 | }
92 | return true
93 | }(exp, got))
94 | })
95 | }
96 |
97 | func TestSchema_CreateCollection(t *testing.T) {
98 |
99 | config := &xmysql.ConnectConfig{
100 | Address: testContext.XPluginAddr,
101 | Username: xxt.UserNative,
102 | Schema: "pxmysql_tests",
103 | }
104 | config.SetPassword(xxt.UserNativePwd)
105 |
106 | ses, err := xmysql.GetSession(context.Background(), config)
107 | xt.OK(t, err)
108 |
109 | schema, err := ses.GetSchema(context.Background())
110 | xt.OK(t, err)
111 |
112 | t.Run("name is required", func(t *testing.T) {
113 |
114 | ctx := context.Background()
115 | _, err := schema.CreateCollection(ctx, "")
116 | xt.KO(t, err)
117 | xt.Eq(t, "creating collection (invalid name)", err.Error())
118 | })
119 |
120 | t.Run("create and drop", func(t *testing.T) {
121 |
122 | ctx := context.Background()
123 | name := "ciwejkuwmidi2938x"
124 | c, err := schema.CreateCollection(ctx, name)
125 | xt.OK(t, err)
126 | xt.Eq(t, name, c.Name())
127 |
128 | t.Run("check existence", func(t *testing.T) {
129 | c, err := schema.GetCollection(ctx, name, collection.GetValidateExistence())
130 | xt.OK(t, err)
131 | xt.Eq(t, name, c.Name())
132 |
133 | t.Run("drop", func(t *testing.T) {
134 | err := schema.DropCollection(ctx, name)
135 | xt.OK(t, err)
136 |
137 | err = c.CheckExistence(ctx)
138 | xt.KO(t, err)
139 | xt.Assert(t, errors.Is(err, xmysql.ErrNotAvailable))
140 | })
141 | })
142 | })
143 |
144 | t.Run("reuse existing", func(t *testing.T) {
145 |
146 | ctx := context.Background()
147 |
148 | name := "eovwo28373"
149 | _, err := schema.CreateCollection(ctx, name)
150 | xt.OK(t, err)
151 |
152 | _, err = schema.CreateCollection(ctx, name)
153 | xt.KO(t, err)
154 | xt.Eq(t, "table 'eovwo28373' already exists [1050:42S01]", errors.Unwrap(err).Error())
155 |
156 | _, err = schema.CreateCollection(ctx, name, collection.CreateReuseExisting())
157 | xt.OK(t, err)
158 | })
159 | }
160 |
--------------------------------------------------------------------------------
/xmysql/xproto/command.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto
4 |
5 | import (
6 | "github.com/golistic/xgo/xstrings"
7 |
8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes"
9 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxsql"
10 | "github.com/golistic/pxmysql/xmysql/internal/network"
11 | )
12 |
13 | func Command(command string, args *mysqlxdatatypes.Any) *mysqlxsql.StmtExecute {
14 | return &mysqlxsql.StmtExecute{
15 | Namespace: xstrings.Pointer(network.NamespaceMySQLx),
16 | Stmt: []byte(command),
17 | Args: []*mysqlxdatatypes.Any{args},
18 | }
19 | }
20 |
21 | func CommandArgs(fields ...*mysqlxdatatypes.Object_ObjectField) *mysqlxdatatypes.Any {
22 | return &mysqlxdatatypes.Any{
23 | Type: mysqlxdatatypes.Any_OBJECT.Enum(),
24 | Obj: &mysqlxdatatypes.Object{
25 | Fld: fields,
26 | },
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/xmysql/xproto/expr.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto
4 |
5 | import (
6 | "reflect"
7 | "strings"
8 |
9 | "github.com/golistic/xgo/xstrings"
10 |
11 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxexpr"
12 | )
13 |
14 | func Expr[T reflect.Value | any](value T) *mysqlxexpr.Expr {
15 |
16 | var rv reflect.Value
17 |
18 | switch v := any(value).(type) {
19 | case nil:
20 | return nil
21 | case reflect.Value:
22 | rv = v
23 | default:
24 | rv = reflect.ValueOf(v)
25 | }
26 |
27 | switch reflect.Indirect(rv).Kind() {
28 | case reflect.Slice:
29 | return &mysqlxexpr.Expr{
30 | Type: mysqlxexpr.Expr_ARRAY.Enum(),
31 | Array: sliceExpr(rv),
32 | }
33 | case reflect.Struct:
34 | return &mysqlxexpr.Expr{
35 | Type: mysqlxexpr.Expr_OBJECT.Enum(),
36 | Object: StructExpr(rv.Interface()),
37 | }
38 | default:
39 | return &mysqlxexpr.Expr{
40 | Type: mysqlxexpr.Expr_LITERAL.Enum(),
41 | Literal: Scalar(rv),
42 | }
43 | }
44 | }
45 |
46 | func StructExpr(object any) *mysqlxexpr.Object {
47 |
48 | rv := reflect.Indirect(reflect.ValueOf(object))
49 |
50 | rt := reflect.TypeOf(object)
51 | if rt.Kind() == reflect.Pointer {
52 | rt = reflect.TypeOf(object).Elem()
53 | }
54 |
55 | obj := &mysqlxexpr.Object{}
56 |
57 | for i := 0; i < rv.NumField(); i++ {
58 | rvf := rv.Field(i)
59 | if !rvf.CanInterface() {
60 | continue
61 | }
62 | rtf := rt.Field(i)
63 |
64 | name := rtf.Name
65 | tag := rtf.Tag.Get("json")
66 | if tag != "" {
67 | if tag == "-" || (strings.HasSuffix(tag, ",omitempty") && rvf.IsZero()) {
68 | continue
69 | }
70 | name = strings.Replace(tag, ",omitempty", "", -1)
71 | }
72 |
73 | obj.Fld = append(obj.Fld, &mysqlxexpr.Object_ObjectField{
74 | Key: xstrings.Pointer(name),
75 | Value: Expr(rvf),
76 | })
77 | }
78 |
79 | return obj
80 | }
81 |
82 | func sliceExpr(value reflect.Value) *mysqlxexpr.Array {
83 | array := &mysqlxexpr.Array{
84 | Value: make([]*mysqlxexpr.Expr, value.Len()),
85 | }
86 |
87 | for i := 0; i < value.Len(); i++ {
88 | array.Value[i] = Expr(value.Index(i))
89 | }
90 |
91 | return array
92 | }
93 |
--------------------------------------------------------------------------------
/xmysql/xproto/expr_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/golistic/xgo/xt"
9 | )
10 |
11 | type Department struct {
12 | Name string `json:"name"`
13 | }
14 |
15 | type Person struct {
16 | Name string `json:"name"`
17 | Age int `json:"age"`
18 | Department Department `json:"department,omitempty"`
19 | }
20 |
21 | func TestExpr(t *testing.T) {
22 | t.Run("object with nested object", func(t *testing.T) {
23 | expr := Expr(&Person{Name: "Alice", Age: 36, Department: Department{Name: "Engineering"}})
24 | got := expr.Object
25 |
26 | xt.Eq(t, "name", *got.Fld[0].Key)
27 | xt.Eq(t, "Alice", string(got.Fld[0].Value.Literal.VString.Value))
28 | xt.Eq(t, "age", *got.Fld[1].Key)
29 | xt.Eq(t, int64(36), *got.Fld[1].Value.Literal.VSignedInt)
30 | xt.Eq(t, "department", *got.Fld[2].Key)
31 | xt.Eq(t, "Engineering", string(got.Fld[2].Value.Object.Fld[0].Value.Literal.VString.Value))
32 | })
33 |
34 | t.Run("array of objects", func(t *testing.T) {
35 | exp := []*Person{
36 | {Name: "Alice", Age: 36},
37 | {Name: "Bob", Age: 34},
38 | }
39 |
40 | expr := Expr(exp)
41 | array := expr.Array
42 | xt.Eq(t, 2, len(array.Value))
43 |
44 | for i, v := range array.Value {
45 | got := v.Object.Fld[0]
46 | xt.Eq(t, "name", *got.Key)
47 | xt.Eq(t, exp[i].Name, got.Value.Literal.VString.Value)
48 | }
49 | })
50 | }
51 |
--------------------------------------------------------------------------------
/xmysql/xproto/fields.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto
4 |
5 | import (
6 | "github.com/golistic/xgo/xstrings"
7 |
8 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes"
9 | )
10 |
11 | type ObjectFields = []*mysqlxdatatypes.Object_ObjectField
12 |
13 | func ObjectField[T ~string | ~bool | ~[]*mysqlxdatatypes.Object_ObjectField](key string, value T) *mysqlxdatatypes.Object_ObjectField {
14 | f := &mysqlxdatatypes.Object_ObjectField{
15 | Key: xstrings.Pointer(key),
16 | }
17 |
18 | switch v := any(value).(type) {
19 | case string:
20 | f.Value = String(v)
21 | case bool:
22 | f.Value = Bool(v)
23 | case []*mysqlxdatatypes.Object_ObjectField:
24 | f.Value = &mysqlxdatatypes.Any{
25 | Type: mysqlxdatatypes.Any_OBJECT.Enum(),
26 | Obj: &mysqlxdatatypes.Object{
27 | Fld: v,
28 | },
29 | }
30 | default:
31 | panic("unsupported value type")
32 | }
33 |
34 | return f
35 | }
36 |
--------------------------------------------------------------------------------
/xmysql/xproto/scalar.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto
4 |
5 | import (
6 | "fmt"
7 | "reflect"
8 | "time"
9 |
10 | "golang.org/x/exp/constraints"
11 | "google.golang.org/protobuf/proto"
12 |
13 | "github.com/golistic/pxmysql/decimal"
14 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes"
15 | )
16 |
17 | func Scalar[T reflect.Value | any](value T) *mysqlxdatatypes.Scalar {
18 |
19 | var rv reflect.Value
20 |
21 | switch v := any(value).(type) {
22 | case nil:
23 | return NilScalar()
24 | case reflect.Value:
25 | rv = v
26 | default:
27 | rv = reflect.ValueOf(v)
28 | }
29 |
30 | switch {
31 | case rv.Kind() == reflect.Slice:
32 | switch {
33 | case rv.Type().Elem().Kind() == reflect.Uint8 && rv.Len() == 0: // empty []byte
34 | return NilScalar()
35 | case rv.Type().Elem().Kind() == reflect.Uint8: // []byte
36 | return BytesScalar(rv.Bytes())
37 | default:
38 | panic("unsupported scalar slice")
39 | }
40 | case rv.Kind() == reflect.Pointer && rv.IsNil():
41 | return NilScalar()
42 | }
43 |
44 | rv = reflect.Indirect(rv)
45 |
46 | switch rv.Kind() {
47 | case reflect.Bool:
48 | return BoolScalar(rv.Bool())
49 | case reflect.String:
50 | return StringScalar(rv.String())
51 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
52 | return UnsignedIntScalar(rv.Uint())
53 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
54 | return SignedIntScalar(rv.Int())
55 | case reflect.Float32:
56 | return Float32Scalar(float32(rv.Float()))
57 | case reflect.Float64:
58 | return Float64Scalar(rv.Float())
59 | default:
60 | panic(fmt.Sprintf("unsupported scalar value; was %s", rv.Kind()))
61 | }
62 | }
63 |
64 | func Bool(value bool) *mysqlxdatatypes.Any {
65 | return &mysqlxdatatypes.Any{
66 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
67 | Scalar: BoolScalar(value),
68 | }
69 | }
70 |
71 | func BoolScalar(value bool) *mysqlxdatatypes.Scalar {
72 | return &mysqlxdatatypes.Scalar{
73 | Type: mysqlxdatatypes.Scalar_V_BOOL.Enum(),
74 | VBool: proto.Bool(value),
75 | }
76 | }
77 |
78 | func Nil() *mysqlxdatatypes.Any {
79 | return &mysqlxdatatypes.Any{
80 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
81 | Scalar: NilScalar(),
82 | }
83 | }
84 |
85 | func NilScalar() *mysqlxdatatypes.Scalar {
86 | return &mysqlxdatatypes.Scalar{
87 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(),
88 | }
89 | }
90 |
91 | func SignedInt[T constraints.Signed](value T) *mysqlxdatatypes.Any {
92 | return &mysqlxdatatypes.Any{
93 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
94 | Scalar: SignedIntScalar(value),
95 | }
96 | }
97 |
98 | func SignedIntScalar[T constraints.Signed](value T) *mysqlxdatatypes.Scalar {
99 | return &mysqlxdatatypes.Scalar{
100 | Type: mysqlxdatatypes.Scalar_V_SINT.Enum(),
101 | VSignedInt: proto.Int64(int64(value)),
102 | }
103 | }
104 |
105 | func UnsignedInt[T constraints.Unsigned](value T) *mysqlxdatatypes.Any {
106 | return &mysqlxdatatypes.Any{
107 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
108 | Scalar: UnsignedIntScalar(value),
109 | }
110 | }
111 |
112 | func UnsignedIntScalar[T constraints.Unsigned](value T) *mysqlxdatatypes.Scalar {
113 | return &mysqlxdatatypes.Scalar{
114 | Type: mysqlxdatatypes.Scalar_V_UINT.Enum(),
115 | VUnsignedInt: proto.Uint64(uint64(value)),
116 | }
117 | }
118 |
119 | func String(value any) *mysqlxdatatypes.Any {
120 | var v string
121 | if value != nil {
122 | switch sv := value.(type) {
123 | case string:
124 | v = sv
125 | case *string:
126 | if sv != nil {
127 | v = *sv
128 | } else {
129 | return Nil()
130 | }
131 | default:
132 | panic(fmt.Sprintf("String accepts string or *string; not %T", value))
133 | }
134 | } else {
135 | return Nil()
136 | }
137 |
138 | return &mysqlxdatatypes.Any{
139 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
140 | Scalar: StringScalar(v),
141 | }
142 | }
143 |
144 | func StringScalar[T ~string](value T) *mysqlxdatatypes.Scalar {
145 | return &mysqlxdatatypes.Scalar{
146 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(),
147 | VString: &mysqlxdatatypes.Scalar_String{
148 | Value: []byte(value),
149 | },
150 | }
151 | }
152 |
153 | func Bytes[T ~[]byte](value T) *mysqlxdatatypes.Any {
154 | v := []byte(value)
155 | return &mysqlxdatatypes.Any{
156 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
157 | Scalar: BytesScalar(v),
158 | }
159 | }
160 |
161 | func BytesScalar[T ~[]byte](value T) *mysqlxdatatypes.Scalar {
162 | return &mysqlxdatatypes.Scalar{
163 | Type: mysqlxdatatypes.Scalar_V_OCTETS.Enum(),
164 | VOctets: &mysqlxdatatypes.Scalar_Octets{
165 | Value: value,
166 | },
167 | }
168 | }
169 |
170 | func Float32[T ~float32](value T) *mysqlxdatatypes.Any {
171 | v := float32(value)
172 | return &mysqlxdatatypes.Any{
173 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
174 | Scalar: Float32Scalar(v),
175 | }
176 | }
177 |
178 | func Float32Scalar[T ~float32](value T) *mysqlxdatatypes.Scalar {
179 | return &mysqlxdatatypes.Scalar{
180 | Type: mysqlxdatatypes.Scalar_V_FLOAT.Enum(),
181 | VFloat: proto.Float32(float32(value)),
182 | }
183 | }
184 |
185 | func Float64[T ~float64](value T) *mysqlxdatatypes.Any {
186 | v := float64(value)
187 | return &mysqlxdatatypes.Any{
188 | Type: mysqlxdatatypes.Any_SCALAR.Enum(),
189 | Scalar: Float64Scalar(v),
190 | }
191 | }
192 |
193 | func Float64Scalar[T ~float64](value T) *mysqlxdatatypes.Scalar {
194 | return &mysqlxdatatypes.Scalar{
195 | Type: mysqlxdatatypes.Scalar_V_DOUBLE.Enum(),
196 | VDouble: proto.Float64(float64(value)),
197 | }
198 | }
199 |
200 | func Decimal(value decimal.Decimal) *mysqlxdatatypes.Any {
201 | // MySQL X Protocol does not support sending the encoded BCD.
202 | return String(value.String())
203 | }
204 |
205 | func Time(value time.Time, timeZoneName string) (*mysqlxdatatypes.Any, error) {
206 | tz, err := time.LoadLocation(timeZoneName)
207 | if err != nil {
208 | return nil, err
209 | }
210 | return String(value.In(tz).Format("2006-01-02 15:04:05.999999")), nil
211 | }
212 |
--------------------------------------------------------------------------------
/xmysql/xproto/scalar_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Geert JM Vanderkelen
2 |
3 | package xproto_test
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/golistic/xgo/xstrings"
10 | "github.com/golistic/xgo/xt"
11 |
12 | "github.com/golistic/pxmysql/internal/mysqlx/mysqlxdatatypes"
13 | "github.com/golistic/pxmysql/xmysql/xproto"
14 | )
15 |
16 | func TestScalar(t *testing.T) {
17 | t.Run("basic types", func(t *testing.T) {
18 | var nilString *string
19 |
20 | var cases = []struct {
21 | have any
22 | exp *mysqlxdatatypes.Scalar
23 | }{
24 | {
25 | have: "gopher",
26 | exp: &mysqlxdatatypes.Scalar{
27 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(),
28 | VString: &mysqlxdatatypes.Scalar_String{
29 | Value: []byte("gopher"),
30 | },
31 | },
32 | },
33 | {
34 | have: "",
35 | exp: &mysqlxdatatypes.Scalar{
36 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(),
37 | VString: &mysqlxdatatypes.Scalar_String{
38 | Value: []byte(""),
39 | },
40 | },
41 | },
42 | {
43 | have: xstrings.Pointer("gopher"),
44 | exp: &mysqlxdatatypes.Scalar{
45 | Type: mysqlxdatatypes.Scalar_V_STRING.Enum(),
46 | VString: &mysqlxdatatypes.Scalar_String{
47 | Value: []byte("gopher"),
48 | },
49 | },
50 | },
51 | {
52 | have: nilString,
53 | exp: &mysqlxdatatypes.Scalar{
54 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(),
55 | },
56 | },
57 | {
58 | have: nil,
59 | exp: &mysqlxdatatypes.Scalar{
60 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(),
61 | },
62 | },
63 | {
64 | have: []byte("gopher"),
65 | exp: &mysqlxdatatypes.Scalar{
66 | Type: mysqlxdatatypes.Scalar_V_OCTETS.Enum(),
67 | VOctets: &mysqlxdatatypes.Scalar_Octets{
68 | Value: []byte("gopher"),
69 | },
70 | },
71 | },
72 | {
73 | have: []byte{},
74 | exp: &mysqlxdatatypes.Scalar{
75 | Type: mysqlxdatatypes.Scalar_V_NULL.Enum(),
76 | },
77 | },
78 | }
79 |
80 | for _, c := range cases {
81 | t.Run(fmt.Sprintf("%T", c.have), func(t *testing.T) {
82 | got := xproto.Scalar(c.have)
83 | xt.Eq(t, c.exp, got)
84 | })
85 | }
86 | })
87 |
88 | }
89 |
--------------------------------------------------------------------------------