├── .github └── workflows │ ├── integrations.yml │ └── unit_test.yml ├── .gitignore ├── LICENSE ├── README.md ├── build.zig ├── integration_tests ├── config.zig ├── conn.zig └── main.zig └── src ├── auth.zig ├── config.zig ├── conn.zig ├── constants.zig ├── conversion.zig ├── myzql.zig ├── pool.zig ├── protocol.zig ├── protocol ├── auth_switch_request.zig ├── column_definition.zig ├── generic_response.zig ├── handshake_response.zig ├── handshake_v10.zig ├── packet.zig ├── packet_reader.zig ├── packet_writer.zig ├── prepared_statements.zig ├── text_command.zig └── utils.zig ├── result.zig ├── result_meta.zig ├── temporal.zig └── utils.zig /.github/workflows/integrations.yml: -------------------------------------------------------------------------------- 1 | name: Integration Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: run mysql server 16 | run: | 17 | docker run --name mysql_server --env MYSQL_ROOT_PASSWORD=password -p 3306:3306 -d mysql 18 | 19 | - uses: actions/checkout@v3 20 | 21 | - name: install zig 22 | run: | 23 | ZIG_VERSION=0.14.0 24 | wget https://ziglang.org/builds/zig-linux-x86_64-$ZIG_VERSION.tar.xz 25 | tar xf zig-linux-x86_64-$ZIG_VERSION.tar.xz 26 | mv zig-linux-x86_64-$ZIG_VERSION $HOME/zig-build 27 | 28 | - name: Run Integration Tests - MySQL 29 | run: | 30 | $HOME/zig-build/zig build 31 | while ! wget -qO- localhost:3306; do sleep 1; docker ps -a; done 32 | $HOME/zig-build/zig build integration_test --summary all 33 | rm -rf zig-cache/ zig-out/ 34 | 35 | # - name: Auth Test 36 | # run: | 37 | # # mysql_native_password 38 | # docker rm -f mysql_server 39 | # docker run --name mysql_server --env MYSQL_ROOT_PASSWORD=password -p 3306:3306 -d mysql \ 40 | # --default-authentication-plugin=mysql_native_password 41 | # while ! wget -qO- localhost:3306; do sleep 1; docker ps -a; docker logs mysql_server; done 42 | # $HOME/zig-build/zig build integration_test -Dtest-filter="ping" --summary all 43 | # rm -rf zig-cache/ zig-out/ 44 | 45 | # # sha2_password 46 | # docker rm -f mysql_server 47 | # docker run --name mysql_server --env MYSQL_ROOT_PASSWORD=password -p 3306:3306 -d mysql \ 48 | # --default-authentication-plugin=sha256_password 49 | # while ! wget -qO- localhost:3306; do sleep 1; docker ps -a; done 50 | # $HOME/zig-build/zig build integration_test -Dtest-filter="ping" --summary all 51 | # rm -rf zig-cache/ zig-out/ 52 | 53 | - name: Run Integration Tests - MariaDB 54 | run: | 55 | docker rm -f mysql_server 56 | docker run --name some-mariadb --env MARIADB_ROOT_PASSWORD=password -p 3306:3306 -d mariadb 57 | while ! wget -qO- localhost:3306; do sleep 1; docker ps -a ; done 58 | $HOME/zig-build/zig build integration_test --summary all 59 | rm -rf zig-cache/ zig-out/ 60 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: install zig 17 | run: | 18 | ZIG_VERSION=0.14.0 19 | wget https://ziglang.org/builds/zig-linux-x86_64-$ZIG_VERSION.tar.xz 20 | tar xf zig-linux-x86_64-$ZIG_VERSION.tar.xz 21 | mv zig-linux-x86_64-$ZIG_VERSION $HOME/zig-build 22 | 23 | - name: Run unit tests 24 | run: | 25 | $HOME/zig-build/zig build unit_test 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | .zig-cache/ 3 | zig-out/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zack 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MyZql 2 | - MySQL and MariaDB driver in native zig 3 | 4 | ## Status 5 | - Beta 6 | 7 | ## Version Compatibility 8 | | MyZQL | Zig | 9 | |-------------|---------------------------| 10 | | 0.0.9.1 | 0.12.0 | 11 | | 0.13.2 | 0.13.0 | 12 | | 0.14.0 | 0.14.0 | 13 | | main | 0.14.0 | 14 | 15 | ## Features 16 | - Native Zig code, no external dependencies 17 | - TCP protocol 18 | - Prepared Statement 19 | - Structs from query result 20 | - Data insertion 21 | - MySQL DateTime and Time support 22 | 23 | ## Requirements 24 | - MySQL/MariaDB 5.7.5 and up 25 | 26 | ## TODOs 27 | - Config from URL 28 | - Connection Pooling 29 | - TLS support 30 | 31 | ## Add as dependency to your Zig project 32 | - `build.zig` 33 | ```zig 34 | //... 35 | const myzql_dep = b.dependency("myzql", .{}); 36 | const myzql = myzql_dep.module("myzql"); 37 | exe.addModule("myzql", myzql); 38 | //... 39 | ``` 40 | 41 | - `build.zig.zon` 42 | ```zon 43 | // ... 44 | .dependencies = .{ 45 | .myzql = .{ 46 | // choose a tag according to "Version Compatibility" table 47 | .url = "https://github.com/speed2exe/myzql/archive/refs/tags/0.13.2.tar.gz", 48 | .hash = "1220582ea45580eec6b16aa93d2a9404467db8bc1d911806d367513aa40f3817f84c", 49 | } 50 | }, 51 | // ... 52 | ``` 53 | 54 | ## Usage 55 | - Project integration example: [Usage](https://github.com/speed2exe/myzql-example) 56 | 57 | ### Connection 58 | ```zig 59 | const myzql = @import("myzql"); 60 | const Conn = myzql.conn.Conn; 61 | 62 | pub fn main() !void { 63 | // Setting up client 64 | var client = try Conn.init( 65 | allocator, 66 | &.{ 67 | .username = "some-user", // default: "root" 68 | .password = "password123", // default: "" 69 | .database = "customers", // default: "" 70 | 71 | // Current default value. 72 | // Use std.net.getAddressList if you need to look up ip based on hostname 73 | .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 3306), 74 | // ... 75 | }, 76 | ); 77 | defer client.deinit(); 78 | 79 | // Connection and Authentication 80 | try client.ping(); 81 | } 82 | ``` 83 | 84 | ## Querying 85 | ```zig 86 | 87 | const OkPacket = protocol.generic_response.OkPacket; 88 | 89 | pub fn main() !void { 90 | // ... 91 | // You can do a text query (text protocol) by using `query` method on `Conn` 92 | const result = try c.query("CREATE DATABASE testdb"); 93 | 94 | // Query results can have a few variant: 95 | // - ok: OkPacket => query is ok 96 | // - err: ErrorPacket => error occurred 97 | // In this example, res will either be `ok` or `err`. 98 | // We are using the convenient method `expect` for simplified error handling. 99 | // If the result variant does not match the kind of result you have specified, 100 | // a message will be printed and you will get an error instead. 101 | const ok: OkPacket = try result.expect(.ok); 102 | 103 | // Alternatively, you can also handle results manually for more control. 104 | // Here, we do a switch statement to handle all possible variant or results. 105 | switch (result.value) { 106 | .ok => |ok| {}, 107 | 108 | // `asError` is also another convenient method to print message and return as zig error. 109 | // You may also choose to inspect individual fields for more control. 110 | .err => |err| return err.asError(), 111 | } 112 | } 113 | ``` 114 | 115 | ## Querying returning rows (Text Results) 116 | - If you want to have query results to be represented by custom created structs, 117 | this is not the section, scroll down to "Executing prepared statements returning results" instead. 118 | ```zig 119 | const myzql = @import("myzql"); 120 | const QueryResult = myzql.result.QueryResult; 121 | const ResultSet = myzql.result.ResultSet; 122 | const ResultRow = myzql.result.ResultRow; 123 | const TextResultRow = myzql.result.TextResultData; 124 | const ResultSetIter = myzql.result.ResultSetIter; 125 | const TableTexts = myzql.result.TableTexts; 126 | const TextElemIter = myzql.result.TextElemIter; 127 | 128 | pub fn main() !void { 129 | const result = try c.queryRows("SELECT * FROM customers.purchases"); 130 | 131 | // This is a query that returns rows, you have to collect the result. 132 | // you can use `expect(.rows)` to try interpret query result as ResultSet(TextResultRow) 133 | const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); 134 | 135 | // Allocation free interators 136 | const rows_iter: ResultRowIter(TextResultRow) = rows.iter(); 137 | { // Option 1: Iterate through every row and elem 138 | while (try rows_iter.next()) |row| { // ResultRow(TextResultRow) 139 | var elems_iter: TextElemIter = row.iter(); 140 | while (elems_iter.next()) |elem| { // ?[] const u8 141 | std.debug.print("{?s} ", .{elem}); 142 | } 143 | } 144 | } 145 | { // Option 2: Iterating over rows, collecting elements into []const ?[]const u8 146 | while (try rows_iter.next()) |row| { 147 | const text_elems: TextElems = try row.textElems(allocator); 148 | defer text_elems.deinit(allocator); // elems are valid until deinit is called 149 | const elems: []const ?[]const u8 = text_elems.elems; 150 | std.debug.print("elems: {any}\n", .{elems}); 151 | } 152 | } 153 | 154 | // You can also use `collectTexts` method to collect all rows. 155 | // Under the hood, it does network call and allocations, until EOF or error 156 | // Results are valid until `deinit` is called on TableTexts. 157 | const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); 158 | const table = try rows.tableTexts(allocator); 159 | defer table.deinit(allocator); // table is valid until deinit is called 160 | std.debug.print("table: {any}\n", .{table.table}); 161 | } 162 | 163 | ``` 164 | 165 | ### Data Insertion 166 | - Let's assume that you have a table of this structure: 167 | ```sql 168 | CREATE TABLE test.person ( 169 | id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, 170 | name VARCHAR(255), 171 | age INT 172 | ) 173 | ``` 174 | 175 | ```zig 176 | const myzql = @import("myzql"); 177 | const QueryResult = myzql.result.QueryResult; 178 | const PreparedStatement = myzql.result.PreparedStatement; 179 | const OkPacket = myzql.protocol.generic_response.OkPacket; 180 | 181 | pub fn main() void { 182 | // In order to do a insertion, you would first need to do a prepared statement. 183 | // Allocation is required as we need to store metadata of parameters and return type 184 | const prep_res = try c.prepare(allocator, "INSERT INTO test.person (name, age) VALUES (?, ?)"); 185 | defer prep_res.deinit(allocator); 186 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 187 | 188 | // Data to be inserted 189 | const params = .{ 190 | .{ "John", 42 }, 191 | .{ "Sam", 24 }, 192 | }; 193 | inline for (params) |param| { 194 | const exe_res = try c.execute(&prep_stmt, param); 195 | const ok: OkPacket = try exe_res.expect(.ok); // expecting ok here because there's no rows returned 196 | const last_insert_id: u64 = ok.last_insert_id; 197 | std.debug.print("last_insert_id: {any}\n", .{last_insert_id}); 198 | } 199 | 200 | // Currently only tuples are supported as an argument for insertion. 201 | // There are plans to include named structs in the future. 202 | } 203 | ``` 204 | 205 | ### Executing prepared statements returning results as structs 206 | ```zig 207 | const ResultSetIter = myzql.result.ResultSetIter; 208 | const QueryResult = myzql.result.QueryResult; 209 | const BinaryResultRow = myzql.result.BinaryResultRow; 210 | const TableStructs = myzql.result.TableStructs; 211 | const ResultSet = myzql.result.ResultSet; 212 | 213 | fn main() !void { 214 | const prep_res = try c.prepare(allocator, "SELECT name, age FROM test.person"); 215 | defer prep_res.deinit(allocator); 216 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 217 | 218 | // This is the struct that represents the columns of a single row. 219 | const Person = struct { 220 | name: []const u8, 221 | age: u8, 222 | }; 223 | 224 | // Execute query and get an iterator from results 225 | const res: QueryResult(BinaryResultRow) = try c.executeRows(&prep_stmt, .{}); 226 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 227 | const iter: ResultSetIter(BinaryResultRow) = rows.iter(); 228 | 229 | { // Iterating over rows, scanning into struct or creating struct 230 | const query_res = try c.executeRows(&prep_stmt, .{}); // no parameters because there's no ? in the query 231 | const rows: ResultSet(BinaryResultRow) = try query_res.expect(.rows); 232 | const rows_iter = rows.iter(); 233 | while (try rows_iter.next()) |row| { 234 | { // Option 1: scanning into preallocated person 235 | var person: Person = undefined; 236 | try row.scan(&person); 237 | person.greet(); 238 | // Important: if any field is a string, it will be valid until the next row is scanned 239 | // or next query. If your rows return have strings and you want to keep the data longer, 240 | // use the method below instead. 241 | } 242 | { // Option 2: passing in allocator to create person 243 | const person_ptr = try row.structCreate(Person, allocator); 244 | 245 | // Important: please use BinaryResultRow.structDestroy 246 | // to destroy the struct created by BinaryResultRow.structCreate 247 | // if your struct contains strings. 248 | // person is valid until BinaryResultRow.structDestroy is called. 249 | defer BinaryResultRow.structDestroy(person_ptr, allocator); 250 | person_ptr.greet(); 251 | } 252 | } 253 | } 254 | 255 | { // collect all rows into a table ([]const Person) 256 | const query_res = try c.executeRows(&prep_stmt, .{}); // no parameters because there's no ? in the query 257 | const rows: ResultSet(BinaryResultRow) = try query_res.expect(.rows); 258 | const rows_iter = rows.iter(); 259 | const person_structs = try rows_iter.tableStructs(Person, allocator); 260 | defer person_structs.deinit(allocator); // data is valid until deinit is called 261 | std.debug.print("person_structs: {any}\n", .{person_structs.struct_list.items}); 262 | } 263 | } 264 | ``` 265 | 266 | ### Temporal Types Support (DateTime, Time) 267 | - Example of using DateTime and Time MySQL column types. 268 | - Let's assume you already got this table set up: 269 | ```sql 270 | CREATE TABLE test.temporal_types_example ( 271 | event_time DATETIME(6) NOT NULL, 272 | duration TIME(6) NOT NULL 273 | ) 274 | ``` 275 | 276 | 277 | ```zig 278 | 279 | const DateTime = myzql.temporal.DateTime; 280 | const Duration = myzql.temporal.Duration; 281 | 282 | fn main() !void { 283 | { // Insert 284 | const prep_res = try c.prepare(allocator, "INSERT INTO test.temporal_types_example VALUES (?, ?)"); 285 | defer prep_res.deinit(allocator); 286 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 287 | 288 | const my_time: DateTime = .{ 289 | .year = 2023, 290 | .month = 11, 291 | .day = 30, 292 | .hour = 6, 293 | .minute = 50, 294 | .second = 58, 295 | .microsecond = 123456, 296 | }; 297 | const my_duration: Duration = .{ 298 | .days = 1, 299 | .hours = 23, 300 | .minutes = 59, 301 | .seconds = 59, 302 | .microseconds = 123456, 303 | }; 304 | const params = .{.{ my_time, my_duration }}; 305 | inline for (params) |param| { 306 | const exe_res = try c.execute(&prep_stmt, param); 307 | _ = try exe_res.expect(.ok); 308 | } 309 | } 310 | 311 | { // Select 312 | const DateTimeDuration = struct { 313 | event_time: DateTime, 314 | duration: Duration, 315 | }; 316 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.temporal_types_example"); 317 | defer prep_res.deinit(allocator); 318 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 319 | const res = try c.executeRows(&prep_stmt, .{}); 320 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 321 | const rows_iter = rows.iter(); 322 | 323 | const structs = try rows_iter.tableStructs(DateTimeDuration, allocator); 324 | defer structs.deinit(allocator); 325 | std.debug.print("structs: {any}\n", .{structs.struct_list.items}); // structs.rows: []const DateTimeDuration 326 | // Do something with structs 327 | } 328 | } 329 | ``` 330 | 331 | ### Arrays Support 332 | - Assume that you have the SQL table: 333 | ```sql 334 | CREATE TABLE test.array_types_example ( 335 | name VARCHAR(16) NOT NULL, 336 | mac_addr BINARY(6) 337 | ) 338 | ``` 339 | 340 | ```zig 341 | fn main() !void { 342 | { // Insert 343 | const prep_res = try c.prepare(allocator, "INSERT INTO test.array_types_example VALUES (?, ?)"); 344 | defer prep_res.deinit(allocator); 345 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 346 | 347 | const params = .{ 348 | .{ "John", &[_]u8 { 0xFE } ** 6 }, 349 | .{ "Alice", null } 350 | }; 351 | inline for (params) |param| { 352 | const exe_res = try c.execute(&prep_stmt, param); 353 | _ = try exe_res.expect(.ok); 354 | } 355 | } 356 | 357 | { // Select 358 | const Client = struct { 359 | name: [16:1]u8, 360 | mac_addr: ?[6]u8, 361 | }; 362 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.array_types_example"); 363 | defer prep_res.deinit(allocator); 364 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 365 | const res = try c.executeRows(&prep_stmt, .{}); 366 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 367 | const rows_iter = rows.iter(); 368 | 369 | const structs = try rows_iter.tableStructs(DateTimeDuration, allocator); 370 | defer structs.deinit(allocator); 371 | std.debug.print("structs: {any}\n", .{structs.struct_list.items}); // structs.rows: []const Client 372 | // Do something with structs 373 | } 374 | } 375 | ``` 376 | - Arrays will be initialized by their sentinel value. In this example, the value of the `name` field corresponding to `John`'s row will be `[16:1]u8 { 'J', 'o', 'h', 'n', 1, 1, 1, ... }` 377 | - If the array doesn't have a sentinel value, it will be zero-initialized. 378 | - Insufficiently sized arrays will silently truncate excess data 379 | 380 | ### `BoundedArray` Support 381 | - Assume that you have the SQL table: 382 | ```sql 383 | CREATE TABLE test.bounded_array_types_example ( 384 | name VARCHAR(16) NOT NULL, 385 | address VARCHAR(128) 386 | ) 387 | ``` 388 | 389 | ```zig 390 | const std = @import("std"); 391 | 392 | fn main() !void { 393 | { // Insert 394 | const prep_res = try c.prepare(allocator, "INSERT INTO test.bounded_array_types_example VALUES (?, ?)"); 395 | defer prep_res.deinit(allocator); 396 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 397 | 398 | const params = .{ 399 | .{ "John", "5 Rosewood Avenue Maryville, TN 37803"}, 400 | .{ "Alice", null } 401 | }; 402 | inline for (params) |param| { 403 | const exe_res = try c.execute(&prep_stmt, param); 404 | _ = try exe_res.expect(.ok); 405 | } 406 | } 407 | 408 | { // Select 409 | const Client = struct { 410 | name: std.BoundedArray(u8, 16), 411 | address: ?std.BoundedArray(u8, 128), 412 | }; 413 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.bounded_array_types_example"); 414 | defer prep_res.deinit(allocator); 415 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 416 | const res = try c.executeRows(&prep_stmt, .{}); 417 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 418 | const rows_iter = rows.iter(); 419 | 420 | const structs = try rows_iter.tableStructs(Client, allocator); 421 | defer structs.deinit(allocator); 422 | std.debug.print("structs: {any}\n", .{structs.struct_list.items}); // structs.rows: []const Client 423 | // Do something with structs 424 | } 425 | } 426 | ``` 427 | 428 | - Insufficiently sized `BoundedArray`s will silently truncate excess data 429 | 430 | ## Unit Tests 431 | - `zig test src/myzql.zig` 432 | 433 | ## Integration Tests 434 | - Start up mysql/mariadb in docker: 435 | ```bash 436 | # MySQL 437 | docker run --name some-mysql --env MYSQL_ROOT_PASSWORD=password -p 3306:3306 -d mysql 438 | ```bash 439 | # MariaDB 440 | docker run --name some-mariadb --env MARIADB_ROOT_PASSWORD=password -p 3306:3306 -d mariadb 441 | ``` 442 | - Run all the test: In root directory of project: 443 | ```bash 444 | zig build -Dtest-filer='...' integration_test 445 | ``` 446 | 447 | ## Philosophy 448 | ### Correctness 449 | Focused on correct representation of server client protocol. 450 | ### Low-level and High-level APIs 451 | Low-level apis should contain all functionality you need. 452 | High-level apis are built on top of low-level ones for convenience and developer ergonomics. 453 | 454 | ### Binary Column Types support 455 | - MySQL Colums Types to Zig Values 456 | ``` 457 | - Null -> ?T 458 | - Int -> u64, u32, u16, u8 459 | - Float -> f32, f64 460 | - String -> []u8, []const u8, enum 461 | ``` 462 | -------------------------------------------------------------------------------- /build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.Build) void { 4 | const myzql = b.addModule("myzql", .{ 5 | .root_source_file = b.path("./src/myzql.zig"), 6 | }); 7 | 8 | // -Dtest-filter="..." 9 | const test_filter = b.option([]const []const u8, "test-filter", "Filter for tests to run"); 10 | 11 | // zig build unit_test 12 | const unit_tests = b.addTest(.{ 13 | .root_source_file = b.path("./src/myzql.zig"), 14 | }); 15 | if (test_filter) |t| unit_tests.filters = t; 16 | 17 | // zig build [install] 18 | b.installArtifact(unit_tests); 19 | 20 | // zig build -Dtest-filter="..." run_unit_test 21 | const run_unit_tests = b.addRunArtifact(unit_tests); 22 | const unit_test_step = b.step("unit_test", "Run unit tests"); 23 | unit_test_step.dependOn(&run_unit_tests.step); 24 | 25 | // zig build -Dtest-filter="..." integration_test 26 | const integration_tests = b.addTest(.{ 27 | .root_source_file = b.path("./integration_tests/main.zig"), 28 | }); 29 | integration_tests.root_module.addImport("myzql", myzql); 30 | if (test_filter) |t| integration_tests.filters = t; 31 | 32 | // zig build [install] 33 | b.installArtifact(integration_tests); 34 | 35 | // zig build integration_test 36 | const run_integration_tests = b.addRunArtifact(integration_tests); 37 | const integration_test_step = b.step("integration_test", "Run integration tests"); 38 | integration_test_step.dependOn(&run_integration_tests.step); 39 | } 40 | -------------------------------------------------------------------------------- /integration_tests/config.zig: -------------------------------------------------------------------------------- 1 | const myzql = @import("myzql"); 2 | const Config = myzql.config.Config; 3 | 4 | pub const test_config: Config = .{ 5 | .password = "password", 6 | }; 7 | 8 | pub const test_config_with_db: Config = .{ 9 | .password = "password", 10 | .database = "mysql", 11 | }; 12 | -------------------------------------------------------------------------------- /integration_tests/conn.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const myzql = @import("myzql"); 3 | const Conn = myzql.conn.Conn; 4 | const test_config = @import("./config.zig").test_config; 5 | const test_config_with_db = @import("./config.zig").test_config_with_db; 6 | const allocator = std.testing.allocator; 7 | const ErrorPacket = myzql.protocol.generic_response.ErrorPacket; 8 | const minInt = std.math.minInt; 9 | const maxInt = std.math.maxInt; 10 | const DateTime = myzql.temporal.DateTime; 11 | const Duration = myzql.temporal.Duration; 12 | const ResultSet = myzql.result.ResultSet; 13 | const ResultRow = myzql.result.ResultRow; 14 | const BinaryResultRow = myzql.result.BinaryResultRow; 15 | const ResultRowIter = myzql.result.ResultRowIter; 16 | const TextResultRow = myzql.result.TextResultRow; 17 | const TextElemIter = myzql.result.TextElemIter; 18 | const TextElems = myzql.result.TextElems; 19 | const PreparedStatement = myzql.result.PreparedStatement; 20 | 21 | // convenient function for testing 22 | fn queryExpectOk(c: *Conn, query: []const u8) !void { 23 | const query_res = try c.query(query); 24 | _ = try query_res.expect(.ok); 25 | } 26 | 27 | fn queryExpectOkLogError(c: *Conn, query: []const u8) void { 28 | queryExpectOk(c, query) catch |err| { 29 | std.debug.print("Error: {}\n", .{err}); 30 | }; 31 | } 32 | 33 | test "ping" { 34 | var c = try Conn.init(std.testing.allocator, &test_config); 35 | defer c.deinit(); 36 | try c.ping(); 37 | } 38 | 39 | test "connect with database" { 40 | var c = try Conn.init(std.testing.allocator, &test_config_with_db); 41 | defer c.deinit(); 42 | try c.ping(); 43 | } 44 | 45 | test "query database create and drop" { 46 | var c = try Conn.init(std.testing.allocator, &test_config); 47 | defer c.deinit(); 48 | try queryExpectOk(&c, "CREATE DATABASE testdb"); 49 | try queryExpectOk(&c, "DROP DATABASE testdb"); 50 | } 51 | 52 | test "query syntax error" { 53 | var c = try Conn.init(std.testing.allocator, &test_config); 54 | defer c.deinit(); 55 | 56 | const qr = try c.query("garbage query"); 57 | _ = try qr.expect(.err); 58 | } 59 | 60 | test "query text protocol" { 61 | var c = try Conn.init(std.testing.allocator, &test_config); 62 | defer c.deinit(); 63 | 64 | { // Iterating over rows and elements 65 | const query_res = try c.queryRows("SELECT 1"); 66 | 67 | const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); 68 | const rows_iter: ResultRowIter(TextResultRow) = rows.iter(); 69 | while (try rows_iter.next()) |row| { // ResultRow(TextResultRow) 70 | var elems_iter: TextElemIter = row.iter(); 71 | while (elems_iter.next()) |elem| { // ?[] const u8 72 | try std.testing.expectEqualDeep(@as(?[]const u8, "1"), elem); 73 | } 74 | } 75 | } 76 | { // Iterating over rows, collecting elements into []const ?[]const u8 77 | const query_res = try c.queryRows("SELECT 3, 4, null, 6, 7"); 78 | const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); 79 | const rows_iter: ResultRowIter(TextResultRow) = rows.iter(); 80 | while (try rows_iter.next()) |row| { 81 | const elems: TextElems = try row.textElems(allocator); 82 | defer elems.deinit(allocator); 83 | 84 | try std.testing.expectEqualDeep( 85 | @as([]const ?[]const u8, &.{ "3", "4", null, "6", "7" }), 86 | elems.elems, 87 | ); 88 | } 89 | } 90 | { // Iterating over rows, collecting elements into []const []const ?[]const u8 91 | const query_res = try c.queryRows("SELECT 8,9 UNION ALL SELECT 10,11"); 92 | const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); 93 | const table = try rows.tableTexts(allocator); 94 | defer table.deinit(allocator); 95 | 96 | try std.testing.expectEqualDeep( 97 | @as([]const []const ?[]const u8, &.{ 98 | &.{ "8", "9" }, 99 | &.{ "10", "11" }, 100 | }), 101 | table.table, 102 | ); 103 | } 104 | } 105 | 106 | test "prepare check" { 107 | var c = try Conn.init(std.testing.allocator, &test_config); 108 | defer c.deinit(); 109 | { // prepare no execute 110 | const prep_res = try c.prepare(allocator, "CREATE TABLE default.testtable (id INT, name VARCHAR(255))"); 111 | defer prep_res.deinit(allocator); 112 | _ = try prep_res.expect(.stmt); 113 | } 114 | { // prepare with params 115 | const prep_res = try c.prepare(allocator, "SELECT CONCAT(?, ?) as my_col"); 116 | defer prep_res.deinit(allocator); 117 | 118 | switch (prep_res) { 119 | .stmt => |prep_stmt| { 120 | try std.testing.expectEqual(prep_stmt.prep_ok.num_params, 2); 121 | try std.testing.expectEqual(prep_stmt.prep_ok.num_columns, 1); 122 | }, 123 | .err => |err| return err.asError(), 124 | } 125 | try std.testing.expectEqual(c.reader.len, c.reader.pos); 126 | } 127 | } 128 | 129 | test "prepare execute - 1" { 130 | var c = try Conn.init(std.testing.allocator, &test_config); 131 | defer c.deinit(); 132 | { 133 | const prep_res = try c.prepare(allocator, "CREATE DATABASE testdb"); 134 | defer prep_res.deinit(allocator); 135 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 136 | const query_res = try c.execute(&prep_stmt, .{}); 137 | _ = try query_res.expect(.ok); 138 | } 139 | { 140 | const prep_res = try c.prepare(allocator, "DROP DATABASE testdb"); 141 | defer prep_res.deinit(allocator); 142 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 143 | const query_res = try c.execute(&prep_stmt, .{}); 144 | _ = try query_res.expect(.ok); 145 | } 146 | } 147 | 148 | test "prepare execute - 2" { 149 | var c = try Conn.init(std.testing.allocator, &test_config); 150 | defer c.deinit(); 151 | 152 | const prep_res_1 = try c.prepare(allocator, "CREATE DATABASE testdb"); 153 | defer prep_res_1.deinit(allocator); 154 | const prep_stmt_1: PreparedStatement = try prep_res_1.expect(.stmt); 155 | 156 | const prep_res_2 = try c.prepare(allocator, "DROP DATABASE testdb"); 157 | defer prep_res_2.deinit(allocator); 158 | const prep_stmt_2: PreparedStatement = try prep_res_2.expect(.stmt); 159 | 160 | { 161 | const query_res = try c.execute(&prep_stmt_1, .{}); 162 | _ = try query_res.expect(.ok); 163 | } 164 | { 165 | const query_res = try c.execute(&prep_stmt_2, .{}); 166 | _ = try query_res.expect(.ok); 167 | } 168 | } 169 | 170 | test "prepare execute with result" { 171 | var c = try Conn.init(std.testing.allocator, &test_config); 172 | defer c.deinit(); 173 | 174 | { 175 | const query = 176 | \\SELECT null, "hello", 3 177 | ; 178 | const prep_res = try c.prepare(allocator, query); 179 | defer prep_res.deinit(allocator); 180 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 181 | const query_res = try c.executeRows(&prep_stmt, .{}); 182 | const rows: ResultSet(BinaryResultRow) = try query_res.expect(.rows); 183 | 184 | const MyType = struct { 185 | a: ?u8, 186 | b: []const u8, 187 | c: ?u8, 188 | }; 189 | const expected = MyType{ 190 | .a = null, 191 | .b = "hello", 192 | .c = 3, 193 | }; 194 | 195 | const rows_iter = rows.iter(); 196 | 197 | var dest_ptr: *MyType = undefined; 198 | while (try rows_iter.next()) |row| { 199 | { 200 | var dest: MyType = undefined; 201 | try row.scan(&dest); 202 | try std.testing.expectEqualDeep(expected, dest); 203 | } 204 | { 205 | dest_ptr = try row.structCreate(MyType, allocator); 206 | try std.testing.expectEqualDeep(expected, dest_ptr.*); 207 | } 208 | } 209 | defer BinaryResultRow.structDestroy(dest_ptr, allocator); 210 | 211 | { // Dummy query to test for invalid memory reuse 212 | const query_res2 = try c.queryRows("SELECT 3, 4, null, 6, 7"); 213 | 214 | const rows2: ResultSet(TextResultRow) = try query_res2.expect(.rows); 215 | const rows_iter2: ResultRowIter(TextResultRow) = rows2.iter(); 216 | while (try rows_iter2.next()) |row| { 217 | _ = row; 218 | } 219 | } 220 | 221 | try std.testing.expectEqualDeep(dest_ptr.b, "hello"); 222 | } 223 | { 224 | const query = 225 | \\SELECT 1, 2, 3 226 | \\UNION ALL 227 | \\SELECT 4, 5, 6 228 | ; 229 | const prep_res = try c.prepare(allocator, query); 230 | defer prep_res.deinit(allocator); 231 | const prep_stmt: PreparedStatement = try prep_res.expect(.stmt); 232 | const query_res = try c.executeRows(&prep_stmt, .{}); 233 | const rows: ResultSet(BinaryResultRow) = try query_res.expect(.rows); 234 | const rows_iter = rows.iter(); 235 | 236 | const MyType = struct { 237 | a: u8, 238 | b: u8, 239 | c: u8, 240 | }; 241 | const expected: []const MyType = &.{ 242 | .{ .a = 1, .b = 2, .c = 3 }, 243 | .{ .a = 4, .b = 5, .c = 6 }, 244 | }; 245 | 246 | const structs = try rows_iter.tableStructs(MyType, allocator); 247 | defer structs.deinit(allocator); 248 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 249 | } 250 | } 251 | 252 | test "prepare execute - first" { 253 | var c = try Conn.init(std.testing.allocator, &test_config); 254 | defer c.deinit(); 255 | 256 | { 257 | const query = 258 | \\SELECT 1 259 | \\UNION ALL 260 | \\SELECT 2 261 | ; 262 | 263 | const prep_res = try c.prepare(allocator, query); 264 | defer prep_res.deinit(allocator); 265 | const prep_stmt = try prep_res.expect(.stmt); 266 | const query_res = try c.executeRows(&prep_stmt, .{}); 267 | const rows = try query_res.expect(.rows); 268 | 269 | const MyType = struct { a: u8 }; 270 | 271 | const expected = MyType{ .a = 1 }; 272 | 273 | const first = try rows.first(); 274 | try std.testing.expect(first != null); 275 | 276 | var value: MyType = undefined; 277 | try first.?.scan(&value); 278 | try std.testing.expectEqualDeep(expected, value); 279 | try c.ping(); 280 | } 281 | 282 | { 283 | const query = 284 | \\SELECT NULL 285 | \\WHERE FALSE 286 | ; 287 | 288 | const prep_res = try c.prepare(allocator, query); 289 | defer prep_res.deinit(allocator); 290 | const prep_stmt = try prep_res.expect(.stmt); 291 | const query_res = try c.executeRows(&prep_stmt, .{}); 292 | const rows = try query_res.expect(.rows); 293 | 294 | const first = try rows.first(); 295 | try std.testing.expectEqual(null, first); 296 | try c.ping(); 297 | } 298 | } 299 | 300 | test "binary data types - int" { 301 | var c = try Conn.init(std.testing.allocator, &test_config); 302 | defer c.deinit(); 303 | 304 | try queryExpectOk(&c, "CREATE DATABASE test"); 305 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 306 | 307 | try queryExpectOk(&c, 308 | \\CREATE TABLE test.int_types_example ( 309 | \\ tinyint_col TINYINT, 310 | \\ smallint_col SMALLINT, 311 | \\ mediumint_col MEDIUMINT, 312 | \\ int_col INT, 313 | \\ bigint_col BIGINT, 314 | \\ tinyint_unsigned_col TINYINT UNSIGNED, 315 | \\ smallint_unsigned_col SMALLINT UNSIGNED, 316 | \\ mediumint_unsigned_col MEDIUMINT UNSIGNED, 317 | \\ int_unsigned_col INT UNSIGNED, 318 | \\ bigint_unsigned_col BIGINT UNSIGNED 319 | \\) 320 | ); 321 | defer queryExpectOk(&c, "DROP TABLE test.int_types_example") catch {}; 322 | 323 | { // Insert (Binary Protocol) 324 | const prep_res = try c.prepare( 325 | allocator, 326 | "INSERT INTO test.int_types_example VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 327 | ); 328 | defer prep_res.deinit(allocator); 329 | const prep_stmt = try prep_res.expect(.stmt); 330 | 331 | const params = .{ 332 | .{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 333 | .{ -(1 << 7), -(1 << 15), -(1 << 23), -(1 << 31), -(1 << 63), 0, 0, 0, 0, 0 }, 334 | .{ (1 << 7) - 1, (1 << 15) - 1, (1 << 23) - 1, (1 << 31) - 1, (1 << 63) - 1, (1 << 8) - 1, (1 << 16) - 1, (1 << 24) - 1, (1 << 32) - 1, (1 << 64) - 1 }, 335 | .{ null, null, null, null, null, null, null, null, null, null }, 336 | .{ @as(?i8, 0), @as(?i16, 0), @as(?i32, 0), @as(?i64, 0), @as(?u8, 0), @as(?u16, 0), @as(?u32, 0), @as(?u64, 0), @as(?u8, 0), @as(?u64, 0) }, 337 | .{ @as(i8, minInt(i8)), @as(i16, minInt(i16)), @as(i32, minInt(i24)), @as(i32, minInt(i32)), @as(i64, minInt(i64)), @as(u8, minInt(u8)), @as(u16, minInt(u16)), @as(u32, minInt(u24)), @as(u32, minInt(u32)), @as(u64, minInt(u64)) }, 338 | .{ @as(i8, maxInt(i8)), @as(i16, maxInt(i16)), @as(i32, maxInt(i24)), @as(i32, maxInt(i32)), @as(i64, maxInt(i64)), @as(u8, maxInt(u8)), @as(u16, maxInt(u16)), @as(u32, maxInt(u24)), @as(u32, maxInt(u32)), @as(u64, maxInt(u64)) }, 339 | .{ @as(?i8, null), @as(?i16, null), @as(?i32, null), @as(?i64, null), @as(?u8, null), @as(?u16, null), @as(?u32, null), @as(?u64, null), @as(?u8, null), @as(?u64, null) }, 340 | }; 341 | inline for (params) |param| { 342 | const exe_res = try c.execute(&prep_stmt, param); 343 | _ = try exe_res.expect(.ok); 344 | } 345 | } 346 | 347 | { // Select (Text Protocol) 348 | const res = try c.queryRows("SELECT * FROM test.int_types_example"); 349 | const rows: ResultSet(TextResultRow) = try res.expect(.rows); 350 | 351 | const table_texts = try rows.tableTexts(allocator); 352 | defer table_texts.deinit(allocator); 353 | 354 | const expected: []const []const ?[]const u8 = &.{ 355 | &.{ "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }, 356 | &.{ "-128", "-32768", "-8388608", "-2147483648", "-9223372036854775808", "0", "0", "0", "0", "0" }, 357 | &.{ "127", "32767", "8388607", "2147483647", "9223372036854775807", "255", "65535", "16777215", "4294967295", "18446744073709551615" }, 358 | &.{ null, null, null, null, null, null, null, null, null, null }, 359 | &.{ "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }, 360 | &.{ "-128", "-32768", "-8388608", "-2147483648", "-9223372036854775808", "0", "0", "0", "0", "0" }, 361 | &.{ "127", "32767", "8388607", "2147483647", "9223372036854775807", "255", "65535", "16777215", "4294967295", "18446744073709551615" }, 362 | &.{ null, null, null, null, null, null, null, null, null, null }, 363 | }; 364 | try std.testing.expectEqualDeep(expected, table_texts.table); 365 | } 366 | 367 | { // Select (Binary Protocol) 368 | const IntTypesExample = struct { 369 | tinyint_col: ?i8, 370 | smallint_col: ?i16, 371 | mediumint_col: ?i24, 372 | int_col: ?i32, 373 | bigint_col: ?i64, 374 | tinyint_unsigned_col: ?u8, 375 | smallint_unsigned_col: ?u16, 376 | mediumint_unsigned_col: ?u24, 377 | int_unsigned_col: ?u32, 378 | bigint_unsigned_col: ?u64, 379 | }; 380 | 381 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.int_types_example LIMIT 4"); 382 | defer prep_res.deinit(allocator); 383 | const prep_stmt = try prep_res.expect(.stmt); 384 | const res = try c.executeRows(&prep_stmt, .{}); 385 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 386 | 387 | const expected: []const IntTypesExample = &.{ 388 | .{ 389 | .tinyint_col = 0, 390 | .smallint_col = 0, 391 | .mediumint_col = 0, 392 | .int_col = 0, 393 | .bigint_col = 0, 394 | .tinyint_unsigned_col = 0, 395 | .smallint_unsigned_col = 0, 396 | .mediumint_unsigned_col = 0, 397 | .int_unsigned_col = 0, 398 | .bigint_unsigned_col = 0, 399 | }, 400 | .{ 401 | .tinyint_col = -128, 402 | .smallint_col = -32768, 403 | .mediumint_col = -8388608, 404 | .int_col = -2147483648, 405 | .bigint_col = -9223372036854775808, 406 | .tinyint_unsigned_col = 0, 407 | .smallint_unsigned_col = 0, 408 | .mediumint_unsigned_col = 0, 409 | .int_unsigned_col = 0, 410 | .bigint_unsigned_col = 0, 411 | }, 412 | .{ 413 | .tinyint_col = 127, 414 | .smallint_col = 32767, 415 | .mediumint_col = 8388607, 416 | .int_col = 2147483647, 417 | .bigint_col = 9223372036854775807, 418 | .tinyint_unsigned_col = 255, 419 | .smallint_unsigned_col = 65535, 420 | .mediumint_unsigned_col = 16777215, 421 | .int_unsigned_col = 4294967295, 422 | .bigint_unsigned_col = 18446744073709551615, 423 | }, 424 | .{ 425 | .tinyint_col = null, 426 | .smallint_col = null, 427 | .mediumint_col = null, 428 | .int_col = null, 429 | .bigint_col = null, 430 | .tinyint_unsigned_col = null, 431 | .smallint_unsigned_col = null, 432 | .mediumint_unsigned_col = null, 433 | .int_unsigned_col = null, 434 | .bigint_unsigned_col = null, 435 | }, 436 | }; 437 | 438 | const structs = try rows.iter().tableStructs(IntTypesExample, allocator); 439 | defer structs.deinit(allocator); 440 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 441 | } 442 | } 443 | 444 | test "binary data types - float" { 445 | var c = try Conn.init(std.testing.allocator, &test_config); 446 | defer c.deinit(); 447 | 448 | try queryExpectOk(&c, "CREATE DATABASE test"); 449 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 450 | 451 | try queryExpectOk(&c, 452 | \\CREATE TABLE test.float_types_example ( 453 | \\ float_col FLOAT, 454 | \\ double_col DOUBLE 455 | \\) 456 | ); 457 | defer queryExpectOk(&c, "DROP TABLE test.float_types_example") catch {}; 458 | 459 | { // Exec Insert 460 | const prep_res = try c.prepare(allocator, "INSERT INTO test.float_types_example VALUES (?, ?)"); 461 | defer prep_res.deinit(allocator); 462 | const prep_stmt = try prep_res.expect(.stmt); 463 | 464 | const params = .{ 465 | .{ 0.0, 0.0 }, 466 | .{ -1.23, -1.23 }, 467 | .{ 1.23, 1.23 }, 468 | .{ null, null }, 469 | .{ @as(?f32, 0), @as(?f64, 0) }, 470 | .{ @as(f32, -1.23), @as(f64, -1.23) }, 471 | .{ @as(f32, 1.23), @as(f64, 1.23) }, 472 | .{ @as(?f32, null), @as(?f64, null) }, 473 | }; 474 | inline for (params) |param| { 475 | const exe_res = try c.execute(&prep_stmt, param); 476 | _ = try exe_res.expect(.ok); 477 | } 478 | } 479 | 480 | { // Text Protocol 481 | const res = try c.queryRows("SELECT * FROM test.float_types_example"); 482 | const rows: ResultSet(TextResultRow) = try res.expect(.rows); 483 | const table_texts = try rows.tableTexts(allocator); 484 | defer table_texts.deinit(allocator); 485 | 486 | const expected: []const []const ?[]const u8 = &.{ 487 | &.{ "0", "0" }, 488 | &.{ "-1.23", "-1.23" }, 489 | &.{ "1.23", "1.23" }, 490 | &.{ null, null }, 491 | &.{ "0", "0" }, 492 | &.{ "-1.23", "-1.23" }, 493 | &.{ "1.23", "1.23" }, 494 | &.{ null, null }, 495 | }; 496 | try std.testing.expectEqualDeep(expected, table_texts.table); 497 | } 498 | 499 | { // Select (Binary Protocol) 500 | const FloatTypesExample = struct { 501 | float_col: f32, 502 | double_col: f64, 503 | }; 504 | 505 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.float_types_example LIMIT 3"); 506 | defer prep_res.deinit(allocator); 507 | const prep_stmt = try prep_res.expect(.stmt); 508 | const res = try c.executeRows(&prep_stmt, .{}); 509 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 510 | const row_iter = rows.iter(); 511 | 512 | const expected: []const FloatTypesExample = &.{ 513 | .{ .float_col = 0, .double_col = 0 }, 514 | .{ .float_col = -1.23, .double_col = -1.23 }, 515 | .{ .float_col = 1.23, .double_col = 1.23 }, 516 | }; 517 | 518 | const structs = try row_iter.tableStructs(FloatTypesExample, allocator); 519 | defer structs.deinit(allocator); 520 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 521 | } 522 | } 523 | 524 | test "binary data types - string" { 525 | var c = try Conn.init(std.testing.allocator, &test_config); 526 | defer c.deinit(); 527 | 528 | try queryExpectOk(&c, "CREATE DATABASE test"); 529 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 530 | 531 | try queryExpectOk(&c, 532 | \\CREATE TABLE test.string_types_example ( 533 | \\ varchar_col VARCHAR(255), 534 | \\ not_null_varchar_col VARCHAR(255) NOT NULL, 535 | \\ enum_col ENUM('a', 'b', 'c'), 536 | \\ not_null_enum_col ENUM('a', 'b', 'c') NOT NULL 537 | \\) 538 | ); 539 | defer queryExpectOk(&c, "DROP TABLE test.string_types_example") catch {}; 540 | 541 | const MyEnum = enum { a, b, c }; 542 | 543 | { // Exec Insert 544 | const prep_res = try c.prepare(allocator, "INSERT INTO test.string_types_example VALUES (?, ?, ?, ?)"); 545 | defer prep_res.deinit(allocator); 546 | const prep_stmt = try prep_res.expect(.stmt); 547 | 548 | const params = .{ 549 | .{ "hello", "world", "a", @as([*c]const u8, "b") }, 550 | .{ null, "foo", null, "c" }, 551 | .{ null, "", null, "a" }, 552 | .{ 553 | @as(?*const [3]u8, "baz"), 554 | @as([*:0]const u8, "bar"), 555 | @as(?[]const u8, null), 556 | @as(MyEnum, .c), 557 | }, 558 | }; 559 | inline for (params) |param| { 560 | const exe_res = try c.execute(&prep_stmt, param); 561 | _ = try exe_res.expect(.ok); 562 | } 563 | } 564 | 565 | { // Text Protocol 566 | const res = try c.queryRows("SELECT * FROM test.string_types_example"); 567 | const rows: ResultSet(TextResultRow) = try res.expect(.rows); 568 | 569 | const table_texts = try rows.tableTexts(allocator); 570 | defer table_texts.deinit(allocator); 571 | 572 | const expected: []const []const ?[]const u8 = &.{ 573 | &.{ "hello", "world", "a", "b" }, 574 | &.{ null, "foo", null, "c" }, 575 | &.{ null, "", null, "a" }, 576 | &.{ "baz", "bar", null, "c" }, 577 | }; 578 | try std.testing.expectEqualDeep(expected, table_texts.table); 579 | } 580 | 581 | { // Select (Binary Protocol) 582 | const StringTypesExample = struct { 583 | varchar_col: ?[]const u8, 584 | not_null_varchar_col: []const u8, 585 | enum_col: ?MyEnum, 586 | not_null_enum_col: MyEnum, 587 | }; 588 | 589 | const prep_res = try c.prepare(allocator, 590 | \\SELECT * FROM test.string_types_example 591 | ); 592 | defer prep_res.deinit(allocator); 593 | const prep_stmt = try prep_res.expect(.stmt); 594 | const res = try c.executeRows(&prep_stmt, .{}); 595 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 596 | const rows_iter = rows.iter(); 597 | 598 | const expected: []const StringTypesExample = &.{ 599 | .{ 600 | .varchar_col = "hello", 601 | .not_null_varchar_col = "world", 602 | .enum_col = .a, 603 | .not_null_enum_col = .b, 604 | }, 605 | .{ 606 | .varchar_col = null, 607 | .not_null_varchar_col = "foo", 608 | .enum_col = null, 609 | .not_null_enum_col = .c, 610 | }, 611 | .{ 612 | .varchar_col = null, 613 | .not_null_varchar_col = "", 614 | .enum_col = null, 615 | .not_null_enum_col = .a, 616 | }, 617 | .{ 618 | .varchar_col = "baz", 619 | .not_null_varchar_col = "bar", 620 | .enum_col = null, 621 | .not_null_enum_col = .c, 622 | }, 623 | }; 624 | 625 | const structs = try rows_iter.tableStructs(StringTypesExample, allocator); 626 | defer structs.deinit(allocator); 627 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 628 | } 629 | } 630 | 631 | test "binary data types - array" { 632 | var c = try Conn.init(std.testing.allocator, &test_config); 633 | defer c.deinit(); 634 | 635 | try queryExpectOk(&c, "CREATE DATABASE test"); 636 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 637 | 638 | try queryExpectOk(&c, 639 | \\CREATE TABLE test.array_types_example ( 640 | \\ binary_col BINARY(4), 641 | \\ not_null_binary_col BINARY(4) NOT NULL 642 | \\) 643 | ); 644 | defer queryExpectOk(&c, "DROP TABLE test.array_types_example") catch {}; 645 | 646 | { // Exec Insert 647 | const prep_res = try c.prepare(allocator, "INSERT INTO test.array_types_example VALUES (?, ?)"); 648 | defer prep_res.deinit(allocator); 649 | const prep_stmt = try prep_res.expect(.stmt); 650 | 651 | const params = .{ 652 | .{ null, "1234" }, 653 | .{ "0246", "1234" }, 654 | .{ null, "123" }, 655 | .{ "024", "123" }, 656 | }; 657 | inline for (params) |param| { 658 | const exe_res = try c.execute(&prep_stmt, param); 659 | _ = try exe_res.expect(.ok); 660 | } 661 | 662 | const fail_params = .{ 663 | .{ null, "12345" }, 664 | .{ "02468", "12345" }, 665 | }; 666 | inline for (fail_params) |param| { 667 | const exe_res = try c.execute(&prep_stmt, param); 668 | _ = try exe_res.expect(.err); 669 | } 670 | } 671 | 672 | { // Text Protocol 673 | const res = try c.queryRows("SELECT * FROM test.array_types_example"); 674 | const rows: ResultSet(TextResultRow) = try res.expect(.rows); 675 | 676 | const table_texts = try rows.tableTexts(allocator); 677 | defer table_texts.deinit(allocator); 678 | 679 | const expected: []const []const ?[]const u8 = &.{ 680 | &.{ null, "1234" }, 681 | &.{ "0246", "1234" }, 682 | &.{ null, "123\x00" }, 683 | &.{ "024\x00", "123\x00" }, 684 | }; 685 | try std.testing.expectEqualDeep(expected, table_texts.table); 686 | } 687 | 688 | { // Select (Binary Protocol) 689 | const ArrayTypesExample = struct { 690 | binary_col: ?[4]u8, 691 | not_null_binary_col: [4]u8, 692 | }; 693 | 694 | const Long = struct { 695 | binary_col: ?[5]u8, 696 | not_null_binary_col: [5]u8, 697 | }; 698 | 699 | const Short = struct { 700 | binary_col: ?[3]u8, 701 | not_null_binary_col: [3]u8, 702 | }; 703 | 704 | const Sentinel = struct { 705 | binary_col: ?[4:1]u8, 706 | not_null_binary_col: [4:1]u8, 707 | }; 708 | 709 | const prep_res = try c.prepare(allocator, 710 | \\SELECT * FROM test.array_types_example 711 | ); 712 | defer prep_res.deinit(allocator); 713 | const prep_stmt = try prep_res.expect(.stmt); 714 | { 715 | const res = try c.executeRows(&prep_stmt, .{}); 716 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 717 | const rows_iter = rows.iter(); 718 | 719 | const expected: []const ArrayTypesExample = &.{ 720 | .{ 721 | .binary_col = null, 722 | .not_null_binary_col = "1234".*, 723 | }, 724 | .{ 725 | .binary_col = "0246".*, 726 | .not_null_binary_col = "1234".*, 727 | }, 728 | .{ .binary_col = null, .not_null_binary_col = "123\x00".* }, 729 | .{ 730 | .binary_col = "024\x00".*, 731 | .not_null_binary_col = "123\x00".*, 732 | }, 733 | }; 734 | 735 | const structs = try rows_iter.tableStructs(ArrayTypesExample, allocator); 736 | defer structs.deinit(allocator); 737 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 738 | } 739 | 740 | { 741 | const res = try c.executeRows(&prep_stmt, .{}); 742 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 743 | const rows_iter = rows.iter(); 744 | 745 | const expected_shorts: []const Short = &.{ 746 | .{ 747 | .binary_col = null, 748 | .not_null_binary_col = "123".*, 749 | }, 750 | .{ 751 | .binary_col = "024".*, 752 | .not_null_binary_col = "123".*, 753 | }, 754 | .{ .binary_col = null, .not_null_binary_col = "123".* }, 755 | .{ 756 | .binary_col = "024".*, 757 | .not_null_binary_col = "123".*, 758 | }, 759 | }; 760 | 761 | const shorts = try rows_iter.tableStructs(Short, allocator); 762 | defer shorts.deinit(allocator); 763 | try std.testing.expectEqualDeep(expected_shorts, shorts.struct_list.items); 764 | } 765 | 766 | { 767 | const res = try c.executeRows(&prep_stmt, .{}); 768 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 769 | const rows_iter = rows.iter(); 770 | 771 | const expected_sentinels: []const Sentinel = &.{ 772 | .{ 773 | .binary_col = null, 774 | .not_null_binary_col = @as(*const [4:1]u8, @ptrCast("1234\x01")).*, 775 | }, 776 | .{ 777 | .binary_col = @as(*const [4:1]u8, @ptrCast("0246\x01")).*, 778 | .not_null_binary_col = @as(*const [4:1]u8, @ptrCast("1234\x01")).*, 779 | }, 780 | .{ .binary_col = null, .not_null_binary_col = @as(*const [4:1]u8, @ptrCast("123\x00\x01")).* }, 781 | .{ 782 | .binary_col = @as(*const [4:1]u8, @ptrCast("024\x00\x01")).*, 783 | .not_null_binary_col = @as(*const [4:1]u8, @ptrCast("123\x00\x01")).*, 784 | }, 785 | }; 786 | 787 | const sentinels = try rows_iter.tableStructs(Sentinel, allocator); 788 | defer sentinels.deinit(allocator); 789 | try std.testing.expectEqualDeep(expected_sentinels, sentinels.struct_list.items); 790 | } 791 | 792 | { 793 | const res = try c.executeRows(&prep_stmt, .{}); 794 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 795 | const rows_iter = rows.iter(); 796 | 797 | const expected_longs: []const Long = &.{ 798 | .{ 799 | .binary_col = null, 800 | .not_null_binary_col = "1234\x00".*, 801 | }, 802 | .{ 803 | .binary_col = "0246\x00".*, 804 | .not_null_binary_col = "1234\x00".*, 805 | }, 806 | .{ .binary_col = null, .not_null_binary_col = "123\x00\x00".* }, 807 | .{ 808 | .binary_col = "024\x00\x00".*, 809 | .not_null_binary_col = "123\x00\x00".*, 810 | }, 811 | }; 812 | 813 | const longs = try rows_iter.tableStructs(Long, allocator); 814 | defer longs.deinit(allocator); 815 | try std.testing.expectEqualDeep(expected_longs, longs.struct_list.items); 816 | } 817 | } 818 | } 819 | 820 | test "binary data types - BoundedArray" { 821 | var c = try Conn.init(std.testing.allocator, &test_config); 822 | defer c.deinit(); 823 | 824 | try queryExpectOk(&c, "CREATE DATABASE test"); 825 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 826 | 827 | try queryExpectOk(&c, 828 | \\CREATE TABLE test.bounded_array_types_example ( 829 | \\ varchar_col VARCHAR(255), 830 | \\ not_null_varchar_col VARCHAR(255) NOT NULL 831 | \\) 832 | ); 833 | defer queryExpectOk(&c, "DROP TABLE test.bounded_array_types_example") catch {}; 834 | 835 | { // Exec Insert 836 | const prep_res = try c.prepare(allocator, "INSERT INTO test.bounded_array_types_example VALUES (?, ?)"); 837 | defer prep_res.deinit(allocator); 838 | const prep_stmt = try prep_res.expect(.stmt); 839 | 840 | const params = .{ 841 | .{ "hello", "world" }, 842 | .{ null, "foo" }, 843 | .{ null, "" }, 844 | .{ 845 | @as(?*const [3]u8, "baz"), 846 | @as([*:0]const u8, "bar"), 847 | }, 848 | }; 849 | inline for (params) |param| { 850 | const exe_res = try c.execute(&prep_stmt, param); 851 | _ = try exe_res.expect(.ok); 852 | } 853 | } 854 | 855 | { // Select (Binary Protocol) 856 | const BoundedArray = std.BoundedArray(u8, 255); 857 | 858 | const BoundedArrayTypesExample = struct { 859 | varchar_col: ?BoundedArray, 860 | not_null_varchar_col: BoundedArray, 861 | }; 862 | 863 | const prep_res = try c.prepare(allocator, 864 | \\SELECT * FROM test.bounded_array_types_example 865 | ); 866 | defer prep_res.deinit(allocator); 867 | const prep_stmt = try prep_res.expect(.stmt); 868 | const res = try c.executeRows(&prep_stmt, .{}); 869 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 870 | const rows_iter = rows.iter(); 871 | 872 | const expected: []const BoundedArrayTypesExample = &.{ 873 | .{ 874 | .varchar_col = blk: { 875 | var arr = try BoundedArray.init(0); 876 | try arr.appendSlice("hello"); 877 | break :blk arr; 878 | }, 879 | .not_null_varchar_col = blk: { 880 | var arr = try BoundedArray.init(0); 881 | try arr.appendSlice("world"); 882 | break :blk arr; 883 | }, 884 | }, 885 | .{ 886 | .varchar_col = null, 887 | .not_null_varchar_col = blk: { 888 | var arr = try BoundedArray.init(0); 889 | try arr.appendSlice("foo"); 890 | break :blk arr; 891 | }, 892 | }, 893 | .{ 894 | .varchar_col = null, 895 | .not_null_varchar_col = blk: { 896 | const arr = try BoundedArray.init(0); 897 | break :blk arr; 898 | }, 899 | }, 900 | .{ 901 | .varchar_col = blk: { 902 | var arr = try BoundedArray.init(0); 903 | try arr.appendSlice("baz"); 904 | break :blk arr; 905 | }, 906 | .not_null_varchar_col = blk: { 907 | var arr = try BoundedArray.init(0); 908 | try arr.appendSlice("bar"); 909 | break :blk arr; 910 | }, 911 | }, 912 | }; 913 | 914 | const structs = try rows_iter.tableStructs(BoundedArrayTypesExample, allocator); 915 | defer structs.deinit(allocator); 916 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 917 | } 918 | 919 | { // Select (Binary Protocol) -- Small array 920 | const BoundedArray = std.BoundedArray(u8, 3); 921 | 922 | const BoundedArrayTypesExample = struct { 923 | varchar_col: ?BoundedArray, 924 | not_null_varchar_col: BoundedArray, 925 | }; 926 | 927 | const prep_res = try c.prepare(allocator, 928 | \\SELECT * FROM test.bounded_array_types_example 929 | ); 930 | defer prep_res.deinit(allocator); 931 | const prep_stmt = try prep_res.expect(.stmt); 932 | const res = try c.executeRows(&prep_stmt, .{}); 933 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 934 | const rows_iter = rows.iter(); 935 | 936 | const expected: []const BoundedArrayTypesExample = &.{ 937 | .{ 938 | .varchar_col = blk: { 939 | var arr = try BoundedArray.init(0); 940 | try arr.appendSlice("hel"); 941 | break :blk arr; 942 | }, 943 | .not_null_varchar_col = blk: { 944 | var arr = try BoundedArray.init(0); 945 | try arr.appendSlice("wor"); 946 | break :blk arr; 947 | }, 948 | }, 949 | .{ 950 | .varchar_col = null, 951 | .not_null_varchar_col = blk: { 952 | var arr = try BoundedArray.init(0); 953 | try arr.appendSlice("foo"); 954 | break :blk arr; 955 | }, 956 | }, 957 | .{ 958 | .varchar_col = null, 959 | .not_null_varchar_col = blk: { 960 | const arr = try BoundedArray.init(0); 961 | break :blk arr; 962 | }, 963 | }, 964 | .{ 965 | .varchar_col = blk: { 966 | var arr = try BoundedArray.init(0); 967 | try arr.appendSlice("baz"); 968 | break :blk arr; 969 | }, 970 | .not_null_varchar_col = blk: { 971 | var arr = try BoundedArray.init(0); 972 | try arr.appendSlice("bar"); 973 | break :blk arr; 974 | }, 975 | }, 976 | }; 977 | 978 | const structs = try rows_iter.tableStructs(BoundedArrayTypesExample, allocator); 979 | defer structs.deinit(allocator); 980 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 981 | } 982 | } 983 | 984 | test "binary data types - temporal" { 985 | var c = try Conn.init(std.testing.allocator, &test_config); 986 | defer c.deinit(); 987 | 988 | try queryExpectOk(&c, "CREATE DATABASE test"); 989 | defer queryExpectOk(&c, "DROP DATABASE test") catch {}; 990 | 991 | try queryExpectOk(&c, 992 | \\CREATE TABLE test.temporal_types_example ( 993 | \\ event_time DATETIME(6) NOT NULL, 994 | \\ event_time2 DATETIME(2) NOT NULL, 995 | \\ event_time3 DATETIME NOT NULL, 996 | \\ duration TIME(6) NOT NULL, 997 | \\ duration2 TIME(4) NOT NULL, 998 | \\ duration3 TIME NOT NULL 999 | \\) 1000 | ); 1001 | defer queryExpectOk(&c, "DROP TABLE test.temporal_types_example") catch {}; 1002 | 1003 | { // Exec Insert 1004 | const prep_res = try c.prepare(allocator, "INSERT INTO test.temporal_types_example VALUES (?, ?, ?, ?, ?, ?)"); 1005 | defer prep_res.deinit(allocator); 1006 | const prep_stmt = try prep_res.expect(.stmt); 1007 | 1008 | const my_time: DateTime = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58, .microsecond = 123456 }; 1009 | const datetime_no_ms: DateTime = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58 }; 1010 | const only_day: DateTime = .{ .year = 2023, .month = 11, .day = 30 }; 1011 | const my_duration: Duration = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59, .microseconds = 123432 }; // should be 123456 but mariadb does not round, using this example just to pass the test 1012 | const duration_no_ms: Duration = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59 }; 1013 | const duration_zero: Duration = .{}; 1014 | 1015 | const params = .{ 1016 | .{ my_time, my_time, my_time, my_duration, my_duration, my_duration }, 1017 | .{ datetime_no_ms, datetime_no_ms, datetime_no_ms, duration_no_ms, duration_no_ms, duration_no_ms }, 1018 | .{ only_day, only_day, only_day, duration_zero, duration_zero, duration_zero }, 1019 | }; 1020 | 1021 | inline for (params) |param| { 1022 | const exe_res = try c.execute(&prep_stmt, param); 1023 | _ = try exe_res.expect(.ok); 1024 | } 1025 | } 1026 | 1027 | { // Text Protocol 1028 | const res = try c.queryRows("SELECT * FROM test.temporal_types_example"); 1029 | const rows: ResultSet(TextResultRow) = try res.expect(.rows); 1030 | 1031 | const table_texts = try rows.tableTexts(allocator); 1032 | defer table_texts.deinit(allocator); 1033 | 1034 | const expected: []const []const ?[]const u8 = &.{ 1035 | &.{ "2023-11-30 06:50:58.123456", "2023-11-30 06:50:58.12", "2023-11-30 06:50:58", "47:59:59.123432", "47:59:59.1234", "47:59:59" }, 1036 | &.{ "2023-11-30 06:50:58.000000", "2023-11-30 06:50:58.00", "2023-11-30 06:50:58", "47:59:59.000000", "47:59:59.0000", "47:59:59" }, 1037 | &.{ "2023-11-30 00:00:00.000000", "2023-11-30 00:00:00.00", "2023-11-30 00:00:00", "00:00:00.000000", "00:00:00.0000", "00:00:00" }, 1038 | }; 1039 | 1040 | try std.testing.expectEqualDeep(expected, table_texts.table); 1041 | } 1042 | 1043 | { // Select (Binary Protocol) 1044 | const TemporalTypesExample = struct { 1045 | event_time: DateTime, 1046 | event_time2: DateTime, 1047 | event_time3: DateTime, 1048 | duration: Duration, 1049 | duration2: Duration, 1050 | duration3: Duration, 1051 | }; 1052 | const prep_res = try c.prepare(allocator, "SELECT * FROM test.temporal_types_example LIMIT 3"); 1053 | defer prep_res.deinit(allocator); 1054 | const prep_stmt = try prep_res.expect(.stmt); 1055 | const res = try c.executeRows(&prep_stmt, .{}); 1056 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 1057 | const rows_iter = rows.iter(); 1058 | 1059 | const expected: []const TemporalTypesExample = &.{ 1060 | .{ 1061 | .event_time = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58, .microsecond = 123456 }, 1062 | .event_time2 = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58, .microsecond = 120000 }, 1063 | .event_time3 = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58 }, 1064 | .duration = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59, .microseconds = 123432 }, 1065 | .duration2 = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59, .microseconds = 123400 }, 1066 | .duration3 = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59 }, 1067 | }, 1068 | .{ 1069 | .event_time = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58 }, 1070 | .event_time2 = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58 }, 1071 | .event_time3 = .{ .year = 2023, .month = 11, .day = 30, .hour = 6, .minute = 50, .second = 58 }, 1072 | .duration = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59 }, 1073 | .duration2 = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59 }, 1074 | .duration3 = .{ .days = 1, .hours = 23, .minutes = 59, .seconds = 59 }, 1075 | }, 1076 | .{ 1077 | .event_time = .{ .year = 2023, .month = 11, .day = 30 }, 1078 | .event_time2 = .{ .year = 2023, .month = 11, .day = 30 }, 1079 | .event_time3 = .{ .year = 2023, .month = 11, .day = 30 }, 1080 | .duration = .{}, 1081 | .duration2 = .{}, 1082 | .duration3 = .{}, 1083 | }, 1084 | }; 1085 | 1086 | const structs = try rows_iter.tableStructs(TemporalTypesExample, allocator); 1087 | defer structs.deinit(allocator); 1088 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 1089 | } 1090 | } 1091 | 1092 | test "select concat with params" { 1093 | var c = try Conn.init(std.testing.allocator, &test_config); 1094 | defer c.deinit(); 1095 | 1096 | { // Select (Binary Protocol) 1097 | const prep_res = try c.prepare(allocator, "SELECT CONCAT(?, ?) AS col1"); 1098 | defer prep_res.deinit(allocator); 1099 | const prep_stmt = try prep_res.expect(.stmt); 1100 | const res = try c.executeRows(&prep_stmt, .{ runtimeValue("hello"), runtimeValue("world") }); 1101 | const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); 1102 | const rows_iter = rows.iter(); 1103 | 1104 | const Result = struct { col1: []const u8 }; 1105 | const expected: []const Result = &.{.{ .col1 = "helloworld" }}; 1106 | const structs = try rows_iter.tableStructs(Result, allocator); 1107 | defer structs.deinit(allocator); 1108 | try std.testing.expectEqualDeep(expected, structs.struct_list.items); 1109 | } 1110 | } 1111 | 1112 | fn runtimeValue(a: anytype) @TypeOf(a) { 1113 | return a; 1114 | } 1115 | -------------------------------------------------------------------------------- /integration_tests/main.zig: -------------------------------------------------------------------------------- 1 | pub const conn = @import("./conn.zig"); 2 | 3 | test { 4 | @import("std").testing.refAllDeclsRecursive(@This()); 5 | } 6 | -------------------------------------------------------------------------------- /src/auth.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const PublicKey = std.crypto.Certificate.rsa.PublicKey; 3 | const Sha1 = std.crypto.hash.Sha1; 4 | const Sha256 = std.crypto.hash.sha2.Sha256; 5 | 6 | const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); 7 | 8 | pub const AuthPlugin = enum { 9 | unspecified, 10 | mysql_native_password, 11 | sha256_password, 12 | caching_sha2_password, 13 | mysql_clear_password, 14 | unknown, 15 | 16 | pub fn fromName(name: []const u8) AuthPlugin { 17 | return std.meta.stringToEnum(AuthPlugin, name) orelse .unknown; 18 | } 19 | 20 | pub fn toName(auth_plugin: AuthPlugin) [:0]const u8 { 21 | return @tagName(auth_plugin); 22 | } 23 | }; 24 | 25 | // https://mariadb.com/kb/en/sha256_password-plugin/ 26 | pub const sha256_password_public_key_request = 0x01; 27 | 28 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html 29 | // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ 30 | pub const caching_sha2_password_public_key_response = 0x01; 31 | pub const caching_sha2_password_public_key_request = 0x02; 32 | pub const caching_sha2_password_fast_auth_success = 0x03; 33 | pub const caching_sha2_password_full_authentication_start = 0x04; 34 | 35 | pub const DecodedPublicKey = struct { 36 | allocated: []const u8, 37 | value: std.crypto.Certificate.rsa.PublicKey, 38 | 39 | pub fn deinit(d: *const DecodedPublicKey, allocator: std.mem.Allocator) void { 40 | allocator.free(d.allocated); 41 | } 42 | }; 43 | 44 | pub fn decodePublicKey(encoded_bytes: []const u8, allocator: std.mem.Allocator) !DecodedPublicKey { 45 | var decoded_pk: DecodedPublicKey = undefined; 46 | 47 | const start_marker = "-----BEGIN PUBLIC KEY-----"; 48 | const end_marker = "-----END PUBLIC KEY-----"; 49 | 50 | const base64_encoded = blk: { 51 | const start_marker_pos = std.mem.indexOfPos(u8, encoded_bytes, 0, start_marker).?; 52 | const base64_start = start_marker_pos + start_marker.len; 53 | const base64_end = std.mem.indexOfPos(u8, encoded_bytes, base64_start, end_marker).?; 54 | break :blk std.mem.trim(u8, encoded_bytes[base64_start..base64_end], " \t\r\n"); 55 | }; 56 | 57 | const dest = try allocator.alloc(u8, try base64.calcSizeUpperBound(base64_encoded.len)); 58 | decoded_pk.allocated = dest; 59 | errdefer allocator.free(decoded_pk.allocated); 60 | 61 | const base64_decoded = blk: { 62 | const n = try base64.decode(dest, base64_encoded); 63 | break :blk decoded_pk.allocated[0..n]; 64 | }; 65 | 66 | // Example of DER-encoded public key: 67 | // SEQUENCE (2 elem) 68 | // SEQUENCE (2 elem) 69 | // OBJECT IDENTIFIER 1.2.840.113549.1.1.1 rsaEncryption (PKCS #1) 70 | // NULL 71 | // BIT STRING (2160 bit) 001100001000001000000001000010100000001010000010000000010000000100000… 72 | // SEQUENCE (2 elem) 73 | // INTEGER (2048 bit) 273994660083475464992607454720526089815923926694328893650906911229409… 74 | // INTEGER 65537 75 | 76 | const bitstring = blk: { 77 | const Element = std.crypto.Certificate.der.Element; 78 | const top_level = try Element.parse(base64_decoded, 0); 79 | const seq_1 = try Element.parse(base64_decoded, top_level.slice.start); 80 | const bitstring_elem = try Element.parse(base64_decoded, seq_1.slice.end); 81 | break :blk std.mem.trim(u8, base64_decoded[bitstring_elem.slice.start..bitstring_elem.slice.end], &.{0}); 82 | }; 83 | 84 | const pk_decoded = try std.crypto.Certificate.rsa.PublicKey.parseDer(bitstring); 85 | decoded_pk.value = try std.crypto.Certificate.rsa.PublicKey.fromBytes(pk_decoded.exponent, pk_decoded.modulus); 86 | return decoded_pk; 87 | } 88 | 89 | test "decode public key" { 90 | const pk = 91 | \\-----BEGIN PUBLIC KEY----- 92 | \\MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA2QurErkXa1sGRr1AV4wJ 93 | \\m7cT0aSDrLsA+PHT8D6yjWhLEOocBzxuK0Z/1ytBAjRH9LtCbyHML81OIIACt03u 94 | \\Y+8xbtFLyOP0NxsLe5FzQ+R4PPQDnubtJeSa4E7jZZEIkAWS11cPo7/wXX3elfeb 95 | \\tzJDEjvFa7VDTcD1jh+0p03k+iPbt9f91+PauD/oCr0RbgL737/UTeN7F5sXCS9F 96 | \\OOPW+bqgdPV08c4Dx4qSxg9WrktRUA9RDxWdetzYyNVc9/+VsKbnCUFQuGCevvWi 97 | \\MHxq6dOI8fy+OYkaNo3UbU+4surE+JVIEdvAkhwVDN3DBBZ6gtpU5PukS4mcpUPt 98 | \\wQIDAQAB 99 | \\-----END PUBLIC KEY----- 100 | ; 101 | 102 | const d = try decodePublicKey(pk, std.testing.allocator); 103 | defer d.deinit(std.testing.allocator); 104 | } 105 | 106 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods_native_password_authentication.html 107 | // SHA1(password) XOR SHA1(scramble ++ SHA1(SHA1(password))) 108 | pub fn scramblePassword(scramble: []const u8, password: []const u8) [20]u8 { 109 | var message1 = blk: { // SHA1(password) 110 | var sha1 = Sha1.init(.{}); 111 | sha1.update(password); 112 | break :blk sha1.finalResult(); 113 | }; 114 | const message2 = blk: { // SHA1(SHA1(password)) 115 | var sha1 = Sha1.init(.{}); 116 | sha1.update(&message1); 117 | var hash = sha1.finalResult(); 118 | 119 | sha1 = Sha1.init(.{}); 120 | sha1.update(scramble); 121 | sha1.update(&hash); 122 | sha1.final(&hash); 123 | break :blk hash; 124 | }; 125 | for (&message1, message2) |*m1, m2| { 126 | m1.* ^= m2; 127 | } 128 | return message1; 129 | } 130 | 131 | test "scramblePassword" { 132 | const scramble: []const u8 = &.{ 133 | 10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 134 | 114, 74, 37, 13, 3, 80, 82, 2, 23, 21, 135 | }; 136 | const tests = [_]struct { 137 | password: []const u8, 138 | expected: [20]u8, 139 | }{ 140 | .{ 141 | .password = "secret", 142 | .expected = .{ 143 | 106, 20, 155, 221, 128, 189, 161, 235, 240, 250, 144 | 43, 210, 207, 46, 151, 23, 254, 204, 52, 187, 145 | }, 146 | }, 147 | .{ 148 | .password = "secret2", 149 | .expected = .{ 150 | 101, 15, 7, 223, 53, 60, 206, 83, 112, 238, 151 | 163, 77, 88, 15, 46, 145, 24, 129, 139, 86, 152 | }, 153 | }, 154 | }; 155 | 156 | for (tests) |t| { 157 | const actual = scramblePassword(scramble, t.password); 158 | try std.testing.expectEqual(t.expected, actual); 159 | } 160 | } 161 | 162 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html 163 | // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) 164 | pub fn scrambleSHA256Password(scramble: []const u8, password: []const u8) [32]u8 { 165 | std.debug.assert(password.len > 0); 166 | 167 | var message1 = blk: { // SHA256(password) 168 | var hasher = Sha256.init(.{}); 169 | hasher.update(password); 170 | break :blk hasher.finalResult(); 171 | }; 172 | const message2 = blk: { // SHA256(SHA256(SHA256(password)), scramble) 173 | var sha256 = Sha256.init(.{}); 174 | sha256.update(&message1); 175 | var hash = sha256.finalResult(); 176 | 177 | sha256 = Sha256.init(.{}); 178 | sha256.update(&hash); 179 | sha256.update(scramble); 180 | sha256.final(&hash); 181 | break :blk hash; 182 | }; 183 | for (&message1, message2) |*m1, m2| { 184 | m1.* ^= m2; 185 | } 186 | return message1; 187 | } 188 | 189 | test "scrambleSHA256Password" { 190 | const scramble = [_]u8{ 10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21 }; 191 | const tests = [_]struct { 192 | password: []const u8, 193 | expected: [32]u8, 194 | }{ 195 | .{ 196 | .password = "secret", 197 | .expected = .{ 244, 144, 231, 111, 102, 217, 216, 102, 101, 206, 84, 217, 140, 120, 208, 172, 254, 47, 176, 176, 139, 66, 61, 168, 7, 20, 72, 115, 211, 11, 49, 44 }, 198 | }, 199 | .{ 200 | .password = "secret2", 201 | .expected = .{ 171, 195, 147, 74, 1, 44, 243, 66, 232, 118, 7, 28, 142, 226, 2, 222, 81, 120, 91, 67, 2, 88, 167, 160, 19, 139, 199, 156, 77, 128, 11, 198 }, 202 | }, 203 | }; 204 | 205 | for (tests) |t| { 206 | const actual = scrambleSHA256Password(&scramble, t.password); 207 | try std.testing.expectEqual(t.expected, actual); 208 | } 209 | } 210 | 211 | // https://mariadb.com/kb/en/sha256_password-plugin/#rsa-encrypted-password 212 | // RSA encrypted value of XOR(password, seed) using server public key (RSA_PKCS1_OAEP_PADDING). 213 | pub fn encryptPassword(allocator: std.mem.Allocator, password: []const u8, seed: *const [20]u8, pk: *const PublicKey) ![]const u8 { 214 | const plain = blk: { 215 | var plain = try allocator.alloc(u8, password.len + 1); 216 | @memcpy(plain.ptr, password); 217 | plain[plain.len - 1] = 0; 218 | break :blk plain; 219 | }; 220 | defer allocator.free(plain); 221 | 222 | for (plain, 0..) |*c, i| { 223 | c.* ^= seed[i % 20]; 224 | } 225 | 226 | return rsaEncryptOAEP(allocator, plain, pk); 227 | } 228 | 229 | fn rsaEncryptOAEP(allocator: std.mem.Allocator, msg: []const u8, pk: *const PublicKey) ![]const u8 { 230 | const init_hash = Sha1.init(.{}); 231 | 232 | const lHash = blk: { 233 | var hash = init_hash; 234 | hash.update(&.{}); 235 | break :blk hash.finalResult(); 236 | }; 237 | const digest_len = lHash.len; 238 | 239 | const k = (pk.n.bits() + 7) / 8; // modulus size in bytes 240 | 241 | var em = try allocator.alloc(u8, k); 242 | defer allocator.free(em); 243 | @memset(em, 0); 244 | const seed = em[1 .. 1 + digest_len]; 245 | const db = em[1 + digest_len ..]; 246 | 247 | @memcpy(db[0..lHash.len], &lHash); 248 | db[db.len - msg.len - 1] = 1; 249 | @memcpy(db[db.len - msg.len ..], msg); 250 | std.crypto.random.bytes(seed); 251 | 252 | mgf1XOR(db, &init_hash, seed); 253 | mgf1XOR(seed, &init_hash, db); 254 | 255 | return encryptMsg(allocator, em, pk); 256 | } 257 | 258 | fn encryptMsg(allocator: std.mem.Allocator, msg: []const u8, pk: *const PublicKey) ![]const u8 { 259 | // can remove this if it's publicly exposed in std.crypto.Certificate.rsa 260 | // for now, just copy it from std.crypto.ff 261 | const max_modulus_bits = 4096; 262 | const Modulus = std.crypto.ff.Modulus(max_modulus_bits); 263 | const Fe = Modulus.Fe; 264 | 265 | const m = try Fe.fromBytes(pk.*.n, msg, .big); 266 | const e = try pk.n.powPublic(m, pk.e); 267 | 268 | const res = try allocator.alloc(u8, msg.len); 269 | try e.toBytes(res, .big); 270 | return res; 271 | } 272 | 273 | // mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function 274 | // specified in PKCS #1 v2.1. 275 | fn mgf1XOR(dest: []u8, init_hash: *const Sha1, seed: []const u8) void { 276 | var counter: [4]u8 = .{ 0, 0, 0, 0 }; 277 | var digest: [Sha1.digest_length]u8 = undefined; 278 | 279 | var done: usize = 0; 280 | while (done < dest.len) : (incCounter(&counter)) { 281 | var hash = init_hash.*; 282 | hash.update(seed); 283 | hash.update(counter[0..4]); 284 | digest = hash.finalResult(); 285 | 286 | for (&digest) |*d| { 287 | if (done >= dest.len) break; 288 | dest[done] ^= d.*; 289 | done += 1; 290 | } 291 | } 292 | } 293 | 294 | // incCounter increments a four byte, big-endian counter. 295 | fn incCounter(c: *[4]u8) void { 296 | inline for (&.{ 3, 2, 1, 0 }) |i| { 297 | c[i], const overflow_bit = @addWithOverflow(c[i], 1); 298 | if (overflow_bit == 0) return; // no overflow, so we're done 299 | } 300 | } 301 | -------------------------------------------------------------------------------- /src/config.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("./constants.zig"); 3 | 4 | pub const Config = struct { 5 | username: [:0]const u8 = "root", 6 | address: std.net.Address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 3306), 7 | password: []const u8 = "", 8 | database: [:0]const u8 = "", 9 | collation: u8 = constants.utf8mb4_general_ci, 10 | 11 | // cfgs from Golang driver 12 | client_found_rows: bool = false, // Return number of matching rows instead of rows changed 13 | ssl: bool = false, 14 | multi_statements: bool = false, 15 | 16 | pub fn capability_flags(config: *const Config) u32 { 17 | // zig fmt: off 18 | var flags: u32 = constants.CLIENT_PROTOCOL_41 19 | | constants.CLIENT_PLUGIN_AUTH 20 | | constants.CLIENT_SECURE_CONNECTION 21 | | constants.CLIENT_DEPRECATE_EOF 22 | // TODO: Support more 23 | ; 24 | // zig fmt: on 25 | if (config.client_found_rows) { 26 | flags |= constants.CLIENT_FOUND_ROWS; 27 | } 28 | if (config.ssl) { 29 | flags |= constants.CLIENT_SSL; 30 | } 31 | if (config.multi_statements) { 32 | flags |= constants.CLIENT_MULTI_STATEMENTS; 33 | } 34 | if (config.database.len > 0) { 35 | flags |= constants.CLIENT_CONNECT_WITH_DB; 36 | } 37 | return flags; 38 | } 39 | }; 40 | -------------------------------------------------------------------------------- /src/conn.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Config = @import("./config.zig").Config; 3 | const constants = @import("./constants.zig"); 4 | const auth = @import("./auth.zig"); 5 | const protocol = @import("./protocol.zig"); 6 | const HandshakeV10 = protocol.handshake_v10.HandshakeV10; 7 | const ErrorPacket = protocol.generic_response.ErrorPacket; 8 | const OkPacket = protocol.generic_response.OkPacket; 9 | const HandshakeResponse41 = protocol.handshake_response.HandshakeResponse41; 10 | const QueryRequest = protocol.text_command.QueryRequest; 11 | const prepared_statements = protocol.prepared_statements; 12 | const PrepareRequest = prepared_statements.PrepareRequest; 13 | const ExecuteRequest = prepared_statements.ExecuteRequest; 14 | const packet_writer = protocol.packet_writer; 15 | const Packet = protocol.packet.Packet; 16 | const PacketReader = protocol.packet_reader.PacketReader; 17 | const PacketWriter = protocol.packet_writer.PacketWriter; 18 | const result = @import("./result.zig"); 19 | const QueryResultRows = result.QueryResultRows; 20 | const QueryResult = result.QueryResult; 21 | const PrepareResult = result.PrepareResult; 22 | const PreparedStatement = result.PreparedStatement; 23 | const TextResultRow = result.TextResultRow; 24 | const BinaryResultRow = result.BinaryResultRow; 25 | const ResultMeta = @import("./result_meta.zig").ResultMeta; 26 | 27 | const max_packet_size = 1 << 24 - 1; 28 | 29 | // TODO: make this adjustable during compile time 30 | const buffer_size: usize = 4096; 31 | 32 | pub const Conn = struct { 33 | connected: bool, 34 | stream: std.net.Stream, 35 | reader: PacketReader, 36 | writer: PacketWriter, 37 | capabilities: u32, 38 | sequence_id: u8, 39 | 40 | // Buffer to store metadata of the result set 41 | result_meta: ResultMeta, 42 | 43 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html 44 | pub fn init(allocator: std.mem.Allocator, config: *const Config) !Conn { 45 | var conn: Conn = blk: { 46 | const stream = switch (config.address.any.family) { 47 | std.posix.AF.INET, std.posix.AF.INET6 => try std.net.tcpConnectToAddress(config.address), 48 | std.posix.AF.UNIX => try std.net.connectUnixSocket(std.mem.span(@as([*:0]const u8, @ptrCast(&config.address.un.path)))), 49 | else => unreachable, 50 | }; 51 | break :blk .{ 52 | .connected = true, 53 | .stream = stream, 54 | .reader = try PacketReader.init(stream, allocator), 55 | .writer = try PacketWriter.init(stream, allocator), 56 | .capabilities = undefined, // not known until we get the first packet 57 | .sequence_id = undefined, // not known until we get the first packet 58 | 59 | .result_meta = ResultMeta.init(allocator), 60 | }; 61 | }; 62 | errdefer conn.deinit(); 63 | 64 | const auth_plugin, const auth_data = blk: { 65 | const packet = try conn.readPacket(); 66 | const handshake_v10 = switch (packet.payload[0]) { 67 | constants.HANDSHAKE_V10 => HandshakeV10.init(&packet), 68 | constants.ERR => return ErrorPacket.initFirst(&packet).asError(), 69 | else => return packet.asError(), 70 | }; 71 | conn.capabilities = handshake_v10.capability_flags() & config.capability_flags(); 72 | 73 | if (conn.capabilities & constants.CLIENT_PROTOCOL_41 == 0) { 74 | std.log.err("protocol older than 4.1 is not supported\n", .{}); 75 | return error.UnsupportedProtocol; 76 | } 77 | 78 | break :blk .{ handshake_v10.get_auth_plugin(), handshake_v10.get_auth_data() }; 79 | }; 80 | 81 | // TODO: TLS handshake if enabled 82 | 83 | // more auth exchange based on auth_method 84 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods.html 85 | switch (auth_plugin) { 86 | .caching_sha2_password => try conn.auth_caching_sha2_password(allocator, &auth_data, config), 87 | .mysql_native_password => try conn.auth_mysql_native_password(&auth_data, config), 88 | .sha256_password => try conn.auth_sha256_password(allocator, &auth_data, config), 89 | else => { 90 | std.log.warn("Unsupported auth plugin: {any}\n", .{auth_plugin}); 91 | return error.UnsupportedAuthPlugin; 92 | }, 93 | } 94 | 95 | return conn; 96 | } 97 | 98 | pub fn deinit(c: *Conn) void { 99 | c.quit() catch |err| { 100 | std.log.err("Failed to quit: {any}\n", .{err}); 101 | }; 102 | c.stream.close(); 103 | c.reader.deinit(); 104 | c.writer.deinit(); 105 | c.result_meta.deinit(); 106 | } 107 | 108 | pub fn ping(c: *Conn) !void { 109 | c.ready(); 110 | try c.writeBytesAsPacket(&[_]u8{constants.COM_PING}); 111 | try c.writer.flush(); 112 | const packet = try c.readPacket(); 113 | 114 | switch (packet.payload[0]) { 115 | constants.OK => _ = OkPacket.init(&packet, c.capabilities), 116 | else => return packet.asError(), 117 | } 118 | } 119 | 120 | // query that doesn't return any rows 121 | pub fn query(c: *Conn, query_string: []const u8) !QueryResult { 122 | c.ready(); 123 | const query_req: QueryRequest = .{ .query = query_string }; 124 | try c.writePacket(query_req); 125 | try c.writer.flush(); 126 | const packet = try c.readPacket(); 127 | return c.queryResult(&packet); 128 | } 129 | 130 | // query that expect rows, even if it returns 0 rows 131 | pub fn queryRows(c: *Conn, query_string: []const u8) !QueryResultRows(TextResultRow) { 132 | c.ready(); 133 | const query_req: QueryRequest = .{ .query = query_string }; 134 | try c.writePacket(query_req); 135 | try c.writer.flush(); 136 | return QueryResultRows(TextResultRow).init(c); 137 | } 138 | 139 | pub fn prepare(c: *Conn, allocator: std.mem.Allocator, query_string: []const u8) !PrepareResult { 140 | c.ready(); 141 | const prepare_request: PrepareRequest = .{ .query = query_string }; 142 | try c.writePacket(prepare_request); 143 | try c.writer.flush(); 144 | return PrepareResult.init(c, allocator); 145 | } 146 | 147 | // execute a prepared statement that doesn't return any rows 148 | pub fn execute(c: *Conn, prep_stmt: *const PreparedStatement, params: anytype) !QueryResult { 149 | c.ready(); 150 | std.debug.assert(prep_stmt.res_cols.len == 0); // execute expects no rows 151 | c.sequence_id = 0; 152 | const execute_request: ExecuteRequest = .{ 153 | .capabilities = c.capabilities, 154 | .prep_stmt = prep_stmt, 155 | }; 156 | try c.writePacketWithParam(execute_request, params); 157 | try c.writer.flush(); 158 | const packet = try c.readPacket(); 159 | return c.queryResult(&packet); 160 | } 161 | 162 | // execute a prepared statement that expect rows, even if it returns 0 rows 163 | pub fn executeRows(c: *Conn, prep_stmt: *const PreparedStatement, params: anytype) !QueryResultRows(BinaryResultRow) { 164 | c.ready(); 165 | std.debug.assert(prep_stmt.res_cols.len > 0); // executeRows expects rows 166 | c.sequence_id = 0; 167 | const execute_request: ExecuteRequest = .{ 168 | .capabilities = c.capabilities, 169 | .prep_stmt = prep_stmt, 170 | }; 171 | try c.writePacketWithParam(execute_request, params); 172 | try c.writer.flush(); 173 | return QueryResultRows(BinaryResultRow).init(c); 174 | } 175 | 176 | fn quit(c: *Conn) !void { 177 | c.ready(); 178 | try c.writeBytesAsPacket(&[_]u8{constants.COM_QUIT}); 179 | try c.writer.flush(); 180 | const packet = c.readPacket() catch |err| switch (err) { 181 | error.UnexpectedEndOfStream => { 182 | c.connected = false; 183 | return; 184 | }, 185 | else => return err, 186 | }; 187 | return packet.asError(); 188 | } 189 | 190 | fn auth_mysql_native_password(c: *Conn, auth_data: *const [20]u8, config: *const Config) !void { 191 | const auth_resp = auth.scramblePassword(auth_data, config.password); 192 | const response = HandshakeResponse41.init(.mysql_native_password, config, if (config.password.len > 0) &auth_resp else &[_]u8{}); 193 | try c.writePacket(response); 194 | try c.writer.flush(); 195 | 196 | const packet = try c.readPacket(); 197 | return switch (packet.payload[0]) { 198 | constants.OK => {}, 199 | else => packet.asError(), 200 | }; 201 | } 202 | 203 | fn auth_sha256_password(c: *Conn, allocator: std.mem.Allocator, auth_data: *const [20]u8, config: *const Config) !void { 204 | // TODO: if there is already a pub key, skip requesting it 205 | const response = HandshakeResponse41.init(.sha256_password, config, &[_]u8{auth.sha256_password_public_key_request}); 206 | try c.writePacket(response); 207 | try c.writer.flush(); 208 | 209 | const pk_packet = try c.readPacket(); 210 | 211 | // Decode public key 212 | const decoded_pk = try auth.decodePublicKey(pk_packet.payload, allocator); 213 | defer decoded_pk.deinit(allocator); 214 | 215 | const enc_pw = try auth.encryptPassword(allocator, config.password, auth_data, &decoded_pk.value); 216 | defer allocator.free(enc_pw); 217 | 218 | try c.writeBytesAsPacket(enc_pw); 219 | try c.writer.flush(); 220 | 221 | const resp_packet = try c.readPacket(); 222 | return switch (resp_packet.payload[0]) { 223 | constants.OK => {}, 224 | else => resp_packet.asError(), 225 | }; 226 | } 227 | 228 | fn auth_caching_sha2_password(c: *Conn, allocator: std.mem.Allocator, auth_data: *const [20]u8, config: *const Config) !void { 229 | const auth_resp = auth.scrambleSHA256Password(auth_data, config.password); 230 | const response = HandshakeResponse41.init(.caching_sha2_password, config, &auth_resp); 231 | try c.writePacket(&response); 232 | try c.writer.flush(); 233 | 234 | while (true) { 235 | const packet = try c.readPacket(); 236 | switch (packet.payload[0]) { 237 | constants.OK => return, 238 | constants.AUTH_MORE_DATA => { 239 | const more_data = packet.payload[1..]; 240 | switch (more_data[0]) { 241 | auth.caching_sha2_password_fast_auth_success => {}, // success (do nothing, wait for next packet) 242 | auth.caching_sha2_password_full_authentication_start => { 243 | // Full Authentication start 244 | 245 | // TODO: support TLS 246 | // // if TLS, send password as plain text 247 | // try conn.sendBytesAsPacket(config.password); 248 | 249 | try c.writeBytesAsPacket(&[_]u8{auth.caching_sha2_password_public_key_request}); 250 | try c.writer.flush(); 251 | const pk_packet = try c.readPacket(); 252 | 253 | // Decode public key 254 | const decoded_pk = try auth.decodePublicKey(pk_packet.payload, allocator); 255 | defer decoded_pk.deinit(allocator); 256 | 257 | // Encrypt password 258 | const enc_pw = try auth.encryptPassword(allocator, config.password, auth_data, &decoded_pk.value); 259 | defer allocator.free(enc_pw); 260 | 261 | try c.writeBytesAsPacket(enc_pw); 262 | try c.writer.flush(); 263 | }, 264 | else => return error.UnsupportedCachingSha2PasswordMoreData, 265 | } 266 | }, 267 | else => return packet.asError(), 268 | } 269 | } 270 | } 271 | 272 | pub inline fn readPacket(c: *Conn) !Packet { 273 | const packet = try c.reader.readPacket(); 274 | c.sequence_id = packet.sequence_id + 1; 275 | return packet; 276 | } 277 | 278 | pub inline fn readPutResultColumns(c: *Conn, n: usize) !void { 279 | try c.result_meta.readPutResultColumns(c, n); 280 | } 281 | 282 | inline fn writePacket(c: *Conn, packet: anytype) !void { 283 | try c.writer.writePacket(c.generateSequenceId(), packet); 284 | } 285 | 286 | inline fn writePacketWithParam(c: *Conn, packet: anytype, params: anytype) !void { 287 | try c.writer.writePacketWithParams(c.generateSequenceId(), packet, params); 288 | } 289 | 290 | inline fn writeBytesAsPacket(c: *Conn, packet: anytype) !void { 291 | try c.writer.writeBytesAsPacket(c.generateSequenceId(), packet); 292 | } 293 | 294 | inline fn generateSequenceId(c: *Conn) u8 { 295 | const sequence_id = c.sequence_id; 296 | c.sequence_id += 1; 297 | return sequence_id; 298 | } 299 | 300 | inline fn queryResult(c: *Conn, packet: *const Packet) !QueryResult { 301 | const res = QueryResult.init(packet, c.capabilities) catch |err| { 302 | switch (err) { 303 | error.UnrecoverableError => { 304 | c.stream.close(); 305 | c.connected = false; 306 | return err; 307 | }, 308 | else => return err, 309 | } 310 | }; 311 | return res; 312 | } 313 | 314 | inline fn ready(c: *Conn) void { 315 | std.debug.assert(c.connected); 316 | std.debug.assert(c.writer.pos == 0); 317 | std.debug.assert(c.reader.pos == c.reader.len); 318 | c.sequence_id = 0; 319 | } 320 | }; 321 | -------------------------------------------------------------------------------- /src/constants.zig: -------------------------------------------------------------------------------- 1 | // zig fmt: off 2 | const std = @import("std"); 3 | 4 | // MySQL Packet Header 5 | pub const OK: u8 = 0x00; 6 | pub const EOF: u8 = 0xFE; 7 | pub const AUTH_SWITCH: u8 = 0xFE; 8 | pub const AUTH_MORE_DATA: u8 = 0x01; 9 | pub const ERR: u8 = 0xFF; 10 | pub const HANDSHAKE_V10: u8 = 0x0A; 11 | pub const LOCAL_INFILE_REQUEST: u8 = 0xFB; 12 | 13 | // Query Result 14 | pub const TEXT_RESULT_ROW_NULL: u8 = 0xFB; 15 | 16 | // https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html 17 | pub const SERVER_STATUS_IN_TRANS: u16 = 1 << 0; 18 | pub const SERVER_STATUS_AUTOCOMMIT: u16 = 1 << 1; 19 | pub const SERVER_MORE_RESULTS_EXISTS: u16 = 1 << 2; 20 | pub const SERVER_QUERY_NO_GOOD_INDEX_USED: u16 = 1 << 3; 21 | pub const SERVER_STATUS_CURSOR_EXISTS: u16 = 1 << 4; 22 | pub const SERVER_STATUS_LAST_ROW_SENT: u16 = 1 << 5; 23 | pub const SERVER_STATUS_DB_DROPPED: u16 = 1 << 6; 24 | pub const SERVER_STATUS_NO_BACKSLASH_ESCAPES: u16 = 1 << 7; 25 | pub const SERVER_QUERY_WAS_SLOW: u16 = 1 << 8; 26 | pub const SERVER_STATUS_IN_TRANS_READONLY: u16 = 1 << 9; 27 | pub const SERVER_SESSION_STATE_CHANGED: u16 = 1 << 10; 28 | 29 | // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html 30 | pub const CLIENT_LONG_PASSWORD: u32 = 1; 31 | pub const CLIENT_FOUND_ROWS: u32 = 2; 32 | pub const CLIENT_LONG_FLAG: u32 = 4; 33 | pub const CLIENT_CONNECT_WITH_DB: u32 = 8; 34 | pub const CLIENT_NO_SCHEMA: u32 = 16; 35 | pub const CLIENT_COMPRESS: u32 = 32; 36 | pub const CLIENT_ODBC: u32 = 64; 37 | pub const CLIENT_LOCAL_FILES: u32 = 128; 38 | pub const CLIENT_IGNORE_SPACE: u32 = 256; 39 | pub const CLIENT_PROTOCOL_41: u32 = 512; 40 | pub const CLIENT_INTERACTIVE: u32 = 1024; 41 | pub const CLIENT_SSL: u32 = 2048; 42 | pub const CLIENT_IGNORE_SIGPIPE: u32 = 4096; 43 | pub const CLIENT_TRANSACTIONS: u32 = 8192; 44 | pub const CLIENT_RESERVED: u32 = 16384; 45 | 46 | pub const CLIENT_SECURE_CONNECTION: u32 = 32768; // Appears deprecated in MySQL but still used in MariaDB 47 | pub const CLIENT_MULTI_STATEMENTS: u32 = 1 << 16; 48 | pub const CLIENT_MULTI_RESULTS: u32 = 1 << 17; 49 | pub const CLIENT_PS_MULTI_RESULTS: u32 = 1 << 18; 50 | pub const CLIENT_PLUGIN_AUTH: u32 = 1 << 19; 51 | pub const CLIENT_CONNECT_ATTRS: u32 = 1 << 20; 52 | pub const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: u32 = 1 << 21; 53 | pub const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: u32 = 1 << 22; 54 | pub const CLIENT_SESSION_TRACK: u32 = 1 << 23; 55 | pub const CLIENT_DEPRECATE_EOF: u32 = 1 << 24; 56 | pub const CLIENT_OPTIONAL_RESULTSET_METADATA: u32 = 1 << 25; 57 | pub const CLIENT_ZSTD_COMPRESSION_ALGORITHM: u32 = 1 << 26; 58 | pub const CLIENT_QUERY_ATTRIBUTES: u32 = 1 << 27; 59 | pub const CTOR_AUTHENTICATION: u32 = 1 << 28; 60 | pub const CLIENT_CAPABILITY_EXTENSION: u32 = 1 << 29; 61 | pub const CLIENT_SSL_VERIFY_SERVER_CERT: u32 = 1 << 30; 62 | 63 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html 64 | pub const COM_QUERY: u8 = 0x03; 65 | 66 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase_ps.html 67 | pub const COM_STMT_PREPARE: u8 = 0x16; 68 | pub const COM_STMT_EXECUTE: u8 = 0x17; 69 | 70 | pub const BINARY_PROTOCOL_RESULTSET_ROW_HEADER: u8 = 0x00; 71 | 72 | // https://dev.mysql.com/doc/dev/mysql-server/latest/field__types_8h_source.html 73 | pub const EnumFieldType = enum(u8) { 74 | MYSQL_TYPE_DECIMAL, 75 | MYSQL_TYPE_TINY, 76 | MYSQL_TYPE_SHORT, 77 | MYSQL_TYPE_LONG, 78 | MYSQL_TYPE_FLOAT, 79 | MYSQL_TYPE_DOUBLE, 80 | MYSQL_TYPE_NULL, 81 | MYSQL_TYPE_TIMESTAMP, 82 | MYSQL_TYPE_LONGLONG, 83 | MYSQL_TYPE_INT24, 84 | MYSQL_TYPE_DATE, 85 | MYSQL_TYPE_TIME, 86 | MYSQL_TYPE_DATETIME, 87 | MYSQL_TYPE_YEAR, 88 | MYSQL_TYPE_NEWDATE, 89 | MYSQL_TYPE_VARCHAR, 90 | MYSQL_TYPE_BIT, 91 | MYSQL_TYPE_TIMESTAMP2, 92 | MYSQL_TYPE_DATETIME2, 93 | MYSQL_TYPE_TIME2, 94 | MYSQL_TYPE_TYPED_ARRAY, 95 | 96 | MYSQL_TYPE_INVALID = 243, 97 | MYSQL_TYPE_BOOL = 244, 98 | MYSQL_TYPE_JSON = 245, 99 | MYSQL_TYPE_NEWDECIMAL = 246, 100 | MYSQL_TYPE_ENUM = 247, 101 | MYSQL_TYPE_SET = 248, 102 | MYSQL_TYPE_TINY_BLOB = 249, 103 | MYSQL_TYPE_MEDIUM_BLOB = 250, 104 | MYSQL_TYPE_LONG_BLOB = 251, 105 | MYSQL_TYPE_BLOB = 252, 106 | MYSQL_TYPE_VAR_STRING = 253, 107 | MYSQL_TYPE_STRING = 254, 108 | MYSQL_TYPE_GEOMETRY = 255 109 | }; 110 | 111 | 112 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase_utility.html 113 | pub const COM_QUIT: u8 = 0x01; 114 | pub const COM_INIT_DB: u8 = 0x02; 115 | pub const COM_FIELD_LIST: u8 = 0x04; 116 | pub const COM_REFRESH: u8 = 0x07; 117 | pub const COM_STATISTICS: u8 = 0x08; 118 | pub const COM_PROCESS_INFO: u8 = 0x0a; 119 | pub const COM_PROCESS_KILL: u8 = 0x0c; 120 | pub const COM_DEBUG: u8 = 0x0d; 121 | pub const COM_PING: u8 = 0x0e; 122 | pub const COM_CHANGE_USER: u8 = 0x11; 123 | pub const COM_RESET_CONNECTION: u8 = 0x1f; 124 | pub const COM_SET_OPTION: u8 = 0x1a; 125 | 126 | // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html 127 | pub const NOT_NULL_FLAG: u16 = 1; 128 | pub const PRI_KEY_FLAG: u16 = 2; 129 | pub const UNIQUE_KEY_FLAG: u16 = 4; 130 | pub const MULTIPLE_KEY_FLAG: u16 = 8; 131 | pub const BLOB_FLAG: u16 = 16; 132 | pub const UNSIGNED_FLAG: u16 = 32; 133 | pub const ZEROFILL_FLAG: u16 = 64; 134 | pub const BINARY_FLAG: u16 = 128; 135 | pub const ENUM_FLAG: u16 = 256; 136 | pub const AUTO_INCREMENT_FLAG: u16 = 512; 137 | pub const TIMESTAMP_FLAG: u16 = 1024; 138 | pub const SET_FLAG: u16 = 2048; 139 | pub const NO_DEFAULT_VALUE_FLAG: u16 = 4096; 140 | pub const ON_UPDATE_NOW_FLAG: u16 = 8192; 141 | pub const NUM_FLAG: u16 = 32768; 142 | 143 | pub const PART_KEY_FLAG: u16 = 16384; 144 | pub const GROUP_FLAG: u16 = 32768; 145 | pub const UNIQUE_FLAG: u32 = 65536; 146 | pub const BINCMP_FLAG: u32 = 131072; 147 | pub const GET_FIXED_FIELDS_FLAG: u32 = (1 << 18); 148 | pub const FIELD_IN_PART_FUNC_FLAG: u32 = (1 << 19); 149 | pub const FIELD_IN_ADD_INDEX: u32 = (1 << 20); 150 | pub const FIELD_IS_RENAMED: u32 = (1 << 21); 151 | pub const FIELD_FLAGS_STORAGE_MEDIA: u32 = 22; 152 | pub const FIELD_FLAGS_STORAGE_MEDIA_MASK: u32 = (3 << FIELD_FLAGS_STORAGE_MEDIA); 153 | pub const FIELD_FLAGS_COLUMN_FORMAT: u32 = 24; 154 | pub const FIELD_FLAGS_COLUMN_FORMAT_MASK: u32 = (3 << FIELD_FLAGS_COLUMN_FORMAT); 155 | pub const FIELD_IS_DROPPED: u32 = (1 << 26); 156 | pub const EXPLICIT_NULL_FLAG: u32 = (1 << 27); 157 | pub const NOT_SECONDARY_FLAG: u32 = (1 << 29); 158 | pub const FIELD_IS_INVISIBLE: u32 = (1 << 30); 159 | 160 | // Derive from: 161 | // SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID < 256 ORDER BY ID 162 | pub const big5_chinese_ci: u8 = 1; 163 | pub const latin2_czech_cs: u8 = 2; 164 | pub const dec8_swedish_ci: u8 = 3; 165 | pub const cp850_general_ci: u8 = 4; 166 | pub const latin1_german1_ci: u8 = 5; 167 | pub const hp8_english_ci: u8 = 6; 168 | pub const koi8r_general_ci: u8 = 7; 169 | pub const latin1_swedish_ci: u8 = 8; 170 | pub const latin2_general_ci: u8 = 9; 171 | pub const swe7_swedish_ci: u8 = 10; 172 | pub const ascii_general_ci: u8 = 11; 173 | pub const ujis_japanese_ci: u8 = 12; 174 | pub const sjis_japanese_ci: u8 = 13; 175 | pub const cp1251_bulgarian_ci: u8 = 14; 176 | pub const latin1_danish_ci: u8 = 15; 177 | pub const hebrew_general_ci: u8 = 16; 178 | pub const tis620_thai_ci: u8 = 18; 179 | pub const euckr_korean_ci: u8 = 19; 180 | pub const latin7_estonian_cs: u8 = 20; 181 | pub const latin2_hungarian_ci: u8 = 21; 182 | pub const koi8u_general_ci: u8 = 22; 183 | pub const cp1251_ukrainian_ci: u8 = 23; 184 | pub const gb2312_chinese_ci: u8 = 24; 185 | pub const greek_general_ci: u8 = 25; 186 | pub const cp1250_general_ci: u8 = 26; 187 | pub const latin2_croatian_ci: u8 = 27; 188 | pub const gbk_chinese_ci: u8 = 28; 189 | pub const cp1257_lithuanian_ci: u8 = 29; 190 | pub const latin5_turkish_ci: u8 = 30; 191 | pub const latin1_german2_ci: u8 = 31; 192 | pub const armscii8_general_ci: u8 = 32; 193 | pub const utf8mb3_general_ci: u8 = 33; 194 | pub const cp1250_czech_cs: u8 = 34; 195 | pub const ucs2_general_ci: u8 = 35; 196 | pub const cp866_general_ci: u8 = 36; 197 | pub const keybcs2_general_ci: u8 = 37; 198 | pub const macce_general_ci: u8 = 38; 199 | pub const macroman_general_ci: u8 = 39; 200 | pub const cp852_general_ci: u8 = 40; 201 | pub const latin7_general_ci: u8 = 41; 202 | pub const latin7_general_cs: u8 = 42; 203 | pub const macce_bin: u8 = 43; 204 | pub const cp1250_croatian_ci: u8 = 44; 205 | pub const utf8mb4_general_ci: u8 = 45; 206 | pub const utf8mb4_bin: u8 = 46; 207 | pub const latin1_bin: u8 = 47; 208 | pub const latin1_general_ci: u8 = 48; 209 | pub const latin1_general_cs: u8 = 49; 210 | pub const cp1251_bin: u8 = 50; 211 | pub const cp1251_general_ci: u8 = 51; 212 | pub const cp1251_general_cs: u8 = 52; 213 | pub const macroman_bin: u8 = 53; 214 | pub const utf16_general_ci: u8 = 54; 215 | pub const utf16_bin: u8 = 55; 216 | pub const utf16le_general_ci: u8 = 56; 217 | pub const cp1256_general_ci: u8 = 57; 218 | pub const cp1257_bin: u8 = 58; 219 | pub const cp1257_general_ci: u8 = 59; 220 | pub const utf32_general_ci: u8 = 60; 221 | pub const utf32_bin: u8 = 61; 222 | pub const utf16le_bin: u8 = 62; 223 | pub const binary: u8 = 63; 224 | pub const armscii8_bin: u8 = 64; 225 | pub const ascii_bin: u8 = 65; 226 | pub const cp1250_bin: u8 = 66; 227 | pub const cp1256_bin: u8 = 67; 228 | pub const cp866_bin: u8 = 68; 229 | pub const dec8_bin: u8 = 69; 230 | pub const greek_bin: u8 = 70; 231 | pub const hebrew_bin: u8 = 71; 232 | pub const hp8_bin: u8 = 72; 233 | pub const keybcs2_bin: u8 = 73; 234 | pub const koi8r_bin: u8 = 74; 235 | pub const koi8u_bin: u8 = 75; 236 | pub const utf8mb3_tolower_ci: u8 = 76; 237 | pub const latin2_bin: u8 = 77; 238 | pub const latin5_bin: u8 = 78; 239 | pub const latin7_bin: u8 = 79; 240 | pub const cp850_bin: u8 = 80; 241 | pub const cp852_bin: u8 = 81; 242 | pub const swe7_bin: u8 = 82; 243 | pub const utf8mb3_bin: u8 = 83; 244 | pub const big5_bin: u8 = 84; 245 | pub const euckr_bin: u8 = 85; 246 | pub const gb2312_bin: u8 = 86; 247 | pub const gbk_bin: u8 = 87; 248 | pub const sjis_bin: u8 = 88; 249 | pub const tis620_bin: u8 = 89; 250 | pub const ucs2_bin: u8 = 90; 251 | pub const ujis_bin: u8 = 91; 252 | pub const geostd8_general_ci: u8 = 92; 253 | pub const geostd8_bin: u8 = 93; 254 | pub const latin1_spanish_ci: u8 = 94; 255 | pub const cp932_japanese_ci: u8 = 95; 256 | pub const cp932_bin: u8 = 96; 257 | pub const eucjpms_japanese_ci: u8 = 97; 258 | pub const eucjpms_bin: u8 = 98; 259 | pub const cp1250_polish_ci: u8 = 99; 260 | pub const utf16_unicode_ci: u8 = 101; 261 | pub const utf16_icelandic_ci: u8 = 102; 262 | pub const utf16_latvian_ci: u8 = 103; 263 | pub const utf16_romanian_ci: u8 = 104; 264 | pub const utf16_slovenian_ci: u8 = 105; 265 | pub const utf16_polish_ci: u8 = 106; 266 | pub const utf16_estonian_ci: u8 = 107; 267 | pub const utf16_spanish_ci: u8 = 108; 268 | pub const utf16_swedish_ci: u8 = 109; 269 | pub const utf16_turkish_ci: u8 = 110; 270 | pub const utf16_czech_ci: u8 = 111; 271 | pub const utf16_danish_ci: u8 = 112; 272 | pub const utf16_lithuanian_ci: u8 = 113; 273 | pub const utf16_slovak_ci: u8 = 114; 274 | pub const utf16_spanish2_ci: u8 = 115; 275 | pub const utf16_roman_ci: u8 = 116; 276 | pub const utf16_persian_ci: u8 = 117; 277 | pub const utf16_esperanto_ci: u8 = 118; 278 | pub const utf16_hungarian_ci: u8 = 119; 279 | pub const utf16_sinhala_ci: u8 = 120; 280 | pub const utf16_german2_ci: u8 = 121; 281 | pub const utf16_croatian_ci: u8 = 122; 282 | pub const utf16_unicode_520_ci: u8 = 123; 283 | pub const utf16_vietnamese_ci: u8 = 124; 284 | pub const ucs2_unicode_ci: u8 = 128; 285 | pub const ucs2_icelandic_ci: u8 = 129; 286 | pub const ucs2_latvian_ci: u8 = 130; 287 | pub const ucs2_romanian_ci: u8 = 131; 288 | pub const ucs2_slovenian_ci: u8 = 132; 289 | pub const ucs2_polish_ci: u8 = 133; 290 | pub const ucs2_estonian_ci: u8 = 134; 291 | pub const ucs2_spanish_ci: u8 = 135; 292 | pub const ucs2_swedish_ci: u8 = 136; 293 | pub const ucs2_turkish_ci: u8 = 137; 294 | pub const ucs2_czech_ci: u8 = 138; 295 | pub const ucs2_danish_ci: u8 = 139; 296 | pub const ucs2_lithuanian_ci: u8 = 140; 297 | pub const ucs2_slovak_ci: u8 = 141; 298 | pub const ucs2_spanish2_ci: u8 = 142; 299 | pub const ucs2_roman_ci: u8 = 143; 300 | pub const ucs2_persian_ci: u8 = 144; 301 | pub const ucs2_esperanto_ci: u8 = 145; 302 | pub const ucs2_hungarian_ci: u8 = 146; 303 | pub const ucs2_sinhala_ci: u8 = 147; 304 | pub const ucs2_german2_ci: u8 = 148; 305 | pub const ucs2_croatian_ci: u8 = 149; 306 | pub const ucs2_unicode_520_ci: u8 = 150; 307 | pub const ucs2_vietnamese_ci: u8 = 151; 308 | pub const ucs2_general_mysql500_ci: u8 = 159; 309 | pub const utf32_unicode_ci: u8 = 160; 310 | pub const utf32_icelandic_ci: u8 = 161; 311 | pub const utf32_latvian_ci: u8 = 162; 312 | pub const utf32_romanian_ci: u8 = 163; 313 | pub const utf32_slovenian_ci: u8 = 164; 314 | pub const utf32_polish_ci: u8 = 165; 315 | pub const utf32_estonian_ci: u8 = 166; 316 | pub const utf32_spanish_ci: u8 = 167; 317 | pub const utf32_swedish_ci: u8 = 168; 318 | pub const utf32_turkish_ci: u8 = 169; 319 | pub const utf32_czech_ci: u8 = 170; 320 | pub const utf32_danish_ci: u8 = 171; 321 | pub const utf32_lithuanian_ci: u8 = 172; 322 | pub const utf32_slovak_ci: u8 = 173; 323 | pub const utf32_spanish2_ci: u8 = 174; 324 | pub const utf32_roman_ci: u8 = 175; 325 | pub const utf32_persian_ci: u8 = 176; 326 | pub const utf32_esperanto_ci: u8 = 177; 327 | pub const utf32_hungarian_ci: u8 = 178; 328 | pub const utf32_sinhala_ci: u8 = 179; 329 | pub const utf32_german2_ci: u8 = 180; 330 | pub const utf32_croatian_ci: u8 = 181; 331 | pub const utf32_unicode_520_ci: u8 = 182; 332 | pub const utf32_vietnamese_ci: u8 = 183; 333 | pub const utf8mb3_unicode_ci: u8 = 192; 334 | pub const utf8mb3_icelandic_ci: u8 = 193; 335 | pub const utf8mb3_latvian_ci: u8 = 194; 336 | pub const utf8mb3_romanian_ci: u8 = 195; 337 | pub const utf8mb3_slovenian_ci: u8 = 196; 338 | pub const utf8mb3_polish_ci: u8 = 197; 339 | pub const utf8mb3_estonian_ci: u8 = 198; 340 | pub const utf8mb3_spanish_ci: u8 = 199; 341 | pub const utf8mb3_swedish_ci: u8 = 200; 342 | pub const utf8mb3_turkish_ci: u8 = 201; 343 | pub const utf8mb3_czech_ci: u8 = 202; 344 | pub const utf8mb3_danish_ci: u8 = 203; 345 | pub const utf8mb3_lithuanian_ci: u8 = 204; 346 | pub const utf8mb3_slovak_ci: u8 = 205; 347 | pub const utf8mb3_spanish2_ci: u8 = 206; 348 | pub const utf8mb3_roman_ci: u8 = 207; 349 | pub const utf8mb3_persian_ci: u8 = 208; 350 | pub const utf8mb3_esperanto_ci: u8 = 209; 351 | pub const utf8mb3_hungarian_ci: u8 = 210; 352 | pub const utf8mb3_sinhala_ci: u8 = 211; 353 | pub const utf8mb3_german2_ci: u8 = 212; 354 | pub const utf8mb3_croatian_ci: u8 = 213; 355 | pub const utf8mb3_unicode_520_ci: u8 = 214; 356 | pub const utf8mb3_vietnamese_ci: u8 = 215; 357 | pub const utf8mb3_general_mysql500_ci: u8 = 223; 358 | pub const utf8mb4_unicode_ci: u8 = 224; 359 | pub const utf8mb4_icelandic_ci: u8 = 225; 360 | pub const utf8mb4_latvian_ci: u8 = 226; 361 | pub const utf8mb4_romanian_ci: u8 = 227; 362 | pub const utf8mb4_slovenian_ci: u8 = 228; 363 | pub const utf8mb4_polish_ci: u8 = 229; 364 | pub const utf8mb4_estonian_ci: u8 = 230; 365 | pub const utf8mb4_spanish_ci: u8 = 231; 366 | pub const utf8mb4_swedish_ci: u8 = 232; 367 | pub const utf8mb4_turkish_ci: u8 = 233; 368 | pub const utf8mb4_czech_ci: u8 = 234; 369 | pub const utf8mb4_danish_ci: u8 = 235; 370 | pub const utf8mb4_lithuanian_ci: u8 = 236; 371 | pub const utf8mb4_slovak_ci: u8 = 237; 372 | pub const utf8mb4_spanish2_ci: u8 = 238; 373 | pub const utf8mb4_roman_ci: u8 = 239; 374 | pub const utf8mb4_persian_ci: u8 = 240; 375 | pub const utf8mb4_esperanto_ci: u8 = 241; 376 | pub const utf8mb4_hungarian_ci: u8 = 242; 377 | pub const utf8mb4_sinhala_ci: u8 = 243; 378 | pub const utf8mb4_german2_ci: u8 = 244; 379 | pub const utf8mb4_croatian_ci: u8 = 245; 380 | pub const utf8mb4_unicode_520_ci: u8 = 246; 381 | pub const utf8mb4_vietnamese_ci: u8 = 247; 382 | pub const gb18030_chinese_ci: u8 = 248; 383 | pub const gb18030_bin: u8 = 249; 384 | pub const gb18030_unicode_520_ci: u8 = 250; 385 | pub const utf8mb4_0900_ai_ci: u8 = 255; 386 | -------------------------------------------------------------------------------- /src/conversion.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const PayloadReader = @import("./protocol/packet.zig").PayloadReader; 3 | const ColumnDefinition41 = @import("./protocol/column_definition.zig").ColumnDefinition41; 4 | const constants = @import("./constants.zig"); 5 | const DateTime = @import("./temporal.zig").DateTime; 6 | const Duration = @import("./temporal.zig").Duration; 7 | const EnumFieldType = @import("./constants.zig").EnumFieldType; 8 | const Packet = @import("./protocol/packet.zig").Packet; 9 | 10 | // dest is a pointer to a struct 11 | pub fn scanBinResultRow(dest: anytype, packet: *const Packet, col_defs: []const ColumnDefinition41, allocator: ?std.mem.Allocator) !void { 12 | var reader = packet.reader(); 13 | const first = reader.readByte(); 14 | std.debug.assert(first == constants.BINARY_PROTOCOL_RESULTSET_ROW_HEADER); 15 | 16 | // null bitmap 17 | const null_bitmap_len = (col_defs.len + 7 + 2) / 8; 18 | const null_bitmap = reader.readRefRuntime(null_bitmap_len); 19 | 20 | const child_type = @typeInfo(@TypeOf(dest)).pointer.child; 21 | const struct_fields = @typeInfo(child_type).@"struct".fields; 22 | 23 | if (struct_fields.len != col_defs.len) { 24 | std.log.err("received {d} columns from mysql, but given {d} fields for struct", .{ struct_fields.len, col_defs.len }); 25 | return error.ColumnAndFieldCountMismatch; 26 | } 27 | 28 | inline for (struct_fields, col_defs, 0..) |field, col_def, i| { 29 | const field_info = @typeInfo(field.type); 30 | const isNull = binResIsNull(null_bitmap, i); 31 | 32 | switch (field_info) { 33 | .optional => { 34 | if (isNull) { 35 | @field(dest, field.name) = null; 36 | } else { 37 | @field(dest, field.name) = try binElemToValue(field_info.optional.child, field.name, &col_def, &reader, allocator); 38 | } 39 | }, 40 | else => { 41 | if (isNull) { 42 | std.log.err("column: {s} value is null, but field: {s} is not nullable\n", .{ col_def.name, field.name }); 43 | return error.UnexpectedNullMySQLValue; 44 | } 45 | @field(dest, field.name) = try binElemToValue(field.type, field.name, &col_def, &reader, allocator); 46 | }, 47 | } 48 | } 49 | std.debug.assert(reader.finished()); 50 | } 51 | 52 | fn decodeDateTime(reader: *PayloadReader) DateTime { 53 | const length = reader.readByte(); 54 | switch (length) { 55 | 11 => return .{ 56 | .year = reader.readInt(u16), 57 | .month = reader.readByte(), 58 | .day = reader.readByte(), 59 | .hour = reader.readByte(), 60 | .minute = reader.readByte(), 61 | .second = reader.readByte(), 62 | .microsecond = reader.readInt(u32), 63 | }, 64 | 7 => return .{ 65 | .year = reader.readInt(u16), 66 | .month = reader.readByte(), 67 | .day = reader.readByte(), 68 | .hour = reader.readByte(), 69 | .minute = reader.readByte(), 70 | .second = reader.readByte(), 71 | }, 72 | 4 => return .{ 73 | .year = reader.readInt(u16), 74 | .month = reader.readByte(), 75 | .day = reader.readByte(), 76 | }, 77 | 0 => return .{}, 78 | else => unreachable, 79 | } 80 | } 81 | 82 | fn decodeDuration(reader: *PayloadReader) Duration { 83 | const length = reader.readByte(); 84 | switch (length) { 85 | 12 => return .{ 86 | .is_negative = reader.readByte(), 87 | .days = reader.readInt(u32), 88 | .hours = reader.readByte(), 89 | .minutes = reader.readByte(), 90 | .seconds = reader.readByte(), 91 | .microseconds = reader.readInt(u32), 92 | }, 93 | 8 => return .{ 94 | .is_negative = reader.readByte(), 95 | .days = reader.readInt(u32), 96 | .hours = reader.readByte(), 97 | .minutes = reader.readByte(), 98 | .seconds = reader.readByte(), 99 | }, 100 | 0 => return .{}, 101 | else => { 102 | unreachable; 103 | }, 104 | } 105 | } 106 | 107 | inline fn logConversionError( 108 | comptime FieldType: type, 109 | comptime field_name: []const u8, 110 | col_name: []const u8, 111 | col_type: EnumFieldType, 112 | ) void { 113 | std.log.err( 114 | "Conversion Error: MySQL Column: (name: {s}, type: {any}), Zig Value: (name: {s}, type: {any})\n", 115 | .{ col_name, col_type, field_name, FieldType }, 116 | ); 117 | } 118 | 119 | inline fn binElemToValue( 120 | comptime FieldType: type, 121 | comptime field_name: []const u8, 122 | col_def: *const ColumnDefinition41, 123 | reader: *PayloadReader, 124 | allocator: ?std.mem.Allocator, 125 | ) !FieldType { 126 | const field_info = @typeInfo(FieldType); 127 | const col_type: EnumFieldType = @enumFromInt(col_def.column_type); 128 | 129 | switch (FieldType) { 130 | DateTime => { 131 | switch (col_type) { 132 | .MYSQL_TYPE_DATE, 133 | .MYSQL_TYPE_DATETIME, 134 | .MYSQL_TYPE_TIMESTAMP, 135 | => return decodeDateTime(reader), 136 | else => {}, 137 | } 138 | }, 139 | Duration => { 140 | switch (col_type) { 141 | .MYSQL_TYPE_TIME => return decodeDuration(reader), 142 | else => {}, 143 | } 144 | }, 145 | else => {}, 146 | } 147 | 148 | switch (field_info) { 149 | .pointer => |pointer| { 150 | switch (@typeInfo(pointer.child)) { 151 | .int => |int| { 152 | if (int.bits == 8) { 153 | switch (col_type) { 154 | .MYSQL_TYPE_STRING, 155 | .MYSQL_TYPE_VARCHAR, 156 | .MYSQL_TYPE_VAR_STRING, 157 | .MYSQL_TYPE_ENUM, 158 | .MYSQL_TYPE_SET, 159 | .MYSQL_TYPE_LONG_BLOB, 160 | .MYSQL_TYPE_MEDIUM_BLOB, 161 | .MYSQL_TYPE_BLOB, 162 | .MYSQL_TYPE_TINY_BLOB, 163 | .MYSQL_TYPE_GEOMETRY, 164 | .MYSQL_TYPE_BIT, 165 | .MYSQL_TYPE_DECIMAL, 166 | .MYSQL_TYPE_NEWDECIMAL, 167 | => { 168 | const str = reader.readLengthEncodedString(); 169 | if (allocator) |a| { 170 | if (pointer.sentinel()) |_| { 171 | return try a.dupeZ(u8, str); 172 | } 173 | return try a.dupe(u8, str); 174 | } 175 | return str; 176 | }, 177 | else => {}, 178 | } 179 | } 180 | }, 181 | else => {}, 182 | } 183 | }, 184 | .@"struct" => |s| { 185 | inline for (s.fields) |field| { 186 | if (std.mem.eql(u8, field.name, "buffer")) { 187 | const info = @typeInfo(field.type); 188 | switch (info) { 189 | .array => |array| { 190 | if (FieldType == std.BoundedArray(u8, array.len)) { 191 | switch (col_type) { 192 | .MYSQL_TYPE_STRING, 193 | .MYSQL_TYPE_VARCHAR, 194 | .MYSQL_TYPE_VAR_STRING, 195 | .MYSQL_TYPE_ENUM, 196 | .MYSQL_TYPE_SET, 197 | .MYSQL_TYPE_LONG_BLOB, 198 | .MYSQL_TYPE_MEDIUM_BLOB, 199 | .MYSQL_TYPE_BLOB, 200 | .MYSQL_TYPE_TINY_BLOB, 201 | .MYSQL_TYPE_GEOMETRY, 202 | .MYSQL_TYPE_BIT, 203 | .MYSQL_TYPE_DECIMAL, 204 | .MYSQL_TYPE_NEWDECIMAL, 205 | => { 206 | const str = reader.readLengthEncodedString(); 207 | var ret = try std.BoundedArray(u8, array.len).init(0); 208 | try ret.appendSlice(str[0..@min(str.len, array.len)]); 209 | return ret; 210 | }, 211 | else => {}, 212 | } 213 | } 214 | }, 215 | else => {}, 216 | } 217 | } 218 | break; 219 | } 220 | }, 221 | 222 | .@"enum" => |e| { 223 | switch (col_type) { 224 | .MYSQL_TYPE_STRING, 225 | .MYSQL_TYPE_VARCHAR, 226 | .MYSQL_TYPE_VAR_STRING, 227 | .MYSQL_TYPE_ENUM, 228 | .MYSQL_TYPE_SET, 229 | .MYSQL_TYPE_LONG_BLOB, 230 | .MYSQL_TYPE_MEDIUM_BLOB, 231 | .MYSQL_TYPE_BLOB, 232 | .MYSQL_TYPE_TINY_BLOB, 233 | .MYSQL_TYPE_GEOMETRY, 234 | .MYSQL_TYPE_BIT, 235 | .MYSQL_TYPE_DECIMAL, 236 | .MYSQL_TYPE_NEWDECIMAL, 237 | => { 238 | const str = reader.readLengthEncodedString(); 239 | inline for (e.fields) |f| { 240 | if (std.mem.eql(u8, str, f.name)) { 241 | return @field(FieldType, f.name); 242 | } 243 | } 244 | std.log.err( 245 | "received string: {s} from mysql, but could not find tag from enum: {s}, field name: {s}\n", 246 | .{ str, @typeName(FieldType), field_name }, 247 | ); 248 | }, 249 | else => {}, 250 | } 251 | }, 252 | .int => |int| { 253 | switch (int.signedness) { 254 | .unsigned => { 255 | switch (col_type) { 256 | .MYSQL_TYPE_LONGLONG => return @intCast(reader.readInt(u64)), 257 | 258 | .MYSQL_TYPE_LONG, 259 | .MYSQL_TYPE_INT24, 260 | => return @intCast(reader.readInt(u32)), 261 | 262 | .MYSQL_TYPE_SHORT, 263 | .MYSQL_TYPE_YEAR, 264 | => return @intCast(reader.readInt(u16)), 265 | 266 | .MYSQL_TYPE_TINY => return @intCast(reader.readByte()), 267 | 268 | else => {}, 269 | } 270 | }, 271 | .signed => { 272 | switch (col_type) { 273 | .MYSQL_TYPE_LONGLONG => return @intCast(@as(i64, @bitCast(reader.readInt(u64)))), 274 | 275 | .MYSQL_TYPE_LONG, 276 | .MYSQL_TYPE_INT24, 277 | => return @intCast(@as(i32, @bitCast(reader.readInt(u32)))), 278 | 279 | .MYSQL_TYPE_SHORT, 280 | .MYSQL_TYPE_YEAR, 281 | => return @intCast(@as(i16, @bitCast(reader.readInt(u16)))), 282 | 283 | .MYSQL_TYPE_TINY => return @intCast(@as(i8, @bitCast(reader.readByte()))), 284 | 285 | else => {}, 286 | } 287 | }, 288 | } 289 | }, 290 | .float => |float| { 291 | if (float.bits >= 64) { 292 | switch (col_type) { 293 | .MYSQL_TYPE_DOUBLE => return @as(f64, @bitCast(reader.readInt(u64))), 294 | .MYSQL_TYPE_FLOAT => return @as(f32, @bitCast(reader.readInt(u32))), 295 | else => {}, 296 | } 297 | } 298 | if (float.bits >= 32) { 299 | switch (col_type) { 300 | .MYSQL_TYPE_FLOAT => return @as(f32, @bitCast(reader.readInt(u32))), 301 | else => {}, 302 | } 303 | } 304 | }, 305 | .array => |array| { 306 | switch (@typeInfo(array.child)) { 307 | .int => |int| { 308 | if (int.bits == 8) { 309 | switch (col_type) { 310 | .MYSQL_TYPE_STRING, 311 | .MYSQL_TYPE_VARCHAR, 312 | .MYSQL_TYPE_VAR_STRING, 313 | .MYSQL_TYPE_ENUM, 314 | .MYSQL_TYPE_SET, 315 | .MYSQL_TYPE_LONG_BLOB, 316 | .MYSQL_TYPE_MEDIUM_BLOB, 317 | .MYSQL_TYPE_BLOB, 318 | .MYSQL_TYPE_TINY_BLOB, 319 | .MYSQL_TYPE_GEOMETRY, 320 | .MYSQL_TYPE_BIT, 321 | .MYSQL_TYPE_DECIMAL, 322 | .MYSQL_TYPE_NEWDECIMAL, 323 | => { 324 | const str = reader.readLengthEncodedString(); 325 | if (array.sentinel()) |sentinel| { 326 | var ret: [array.len:sentinel]u8 = [_:sentinel]u8{sentinel} ** array.len; 327 | const min = @min(str.len, array.len); 328 | @memcpy(ret[0..min], str[0..min]); 329 | return ret; 330 | } else { 331 | var ret: [array.len]u8 = [_]u8{0} ** array.len; 332 | const min = @min(str.len, array.len); 333 | @memcpy(ret[0..min], str[0..min]); 334 | return ret; 335 | } 336 | }, 337 | else => {}, 338 | } 339 | } 340 | }, 341 | else => {}, 342 | } 343 | }, 344 | else => {}, 345 | } 346 | 347 | logConversionError(FieldType, field_name, col_def.name, col_type); 348 | return error.IncompatibleBinaryConversion; 349 | } 350 | 351 | inline fn binResIsNull(null_bitmap: []const u8, col_idx: usize) bool { 352 | // TODO: optimize: divmod 353 | const byte_idx = (col_idx + 2) / 8; 354 | const bit_idx = (col_idx + 2) % 8; 355 | const byte = null_bitmap[byte_idx]; 356 | return (byte & (1 << bit_idx)) > 0; 357 | } 358 | 359 | test "binResIsNull" { 360 | const tests = .{ 361 | .{ 362 | .null_bitmap = &.{0b00000100}, 363 | .col_idx = 0, 364 | .expected = true, 365 | }, 366 | .{ 367 | .null_bitmap = &.{0b00000000}, 368 | .col_idx = 0, 369 | .expected = false, 370 | }, 371 | .{ 372 | .null_bitmap = &.{ 0b00000000, 0b00000001 }, 373 | .col_idx = 6, 374 | .expected = true, 375 | }, 376 | .{ 377 | .null_bitmap = &.{ 0b10000000, 0b00000000 }, 378 | .col_idx = 5, 379 | .expected = true, 380 | }, 381 | }; 382 | 383 | inline for (tests) |t| { 384 | const actual = binResIsNull(t.null_bitmap, t.col_idx); 385 | try std.testing.expectEqual(t.expected, actual); 386 | } 387 | } 388 | -------------------------------------------------------------------------------- /src/myzql.zig: -------------------------------------------------------------------------------- 1 | pub const config = @import("./config.zig"); 2 | pub const constants = @import("./constants.zig"); 3 | pub const conn = @import("./conn.zig"); 4 | pub const pool = @import("./pool.zig"); 5 | pub const protocol = @import("./protocol.zig"); 6 | pub const temporal = @import("./temporal.zig"); 7 | pub const result = @import("./result.zig"); 8 | 9 | test { 10 | @import("std").testing.refAllDeclsRecursive(@This()); 11 | } 12 | -------------------------------------------------------------------------------- /src/pool.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Config = @import("./config.zig").Config; 3 | const protocol = @import("./protocol.zig"); 4 | const Conn = @import("./conn.zig").Conn; 5 | const result = @import("./result.zig"); 6 | 7 | // TODO: Pool 8 | // pub const Pool = struct { 9 | // config: Config, 10 | // conn: Conn, 11 | // 12 | // pub fn init(config: Config) Pool { 13 | // return .{ 14 | // .config = config, 15 | // .conn = .{}, 16 | // }; 17 | // } 18 | // 19 | // // TODO: 20 | // }; 21 | -------------------------------------------------------------------------------- /src/protocol.zig: -------------------------------------------------------------------------------- 1 | pub const packet_reader = @import("./protocol/packet_reader.zig"); 2 | pub const packet_writer = @import("./protocol/packet_writer.zig"); 3 | pub const packet = @import("./protocol/packet.zig"); 4 | pub const generic_response = @import("./protocol/generic_response.zig"); 5 | pub const handshake_v10 = @import("./protocol/handshake_v10.zig"); 6 | pub const handshake_response = @import("./protocol/handshake_response.zig"); 7 | pub const auth_switch_request = @import("./protocol/auth_switch_request.zig"); 8 | pub const text_command = @import("./protocol/text_command.zig"); 9 | pub const prepared_statements = @import("./protocol/prepared_statements.zig"); 10 | pub const column_definition = @import("./protocol/column_definition.zig"); 11 | -------------------------------------------------------------------------------- /src/protocol/auth_switch_request.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("../constants.zig"); 3 | const Packet = @import("./packet.zig").Packet; 4 | 5 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html 6 | pub const AuthSwitchRequest = struct { 7 | plugin_name: [:0]const u8, 8 | plugin_data: []const u8, 9 | 10 | pub fn initFromPacket(packet: *const Packet) AuthSwitchRequest { 11 | var auth_switch_request: AuthSwitchRequest = undefined; 12 | var reader = packet.reader(); 13 | const header = reader.readByte(); 14 | std.debug.assert(header == constants.AUTH_SWITCH); 15 | auth_switch_request.plugin_name = reader.readNullTerminatedString(); 16 | auth_switch_request.plugin_data = reader.readRefRemaining(); 17 | return auth_switch_request; 18 | } 19 | }; 20 | -------------------------------------------------------------------------------- /src/protocol/column_definition.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("../constants.zig"); 3 | const Packet = @import("./packet.zig").Packet; 4 | const PacketReader = @import("./packet_reader.zig").PayloadReader; 5 | 6 | pub const ColumnDefinition41 = struct { 7 | catalog: []const u8, 8 | schema: []const u8, 9 | table: []const u8, 10 | org_table: []const u8, 11 | name: []const u8, 12 | org_name: []const u8, 13 | fixed_length_fields_length: u64, 14 | character_set: u16, 15 | column_length: u32, 16 | column_type: u8, 17 | flags: u16, 18 | decimals: u8, 19 | 20 | pub fn init(packet: *const Packet) ColumnDefinition41 { 21 | var column_definition_41: ColumnDefinition41 = undefined; 22 | column_definition_41.init2(packet); 23 | return column_definition_41; 24 | } 25 | 26 | pub fn init2(c: *ColumnDefinition41, packet: *const Packet) void { 27 | var reader = packet.reader(); 28 | 29 | c.catalog = reader.readLengthEncodedString(); 30 | c.schema = reader.readLengthEncodedString(); 31 | c.table = reader.readLengthEncodedString(); 32 | c.org_table = reader.readLengthEncodedString(); 33 | c.name = reader.readLengthEncodedString(); 34 | c.org_name = reader.readLengthEncodedString(); 35 | c.fixed_length_fields_length = reader.readLengthEncodedInteger(); 36 | c.character_set = reader.readInt(u16); 37 | c.column_length = reader.readInt(u32); 38 | c.column_type = reader.readByte(); 39 | c.flags = reader.readInt(u16); 40 | c.decimals = reader.readByte(); 41 | 42 | // https://mariadb.com/kb/en/result-set-packets/#column-definition-packet 43 | // According to mariadb, there seem to be extra 2 bytes at the end that is not being used 44 | std.debug.assert(reader.remained() == 2); 45 | } 46 | }; 47 | -------------------------------------------------------------------------------- /src/protocol/generic_response.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("../constants.zig"); 3 | const Packet = @import("./packet.zig").Packet; 4 | const PacketReader = @import("./packet_reader.zig").PacketReader; 5 | 6 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html 7 | pub const ErrorPacket = struct { 8 | error_code: u16, 9 | sql_state_marker: u8, 10 | sql_state: *const [5]u8, 11 | error_message: []const u8, 12 | 13 | pub fn initFirst(packet: *const Packet) ErrorPacket { 14 | var reader = packet.reader(); 15 | const header = reader.readByte(); 16 | std.debug.assert(header == constants.ERR); 17 | 18 | var error_packet: ErrorPacket = undefined; 19 | error_packet.error_code = reader.readInt(u16); 20 | error_packet.error_message = reader.readRefRemaining(); 21 | return error_packet; 22 | } 23 | 24 | pub fn init(packet: *const Packet) ErrorPacket { 25 | var reader = packet.reader(); 26 | const header = reader.readByte(); 27 | std.debug.assert(header == constants.ERR); 28 | 29 | var error_packet: ErrorPacket = undefined; 30 | error_packet.error_code = reader.readInt(u16); 31 | 32 | // CLIENT_PROTOCOL_41 33 | error_packet.sql_state_marker = reader.readByte(); 34 | error_packet.sql_state = reader.readRefComptime(5); 35 | 36 | error_packet.error_message = reader.readRefRemaining(); 37 | return error_packet; 38 | } 39 | 40 | pub fn asError(err: *const ErrorPacket) error{ErrorPacket} { 41 | // TODO: better way to do this? 42 | std.log.warn( 43 | "error packet: (code: {d}, message: {s})", 44 | .{ err.error_code, err.error_message }, 45 | ); 46 | return error.ErrorPacket; 47 | } 48 | }; 49 | 50 | //https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html 51 | pub const OkPacket = struct { 52 | affected_rows: u64, 53 | last_insert_id: u64, 54 | status_flags: ?u16, 55 | warnings: ?u16, 56 | info: ?[]const u8, 57 | session_state_info: ?[]const u8, 58 | 59 | pub fn init(packet: *const Packet, capabilities: u32) OkPacket { 60 | var ok_packet: OkPacket = undefined; 61 | 62 | var reader = packet.reader(); 63 | const header = reader.readByte(); 64 | std.debug.assert(header == constants.OK or header == constants.EOF); 65 | 66 | ok_packet.affected_rows = reader.readLengthEncodedInteger(); 67 | ok_packet.last_insert_id = reader.readLengthEncodedInteger(); 68 | 69 | // CLIENT_PROTOCOL_41 70 | ok_packet.status_flags = reader.readInt(u16); 71 | ok_packet.warnings = reader.readInt(u16); 72 | 73 | ok_packet.session_state_info = null; 74 | if (capabilities & constants.CLIENT_SESSION_TRACK > 0) { 75 | ok_packet.info = reader.readLengthEncodedString(); 76 | if (ok_packet.status_flags) |sf| { 77 | if (sf & constants.SERVER_SESSION_STATE_CHANGED > 0) { 78 | ok_packet.session_state_info = reader.readLengthEncodedString(); 79 | } 80 | } 81 | } else { 82 | ok_packet.info = reader.readRefRemaining(); 83 | } 84 | 85 | std.debug.assert(reader.finished()); 86 | return ok_packet; 87 | } 88 | }; 89 | -------------------------------------------------------------------------------- /src/protocol/handshake_response.zig: -------------------------------------------------------------------------------- 1 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html 2 | // https://mariadb.com/kb/en/connection/#client-handshake-response 3 | const packer_writer = @import("./packet_writer.zig"); 4 | const std = @import("std"); 5 | const constants = @import("../constants.zig"); 6 | const Config = @import("../config.zig").Config; 7 | const AuthPlugin = @import("../auth.zig").AuthPlugin; 8 | const PacketWriter = @import("./packet_writer.zig").PacketWriter; 9 | 10 | pub const HandshakeResponse41 = struct { 11 | client_flag: u32, // capabilities 12 | max_packet_size: u32 = 0, // TODO: support configurable max packet size 13 | character_set: u8, 14 | username: [:0]const u8, 15 | auth_response: []const u8, 16 | database: [:0]const u8, 17 | client_plugin_name: [:0]const u8, 18 | key_values: []const [2][]const u8 = &.{}, 19 | zstd_compression_level: u8 = 0, 20 | 21 | pub fn init(comptime auth_plugin: AuthPlugin, config: *const Config, auth_resp: []const u8) HandshakeResponse41 { 22 | return .{ 23 | .database = config.database, 24 | .client_flag = config.capability_flags(), 25 | .character_set = config.collation, 26 | .username = config.username, 27 | .auth_response = auth_resp, 28 | .client_plugin_name = auth_plugin.toName(), 29 | }; 30 | } 31 | 32 | pub fn write(h: *const HandshakeResponse41, writer: *PacketWriter) !void { 33 | try writer.writeInt(u32, h.client_flag); 34 | 35 | try writer.writeInt(u32, h.max_packet_size); 36 | try writer.writeInt(u8, h.character_set); 37 | try writer.write(&([_]u8{0} ** 23)); // filler 38 | try writer.writeNullTerminatedString(h.username); 39 | 40 | if ((h.client_flag & constants.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) > 0) { 41 | try writer.writeLengthEncodedString(h.auth_response); 42 | } else if ((h.client_flag & constants.CLIENT_SECURE_CONNECTION) > 0) { 43 | const length: u8 = @intCast(h.auth_response.len); 44 | try writer.writeInt(u8, length); 45 | try writer.write(h.auth_response); 46 | } else { 47 | try writer.write(h.auth_response); 48 | try writer.writeInt(u8, 0); 49 | } 50 | if ((h.client_flag & constants.CLIENT_CONNECT_WITH_DB) > 0) { 51 | try writer.writeNullTerminatedString(h.database); 52 | } 53 | if ((h.client_flag & constants.CLIENT_PLUGIN_AUTH) > 0) { 54 | try writer.writeNullTerminatedString(h.client_plugin_name); 55 | } 56 | if ((h.client_flag & constants.CLIENT_CONNECT_ATTRS) > 0) { 57 | try writer.writeLengthEncodedInteger(h.key_values.len); 58 | for (h.key_values) |key_value| { 59 | try writer.writeLengthEncodedString(key_value[0]); 60 | try writer.writeLengthEncodedString(key_value[1]); 61 | } 62 | } 63 | if ((h.client_flag & constants.CLIENT_ZSTD_COMPRESSION_ALGORITHM) > 0) { 64 | try writer.writeInt(u8, h.zstd_compression_level); 65 | } 66 | } 67 | }; 68 | -------------------------------------------------------------------------------- /src/protocol/handshake_v10.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Packet = @import("./packet.zig").Packet; 3 | const constants = @import("../constants.zig"); 4 | const AuthPlugin = @import("../auth.zig").AuthPlugin; 5 | const PayloadReader = @import("./packet.zig").PayloadReader; 6 | 7 | // https://mariadb.com/kb/en/connection/#initial-handshake-packet 8 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html 9 | pub const HandshakeV10 = struct { 10 | server_version: [:0]const u8, 11 | connection_id: u32, 12 | auth_plugin_data_part_1: *const [8]u8, 13 | capability_flags_1: u16, 14 | character_set: u8, 15 | status_flags: u16, 16 | capability_flags_2: u16, 17 | auth_plugin_data_len: ?u8, 18 | auth_plugin_data_part_2: [:0]const u8, 19 | auth_plugin_name: ?[:0]const u8, 20 | 21 | pub fn init(packet: *const Packet) HandshakeV10 { 22 | var reader = packet.reader(); 23 | var handshake_v10: HandshakeV10 = undefined; 24 | 25 | const protocol_version = reader.readByte(); 26 | std.debug.assert(protocol_version == constants.HANDSHAKE_V10); 27 | 28 | handshake_v10.server_version = reader.readNullTerminatedString(); 29 | handshake_v10.connection_id = reader.readInt(u32); 30 | handshake_v10.auth_plugin_data_part_1 = reader.readRefComptime(8); 31 | _ = reader.readByte(); // filler 32 | handshake_v10.capability_flags_1 = reader.readInt(u16); 33 | handshake_v10.character_set = reader.readByte(); 34 | handshake_v10.status_flags = reader.readInt(u16); 35 | handshake_v10.capability_flags_2 = reader.readInt(u16); 36 | 37 | handshake_v10.auth_plugin_data_len = reader.readByte(); 38 | 39 | // mariadb or mysql specific, ignore for now 40 | reader.skipComptime(10); 41 | 42 | // This part ambiguous in mariadb and mysql, 43 | // It seems like null terminated string works for both, at least for now 44 | // https://mariadb.com/kb/en/connection/#initial-handshake-packet 45 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html 46 | handshake_v10.auth_plugin_data_part_2 = reader.readNullTerminatedString(); 47 | if (handshake_v10.capability_flags() & constants.CLIENT_PLUGIN_AUTH > 0) { 48 | handshake_v10.auth_plugin_name = reader.readNullTerminatedString(); 49 | } else { 50 | handshake_v10.auth_plugin_name = null; 51 | } 52 | 53 | std.debug.assert(reader.finished()); 54 | return handshake_v10; 55 | } 56 | 57 | pub fn capability_flags(h: *const HandshakeV10) u32 { 58 | var f: u32 = h.capability_flags_2; 59 | f <<= 16; 60 | f |= h.capability_flags_1; 61 | return f; 62 | } 63 | 64 | pub fn get_auth_plugin(h: *const HandshakeV10) AuthPlugin { 65 | const name = h.auth_plugin_name orelse return .unspecified; 66 | return AuthPlugin.fromName(name); 67 | } 68 | 69 | pub fn get_auth_data(h: *const HandshakeV10) [20]u8 { 70 | const length = h.auth_plugin_data_part_1.len + h.auth_plugin_data_part_2.len; 71 | std.debug.assert(length <= 20); 72 | var auth_data: [20]u8 = undefined; 73 | 74 | const part_1_len = h.auth_plugin_data_part_1.len; 75 | @memcpy(auth_data[0..part_1_len], h.auth_plugin_data_part_1); 76 | @memcpy(auth_data[part_1_len..], h.auth_plugin_data_part_2); 77 | return auth_data; 78 | } 79 | }; 80 | -------------------------------------------------------------------------------- /src/protocol/packet.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("../constants.zig"); 3 | const ErrorPacket = @import("./generic_response.zig").ErrorPacket; 4 | // const PacketReader = @import("./packet_reader.zig").PacketReader; 5 | 6 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html#sect_protocol_basic_packets_packet 7 | pub const Packet = struct { 8 | sequence_id: u8, 9 | payload: []const u8, 10 | 11 | pub fn init(sequence_id: u8, payload: []const u8) Packet { 12 | return .{ .sequence_id = sequence_id, .payload = payload }; 13 | } 14 | 15 | pub fn asError(packet: *const Packet) error{ UnexpectedPacket, ErrorPacket } { 16 | if (packet.payload[0] == constants.ERR) { 17 | return ErrorPacket.init(packet).asError(); 18 | } 19 | std.log.warn("unexpected packet: {any}", .{packet}); 20 | return error.UnexpectedPacket; 21 | } 22 | 23 | pub fn reader(packet: *const Packet) PayloadReader { 24 | return PayloadReader.init(packet.payload); 25 | } 26 | 27 | pub fn cloneAlloc(packet: *const Packet, allocator: std.mem.Allocator) !Packet { 28 | const payload_copy = try allocator.dupe(u8, packet.payload); 29 | return .{ .sequence_id = packet.sequence_id, .payload = payload_copy }; 30 | } 31 | 32 | pub fn deinit(packet: *const Packet, allocator: std.mem.Allocator) void { 33 | allocator.free(packet.payload); 34 | } 35 | }; 36 | 37 | pub const PayloadReader = struct { 38 | payload: []const u8, 39 | pos: usize, 40 | 41 | fn init(payload: []const u8) PayloadReader { 42 | return .{ .payload = payload, .pos = 0 }; 43 | } 44 | 45 | pub fn peek(p: *const PayloadReader) ?u8 { 46 | std.debug.assert(p.pos <= p.payload.len); 47 | if (p.pos == p.payload.len) { 48 | return null; 49 | } 50 | return p.payload[p.pos]; 51 | } 52 | 53 | pub fn readByte(p: *PayloadReader) u8 { 54 | std.debug.assert(p.pos <= p.payload.len); 55 | const byte = p.payload[p.pos]; 56 | p.pos += 1; 57 | return byte; 58 | } 59 | 60 | pub fn readInt(p: *PayloadReader, Int: type) Int { 61 | const bytes = p.readRefComptime(@divExact(@typeInfo(Int).int.bits, 8)); 62 | return std.mem.readInt(Int, bytes, .little); 63 | } 64 | 65 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_strings.html#sect_protocol_basic_dt_string_eof 66 | pub fn readRefRemaining(p: *PayloadReader) []const u8 { 67 | return p.readRefRuntime(p.payload.len - p.pos); 68 | } 69 | 70 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le 71 | // max possible value is 2^64 - 1, so return type is u64 72 | pub fn readLengthEncodedInteger(p: *PayloadReader) u64 { 73 | const first_byte = p.readByte(); 74 | switch (first_byte) { 75 | 0xFC => return p.readInt(u16), 76 | 0xFD => return p.readInt(u24), 77 | 0xFE => return p.readInt(u64), 78 | else => return first_byte, 79 | } 80 | } 81 | 82 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_strings.html#sect_protocol_basic_dt_string_le 83 | pub fn readLengthEncodedString(p: *PayloadReader) []const u8 { 84 | const length = p.readLengthEncodedInteger(); 85 | return p.readRefRuntime(@as(usize, length)); 86 | } 87 | 88 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_strings.html#sect_protocol_basic_dt_string_null 89 | pub fn readNullTerminatedString(p: *PayloadReader) [:0]const u8 { 90 | const i = std.mem.indexOfScalarPos(u8, p.payload, p.pos, 0) orelse { 91 | std.log.warn( 92 | "null terminated string not found\n, pos: {any}, payload: {any}", 93 | .{ p.pos, p.payload }, 94 | ); 95 | unreachable; 96 | }; 97 | 98 | const bytes = p.payload[p.pos..i]; 99 | p.pos = i + 1; 100 | return @ptrCast(bytes); 101 | } 102 | 103 | pub fn skipComptime(p: *PayloadReader, comptime n: usize) void { 104 | std.debug.assert(p.pos + n <= p.payload.len); 105 | p.pos += n; 106 | } 107 | 108 | pub fn finished(p: *const PayloadReader) bool { 109 | return p.payload.len == p.pos; 110 | } 111 | 112 | pub fn remained(p: *const PayloadReader) usize { 113 | return p.payload.len - p.pos; 114 | } 115 | 116 | pub fn readRefComptime(p: *PayloadReader, comptime n: usize) *const [n]u8 { 117 | std.debug.assert(p.pos + n <= p.payload.len); 118 | const bytes = p.payload[p.pos..][0..n]; 119 | p.pos += n; 120 | return bytes; 121 | } 122 | 123 | pub fn readRefRuntime(p: *PayloadReader, n: usize) []const u8 { 124 | std.debug.assert(p.pos + n <= p.payload.len); 125 | const bytes = p.payload[p.pos .. p.pos + n]; 126 | p.pos += n; 127 | return bytes; 128 | } 129 | }; 130 | -------------------------------------------------------------------------------- /src/protocol/packet_reader.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const utils = @import("./utils.zig"); 3 | const Packet = @import("./packet.zig"); 4 | 5 | pub const PacketReader = struct { 6 | stream: std.net.Stream, 7 | allocator: std.mem.Allocator, 8 | 9 | // valid buffer read from network but yet to consume to create packet: 10 | // buf[pos..len] 11 | buf: []u8, 12 | pos: usize, 13 | len: usize, 14 | 15 | // if in one read, the buffer is filled, we should double the buffer size 16 | should_double_buf: bool, 17 | 18 | pub fn init(stream: std.net.Stream, allocator: std.mem.Allocator) !PacketReader { 19 | return .{ 20 | .buf = &.{}, 21 | .stream = stream, 22 | .allocator = allocator, 23 | .pos = 0, 24 | .len = 0, 25 | .should_double_buf = false, 26 | }; 27 | } 28 | 29 | pub fn deinit(p: *const PacketReader) void { 30 | p.allocator.free(p.buf); 31 | } 32 | 33 | // invalidates the last packet returned 34 | pub fn readPacket(p: *PacketReader) !Packet.Packet { 35 | if (p.pos == p.len) { 36 | p.pos = 0; 37 | p.len = 0; 38 | try p.readToBufferAtLeast(4); 39 | } else if (p.len - p.pos < 4) { 40 | try p.readToBufferAtLeast(4); 41 | } 42 | 43 | // Packet header 44 | const payload_length = std.mem.readInt(u24, p.buf[p.pos..][0..3], .little); 45 | const sequence_id = p.buf[3]; 46 | p.pos += 4; 47 | 48 | { // read more bytes from network if required 49 | const n_valid_unread = p.len - p.pos; 50 | if (n_valid_unread < payload_length) { 51 | try p.readToBufferAtLeast(payload_length - n_valid_unread); 52 | } 53 | } 54 | 55 | // Packet payload 56 | const payload = p.buf[p.pos .. p.pos + payload_length]; 57 | p.pos += payload_length; 58 | 59 | return .{ 60 | .sequence_id = sequence_id, 61 | .payload = payload, 62 | }; 63 | } 64 | 65 | fn readToBufferAtLeast(p: *PacketReader, at_least: usize) !void { 66 | try p.expandBufIfNeeded(at_least); 67 | const n = try p.stream.readAtLeast(p.buf[p.len..], at_least); 68 | if (n == 0) { 69 | return error.UnexpectedEndOfStream; 70 | } 71 | 72 | p.len += n; 73 | if (n >= p.buf.len / 2) { 74 | p.should_double_buf = true; 75 | } 76 | } 77 | 78 | fn moveRemainingDataToBeginning(p: *PacketReader) void { 79 | if (p.pos == 0) { 80 | return; 81 | } 82 | const n_remain = p.len - p.pos; 83 | if (n_remain > p.pos) { // if overlap 84 | utils.memMove(p.buf, p.buf[p.pos..p.len]); 85 | } else { 86 | @memcpy(p.buf[0..n_remain], p.buf[p.pos..p.len]); 87 | } 88 | p.pos = 0; 89 | p.len = n_remain; 90 | } 91 | 92 | // ensure that the buffer can read extra `req_n` bytes 93 | fn expandBufIfNeeded(p: *PacketReader, req_n: usize) !void { 94 | if (p.buf.len - p.len >= req_n) { 95 | return; 96 | } 97 | 98 | const n_remain = p.len - p.pos; 99 | 100 | // possible to move remaining data to the beginning of the buffer 101 | // such that it will be enough? 102 | if (!p.should_double_buf) { 103 | // move remaining data to the beginning of the buffer 104 | const unused = p.buf.len - n_remain; 105 | if (unused >= req_n) { 106 | p.moveRemainingDataToBeginning(); 107 | return; 108 | } 109 | } 110 | 111 | const new_len = blk: { 112 | var current = p.buf.len; 113 | if (p.should_double_buf) { 114 | current *= 2; 115 | p.should_double_buf = false; 116 | } 117 | break :blk utils.nextPowerOf2(@truncate(req_n + current)); 118 | }; 119 | 120 | // try resize 121 | if (p.allocator.resize(p.buf, new_len)) { 122 | p.buf = p.buf[0..new_len]; 123 | p.moveRemainingDataToBeginning(); 124 | return; 125 | } 126 | 127 | // if resize failed, try to allocate a new buffer 128 | // and copy the remaining data to the new buffer 129 | const new_buf = try p.allocator.alloc(u8, new_len); 130 | @memcpy(new_buf[0..n_remain], p.buf[p.pos..p.len]); 131 | p.allocator.free(p.buf); 132 | p.buf = new_buf; 133 | p.pos = 0; 134 | p.len = n_remain; 135 | } 136 | }; 137 | -------------------------------------------------------------------------------- /src/protocol/packet_writer.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const utils = @import("./utils.zig"); 3 | 4 | pub const PacketWriter = struct { 5 | buf: []u8, 6 | pos: usize, // buf[0..pos]: buffer is written but not flushed 7 | stream: std.net.Stream, 8 | allocator: std.mem.Allocator, 9 | 10 | pub fn init(s: std.net.Stream, allocator: std.mem.Allocator) !PacketWriter { 11 | return .{ 12 | .stream = s, 13 | .buf = &.{}, 14 | .pos = 0, 15 | .allocator = allocator, 16 | }; 17 | } 18 | 19 | pub fn deinit(w: *const PacketWriter) void { 20 | w.allocator.free(w.buf); 21 | } 22 | 23 | // invalidates all previous writes 24 | pub fn reset(w: *PacketWriter) void { 25 | w.pos = 0; 26 | } 27 | 28 | pub fn write(w: *PacketWriter, src: []const u8) !void { 29 | try w.expandIfNeeded(src.len); 30 | const n = utils.copy(w.buf[w.pos..], src); 31 | std.debug.assert(n == src.len); 32 | w.pos += n; 33 | } 34 | 35 | // increase the length of the buffer as if it was written to 36 | fn skip(w: *PacketWriter, n: usize) !void { 37 | try w.expandIfNeeded(n); 38 | w.pos += n; 39 | } 40 | 41 | // increase the length of the buffer as if it was written to 42 | // returns a slice of the buffer that can be written to 43 | fn advance(w: *PacketWriter, n: usize) ![]u8 { 44 | try w.expandIfNeeded(n); 45 | const res = w.buf[w.pos .. w.pos + n]; 46 | w.pos += n; 47 | return res; 48 | } 49 | 50 | fn advanceComptime(w: *PacketWriter, comptime n: usize) !*[n]u8 { 51 | try w.expandIfNeeded(n); 52 | const res = w.buf[w.pos..][0..n]; 53 | w.pos += n; 54 | return res; 55 | } 56 | 57 | // flush the buffer to the stream 58 | pub inline fn flush(p: *PacketWriter) !void { 59 | try p.stream.writeAll(p.buf[0..p.pos]); 60 | p.pos = 0; 61 | } 62 | 63 | pub fn writeBytesAsPacket(p: *PacketWriter, sequence_id: u8, buffer: []const u8) !void { 64 | try p.writeInt(u24, @intCast(buffer.len)); 65 | try p.writeInt(u8, sequence_id); 66 | try p.write(buffer); 67 | } 68 | 69 | pub fn writePacket(p: *PacketWriter, sequence_id: u8, packet: anytype) !void { 70 | try p.writePacketInner(false, sequence_id, packet, {}); 71 | } 72 | 73 | pub fn writePacketWithParams(p: *PacketWriter, sequence_id: u8, packet: anytype, params: anytype) !void { 74 | try p.writePacketInner(true, sequence_id, packet, params); 75 | } 76 | 77 | fn writePacketInner( 78 | p: *PacketWriter, 79 | comptime has_params: bool, 80 | sequence_id: u8, 81 | packet: anytype, 82 | params: anytype, 83 | ) !void { 84 | const start = p.pos; 85 | try p.skip(4); 86 | // we need to write the payload length and sequence id later 87 | // after the packet is written 88 | // [0..3] [4] [......] 89 | // ^u24 payload_length ^u8 seq_id ^payload 90 | 91 | if (has_params) { 92 | try packet.writeWithParams(p, params); 93 | } else { 94 | try packet.write(p); 95 | } 96 | 97 | const written = p.pos - start - 4; 98 | const written_buf = p.buf[start..][0..3]; 99 | std.mem.writeInt(u24, written_buf, @intCast(written), .little); 100 | p.buf[start + 3] = sequence_id; 101 | } 102 | 103 | pub fn writeInt(p: *PacketWriter, comptime Int: type, int: Int) !void { 104 | const bytes = try p.advanceComptime(@divExact(@typeInfo(Int).int.bits, 8)); 105 | std.mem.writeInt(Int, bytes, int, .little); 106 | } 107 | 108 | pub fn writeNullTerminatedString(p: *PacketWriter, v: [:0]const u8) !void { 109 | try p.write(v[0 .. v.len + 1]); 110 | } 111 | 112 | pub fn writeFillers(p: *PacketWriter, comptime n: comptime_int) !void { 113 | _ = try p.advance(n); 114 | } 115 | 116 | pub fn writeLengthEncodedString(p: *PacketWriter, s: []const u8) !void { 117 | try p.writeLengthEncodedInteger(s.len); 118 | try p.write(s); 119 | } 120 | 121 | pub fn writeLengthEncodedInteger(p: *PacketWriter, v: u64) !void { 122 | if (v < 251) { 123 | try p.writeInt(u8, @intCast(v)); 124 | } else if (v < 1 << 16) { 125 | try p.writeInt(u8, 0xFC); 126 | try p.writeInt(u16, @intCast(v)); 127 | } else if (v < 1 << 24) { 128 | try p.writeInt(u8, 0xFD); 129 | try p.writeInt(u24, @intCast(v)); 130 | } else if (v < 1 << 64) { 131 | try p.writeInt(u8, 0xFE); 132 | try p.writeInt(u64, v); 133 | } else { 134 | std.log.warn("Invalid length encoded integer: {any}\n", .{v}); 135 | return error.InvalidLengthEncodedInteger; 136 | } 137 | } 138 | 139 | // invalidates all futures writes returned by `advance` 140 | fn expandIfNeeded(w: *PacketWriter, req_n: usize) !void { 141 | if (req_n <= w.buf.len - w.pos) { 142 | return; 143 | } 144 | 145 | const target_len = w.buf.len + req_n; 146 | const new_len = utils.nextPowerOf2(@truncate(target_len)); 147 | 148 | // try resize 149 | if (w.allocator.resize(w.buf, new_len)) { 150 | return; 151 | } 152 | 153 | // if resize failed, try to allocate a new buffer 154 | const new_buf = try w.allocator.alloc(u8, new_len); 155 | @memcpy(new_buf[0..w.buf.len], w.buf); 156 | w.allocator.free(w.buf); 157 | w.buf = new_buf; 158 | } 159 | }; 160 | 161 | // pub fn lengthEncodedStringPayloadSize(str_len: usize) u24 { 162 | // var str_len_24: u24 = @intCast(str_len); 163 | // if (str_len < 251) { 164 | // str_len_24 += 1; 165 | // } else if (str_len < 1 << 16) { 166 | // str_len_24 += 3; 167 | // } else if (str_len < 1 << 24) { 168 | // str_len_24 += 4; 169 | // } else if (str_len < 1 << 64) { 170 | // str_len_24 += 9; 171 | // } else unreachable; 172 | // return str_len_24; 173 | // } 174 | 175 | // pub fn lengthEncodedIntegerPayloadSize(v: u64) u24 { 176 | // if (v < 251) { 177 | // return 1; 178 | // } else if (v < 1 << 16) { 179 | // return 3; 180 | // } else if (v < 1 << 24) { 181 | // return 4; 182 | // } else if (v < 1 << 64) { 183 | // return 9; 184 | // } else unreachable; 185 | // } 186 | -------------------------------------------------------------------------------- /src/protocol/prepared_statements.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const constants = @import("../constants.zig"); 3 | const Packet = @import("./packet.zig").Packet; 4 | const PreparedStatement = @import("./../result.zig").PreparedStatement; 5 | const ColumnDefinition41 = @import("./column_definition.zig").ColumnDefinition41; 6 | const DateTime = @import("../temporal.zig").DateTime; 7 | const Duration = @import("../temporal.zig").Duration; 8 | const PacketWriter = @import("./packet_writer.zig").PacketWriter; 9 | const PacketReader = @import("./packet_reader.zig").PacketReader; 10 | const maxInt = std.math.maxInt; 11 | 12 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html 13 | pub const PrepareRequest = struct { 14 | query: []const u8, 15 | 16 | pub fn write(q: *const PrepareRequest, writer: *PacketWriter) !void { 17 | try writer.writeInt(u8, constants.COM_STMT_PREPARE); 18 | try writer.write(q.query); 19 | } 20 | }; 21 | 22 | pub const PrepareOk = struct { 23 | status: u8, 24 | statement_id: u32, 25 | num_columns: u16, 26 | num_params: u16, 27 | 28 | warning_count: ?u16, 29 | metadata_follows: ?u8, 30 | 31 | pub fn init(packet: *const Packet, capabilities: u32) PrepareOk { 32 | var prepare_ok_packet: PrepareOk = undefined; 33 | 34 | var reader = packet.reader(); 35 | prepare_ok_packet.status = reader.readByte(); 36 | prepare_ok_packet.statement_id = reader.readInt(u32); 37 | prepare_ok_packet.num_columns = reader.readInt(u16); 38 | prepare_ok_packet.num_params = reader.readInt(u16); 39 | 40 | // Reserved 1 byte 41 | _ = reader.readByte(); 42 | 43 | if (reader.payload.len >= 12) { // mysql says "> 12", but it seems to be ">= 12" 44 | prepare_ok_packet.warning_count = reader.readInt(u16); 45 | if (capabilities & constants.CLIENT_OPTIONAL_RESULTSET_METADATA > 0) { 46 | prepare_ok_packet.metadata_follows = reader.readByte(); 47 | } else { 48 | prepare_ok_packet.metadata_follows = null; 49 | } 50 | } else { 51 | prepare_ok_packet.warning_count = null; 52 | prepare_ok_packet.metadata_follows = null; 53 | } 54 | 55 | std.debug.assert(reader.finished()); 56 | return prepare_ok_packet; 57 | } 58 | }; 59 | 60 | pub const BinaryParam = struct { 61 | type_and_flag: [2]u8, // LSB: type, MSB: flag 62 | name: []const u8, 63 | raw: ?[]const u8, 64 | }; 65 | 66 | pub const ExecuteRequest = struct { 67 | prep_stmt: *const PreparedStatement, 68 | 69 | capabilities: u32, 70 | flags: u8 = 0, // Cursor type 71 | iteration_count: u32 = 1, // Always 1 72 | new_params_bind_flag: u8 = 1, 73 | 74 | // attributes: []const BinaryParam = &.{}, // Not supported yet 75 | 76 | pub fn writeWithParams(e: *const ExecuteRequest, writer: *PacketWriter, params: anytype) !void { 77 | try writer.writeInt(u8, constants.COM_STMT_EXECUTE); 78 | try writer.writeInt(u32, e.prep_stmt.prep_ok.statement_id); 79 | try writer.writeInt(u8, e.flags); 80 | try writer.writeInt(u32, e.iteration_count); 81 | 82 | const col_defs = e.prep_stmt.params; 83 | if (params.len != col_defs.len) { 84 | std.log.err("expected column count: {d}, but got {d}", .{ col_defs.len, params.len }); 85 | return error.ParamsCountNotMatch; 86 | } 87 | 88 | // const has_attributes_to_write = (e.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) and e.attributes.len > 0; 89 | 90 | // const param_count = params.len; 91 | if (params.len > 0 92 | //or has_attributes_to_write 93 | ) { 94 | // if (has_attributes_to_write) { 95 | // try packet_writer.writeLengthEncodedInteger(writer, e.attributes.len + param_count); 96 | // } 97 | 98 | // Write Null Bitmap 99 | // if (has_attributes_to_write) { 100 | // try writeNullBitmap(params, e.attributes, writer); 101 | // } else { 102 | // try writeNullBitmapWithAttrs(params, &.{}, writer); 103 | // } 104 | 105 | try writeNullBitmap(params, writer); 106 | 107 | // If a statement is re-executed without changing the params types, 108 | // the types do not need to be sent to the server again. 109 | // send type to server (0 / 1) 110 | try writer.writeLengthEncodedInteger(e.new_params_bind_flag); 111 | //if (e.new_params_bind_flag > 0) { 112 | comptime var enum_field_types: [params.len]constants.EnumFieldType = undefined; 113 | inline for (params, &enum_field_types) |param, *enum_field_type| { 114 | enum_field_type.* = comptime enumFieldTypeFromParam(@TypeOf(param)); 115 | } 116 | 117 | inline for (params, enum_field_types) |param, enum_field_type| { 118 | try writer.writeInt(u8, @intFromEnum(enum_field_type)); 119 | const sign_flag = switch (@typeInfo(@TypeOf(param))) { 120 | .comptime_int => if (param > maxInt(i64)) 0x80 else 0, 121 | .int => |int| if (int.signedness == .unsigned) 0x80 else 0, 122 | else => 0, 123 | }; 124 | try writer.writeInt(u8, sign_flag); 125 | 126 | // Not supported yet 127 | // if (e.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) { 128 | // try packet_writer.writeLengthEncodedString(writer, ""); 129 | // } 130 | } 131 | 132 | // if (has_attributes_to_write) { 133 | // for (e.attributes) |b| { 134 | // try writer.write(&b.type_and_flag); 135 | // try packet_writer.writeLengthEncodedString(writer, b.name); 136 | // } 137 | // } 138 | // } 139 | 140 | // TODO: Write params and attr as binary values 141 | // Write params as binary values 142 | inline for (params, enum_field_types) |param, enum_field_type| { 143 | if (isNull(param)) { 144 | try writeParamAsFieldType(writer, constants.EnumFieldType.MYSQL_TYPE_NULL, param); 145 | } else { 146 | try writeParamAsFieldType(writer, enum_field_type, param); 147 | } 148 | } 149 | 150 | // if (has_attributes_to_write) { 151 | // for (e.attributes) |b| { 152 | // try writeAttr(b, writer); 153 | // } 154 | // } 155 | } 156 | } 157 | }; 158 | 159 | fn enumFieldTypeFromParam(Param: type) constants.EnumFieldType { 160 | const param_type_info = @typeInfo(Param); 161 | return switch (Param) { 162 | DateTime => constants.EnumFieldType.MYSQL_TYPE_DATETIME, 163 | Duration => constants.EnumFieldType.MYSQL_TYPE_TIME, 164 | else => switch (param_type_info) { 165 | .null => return constants.EnumFieldType.MYSQL_TYPE_NULL, 166 | .optional => |o| return enumFieldTypeFromParam(o.child), 167 | .int => |int| { 168 | if (int.bits <= 8) { 169 | return constants.EnumFieldType.MYSQL_TYPE_TINY; 170 | } else if (int.bits <= 16) { 171 | return constants.EnumFieldType.MYSQL_TYPE_SHORT; 172 | } else if (int.bits <= 32) { 173 | return constants.EnumFieldType.MYSQL_TYPE_LONG; 174 | } else if (int.bits <= 64) { 175 | return constants.EnumFieldType.MYSQL_TYPE_LONGLONG; 176 | } 177 | }, 178 | .comptime_int => return constants.EnumFieldType.MYSQL_TYPE_LONGLONG, 179 | .float => |float| { 180 | if (float.bits <= 32) { 181 | return constants.EnumFieldType.MYSQL_TYPE_FLOAT; 182 | } else if (float.bits <= 64) { 183 | return constants.EnumFieldType.MYSQL_TYPE_DOUBLE; 184 | } 185 | }, 186 | .comptime_float => return constants.EnumFieldType.MYSQL_TYPE_DOUBLE, // Safer to assume double 187 | .array => |array| { 188 | switch (@typeInfo(array.child)) { 189 | .int => |int| { 190 | if (int.bits == 8) { 191 | return constants.EnumFieldType.MYSQL_TYPE_STRING; 192 | } 193 | }, 194 | else => {}, 195 | } 196 | }, 197 | .@"enum" => return constants.EnumFieldType.MYSQL_TYPE_STRING, 198 | .pointer => |pointer| { 199 | switch (pointer.size) { 200 | .one => return enumFieldTypeFromParam(pointer.child), 201 | else => {}, 202 | } 203 | switch (@typeInfo(pointer.child)) { 204 | .int => |int| { 205 | if (int.bits == 8) { 206 | switch (pointer.size) { 207 | .slice, .c, .many => return constants.EnumFieldType.MYSQL_TYPE_STRING, 208 | else => {}, 209 | } 210 | } 211 | }, 212 | else => {}, 213 | } 214 | }, 215 | else => { 216 | @compileLog(Param); 217 | @compileError("unsupported type"); 218 | }, 219 | }, 220 | }; 221 | } 222 | 223 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row_value 224 | // https://mariadb.com/kb/en/com_stmt_execute/#binary-parameter-encoding 225 | fn writeParamAsFieldType( 226 | writer: *PacketWriter, 227 | comptime enum_field_type: constants.EnumFieldType, 228 | param: anytype, 229 | ) !void { 230 | return switch (@typeInfo(@TypeOf(param))) { 231 | .optional => if (param) |p| { 232 | return try writeParamAsFieldType(writer, enum_field_type, p); 233 | } else { 234 | return; 235 | }, 236 | else => switch (enum_field_type) { 237 | .MYSQL_TYPE_NULL => {}, 238 | .MYSQL_TYPE_TINY => try writer.writeInt(u8, uintCast(u8, i8, param)), 239 | .MYSQL_TYPE_SHORT => try writer.writeInt(u16, uintCast(u16, i16, param)), 240 | .MYSQL_TYPE_LONG => try writer.writeInt(u32, uintCast(u32, i32, param)), 241 | .MYSQL_TYPE_LONGLONG => try writer.writeInt(u64, uintCast(u64, i64, param)), 242 | .MYSQL_TYPE_FLOAT => try writer.writeInt(u32, @bitCast(@as(f32, param))), 243 | .MYSQL_TYPE_DOUBLE => try writer.writeInt(u64, @bitCast(@as(f64, param))), 244 | .MYSQL_TYPE_DATETIME => try writeDateTime(param, writer), 245 | .MYSQL_TYPE_TIME => try writeDuration(param, writer), 246 | .MYSQL_TYPE_STRING => try writer.writeLengthEncodedString(stringCast(param)), 247 | else => { 248 | @compileLog(enum_field_type); 249 | @compileLog(param); 250 | @compileError("unsupported type"); 251 | }, 252 | }, 253 | }; 254 | } 255 | 256 | fn stringCast(param: anytype) []const u8 { 257 | switch (@typeInfo(@TypeOf(param))) { 258 | .pointer => |pointer| { 259 | switch (pointer.size) { 260 | .c, .many => return std.mem.span(param), 261 | else => {}, 262 | } 263 | }, 264 | .@"enum" => return @tagName(param), 265 | else => {}, 266 | } 267 | 268 | return param; 269 | } 270 | 271 | fn uintCast(comptime UInt: type, comptime Int: type, value: anytype) UInt { 272 | return switch (@TypeOf(value)) { 273 | comptime_int => comptimeIntToUInt(UInt, Int, value), 274 | else => @bitCast(value), 275 | }; 276 | } 277 | 278 | fn comptimeIntToUInt( 279 | comptime UInt: type, 280 | comptime Int: type, 281 | comptime int: comptime_int, 282 | ) UInt { 283 | if (comptime (int < 0)) { 284 | return @bitCast(@as(Int, int)); 285 | } else { 286 | return int; 287 | } 288 | } 289 | 290 | // To save space the packet can be compressed: 291 | // if year, month, day, hour, minutes, seconds and microseconds are all 0, length is 0 and no other field is sent. 292 | // if hour, seconds and microseconds are all 0, length is 4 and no other field is sent. 293 | // if microseconds is 0, length is 7 and micro_seconds is not sent. 294 | // otherwise the length is 11 295 | fn writeDateTime(dt: DateTime, writer: *PacketWriter) !void { 296 | if (dt.microsecond > 0) { 297 | try writer.writeInt(u8, 11); 298 | try writer.writeInt(u16, dt.year); 299 | try writer.writeInt(u8, dt.month); 300 | try writer.writeInt(u8, dt.day); 301 | try writer.writeInt(u8, dt.hour); 302 | try writer.writeInt(u8, dt.minute); 303 | try writer.writeInt(u8, dt.second); 304 | try writer.writeInt(u32, dt.microsecond); 305 | } else if (dt.hour > 0 or dt.minute > 0 or dt.second > 0) { 306 | try writer.writeInt(u8, 7); 307 | try writer.writeInt(u16, dt.year); 308 | try writer.writeInt(u8, dt.month); 309 | try writer.writeInt(u8, dt.day); 310 | try writer.writeInt(u8, dt.hour); 311 | try writer.writeInt(u8, dt.minute); 312 | try writer.writeInt(u8, dt.second); 313 | } else if (dt.year > 0 or dt.month > 0 or dt.day > 0) { 314 | try writer.writeInt(u8, 4); 315 | try writer.writeInt(u16, dt.year); 316 | try writer.writeInt(u8, dt.month); 317 | try writer.writeInt(u8, dt.day); 318 | } else { 319 | try writer.writeInt(u8, 0); 320 | } 321 | } 322 | 323 | // To save space the packet can be compressed: 324 | // if day, hour, minutes, seconds and microseconds are all 0, length is 0 and no other field is sent. 325 | // if microseconds is 0, length is 8 and micro_seconds is not sent. 326 | // otherwise the length is 12 327 | fn writeDuration(d: Duration, writer: *PacketWriter) !void { 328 | if (d.microseconds > 0) { 329 | try writer.writeInt(u8, 12); 330 | try writer.writeInt(u8, d.is_negative); 331 | try writer.writeInt(u32, d.days); 332 | try writer.writeInt(u8, d.hours); 333 | try writer.writeInt(u8, d.minutes); 334 | try writer.writeInt(u8, d.seconds); 335 | try writer.writeInt(u32, d.microseconds); 336 | } else if (d.days > 0 or d.hours > 0 or d.minutes > 0 or d.seconds > 0) { 337 | try writer.writeInt(u8, 8); 338 | try writer.writeInt(u8, d.is_negative); 339 | try writer.writeInt(u32, d.days); 340 | try writer.writeInt(u8, d.hours); 341 | try writer.writeInt(u8, d.minutes); 342 | try writer.writeInt(u8, d.seconds); 343 | } else { 344 | try writer.writeInt(u8, 0); 345 | } 346 | } 347 | 348 | fn writeNullBitmap(params: anytype, writer: *PacketWriter) !void { 349 | comptime var pos: usize = 0; 350 | var byte: u8 = 0; 351 | var current_bit: u8 = 1; 352 | inline for (params) |param| { 353 | pos += 1; 354 | if (isNull(param)) { 355 | byte |= current_bit; 356 | } 357 | current_bit <<= 1; 358 | 359 | if (pos == 8) { 360 | try writer.writeInt(u8, byte); 361 | byte = 0; 362 | current_bit = 1; 363 | pos = 0; 364 | } 365 | } 366 | if (pos > 0) { 367 | try writer.writeInt(u8, byte); 368 | } 369 | } 370 | 371 | fn writeNullBitmapWithAttrs(params: anytype, attributes: []const BinaryParam, writer: *PacketWriter) !void { 372 | const byte_count = (params.len + attributes.len + 7) / 8; 373 | for (0..byte_count) |i| { 374 | const start = i * 8; 375 | const end = (i + 1) * 8; 376 | 377 | const byte: u8 = blk: { 378 | if (params.len >= end) { 379 | break :blk nullBitsParams(params, start); 380 | } else if (start >= params.len) { 381 | break :blk nullBitsAttrs(attributes[(start - params.len)..]); 382 | } else { 383 | break :blk nullBitsParamsAttrs(params, start, attributes); 384 | } 385 | }; 386 | 387 | // [1,1,1,1] [1,1,1] 388 | // start = 0, end = 8 389 | try writer.writeInt(u8, byte); 390 | } 391 | } 392 | 393 | pub fn nullBitsParams(params: anytype, start: usize) u8 { 394 | var byte: u8 = 0; 395 | 396 | var current_bit: u8 = 1; 397 | 398 | const end = comptime if (params.len > 8) 8 else params.len; 399 | inline for (params, 0..) |param, i| { 400 | if (i >= end) break; 401 | if (i >= start) { 402 | if (isNull(param)) { 403 | byte |= current_bit; 404 | } 405 | current_bit <<= 1; 406 | } 407 | } 408 | 409 | return byte; 410 | } 411 | 412 | pub fn nullBitsAttrs(attrs: []const BinaryParam) u8 { 413 | const final_attrs = if (attrs.len > 8) attrs[0..8] else attrs; 414 | 415 | var byte: u8 = 0; 416 | var current_bit: u8 = 1; 417 | for (final_attrs) |p| { 418 | if (p.raw == null) { 419 | byte |= current_bit; 420 | } 421 | current_bit <<= 1; 422 | } 423 | return byte; 424 | } 425 | 426 | pub fn nullBitsParamsAttrs(params: anytype, start: usize, attrs: []const BinaryParam) u8 { 427 | const final_attributes = if (attrs.len > 8) attrs[0..8] else attrs; 428 | 429 | var byte: u8 = 0; 430 | var current_bit: u8 = 1; 431 | 432 | inline for (params, 0..) |param, i| { 433 | if (i >= start) { 434 | if (isNull(param)) byte |= current_bit; 435 | current_bit <<= 1; 436 | } 437 | } 438 | 439 | for (final_attributes) |p| { 440 | if (p.raw == null) { 441 | byte |= current_bit; 442 | } 443 | current_bit <<= 1; 444 | } 445 | 446 | return byte; 447 | } 448 | 449 | inline fn isNull(param: anytype) bool { 450 | return comptime switch (@typeInfo(@TypeOf(param))) { 451 | inline .optional => if (param) |p| isNull(p) else true, 452 | inline .null => true, 453 | inline else => false, 454 | }; 455 | } 456 | 457 | fn nonNullBinaryParam() BinaryParam { 458 | return .{ 459 | .type_and_flag = .{ 0x00, 0x00 }, 460 | .name = "foo", 461 | .raw = "bar", 462 | }; 463 | } 464 | 465 | fn nullBinaryParam() BinaryParam { 466 | return .{ 467 | .type_and_flag = .{ 0x00, 0x00 }, 468 | .name = "hello", 469 | .raw = null, 470 | }; 471 | } 472 | 473 | test "writeNullBitmap" { 474 | const some_param = nonNullBinaryParam(); 475 | const null_param = nullBinaryParam(); 476 | 477 | const tests = .{ 478 | .{ 479 | .params = &.{1}, 480 | .attributes = &.{some_param}, 481 | .expected = &[_]u8{0b00000000}, 482 | }, 483 | .{ 484 | .params = &.{ null, @as(?u8, null) }, 485 | .attributes = &.{null_param}, 486 | .expected = &[_]u8{0b00000111}, 487 | }, 488 | .{ 489 | .params = &.{ null, null, null, null, null, null, null, null }, 490 | .attributes = &.{}, 491 | .expected = &[_]u8{0b11111111}, 492 | }, 493 | .{ 494 | .params = &.{ null, null, null, null, null, null, null, null, null }, 495 | .attributes = &.{}, 496 | .expected = &[_]u8{ 0b11111111, 0b00000001 }, 497 | }, 498 | .{ 499 | .params = &.{}, 500 | .attributes = &.{ null_param, null_param, null_param, null_param, null_param, null_param, null_param, null_param, null_param }, 501 | .expected = &[_]u8{ 0b11111111, 0b00000001 }, 502 | }, 503 | .{ 504 | .params = &.{}, 505 | .attributes = &.{ null_param, null_param, null_param, null_param, null_param, null_param, null_param, null_param }, 506 | .expected = &[_]u8{0b11111111}, 507 | }, 508 | .{ 509 | .params = &.{}, 510 | .attributes = &.{ null_param, null_param, null_param, null_param, null_param, null_param, null_param }, 511 | .expected = &[_]u8{0b01111111}, 512 | }, 513 | .{ 514 | .params = &.{ null, null, null, null, null, null, null, null }, 515 | .attributes = &.{null_param}, 516 | .expected = &[_]u8{ 0b11111111, 0b00000001 }, 517 | }, 518 | .{ 519 | .params = &.{ null, null, null, null, null, null, null }, 520 | .attributes = &.{ null_param, null_param }, 521 | .expected = &[_]u8{ 0b11111111, 0b00000001 }, 522 | }, 523 | .{ 524 | .params = &.{ null, null, null, null }, 525 | .attributes = &.{ null_param, null_param, null_param, null_param }, 526 | .expected = &[_]u8{0b11111111}, 527 | }, 528 | .{ 529 | .params = &.{ null, null, null, null, null, null, null, null, null }, 530 | .attributes = &.{ null_param, null_param }, 531 | .expected = &[_]u8{ 0b11111111, 0b00000111 }, 532 | }, 533 | }; 534 | 535 | inline for (tests) |t| { 536 | var buf: [1024]u8 = undefined; 537 | 538 | var fake_packet_writer: PacketWriter = .{ 539 | .buf = &buf, 540 | .pos = 0, 541 | .stream = undefined, 542 | .allocator = std.testing.allocator, 543 | }; 544 | fake_packet_writer = 545 | fake_packet_writer; 546 | 547 | _ = try writeNullBitmapWithAttrs(t.params, t.attributes, &fake_packet_writer); 548 | const written = buf[0..fake_packet_writer.pos]; 549 | try std.testing.expectEqualSlices(u8, t.expected, written); 550 | } 551 | } 552 | -------------------------------------------------------------------------------- /src/protocol/text_command.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const PacketWriter = @import("./packet_writer.zig").PacketWriter; 3 | const constants = @import("../constants.zig"); 4 | 5 | pub const QueryParam = struct { 6 | type_and_flag: [2]u8, // LSB: type, MSB: flag 7 | name: []const u8, 8 | value: []const u8, 9 | }; 10 | 11 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html 12 | pub const QueryRequest = struct { 13 | query: []const u8, 14 | 15 | // params 16 | capabilities: u32 = 0, 17 | params: []const ?QueryParam = &.{}, 18 | 19 | pub fn write(q: *const QueryRequest, writer: *PacketWriter) !void { 20 | // Packet Header 21 | try writer.writeInt(u8, constants.COM_QUERY); 22 | 23 | // Query Parameters 24 | if (q.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) { 25 | try writer.writeLengthEncodedInteger(q.params.len); 26 | try writer.writeLengthEncodedInteger(1); // Number of parameter sets. Currently always 1 27 | if (q.params.len > 0) { 28 | // NULL bitmap, length= (num_params + 7) / 8 29 | try writeNullBitmap(writer, q.params); 30 | 31 | // new_params_bind_flag 32 | // Always 1. Malformed packet error if not 1 33 | try writer.writeInt(u8, 1); 34 | 35 | // write type_and_flag, name and values 36 | // for each parameter 37 | for (q.params) |p_opt| { 38 | const p = p_opt orelse continue; // TODO: may not be correct 39 | try writer.write(&p.type_and_flag); 40 | try writer.writeLengthEncodedString(p.name); 41 | try writer.write(p.value); 42 | } 43 | } 44 | } 45 | 46 | // Query String 47 | try writer.write(q.query); 48 | } 49 | }; 50 | 51 | pub fn writeNullBitmap(writer: *PacketWriter, params: []const ?QueryParam) !void { 52 | const byte_count = (params.len + 7) / 8; 53 | for (0..byte_count) |i| { 54 | const byte = nullBits(params[i * 8 ..]); 55 | try writer.writeInt(u8, byte); 56 | } 57 | } 58 | 59 | pub fn nullBits(params: []const ?QueryParam) u8 { 60 | const final_params = if (params.len > 8) 61 | params[0..8] 62 | else 63 | params; 64 | 65 | var byte: u8 = 0; 66 | var current_bit: u8 = 1; 67 | for (final_params) |p_opt| { 68 | if (p_opt == null) { 69 | byte |= current_bit; 70 | } 71 | current_bit <<= 1; 72 | } 73 | return byte; 74 | } 75 | 76 | // test "writeNullBitmap - 1" { 77 | // const params: []const ?QueryParam = &.{ 78 | // null, 79 | // .{ 80 | // .type_and_flag = .{ 0, 0 }, 81 | // .name = "foo", 82 | // .value = "bar", 83 | // }, 84 | // null, 85 | // }; 86 | // 87 | // var buffer: [4]u8 = undefined; 88 | // var fbs = std.io.fixedBufferStream(&buffer); 89 | // try writeNullBitmap(params, &fbs); 90 | // 91 | // const written = fbs.buffer[0..fbs.pos]; 92 | // 93 | // // TODO: not sure if this is the expected result 94 | // // but it serves a good reference for now 95 | // // could be big endian 96 | // try std.testing.expectEqualSlices(u8, written, &[_]u8{0b00000101}); 97 | // } 98 | // 99 | // test "writeNullBitmap - 2" { 100 | // const params: []const ?QueryParam = &.{ 101 | // null, null, null, null, 102 | // null, null, null, null, 103 | // null, null, null, null, 104 | // }; 105 | // 106 | // var buffer: [4]u8 = undefined; 107 | // var fbs = std.io.fixedBufferStream(&buffer); 108 | // try writeNullBitmap(params, &fbs); 109 | // 110 | // const written = fbs.buffer[0..fbs.pos]; 111 | // 112 | // // TODO: not sure if this is the expected result 113 | // // but it serves a good reference for now 114 | // // could be big endian 115 | // try std.testing.expectEqualSlices(u8, &[_]u8{ 0b11111111, 0b00001111 }, written); 116 | // } 117 | -------------------------------------------------------------------------------- /src/protocol/utils.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn copy(dest: []u8, src: []const u8) usize { 4 | const amount_copied = @min(dest.len, src.len); 5 | const final_dest = dest[0..amount_copied]; 6 | const final_src = src[0..amount_copied]; 7 | @memcpy(final_dest, final_src); 8 | return amount_copied; 9 | } 10 | 11 | test "copy - same length" { 12 | const src = "hello"; 13 | var dest = [_]u8{ 0, 0, 0, 0, 0 }; 14 | 15 | const n = copy(&dest, src); 16 | try std.testing.expectEqual(@as(usize, 5), n); 17 | try std.testing.expectEqualSlices(u8, src, &dest); 18 | } 19 | 20 | test "copy - src length is longer" { 21 | const src = "hello_goodbye"; 22 | var dest = [_]u8{ 0, 0, 0, 0, 0 }; 23 | 24 | const n = copy(&dest, src); 25 | try std.testing.expectEqual(@as(usize, 5), n); 26 | try std.testing.expectEqualSlices(u8, "hello", &dest); 27 | } 28 | 29 | test "copy - dest length is longer" { 30 | const src = "hello"; 31 | var dest = [_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 32 | 33 | const n = copy(&dest, src); 34 | try std.testing.expectEqual(@as(usize, 5), n); 35 | try std.testing.expectEqualSlices(u8, "hello", dest[0..n]); 36 | } 37 | 38 | // dst.len >= src.len to ensure all data can be moved 39 | pub fn memMove(dst: []u8, src: []const u8) void { 40 | std.debug.assert(dst.len >= src.len); 41 | for (dst[0..src.len], src) |*d, s| { 42 | d.* = s; 43 | } 44 | } 45 | 46 | // 1 -> 1 47 | // 2 -> 2 48 | // 3 -> 4 49 | // 4 -> 4 50 | // 5 -> 8 51 | // 6 -> 8 52 | // ... 53 | pub fn nextPowerOf2(n: u32) u32 { 54 | std.debug.assert(n > 0); 55 | var x = n - 1; 56 | x |= x >> 1; 57 | x |= x >> 2; 58 | x |= x >> 4; 59 | x |= x >> 8; 60 | x |= x >> 16; 61 | return x + 1; 62 | } 63 | -------------------------------------------------------------------------------- /src/result.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const protocol = @import("./protocol.zig"); 3 | const constants = @import("./constants.zig"); 4 | const prep_stmts = protocol.prepared_statements; 5 | const PrepareOk = prep_stmts.PrepareOk; 6 | const Packet = protocol.packet.Packet; 7 | const PayloadReader = protocol.packet.PayloadReader; 8 | const OkPacket = protocol.generic_response.OkPacket; 9 | const ErrorPacket = protocol.generic_response.ErrorPacket; 10 | const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41; 11 | const Conn = @import("./conn.zig").Conn; 12 | const conversion = @import("./conversion.zig"); 13 | 14 | pub const QueryResult = union(enum) { 15 | ok: OkPacket, 16 | err: ErrorPacket, 17 | 18 | pub fn init(packet: *const Packet, capabilities: u32) !QueryResult { 19 | return switch (packet.payload[0]) { 20 | constants.OK => .{ .ok = OkPacket.init(packet, capabilities) }, 21 | constants.ERR => .{ .err = ErrorPacket.init(packet) }, 22 | constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"), 23 | else => { 24 | std.log.warn( 25 | \\Unexpected packet: {any}\n, 26 | \\Are you expecting a result set? If so, use QueryResultRows instead. 27 | \\This is unrecoverable error. 28 | , .{packet}); 29 | return error.UnrecoverableError; 30 | }, 31 | }; 32 | } 33 | 34 | pub fn expect( 35 | q: QueryResult, 36 | comptime value_variant: std.meta.FieldEnum(QueryResult), 37 | ) !std.meta.FieldType(QueryResult, value_variant) { 38 | return switch (q) { 39 | value_variant => @field(q, @tagName(value_variant)), 40 | else => { 41 | return switch (q) { 42 | .err => |err| return err.asError(), 43 | .ok => |ok| { 44 | std.log.err("Unexpected OkPacket: {any}\n", .{ok}); 45 | return error.UnexpectedOk; 46 | }, 47 | }; 48 | }, 49 | }; 50 | } 51 | }; 52 | 53 | /// T is either: 54 | /// - TextResultRow (queryRows) 55 | /// - BinaryResultRow (executeRows) 56 | pub fn QueryResultRows(comptime T: type) type { 57 | return union(enum) { 58 | err: ErrorPacket, 59 | rows: ResultSet(T), 60 | 61 | // allocation happens when a result set is returned 62 | pub fn init(c: *Conn) !QueryResultRows(T) { 63 | const packet = try c.readPacket(); 64 | return switch (packet.payload[0]) { 65 | constants.OK => { 66 | std.log.warn( 67 | \\Unexpected OkPacket: {any}\n, 68 | \\If your query is not expecting a result set, use QueryResult instead. 69 | , .{OkPacket.init(&packet, c.capabilities)}); 70 | return packet.asError(); 71 | }, 72 | constants.ERR => .{ .err = ErrorPacket.init(&packet) }, 73 | constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"), 74 | else => .{ .rows = try ResultSet(T).init(c, &packet) }, 75 | }; 76 | } 77 | 78 | /// Example (TextResultRow): 79 | /// ... 80 | /// const result: QueryResultRows(TextResultRow) = try conn.queryRows("SELECT * FROM table"); 81 | /// const rows: ResultSet(TextResultRow) = try result.expect(.rows); 82 | /// ... 83 | pub fn expect( 84 | q: QueryResultRows(T), 85 | comptime value_variant: std.meta.FieldEnum(QueryResultRows(T)), 86 | ) !std.meta.FieldType(QueryResultRows(T), value_variant) { 87 | return switch (q) { 88 | value_variant => @field(q, @tagName(value_variant)), 89 | else => { 90 | return switch (q) { 91 | .err => |err| return err.asError(), 92 | .rows => |rows| { 93 | std.log.err("Unexpected ResultSet: {any}\n", .{rows}); 94 | return error.UnexpectedResultSet; 95 | }, 96 | }; 97 | }, 98 | }; 99 | } 100 | }; 101 | } 102 | 103 | /// T is either: 104 | /// - TextResultRow (queryRows) 105 | /// - BinaryResultRow (executeRows) 106 | pub fn ResultSet(comptime T: type) type { 107 | return struct { 108 | conn: *Conn, 109 | col_defs: []const ColumnDefinition41, 110 | 111 | pub fn init(conn: *Conn, packet: *const Packet) !ResultSet(T) { 112 | var reader = packet.reader(); 113 | const n_columns = reader.readLengthEncodedInteger(); 114 | std.debug.assert(reader.finished()); 115 | 116 | try conn.readPutResultColumns(n_columns); 117 | 118 | return .{ 119 | .conn = conn, 120 | .col_defs = conn.result_meta.col_defs.items, 121 | }; 122 | } 123 | 124 | fn deinit(r: *const ResultSet(T), allocator: std.mem.Allocator) void { 125 | for (r.col_packets) |packet| { 126 | packet.deinit(allocator); 127 | } 128 | allocator.free(r.col_packets); 129 | allocator.free(r.col_defs); 130 | } 131 | 132 | pub fn readRow(r: *const ResultSet(T)) !ResultRow(T) { 133 | return ResultRow(T).init(r.conn, r.col_defs); 134 | } 135 | 136 | pub fn tableTexts(r: *const ResultSet(TextResultRow), allocator: std.mem.Allocator) !TableTexts { 137 | const all_rows = try collectAllRowsPacketUntilEof(r.conn, allocator); 138 | errdefer deinitOwnedPacketList(all_rows); 139 | return try TableTexts.init(all_rows, allocator, r.col_defs.len); 140 | } 141 | 142 | pub fn first(r: *const ResultSet(T)) !?T { 143 | const row_res = try r.readRow(); 144 | return switch (row_res) { 145 | .ok => null, 146 | .err => |err| err.asError(), 147 | .row => |row| blk: { 148 | const i = r.iter(); 149 | while (try i.next()) |_| {} 150 | break :blk row; 151 | }, 152 | }; 153 | } 154 | 155 | pub fn iter(r: *const ResultSet(T)) ResultRowIter(T) { 156 | return .{ .result_set = r }; 157 | } 158 | }; 159 | } 160 | 161 | pub const TextResultRow = struct { 162 | packet: Packet, 163 | col_defs: []const ColumnDefinition41, 164 | 165 | pub fn iter(t: *const TextResultRow) TextElemIter { 166 | return TextElemIter.init(&t.packet); 167 | } 168 | 169 | pub fn textElems(t: *const TextResultRow, allocator: std.mem.Allocator) !TextElems { 170 | return TextElems.init(&t.packet, allocator, t.col_defs.len); 171 | } 172 | }; 173 | 174 | pub const TextElems = struct { 175 | packet: Packet, 176 | elems: []const ?[]const u8, 177 | 178 | pub fn init(p: *const Packet, allocator: std.mem.Allocator, n: usize) !TextElems { 179 | const packet = try p.cloneAlloc(allocator); 180 | errdefer packet.deinit(allocator); 181 | const elems = try allocator.alloc(?[]const u8, n); 182 | scanTextResultRow(elems, &packet); 183 | return .{ .packet = packet, .elems = elems }; 184 | } 185 | 186 | pub fn deinit(t: *const TextElems, allocator: std.mem.Allocator) void { 187 | t.packet.deinit(allocator); 188 | allocator.free(t.elems); 189 | } 190 | }; 191 | 192 | pub const TextElemIter = struct { 193 | reader: PayloadReader, 194 | 195 | pub fn init(packet: *const Packet) TextElemIter { 196 | return .{ .reader = packet.reader() }; 197 | } 198 | 199 | pub fn next(i: *TextElemIter) ??[]const u8 { 200 | const first_byte = i.reader.peek() orelse return null; 201 | if (first_byte == constants.TEXT_RESULT_ROW_NULL) { 202 | i.reader.skipComptime(1); 203 | return @as(?[]const u8, null); 204 | } 205 | return i.reader.readLengthEncodedString(); 206 | } 207 | }; 208 | 209 | fn scanTextResultRow(dest: []?[]const u8, packet: *const Packet) void { 210 | var reader = packet.reader(); 211 | for (dest) |*d| { 212 | d.* = blk: { 213 | const first_byte = reader.peek() orelse unreachable; 214 | if (first_byte == constants.TEXT_RESULT_ROW_NULL) { 215 | reader.skipComptime(1); 216 | break :blk null; 217 | } 218 | break :blk reader.readLengthEncodedString(); 219 | }; 220 | } 221 | } 222 | 223 | pub const BinaryResultRow = struct { 224 | packet: Packet, 225 | col_defs: []const ColumnDefinition41, 226 | 227 | // dest: pointer to a struct 228 | // string types like []u8, []const u8, ?[]u8 are shallow copied, data may be invalidated 229 | // from next scan, or network request. 230 | // use structCreate and structDestroy to allocate and deallocate struct objects 231 | // from binary result values 232 | pub fn scan(b: *const BinaryResultRow, dest: anytype) !void { 233 | try conversion.scanBinResultRow(dest, &b.packet, b.col_defs, null); 234 | } 235 | 236 | // returns a pointer to allocated struct object, caller must remember to call structDestroy 237 | // after use 238 | pub fn structCreate(b: *const BinaryResultRow, comptime Struct: type, allocator: std.mem.Allocator) !*Struct { 239 | const s = try allocator.create(Struct); 240 | try conversion.scanBinResultRow(s, &b.packet, b.col_defs, allocator); 241 | return s; 242 | } 243 | 244 | // deallocate struct object created from `structCreate` 245 | // s: *Struct 246 | pub fn structDestroy(s: anytype, allocator: std.mem.Allocator) void { 247 | structFreeDynamic(s.*, allocator); 248 | allocator.destroy(s); 249 | } 250 | 251 | fn structFreeDynamic(s: anytype, allocator: std.mem.Allocator) void { 252 | const s_ti = @typeInfo(@TypeOf(s)).@"struct"; 253 | inline for (s_ti.fields) |field| { 254 | structFreeStr(field.type, @field(s, field.name), allocator); 255 | } 256 | } 257 | 258 | fn structFreeStr(comptime StructField: type, value: StructField, allocator: std.mem.Allocator) void { 259 | switch (@typeInfo(StructField)) { 260 | .pointer => |p| switch (@typeInfo(p.child)) { 261 | .int => |int| if (int.bits == 8) { 262 | allocator.free(value); 263 | }, 264 | else => {}, 265 | }, 266 | .optional => |o| if (value) |some| structFreeStr(o.child, some, allocator), 267 | else => {}, 268 | } 269 | } 270 | }; 271 | 272 | pub fn ResultRow(comptime T: type) type { 273 | return union(enum) { 274 | err: ErrorPacket, 275 | ok: OkPacket, 276 | row: T, 277 | 278 | fn init(conn: *Conn, col_defs: []const ColumnDefinition41) !ResultRow(T) { 279 | const packet = try conn.readPacket(); 280 | return switch (packet.payload[0]) { 281 | constants.ERR => .{ .err = ErrorPacket.init(&packet) }, 282 | constants.EOF => .{ .ok = OkPacket.init(&packet, conn.capabilities) }, 283 | else => .{ .row = .{ .packet = packet, .col_defs = col_defs } }, 284 | }; 285 | } 286 | 287 | pub fn expect( 288 | r: ResultRow(T), 289 | comptime value_variant: std.meta.FieldEnum(ResultRow(T)), 290 | ) !std.meta.FieldType(ResultRow(T), value_variant) { 291 | return switch (r) { 292 | value_variant => @field(r, @tagName(value_variant)), 293 | else => { 294 | return switch (r) { 295 | .err => |err| return err.asError(), 296 | .ok => |ok| { 297 | std.log.err("Unexpected OkPacket: {any}\n", .{ok}); 298 | return error.UnexpectedOk; 299 | }, 300 | .row => |data| { 301 | std.log.err("Unexpected Row: {any}\n", .{data}); 302 | return error.UnexpectedResultData; 303 | }, 304 | }; 305 | }, 306 | }; 307 | } 308 | }; 309 | } 310 | 311 | fn deinitOwnedPacketList(packet_list: std.ArrayList(Packet)) void { 312 | for (packet_list.items) |packet| { 313 | packet.deinit(packet_list.allocator); 314 | } 315 | packet_list.deinit(); 316 | } 317 | 318 | fn collectAllRowsPacketUntilEof(conn: *Conn, allocator: std.mem.Allocator) !std.ArrayList(Packet) { 319 | var packet_list = std.ArrayList(Packet).init(allocator); 320 | errdefer deinitOwnedPacketList(packet_list); 321 | 322 | // Accumulate all packets until EOF 323 | while (true) { 324 | const packet = try conn.readPacket(); 325 | return switch (packet.payload[0]) { 326 | constants.ERR => ErrorPacket.init(&packet).asError(), 327 | constants.EOF => { 328 | _ = OkPacket.init(&packet, conn.capabilities); 329 | return packet_list; 330 | }, 331 | else => { 332 | const owned_packet = try packet.cloneAlloc(allocator); 333 | try packet_list.append(owned_packet); 334 | continue; 335 | }, 336 | }; 337 | } 338 | } 339 | 340 | pub const PrepareResult = union(enum) { 341 | err: ErrorPacket, 342 | stmt: PreparedStatement, 343 | 344 | pub fn init(c: *Conn, allocator: std.mem.Allocator) !PrepareResult { 345 | const response_packet = try c.readPacket(); 346 | return switch (response_packet.payload[0]) { 347 | constants.ERR => .{ .err = ErrorPacket.init(&response_packet) }, 348 | constants.OK => .{ .stmt = try PreparedStatement.init(&response_packet, c, allocator) }, 349 | else => return response_packet.asError(), 350 | }; 351 | } 352 | 353 | pub fn deinit(p: *const PrepareResult, allocator: std.mem.Allocator) void { 354 | switch (p.*) { 355 | .stmt => |prep_stmt| prep_stmt.deinit(allocator), 356 | else => {}, 357 | } 358 | } 359 | 360 | pub fn expect( 361 | p: PrepareResult, 362 | comptime value_variant: std.meta.FieldEnum(PrepareResult), 363 | ) !std.meta.FieldType(PrepareResult, value_variant) { 364 | return switch (p) { 365 | value_variant => @field(p, @tagName(value_variant)), 366 | else => { 367 | return switch (p) { 368 | .err => |err| return err.asError(), 369 | .stmt => |ok| { 370 | std.log.err("Unexpected PreparedStatement: {any}\n", .{ok}); 371 | return error.UnexpectedOk; 372 | }, 373 | }; 374 | }, 375 | }; 376 | } 377 | }; 378 | 379 | pub const PreparedStatement = struct { 380 | prep_ok: PrepareOk, 381 | packets: []const Packet, 382 | col_defs: []const ColumnDefinition41, 383 | params: []const ColumnDefinition41, // parameters that would be passed when executing the query 384 | res_cols: []const ColumnDefinition41, // columns that would be returned when executing the query 385 | 386 | pub fn init(ok_packet: *const Packet, conn: *Conn, allocator: std.mem.Allocator) !PreparedStatement { 387 | const prep_ok = PrepareOk.init(ok_packet, conn.capabilities); 388 | 389 | const col_count = prep_ok.num_params + prep_ok.num_columns; 390 | 391 | const packets = try allocator.alloc(Packet, col_count); 392 | @memset(packets, .{ .sequence_id = 0, .payload = &.{} }); 393 | errdefer { 394 | for (packets) |packet| { 395 | packet.deinit(allocator); 396 | } 397 | allocator.free(packets); 398 | } 399 | 400 | const col_defs = try allocator.alloc(ColumnDefinition41, col_count); 401 | errdefer allocator.free(col_defs); 402 | 403 | for (packets, col_defs) |*packet, *col_def| { 404 | packet.* = try (try conn.readPacket()).cloneAlloc(allocator); 405 | col_def.* = ColumnDefinition41.init(packet); 406 | } 407 | 408 | return .{ 409 | .prep_ok = prep_ok, 410 | .packets = packets, 411 | .col_defs = col_defs, 412 | .params = col_defs[0..prep_ok.num_params], 413 | .res_cols = col_defs[prep_ok.num_params..], 414 | }; 415 | } 416 | 417 | fn deinit(prep_stmt: *const PreparedStatement, allocator: std.mem.Allocator) void { 418 | allocator.free(prep_stmt.col_defs); 419 | for (prep_stmt.packets) |packet| { 420 | packet.deinit(allocator); 421 | } 422 | allocator.free(prep_stmt.packets); 423 | } 424 | }; 425 | 426 | pub fn ResultRowIter(comptime T: type) type { 427 | return struct { 428 | result_set: *const ResultSet(T), 429 | 430 | pub fn next(iter: *const ResultRowIter(T)) !?T { 431 | const row_res = try iter.result_set.readRow(); 432 | return switch (row_res) { 433 | .ok => return null, 434 | .err => |err| err.asError(), 435 | .row => |row| row, 436 | }; 437 | } 438 | 439 | pub fn tableStructs(iter: *const ResultRowIter(BinaryResultRow), comptime Struct: type, allocator: std.mem.Allocator) !TableStructs(Struct) { 440 | return TableStructs(Struct).init(iter, allocator); 441 | } 442 | }; 443 | } 444 | 445 | pub const TableTexts = struct { 446 | packet_list: std.ArrayList(Packet), 447 | 448 | flattened: []const ?[]const u8, 449 | table: []const []const ?[]const u8, 450 | 451 | fn init(packet_list: std.ArrayList(Packet), allocator: std.mem.Allocator, n_cols: usize) !TableTexts { 452 | var table = try allocator.alloc([]?[]const u8, packet_list.items.len); // TODO: alloc once instead 453 | errdefer allocator.free(table); 454 | var flattened = try allocator.alloc(?[]const u8, packet_list.items.len * n_cols); 455 | 456 | for (packet_list.items, 0..) |packet, i| { 457 | const dest_row = flattened[i * n_cols .. (i + 1) * n_cols]; 458 | scanTextResultRow(dest_row, &packet); 459 | table[i] = dest_row; 460 | } 461 | 462 | return .{ 463 | .packet_list = packet_list, 464 | .flattened = flattened, 465 | .table = table, 466 | }; 467 | } 468 | 469 | pub fn deinit(t: *const TableTexts, allocator: std.mem.Allocator) void { 470 | deinitOwnedPacketList(t.packet_list); 471 | allocator.free(t.table); 472 | allocator.free(t.flattened); 473 | } 474 | 475 | pub fn debugPrint(t: *const TableTexts) !void { 476 | const w = std.io.getStdOut().writer(); 477 | for (t.table, 0..) |row, i| { 478 | try w.print("row: {d} -> ", .{i}); 479 | try w.print("|", .{}); 480 | for (row) |elem| { 481 | try w.print("{?s}", .{elem}); 482 | try w.print("|", .{}); 483 | } 484 | try w.print("\n", .{}); 485 | } 486 | } 487 | }; 488 | 489 | pub fn TableStructs(comptime Struct: type) type { 490 | return struct { 491 | struct_list: std.ArrayList(Struct), 492 | 493 | pub fn init(iter: *const ResultRowIter(BinaryResultRow), allocator: std.mem.Allocator) !TableStructs(Struct) { 494 | var struct_list = std.ArrayList(Struct).init(allocator); 495 | while (try iter.next()) |row| { 496 | const new_struct_ptr = try struct_list.addOne(); 497 | try conversion.scanBinResultRow(new_struct_ptr, &row.packet, row.col_defs, allocator); 498 | } 499 | return .{ .struct_list = struct_list }; 500 | } 501 | 502 | pub fn deinit(t: *const TableStructs(Struct), allocator: std.mem.Allocator) void { 503 | for (t.struct_list.items) |s| { 504 | BinaryResultRow.structFreeDynamic(s, allocator); 505 | } 506 | t.struct_list.deinit(); 507 | } 508 | 509 | pub fn debugPrint(t: *const TableStructs(Struct)) !void { 510 | const w = std.io.getStdOut().writer(); 511 | for (t.struct_list.items, 0..) |row, i| { 512 | try w.print("row: {any} -> ", .{i}); 513 | try w.print("{any}", .{row}); 514 | try w.print("\n", .{}); 515 | } 516 | } 517 | }; 518 | } 519 | -------------------------------------------------------------------------------- /src/result_meta.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const protocol = @import("./protocol.zig"); 3 | const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41; 4 | const Conn = @import("./conn.zig").Conn; 5 | 6 | pub const ResultMeta = struct { 7 | raw: std.ArrayList(u8), 8 | col_defs: std.ArrayList(ColumnDefinition41), 9 | 10 | pub fn init(allocator: std.mem.Allocator) ResultMeta { 11 | return ResultMeta{ 12 | .raw = std.ArrayList(u8).init(allocator), 13 | .col_defs = std.ArrayList(ColumnDefinition41).init(allocator), 14 | }; 15 | } 16 | 17 | pub fn deinit(r: *const ResultMeta) void { 18 | r.raw.deinit(); 19 | r.col_defs.deinit(); 20 | } 21 | 22 | pub inline fn readPutResultColumns(r: *ResultMeta, c: *Conn, n: usize) !void { 23 | r.raw.clearRetainingCapacity(); 24 | r.col_defs.clearRetainingCapacity(); 25 | 26 | const col_defs = try r.col_defs.addManyAsSlice(n); 27 | for (col_defs) |*col_def| { 28 | var packet = try c.readPacket(); 29 | const payload_owned = try r.raw.addManyAsSlice(packet.payload.len); 30 | @memcpy(payload_owned, packet.payload); 31 | packet.payload = payload_owned; 32 | col_def.init2(&packet); 33 | } 34 | } 35 | }; 36 | -------------------------------------------------------------------------------- /src/temporal.zig: -------------------------------------------------------------------------------- 1 | // Type for MYSQL_TYPE_DATE, MYSQL_TYPE_DATETIME and MYSQL_TYPE_TIMESTAMP, i.e. When was it? 2 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row_value_date 3 | pub const DateTime = struct { 4 | year: u16 = 0, 5 | month: u8 = 0, 6 | day: u8 = 0, 7 | hour: u8 = 0, 8 | minute: u8 = 0, 9 | second: u8 = 0, 10 | microsecond: u32 = 0, 11 | }; 12 | 13 | // Type for MYSQL_TYPE_TIME, i.e. How long did it take? 14 | // `Time` is ambigious and confusing, `Duration` was chosen as the name instead 15 | // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row_value_time 16 | pub const Duration = struct { 17 | is_negative: u8 = 0, // 1 if minus, 0 for plus 18 | days: u32 = 0, 19 | hours: u8 = 0, 20 | minutes: u8 = 0, 21 | seconds: u8 = 0, 22 | microseconds: u32 = 0, 23 | }; 24 | -------------------------------------------------------------------------------- /src/utils.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | fn numSlice(comptime N: usize, shape: *const [N]usize) usize { 4 | return switch (N) { 5 | 0 => return 0, 6 | 1 => return shape[0], 7 | else => numSlice(N - 1, shape[1..]) * shape[0] + shape[0], 8 | }; 9 | } 10 | 11 | // General purpose Multi-Dimensional Array allocating function 12 | // T: type of the element 13 | // N: number of dimensions 14 | // shape: shape of the ndarray 15 | // does not work on packed structs 16 | pub fn ndArrayAlloc(comptime T: type, comptime N: usize, shape: *const [N]usize, allocator: std.mem.Allocator) !struct { NdSlice(T, N), []u8 } { 17 | std.debug.assert(N > 0); 18 | const num_elem = blk: { 19 | var res = shape[0]; 20 | inline for (shape[1..]) |n| { 21 | res *= n; 22 | } 23 | break :blk res; 24 | }; 25 | 26 | // Extra Allocation for Slices 27 | // *[N]T => 0 + num_elem * sizeof(T) 28 | // *[M]*[N]T => M * sizeof([]T) + num_elem * sizeof(T) 29 | // *[O]*[M]*[N]T => O * sizeof([][]T) + M * sizeof([]T) + num_elem * sizeof(T) 30 | // ... 31 | const num_slice = numSlice(N - 1, shape[0 .. shape.len - 1]); 32 | 33 | const size_of_slices = num_slice * @sizeOf([]void); 34 | const size_of_elems = num_elem * @sizeOf(T); 35 | 36 | const raw = try allocator.alloc(u8, (size_of_elems + size_of_slices)); 37 | 38 | const refs = raw[0..size_of_slices]; 39 | const elems = std.mem.bytesAsSlice(T, raw[size_of_slices..]); 40 | 41 | return .{ setNdArraySlices(T, N, shape, @alignCast(elems), refs), raw }; 42 | } 43 | 44 | fn NdSlice(comptime T: type, comptime N: usize) type { 45 | // N: 0, T 46 | // N: 1, []T 47 | // N: 2, [][]T 48 | // ... 49 | 50 | switch (N) { 51 | 0 => return T, 52 | else => return []NdSlice(T, N - 1), 53 | } 54 | } 55 | 56 | fn setNdArraySlices( 57 | comptime T: type, 58 | comptime N: usize, 59 | shape: *const [N]usize, // {2,3,4} 60 | elems: []T, // .{ T ** 24 } 61 | refs: []u8, 62 | ) NdSlice(T, N) { 63 | std.debug.assert(N > 0); 64 | if (N == 1) { 65 | return elems; 66 | } 67 | 68 | const divider = shape[0] * @sizeOf([]void); 69 | const parent_refs = refs[0..divider]; 70 | const children_refs = refs[divider..]; 71 | 72 | const res = std.mem.bytesAsSlice(NdSlice(T, N - 1), parent_refs); // [][][]T 73 | for (res, 0..) |*elem, i| { // elem: [][]T 74 | const next_refs = blk: { 75 | if (N == 2) { 76 | break :blk &.{}; 77 | } 78 | const ref_start = i * shape[1] * @sizeOf([]void); 79 | const ref_end = (i + 1) * shape[1] * @sizeOf([]void); 80 | break :blk children_refs[ref_start..ref_end]; 81 | }; 82 | 83 | const elem_start = i * elems.len / shape[0]; 84 | const elem_end = (i + 1) * elems.len / shape[0]; 85 | 86 | elem.* = setNdArraySlices(T, N - 1, shape[1..], elems[elem_start..elem_end], next_refs); 87 | } 88 | 89 | return @alignCast(res); 90 | } 91 | 92 | const MyStruct = struct { a: u8, b: u16, c: f32, d: f64, e: u64, f: u64 }; 93 | 94 | test "ndArrayAlloc - 1D" { 95 | const shape = &[_]usize{3}; 96 | const nd, const raw = try ndArrayAlloc(MyStruct, shape.len, shape, std.testing.allocator); 97 | defer std.testing.allocator.free(raw); 98 | for (nd) |*elem| { 99 | elem.* = .{ .a = 1, .b = 2, .c = 3.0, .d = 4.0, .e = 5, .f = 6 }; 100 | } 101 | } 102 | 103 | test "ndArrayAlloc - 2D" { 104 | const shape = &[_]usize{ 2, 3 }; 105 | const nnd, const raw = try ndArrayAlloc(MyStruct, shape.len, shape, std.testing.allocator); 106 | defer std.testing.allocator.free(raw); 107 | for (nnd) |nd| { 108 | for (nd) |*elem| { 109 | elem.* = .{ .a = 1, .b = 2, .c = 3.0, .d = 4.0, .e = 5, .f = 6 }; 110 | } 111 | } 112 | } 113 | 114 | test "ndArrayAlloc - 3D" { 115 | const shape = &[_]usize{ 2, 3, 4 }; 116 | const nnnd, const raw = try ndArrayAlloc(MyStruct, shape.len, shape, std.testing.allocator); 117 | defer std.testing.allocator.free(raw); 118 | for (nnnd) |nnd| { 119 | for (nnd) |nd| { 120 | for (nd) |*elem| { 121 | elem.* = .{ .a = 1, .b = 2, .c = 3.0, .d = 4.0, .e = 5, .f = 6 }; 122 | } 123 | } 124 | } 125 | } 126 | 127 | // broken, todo:use hashmap to check duplication pointer 128 | // test "ndArrayAlloc - 4D" { 129 | // const shape = &[_]usize{ 2, 3, 4, 5 }; 130 | // const nnnnd, const raw = try ndArrayAlloc(MyStruct, shape.len, shape, std.testing.allocator); 131 | // defer std.testing.allocator.free(raw); 132 | // for (nnnnd) |nnnd| { 133 | // for (nnnd) |nnd| { 134 | // for (nnd) |nd| { 135 | // for (nd) |*elem| { 136 | // elem.* = .{ .a = 1, .b = 2, .c = 3.0, .d = 4.0, .e = 5, .f = 6 }; 137 | // } 138 | // } 139 | // } 140 | // } 141 | // } 142 | --------------------------------------------------------------------------------