├── .gitignore ├── .goreleaser.yml ├── LICENSE ├── README.md ├── client.go ├── docker-compose.yml ├── go.mod ├── go.sum ├── main.go ├── scanner.go └── scanner_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | /dist/ 2 | /mysqldump-loader 3 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | builds: 2 | - ldflags: 3 | - -extldflags -static -s -w 4 | env: 5 | - CGO_ENABLED=0 6 | signs: 7 | - artifacts: checksum 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Satoshi Matsumoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mysqldump-loader 2 | 3 | Load a MySQL dump file using LOAD DATA INFILE. 4 | 5 | ## Usage 6 | 7 | ./mysqldump-loader 8 | 9 | ### Flags 10 | 11 | $ ./mysqldump-loader --help 12 | Usage of ./mysqldump-loader: 13 | -concurrency int 14 | Maximum number of concurrent load operations. (default Number of available CPUs) 15 | -data-source-name string 16 | Data source name for MySQL server to load data into. 17 | -dump-file string 18 | MySQL dump file to load. 19 | -low-priority 20 | Use LOW_PRIORITY when loading data. 21 | -mysql-variable value 22 | MySQL variable (format: =) 23 | -replace-table 24 | Load data into a temporary table and replace the old table with it once load is complete. 25 | -verbose 26 | Verbose mode. 27 | 28 | ## Development 29 | 30 | ### Building 31 | 32 | go build 33 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | ) 9 | 10 | type client struct { 11 | conn *sql.Conn 12 | } 13 | 14 | func (c *client) addForeignKeys(ctx context.Context, database, table string, foreignKeys []string) error { 15 | var query bytes.Buffer 16 | 17 | query.WriteString("ALTER TABLE ") 18 | 19 | if database != "" { 20 | query.Write(quoteName(database)) 21 | query.WriteByte('.') 22 | } 23 | query.Write(quoteName(table)) 24 | 25 | for i, fk := range foreignKeys { 26 | if i != 0 { 27 | query.WriteByte(',') 28 | } 29 | query.WriteString(" ADD ") 30 | query.WriteString(fk) 31 | } 32 | 33 | return c.exec(ctx, query.String()) 34 | } 35 | 36 | func (c *client) createTable(ctx context.Context, database, table, body string) error { 37 | var query bytes.Buffer 38 | 39 | query.WriteString("CREATE TABLE ") 40 | if database != "" { 41 | query.Write(quoteName(database)) 42 | query.WriteByte('.') 43 | } 44 | 45 | query.Write(quoteName(table)) 46 | query.WriteByte(' ') 47 | 48 | query.WriteString(body) 49 | 50 | return c.exec(ctx, query.String()) 51 | } 52 | 53 | func (c *client) dropTableIfExists(ctx context.Context, database, table string) error { 54 | var query bytes.Buffer 55 | 56 | query.WriteString("DROP TABLE IF EXISTS ") 57 | 58 | if database != "" { 59 | query.Write(quoteName(database)) 60 | query.WriteByte('.') 61 | } 62 | query.Write(quoteName(table)) 63 | 64 | return c.exec(ctx, query.String()) 65 | } 66 | 67 | func (c *client) renameTable(ctx context.Context, database, oldTable, newTable string) error { 68 | var query bytes.Buffer 69 | 70 | query.WriteString("RENAME TABLE ") 71 | if database != "" { 72 | query.Write(quoteName(database)) 73 | query.WriteByte('.') 74 | } 75 | query.Write(quoteName(oldTable)) 76 | 77 | query.WriteString(" TO ") 78 | if database != "" { 79 | query.Write(quoteName(database)) 80 | query.WriteByte('.') 81 | } 82 | query.Write(quoteName(newTable)) 83 | 84 | return c.exec(ctx, query.String()) 85 | } 86 | 87 | func (c *client) setCharacterSet(ctx context.Context, charset string) error { 88 | return c.exec(ctx, fmt.Sprintf("SET NAMES %s", charset)) 89 | } 90 | 91 | func (c *client) setVariables(ctx context.Context, variables map[string]string) error { 92 | var query bytes.Buffer 93 | 94 | query.WriteString("SET ") 95 | for name, value := range variables { 96 | if query.Len() != 4 { 97 | query.WriteString(", ") 98 | } 99 | query.WriteString("SESSION ") 100 | query.WriteString(name) 101 | query.WriteString(" = ") 102 | query.WriteString(value) 103 | } 104 | 105 | return c.exec(ctx, query.String()) 106 | } 107 | 108 | func (c *client) exec(ctx context.Context, query string) error { 109 | _, err := c.conn.ExecContext(ctx, query) 110 | return err 111 | } 112 | 113 | func (c *client) close() error { 114 | return c.conn.Close() 115 | } 116 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | mysql56: 4 | image: mysql:5.6 5 | ports: 6 | - 3306 7 | environment: 8 | MYSQL_ROOT_PASSWORD: password 9 | mysql57: 10 | image: mysql:5.7 11 | ports: 12 | - 3306 13 | environment: 14 | MYSQL_ROOT_PASSWORD: password 15 | mysql80: 16 | image: mysql:8.0 17 | ports: 18 | - 3306 19 | environment: 20 | MYSQL_ROOT_PASSWORD: password 21 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kaorimatz/mysqldump-loader 2 | 3 | go 1.14 4 | 5 | require github.com/go-sql-driver/mysql v1.5.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= 2 | github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 3 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "encoding/hex" 8 | "errors" 9 | "flag" 10 | "fmt" 11 | "io" 12 | "log" 13 | "os" 14 | "runtime" 15 | "strconv" 16 | "strings" 17 | "sync" 18 | 19 | "github.com/go-sql-driver/mysql" 20 | ) 21 | 22 | var ( 23 | concurrency = flag.Int("concurrency", 0, "Maximum number of concurrent load operations.") 24 | dataSourceName = flag.String("data-source-name", "", "Data source name for MySQL server to load data into.") 25 | dumpFile = flag.String("dump-file", "", "MySQL dump file to load.") 26 | lowPriority = flag.Bool("low-priority", false, "Use LOW_PRIORITY when loading data.") 27 | replaceTable = flag.Bool("replace-table", false, "Load data into a temporary table and replace the old table with it once load is complete.") 28 | verbose = flag.Bool("verbose", false, "Verbose mode.") 29 | mysqlVariables = make(mysqlVariableValue) 30 | ) 31 | 32 | type mysqlVariableValue map[string]string 33 | 34 | func (v *mysqlVariableValue) String() string { 35 | var buf bytes.Buffer 36 | for name, value := range *v { 37 | if buf.Len() != 0 { 38 | buf.WriteByte(',') 39 | } 40 | buf.WriteString(name) 41 | buf.WriteString("=") 42 | buf.WriteString(value) 43 | } 44 | return buf.String() 45 | } 46 | 47 | func (v *mysqlVariableValue) Set(value string) error { 48 | parts := strings.SplitN(value, "=", 2) 49 | if len(parts) != 2 { 50 | return errors.New("must be a name=value pair") 51 | } 52 | (*v)[parts[0]] = parts[1] 53 | return nil 54 | } 55 | 56 | func init() { 57 | flag.Var(&mysqlVariables, "mysql-variable", "MySQL variable (format: =)") 58 | flag.Lookup("concurrency").DefValue = "Number of available CPUs" 59 | } 60 | 61 | func main() { 62 | flag.Parse() 63 | 64 | if *concurrency == 0 { 65 | *concurrency = runtime.NumCPU() 66 | } 67 | 68 | if *dataSourceName == "" { 69 | *dataSourceName = os.Getenv("DATA_SOURCE_NAME") 70 | } 71 | 72 | db, err := sql.Open("mysql", *dataSourceName) 73 | if err != nil { 74 | log.Fatal(err) 75 | } 76 | 77 | r := os.Stdin 78 | if *dumpFile != "" { 79 | if r, err = os.Open(*dumpFile); err != nil { 80 | log.Fatal(err) 81 | } 82 | } 83 | 84 | clientFactory := func(ctx context.Context) (*client, error) { 85 | conn, err := db.Conn(ctx) 86 | if err != nil { 87 | return nil, fmt.Errorf("error getting database connection: %v", err) 88 | } 89 | 90 | c := client{conn: conn} 91 | 92 | mysqlVariables["foreign_key_checks"] = "0" 93 | if err := c.setVariables(ctx, mysqlVariables); err != nil { 94 | return nil, fmt.Errorf("error setting MySQL variables: %v", err) 95 | } 96 | 97 | return &c, nil 98 | } 99 | 100 | client, err := clientFactory(context.Background()) 101 | if err != nil { 102 | log.Fatal(err) 103 | } 104 | defer client.close() 105 | 106 | var replacer *replacer 107 | if *replaceTable { 108 | replacer = newReplacer(client) 109 | } 110 | 111 | loader := newLoader(clientFactory, *concurrency, *lowPriority) 112 | 113 | scanner := newScanner(r) 114 | 115 | executor := &executor{client: client, loader: loader, scanner: scanner, replacer: replacer} 116 | if err := executor.execute(); err != nil { 117 | log.Fatal(err) 118 | } 119 | } 120 | 121 | type executor struct { 122 | client *client 123 | loader *loader 124 | replacer *replacer 125 | scanner *scanner 126 | } 127 | 128 | func (e *executor) execute() error { 129 | var table *table 130 | var charset, database string 131 | var err error 132 | 133 | e.loader.start() 134 | 135 | for e.scanner.scan() { 136 | q := e.scanner.query() 137 | if e.replacer != nil && q.isDropTableStatement() { 138 | continue 139 | } else if e.replacer != nil && q.isCreateTableStatement() { 140 | if table != nil { 141 | if err := e.replacer.execute(context.Background(), database, table); err != nil { 142 | return err 143 | } 144 | } 145 | 146 | table, err = parseCreateTableStatement(q) 147 | if err != nil { 148 | return fmt.Errorf("error parsing CREATE TABLE statement on line %d: %v", q.line, err) 149 | } 150 | 151 | if *verbose { 152 | log.Printf("Creating new table %s...", quoteName(table.name)) 153 | } 154 | 155 | if err := e.client.createTable(context.Background(), database, table.name, table.body); err != nil { 156 | return fmt.Errorf("error creating table %s: %v", quoteName(table.name), err) 157 | } 158 | } else if q.isAlterTableStatement() || q.isLockTablesStatement() || q.isUnlockTablesStatement() { 159 | continue 160 | } else if q.isInsertStatement() || q.isReplaceStatement() { 161 | if err := e.loader.execute(context.Background(), q, charset, database, table); err != nil { 162 | return err 163 | } 164 | } else { 165 | if err := e.client.exec(context.Background(), q.s); err != nil { 166 | return fmt.Errorf("error executing query on line %d: %v", q.line, err) 167 | } 168 | if q.isSetNamesStatement() { 169 | if charset, err = parseSetNamesStatement(q); err != nil { 170 | return fmt.Errorf("error parsing SET NAMES statement on line %d: %v", q.line, err) 171 | } 172 | } 173 | if q.isUseStatement() { 174 | if database, err = parseUseStatement(q); err != nil { 175 | return fmt.Errorf("error parsing USE statement on line %d: %v", q.line, err) 176 | } 177 | } 178 | } 179 | } 180 | 181 | if e.replacer != nil && table != nil { 182 | if err := e.replacer.execute(context.Background(), database, table); err != nil { 183 | return err 184 | } 185 | } 186 | 187 | if err := e.scanner.err(); err != nil { 188 | return fmt.Errorf("error reading dump file: %v", err) 189 | } 190 | 191 | if err := e.loader.wait(); err != nil { 192 | return err 193 | } 194 | 195 | if e.replacer != nil { 196 | if err := e.replacer.wait(); err != nil { 197 | return err 198 | } 199 | } 200 | 201 | return nil 202 | } 203 | 204 | func parseCreateTableStatement(q *query) (*table, error) { 205 | var buf bytes.Buffer 206 | var foreignKeys []string 207 | 208 | origName, i, err := parseIdentifier(q.s, len("CREATE TABLE "), " ") 209 | if err != nil { 210 | return nil, fmt.Errorf("error parsing table name: %v", err) 211 | } 212 | i++ 213 | 214 | if !strings.HasPrefix(q.s[i:], "(\n") { 215 | return nil, errors.New("unsupported CREATE TABLE statement") 216 | } 217 | i += 2 218 | 219 | name := "_" + origName + "_tmp" 220 | 221 | buf.WriteString("(\n") 222 | scanner := &tableScanner{s: q.s[i:]} 223 | for scanner.scan() { 224 | d := scanner.definition() 225 | if isConstraintClause(d) { 226 | foreignKeys = append(foreignKeys, d) 227 | } else { 228 | if buf.Len() != 2 { 229 | buf.WriteString(",\n") 230 | } 231 | buf.WriteString(" ") 232 | buf.WriteString(d) 233 | } 234 | } 235 | if err := scanner.err(); err != nil { 236 | return nil, fmt.Errorf("error parsing table definition: %v", err) 237 | } 238 | i += scanner.pos() 239 | 240 | buf.WriteByte('\n') 241 | buf.WriteString(q.s[i:]) 242 | 243 | return &table{body: buf.String(), foreignKeys: foreignKeys, name: name, origName: origName}, nil 244 | } 245 | 246 | type table struct { 247 | body string 248 | foreignKeys []string 249 | name string 250 | origName string 251 | wg sync.WaitGroup 252 | } 253 | 254 | type tableScanner struct { 255 | d string 256 | e error 257 | p int 258 | quote byte 259 | s string 260 | stringLiteral bool 261 | } 262 | 263 | func (s *tableScanner) scan() bool { 264 | i := s.p 265 | 266 | if !strings.HasPrefix(s.s[i:], " ") { 267 | return false 268 | } 269 | i += 2 270 | 271 | for { 272 | j := strings.IndexAny(s.s[i:], "`\"'\\\n") 273 | if j == -1 { 274 | return false 275 | } else if s.quote == 0 && strings.IndexByte("`\"'", s.s[i+j]) != -1 { 276 | s.quote = s.s[i+j] 277 | s.stringLiteral = s.s[i+j] == '\'' 278 | i += j + 1 279 | } else if s.quote != 0 && s.s[i+j] == s.quote { 280 | if !s.stringLiteral && len(s.s) > i+j+1 && s.s[i+j+1] == s.quote { 281 | i += j + 2 282 | } else { 283 | s.quote = 0 284 | s.stringLiteral = false 285 | i += j + 1 286 | } 287 | } else if s.stringLiteral && s.s[i+j] == '\\' { 288 | i += j + 2 289 | } else if s.quote == 0 && s.s[i+j] == '\n' { 290 | if len(s.s) > 1 && s.s[i+j-1] == ',' { 291 | s.d = s.s[s.p+2 : i+j-1] 292 | } else { 293 | s.d = s.s[s.p+2 : i+j] 294 | } 295 | s.p = i + j + 1 296 | return true 297 | } else { 298 | i += j + 1 299 | } 300 | } 301 | } 302 | 303 | func (s *tableScanner) err() error { 304 | return s.e 305 | } 306 | 307 | func (s *tableScanner) definition() string { 308 | return s.d 309 | } 310 | 311 | func (s *tableScanner) pos() int { 312 | return s.p 313 | } 314 | 315 | func isConstraintClause(d string) bool { 316 | return strings.HasPrefix(d, "CONSTRAINT ") 317 | } 318 | 319 | func parseSetNamesStatement(q *query) (charset string, err error) { 320 | if strings.HasPrefix(q.s, "/*!") { 321 | // A version number after the ! character consists of exactly 322 | // 5 digits. See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/sql_lex.cc#L1728-L1735. 323 | charset, _, err = parseIdentifier(q.s, len("/*!00000 SET NAMES "), " ") 324 | } else { 325 | charset, _, err = parseIdentifier(q.s, len(" SET NAMES "), " ") 326 | } 327 | return 328 | } 329 | 330 | func parseUseStatement(q *query) (database string, err error) { 331 | database, _, err = parseIdentifier(q.s, len("USE "), ";") 332 | return 333 | } 334 | 335 | func parseIdentifier(s string, i int, terms string) (string, int, error) { 336 | var buf bytes.Buffer 337 | if s[i] == '`' || s[i] == '"' { 338 | quote := s[i] 339 | i++ 340 | for { 341 | j := strings.IndexByte(s[i:], quote) 342 | if j == -1 { 343 | return "", 0, fmt.Errorf("name is not enclosed by '%c'", quote) 344 | } 345 | buf.WriteString(s[i : i+j]) 346 | i += j + 1 347 | if strings.IndexByte(terms, s[i]) != -1 { 348 | break 349 | } else if s[i] == quote { 350 | buf.WriteByte(quote) 351 | } else { 352 | return "", 0, fmt.Errorf("unexpected character '%c'", s[i]) 353 | } 354 | } 355 | } else { 356 | j := strings.IndexAny(s[i:], terms) 357 | if j == -1 { 358 | return "", 0, errors.New("name is not terminated") 359 | } else { 360 | buf.WriteString(s[i : i+j]) 361 | i += j 362 | } 363 | } 364 | return buf.String(), i, nil 365 | } 366 | 367 | type query struct { 368 | line int 369 | s string 370 | } 371 | 372 | func (q *query) isAlterTableStatement() bool { 373 | return strings.HasPrefix(q.s, "/*!40000 ALTER TABLE ") 374 | } 375 | 376 | func (q *query) isCreateTableStatement() bool { 377 | return strings.HasPrefix(q.s, "CREATE TABLE ") 378 | } 379 | 380 | func (q *query) isDropTableStatement() bool { 381 | return strings.HasPrefix(q.s, "DROP TABLE ") 382 | } 383 | 384 | func (q *query) isInsertStatement() bool { 385 | return strings.HasPrefix(q.s, "INSERT ") 386 | } 387 | 388 | func (q *query) isLockTablesStatement() bool { 389 | return strings.HasPrefix(q.s, "LOCK TABLES ") 390 | } 391 | 392 | func (q *query) isReplaceStatement() bool { 393 | return strings.HasPrefix(q.s, "REPLACE ") 394 | } 395 | 396 | func (q *query) isSetNamesStatement() bool { 397 | return strings.HasPrefix(q.s, " SET NAMES ") || 398 | strings.HasPrefix(q.s, "/*!40101 SET NAMES ") || 399 | strings.HasPrefix(q.s, "/*!50503 SET NAMES ") 400 | } 401 | 402 | func (q *query) isUnlockTablesStatement() bool { 403 | return q.s == "UNLOCK TABLES;" 404 | } 405 | 406 | func (q *query) isUseStatement() bool { 407 | return strings.HasPrefix(q.s, "USE ") 408 | } 409 | 410 | type loader struct { 411 | ch chan request 412 | concurrency int 413 | clientFactory func(ctx context.Context) (*client, error) 414 | errCh chan error 415 | lowPriority bool 416 | wg sync.WaitGroup 417 | } 418 | 419 | func newLoader(clientFactory func(ctx context.Context) (*client, error), concurrency int, lowPriority bool) *loader { 420 | return &loader{ 421 | clientFactory: clientFactory, 422 | concurrency: concurrency, 423 | lowPriority: lowPriority, 424 | } 425 | } 426 | 427 | func (l *loader) start() { 428 | l.ch = make(chan request, l.concurrency*2) 429 | l.errCh = make(chan error, l.concurrency) 430 | 431 | l.wg.Add(l.concurrency) 432 | 433 | for i := 0; i < l.concurrency; i++ { 434 | go func() { 435 | defer l.wg.Done() 436 | l.loop() 437 | }() 438 | } 439 | } 440 | 441 | func (l *loader) loop() { 442 | client, err := l.clientFactory(context.Background()) 443 | if err != nil { 444 | l.errCh <- err 445 | return 446 | } 447 | defer client.close() 448 | 449 | for r := range l.ch { 450 | if err := l.load(client, r.ctx, r.q, r.charset, r.database, r.table); err != nil { 451 | l.errCh <- fmt.Errorf("error loading data on line %d: %v", r.q.line, err) 452 | break 453 | } 454 | 455 | if r.table != nil { 456 | r.table.wg.Done() 457 | } 458 | } 459 | } 460 | 461 | func (l *loader) execute(ctx context.Context, q *query, charset, database string, table *table) error { 462 | select { 463 | case err := <-l.errCh: 464 | return err 465 | default: 466 | } 467 | 468 | if table != nil { 469 | table.wg.Add(1) 470 | } 471 | 472 | l.ch <- request{ctx: ctx, q: q, charset: charset, database: database, table: table} 473 | 474 | return nil 475 | } 476 | 477 | func (l *loader) load(client *client, ctx context.Context, q *query, charset, database string, table *table) error { 478 | i, err := convert(q) 479 | if err != nil { 480 | return fmt.Errorf("error converting query: %v", err) 481 | } 482 | 483 | var query bytes.Buffer 484 | query.WriteString("LOAD DATA ") 485 | if l.lowPriority { 486 | query.WriteString("LOW_PRIORITY ") 487 | } 488 | query.WriteString(fmt.Sprintf("LOCAL INFILE 'Reader::%d' ", q.line)) 489 | if i.replace { 490 | query.WriteString("REPLACE ") 491 | } else if i.ignore { 492 | query.WriteString("IGNORE ") 493 | } 494 | query.WriteString("INTO TABLE ") 495 | if database != "" { 496 | query.Write(quoteName(database)) 497 | query.WriteByte('.') 498 | } 499 | if table != nil { 500 | query.Write(quoteName(table.name)) 501 | } else { 502 | query.Write(quoteName(i.table)) 503 | } 504 | if charset != "" { 505 | query.WriteString(" CHARACTER SET ") 506 | query.WriteString(charset) 507 | } 508 | 509 | mysql.RegisterReaderHandler(strconv.Itoa(q.line), func() io.Reader { return i.r }) 510 | defer mysql.DeregisterReaderHandler(strconv.Itoa(q.line)) 511 | 512 | if charset != "" { 513 | if err := client.setCharacterSet(ctx, charset); err != nil { 514 | return fmt.Errorf("error setting character set: %v", err) 515 | } 516 | } 517 | 518 | if err := client.exec(ctx, query.String()); err != nil { 519 | return err 520 | } 521 | 522 | return nil 523 | } 524 | 525 | func convert(q *query) (*insertion, error) { 526 | var replace, ignore bool 527 | var i int 528 | if strings.HasPrefix(q.s, "INSERT ") { 529 | i = len("INSERT ") 530 | } else if strings.HasPrefix(q.s, "REPLACE ") { 531 | replace = true 532 | i = len("REPLACE ") 533 | } else { 534 | return nil, errors.New("unsupported statement") 535 | } 536 | 537 | if strings.HasPrefix(q.s[i:], "IGNORE ") { 538 | ignore = true 539 | i += len("IGNORE ") 540 | } 541 | 542 | if strings.HasPrefix(q.s[i:], "INTO ") { 543 | i += len("INTO ") 544 | } else { 545 | return nil, errors.New("unsupported statement") 546 | } 547 | 548 | table, i, err := parseIdentifier(q.s, i, " ") 549 | if err != nil { 550 | return nil, fmt.Errorf("error parsing table name: %v", err) 551 | } 552 | i++ 553 | 554 | if q.s[i] == '(' { 555 | i++ 556 | for { 557 | _, i, err = parseIdentifier(q.s, i, ",)") 558 | if err != nil { 559 | return nil, fmt.Errorf("error parsing column name: %v", err) 560 | } 561 | if q.s[i] == ')' { 562 | i++ 563 | break 564 | } else if strings.HasPrefix(q.s[i:], ", ") { 565 | i += 2 566 | } else { 567 | return nil, errors.New("no space character after ',' in a list of column names") 568 | } 569 | } 570 | if q.s[i] != ' ' { 571 | return nil, errors.New("no space character after a list of column names") 572 | } 573 | i++ 574 | } 575 | 576 | if strings.HasPrefix(q.s[i:], "VALUES ") { 577 | i += len("VALUES ") 578 | } else { 579 | return nil, errors.New("unsupported statement") 580 | } 581 | 582 | var buf bytes.Buffer 583 | for { 584 | for { 585 | if q.s[i] == '(' { 586 | i++ 587 | } 588 | if strings.HasPrefix(q.s[i:], "_binary ") { 589 | i += len("_binary ") 590 | } 591 | if q.s[i] == '\'' { 592 | i++ 593 | for { 594 | // TODO: NO_BACKSLASH_ESCAPES 595 | j := strings.IndexAny(q.s[i:], "\\\t'") 596 | if j == -1 { 597 | return nil, errors.New("column value is not enclosed") 598 | } 599 | buf.WriteString(q.s[i : i+j]) 600 | i += j 601 | if q.s[i] == '\\' { 602 | buf.WriteString(q.s[i : i+2]) 603 | i += 2 604 | } else if q.s[i] == '\t' { 605 | buf.WriteString(`\t`) 606 | i++ 607 | } else if strings.IndexByte(",)", q.s[i+1]) != -1 { 608 | i++ 609 | break 610 | } else { 611 | return nil, errors.New("unescaped single quote") 612 | } 613 | } 614 | } else if strings.HasPrefix(q.s[i:], "0x") { 615 | j := strings.IndexAny(q.s[i+2:], ",)") 616 | if j == -1 { 617 | return nil, errors.New("hex blob is not terminated") 618 | } 619 | if _, err := buf.ReadFrom(hex.NewDecoder(strings.NewReader(q.s[i+2 : i+2+j]))); err != nil { 620 | return nil, fmt.Errorf("error decoding hex blob: %v", err) 621 | } 622 | i += 2 + j 623 | } else { 624 | j := strings.IndexAny(q.s[i:], ",)") 625 | if j == -1 { 626 | return nil, errors.New("column value is not terminated") 627 | } 628 | s := q.s[i : i+j] 629 | if s == "NULL" { 630 | buf.WriteString(`\N`) 631 | } else { 632 | buf.WriteString(s) 633 | } 634 | i += j 635 | } 636 | if q.s[i] == ',' { 637 | buf.WriteByte('\t') 638 | i++ 639 | } else { 640 | buf.WriteByte('\n') 641 | i++ 642 | break 643 | } 644 | } 645 | if q.s[i] == ',' { 646 | i++ 647 | } else if q.s[i] == ';' { 648 | i++ 649 | break 650 | } else { 651 | return nil, fmt.Errorf("unexpected character '%c'", q.s[i]) 652 | } 653 | } 654 | 655 | return &insertion{ignore: ignore, r: &buf, replace: replace, table: table}, nil 656 | } 657 | 658 | func quoteName(name string) []byte { 659 | var i int 660 | buf := make([]byte, len(name)*2+2) 661 | 662 | buf[i] = '`' 663 | i++ 664 | for j := 0; j < len(name); j++ { 665 | if name[j] == '`' { 666 | buf[i] = '`' 667 | i++ 668 | } 669 | buf[i] = name[j] 670 | i++ 671 | } 672 | buf[i] = '`' 673 | i++ 674 | 675 | return buf[:i] 676 | } 677 | 678 | func (l *loader) wait() error { 679 | close(l.ch) 680 | 681 | waitCh := make(chan struct{}) 682 | go func() { 683 | defer close(waitCh) 684 | l.wg.Wait() 685 | }() 686 | 687 | select { 688 | case err := <-l.errCh: 689 | return err 690 | case <-waitCh: 691 | return nil 692 | } 693 | } 694 | 695 | type request struct { 696 | charset string 697 | ctx context.Context 698 | database string 699 | q *query 700 | table *table 701 | } 702 | 703 | type insertion struct { 704 | ignore bool 705 | r io.Reader 706 | replace bool 707 | table string 708 | } 709 | 710 | type replacer struct { 711 | client *client 712 | errCh chan error 713 | mutex sync.Mutex 714 | wg sync.WaitGroup 715 | } 716 | 717 | func newReplacer(client *client) *replacer { 718 | return &replacer{client: client, errCh: make(chan error, 1)} 719 | } 720 | 721 | func (r *replacer) execute(ctx context.Context, database string, table *table) error { 722 | select { 723 | case err := <-r.errCh: 724 | return err 725 | default: 726 | } 727 | r.wg.Add(1) 728 | go func() { 729 | defer r.wg.Done() 730 | table.wg.Wait() 731 | if err := r.replace(ctx, database, table); err != nil { 732 | r.errCh <- fmt.Errorf("error replacing table %s with new table %s: %v", quoteName(table.origName), quoteName(table.name), err) 733 | } 734 | }() 735 | return nil 736 | } 737 | 738 | func (r *replacer) replace(ctx context.Context, database string, table *table) error { 739 | r.mutex.Lock() 740 | defer r.mutex.Unlock() 741 | 742 | if *verbose { 743 | log.Printf("Replacing table %s with new table %s...", quoteName(table.origName), quoteName(table.name)) 744 | } 745 | 746 | if err := r.client.dropTableIfExists(ctx, database, table.origName); err != nil { 747 | return fmt.Errorf("error dropping table %s: %v", quoteName(table.origName), err) 748 | } 749 | 750 | if err := r.client.renameTable(ctx, database, table.name, table.origName); err != nil { 751 | return fmt.Errorf("error renaming table %s to %s: %v", quoteName(table.name), quoteName(table.origName), err) 752 | } 753 | 754 | if len(table.foreignKeys) > 0 { 755 | if err := r.client.addForeignKeys(ctx, database, table.origName, table.foreignKeys); err != nil { 756 | return fmt.Errorf("error restoring foreign keys in table %s: %v", quoteName(table.origName), err) 757 | } 758 | } 759 | 760 | return nil 761 | } 762 | 763 | func (r *replacer) wait() error { 764 | waitCh := make(chan struct{}) 765 | go func() { 766 | defer close(waitCh) 767 | r.wg.Wait() 768 | }() 769 | 770 | select { 771 | case err := <-r.errCh: 772 | return err 773 | case <-waitCh: 774 | return nil 775 | } 776 | } 777 | -------------------------------------------------------------------------------- /scanner.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "strings" 10 | ) 11 | 12 | type scanner struct { 13 | buf bytes.Buffer 14 | comment bool 15 | e error 16 | line int 17 | q *query 18 | quote byte 19 | reader *bufio.Reader 20 | stringLiteral bool 21 | } 22 | 23 | func newScanner(r io.Reader) *scanner { 24 | return &scanner{reader: bufio.NewReader(r)} 25 | } 26 | 27 | func (s *scanner) scan() bool { 28 | if s.e != nil { 29 | return false 30 | } 31 | 32 | line := s.line 33 | 34 | for { 35 | str, err := s.reader.ReadString('\n') 36 | if err == io.EOF && s.buf.Len() != 0 { 37 | s.e = errors.New("unexpected EOF") 38 | return false 39 | } 40 | if err != nil { 41 | s.e = err 42 | return false 43 | } 44 | s.line++ 45 | 46 | if !s.comment && s.quote == 0 && (strings.HasPrefix(str, "--") || str == "\n") { 47 | continue 48 | } 49 | 50 | var i int 51 | for { 52 | j := strings.IndexAny(str[i:], "/*`\"'\\;") 53 | if j == -1 { 54 | s.buf.WriteString(str) 55 | break 56 | } else if !s.comment && s.quote == 0 && strings.HasPrefix(str[i+j:], "/*!") { 57 | s.comment = true 58 | i += j + 3 59 | } else if s.comment && s.quote == 0 && strings.HasPrefix(str[i+j:], "*/") { 60 | s.comment = false 61 | i += j + 2 62 | } else if s.quote == 0 && strings.IndexByte("`\"'", str[i+j]) != -1 { 63 | s.quote = str[i+j] 64 | s.stringLiteral = str[i+j] == '\'' 65 | i += j + 1 66 | } else if s.quote != 0 && str[i+j] == s.quote { 67 | if !s.stringLiteral && len(str) > i+j+1 && str[i+j+1] == s.quote { 68 | i += j + 2 69 | } else { 70 | s.quote = 0 71 | s.stringLiteral = false 72 | i += j + 1 73 | } 74 | } else if s.stringLiteral && str[i+j] == '\\' { 75 | i += j + 2 76 | } else if !s.comment && s.quote == 0 && str[i+j] == ';' { 77 | if len(str) != i+j+2 || str[i+j+1] != '\n' { 78 | s.e = fmt.Errorf("newline is expected after ';'. line=%d", s.line) 79 | return false 80 | } 81 | s.buf.WriteString(str[:i+j+1]) 82 | s.q = &query{line: line + 1, s: s.buf.String()} 83 | s.buf.Reset() 84 | return true 85 | } else { 86 | i += j + 1 87 | } 88 | } 89 | } 90 | } 91 | 92 | func (s *scanner) err() error { 93 | if s.e == io.EOF { 94 | return nil 95 | } 96 | return s.e 97 | } 98 | 99 | func (s *scanner) query() *query { 100 | if s.e != nil { 101 | return nil 102 | } 103 | return s.q 104 | } 105 | -------------------------------------------------------------------------------- /scanner_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestQuery(t *testing.T) { 11 | cases := []struct { 12 | in string 13 | want *query 14 | }{ 15 | {"\n", nil}, 16 | {"--\n", nil}, 17 | {"/*!40000 ALTER TABLE `/*!` DISABLE KEYS */;\n", &query{line: 1, s: "/*!40000 ALTER TABLE `/*!` DISABLE KEYS */;"}}, 18 | {"/*!40000 ALTER TABLE `*/` DISABLE KEYS */;\n", &query{line: 1, s: "/*!40000 ALTER TABLE `*/` DISABLE KEYS */;"}}, 19 | {"/*!40000 ALTER TABLE `'\"` DISABLE KEYS */;\n", &query{line: 1, s: "/*!40000 ALTER TABLE `'\"` DISABLE KEYS */;"}}, 20 | {"/*!40000 ALTER TABLE `;` DISABLE KEYS */;\n", &query{line: 1, s: "/*!40000 ALTER TABLE `;` DISABLE KEYS */;"}}, 21 | {"/*!40000 ALTER TABLE `a\nb` DISABLE KEYS */;\n", &query{line: 1, s: "/*!40000 ALTER TABLE `a\nb` DISABLE KEYS */;"}}, 22 | } 23 | for _, c := range cases { 24 | s := &scanner{reader: bufio.NewReader(strings.NewReader(c.in))} 25 | s.scan() 26 | got := s.query() 27 | if got == nil && c.want == nil { 28 | continue 29 | } 30 | if got == nil || *got != *c.want { 31 | t.Errorf("in: %q, got: %#v, want: %#v", c.in, got, c.want) 32 | } 33 | } 34 | } 35 | 36 | func TestErr(t *testing.T) { 37 | cases := []struct { 38 | in string 39 | want error 40 | }{ 41 | {"", nil}, 42 | {"UNLOCK TABLES\n", errors.New("unexpected EOF")}, 43 | {"LOCK TABLES `a` WRITE;UNLOCK TABLES;\n", errors.New("newline is expected after ';'. line=1")}, 44 | } 45 | for _, c := range cases { 46 | s := &scanner{reader: bufio.NewReader(strings.NewReader(c.in))} 47 | s.scan() 48 | got := s.err() 49 | if got == nil && c.want == nil { 50 | continue 51 | } 52 | if got == nil || got.Error() != c.want.Error() { 53 | t.Errorf("in: %q, got: %#v, want: %#v", c.in, got, c.want) 54 | } 55 | } 56 | } 57 | --------------------------------------------------------------------------------