├── .travis.yml ├── README.md ├── asyncmysql.nim ├── asyncmysql.nimble ├── demo.nim └── test_sql.nim /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | os: linux 3 | dist: xenial 4 | 5 | matrix: 6 | include: 7 | # Build and test against a range of versions 8 | - env: NIMVERSION=1.4.4 9 | services: 10 | mysql 11 | - env: NIMVERSION=1.2.8 12 | services: 13 | mysql 14 | - env: NIMVERSION=1.0.2 15 | services: 16 | mysql 17 | 18 | install: 19 | - curl https://nim-lang.org/download/nim-$NIMVERSION-linux_x64.tar.xz | xzcat | tar -C "$HOME" -xf - 20 | - export PATH=$HOME/nim-$NIMVERSION/bin:$HOME/.nimble/bin:$PATH 21 | - echo "export PATH=$PATH" >> ~/.profile 22 | 23 | before_script: 24 | - mysql -h 127.0.0.1 -u root -e "create database if not exists test;" 25 | - mysql -h 127.0.0.1 -u root -e "create user 'nimtest'@'127.0.0.1' identified WITH mysql_native_password by '123456';" 26 | - mysql -h 127.0.0.1 -u root -e "grant all on test.* to 'nimtest'@'127.0.0.1';" 27 | 28 | script: 29 | - nimble check 30 | - nim c -d:test -d:ssl -r asyncmysql.nim 31 | - nim c -d:test test_sql.nim 32 | - ./test_sql --no-ssl -D test -h localhost -u nimtest --password 123456 33 | - nim c -d:ssl -d:test test_sql.nim 34 | - ./test_sql --no-ssl -D test -h localhost -u nimtest --password 123456 35 | - ./test_sql --ssl --allow-mitm -D test -h localhost -u nimtest --password 123456 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Non-blocking mysql client for Nim 2 | ================================= 3 | 4 | [![Build Status](https://travis-ci.org/wiml/nim-asyncmysql.svg?branch=master)](https://travis-ci.org/wiml/nim-asyncmysql) 5 | 6 | This is a scratch-written pure-[Nim][nimlang] implementation of the client 7 | side of the MySQL database protocol (also compatible 8 | with MariaDB, etc.). It's based on the `asyncdispatch` and 9 | `asyncnet` modules and should be a fully non-blocking, asynchronous 10 | library. 11 | 12 | The library implements both the 13 | "[text protocol](https://dev.mysql.com/doc/internals/en/com-query.html)" 14 | (send a simple string query, get back results as strings) 15 | and the 16 | "[binary protocol](https://dev.mysql.com/doc/internals/en/prepared-statements.html)" 17 | (get a prepared statement handle from a string with 18 | placeholders; send a set of value bindings, get back results 19 | as various datatypes approximating what the server is 20 | using). 21 | 22 | Other than ordinary queries, it does not support various 23 | other commands that might be useful. It doesn't support 24 | old versions of the server (pre-4.1) or probably several other things. 25 | It was primarily an exercise in learning Nim. 26 | 27 | Notes and Deficiencies 28 | ---------------------- 29 | 30 | For practical asynchronous use, some kind of turnstile mechanism needs 31 | to exist in order to prevent different requests from stomping on 32 | each other. It might make sense to combine this with some kind of 33 | transaction support. 34 | 35 | The API presented by this module is very specific to MySQL. A more 36 | generic async DB API would be nice. 37 | That API would ideally be a separate layer on top of the MySQL-specific API. 38 | 39 | Long packets (more than 2^24-1 bytes) are not handled correctly. 40 | 41 | The compressed protocol is not supported--- I'm not sure if this is 42 | actually a deficiency. As a workaround, SSL with a null cipher and 43 | compression could be used. 44 | 45 | For local (unix-domain) connections to work, you would need to extend Nim's 46 | socket modules to support those. 47 | 48 | ### Binary protocol 49 | 50 | Date and time values are not supported (but not hard to support). You 51 | can always CAST to a string or seconds-since-the-epoch format on the 52 | server side, though. 53 | 54 | The protocol allows streaming large values to the server (if, for example, 55 | you are inserting a large BLOB) and this could be implemented elegantly 56 | as parameter that lazily generates strings. 57 | 58 | Cursors, FETCH, and the like are not implemented. 59 | 60 | [nimlang]: http://nim-lang.org/ 61 | -------------------------------------------------------------------------------- /asyncmysql.nim: -------------------------------------------------------------------------------- 1 | ## This module implements (a subset of) the MySQL/MariaDB client 2 | ## protocol based on asyncnet and asyncdispatch. 3 | ## 4 | ## No attempt is made to make this look like the C-language 5 | ## libmysql API. 6 | ## 7 | ## This is currently somewhat experimental. 8 | ## 9 | ## Copyright (c) 2015,2020 William Lewis 10 | ## 11 | 12 | {.experimental: "notnil".} 13 | import asyncnet, asyncdispatch 14 | import strutils 15 | import std/sha1 as sha1 16 | from endians import nil 17 | from math import fcNormal, fcZero, fcNegZero, fcSubnormal, fcNan, fcInf, fcNegInf 18 | 19 | when defined(ssl): 20 | import net # needed for the SslContext type 21 | 22 | when isMainModule: 23 | import unittest 24 | proc hexstr(s: string): string 25 | 26 | # These are protocol constants; see 27 | # https://dev.mysql.com/doc/internals/en/overview.html 28 | 29 | const 30 | ResponseCode_OK : uint8 = 0 31 | ResponseCode_EOF : uint8 = 254 # Deprecated in mysql 5.7.5 32 | ResponseCode_ERR : uint8 = 255 33 | 34 | NullColumn = char(0xFB) 35 | 36 | LenEnc_16 = 0xFC 37 | LenEnc_24 = 0xFD 38 | LenEnc_64 = 0xFE 39 | 40 | HandshakeV10 : uint8 = 0x0A # Initial handshake packet since MySQL 3.21 41 | 42 | Charset_swedish_ci : uint8 = 0x08 43 | Charset_utf8_ci : uint8 = 0x21 44 | Charset_binary : uint8 = 0x3f 45 | 46 | type 47 | # These correspond to the bits in the capability words, 48 | # and the CLIENT_FOO_BAR definitions in mysql. We rely on 49 | # Nim's set representation being compatible with the 50 | # C bit-masking convention. 51 | Cap {.pure.} = enum 52 | longPassword = 0 # new more secure passwords 53 | foundRows = 1 # Found instead of affected rows 54 | longFlag = 2 # Get all column flags 55 | connectWithDb = 3 # One can specify db on connect 56 | noSchema = 4 # Don't allow database.table.column 57 | compress = 5 # Can use compression protocol 58 | odbc = 6 # Odbc client 59 | localFiles = 7 # Can use LOAD DATA LOCAL 60 | ignoreSpace = 8 # Ignore spaces before '(' 61 | protocol41 = 9 # New 4.1 protocol 62 | interactive = 10 # This is an interactive client 63 | ssl = 11 # Switch to SSL after handshake 64 | ignoreSigpipe = 12 # IGNORE sigpipes 65 | transactions = 13 # Client knows about transactions 66 | reserved = 14 # Old flag for 4.1 protocol 67 | secureConnection = 15 # Old flag for 4.1 authentication 68 | multiStatements = 16 # Enable/disable multi-stmt support 69 | multiResults = 17 # Enable/disable multi-results 70 | psMultiResults = 18 # Multi-results in PS-protocol 71 | pluginAuth = 19 # Client supports plugin authentication 72 | connectAttrs = 20 # Client supports connection attributes 73 | pluginAuthLenencClientData = 21 # Enable authentication response packet to be larger than 255 bytes. 74 | canHandleExpiredPasswords = 22 # Don't close the connection for a connection with expired password. 75 | sessionTrack = 23 76 | deprecateEof = 24 # Client no longer needs EOF packet 77 | sslVerifyServerCert = 30 78 | rememberOptions = 31 79 | 80 | Status {.pure.} = enum 81 | inTransaction = 0 # a transaction is active 82 | autoCommit = 1 # auto-commit is enabled 83 | moreResultsExist = 3 84 | noGoodIndexUsed = 4 85 | noIndexUsed = 5 86 | cursorExists = 6 # Used by Binary Protocol Resultset 87 | lastRowSent = 7 88 | dbDropped = 8 89 | noBackslashEscapes = 9 90 | metadataChanged = 10 91 | queryWasSlow = 11 92 | psOutParams = 12 93 | inTransactionReadOnly = 13 # in a read-only transaction 94 | sessionStateChanged = 14 # connection state information has changed 95 | 96 | # These correspond to the CMD_FOO definitions in mysql. 97 | # Commands marked "internal to the server", and commands 98 | # only used by the replication protocol, are commented out 99 | Command {.pure.} = enum 100 | # sleep = 0 101 | quiT = 1 102 | initDb = 2 103 | query = 3 104 | fieldList = 4 105 | createDb = 5 106 | dropDb = 6 107 | refresh = 7 108 | shutdown = 8 109 | statistics = 9 110 | processInfo = 10 111 | # connect = 11 112 | processKill = 12 113 | debug = 13 114 | ping = 14 115 | # time = 15 116 | # delayedInsert = 16 117 | changeUser = 17 118 | 119 | # Replication commands 120 | # binlogDump = 18 121 | # tableDump = 19 122 | # connectOut = 20 123 | # registerSlave = 21 124 | # binlogDumpGtid = 30 125 | 126 | # Prepared statements 127 | statementPrepare = 22 128 | statementExecute = 23 129 | statementSendLongData = 24 130 | statementClose = 25 131 | statementReset = 26 132 | 133 | # Stored procedures 134 | setOption = 27 135 | statementFetch = 28 136 | 137 | # daemon = 29 138 | resetConnection = 31 139 | 140 | FieldFlag* {.pure.} = enum 141 | notNull = 0 # Field can't be NULL 142 | primaryKey = 1 # Field is part of a primary key 143 | uniqueKey = 2 # Field is part of a unique key 144 | multipleKey = 3 # Field is part of a key 145 | blob = 4 # Field is a blob 146 | unsigned = 5 # Field is unsigned 147 | zeroFill = 6 # Field is zerofill 148 | binary = 7 # Field is binary 149 | 150 | # The following are only sent to new clients (what is "new"? 4.1+?) 151 | enumeration = 8 # field is an enum 152 | autoIncrement = 9 # field is a autoincrement field 153 | timeStamp = 10 # Field is a timestamp 154 | isSet = 11 # Field is a set 155 | noDefaultValue = 12 # Field doesn't have default value 156 | onUpdateNow = 13 # Field is set to NOW on UPDATE 157 | isNum = 15 # Field is num (for clients) 158 | 159 | FieldType* = enum 160 | fieldTypeDecimal = uint8(0) 161 | fieldTypeTiny = uint8(1) 162 | fieldTypeShort = uint8(2) 163 | fieldTypeLong = uint8(3) 164 | fieldTypeFloat = uint8(4) 165 | fieldTypeDouble = uint8(5) 166 | fieldTypeNull = uint8(6) 167 | fieldTypeTimestamp = uint8(7) 168 | fieldTypeLongLong = uint8(8) 169 | fieldTypeInt24 = uint8(9) 170 | fieldTypeDate = uint8(10) 171 | fieldTypeTime = uint8(11) 172 | fieldTypeDateTime = uint8(12) 173 | fieldTypeYear = uint8(13) 174 | fieldTypeVarchar = uint8(15) 175 | fieldTypeBit = uint8(16) 176 | fieldTypeNewDecimal = uint8(246) 177 | fieldTypeEnum = uint8(247) 178 | fieldTypeSet = uint8(248) 179 | fieldTypeTinyBlob = uint8(249) 180 | fieldTypeMediumBlob = uint8(250) 181 | fieldTypeLongBlob = uint8(251) 182 | fieldTypeBlob = uint8(252) 183 | fieldTypeVarString = uint8(253) 184 | fieldTypeString = uint8(254) 185 | fieldTypeGeometry = uint8(255) 186 | 187 | CursorType* {.pure.} = enum 188 | noCursor = 0 189 | readOnly = 1 190 | forUpdate = 2 191 | scrollable = 3 192 | 193 | # This represents a value returned from the server when using 194 | # the prepared statement / binary protocol. For convenience's sake 195 | # we combine multiple wire types into the nearest Nim type. 196 | ResultValueType = enum 197 | rvtNull, 198 | rvtInteger, 199 | rvtLong, 200 | rvtULong, 201 | rvtFloat32, 202 | rvtFloat64, 203 | rvtDate, 204 | rvtTime, 205 | rvtDateTime, 206 | rvtString, 207 | rvtBlob 208 | ResultValue* = object 209 | ## A value returned from the server when using the prepared statement 210 | ## (binary) protocol. This might contain a numeric or string type 211 | ## or NULL. To check for NULL, use `isNil`; attempts to read a value 212 | ## from a NULL result will result in a `ValueError`. 213 | case typ: ResultValueType 214 | of rvtInteger: 215 | intVal: int 216 | of rvtLong: 217 | longVal: int64 218 | of rvtULong: 219 | uLongVal: uint64 220 | of rvtString, rvtBlob: 221 | strVal: string 222 | of rvtNull: 223 | discard 224 | of rvtFloat32: 225 | floatVal: float32 226 | of rvtFloat64: 227 | doubleVal: float64 228 | of rvtDate, rvtTime, rvtDateTime: 229 | discard # TODO 230 | 231 | ResultString* = object 232 | ## A value returned from the server when using the text protocol. 233 | ## This contains either a string or an SQL NULL. 234 | case isNull: bool 235 | of false: 236 | value: string 237 | of true: 238 | discard 239 | 240 | ParamBindingType = enum 241 | paramNull, 242 | paramString, 243 | paramBlob, 244 | paramInt, 245 | paramUInt, 246 | paramFloat, 247 | paramDouble, 248 | # paramLazyString, paramLazyBlob, 249 | ParameterBinding* = object 250 | ## This represents a value we're sending to the server as a parameter. 251 | ## Since parameters' types are always sent along with their values, 252 | ## we choose the wire type of integers based on the particular value 253 | ## we're sending each time. 254 | case typ: ParamBindingType 255 | of paramNull: 256 | discard 257 | of paramString, paramBlob: 258 | strVal: string 259 | of paramInt: 260 | intVal: int64 261 | of paramUInt: 262 | uintVal: uint64 263 | of paramFloat: 264 | floatVal: float32 265 | of paramDouble: 266 | doubleVal: float64 267 | 268 | type 269 | nat24 = range[0 .. 16777215] 270 | Connection* = ref ConnectionObj ## A database connection handle. 271 | ConnectionObj* = object of RootObj 272 | socket: AsyncSocket not nil # Bytestream connection 273 | packet_number: uint8 # Next expected seq number (mod-256) 274 | 275 | # Information from the connection setup 276 | server_version*: string 277 | thread_id*: uint32 278 | server_caps: set[Cap] 279 | 280 | # Other connection parameters 281 | client_caps: set[Cap] 282 | 283 | ProtocolError* = object of IOError 284 | ## ProtocolError is thrown if we get something we don't understand 285 | ## or expect. This is generally a fatal error as far as this connection 286 | ## is concerned, since we might have lost framing, packet sequencing, 287 | ## etc.. Unexpected connection closure will also result in this exception. 288 | 289 | # Server response packets: OK and EOF 290 | ResponseOK* {.final.} = object 291 | ## Status information returned from the server after each successful 292 | ## command. 293 | eof : bool # True if EOF packet, false if OK packet 294 | affected_rows* : Natural 295 | last_insert_id* : Natural 296 | status_flags* : set[Status] 297 | warning_count* : Natural 298 | info* : string 299 | # session_state_changes: seq[ ... ] 300 | 301 | # Server response packet: ERR (which can be thrown as an exception) 302 | ResponseERR* = object of CatchableError 303 | ## This exception is thrown when a command fails. 304 | error_code*: uint16 ## A MySQL-specific error number 305 | sqlstate*: string ## An ANSI SQL state code 306 | 307 | ColumnDefinition* {.final.} = object 308 | catalog* : string 309 | schema* : string 310 | table* : string 311 | orig_table* : string 312 | name* : string 313 | orig_name* : string 314 | 315 | charset : int16 316 | length* : uint32 317 | column_type* : FieldType 318 | flags* : set[FieldFlag] 319 | decimals* : int 320 | 321 | ResultSet*[T] {.final.} = object 322 | status* : ResponseOK 323 | columns* : seq[ColumnDefinition] 324 | rows* : seq[seq[T]] 325 | 326 | PreparedStatement* = ref PreparedStatementObj 327 | PreparedStatementObj = object 328 | statement_id: array[4, char] 329 | parameters: seq[ColumnDefinition] 330 | columns: seq[ColumnDefinition] 331 | warnings: Natural 332 | 333 | type sqlNull = distinct tuple[] 334 | const SQLNULL*: sqlNull = sqlNull( () ) 335 | ## `SQLNULL` is a singleton value corresponding to SQL's NULL. 336 | ## This is used to send a NULL value for a parameter when 337 | ## executing a prepared statement. 338 | 339 | const advertisedMaxPacketSize: uint32 = 65536 # max packet size, TODO: what should I put here? 340 | 341 | # ###################################################################### 342 | # 343 | # Forward declarations 344 | 345 | proc selectDatabase*(conn: Connection, database: string): Future[ResponseOK] 346 | 347 | # ###################################################################### 348 | # 349 | # Basic datatype packers/unpackers 350 | 351 | # Integers 352 | 353 | proc scanU32(buf: string, pos: int): uint32 = 354 | result = uint32(buf[pos]) + `shl`(uint32(buf[pos+1]), 8'u32) + (uint32(buf[pos+2]) shl 16'u32) + (uint32(buf[pos+3]) shl 24'u32) 355 | 356 | proc putU32(buf: var string, val: uint32) = 357 | buf.add( char( val and 0xff ) ) 358 | buf.add( char( (val shr 8) and 0xff ) ) 359 | buf.add( char( (val shr 16) and 0xff ) ) 360 | buf.add( char( (val shr 24) and 0xff ) ) 361 | 362 | proc scanU16(buf: string, pos: int): uint16 = 363 | result = uint16(buf[pos]) + (uint16(buf[pos+1]) shl 8'u16) 364 | proc putU16(buf: var string, val: uint16) = 365 | buf.add( char( val and 0xFF ) ) 366 | buf.add( char( (val shr 8) and 0xFF ) ) 367 | 368 | proc putU8(buf: var string, val: uint8) {.inline.} = 369 | buf.add( char(val) ) 370 | proc putU8(buf: var string, val: range[0..255]) {.inline.} = 371 | buf.add( char(val) ) 372 | 373 | proc scanU64(buf: string, pos: int): uint64 = 374 | let l32 = scanU32(buf, pos) 375 | let h32 = scanU32(buf, pos+4) 376 | return uint64(l32) + ( (uint64(h32) shl 32 ) ) 377 | 378 | proc putS64(buf: var string, val: int64) = 379 | let compl: uint64 = cast[uint64](val) 380 | buf.putU32(uint32(compl and 0xFFFFFFFF'u64)) 381 | buf.putU32(uint32(compl shr 32)) 382 | 383 | proc scanLenInt(buf: string, pos: var int): int = 384 | let b1 = uint8(buf[pos]) 385 | if b1 < 251: 386 | inc(pos) 387 | return int(b1) 388 | if b1 == LenEnc_16: 389 | result = int(uint16(buf[pos+1]) + ( uint16(buf[pos+2]) shl 8 )) 390 | pos = pos + 3 391 | return 392 | if b1 == LenEnc_24: 393 | result = int(uint32(buf[pos+1]) + ( uint32(buf[pos+2]) shl 8 ) + ( uint32(buf[pos+3]) shl 16 )) 394 | pos = pos + 4 395 | return 396 | return -1 397 | 398 | proc putLenInt(buf: var string, val: int) = 399 | if val < 0: 400 | raise newException(ProtocolError, "trying to send a negative lenenc-int") 401 | elif val < 251: 402 | buf.add( char(val) ) 403 | elif val < 65536: 404 | buf.add( char(LenEnc_16) ) 405 | buf.add( char( val and 0xFF ) ) 406 | buf.add( char( (val shr 8) and 0xFF ) ) 407 | elif val <= 0xFFFFFF: 408 | buf.add( char(LenEnc_24) ) 409 | buf.add( char( val and 0xFF ) ) 410 | buf.add( char( (val shr 8) and 0xFF ) ) 411 | buf.add( char( (val shr 16) and 0xFF ) ) 412 | else: 413 | raise newException(ProtocolError, "lenenc-int too long for me!") 414 | 415 | # Strings 416 | proc scanNulString(buf: string, pos: var int): string = 417 | result = "" 418 | while buf[pos] != char(0): 419 | result.add(buf[pos]) 420 | inc(pos) 421 | inc(pos) 422 | proc scanNulStringX(buf: string, pos: var int): string = 423 | result = "" 424 | while pos <= high(buf) and buf[pos] != char(0): 425 | result.add(buf[pos]) 426 | inc(pos) 427 | inc(pos) 428 | 429 | proc putNulString(buf: var string, val: string) = 430 | buf.add(val) 431 | buf.add( char(0) ) 432 | 433 | proc scanLenStr(buf: string, pos: var int): string = 434 | let slen = scanLenInt(buf, pos) 435 | if slen < 0: 436 | raise newException(ProtocolError, "lenenc-int: is 0x" & toHex(int(buf[pos]), 2)) 437 | result = substr(buf, pos, pos+slen-1) 438 | pos = pos + slen 439 | 440 | proc putLenStr(buf: var string, val: string) = 441 | putLenInt(buf, val.len) 442 | buf.add(val) 443 | 444 | 445 | # Floating point numbers. We assume that the wire protocol is always 446 | # little-endian IEEE-754 (because all the world's a Vax^H^H^H 386), 447 | # and we assume that the native representation is also IEEE-754 (but 448 | # we check that second assumption in our unit tests). 449 | proc scanIEEE754Single(buf: string, pos: int): float32 = 450 | endians.littleEndian32(addr(result), unsafeAddr(buf[pos])) 451 | proc scanIEEE754Double(buf: string, pos: int): float64 = 452 | endians.littleEndian64(addr(result), unsafeAddr(buf[pos])) 453 | 454 | proc putIEEE754(buf: var string, val: float32) = 455 | let oldLen = buf.len() 456 | buf.setLen(oldLen + 4) 457 | endians.littleEndian32(addr(buf[oldLen]), unsafeAddr val) 458 | proc putIEEE754(buf: var string, val: float64) = 459 | let oldLen = buf.len() 460 | buf.setLen(oldLen + 8) 461 | endians.littleEndian64(addr(buf[oldLen]), unsafeAddr val) 462 | 463 | when isMainModule: suite "Packing/unpacking of primitive types": 464 | test "Integers": 465 | var buf: string = "" 466 | putLenInt(buf, 0) 467 | putLenInt(buf, 1) 468 | putLenInt(buf, 250) 469 | putLenInt(buf, 251) 470 | putLenInt(buf, 252) 471 | putLenInt(buf, 512) 472 | putLenInt(buf, 640) 473 | putLenInt(buf, 65535) 474 | putLenInt(buf, 65536) 475 | putLenInt(buf, 15715755) 476 | putU32(buf, uint32(65535)) 477 | putU32(buf, uint32(65536)) 478 | putU32(buf, 0x80C00AAA'u32) 479 | check "0001fafcfb00fcfc00fc0002fc8002fcfffffd000001fdabcdefffff000000000100aa0ac080" == hexstr(buf) 480 | 481 | var pos: int = 0 482 | check 0 == scanLenInt(buf, pos) 483 | check 1 == scanLenInt(buf, pos) 484 | check 250 == scanLenInt(buf, pos) 485 | check 251 == scanLenInt(buf, pos) 486 | check 252 == scanLenInt(buf, pos) 487 | check 512 == scanLenInt(buf, pos) 488 | check 640 == scanLenInt(buf, pos) 489 | check 0x0FFFF == scanLenInt(buf, pos) 490 | check 0x10000 == scanLenInt(buf, pos) 491 | check 15715755 == scanLenInt(buf, pos) 492 | check 65535'u32 == scanU32(buf, pos) 493 | check 65535'u16 == scanU16(buf, pos) 494 | check 255'u16 == scanU16(buf, pos+1) 495 | check 0'u16 == scanU16(buf, pos+2) 496 | pos += 4 497 | check 65536'u32 == scanU32(buf, pos) 498 | pos += 4 499 | check 0x80C00AAA'u32 == scanU32(buf, pos) 500 | pos += 4 501 | check 0x80C00AAA00010000'u64 == scanU64(buf, pos-8) 502 | check len(buf) == pos 503 | 504 | test "Integers (bit-walking tests)": 505 | for bit in 0..63: 506 | var byhand: string = "\xFF" 507 | var test: string 508 | 509 | for b_off in 0..7: 510 | if b_off == bit div 8: 511 | byhand.add(chr(0x01 shl (bit mod 8))) 512 | else: 513 | byhand.add(chr(0)) 514 | 515 | if bit < 16: 516 | let v16: uint16 = (1'u16) shl bit 517 | check scanU16(byhand, 1) == v16 518 | test = "\xFF" 519 | putU16(test, v16) 520 | test &= "\x00\x00\x00\x00\x00\x00" 521 | check test == byhand 522 | check hexstr(test) == hexstr(byhand) 523 | 524 | if bit < 32: 525 | let v32: uint32 = (1'u32) shl bit 526 | check scanU32(byhand, 1) == v32 527 | test = "\xFF" 528 | putU32(test, v32) 529 | test &= "\x00\x00\x00\x00" 530 | check test == byhand 531 | 532 | if bit < 63: 533 | test = "\xFF" 534 | putS64(test, (1'i64) shl bit) 535 | check test == byhand 536 | check hexstr(test) == hexstr(byhand) 537 | 538 | let v64: uint64 = (1'u64) shl bit 539 | check scanU64(byhand, 1) == v64 540 | 541 | const e32: float32 = 0.00000011920928955078125'f32 542 | 543 | test "Floats": 544 | var buf: string = "" 545 | 546 | putIEEE754(buf, 1.0'f32) 547 | putIEEE754(buf, e32) 548 | putIEEE754(buf, 1.0'f32 + e32) 549 | check "0000803f000000340100803f" == hexstr(buf) 550 | check: 551 | scanIEEE754Single(buf, 0) == 1.0'f32 552 | scanIEEE754Single(buf, 4) == e32 553 | scanIEEE754Single(buf, 8) == 1.0'f32 + e32 554 | 555 | # Non-word-aligned 556 | check: 557 | scanIEEE754Single("XAB\x01\x49Y", 1) == 0x81424 + 0.0625'f32 558 | 559 | test "Doubles": 560 | var buf: string = "" 561 | 562 | putIEEE754(buf, -2.0'f64) 563 | putIEEE754(buf, float64(e32)) 564 | putIEEE754(buf, 1024'f64 + float64(e32)) 565 | check "00000000000000c0000000000000803e0000080000009040" == hexstr(buf) 566 | check: 567 | scanIEEE754Double(buf, 0) == -2'f64 568 | scanIEEE754Double(buf, 8) == float64(e32) 569 | scanIEEE754Double(buf, 16) == 1024'f64 + float64(e32) 570 | 571 | # Non-word-aligned 572 | check: 573 | scanIEEE754Double("XYZGFEDCB\xFA\x42QRS", 3) == float64(0x1A42434445464) + 0.4375'f64 574 | 575 | proc hexdump(buf: openarray[char], fp: File) {.used.} = 576 | var pos = low(buf) 577 | while pos <= high(buf): 578 | for i in 0 .. 15: 579 | fp.write(' ') 580 | if i == 8: fp.write(' ') 581 | let p = i+pos 582 | fp.write( if p <= high(buf): toHex(int(buf[p]), 2) else: " " ) 583 | fp.write(" |") 584 | for i in 0 .. 15: 585 | var ch = ( if (i+pos) > high(buf): ' ' else: buf[i+pos] ) 586 | if ch < ' ' or ch > '~': 587 | ch = '.' 588 | fp.write(ch) 589 | pos += 16 590 | fp.write("|\n") 591 | 592 | 593 | # ###################################################################### 594 | # 595 | # Parameter and result packers/unpackers 596 | 597 | proc addTypeUnlessNULL(p: ParameterBinding, pkt: var string) = 598 | case p.typ 599 | of paramNull: 600 | return 601 | of paramString: 602 | pkt.add(char(fieldTypeString)) 603 | pkt.add(char(0)) 604 | of paramBlob: 605 | pkt.add(char(fieldTypeBlob)) 606 | pkt.add(char(0)) 607 | of paramInt: 608 | if p.intVal >= 0: 609 | if p.intVal < 256'i64: 610 | pkt.add(char(fieldTypeTiny)) 611 | elif p.intVal < 65536'i64: 612 | pkt.add(char(fieldTypeShort)) 613 | elif p.intVal < (65536'i64 * 65536'i64): 614 | pkt.add(char(fieldTypeLong)) 615 | else: 616 | pkt.add(char(fieldTypeLongLong)) 617 | pkt.add(char(0x80)) 618 | else: 619 | if p.intVal >= -128: 620 | pkt.add(char(fieldTypeTiny)) 621 | elif p.intVal >= -32768: 622 | pkt.add(char(fieldTypeShort)) 623 | else: 624 | pkt.add(char(fieldTypeLongLong)) 625 | pkt.add(char(0)) 626 | of paramUInt: 627 | if p.uintVal < (65536'u64 * 65536'u64): 628 | pkt.add(char(fieldTypeLong)) 629 | else: 630 | pkt.add(char(fieldTypeLongLong)) 631 | pkt.add(char(0x80)) 632 | of paramFloat: 633 | pkt.add(char(fieldTypeFloat)) 634 | pkt.add(char(0)) 635 | of paramDouble: 636 | pkt.add(char(fieldTypeDouble)) 637 | pkt.add(char(0)) 638 | 639 | proc addValueUnlessNULL(p: ParameterBinding, pkt: var string) = 640 | case p.typ 641 | of paramNull: 642 | return 643 | of paramString, paramBlob: 644 | putLenStr(pkt, p.strVal) 645 | of paramInt: 646 | if p.intVal >= 0: 647 | pkt.putU8(p.intVal and 0xFF) 648 | if p.intVal >= 256: 649 | pkt.putU8((p.intVal shr 8) and 0xFF) 650 | if p.intVal >= 65536: 651 | pkt.putU16( ((p.intVal shr 16) and 0xFFFF).uint16 ) 652 | if p.intVal >= (65536'i64 * 65536'i64): 653 | pkt.putU32(uint32(p.intVal shr 32)) 654 | else: 655 | if p.intVal >= -128: 656 | pkt.putU8(uint8(p.intVal + 256)) 657 | elif p.intVal >= -32768: 658 | pkt.putU16(uint16(p.intVal + 65536)) 659 | else: 660 | pkt.putS64(p.intVal) 661 | of paramUInt: 662 | putU32(pkt, uint32(p.uintVal and 0xFFFFFFFF'u64)) 663 | if p.uintVal >= 0xFFFFFFFF'u64: 664 | putU32(pkt, uint32(p.uintVal shr 32)) 665 | of paramFloat: 666 | pkt.putIEEE754(p.floatVal) 667 | of paramDouble: 668 | pkt.putIEEE754(p.doubleVal) 669 | 670 | proc approximatePackedSize(p: ParameterBinding): int {.inline.} = 671 | case p.typ 672 | of paramNull: 673 | return 0 674 | of paramString, paramBlob: 675 | return 5 + len(p.strVal) 676 | of paramInt, paramUInt, paramFloat: 677 | return 4 678 | of paramDouble: 679 | return 8 680 | 681 | proc asParam*(n: sqlNull): ParameterBinding {. inline .} = ParameterBinding(typ: paramNull) 682 | 683 | proc asParam*(n: typeof(nil)): ParameterBinding {. deprecated("Do not use nil for NULL parameters, use SQLNULL") .} = ParameterBinding(typ: paramNull) 684 | 685 | proc asParam*(s: string): ParameterBinding = 686 | ParameterBinding(typ: paramString, strVal: s) 687 | 688 | proc asParam*(i: int): ParameterBinding {. inline .} = ParameterBinding(typ: paramInt, intVal: i) 689 | 690 | proc asParam*(i: uint): ParameterBinding = 691 | if i > uint(high(int)): 692 | ParameterBinding(typ: paramUInt, uintVal: uint64(i)) 693 | else: 694 | ParameterBinding(typ: paramInt, intVal: int64(i)) 695 | 696 | proc asParam*(i: int64): ParameterBinding = 697 | ParameterBinding(typ: paramInt, intVal: i) 698 | 699 | proc asParam*(i: uint64): ParameterBinding = 700 | if i > uint64(high(int)): 701 | ParameterBinding(typ: paramUInt, uintVal: i) 702 | else: 703 | ParameterBinding(typ: paramInt, intVal: int64(i)) 704 | 705 | proc asParam*(b: bool): ParameterBinding = ParameterBinding(typ: paramInt, intVal: if b: 1 else: 0) 706 | 707 | proc asParam*(f: float32): ParameterBinding {. inline .} = 708 | ParameterBinding(typ: paramFloat, floatVal:f) 709 | 710 | proc asParam*(f: float64): ParameterBinding {. inline .} = 711 | ParameterBinding(typ: paramDouble, doubleVal:f) 712 | 713 | proc isNil*(v: ResultValue): bool {.inline.} = v.typ == rvtNull 714 | 715 | proc `$`*(v: ResultValue): string = 716 | ## Produce an approximate string representation of the value. This 717 | ## should mainly be restricted to debugging uses, since it is impossible 718 | ## to distingiuish between, *e.g.*, a NULL value and the four-character 719 | ## string "NULL". 720 | case v.typ 721 | of rvtNull: 722 | return "NULL" 723 | of rvtString, rvtBlob: 724 | return v.strVal 725 | of rvtInteger: 726 | return $(v.intVal) 727 | of rvtLong: 728 | return $(v.longVal) 729 | of rvtULong: 730 | return $(v.uLongVal) 731 | of rvtFloat32: 732 | return $(v.floatVal) 733 | of rvtFloat64: 734 | return $(v.doubleVal) 735 | else: 736 | return "(unrepresentable!)" 737 | 738 | {.push overflowChecks: on .} 739 | proc toNumber[T: SomeInteger](v: ResultValue): T {.inline.} = 740 | case v.typ 741 | of rvtInteger: 742 | return T(v.intVal) 743 | of rvtLong: 744 | return T(v.longVal) 745 | of rvtULong: 746 | return T(v.uLongVal) 747 | of rvtNull: 748 | raise newException(ValueError, "NULL value") 749 | else: 750 | raise newException(ValueError, "cannot convert " & $(v.typ) & " to " & $(T)) 751 | 752 | # Converters can't be generic; we need to explicitly instantiate 753 | # the ones we think might be needed. 754 | converter asInt8*(v: ResultValue): uint8 = return toNumber[uint8](v) 755 | converter asInt*(v: ResultValue): int = return toNumber[int](v) 756 | converter asUInt*(v: ResultValue): uint = return toNumber[uint](v) 757 | converter asInt64*(v: ResultValue): int64 = return toNumber[int64](v) 758 | converter asUInt64*(v: ResultValue): uint64 = return toNumber[uint64](v) 759 | 760 | proc toFloat[T: SomeFloat](v: ResultValue): T {.inline.} = 761 | case v.typ 762 | of rvtFloat32: 763 | return v.floatVal 764 | of rvtFloat64: 765 | return v.doubleVal 766 | of rvtNULL: 767 | raise newException(ValueError, "NULL value") 768 | else: 769 | raise newException(ValueError, "cannot convert " & $(v.typ) & " to float") 770 | 771 | converter asFloat32*(v: ResultValue): float32 = toFloat[float32](v) 772 | converter asFloat64*(v: ResultValue): float64 = toFloat[float64](v) 773 | {. pop .} 774 | 775 | converter asString*(v: ResultValue): string = 776 | ## If the value is a string, return it; otherwise raise a `ValueError`. 777 | case v.typ 778 | of rvtNull: 779 | raise newException(ValueError, "NULL value") 780 | of rvtString, rvtBlob: 781 | return v.strVal 782 | else: 783 | raise newException(ValueError, "cannot convert " & $(v.typ) & " to string") 784 | 785 | converter asBool*(v: ResultValue): bool = 786 | ## If the value is numeric, return it as a boolean; otherwise 787 | ## raise a `ValueError`. Note that `NULL` is neither true nor 788 | ## false and will raise. 789 | case v.typ 790 | of rvtInteger: 791 | return v.intVal != 0 792 | of rvtLong: 793 | return v.longVal != 0 794 | of rvtULong: 795 | return v.uLongVal != 0 796 | of rvtNull: 797 | raise newException(ValueError, "NULL value") 798 | else: 799 | raise newException(ValueError, "cannot convert " & $(v.typ) & " to boolean") 800 | 801 | proc `==`*(v: ResultValue, s: string): bool = 802 | ## Compare the result value to a string. 803 | ## NULL values are not equal to any string. 804 | ## Non-string non-NULL values will result in an exception. 805 | case v.typ 806 | of rvtNull: 807 | return false 808 | of rvtString, rvtBlob: 809 | return v.strVal == s 810 | else: 811 | raise newException(ValueError, "cannot convert " & $(v.typ) & " to string") 812 | 813 | proc floatEqualsInt[F: SomeFloat, I: SomeInteger](v: F, n: I): bool = 814 | ## Compare a float to an integer. Note that this is inherently a 815 | ## dodgy operation (which is why it's not overloading `==`). Floats 816 | ## are inexact, and each float corresponds to a range of real numbers; 817 | ## for larger numbers, a single float value can be "equal to" many 818 | ## different integers. (Or maybe it''s equal to none of them if it 819 | ## can't represent any of them exactly — it really depends on what 820 | ## you're modeling with that float, doesn''t it?) Anyway, for my particular 821 | ## case I don't care about that. 822 | 823 | # Infinities, NaNs, etc., are not equal to any integer. Subnormals 824 | # are also always less than 1 (and nonzero) so cannot be integers. 825 | case math.classify(v) 826 | of fcNormal: 827 | if n == 0: 828 | return false 829 | else: 830 | return v == F(n) # kludge 831 | of fcZero, fcNegZero: 832 | return n == 0 833 | of fcSubnormal, fcNan, fcInf, fcNegInf: 834 | return false 835 | 836 | proc `==`[S: SomeSignedInt, U: SomeUnsignedInt](s: S, u: U): bool = 837 | ## Safely compare a signed and an unsigned integer of possibly 838 | ## different widths. 839 | if s < 0: 840 | return false 841 | when sizeof(U) >= sizeof(S): 842 | if u > U(high(S)): 843 | return false 844 | else: 845 | return S(u) == s 846 | else: 847 | if s > S(high(U)): 848 | return false 849 | else: 850 | return U(s) == u 851 | 852 | when (NimMajor, NimMinor) < (1, 2) and uint isnot uint64: 853 | # Support for Nim < 1.2 854 | proc `==`(a: uint, b: uint64): bool = 855 | return uint64(a) == b 856 | 857 | proc `==`*[T: SomeInteger](v: ResultValue, n: T): bool = 858 | ## Compare the result value to an integer. 859 | ## NULL values are not equal to any integer. 860 | ## Non-numeric non-NULL values (strings, etc.) will result in an exception. 861 | ## 862 | ## As a special case, this allows comparing a floating point ResultValue 863 | ## to an integer. 864 | case v.typ 865 | of rvtInteger: 866 | return v.intVal == n 867 | of rvtLong: 868 | return v.longVal == n 869 | of rvtULong: 870 | return n == v.uLongVal 871 | of rvtFloat32: 872 | return floatEqualsInt(v.floatVal, n) 873 | of rvtFloat64: 874 | return floatEqualsInt(v.doubleVal, n) 875 | of rvtNull: 876 | return false 877 | else: 878 | raise newException(ValueError, "cannot compare " & $(v.typ) & " to integer") 879 | 880 | proc `==`*[F: SomeFloat](v: ResultValue, n: F): bool = 881 | ## Compare the result value to a float. 882 | ## NULL values are not equal to anything. 883 | ## Non-float values (including integers) will result in an exception. 884 | case v.typ 885 | of rvtFloat32: 886 | return v.floatVal == n 887 | of rvtFloat64: 888 | return v.doubleVal == n 889 | of rvtNull: 890 | return false 891 | else: 892 | raise newException(ValueError, "cannot compare " & $(v.typ) & " to floating-point number") 893 | 894 | proc `==`*(v: ResultValue, b: bool): bool = 895 | ## Compare a result value to a boolean. 896 | ## 897 | ## The MySQL wire protocol does 898 | ## not have an explicit boolean type, so this tests an integer type against 899 | ## zero. NULL values are not equal to true *or* false (therefore, 900 | ## `if v == true:` is not equivalent to `if v:`: the latter will raise 901 | ## an exception if v is NULL). Non-integer values will result in an exception. 902 | if v.typ == rvtNull: 903 | return false 904 | else: 905 | return bool(v) == b 906 | 907 | proc isNil*(v: ResultString): bool {.inline.} = v.isNull 908 | 909 | proc `$`*(v: ResultString): string = 910 | ## Produce an approximate string representation of the value. This 911 | ## should mainly be restricted to debugging uses, since it is impossible 912 | ## to distingiuish between a NULL value and the four-character 913 | ## string "NULL". 914 | case v.isNull 915 | of true: 916 | return "NULL" 917 | of false: 918 | return v.value 919 | 920 | converter asString*(v: ResultString): string = 921 | ## Return the result as a string. 922 | ## Raise `ValueError` if the result is NULL. 923 | case v.isNull: 924 | of true: 925 | raise newException(ValueError, "NULL value") 926 | of false: 927 | return v.value 928 | 929 | proc `==`*(a: ResultString, b: ResultString): bool = 930 | ## Compare two result strings. **Note:** This does not 931 | ## follow SQL semantics; NULL will compare equal to NULL. 932 | case a.isNull 933 | of true: 934 | return b.isNull 935 | of false: 936 | return (not b.isNull) and (a.value == b.value) 937 | 938 | proc `==`*(a: ResultString, b: string): bool = 939 | ## Compare a result to a string. NULL results are not 940 | ## equal to any string. 941 | case a.isNull 942 | of true: 943 | return false 944 | of false: 945 | return (a.value == b) 946 | 947 | proc asResultString*(s: string): ResultString {.inline.} = 948 | ResultString(isNull: false, value: s) 949 | proc asResultString*(n: sqlNull): ResultString {.inline.} = 950 | ResultString(isNull: true) 951 | 952 | # ###################################################################### 953 | # 954 | # MySQL packet packers/unpackers 955 | 956 | proc processHeader(c: Connection, hdr: array[4, char]): nat24 = 957 | result = int32(hdr[0]) + int32(hdr[1])*256 + int32(hdr[2])*65536 958 | let pnum = uint8(hdr[3]) 959 | if pnum != c.packet_number: 960 | raise newException(ProtocolError, "Bad packet number (got sequence number " & $(pnum) & ", expected " & $(c.packet_number) & ")") 961 | c.packet_number += 1 962 | 963 | proc receivePacket(conn:Connection, drop_ok: bool = false): Future[string] {.async.} = 964 | let hdr = await conn.socket.recv(4) 965 | if len(hdr) == 0: 966 | if drop_ok: 967 | return "" 968 | else: 969 | raise newException(ProtocolError, "Connection closed") 970 | if len(hdr) != 4: 971 | raise newException(ProtocolError, "Connection closed unexpectedly") 972 | let b = cast[ptr array[4,char]](cstring(hdr)) 973 | let packet_length = conn.processHeader(b[]) 974 | if packet_length == 0: 975 | return "" 976 | result = await conn.socket.recv(packet_length) 977 | if len(result) == 0: 978 | raise newException(ProtocolError, "Connection closed unexpectedly") 979 | if len(result) != packet_length: 980 | raise newException(ProtocolError, "TODO finish this part") 981 | 982 | # Caller must have left the first four bytes of the buffer available for 983 | # us to write the packet header. 984 | proc sendPacket(conn: Connection, buf: var string, reset_seq_no = false): Future[void] = 985 | let bodylen = len(buf) - 4 986 | buf[0] = char( (bodylen and 0xFF) ) 987 | buf[1] = char( ((bodylen shr 8) and 0xFF) ) 988 | buf[2] = char( ((bodylen shr 16) and 0xFF) ) 989 | if reset_seq_no: 990 | conn.packet_number = 0 991 | buf[3] = char( conn.packet_number ) 992 | inc(conn.packet_number) 993 | # hexdump(buf, stdmsg) 994 | return conn.socket.send(buf) 995 | 996 | type 997 | greetingVars {.final.} = object 998 | scramble: string 999 | authentication_plugin: string 1000 | 1001 | # This implements the "mysql_native_password" auth plugin, 1002 | # which is the only auth we support. 1003 | proc mysql_native_password_hash(scramble: string, password: string): string = 1004 | let phash1 = sha1.Sha1Digest(sha1.secureHash(password)) 1005 | let phash2 = sha1.Sha1Digest(sha1.secureHash(cast[array[20, char]](phash1))) 1006 | 1007 | var ctx = sha1.newSha1State() 1008 | ctx.update(scramble) 1009 | ctx.update(cast[array[20, char]](phash2)) 1010 | let rhs = ctx.finalize() 1011 | 1012 | result = newString(1+high(phash1)) 1013 | for i in 0 .. high(phash1): 1014 | result[i] = char(phash1[i] xor rhs[i]) 1015 | const mysql_native_password_plugin = "mysql_native_password" 1016 | 1017 | when isMainModule: 1018 | test "Password hash": 1019 | # Test vectors captured from tcp traces of official mysql 1020 | check hexstr(mysql_native_password_hash("L\\i{NQ09k2W>p= (pos+5): 1045 | let cflags_h = scanU16(greeting, pos+3) 1046 | conn.server_caps = cast[set[Cap]]( uint32(cflags_l) + (uint32(cflags_h) shl 16) ) 1047 | 1048 | let moreScram = ( if Cap.protocol41 in conn.server_caps: int(greeting[pos+5]) else: 0 ) 1049 | if moreScram > 8: 1050 | result.scramble.add(greeting[pos + 16 .. pos + 16 + moreScram - 8 - 2]) 1051 | pos = pos + 16 + ( if moreScram < 20: 12 else: moreScram - 8 ) 1052 | 1053 | if Cap.pluginAuth in conn.server_caps: 1054 | result.authentication_plugin = scanNulStringX(greeting, pos) 1055 | 1056 | proc computeHandshakeResponse(conn: Connection, 1057 | greetingPacket: string, 1058 | username, password: string, 1059 | database: string, 1060 | starttls: bool): string = 1061 | 1062 | let greet: greetingVars = conn.parseInitialGreeting(greetingPacket) 1063 | 1064 | let server_caps = conn.server_caps 1065 | var caps: set[Cap] = { Cap.longPassword, 1066 | Cap.protocol41, 1067 | Cap.secureConnection } 1068 | if Cap.longFlag in server_caps: 1069 | incl(caps, Cap.longFlag) 1070 | 1071 | if len(database) > 0 and Cap.connectWithDb in conn.server_caps: 1072 | incl(caps, Cap.connectWithDb) 1073 | 1074 | if starttls: 1075 | if Cap.ssl notin conn.server_caps: 1076 | raise newException(ProtocolError, "Server does not support SSL") 1077 | else: 1078 | incl(caps, Cap.ssl) 1079 | 1080 | # Figure out our authentication response. Right now we only 1081 | # support the mysql_native_password_hash method. 1082 | var auth_response: string 1083 | var auth_plugin: string 1084 | 1085 | # password authentication 1086 | if password.len == 0: 1087 | # The caller passes a 0-length password to indicate no password, since 1088 | # we don't have nillable strings. 1089 | auth_response = "" 1090 | auth_plugin = "" 1091 | else: # in future: if greet.authentication_plugin == "" or greet.authentication_plugin == mysql_native_password 1092 | auth_response = mysql_native_password_hash(greet.scramble, password) 1093 | if Cap.pluginAuth in server_caps: 1094 | auth_plugin = mysql_native_password_plugin 1095 | incl(caps, Cap.pluginAuth) 1096 | else: 1097 | auth_plugin = "" 1098 | 1099 | # Do we need pluginAuthLenencClientData ? 1100 | if len(auth_response) > 255: 1101 | if Cap.pluginAuthLenencClientData in server_caps: 1102 | incl(caps, Cap.pluginAuthLenencClientData) 1103 | else: 1104 | raise newException(ProtocolError, "server cannot handle long auth_response") 1105 | 1106 | conn.client_caps = caps 1107 | 1108 | var buf: string = newStringOfCap(128) 1109 | buf.setLen(4) 1110 | 1111 | # Fixed-length portion 1112 | putU32(buf, cast[uint32](caps)) 1113 | putU32(buf, advertisedMaxPacketSize) 1114 | buf.add( char(Charset_utf8_ci) ) 1115 | 1116 | # 23 bytes of filler 1117 | for i in 1 .. 23: 1118 | buf.add( char(0) ) 1119 | 1120 | # Our username 1121 | putNulString(buf, username) 1122 | 1123 | # Authentication data 1124 | let authLen = len(auth_response) 1125 | if Cap.pluginAuthLenencClientData in caps: 1126 | putLenInt(buf, authLen) 1127 | else: 1128 | putU8(buf, len(auth_response)) 1129 | buf.add(auth_response) 1130 | 1131 | if Cap.connectWithDb in caps: 1132 | putNulString(buf, database) 1133 | 1134 | if Cap.pluginAuth in caps: 1135 | putNulString(buf, auth_plugin) 1136 | 1137 | return buf 1138 | 1139 | proc sendCommand(conn: Connection, cmd: Command): Future[void] = 1140 | ## Send a simple, argument-less command. 1141 | var buf: string = newString(5) 1142 | buf[4] = char(cmd) 1143 | return conn.sendPacket(buf, reset_seq_no=true) 1144 | 1145 | proc sendQuery(conn: Connection, query: string): Future[void] = 1146 | var buf: string = newStringOfCap(4 + 1 + len(query)) 1147 | buf.setLen(4) 1148 | buf.add( char(Command.query) ) 1149 | buf.add(query) 1150 | return conn.sendPacket(buf, reset_seq_no=true) 1151 | 1152 | proc receiveMetadata(conn: Connection, count: Positive): Future[seq[ColumnDefinition]] {.async.} = 1153 | var received = 0 1154 | result = newSeq[ColumnDefinition](count) 1155 | while received < count: 1156 | let pkt = await conn.receivePacket() 1157 | # hexdump(pkt, stdmsg) 1158 | if uint8(pkt[0]) == ResponseCode_ERR or uint8(pkt[0]) == ResponseCode_EOF: 1159 | raise newException(ProtocolError, "TODO") 1160 | var pos = 0 1161 | result[received].catalog = scanLenStr(pkt, pos) 1162 | result[received].schema = scanLenStr(pkt, pos) 1163 | result[received].table = scanLenStr(pkt, pos) 1164 | result[received].orig_table = scanLenStr(pkt, pos) 1165 | result[received].name = scanLenStr(pkt, pos) 1166 | result[received].orig_name = scanLenStr(pkt, pos) 1167 | let extras_len = scanLenInt(pkt, pos) 1168 | if extras_len < 10 or (pos+extras_len > len(pkt)): 1169 | raise newException(ProtocolError, "truncated column packet") 1170 | result[received].charset = int16(scanU16(pkt, pos)) 1171 | result[received].length = scanU32(pkt, pos+2) 1172 | result[received].column_type = FieldType(uint8(pkt[pos+6])) 1173 | result[received].flags = cast[set[FieldFlag]](scanU16(pkt, pos+7)) 1174 | result[received].decimals = int(pkt[pos+9]) 1175 | inc(received) 1176 | let endPacket = await conn.receivePacket() 1177 | if uint8(endPacket[0]) != ResponseCode_EOF: 1178 | raise newException(ProtocolError, "Expected EOF after column defs, got something else") 1179 | 1180 | proc parseTextRow(pkt: string): seq[ResultString] = 1181 | var pos = 0 1182 | result = newSeq[ResultString]() 1183 | while pos < len(pkt): 1184 | if pkt[pos] == NullColumn: 1185 | result.add( ResultString(isNull: true) ) 1186 | inc(pos) 1187 | else: 1188 | result.add( ResultString(isNull: false, value: pkt.scanLenStr(pos)) ) 1189 | 1190 | # EOF is signaled by a packet that starts with 0xFE, which is 1191 | # also a valid length-encoded-integer. In order to distinguish 1192 | # between the two cases, we check the length of the packet: EOFs 1193 | # are always short, and an 0xFE in a result row would be followed 1194 | # by at least 65538 bytes of data. 1195 | proc isEOFPacket(pkt: string): bool = 1196 | result = (len(pkt) >= 1) and (pkt[0] == char(ResponseCode_EOF)) and (len(pkt) < 9) 1197 | 1198 | # Error packets are simpler to detect, because 0xFF is not (yet?) 1199 | # valid as the start of a length-encoded-integer. 1200 | proc isERRPacket(pkt: string): bool = (len(pkt) >= 3) and (pkt[0] == char(ResponseCode_ERR)) 1201 | 1202 | proc isOKPacket(pkt: string): bool = (len(pkt) >= 3) and (pkt[0] == char(ResponseCode_OK)) 1203 | 1204 | proc parseErrorPacket(pkt: string): ref ResponseERR not nil = 1205 | new(result) 1206 | result.error_code = scanU16(pkt, 1) 1207 | var pos: int 1208 | if len(pkt) >= 9 and pkt[3] == '#': 1209 | result.sqlstate = pkt.substr(4, 8) 1210 | pos = 9 1211 | else: 1212 | pos = 3 1213 | result.msg = pkt[pos .. high(pkt)] 1214 | 1215 | proc parseOKPacket(conn: Connection, pkt: string): ResponseOK = 1216 | result.eof = false 1217 | var pos: int = 1 1218 | result.affected_rows = scanLenInt(pkt, pos) 1219 | result.last_insert_id = scanLenInt(pkt, pos) 1220 | # We always supply Cap.protocol41 in client caps 1221 | result.status_flags = cast[set[Status]]( scanU16(pkt, pos) ) 1222 | result.warning_count = scanU16(pkt, pos+2) 1223 | pos = pos + 4 1224 | if Cap.sessionTrack in conn.client_caps: 1225 | result.info = scanLenStr(pkt, pos) 1226 | else: 1227 | result.info = scanNulStringX(pkt, pos) 1228 | 1229 | proc parseEOFPacket(pkt: string): ResponseOK = 1230 | result.eof = true 1231 | result.warning_count = scanU16(pkt, 1) 1232 | result.status_flags = cast[set[Status]]( scanU16(pkt, 3) ) 1233 | 1234 | proc expectOK(conn: Connection, ctxt: string): Future[ResponseOK] {.async.} = 1235 | let pkt = await conn.receivePacket() 1236 | if isERRPacket(pkt): 1237 | raise parseErrorPacket(pkt) 1238 | elif isOKPacket(pkt): 1239 | return parseOKPacket(conn, pkt) 1240 | else: 1241 | raise newException(ProtocolError, "unexpected response to " & ctxt) 1242 | 1243 | proc prepareStatement*(conn: Connection, query: string): Future[PreparedStatement] {.async.} = 1244 | ## Prepare a statement for future execution. The returned statement handle 1245 | ## must only be used with this connection. This is equivalent to 1246 | ## the `mysql_stmt_prepare()` function in the standard C API. 1247 | var buf: string = newStringOfCap(4 + 1 + len(query)) 1248 | buf.setLen(4) 1249 | buf.add( char(Command.statementPrepare) ) 1250 | buf.add(query) 1251 | await conn.sendPacket(buf, reset_seq_no=true) 1252 | let pkt = await conn.receivePacket() 1253 | if isERRPacket(pkt): 1254 | raise parseErrorPacket(pkt) 1255 | if pkt[0] != char(ResponseCode_OK) or len(pkt) < 12: 1256 | raise newException(ProtocolError, "Unexpected response to STMT_PREPARE (len=" & $(pkt.len) & ", first byte=0x" & toHex(int(pkt[0]), 2) & ")") 1257 | let num_columns = scanU16(pkt, 5) 1258 | let num_params = scanU16(pkt, 7) 1259 | let num_warnings = scanU16(pkt, 10) 1260 | 1261 | new(result) 1262 | result.warnings = num_warnings 1263 | for b in 0 .. 3: result.statement_id[b] = pkt[1+b] 1264 | if num_params > 0'u16: 1265 | result.parameters = await conn.receiveMetadata(int(num_params)) 1266 | else: 1267 | result.parameters = newSeq[ColumnDefinition](0) 1268 | if num_columns > 0'u16: 1269 | result.columns = await conn.receiveMetadata(int(num_columns)) 1270 | 1271 | proc prepStmtBuf(stmt: PreparedStatement, buf: var string, cmd: Command, cap: int = 9) = 1272 | buf = newStringOfCap(cap) 1273 | buf.setLen(9) 1274 | buf[4] = char(cmd) 1275 | for b in 0..3: buf[b+5] = stmt.statement_id[b] 1276 | 1277 | proc closeStatement*(conn: Connection, stmt: PreparedStatement): Future[void] = 1278 | ## Indicate to the server that this prepared statement is no longer 1279 | ## needed. Note that statement handles are not closed automatically 1280 | ## if garbage-collected, and will continue to occupy a statement 1281 | ## handle on the server side until the connection is closed. 1282 | var buf: string 1283 | stmt.prepStmtBuf(buf, Command.statementClose) 1284 | return conn.sendPacket(buf, reset_seq_no=true) 1285 | proc resetStatement*(conn: Connection, stmt: PreparedStatement): Future[void] = 1286 | var buf: string 1287 | stmt.prepStmtBuf(buf, Command.statementReset) 1288 | return conn.sendPacket(buf, reset_seq_no=true) 1289 | 1290 | proc formatBoundParams(stmt: PreparedStatement, params: openarray[ParameterBinding]): string = 1291 | if len(params) != len(stmt.parameters): 1292 | raise newException(ValueError, "Wrong number of parameters supplied to prepared statement (got " & $len(params) & ", statement expects " & $len(stmt.parameters) & ")") 1293 | var approx = 14 + ( (params.len + 7) div 8 ) + (params.len * 2) 1294 | for p in params: 1295 | approx += p.approximatePackedSize() 1296 | stmt.prepStmtBuf(result, Command.statementExecute, cap = approx) 1297 | result.putU8(uint8(CursorType.noCursor)) 1298 | result.putU32(1) # "iteration-count" always 1 1299 | if stmt.parameters.len == 0: 1300 | return 1301 | # Compute the null bitmap 1302 | var ch = 0 1303 | for p in 0 .. high(stmt.parameters): 1304 | let bit = p mod 8 1305 | if bit == 0 and p > 0: 1306 | result.add(char(ch)) 1307 | ch = 0 1308 | if params[p].typ == paramNull: 1309 | ch = ch or ( 1 shl bit ) 1310 | result.add(char(ch)) 1311 | result.add(char(1)) # new-params-bound flag 1312 | for p in params: 1313 | p.addTypeUnlessNULL(result) 1314 | for p in params: 1315 | p.addValueUnlessNULL(result) 1316 | 1317 | proc parseBinaryRow(columns: seq[ColumnDefinition], pkt: string): seq[ResultValue] = 1318 | let column_count = columns.len 1319 | let bitmap_len = (column_count + 9) div 8 1320 | if len(pkt) < (1 + bitmap_len) or pkt[0] != char(0): 1321 | raise newException(ProtocolError, "Truncated or incorrect binary result row") 1322 | newSeq(result, column_count) 1323 | var pos = 1 + bitmap_len 1324 | for ix in 0 .. column_count-1: 1325 | # First, check whether this column's bit is set in the null 1326 | # bitmap. The bitmap is offset by 2, for no apparent reason. 1327 | let bitmap_index = ix + 2 1328 | let bitmap_entry = uint8(pkt[ 1 + (bitmap_index div 8) ]) 1329 | if (bitmap_entry and uint8(1 shl (bitmap_index mod 8))) != 0'u8: 1330 | # This value is NULL 1331 | result[ix] = ResultValue(typ: rvtNull) 1332 | else: 1333 | let typ = columns[ix].column_type 1334 | let uns = FieldFlag.unsigned in columns[ix].flags 1335 | case typ 1336 | of fieldTypeNull: 1337 | result[ix] = ResultValue(typ: rvtNull) 1338 | of fieldTypeTiny: 1339 | let v = pkt[pos] 1340 | inc(pos) 1341 | let ext = (if uns: int(cast[uint8](v)) else: int(cast[int8](v))) 1342 | result[ix] = ResultValue(typ: rvtInteger, intVal: ext) 1343 | of fieldTypeShort, fieldTypeYear: 1344 | let v = int(scanU16(pkt, pos)) 1345 | inc(pos, 2) 1346 | let ext = (if uns or (v <= 32767): v else: 65536 - v) 1347 | result[ix] = ResultValue(typ: rvtInteger, intVal: ext) 1348 | of fieldTypeInt24, fieldTypeLong: 1349 | let v = scanU32(pkt, pos) 1350 | inc(pos, 4) 1351 | var ext: int 1352 | if not uns and (typ == fieldTypeInt24) and v >= 8388608'u32: 1353 | ext = 16777216 - int(v) 1354 | elif not uns and (typ == fieldTypeLong): 1355 | ext = int( cast[int32](v) ) # rely on 2's-complement reinterpretation here 1356 | else: 1357 | ext = int(v) 1358 | result[ix] = ResultValue(typ: rvtInteger, intVal: ext) 1359 | of fieldTypeLongLong: 1360 | let v = scanU64(pkt, pos) 1361 | inc(pos, 8) 1362 | if uns: 1363 | result[ix] = ResultValue(typ: rvtULong, uLongVal: v) 1364 | else: 1365 | result[ix] = ResultValue(typ: rvtLong, longVal: cast[int64](v)) 1366 | of fieldTypeFloat: 1367 | result[ix] = ResultValue(typ: rvtFloat32, 1368 | floatVal: scanIEEE754Single(pkt, pos)) 1369 | inc(pos, 4) 1370 | of fieldTypeDouble: 1371 | result[ix] = ResultValue(typ: rvtFloat64, 1372 | doubleVal: scanIEEE754Double(pkt, pos)) 1373 | inc(pos, 8) 1374 | of fieldTypeTime, fieldTypeDate, fieldTypeDateTime, fieldTypeTimestamp: 1375 | raise newException(Exception, "Not implemented, TODO") 1376 | of fieldTypeTinyBlob, fieldTypeMediumBlob, fieldTypeLongBlob, fieldTypeBlob, fieldTypeBit: 1377 | result[ix] = ResultValue(typ: rvtBlob, strVal: scanLenStr(pkt, pos)) 1378 | of fieldTypeVarchar, fieldTypeVarString, fieldTypeString, fieldTypeDecimal, fieldTypeNewDecimal: 1379 | result[ix] = ResultValue(typ: rvtString, strVal: scanLenStr(pkt, pos)) 1380 | of fieldTypeEnum, fieldTypeSet, fieldTypeGeometry: 1381 | raise newException(ProtocolError, "Unexpected field type " & $(typ) & " in resultset") 1382 | 1383 | proc finishEstablishingConnection(conn: Connection, database: string): Future[void] {.async.} = 1384 | # await confirmation from the server 1385 | let pkt = await conn.receivePacket() 1386 | if isOKPacket(pkt): 1387 | discard 1388 | elif isERRPacket(pkt): 1389 | raise parseErrorPacket(pkt) 1390 | else: 1391 | raise newException(ProtocolError, "Unexpected packet received after sending client handshake") 1392 | 1393 | # Normally we bundle the initial database selection into the 1394 | # connection setup exchange, but if we couldn't do that, then do it 1395 | # here. 1396 | if len(database) > 0 and Cap.connectWithDb notin conn.client_caps: 1397 | discard await conn.selectDatabase(database) 1398 | 1399 | when declared(SslContext) and defined(ssl): 1400 | proc establishConnection*(sock: AsyncSocket not nil, username: string, password: string, database: string = "", sslHostname: string, ssl: SslContext): Future[Connection] {.async.} = 1401 | ## Establish a connection, requesting SSL (TLS). The `sslHostname` and 1402 | ## `ssl` parameters are as used by `asyncnet.wrapConnectedSocket`. 1403 | if isNil(ssl): 1404 | raise newException(ValueError, "nil SSL context") 1405 | if isNil(sock): 1406 | raise newException(ValueError, "nil socket") 1407 | else: 1408 | result = Connection(socket: sock) 1409 | let pkt = await result.receivePacket() 1410 | var response = computeHandshakeResponse(result, pkt, 1411 | username, password, database, 1412 | starttls = true) 1413 | 1414 | # MySQL's equivalent of STARTTLS: we send a sort of stub response 1415 | # here, which is a prefix of the real response just containing our 1416 | # client caps flags, then do SSL setup, and send the entire response 1417 | # over the encrypted connection. 1418 | var stub: string = response[0 ..< 36] 1419 | await result.sendPacket(stub) 1420 | 1421 | # The server will respond with the SSL SERVER_HELLO packet. 1422 | wrapConnectedSocket(ssl, result.socket, 1423 | handshake = handshakeAsClient, 1424 | hostname = sslHostname) 1425 | # and, once the encryption is negotiated, we will continue 1426 | # with the real handshake response. 1427 | await result.sendPacket(response) 1428 | 1429 | # And finish the handshake 1430 | await result.finishEstablishingConnection(database) 1431 | 1432 | proc establishConnection*(sock: AsyncSocket not nil, username: string, password: string, database: string = ""): Future[Connection] {.async.} = 1433 | ## Establish a database session. The caller is responsible for setting up 1434 | ## the underlying socket, which will be adopted by the returned `Connection` 1435 | ## instance and closed when the connection is closed. 1436 | ## 1437 | ## If `password` is non-empty, password authentication is performed 1438 | ## (it is not possible to perform password authentication with a zero-length 1439 | ## password using this library). If `database` is non-empty, the named 1440 | ## database will be selected. 1441 | if isNil(sock): 1442 | raise newException(ValueError, "nil socket") 1443 | else: 1444 | result = Connection(socket: sock) 1445 | let pkt = await result.receivePacket() 1446 | var response = computeHandshakeResponse(result, pkt, 1447 | username, password, database, 1448 | starttls = false) 1449 | await result.sendPacket(response) 1450 | await result.finishEstablishingConnection(database) 1451 | 1452 | proc textQuery*(conn: Connection, query: string): Future[ResultSet[ResultString]] {.async.} = 1453 | ## Perform a query using the text protocol, returning a single result set. 1454 | await conn.sendQuery(query) 1455 | let pkt = await conn.receivePacket() 1456 | if isOKPacket(pkt): 1457 | # Success, but no rows returned. 1458 | result.status = parseOKPacket(conn, pkt) 1459 | result.columns = @[] 1460 | result.rows = @[] 1461 | elif isERRPacket(pkt): 1462 | # Some kind of failure. 1463 | raise parseErrorPacket(pkt) 1464 | else: 1465 | var p = 0 1466 | let column_count = scanLenInt(pkt, p) 1467 | result.columns = await conn.receiveMetadata(column_count) 1468 | var rows: seq[seq[ResultString]] 1469 | newSeq(rows, 0) 1470 | while true: 1471 | let pkt = await conn.receivePacket() 1472 | if isEOFPacket(pkt): 1473 | result.status = parseEOFPacket(pkt) 1474 | break 1475 | elif isOKPacket(pkt): 1476 | result.status = parseOKPacket(conn, pkt) 1477 | break 1478 | elif isERRPacket(pkt): 1479 | raise parseErrorPacket(pkt) 1480 | else: 1481 | rows.add(parseTextRow(pkt)) 1482 | result.rows = rows 1483 | return 1484 | 1485 | proc performPreparedQuery(conn: Connection, stmt: PreparedStatement, st: Future[void]): Future[ResultSet[ResultValue]] {.async.} = 1486 | await st 1487 | let initialPacket = await conn.receivePacket() 1488 | if isOKPacket(initialPacket): 1489 | # Success, but no rows returned. 1490 | result.status = parseOKPacket(conn, initialPacket) 1491 | result.columns = @[] 1492 | result.rows = @[] 1493 | elif isERRPacket(initialPacket): 1494 | # Some kind of failure. 1495 | raise parseErrorPacket(initialPacket) 1496 | else: 1497 | var p = 0 1498 | let column_count = scanLenInt(initialPacket, p) 1499 | result.columns = await conn.receiveMetadata(column_count) 1500 | var rows: seq[seq[ResultValue]] 1501 | newSeq(rows, 0) 1502 | while true: 1503 | let pkt = await conn.receivePacket() 1504 | # hexdump(pkt, stdmsg) 1505 | if isEOFPacket(pkt): 1506 | result.status = parseEOFPacket(pkt) 1507 | break 1508 | elif isERRPacket(pkt): 1509 | raise parseErrorPacket(pkt) 1510 | else: 1511 | rows.add(parseBinaryRow(result.columns, pkt)) 1512 | result.rows = rows 1513 | 1514 | proc preparedQuery*(conn: Connection, stmt: PreparedStatement, params: varargs[ParameterBinding, asParam]): Future[ResultSet[ResultValue]] = 1515 | ## Perform a query using the binary (prepared-statement) protocol, 1516 | ## returning a single result set. 1517 | var pkt = formatBoundParams(stmt, params) 1518 | var sent = conn.sendPacket(pkt, reset_seq_no=true) 1519 | return performPreparedQuery(conn, stmt, sent) 1520 | 1521 | proc selectDatabase*(conn: Connection, database: string): Future[ResponseOK] {.async.} = 1522 | ## Select a database. 1523 | ## This is equivalent to the `mysql_select_db()` function in the 1524 | ## standard C API. 1525 | var buf: string = newStringOfCap(4 + 1 + len(database)) 1526 | buf.setLen(4) 1527 | buf.add( char(Command.initDb) ) 1528 | buf.add(database) 1529 | await conn.sendPacket(buf, reset_seq_no=true) 1530 | return await conn.expectOK("COM_INIT_DB") 1531 | 1532 | proc ping*(conn: Connection): Future[ResponseOK] {.async.} = 1533 | ## Send a ping packet to the server to check for liveness. 1534 | ## This is equivalent to the `mysql_ping()` function in the 1535 | ## standard C API. 1536 | await conn.sendCommand(Command.ping) 1537 | return await conn.expectOK("COM_PING") 1538 | 1539 | proc close*(conn: Connection): Future[void] {.async.} = 1540 | ## Close the connection to the database, including the underlying socket. 1541 | await conn.sendCommand(Command.quiT) 1542 | let pkt = await conn.receivePacket(drop_ok=true) 1543 | conn.socket.close() 1544 | 1545 | # ###################################################################### 1546 | # 1547 | # Internal tests 1548 | # These don't try to test everything, just basic things and things 1549 | # that won't be exercised by functional testing against a server 1550 | 1551 | 1552 | when isMainModule: 1553 | proc hexstr(s: string): string = 1554 | result = "" 1555 | let chs = "0123456789abcdef" 1556 | for ch in s: 1557 | let i = int(ch) 1558 | result.add(chs[ (i and 0xF0) shr 4]) 1559 | result.add(chs[ i and 0x0F ]) 1560 | 1561 | test "Parameter packing": 1562 | let dummy_param = ColumnDefinition() 1563 | var sth: PreparedStatement 1564 | new(sth) 1565 | sth.statement_id = ['\0', '\xFF', '\xAA', '\x55' ] 1566 | sth.parameters = @[dummy_param, dummy_param, dummy_param, dummy_param, dummy_param, dummy_param, dummy_param, dummy_param] 1567 | 1568 | # Small numbers 1569 | let buf = formatBoundParams(sth, [ asParam(0), asParam(1), asParam(127), asParam(128), asParam(255), asParam(256), asParam(-1), asParam(-127) ]) 1570 | let h = "000000001700ffaa5500010000000001" & # packet header 1571 | "01800180018001800180028001000100" & # wire type info 1572 | "00017f80ff0001ff81" # packed values 1573 | check h == hexstr(buf) 1574 | 1575 | # Numbers and NULLs 1576 | sth.parameters = sth.parameters & dummy_param 1577 | let buf2 = formatBoundParams(sth, [ asParam(-128), asParam(-129), asParam(-255), asParam(nil), asParam(SQLNULL), asParam(-256), asParam(-257), asParam(-32768), asParam(SQLNULL) ]) 1578 | let h2 = "000000001700ffaa550001000000180101" & # packet header 1579 | "010002000200020002000200" & # wire type info 1580 | "807fff01ff00fffffe0080" # packed values 1581 | check h2 == hexstr(buf2) 1582 | 1583 | # More values (strings, etc) 1584 | let buf3 = formatBoundParams(sth, [ asParam("hello"), asParam(SQLNULL), 1585 | asParam(0xFFFF), asParam(0xF1F2F3), asParam(0xFFFFFFFF), asParam(0xFFFFFFFFFF), 1586 | asParam(-12885), asParam(-2160069290), asParam(low(int64) + 512) ]) 1587 | let h3 = "000000001700ffaa550001000000020001" & # packet header 1588 | "fe000280038003800880020008000800" & # wire type info 1589 | "0568656c6c6ffffff3f2f100ffffffffffffffffff000000abcd56f53f7fffffffff0002000000000080" 1590 | check h3 == hexstr(buf3) 1591 | 1592 | # Floats and doubles 1593 | const e32: float32 = 0.00000011920928955078125'f32 1594 | let buf4 = formatBoundParams(sth, [ 1595 | asParam(0'f32), asParam(65535'f32), 1596 | asParam(e32), asParam(1 + e32), 1597 | asParam(0'f64), asParam(-1'f64), 1598 | asParam(float64(e32)), asParam(1 + float64(e32)), asParam(1024 + float64(e32)) ]) 1599 | let h4 = "000000001700ffaa550001000000000001" & # packet header 1600 | "040004000400040005000500050005000500" & # wire type info 1601 | "0000000000ff7f47000000340100803f" & # floats 1602 | "0000000000000000000000000000f0bf" & # doubles 1603 | "000000000000803e000000200000f03f0000080000009040" 1604 | check h4 == hexstr(buf4) 1605 | -------------------------------------------------------------------------------- /asyncmysql.nimble: -------------------------------------------------------------------------------- 1 | # Package 2 | version = "0.2.2" 3 | author = "Wim Lewis" 4 | description = "Nonblocking pure-Nim mysql client module." 5 | license = "MIT" 6 | 7 | requires "nim >= 1.0.2" 8 | -------------------------------------------------------------------------------- /demo.nim: -------------------------------------------------------------------------------- 1 | {.experimental: "notnil".} 2 | 3 | import asyncmysql, asyncdispatch, asyncnet 4 | from nativesockets import AF_INET, SOCK_STREAM 5 | 6 | import net 7 | 8 | proc printResultSet[T](resultSet: ResultSet[T]) = 9 | if resultSet.columns.len > 0: 10 | for ix in low(resultSet.columns) .. high(resultSet.columns): 11 | stdmsg.writeLine("Column ", ix, " - ", $(resultSet.columns[ix].column_type)) 12 | stdmsg.writeLine(" Name: ", resultSet.columns[ix].name) 13 | stdmsg.writeLine(" orig: ", resultSet.columns[ix].catalog, ".", resultSet.columns[ix].schema, ".", resultSet.columns[ix].orig_table, ".", resultSet.columns[ix].orig_name) 14 | stdmsg.writeLine(" length=", int(resultSet.columns[ix].length)) 15 | stdmsg.writeLine("") 16 | for row in resultSet.rows: 17 | for ix in low(row)..high(row): 18 | stdmsg.write(resultSet.columns[ix].name) 19 | if isNil(row[ix]): 20 | stdmsg.writeLine(" is NULL") 21 | else: 22 | stdmsg.writeLine(" = ", row[ix]) 23 | stdmsg.writeLine("") 24 | stdmsg.writeLine(resultSet.status.affected_rows, " rows affected") 25 | stdmsg.writeLine("last_insert_id = ", resultSet.status.last_insert_id) 26 | stdmsg.writeLine(resultSet.status.warning_count, " warnings") 27 | stdmsg.writeLine("status: ", $(resultSet.status.status_flags)) 28 | if len(resultSet.status.info) > 0: 29 | stdmsg.writeLine("Info: ", resultSet.status.info) 30 | 31 | proc demoTextQuery(conn: Connection, query: string) {.async.} = 32 | let res = await conn.textQuery(query) 33 | printResultSet(res) 34 | 35 | proc demoPreparedStatement(conn: Connection) {.async.} = 36 | let stmt = await conn.prepareStatement("select *, ( ? + 1 ) from user u where u.user = ?") 37 | let rslt = await conn.preparedQuery(stmt, 42, "root") 38 | printResultSet(rslt) 39 | await conn.closeStatement(stmt) 40 | 41 | proc blah() {. async .} = 42 | let sockn = newAsyncSocket(AF_INET, SOCK_STREAM) 43 | var sock: AsyncSocket not nil 44 | if sockn.isNil: 45 | raise newException(ValueError, "nil socket") 46 | else: 47 | sock = sockn 48 | await connect(sock, "db4free.net", Port(3306)) 49 | stdmsg.writeLine("(socket connection established)") 50 | when defined(ssl): 51 | let conn = await establishConnection(sock, "test", database = "testdb", password = "test_pass", sslHostname = "db4free.net", ssl = newContext(verifyMode = CVerifyPeer)) 52 | else: 53 | let conn = await establishConnection(sock, "test", database = "testdb", password = "test_pass") 54 | stdmsg.writeLine("(mysql session established)") 55 | await conn.demoTextQuery("select * from mysql.user") 56 | await conn.demoPreparedStatement() 57 | #await conn.demoTextQuery("show session variables like '%ssl%'"); 58 | await conn.demoTextQuery("show session variables like '%version%'"); 59 | 60 | proc foof() = 61 | let fut = blah() 62 | stdmsg.writeLine("starting loop") 63 | waitFor(fut) 64 | stdmsg.writeLine("done") 65 | 66 | foof() 67 | 68 | -------------------------------------------------------------------------------- /test_sql.nim: -------------------------------------------------------------------------------- 1 | import asyncmysql, asyncdispatch, asyncnet, os, parseutils 2 | from nativesockets import AF_INET, SOCK_STREAM 3 | 4 | import net 5 | import strutils 6 | 7 | var database_name: string 8 | var port: int = 3306 9 | var host_name: string = "localhost" 10 | var user_name: string 11 | var pass_word: string 12 | var ssl: bool = false 13 | var allow_mitm: bool = false 14 | var verbose: bool = false 15 | 16 | when defined(ssl): 17 | ssl = true 18 | 19 | proc doTCPConnect(dbn: string = ""): Future[Connection] {.async.} = 20 | let sock = newAsyncSocket(AF_INET, SOCK_STREAM) 21 | await connect(sock, host_name, Port(port)) 22 | if sock.isNil: 23 | raise newException(ValueError, "nil socket") 24 | else: 25 | if ssl: 26 | when defined(ssl): 27 | let ctx = newContext(verifyMode = (if allow_mitm: CVerifyNone else: CVerifyPeer)) 28 | return await establishConnection(sock, user_name, database=dbn, password = pass_word, sslHostname = host_name, ssl=ctx) 29 | else: 30 | raise newException(CatchableError, "ssl is not enabled in this build") 31 | return await establishConnection(sock, user_name, database=dbn, password = pass_word) 32 | 33 | proc getCurrentDatabase(conn: Connection): Future[ResultString] {.async.} = 34 | let rslt = await conn.textQuery("select database()") 35 | doAssert(len(rslt.columns) == 1, "wrong number of result columns") 36 | doAssert(len(rslt.rows) == 1, "wrong number of result rows") 37 | return rslt.rows[0][0] 38 | 39 | proc checkCurrentCipher(conn: Connection): Future[bool] {.async.} = 40 | let rslt = await conn.textQuery("show session status like 'Ssl_cipher'") 41 | doAssert(len(rslt.columns) == 2, "wrong number of result columns") 42 | doAssert(len(rslt.rows) == 1, "wrong number of result rows") 43 | echo " ", rslt.rows[0][0], " = ", rslt.rows[0][1] 44 | let ssl_cipher = rslt.rows[0][1] 45 | if ssl_cipher.isNil or ssl_cipher == "": 46 | return false 47 | else: 48 | return true 49 | 50 | proc connTest(): Future[Connection] {.async.} = 51 | echo "Connecting (with initial db: ", database_name, ")" 52 | let conn1 = await doTCPConnect(dbn = database_name) 53 | echo "Checking current database is correct" 54 | let conn1db1 = await getCurrentDatabase(conn1) 55 | if conn1db1 != database_name: 56 | echo "FAIL (actual db: ", $conn1db1, ")" 57 | echo "Connecting (without initial db)" 58 | let conn2 = await doTCPConnect() 59 | let conn2db1 = await getCurrentDatabase(conn2) 60 | if not isNil(conn2db1): 61 | echo "FAIL (db should be NULL, is: ", $conn2db1, ")" 62 | discard await conn2.selectDatabase(database_name) 63 | let conn2db2 = await getCurrentDatabase(conn2) 64 | if conn2db2 != database_name: 65 | echo "FAIL (db should be: ", database_name, " is: ", conn2db2, ")" 66 | echo "Checking TIDs (", conn1.thread_id, ", ", conn2.thread_id, ")" 67 | let rslt = await conn1.textQuery("show processlist"); 68 | var saw_conn1 = false 69 | var saw_conn2 = false 70 | for row in rslt.rows: 71 | if row[0] == $(conn1.thread_id): 72 | doAssert(saw_conn1 == false, "Multiple rows with conn1's TID") 73 | saw_conn1 = true 74 | if row[0] == $(conn2.thread_id): 75 | doAssert(saw_conn2 == false, "Multiple rows with conn1's TID") 76 | saw_conn2 = true 77 | doAssert(saw_conn1, "Didn't see conn1's TID") 78 | doAssert(saw_conn2, "Didn't see conn2's TID") 79 | let ssl1 = conn1.checkCurrentCipher() 80 | let ssl2 = conn2.checkCurrentCipher() 81 | await `and`(ssl1, ssl2) 82 | doAssert(ssl1.read() == ssl) 83 | doAssert(ssl2.read() == ssl) 84 | let p1 = conn1.ping() 85 | let p2 = conn2.ping() 86 | await `and`(p1, p2) 87 | echo "Closing second connection" 88 | await conn2.close() 89 | return conn1 90 | 91 | template assertEq(T: typedesc, got: untyped, expect: untyped, msg: string = "incorrect value") = 92 | let aa: T = got 93 | bind instantiationInfo 94 | {.line: instantiationInfo().}: 95 | if aa != expect: 96 | raiseAssert("assertEq(" & astToStr(got) & ", " & astToStr(expect) & ") failed (got " & repr(aa) & "): " & msg) 97 | 98 | template assertEqrs(got: untyped, expect: varargs[ResultString, asResultString]) = 99 | bind instantiationInfo 100 | let aa: seq[ResultString] = got 101 | let count = aa.len 102 | {.line: instantiationInfo().}: 103 | if count != expect.len: 104 | raiseAssert(format("assertEqrs($1, ...) failed (got $2 columns, expected $3)", astToStr(got), count, aa.len)) 105 | for col in 0 .. high(aa): 106 | if aa[col] != expect[col]: 107 | raiseAssert(format("assertEqrs($1, $2) failed (mismatch at index $3)", astToStr(got), expect, col)) 108 | 109 | proc numberTests(conn: Connection): Future[void] {.async.} = 110 | echo "Setting up table for numeric tests..." 111 | discard await conn.textQuery("drop table if exists num_tests") 112 | discard await conn.textQuery("create table num_tests (s text, u8 tinyint unsigned, s8 tinyint, u int unsigned, i int, b bigint)") 113 | 114 | echo "Testing numeric parameters" 115 | # Insert values using the binary protocol 116 | let insrow = await conn.prepareStatement("insert into `num_tests` (s, u8, s8, u, i, b) values (?, ?, ?, ?, ?, ?)") 117 | discard await conn.preparedQuery(insrow, "one", 1, 1, 1, 1, 1) 118 | discard await conn.preparedQuery(insrow, "max", 255, 127, 4294967295, 2147483647, 9223372036854775807'u64) 119 | discard await conn.preparedQuery(insrow, "min", 0, -128, 0, -2147483648, (-9223372036854775807'i64 - 1)) 120 | discard await conn.preparedQuery(insrow, "foo", 128, -127, 256, -32767, -32768) 121 | discard await conn.preparedQuery(insrow, "feh", 130'f32, -128'f64, 256.1'f32, 122 | -2100000000'f32, 2147483649.0125'f64) 123 | await conn.closeStatement(insrow) 124 | 125 | # Read them back using the text protocol 126 | let r1 = await conn.textQuery("select s, u8, s8, u, i, b from num_tests order by u8 asc") 127 | assertEq(int, r1.columns.len(), 6, "column count") 128 | assertEq(int, r1.rows.len(), 5, "row count") 129 | assertEq(string, r1.columns[0].name, "s") 130 | assertEq(string, r1.columns[5].name, "b") 131 | 132 | assertEqrs(r1.rows[0], "min", "0", "-128", "0", "-2147483648", "-9223372036854775808") 133 | assertEqrs(r1.rows[1], "one", "1", "1", "1", "1", "1") 134 | assertEqrs(r1.rows[2], "foo", "128", "-127", "256", "-32767", "-32768") 135 | assertEqrs(r1.rows[3], "feh", "130", "-128", "256", "-2100000000", "2147483649") 136 | assertEqrs(r1.rows[4], "max", "255", "127", "4294967295", "2147483647", "9223372036854775807") 137 | 138 | # Now read them back using the binary protocol 139 | echo "Testing numeric results" 140 | let rdtab = await conn.prepareStatement("select b, i, u, s, u8, s8 from num_tests order by i desc") 141 | let r2 = await conn.preparedQuery(rdtab) 142 | assertEq(int, r2.columns.len(), 6, "column count") 143 | assertEq(int, r2.rows.len(), 5, "row count") 144 | assertEq(string, r2.columns[0].name, "b") 145 | assertEq(string, r2.columns[5].name, "s8") 146 | 147 | assertEq(int64, r2.rows[0][0], 9223372036854775807'i64) 148 | assertEq(uint64, r2.rows[0][0], 9223372036854775807'u64) 149 | assertEq(int64, r2.rows[0][1], 2147483647'i64) 150 | assertEq(uint64, r2.rows[0][1], 2147483647'u64) 151 | assertEq(int, r2.rows[0][1], 2147483647) 152 | assertEq(uint, r2.rows[0][1], 2147483647'u) 153 | assertEq(uint, r2.rows[0][2], 4294967295'u) 154 | assertEq(int64, r2.rows[0][2], 4294967295'i64) 155 | assertEq(uint64, r2.rows[0][2], 4294967295'u64) 156 | assertEq(string, r2.rows[0][3], "max") 157 | assertEq(int, r2.rows[0][4], 255) 158 | assertEq(int, r2.rows[0][5], 127) 159 | 160 | assertEq(int, r2.rows[1][1], 1) 161 | assertEq(string, r2.rows[1][3], "one") 162 | 163 | assertEq(int, r2.rows[2][0], -32768) 164 | assertEq(int64, r2.rows[2][0], -32768'i64) 165 | assertEq(int, r2.rows[2][1], -32767) 166 | assertEq(int64, r2.rows[2][1], -32767'i64) 167 | assertEq(int, r2.rows[2][2], 256) 168 | assertEq(string, r2.rows[2][3], "foo") 169 | assertEq(int, r2.rows[2][4], 128) 170 | assertEq(int, r2.rows[2][5], -127) 171 | assertEq(int64, r2.rows[2][5], -127'i64) 172 | 173 | assertEq(int64, r2.rows[3][0], 2147483649) 174 | assertEq(uint, r2.rows[3][2], 256) 175 | 176 | assertEq(int64, r2.rows[4][0], ( -9223372036854775807'i64 - 1 )) 177 | assertEq(int, r2.rows[4][1], -2147483648) 178 | assertEq(int, r2.rows[4][4], 0) 179 | assertEq(int64, r2.rows[4][4], 0'i64) 180 | 181 | await conn.closeStatement(rdtab) 182 | discard await conn.textQuery("drop table `num_tests`") 183 | 184 | proc floatTests(conn: Connection): Future[void] {.async.} = 185 | echo "Setting up table for float tests..." 186 | discard await conn.textQuery("drop table if exists float_tests") 187 | discard await conn.textQuery("create table float_tests (s text, a FLOAT, b DOUBLE)") 188 | 189 | echo "Inserting float values" 190 | # Insert values using the binary protocol 191 | let insrow = await conn.prepareStatement("insert into `float_tests` (s, a, b) values (?, ?, ?)") 192 | discard await conn.preparedQuery(insrow, "one", int8(1), 1'f32) 193 | discard await conn.preparedQuery(insrow, "thou", 0.001'f32, 0.001'f64) 194 | discard await conn.preparedQuery(insrow, "many", 524288'f64, 1073741824'f32) #swapped 195 | await conn.closeStatement(insrow) 196 | 197 | # Read them back using the text protocol 198 | let r1 = await conn.textQuery("select s, a, b from float_tests order by a asc") 199 | assertEq(int, r1.columns.len(), 3, "column count") 200 | assertEq(int, r1.rows.len(), 3, "row count") 201 | assertEq(string, r1.columns[0].name, "s") 202 | assertEq(string, r1.columns[1].name, "a") 203 | 204 | assertEqrs(r1.rows[0], "thou", "0.001", "0.001") 205 | assertEqrs(r1.rows[1], "one", "1", "1") 206 | assertEqrs(r1.rows[2], "many", "524288", "1073741824") 207 | 208 | # Now read them back using the binary protocol 209 | echo "Reading float values" 210 | let rdtab = await conn.prepareStatement("select s, a, b from float_tests order by a desc") 211 | let rdcross = await conn.prepareStatement("select CONCAT(x.s, '+', y.s) as v, x.a + y.a, x.b + y.b from float_tests x, float_tests y where x.s <= y.s order by v") 212 | 213 | let r2 = await conn.preparedQuery(rdtab) 214 | assertEq(int, r2.rows.len(), 3, "row count") 215 | 216 | doAssert(r2.rows[0][1] == 524288'i32) 217 | doAssert(r2.rows[0][1] == 524288'i64) 218 | doAssert(r2.rows[0][1] == 524288'f32) 219 | doAssert(r2.rows[0][1] == 524288'f64) 220 | doAssert(r2.rows[0][2] == 1073741824'i32) 221 | doAssert(r2.rows[0][2] == 1073741824'i64) 222 | doAssert(r2.rows[0][2] == 1073741824'f32) 223 | doAssert(r2.rows[0][2] == 1073741824'f64) 224 | 225 | # echo r2.rows[1] 226 | doAssert(r2.rows[1][1] == 1'u) 227 | doAssert(r2.rows[1][1] == 1'f32) 228 | doAssert(r2.rows[1][1] == 1'f64) 229 | doAssert(r2.rows[1][2] == 1'u) 230 | doAssert(r2.rows[1][2] == 1'f32) 231 | doAssert(r2.rows[1][2] == 1'f64) 232 | 233 | let r3 = await conn.preparedQuery(rdcross) 234 | assertEq(int, r3.rows.len(), 6, "row count") 235 | assertEq(int, r3.columns.len(), 3, "column count") 236 | 237 | assertEq(string, r3.rows[0][0], "many+many") 238 | doAssert(r3.rows[0][1] == 1048576'f32) 239 | doAssert(r3.rows[0][2] == 2147483648'f64) 240 | 241 | assertEq(string, r3.rows[1][0], "many+one") 242 | doAssert(r3.rows[1][1] == 524289'f32) 243 | doAssert(r3.rows[1][2] == 1073741825'f64) 244 | 245 | assertEq(string, r3.rows[2][0], "many+thou") 246 | # Note: [2][1] should be 524288 in single-precision, or 524288.001 in double 247 | doAssert(r3.rows[2][2] == 1073741824.001'f64) 248 | 249 | assertEq(string, r3.rows[3][0], "one+one") 250 | # nothing we haven't already tested here 251 | 252 | assertEq(string, r3.rows[4][0], "one+thou") 253 | assertEq(float32, r3.rows[4][1], 1.001'f32) 254 | doAssert(r3.rows[4][1] != 1) 255 | doAssert(r3.rows[4][2] == 1.001'f64) 256 | 257 | await conn.closeStatement(rdtab) 258 | await conn.closeStatement(rdcross) 259 | 260 | discard await conn.textQuery("drop table `float_tests`") 261 | 262 | proc runTests(): Future[void] {.async.} = 263 | let conn = await connTest() 264 | await conn.numberTests() 265 | await conn.floatTests() 266 | await conn.close() 267 | 268 | proc usage(unopt: string = "") = 269 | if unopt.len > 0: 270 | stdmsg.writeLine("Unrecognized argument: ", unopt) 271 | echo "Usage:" 272 | echo paramStr(0), " [--ssl|--no-ssl] [-v] [-D database] [-h host] [-P portnum] [-u username]" 273 | echo "\t-D, --database: Perform tests in specified database. (required)" 274 | echo "\t-h, --host: Connect to server on host. (default: localhost)" 275 | echo "\t-P, --port: Connect to specified TCP port (default: 3306)" 276 | echo "\t-u, --username: Connect as specified username (required)" 277 | echo "\t-p, --password: Provide the specified password" 278 | echo "\t--ssl, --no-ssl: Enable ssl/tls (default: cleartext)" 279 | echo "\t--allow-mitm: Disable security checks for SSL" 280 | echo "\t-v: More verbose output" 281 | echo "The user must have the ability to create and drop tables in the" 282 | echo "database, as well as the usual select and insert privileges." 283 | quit(QuitFailure) 284 | 285 | block: 286 | ## Nim stdlib's parseopt2 doesn't handle standard argument syntax, 287 | ## so this is a half-assed attempt to do that. 288 | var ix = 1 289 | while (ix+1) <= os.paramCount(): 290 | let param = os.paramStr(ix) 291 | inc(ix) 292 | case param 293 | of "--database", "-D": 294 | database_name = os.paramStr(ix) 295 | inc(ix) 296 | of "--host", "-h": 297 | host_name = os.paramStr(ix) 298 | inc(ix) 299 | of "--port", "-P": 300 | let val = os.paramStr(ix) 301 | inc(ix) 302 | if parseInt(val, port, 0) != len(val): 303 | usage() 304 | of "--user", "-u": 305 | user_name = os.paramStr(ix) 306 | inc(ix) 307 | of "--password", "-p": 308 | pass_word = os.paramStr(ix) 309 | inc(ix) 310 | of "--ssl": 311 | ssl = true 312 | of "--no-ssl": 313 | ssl = false 314 | of "--allow-mitm": 315 | allow_mitm = true 316 | of "-v", "--verbose": 317 | verbose = true 318 | else: 319 | usage(param) 320 | if ix != os.paramCount()+1: 321 | usage() 322 | if database_name.len == 0 or user_name.len == 0 or port < 1 or port > 65535: 323 | usage() 324 | 325 | waitFor(runTests()) 326 | echo "Done" 327 | quit(QuitSuccess) 328 | --------------------------------------------------------------------------------