├── .gitignore ├── LICENSE ├── README.md ├── cmd └── mig │ ├── create.go │ ├── down.go │ ├── main.go │ ├── redo.go │ ├── root.go │ ├── status.go │ ├── up.go │ └── version.go ├── dialect.go ├── migrate.go ├── migrate_test.go ├── migration.go ├── migration_sql.go ├── migration_test.go ├── util.go └── wrappers.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | cmd/mig/mig 3 | *.swp 4 | *.test 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified work Copyright (c) 2017 Patrick O'Brien - github.com/nullbio/mig 2 | Modified work Copyright (c) 2016 Vojtech Vitek - github.com/pressly/goose 3 | Original work Copyright (c) 2012 Liam Staskawicz - bitbucket.org/liamstask/goose 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of 8 | this software and associated documentation files (the "Software"), to deal in 9 | the Software without restriction, including without limitation the rights to 10 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 11 | the Software, and to permit persons to whom the Software is furnished to do so, 12 | subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 19 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 20 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 21 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 22 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mig 2 | 3 | mig is a database migration tool. Manage your database's evolution by creating incremental SQL files. 4 | 5 | [![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/volatiletech/mig/blob/master/LICENSE) 6 | [![GoDoc](https://godoc.org/github.com/volatiletech/mig?status.svg)](https://godoc.org/github.com/volatiletech/mig) 7 | [![Go Report Card](https://goreportcard.com/badge/volatiletech/mig)](http://goreportcard.com/report/volatiletech/mig) 8 | 9 | ### Goals of this fork 10 | 11 | This is a restructured and modified fork of https://github.com/pressly/goose 12 | which includes many necessary bug fixes and general improvements to 13 | better suit the Go abcweb framework https://github.com/nullbio/abcweb -- 14 | although feel free to use this migration tool as a standalone or in your own 15 | projects even if you do not use abcweb, it by no means requires it. 16 | 17 | # Install 18 | 19 | $ go get -u github.com/volatiletech/mig/... 20 | 21 | # Usage 22 | 23 | ``` 24 | mig is a database migration tool for Postgres and MySQL. 25 | 26 | Usage: 27 | mig [command] 28 | 29 | Examples: 30 | mig up postgres "user=postgres dbname=postgres sslmode=disable" 31 | mig down mysql "user:password@/dbname" 32 | mig create add_users 33 | 34 | Available Commands: 35 | create Create a blank migration template 36 | down Roll back the version by one 37 | downall Roll back all migrations 38 | help Help about any command 39 | redo Down then up the latest migration 40 | redoall Down then up all migrations 41 | status Dump the migration status for the database 42 | up Migrate the database to the most recent version available 43 | upone Migrate the database by one version 44 | version Print the current version of the database 45 | 46 | Flags: 47 | --version Print the mig tool version 48 | 49 | Use "mig [command] --help" for more information about a command. 50 | ``` 51 | 52 | Note: If you're using [ABCWeb](https://github.com/volatiletech/abcweb) the `mig` 53 | commands are built into the `abcweb` tool. See `abcweb --help` for usage. 54 | 55 | ## Supported Databases 56 | 57 | mig supports MySQL and Postgres. The drivers used are: 58 | 59 | https://github.com/go-sql-driver/mysql 60 | 61 | https://github.com/lib/pq 62 | 63 | See these drivers for details on the format of their connection strings. 64 | 65 | ## Couple of example runs 66 | 67 | ### create 68 | 69 | Create a new SQL migration. 70 | 71 | $ mig create add_users 72 | $ Created db/migrations/20130106093224_add_users.sql 73 | 74 | Edit the newly created script to define the behavior of your migration. Your 75 | SQL statements should go below the Up and Down comments. 76 | 77 | An example command run: 78 | 79 | ### up 80 | 81 | Apply all available migrations. 82 | 83 | $ mig up postgres "user=username password=password dbname=database" 84 | $ Success 20170314220650_add_dogs.sql 85 | $ Success 20170314221501_add_cats.sql 86 | $ Success 2 migrations 87 | 88 | ## Migrations 89 | 90 | A sample SQL migration looks like: 91 | 92 | ```sql 93 | -- +mig Up 94 | CREATE TABLE post ( 95 | id int NOT NULL, 96 | title text, 97 | body text, 98 | PRIMARY KEY(id) 99 | ); 100 | 101 | -- +mig Down 102 | DROP TABLE post; 103 | ``` 104 | 105 | Notice the annotations in the comments. Any statements following `-- +mig Up` will be executed as part of a forward migration, and any statements following `-- +mig Down` will be executed as part of a rollback. 106 | 107 | By default, SQL statements are delimited by semicolons - in fact, query statements must end with a semicolon to be properly recognized by mig. 108 | 109 | More complex statements (PL/pgSQL) that have semicolons within them must be annotated with `-- +mig StatementBegin` and `-- +mig StatementEnd` to be properly recognized. For example: 110 | 111 | ```sql 112 | -- +mig Up 113 | -- +mig StatementBegin 114 | CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) 115 | returns void AS $$ 116 | DECLARE 117 | create_query text; 118 | BEGIN 119 | FOR create_query IN SELECT 120 | 'CREATE TABLE IF NOT EXISTS histories_' 121 | || TO_CHAR( d, 'YYYY_MM' ) 122 | || ' ( CHECK( created_at >= timestamp ''' 123 | || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) 124 | || ''' AND created_at < timestamp ''' 125 | || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) 126 | || ''' ) ) inherits ( histories );' 127 | FROM generate_series( $1, $2, '1 month' ) AS d 128 | LOOP 129 | EXECUTE create_query; 130 | END LOOP; -- LOOP END 131 | END; -- FUNCTION END 132 | $$ 133 | language plpgsql; 134 | -- +mig StatementEnd 135 | ``` 136 | 137 | ## Library functions 138 | 139 | 140 | ```go 141 | // Global io.Writer variable that can be changed to get incremental success 142 | // messages from function calls that process more than one migration, 143 | // for example Up and DownAll. Defaults to ioutil.Discard. 144 | var mig.Log 145 | 146 | // Create a templated migration file in dir 147 | mig.Create(name, dir string) (name string, err error) 148 | 149 | // Down rolls back the version by one 150 | mig.Down(driver, conn, dir string) (name string, err error) 151 | 152 | // DownAll rolls back all migrations. 153 | // Logs success messages to global writer variable Log. 154 | mig.DownAll(driver, conn, dir string) (count int, err error) 155 | 156 | // Up migrates to the highest version available 157 | mig.Up(driver, conn, dir string) (count int, err error) 158 | 159 | // UpOne migrates one version 160 | mig.UpOne(driver, conn, dir string) (name string, err error) 161 | 162 | // Redo re-runs the latest migration. 163 | mig.Redo(driver, conn, dir string) (name string, err error) 164 | 165 | // Return the status of each migration 166 | mig.Status(driver, conn, dir string) (status, error) 167 | 168 | // Return the current migration version 169 | mig.Version(driver, conn string) (version int64, err error) 170 | ``` 171 | -------------------------------------------------------------------------------- /cmd/mig/create.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/spf13/cobra" 8 | "github.com/spf13/viper" 9 | "github.com/volatiletech/mig" 10 | ) 11 | 12 | var createCmd = &cobra.Command{ 13 | Use: "create", 14 | Short: "Create a blank migration template", 15 | Long: "Create a blank migration template", 16 | Example: `mig create add_users`, 17 | RunE: createRunE, 18 | } 19 | 20 | func init() { 21 | createCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 22 | 23 | rootCmd.AddCommand(createCmd) 24 | createCmd.PreRun = func(*cobra.Command, []string) { 25 | viper.BindPFlags(createCmd.Flags()) 26 | } 27 | } 28 | 29 | func createRunE(cmd *cobra.Command, args []string) error { 30 | if len(args) < 1 || len(args[0]) == 0 { 31 | return errors.New("no migration name provided") 32 | } 33 | 34 | path, err := mig.Create(args[0], viper.GetString("dir")) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | fmt.Println(fmt.Sprintf("Created %s", path)) 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /cmd/mig/down.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/viper" 8 | "github.com/volatiletech/mig" 9 | ) 10 | 11 | var downCmd = &cobra.Command{ 12 | Use: "down", 13 | Short: "Roll back the version by one", 14 | Long: "Roll back the version by one", 15 | Example: `mig down mysql "user:password@/dbname"`, 16 | RunE: downRunE, 17 | } 18 | 19 | var downAllCmd = &cobra.Command{ 20 | Use: "downall", 21 | Short: "Roll back all migrations", 22 | Long: "Roll back all migrations", 23 | Example: `mig downall mysql "user:password@/dbname"`, 24 | RunE: downAllRunE, 25 | } 26 | 27 | func init() { 28 | downCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 29 | downAllCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 30 | 31 | rootCmd.AddCommand(downCmd) 32 | rootCmd.AddCommand(downAllCmd) 33 | 34 | downCmd.PreRun = func(*cobra.Command, []string) { 35 | viper.BindPFlags(downCmd.Flags()) 36 | } 37 | downAllCmd.PreRun = func(*cobra.Command, []string) { 38 | viper.BindPFlags(downAllCmd.Flags()) 39 | } 40 | } 41 | 42 | func downRunE(cmd *cobra.Command, args []string) error { 43 | driver, conn, err := getConnArgs(args) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | name, err := mig.Down(driver, conn, viper.GetString("dir")) 49 | if mig.IsNoMigrationError(err) { 50 | fmt.Println("No migrations to run") 51 | return nil 52 | } else if err != nil { 53 | return err 54 | } 55 | 56 | fmt.Printf("Success %v\n", name) 57 | return nil 58 | } 59 | 60 | func downAllRunE(cmd *cobra.Command, args []string) error { 61 | driver, conn, err := getConnArgs(args) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | count, err := mig.DownAll(driver, conn, viper.GetString("dir")) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | if count == 0 { 72 | fmt.Printf("No migrations to run") 73 | } else { 74 | fmt.Printf("Success %d migrations\n", count) 75 | } 76 | 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /cmd/mig/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | const migVersion = "1.0.0" 9 | 10 | func main() { 11 | // Too much happens between here and cobra's argument handling, for 12 | // something so simple. Just do it immediately. 13 | if len(os.Args) > 1 && os.Args[1] == "--version" { 14 | fmt.Println("mig v" + migVersion) 15 | return 16 | } 17 | 18 | if err := rootCmd.Execute(); err != nil { 19 | fmt.Println(err) 20 | os.Exit(1) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /cmd/mig/redo.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/viper" 8 | "github.com/volatiletech/mig" 9 | ) 10 | 11 | var redoCmd = &cobra.Command{ 12 | Use: "redo", 13 | Short: "Down then up the latest migration", 14 | Long: "Down then up the latest migration", 15 | Example: `mig redo postgres "user=postgres dbname=postgres sslmode=disable"`, 16 | RunE: redoRunE, 17 | } 18 | 19 | var redoAllCmd = &cobra.Command{ 20 | Use: "redo", 21 | Short: "Down then up all migrations", 22 | Long: "Down then up all migrations", 23 | Example: `mig redoall postgres "user=postgres dbname=postgres sslmode=disable"`, 24 | RunE: redoAllRunE, 25 | } 26 | 27 | func init() { 28 | redoCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 29 | 30 | rootCmd.AddCommand(redoCmd) 31 | redoCmd.PreRun = func(*cobra.Command, []string) { 32 | viper.BindPFlags(redoCmd.Flags()) 33 | } 34 | } 35 | 36 | func redoRunE(cmd *cobra.Command, args []string) error { 37 | driver, conn, err := getConnArgs(args) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | name, err := mig.Redo(driver, conn, viper.GetString("dir")) 43 | if mig.IsNoMigrationError(err) { 44 | fmt.Println("No migrations to run") 45 | return nil 46 | } else if err != nil { 47 | return err 48 | } 49 | 50 | fmt.Printf("Success %v\n", name) 51 | return nil 52 | } 53 | 54 | func redoAllRunE(cmd *cobra.Command, args []string) error { 55 | driver, conn, err := getConnArgs(args) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | _, err = mig.DownAll(driver, conn, viper.GetString("dir")) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | count, err := mig.Up(driver, conn, viper.GetString("dir")) 66 | if err != nil { 67 | return err 68 | } 69 | 70 | if count == 0 { 71 | fmt.Printf("No migrations to run") 72 | } else { 73 | fmt.Printf("Success %d migrations\n", count) 74 | } 75 | 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /cmd/mig/root.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "math" 6 | "os" 7 | 8 | "github.com/spf13/cobra" 9 | "github.com/spf13/viper" 10 | "github.com/volatiletech/mig" 11 | ) 12 | 13 | var ( 14 | minVersion = int64(0) 15 | maxVersion = int64(math.MaxInt64) 16 | ) 17 | 18 | var rootCmd = &cobra.Command{ 19 | Use: "mig", 20 | Short: "mig is a database migration tool for Postgres and MySQL.", 21 | Long: "mig is a database migration tool for Postgres and MySQL.", 22 | Example: `mig up postgres "user=postgres dbname=postgres sslmode=disable" 23 | mig down mysql "user:password@/dbname" 24 | mig create add_users`, 25 | } 26 | 27 | func init() { 28 | // Set the mig library logger to os.Stdout 29 | mig.Log = os.Stdout 30 | 31 | rootCmd.Flags().BoolP("version", "", false, "Print the mig tool version") 32 | viper.BindPFlags(rootCmd.Flags()) 33 | } 34 | 35 | // getConnArgs takes in args from cobra and returns the 0th and 1st arg 36 | // which should be the driver and connection string 37 | func getConnArgs(args []string) (driver string, conn string, err error) { 38 | if len(args) < 2 { 39 | return "", "", errors.New("no connection details provided") 40 | } 41 | 42 | return args[0], args[1], nil 43 | } 44 | -------------------------------------------------------------------------------- /cmd/mig/status.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/viper" 8 | "github.com/volatiletech/mig" 9 | ) 10 | 11 | var statusCmd = &cobra.Command{ 12 | Use: "status", 13 | Short: "Dump the migration status for the database", 14 | Long: "Dump the migration status for the database", 15 | Example: `mig status postgres "user=postgres dbname=postgres sslmode=disable"`, 16 | RunE: statusRunE, 17 | } 18 | 19 | func init() { 20 | statusCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 21 | 22 | rootCmd.AddCommand(statusCmd) 23 | 24 | statusCmd.PreRun = func(*cobra.Command, []string) { 25 | viper.BindPFlags(statusCmd.Flags()) 26 | } 27 | } 28 | 29 | func statusRunE(cmd *cobra.Command, args []string) error { 30 | driver, conn, err := getConnArgs(args) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | status, err := mig.Status(driver, conn, viper.GetString("dir")) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | if len(status) == 0 { 41 | fmt.Printf("No migrations applied") 42 | return nil 43 | } 44 | 45 | fmt.Println("Applied At Migration") 46 | fmt.Println("===================================================") 47 | for _, s := range status { 48 | fmt.Printf("%-24s -- %v\n", s.Applied, s.Name) 49 | } 50 | 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /cmd/mig/up.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/viper" 8 | "github.com/volatiletech/mig" 9 | ) 10 | 11 | var upCmd = &cobra.Command{ 12 | Use: "up", 13 | Short: "Migrate the database to the most recent version available", 14 | Long: "Migrate the database to the most recent version available", 15 | Example: `mig up mysql "user:password@/dbname"`, 16 | RunE: upRunE, 17 | } 18 | 19 | var upOneCmd = &cobra.Command{ 20 | Use: "upone", 21 | Short: "Migrate the database by one version", 22 | Long: "Migrate the database by one version", 23 | Example: `mig upone mysql "user:password@/dbname"`, 24 | RunE: upOneRunE, 25 | } 26 | 27 | func init() { 28 | upCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 29 | upOneCmd.Flags().StringP("dir", "d", ".", "directory with migration files") 30 | 31 | rootCmd.AddCommand(upCmd) 32 | rootCmd.AddCommand(upOneCmd) 33 | 34 | upCmd.PreRun = func(*cobra.Command, []string) { 35 | viper.BindPFlags(upCmd.Flags()) 36 | } 37 | upOneCmd.PreRun = func(*cobra.Command, []string) { 38 | viper.BindPFlags(upOneCmd.Flags()) 39 | } 40 | } 41 | 42 | func upRunE(cmd *cobra.Command, args []string) error { 43 | driver, conn, err := getConnArgs(args) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | count, err := mig.Up(driver, conn, viper.GetString("dir")) 49 | if err != nil { 50 | return err 51 | } 52 | 53 | if count == 0 { 54 | fmt.Printf("No migrations to run") 55 | } else { 56 | fmt.Printf("Success %d migrations\n", count) 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func upOneRunE(cmd *cobra.Command, args []string) error { 63 | driver, conn, err := getConnArgs(args) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | name, err := mig.UpOne(driver, conn, viper.GetString("dir")) 69 | if mig.IsNoMigrationError(err) { 70 | fmt.Println("No migrations to run") 71 | return nil 72 | } else if err != nil { 73 | return err 74 | } 75 | 76 | fmt.Printf("Success %v\n", name) 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /cmd/mig/version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/viper" 8 | "github.com/volatiletech/mig" 9 | ) 10 | 11 | var versionCmd = &cobra.Command{ 12 | Use: "version", 13 | Short: "Print the current version of the database", 14 | Long: "Print the current version of the database", 15 | Example: `mig version postgres "user=postgres dbname=postgres sslmode=disable"`, 16 | RunE: versionRunE, 17 | } 18 | 19 | func init() { 20 | rootCmd.AddCommand(versionCmd) 21 | versionCmd.PreRun = func(*cobra.Command, []string) { 22 | viper.BindPFlags(versionCmd.Flags()) 23 | } 24 | } 25 | 26 | func versionRunE(cmd *cobra.Command, args []string) error { 27 | driver, conn, err := getConnArgs(args) 28 | if err != nil { 29 | return err 30 | } 31 | 32 | version, err := mig.Version(driver, conn) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | if version == 0 { 38 | fmt.Printf("No migrations applied") 39 | } else { 40 | fmt.Printf("Version %d\n", version) 41 | } 42 | 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | // sqlDialect abstracts the details of specific SQL dialects 10 | // for mig's few SQL specific statements 11 | type sqlDialect interface { 12 | createVersionTableSQL() string // sql string to create the mig_migrations table 13 | insertVersionSQL() string // sql string to insert the initial version table row 14 | versionQuery(db *sql.DB) (*sql.Rows, error) 15 | } 16 | 17 | var dialect sqlDialect = &postgresDialect{} 18 | 19 | func getDialect() sqlDialect { 20 | return dialect 21 | } 22 | 23 | // SetDialect sets the current driver dialect for all future calls 24 | // to the library. 25 | func SetDialect(d string) error { 26 | return setDialect(d) 27 | } 28 | 29 | func setDialect(d string) error { 30 | switch d { 31 | case "postgres": 32 | dialect = &postgresDialect{} 33 | case "mysql": 34 | dialect = &mySQLDialect{} 35 | case "sqlite3": 36 | fmt.Println("sqlite3 not supported") 37 | os.Exit(1) 38 | //dialect = &sqlite3Dialect{} 39 | default: 40 | return fmt.Errorf("%q: unknown dialect", d) 41 | } 42 | 43 | return nil 44 | } 45 | 46 | type postgresDialect struct{} 47 | type mySQLDialect struct{} 48 | type sqlite3Dialect struct{} 49 | 50 | func (postgresDialect) createVersionTableSQL() string { 51 | return `CREATE TABLE mig_migrations ( 52 | id serial NOT NULL, 53 | version_id bigint NOT NULL, 54 | is_applied boolean NOT NULL, 55 | tstamp timestamp NULL default now(), 56 | PRIMARY KEY(id) 57 | );` 58 | } 59 | 60 | func (postgresDialect) insertVersionSQL() string { 61 | return "INSERT INTO mig_migrations (version_id, is_applied) VALUES ($1, $2);" 62 | } 63 | 64 | func (postgresDialect) versionQuery(db *sql.DB) (*sql.Rows, error) { 65 | rows, err := db.Query("SELECT version_id, is_applied from mig_migrations ORDER BY id DESC") 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return rows, err 71 | } 72 | 73 | func (mySQLDialect) createVersionTableSQL() string { 74 | return `CREATE TABLE mig_migrations ( 75 | id serial NOT NULL, 76 | version_id bigint NOT NULL, 77 | is_applied boolean NOT NULL, 78 | tstamp timestamp NULL default now(), 79 | PRIMARY KEY(id) 80 | );` 81 | } 82 | 83 | func (mySQLDialect) insertVersionSQL() string { 84 | return "INSERT INTO mig_migrations (version_id, is_applied) VALUES (?, ?);" 85 | } 86 | 87 | func (mySQLDialect) versionQuery(db *sql.DB) (*sql.Rows, error) { 88 | rows, err := db.Query("SELECT version_id, is_applied from mig_migrations ORDER BY id DESC") 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | return rows, err 94 | } 95 | 96 | func (sqlite3Dialect) createVersionTableSQL() string { 97 | return `CREATE TABLE mig_migrations ( 98 | id INTEGER PRIMARY KEY AUTOINCREMENT, 99 | version_id INTEGER NOT NULL, 100 | is_applied INTEGER NOT NULL, 101 | tstamp TIMESTAMP DEFAULT (datetime('now')) 102 | );` 103 | } 104 | 105 | func (sqlite3Dialect) insertVersionSQL() string { 106 | return "INSERT INTO mig_migrations (version_id, is_applied) VALUES (?, ?);" 107 | } 108 | 109 | func (sqlite3Dialect) versionQuery(db *sql.DB) (*sql.Rows, error) { 110 | rows, err := db.Query("SELECT version_id, is_applied from mig_migrations ORDER BY id DESC") 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | return rows, err 116 | } 117 | -------------------------------------------------------------------------------- /migrate.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "path/filepath" 10 | "sort" 11 | "time" 12 | ) 13 | 14 | var ( 15 | ErrNoCurrentVersion = errors.New("no current version found") 16 | ErrNoNextVersion = errors.New("no next version found") 17 | ) 18 | 19 | var Log io.Writer 20 | 21 | func init() { 22 | Log = ioutil.Discard 23 | } 24 | 25 | type errNoMigration struct{} 26 | 27 | func (e errNoMigration) Error() string { 28 | return "no migrations to execute" 29 | } 30 | 31 | // IsNoMigrationError returns true if the error type is of 32 | // errNoMigration, indicating that there is no migration to run 33 | func IsNoMigrationError(err error) bool { 34 | _, ok := err.(errNoMigration) 35 | if ok { 36 | return ok 37 | } 38 | 39 | return false 40 | } 41 | 42 | type migrations []*migration 43 | 44 | // helpers so we can use pkg sort 45 | func (m migrations) Len() int { return len(m) } 46 | func (m migrations) Swap(i, j int) { m[i], m[j] = m[j], m[i] } 47 | func (m migrations) Less(i, j int) bool { 48 | if m[i].version == m[j].version { 49 | panic(fmt.Sprintf("mig: duplicate version %v detected:\n%v\n%v", m[i].version, m[i].source, m[j].source)) 50 | } 51 | return m[i].version < m[j].version 52 | } 53 | 54 | func (m migrations) current(current int64) (*migration, error) { 55 | for i, migration := range m { 56 | if migration.version == current { 57 | return m[i], nil 58 | } 59 | } 60 | 61 | return nil, ErrNoCurrentVersion 62 | } 63 | 64 | func (m migrations) next(current int64) (*migration, error) { 65 | for i, migration := range m { 66 | if migration.version > current { 67 | return m[i], nil 68 | } 69 | } 70 | 71 | return nil, ErrNoNextVersion 72 | } 73 | 74 | func (m migrations) last() (*migration, error) { 75 | if len(m) == 0 { 76 | return nil, ErrNoNextVersion 77 | } 78 | 79 | return m[len(m)-1], nil 80 | } 81 | 82 | func (m migrations) String() string { 83 | str := "" 84 | for _, migration := range m { 85 | str += fmt.Sprintln(migration) 86 | } 87 | return str 88 | } 89 | 90 | // collect all the valid looking migration scripts in the migrations folder, 91 | // and order them by version. 92 | func collectMigrations(dirpath string, current, target int64) (migrations, error) { 93 | var migrations migrations 94 | 95 | // extract the numeric component of each migration, 96 | // filter out any uninteresting files, 97 | // and ensure we only have one file per migration version. 98 | files, err := filepath.Glob(dirpath + "/*.sql") 99 | if err != nil { 100 | return nil, err 101 | } 102 | 103 | for _, file := range files { 104 | v, err := numericComponent(file) 105 | if err != nil { 106 | return nil, err 107 | } 108 | if versionFilter(v, current, target) { 109 | migration := &migration{version: v, next: -1, previous: -1, source: file} 110 | migrations = append(migrations, migration) 111 | } 112 | } 113 | 114 | migrations = sortAndConnectMigrations(migrations) 115 | 116 | return migrations, nil 117 | } 118 | 119 | // sortAndConnectMigrations sorts the migrations based on the version numbers 120 | // and creates a linked list between each migration. 121 | func sortAndConnectMigrations(migrations migrations) migrations { 122 | // Sort the migrations based on version 123 | sort.Sort(migrations) 124 | 125 | // now that we're sorted in the appropriate direction, 126 | // populate next and previous for each migration 127 | for i, m := range migrations { 128 | prev := int64(-1) 129 | if i > 0 { 130 | prev = migrations[i-1].version 131 | migrations[i-1].next = m.version 132 | } 133 | migrations[i].previous = prev 134 | } 135 | 136 | return migrations 137 | } 138 | 139 | // versionFilter returns true if v is within the current version and target 140 | // version range. 141 | func versionFilter(v, current, target int64) bool { 142 | if target > current { 143 | return v > current && v <= target 144 | } 145 | 146 | if target < current { 147 | return v <= current && v > target 148 | } 149 | 150 | return false 151 | } 152 | 153 | // Create the mig_migrations table 154 | // and insert the initial 0 value into it 155 | func createVersionTable(db *sql.DB) error { 156 | txn, err := db.Begin() 157 | if err != nil { 158 | return err 159 | } 160 | 161 | d := getDialect() 162 | 163 | if _, err := txn.Exec(d.createVersionTableSQL()); err != nil { 164 | txn.Rollback() 165 | return err 166 | } 167 | 168 | version := 0 169 | applied := true 170 | if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil { 171 | txn.Rollback() 172 | return err 173 | } 174 | 175 | return txn.Commit() 176 | } 177 | 178 | // getVersion retrieves the current version for this database. 179 | // Create and initialize the database migration table if it doesn't exist. 180 | func getVersion(db *sql.DB) (int64, error) { 181 | rows, err := getDialect().versionQuery(db) 182 | if err != nil { 183 | return 0, createVersionTable(db) 184 | } 185 | defer rows.Close() 186 | 187 | // The most recent record for each migration specifies 188 | // whether it has been applied or rolled back. 189 | // The first version we find that has been applied is the current version. 190 | 191 | toSkip := make([]int64, 0) 192 | 193 | for rows.Next() { 194 | var row migrationRecord 195 | if err = rows.Scan(&row.versionId, &row.isApplied); err != nil { 196 | return 0, fmt.Errorf("error scanning rows: %s", err) 197 | } 198 | 199 | // have we already marked this version to be skipped? 200 | skip := false 201 | for _, v := range toSkip { 202 | if v == row.versionId { 203 | skip = true 204 | break 205 | } 206 | } 207 | 208 | if skip { 209 | continue 210 | } 211 | 212 | // if version has been applied we're done 213 | if row.isApplied { 214 | return row.versionId, nil 215 | } 216 | 217 | // latest version of migration has not been applied. 218 | toSkip = append(toSkip, row.versionId) 219 | } 220 | 221 | panic("unreachable") 222 | } 223 | 224 | func getMigrationStatus(db *sql.DB, version int64) string { 225 | var row migrationRecord 226 | q := fmt.Sprintf("SELECT tstamp, is_applied FROM mig_migrations WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version) 227 | e := db.QueryRow(q).Scan(&row.tstamp, &row.isApplied) 228 | 229 | if e != nil && e != sql.ErrNoRows { 230 | panic(e) 231 | } 232 | 233 | var appliedAt string 234 | 235 | if row.isApplied { 236 | appliedAt = row.tstamp.Format(time.ANSIC) 237 | } else { 238 | appliedAt = "Pending" 239 | } 240 | 241 | return appliedAt 242 | } 243 | -------------------------------------------------------------------------------- /migrate_test.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func newMigration(v int64, src string) *migration { 8 | return &migration{version: v, previous: -1, next: -1, source: src} 9 | } 10 | 11 | func TestMigrationSort(t *testing.T) { 12 | 13 | ms := migrations{} 14 | 15 | // insert in any order 16 | ms = append(ms, newMigration(20120000, "test")) 17 | ms = append(ms, newMigration(20128000, "test")) 18 | ms = append(ms, newMigration(20129000, "test")) 19 | ms = append(ms, newMigration(20127000, "test")) 20 | 21 | ms = sortAndConnectMigrations(ms) 22 | 23 | sorted := []int64{20120000, 20127000, 20128000, 20129000} 24 | 25 | validateMigrationSort(t, ms, sorted) 26 | } 27 | 28 | func validateMigrationSort(t *testing.T, ms migrations, sorted []int64) { 29 | 30 | for i, m := range ms { 31 | if sorted[i] != m.version { 32 | t.Error("incorrect sorted version") 33 | } 34 | 35 | var next, prev int64 36 | 37 | if i == 0 { 38 | prev = -1 39 | next = ms[i+1].version 40 | } else if i == len(ms)-1 { 41 | prev = ms[i-1].version 42 | next = -1 43 | } else { 44 | prev = ms[i-1].version 45 | next = ms[i+1].version 46 | } 47 | 48 | if m.next != next { 49 | t.Errorf("mismatched next. v: %v, got %v, wanted %v\n", m, m.next, next) 50 | } 51 | 52 | if m.previous != prev { 53 | t.Errorf("mismatched previous v: %v, got %v, wanted %v\n", m, m.previous, prev) 54 | } 55 | } 56 | 57 | t.Log(ms) 58 | } 59 | -------------------------------------------------------------------------------- /migration.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "database/sql" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "os" 11 | "path/filepath" 12 | "strconv" 13 | "strings" 14 | "text/template" 15 | "time" 16 | ) 17 | 18 | type migrationRecord struct { 19 | versionId int64 20 | tstamp time.Time 21 | isApplied bool // was this a result of up() or down() 22 | } 23 | 24 | type migration struct { 25 | version int64 26 | next int64 // next version, or -1 if none 27 | previous int64 // previous version, -1 if none 28 | source string // path to .sql script 29 | } 30 | 31 | const sqlCmdPrefix = "-- +mig " 32 | 33 | var migrationTemplate = template.Must(template.New("mig.sql-migration").Parse(`-- +mig Up 34 | 35 | -- +mig Down 36 | 37 | `)) 38 | 39 | func (m *migration) String() string { 40 | return fmt.Sprintf(m.source) 41 | } 42 | 43 | func (m *migration) up(db *sql.DB) (string, error) { 44 | return m.run(db, true) 45 | } 46 | 47 | func (m *migration) down(db *sql.DB) (string, error) { 48 | return m.run(db, false) 49 | } 50 | 51 | func (m *migration) run(db *sql.DB, direction bool) (name string, err error) { 52 | if err := runMigration(db, m.source, m.version, direction); err != nil { 53 | return "", err 54 | } 55 | 56 | return filepath.Base(m.source), nil 57 | } 58 | 59 | // look for migration scripts with names in the form: 60 | // XXX_descriptivename.sql 61 | // where XXX specifies the version number 62 | func numericComponent(name string) (int64, error) { 63 | base := filepath.Base(name) 64 | 65 | if ext := filepath.Ext(base); ext != ".sql" { 66 | return 0, errors.New("not a recognized migration file type") 67 | } 68 | 69 | idx := strings.Index(base, "_") 70 | if idx < 0 { 71 | return 0, errors.New("no separator found") 72 | } 73 | 74 | n, e := strconv.ParseInt(base[:idx], 10, 64) 75 | if e == nil && n <= 0 { 76 | return 0, errors.New("migration IDs must be greater than zero") 77 | } 78 | 79 | return n, e 80 | } 81 | 82 | // Update the version table for the given migration, 83 | // and finalize the transaction. 84 | func finalizeMigration(tx *sql.Tx, direction bool, v int64) error { 85 | stmt := getDialect().insertVersionSQL() 86 | if _, err := tx.Exec(stmt, v, direction); err != nil { 87 | tx.Rollback() 88 | return err 89 | } 90 | 91 | return tx.Commit() 92 | } 93 | 94 | // Checks the line to see if the line has a statement-ending semicolon 95 | // or if the line contains a double-dash comment. 96 | func endsWithSemicolon(line string) bool { 97 | prev := "" 98 | scanner := bufio.NewScanner(strings.NewReader(line)) 99 | scanner.Split(bufio.ScanWords) 100 | 101 | for scanner.Scan() { 102 | word := scanner.Text() 103 | if strings.HasPrefix(word, "--") { 104 | break 105 | } 106 | prev = word 107 | } 108 | 109 | return strings.HasSuffix(prev, ";") 110 | } 111 | 112 | // Split the given sql script into individual statements. 113 | // 114 | // The base case is to simply split on semicolons, as these 115 | // naturally terminate a statement. 116 | // 117 | // However, more complex cases like pl/pgsql can have semicolons 118 | // within a statement. For these cases, we provide the explicit annotations 119 | // 'StatementBegin' and 'StatementEnd' to allow the script to 120 | // tell us to ignore semicolons. 121 | func splitSQLStatements(r io.Reader, direction bool) ([]string, error) { 122 | var err error 123 | var stmts []string 124 | var buf bytes.Buffer 125 | scanner := bufio.NewScanner(r) 126 | 127 | // track the count of each section 128 | // so we can diagnose scripts with no annotations 129 | upSections := 0 130 | downSections := 0 131 | 132 | statementEnded := false 133 | ignoreSemicolons := false 134 | directionIsActive := false 135 | 136 | for scanner.Scan() { 137 | 138 | line := scanner.Text() 139 | 140 | // handle any mig-specific commands 141 | if strings.HasPrefix(line, sqlCmdPrefix) { 142 | cmd := strings.TrimSpace(line[len(sqlCmdPrefix):]) 143 | switch cmd { 144 | case "Up": 145 | directionIsActive = (direction == true) 146 | upSections++ 147 | break 148 | 149 | case "Down": 150 | directionIsActive = (direction == false) 151 | downSections++ 152 | break 153 | 154 | case "StatementBegin": 155 | if directionIsActive { 156 | ignoreSemicolons = true 157 | } 158 | break 159 | 160 | case "StatementEnd": 161 | if directionIsActive { 162 | statementEnded = (ignoreSemicolons == true) 163 | ignoreSemicolons = false 164 | } 165 | break 166 | } 167 | } 168 | 169 | if !directionIsActive { 170 | continue 171 | } 172 | 173 | if _, err := buf.WriteString(line + "\n"); err != nil { 174 | panic(fmt.Sprintf("io err: %v", err)) 175 | } 176 | 177 | // Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement 178 | // Lines that end with semicolon that are in a statement block 179 | // do not conclude statement. 180 | if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded { 181 | statementEnded = false 182 | stmts = append(stmts, buf.String()) 183 | buf.Reset() 184 | } 185 | } 186 | 187 | if err := scanner.Err(); err != nil { 188 | return stmts, fmt.Errorf("error reading migration: %v", err) 189 | } 190 | 191 | // diagnose likely migration script errors 192 | if ignoreSemicolons { 193 | return stmts, errors.New("saw '-- +mig StatementBegin' with no matching '-- +mig StatementEnd'") 194 | } 195 | 196 | if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 { 197 | return stmts, fmt.Errorf("unexpected unfinished SQL query: %s. Missing a semicolon?", bufferRemaining) 198 | } 199 | 200 | if upSections == 0 && downSections == 0 { 201 | return stmts, fmt.Errorf(`no up/down annotations found, so no statements were executed`) 202 | } 203 | 204 | return stmts, err 205 | } 206 | 207 | // runMigration runs a migration specified in raw SQL. 208 | // 209 | // Sections of the script can be annotated with a special comment, 210 | // starting with "-- +mig" to specify whether the section should 211 | // be applied during an Up or Down migration 212 | // 213 | // All statements following an Up or Down directive are grouped together 214 | // until another direction directive is found. 215 | func runMigration(db *sql.DB, scriptFile string, v int64, direction bool) error { 216 | tx, err := db.Begin() 217 | if err != nil { 218 | panic(err) 219 | } 220 | 221 | f, err := os.Open(scriptFile) 222 | if err != nil { 223 | return fmt.Errorf("cannot open migration file %s: %v", scriptFile, err) 224 | } 225 | 226 | stmts, err := splitSQLStatements(f, direction) 227 | if err != nil { 228 | return fmt.Errorf("error splitting migration %s: %v", filepath.Base(scriptFile), err) 229 | } 230 | 231 | // find each statement, checking annotations for up/down direction 232 | // and execute each of them in the current transaction. 233 | // Commits the transaction if successfully applied each statement and 234 | // records the version into the version table or returns an error and 235 | // rolls back the transaction. 236 | for _, query := range stmts { 237 | if _, err = tx.Exec(query); err != nil { 238 | tx.Rollback() 239 | return fmt.Errorf("error executing migration %s: %v", filepath.Base(scriptFile), err) 240 | } 241 | } 242 | 243 | if err = finalizeMigration(tx, direction, v); err != nil { 244 | return fmt.Errorf("error committing migration %s: %v", filepath.Base(scriptFile), err) 245 | } 246 | 247 | return nil 248 | } 249 | -------------------------------------------------------------------------------- /migration_sql.go: -------------------------------------------------------------------------------- 1 | package mig 2 | -------------------------------------------------------------------------------- /migration_test.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestSemicolons(t *testing.T) { 9 | 10 | type testData struct { 11 | line string 12 | result bool 13 | } 14 | 15 | tests := []testData{ 16 | { 17 | line: "END;", 18 | result: true, 19 | }, 20 | { 21 | line: "END; -- comment", 22 | result: true, 23 | }, 24 | { 25 | line: "END ; -- comment", 26 | result: true, 27 | }, 28 | { 29 | line: "END -- comment", 30 | result: false, 31 | }, 32 | { 33 | line: "END -- comment ;", 34 | result: false, 35 | }, 36 | { 37 | line: "END \" ; \" -- comment", 38 | result: false, 39 | }, 40 | } 41 | 42 | for _, test := range tests { 43 | r := endsWithSemicolon(test.line) 44 | if r != test.result { 45 | t.Errorf("incorrect semicolon. got %v, want %v", r, test.result) 46 | } 47 | } 48 | } 49 | 50 | func TestSplitStatements(t *testing.T) { 51 | 52 | type testData struct { 53 | sql string 54 | direction bool 55 | count int 56 | } 57 | 58 | tests := []testData{ 59 | { 60 | sql: functxt, 61 | direction: true, 62 | count: 2, 63 | }, 64 | { 65 | sql: functxt, 66 | direction: false, 67 | count: 2, 68 | }, 69 | { 70 | sql: multitxt, 71 | direction: true, 72 | count: 2, 73 | }, 74 | { 75 | sql: multitxt, 76 | direction: false, 77 | count: 2, 78 | }, 79 | } 80 | 81 | for _, test := range tests { 82 | stmts, err := splitSQLStatements(strings.NewReader(test.sql), test.direction) 83 | if err != nil { 84 | t.Error(err) 85 | } 86 | if len(stmts) != test.count { 87 | t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count) 88 | } 89 | } 90 | } 91 | 92 | var functxt = `-- +mig Up 93 | CREATE TABLE IF NOT EXISTS histories ( 94 | id BIGSERIAL PRIMARY KEY, 95 | current_value varchar(2000) NOT NULL, 96 | created_at timestamp with time zone NOT NULL 97 | ); 98 | 99 | -- +mig StatementBegin 100 | CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) 101 | returns void AS $$ 102 | DECLARE 103 | create_query text; 104 | BEGIN 105 | FOR create_query IN SELECT 106 | 'CREATE TABLE IF NOT EXISTS histories_' 107 | || TO_CHAR( d, 'YYYY_MM' ) 108 | || ' ( CHECK( created_at >= timestamp ''' 109 | || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) 110 | || ''' AND created_at < timestamp ''' 111 | || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) 112 | || ''' ) ) inherits ( histories );' 113 | FROM generate_series( $1, $2, '1 month' ) AS d 114 | LOOP 115 | EXECUTE create_query; 116 | END LOOP; -- LOOP END 117 | END; -- FUNCTION END 118 | $$ 119 | language plpgsql; 120 | -- +mig StatementEnd 121 | 122 | -- +mig Down 123 | drop function histories_partition_creation(DATE, DATE); 124 | drop TABLE histories; 125 | ` 126 | 127 | // test multiple up/down transitions in a single script 128 | var multitxt = `-- +mig Up 129 | CREATE TABLE post ( 130 | id int NOT NULL, 131 | title text, 132 | body text, 133 | PRIMARY KEY(id) 134 | ); 135 | 136 | -- +mig Down 137 | DROP TABLE post; 138 | 139 | -- +mig Up 140 | CREATE TABLE fancier_post ( 141 | id int NOT NULL, 142 | title text, 143 | body text, 144 | created_on timestamp without time zone, 145 | PRIMARY KEY(id) 146 | ); 147 | 148 | -- +mig Down 149 | DROP TABLE fancier_post; 150 | ` 151 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "text/template" 7 | ) 8 | 9 | // common routines 10 | 11 | func writeTemplateToFile(path string, t *template.Template, data interface{}) (string, error) { 12 | f, e := os.Create(path) 13 | if e != nil { 14 | return "", e 15 | } 16 | defer f.Close() 17 | 18 | e = t.Execute(f, data) 19 | if e != nil { 20 | return "", e 21 | } 22 | 23 | return f.Name(), nil 24 | } 25 | 26 | func copyFile(dst, src string) (int64, error) { 27 | sf, err := os.Open(src) 28 | if err != nil { 29 | return 0, err 30 | } 31 | defer sf.Close() 32 | 33 | df, err := os.Create(dst) 34 | if err != nil { 35 | return 0, err 36 | } 37 | defer df.Close() 38 | 39 | return io.Copy(df, sf) 40 | } 41 | -------------------------------------------------------------------------------- /wrappers.go: -------------------------------------------------------------------------------- 1 | package mig 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "math" 7 | "path/filepath" 8 | "time" 9 | 10 | // mysql driver 11 | _ "github.com/go-sql-driver/mysql" 12 | // postgres driver 13 | _ "github.com/lib/pq" 14 | ) 15 | 16 | // Create a templated migration file in dir 17 | func Create(name, dir string) (string, error) { 18 | timestamp := time.Now().Format("20060102150405") 19 | filename := fmt.Sprintf("%v_%v.sql", timestamp, name) 20 | 21 | fpath := filepath.Join(dir, filename) 22 | tmpl := migrationTemplate 23 | 24 | path, err := writeTemplateToFile(fpath, tmpl, timestamp) 25 | return path, err 26 | } 27 | 28 | // Down rolls back the version by one 29 | func Down(driver, conn, dir string) (name string, err error) { 30 | db, err := sql.Open(driver, conn) 31 | if err != nil { 32 | return "", err 33 | } 34 | 35 | err = setDialect(driver) 36 | if err != nil { 37 | return "", err 38 | } 39 | 40 | return DownDB(db, dir) 41 | } 42 | 43 | // DownDB rolls back the version by one 44 | // Expects SetDialect to be called beforehand. 45 | func DownDB(db *sql.DB, dir string) (name string, err error) { 46 | currentVersion, err := getVersion(db) 47 | if err != nil { 48 | return "", err 49 | } 50 | 51 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 52 | if err != nil { 53 | return "", err 54 | } 55 | 56 | current, err := migrations.current(currentVersion) 57 | if err != nil { 58 | return "", errNoMigration{} 59 | } 60 | 61 | return current.down(db) 62 | } 63 | 64 | // DownAll rolls back all migrations. 65 | // Logs success messages to global writer variable Log. 66 | func DownAll(driver, conn, dir string) (int, error) { 67 | db, err := sql.Open(driver, conn) 68 | if err != nil { 69 | return 0, err 70 | } 71 | 72 | err = setDialect(driver) 73 | if err != nil { 74 | return 0, err 75 | } 76 | 77 | return DownAllDB(db, dir) 78 | } 79 | 80 | // DownAllDB rolls back all migrations. 81 | // Logs success messages to global writer variable Log. 82 | // Expects SetDialect to be called beforehand. 83 | func DownAllDB(db *sql.DB, dir string) (int, error) { 84 | count := 0 85 | 86 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 87 | if err != nil { 88 | return count, err 89 | } 90 | 91 | for { 92 | currentVersion, err := getVersion(db) 93 | if err != nil { 94 | return count, err 95 | } 96 | 97 | current, err := migrations.current(currentVersion) 98 | // no migrations left to run 99 | if err != nil { 100 | return count, nil 101 | } 102 | 103 | name, err := current.down(db) 104 | if err != nil { 105 | return count, err 106 | } 107 | 108 | Log.Write([]byte(fmt.Sprintf("Success %v\n", name))) 109 | count++ 110 | } 111 | } 112 | 113 | // Up migrates to the highest version available 114 | func Up(driver, conn, dir string) (int, error) { 115 | db, err := sql.Open(driver, conn) 116 | if err != nil { 117 | return 0, err 118 | } 119 | 120 | err = setDialect(driver) 121 | if err != nil { 122 | return 0, err 123 | } 124 | 125 | return UpDB(db, dir) 126 | } 127 | 128 | // UpDB migrates to the highest version available 129 | // Expects SetDialect to be called beforehand. 130 | func UpDB(db *sql.DB, dir string) (int, error) { 131 | count := 0 132 | 133 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 134 | if err != nil { 135 | return count, err 136 | } 137 | 138 | for { 139 | currentVersion, err := getVersion(db) 140 | if err != nil { 141 | return count, err 142 | } 143 | 144 | next, err := migrations.next(currentVersion) 145 | // no migrations left to run 146 | if err != nil { 147 | return count, nil 148 | } 149 | 150 | name, err := next.up(db) 151 | if err != nil { 152 | return count, err 153 | } 154 | 155 | Log.Write([]byte(fmt.Sprintf("Success %v\n", name))) 156 | count++ 157 | } 158 | } 159 | 160 | // UpOne migrates one version 161 | func UpOne(driver, conn, dir string) (name string, err error) { 162 | db, err := sql.Open(driver, conn) 163 | if err != nil { 164 | return "", err 165 | } 166 | 167 | err = setDialect(driver) 168 | if err != nil { 169 | return "", err 170 | } 171 | 172 | return UpOneDB(db, dir) 173 | } 174 | 175 | // UpOneDB migrates one version 176 | // Expects SetDialect to be called beforehand. 177 | func UpOneDB(db *sql.DB, dir string) (name string, err error) { 178 | currentVersion, err := getVersion(db) 179 | if err != nil { 180 | return "", err 181 | } 182 | 183 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 184 | if err != nil { 185 | return "", err 186 | } 187 | 188 | next, err := migrations.next(currentVersion) 189 | if err != nil { 190 | return "", errNoMigration{} 191 | } 192 | 193 | return next.up(db) 194 | } 195 | 196 | // Redo re-runs the latest migration. 197 | func Redo(driver, conn, dir string) (string, error) { 198 | db, err := sql.Open(driver, conn) 199 | if err != nil { 200 | return "", err 201 | } 202 | 203 | err = setDialect(driver) 204 | if err != nil { 205 | return "", err 206 | } 207 | 208 | return RedoDB(db, dir) 209 | } 210 | 211 | // RedoDB re-runs the latest migration. 212 | // Expects SetDialect to be called beforehand. 213 | func RedoDB(db *sql.DB, dir string) (string, error) { 214 | currentVersion, err := getVersion(db) 215 | if err != nil { 216 | return "", err 217 | } 218 | 219 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 220 | if err != nil { 221 | return "", err 222 | } 223 | 224 | current, err := migrations.current(currentVersion) 225 | if err != nil { 226 | return "", errNoMigration{} 227 | } 228 | 229 | if _, err := current.down(db); err != nil { 230 | return "", err 231 | } 232 | 233 | return current.up(db) 234 | } 235 | 236 | type migrationStatus struct { 237 | Applied string 238 | Name string 239 | } 240 | 241 | // status is a slice of migrationStatus 242 | type status []migrationStatus 243 | 244 | // Status returns the status of each migration 245 | func Status(driver, conn, dir string) (status, error) { 246 | s := status{} 247 | 248 | db, err := sql.Open(driver, conn) 249 | if err != nil { 250 | return s, err 251 | } 252 | 253 | err = setDialect(driver) 254 | if err != nil { 255 | return s, err 256 | } 257 | 258 | return StatusDB(db, dir) 259 | } 260 | 261 | // StatusDB returns the status of each migration 262 | // Expects SetDialect to be called beforehand 263 | func StatusDB(db *sql.DB, dir string) (status, error) { 264 | s := status{} 265 | 266 | migrations, err := collectMigrations(dir, 0, math.MaxInt64) 267 | if err != nil { 268 | return s, err 269 | } 270 | 271 | // must ensure that the version table exists if we're running on a pristine DB 272 | if _, err := getVersion(db); err != nil { 273 | return s, err 274 | } 275 | 276 | for _, migration := range migrations { 277 | s = append(s, migrationStatus{ 278 | Applied: getMigrationStatus(db, migration.version), 279 | Name: filepath.Base(migration.source), 280 | }) 281 | } 282 | 283 | return s, nil 284 | } 285 | 286 | // Version returns the current migration version 287 | func Version(driver, conn string) (int64, error) { 288 | db, err := sql.Open(driver, conn) 289 | if err != nil { 290 | return 0, err 291 | } 292 | 293 | err = setDialect(driver) 294 | if err != nil { 295 | return 0, err 296 | } 297 | 298 | return VersionDB(db) 299 | } 300 | 301 | // VersionDB returns the current migration version 302 | // Expects SetDialect to be called beforehand 303 | func VersionDB(db *sql.DB) (int64, error) { 304 | return getVersion(db) 305 | } 306 | --------------------------------------------------------------------------------