├── .gitignore ├── LICENSE ├── README.md ├── all.bash ├── autorc ├── LICENSE ├── autorecon.go └── autorecon_test.go ├── codelingo.yaml ├── doc.go ├── examples ├── database_sql │ └── database_sql.go ├── long_data │ └── long_data.go ├── parallel │ └── parallel.go ├── prepared_stmt │ └── prepared_stmt.go ├── reconnect │ └── reconnect.go ├── simple │ └── simple.go └── transactions │ └── transactions.go ├── go.mod ├── godrv ├── appengine.go ├── driver.go └── driver_test.go ├── mysql ├── errors.go ├── field.go ├── interface.go ├── row.go ├── status.go ├── types.go ├── types_test.go └── utils.go ├── native ├── LICENSE ├── addons.go ├── bind_test.go ├── binding.go ├── codecs.go ├── command.go ├── common.go ├── consts.go ├── init.go ├── mysql.go ├── native_test.go ├── packet.go ├── paramvalue.go ├── passwd.go ├── prepared.go ├── result.go └── unsafe.go-disabled └── thrsafe ├── LICENSE ├── thrsafe.go └── thrsafe_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.6 2 | *.8 3 | *.a 4 | *.o 5 | *.so 6 | *.out 7 | *.go~ 8 | _obj 9 | _testmain.go 10 | _go_.6 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010, Michal Derkacz 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 3. The name of the author may not be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /all.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | p=github.com/ziutek/mymysql 3 | 4 | go $* $p/mysql $p/native $p/thrsafe $p/autorc $p/godrv 5 | -------------------------------------------------------------------------------- /autorc/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010, Michal Derkacz 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 3. The name of the author may not be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /autorc/autorecon.go: -------------------------------------------------------------------------------- 1 | // Package autorc provides an auto reconnect interface for MyMySQL. 2 | package autorc 3 | 4 | import ( 5 | "io" 6 | "log" 7 | "net" 8 | "time" 9 | 10 | "github.com/ziutek/mymysql/mysql" 11 | ) 12 | 13 | // IsNetErr returns true if error is network error or UnexpectedEOF. 14 | func IsNetErr(err error) bool { 15 | if err == io.ErrUnexpectedEOF { 16 | return true 17 | } 18 | if _, ok := err.(net.Error); ok { 19 | return true 20 | } 21 | if mysqlError, ok := err.(mysql.Error); ok { 22 | switch mysqlError.Code { 23 | case mysql.ER_QUERY_INTERRUPTED: 24 | return true 25 | case mysql.ER_NET_READ_ERROR: 26 | return true 27 | case mysql.ER_NET_READ_INTERRUPTED: 28 | return true 29 | case mysql.ER_NET_ERROR_ON_WRITE: 30 | return true 31 | case mysql.ER_NET_WRITE_INTERRUPTED: 32 | return true 33 | } 34 | } 35 | return false 36 | } 37 | 38 | // Conn is an autoreconnecting connection type. 39 | type Conn struct { 40 | Raw mysql.Conn 41 | // Maximum reconnect retries. 42 | // Default is 7 which means 1+2+3+4+5+6+7 = 28 seconds before return error 43 | // (if waiting for error takes no time). 44 | MaxRetries int 45 | 46 | // Debug logging. You may change it at any time. 47 | Debug bool 48 | } 49 | 50 | // New creates a new autoreconnecting connection. 51 | func New(proto, laddr, raddr, user, passwd string, db ...string) *Conn { 52 | return &Conn{ 53 | Raw: mysql.New(proto, laddr, raddr, user, passwd, db...), 54 | MaxRetries: 7, 55 | } 56 | } 57 | 58 | // NewFromCF creates a new autoreconnecting connection from config file. 59 | // Returns connection handler and map containing unknown options. 60 | func NewFromCF(cfgFile string) (*Conn, map[string]string, error) { 61 | raw, unk, err := mysql.NewFromCF(cfgFile) 62 | if err != nil { 63 | return nil, nil, err 64 | } 65 | return &Conn{raw, 7, false}, unk, nil 66 | } 67 | 68 | // Clone makes a copy of the connection. 69 | func (c *Conn) Clone() *Conn { 70 | return &Conn{ 71 | Raw: c.Raw.Clone(), 72 | MaxRetries: c.MaxRetries, 73 | Debug: c.Debug, 74 | } 75 | } 76 | 77 | // SetTimeout sets a timeout for underlying mysql.Conn connection. 78 | func (c *Conn) SetTimeout(timeout time.Duration) { 79 | c.Raw.SetTimeout(timeout) 80 | } 81 | 82 | func (c *Conn) reconnectIfNetErr(nn *int, err *error) { 83 | for *err != nil && IsNetErr(*err) && *nn <= c.MaxRetries { 84 | if c.Debug { 85 | log.Printf("Error: '%s' - reconnecting...", *err) 86 | } 87 | time.Sleep(time.Second * time.Duration(*nn)) 88 | *err = c.Raw.Reconnect() 89 | if c.Debug && *err != nil { 90 | log.Println("Can't reconnect:", *err) 91 | } 92 | *nn++ 93 | } 94 | } 95 | 96 | func (c *Conn) connectIfNotConnected() (err error) { 97 | if c.Raw.IsConnected() { 98 | return 99 | } 100 | err = c.Raw.Connect() 101 | nn := 0 102 | c.reconnectIfNetErr(&nn, &err) 103 | return 104 | } 105 | 106 | // Reconnect tries to reconnect the connection up to MaxRetries times. 107 | func (c *Conn) Reconnect() (err error) { 108 | err = c.Raw.Reconnect() 109 | nn := 0 110 | c.reconnectIfNetErr(&nn, &err) 111 | return 112 | } 113 | 114 | func (c *Conn) Register(sql string) { 115 | c.Raw.Register(sql) 116 | } 117 | 118 | func (c *Conn) SetMaxPktSize(new_size int) int { 119 | return c.Raw.SetMaxPktSize(new_size) 120 | } 121 | 122 | // Use is an automatic connect/reconnect/repeat version of mysql.Conn.Use. 123 | func (c *Conn) Use(dbname string) (err error) { 124 | if err = c.connectIfNotConnected(); err != nil { 125 | return 126 | } 127 | nn := 0 128 | for { 129 | if err = c.Raw.Use(dbname); err == nil { 130 | return 131 | } 132 | if c.reconnectIfNetErr(&nn, &err); err != nil { 133 | return 134 | } 135 | } 136 | panic(nil) 137 | } 138 | 139 | // Query is an automatic connect/reconnect/repeat version of mysql.Conn.Query. 140 | func (c *Conn) Query(sql string, params ...interface{}) (rows []mysql.Row, res mysql.Result, err error) { 141 | 142 | if err = c.connectIfNotConnected(); err != nil { 143 | return 144 | } 145 | nn := 0 146 | for { 147 | if rows, res, err = c.Raw.Query(sql, params...); err == nil { 148 | return 149 | } 150 | if c.reconnectIfNetErr(&nn, &err); err != nil { 151 | return 152 | } 153 | } 154 | panic(nil) 155 | } 156 | 157 | // QueryFirst is an automatic connect/reconnect/repeat version of mysql.Conn.QueryFirst. 158 | func (c *Conn) QueryFirst(sql string, params ...interface{}) (row mysql.Row, res mysql.Result, err error) { 159 | 160 | if err = c.connectIfNotConnected(); err != nil { 161 | return 162 | } 163 | nn := 0 164 | for { 165 | if row, res, err = c.Raw.QueryFirst(sql, params...); err == nil { 166 | return 167 | } 168 | if c.reconnectIfNetErr(&nn, &err); err != nil { 169 | return 170 | } 171 | } 172 | panic(nil) 173 | } 174 | 175 | // QueryLast is an automatic connect/reconnect/repeat version of mysql.Conn.QueryLast. 176 | func (c *Conn) QueryLast(sql string, params ...interface{}) (row mysql.Row, res mysql.Result, err error) { 177 | 178 | if err = c.connectIfNotConnected(); err != nil { 179 | return 180 | } 181 | nn := 0 182 | for { 183 | if row, res, err = c.Raw.QueryLast(sql, params...); err == nil { 184 | return 185 | } 186 | if c.reconnectIfNetErr(&nn, &err); err != nil { 187 | return 188 | } 189 | } 190 | panic(nil) 191 | } 192 | 193 | // Escape is an automatic connect/reconnect/repeat version of mysql.Conn.Escape. 194 | func (c *Conn) Escape(s string) string { 195 | return c.Raw.Escape(s) 196 | } 197 | 198 | // Stmt contains mysql.Stmt and autoteconnecting connection. 199 | type Stmt struct { 200 | Raw mysql.Stmt 201 | con *Conn 202 | 203 | sql string 204 | } 205 | 206 | // PrepareOnce prepares a statement if it wasn't prepared before. 207 | func (c *Conn) PrepareOnce(s *Stmt, sql string) error { 208 | if s.Raw != nil { 209 | return nil 210 | } 211 | if err := c.connectIfNotConnected(); err != nil { 212 | return err 213 | } 214 | nn := 0 215 | for { 216 | var err error 217 | if s.Raw, err = c.Raw.Prepare(sql); err == nil { 218 | s.con = c 219 | return nil 220 | } 221 | if c.reconnectIfNetErr(&nn, &err); err != nil { 222 | return err 223 | } 224 | } 225 | panic(nil) 226 | } 227 | 228 | // Prepare is an automatic connect/reconnect/repeat version of mysql.Conn.Prepare. 229 | func (c *Conn) Prepare(sql string) (*Stmt, error) { 230 | var s Stmt 231 | s.sql = sql 232 | if err := c.PrepareOnce(&s, sql); err != nil { 233 | return nil, err 234 | } 235 | return &s, nil 236 | } 237 | 238 | func (c *Conn) reprepare(stmt *Stmt) error { 239 | sql := stmt.sql 240 | stmt.Raw = nil 241 | 242 | return c.PrepareOnce(stmt, sql) 243 | } 244 | 245 | // Begin starts a transaction and calls f to complete it. 246 | // If f returns an error and IsNetErr(error) == true it reconnects and calls 247 | // f up to MaxRetries times. If error is of type *mysql.Error it tries to rollback 248 | // the transaction. 249 | func (c *Conn) Begin(f func(mysql.Transaction, ...interface{}) error, args ...interface{}) error { 250 | err := c.connectIfNotConnected() 251 | if err != nil { 252 | return err 253 | } 254 | nn := 0 255 | for { 256 | var tr mysql.Transaction 257 | if tr, err = c.Raw.Begin(); err == nil { 258 | if err = f(tr, args...); err == nil { 259 | return nil 260 | } 261 | } 262 | if c.reconnectIfNetErr(&nn, &err); err != nil { 263 | if _, ok := err.(*mysql.Error); ok && tr.IsValid() { 264 | tr.Rollback() 265 | } 266 | return err 267 | } 268 | } 269 | panic(nil) 270 | } 271 | 272 | // Bind is an automatic connect/reconnect/repeat version of mysql.Stmt.Bind. 273 | func (s *Stmt) Bind(params ...interface{}) { 274 | s.Raw.Bind(params...) 275 | } 276 | 277 | func (s *Stmt) needsRepreparing(err error) bool { 278 | if mysqlErr, ok := err.(*mysql.Error); ok { 279 | if mysqlErr.Code == mysql.ER_UNKNOWN_STMT_HANDLER { 280 | return true 281 | } 282 | } 283 | 284 | return false 285 | } 286 | 287 | // Exec is an automatic connect/reconnect/repeat version of mysql.Stmt.Exec. 288 | func (s *Stmt) Exec(params ...interface{}) (rows []mysql.Row, res mysql.Result, err error) { 289 | 290 | if err = s.con.connectIfNotConnected(); err != nil { 291 | return 292 | } 293 | nn := 0 294 | for { 295 | if rows, res, err = s.Raw.Exec(params...); err == nil { 296 | return 297 | } 298 | 299 | if s.needsRepreparing(err) { 300 | if s.con.reprepare(s) != nil { 301 | return 302 | } 303 | 304 | // Try again 305 | continue 306 | } 307 | 308 | if s.con.reconnectIfNetErr(&nn, &err); err != nil { 309 | return 310 | } 311 | } 312 | panic(nil) 313 | } 314 | 315 | // ExecFirst is an automatic connect/reconnect/repeat version of mysql.Stmt.ExecFirst. 316 | func (s *Stmt) ExecFirst(params ...interface{}) (row mysql.Row, res mysql.Result, err error) { 317 | 318 | if err = s.con.connectIfNotConnected(); err != nil { 319 | return 320 | } 321 | nn := 0 322 | for { 323 | if row, res, err = s.Raw.ExecFirst(params...); err == nil { 324 | return 325 | } 326 | 327 | if s.needsRepreparing(err) { 328 | if s.con.reprepare(s) != nil { 329 | return 330 | } 331 | 332 | // Try again 333 | continue 334 | } 335 | 336 | if s.con.reconnectIfNetErr(&nn, &err); err != nil { 337 | return 338 | } 339 | } 340 | panic(nil) 341 | } 342 | 343 | // ExecLast is an automatic connect/reconnect/repeat version of mysql.Stmt.ExecLast. 344 | func (s *Stmt) ExecLast(params ...interface{}) (row mysql.Row, res mysql.Result, err error) { 345 | 346 | if err = s.con.connectIfNotConnected(); err != nil { 347 | return 348 | } 349 | nn := 0 350 | for { 351 | if row, res, err = s.Raw.ExecLast(params...); err == nil { 352 | return 353 | } 354 | 355 | if s.needsRepreparing(err) { 356 | if s.con.reprepare(s) != nil { 357 | return 358 | } 359 | 360 | // Try again 361 | continue 362 | } 363 | 364 | if s.con.reconnectIfNetErr(&nn, &err); err != nil { 365 | return 366 | } 367 | } 368 | panic(nil) 369 | } 370 | -------------------------------------------------------------------------------- /autorc/autorecon_test.go: -------------------------------------------------------------------------------- 1 | package autorc 2 | 3 | import ( 4 | _ "github.com/ziutek/mymysql/thrsafe" 5 | "testing" 6 | ) 7 | 8 | var ( 9 | conn = []string{"tcp", "", "127.0.0.1:3306"} 10 | user = "testuser" 11 | passwd = "TestPasswd9" 12 | dbname = "test" 13 | ) 14 | 15 | func checkErr(t *testing.T, err error, exp_err error) { 16 | if err != exp_err { 17 | if exp_err == nil { 18 | t.Fatalf("Error: %v", err) 19 | } else { 20 | t.Fatalf("Error: %v\nExpected error: %v", err, exp_err) 21 | } 22 | } 23 | } 24 | 25 | func TestAutoConnectReconnect(t *testing.T) { 26 | c := New(conn[0], conn[1], conn[2], user, passwd) 27 | c.Debug = false 28 | 29 | // Register initialisation commands 30 | c.Register("set names utf8") 31 | 32 | // my is in unconnected state 33 | checkErr(t, c.Use(dbname), nil) 34 | 35 | // Disconnect 36 | c.Raw.Close() 37 | 38 | // Drop test table if exists 39 | _, _, err := c.Query("drop table if exists R") 40 | checkErr(t, err, nil) 41 | 42 | // Disconnect 43 | c.Raw.Close() 44 | 45 | // Create table 46 | _, _, err = c.Query( 47 | "create table R (id int primary key, name varchar(20))", 48 | ) 49 | checkErr(t, err, nil) 50 | 51 | // Kill the connection 52 | c.Query("kill %d", c.Raw.ThreadId()) 53 | // MySQL 5.5 returns "Query execution was interrupted" after kill command 54 | 55 | // Prepare insert statement 56 | ins, err := c.Prepare("insert R values (?, ?)") 57 | checkErr(t, err, nil) 58 | 59 | // Kill the connection 60 | c.Query("kill %d", c.Raw.ThreadId()) 61 | 62 | // Bind insert parameters 63 | ins.Bind(1, "jeden") 64 | // Insert into table 65 | _, _, err = ins.Exec() 66 | checkErr(t, err, nil) 67 | 68 | // Kill the connection 69 | c.Query("kill %d", c.Raw.ThreadId()) 70 | 71 | // Bind insert parameters 72 | ins.Bind(2, "dwa") 73 | // Insert into table 74 | _, _, err = ins.Exec() 75 | checkErr(t, err, nil) 76 | 77 | // Kill the connection 78 | c.Query("kill %d", c.Raw.ThreadId()) 79 | 80 | // Select from table 81 | rows, res, err := c.Query("select * from R") 82 | checkErr(t, err, nil) 83 | id := res.Map("id") 84 | name := res.Map("name") 85 | if len(rows) != 2 || 86 | rows[0].Int(id) != 1 || rows[0].Str(name) != "jeden" || 87 | rows[1].Int(id) != 2 || rows[1].Str(name) != "dwa" { 88 | t.Fatal("Bad result") 89 | } 90 | 91 | // Kill the connection 92 | c.Query("kill %d", c.Raw.ThreadId()) 93 | 94 | // Drop table 95 | _, _, err = c.Query("drop table R") 96 | checkErr(t, err, nil) 97 | 98 | // Disconnect 99 | c.Raw.Close() 100 | } 101 | -------------------------------------------------------------------------------- /codelingo.yaml: -------------------------------------------------------------------------------- 1 | tenets: 2 | - import: codelingo/code-review-comments 3 | - import: codelingo/effective-go 4 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package mymysql provides MySQL client API and database/sql driver. 2 | // 3 | // It can be used as a library or as a database/sql driver. 4 | // 5 | // Using as a library 6 | // 7 | // Import native or thrsafe engine. Optionally import autorc for autoreconnect connections. 8 | // 9 | // import ( 10 | // "github.com/ziutek/mymysql/mysql" 11 | // _ "github.com/ziutek/mymysql/thrsafe" // OR native 12 | // // _ "github.com/ziutek/mymysql/native" 13 | // "github.com/ziutek/mymysql/autorc" // for autoreconnect 14 | // ) 15 | // 16 | // 17 | // 18 | // Using as a Go sql driver 19 | // 20 | // Import Go standard sql package and godrv driver. 21 | // 22 | // import ( 23 | // "database/sql" 24 | // _ "github.com/ziutek/mymysql/godrv" 25 | // ) 26 | // 27 | // 28 | // 29 | package mymysql 30 | -------------------------------------------------------------------------------- /examples/database_sql/database_sql.go: -------------------------------------------------------------------------------- 1 | // Example demonstrates using of mymysql/godrv driver. 2 | package main 3 | 4 | import ( 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | 9 | _ "github.com/ziutek/mymysql/godrv" // Go driver for database/sql package 10 | ) 11 | 12 | func main() { 13 | db, err := sql.Open("mymysql", "tcp:127.0.0.1:3306*mydb/username/passw0rd") 14 | if err != nil { 15 | log.Fatal(err) 16 | } 17 | 18 | id := 1 19 | var query = "SELECT email from users WHERE id = ?" 20 | 21 | rows, err := db.Query(query, id) 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | 26 | var email string 27 | for rows.Next() { 28 | if err := rows.Scan(&email); err != nil { 29 | log.Fatal(err) 30 | } 31 | 32 | fmt.Printf("Email address: %s\n", email) 33 | 34 | } 35 | 36 | if err := rows.Err(); err != nil { 37 | log.Fatal(err) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/long_data/long_data.go: -------------------------------------------------------------------------------- 1 | // Example reads URL from stdin and retrieves its content directly to 2 | // database using SendLongData method. 3 | package main 4 | 5 | import ( 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "strings" 10 | 11 | "github.com/ziutek/mymysql/mysql" 12 | _ "github.com/ziutek/mymysql/thrsafe" 13 | ) 14 | 15 | func printOK() { 16 | fmt.Println("OK") 17 | } 18 | 19 | func checkError(err error) { 20 | if err != nil { 21 | fmt.Println(err) 22 | os.Exit(1) 23 | } 24 | } 25 | 26 | func main() { 27 | user := "testuser" 28 | pass := "TestPasswd9" 29 | dbname := "test" 30 | //proto := "unix" 31 | //addr := "/var/run/mysqld/mysqld.sock" 32 | proto := "tcp" 33 | addr := "127.0.0.1:3306" 34 | 35 | db := mysql.New(proto, "", addr, user, pass, dbname) 36 | //db.Debug = true 37 | 38 | fmt.Printf("Connect to %s:%s... ", proto, addr) 39 | checkError(db.Connect()) 40 | printOK() 41 | 42 | fmt.Print("Drop 'web' table if exists... ") 43 | _, err := db.Start("DROP TABLE web") 44 | if err == nil { 45 | printOK() 46 | } else if e, ok := err.(*mysql.Error); ok { 47 | // Error from MySQL server 48 | fmt.Println(e) 49 | } else { 50 | checkError(err) 51 | } 52 | 53 | fmt.Print("Create 'web' table... ") 54 | _, err = db.Start("CREATE TABLE web (url VARCHAR(80), content LONGBLOB)") 55 | checkError(err) 56 | printOK() 57 | 58 | fmt.Print("Prepare insert statement... ") 59 | ins, err := db.Prepare("INSERT INTO web VALUES (?, ?)") 60 | checkError(err) 61 | printOK() 62 | 63 | fmt.Print("Prepare select statement... ") 64 | sel, err := db.Prepare("SELECT url, OCTET_LENGTH(content) FROM web") 65 | checkError(err) 66 | printOK() 67 | 68 | var url string 69 | 70 | fmt.Print("Bind insert parameters... ") 71 | ins.Bind(&url, []byte(nil)) 72 | printOK() 73 | 74 | fmt.Println() 75 | for { 76 | url = "" 77 | fmt.Print("Please enter an URL (blank line terminates input): ") 78 | fmt.Scanln(&url) 79 | if len(url) == 0 { 80 | break 81 | } 82 | if !strings.Contains(url, "://") { 83 | url = "http://" + url 84 | } 85 | http_res, err := http.Get(url) 86 | if err != nil { 87 | fmt.Println(err) 88 | continue 89 | } 90 | // Retrieve response directly into database. Use 8 kB buffer. 91 | checkError(ins.SendLongData(1, http_res.Body, 8192)) 92 | _, err = ins.Run() 93 | checkError(err) 94 | } 95 | fmt.Println() 96 | 97 | fmt.Print("Select from 'web' table... ") 98 | rows, res, err := sel.Exec() 99 | checkError(err) 100 | printOK() 101 | 102 | // Print fields names 103 | fmt.Println() 104 | for _, field := range res.Fields() { 105 | fmt.Printf("%-38s ", field.Name) 106 | } 107 | fmt.Println() 108 | fmt.Println("------------------------------------------------------------") 109 | 110 | // Print result 111 | for _, row := range rows { 112 | for ii, col := range row { 113 | if col == nil { 114 | fmt.Print("%-38s ", "NULL") 115 | } else { 116 | fmt.Printf("%-38s ", row.Bin(ii)) 117 | } 118 | } 119 | fmt.Println() 120 | } 121 | fmt.Println() 122 | 123 | fmt.Print("Hit ENTER to exit ") 124 | fmt.Scanln() 125 | 126 | fmt.Print("Remove 'web' table... ") 127 | _, err = db.Start("DROP TABLE web") 128 | checkError(err) 129 | printOK() 130 | 131 | fmt.Print("Close connection... ") 132 | checkError(db.Close()) 133 | printOK() 134 | } 135 | -------------------------------------------------------------------------------- /examples/parallel/parallel.go: -------------------------------------------------------------------------------- 1 | // This file is there temporary and it isn't any example of how to use mymysql. 2 | package main 3 | 4 | import ( 5 | "io" 6 | "log" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/ziutek/mymysql/mysql" 13 | _ "github.com/ziutek/mymysql/native" 14 | ) 15 | 16 | const ( 17 | n_sends = 3 * 1000 18 | n_goroutines = 100 19 | ) 20 | 21 | func checkErr(err error) { 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | } 26 | 27 | func main() { 28 | work_chan := make(chan bool) 29 | sends_chan := make(chan bool) 30 | results_chan := make(chan bool, n_sends) 31 | 32 | signal_chan := make(chan os.Signal, 1) 33 | signal.Notify(signal_chan, syscall.SIGINT) 34 | 35 | for i := 0; i < n_goroutines; i++ { 36 | go func() { 37 | conn := mysql.New( 38 | "tcp", "", "127.0.0.1:3306", 39 | "testuser", "TestPasswd9", 40 | ) 41 | conn.SetTimeout(2 * time.Second) 42 | defer conn.Close() 43 | 44 | for { 45 | <-work_chan 46 | 47 | if !conn.IsConnected() { 48 | checkErr(conn.Reconnect()) 49 | } 50 | 51 | res, err := conn.Start("show processlist") 52 | checkErr(err) 53 | row := res.MakeRow() 54 | for { 55 | err := res.ScanRow(row) 56 | if err == io.EOF { 57 | break 58 | } 59 | checkErr(err) 60 | // _, _ = row.ForceUint64(0), row.ForceUint(1) 61 | } 62 | 63 | // sleep_time := time.Duration(rand.Intn(10)) * time.Millisecond 64 | // time.Sleep(sleep_time) 65 | 66 | results_chan <- true 67 | } 68 | }() 69 | } 70 | 71 | go func() { 72 | for i := 0; i < n_sends; i++ { 73 | work_chan <- true 74 | sends_chan <- true 75 | } 76 | }() 77 | 78 | done_sends := 0 79 | ticker := time.NewTicker(1 * time.Second) 80 | 81 | for got_results := 0; got_results < n_sends; { 82 | select { 83 | case <-results_chan: 84 | got_results++ 85 | case <-sends_chan: 86 | done_sends++ 87 | case <-ticker.C: 88 | log.Printf("done %d sends, got %d results", done_sends, got_results) 89 | case <-signal_chan: 90 | panic("show me the goroutines") 91 | } 92 | } 93 | log.Printf("got all %d results", n_sends) 94 | 95 | // panic("show me the goroutines") 96 | } 97 | -------------------------------------------------------------------------------- /examples/prepared_stmt/prepared_stmt.go: -------------------------------------------------------------------------------- 1 | // Examples prepares and executes statements. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | 8 | "github.com/ziutek/mymysql/mysql" 9 | _ "github.com/ziutek/mymysql/thrsafe" 10 | ) 11 | 12 | func printOK() { 13 | fmt.Println("OK") 14 | } 15 | 16 | func checkError(err error) { 17 | if err != nil { 18 | fmt.Println(err) 19 | os.Exit(1) 20 | } 21 | } 22 | 23 | func checkedResult(rows []mysql.Row, res mysql.Result, err error) ([]mysql.Row, mysql.Result) { 24 | checkError(err) 25 | return rows, res 26 | } 27 | 28 | func main() { 29 | user := "testuser" 30 | pass := "TestPasswd9" 31 | dbname := "test" 32 | //proto := "unix" 33 | //addr := "/var/run/mysqld/mysqld.sock" 34 | proto := "tcp" 35 | addr := "127.0.0.1:3306" 36 | 37 | db := mysql.New(proto, "", addr, user, pass, dbname) 38 | 39 | fmt.Printf("Connect to %s:%s... ", proto, addr) 40 | checkError(db.Connect()) 41 | printOK() 42 | 43 | fmt.Print("Drop A table if exists... ") 44 | _, err := db.Start("drop table A") 45 | if err == nil { 46 | printOK() 47 | } else if e, ok := err.(*mysql.Error); ok { 48 | // Error from MySQL server 49 | fmt.Println(e) 50 | } else { 51 | checkError(err) 52 | } 53 | 54 | fmt.Print("Create A table... ") 55 | checkedResult(db.Query("create table A (name varchar(40), number int)")) 56 | printOK() 57 | 58 | fmt.Print("Prepare insert statement... ") 59 | ins, err := db.Prepare("insert A values (?, ?)") 60 | checkError(err) 61 | printOK() 62 | 63 | fmt.Print("Prepare select statement... ") 64 | sel, err := db.Prepare("select * from A where number > ? or number is null") 65 | checkError(err) 66 | printOK() 67 | 68 | params := struct { 69 | txt *string 70 | number *int 71 | }{} 72 | 73 | fmt.Print("Bind insert parameters... ") 74 | ins.Bind(¶ms) 75 | printOK() 76 | 77 | fmt.Print("Insert into A... ") 78 | for ii := 0; ii < 1000; ii += 100 { 79 | if ii%500 == 0 { 80 | // Assign NULL values to the parameters 81 | params.txt = nil 82 | params.number = nil 83 | } else { 84 | // Modify parameters 85 | str := fmt.Sprintf("%d*10= %d", ii/100, ii/10) 86 | params.txt = &str 87 | params.number = &ii 88 | } 89 | // Execute statement with modified data 90 | _, err = ins.Run() 91 | checkError(err) 92 | } 93 | printOK() 94 | 95 | fmt.Println("Select from A... ") 96 | rows, res := checkedResult(sel.Exec(0)) 97 | name := res.Map("name") 98 | number := res.Map("number") 99 | for ii, row := range rows { 100 | fmt.Printf( 101 | "Row: %d\n name: %-10s {%#v}\n number: %-8d {%#v}\n", ii, 102 | "'"+row.Str(name)+"'", row[name], 103 | row.Int(number), row[number], 104 | ) 105 | } 106 | 107 | fmt.Print("Remove A... ") 108 | checkedResult(db.Query("drop table A")) 109 | printOK() 110 | 111 | fmt.Print("Close connection... ") 112 | checkError(db.Close()) 113 | printOK() 114 | } 115 | -------------------------------------------------------------------------------- /examples/reconnect/reconnect.go: -------------------------------------------------------------------------------- 1 | // Example demonstrates using of autorc package. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | "github.com/ziutek/mymysql/autorc" 10 | _ "github.com/ziutek/mymysql/thrsafe" 11 | ) 12 | 13 | func main() { 14 | user := "testuser" 15 | passwd := "TestPasswd9" 16 | dbname := "test" 17 | //conn := []string{"unix", "", "/var/run/mysqld/mysqld.sock"} 18 | conn := []string{"tcp", "", "127.0.0.1:3306"} 19 | 20 | c := autorc.New(conn[0], conn[1], conn[2], user, passwd) 21 | 22 | // Register initialisation commands 23 | c.Raw.Register("set names utf8") 24 | 25 | // my is in unconnected state 26 | checkErr(c.Use(dbname)) 27 | 28 | // Now we ar connected - disconnect 29 | c.Raw.Close() 30 | 31 | // Drop test table if exists 32 | _, _, err := c.Query("drop table R") 33 | 34 | fmt.Println("You may restart MySQL sererr or down the network interface.") 35 | sec := 9 36 | fmt.Printf("Waiting %ds...", sec) 37 | for sec--; sec >= 0; sec-- { 38 | time.Sleep(1e9) 39 | fmt.Printf("\b\b\b\b\b%ds...", sec) 40 | } 41 | fmt.Println() 42 | 43 | // Create table 44 | _, _, err = c.Query( 45 | "create table R (id int primary key, name varchar(20))", 46 | ) 47 | checkErr(err) 48 | 49 | // Kill the connection 50 | _, _, err = c.Query("kill %d", c.Raw.ThreadId()) 51 | checkErr(err) 52 | 53 | // Prepare insert statement 54 | ins, err := c.Prepare("insert R values (?, ?)") 55 | checkErr(err) 56 | 57 | // Kill the connection 58 | _, _, err = c.Query("kill %d", c.Raw.ThreadId()) 59 | checkErr(err) 60 | 61 | // Bind insert parameters 62 | ins.Raw.Bind(1, "jeden") 63 | // Insert into table 64 | _, _, err = ins.Exec() 65 | checkErr(err) 66 | 67 | // Kill the connection 68 | _, _, err = c.Query("kill %d", c.Raw.ThreadId()) 69 | checkErr(err) 70 | 71 | // Bind insert parameters 72 | ins.Raw.Bind(2, "dwa") 73 | // Insert into table 74 | _, _, err = ins.Exec() 75 | checkErr(err) 76 | 77 | // Kill the connection 78 | _, _, err = c.Query("kill %d", c.Raw.ThreadId()) 79 | checkErr(err) 80 | 81 | // Select from table 82 | rows, res, err := c.Query("select * from R") 83 | checkErr(err) 84 | id := res.Map("id") 85 | name := res.Map("name") 86 | if len(rows) != 2 || 87 | rows[0].Int(id) != 1 || rows[0].Str(name) != "jeden" || 88 | rows[1].Int(id) != 2 || rows[1].Str(name) != "dwa" { 89 | fmt.Println("Bad result") 90 | } 91 | 92 | // Kill the connection 93 | _, _, err = c.Query("kill %d", c.Raw.ThreadId()) 94 | checkErr(err) 95 | 96 | // Drop table 97 | _, _, err = c.Query("drop table R") 98 | checkErr(err) 99 | 100 | // Disconnect 101 | c.Raw.Close() 102 | 103 | } 104 | 105 | func checkErr(err error) { 106 | if err != nil { 107 | fmt.Println("Error:", err) 108 | os.Exit(1) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /examples/simple/simple.go: -------------------------------------------------------------------------------- 1 | // Example demonstrates simple thread safe DB operations. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | 8 | "github.com/ziutek/mymysql/mysql" 9 | _ "github.com/ziutek/mymysql/thrsafe" 10 | ) 11 | 12 | func printOK() { 13 | fmt.Println("OK") 14 | } 15 | 16 | func checkError(err error) { 17 | if err != nil { 18 | fmt.Println(err) 19 | os.Exit(1) 20 | } 21 | } 22 | 23 | func checkedResult(rows []mysql.Row, res mysql.Result, err error) ([]mysql.Row, 24 | mysql.Result) { 25 | checkError(err) 26 | return rows, res 27 | } 28 | 29 | func main() { 30 | user := "testuser" 31 | pass := "TestPasswd9" 32 | dbname := "test" 33 | //proto := "unix" 34 | //addr := "/var/run/mysqld/mysqld.sock" 35 | proto := "tcp" 36 | addr := "127.0.0.1:3306" 37 | 38 | db := mysql.New(proto, "", addr, user, pass, dbname) 39 | 40 | fmt.Printf("Connect to %s:%s... ", proto, addr) 41 | checkError(db.Connect()) 42 | printOK() 43 | 44 | fmt.Print("Drop A table if exists... ") 45 | _, err := db.Start("drop table A") 46 | if err == nil { 47 | printOK() 48 | } else if e, ok := err.(*mysql.Error); ok { 49 | // Error from MySQL server 50 | fmt.Println(e) 51 | } else { 52 | checkError(err) 53 | } 54 | 55 | fmt.Print("Create A table... ") 56 | checkedResult(db.Query("create table A (name varchar(40), number int)")) 57 | printOK() 58 | 59 | fmt.Print("Insert into A... ") 60 | for ii := 0; ii < 10; ii++ { 61 | if ii%5 == 0 { 62 | checkedResult(db.Query("insert A values (null, null)")) 63 | } else { 64 | checkedResult(db.Query( 65 | "insert A values ('%d*10= %d', %d)", ii, ii*10, ii*100, 66 | )) 67 | } 68 | } 69 | printOK() 70 | 71 | fmt.Println("Select from A... ") 72 | rows, res := checkedResult(db.Query("select * from A")) 73 | name := res.Map("name") 74 | number := res.Map("number") 75 | for ii, row := range rows { 76 | fmt.Printf( 77 | "Row: %d\n name: %-10s {%#v}\n number: %-8d {%#v}\n", ii, 78 | "'"+row.Str(name)+"'", row[name], 79 | row.Int(number), row[number], 80 | ) 81 | } 82 | 83 | fmt.Print("Remove A... ") 84 | checkedResult(db.Query("drop table A")) 85 | printOK() 86 | 87 | fmt.Print("Close connection... ") 88 | checkError(db.Close()) 89 | printOK() 90 | } 91 | -------------------------------------------------------------------------------- /examples/transactions/transactions.go: -------------------------------------------------------------------------------- 1 | // Example using transactions. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | 8 | "github.com/ziutek/mymysql/mysql" 9 | _ "github.com/ziutek/mymysql/thrsafe" 10 | //_ "github.com/ziutek/mymysql/native" 11 | ) 12 | 13 | func printOK() { 14 | fmt.Println("OK") 15 | } 16 | 17 | func checkError(err error) { 18 | if err != nil { 19 | fmt.Println(err) 20 | os.Exit(1) 21 | } 22 | } 23 | 24 | func checkedResult(rows []mysql.Row, res mysql.Result, err error) ([]mysql.Row, mysql.Result) { 25 | checkError(err) 26 | return rows, res 27 | } 28 | 29 | func main() { 30 | user := "testuser" 31 | pass := "TestPasswd9" 32 | dbname := "test" 33 | //proto := "unix" 34 | //addr := "/var/run/mysqld/mysqld.sock" 35 | proto := "tcp" 36 | addr := "127.0.0.1:3306" 37 | 38 | db := mysql.New(proto, "", addr, user, pass, dbname) 39 | 40 | fmt.Printf("Connect to %s:%s... ", proto, addr) 41 | checkError(db.Connect()) 42 | printOK() 43 | 44 | fmt.Print("Drop A table if exists... ") 45 | _, err := db.Start("drop table A") 46 | if err == nil { 47 | printOK() 48 | } else if e, ok := err.(*mysql.Error); ok { 49 | // Error from MySQL server 50 | fmt.Println(e) 51 | } else { 52 | checkError(err) 53 | } 54 | 55 | fmt.Print("Create A table... ") 56 | _, err = db.Start("create table A (name varchar(9), number int) engine=InnoDB") 57 | checkError(err) 58 | printOK() 59 | 60 | fmt.Print("Prepare insert statement... ") 61 | ins, err := db.Prepare("insert A values (?, ?)") 62 | checkError(err) 63 | printOK() 64 | 65 | fmt.Print("Prepare select statement... ") 66 | sel, err := db.Prepare("select * from A") 67 | checkError(err) 68 | printOK() 69 | 70 | fmt.Print("Begining a new transaction... ") 71 | tr, err := db.Begin() 72 | checkError(err) 73 | printOK() 74 | 75 | tr_ins := tr.Do(ins) 76 | 77 | fmt.Print("Performing two inserts... ") 78 | _, err = tr_ins.Run("jeden", 1) 79 | checkError(err) 80 | _, err = tr_ins.Run("dwa", 2) 81 | checkError(err) 82 | printOK() 83 | 84 | fmt.Print("Commit the transaction... ") 85 | checkError(tr.Commit()) 86 | printOK() 87 | 88 | fmt.Print("Begining a new transaction... ") 89 | tr, err = db.Begin() 90 | checkError(err) 91 | printOK() 92 | 93 | fmt.Print("Performing one insert... ") 94 | _, err = tr.Do(ins).Run("trzy", 3) 95 | checkError(err) 96 | printOK() 97 | 98 | fmt.Print("Rollback the transaction... ") 99 | checkError(tr.Rollback()) 100 | printOK() 101 | 102 | fmt.Println("Select from A... ") 103 | rows, res := checkedResult(sel.Exec()) 104 | name := res.Map("name") 105 | number := res.Map("number") 106 | for ii, row := range rows { 107 | fmt.Printf("%d: %-10s %-8d\n", ii, row[name], row[number]) 108 | } 109 | 110 | fmt.Print("Remove A... ") 111 | checkedResult(db.Query("drop table A")) 112 | printOK() 113 | 114 | fmt.Print("Close connection... ") 115 | checkError(db.Close()) 116 | printOK() 117 | } 118 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ziutek/mymysql 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /godrv/appengine.go: -------------------------------------------------------------------------------- 1 | // +build appengine 2 | 3 | package godrv 4 | 5 | import ( 6 | "net" 7 | "time" 8 | 9 | "appengine/cloudsql" 10 | ) 11 | 12 | func init() { 13 | SetDialer(func(proto, laddr, raddr, user, dbname string, timeout time.Duration) (net.Conn, error) { 14 | return cloudsql.Dial(raddr) 15 | }) 16 | } 17 | -------------------------------------------------------------------------------- /godrv/driver.go: -------------------------------------------------------------------------------- 1 | // Package godrv implements database/sql MySQL driver. 2 | package godrv 3 | 4 | import ( 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "github.com/ziutek/mymysql/mysql" 16 | "github.com/ziutek/mymysql/native" 17 | ) 18 | 19 | type conn struct { 20 | my mysql.Conn 21 | } 22 | 23 | type rowsRes struct { 24 | row mysql.Row 25 | my mysql.Result 26 | simpleQuery mysql.Stmt 27 | } 28 | 29 | func errFilter(err error) error { 30 | if err == io.ErrUnexpectedEOF { 31 | return driver.ErrBadConn 32 | } 33 | if _, ok := err.(net.Error); ok { 34 | return driver.ErrBadConn 35 | } 36 | return err 37 | } 38 | 39 | func join(a []string) string { 40 | n := 0 41 | for _, s := range a { 42 | n += len(s) 43 | } 44 | b := make([]byte, n) 45 | n = 0 46 | for _, s := range a { 47 | n += copy(b[n:], s) 48 | } 49 | return string(b) 50 | } 51 | 52 | func (c conn) parseQuery(query string, args []driver.Value) (string, error) { 53 | if len(args) == 0 { 54 | return query, nil 55 | } 56 | if strings.ContainsAny(query, `'"`) { 57 | return "", nil 58 | } 59 | q := make([]string, 2*len(args)+1) 60 | n := 0 61 | for _, a := range args { 62 | i := strings.IndexRune(query, '?') 63 | if i == -1 { 64 | return "", errors.New("number of parameters doesn't match number of placeholders") 65 | } 66 | var s string 67 | switch v := a.(type) { 68 | case nil: 69 | s = "NULL" 70 | case string: 71 | s = "'" + c.my.Escape(v) + "'" 72 | case []byte: 73 | s = "'" + c.my.Escape(string(v)) + "'" 74 | case int64: 75 | s = strconv.FormatInt(v, 10) 76 | case time.Time: 77 | s = "'" + v.Format(mysql.TimeFormat) + "'" 78 | case bool: 79 | if v { 80 | s = "1" 81 | } else { 82 | s = "0" 83 | } 84 | case float64: 85 | s = strconv.FormatFloat(v, 'e', 12, 64) 86 | default: 87 | panic(fmt.Sprintf("%v (%T) can't be handled by godrv", v, v)) 88 | } 89 | q[n] = query[:i] 90 | q[n+1] = s 91 | query = query[i+1:] 92 | n += 2 93 | } 94 | q[n] = query 95 | return join(q), nil 96 | } 97 | 98 | func (c conn) Exec(query string, args []driver.Value) (driver.Result, error) { 99 | q, err := c.parseQuery(query, args) 100 | if err != nil { 101 | return nil, err 102 | } 103 | if len(q) == 0 { 104 | return nil, driver.ErrSkip 105 | } 106 | res, err := c.my.Start(q) 107 | if err != nil { 108 | return nil, errFilter(err) 109 | } 110 | return &rowsRes{my: res}, nil 111 | } 112 | 113 | var textQuery = mysql.Stmt(new(native.Stmt)) 114 | 115 | func (c conn) Query(query string, args []driver.Value) (driver.Rows, error) { 116 | q, err := c.parseQuery(query, args) 117 | if err != nil { 118 | return nil, err 119 | } 120 | if len(q) == 0 { 121 | return nil, driver.ErrSkip 122 | } 123 | res, err := c.my.Start(q) 124 | if err != nil { 125 | return nil, errFilter(err) 126 | } 127 | return &rowsRes{row: res.MakeRow(), my: res, simpleQuery: textQuery}, nil 128 | } 129 | 130 | type stmt struct { 131 | my mysql.Stmt 132 | args []interface{} 133 | } 134 | 135 | func (s *stmt) run(args []driver.Value) (*rowsRes, error) { 136 | for i, v := range args { 137 | s.args[i] = interface{}(v) 138 | } 139 | res, err := s.my.Run(s.args...) 140 | if err != nil { 141 | return nil, errFilter(err) 142 | } 143 | return &rowsRes{my: res}, nil 144 | } 145 | 146 | func (c conn) Prepare(query string) (driver.Stmt, error) { 147 | st, err := c.my.Prepare(query) 148 | if err != nil { 149 | return nil, errFilter(err) 150 | } 151 | return &stmt{st, make([]interface{}, st.NumParam())}, nil 152 | } 153 | 154 | func (c *conn) Close() (err error) { 155 | err = c.my.Close() 156 | c.my = nil 157 | if err != nil { 158 | err = errFilter(err) 159 | } 160 | return 161 | } 162 | 163 | type tx struct { 164 | my mysql.Transaction 165 | } 166 | 167 | func (c conn) Begin() (driver.Tx, error) { 168 | t, err := c.my.Begin() 169 | if err != nil { 170 | return nil, errFilter(err) 171 | } 172 | return tx{t}, nil 173 | } 174 | 175 | func (t tx) Commit() (err error) { 176 | err = t.my.Commit() 177 | if err != nil { 178 | err = errFilter(err) 179 | } 180 | return 181 | } 182 | 183 | func (t tx) Rollback() (err error) { 184 | err = t.my.Rollback() 185 | if err != nil { 186 | err = errFilter(err) 187 | } 188 | return 189 | } 190 | 191 | func (s *stmt) Close() (err error) { 192 | if s.my == nil { 193 | panic("godrv: stmt closed twice") 194 | } 195 | err = s.my.Delete() 196 | s.my = nil 197 | if err != nil { 198 | err = errFilter(err) 199 | } 200 | return 201 | } 202 | 203 | func (s *stmt) NumInput() int { 204 | return s.my.NumParam() 205 | } 206 | 207 | func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 208 | return s.run(args) 209 | } 210 | 211 | func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 212 | r, err := s.run(args) 213 | if err != nil { 214 | return nil, err 215 | } 216 | r.row = r.my.MakeRow() 217 | return r, nil 218 | } 219 | 220 | func (r *rowsRes) LastInsertId() (int64, error) { 221 | return int64(r.my.InsertId()), nil 222 | } 223 | 224 | func (r *rowsRes) RowsAffected() (int64, error) { 225 | return int64(r.my.AffectedRows()), nil 226 | } 227 | 228 | func (r *rowsRes) Columns() []string { 229 | flds := r.my.Fields() 230 | cls := make([]string, len(flds)) 231 | for i, f := range flds { 232 | cls[i] = f.Name 233 | } 234 | return cls 235 | } 236 | 237 | func (r *rowsRes) Close() error { 238 | if r.my == nil { 239 | return nil // closed before 240 | } 241 | if err := r.my.End(); err != nil { 242 | return errFilter(err) 243 | } 244 | if r.simpleQuery != nil && r.simpleQuery != textQuery { 245 | if err := r.simpleQuery.Delete(); err != nil { 246 | return errFilter(err) 247 | } 248 | } 249 | r.my = nil 250 | return nil 251 | } 252 | 253 | var location = time.Local 254 | 255 | // Next: DATE, DATETIME, TIMESTAMP are treated as they are in Local time zone (this 256 | // can be changed globaly using SetLocation function). 257 | func (r *rowsRes) Next(dest []driver.Value) error { 258 | if r.my == nil { 259 | return io.EOF // closed before 260 | } 261 | err := r.my.ScanRow(r.row) 262 | if err == nil { 263 | if r.simpleQuery == textQuery { 264 | // workaround for time.Time from text queries 265 | for i, f := range r.my.Fields() { 266 | if r.row[i] != nil { 267 | switch f.Type { 268 | case native.MYSQL_TYPE_TIMESTAMP, native.MYSQL_TYPE_DATETIME, 269 | native.MYSQL_TYPE_DATE, native.MYSQL_TYPE_NEWDATE: 270 | r.row[i] = r.row.ForceTime(i, location) 271 | } 272 | } 273 | } 274 | } 275 | for i, d := range r.row { 276 | dest[i] = driver.Value(d) 277 | } 278 | return nil 279 | } 280 | if err != io.EOF { 281 | return errFilter(err) 282 | } 283 | if r.simpleQuery != nil && r.simpleQuery != textQuery { 284 | if err = r.simpleQuery.Delete(); err != nil { 285 | return errFilter(err) 286 | } 287 | } 288 | r.my = nil 289 | return io.EOF 290 | } 291 | 292 | // Driver implements database/sql/driver interface. 293 | type Driver struct { 294 | // Defaults 295 | proto, laddr, raddr, user, passwd, db string 296 | timeout time.Duration 297 | dialer Dialer 298 | 299 | initCmds []string 300 | } 301 | 302 | // Open creates a new connection. The uri needs to have the following syntax: 303 | // 304 | // [PROTOCOL_SPECFIIC*]DBNAME/USER/PASSWD 305 | // 306 | // where protocol specific part may be empty (this means connection to 307 | // local server using default protocol). Currently possible forms are: 308 | // 309 | // DBNAME/USER/PASSWD 310 | // unix:SOCKPATH*DBNAME/USER/PASSWD 311 | // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD 312 | // tcp:ADDR*DBNAME/USER/PASSWD 313 | // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD 314 | // cloudsql:INSTANCE*DBNAME/USER/PASSWD 315 | // 316 | // OPTIONS can contain comma separated list of options in form: 317 | // opt1=VAL1,opt2=VAL2,boolopt3,boolopt4 318 | // Currently implemented options, in addition to default MySQL variables: 319 | // laddr - local address/port (eg. 1.2.3.4:0) 320 | // timeout - connect timeout in format accepted by time.ParseDuration 321 | func (d *Driver) Open(uri string) (driver.Conn, error) { 322 | cfg := *d // copy default configuration 323 | pd := strings.SplitN(uri, "*", 2) 324 | connCommands := []string{} 325 | if len(pd) == 2 { 326 | // Parse protocol part of URI 327 | p := strings.SplitN(pd[0], ":", 2) 328 | if len(p) != 2 { 329 | return nil, errors.New("Wrong protocol part of URI") 330 | } 331 | cfg.proto = p[0] 332 | options := strings.Split(p[1], ",") 333 | cfg.raddr = options[0] 334 | for _, o := range options[1:] { 335 | kv := strings.SplitN(o, "=", 2) 336 | var k, v string 337 | if len(kv) == 2 { 338 | k, v = kv[0], kv[1] 339 | } else { 340 | k, v = o, "true" 341 | } 342 | switch k { 343 | case "laddr": 344 | cfg.laddr = v 345 | case "timeout": 346 | to, err := time.ParseDuration(v) 347 | if err != nil { 348 | return nil, err 349 | } 350 | cfg.timeout = to 351 | default: 352 | connCommands = append(connCommands, "SET "+k+"="+v) 353 | } 354 | } 355 | // Remove protocol part 356 | pd = pd[1:] 357 | } 358 | // Parse database part of URI 359 | dup := strings.SplitN(pd[0], "/", 3) 360 | if len(dup) != 3 { 361 | return nil, errors.New("Wrong database part of URI") 362 | } 363 | cfg.db = dup[0] 364 | cfg.user = dup[1] 365 | cfg.passwd = dup[2] 366 | 367 | c := conn{mysql.New( 368 | cfg.proto, cfg.laddr, cfg.raddr, cfg.user, cfg.passwd, cfg.db, 369 | )} 370 | if d.dialer != nil { 371 | dialer := func(proto, laddr, raddr string, timeout time.Duration) ( 372 | net.Conn, error) { 373 | 374 | return d.dialer(proto, laddr, raddr, cfg.user, cfg.passwd, timeout) 375 | } 376 | c.my.SetDialer(dialer) 377 | } 378 | 379 | // Establish the connection 380 | c.my.SetTimeout(cfg.timeout) 381 | for _, q := range cfg.initCmds { 382 | c.my.Register(q) // Register initialisation commands 383 | } 384 | for _, q := range connCommands { 385 | c.my.Register(q) 386 | } 387 | if err := c.my.Connect(); err != nil { 388 | return nil, errFilter(err) 389 | } 390 | c.my.NarrowTypeSet(true) 391 | c.my.FullFieldInfo(false) 392 | return &c, nil 393 | } 394 | 395 | // Register registers initialization commands. 396 | // This is workaround, see http://codereview.appspot.com/5706047 397 | func (drv *Driver) Register(query string) { 398 | drv.initCmds = append(drv.initCmds, query) 399 | } 400 | 401 | // Dialer can be used to dial connections to MySQL. If Dialer returns (nil, nil) 402 | // the hook is skipped and normal dialing proceeds. user and dbname are there 403 | // only for logging. 404 | type Dialer func(proto, laddr, raddr, user, dbname string, timeout time.Duration) (net.Conn, error) 405 | 406 | // SetDialer sets custom Dialer used by Driver to make connections. 407 | func (drv *Driver) SetDialer(dialer Dialer) { 408 | drv.dialer = dialer 409 | } 410 | 411 | // Driver automatically registered in database/sql. 412 | var dfltdrv = Driver{proto: "tcp", raddr: "127.0.0.1:3306"} 413 | 414 | // Register calls Register method on driver registered in database/sql. 415 | // If Register is called twice with the same name it panics. 416 | func Register(query string) { 417 | dfltdrv.Register(query) 418 | } 419 | 420 | // SetDialer calls SetDialer method on driver registered in database/sql. 421 | func SetDialer(dialer Dialer) { 422 | dfltdrv.SetDialer(dialer) 423 | } 424 | 425 | func init() { 426 | Register("SET NAMES utf8") 427 | sql.Register("mymysql", &dfltdrv) 428 | } 429 | 430 | // Version returns mymysql version string. 431 | func Version() string { 432 | return mysql.Version() 433 | } 434 | 435 | // SetLocation changes default location used to convert dates obtained from 436 | // server to time.Time. 437 | func SetLocation(loc *time.Location) { 438 | location = loc 439 | } 440 | -------------------------------------------------------------------------------- /godrv/driver_test.go: -------------------------------------------------------------------------------- 1 | package godrv 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "github.com/ziutek/mymysql/mysql" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func init() { 12 | Register("set names utf8") 13 | } 14 | 15 | func checkErr(t *testing.T, err error) { 16 | if err != nil { 17 | t.Fatalf("Error: %v", err) 18 | } 19 | } 20 | func checkErrId(t *testing.T, err error, rid, eid int64) { 21 | checkErr(t, err) 22 | if rid != eid { 23 | t.Fatal("res.LastInsertId() ==", rid, "but should be", eid) 24 | } 25 | } 26 | 27 | func TestAll(t *testing.T) { 28 | data := []string{"jeden", "dwa", "trzy"} 29 | 30 | db, err := sql.Open("mymysql", "test/testuser/TestPasswd9") 31 | checkErr(t, err) 32 | defer db.Close() 33 | defer db.Exec("DROP TABLE go") 34 | 35 | db.Exec("DROP TABLE go") 36 | 37 | _, err = db.Exec( 38 | `CREATE TABLE go ( 39 | id INT(11) NOT NULL PRIMARY KEY AUTO_INCREMENT, 40 | txt TEXT, 41 | n BIGINT 42 | ) ENGINE=InnoDB`) 43 | checkErr(t, err) 44 | 45 | ins, err := db.Prepare("INSERT go SET txt=?, n=?") 46 | checkErr(t, err) 47 | 48 | tx, err := db.Begin() 49 | checkErr(t, err) 50 | 51 | res, err := ins.Exec(data[0], 0) 52 | checkErr(t, err) 53 | id, err := res.LastInsertId() 54 | checkErrId(t, err, id, 1) 55 | 56 | res, err = ins.Exec(data[1], 1) 57 | checkErr(t, err) 58 | id, err = res.LastInsertId() 59 | checkErrId(t, err, id, 2) 60 | 61 | checkErr(t, tx.Commit()) 62 | 63 | tx, err = db.Begin() 64 | checkErr(t, err) 65 | 66 | res, err = tx.Exec("INSERT go SET txt=?, n=?", "cztery", 3) 67 | checkErr(t, err) 68 | id, err = res.LastInsertId() 69 | checkErrId(t, err, id, 3) 70 | 71 | checkErr(t, tx.Rollback()) 72 | 73 | rows, err := db.Query("SELECT * FROM go") 74 | checkErr(t, err) 75 | i := 1 76 | for rows.Next() { 77 | var ( 78 | id int 79 | txt string 80 | n int64 81 | ) 82 | checkErr(t, rows.Scan(&id, &txt, &n)) 83 | if id > len(data) { 84 | t.Fatal("To many rows in table") 85 | } 86 | if id != i || data[i-1] != txt || int64(i-1) != n { 87 | t.Fatalf("txt[%d] == '%s' != '%s'", id, txt, data[id-1]) 88 | } 89 | i++ 90 | } 91 | checkErr(t, rows.Err()) 92 | 93 | sel, err := db.Prepare("SELECT * FROM go") 94 | checkErr(t, err) 95 | 96 | rows, err = sel.Query() 97 | checkErr(t, err) 98 | i = 1 99 | for rows.Next() { 100 | var ( 101 | id int 102 | txt string 103 | n int64 104 | ) 105 | checkErr(t, rows.Scan(&id, &txt, &n)) 106 | if id > len(data) { 107 | t.Fatal("To many rows in table") 108 | } 109 | if id != i || data[i-1] != txt || int64(i-1) != n { 110 | t.Fatalf("txt[%d] == '%s' != '%s'", id, txt, data[id-1]) 111 | } 112 | i++ 113 | } 114 | checkErr(t, rows.Err()) 115 | 116 | sql := "select sum(41) as test" 117 | row := db.QueryRow(sql) 118 | var vi int64 119 | checkErr(t, row.Scan(&vi)) 120 | if vi != 41 { 121 | t.Fatal(sql) 122 | } 123 | sql = "select sum(4123232323232) as test" 124 | row = db.QueryRow(sql) 125 | var vf float64 126 | checkErr(t, row.Scan(&vf)) 127 | if vf != 4123232323232 { 128 | t.Fatal(sql) 129 | } 130 | } 131 | 132 | func TestMediumInt(t *testing.T) { 133 | db, err := sql.Open("mymysql", "test/testuser/TestPasswd9") 134 | checkErr(t, err) 135 | defer db.Exec("DROP TABLE mi") 136 | defer db.Close() 137 | 138 | db.Exec("DROP TABLE mi") 139 | 140 | _, err = db.Exec( 141 | `CREATE TABLE mi ( 142 | id INT PRIMARY KEY AUTO_INCREMENT, 143 | m MEDIUMINT 144 | )`) 145 | checkErr(t, err) 146 | 147 | const n = 9 148 | 149 | for i := 0; i < n; i++ { 150 | _, err = db.Exec("INSERT mi VALUES (0, ?)", i) 151 | checkErr(t, err) 152 | } 153 | 154 | rows, err := db.Query("SELECT * FROM mi") 155 | checkErr(t, err) 156 | 157 | var i int 158 | for i = 0; rows.Next(); i++ { 159 | var id, m int 160 | checkErr(t, rows.Scan(&id, &m)) 161 | if id != i+1 || m != i { 162 | t.Fatalf("i=%d id=%d m=%d", i, id, m) 163 | } 164 | } 165 | checkErr(t, rows.Err()) 166 | if i != n { 167 | t.Fatalf("%d rows read, %d expected", i, n) 168 | } 169 | } 170 | 171 | func TestTypes(t *testing.T) { 172 | db, err := sql.Open("mymysql", "test/testuser/TestPasswd9") 173 | checkErr(t, err) 174 | defer db.Close() 175 | defer db.Exec("DROP TABLE t") 176 | 177 | db.Exec("DROP TABLE t") 178 | 179 | _, err = db.Exec( 180 | `CREATE TABLE t ( 181 | i INT NOT NULL, 182 | f DOUBLE NOT NULL, 183 | b BOOL NOT NULL, 184 | s VARCHAR(8) NOT NULL, 185 | d DATETIME NOT NULL, 186 | y DATE NOT NULL, 187 | n INT 188 | ) ENGINE=InnoDB`) 189 | checkErr(t, err) 190 | 191 | _, err = db.Exec( 192 | `INSERT t VALUES ( 193 | 23, 0.25, true, 'test', '2013-03-06 21:07', '2013-03-19', NULL 194 | )`, 195 | ) 196 | checkErr(t, err) 197 | l, err := time.LoadLocation("Local") 198 | td := time.Date(2013, 3, 6, 21, 7, 0, 0, l) 199 | dd := time.Date(2013, 3, 19, 0, 0, 0, 0, l) 200 | checkErr(t, err) 201 | _, err = db.Exec( 202 | "INSERT t VALUES (?, ?, ?, ?, ?, ?)", 203 | 23, 0.25, true, "test", td, dd, nil, 204 | ) 205 | 206 | rows, err := db.Query("SELECT * FROM t") 207 | checkErr(t, err) 208 | var ( 209 | i int64 210 | f float64 211 | b bool 212 | s string 213 | d time.Time 214 | y time.Time 215 | n sql.NullInt64 216 | ) 217 | 218 | for rows.Next() { 219 | checkErr(t, rows.Scan(&i, &f, &b, &s, &d, &y, &n)) 220 | if i != 23 { 221 | t.Fatal("int64", i) 222 | } 223 | if f != 0.25 { 224 | t.Fatal("float64", f) 225 | } 226 | if b != true { 227 | t.Fatal("bool", b) 228 | } 229 | if s != "test" { 230 | t.Fatal("string", s) 231 | } 232 | if d != td { 233 | t.Fatal("time.Time", d) 234 | } 235 | if y != dd { 236 | t.Fatal("time.Time", y) 237 | } 238 | if n.Valid { 239 | t.Fatal("mysql.NullInt64", n) 240 | } 241 | } 242 | } 243 | 244 | func TestMultiple(t *testing.T) { 245 | db, err := sql.Open("mymysql", "test/testuser/TestPasswd9") 246 | checkErr(t, err) 247 | defer db.Close() 248 | defer db.Exec("DROP TABLE t") 249 | 250 | db.Exec("DROP TABLE t") 251 | _, err = db.Exec(`CREATE TABLE t ( 252 | email VARCHAR(16), 253 | password VARCHAR(16), 254 | status VARCHAR(16), 255 | signup_date DATETIME, 256 | zipcode VARCHAR(16), 257 | fname VARCHAR(16), 258 | lname VARCHAR(16) 259 | )`) 260 | checkErr(t, err) 261 | 262 | const shortFormat = "2006-01-02 15:04:05" 263 | now := time.Now() 264 | 265 | _, err = db.Exec(fmt.Sprintf(`INSERT INTO t ( 266 | email, 267 | password, 268 | status, 269 | signup_date, 270 | zipcode, 271 | fname, 272 | lname 273 | ) VALUES ( 274 | 'a@a.com', 275 | 'asdf', 276 | 'unverified', 277 | '%s', 278 | '111', 279 | 'asdf', 280 | 'asdf' 281 | );`, now.Format(mysql.TimeFormat))) 282 | checkErr(t, err) 283 | 284 | _, err = db.Exec(`INSERT INTO t ( 285 | email, 286 | password, 287 | status, 288 | signup_date, 289 | zipcode, 290 | fname, 291 | lname 292 | ) VALUES ( 293 | ?, ?, ?, ?, ?, ?, ? 294 | );`, "a@a.com", "asdf", "unverified", now, "111", "asdf", "asdf") 295 | checkErr(t, err) 296 | 297 | _, err = db.Exec(`INSERT INTO t ( 298 | email, 299 | password, 300 | status, 301 | signup_date, 302 | zipcode, 303 | fname, 304 | lname 305 | ) VALUES ( 306 | "a@a.com", 'asdf', ?, ?, ?, ?, 'asdf' 307 | );`, "unverified", now, "111", "asdf") 308 | checkErr(t, err) 309 | 310 | rows, err := db.Query("SELECT * FROM t") 311 | checkErr(t, err) 312 | var ( 313 | email, password, status, zipcode, fname, lname string 314 | signup_date time.Time 315 | ) 316 | n := 0 317 | for rows.Next() { 318 | checkErr(t, rows.Scan( 319 | &email, &password, &status, &signup_date, &zipcode, &fname, &lname, 320 | )) 321 | if email != "a@a.com" { 322 | t.Fatal(n, "email:", email) 323 | } 324 | if password != "asdf" { 325 | t.Fatal(n, "password:", password) 326 | } 327 | if status != "unverified" { 328 | t.Fatal(n, "status:", status) 329 | 330 | } 331 | e := signup_date.Format(mysql.TimeFormat) 332 | d := signup_date.Format(mysql.TimeFormat) 333 | if e[:len(shortFormat)] != d[:len(shortFormat)] { 334 | t.Fatal(n, "signup_date:", d) 335 | } 336 | if zipcode != "111" { 337 | t.Fatal(n, "zipcode:", zipcode) 338 | } 339 | if fname != "asdf" { 340 | t.Fatal(n, "fname:", fname) 341 | } 342 | if lname != "asdf" { 343 | t.Fatal(n, "lname:", lname) 344 | } 345 | n++ 346 | } 347 | if n != 3 { 348 | t.Fatal("Too short result set") 349 | } 350 | } 351 | -------------------------------------------------------------------------------- /mysql/field.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | type Field struct { 4 | Catalog string 5 | Db string 6 | Table string 7 | OrgTable string 8 | Name string 9 | OrgName string 10 | DispLen uint32 11 | // Charset uint16 12 | Flags uint16 13 | Type byte 14 | Scale byte 15 | } 16 | -------------------------------------------------------------------------------- /mysql/interface.go: -------------------------------------------------------------------------------- 1 | // Package mysql is a MySQL Client API written entirely in Go without any external dependences. 2 | package mysql 3 | 4 | import ( 5 | "net" 6 | "time" 7 | ) 8 | 9 | // ConnCommon is a common interface for the connection. 10 | // See mymysql/native for method documentation. 11 | type ConnCommon interface { 12 | Start(sql string, params ...interface{}) (Result, error) 13 | Prepare(sql string) (Stmt, error) 14 | 15 | Ping() error 16 | ThreadId() uint32 17 | Escape(txt string) string 18 | 19 | Query(sql string, params ...interface{}) ([]Row, Result, error) 20 | QueryFirst(sql string, params ...interface{}) (Row, Result, error) 21 | QueryLast(sql string, params ...interface{}) (Row, Result, error) 22 | } 23 | 24 | // Dialer can be used to dial connections to MySQL. If Dialer returns (nil, nil) 25 | // the hook is skipped and normal dialing proceeds. 26 | type Dialer func(proto, laddr, raddr string, timeout time.Duration) (net.Conn, error) 27 | 28 | // Conn represents connection to the MySQL server. 29 | // See mymysql/native for method documentation. 30 | type Conn interface { 31 | ConnCommon 32 | 33 | Clone() Conn 34 | SetTimeout(time.Duration) 35 | Connect() error 36 | NetConn() net.Conn 37 | SetDialer(Dialer) 38 | Close() error 39 | IsConnected() bool 40 | Reconnect() error 41 | Use(dbname string) error 42 | Register(sql string) 43 | SetMaxPktSize(new_size int) int 44 | NarrowTypeSet(narrow bool) 45 | FullFieldInfo(full bool) 46 | Status() ConnStatus 47 | Credentials() (user, passwd string) 48 | 49 | Begin() (Transaction, error) 50 | } 51 | 52 | // Transaction represents MySQL transaction. 53 | // See mymysql/native for method documentation. 54 | type Transaction interface { 55 | ConnCommon 56 | 57 | Commit() error 58 | Rollback() error 59 | Do(st Stmt) Stmt 60 | IsValid() bool 61 | } 62 | 63 | // Stmt represents MySQL prepared statement. 64 | // See mymysql/native for method documentation. 65 | type Stmt interface { 66 | Bind(params ...interface{}) 67 | Run(params ...interface{}) (Result, error) 68 | Delete() error 69 | Reset() error 70 | SendLongData(pnum int, data interface{}, pkt_size int) error 71 | 72 | Fields() []*Field 73 | NumParam() int 74 | WarnCount() int 75 | 76 | Exec(params ...interface{}) ([]Row, Result, error) 77 | ExecFirst(params ...interface{}) (Row, Result, error) 78 | ExecLast(params ...interface{}) (Row, Result, error) 79 | } 80 | 81 | // Result represents one MySQL result set. 82 | // See mymysql/native for method documentation. 83 | type Result interface { 84 | StatusOnly() bool 85 | ScanRow(Row) error 86 | GetRow() (Row, error) 87 | 88 | MoreResults() bool 89 | NextResult() (Result, error) 90 | 91 | Fields() []*Field 92 | Map(string) int 93 | Message() string 94 | AffectedRows() uint64 95 | InsertId() uint64 96 | WarnCount() int 97 | 98 | MakeRow() Row 99 | GetRows() ([]Row, error) 100 | End() error 101 | GetFirstRow() (Row, error) 102 | GetLastRow() (Row, error) 103 | } 104 | 105 | // New can be used to establish a connection. It is set by imported engine 106 | // (see mymysql/native, mymysql/thrsafe). 107 | var New func(proto, laddr, raddr, user, passwd string, db ...string) Conn 108 | -------------------------------------------------------------------------------- /mysql/row.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math" 7 | "os" 8 | "reflect" 9 | "strconv" 10 | "time" 11 | ) 12 | 13 | // Row is a type for result row. It contains values for any column of received row. 14 | // 15 | // If row is a result of ordinary text query, its element can be 16 | // []byte slice, contained result text or nil if NULL is returned. 17 | // 18 | // If it is a result of prepared statement execution, its element field can be: 19 | // intX, uintX, floatX, []byte, Date, Time, time.Time (in Local location) or nil. 20 | type Row []interface{} 21 | 22 | // Bin gets the nn-th value and returns it as []byte ([]byte{} if NULL). 23 | func (tr Row) Bin(nn int) (bin []byte) { 24 | switch data := tr[nn].(type) { 25 | case nil: 26 | // bin = []byte{} 27 | case []byte: 28 | bin = data 29 | default: 30 | buf := new(bytes.Buffer) 31 | fmt.Fprint(buf, data) 32 | bin = buf.Bytes() 33 | } 34 | return 35 | } 36 | 37 | // Str gets the nn-th value and returns it as string ("" if NULL). 38 | func (tr Row) Str(nn int) (str string) { 39 | switch data := tr[nn].(type) { 40 | case nil: 41 | // str = "" 42 | case []byte: 43 | str = string(data) 44 | case time.Time: 45 | str = TimeString(data) 46 | case time.Duration: 47 | str = DurationString(data) 48 | default: 49 | str = fmt.Sprint(data) 50 | } 51 | return 52 | } 53 | 54 | const _MAX_INT = int64(int(^uint(0) >> 1)) 55 | const _MIN_INT = -_MAX_INT - 1 56 | 57 | // IntErr gets the nn-th value and returns it as int (0 if NULL). Returns error if 58 | // conversion is impossible. 59 | func (tr Row) IntErr(nn int) (val int, err error) { 60 | switch data := tr[nn].(type) { 61 | case nil: 62 | // nop 63 | case int32: 64 | val = int(data) 65 | case int16: 66 | val = int(data) 67 | case uint16: 68 | val = int(data) 69 | case int8: 70 | val = int(data) 71 | case uint8: 72 | val = int(data) 73 | case []byte: 74 | val, err = strconv.Atoi(string(data)) 75 | case int64: 76 | if data >= _MIN_INT && data <= _MAX_INT { 77 | val = int(data) 78 | } else { 79 | err = strconv.ErrRange 80 | } 81 | case uint32: 82 | if int64(data) <= _MAX_INT { 83 | val = int(data) 84 | } else { 85 | err = strconv.ErrRange 86 | } 87 | case uint64: 88 | if data <= uint64(_MAX_INT) { 89 | val = int(data) 90 | } else { 91 | err = strconv.ErrRange 92 | } 93 | default: 94 | err = os.ErrInvalid 95 | } 96 | return 97 | } 98 | 99 | // Int gets the nn-th value and returns it as int (0 if NULL). Panics if conversion is 100 | // impossible. 101 | func (tr Row) Int(nn int) (val int) { 102 | val, err := tr.IntErr(nn) 103 | if err != nil { 104 | panic(err) 105 | } 106 | return 107 | } 108 | 109 | // ForceInt gets the nn-th value and returns it as int. Returns 0 if value is NULL or 110 | // conversion is impossible. 111 | func (tr Row) ForceInt(nn int) (val int) { 112 | val, _ = tr.IntErr(nn) 113 | return 114 | } 115 | 116 | const _MAX_UINT = uint64(^uint(0)) 117 | 118 | // UintErr gets the nn-th value and return it as uint (0 if NULL). Returns error if 119 | // conversion is impossible. 120 | func (tr Row) UintErr(nn int) (val uint, err error) { 121 | switch data := tr[nn].(type) { 122 | case nil: 123 | // nop 124 | case uint32: 125 | val = uint(data) 126 | case uint16: 127 | val = uint(data) 128 | case uint8: 129 | val = uint(data) 130 | case []byte: 131 | var v uint64 132 | v, err = strconv.ParseUint(string(data), 0, 0) 133 | val = uint(v) 134 | case uint64: 135 | if data <= _MAX_UINT { 136 | val = uint(data) 137 | } else { 138 | err = strconv.ErrRange 139 | } 140 | case int8, int16, int32, int64: 141 | v := reflect.ValueOf(data).Int() 142 | if v >= 0 && uint64(v) <= _MAX_UINT { 143 | val = uint(v) 144 | } else { 145 | err = strconv.ErrRange 146 | } 147 | default: 148 | err = os.ErrInvalid 149 | } 150 | return 151 | } 152 | 153 | // Uint gets the nn-th value and returns it as uint (0 if NULL). Panics if conversion is 154 | // impossible. 155 | func (tr Row) Uint(nn int) (val uint) { 156 | val, err := tr.UintErr(nn) 157 | if err != nil { 158 | panic(err) 159 | } 160 | return 161 | } 162 | 163 | // ForceUint gets the nn-th value and returns it as uint. Returns 0 if value is NULL or 164 | // conversion is impossible. 165 | func (tr Row) ForceUint(nn int) (val uint) { 166 | val, _ = tr.UintErr(nn) 167 | return 168 | } 169 | 170 | // DateErr gets the nn-th value and returns it as Date (0000-00-00 if NULL). Returns error 171 | // if conversion is impossible. 172 | func (tr Row) DateErr(nn int) (val Date, err error) { 173 | switch data := tr[nn].(type) { 174 | case nil: 175 | // nop 176 | case Date: 177 | val = data 178 | case []byte: 179 | val, err = ParseDate(string(data)) 180 | } 181 | return 182 | } 183 | 184 | // Date is like DateErr but panics if conversion is impossible. 185 | func (tr Row) Date(nn int) (val Date) { 186 | val, err := tr.DateErr(nn) 187 | if err != nil { 188 | panic(err) 189 | } 190 | return 191 | } 192 | 193 | // ForceDate is like DateErr but returns 0000-00-00 if conversion is impossible. 194 | func (tr Row) ForceDate(nn int) (val Date) { 195 | val, _ = tr.DateErr(nn) 196 | return 197 | } 198 | 199 | // TimeErr gets the nn-th value and returns it as time.Time in loc location (zero if NULL) 200 | // Returns error if conversion is impossible. It can convert Date to time.Time. 201 | func (tr Row) TimeErr(nn int, loc *time.Location) (t time.Time, err error) { 202 | switch data := tr[nn].(type) { 203 | case nil: 204 | // nop 205 | case time.Time: 206 | if loc == time.Local { 207 | t = data 208 | } else { 209 | y, mon, d := data.Date() 210 | h, m, s := data.Clock() 211 | t = time.Date(y, mon, d, h, m, s, t.Nanosecond(), loc) 212 | } 213 | case Date: 214 | t = data.Time(loc) 215 | case []byte: 216 | t, err = ParseTime(string(data), loc) 217 | } 218 | return 219 | } 220 | 221 | // Time is like TimeErr but panics if conversion is impossible. 222 | func (tr Row) Time(nn int, loc *time.Location) (val time.Time) { 223 | val, err := tr.TimeErr(nn, loc) 224 | if err != nil { 225 | panic(err) 226 | } 227 | return 228 | } 229 | 230 | // ForceTime is like TimeErr but returns 0000-00-00 00:00:00 if conversion is 231 | // impossible. 232 | func (tr Row) ForceTime(nn int, loc *time.Location) (val time.Time) { 233 | val, _ = tr.TimeErr(nn, loc) 234 | return 235 | } 236 | 237 | // LocaltimeErr gets the nn-th value and returns it as time.Time in Local location 238 | // (zero if NULL). Returns error if conversion is impossible. 239 | // It can convert Date to time.Time. 240 | func (tr Row) LocaltimeErr(nn int) (t time.Time, err error) { 241 | switch data := tr[nn].(type) { 242 | case nil: 243 | // nop 244 | case time.Time: 245 | t = data 246 | case Date: 247 | t = data.Time(time.Local) 248 | case []byte: 249 | t, err = ParseTime(string(data), time.Local) 250 | } 251 | return 252 | } 253 | 254 | // Localtime is like LocaltimeErr but panics if conversion is impossible. 255 | func (tr Row) Localtime(nn int) (val time.Time) { 256 | val, err := tr.LocaltimeErr(nn) 257 | if err != nil { 258 | panic(err) 259 | } 260 | return 261 | } 262 | 263 | // ForceLocaltime is like LocaltimeErr but returns 0000-00-00 00:00:00 if conversion is 264 | // impossible. 265 | func (tr Row) ForceLocaltime(nn int) (val time.Time) { 266 | val, _ = tr.LocaltimeErr(nn) 267 | return 268 | } 269 | 270 | // DurationErr gets the nn-th value and returns it as time.Duration (0 if NULL). Returns error 271 | // if conversion is impossible. 272 | func (tr Row) DurationErr(nn int) (val time.Duration, err error) { 273 | switch data := tr[nn].(type) { 274 | case nil: 275 | case time.Duration: 276 | val = data 277 | case []byte: 278 | val, err = ParseDuration(string(data)) 279 | default: 280 | err = fmt.Errorf("Can't convert `%v` to time.Duration", data) 281 | } 282 | return 283 | } 284 | 285 | // Duration is like DurationErr but panics if conversion is impossible. 286 | func (tr Row) Duration(nn int) (val time.Duration) { 287 | val, err := tr.DurationErr(nn) 288 | if err != nil { 289 | panic(err) 290 | } 291 | return 292 | } 293 | 294 | // ForceDuration is like DurationErr but returns 0 if conversion is impossible. 295 | func (tr Row) ForceDuration(nn int) (val time.Duration) { 296 | val, _ = tr.DurationErr(nn) 297 | return 298 | } 299 | 300 | // BoolErr gets the nn-th value and returns it as bool. Returns error 301 | // if conversion is impossible. 302 | func (tr Row) BoolErr(nn int) (val bool, err error) { 303 | switch data := tr[nn].(type) { 304 | case nil: 305 | // nop 306 | case int8: 307 | val = (data != 0) 308 | case int32: 309 | val = (data != 0) 310 | case int16: 311 | val = (data != 0) 312 | case int64: 313 | val = (data != 0) 314 | case uint8: 315 | val = (data != 0) 316 | case uint32: 317 | val = (data != 0) 318 | case uint16: 319 | val = (data != 0) 320 | case uint64: 321 | val = (data != 0) 322 | case []byte: 323 | var v int64 324 | v, err = strconv.ParseInt(string(data), 0, 64) 325 | val = (v != 0) 326 | default: 327 | err = os.ErrInvalid 328 | } 329 | return 330 | } 331 | 332 | // Bool is like BoolErr but panics if conversion is impossible. 333 | func (tr Row) Bool(nn int) (val bool) { 334 | val, err := tr.BoolErr(nn) 335 | if err != nil { 336 | panic(err) 337 | } 338 | return 339 | } 340 | 341 | // ForceBool is like BoolErr but returns false if conversion is impossible. 342 | func (tr Row) ForceBool(nn int) (val bool) { 343 | val, _ = tr.BoolErr(nn) 344 | return 345 | } 346 | 347 | // Int64Err gets the nn-th value and returns it as int64 (0 if NULL). Returns error if 348 | // conversion is impossible. 349 | func (tr Row) Int64Err(nn int) (val int64, err error) { 350 | switch data := tr[nn].(type) { 351 | case nil: 352 | // nop 353 | case int64, int32, int16, int8: 354 | val = reflect.ValueOf(data).Int() 355 | case uint64, uint32, uint16, uint8: 356 | u := reflect.ValueOf(data).Uint() 357 | if u > math.MaxInt64 { 358 | err = strconv.ErrRange 359 | } else { 360 | val = int64(u) 361 | } 362 | case []byte: 363 | val, err = strconv.ParseInt(string(data), 10, 64) 364 | default: 365 | err = os.ErrInvalid 366 | } 367 | return 368 | } 369 | 370 | // Int64 gets the nn-th value and returns it as int64 (0 if NULL). 371 | // Panics if conversion is impossible. 372 | func (tr Row) Int64(nn int) (val int64) { 373 | val, err := tr.Int64Err(nn) 374 | if err != nil { 375 | panic(err) 376 | } 377 | return 378 | } 379 | 380 | // ForceInt64 gets the nn-th value and returns it as int64. Returns 0 if value is NULL or 381 | // conversion is impossible. 382 | func (tr Row) ForceInt64(nn int) (val int64) { 383 | val, _ = tr.Int64Err(nn) 384 | return 385 | } 386 | 387 | // Uint64Err gets the nn-th value and returns it as uint64 (0 if NULL). Returns error if 388 | // conversion is impossible. 389 | func (tr Row) Uint64Err(nn int) (val uint64, err error) { 390 | switch data := tr[nn].(type) { 391 | case nil: 392 | // nop 393 | case uint64, uint32, uint16, uint8: 394 | val = reflect.ValueOf(data).Uint() 395 | case int64, int32, int16, int8: 396 | i := reflect.ValueOf(data).Int() 397 | if i < 0 { 398 | err = strconv.ErrRange 399 | } else { 400 | val = uint64(i) 401 | } 402 | case []byte: 403 | val, err = strconv.ParseUint(string(data), 10, 64) 404 | default: 405 | err = os.ErrInvalid 406 | } 407 | return 408 | } 409 | 410 | // Uint64 gets the nn-th value and returns it as uint64 (0 if NULL). 411 | // Panic if conversion is impossible. 412 | func (tr Row) Uint64(nn int) (val uint64) { 413 | val, err := tr.Uint64Err(nn) 414 | if err != nil { 415 | panic(err) 416 | } 417 | return 418 | } 419 | 420 | // ForceUint64 gets the nn-th value and returns it as uint64. Returns 0 if value is NULL or 421 | // conversion is impossible. 422 | func (tr Row) ForceUint64(nn int) (val uint64) { 423 | val, _ = tr.Uint64Err(nn) 424 | return 425 | } 426 | 427 | // FloatErr gets the nn-th value and returns it as float64 (0 if NULL). Returns error if 428 | // conversion is impossible. 429 | func (tr Row) FloatErr(nn int) (val float64, err error) { 430 | switch data := tr[nn].(type) { 431 | case nil: 432 | // nop 433 | case float64, float32: 434 | val = reflect.ValueOf(data).Float() 435 | case int64, int32, int16, int8: 436 | i := reflect.ValueOf(data).Int() 437 | if i >= 2<<53 || i <= -(2<<53) { 438 | err = strconv.ErrRange 439 | } else { 440 | val = float64(i) 441 | } 442 | case uint64, uint32, uint16, uint8: 443 | u := reflect.ValueOf(data).Uint() 444 | if u >= 2<<53 { 445 | err = strconv.ErrRange 446 | } else { 447 | val = float64(u) 448 | } 449 | case []byte: 450 | val, err = strconv.ParseFloat(string(data), 64) 451 | default: 452 | err = os.ErrInvalid 453 | } 454 | return 455 | } 456 | 457 | // Float gets the nn-th value and returns it as float64 (0 if NULL). 458 | // Panics if conversion is impossible. 459 | func (tr Row) Float(nn int) (val float64) { 460 | val, err := tr.FloatErr(nn) 461 | if err != nil { 462 | panic(err) 463 | } 464 | return 465 | } 466 | 467 | // ForceFloat gets the nn-th value and returns it as float64. Returns 0 if value is NULL or 468 | // if conversion is impossible. 469 | func (tr Row) ForceFloat(nn int) (val float64) { 470 | val, _ = tr.FloatErr(nn) 471 | return 472 | } 473 | -------------------------------------------------------------------------------- /mysql/status.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | type ConnStatus uint16 4 | 5 | // Status of server connection 6 | const ( 7 | SERVER_STATUS_IN_TRANS ConnStatus = 0x01 // Transaction has started 8 | SERVER_STATUS_AUTOCOMMIT ConnStatus = 0x02 // Server in auto_commit mode 9 | SERVER_STATUS_MORE_RESULTS ConnStatus = 0x04 10 | SERVER_MORE_RESULTS_EXISTS ConnStatus = 0x08 // Multi query - next query exists 11 | SERVER_QUERY_NO_GOOD_INDEX_USED ConnStatus = 0x10 12 | SERVER_QUERY_NO_INDEX_USED ConnStatus = 0x20 13 | SERVER_STATUS_CURSOR_EXISTS ConnStatus = 0x40 // Server opened a read-only non-scrollable cursor for a query 14 | SERVER_STATUS_LAST_ROW_SENT ConnStatus = 0x80 15 | 16 | SERVER_STATUS_DB_DROPPED ConnStatus = 0x100 17 | SERVER_STATUS_NO_BACKSLASH_ESCAPES ConnStatus = 0x200 18 | ) 19 | -------------------------------------------------------------------------------- /mysql/types.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // For MySQL DATE type 12 | type Date struct { 13 | Year int16 14 | Month, Day byte 15 | } 16 | 17 | func (dd Date) String() string { 18 | return fmt.Sprintf("%04d-%02d-%02d", dd.Year, dd.Month, dd.Day) 19 | } 20 | 21 | // IsZero: True if date is 0000-00-00 22 | func (dd Date) IsZero() bool { 23 | return dd.Day == 0 && dd.Month == 0 && dd.Year == 0 24 | } 25 | 26 | // Time: Converts Date to time.Time using loc location. 27 | // Converts MySQL zero to time.Time zero. 28 | func (dd Date) Time(loc *time.Location) (t time.Time) { 29 | if !dd.IsZero() { 30 | t = time.Date( 31 | int(dd.Year), time.Month(dd.Month), int(dd.Day), 32 | 0, 0, 0, 0, 33 | loc, 34 | ) 35 | } 36 | return 37 | } 38 | 39 | // Localtime: Converts Date to time.Time using Local location. 40 | // Converts MySQL zero to time.Time zero. 41 | func (dd Date) Localtime() time.Time { 42 | return dd.Time(time.Local) 43 | } 44 | 45 | // ParseDate: Convert string date in format YYYY-MM-DD to Date. 46 | // Leading and trailing spaces are ignored. 47 | func ParseDate(str string) (dd Date, err error) { 48 | str = strings.TrimSpace(str) 49 | if str == "0000-00-00" { 50 | return 51 | } 52 | var ( 53 | y, m, d int 54 | ) 55 | if len(str) != 10 || str[4] != '-' || str[7] != '-' { 56 | goto invalid 57 | } 58 | if y, err = strconv.Atoi(str[0:4]); err != nil { 59 | return 60 | } 61 | if m, err = strconv.Atoi(str[5:7]); err != nil { 62 | return 63 | } 64 | if m < 0 || m > 12 { // MySQL permits month == 0 65 | goto invalid 66 | } 67 | if d, err = strconv.Atoi(str[8:10]); err != nil { 68 | return 69 | } 70 | if d < 0 { // MySQL permits day == 0 71 | goto invalid 72 | } 73 | switch m { 74 | case 1, 3, 5, 7, 8, 10, 12: 75 | if d > 31 { 76 | goto invalid 77 | } 78 | case 4, 6, 9, 11: 79 | if d > 30 { 80 | goto invalid 81 | } 82 | case 2: 83 | if d > 29 { 84 | goto invalid 85 | } 86 | } 87 | dd.Year = int16(y) 88 | dd.Month = byte(m) 89 | dd.Day = byte(d) 90 | return 91 | 92 | invalid: 93 | err = errors.New("Invalid MySQL DATE string: " + str) 94 | return 95 | } 96 | 97 | // Sandard MySQL datetime format 98 | const TimeFormat = "2006-01-02 15:04:05.000000000" 99 | 100 | // TimeString returns t as string in MySQL format Converts time.Time zero to MySQL zero. 101 | func TimeString(t time.Time) string { 102 | if t.IsZero() { 103 | return "0000-00-00 00:00:00" 104 | } 105 | if t.Nanosecond() == 0 { 106 | return t.Format(TimeFormat[:19]) 107 | } 108 | return t.Format(TimeFormat) 109 | } 110 | 111 | // ParseTime: Parses string datetime in TimeFormat using loc location. 112 | // Converts MySQL zero to time.Time zero. 113 | func ParseTime(str string, loc *time.Location) (t time.Time, err error) { 114 | str = strings.TrimSpace(str) 115 | format := TimeFormat[:19] 116 | switch len(str) { 117 | case 10: 118 | if str == "0000-00-00" { 119 | return 120 | } 121 | format = format[:10] 122 | case 19: 123 | if str == "0000-00-00 00:00:00" { 124 | return 125 | } 126 | } 127 | // Don't expect 0000-00-00 00:00:00.0+ 128 | t, err = time.ParseInLocation(format, str, loc) 129 | return 130 | } 131 | 132 | // DurationString: Convert time.Duration to string representation of mysql.TIME 133 | func DurationString(d time.Duration) string { 134 | sign := 1 135 | if d < 0 { 136 | sign = -1 137 | d = -d 138 | } 139 | ns := int(d % 1e9) 140 | d /= 1e9 141 | sec := int(d % 60) 142 | d /= 60 143 | min := int(d % 60) 144 | hour := int(d/60) * sign 145 | if ns == 0 { 146 | return fmt.Sprintf("%d:%02d:%02d", hour, min, sec) 147 | } 148 | return fmt.Sprintf("%d:%02d:%02d.%09d", hour, min, sec, ns) 149 | } 150 | 151 | // ParseDuration: Parse duration from MySQL string format [+-]H+:MM:SS[.UUUUUUUUU]. 152 | // Leading and trailing spaces are ignored. If format is invalid returns nil. 153 | func ParseDuration(str string) (dur time.Duration, err error) { 154 | str = strings.TrimSpace(str) 155 | orig := str 156 | // Check sign 157 | sign := int64(1) 158 | switch str[0] { 159 | case '-': 160 | sign = -1 161 | fallthrough 162 | case '+': 163 | str = str[1:] 164 | } 165 | var i, d int64 166 | // Find houre 167 | if nn := strings.IndexRune(str, ':'); nn != -1 { 168 | if i, err = strconv.ParseInt(str[0:nn], 10, 64); err != nil { 169 | return 170 | } 171 | d = i * 3600 172 | str = str[nn+1:] 173 | } else { 174 | goto invalid 175 | } 176 | if len(str) != 5 && len(str) != 15 || str[2] != ':' { 177 | goto invalid 178 | } 179 | if i, err = strconv.ParseInt(str[0:2], 10, 64); err != nil { 180 | return 181 | } 182 | if i < 0 || i > 59 { 183 | goto invalid 184 | } 185 | d += i * 60 186 | if i, err = strconv.ParseInt(str[3:5], 10, 64); err != nil { 187 | return 188 | } 189 | if i < 0 || i > 59 { 190 | goto invalid 191 | } 192 | d += i 193 | d *= 1e9 194 | if len(str) == 15 { 195 | if str[5] != '.' { 196 | goto invalid 197 | } 198 | if i, err = strconv.ParseInt(str[6:15], 10, 64); err != nil { 199 | return 200 | } 201 | d += i 202 | } 203 | dur = time.Duration(d * sign) 204 | return 205 | 206 | invalid: 207 | err = errors.New("invalid MySQL TIME string: " + orig) 208 | return 209 | 210 | } 211 | 212 | type Blob []byte 213 | 214 | type Raw struct { 215 | Typ uint16 216 | Val *[]byte 217 | } 218 | 219 | type Timestamp struct { 220 | time.Time 221 | } 222 | 223 | func (t Timestamp) String() string { 224 | return TimeString(t.Time) 225 | } 226 | -------------------------------------------------------------------------------- /mysql/types_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | type sio struct { 9 | in, out string 10 | } 11 | 12 | func checkRow(t *testing.T, examples []sio, conv func(string) interface{}) { 13 | row := make(Row, 1) 14 | for _, ex := range examples { 15 | row[0] = conv(ex.in) 16 | str := row.Str(0) 17 | if str != ex.out { 18 | t.Fatalf("Wrong conversion: '%s' != '%s'", str, ex.out) 19 | } 20 | } 21 | } 22 | 23 | var dates = []sio{ 24 | sio{"2121-11-22", "2121-11-22"}, 25 | sio{"0000-00-00", "0000-00-00"}, 26 | sio{" 1234-12-18 ", "1234-12-18"}, 27 | sio{"\t1234-12-18 \r\n", "1234-12-18"}, 28 | } 29 | 30 | func TestConvDate(t *testing.T) { 31 | conv := func(str string) interface{} { 32 | d, err := ParseDate(str) 33 | if err != nil { 34 | return err 35 | } 36 | return d 37 | } 38 | checkRow(t, dates, conv) 39 | } 40 | 41 | var datetimes = []sio{ 42 | sio{"2121-11-22 11:22:32", "2121-11-22 11:22:32"}, 43 | sio{" 1234-12-18 22:11:22 ", "1234-12-18 22:11:22"}, 44 | sio{"\t 1234-12-18 22:11:22 \r\n", "1234-12-18 22:11:22"}, 45 | sio{"2000-11-11", "2000-11-11 00:00:00"}, 46 | sio{"0000-00-00 00:00:00", "0000-00-00 00:00:00"}, 47 | sio{"0000-00-00", "0000-00-00 00:00:00"}, 48 | sio{"2000-11-22 11:11:11.000111222", "2000-11-22 11:11:11.000111222"}, 49 | } 50 | 51 | func TestConvTime(t *testing.T) { 52 | conv := func(str string) interface{} { 53 | d, err := ParseTime(str, time.Local) 54 | if err != nil { 55 | return err 56 | } 57 | return d 58 | } 59 | checkRow(t, datetimes, conv) 60 | } 61 | 62 | var times = []sio{ 63 | sio{"1:23:45", "1:23:45"}, 64 | sio{"-112:23:45", "-112:23:45"}, 65 | sio{"+112:23:45", "112:23:45"}, 66 | sio{"1:60:00", "invalid MySQL TIME string: 1:60:00"}, 67 | sio{"1:00:60", "invalid MySQL TIME string: 1:00:60"}, 68 | sio{"1:23:45.000111333", "1:23:45.000111333"}, 69 | sio{"-1:23:45.000111333", "-1:23:45.000111333"}, 70 | } 71 | 72 | func TestConvDuration(t *testing.T) { 73 | conv := func(str string) interface{} { 74 | d, err := ParseDuration(str) 75 | if err != nil { 76 | return err 77 | } 78 | return d 79 | 80 | } 81 | checkRow(t, times, conv) 82 | } 83 | 84 | func TestEscapeString(t *testing.T) { 85 | txt := " \000 \n \r \\ ' \" \032 " 86 | exp := ` \0 \n \r \\ \' \" \Z ` 87 | out := escapeString(txt) 88 | if out != exp { 89 | t.Fatalf("escapeString: ret='%s' exp='%s'", out, exp) 90 | } 91 | } 92 | 93 | func TestEscapeQuotes(t *testing.T) { 94 | txt := " '' '' ' ' ' " 95 | exp := ` '''' '''' '' '' '' ` 96 | out := escapeQuotes(txt) 97 | if out != exp { 98 | t.Fatalf("escapeString: ret='%s' exp='%s'", out, exp) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /mysql/utils.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "strings" 11 | "time" 12 | "unicode" 13 | ) 14 | 15 | // Version returns mymysql version string 16 | func Version() string { 17 | return "1.5.3" 18 | } 19 | 20 | func syntaxError(ln int) error { 21 | return fmt.Errorf("syntax error at line: %d", ln) 22 | } 23 | 24 | // NewFromCF: Creates new conneection handler using configuration in cfgFile. Returns 25 | // connection handler and map contains unknown options. 26 | // 27 | // Config file format(example): 28 | // 29 | // # mymysql options (if some option isn't specified it defaults to "") 30 | // 31 | // DbRaddr 127.0.0.1:3306 32 | // # DbRaddr /var/run/mysqld/mysqld.sock 33 | // DbUser testuser 34 | // DbPass TestPasswd9 35 | // # optional: DbName test 36 | // # optional: DbEncd utf8 37 | // # optional: DbLaddr 127.0.0.1:0 38 | // # optional: DbTimeout 15s 39 | // 40 | // # Your options (returned in unk) 41 | // 42 | // MyOpt some text 43 | func NewFromCF(cfgFile string) (con Conn, unk map[string]string, err error) { 44 | var cf *os.File 45 | cf, err = os.Open(cfgFile) 46 | if err != nil { 47 | return 48 | } 49 | br := bufio.NewReader(cf) 50 | um := make(map[string]string) 51 | var proto, laddr, raddr, user, pass, name, encd, to string 52 | for i := 1; ; i++ { 53 | buf, isPrefix, e := br.ReadLine() 54 | if e != nil { 55 | if e == io.EOF { 56 | break 57 | } 58 | err = e 59 | return 60 | } 61 | l := string(buf) 62 | if isPrefix { 63 | err = fmt.Errorf("line %d is too long", i) 64 | return 65 | } 66 | l = strings.TrimFunc(l, unicode.IsSpace) 67 | if len(l) == 0 || l[0] == '#' { 68 | continue 69 | } 70 | n := strings.IndexFunc(l, unicode.IsSpace) 71 | if n == -1 { 72 | err = fmt.Errorf("syntax error at line: %d", i) 73 | return 74 | } 75 | v := l[:n] 76 | l = strings.TrimLeftFunc(l[n:], unicode.IsSpace) 77 | switch v { 78 | case "DbLaddr": 79 | laddr = l 80 | case "DbRaddr": 81 | raddr = l 82 | proto = "tcp" 83 | if !strings.ContainsRune(l, ':') { 84 | proto = "unix" 85 | } 86 | case "DbUser": 87 | user = l 88 | case "DbPass": 89 | pass = l 90 | case "DbName": 91 | name = l 92 | case "DbEncd": 93 | encd = l 94 | case "DbTimeout": 95 | to = l 96 | default: 97 | um[v] = l 98 | } 99 | } 100 | if raddr == "" { 101 | err = errors.New("DbRaddr option is empty") 102 | return 103 | } 104 | unk = um 105 | if name != "" { 106 | con = New(proto, laddr, raddr, user, pass, name) 107 | } else { 108 | con = New(proto, laddr, raddr, user, pass) 109 | } 110 | if encd != "" { 111 | con.Register(fmt.Sprintf("SET NAMES %s", encd)) 112 | } 113 | if to != "" { 114 | var timeout time.Duration 115 | timeout, err = time.ParseDuration(to) 116 | if err != nil { 117 | return 118 | } 119 | con.SetTimeout(timeout) 120 | } 121 | return 122 | } 123 | 124 | // Query: Calls Start and next calls GetRow as long as it reads all rows from the 125 | // result. Next it returns all readed rows as the slice of rows. 126 | func Query(c Conn, sql string, params ...interface{}) (rows []Row, res Result, err error) { 127 | res, err = c.Start(sql, params...) 128 | if err != nil { 129 | return 130 | } 131 | rows, err = GetRows(res) 132 | return 133 | } 134 | 135 | // QueryFirst: Calls Start and next calls GetFirstRow 136 | func QueryFirst(c Conn, sql string, params ...interface{}) (row Row, res Result, err error) { 137 | res, err = c.Start(sql, params...) 138 | if err != nil { 139 | return 140 | } 141 | row, err = GetFirstRow(res) 142 | return 143 | } 144 | 145 | // QueryLast: Calls Start and next calls GetLastRow 146 | func QueryLast(c Conn, sql string, params ...interface{}) (row Row, res Result, err error) { 147 | res, err = c.Start(sql, params...) 148 | if err != nil { 149 | return 150 | } 151 | row, err = GetLastRow(res) 152 | return 153 | } 154 | 155 | // Exec: Calls Run and next call GetRow as long as it reads all rows from the 156 | // result. Next it returns all readed rows as the slice of rows. 157 | func Exec(s Stmt, params ...interface{}) (rows []Row, res Result, err error) { 158 | res, err = s.Run(params...) 159 | if err != nil { 160 | return 161 | } 162 | rows, err = GetRows(res) 163 | return 164 | } 165 | 166 | // ExecFirst: Calls Run and next call GetFirstRow 167 | func ExecFirst(s Stmt, params ...interface{}) (row Row, res Result, err error) { 168 | res, err = s.Run(params...) 169 | if err != nil { 170 | return 171 | } 172 | row, err = GetFirstRow(res) 173 | return 174 | } 175 | 176 | // ExecLast: Calls Run and next call GetLastRow 177 | func ExecLast(s Stmt, params ...interface{}) (row Row, res Result, err error) { 178 | res, err = s.Run(params...) 179 | if err != nil { 180 | return 181 | } 182 | row, err = GetLastRow(res) 183 | return 184 | } 185 | 186 | // GetRow: Calls r.MakeRow and next r.ScanRow. Doesn't return io.EOF error (returns nil 187 | // row insted). 188 | func GetRow(r Result) (Row, error) { 189 | row := r.MakeRow() 190 | err := r.ScanRow(row) 191 | if err != nil { 192 | if err == io.EOF { 193 | return nil, nil 194 | } 195 | return nil, err 196 | } 197 | return row, nil 198 | } 199 | 200 | // GetRows reads all rows from result and returns them as slice. 201 | func GetRows(r Result) (rows []Row, err error) { 202 | var row Row 203 | for { 204 | row, err = r.GetRow() 205 | if err != nil || row == nil { 206 | break 207 | } 208 | rows = append(rows, row) 209 | } 210 | return 211 | } 212 | 213 | // GetLastRow returns last row and discard others 214 | func GetLastRow(r Result) (Row, error) { 215 | row := r.MakeRow() 216 | err := r.ScanRow(row) 217 | if err == io.EOF { 218 | return nil, nil 219 | } 220 | for err == nil { 221 | err = r.ScanRow(row) 222 | } 223 | if err == io.EOF { 224 | return row, nil 225 | } 226 | return nil, err 227 | } 228 | 229 | // End reads all unreaded rows and discard them. This function is useful if you 230 | // don't want to use the remaining rows. It has an impact only on current 231 | // result. If there is multi result query, you must use NextResult method and 232 | // read/discard all rows in this result, before use other method that sends 233 | // data to the server. You can't use this function if last GetRow returned nil. 234 | func End(r Result) error { 235 | _, err := GetLastRow(r) 236 | return err 237 | } 238 | 239 | // GetFirstRow returns first row and discard others 240 | func GetFirstRow(r Result) (row Row, err error) { 241 | row, err = r.GetRow() 242 | if err == nil && row != nil { 243 | err = r.End() 244 | } 245 | return 246 | } 247 | 248 | func escapeString(txt string) string { 249 | var ( 250 | esc string 251 | buf bytes.Buffer 252 | ) 253 | last := 0 254 | for ii, bb := range txt { 255 | switch bb { 256 | case 0: 257 | esc = `\0` 258 | case '\n': 259 | esc = `\n` 260 | case '\r': 261 | esc = `\r` 262 | case '\\': 263 | esc = `\\` 264 | case '\'': 265 | esc = `\'` 266 | case '"': 267 | esc = `\"` 268 | case '\032': 269 | esc = `\Z` 270 | default: 271 | continue 272 | } 273 | io.WriteString(&buf, txt[last:ii]) 274 | io.WriteString(&buf, esc) 275 | last = ii + 1 276 | } 277 | io.WriteString(&buf, txt[last:]) 278 | return buf.String() 279 | } 280 | 281 | func escapeQuotes(txt string) string { 282 | var buf bytes.Buffer 283 | last := 0 284 | for ii, bb := range txt { 285 | if bb == '\'' { 286 | io.WriteString(&buf, txt[last:ii]) 287 | io.WriteString(&buf, `''`) 288 | last = ii + 1 289 | } 290 | } 291 | io.WriteString(&buf, txt[last:]) 292 | return buf.String() 293 | } 294 | 295 | // Escape: Escapes special characters in the txt, so it is safe to place returned string 296 | // to Query method. 297 | func Escape(c Conn, txt string) string { 298 | if c.Status()&SERVER_STATUS_NO_BACKSLASH_ESCAPES != 0 { 299 | return escapeQuotes(txt) 300 | } 301 | return escapeString(txt) 302 | } 303 | -------------------------------------------------------------------------------- /native/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010, Michal Derkacz 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 3. The name of the author may not be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /native/addons.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | func NbinToNstr(nbin *[]byte) *string { 4 | if nbin == nil { 5 | return nil 6 | } 7 | str := string(*nbin) 8 | return &str 9 | } 10 | 11 | func NstrToNbin(nstr *string) *[]byte { 12 | if nstr == nil { 13 | return nil 14 | } 15 | bin := []byte(*nstr) 16 | return &bin 17 | } 18 | -------------------------------------------------------------------------------- /native/bind_test.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "github.com/ziutek/mymysql/mysql" 7 | "math" 8 | "reflect" 9 | "strconv" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | var ( 15 | Bytes = []byte("Ala ma Kota!") 16 | String = "ssss" //"A kot ma Alę!" 17 | blob = mysql.Blob{1, 2, 3} 18 | dateT = time.Date(2010, 12, 30, 17, 21, 01, 0, time.Local) 19 | tstamp = mysql.Timestamp{dateT.Add(1e9)} 20 | date = mysql.Date{Year: 2011, Month: 2, Day: 3} 21 | tim = -time.Duration((5*24*3600+4*3600+3*60+2)*1e9 + 1) 22 | bol = true 23 | 24 | pBytes *[]byte 25 | pString *string 26 | pBlob *mysql.Blob 27 | pDateT *time.Time 28 | pTstamp *mysql.Timestamp 29 | pDate *mysql.Date 30 | pTim *time.Duration 31 | pBol *bool 32 | 33 | raw = mysql.Raw{MYSQL_TYPE_INT24, &[]byte{3, 2, 1, 0}} 34 | 35 | Int8 = int8(1) 36 | Uint8 = uint8(2) 37 | Int16 = int16(3) 38 | Uint16 = uint16(4) 39 | Int32 = int32(5) 40 | Uint32 = uint32(6) 41 | Int64 = int64(0x7000100020003001) 42 | Uint64 = uint64(0xffff0000ffff0000) 43 | Int = int(7) 44 | Uint = uint(8) 45 | 46 | Float32 = float32(1e10) 47 | Float64 = 256e256 48 | 49 | pInt8 *int8 50 | pUint8 *uint8 51 | pInt16 *int16 52 | pUint16 *uint16 53 | pInt32 *int32 54 | pUint32 *uint32 55 | pInt64 *int64 56 | pUint64 *uint64 57 | pInt *int 58 | pUint *uint 59 | pFloat32 *float32 60 | pFloat64 *float64 61 | ) 62 | 63 | type BindTest struct { 64 | val interface{} 65 | typ uint16 66 | length int 67 | } 68 | 69 | func intSize() int { 70 | switch strconv.IntSize { 71 | case 32: 72 | return 4 73 | case 64: 74 | return 8 75 | } 76 | panic("bad int size") 77 | } 78 | 79 | func intType() uint16 { 80 | switch strconv.IntSize { 81 | case 32: 82 | return MYSQL_TYPE_LONG 83 | case 64: 84 | return MYSQL_TYPE_LONGLONG 85 | } 86 | panic("bad int size") 87 | 88 | } 89 | 90 | var bindTests = []BindTest{ 91 | BindTest{nil, MYSQL_TYPE_NULL, 0}, 92 | 93 | BindTest{Bytes, MYSQL_TYPE_VAR_STRING, -1}, 94 | BindTest{String, MYSQL_TYPE_STRING, -1}, 95 | BindTest{blob, MYSQL_TYPE_BLOB, -1}, 96 | BindTest{dateT, MYSQL_TYPE_DATETIME, -1}, 97 | BindTest{tstamp, MYSQL_TYPE_TIMESTAMP, -1}, 98 | BindTest{date, MYSQL_TYPE_DATE, -1}, 99 | BindTest{tim, MYSQL_TYPE_TIME, -1}, 100 | BindTest{bol, MYSQL_TYPE_TINY, -1}, 101 | 102 | BindTest{&Bytes, MYSQL_TYPE_VAR_STRING, -1}, 103 | BindTest{&String, MYSQL_TYPE_STRING, -1}, 104 | BindTest{&blob, MYSQL_TYPE_BLOB, -1}, 105 | BindTest{&dateT, MYSQL_TYPE_DATETIME, -1}, 106 | BindTest{&tstamp, MYSQL_TYPE_TIMESTAMP, -1}, 107 | BindTest{&date, MYSQL_TYPE_DATE, -1}, 108 | BindTest{&tim, MYSQL_TYPE_TIME, -1}, 109 | 110 | BindTest{pBytes, MYSQL_TYPE_VAR_STRING, -1}, 111 | BindTest{pString, MYSQL_TYPE_STRING, -1}, 112 | BindTest{pBlob, MYSQL_TYPE_BLOB, -1}, 113 | BindTest{pDateT, MYSQL_TYPE_DATETIME, -1}, 114 | BindTest{pTstamp, MYSQL_TYPE_TIMESTAMP, -1}, 115 | BindTest{pDate, MYSQL_TYPE_DATE, -1}, 116 | BindTest{pTim, MYSQL_TYPE_TIME, -1}, 117 | BindTest{pBol, MYSQL_TYPE_TINY, -1}, 118 | 119 | BindTest{raw, MYSQL_TYPE_INT24, -1}, 120 | 121 | BindTest{Int8, MYSQL_TYPE_TINY, 1}, 122 | BindTest{Int16, MYSQL_TYPE_SHORT, 2}, 123 | BindTest{Int32, MYSQL_TYPE_LONG, 4}, 124 | BindTest{Int64, MYSQL_TYPE_LONGLONG, 8}, 125 | BindTest{Int, intType(), intSize()}, 126 | 127 | BindTest{&Int8, MYSQL_TYPE_TINY, 1}, 128 | BindTest{&Int16, MYSQL_TYPE_SHORT, 2}, 129 | BindTest{&Int32, MYSQL_TYPE_LONG, 4}, 130 | BindTest{&Int64, MYSQL_TYPE_LONGLONG, 8}, 131 | BindTest{&Int, intType(), intSize()}, 132 | 133 | BindTest{pInt8, MYSQL_TYPE_TINY, 1}, 134 | BindTest{pInt16, MYSQL_TYPE_SHORT, 2}, 135 | BindTest{pInt32, MYSQL_TYPE_LONG, 4}, 136 | BindTest{pInt64, MYSQL_TYPE_LONGLONG, 8}, 137 | BindTest{pInt, intType(), intSize()}, 138 | 139 | BindTest{Uint8, MYSQL_TYPE_TINY | MYSQL_UNSIGNED_MASK, 1}, 140 | BindTest{Uint16, MYSQL_TYPE_SHORT | MYSQL_UNSIGNED_MASK, 2}, 141 | BindTest{Uint32, MYSQL_TYPE_LONG | MYSQL_UNSIGNED_MASK, 4}, 142 | BindTest{Uint64, MYSQL_TYPE_LONGLONG | MYSQL_UNSIGNED_MASK, 8}, 143 | BindTest{Uint, intType() | MYSQL_UNSIGNED_MASK, intSize()}, 144 | 145 | BindTest{&Uint8, MYSQL_TYPE_TINY | MYSQL_UNSIGNED_MASK, 1}, 146 | BindTest{&Uint16, MYSQL_TYPE_SHORT | MYSQL_UNSIGNED_MASK, 2}, 147 | BindTest{&Uint32, MYSQL_TYPE_LONG | MYSQL_UNSIGNED_MASK, 4}, 148 | BindTest{&Uint64, MYSQL_TYPE_LONGLONG | MYSQL_UNSIGNED_MASK, 8}, 149 | BindTest{&Uint, intType() | MYSQL_UNSIGNED_MASK, intSize()}, 150 | 151 | BindTest{pUint8, MYSQL_TYPE_TINY | MYSQL_UNSIGNED_MASK, 1}, 152 | BindTest{pUint16, MYSQL_TYPE_SHORT | MYSQL_UNSIGNED_MASK, 2}, 153 | BindTest{pUint32, MYSQL_TYPE_LONG | MYSQL_UNSIGNED_MASK, 4}, 154 | BindTest{pUint64, MYSQL_TYPE_LONGLONG | MYSQL_UNSIGNED_MASK, 8}, 155 | BindTest{pUint, intType() | MYSQL_UNSIGNED_MASK, intSize()}, 156 | 157 | BindTest{Float32, MYSQL_TYPE_FLOAT, 4}, 158 | BindTest{Float64, MYSQL_TYPE_DOUBLE, 8}, 159 | 160 | BindTest{&Float32, MYSQL_TYPE_FLOAT, 4}, 161 | BindTest{&Float64, MYSQL_TYPE_DOUBLE, 8}, 162 | } 163 | 164 | func makeAddressable(v reflect.Value) reflect.Value { 165 | if v.IsValid() { 166 | // Make an addresable value 167 | av := reflect.New(v.Type()).Elem() 168 | av.Set(v) 169 | v = av 170 | } 171 | return v 172 | } 173 | 174 | func TestBind(t *testing.T) { 175 | for _, test := range bindTests { 176 | v := makeAddressable(reflect.ValueOf(test.val)) 177 | val := bindValue(v) 178 | if val.typ != test.typ || val.length != test.length { 179 | t.Errorf( 180 | "Type: %s exp=0x%x res=0x%x Len: exp=%d res=%d", 181 | reflect.TypeOf(test.val), test.typ, val.typ, test.length, 182 | val.length, 183 | ) 184 | } 185 | } 186 | } 187 | 188 | type WriteTest struct { 189 | val interface{} 190 | exp []byte 191 | } 192 | 193 | var writeTest []WriteTest 194 | 195 | func encodeU16(v uint16) []byte { 196 | buf := make([]byte, 2) 197 | EncodeU16(buf, v) 198 | return buf 199 | } 200 | 201 | func encodeU24(v uint32) []byte { 202 | buf := make([]byte, 3) 203 | EncodeU24(buf, v) 204 | return buf 205 | } 206 | 207 | func encodeU32(v uint32) []byte { 208 | buf := make([]byte, 4) 209 | EncodeU32(buf, v) 210 | return buf 211 | } 212 | 213 | func encodeU64(v uint64) []byte { 214 | buf := make([]byte, 8) 215 | EncodeU64(buf, v) 216 | return buf 217 | } 218 | 219 | func encodeDuration(d time.Duration) []byte { 220 | buf := make([]byte, 13) 221 | n := EncodeDuration(buf, d) 222 | return buf[:n] 223 | } 224 | 225 | func encodeTime(t time.Time) []byte { 226 | buf := make([]byte, 12) 227 | n := EncodeTime(buf, t) 228 | return buf[:n] 229 | } 230 | 231 | func encodeDate(d mysql.Date) []byte { 232 | buf := make([]byte, 5) 233 | n := EncodeDate(buf, d) 234 | return buf[:n] 235 | } 236 | 237 | func encodeUint(u uint) []byte { 238 | switch strconv.IntSize { 239 | case 32: 240 | return encodeU32(uint32(u)) 241 | case 64: 242 | return encodeU64(uint64(u)) 243 | } 244 | panic("bad int size") 245 | 246 | } 247 | 248 | func init() { 249 | b := make([]byte, 64*1024) 250 | for ii := range b { 251 | b[ii] = byte(ii) 252 | } 253 | blob = mysql.Blob(b) 254 | 255 | writeTest = []WriteTest{ 256 | WriteTest{Bytes, append([]byte{byte(len(Bytes))}, Bytes...)}, 257 | WriteTest{String, append([]byte{byte(len(String))}, []byte(String)...)}, 258 | WriteTest{pBytes, nil}, 259 | WriteTest{pString, nil}, 260 | WriteTest{ 261 | blob, 262 | append( 263 | append([]byte{253}, byte(len(blob)), byte(len(blob)>>8), byte(len(blob)>>16)), 264 | []byte(blob)...), 265 | }, 266 | WriteTest{ 267 | dateT, 268 | []byte{ 269 | 7, byte(dateT.Year()), byte(dateT.Year() >> 8), 270 | byte(dateT.Month()), 271 | byte(dateT.Day()), byte(dateT.Hour()), byte(dateT.Minute()), 272 | byte(dateT.Second()), 273 | }, 274 | }, 275 | WriteTest{ 276 | &dateT, 277 | []byte{ 278 | 7, byte(dateT.Year()), byte(dateT.Year() >> 8), 279 | byte(dateT.Month()), 280 | byte(dateT.Day()), byte(dateT.Hour()), byte(dateT.Minute()), 281 | byte(dateT.Second()), 282 | }, 283 | }, 284 | WriteTest{ 285 | date, 286 | []byte{ 287 | 4, byte(date.Year), byte(date.Year >> 8), byte(date.Month), 288 | byte(date.Day), 289 | }, 290 | }, 291 | WriteTest{ 292 | &date, 293 | []byte{ 294 | 4, byte(date.Year), byte(date.Year >> 8), byte(date.Month), 295 | byte(date.Day), 296 | }, 297 | }, 298 | WriteTest{ 299 | tim, 300 | []byte{12, 1, 5, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0}, 301 | }, 302 | WriteTest{ 303 | &tim, 304 | []byte{12, 1, 5, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0}, 305 | }, 306 | WriteTest{bol, []byte{1}}, 307 | WriteTest{&bol, []byte{1}}, 308 | WriteTest{pBol, nil}, 309 | 310 | WriteTest{dateT, encodeTime(dateT)}, 311 | WriteTest{&dateT, encodeTime(dateT)}, 312 | WriteTest{pDateT, nil}, 313 | 314 | WriteTest{tstamp, encodeTime(tstamp.Time)}, 315 | WriteTest{&tstamp, encodeTime(tstamp.Time)}, 316 | WriteTest{pTstamp, nil}, 317 | 318 | WriteTest{date, encodeDate(date)}, 319 | WriteTest{&date, encodeDate(date)}, 320 | WriteTest{pDate, nil}, 321 | 322 | WriteTest{tim, encodeDuration(tim)}, 323 | WriteTest{&tim, encodeDuration(tim)}, 324 | WriteTest{pTim, nil}, 325 | 326 | WriteTest{Int, encodeUint(uint(Int))}, 327 | WriteTest{Int16, encodeU16(uint16(Int16))}, 328 | WriteTest{Int32, encodeU32(uint32(Int32))}, 329 | WriteTest{Int64, encodeU64(uint64(Int64))}, 330 | 331 | WriteTest{Uint, encodeUint(Uint)}, 332 | WriteTest{Uint16, encodeU16(Uint16)}, 333 | WriteTest{Uint32, encodeU32(Uint32)}, 334 | WriteTest{Uint64, encodeU64(Uint64)}, 335 | 336 | WriteTest{&Int, encodeUint(uint(Int))}, 337 | WriteTest{&Int16, encodeU16(uint16(Int16))}, 338 | WriteTest{&Int32, encodeU32(uint32(Int32))}, 339 | WriteTest{&Int64, encodeU64(uint64(Int64))}, 340 | 341 | WriteTest{&Uint, encodeUint(Uint)}, 342 | WriteTest{&Uint16, encodeU16(Uint16)}, 343 | WriteTest{&Uint32, encodeU32(Uint32)}, 344 | WriteTest{&Uint64, encodeU64(Uint64)}, 345 | 346 | WriteTest{pInt, nil}, 347 | WriteTest{pInt16, nil}, 348 | WriteTest{pInt32, nil}, 349 | WriteTest{pInt64, nil}, 350 | 351 | WriteTest{Float32, encodeU32(math.Float32bits(Float32))}, 352 | WriteTest{Float64, encodeU64(math.Float64bits(Float64))}, 353 | 354 | WriteTest{&Float32, encodeU32(math.Float32bits(Float32))}, 355 | WriteTest{&Float64, encodeU64(math.Float64bits(Float64))}, 356 | 357 | WriteTest{pFloat32, nil}, 358 | WriteTest{pFloat64, nil}, 359 | } 360 | } 361 | 362 | func TestWrite(t *testing.T) { 363 | buf := new(bytes.Buffer) 364 | for _, test := range writeTest { 365 | buf.Reset() 366 | var seq byte 367 | pw := &pktWriter{ 368 | wr: bufio.NewWriter(buf), 369 | seq: &seq, 370 | to_write: len(test.exp), 371 | } 372 | v := makeAddressable(reflect.ValueOf(test.val)) 373 | val := bindValue(v) 374 | pw.writeValue(&val) 375 | if !reflect.Indirect(v).IsValid() && len(buf.Bytes()) == 0 { 376 | // writeValue writes nothing for nil 377 | continue 378 | } 379 | if len(buf.Bytes()) != len(test.exp)+4 || !bytes.Equal(buf.Bytes()[4:], test.exp) || val.Len() != len(test.exp) { 380 | t.Fatalf("%s - exp_len=%d res_len=%d exp: %v res: %v", 381 | reflect.TypeOf(test.val), len(test.exp), val.Len(), 382 | test.exp, buf.Bytes(), 383 | ) 384 | } 385 | } 386 | } 387 | -------------------------------------------------------------------------------- /native/binding.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "reflect" 6 | "time" 7 | ) 8 | 9 | var ( 10 | timeType = reflect.TypeOf(time.Time{}) 11 | timestampType = reflect.TypeOf(mysql.Timestamp{}) 12 | dateType = reflect.TypeOf(mysql.Date{}) 13 | durationType = reflect.TypeOf(time.Duration(0)) 14 | blobType = reflect.TypeOf(mysql.Blob{}) 15 | rawType = reflect.TypeOf(mysql.Raw{}) 16 | ) 17 | 18 | // val should be an addressable value 19 | func bindValue(val reflect.Value) (out paramValue) { 20 | if !val.IsValid() { 21 | out.typ = MYSQL_TYPE_NULL 22 | return 23 | } 24 | typ := val.Type() 25 | if typ.Kind() == reflect.Ptr { 26 | // We have addressable pointer 27 | out.addr = val.Addr() 28 | // Dereference pointer for next operation on its value 29 | typ = typ.Elem() 30 | val = val.Elem() 31 | } else { 32 | // We have addressable value. Create a pointer to it 33 | pv := val.Addr() 34 | // This pointer is unaddressable so copy it and return an address 35 | out.addr = reflect.New(pv.Type()) 36 | out.addr.Elem().Set(pv) 37 | } 38 | 39 | // Obtain value type 40 | switch typ.Kind() { 41 | case reflect.String: 42 | out.typ = MYSQL_TYPE_STRING 43 | out.length = -1 44 | return 45 | 46 | case reflect.Int: 47 | out.typ = _INT_TYPE 48 | out.length = _SIZE_OF_INT 49 | return 50 | 51 | case reflect.Int8: 52 | out.typ = MYSQL_TYPE_TINY 53 | out.length = 1 54 | return 55 | 56 | case reflect.Int16: 57 | out.typ = MYSQL_TYPE_SHORT 58 | out.length = 2 59 | return 60 | 61 | case reflect.Int32: 62 | out.typ = MYSQL_TYPE_LONG 63 | out.length = 4 64 | return 65 | 66 | case reflect.Int64: 67 | if typ == durationType { 68 | out.typ = MYSQL_TYPE_TIME 69 | out.length = -1 70 | return 71 | } 72 | out.typ = MYSQL_TYPE_LONGLONG 73 | out.length = 8 74 | return 75 | 76 | case reflect.Uint: 77 | out.typ = _INT_TYPE | MYSQL_UNSIGNED_MASK 78 | out.length = _SIZE_OF_INT 79 | return 80 | 81 | case reflect.Uint8: 82 | out.typ = MYSQL_TYPE_TINY | MYSQL_UNSIGNED_MASK 83 | out.length = 1 84 | return 85 | 86 | case reflect.Uint16: 87 | out.typ = MYSQL_TYPE_SHORT | MYSQL_UNSIGNED_MASK 88 | out.length = 2 89 | return 90 | 91 | case reflect.Uint32: 92 | out.typ = MYSQL_TYPE_LONG | MYSQL_UNSIGNED_MASK 93 | out.length = 4 94 | return 95 | 96 | case reflect.Uint64: 97 | out.typ = MYSQL_TYPE_LONGLONG | MYSQL_UNSIGNED_MASK 98 | out.length = 8 99 | return 100 | 101 | case reflect.Float32: 102 | out.typ = MYSQL_TYPE_FLOAT 103 | out.length = 4 104 | return 105 | 106 | case reflect.Float64: 107 | out.typ = MYSQL_TYPE_DOUBLE 108 | out.length = 8 109 | return 110 | 111 | case reflect.Slice: 112 | out.length = -1 113 | if typ == blobType { 114 | out.typ = MYSQL_TYPE_BLOB 115 | return 116 | } 117 | if typ.Elem().Kind() == reflect.Uint8 { 118 | out.typ = MYSQL_TYPE_VAR_STRING 119 | return 120 | } 121 | 122 | case reflect.Struct: 123 | out.length = -1 124 | if typ == timeType { 125 | out.typ = MYSQL_TYPE_DATETIME 126 | return 127 | } 128 | if typ == dateType { 129 | out.typ = MYSQL_TYPE_DATE 130 | return 131 | } 132 | if typ == timestampType { 133 | out.typ = MYSQL_TYPE_TIMESTAMP 134 | return 135 | } 136 | if typ == rawType { 137 | out.typ = val.FieldByName("Typ").Interface().(uint16) 138 | out.addr = val.FieldByName("Val").Addr() 139 | out.raw = true 140 | return 141 | } 142 | 143 | case reflect.Bool: 144 | out.typ = MYSQL_TYPE_TINY 145 | // bool implementation isn't documented so we treat it in special way 146 | out.length = -1 147 | return 148 | } 149 | panic(mysql.ErrBindUnkType) 150 | } 151 | -------------------------------------------------------------------------------- /native/codecs.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "time" 6 | ) 7 | 8 | // Integers 9 | 10 | func DecodeU16(buf []byte) uint16 { 11 | return uint16(buf[1])<<8 | uint16(buf[0]) 12 | } 13 | func (pr *pktReader) readU16() uint16 { 14 | buf := pr.buf[:2] 15 | pr.readFull(buf) 16 | return DecodeU16(buf) 17 | } 18 | 19 | func DecodeU24(buf []byte) uint32 { 20 | return (uint32(buf[2])<<8|uint32(buf[1]))<<8 | uint32(buf[0]) 21 | } 22 | func (pr *pktReader) readU24() uint32 { 23 | buf := pr.buf[:3] 24 | pr.readFull(buf) 25 | return DecodeU24(buf) 26 | } 27 | 28 | func DecodeU32(buf []byte) uint32 { 29 | return ((uint32(buf[3])<<8|uint32(buf[2]))<<8| 30 | uint32(buf[1]))<<8 | uint32(buf[0]) 31 | } 32 | func (pr *pktReader) readU32() uint32 { 33 | buf := pr.buf[:4] 34 | pr.readFull(buf) 35 | return DecodeU32(buf) 36 | } 37 | 38 | func DecodeU64(buf []byte) (rv uint64) { 39 | for ii, vv := range buf { 40 | rv |= uint64(vv) << uint(ii*8) 41 | } 42 | return 43 | } 44 | func (pr *pktReader) readU64() (rv uint64) { 45 | buf := pr.buf[:8] 46 | pr.readFull(buf) 47 | return DecodeU64(buf) 48 | } 49 | 50 | func EncodeU16(buf []byte, val uint16) { 51 | buf[0] = byte(val) 52 | buf[1] = byte(val >> 8) 53 | } 54 | func (pw *pktWriter) writeU16(val uint16) { 55 | buf := pw.buf[:2] 56 | EncodeU16(buf, val) 57 | pw.write(buf) 58 | } 59 | 60 | func EncodeU24(buf []byte, val uint32) { 61 | buf[0] = byte(val) 62 | buf[1] = byte(val >> 8) 63 | buf[2] = byte(val >> 16) 64 | } 65 | func (pw *pktWriter) writeU24(val uint32) { 66 | buf := pw.buf[:3] 67 | EncodeU24(buf, val) 68 | pw.write(buf) 69 | } 70 | 71 | func EncodeU32(buf []byte, val uint32) { 72 | buf[0] = byte(val) 73 | buf[1] = byte(val >> 8) 74 | buf[2] = byte(val >> 16) 75 | buf[3] = byte(val >> 24) 76 | } 77 | func (pw *pktWriter) writeU32(val uint32) { 78 | buf := pw.buf[:4] 79 | EncodeU32(buf, val) 80 | pw.write(buf) 81 | } 82 | 83 | func EncodeU64(buf []byte, val uint64) { 84 | buf[0] = byte(val) 85 | buf[1] = byte(val >> 8) 86 | buf[2] = byte(val >> 16) 87 | buf[3] = byte(val >> 24) 88 | buf[4] = byte(val >> 32) 89 | buf[5] = byte(val >> 40) 90 | buf[6] = byte(val >> 48) 91 | buf[7] = byte(val >> 56) 92 | } 93 | func (pw *pktWriter) writeU64(val uint64) { 94 | buf := pw.buf[:8] 95 | EncodeU64(buf, val) 96 | pw.write(buf) 97 | } 98 | 99 | // Variable length values 100 | 101 | func (pr *pktReader) readNullLCB() (lcb uint64, null bool) { 102 | bb := pr.readByte() 103 | switch bb { 104 | case 251: 105 | null = true 106 | case 252: 107 | lcb = uint64(pr.readU16()) 108 | case 253: 109 | lcb = uint64(pr.readU24()) 110 | case 254: 111 | lcb = pr.readU64() 112 | default: 113 | lcb = uint64(bb) 114 | } 115 | return 116 | } 117 | 118 | func (pr *pktReader) readLCB() uint64 { 119 | lcb, null := pr.readNullLCB() 120 | if null { 121 | panic(mysql.ErrUnexpNullLCB) 122 | } 123 | return lcb 124 | } 125 | 126 | func (pw *pktWriter) writeLCB(val uint64) { 127 | switch { 128 | case val <= 250: 129 | pw.writeByte(byte(val)) 130 | 131 | case val <= 0xffff: 132 | pw.writeByte(252) 133 | pw.writeU16(uint16(val)) 134 | 135 | case val <= 0xffffff: 136 | pw.writeByte(253) 137 | pw.writeU24(uint32(val)) 138 | 139 | default: 140 | pw.writeByte(254) 141 | pw.writeU64(val) 142 | } 143 | } 144 | 145 | func lenLCB(val uint64) int { 146 | switch { 147 | case val <= 250: 148 | return 1 149 | 150 | case val <= 0xffff: 151 | return 3 152 | 153 | case val <= 0xffffff: 154 | return 4 155 | } 156 | return 9 157 | } 158 | 159 | func (pr *pktReader) readNullBin() (buf []byte, null bool) { 160 | var l uint64 161 | l, null = pr.readNullLCB() 162 | if null { 163 | return 164 | } 165 | buf = make([]byte, l) 166 | pr.readFull(buf) 167 | return 168 | } 169 | 170 | func (pr *pktReader) readBin() []byte { 171 | buf, null := pr.readNullBin() 172 | if null { 173 | panic(mysql.ErrUnexpNullLCS) 174 | } 175 | return buf 176 | } 177 | 178 | func (pr *pktReader) skipBin() { 179 | n, _ := pr.readNullLCB() 180 | pr.skipN(int(n)) 181 | } 182 | 183 | func (pw *pktWriter) writeBin(buf []byte) { 184 | pw.writeLCB(uint64(len(buf))) 185 | pw.write(buf) 186 | } 187 | 188 | func lenBin(buf []byte) int { 189 | return lenLCB(uint64(len(buf))) + len(buf) 190 | } 191 | 192 | func lenStr(str string) int { 193 | return lenLCB(uint64(len(str))) + len(str) 194 | } 195 | 196 | func (pw *pktWriter) writeLC(v interface{}) { 197 | switch val := v.(type) { 198 | case []byte: 199 | pw.writeBin(val) 200 | case *[]byte: 201 | pw.writeBin(*val) 202 | case string: 203 | pw.writeBin([]byte(val)) 204 | case *string: 205 | pw.writeBin([]byte(*val)) 206 | default: 207 | panic("Unknown data type for write as length coded string") 208 | } 209 | } 210 | 211 | func lenLC(v interface{}) int { 212 | switch val := v.(type) { 213 | case []byte: 214 | return lenBin(val) 215 | case *[]byte: 216 | return lenBin(*val) 217 | case string: 218 | return lenStr(val) 219 | case *string: 220 | return lenStr(*val) 221 | } 222 | panic("Unknown data type for write as length coded string") 223 | } 224 | 225 | func (pr *pktReader) readNTB() (buf []byte) { 226 | for { 227 | ch := pr.readByte() 228 | if ch == 0 { 229 | break 230 | } 231 | buf = append(buf, ch) 232 | } 233 | return 234 | } 235 | 236 | func (pw *pktWriter) writeNTB(buf []byte) { 237 | pw.write(buf) 238 | pw.writeByte(0) 239 | } 240 | 241 | func (pw *pktWriter) writeNT(v interface{}) { 242 | switch val := v.(type) { 243 | case []byte: 244 | pw.writeNTB(val) 245 | case string: 246 | pw.writeNTB([]byte(val)) 247 | default: 248 | panic("Unknown type for write as null terminated data") 249 | } 250 | } 251 | 252 | // Date and time 253 | 254 | func (pr *pktReader) readDuration() time.Duration { 255 | dlen := pr.readByte() 256 | switch dlen { 257 | case 251: 258 | // Null 259 | panic(mysql.ErrUnexpNullTime) 260 | case 0: 261 | // 00:00:00 262 | return 0 263 | case 5, 8, 12: 264 | // Properly time length 265 | default: 266 | panic(mysql.ErrWrongDateLen) 267 | } 268 | buf := pr.buf[:dlen] 269 | pr.readFull(buf) 270 | tt := int64(0) 271 | switch dlen { 272 | case 12: 273 | // Nanosecond part 274 | tt += int64(DecodeU32(buf[8:])) 275 | fallthrough 276 | case 8: 277 | // HH:MM:SS part 278 | tt += int64(int(buf[5])*3600+int(buf[6])*60+int(buf[7])) * 1e9 279 | fallthrough 280 | case 5: 281 | // Day part 282 | tt += int64(DecodeU32(buf[1:5])) * (24 * 3600 * 1e9) 283 | } 284 | if buf[0] != 0 { 285 | tt = -tt 286 | } 287 | return time.Duration(tt) 288 | } 289 | 290 | func EncodeDuration(buf []byte, d time.Duration) int { 291 | buf[0] = 0 292 | if d < 0 { 293 | buf[1] = 1 294 | d = -d 295 | } 296 | if ns := uint32(d % 1e9); ns != 0 { 297 | EncodeU32(buf[9:13], ns) // nanosecond 298 | buf[0] += 4 299 | } 300 | d /= 1e9 301 | if hms := int(d % (24 * 3600)); buf[0] != 0 || hms != 0 { 302 | buf[8] = byte(hms % 60) // second 303 | hms /= 60 304 | buf[7] = byte(hms % 60) // minute 305 | buf[6] = byte(hms / 60) // hour 306 | buf[0] += 3 307 | } 308 | if day := uint32(d / (24 * 3600)); buf[0] != 0 || day != 0 { 309 | EncodeU32(buf[2:6], day) // day 310 | buf[0] += 4 311 | } 312 | buf[0]++ // For sign byte 313 | return int(buf[0] + 1) 314 | } 315 | 316 | func (pw *pktWriter) writeDuration(d time.Duration) { 317 | buf := pw.buf[:13] 318 | n := EncodeDuration(buf, d) 319 | pw.write(buf[:n]) 320 | } 321 | 322 | func lenDuration(d time.Duration) int { 323 | if d == 0 { 324 | return 2 325 | } 326 | if d%1e9 != 0 { 327 | return 13 328 | } 329 | d /= 1e9 330 | if d%(24*3600) != 0 { 331 | return 9 332 | } 333 | return 6 334 | } 335 | 336 | func (pr *pktReader) readTime() time.Time { 337 | dlen := pr.readByte() 338 | switch dlen { 339 | case 251: 340 | // Null 341 | panic(mysql.ErrUnexpNullDate) 342 | case 0: 343 | // return 0000-00-00 converted to time.Time zero 344 | return time.Time{} 345 | case 4, 7, 11: 346 | // Properly datetime length 347 | default: 348 | panic(mysql.ErrWrongDateLen) 349 | } 350 | 351 | buf := pr.buf[:dlen] 352 | pr.readFull(buf) 353 | var y, mon, d, h, m, s, u int 354 | switch dlen { 355 | case 11: 356 | // 2006-01-02 15:04:05.001004005 357 | u = int(DecodeU32(buf[7:])) 358 | fallthrough 359 | case 7: 360 | // 2006-01-02 15:04:05 361 | h = int(buf[4]) 362 | m = int(buf[5]) 363 | s = int(buf[6]) 364 | fallthrough 365 | case 4: 366 | // 2006-01-02 367 | y = int(DecodeU16(buf[0:2])) 368 | mon = int(buf[2]) 369 | d = int(buf[3]) 370 | } 371 | n := u * int(time.Microsecond) 372 | return time.Date(y, time.Month(mon), d, h, m, s, n, time.Local) 373 | } 374 | 375 | func encodeNonzeroTime(buf []byte, y int16, mon, d, h, m, s byte, u uint32) int { 376 | buf[0] = 0 377 | switch { 378 | case u != 0: 379 | EncodeU32(buf[8:12], u) 380 | buf[0] += 4 381 | fallthrough 382 | case s != 0 || m != 0 || h != 0: 383 | buf[7] = s 384 | buf[6] = m 385 | buf[5] = h 386 | buf[0] += 3 387 | } 388 | buf[4] = d 389 | buf[3] = mon 390 | EncodeU16(buf[1:3], uint16(y)) 391 | buf[0] += 4 392 | return int(buf[0] + 1) 393 | } 394 | 395 | func getTimeMicroseconds(t time.Time) int { 396 | return (t.Nanosecond() + int(time.Microsecond/2)) / int(time.Microsecond) 397 | } 398 | 399 | func EncodeTime(buf []byte, t time.Time) int { 400 | if t.IsZero() { 401 | // MySQL zero 402 | buf[0] = 0 403 | return 1 // MySQL zero 404 | } 405 | y, mon, d := t.Date() 406 | h, m, s := t.Clock() 407 | u:= getTimeMicroseconds(t) 408 | return encodeNonzeroTime( 409 | buf, 410 | int16(y), byte(mon), byte(d), 411 | byte(h), byte(m), byte(s), uint32(u), 412 | ) 413 | } 414 | 415 | func (pw *pktWriter) writeTime(t time.Time) { 416 | buf := pw.buf[:12] 417 | n := EncodeTime(buf, t) 418 | pw.write(buf[:n]) 419 | } 420 | 421 | func lenTime(t time.Time) int { 422 | switch { 423 | case t.IsZero(): 424 | return 1 425 | case getTimeMicroseconds(t) != 0: 426 | return 12 427 | case t.Second() != 0 || t.Minute() != 0 || t.Hour() != 0: 428 | return 8 429 | } 430 | return 5 431 | } 432 | 433 | func (pr *pktReader) readDate() mysql.Date { 434 | y, m, d := pr.readTime().Date() 435 | return mysql.Date{int16(y), byte(m), byte(d)} 436 | } 437 | 438 | func EncodeDate(buf []byte, d mysql.Date) int { 439 | if d.IsZero() { 440 | // MySQL zero 441 | buf[0] = 0 442 | return 1 443 | } 444 | return encodeNonzeroTime(buf, d.Year, d.Month, d.Day, 0, 0, 0, 0) 445 | } 446 | 447 | func (pw *pktWriter) writeDate(d mysql.Date) { 448 | buf := pw.buf[:5] 449 | n := EncodeDate(buf, d) 450 | pw.write(buf[:n]) 451 | } 452 | 453 | func lenDate(d mysql.Date) int { 454 | if d.IsZero() { 455 | return 1 456 | } 457 | return 5 458 | } 459 | -------------------------------------------------------------------------------- /native/command.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | //import "log" 8 | 9 | // _COM_QUIT, _COM_STATISTICS, _COM_PROCESS_INFO, _COM_DEBUG, _COM_PING: 10 | func (my *Conn) sendCmd(cmd byte) { 11 | my.seq = 0 12 | pw := my.newPktWriter(1) 13 | pw.writeByte(cmd) 14 | if my.Debug { 15 | log.Printf("[%2d <-] Command packet: Cmd=0x%x", my.seq-1, cmd) 16 | } 17 | } 18 | 19 | // _COM_QUERY, _COM_INIT_DB, _COM_CREATE_DB, _COM_DROP_DB, _COM_STMT_PREPARE: 20 | func (my *Conn) sendCmdStr(cmd byte, s string) { 21 | my.seq = 0 22 | pw := my.newPktWriter(1 + len(s)) 23 | pw.writeByte(cmd) 24 | pw.write([]byte(s)) 25 | if my.Debug { 26 | log.Printf("[%2d <-] Command packet: Cmd=0x%x %s", my.seq-1, cmd, s) 27 | } 28 | } 29 | 30 | // _COM_PROCESS_KILL, _COM_STMT_CLOSE, _COM_STMT_RESET: 31 | func (my *Conn) sendCmdU32(cmd byte, u uint32) { 32 | my.seq = 0 33 | pw := my.newPktWriter(1 + 4) 34 | pw.writeByte(cmd) 35 | pw.writeU32(u) 36 | if my.Debug { 37 | log.Printf("[%2d <-] Command packet: Cmd=0x%x %d", my.seq-1, cmd, u) 38 | } 39 | } 40 | 41 | func (my *Conn) sendLongData(stmtid uint32, pnum uint16, data []byte) { 42 | my.seq = 0 43 | pw := my.newPktWriter(1 + 4 + 2 + len(data)) 44 | pw.writeByte(_COM_STMT_SEND_LONG_DATA) 45 | pw.writeU32(stmtid) // Statement ID 46 | pw.writeU16(pnum) // Parameter number 47 | pw.write(data) // payload 48 | if my.Debug { 49 | log.Printf("[%2d <-] SendLongData packet: pnum=%d", my.seq-1, pnum) 50 | } 51 | } 52 | 53 | /*func (my *Conn) sendCmd(cmd byte, argv ...interface{}) { 54 | // Reset sequence number 55 | my.seq = 0 56 | // Write command 57 | switch cmd { 58 | case _COM_QUERY, _COM_INIT_DB, _COM_CREATE_DB, _COM_DROP_DB, 59 | _COM_STMT_PREPARE: 60 | pw := my.newPktWriter(1 + lenBS(argv[0])) 61 | writeByte(pw, cmd) 62 | writeBS(pw, argv[0]) 63 | 64 | case _COM_STMT_SEND_LONG_DATA: 65 | pw := my.newPktWriter(1 + 4 + 2 + lenBS(argv[2])) 66 | writeByte(pw, cmd) 67 | writeU32(pw, argv[0].(uint32)) // Statement ID 68 | writeU16(pw, argv[1].(uint16)) // Parameter number 69 | writeBS(pw, argv[2]) // payload 70 | 71 | case _COM_QUIT, _COM_STATISTICS, _COM_PROCESS_INFO, _COM_DEBUG, _COM_PING: 72 | pw := my.newPktWriter(1) 73 | writeByte(pw, cmd) 74 | 75 | case _COM_FIELD_LIST: 76 | pay_len := 1 + lenBS(argv[0]) + 1 77 | if len(argv) > 1 { 78 | pay_len += lenBS(argv[1]) 79 | } 80 | 81 | pw := my.newPktWriter(pay_len) 82 | writeByte(pw, cmd) 83 | writeNT(pw, argv[0]) 84 | if len(argv) > 1 { 85 | writeBS(pw, argv[1]) 86 | } 87 | 88 | case _COM_TABLE_DUMP: 89 | pw := my.newPktWriter(1 + lenLC(argv[0]) + lenLC(argv[1])) 90 | writeByte(pw, cmd) 91 | writeLC(pw, argv[0]) 92 | writeLC(pw, argv[1]) 93 | 94 | case _COM_REFRESH, _COM_SHUTDOWN: 95 | pw := my.newPktWriter(1 + 1) 96 | writeByte(pw, cmd) 97 | writeByte(pw, argv[0].(byte)) 98 | 99 | case _COM_STMT_FETCH: 100 | pw := my.newPktWriter(1 + 4 + 4) 101 | writeByte(pw, cmd) 102 | writeU32(pw, argv[0].(uint32)) 103 | writeU32(pw, argv[1].(uint32)) 104 | 105 | case _COM_PROCESS_KILL, _COM_STMT_CLOSE, _COM_STMT_RESET: 106 | pw := my.newPktWriter(1 + 4) 107 | writeByte(pw, cmd) 108 | writeU32(pw, argv[0].(uint32)) 109 | 110 | case _COM_SET_OPTION: 111 | pw := my.newPktWriter(1 + 2) 112 | writeByte(pw, cmd) 113 | writeU16(pw, argv[0].(uint16)) 114 | 115 | case _COM_CHANGE_USER: 116 | pw := my.newPktWriter( 117 | 1 + lenBS(argv[0]) + 1 + lenLC(argv[1]) + lenBS(argv[2]) + 1, 118 | ) 119 | writeByte(pw, cmd) 120 | writeNT(pw, argv[0]) // User name 121 | writeLC(pw, argv[1]) // Scrambled password 122 | writeNT(pw, argv[2]) // Database name 123 | //writeU16(pw, argv[3]) // Character set number (since 5.1.23?) 124 | 125 | case _COM_BINLOG_DUMP: 126 | pay_len := 1 + 4 + 2 + 4 127 | if len(argv) > 3 { 128 | pay_len += lenBS(argv[3]) 129 | } 130 | 131 | pw := my.newPktWriter(pay_len) 132 | writeByte(pw, cmd) 133 | writeU32(pw, argv[0].(uint32)) // Start position 134 | writeU16(pw, argv[1].(uint16)) // Flags 135 | writeU32(pw, argv[2].(uint32)) // Slave server id 136 | if len(argv) > 3 { 137 | writeBS(pw, argv[3]) 138 | } 139 | 140 | // TODO: case COM_REGISTER_SLAVE: 141 | 142 | default: 143 | panic("Unknown code for MySQL command") 144 | } 145 | 146 | if my.Debug { 147 | log.Printf("[%2d <-] Command packet: Cmd=0x%x", my.seq-1, cmd) 148 | } 149 | }*/ 150 | -------------------------------------------------------------------------------- /native/common.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "io" 5 | "runtime" 6 | ) 7 | 8 | var tab8s = " " 9 | 10 | func catchError(err *error) { 11 | if pv := recover(); pv != nil { 12 | switch e := pv.(type) { 13 | case runtime.Error: 14 | panic(pv) 15 | case error: 16 | if e == io.EOF { 17 | *err = io.ErrUnexpectedEOF 18 | } else { 19 | *err = e 20 | } 21 | default: 22 | panic(pv) 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /native/consts.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import "strconv" 4 | 5 | // Client caps - borrowed from GoMySQL 6 | const ( 7 | _CLIENT_LONG_PASSWORD = 1 << iota // new more secure passwords 8 | _CLIENT_FOUND_ROWS // Found instead of affected rows 9 | _CLIENT_LONG_FLAG // Get all column flags 10 | _CLIENT_CONNECT_WITH_DB // One can specify db on connect 11 | _CLIENT_NO_SCHEMA // Don't allow database.table.column 12 | _CLIENT_COMPRESS // Can use compression protocol 13 | _CLIENT_ODBC // Odbc client 14 | _CLIENT_LOCAL_FILES // Can use LOAD DATA LOCAL 15 | _CLIENT_IGNORE_SPACE // Ignore spaces before '(' 16 | _CLIENT_PROTOCOL_41 // New 4.1 protocol 17 | _CLIENT_INTERACTIVE // This is an interactive client 18 | _CLIENT_SSL // Switch to SSL after handshake 19 | _CLIENT_IGNORE_SIGPIPE // IGNORE sigpipes 20 | _CLIENT_TRANSACTIONS // Client knows about transactions 21 | _CLIENT_RESERVED // Old flag for 4.1 protocol 22 | _CLIENT_SECURE_CONN // New 4.1 authentication 23 | _CLIENT_MULTI_STATEMENTS // Enable/disable multi-stmt support 24 | _CLIENT_MULTI_RESULTS // Enable/disable multi-results 25 | _CLIENT_PS_MULTI_RESULTS // Enable/disable multiple resultsets for COM_STMT_EXECUTE 26 | _CLIENT_PLUGIN_AUTH // Supports authentication plugins 27 | _CLIENT_CONNECT_ATTRS // Sends connection attributes in Protocol::HandshakeResponse41 28 | _CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA // Length of auth response data in Protocol::HandshakeResponse41 is a length-encoded integer 29 | _CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS // Enable/disable expired passwords 30 | _CLIENT_SESSION_TRACK // Can set SERVER_SESSION_STATE_CHANGED in the Status Flags and send session-state change data after a OK packet. 31 | _CLIENT_DEPRECATE_EOF // Expects an OK (instead of EOF) after the resultset rows of a Text Resultset 32 | ) 33 | 34 | // Commands - borrowed from GoMySQL 35 | const ( 36 | _COM_QUIT = 0x01 37 | _COM_INIT_DB = 0x02 38 | _COM_QUERY = 0x03 39 | _COM_FIELD_LIST = 0x04 40 | _COM_CREATE_DB = 0x05 41 | _COM_DROP_DB = 0x06 42 | _COM_REFRESH = 0x07 43 | _COM_SHUTDOWN = 0x08 44 | _COM_STATISTICS = 0x09 45 | _COM_PROCESS_INFO = 0x0a 46 | _COM_CONNECT = 0x0b 47 | _COM_PROCESS_KILL = 0x0c 48 | _COM_DEBUG = 0x0d 49 | _COM_PING = 0x0e 50 | _COM_TIME = 0x0f 51 | _COM_DELAYED_INSERT = 0x10 52 | _COM_CHANGE_USER = 0x11 53 | _COM_BINLOG_DUMP = 0x12 54 | _COM_TABLE_DUMP = 0x13 55 | _COM_CONNECT_OUT = 0x14 56 | _COM_REGISTER_SLAVE = 0x15 57 | _COM_STMT_PREPARE = 0x16 58 | _COM_STMT_EXECUTE = 0x17 59 | _COM_STMT_SEND_LONG_DATA = 0x18 60 | _COM_STMT_CLOSE = 0x19 61 | _COM_STMT_RESET = 0x1a 62 | _COM_SET_OPTION = 0x1b 63 | _COM_STMT_FETCH = 0x1c 64 | ) 65 | 66 | // MySQL protocol types. 67 | // 68 | // mymysql uses only some of them for send data to the MySQL server. Used 69 | // MySQL types are marked with a comment contains mymysql type that uses it. 70 | const ( 71 | MYSQL_TYPE_DECIMAL = 0x00 72 | MYSQL_TYPE_TINY = 0x01 // int8, uint8, bool 73 | MYSQL_TYPE_SHORT = 0x02 // int16, uint16 74 | MYSQL_TYPE_LONG = 0x03 // int32, uint32 75 | MYSQL_TYPE_FLOAT = 0x04 // float32 76 | MYSQL_TYPE_DOUBLE = 0x05 // float64 77 | MYSQL_TYPE_NULL = 0x06 // nil 78 | MYSQL_TYPE_TIMESTAMP = 0x07 // Timestamp 79 | MYSQL_TYPE_LONGLONG = 0x08 // int64, uint64 80 | MYSQL_TYPE_INT24 = 0x09 81 | MYSQL_TYPE_DATE = 0x0a // Date 82 | MYSQL_TYPE_TIME = 0x0b // Time 83 | MYSQL_TYPE_DATETIME = 0x0c // time.Time 84 | MYSQL_TYPE_YEAR = 0x0d 85 | MYSQL_TYPE_NEWDATE = 0x0e 86 | MYSQL_TYPE_VARCHAR = 0x0f 87 | MYSQL_TYPE_BIT = 0x10 88 | MYSQL_TYPE_NEWDECIMAL = 0xf6 89 | MYSQL_TYPE_ENUM = 0xf7 90 | MYSQL_TYPE_SET = 0xf8 91 | MYSQL_TYPE_TINY_BLOB = 0xf9 92 | MYSQL_TYPE_MEDIUM_BLOB = 0xfa 93 | MYSQL_TYPE_LONG_BLOB = 0xfb 94 | MYSQL_TYPE_BLOB = 0xfc // Blob 95 | MYSQL_TYPE_VAR_STRING = 0xfd // []byte 96 | MYSQL_TYPE_STRING = 0xfe // string 97 | MYSQL_TYPE_GEOMETRY = 0xff 98 | 99 | MYSQL_UNSIGNED_MASK = uint16(1 << 15) 100 | ) 101 | 102 | // Mapping of MySQL types to (prefered) protocol types. Use it if you create 103 | // your own Raw value. 104 | // 105 | // Comments contains corresponding types used by mymysql. string type may be 106 | // replaced by []byte type and vice versa. []byte type is native for sending 107 | // on a network, so any string is converted to it before sending. Than for 108 | // better performance use []byte. 109 | const ( 110 | // Client send and receive, mymysql representation for send / receive 111 | TINYINT = MYSQL_TYPE_TINY // int8 / int8 112 | SMALLINT = MYSQL_TYPE_SHORT // int16 / int16 113 | INT = MYSQL_TYPE_LONG // int32 / int32 114 | BIGINT = MYSQL_TYPE_LONGLONG // int64 / int64 115 | FLOAT = MYSQL_TYPE_FLOAT // float32 / float32 116 | DOUBLE = MYSQL_TYPE_DOUBLE // float64 / float32 117 | TIME = MYSQL_TYPE_TIME // Time / Time 118 | DATE = MYSQL_TYPE_DATE // Date / Date 119 | DATETIME = MYSQL_TYPE_DATETIME // time.Time / time.Time 120 | TIMESTAMP = MYSQL_TYPE_TIMESTAMP // Timestamp / time.Time 121 | CHAR = MYSQL_TYPE_STRING // string / []byte 122 | BLOB = MYSQL_TYPE_BLOB // Blob / []byte 123 | NULL = MYSQL_TYPE_NULL // nil 124 | 125 | // Client send only, mymysql representation for send 126 | OUT_TEXT = MYSQL_TYPE_STRING // string 127 | OUT_VARCHAR = MYSQL_TYPE_STRING // string 128 | OUT_BINARY = MYSQL_TYPE_BLOB // Blob 129 | OUT_VARBINARY = MYSQL_TYPE_BLOB // Blob 130 | 131 | // Client receive only, mymysql representation for receive 132 | IN_MEDIUMINT = MYSQL_TYPE_LONG // int32 133 | IN_YEAR = MYSQL_TYPE_SHORT // int16 134 | IN_BINARY = MYSQL_TYPE_STRING // []byte 135 | IN_VARCHAR = MYSQL_TYPE_VAR_STRING // []byte 136 | IN_VARBINARY = MYSQL_TYPE_VAR_STRING // []byte 137 | IN_TINYBLOB = MYSQL_TYPE_TINY_BLOB // []byte 138 | IN_TINYTEXT = MYSQL_TYPE_TINY_BLOB // []byte 139 | IN_TEXT = MYSQL_TYPE_BLOB // []byte 140 | IN_MEDIUMBLOB = MYSQL_TYPE_MEDIUM_BLOB // []byte 141 | IN_MEDIUMTEXT = MYSQL_TYPE_MEDIUM_BLOB // []byte 142 | IN_LONGBLOB = MYSQL_TYPE_LONG_BLOB // []byte 143 | IN_LONGTEXT = MYSQL_TYPE_LONG_BLOB // []byte 144 | 145 | // MySQL 5.x specific 146 | IN_DECIMAL = MYSQL_TYPE_NEWDECIMAL // TODO 147 | IN_BIT = MYSQL_TYPE_BIT // []byte 148 | ) 149 | 150 | // Flags - borrowed from GoMySQL 151 | const ( 152 | _FLAG_NOT_NULL = 1 << iota 153 | _FLAG_PRI_KEY 154 | _FLAG_UNIQUE_KEY 155 | _FLAG_MULTIPLE_KEY 156 | _FLAG_BLOB 157 | _FLAG_UNSIGNED 158 | _FLAG_ZEROFILL 159 | _FLAG_BINARY 160 | _FLAG_ENUM 161 | _FLAG_AUTO_INCREMENT 162 | _FLAG_TIMESTAMP 163 | _FLAG_SET 164 | _FLAG_NO_DEFAULT_VALUE 165 | ) 166 | 167 | var ( 168 | _SIZE_OF_INT int 169 | _INT_TYPE uint16 170 | ) 171 | 172 | func init() { 173 | switch strconv.IntSize { 174 | case 32: 175 | _INT_TYPE = MYSQL_TYPE_LONG 176 | _SIZE_OF_INT = 4 177 | case 64: 178 | _INT_TYPE = MYSQL_TYPE_LONGLONG 179 | _SIZE_OF_INT = 8 180 | default: 181 | panic("bad int size") 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /native/init.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/sha1" 7 | "crypto/x509" 8 | "encoding/pem" 9 | "log" 10 | 11 | "github.com/ziutek/mymysql/mysql" 12 | ) 13 | 14 | func (my *Conn) init() { 15 | my.seq = 0 // Reset sequence number, mainly for reconnect 16 | if my.Debug { 17 | log.Printf("[%2d ->] Init packet:", my.seq) 18 | } 19 | pr := my.newPktReader() 20 | 21 | my.info.prot_ver = pr.readByte() 22 | my.info.serv_ver = pr.readNTB() 23 | my.info.thr_id = pr.readU32() 24 | pr.readFull(my.info.scramble[0:8]) 25 | pr.skipN(1) 26 | my.info.caps = uint32(pr.readU16()) // lower two bytes 27 | my.info.lang = pr.readByte() 28 | my.status = mysql.ConnStatus(pr.readU16()) 29 | my.info.caps = uint32(pr.readU16())<<16 | my.info.caps // upper two bytes 30 | pr.skipN(11) 31 | if my.info.caps&_CLIENT_PROTOCOL_41 != 0 { 32 | pr.readFull(my.info.scramble[8:]) 33 | } 34 | pr.skipN(1) // reserved (all [00]) 35 | if my.info.caps&_CLIENT_PLUGIN_AUTH != 0 { 36 | my.info.plugin = pr.readNTB() 37 | } 38 | pr.skipAll() // Skip other information 39 | if my.Debug { 40 | log.Printf(tab8s+"ProtVer=%d, ServVer=\"%s\" Status=0x%x", 41 | my.info.prot_ver, my.info.serv_ver, my.status, 42 | ) 43 | } 44 | if my.info.caps&_CLIENT_PROTOCOL_41 == 0 { 45 | panic(mysql.ErrOldProtocol) 46 | } 47 | } 48 | 49 | func (my *Conn) auth() { 50 | if my.Debug { 51 | log.Printf("[%2d <-] Authentication packet", my.seq) 52 | } 53 | flags := uint32( 54 | _CLIENT_PROTOCOL_41 | 55 | _CLIENT_LONG_PASSWORD | 56 | _CLIENT_LONG_FLAG | 57 | _CLIENT_TRANSACTIONS | 58 | _CLIENT_SECURE_CONN | 59 | _CLIENT_LOCAL_FILES | 60 | _CLIENT_MULTI_STATEMENTS | 61 | _CLIENT_MULTI_RESULTS) 62 | // Reset flags not supported by server 63 | flags &= uint32(my.info.caps) | 0xffff0000 64 | if my.plugin != string(my.info.plugin) { 65 | my.plugin = string(my.info.plugin) 66 | } 67 | var scrPasswd []byte 68 | switch my.plugin { 69 | case "caching_sha2_password": 70 | flags |= _CLIENT_PLUGIN_AUTH 71 | scrPasswd = encryptedSHA256Passwd(my.passwd, my.info.scramble[:]) 72 | case "mysql_old_password": 73 | my.oldPasswd() 74 | return 75 | default: 76 | // mysql_native_password by default 77 | scrPasswd = encryptedPasswd(my.passwd, my.info.scramble[:]) 78 | } 79 | 80 | // encode length of the auth plugin data 81 | var authRespLEIBuf [9]byte 82 | authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(scrPasswd))) 83 | if len(authRespLEI) > 1 { 84 | // if the length can not be written in 1 byte, it must be written as a 85 | // length encoded integer 86 | flags |= _CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 87 | } 88 | 89 | pay_len := 4 + 4 + 1 + 23 + len(my.user) + 1 + len(authRespLEI) + len(scrPasswd) + 21 + 1 90 | 91 | if len(my.dbname) > 0 { 92 | pay_len += len(my.dbname) + 1 93 | flags |= _CLIENT_CONNECT_WITH_DB 94 | } 95 | pw := my.newPktWriter(pay_len) 96 | pw.writeU32(flags) 97 | pw.writeU32(uint32(my.max_pkt_size)) 98 | pw.writeByte(my.info.lang) // Charset number 99 | pw.writeZeros(23) // Filler 100 | pw.writeNTB([]byte(my.user)) // Username 101 | pw.writeBin(scrPasswd) // Encrypted password 102 | 103 | // write database name 104 | if len(my.dbname) > 0 { 105 | pw.writeNTB([]byte(my.dbname)) 106 | } 107 | 108 | // write plugin name 109 | if my.plugin != "" { 110 | pw.writeNTB([]byte(my.plugin)) 111 | } else { 112 | pw.writeNTB([]byte("mysql_native_password")) 113 | } 114 | return 115 | } 116 | 117 | func (my *Conn) authResponse() { 118 | // Read Result Packet 119 | authData, newPlugin := my.getAuthResult() 120 | 121 | // handle auth plugin switch, if requested 122 | if newPlugin != "" { 123 | var scrPasswd []byte 124 | if len(authData) >= 20 { 125 | // old_password's len(authData) == 0 126 | copy(my.info.scramble[:], authData[:20]) 127 | } 128 | my.info.plugin = []byte(newPlugin) 129 | my.plugin = newPlugin 130 | switch my.plugin { 131 | case "caching_sha2_password": 132 | scrPasswd = encryptedSHA256Passwd(my.passwd, my.info.scramble[:]) 133 | case "mysql_old_password": 134 | scrPasswd = encryptedOldPassword(my.passwd, my.info.scramble[:]) 135 | // append \0 after old_password 136 | scrPasswd = append(scrPasswd, 0) 137 | case "sha256_password": 138 | // request public key from server 139 | scrPasswd = []byte{1} 140 | default: // mysql_native_password 141 | scrPasswd = encryptedPasswd(my.passwd, my.info.scramble[:]) 142 | } 143 | my.writeAuthSwitchPacket(scrPasswd) 144 | 145 | // Read Result Packet 146 | authData, newPlugin = my.getAuthResult() 147 | 148 | // Do not allow to change the auth plugin more than once 149 | if newPlugin != "" { 150 | return 151 | } 152 | } 153 | 154 | switch my.plugin { 155 | 156 | // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ 157 | case "caching_sha2_password": 158 | switch len(authData) { 159 | case 0: 160 | return // auth successful 161 | case 1: 162 | switch authData[0] { 163 | case 3: // cachingSha2PasswordFastAuthSuccess 164 | my.getResult(nil, nil) 165 | 166 | case 4: // cachingSha2PasswordPerformFullAuthentication 167 | // request public key from server 168 | pw := my.newPktWriter(1) 169 | pw.writeByte(2) 170 | 171 | // parse public key 172 | pr := my.newPktReader() 173 | pr.skipN(1) 174 | data := pr.readAll() 175 | block, _ := pem.Decode(data) 176 | pkix, err := x509.ParsePKIXPublicKey(block.Bytes) 177 | if err != nil { 178 | panic(mysql.ErrAuthentication) 179 | } 180 | pubKey := pkix.(*rsa.PublicKey) 181 | 182 | // send encrypted password 183 | my.sendEncryptedPassword(my.info.scramble[:], pubKey) 184 | my.getResult(nil, nil) 185 | } 186 | } 187 | case "sha256_password": 188 | switch len(authData) { 189 | case 0: 190 | return // auth successful 191 | default: 192 | // parse public key 193 | block, _ := pem.Decode(authData) 194 | pub, err := x509.ParsePKIXPublicKey(block.Bytes) 195 | if err != nil { 196 | panic(mysql.ErrAuthentication) 197 | } 198 | 199 | // send encrypted password 200 | my.sendEncryptedPassword(my.info.scramble[:], pub.(*rsa.PublicKey)) 201 | my.getResult(nil, nil) 202 | } 203 | } 204 | return 205 | } 206 | 207 | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse 208 | func (my *Conn) writeAuthSwitchPacket(scrPasswd []byte) { 209 | pw := my.newPktWriter(len(scrPasswd)) 210 | pw.write(scrPasswd) // Encrypted password 211 | return 212 | } 213 | 214 | func (my *Conn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) { 215 | enc, err := encryptPassword(my.passwd, seed, pub) 216 | if err != nil { 217 | panic(mysql.ErrAuthentication) 218 | } 219 | my.writeAuthSwitchPacket(enc) 220 | } 221 | 222 | func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { 223 | plain := make([]byte, len(password)+1) 224 | copy(plain, password) 225 | for i := range plain { 226 | j := i % len(seed) 227 | plain[i] ^= seed[j] 228 | } 229 | sha1 := sha1.New() 230 | return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) 231 | } 232 | 233 | func (my *Conn) oldPasswd() { 234 | if my.Debug { 235 | log.Printf("[%2d <-] Password packet", my.seq) 236 | } 237 | scrPasswd := encryptedOldPassword(my.passwd, my.info.scramble[:]) 238 | pw := my.newPktWriter(len(scrPasswd) + 1) 239 | pw.write(scrPasswd) 240 | pw.writeByte(0) 241 | } 242 | -------------------------------------------------------------------------------- /native/mysql.go: -------------------------------------------------------------------------------- 1 | // Package native is a thread unsafe engine for MyMySQL. 2 | package native 3 | 4 | import ( 5 | "bufio" 6 | "fmt" 7 | "io" 8 | "net" 9 | "reflect" 10 | "strings" 11 | "time" 12 | 13 | "github.com/ziutek/mymysql/mysql" 14 | ) 15 | 16 | type serverInfo struct { 17 | prot_ver byte 18 | serv_ver []byte 19 | thr_id uint32 20 | scramble [20]byte 21 | caps uint32 22 | lang byte 23 | plugin []byte 24 | } 25 | 26 | // MySQL connection handler 27 | type Conn struct { 28 | proto string // Network protocol 29 | laddr string // Local address 30 | raddr string // Remote (server) address 31 | 32 | user string // MySQL username 33 | passwd string // MySQL password 34 | dbname string // Database name 35 | plugin string // authentication plugin 36 | 37 | net_conn net.Conn // MySQL connection 38 | rd *bufio.Reader 39 | wr *bufio.Writer 40 | 41 | info serverInfo // MySQL server information 42 | seq byte // MySQL sequence number 43 | 44 | unreaded_reply bool 45 | 46 | init_cmds []string // MySQL commands/queries executed after connect 47 | stmt_map map[uint32]*Stmt // For reprepare during reconnect 48 | 49 | // Current status of MySQL server connection 50 | status mysql.ConnStatus 51 | 52 | // Maximum packet size that client can accept from server. 53 | // Default 16*1024*1024-1. You may change it before connect. 54 | max_pkt_size int 55 | 56 | // Timeout for connect 57 | timeout time.Duration 58 | 59 | dialer mysql.Dialer 60 | 61 | // Return only types accepted by godrv 62 | narrowTypeSet bool 63 | // Store full information about fields in result 64 | fullFieldInfo bool 65 | 66 | // Debug logging. You may change it at any time. 67 | Debug bool 68 | } 69 | 70 | // New: Create new MySQL handler. The first three arguments are passed to net.Bind 71 | // for create connection. user and passwd are for authentication. Optional db 72 | // is database name (you may not specify it and use Use() method later). 73 | func New(proto, laddr, raddr, user, passwd string, args ...string) mysql.Conn { 74 | my := Conn{ 75 | proto: proto, 76 | laddr: laddr, 77 | raddr: raddr, 78 | plugin: "mysql_native_password", 79 | user: user, 80 | passwd: passwd, 81 | stmt_map: make(map[uint32]*Stmt), 82 | max_pkt_size: 16*1024*1024 - 1, 83 | timeout: 2 * time.Minute, 84 | fullFieldInfo: true, 85 | } 86 | if len(args) == 1 { 87 | my.dbname = args[0] 88 | } else if len(args) == 2 { 89 | my.dbname = args[0] 90 | my.plugin = args[1] 91 | } else if len(args) > 2 { 92 | panic("mymy.New: too many arguments") 93 | } 94 | return &my 95 | } 96 | 97 | func (my *Conn) Credentials() (user, passwd string) { 98 | return my.user, my.passwd 99 | } 100 | 101 | func (my *Conn) NarrowTypeSet(narrow bool) { 102 | my.narrowTypeSet = narrow 103 | } 104 | 105 | func (my *Conn) FullFieldInfo(full bool) { 106 | my.fullFieldInfo = full 107 | } 108 | 109 | // Clone: Creates new (not connected) connection using configuration from current 110 | // connection. 111 | func (my *Conn) Clone() mysql.Conn { 112 | var c *Conn 113 | if my.dbname == "" { 114 | c = New(my.proto, my.laddr, my.raddr, my.user, my.passwd).(*Conn) 115 | } else { 116 | c = New(my.proto, my.laddr, my.raddr, my.user, my.passwd, my.dbname).(*Conn) 117 | } 118 | c.max_pkt_size = my.max_pkt_size 119 | c.timeout = my.timeout 120 | c.Debug = my.Debug 121 | return c 122 | } 123 | 124 | // SetMaxPktSize: If new_size > 0 sets maximum packet size. Returns old size. 125 | func (my *Conn) SetMaxPktSize(new_size int) int { 126 | old_size := my.max_pkt_size 127 | if new_size > 0 { 128 | my.max_pkt_size = new_size 129 | } 130 | return old_size 131 | } 132 | 133 | // SetTimeout sets timeout for Connect and Reconnect 134 | func (my *Conn) SetTimeout(timeout time.Duration) { 135 | my.timeout = timeout 136 | } 137 | 138 | // NetConn return internall net.Conn 139 | func (my *Conn) NetConn() net.Conn { 140 | return my.net_conn 141 | } 142 | 143 | type timeoutError struct{} 144 | 145 | func (e *timeoutError) Error() string { return "i/o timeout" } 146 | func (e *timeoutError) Timeout() bool { return true } 147 | func (e *timeoutError) Temporary() bool { return true } 148 | 149 | type stringAddr struct { 150 | net, addr string 151 | } 152 | 153 | func (a stringAddr) Network() string { return a.net } 154 | func (a stringAddr) String() string { return a.addr } 155 | 156 | var DefaultDialer mysql.Dialer = func(proto, laddr, raddr string, 157 | timeout time.Duration) (net.Conn, error) { 158 | 159 | if proto == "" { 160 | proto = "unix" 161 | if strings.ContainsRune(raddr, ':') { 162 | proto = "tcp" 163 | } 164 | } 165 | 166 | // Make a connection 167 | d := &net.Dialer{Timeout: timeout} 168 | if laddr != "" { 169 | var err error 170 | switch proto { 171 | case "tcp", "tcp4", "tcp6": 172 | d.LocalAddr, err = net.ResolveTCPAddr(proto, laddr) 173 | case "unix": 174 | d.LocalAddr, err = net.ResolveTCPAddr(proto, laddr) 175 | default: 176 | err = net.UnknownNetworkError(proto) 177 | } 178 | if err != nil { 179 | return nil, err 180 | } 181 | } 182 | return d.Dial(proto, raddr) 183 | } 184 | 185 | func (my *Conn) SetDialer(d mysql.Dialer) { 186 | my.dialer = d 187 | } 188 | 189 | func (my *Conn) connect() (err error) { 190 | defer catchError(&err) 191 | 192 | my.net_conn = nil 193 | if my.dialer != nil { 194 | my.net_conn, err = my.dialer(my.proto, my.laddr, my.raddr, my.timeout) 195 | if err != nil { 196 | my.net_conn = nil 197 | return 198 | } 199 | } 200 | if my.net_conn == nil { 201 | my.net_conn, err = DefaultDialer(my.proto, my.laddr, my.raddr, my.timeout) 202 | if err != nil { 203 | my.net_conn = nil 204 | return 205 | } 206 | } 207 | my.rd = bufio.NewReader(my.net_conn) 208 | my.wr = bufio.NewWriter(my.net_conn) 209 | 210 | // Initialisation 211 | my.init() 212 | my.auth() 213 | my.authResponse() 214 | 215 | // Execute all registered commands 216 | for _, cmd := range my.init_cmds { 217 | // Send command 218 | my.sendCmdStr(_COM_QUERY, cmd) 219 | // Get command response 220 | res := my.getResponse() 221 | 222 | // Read and discard all result rows 223 | row := res.MakeRow() 224 | for res != nil { 225 | // Only read rows if they exist 226 | if !res.StatusOnly() { 227 | //read each row in this set 228 | for { 229 | err = res.getRow(row) 230 | if err == io.EOF { 231 | break 232 | } else if err != nil { 233 | return 234 | } 235 | } 236 | } 237 | 238 | // Move to the next result 239 | if res, err = res.nextResult(); err != nil { 240 | return 241 | } 242 | } 243 | } 244 | 245 | return 246 | } 247 | 248 | // Connect: Establishes a connection with MySQL server version 4.1 or later. 249 | func (my *Conn) Connect() (err error) { 250 | if my.net_conn != nil { 251 | return mysql.ErrAlredyConn 252 | } 253 | 254 | return my.connect() 255 | } 256 | 257 | // IsConnected checks if connection is established 258 | func (my *Conn) IsConnected() bool { 259 | return my.net_conn != nil 260 | } 261 | 262 | func (my *Conn) closeConn() (err error) { 263 | defer catchError(&err) 264 | 265 | // Always close and invalidate connection, even if 266 | // COM_QUIT returns an error 267 | defer func() { 268 | err = my.net_conn.Close() 269 | my.net_conn = nil // Mark that we disconnect 270 | }() 271 | 272 | // Close the connection 273 | my.sendCmd(_COM_QUIT) 274 | return 275 | } 276 | 277 | // Close connection to the server 278 | func (my *Conn) Close() (err error) { 279 | if my.net_conn == nil { 280 | return mysql.ErrNotConn 281 | } 282 | if my.unreaded_reply { 283 | return mysql.ErrUnreadedReply 284 | } 285 | 286 | return my.closeConn() 287 | } 288 | 289 | // Reconnect: Close and reopen connection. 290 | // Ignore unreaded rows, reprepare all prepared statements. 291 | func (my *Conn) Reconnect() (err error) { 292 | if my.net_conn != nil { 293 | // Close connection, ignore all errors 294 | my.closeConn() 295 | } 296 | // Reopen the connection. 297 | if err = my.connect(); err != nil { 298 | return 299 | } 300 | 301 | // Reprepare all prepared statements 302 | var ( 303 | new_stmt *Stmt 304 | new_map = make(map[uint32]*Stmt) 305 | ) 306 | for _, stmt := range my.stmt_map { 307 | new_stmt, err = my.prepare(stmt.sql) 308 | if err != nil { 309 | return 310 | } 311 | // Assume that fields set in new_stmt by prepare() are indentical to 312 | // corresponding fields in stmt. Why can they be different? 313 | stmt.id = new_stmt.id 314 | stmt.rebind = true 315 | new_map[stmt.id] = stmt 316 | } 317 | // Replace the stmt_map 318 | my.stmt_map = new_map 319 | 320 | return 321 | } 322 | 323 | // Use: Change database 324 | func (my *Conn) Use(dbname string) (err error) { 325 | defer catchError(&err) 326 | 327 | if my.net_conn == nil { 328 | return mysql.ErrNotConn 329 | } 330 | if my.unreaded_reply { 331 | return mysql.ErrUnreadedReply 332 | } 333 | 334 | // Send command 335 | my.sendCmdStr(_COM_INIT_DB, dbname) 336 | // Get server response 337 | my.getResult(nil, nil) 338 | // Save new database name if no errors 339 | my.dbname = dbname 340 | 341 | return 342 | } 343 | 344 | func (my *Conn) getResponse() (res *Result) { 345 | res = my.getResult(nil, nil) 346 | if res == nil { 347 | panic(mysql.ErrBadResult) 348 | } 349 | my.unreaded_reply = !res.StatusOnly() 350 | return 351 | } 352 | 353 | // Start new query. 354 | // 355 | // If you specify the parameters, the SQL string will be a result of 356 | // fmt.Sprintf(sql, params...). 357 | // You must get all result rows (if they exists) before next query. 358 | func (my *Conn) Start(sql string, params ...interface{}) (res mysql.Result, err error) { 359 | defer catchError(&err) 360 | 361 | if my.net_conn == nil { 362 | return nil, mysql.ErrNotConn 363 | } 364 | if my.unreaded_reply { 365 | return nil, mysql.ErrUnreadedReply 366 | } 367 | 368 | if len(params) != 0 { 369 | sql = fmt.Sprintf(sql, params...) 370 | } 371 | // Send query 372 | my.sendCmdStr(_COM_QUERY, sql) 373 | 374 | // Get command response 375 | res = my.getResponse() 376 | return 377 | } 378 | 379 | func (res *Result) getRow(row mysql.Row) (err error) { 380 | defer catchError(&err) 381 | 382 | if res.my.getResult(res, row) != nil { 383 | return io.EOF 384 | } 385 | return nil 386 | } 387 | 388 | // MoreResults returns true if more results exixts. You don't have to call it before 389 | // NextResult method (NextResult returns nil if there is no more results). 390 | func (res *Result) MoreResults() bool { 391 | return res.status&mysql.SERVER_MORE_RESULTS_EXISTS != 0 392 | } 393 | 394 | // ScanRow gets the data row from server. This method reads one row of result set 395 | // directly from network connection (without rows buffering on client side). 396 | // Returns io.EOF if there is no more rows in current result set. 397 | func (res *Result) ScanRow(row mysql.Row) error { 398 | if row == nil { 399 | return mysql.ErrRowLength 400 | } 401 | if res.eor_returned { 402 | return mysql.ErrReadAfterEOR 403 | } 404 | if res.StatusOnly() { 405 | // There is no fields in result (OK result) 406 | res.eor_returned = true 407 | return io.EOF 408 | } 409 | err := res.getRow(row) 410 | if err == io.EOF { 411 | res.eor_returned = true 412 | if !res.MoreResults() { 413 | res.my.unreaded_reply = false 414 | } 415 | } 416 | return err 417 | } 418 | 419 | // GetRow: Like ScanRow but allocates memory for every row. 420 | // Returns nil row insted of io.EOF error. 421 | func (res *Result) GetRow() (mysql.Row, error) { 422 | return mysql.GetRow(res) 423 | } 424 | 425 | func (res *Result) nextResult() (next *Result, err error) { 426 | defer catchError(&err) 427 | if res.MoreResults() { 428 | next = res.my.getResponse() 429 | } 430 | return 431 | } 432 | 433 | // NextResult is used when last query was the multi result query or 434 | // procedure call. Returns the next result or nil if no more resuts exists. 435 | // 436 | // Statements within the procedure may produce unknown number of result sets. 437 | // The final result from the procedure is a status result that includes no 438 | // result set (Result.StatusOnly() == true) . 439 | func (res *Result) NextResult() (mysql.Result, error) { 440 | if !res.MoreResults() { 441 | return nil, nil 442 | } 443 | res, err := res.nextResult() 444 | return res, err 445 | } 446 | 447 | // Ping: Send MySQL PING to the server. 448 | func (my *Conn) Ping() (err error) { 449 | defer catchError(&err) 450 | 451 | if my.net_conn == nil { 452 | return mysql.ErrNotConn 453 | } 454 | if my.unreaded_reply { 455 | return mysql.ErrUnreadedReply 456 | } 457 | 458 | // Send command 459 | my.sendCmd(_COM_PING) 460 | // Get server response 461 | my.getResult(nil, nil) 462 | 463 | return 464 | } 465 | 466 | func (my *Conn) prepare(sql string) (stmt *Stmt, err error) { 467 | defer catchError(&err) 468 | 469 | // Send command 470 | my.sendCmdStr(_COM_STMT_PREPARE, sql) 471 | // Get server response 472 | stmt, ok := my.getPrepareResult(nil).(*Stmt) 473 | if !ok { 474 | return nil, mysql.ErrBadResult 475 | } 476 | if len(stmt.params) > 0 { 477 | // Get param fields 478 | my.getPrepareResult(stmt) 479 | } 480 | if len(stmt.fields) > 0 { 481 | // Get column fields 482 | my.getPrepareResult(stmt) 483 | } 484 | return 485 | } 486 | 487 | // Prepare server side statement. Return statement handler. 488 | func (my *Conn) Prepare(sql string) (mysql.Stmt, error) { 489 | if my.net_conn == nil { 490 | return nil, mysql.ErrNotConn 491 | } 492 | if my.unreaded_reply { 493 | return nil, mysql.ErrUnreadedReply 494 | } 495 | 496 | stmt, err := my.prepare(sql) 497 | if err != nil { 498 | return nil, err 499 | } 500 | // Connect statement with database handler 501 | my.stmt_map[stmt.id] = stmt 502 | // Save SQL for reconnect 503 | stmt.sql = sql 504 | 505 | return stmt, nil 506 | } 507 | 508 | // Bind input data for the parameter markers in the SQL statement that was 509 | // passed to Prepare. 510 | // 511 | // params may be a parameter list (slice), a struct or a pointer to the struct. 512 | // A struct field can by value or pointer to value. A parameter (slice element) 513 | // can be value, pointer to value or pointer to pointer to value. 514 | // Values may be of the folowind types: intXX, uintXX, floatXX, bool, []byte, 515 | // Blob, string, Time, Date, Time, Timestamp, Raw. 516 | func (stmt *Stmt) Bind(params ...interface{}) { 517 | stmt.rebind = true 518 | 519 | if len(params) == 1 { 520 | // Check for struct binding 521 | pval := reflect.ValueOf(params[0]) 522 | kind := pval.Kind() 523 | if kind == reflect.Ptr { 524 | // Dereference pointer 525 | pval = pval.Elem() 526 | kind = pval.Kind() 527 | } 528 | typ := pval.Type() 529 | if kind == reflect.Struct && 530 | typ != timeType && 531 | typ != dateType && 532 | typ != timestampType && 533 | typ != rawType { 534 | // We have a struct to bind 535 | if pval.NumField() != stmt.param_count { 536 | panic(mysql.ErrBindCount) 537 | } 538 | if !pval.CanAddr() { 539 | // Make an addressable structure 540 | v := reflect.New(pval.Type()).Elem() 541 | v.Set(pval) 542 | pval = v 543 | } 544 | for ii := 0; ii < stmt.param_count; ii++ { 545 | stmt.params[ii] = bindValue(pval.Field(ii)) 546 | } 547 | stmt.binded = true 548 | return 549 | } 550 | } 551 | 552 | // There isn't struct to bind 553 | 554 | if len(params) != stmt.param_count { 555 | panic(mysql.ErrBindCount) 556 | } 557 | for ii, par := range params { 558 | pval := reflect.ValueOf(par) 559 | if pval.IsValid() { 560 | if pval.Kind() == reflect.Ptr { 561 | // Dereference pointer - this value i addressable 562 | pval = pval.Elem() 563 | } else { 564 | // Make an addressable value 565 | v := reflect.New(pval.Type()).Elem() 566 | v.Set(pval) 567 | pval = v 568 | } 569 | } 570 | stmt.params[ii] = bindValue(pval) 571 | } 572 | stmt.binded = true 573 | } 574 | 575 | // Run executes prepared statement. If statement requires parameters you may bind 576 | // them first or specify directly. After this command you may use GetRow to 577 | // retrieve data. 578 | func (stmt *Stmt) Run(params ...interface{}) (res mysql.Result, err error) { 579 | defer catchError(&err) 580 | 581 | if stmt.my.net_conn == nil { 582 | return nil, mysql.ErrNotConn 583 | } 584 | if stmt.my.unreaded_reply { 585 | return nil, mysql.ErrUnreadedReply 586 | } 587 | 588 | // Bind parameters if any 589 | if len(params) != 0 { 590 | stmt.Bind(params...) 591 | } else if stmt.param_count != 0 && !stmt.binded { 592 | panic(mysql.ErrBindCount) 593 | } 594 | 595 | // Send EXEC command with binded parameters 596 | stmt.sendCmdExec() 597 | // Get response 598 | r := stmt.my.getResponse() 599 | r.binary = true 600 | res = r 601 | return 602 | } 603 | 604 | // Delete: Destroy statement on server side. Client side handler is invalid after this 605 | // command. 606 | func (stmt *Stmt) Delete() (err error) { 607 | defer catchError(&err) 608 | 609 | if stmt.my.net_conn == nil { 610 | return mysql.ErrNotConn 611 | } 612 | if stmt.my.unreaded_reply { 613 | return mysql.ErrUnreadedReply 614 | } 615 | 616 | // Allways delete statement on client side, even if 617 | // the command return an error. 618 | defer func() { 619 | // Delete statement from stmt_map 620 | delete(stmt.my.stmt_map, stmt.id) 621 | // Invalidate handler 622 | *stmt = Stmt{} 623 | }() 624 | 625 | // Send command 626 | stmt.my.sendCmdU32(_COM_STMT_CLOSE, stmt.id) 627 | return 628 | } 629 | 630 | // Reset: Resets a prepared statement on server: data sent to the server, unbuffered 631 | // result sets and current errors. 632 | func (stmt *Stmt) Reset() (err error) { 633 | defer catchError(&err) 634 | 635 | if stmt.my.net_conn == nil { 636 | return mysql.ErrNotConn 637 | } 638 | if stmt.my.unreaded_reply { 639 | return mysql.ErrUnreadedReply 640 | } 641 | 642 | // Next exec must send type information. We set rebind flag regardless of 643 | // whether the command succeeds or not. 644 | stmt.rebind = true 645 | // Send command 646 | stmt.my.sendCmdU32(_COM_STMT_RESET, stmt.id) 647 | // Get result 648 | stmt.my.getResult(nil, nil) 649 | return 650 | } 651 | 652 | // SendLongData: Send long data to MySQL server in chunks. 653 | // You can call this method after Bind and before Exec. It can be called 654 | // multiple times for one parameter to send TEXT or BLOB data in chunks. 655 | // 656 | // pnum - Parameter number to associate the data with. 657 | // 658 | // data - Data source string, []byte or io.Reader. 659 | // 660 | // pkt_size - It must be must be greater than 6 and less or equal to MySQL 661 | // max_allowed_packet variable. You can obtain value of this variable 662 | // using such query: SHOW variables WHERE Variable_name = 'max_allowed_packet' 663 | // If data source is io.Reader then (pkt_size - 6) is size of a buffer that 664 | // will be allocated for reading. 665 | // 666 | // If you have data source of type string or []byte in one piece you may 667 | // properly set pkt_size and call this method once. If you have data in 668 | // multiple pieces you can call this method multiple times. If data source is 669 | // io.Reader you should properly set pkt_size. Data will be readed from 670 | // io.Reader and send in pieces to the server until EOF. 671 | func (stmt *Stmt) SendLongData(pnum int, data interface{}, pkt_size int) (err error) { 672 | defer catchError(&err) 673 | 674 | if stmt.my.net_conn == nil { 675 | return mysql.ErrNotConn 676 | } 677 | if stmt.my.unreaded_reply { 678 | return mysql.ErrUnreadedReply 679 | } 680 | if pnum < 0 || pnum >= stmt.param_count { 681 | return mysql.ErrWrongParamNum 682 | } 683 | if pkt_size -= 6; pkt_size < 0 { 684 | return mysql.ErrSmallPktSize 685 | } 686 | 687 | switch dd := data.(type) { 688 | case io.Reader: 689 | buf := make([]byte, pkt_size) 690 | for { 691 | nn, ee := dd.Read(buf) 692 | if nn != 0 { 693 | stmt.my.sendLongData(stmt.id, uint16(pnum), buf[0:nn]) 694 | } 695 | if ee == io.EOF { 696 | return 697 | } 698 | if ee != nil { 699 | return ee 700 | } 701 | } 702 | 703 | case []byte: 704 | for len(dd) > pkt_size { 705 | stmt.my.sendLongData(stmt.id, uint16(pnum), dd[0:pkt_size]) 706 | dd = dd[pkt_size:] 707 | } 708 | stmt.my.sendLongData(stmt.id, uint16(pnum), dd) 709 | return 710 | 711 | case string: 712 | for len(dd) > pkt_size { 713 | stmt.my.sendLongData( 714 | stmt.id, 715 | uint16(pnum), 716 | []byte(dd[0:pkt_size]), 717 | ) 718 | dd = dd[pkt_size:] 719 | } 720 | stmt.my.sendLongData(stmt.id, uint16(pnum), []byte(dd)) 721 | return 722 | } 723 | return mysql.ErrUnkDataType 724 | } 725 | 726 | // ThreadId returns the thread ID of the current connection. 727 | func (my *Conn) ThreadId() uint32 { 728 | return my.info.thr_id 729 | } 730 | 731 | // Register MySQL command/query to be executed immediately after connecting to 732 | // the server. You may register multiple commands. They will be executed in 733 | // the order of registration. Yhis method is mainly useful for reconnect. 734 | func (my *Conn) Register(sql string) { 735 | my.init_cmds = append(my.init_cmds, sql) 736 | } 737 | 738 | // Query: See mysql.Query 739 | func (my *Conn) Query(sql string, params ...interface{}) ([]mysql.Row, mysql.Result, error) { 740 | return mysql.Query(my, sql, params...) 741 | } 742 | 743 | // QueryFirst: See mysql.QueryFirst 744 | func (my *Conn) QueryFirst(sql string, params ...interface{}) (mysql.Row, mysql.Result, error) { 745 | return mysql.QueryFirst(my, sql, params...) 746 | } 747 | 748 | // QueryLast: See mysql.QueryLast 749 | func (my *Conn) QueryLast(sql string, params ...interface{}) (mysql.Row, mysql.Result, error) { 750 | return mysql.QueryLast(my, sql, params...) 751 | } 752 | 753 | // Exec: See mysql.Exec 754 | func (stmt *Stmt) Exec(params ...interface{}) ([]mysql.Row, mysql.Result, error) { 755 | return mysql.Exec(stmt, params...) 756 | } 757 | 758 | // ExecFirst: See mysql.ExecFirst 759 | func (stmt *Stmt) ExecFirst(params ...interface{}) (mysql.Row, mysql.Result, error) { 760 | return mysql.ExecFirst(stmt, params...) 761 | } 762 | 763 | // ExecLast: See mysql.ExecLast 764 | func (stmt *Stmt) ExecLast(params ...interface{}) (mysql.Row, mysql.Result, error) { 765 | return mysql.ExecLast(stmt, params...) 766 | } 767 | 768 | // End: See mysql.End 769 | func (res *Result) End() error { 770 | return mysql.End(res) 771 | } 772 | 773 | // GetFirstRow: See mysql.GetFirstRow 774 | func (res *Result) GetFirstRow() (mysql.Row, error) { 775 | return mysql.GetFirstRow(res) 776 | } 777 | 778 | // GetLastRow: See mysql.GetLastRow 779 | func (res *Result) GetLastRow() (mysql.Row, error) { 780 | return mysql.GetLastRow(res) 781 | } 782 | 783 | // GetRows: See mysql.GetRows 784 | func (res *Result) GetRows() ([]mysql.Row, error) { 785 | return mysql.GetRows(res) 786 | } 787 | 788 | // Escape: Escapes special characters in the txt, so it is safe to place returned string 789 | // to Query method. 790 | func (my *Conn) Escape(txt string) string { 791 | return mysql.Escape(my, txt) 792 | } 793 | 794 | func (my *Conn) Status() mysql.ConnStatus { 795 | return my.status 796 | } 797 | 798 | type Transaction struct { 799 | *Conn 800 | } 801 | 802 | // Begin starts a new transaction 803 | func (my *Conn) Begin() (mysql.Transaction, error) { 804 | _, err := my.Start("START TRANSACTION") 805 | return &Transaction{my}, err 806 | } 807 | 808 | // Commit a transaction 809 | func (tr Transaction) Commit() error { 810 | _, err := tr.Start("COMMIT") 811 | tr.Conn = nil // Invalidate this transaction 812 | return err 813 | } 814 | 815 | // Rollback a transaction 816 | func (tr Transaction) Rollback() error { 817 | _, err := tr.Start("ROLLBACK") 818 | tr.Conn = nil // Invalidate this transaction 819 | return err 820 | } 821 | 822 | func (tr Transaction) IsValid() bool { 823 | return tr.Conn != nil 824 | } 825 | 826 | // Do: Binds statement to the context of transaction. For native engine this is 827 | // identity function. 828 | func (tr Transaction) Do(st mysql.Stmt) mysql.Stmt { 829 | if s, ok := st.(*Stmt); !ok || s.my != tr.Conn { 830 | panic("Transaction and statement doesn't belong to the same connection") 831 | } 832 | return st 833 | } 834 | 835 | func init() { 836 | mysql.New = New 837 | } 838 | -------------------------------------------------------------------------------- /native/packet.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "bufio" 5 | "github.com/ziutek/mymysql/mysql" 6 | "io" 7 | "io/ioutil" 8 | ) 9 | 10 | type pktReader struct { 11 | rd *bufio.Reader 12 | seq *byte 13 | remain int 14 | last bool 15 | buf [12]byte 16 | ibuf [3]byte 17 | } 18 | 19 | func (my *Conn) newPktReader() *pktReader { 20 | return &pktReader{rd: my.rd, seq: &my.seq} 21 | } 22 | 23 | func (pr *pktReader) readHeader() { 24 | // Read next packet header 25 | buf := pr.ibuf[:] 26 | for { 27 | n, err := pr.rd.Read(buf) 28 | if err != nil { 29 | panic(err) 30 | } 31 | buf = buf[n:] 32 | if len(buf) == 0 { 33 | break 34 | } 35 | } 36 | pr.remain = int(DecodeU24(pr.ibuf[:])) 37 | seq, err := pr.rd.ReadByte() 38 | if err != nil { 39 | panic(err) 40 | } 41 | // Chceck sequence number 42 | if *pr.seq != seq { 43 | panic(mysql.ErrSeq) 44 | } 45 | *pr.seq++ 46 | // Last packet? 47 | pr.last = (pr.remain != 0xffffff) 48 | } 49 | 50 | func (pr *pktReader) readFull(buf []byte) { 51 | for len(buf) > 0 { 52 | if pr.remain == 0 { 53 | if pr.last { 54 | // No more packets 55 | panic(io.EOF) 56 | } 57 | pr.readHeader() 58 | } 59 | n := len(buf) 60 | if n > pr.remain { 61 | n = pr.remain 62 | } 63 | n, err := pr.rd.Read(buf[:n]) 64 | pr.remain -= n 65 | if err != nil { 66 | panic(err) 67 | } 68 | buf = buf[n:] 69 | } 70 | return 71 | } 72 | 73 | func (pr *pktReader) readByte() byte { 74 | if pr.remain == 0 { 75 | if pr.last { 76 | // No more packets 77 | panic(io.EOF) 78 | } 79 | pr.readHeader() 80 | } 81 | b, err := pr.rd.ReadByte() 82 | if err != nil { 83 | panic(err) 84 | } 85 | pr.remain-- 86 | return b 87 | } 88 | 89 | func (pr *pktReader) readAll() (buf []byte) { 90 | m := 0 91 | for { 92 | if pr.remain == 0 { 93 | if pr.last { 94 | break 95 | } 96 | pr.readHeader() 97 | } 98 | new_buf := make([]byte, m+pr.remain) 99 | copy(new_buf, buf) 100 | buf = new_buf 101 | n, err := pr.rd.Read(buf[m:]) 102 | pr.remain -= n 103 | m += n 104 | if err != nil { 105 | panic(err) 106 | } 107 | } 108 | return 109 | } 110 | 111 | func (pr *pktReader) skipAll() { 112 | for { 113 | if pr.remain == 0 { 114 | if pr.last { 115 | break 116 | } 117 | pr.readHeader() 118 | } 119 | n, err := io.CopyN(ioutil.Discard, pr.rd, int64(pr.remain)) 120 | pr.remain -= int(n) 121 | if err != nil { 122 | panic(err) 123 | } 124 | } 125 | return 126 | } 127 | 128 | func (pr *pktReader) skipN(n int) { 129 | for n > 0 { 130 | if pr.remain == 0 { 131 | if pr.last { 132 | panic(io.EOF) 133 | } 134 | pr.readHeader() 135 | } 136 | m := int64(n) 137 | if n > pr.remain { 138 | m = int64(pr.remain) 139 | } 140 | m, err := io.CopyN(ioutil.Discard, pr.rd, m) 141 | pr.remain -= int(m) 142 | n -= int(m) 143 | if err != nil { 144 | panic(err) 145 | } 146 | } 147 | return 148 | } 149 | 150 | func (pr *pktReader) unreadByte() { 151 | if err := pr.rd.UnreadByte(); err != nil { 152 | panic(err) 153 | } 154 | pr.remain++ 155 | } 156 | 157 | func (pr *pktReader) eof() bool { 158 | return pr.remain == 0 && pr.last 159 | } 160 | 161 | func (pr *pktReader) checkEof() { 162 | if !pr.eof() { 163 | panic(mysql.ErrPktLong) 164 | } 165 | } 166 | 167 | type pktWriter struct { 168 | wr *bufio.Writer 169 | seq *byte 170 | remain int 171 | to_write int 172 | last bool 173 | buf [23]byte 174 | ibuf [3]byte 175 | } 176 | 177 | func (my *Conn) newPktWriter(to_write int) *pktWriter { 178 | return &pktWriter{wr: my.wr, seq: &my.seq, to_write: to_write} 179 | } 180 | 181 | func (pw *pktWriter) writeHeader(l int) { 182 | buf := pw.ibuf[:] 183 | EncodeU24(buf, uint32(l)) 184 | if _, err := pw.wr.Write(buf); err != nil { 185 | panic(err) 186 | } 187 | if err := pw.wr.WriteByte(*pw.seq); err != nil { 188 | panic(err) 189 | } 190 | // Update sequence number 191 | *pw.seq++ 192 | } 193 | 194 | func (pw *pktWriter) write(buf []byte) { 195 | if len(buf) == 0 { 196 | return 197 | } 198 | var nn int 199 | for len(buf) != 0 { 200 | if pw.remain == 0 { 201 | if pw.to_write == 0 { 202 | panic("too many data for write as packet") 203 | } 204 | if pw.to_write >= 0xffffff { 205 | pw.remain = 0xffffff 206 | } else { 207 | pw.remain = pw.to_write 208 | pw.last = true 209 | } 210 | pw.to_write -= pw.remain 211 | pw.writeHeader(pw.remain) 212 | } 213 | nn = len(buf) 214 | if nn > pw.remain { 215 | nn = pw.remain 216 | } 217 | var err error 218 | nn, err = pw.wr.Write(buf[0:nn]) 219 | pw.remain -= nn 220 | if err != nil { 221 | panic(err) 222 | } 223 | buf = buf[nn:] 224 | } 225 | if pw.remain+pw.to_write == 0 { 226 | if !pw.last { 227 | // Write header for empty packet 228 | pw.writeHeader(0) 229 | } 230 | // Flush bufio buffers 231 | if err := pw.wr.Flush(); err != nil { 232 | panic(err) 233 | } 234 | } 235 | return 236 | } 237 | 238 | func (pw *pktWriter) writeByte(b byte) { 239 | pw.buf[0] = b 240 | pw.write(pw.buf[:1]) 241 | } 242 | 243 | // n should be <= 23 244 | func (pw *pktWriter) writeZeros(n int) { 245 | buf := pw.buf[:n] 246 | for i := range buf { 247 | buf[i] = 0 248 | } 249 | pw.write(buf) 250 | } 251 | -------------------------------------------------------------------------------- /native/paramvalue.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "math" 6 | "reflect" 7 | "time" 8 | ) 9 | 10 | type paramValue struct { 11 | typ uint16 12 | addr reflect.Value 13 | raw bool 14 | length int // >=0 - length of value, <0 - unknown length 15 | } 16 | 17 | func (val *paramValue) Len() int { 18 | if !val.addr.IsValid() { 19 | // Invalid Value was binded 20 | return 0 21 | } 22 | // val.addr always points to the pointer - lets dereference it 23 | v := val.addr.Elem() 24 | if v.IsNil() { 25 | // Binded Ptr Value is nil 26 | return 0 27 | } 28 | v = v.Elem() 29 | 30 | if val.length >= 0 { 31 | return val.length 32 | } 33 | 34 | switch val.typ { 35 | case MYSQL_TYPE_STRING: 36 | return lenStr(v.String()) 37 | 38 | case MYSQL_TYPE_DATE: 39 | return lenDate(v.Interface().(mysql.Date)) 40 | 41 | case MYSQL_TYPE_TIMESTAMP: 42 | return lenTime(v.Interface().(mysql.Timestamp).Time) 43 | case MYSQL_TYPE_DATETIME: 44 | return lenTime(v.Interface().(time.Time)) 45 | 46 | case MYSQL_TYPE_TIME: 47 | return lenDuration(v.Interface().(time.Duration)) 48 | 49 | case MYSQL_TYPE_TINY: // val.length < 0 so this is bool 50 | return 1 51 | } 52 | // MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_BLOB and type of Raw value 53 | return lenBin(v.Bytes()) 54 | } 55 | 56 | func (pw *pktWriter) writeValue(val *paramValue) { 57 | if !val.addr.IsValid() { 58 | // Invalid Value was binded 59 | return 60 | } 61 | // val.addr always points to the pointer - lets dereference it 62 | v := val.addr.Elem() 63 | if v.IsNil() { 64 | // Binded Ptr Value is nil 65 | return 66 | } 67 | v = v.Elem() 68 | 69 | if val.raw || val.typ == MYSQL_TYPE_VAR_STRING || 70 | val.typ == MYSQL_TYPE_BLOB { 71 | pw.writeBin(v.Bytes()) 72 | return 73 | } 74 | // We don't need unsigned bit to check type 75 | unsign := (val.typ & MYSQL_UNSIGNED_MASK) != 0 76 | switch val.typ & ^MYSQL_UNSIGNED_MASK { 77 | case MYSQL_TYPE_NULL: 78 | // Don't write null values 79 | 80 | case MYSQL_TYPE_STRING: 81 | pw.writeBin([]byte(v.String())) 82 | 83 | case MYSQL_TYPE_LONG: 84 | i := v.Interface() 85 | if unsign { 86 | l, ok := i.(uint32) 87 | if !ok { 88 | l = uint32(i.(uint)) 89 | } 90 | pw.writeU32(l) 91 | } else { 92 | l, ok := i.(int32) 93 | if !ok { 94 | l = int32(i.(int)) 95 | } 96 | pw.writeU32(uint32(l)) 97 | } 98 | 99 | case MYSQL_TYPE_FLOAT: 100 | pw.writeU32(math.Float32bits(v.Interface().(float32))) 101 | 102 | case MYSQL_TYPE_SHORT: 103 | if unsign { 104 | pw.writeU16(v.Interface().(uint16)) 105 | } else { 106 | pw.writeU16(uint16(v.Interface().(int16))) 107 | 108 | } 109 | 110 | case MYSQL_TYPE_TINY: 111 | if val.length == -1 { 112 | // Translate bool value to MySQL tiny 113 | if v.Bool() { 114 | pw.writeByte(1) 115 | } else { 116 | pw.writeByte(0) 117 | } 118 | } else { 119 | if unsign { 120 | pw.writeByte(v.Interface().(uint8)) 121 | } else { 122 | pw.writeByte(uint8(v.Interface().(int8))) 123 | } 124 | } 125 | 126 | case MYSQL_TYPE_LONGLONG: 127 | i := v.Interface() 128 | if unsign { 129 | l, ok := i.(uint64) 130 | if !ok { 131 | l = uint64(i.(uint)) 132 | } 133 | pw.writeU64(l) 134 | } else { 135 | l, ok := i.(int64) 136 | if !ok { 137 | l = int64(i.(int)) 138 | } 139 | pw.writeU64(uint64(l)) 140 | } 141 | 142 | case MYSQL_TYPE_DOUBLE: 143 | pw.writeU64(math.Float64bits(v.Interface().(float64))) 144 | 145 | case MYSQL_TYPE_DATE: 146 | pw.writeDate(v.Interface().(mysql.Date)) 147 | 148 | case MYSQL_TYPE_TIMESTAMP: 149 | pw.writeTime(v.Interface().(mysql.Timestamp).Time) 150 | 151 | case MYSQL_TYPE_DATETIME: 152 | pw.writeTime(v.Interface().(time.Time)) 153 | 154 | case MYSQL_TYPE_TIME: 155 | pw.writeDuration(v.Interface().(time.Duration)) 156 | 157 | default: 158 | panic(mysql.ErrBindUnkType) 159 | } 160 | return 161 | } 162 | 163 | // encodes a uint64 value and appends it to the given bytes slice 164 | func appendLengthEncodedInteger(b []byte, n uint64) []byte { 165 | switch { 166 | case n <= 250: 167 | return append(b, byte(n)) 168 | 169 | case n <= 0xffff: 170 | return append(b, 0xfc, byte(n), byte(n>>8)) 171 | 172 | case n <= 0xffffff: 173 | return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) 174 | } 175 | return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), 176 | byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) 177 | } -------------------------------------------------------------------------------- /native/passwd.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "crypto/sha1" 5 | "crypto/sha256" 6 | "math" 7 | ) 8 | 9 | // Borrowed from GoMySQL 10 | // SHA1(SHA1(SHA1(password)), scramble) XOR SHA1(password) 11 | func encryptedPasswd(password string, scramble []byte) (out []byte) { 12 | if len(password) == 0 { 13 | return 14 | } 15 | // stage1_hash = SHA1(password) 16 | // SHA1 encode 17 | crypt := sha1.New() 18 | crypt.Write([]byte(password)) 19 | stg1Hash := crypt.Sum(nil) 20 | // token = SHA1(SHA1(stage1_hash), scramble) XOR stage1_hash 21 | // SHA1 encode again 22 | crypt.Reset() 23 | crypt.Write(stg1Hash) 24 | stg2Hash := crypt.Sum(nil) 25 | // SHA1 2nd hash and scramble 26 | crypt.Reset() 27 | crypt.Write(scramble) 28 | crypt.Write(stg2Hash) 29 | stg3Hash := crypt.Sum(nil) 30 | // XOR with first hash 31 | out = make([]byte, len(scramble)) 32 | for ii := range scramble { 33 | out[ii] = stg3Hash[ii] ^ stg1Hash[ii] 34 | } 35 | return 36 | } 37 | 38 | // Hash password using MySQL 8+ method (SHA256) 39 | func encryptedSHA256Passwd(password string, scramble []byte) []byte { 40 | if len(password) == 0 { 41 | return nil 42 | } 43 | 44 | // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) 45 | 46 | crypt := sha256.New() 47 | crypt.Write([]byte(password)) 48 | message1 := crypt.Sum(nil) 49 | 50 | crypt.Reset() 51 | crypt.Write(message1) 52 | message1Hash := crypt.Sum(nil) 53 | 54 | crypt.Reset() 55 | crypt.Write(message1Hash) 56 | crypt.Write(scramble) 57 | message2 := crypt.Sum(nil) 58 | 59 | for i := range message1 { 60 | message1[i] ^= message2[i] 61 | } 62 | 63 | return message1 64 | } 65 | 66 | // Old password handling based on translating to Go some functions from 67 | // libmysql 68 | 69 | // The main idea is that no password are sent between client & server on 70 | // connection and that no password are saved in mysql in a decodable form. 71 | // 72 | // On connection a random string is generated and sent to the client. 73 | // The client generates a new string with a random generator inited with 74 | // the hash values from the password and the sent string. 75 | // This 'check' string is sent to the server where it is compared with 76 | // a string generated from the stored hash_value of the password and the 77 | // random string. 78 | 79 | // libmysql/my_rnd.c 80 | type myRnd struct { 81 | seed1, seed2 uint32 82 | } 83 | 84 | const myRndMaxVal = 0x3FFFFFFF 85 | 86 | func newMyRnd(seed1, seed2 uint32) *myRnd { 87 | r := new(myRnd) 88 | r.seed1 = seed1 % myRndMaxVal 89 | r.seed2 = seed2 % myRndMaxVal 90 | return r 91 | } 92 | 93 | func (r *myRnd) Float64() float64 { 94 | r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal 95 | r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal 96 | return float64(r.seed1) / myRndMaxVal 97 | } 98 | 99 | // libmysql/password.c 100 | func pwHash(password []byte) (result [2]uint32) { 101 | var nr, add, nr2, tmp uint32 102 | nr, add, nr2 = 1345345333, 7, 0x12345671 103 | 104 | for _, c := range password { 105 | if c == ' ' || c == '\t' { 106 | continue // skip space in password 107 | } 108 | 109 | tmp = uint32(c) 110 | nr ^= (((nr & 63) + add) * tmp) + (nr << 8) 111 | nr2 += (nr2 << 8) ^ nr 112 | add += tmp 113 | } 114 | 115 | result[0] = nr & ((1 << 31) - 1) // Don't use sign bit (str2int) 116 | result[1] = nr2 & ((1 << 31) - 1) 117 | return 118 | } 119 | 120 | func encryptedOldPassword(password string, scramble []byte) []byte { 121 | if len(password) == 0 { 122 | return nil 123 | } 124 | scramble = scramble[:8] 125 | hashPw := pwHash([]byte(password)) 126 | hashSc := pwHash(scramble) 127 | r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) 128 | var out [8]byte 129 | for i := range out { 130 | out[i] = byte(math.Floor(r.Float64()*31) + 64) 131 | } 132 | extra := byte(math.Floor(r.Float64() * 31)) 133 | for i := range out { 134 | out[i] ^= extra 135 | } 136 | return out[:] 137 | } 138 | -------------------------------------------------------------------------------- /native/prepared.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "log" 6 | ) 7 | 8 | type Stmt struct { 9 | my *Conn 10 | 11 | id uint32 12 | sql string // For reprepare during reconnect 13 | 14 | params []paramValue // Parameters binding 15 | rebind bool 16 | binded bool 17 | 18 | fields []*mysql.Field 19 | 20 | field_count int 21 | param_count int 22 | warning_count int 23 | status mysql.ConnStatus 24 | 25 | null_bitmap []byte 26 | } 27 | 28 | func (stmt *Stmt) Fields() []*mysql.Field { 29 | return stmt.fields 30 | } 31 | 32 | func (stmt *Stmt) NumParam() int { 33 | return stmt.param_count 34 | } 35 | 36 | func (stmt *Stmt) WarnCount() int { 37 | return stmt.warning_count 38 | } 39 | 40 | func (stmt *Stmt) sendCmdExec() { 41 | // Calculate packet length and NULL bitmap 42 | pkt_len := 1 + 4 + 1 + 4 + 1 + len(stmt.null_bitmap) 43 | for ii := range stmt.null_bitmap { 44 | stmt.null_bitmap[ii] = 0 45 | } 46 | for ii, param := range stmt.params { 47 | par_len := param.Len() 48 | pkt_len += par_len 49 | if par_len == 0 { 50 | null_byte := ii >> 3 51 | null_mask := byte(1) << uint(ii-(null_byte<<3)) 52 | stmt.null_bitmap[null_byte] |= null_mask 53 | } 54 | } 55 | if stmt.rebind { 56 | pkt_len += stmt.param_count * 2 57 | } 58 | // Reset sequence number 59 | stmt.my.seq = 0 60 | // Packet sending 61 | pw := stmt.my.newPktWriter(pkt_len) 62 | pw.writeByte(_COM_STMT_EXECUTE) 63 | pw.writeU32(stmt.id) 64 | pw.writeByte(0) // flags = CURSOR_TYPE_NO_CURSOR 65 | pw.writeU32(1) // iteration_count 66 | pw.write(stmt.null_bitmap) 67 | if stmt.rebind { 68 | pw.writeByte(1) 69 | // Types 70 | for _, param := range stmt.params { 71 | pw.writeU16(param.typ) 72 | } 73 | } else { 74 | pw.writeByte(0) 75 | } 76 | // Values 77 | for i := range stmt.params { 78 | pw.writeValue(&stmt.params[i]) 79 | } 80 | 81 | if stmt.my.Debug { 82 | log.Printf("[%2d <-] Exec command packet: len=%d, null_bitmap=%v, rebind=%t", 83 | stmt.my.seq-1, pkt_len, stmt.null_bitmap, stmt.rebind) 84 | } 85 | 86 | // Mark that we sended information about binded types 87 | stmt.rebind = false 88 | } 89 | 90 | func (my *Conn) getPrepareResult(stmt *Stmt) interface{} { 91 | loop: 92 | pr := my.newPktReader() // New reader for next packet 93 | pkt0 := pr.readByte() 94 | 95 | //log.Println("pkt0:", pkt0, "stmt:", stmt) 96 | 97 | if pkt0 == 255 { 98 | // Error packet 99 | my.getErrorPacket(pr) 100 | } 101 | 102 | if stmt == nil { 103 | if pkt0 == 0 { 104 | // OK packet 105 | return my.getPrepareOkPacket(pr) 106 | } 107 | } else { 108 | unreaded_params := (stmt.param_count < len(stmt.params)) 109 | switch { 110 | case pkt0 == 254: 111 | // EOF packet 112 | stmt.warning_count, stmt.status = my.getEofPacket(pr) 113 | stmt.my.status = stmt.status 114 | return stmt 115 | 116 | case pkt0 > 0 && pkt0 < 251 && (stmt.field_count < len(stmt.fields) || 117 | unreaded_params): 118 | // Field packet 119 | if unreaded_params { 120 | // Read and ignore parameter field. Sentence from MySQL source: 121 | /* skip parameters data: we don't support it yet */ 122 | pr.skipAll() 123 | // Increment param_count count 124 | stmt.param_count++ 125 | } else { 126 | field := my.getFieldPacket(pr) 127 | stmt.fields[stmt.field_count] = field 128 | // Increment field count 129 | stmt.field_count++ 130 | } 131 | // Read next packet 132 | goto loop 133 | } 134 | } 135 | panic(mysql.ErrUnkResultPkt) 136 | } 137 | 138 | func (my *Conn) getPrepareOkPacket(pr *pktReader) (stmt *Stmt) { 139 | if my.Debug { 140 | log.Printf("[%2d ->] Perpared OK packet:", my.seq-1) 141 | } 142 | 143 | stmt = new(Stmt) 144 | stmt.my = my 145 | // First byte was readed by getPrepRes 146 | stmt.id = pr.readU32() 147 | stmt.fields = make([]*mysql.Field, int(pr.readU16())) // FieldCount 148 | pl := int(pr.readU16()) // ParamCount 149 | if pl > 0 { 150 | stmt.params = make([]paramValue, pl) 151 | stmt.null_bitmap = make([]byte, (pl+7)>>3) 152 | } 153 | pr.skipN(1) 154 | stmt.warning_count = int(pr.readU16()) 155 | pr.checkEof() 156 | 157 | if my.Debug { 158 | log.Printf(tab8s+"ID=0x%x ParamCount=%d FieldsCount=%d WarnCount=%d", 159 | stmt.id, len(stmt.params), len(stmt.fields), stmt.warning_count, 160 | ) 161 | } 162 | return 163 | } 164 | -------------------------------------------------------------------------------- /native/result.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "github.com/ziutek/mymysql/mysql" 7 | "log" 8 | "math" 9 | "strconv" 10 | ) 11 | 12 | type Result struct { 13 | my *Conn 14 | status_only bool // true if result doesn't contain result set 15 | binary bool // Binary result expected 16 | 17 | field_count int 18 | fields []*mysql.Field // Fields table 19 | fc_map map[string]int // Maps field name to column number 20 | 21 | message []byte 22 | affected_rows uint64 23 | 24 | // Primary key value (useful for AUTO_INCREMENT primary keys) 25 | insert_id uint64 26 | 27 | // Number of warinigs during command execution 28 | // You can use the SHOW WARNINGS query for details. 29 | warning_count int 30 | 31 | // MySQL server status immediately after the query execution 32 | status mysql.ConnStatus 33 | 34 | // Seted by GetRow if it returns nil row 35 | eor_returned bool 36 | } 37 | 38 | // StatusOnly returns true if this is status result that includes no result set 39 | func (res *Result) StatusOnly() bool { 40 | return res.status_only 41 | } 42 | 43 | // Fields returns a table containing descriptions of the columns 44 | func (res *Result) Fields() []*mysql.Field { 45 | return res.fields 46 | } 47 | 48 | // Map returns index for given name or -1 if field of that name doesn't exist 49 | func (res *Result) Map(field_name string) int { 50 | if fi, ok := res.fc_map[field_name]; ok { 51 | return fi 52 | } 53 | return -1 54 | } 55 | 56 | func (res *Result) Message() string { 57 | return string(res.message) 58 | } 59 | 60 | func (res *Result) AffectedRows() uint64 { 61 | return res.affected_rows 62 | } 63 | 64 | func (res *Result) InsertId() uint64 { 65 | return res.insert_id 66 | } 67 | 68 | func (res *Result) WarnCount() int { 69 | return res.warning_count 70 | } 71 | 72 | func (res *Result) MakeRow() mysql.Row { 73 | return make(mysql.Row, res.field_count) 74 | } 75 | 76 | // getAuthResult After sending login request 77 | // use this func get server return packet 78 | func (my *Conn) getAuthResult() ([]byte, string) { 79 | pr := my.newPktReader() 80 | pkt := pr.readAll() 81 | pkt0 := pkt[0] 82 | 83 | // packet indicator 84 | switch pkt0 { 85 | case 0: // OK 86 | return nil, "" 87 | 88 | case 1: // AuthMoreData 89 | return pkt[1:], "" 90 | 91 | case 254: // EOF 92 | if len(pkt) == 1 { 93 | // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest 94 | return nil, "mysql_old_password" 95 | } 96 | pluginEndIndex := bytes.IndexByte(pkt, 0x00) 97 | if pluginEndIndex < 0 { 98 | return nil, "" 99 | } 100 | plugin := string(pkt[1:pluginEndIndex]) 101 | authData := pkt[pluginEndIndex+1:] 102 | return authData, plugin 103 | 104 | case 255: // Error packet 105 | panic(mysql.ErrAuthentication) 106 | return nil, "" 107 | 108 | default: // Error otherwise 109 | panic(mysql.ErrUnkResultPkt) 110 | return nil, "" 111 | } 112 | } 113 | 114 | func (my *Conn) getResult(res *Result, row mysql.Row) *Result { 115 | loop: 116 | pr := my.newPktReader() // New reader for next packet 117 | pkt0 := pr.readByte() 118 | 119 | if pkt0 == 255 { 120 | // Error packet 121 | my.getErrorPacket(pr) 122 | } 123 | 124 | if res == nil { 125 | switch { 126 | case pkt0 == 0: 127 | // OK packet 128 | return my.getOkPacket(pr) 129 | 130 | case pkt0 > 0 && pkt0 < 251: 131 | // Result set header packet 132 | res = my.getResSetHeadPacket(pr) 133 | // Read next packet 134 | goto loop 135 | case pkt0 == 251: 136 | // Load infile response 137 | // Handle response 138 | goto loop 139 | case pkt0 == 254: 140 | // EOF packet (without body) 141 | return nil 142 | } 143 | } else { 144 | switch { 145 | case pkt0 == 254: 146 | // EOF packet 147 | res.warning_count, res.status = my.getEofPacket(pr) 148 | my.status = res.status 149 | return res 150 | 151 | case pkt0 > 0 && pkt0 < 251 && res.field_count < len(res.fields): 152 | // Field packet 153 | field := my.getFieldPacket(pr) 154 | res.fields[res.field_count] = field 155 | res.fc_map[field.Name] = res.field_count 156 | // Increment field count 157 | res.field_count++ 158 | // Read next packet 159 | goto loop 160 | 161 | case pkt0 < 254 && res.field_count == len(res.fields): 162 | // Row Data Packet 163 | if len(row) != res.field_count { 164 | panic(mysql.ErrRowLength) 165 | } 166 | if res.binary { 167 | my.getBinRowPacket(pr, res, row) 168 | } else { 169 | my.getTextRowPacket(pr, res, row) 170 | } 171 | return nil 172 | } 173 | } 174 | panic(mysql.ErrUnkResultPkt) 175 | } 176 | 177 | func (my *Conn) getOkPacket(pr *pktReader) (res *Result) { 178 | if my.Debug { 179 | log.Printf("[%2d ->] OK packet:", my.seq-1) 180 | } 181 | res = new(Result) 182 | res.status_only = true 183 | res.my = my 184 | // First byte was readed by getResult 185 | res.affected_rows = pr.readLCB() 186 | res.insert_id = pr.readLCB() 187 | res.status = mysql.ConnStatus(pr.readU16()) 188 | my.status = res.status 189 | res.warning_count = int(pr.readU16()) 190 | res.message = pr.readAll() 191 | pr.checkEof() 192 | 193 | if my.Debug { 194 | log.Printf(tab8s+"AffectedRows=%d InsertId=0x%x Status=0x%x "+ 195 | "WarningCount=%d Message=\"%s\"", res.affected_rows, res.insert_id, 196 | res.status, res.warning_count, res.message, 197 | ) 198 | } 199 | return 200 | } 201 | 202 | func (my *Conn) getErrorPacket(pr *pktReader) { 203 | if my.Debug { 204 | log.Printf("[%2d ->] Error packet:", my.seq-1) 205 | } 206 | var err mysql.Error 207 | err.Code = pr.readU16() 208 | if pr.readByte() != '#' { 209 | panic(mysql.ErrPkt) 210 | } 211 | pr.skipN(5) 212 | err.Msg = pr.readAll() 213 | pr.checkEof() 214 | 215 | if my.Debug { 216 | log.Printf(tab8s+"code=0x%x msg=\"%s\"", err.Code, err.Msg) 217 | } 218 | panic(&err) 219 | } 220 | 221 | func (my *Conn) getEofPacket(pr *pktReader) (warn_count int, status mysql.ConnStatus) { 222 | if my.Debug { 223 | if pr.eof() { 224 | log.Printf("[%2d ->] EOF packet without body", my.seq-1) 225 | } else { 226 | log.Printf("[%2d ->] EOF packet:", my.seq-1) 227 | } 228 | } 229 | if pr.eof() { 230 | return 231 | } 232 | warn_count = int(pr.readU16()) 233 | if pr.eof() { 234 | return 235 | } 236 | status = mysql.ConnStatus(pr.readU16()) 237 | pr.checkEof() 238 | 239 | if my.Debug { 240 | log.Printf(tab8s+"WarningCount=%d Status=0x%x", warn_count, status) 241 | } 242 | return 243 | } 244 | 245 | func (my *Conn) getResSetHeadPacket(pr *pktReader) (res *Result) { 246 | if my.Debug { 247 | log.Printf("[%2d ->] Result set header packet:", my.seq-1) 248 | } 249 | pr.unreadByte() 250 | 251 | field_count := int(pr.readLCB()) 252 | pr.checkEof() 253 | 254 | res = &Result{ 255 | my: my, 256 | fields: make([]*mysql.Field, field_count), 257 | fc_map: make(map[string]int), 258 | } 259 | 260 | if my.Debug { 261 | log.Printf(tab8s+"FieldCount=%d", field_count) 262 | } 263 | return 264 | } 265 | 266 | func (my *Conn) getFieldPacket(pr *pktReader) (field *mysql.Field) { 267 | if my.Debug { 268 | log.Printf("[%2d ->] Field packet:", my.seq-1) 269 | } 270 | pr.unreadByte() 271 | 272 | field = new(mysql.Field) 273 | if my.fullFieldInfo { 274 | field.Catalog = string(pr.readBin()) 275 | field.Db = string(pr.readBin()) 276 | field.Table = string(pr.readBin()) 277 | field.OrgTable = string(pr.readBin()) 278 | } else { 279 | pr.skipBin() 280 | pr.skipBin() 281 | pr.skipBin() 282 | pr.skipBin() 283 | } 284 | field.Name = string(pr.readBin()) 285 | if my.fullFieldInfo { 286 | field.OrgName = string(pr.readBin()) 287 | } else { 288 | pr.skipBin() 289 | } 290 | pr.skipN(1 + 2) 291 | //field.Charset= pr.readU16() 292 | field.DispLen = pr.readU32() 293 | field.Type = pr.readByte() 294 | field.Flags = pr.readU16() 295 | field.Scale = pr.readByte() 296 | pr.skipN(2) 297 | pr.checkEof() 298 | 299 | if my.Debug { 300 | log.Printf(tab8s+"Name=\"%s\" Type=0x%x", field.Name, field.Type) 301 | } 302 | return 303 | } 304 | 305 | func (my *Conn) getTextRowPacket(pr *pktReader, res *Result, row mysql.Row) { 306 | if my.Debug { 307 | log.Printf("[%2d ->] Text row data packet", my.seq-1) 308 | } 309 | pr.unreadByte() 310 | 311 | for ii := 0; ii < res.field_count; ii++ { 312 | bin, null := pr.readNullBin() 313 | if null { 314 | row[ii] = nil 315 | } else { 316 | row[ii] = bin 317 | } 318 | } 319 | pr.checkEof() 320 | } 321 | 322 | func (my *Conn) getBinRowPacket(pr *pktReader, res *Result, row mysql.Row) { 323 | if my.Debug { 324 | log.Printf("[%2d ->] Binary row data packet", my.seq-1) 325 | } 326 | // First byte was readed by getResult 327 | 328 | null_bitmap := make([]byte, (res.field_count+7+2)>>3) 329 | pr.readFull(null_bitmap) 330 | 331 | for ii, field := range res.fields { 332 | null_byte := (ii + 2) >> 3 333 | null_mask := byte(1) << uint(2+ii-(null_byte<<3)) 334 | if null_bitmap[null_byte]&null_mask != 0 { 335 | // Null field 336 | row[ii] = nil 337 | continue 338 | } 339 | unsigned := (field.Flags & _FLAG_UNSIGNED) != 0 340 | if my.narrowTypeSet { 341 | row[ii] = readValueNarrow(pr, field.Type, unsigned) 342 | } else { 343 | row[ii] = readValue(pr, field.Type, unsigned) 344 | } 345 | } 346 | } 347 | 348 | func readValue(pr *pktReader, typ byte, unsigned bool) interface{} { 349 | switch typ { 350 | case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR, 351 | MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB, 352 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET, 353 | MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY: 354 | return pr.readBin() 355 | case MYSQL_TYPE_TINY: 356 | if unsigned { 357 | return pr.readByte() 358 | } else { 359 | return int8(pr.readByte()) 360 | } 361 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: 362 | if unsigned { 363 | return pr.readU16() 364 | } else { 365 | return int16(pr.readU16()) 366 | } 367 | case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24: 368 | if unsigned { 369 | return pr.readU32() 370 | } else { 371 | return int32(pr.readU32()) 372 | } 373 | case MYSQL_TYPE_LONGLONG: 374 | if unsigned { 375 | return pr.readU64() 376 | } else { 377 | return int64(pr.readU64()) 378 | } 379 | case MYSQL_TYPE_FLOAT: 380 | return math.Float32frombits(pr.readU32()) 381 | case MYSQL_TYPE_DOUBLE: 382 | return math.Float64frombits(pr.readU64()) 383 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL: 384 | dec := string(pr.readBin()) 385 | r, err := strconv.ParseFloat(dec, 64) 386 | if err != nil { 387 | panic(errors.New("MySQL server returned wrong decimal value: " + dec)) 388 | } 389 | return r 390 | case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE: 391 | return pr.readDate() 392 | case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP: 393 | return pr.readTime() 394 | case MYSQL_TYPE_TIME: 395 | return pr.readDuration() 396 | } 397 | panic(mysql.ErrUnkMySQLType) 398 | } 399 | 400 | func readValueNarrow(pr *pktReader, typ byte, unsigned bool) interface{} { 401 | switch typ { 402 | case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR, 403 | MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB, 404 | MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET, 405 | MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY: 406 | return pr.readBin() 407 | case MYSQL_TYPE_TINY: 408 | if unsigned { 409 | return int64(pr.readByte()) 410 | } 411 | return int64(int8(pr.readByte())) 412 | case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: 413 | if unsigned { 414 | return int64(pr.readU16()) 415 | } 416 | return int64(int16(pr.readU16())) 417 | case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24: 418 | if unsigned { 419 | return int64(pr.readU32()) 420 | } 421 | return int64(int32(pr.readU32())) 422 | case MYSQL_TYPE_LONGLONG: 423 | v := pr.readU64() 424 | if unsigned && v > math.MaxInt64 { 425 | panic(errors.New("Value to large for int64 type")) 426 | } 427 | return int64(v) 428 | case MYSQL_TYPE_FLOAT: 429 | return float64(math.Float32frombits(pr.readU32())) 430 | case MYSQL_TYPE_DOUBLE: 431 | return math.Float64frombits(pr.readU64()) 432 | case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL: 433 | dec := string(pr.readBin()) 434 | r, err := strconv.ParseFloat(dec, 64) 435 | if err != nil { 436 | panic("MySQL server returned wrong decimal value: " + dec) 437 | } 438 | return r 439 | case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE: 440 | return pr.readTime() 441 | case MYSQL_TYPE_TIME: 442 | return int64(pr.readDuration()) 443 | } 444 | panic(mysql.ErrUnkMySQLType) 445 | } 446 | -------------------------------------------------------------------------------- /native/unsafe.go-disabled: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "time" 6 | "unsafe" 7 | ) 8 | 9 | type paramValue struct { 10 | typ uint16 11 | addr unsafe.Pointer 12 | raw bool 13 | length int // >=0 - length of value, <0 - unknown length 14 | } 15 | 16 | func (pv *paramValue) SetAddr(addr uintptr) { 17 | pv.addr = unsafe.Pointer(addr) 18 | } 19 | 20 | func (val *paramValue) Len() int { 21 | if val.addr == nil { 22 | // Invalid Value was binded 23 | return 0 24 | } 25 | // val.addr always points to the pointer - lets dereference it 26 | ptr := *(*unsafe.Pointer)(val.addr) 27 | if ptr == nil { 28 | // Binded Ptr Value is nil 29 | return 0 30 | } 31 | 32 | if val.length >= 0 { 33 | return val.length 34 | } 35 | 36 | switch val.typ { 37 | case MYSQL_TYPE_STRING: 38 | return lenStr(*(*string)(ptr)) 39 | 40 | case MYSQL_TYPE_DATE: 41 | return lenDate(*(*mysql.Date)(ptr)) 42 | 43 | case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME: 44 | return lenTime(*(*time.Time)(ptr)) 45 | 46 | case MYSQL_TYPE_TIME: 47 | return lenDuration(*(*time.Duration)(ptr)) 48 | 49 | case MYSQL_TYPE_TINY: // val.length < 0 so this is bool 50 | return 1 51 | } 52 | // MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_BLOB and type of Raw value 53 | return lenBin(*(*[]byte)(ptr)) 54 | } 55 | 56 | func (pw *pktWriter) writeValue(val *paramValue) { 57 | if val.addr == nil { 58 | // Invalid Value was binded 59 | return 60 | } 61 | // val.addr always points to the pointer - lets dereference it 62 | ptr := *(*unsafe.Pointer)(val.addr) 63 | if ptr == nil { 64 | // Binded Ptr Value is nil 65 | return 66 | } 67 | 68 | if val.raw || val.typ == MYSQL_TYPE_VAR_STRING || 69 | val.typ == MYSQL_TYPE_BLOB { 70 | pw.writeBin(*(*[]byte)(ptr)) 71 | return 72 | } 73 | // We don't need unsigned bit to check type 74 | switch val.typ & ^MYSQL_UNSIGNED_MASK { 75 | case MYSQL_TYPE_NULL: 76 | // Don't write null values 77 | 78 | case MYSQL_TYPE_STRING: 79 | s := *(*string)(ptr) 80 | pw.writeBin([]byte(s)) 81 | 82 | case MYSQL_TYPE_LONG, MYSQL_TYPE_FLOAT: 83 | pw.writeU32(*(*uint32)(ptr)) 84 | 85 | case MYSQL_TYPE_SHORT: 86 | pw.writeU16(*(*uint16)(ptr)) 87 | 88 | case MYSQL_TYPE_TINY: 89 | if val.length == -1 { 90 | // Translate bool value to MySQL tiny 91 | if *(*bool)(ptr) { 92 | pw.writeByte(1) 93 | } else { 94 | pw.writeByte(0) 95 | } 96 | } else { 97 | pw.writeByte(*(*byte)(ptr)) 98 | } 99 | 100 | case MYSQL_TYPE_LONGLONG, MYSQL_TYPE_DOUBLE: 101 | pw.writeU64(*(*uint64)(ptr)) 102 | 103 | case MYSQL_TYPE_DATE: 104 | pw.writeDate(*(*mysql.Date)(ptr)) 105 | 106 | case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME: 107 | pw.writeTime(*(*time.Time)(ptr)) 108 | 109 | case MYSQL_TYPE_TIME: 110 | pw.writeDuration(*(*time.Duration)(ptr)) 111 | 112 | default: 113 | panic(mysql.ErrBindUnkType) 114 | } 115 | return 116 | } 117 | -------------------------------------------------------------------------------- /thrsafe/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010, Michal Derkacz 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 3. The name of the author may not be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /thrsafe/thrsafe.go: -------------------------------------------------------------------------------- 1 | // Package thrsafe is a thread safe engine for MyMySQL. 2 | // 3 | // In contrast to native engine: 4 | // 5 | // - one connection can be used by multiple gorutines, 6 | // 7 | // - if connection is idle pings are sent to the server (once per minute) to avoid timeout. 8 | // 9 | // See documentation of mymysql/native for details. 10 | package thrsafe 11 | 12 | import ( 13 | "io" 14 | "sync" 15 | "time" 16 | 17 | "github.com/ziutek/mymysql/mysql" 18 | _ "github.com/ziutek/mymysql/native" 19 | ) 20 | 21 | // Conn is a thread safe connection type. 22 | type Conn struct { 23 | mysql.Conn 24 | mutex *sync.Mutex 25 | 26 | stopPinger chan struct{} 27 | lastUsed time.Time 28 | } 29 | 30 | func (c *Conn) lock() { 31 | //log.Println(c, ":: lock @", c.mutex) 32 | c.mutex.Lock() 33 | } 34 | 35 | func (c *Conn) unlock() { 36 | //log.Println(c, ":: unlock @", c.mutex) 37 | c.lastUsed = time.Now() 38 | c.mutex.Unlock() 39 | } 40 | 41 | // Result is a thread safe result type. 42 | type Result struct { 43 | mysql.Result 44 | conn *Conn 45 | } 46 | 47 | // Stmt is a thread safe statement type. 48 | type Stmt struct { 49 | mysql.Stmt 50 | conn *Conn 51 | } 52 | 53 | // Transaction is a thread safe transaction type. 54 | type Transaction struct { 55 | *Conn 56 | conn *Conn 57 | } 58 | 59 | // New creates a new thread safe connection. 60 | func New(proto, laddr, raddr, user, passwd string, db ...string) mysql.Conn { 61 | return &Conn{ 62 | Conn: orgNew(proto, laddr, raddr, user, passwd, db...), 63 | mutex: new(sync.Mutex), 64 | } 65 | } 66 | 67 | func (c *Conn) Clone() mysql.Conn { 68 | return &Conn{ 69 | Conn: c.Conn.Clone(), 70 | mutex: new(sync.Mutex), 71 | } 72 | } 73 | 74 | func (c *Conn) pinger() { 75 | const to = 60 * time.Second 76 | sleep := to 77 | for { 78 | timer := time.After(sleep) 79 | select { 80 | case <-c.stopPinger: 81 | return 82 | case t := <-timer: 83 | c.mutex.Lock() 84 | lastUsed := c.lastUsed 85 | c.mutex.Unlock() 86 | sleep := to - t.Sub(lastUsed) 87 | if sleep <= 0 { 88 | if c.Ping() != nil { 89 | return 90 | } 91 | sleep = to 92 | } 93 | } 94 | } 95 | } 96 | 97 | func (c *Conn) Connect() error { 98 | //log.Println("Connect") 99 | c.lock() 100 | defer c.unlock() 101 | c.stopPinger = make(chan struct{}) 102 | go c.pinger() 103 | return c.Conn.Connect() 104 | } 105 | 106 | // Close closes the connection. 107 | func (c *Conn) Close() error { 108 | //log.Println("Close") 109 | close(c.stopPinger) // Stop pinger before lock connection 110 | c.lock() 111 | defer c.unlock() 112 | return c.Conn.Close() 113 | } 114 | 115 | func (c *Conn) Reconnect() error { 116 | //log.Println("Reconnect") 117 | c.lock() 118 | defer c.unlock() 119 | if c.stopPinger == nil { 120 | go c.pinger() 121 | } 122 | return c.Conn.Reconnect() 123 | } 124 | 125 | func (c *Conn) Use(dbname string) error { 126 | //log.Println("Use") 127 | c.lock() 128 | defer c.unlock() 129 | return c.Conn.Use(dbname) 130 | } 131 | 132 | func (c *Conn) Start(sql string, params ...interface{}) (mysql.Result, error) { 133 | //log.Println("Start") 134 | c.lock() 135 | res, err := c.Conn.Start(sql, params...) 136 | // Unlock if error or OK result (which doesn't provide any fields) 137 | if err != nil { 138 | c.unlock() 139 | return nil, err 140 | } 141 | if res.StatusOnly() && !res.MoreResults() { 142 | c.unlock() 143 | } 144 | return &Result{Result: res, conn: c}, err 145 | } 146 | 147 | func (c *Conn) Status() mysql.ConnStatus { 148 | c.lock() 149 | defer c.unlock() 150 | return c.Conn.Status() 151 | } 152 | 153 | func (c *Conn) Escape(txt string) string { 154 | return mysql.Escape(c, txt) 155 | } 156 | 157 | func (res *Result) ScanRow(row mysql.Row) error { 158 | //log.Println("ScanRow") 159 | err := res.Result.ScanRow(row) 160 | if err == nil { 161 | // There are more rows to read 162 | return nil 163 | } 164 | if err == mysql.ErrReadAfterEOR { 165 | // Trying read after EOR - connection unlocked before 166 | return err 167 | } 168 | if err != io.EOF || !res.StatusOnly() && !res.MoreResults() { 169 | // Error or no more rows in not empty result set and no more resutls. 170 | // In case if empty result set and no more resutls Start has unlocked 171 | // it before. 172 | res.conn.unlock() 173 | } 174 | return err 175 | } 176 | 177 | func (res *Result) GetRow() (mysql.Row, error) { 178 | return mysql.GetRow(res) 179 | } 180 | 181 | func (res *Result) NextResult() (mysql.Result, error) { 182 | //log.Println("NextResult") 183 | next, err := res.Result.NextResult() 184 | if err != nil { 185 | return nil, err 186 | } 187 | if next == nil { 188 | return nil, nil 189 | } 190 | if next.StatusOnly() && !next.MoreResults() { 191 | res.conn.unlock() 192 | } 193 | return &Result{next, res.conn}, nil 194 | } 195 | 196 | func (c *Conn) Ping() error { 197 | c.lock() 198 | defer c.unlock() 199 | return c.Conn.Ping() 200 | } 201 | 202 | func (c *Conn) Prepare(sql string) (mysql.Stmt, error) { 203 | //log.Println("Prepare") 204 | c.lock() 205 | defer c.unlock() 206 | stmt, err := c.Conn.Prepare(sql) 207 | if err != nil { 208 | return nil, err 209 | } 210 | return &Stmt{Stmt: stmt, conn: c}, nil 211 | } 212 | 213 | func (stmt *Stmt) Run(params ...interface{}) (mysql.Result, error) { 214 | //log.Println("Run") 215 | stmt.conn.lock() 216 | res, err := stmt.Stmt.Run(params...) 217 | // Unlock if error or OK result (which doesn't provide any fields) 218 | if err != nil { 219 | stmt.conn.unlock() 220 | return nil, err 221 | } 222 | if res.StatusOnly() && !res.MoreResults() { 223 | stmt.conn.unlock() 224 | } 225 | return &Result{Result: res, conn: stmt.conn}, nil 226 | } 227 | 228 | func (stmt *Stmt) Delete() error { 229 | //log.Println("Delete") 230 | stmt.conn.lock() 231 | defer stmt.conn.unlock() 232 | return stmt.Stmt.Delete() 233 | } 234 | 235 | func (stmt *Stmt) Reset() error { 236 | //log.Println("Reset") 237 | stmt.conn.lock() 238 | defer stmt.conn.unlock() 239 | return stmt.Stmt.Reset() 240 | } 241 | 242 | func (stmt *Stmt) SendLongData(pnum int, data interface{}, pkt_size int) error { 243 | //log.Println("SendLongData") 244 | stmt.conn.lock() 245 | defer stmt.conn.unlock() 246 | return stmt.Stmt.SendLongData(pnum, data, pkt_size) 247 | } 248 | 249 | // Query: See mysql.Query 250 | func (c *Conn) Query(sql string, params ...interface{}) ([]mysql.Row, mysql.Result, error) { 251 | return mysql.Query(c, sql, params...) 252 | } 253 | 254 | // QueryFirst: See mysql.QueryFirst 255 | func (my *Conn) QueryFirst(sql string, params ...interface{}) (mysql.Row, mysql.Result, error) { 256 | return mysql.QueryFirst(my, sql, params...) 257 | } 258 | 259 | // QueryLast: See mysql.QueryLast 260 | func (my *Conn) QueryLast(sql string, params ...interface{}) (mysql.Row, mysql.Result, error) { 261 | return mysql.QueryLast(my, sql, params...) 262 | } 263 | 264 | // Exec: See mysql.Exec 265 | func (stmt *Stmt) Exec(params ...interface{}) ([]mysql.Row, mysql.Result, error) { 266 | return mysql.Exec(stmt, params...) 267 | } 268 | 269 | // ExecFirst: See mysql.ExecFirst 270 | func (stmt *Stmt) ExecFirst(params ...interface{}) (mysql.Row, mysql.Result, error) { 271 | return mysql.ExecFirst(stmt, params...) 272 | } 273 | 274 | // ExecLast: See mysql.ExecLast 275 | func (stmt *Stmt) ExecLast(params ...interface{}) (mysql.Row, mysql.Result, error) { 276 | return mysql.ExecLast(stmt, params...) 277 | } 278 | 279 | // End: See mysql.End 280 | func (res *Result) End() error { 281 | return mysql.End(res) 282 | } 283 | 284 | // GetFirstRow: See mysql.GetFirstRow 285 | func (res *Result) GetFirstRow() (mysql.Row, error) { 286 | return mysql.GetFirstRow(res) 287 | } 288 | 289 | // GetLastRow: See mysql.GetLastRow 290 | func (res *Result) GetLastRow() (mysql.Row, error) { 291 | return mysql.GetLastRow(res) 292 | } 293 | 294 | // GetRows: See mysql.GetRows 295 | func (res *Result) GetRows() ([]mysql.Row, error) { 296 | return mysql.GetRows(res) 297 | } 298 | 299 | // Begins a new transaction. No any other thread can send command on this 300 | // connection until Commit or Rollback will be called. 301 | // Periodical pinging the server is disabled during transaction. 302 | 303 | func (c *Conn) Begin() (mysql.Transaction, error) { 304 | //log.Println("Begin") 305 | c.lock() 306 | tr := Transaction{ 307 | &Conn{Conn: c.Conn, mutex: new(sync.Mutex)}, 308 | c, 309 | } 310 | _, err := c.Conn.Start("START TRANSACTION") 311 | if err != nil { 312 | c.unlock() 313 | return nil, err 314 | } 315 | return &tr, nil 316 | } 317 | 318 | func (tr *Transaction) end(cr string) error { 319 | tr.lock() 320 | _, err := tr.conn.Conn.Start(cr) 321 | tr.conn.unlock() 322 | // Invalidate this transaction 323 | m := tr.Conn.mutex 324 | tr.Conn = nil 325 | tr.conn = nil 326 | m.Unlock() // One goorutine which still uses this transaction will panic 327 | return err 328 | } 329 | 330 | func (tr *Transaction) Commit() error { 331 | //log.Println("Commit") 332 | return tr.end("COMMIT") 333 | } 334 | 335 | func (tr *Transaction) Rollback() error { 336 | //log.Println("Rollback") 337 | return tr.end("ROLLBACK") 338 | } 339 | 340 | func (tr *Transaction) IsValid() bool { 341 | return tr.Conn != nil 342 | } 343 | 344 | func (tr *Transaction) Do(st mysql.Stmt) mysql.Stmt { 345 | if s, ok := st.(*Stmt); ok && s.conn == tr.conn { 346 | // Returns new statement which uses statement mutexes 347 | return &Stmt{s.Stmt, tr.Conn} 348 | } 349 | panic("Transaction and statement doesn't belong to the same connection") 350 | } 351 | 352 | var orgNew func(proto, laddr, raddr, user, passwd string, db ...string) mysql.Conn 353 | 354 | func init() { 355 | orgNew = mysql.New 356 | mysql.New = New 357 | } 358 | -------------------------------------------------------------------------------- /thrsafe/thrsafe_test.go: -------------------------------------------------------------------------------- 1 | package thrsafe 2 | 3 | import ( 4 | "github.com/ziutek/mymysql/mysql" 5 | "github.com/ziutek/mymysql/native" 6 | "testing" 7 | ) 8 | 9 | const ( 10 | user = "testuser" 11 | passwd = "TestPasswd9" 12 | dbname = "test" 13 | proto = "tcp" 14 | daddr = "127.0.0.1:3306" 15 | //proto = "unix" 16 | //daddr = "/var/run/mysqld/mysqld.sock" 17 | debug = false 18 | ) 19 | 20 | var db mysql.Conn 21 | 22 | func checkErr(t *testing.T, err error) { 23 | if err != nil { 24 | t.Fatalf("Error: %v", err) 25 | } 26 | } 27 | 28 | func connect(t *testing.T) mysql.Conn { 29 | db := New(proto, "", daddr, user, passwd, dbname) 30 | db.(*Conn).Conn.(*native.Conn).Debug = debug 31 | checkErr(t, db.Connect()) 32 | return db 33 | } 34 | 35 | func TestS(t *testing.T) { 36 | db := connect(t) 37 | res, err := db.Start("SET @a=1") 38 | checkErr(t, err) 39 | if !res.StatusOnly() { 40 | t.Fatalf("'SET @a' statement returns result with rows") 41 | } 42 | err = db.Close() 43 | checkErr(t, err) 44 | } 45 | 46 | func TestSS(t *testing.T) { 47 | db := connect(t) 48 | 49 | res, err := db.Start("SET @a=1; SET @b=2") 50 | checkErr(t, err) 51 | if !res.StatusOnly() { 52 | t.Fatalf("'SET @a' statement returns result with rows") 53 | } 54 | 55 | res, err = res.NextResult() 56 | checkErr(t, err) 57 | if !res.StatusOnly() { 58 | t.Fatalf("'SET @b' statement returns result with rows") 59 | } 60 | 61 | err = db.Close() 62 | checkErr(t, err) 63 | } 64 | 65 | func TestSDS(t *testing.T) { 66 | db := connect(t) 67 | 68 | res, err := db.Start("SET @a=1; SELECT @a; SET @b=2") 69 | checkErr(t, err) 70 | if !res.StatusOnly() { 71 | t.Fatalf("'SET @a' statement returns result with rows") 72 | } 73 | 74 | res, err = res.NextResult() 75 | checkErr(t, err) 76 | rows, err := res.GetRows() 77 | checkErr(t, err) 78 | if rows[0].Int(0) != 1 { 79 | t.Fatalf("First query doesn't return '1'") 80 | } 81 | 82 | res, err = res.NextResult() 83 | checkErr(t, err) 84 | if !res.StatusOnly() { 85 | t.Fatalf("'SET @b' statement returns result with rows") 86 | } 87 | 88 | err = db.Close() 89 | checkErr(t, err) 90 | } 91 | 92 | func TestSSDDD(t *testing.T) { 93 | db := connect(t) 94 | 95 | res, err := db.Start("SET @a=1; SET @b=2; SELECT @a; SELECT @b; SELECT 3") 96 | checkErr(t, err) 97 | if !res.StatusOnly() { 98 | t.Fatalf("'SET @a' statement returns result with rows") 99 | } 100 | 101 | res, err = res.NextResult() 102 | checkErr(t, err) 103 | if !res.StatusOnly() { 104 | t.Fatalf("'SET @b' statement returns result with rows") 105 | } 106 | 107 | res, err = res.NextResult() 108 | checkErr(t, err) 109 | rows, err := res.GetRows() 110 | checkErr(t, err) 111 | if rows[0].Int(0) != 1 { 112 | t.Fatalf("First query doesn't return '1'") 113 | } 114 | 115 | res, err = res.NextResult() 116 | checkErr(t, err) 117 | rows, err = res.GetRows() 118 | checkErr(t, err) 119 | if rows[0].Int(0) != 2 { 120 | t.Fatalf("Second query doesn't return '2'") 121 | } 122 | 123 | res, err = res.NextResult() 124 | checkErr(t, err) 125 | rows, err = res.GetRows() 126 | checkErr(t, err) 127 | if rows[0].Int(0) != 3 { 128 | t.Fatalf("Thrid query doesn't return '3'") 129 | } 130 | if res.MoreResults() { 131 | t.Fatalf("There is unexpected one more result") 132 | } 133 | 134 | err = db.Close() 135 | checkErr(t, err) 136 | } 137 | --------------------------------------------------------------------------------