├── .gitignore ├── .travis.yml ├── LICENSE.txt ├── README.md ├── db.go ├── gomigrate.go ├── gomigrate_test.go ├── migration.go ├── test_migrations ├── test1_mysql │ ├── 1_test_down.sql │ └── 1_test_up.sql ├── test1_pg │ ├── 1_test_down.sql │ ├── 1_test_up.sql │ ├── 2_create_index_function_if_not_exists_down.sql │ ├── 2_create_index_function_if_not_exists_up.sql │ ├── 3_multiple_inserts_down.sql │ ├── 3_multiple_inserts_up.sql │ ├── 4_dash-test_down.sql │ └── 4_dash-test_up.sql └── test1_sqlite3 │ ├── 1_test_down.sql │ └── 1_test_up.sql └── utils.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | 3 | /.idea 4 | /gomigrate.iml 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 1.3 3 | addons: 4 | postgresql: "9.3" 5 | before_script: 6 | - psql -c 'create database gomigrate;' -U postgres 7 | - mysql -uroot -e "CREATE USER 'gomigrate'@'localhost' IDENTIFIED BY 'password';" 8 | - mysql -uroot -e "GRANT ALL PRIVILEGES ON * . * TO 'gomigrate'@'localhost';" 9 | - mysql -uroot -e "CREATE DATABASE gomigrate;" 10 | - go get github.com/lib/pq 11 | - go get github.com/go-sql-driver/mysql 12 | - go get github.com/mattn/go-sqlite3 13 | script: 14 | - DB=pg go test 15 | - DB=mysql go test 16 | - DB=sqlite3 go test 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 David Huie 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gomigrate 2 | 3 | [![Build Status](https://travis-ci.org/DavidHuie/gomigrate.svg?branch=master)](https://travis-ci.org/DavidHuie/gomigrate) 4 | 5 | A SQL database migration toolkit in Golang. 6 | 7 | ## Supported databases 8 | 9 | - PostgreSQL 10 | - MariaDB 11 | - MySQL 12 | - Sqlite3 13 | 14 | ## Usage 15 | 16 | First import the package: 17 | 18 | ```go 19 | import "github.com/DavidHuie/gomigrate" 20 | ``` 21 | 22 | Given a `database/sql` database connection to a PostgreSQL database, `db`, 23 | and a directory to migration files, create a migrator: 24 | 25 | ```go 26 | migrator, _ := gomigrate.NewMigrator(db, gomigrate.Postgres{}, "./migrations") 27 | ``` 28 | 29 | You may also specify a specific logger to use, such as logrus: 30 | 31 | ```go 32 | migrator, _ := gomigrate.NewMigratorWithLogger(db, gomigrate.Postgres{}, "./migrations", logrus.New()) 33 | ``` 34 | 35 | To migrate the database, run: 36 | 37 | ```go 38 | err := migrator.Migrate() 39 | ``` 40 | 41 | To rollback the last migration, run: 42 | 43 | ```go 44 | err := migrator.Rollback() 45 | ``` 46 | 47 | ## Migration files 48 | 49 | Migration files need to follow a standard format and must be present 50 | in the same directory. Given "up" and "down" steps for a migration, 51 | create a file for each by following this template: 52 | 53 | ``` 54 | {{ id }}_{{ name }}_{{ "up" or "down" }}.sql 55 | ``` 56 | 57 | For a given migration, the `id` and `name` fields must be the same. 58 | The id field is an integer that corresponds to the order in which 59 | the migration should run relative to the other migrations. 60 | 61 | `id` should not be `0` as that value is used for internal validations. 62 | 63 | ### Example 64 | 65 | If I'm trying to add a "users" table to the database, I would create 66 | the following two files: 67 | 68 | #### 1_add_users_table_up.sql 69 | 70 | ``` 71 | CREATE TABLE users(); 72 | ``` 73 | 74 | #### 1_add_users_table_down.sql 75 | ``` 76 | DROP TABLE users; 77 | ``` 78 | 79 | ## Copyright 80 | 81 | Copyright (c) 2014 David Huie. See LICENSE.txt for further details. 82 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package gomigrate 2 | 3 | import "strings" 4 | 5 | type Migratable interface { 6 | SelectMigrationTableSql() string 7 | CreateMigrationTableSql() string 8 | GetMigrationSql() string 9 | MigrationLogInsertSql() string 10 | MigrationLogDeleteSql() string 11 | GetMigrationCommands(string) []string 12 | } 13 | 14 | // POSTGRES 15 | 16 | type Postgres struct{} 17 | 18 | func (p Postgres) SelectMigrationTableSql() string { 19 | return "SELECT tablename FROM pg_catalog.pg_tables WHERE tablename = $1" 20 | } 21 | 22 | func (p Postgres) CreateMigrationTableSql() string { 23 | return `CREATE TABLE gomigrate ( 24 | id SERIAL PRIMARY KEY, 25 | migration_id BIGINT UNIQUE NOT NULL 26 | )` 27 | } 28 | 29 | func (p Postgres) GetMigrationSql() string { 30 | return `SELECT migration_id FROM gomigrate WHERE migration_id = $1` 31 | } 32 | 33 | func (p Postgres) MigrationLogInsertSql() string { 34 | return "INSERT INTO gomigrate (migration_id) values ($1)" 35 | } 36 | 37 | func (p Postgres) MigrationLogDeleteSql() string { 38 | return "DELETE FROM gomigrate WHERE migration_id = $1" 39 | } 40 | 41 | func (p Postgres) GetMigrationCommands(sql string) []string { 42 | return []string{sql} 43 | } 44 | 45 | // MYSQL 46 | 47 | type Mysql struct{} 48 | 49 | func (m Mysql) SelectMigrationTableSql() string { 50 | return "SELECT table_name FROM information_schema.tables WHERE table_name = ? AND table_schema = (SELECT DATABASE())" 51 | } 52 | 53 | func (m Mysql) CreateMigrationTableSql() string { 54 | return `CREATE TABLE gomigrate ( 55 | id INT NOT NULL AUTO_INCREMENT, 56 | migration_id BIGINT NOT NULL UNIQUE, 57 | PRIMARY KEY (id) 58 | )` 59 | } 60 | 61 | func (m Mysql) GetMigrationSql() string { 62 | return `SELECT migration_id FROM gomigrate WHERE migration_id = ?` 63 | } 64 | 65 | func (m Mysql) MigrationLogInsertSql() string { 66 | return "INSERT INTO gomigrate (migration_id) values (?)" 67 | } 68 | 69 | func (m Mysql) MigrationLogDeleteSql() string { 70 | return "DELETE FROM gomigrate WHERE migration_id = ?" 71 | } 72 | 73 | func (m Mysql) GetMigrationCommands(sql string) []string { 74 | count := strings.Count(sql, ";") 75 | commands := strings.SplitN(string(sql), ";", count) 76 | return commands 77 | } 78 | 79 | // MARIADB 80 | 81 | type Mariadb struct { 82 | Mysql 83 | } 84 | 85 | // SQLITE3 86 | 87 | type Sqlite3 struct{} 88 | 89 | func (s Sqlite3) SelectMigrationTableSql() string { 90 | return "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?" 91 | } 92 | 93 | func (s Sqlite3) CreateMigrationTableSql() string { 94 | return `CREATE TABLE gomigrate ( 95 | id INTEGER PRIMARY KEY, 96 | migration_id INTEGER NOT NULL UNIQUE 97 | )` 98 | } 99 | 100 | func (s Sqlite3) GetMigrationSql() string { 101 | return "SELECT migration_id FROM gomigrate WHERE migration_id = ?" 102 | } 103 | 104 | func (s Sqlite3) MigrationLogInsertSql() string { 105 | return "INSERT INTO gomigrate (migration_id) values (?)" 106 | } 107 | 108 | func (s Sqlite3) MigrationLogDeleteSql() string { 109 | return "DELETE FROM gomigrate WHERE migration_id = ?" 110 | } 111 | 112 | func (s Sqlite3) GetMigrationCommands(sql string) []string { 113 | return []string{sql} 114 | } 115 | 116 | // SqlServer 117 | 118 | type SqlServer struct{} 119 | 120 | func (s SqlServer) SelectMigrationTableSql() string { 121 | return "SELECT name FROM sys.objects WHERE object_id = object_id(?)" 122 | } 123 | 124 | func (s SqlServer) CreateMigrationTableSql() string { 125 | return `CREATE TABLE gomigrate ( 126 | id INT IDENTITY(1,1) PRIMARY KEY, 127 | migration_id BIGINT NOT NULL 128 | )` 129 | } 130 | 131 | func (s SqlServer) GetMigrationSql() string { 132 | return `SELECT migration_id FROM gomigrate WHERE migration_id = ?` 133 | } 134 | 135 | func (s SqlServer) MigrationLogInsertSql() string { 136 | return "INSERT INTO gomigrate (migration_id) values (?)" 137 | } 138 | 139 | func (s SqlServer) MigrationLogDeleteSql() string { 140 | return "DELETE FROM gomigrate WHERE migration_id = ?" 141 | } 142 | 143 | func (s SqlServer) GetMigrationCommands(sql string) []string { 144 | return []string{sql} 145 | } 146 | -------------------------------------------------------------------------------- /gomigrate.go: -------------------------------------------------------------------------------- 1 | // A simple database migrator for PostgreSQL. 2 | 3 | package gomigrate 4 | 5 | import ( 6 | "database/sql" 7 | "errors" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "path/filepath" 12 | "sort" 13 | ) 14 | 15 | type migrationType string 16 | 17 | const ( 18 | migrationTableName = "gomigrate" 19 | upMigration = migrationType("up") 20 | downMigration = migrationType("down") 21 | ) 22 | 23 | var ( 24 | InvalidMigrationFile = errors.New("Invalid migration file") 25 | InvalidMigrationPair = errors.New("Invalid pair of migration files") 26 | InvalidMigrationsPath = errors.New("Invalid migrations path") 27 | InvalidMigrationType = errors.New("Invalid migration type") 28 | NoActiveMigrations = errors.New("No active migrations to rollback") 29 | ) 30 | 31 | type Migrator struct { 32 | DB *sql.DB 33 | MigrationsPath string 34 | dbAdapter Migratable 35 | migrations map[uint64]*Migration 36 | logger Logger 37 | } 38 | 39 | type Logger interface { 40 | Print(v ...interface{}) 41 | Printf(format string, v ...interface{}) 42 | Println(v ...interface{}) 43 | Fatalf(format string, v ...interface{}) 44 | } 45 | 46 | // Returns true if the migration table already exists. 47 | func (m *Migrator) MigrationTableExists() (bool, error) { 48 | row := m.DB.QueryRow(m.dbAdapter.SelectMigrationTableSql(), migrationTableName) 49 | var tableName string 50 | err := row.Scan(&tableName) 51 | if err == sql.ErrNoRows { 52 | m.logger.Print("Migrations table not found") 53 | return false, nil 54 | } 55 | if err != nil { 56 | m.logger.Printf("Error checking for migration table: %v", err) 57 | return false, err 58 | } 59 | m.logger.Print("Migrations table found") 60 | return true, nil 61 | } 62 | 63 | // Creates the migrations table if it doesn't exist. 64 | func (m *Migrator) CreateMigrationsTable() error { 65 | _, err := m.DB.Exec(m.dbAdapter.CreateMigrationTableSql()) 66 | if err != nil { 67 | m.logger.Fatalf("Error creating migrations table: %v", err) 68 | } 69 | 70 | m.logger.Printf("Created migrations table: %s", migrationTableName) 71 | 72 | return nil 73 | } 74 | 75 | // Returns a new migrator. 76 | func NewMigrator(db *sql.DB, adapter Migratable, migrationsPath string) (*Migrator, error) { 77 | return NewMigratorWithLogger(db, adapter, migrationsPath, log.New(os.Stderr, "[gomigrate] ", log.LstdFlags)) 78 | } 79 | 80 | // Returns a new migrator with the specified logger. 81 | func NewMigratorWithLogger(db *sql.DB, adapter Migratable, migrationsPath string, logger Logger) (*Migrator, error) { 82 | // Normalize the migrations path. 83 | path := []byte(migrationsPath) 84 | pathLength := len(path) 85 | if path[pathLength-1] != '/' { 86 | path = append(path, '/') 87 | } 88 | 89 | logger.Printf("Migrations path: %s", path) 90 | 91 | migrator := Migrator{ 92 | db, 93 | string(path), 94 | adapter, 95 | make(map[uint64]*Migration), 96 | logger, 97 | } 98 | 99 | // Create the migrations table if it doesn't exist. 100 | tableExists, err := migrator.MigrationTableExists() 101 | if err != nil { 102 | return nil, err 103 | } 104 | if !tableExists { 105 | if err := migrator.CreateMigrationsTable(); err != nil { 106 | return nil, err 107 | } 108 | } 109 | 110 | // Get all metadata from the database. 111 | if err := migrator.fetchMigrations(); err != nil { 112 | return nil, err 113 | } 114 | if err := migrator.getMigrationStatuses(); err != nil { 115 | return nil, err 116 | } 117 | 118 | return &migrator, nil 119 | } 120 | 121 | // Populates a migrator with a sorted list of migrations from the file system. 122 | func (m *Migrator) fetchMigrations() error { 123 | pathGlob := append([]byte(m.MigrationsPath), []byte("*")...) 124 | 125 | matches, err := filepath.Glob(string(pathGlob)) 126 | if err != nil { 127 | m.logger.Fatalf("Error while globbing migrations: %v", err) 128 | } 129 | 130 | for _, match := range matches { 131 | num, migrationType, name, err := parseMigrationPath(match) 132 | if err != nil { 133 | m.logger.Printf("Invalid migration file found: %s", match) 134 | continue 135 | } 136 | 137 | m.logger.Printf("Migration file found: %s", match) 138 | 139 | migration, ok := m.migrations[num] 140 | if !ok { 141 | migration = &Migration{Id: num, Name: name, Status: Inactive} 142 | m.migrations[num] = migration 143 | } 144 | if migrationType == upMigration { 145 | migration.UpPath = match 146 | } else { 147 | migration.DownPath = match 148 | } 149 | } 150 | 151 | // Validate each migration. 152 | for _, migration := range m.migrations { 153 | if !migration.valid() { 154 | path := migration.UpPath 155 | if path == "" { 156 | path = migration.DownPath 157 | } 158 | m.logger.Printf("Invalid migration pair for path: %s", path) 159 | return InvalidMigrationPair 160 | } 161 | } 162 | 163 | m.logger.Printf("Migrations file pairs found: %v", len(m.migrations)) 164 | 165 | return nil 166 | } 167 | 168 | // Queries the migration table to determine the status of each 169 | // migration. 170 | func (m *Migrator) getMigrationStatuses() error { 171 | for _, migration := range m.migrations { 172 | row := m.DB.QueryRow(m.dbAdapter.GetMigrationSql(), migration.Id) 173 | var mid uint64 174 | err := row.Scan(&mid) 175 | if err == sql.ErrNoRows { 176 | continue 177 | } 178 | if err != nil { 179 | m.logger.Printf( 180 | "Error getting migration status for %s: %v", 181 | migration.Name, 182 | err, 183 | ) 184 | return err 185 | } 186 | migration.Status = Active 187 | } 188 | return nil 189 | } 190 | 191 | // Returns a sorted list of migration ids for a given status. -1 returns 192 | // all migrations. 193 | func (m *Migrator) Migrations(status int) []*Migration { 194 | // Sort all migration ids. 195 | ids := make([]uint64, 0) 196 | for id, _ := range m.migrations { 197 | ids = append(ids, id) 198 | } 199 | sort.Sort(uint64slice(ids)) 200 | 201 | // Find ids for the given status. 202 | migrations := make([]*Migration, 0) 203 | for _, id := range ids { 204 | migration := m.migrations[id] 205 | if status == -1 || migration.Status == status { 206 | migrations = append(migrations, migration) 207 | } 208 | } 209 | return migrations 210 | } 211 | 212 | // Applies a single migration. 213 | func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) error { 214 | var path string 215 | if mType == upMigration { 216 | path = migration.UpPath 217 | } else if mType == downMigration { 218 | path = migration.DownPath 219 | } else { 220 | return InvalidMigrationType 221 | } 222 | 223 | m.logger.Printf("Applying migration: %s", path) 224 | 225 | sql, err := ioutil.ReadFile(path) 226 | if err != nil { 227 | m.logger.Printf("Error reading migration: %s", path) 228 | return err 229 | } 230 | transaction, err := m.DB.Begin() 231 | if err != nil { 232 | m.logger.Printf("Error opening transaction: %v", err) 233 | return err 234 | } 235 | 236 | // Certain adapters can not handle multiple sql commands in one file so we need the adapter to split up the command 237 | commands := m.dbAdapter.GetMigrationCommands(string(sql)) 238 | 239 | // Perform the migration. 240 | for _, cmd := range commands { 241 | result, err := transaction.Exec(cmd) 242 | if err != nil { 243 | m.logger.Printf("Error executing migration: %v", err) 244 | if rollbackErr := transaction.Rollback(); rollbackErr != nil { 245 | m.logger.Printf("Error rolling back transaction: %v", rollbackErr) 246 | return rollbackErr 247 | } 248 | return err 249 | } 250 | if result != nil { 251 | if rowsAffected, err := result.RowsAffected(); err != nil { 252 | m.logger.Printf("Error getting rows affected: %v", err) 253 | if rollbackErr := transaction.Rollback(); rollbackErr != nil { 254 | m.logger.Printf("Error rolling back transaction: %v", rollbackErr) 255 | return rollbackErr 256 | } 257 | return err 258 | } else { 259 | m.logger.Printf("Rows affected: %v", rowsAffected) 260 | } 261 | } 262 | } 263 | 264 | // Log the event. 265 | if mType == upMigration { 266 | _, err = transaction.Exec( 267 | m.dbAdapter.MigrationLogInsertSql(), 268 | migration.Id, 269 | ) 270 | } else { 271 | _, err = transaction.Exec( 272 | m.dbAdapter.MigrationLogDeleteSql(), 273 | migration.Id, 274 | ) 275 | } 276 | if err != nil { 277 | m.logger.Printf("Error logging migration: %v", err) 278 | if rollbackErr := transaction.Rollback(); rollbackErr != nil { 279 | m.logger.Printf("Error rolling back transaction: %v", rollbackErr) 280 | return rollbackErr 281 | } 282 | return err 283 | } 284 | 285 | // Commit and update the struct status. 286 | if err := transaction.Commit(); err != nil { 287 | m.logger.Printf("Error commiting transaction: %v", err) 288 | return err 289 | } 290 | if mType == upMigration { 291 | migration.Status = Active 292 | } else { 293 | migration.Status = Inactive 294 | } 295 | 296 | return nil 297 | } 298 | 299 | // Applies all inactive migrations. 300 | func (m *Migrator) Migrate() error { 301 | for _, migration := range m.Migrations(Inactive) { 302 | if err := m.ApplyMigration(migration, upMigration); err != nil { 303 | return err 304 | } 305 | } 306 | return nil 307 | } 308 | 309 | // Rolls back the last migration. 310 | func (m *Migrator) Rollback() error { 311 | return m.RollbackN(1) 312 | } 313 | 314 | // Rolls back N migrations. 315 | func (m *Migrator) RollbackN(n int) error { 316 | migrations := m.Migrations(Active) 317 | if len(migrations) == 0 { 318 | return nil 319 | } 320 | 321 | last_migration := len(migrations) - 1 - n 322 | 323 | for i := len(migrations) - 1; i != last_migration; i-- { 324 | if err := m.ApplyMigration(migrations[i], downMigration); err != nil { 325 | return err 326 | } 327 | } 328 | 329 | return nil 330 | } 331 | 332 | // Rolls back all migrations. 333 | func (m *Migrator) RollbackAll() error { 334 | migrations := m.Migrations(Active) 335 | return m.RollbackN(len(migrations)) 336 | } 337 | -------------------------------------------------------------------------------- /gomigrate_test.go: -------------------------------------------------------------------------------- 1 | package gomigrate 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "log" 7 | "os" 8 | "testing" 9 | 10 | _ "github.com/go-sql-driver/mysql" 11 | _ "github.com/lib/pq" 12 | _ "github.com/mattn/go-sqlite3" 13 | ) 14 | 15 | var ( 16 | db *sql.DB 17 | adapter Migratable 18 | dbType string 19 | ) 20 | 21 | func GetMigrator(test string) *Migrator { 22 | path := fmt.Sprintf("test_migrations/%s_%s", test, dbType) 23 | m, err := NewMigrator(db, adapter, path) 24 | if err != nil { 25 | panic(err) 26 | } 27 | return m 28 | } 29 | 30 | func TestNewMigrator(t *testing.T) { 31 | m := GetMigrator("test1") 32 | switch { 33 | case dbType == "pg" && len(m.migrations) != 4: 34 | t.Errorf("Invalid number of migrations detected") 35 | 36 | case dbType == "mysql" && len(m.migrations) != 1: 37 | t.Errorf("Invalid number of migrations detected") 38 | 39 | case dbType == "sqlite3" && len(m.migrations) != 1: 40 | t.Errorf("Invalid number of migrations detected") 41 | } 42 | 43 | migration := m.migrations[1] 44 | 45 | if migration.Name != "test" { 46 | t.Errorf("Invalid migration name detected: %s", migration.Name) 47 | } 48 | if migration.Id != 1 { 49 | t.Errorf("Invalid migration num detected: %d", migration.Id) 50 | } 51 | if migration.Status != Inactive { 52 | t.Errorf("Invalid migration num detected: %d", migration.Status) 53 | } 54 | 55 | cleanup() 56 | } 57 | 58 | func TestCreatingMigratorWhenTableExists(t *testing.T) { 59 | // Create the table and populate it with a row. 60 | _, err := db.Exec(adapter.CreateMigrationTableSql()) 61 | if err != nil { 62 | t.Error(err) 63 | } 64 | _, err = db.Exec(adapter.MigrationLogInsertSql(), 123) 65 | if err != nil { 66 | t.Error(err) 67 | } 68 | 69 | GetMigrator("test1") 70 | 71 | // Check that our row is still present. 72 | row := db.QueryRow("select migration_id from gomigrate") 73 | var id uint64 74 | err = row.Scan(&id) 75 | if err != nil { 76 | t.Error(err) 77 | } 78 | if id != 123 { 79 | t.Error("Invalid id found in database") 80 | } 81 | cleanup() 82 | } 83 | 84 | func TestMigrationAndRollback(t *testing.T) { 85 | m := GetMigrator("test1") 86 | 87 | if err := m.Migrate(); err != nil { 88 | t.Error(err) 89 | } 90 | 91 | // Ensure that the migration ran. 92 | row := db.QueryRow( 93 | adapter.SelectMigrationTableSql(), 94 | "test", 95 | ) 96 | var tableName string 97 | if err := row.Scan(&tableName); err != nil { 98 | t.Error(err) 99 | } 100 | if tableName != "test" { 101 | t.Errorf("Migration table not created") 102 | } 103 | // Ensure that the migrate status is correct. 104 | row = db.QueryRow( 105 | adapter.GetMigrationSql(), 106 | 1, 107 | ) 108 | var status int 109 | if err := row.Scan(&status); err != nil { 110 | t.Error(err) 111 | } 112 | if status != Active || m.migrations[1].Status != Active { 113 | t.Error("Invalid status for migration") 114 | } 115 | if err := m.RollbackN(len(m.migrations)); err != nil { 116 | t.Error(err) 117 | } 118 | 119 | // Ensure that the down migration ran. 120 | row = db.QueryRow( 121 | adapter.SelectMigrationTableSql(), 122 | "test", 123 | ) 124 | err := row.Scan(&tableName) 125 | if err != nil && err != sql.ErrNoRows { 126 | t.Errorf("Migration table should be deleted: %v", err) 127 | } 128 | 129 | // Ensure that the migration log is missing. 130 | row = db.QueryRow( 131 | adapter.GetMigrationSql(), 132 | 1, 133 | ) 134 | if err := row.Scan(&status); err != nil && err != sql.ErrNoRows { 135 | t.Error(err) 136 | } 137 | if m.migrations[1].Status != Inactive { 138 | t.Errorf("Invalid status for migration, expected: %d, got: %v", Inactive, m.migrations[1].Status) 139 | } 140 | 141 | cleanup() 142 | } 143 | 144 | func cleanup() { 145 | _, err := db.Exec("drop table gomigrate") 146 | if err != nil { 147 | panic(err) 148 | } 149 | } 150 | 151 | func init() { 152 | var err error 153 | 154 | switch os.Getenv("DB") { 155 | case "mysql": 156 | dbType = "mysql" 157 | log.Print("Using mysql") 158 | adapter = Mariadb{} 159 | db, err = sql.Open("mysql", "gomigrate:password@/gomigrate") 160 | case "sqlite3": 161 | dbType = "sqlite3" 162 | log.Print("Using sqlite3") 163 | adapter = Sqlite3{} 164 | db, err = sql.Open("sqlite3", "file::memory:?cache=shared") 165 | default: 166 | dbType = "pg" 167 | log.Print("Using postgres") 168 | adapter = Postgres{} 169 | db, err = sql.Open("postgres", "host=localhost dbname=gomigrate sslmode=disable") 170 | } 171 | 172 | if err != nil { 173 | panic(err) 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /migration.go: -------------------------------------------------------------------------------- 1 | // Holds metadata about a migration. 2 | 3 | package gomigrate 4 | 5 | // Migration statuses. 6 | const ( 7 | Inactive = iota 8 | Active 9 | ) 10 | 11 | // Holds configuration information for a given migration. 12 | type Migration struct { 13 | DownPath string 14 | Id uint64 15 | Name string 16 | Status int 17 | UpPath string 18 | } 19 | 20 | // Performs a basic validation of a migration. 21 | func (m *Migration) valid() bool { 22 | if m.Id != 0 && m.Name != "" && m.UpPath != "" && m.DownPath != "" { 23 | return true 24 | } 25 | return false 26 | } 27 | -------------------------------------------------------------------------------- /test_migrations/test1_mysql/1_test_down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE test; 2 | DROP TABLE test2; 3 | DROP TABLE tt; 4 | -------------------------------------------------------------------------------- /test_migrations/test1_mysql/1_test_up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE test ( 2 | id INT NOT NULL AUTO_INCREMENT, 3 | PRIMARY KEY (id) 4 | ); 5 | 6 | CREATE TABLE test2 ( 7 | id INT NOT NULL AUTO_INCREMENT, 8 | PRIMARY KEY (id) 9 | ); 10 | 11 | CREATE TABLE tt (c text NOT NULL); 12 | 13 | INSERT INTO tt VALUES('a'); 14 | INSERT INTO tt VALUES('x'); 15 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/1_test_down.sql: -------------------------------------------------------------------------------- 1 | drop table if exists test; 2 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/1_test_up.sql: -------------------------------------------------------------------------------- 1 | create table if not exists test(); 2 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/2_create_index_function_if_not_exists_down.sql: -------------------------------------------------------------------------------- 1 | drop function if exists create_index_if_not_exists (t_name text, i_name text, index_sql text); -------------------------------------------------------------------------------- /test_migrations/test1_pg/2_create_index_function_if_not_exists_up.sql: -------------------------------------------------------------------------------- 1 | -- this function allows us to create indexes if they don't exist 2 | create or replace function create_index_if_not_exists (t_name text, i_name text, index_sql text) returns void as $$ 3 | declare 4 | full_index_name varchar; 5 | schema_name varchar; 6 | begin 7 | 8 | full_index_name = t_name || '_' || i_name; 9 | schema_name = 'public'; 10 | 11 | if not exists ( 12 | select 1 13 | from pg_class c 14 | join pg_namespace n on n.oid = c.relnamespace 15 | where c.relname = full_index_name 16 | and n.nspname = schema_name 17 | ) then 18 | 19 | execute 'create index ' || full_index_name || ' on ' || schema_name || '.' || t_name || ' ' || index_sql; 20 | end if; 21 | end 22 | $$ 23 | language plpgsql volatile; -------------------------------------------------------------------------------- /test_migrations/test1_pg/3_multiple_inserts_down.sql: -------------------------------------------------------------------------------- 1 | drop table tt; 2 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/3_multiple_inserts_up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE tt ( 2 | c text NOT NULL 3 | ); 4 | insert into tt values ('a'); 5 | insert into tt values ('x'); 6 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/4_dash-test_down.sql: -------------------------------------------------------------------------------- 1 | select 1; 2 | -------------------------------------------------------------------------------- /test_migrations/test1_pg/4_dash-test_up.sql: -------------------------------------------------------------------------------- 1 | select 1; 2 | -------------------------------------------------------------------------------- /test_migrations/test1_sqlite3/1_test_down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE test; 2 | -------------------------------------------------------------------------------- /test_migrations/test1_sqlite3/1_test_up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE test ( 2 | id INTEGER PRIMARY KEY 3 | ) 4 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package gomigrate 2 | 3 | import ( 4 | "path/filepath" 5 | "regexp" 6 | "strconv" 7 | ) 8 | 9 | var ( 10 | upMigrationFile = regexp.MustCompile(`(\d+)_([\w-]+)_up\.sql`) 11 | downMigrationFile = regexp.MustCompile(`(\d+)_([\w-]+)_down\.sql`) 12 | subMigrationSplit = regexp.MustCompile(`;\s*`) 13 | allWhitespace = regexp.MustCompile(`^\s*$`) 14 | ) 15 | 16 | // Returns the migration number, type and base name, so 1, "up", "migration" from "01_migration_up.sql" 17 | func parseMigrationPath(path string) (uint64, migrationType, string, error) { 18 | filebase := filepath.Base(path) 19 | 20 | matches := upMigrationFile.FindAllSubmatch([]byte(filebase), -1) 21 | if matches != nil { 22 | return parseMatches(matches, upMigration) 23 | } 24 | matches = downMigrationFile.FindAllSubmatch([]byte(filebase), -1) 25 | if matches != nil { 26 | return parseMatches(matches, downMigration) 27 | } 28 | 29 | return 0, "", "", InvalidMigrationFile 30 | } 31 | 32 | // Parses matches given by a migration file regex. 33 | func parseMatches(matches [][][]byte, mType migrationType) (uint64, migrationType, string, error) { 34 | num := matches[0][1] 35 | name := matches[0][2] 36 | parsedNum, err := strconv.ParseUint(string(num), 10, 64) 37 | if err != nil { 38 | return 0, "", "", err 39 | } 40 | return parsedNum, mType, string(name), nil 41 | } 42 | 43 | // This type is used to sort migration ids. 44 | type uint64slice []uint64 45 | 46 | func (u uint64slice) Len() int { 47 | return len(u) 48 | } 49 | 50 | func (u uint64slice) Less(a, b int) bool { 51 | return u[a] < u[b] 52 | } 53 | 54 | func (u uint64slice) Swap(a, b int) { 55 | tempA := u[a] 56 | u[a] = u[b] 57 | u[b] = tempA 58 | } 59 | --------------------------------------------------------------------------------