├── .travis.yml ├── LICENSE.md ├── README.md ├── TODO.md ├── go.mod ├── go.sum ├── main.go └── pgmgr ├── config.go ├── config_test.go ├── dump_config.go ├── dump_config_test.go ├── pgmgr.go └── pgmgr_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - '1.12' 5 | - stable 6 | - master 7 | 8 | env: 9 | global: 10 | - GO111MODULE=on 11 | 12 | addons: 13 | postgresql: "10" 14 | 15 | matrix: 16 | allow_failures: 17 | - go: master 18 | 19 | script: 20 | - go get -u golang.org/x/lint/golint 21 | - test -z "$(golint ./...)" 22 | - test -z "$(gofmt -l .)" 23 | - createuser pgmgr -s 24 | - go test -v ./... 25 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016. rnubel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a 6 | copy 7 | of this software and associated documentation files (the "Software"), to 8 | deal 9 | in the Software without restriction, including without limitation the 10 | rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or 12 | sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included 17 | in 18 | all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 21 | OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 24 | THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 27 | FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 29 | IN 30 | THE SOFTWARE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Postgres Manager (pgmgr) 2 | [![Build Status](https://travis-ci.org/rnubel/pgmgr.svg?branch=master)](https://travis-ci.org/rnubel/pgmgr) 3 | 4 | Utility for web applications to manage their Postgres database in a 5 | reliable, consistent manner. Inspired by [mattes/migrate] 6 | (http://www.github.com/mattes/migrate), but with several benefits: 7 | 8 | * Migration version numbers are timestamp-based, not sequential. This saves 9 | significant headaches when multiple developers are working on the same 10 | application in parallel. However, `pgmgr` is still compatible with your 11 | old migrations. 12 | * `pgmgr` can generate database dumps, with seed data included, so that 13 | there's a single, authoritative source of your database structure. It's 14 | recommended you regularly check this file into source control. 15 | 16 | ## Installation 17 | 18 | ``` 19 | $ go install github.com/rnubel/pgmgr@latest 20 | ``` 21 | 22 | If you cannot run `pgmgr` after this, check that the directory Go install binaries to 23 | (the `GOBIN` environment variable, which defaults to `$GOPATH/bin` or `$HOME/go/bin`) 24 | is in your PATH. 25 | 26 | ## Getting Started 27 | 28 | First, create a `.pgmgr.json` file in your app, as described below. Then, 29 | generate your first migration: 30 | 31 | ``` 32 | $ pgmgr migration MyFirstMigration 33 | 34 | Created migrations/1433277961_MyFirstMigration.up.sql 35 | Created migrations/1433277961_MyFirstMigration.down.sql 36 | ``` 37 | 38 | Flesh it out: 39 | ``` 40 | $ echo 'CREATE TABLE foos (foo_id INTEGER)' > 1433277961_MyFirstMigration.up.sql 41 | ``` 42 | 43 | Bootstrap your database: 44 | ``` 45 | $ pgmgr db create 46 | Database pgmgr-test-app created successfully. 47 | ``` 48 | 49 | And apply your migration: 50 | ``` 51 | $ pgmgr db migrate 52 | == Applying 1433277961_MyFirstMigration.up.sql == 53 | == Completed in 8 ms == 54 | ``` 55 | 56 | ## Configuration 57 | 58 | `pgmgr` supports file-based configuration (useful for checking into your 59 | source code) and environment-based configuration, which always overrides 60 | the former (useful for production deploys, Docker usage, etc). 61 | 62 | ### .pgmgr.json 63 | 64 | By default, `pgmgr` will look for a file named `.pgmgr.json` in your 65 | working directory. You can override the file path with the environment 66 | variable `PGMGR_CONFIG_FILE`. It should look something like: 67 | 68 | ``` 69 | { 70 | "host": "localhost", 71 | "port": 5432, 72 | "username": "test", 73 | "password": "test", 74 | "database": "testdb", 75 | "sslmode": "disable", 76 | "migration-table": "public.schema_migrations", 77 | "migration-folder": "db/migrate", 78 | "dump-file": "db/dump.sql", 79 | "column-type": "integer", 80 | "format": "unix", 81 | "seed-tables": [ "foos", "bars" ] 82 | } 83 | ``` 84 | 85 | The `column-type` option can be `integer` or `string`, and determines 86 | the type of the `schema_migrations.version` column. The `string` column 87 | type will store versions as `CHARACTER VARYING (255)`. 88 | 89 | The `format` option can be `unix` or `datetime`. The `unix` format is 90 | the integer epoch time; the `datetime` uses versions similar to ActiveRecord, 91 | such as `20150910120933`. In order to use the `datetime` format, you must 92 | also use the `string` column type. 93 | 94 | The `migration-table` option can be used to specify an alternate table name 95 | in which to track migration status. It defaults to the schema un-qualified 96 | `schema_migrations`, which will typically create a table in the `public` 97 | schema unless the database's default search path has been modified. If you 98 | use a schema qualified name, pgmgr will attempt to create the schema first 99 | if it does not yet exist. 100 | 101 | `migration-driver`, added in August 2019, allows migrations to be run either 102 | through the Go `pq` library (which runs the migrations as a single multi-statement 103 | command) or through the `psql` command-line utility. The possible options are 104 | `'pq'` or `'psql'`. The default is currently `pq` (subject to change). 105 | 106 | ### Environment variables 107 | 108 | The values above map to these environment variables: 109 | 110 | * `PGMGR_HOST` 111 | * `PGMGR_PORT` 112 | * `PGMGR_USERNAME` 113 | * `PGMGR_PASSWORD` 114 | * `PGMGR_DATABASE` 115 | * `PGMGR_SSLMODE` 116 | * `PGMGR_DUMP_FILE` (the filepath to dump the database definition out to) 117 | * `PGMGR_SEED_TABLES` (tables to include data with when dumping the database) 118 | * `PGMGR_COLUMN_TYPE` 119 | * `PGMGR_FORMAT` 120 | * `PGMGR_MIGRATION_TABLE` 121 | * `PGMGR_MIGRATION_DRIVER` 122 | * `PGMGR_MIGRATION_FOLDER` 123 | 124 | If you prefer to use a connection string, you can set `PGMGR_URL` which will supersede the other configuration settings, e.g.: 125 | 126 | ``` 127 | PGMGR_URL='postgres://test:test@localhost/testdb?sslmode=require' 128 | ``` 129 | 130 | Also, for host, port, username, password, and database, if you haven't set a 131 | value via the config file, CLI arguments, or environment variables, pgmgr will 132 | look at the standard Postgres env vars (`PGHOST`, `PGUSERNAME`, etc). 133 | 134 | ## Usage 135 | 136 | ``` 137 | pgmgr migration MigrationName # generates files for a new migration 138 | pgmgr db create # creates the database if it doesn't exist 139 | pgmgr db drop # drop the database 140 | pgmgr db migrate # apply un-applied migrations 141 | pgmgr db rollback # reverts the latest migration, if possible. 142 | pgmgr db load # loads the schema dump file from PGMGR_DUMP_FILE 143 | pgmgr db dump # dumps the database structure & seeds to PGMGR_DUMP_FILE 144 | ``` 145 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | * Add `status` command to view status of all migrations in MigrationFolder 4 | * Support PGPASSFILE 5 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/rnubel/pgmgr 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/lib/pq v1.2.0 7 | github.com/urfave/cli v1.22.1 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= 3 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 4 | github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= 5 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= 9 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 10 | github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= 11 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 12 | github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= 13 | github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 15 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 16 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/rnubel/pgmgr/pgmgr" 8 | cli "github.com/urfave/cli" 9 | ) 10 | 11 | func displayErrorOrMessage(err error, args ...interface{}) error { 12 | if err != nil { 13 | return cli.NewExitError(fmt.Sprintln("Error: ", err), 1) 14 | } 15 | 16 | fmt.Println(args...) 17 | return nil 18 | } 19 | 20 | func displayVersion(config *pgmgr.Config) error { 21 | v, err := pgmgr.Version(config) 22 | if v < 0 { 23 | return displayErrorOrMessage(err, "Database has no schema_migrations table; run `pgmgr db migrate` to create it.") 24 | } 25 | 26 | return displayErrorOrMessage(err, "Latest migration version:", v) 27 | } 28 | 29 | func main() { 30 | config := &pgmgr.Config{} 31 | app := cli.NewApp() 32 | 33 | app.Name = "pgmgr" 34 | app.Usage = "manage your app's Postgres database" 35 | app.Version = "1.1.5" 36 | 37 | var s []string 38 | 39 | app.Flags = []cli.Flag{ 40 | cli.StringFlag{ 41 | Name: "config-file, c", 42 | Value: ".pgmgr.json", 43 | Usage: "set the path to the JSON configuration file specifying your DB parameters", 44 | EnvVar: "PGMGR_CONFIG_FILE", 45 | }, 46 | cli.StringFlag{ 47 | Name: "database, d", 48 | Value: "", 49 | Usage: "the database name which pgmgr will connect to or try to create", 50 | EnvVar: "PGMGR_DATABASE", 51 | }, 52 | cli.StringFlag{ 53 | Name: "username, u", 54 | Value: "", 55 | Usage: "the username which pgmgr will connect with", 56 | EnvVar: "PGMGR_USERNAME", 57 | }, 58 | cli.StringFlag{ 59 | Name: "password, P", 60 | Value: "", 61 | Usage: "the password which pgmgr will connect with", 62 | EnvVar: "PGMGR_PASSWORD", 63 | }, 64 | cli.StringFlag{ 65 | Name: "host, H", 66 | Value: "", 67 | Usage: "the host which pgmgr will connect to", 68 | EnvVar: "PGMGR_HOST", 69 | }, 70 | cli.IntFlag{ 71 | Name: "port, p", 72 | Value: 0, 73 | Usage: "the port which pgmgr will connect to", 74 | EnvVar: "PGMGR_PORT", 75 | }, 76 | cli.StringFlag{ 77 | Name: "sslmode", 78 | Value: "", 79 | Usage: "whether to verify SSL connection or not. See https://www.postgresql.org/docs/9.1/static/libpq-ssl.html", 80 | EnvVar: "PGMGR_SSLMODE", 81 | }, 82 | cli.StringFlag{ 83 | Name: "url", 84 | Value: "", 85 | Usage: "connection URL or DSN containing connection info; will override the other params if given", 86 | EnvVar: "PGMGR_URL", 87 | }, 88 | cli.StringFlag{ 89 | Name: "dump-file", 90 | Value: "", 91 | Usage: "where to dump or load the database structure and contents to or from", 92 | EnvVar: "PGMGR_DUMP_FILE", 93 | }, 94 | cli.StringFlag{ 95 | Name: "column-type", 96 | Value: "", 97 | Usage: "column type to use in schema_migrations table; 'integer' or 'string' (default: integer)", 98 | EnvVar: "PGMGR_COLUMN_TYPE", 99 | }, 100 | cli.StringFlag{ 101 | Name: "format", 102 | Value: "", 103 | Usage: "timestamp format for migrations; 'unix' or 'datetime' (default: unix)", 104 | EnvVar: "PGMGR_FORMAT", 105 | }, 106 | cli.StringFlag{ 107 | Name: "migration-table", 108 | Value: "", 109 | Usage: "table to use for storing migration status; eg 'myschema.applied_migrations' (default: schema_migrations)", 110 | EnvVar: "PGMGR_MIGRATION_TABLE", 111 | }, 112 | cli.StringFlag{ 113 | Name: "migration-folder", 114 | Value: "", 115 | Usage: "folder containing the migrations to apply", 116 | EnvVar: "PGMGR_MIGRATION_FOLDER", 117 | }, 118 | cli.StringFlag{ 119 | Name: "migration-driver", 120 | Value: "", 121 | Usage: "how to apply the migrations. supported options are pq (which will execute the migration as one statement) or psql (which will use the psql binary on your system to execute each line) (default: pq)", 122 | EnvVar: "PGMGR_MIGRATION_DRIVER", 123 | }, 124 | cli.BoolFlag{ 125 | Name: "no-compress", 126 | Usage: "whether to skip compressing the database dump. See pg_dump -Z.", 127 | EnvVar: "PGMGR_NO_COMPRESS", 128 | }, 129 | cli.BoolFlag{ 130 | Name: "include-triggers", 131 | Usage: "whether to enable triggers on the dump. See pg_dump --disable-triggers.", 132 | EnvVar: "PGMGR_INCLUDE_TRIGGERS", 133 | }, 134 | cli.BoolFlag{ 135 | Name: "include-privileges", 136 | Usage: "whether to enable access privileges on the dump. See pg_dump -x.", 137 | EnvVar: "PGMGR_INCLUDE_PRIVILEGES", 138 | }, 139 | cli.StringSliceFlag{ 140 | Name: "seed-tables", 141 | Value: (*cli.StringSlice)(&s), 142 | Usage: "only dump data from tables matching these table names or globs. See pg_dump -t.", 143 | EnvVar: "PGMGR_SEED_TABLES", 144 | }, 145 | cli.StringSliceFlag{ 146 | Name: "exclude-schemas", 147 | Value: (*cli.StringSlice)(&s), 148 | Usage: "do not dump any schemas matching these schema names or globs. See pg_dump -N.", 149 | EnvVar: "PGMGR_EXCLUDE_SCHEMAS", 150 | }, 151 | } 152 | 153 | app.Before = func(c *cli.Context) error { 154 | return pgmgr.LoadConfig(config, c) 155 | } 156 | 157 | app.Commands = []cli.Command{ 158 | { 159 | Name: "migration", 160 | Usage: "generates a new migration with the given name", 161 | Flags: []cli.Flag{ 162 | cli.BoolFlag{ 163 | Name: "no-txn", 164 | Usage: "generate a migration that will not be wrapped in a transaction when run", 165 | }, 166 | }, 167 | Action: func(c *cli.Context) error { 168 | if len(c.Args()) == 0 { 169 | return cli.NewExitError("migration name not given! try `pgmgr migration NameGoesHere`", 1) 170 | } 171 | 172 | return displayErrorOrMessage(pgmgr.CreateMigration(config, c.Args()[0], c.Bool("no-txn"))) 173 | }, 174 | }, 175 | { 176 | Name: "config", 177 | Usage: "displays the current configuration as seen by pgmgr", 178 | Action: func(c *cli.Context) error { 179 | fmt.Printf("%+v\n", config) 180 | return nil 181 | }, 182 | }, 183 | { 184 | Name: "db", 185 | Usage: "manage your database. use 'pgmgr db help' for more info", 186 | Subcommands: []cli.Command{ 187 | { 188 | Name: "create", 189 | Usage: "creates the database if it doesn't exist", 190 | Action: func(c *cli.Context) error { 191 | return displayErrorOrMessage(pgmgr.Create(config), "Database", config.Database, "created successfully.") 192 | }, 193 | }, 194 | { 195 | Name: "drop", 196 | Usage: "drops the database (all sessions must be disconnected first. this command does not force it)", 197 | Action: func(c *cli.Context) error { 198 | return displayErrorOrMessage(pgmgr.Drop(config), "Database", config.Database, "dropped successfully.") 199 | }, 200 | }, 201 | { 202 | Name: "dump", 203 | Usage: "dumps the database schema and contents to the dump file (see --dump-file)", 204 | Action: func(c *cli.Context) error { 205 | err := pgmgr.Dump(config) 206 | return displayErrorOrMessage(err, "Database dumped to", config.DumpConfig.GetDumpFile(), "successfully") 207 | }, 208 | }, 209 | { 210 | Name: "load", 211 | Usage: "loads the database schema and contents from the dump file (see --dump-file)", 212 | Action: func(c *cli.Context) error { 213 | err := pgmgr.Load(config) 214 | err = displayErrorOrMessage(err, "Database loaded successfully.") 215 | if err != nil { 216 | return err 217 | } 218 | 219 | return displayVersion(config) 220 | }, 221 | }, 222 | { 223 | Name: "version", 224 | Usage: "returns the current schema version", 225 | Action: func(c *cli.Context) error { 226 | return displayVersion(config) 227 | }, 228 | }, 229 | { 230 | Name: "migrate", 231 | Usage: "applies any un-applied migrations in the migration folder (see --migration-folder)", 232 | Action: func(c *cli.Context) error { 233 | err := pgmgr.Migrate(config) 234 | if err != nil { 235 | return cli.NewExitError(fmt.Sprintln("Error during migration:", err), 1) 236 | } 237 | 238 | return nil 239 | }, 240 | }, 241 | { 242 | Name: "rollback", 243 | Usage: "rolls back the latest migration", 244 | Action: func(c *cli.Context) error { 245 | pgmgr.Rollback(config) 246 | return nil 247 | }, 248 | }, 249 | }, 250 | }, 251 | } 252 | 253 | app.Action = func(c *cli.Context) error { 254 | app.Command("help").Run(c) 255 | return nil 256 | } 257 | 258 | app.Run(os.Args) 259 | } 260 | -------------------------------------------------------------------------------- /pgmgr/config.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io/ioutil" 8 | "os" 9 | "regexp" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/lib/pq" 14 | ) 15 | 16 | // Something that stores key-value pairs of various types, 17 | // e.g., cli.Context. 18 | type argumentContext interface { 19 | String(string) string 20 | Int(string) int 21 | StringSlice(string) []string 22 | Bool(string) bool 23 | } 24 | 25 | // Config stores the options used by pgmgr. 26 | type Config struct { 27 | // connection 28 | Username string 29 | Password string 30 | Database string 31 | Host string 32 | Port int 33 | URL string 34 | SslMode string 35 | 36 | // dump 37 | DumpConfig DumpConfig `json:"dump-options"` 38 | 39 | // filepaths 40 | MigrationFolder string `json:"migration-folder"` 41 | 42 | // options 43 | MigrationTable string `json:"migration-table"` 44 | MigrationDriver string `json:"migration-driver"` 45 | ColumnType string `json:"column-type"` 46 | Format string 47 | 48 | // deprecated -- see dump_config.go 49 | DumpFile string `json:"dump-file"` 50 | SeedTables []string `json:"seed-tables"` 51 | } 52 | 53 | // LoadConfig reads the config file, applies CLI arguments as 54 | // overrides, and returns an error if the configuration is invalid. 55 | func LoadConfig(config *Config, ctx argumentContext) error { 56 | // load configuration from file first; then override with 57 | // flags or env vars if they're present. 58 | configFile := ctx.String("config-file") 59 | config.populateFromFile(configFile) 60 | 61 | // apply defaults from Postgres environment variables, but allow 62 | // them to be overridden in the next step 63 | config.populateFromPostgresVars() 64 | 65 | // apply some other, sane defaults 66 | config.applyDefaults() 67 | 68 | // override if passed-in from the CLI or via environment variables 69 | config.applyArguments(ctx) 70 | 71 | // if a connection URL was passed, use that instead for our connection 72 | // configuration 73 | if config.URL != "" { 74 | config.overrideFromURL() 75 | } 76 | return config.validate() 77 | } 78 | 79 | func (config *Config) populateFromFile(configFile string) { 80 | contents, err := ioutil.ReadFile(configFile) 81 | if err == nil { 82 | json.Unmarshal(contents, &config) 83 | } else { 84 | fmt.Println("error reading config file: ", err) 85 | } 86 | if config.DumpFile != "" { 87 | deprecatedDumpFieldWarning("dump-file") 88 | config.DumpConfig.DumpFile = config.DumpFile 89 | } 90 | if len(config.SeedTables) != 0 { 91 | deprecatedDumpFieldWarning("seed-tables") 92 | config.DumpConfig.IncludeTables = config.SeedTables 93 | } 94 | } 95 | 96 | func (config *Config) populateFromPostgresVars() { 97 | if os.Getenv("PGUSER") != "" { 98 | config.Username = os.Getenv("PGUSER") 99 | } 100 | if os.Getenv("PGPASSWORD") != "" { 101 | config.Password = os.Getenv("PGPASSWORD") 102 | } 103 | if os.Getenv("PGDATABASE") != "" { 104 | config.Database = os.Getenv("PGDATABASE") 105 | } 106 | if os.Getenv("PGHOST") != "" { 107 | config.Host = os.Getenv("PGHOST") 108 | } 109 | if os.Getenv("PGPORT") != "" { 110 | config.Port, _ = strconv.Atoi(os.Getenv("PGPORT")) 111 | } 112 | if os.Getenv("PGSSLMODE") != "" { 113 | config.SslMode = os.Getenv("PGSSLMODE") 114 | } 115 | } 116 | 117 | // DumpToEnv applies all applicable keys as PG environment variables, so that 118 | // shell commands will work on the correct target. 119 | func (config *Config) DumpToEnv() error { 120 | if err := os.Setenv("PGUSER", config.Username); err != nil { 121 | return err 122 | } 123 | if err := os.Setenv("PGPASSWORD", config.Password); err != nil { 124 | return err 125 | } 126 | if err := os.Setenv("PGDATABASE", config.Database); err != nil { 127 | return err 128 | } 129 | if err := os.Setenv("PGHOST", config.Host); err != nil { 130 | return err 131 | } 132 | if err := os.Setenv("PGPORT", fmt.Sprint(config.Port)); err != nil { 133 | return err 134 | } 135 | if err := os.Setenv("PGSSLMODE", config.SslMode); err != nil { 136 | return err 137 | } 138 | 139 | return nil 140 | } 141 | 142 | func (config *Config) applyDefaults() { 143 | if config.Port == 0 { 144 | config.Port = 5432 145 | } 146 | if config.Host == "" { 147 | config.Host = "localhost" 148 | } 149 | if config.Format == "" { 150 | config.Format = "unix" 151 | } 152 | if config.ColumnType == "" { 153 | config.ColumnType = "integer" 154 | } 155 | if config.MigrationTable == "" { 156 | config.MigrationTable = "schema_migrations" 157 | } 158 | if config.MigrationDriver == "" { 159 | config.MigrationDriver = "pq" 160 | } 161 | if config.SslMode == "" { 162 | config.SslMode = "disable" 163 | } 164 | config.DumpConfig.applyDefaults() 165 | } 166 | 167 | func (config *Config) applyArguments(ctx argumentContext) { 168 | if ctx.String("username") != "" { 169 | config.Username = ctx.String("username") 170 | } 171 | if ctx.String("password") != "" { 172 | config.Password = ctx.String("password") 173 | } 174 | if ctx.String("database") != "" { 175 | config.Database = ctx.String("database") 176 | } 177 | if ctx.String("host") != "" { 178 | config.Host = ctx.String("host") 179 | } 180 | if ctx.Int("port") != 0 { 181 | config.Port = ctx.Int("port") 182 | } 183 | if ctx.String("url") != "" { 184 | config.URL = ctx.String("url") 185 | } 186 | if ctx.String("sslmode") != "" { 187 | config.SslMode = ctx.String("sslmode") 188 | } 189 | if ctx.String("migration-folder") != "" { 190 | config.MigrationFolder = ctx.String("migration-folder") 191 | } 192 | if ctx.String("migration-driver") != "" { 193 | config.MigrationDriver = ctx.String("migration-driver") 194 | } 195 | if ctx.String("migration-table") != "" { 196 | config.MigrationTable = ctx.String("migration-table") 197 | } 198 | if ctx.String("column-type") != "" { 199 | config.ColumnType = ctx.String("column-type") 200 | } 201 | if ctx.String("format") != "" { 202 | config.Format = ctx.String("format") 203 | } 204 | config.DumpConfig.applyArguments(ctx) 205 | } 206 | 207 | func (config *Config) overrideFromURL() { 208 | // parse the DSN and populate the other configuration values. Some of the pg commands 209 | // accept a DSN parameter, but not all, so this will help unify things. 210 | r := regexp.MustCompile("^postgres://(.*)@(.*):([0-9]+)/([^?]+)") 211 | m := r.FindStringSubmatch(config.URL) 212 | if len(m) > 0 { 213 | user := m[1] 214 | config.Host = m[2] 215 | config.Port, _ = strconv.Atoi(m[3]) 216 | config.Database = m[4] 217 | 218 | userRegex := regexp.MustCompile("^(.*):(.*)$") 219 | userMatch := userRegex.FindStringSubmatch(user) 220 | 221 | if len(userMatch) > 0 { 222 | config.Username = userMatch[1] 223 | config.Password = userMatch[2] 224 | } else { 225 | config.Username = user 226 | } 227 | 228 | queryRegex := regexp.MustCompile("([a-zA-Z0-9_-]+)=([a-zA-Z0-9_-]+)") 229 | matches := queryRegex.FindAllStringSubmatch(config.URL, -1) 230 | for _, match := range matches { 231 | if match[1] == "sslmode" { 232 | config.SslMode = match[2] 233 | } 234 | } 235 | } else { 236 | println("Could not parse DSN: ", config.URL, " using regex ", r.String()) 237 | } 238 | } 239 | 240 | func (config *Config) validate() error { 241 | if config.ColumnType != "integer" && config.ColumnType != "string" { 242 | return errors.New(`ColumnType must be "integer" or "string"`) 243 | } 244 | 245 | if config.Format != "unix" && config.Format != "datetime" { 246 | return errors.New(`Format must be "unix" or "datetime"`) 247 | } 248 | 249 | if config.Format == "datetime" && config.ColumnType != "string" { 250 | return errors.New(`ColumnType must be "string" if Format is "datetime"`) 251 | } 252 | 253 | if config.MigrationDriver != "pq" && config.MigrationDriver != "psql" { 254 | return errors.New("MigrationDriver must be one of: pq, psql") 255 | } 256 | 257 | return nil 258 | } 259 | 260 | func (config *Config) quotedMigrationTable() string { 261 | if !strings.Contains(config.MigrationTable, ".") { 262 | return pq.QuoteIdentifier(config.MigrationTable) 263 | } 264 | 265 | tokens := strings.SplitN(config.MigrationTable, ".", 2) 266 | return pq.QuoteIdentifier(tokens[0]) + "." + pq.QuoteIdentifier(tokens[1]) 267 | } 268 | 269 | func (config *Config) versionColumnType() string { 270 | if config.ColumnType == "string" { 271 | return "CHARACTER VARYING (255)" 272 | } 273 | 274 | return "INTEGER" 275 | } 276 | 277 | func deprecatedDumpFieldWarning(field string) { 278 | fmt.Println( 279 | "WARN: Providing '"+field+"' as a top-level config key is deprecated.", 280 | "Specify it underneath the 'dump-options' key in your config file.", 281 | ) 282 | } 283 | -------------------------------------------------------------------------------- /pgmgr/config_test.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | // create a mock to replace cli.Context 9 | type TestContext struct { 10 | StringVals map[string]string 11 | IntVals map[string]int 12 | StringSliceVals map[string][]string 13 | BoolVals map[string]bool 14 | } 15 | 16 | func (t *TestContext) String(key string) string { 17 | return t.StringVals[key] 18 | } 19 | func (t *TestContext) Int(key string) int { 20 | return t.IntVals[key] 21 | } 22 | func (t *TestContext) StringSlice(key string) []string { 23 | return t.StringSliceVals[key] 24 | } 25 | func (t *TestContext) Bool(key string) bool { 26 | return t.BoolVals[key] 27 | } 28 | 29 | func TestDefaults(t *testing.T) { 30 | c := &Config{} 31 | 32 | LoadConfig(c, &TestContext{}) 33 | 34 | if c.Port != 5432 { 35 | t.Fatal("config's port should default to 5432") 36 | } 37 | 38 | if c.Host != "localhost" { 39 | t.Fatal("config's host should default to localhost, but was ", c.Host) 40 | } 41 | 42 | if c.MigrationTable != "schema_migrations" { 43 | t.Fatal("config's migration table should default to schema_migrations, but was ", c.MigrationTable) 44 | } 45 | 46 | if c.ColumnType != "integer" { 47 | t.Fatal("config's column type should default to integer, but was ", c.ColumnType) 48 | } 49 | 50 | if c.Format != "unix" { 51 | t.Fatal("config's format should default to unix, but was ", c.Format) 52 | } 53 | 54 | if c.SslMode != "disable" { 55 | t.Fatal("config's sslmode should default to 'disable', but was ", c.SslMode) 56 | } 57 | } 58 | 59 | func TestOverlays(t *testing.T) { 60 | c := &Config{} 61 | ctx := &TestContext{IntVals: make(map[string]int)} 62 | 63 | // should prefer the value from ctx, since 64 | // it was passed-in explictly at runtime 65 | c.Port = 123 66 | ctx.IntVals["port"] = 456 67 | os.Setenv("PGPORT", "789") 68 | 69 | LoadConfig(c, ctx) 70 | 71 | if c.Port != 456 { 72 | t.Fatal("config's port should come from the context, but was", c.Port) 73 | } 74 | 75 | // reset 76 | c = &Config{} 77 | ctx = &TestContext{IntVals: make(map[string]int)} 78 | 79 | // should prefer the value from PGPORT, since 80 | // nothing was passed-in at runtime 81 | c.Port = 123 82 | os.Setenv("PGPORT", "789") 83 | 84 | LoadConfig(c, ctx) 85 | 86 | if c.Port != 789 { 87 | t.Fatal("config's port should come from PGPORT, but was", c.Port) 88 | } 89 | 90 | // reset 91 | c = &Config{} 92 | ctx = &TestContext{IntVals: make(map[string]int)} 93 | 94 | // should prefer the value in the struct, since 95 | // nothing else is given 96 | c.Port = 123 97 | os.Setenv("PGPORT", "") 98 | 99 | LoadConfig(c, ctx) 100 | 101 | if c.Port != 123 { 102 | t.Fatal("config's port should not change, but was", c.Port) 103 | } 104 | 105 | //reset 106 | c = &Config{} 107 | ctx = &TestContext{StringVals: make(map[string]string)} 108 | 109 | // should prefer the value from ctx, since 110 | // it was passed-in explictly at runtime 111 | c.ColumnType = "integer" 112 | ctx.StringVals["column-type"] = "string" 113 | 114 | LoadConfig(c, ctx) 115 | 116 | if c.ColumnType != "string" { 117 | t.Fatal("config's column-type should come from the context, but was", c.ColumnType) 118 | } 119 | } 120 | 121 | func TestURL(t *testing.T) { 122 | c := &Config{} 123 | c.URL = "postgres://foo@bar:5431/test-db.one?sslmode=verify-ca" 124 | 125 | LoadConfig(c, &TestContext{}) 126 | 127 | if c.Username != "foo" || c.Host != "bar" || c.Port != 5431 || c.Database != "test-db.one" || c.SslMode != "verify-ca" { 128 | t.Fatal("config did not populate itself from the given URL:", c) 129 | } 130 | } 131 | 132 | func TestURLWithPassword(t *testing.T) { 133 | c := &Config{} 134 | c.URL = "postgres://foo:baz@bar:5431/test-db.one?sslmode=verify-ca" 135 | 136 | LoadConfig(c, &TestContext{}) 137 | 138 | if c.Username != "foo" || c.Password != "baz" || c.Host != "bar" || c.Port != 5431 || c.Database != "test-db.one" || c.SslMode != "verify-ca" { 139 | t.Fatal("config did not populate itself from the given URL:", c) 140 | } 141 | } 142 | 143 | func TestValidation(t *testing.T) { 144 | c := &Config{} 145 | c.Format = "wrong" 146 | 147 | if err := LoadConfig(c, &TestContext{}); err == nil { 148 | t.Fatal("LoadConfig should reject invalid Format value") 149 | } 150 | 151 | c.Format = "" 152 | c.ColumnType = "wrong" 153 | if err := LoadConfig(c, &TestContext{}); err == nil { 154 | t.Fatal("LoadConfig should reject invalid ColumnType value") 155 | } 156 | 157 | c.Format = "datetime" 158 | c.ColumnType = "integer" 159 | if err := LoadConfig(c, &TestContext{}); err == nil { 160 | t.Fatal("LoadConfig should prevent Format=datetime when ColumnType=integer") 161 | } 162 | } 163 | 164 | func TestQuotedMigrationTable(t *testing.T) { 165 | c := &Config{MigrationTable: "abc"} 166 | if c.quotedMigrationTable() != `"abc"` { 167 | t.Fatal(`Migration table should be "abc", got`, c.quotedMigrationTable()) 168 | } 169 | 170 | c.MigrationTable = "abc.def" 171 | if c.quotedMigrationTable() != `"abc"."def"` { 172 | t.Fatal(`Schema-qualified migration table should be "abc"."def", got`, c.quotedMigrationTable()) 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /pgmgr/dump_config.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import "strings" 4 | 5 | // DumpConfig stores the options used by pgmgr's dump tool 6 | // and defers connection-type options to the main config file 7 | type DumpConfig struct { 8 | // exclusions 9 | ExcludeSchemas []string `json:"exclude-schemas"` 10 | 11 | // inclusions 12 | IncludeTables []string `json:"seed-tables"` 13 | IncludePrivileges bool `json:"include-privileges"` 14 | IncludeTriggers bool `json:"include-triggers"` 15 | 16 | // options 17 | NoCompress bool `json:"no-compress"` 18 | DumpFile string `json:"dump-file"` 19 | } 20 | 21 | // GetDumpFileRaw returns the literal dump file name as configured 22 | func (config DumpConfig) GetDumpFileRaw() string { 23 | return config.DumpFile 24 | } 25 | 26 | // GetDumpFile returns the true dump file name 27 | // with or without the specified compression suffix 28 | func (config DumpConfig) GetDumpFile() string { 29 | if config.IsCompressed() { 30 | return config.DumpFile + ".gz" 31 | } 32 | return config.DumpFile 33 | } 34 | 35 | // IsCompressed returns whether compression is set 36 | func (config DumpConfig) IsCompressed() bool { 37 | return !config.NoCompress 38 | } 39 | 40 | func (config *DumpConfig) applyArguments(ctx argumentContext) { 41 | if sliceValuesGiven(ctx, "exclude-schemas") { 42 | config.ExcludeSchemas = ctx.StringSlice("exclude-schemas") 43 | } 44 | if sliceValuesGiven(ctx, "seed-tables") { 45 | config.IncludeTables = ctx.StringSlice("seed-tables") 46 | } 47 | if ctx.String("dump-file") != "" { 48 | config.DumpFile = ctx.String("dump-file") 49 | } 50 | if ctx.Bool("no-compress") { 51 | config.NoCompress = true 52 | } 53 | if ctx.Bool("include-privileges") { 54 | config.IncludePrivileges = true 55 | } 56 | if ctx.Bool("include-triggers") { 57 | config.IncludeTriggers = true 58 | } 59 | if strings.HasSuffix(config.DumpFile, ".gz") { 60 | config.DumpFile = config.DumpFile[0 : len(config.DumpFile)-3] 61 | config.NoCompress = false 62 | } 63 | } 64 | 65 | func (config *DumpConfig) applyDefaults() { 66 | if config.DumpFile == "" { 67 | config.DumpFile = "dump.sql" 68 | } 69 | } 70 | 71 | func sliceValuesGiven(ctx argumentContext, key string) bool { 72 | return ctx.StringSlice(key) != nil && len(ctx.StringSlice(key)) > 0 73 | } 74 | 75 | func (config DumpConfig) baseFlags() []string { 76 | var args []string 77 | for _, schema := range config.ExcludeSchemas { 78 | args = append(args, "-N", schema) 79 | } 80 | 81 | if config.IsCompressed() { 82 | args = append(args, "-Z", "9") 83 | } 84 | 85 | if !config.IncludePrivileges { 86 | args = append(args, "-x") 87 | } 88 | 89 | return args 90 | } 91 | 92 | func (config DumpConfig) schemaFlags() []string { 93 | args := config.baseFlags() 94 | return append(args, "--schema-only") 95 | } 96 | 97 | func (config DumpConfig) dataFlags() []string { 98 | args := config.baseFlags() 99 | 100 | for _, table := range config.IncludeTables { 101 | args = append(args, "-t", table) 102 | } 103 | 104 | if !config.IncludeTriggers { 105 | args = append(args, "--disable-triggers") 106 | } 107 | 108 | args = append(args, "--data-only") 109 | return args 110 | } 111 | -------------------------------------------------------------------------------- /pgmgr/dump_config_test.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestDumpFlags(test *testing.T) { 9 | c := DumpConfig{ 10 | IncludeTables: []string{"iTable1", "iTable2"}, 11 | ExcludeSchemas: []string{}, 12 | } 13 | 14 | flags := strings.Join(c.baseFlags(), " ") 15 | for _, t := range c.ExcludeSchemas { 16 | if !strings.Contains(flags, "-N "+t) { 17 | test.Fatal("Dump flags should flag each excluded schema with '-N', missing", t) 18 | } 19 | } 20 | if !strings.Contains(flags, "-Z 9") { 21 | test.Fatal("Dump flags should set compression level to 9 when NoCompress is 'f'") 22 | } 23 | if !strings.Contains(flags, "-x") { 24 | test.Fatal("Dump flags should set -x when IncludePrivileges is 'f'") 25 | } 26 | 27 | c.NoCompress = true 28 | c.IncludePrivileges = true 29 | flags = strings.Join(c.baseFlags(), " ") 30 | if strings.Contains(flags, "-Z 9") { 31 | test.Fatal("Dump flags should not set compression level to 9 when NoCompress is 't'") 32 | } 33 | 34 | if strings.Contains(flags, "-x") { 35 | test.Fatal("Dump flags should not set -x when IncludePrivileges is 't'") 36 | } 37 | 38 | flags = strings.Join(c.dataFlags(), " ") 39 | for _, t := range c.IncludeTables { 40 | if !strings.Contains(flags, "-t "+t) { 41 | test.Fatal("Data flags should flag each included table with '-t', missing", t) 42 | } 43 | } 44 | if !strings.Contains(flags, "--data-only") { 45 | test.Fatal("Data flags should mark --data-only") 46 | } 47 | if !strings.Contains(flags, "--disable-triggers") { 48 | test.Fatal("Data flags should set --disable-triggers when IncludeTriggers is 'f'") 49 | } 50 | 51 | c.IncludeTriggers = true 52 | flags = strings.Join(c.dataFlags(), " ") 53 | if strings.Contains(flags, "--disable-triggers") { 54 | test.Fatal("Data flags should not set --disable-triggers when IncludeTriggers is 't'") 55 | } 56 | 57 | flags = strings.Join(c.schemaFlags(), " ") 58 | if !strings.Contains(flags, "--schema-only") { 59 | test.Fatal("Schema flags should mark --schema-only") 60 | } 61 | } 62 | 63 | func TestDumpDefaults(t *testing.T) { 64 | c := &Config{} 65 | c.applyDefaults() 66 | 67 | if c.DumpConfig.DumpFile != "dump.sql" { 68 | t.Fatal("dump config's dump-file should default to 'dump.sql', but was ", c.DumpConfig.DumpFile) 69 | } 70 | 71 | if c.DumpConfig.NoCompress { 72 | t.Fatal("dump config's NoCompress should default to 'f', but was ", c.DumpConfig.NoCompress) 73 | } 74 | 75 | if c.DumpConfig.IncludePrivileges { 76 | t.Fatal("dump config's include privileges should default to 'f', but was ", c.DumpConfig.IncludePrivileges) 77 | } 78 | 79 | if c.DumpConfig.IncludeTriggers { 80 | t.Fatal("dump config's include triggers should default to 'f', but was ", c.DumpConfig.IncludeTriggers) 81 | } 82 | 83 | dumpContext := TestContext{ 84 | StringVals: map[string]string{ 85 | "dump-file": "dump.file.sql.gz", 86 | }, 87 | } 88 | LoadConfig(c, &dumpContext) 89 | 90 | if c.DumpConfig.DumpFile != "dump.file.sql" { 91 | t.Fatal("dump config should strip '.gz' suffix, but was ", c.DumpConfig.DumpFile) 92 | } 93 | if c.DumpConfig.NoCompress { 94 | t.Fatal("dump config should set NoCompress='f' if '.gz' suffix is present, but was ", c.DumpConfig.NoCompress) 95 | } 96 | } 97 | 98 | func TestDumpOverlays(t *testing.T) { 99 | c := &Config{} 100 | ctx := &TestContext{StringVals: make(map[string]string)} 101 | 102 | // should prefer the value from ctx, since 103 | // it was passed-in explictly at runtime 104 | c.DumpConfig.DumpFile = "structval" 105 | ctx.StringVals["dump-file"] = "stringval" 106 | 107 | LoadConfig(c, ctx) 108 | 109 | if c.DumpConfig.DumpFile != "stringval" { 110 | t.Fatal("config's dump file should come from the context, but was", c.DumpConfig.DumpFile) 111 | } 112 | 113 | // reset 114 | c = &Config{} 115 | ctx = &TestContext{StringVals: make(map[string]string)} 116 | 117 | // should prefer the value in the struct, since 118 | // nothing else is given 119 | c.DumpConfig.DumpFile = "structval" 120 | LoadConfig(c, ctx) 121 | 122 | if c.DumpConfig.DumpFile != "structval" { 123 | t.Fatal("config's dump file should not change, but was", c.DumpConfig.DumpFile) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /pgmgr/pgmgr.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | "regexp" 13 | "strconv" 14 | "strings" 15 | "time" 16 | 17 | "github.com/lib/pq" 18 | ) 19 | 20 | // Migration directions 21 | const ( 22 | DOWN = iota 23 | UP 24 | ) 25 | 26 | const datetimeFormat = "20060102130405" 27 | 28 | // Migration directions used for error message building 29 | const ( 30 | MIGRATION = "migration" 31 | ROLLBACK = "rollback" 32 | ) 33 | 34 | // Migration stores a single migration's version and filename. 35 | type Migration struct { 36 | Filename string 37 | Version int64 38 | } 39 | 40 | // WrapInTransaction returns whether the migration should be run within 41 | // a transaction. 42 | func (m Migration) WrapInTransaction() bool { 43 | return !strings.Contains(m.Filename, ".no_txn.") 44 | } 45 | 46 | // Create creates the database specified by the configuration. 47 | func Create(c *Config) error { 48 | if err := c.DumpToEnv(); err != nil { 49 | return err 50 | } 51 | 52 | return sh("createdb", []string{c.Database}) 53 | } 54 | 55 | // Drop drops the database specified by the configuration. 56 | func Drop(c *Config) error { 57 | if err := c.DumpToEnv(); err != nil { 58 | return err 59 | } 60 | 61 | return sh("dropdb", []string{c.Database}) 62 | } 63 | 64 | // Dump dumps the schema and contents of the database to the dump file. 65 | func Dump(c *Config) error { 66 | if err := c.DumpToEnv(); err != nil { 67 | return err 68 | } 69 | 70 | // See https://www.postgresql.org/docs/11/app-pgdump.html for flag details 71 | 72 | // first we want the structure to be dumped 73 | schemaDump, err := shRead("pg_dump", c.DumpConfig.schemaFlags()) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | // then we want the data to be dumped 79 | dataDump, err := shRead("pg_dump", c.DumpConfig.dataFlags()) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | // combine the results into one dump file. 85 | file, err := os.OpenFile(c.DumpConfig.GetDumpFile(), os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0600) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | file.Write(*schemaDump) 91 | file.Write(*dataDump) 92 | file.Close() 93 | 94 | return nil 95 | } 96 | 97 | // Load loads the database from the dump file using psql. 98 | func Load(c *Config) error { 99 | if err := c.DumpToEnv(); err != nil { 100 | return err 101 | } 102 | 103 | dumpFile := c.DumpConfig.GetDumpFile() 104 | dumpFileRaw := c.DumpConfig.GetDumpFileRaw() 105 | if _, err := os.Stat(dumpFile); os.IsNotExist(err) { 106 | fmt.Println("Dump file does not exist or was not provided. Exiting.") 107 | return nil 108 | } 109 | 110 | if c.DumpConfig.IsCompressed() { 111 | dumpSQL, err := shRead("gunzip", []string{"-c", dumpFile}) 112 | if err != nil { 113 | return err 114 | } 115 | 116 | defer func() { sh("rm", []string{"-f", dumpFileRaw}) }() 117 | 118 | file, err := os.OpenFile(dumpFileRaw, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0600) 119 | if err != nil { 120 | return err 121 | } 122 | 123 | file.Write(*dumpSQL) 124 | file.Close() 125 | } 126 | 127 | return sh("psql", []string{"-d", c.Database, "-f", dumpFileRaw}) 128 | } 129 | 130 | // Migrate applies un-applied migrations in the specified MigrationFolder. 131 | func Migrate(c *Config) error { 132 | migrations, err := migrations(c, "up") 133 | if err != nil { 134 | return err 135 | } 136 | 137 | // ensure the version table is created 138 | if err := Initialize(c); err != nil { 139 | return err 140 | } 141 | 142 | appliedAny := false 143 | for _, m := range migrations { 144 | if applied, _ := migrationIsApplied(c, m.Version); !applied { 145 | fmt.Println("== Applying", m.Filename, "==") 146 | t0 := time.Now() 147 | 148 | if err = applyMigration(c, m, UP); err != nil { // halt the migration process and return the error. 149 | printFailedMigrationMessage(err, MIGRATION) 150 | return err 151 | } 152 | 153 | fmt.Println("== Completed in", time.Now().Sub(t0).Nanoseconds()/1e6, "ms ==") 154 | appliedAny = true 155 | } 156 | } 157 | 158 | if !appliedAny { 159 | fmt.Println("Nothing to do; all migrations already applied.") 160 | } 161 | 162 | return nil 163 | } 164 | 165 | // Rollback un-applies the latest migration, if possible. 166 | func Rollback(c *Config) error { 167 | migrations, err := migrations(c, "down") 168 | if err != nil { 169 | return err 170 | } 171 | 172 | v, _ := Version(c) 173 | var toRollback *Migration 174 | for _, m := range migrations { 175 | if m.Version == v { 176 | toRollback = &m 177 | break 178 | } 179 | } 180 | 181 | if toRollback == nil { 182 | return nil 183 | } 184 | 185 | // rollback only the last migration 186 | fmt.Println("== Reverting", toRollback.Filename, "==") 187 | t0 := time.Now() 188 | 189 | if err = applyMigration(c, *toRollback, DOWN); err != nil { 190 | printFailedMigrationMessage(err, ROLLBACK) 191 | return err 192 | } 193 | 194 | fmt.Println("== Completed in", time.Now().Sub(t0).Nanoseconds()/1e6, "ms ==") 195 | 196 | return nil 197 | } 198 | 199 | // Version returns the highest version number stored in the database. This is not 200 | // necessarily enough info to uniquely identify the version, since there may 201 | // be backdated migrations which have not yet applied. 202 | func Version(c *Config) (int64, error) { 203 | db, err := openConnection(c) 204 | defer db.Close() 205 | 206 | if err != nil { 207 | return -1, err 208 | } 209 | 210 | exists, err := migrationTableExists(c, db) 211 | if err != nil { 212 | return -1, err 213 | } 214 | 215 | if !exists { 216 | return -1, nil 217 | } 218 | 219 | var version int64 220 | err = db.QueryRow(fmt.Sprintf( 221 | `SELECT COALESCE(MAX(version)::text, '-1') FROM %s`, 222 | c.quotedMigrationTable(), 223 | )).Scan(&version) 224 | 225 | return version, err 226 | } 227 | 228 | // Initialize creates the schema_migrations table if necessary. 229 | func Initialize(c *Config) error { 230 | db, err := openConnection(c) 231 | defer db.Close() 232 | if err != nil { 233 | return err 234 | } 235 | 236 | if err := createSchemaUnlessExists(c, db); err != nil { 237 | return err 238 | } 239 | 240 | tableExists, err := migrationTableExists(c, db) 241 | if err != nil { 242 | return err 243 | } 244 | 245 | if tableExists { 246 | return nil 247 | } 248 | 249 | _, err = db.Exec(fmt.Sprintf( 250 | "CREATE TABLE %s (version %s NOT NULL UNIQUE);", 251 | c.quotedMigrationTable(), 252 | c.versionColumnType(), 253 | )) 254 | 255 | if err != nil { 256 | return err 257 | } 258 | 259 | return nil 260 | } 261 | 262 | // CreateMigration generates new, empty migration files. 263 | func CreateMigration(c *Config, name string, noTransaction bool) error { 264 | version := generateVersion(c) 265 | prefix := fmt.Sprint(version, "_", name) 266 | 267 | if noTransaction { 268 | prefix += ".no_txn" 269 | } 270 | 271 | upFilepath := filepath.Join(c.MigrationFolder, prefix+".up.sql") 272 | downFilepath := filepath.Join(c.MigrationFolder, prefix+".down.sql") 273 | 274 | err := ioutil.WriteFile(upFilepath, []byte(`-- Migration goes here.`), 0644) 275 | if err != nil { 276 | return err 277 | } 278 | fmt.Println("Created", upFilepath) 279 | 280 | err = ioutil.WriteFile(downFilepath, []byte(`-- Rollback of migration goes here. If you don't want to write it, delete this file.`), 0644) 281 | if err != nil { 282 | return err 283 | } 284 | fmt.Println("Created", downFilepath) 285 | 286 | return nil 287 | } 288 | 289 | func generateVersion(c *Config) string { 290 | // TODO: guarantee no conflicts by incrementing if there is a conflict 291 | t := time.Now() 292 | 293 | if c.Format == "datetime" { 294 | return t.Format(datetimeFormat) 295 | } 296 | 297 | return strconv.FormatInt(t.Unix(), 10) 298 | } 299 | 300 | // need access to the original query contents in order to print it out properly, 301 | // unfortunately. 302 | func formatPgErr(contents *[]byte, pgerr *pq.Error) string { 303 | pos, _ := strconv.Atoi(pgerr.Position) 304 | lineNo := bytes.Count((*contents)[:pos], []byte("\n")) + 1 305 | columnNo := pos - bytes.LastIndex((*contents)[:pos], []byte("\n")) - 1 306 | 307 | return fmt.Sprint("PGERROR: line ", lineNo, " pos ", columnNo, ": ", pgerr.Message, ". ", pgerr.Detail) 308 | } 309 | 310 | type execer interface { 311 | Exec(query string, args ...interface{}) (sql.Result, error) 312 | } 313 | 314 | func applyMigration(c *Config, m Migration, direction int) error { 315 | if c.MigrationDriver == "psql" { 316 | return applyMigrationByPsql(c, m, direction) 317 | } 318 | 319 | return applyMigrationByPq(c, m, direction) 320 | } 321 | 322 | func applyMigrationByPsql(c *Config, m Migration, direction int) error { 323 | if err := c.DumpToEnv(); err != nil { 324 | return err 325 | } 326 | 327 | contents, err := ioutil.ReadFile(filepath.Join(c.MigrationFolder, m.Filename)) 328 | if err != nil { 329 | return err 330 | } 331 | 332 | tmpfile, err := ioutil.TempFile("", "migration") 333 | if err != nil { 334 | return err 335 | } 336 | defer os.Remove(tmpfile.Name()) // clean up 337 | 338 | if _, err := tmpfile.Write(contents); err != nil { 339 | return err 340 | } 341 | 342 | if direction == UP { 343 | tmpfile.WriteString( 344 | fmt.Sprintf("\n; INSERT INTO %s (version) VALUES ('%d');", c.quotedMigrationTable(), m.Version), 345 | ) 346 | } else { // DOWN 347 | tmpfile.WriteString( 348 | fmt.Sprintf("\n; DELETE FROM %s WHERE version = '%d';", c.quotedMigrationTable(), m.Version), 349 | ) 350 | } 351 | 352 | if err := tmpfile.Close(); err != nil { 353 | return err 354 | } 355 | 356 | migrationFilePath := tmpfile.Name() 357 | args := []string{"-f", migrationFilePath, "-v", "ON_ERROR_STOP=1"} 358 | 359 | if m.WrapInTransaction() { 360 | args = append(args, "-1") 361 | } 362 | 363 | if err := sh("psql", args); err != nil { 364 | return err 365 | } 366 | 367 | return nil 368 | } 369 | 370 | func applyMigrationByPq(c *Config, m Migration, direction int) error { 371 | var tx *sql.Tx 372 | var exec execer 373 | 374 | rollback := func() { 375 | if tx != nil { 376 | tx.Rollback() 377 | } 378 | } 379 | 380 | contents, err := ioutil.ReadFile(filepath.Join(c.MigrationFolder, m.Filename)) 381 | if err != nil { 382 | return err 383 | } 384 | 385 | db, err := openConnection(c) 386 | defer db.Close() 387 | if err != nil { 388 | return err 389 | } 390 | exec = db 391 | 392 | if m.WrapInTransaction() { 393 | tx, err = db.Begin() 394 | if err != nil { 395 | return err 396 | } 397 | exec = tx 398 | } 399 | 400 | if _, err = exec.Exec(string(contents)); err != nil { 401 | rollback() 402 | return errors.New(formatPgErr(&contents, err.(*pq.Error))) 403 | } 404 | 405 | if direction == UP { 406 | if err = insertSchemaVersion(c, exec, m.Version); err != nil { 407 | rollback() 408 | return errors.New(formatPgErr(&contents, err.(*pq.Error))) 409 | } 410 | } else { 411 | if err = deleteSchemaVersion(c, exec, m.Version); err != nil { 412 | rollback() 413 | return errors.New(formatPgErr(&contents, err.(*pq.Error))) 414 | } 415 | } 416 | 417 | if tx != nil { 418 | err = tx.Commit() 419 | if err != nil { 420 | return err 421 | } 422 | } 423 | 424 | return nil 425 | } 426 | 427 | func createSchemaUnlessExists(c *Config, db *sql.DB) error { 428 | // If there's no schema name in the config, we don't need to create the schema. 429 | if !strings.Contains(c.MigrationTable, ".") { 430 | return nil 431 | } 432 | 433 | var exists bool 434 | 435 | schema := strings.SplitN(c.MigrationTable, ".", 2)[0] 436 | err := db.QueryRow( 437 | `SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = $1)`, 438 | schema, 439 | ).Scan(&exists) 440 | 441 | if err != nil { 442 | return err 443 | } 444 | 445 | if exists { 446 | return nil 447 | } 448 | 449 | _, err = db.Exec(fmt.Sprintf( 450 | "CREATE SCHEMA %s;", 451 | pq.QuoteIdentifier(schema), 452 | )) 453 | return err 454 | } 455 | 456 | func insertSchemaVersion(c *Config, tx execer, version int64) error { 457 | _, err := tx.Exec( 458 | fmt.Sprintf(`INSERT INTO %s (version) VALUES ($1) RETURNING version;`, c.quotedMigrationTable()), 459 | typedVersion(c, version), 460 | ) 461 | return err 462 | } 463 | 464 | func deleteSchemaVersion(c *Config, tx execer, version int64) error { 465 | _, err := tx.Exec( 466 | fmt.Sprintf(`DELETE FROM %s WHERE version = $1`, c.quotedMigrationTable()), 467 | typedVersion(c, version), 468 | ) 469 | return err 470 | } 471 | 472 | func typedVersion(c *Config, version int64) interface{} { 473 | if c.ColumnType == "string" { 474 | return strconv.FormatInt(version, 10) 475 | } 476 | return version 477 | } 478 | 479 | func migrationTableExists(c *Config, db *sql.DB) (bool, error) { 480 | var hasTable bool 481 | var err error 482 | 483 | if strings.Contains(c.MigrationTable, ".") { 484 | tokens := strings.SplitN(c.MigrationTable, ".", 2) 485 | err = db.QueryRow( 486 | `SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE schemaname = $1 AND tablename = $2)`, 487 | tokens[0], tokens[1], 488 | ).Scan(&hasTable) 489 | } else { 490 | err = db.QueryRow( 491 | `SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = $1)`, 492 | c.MigrationTable, 493 | ).Scan(&hasTable) 494 | } 495 | 496 | return hasTable, err 497 | } 498 | 499 | func migrationIsApplied(c *Config, version int64) (bool, error) { 500 | db, err := openConnection(c) 501 | defer db.Close() 502 | if err != nil { 503 | return false, err 504 | } 505 | 506 | var applied bool 507 | err = db.QueryRow( 508 | fmt.Sprintf(`SELECT EXISTS(SELECT 1 FROM %s WHERE version = $1)`, c.quotedMigrationTable()), 509 | version, 510 | ).Scan(&applied) 511 | 512 | if err != nil { 513 | return false, err 514 | } 515 | 516 | return applied, nil 517 | } 518 | 519 | func openConnection(c *Config) (*sql.DB, error) { 520 | db, err := sql.Open("postgres", SQLConnectionString(c)) 521 | if err != nil { 522 | return nil, err 523 | } 524 | return db, db.Ping() 525 | } 526 | 527 | // SQLConnectionString formats the values pulled from the config into a connection string 528 | func SQLConnectionString(c *Config) string { 529 | args := make([]interface{}, 0) 530 | 531 | if c.Username != "" { 532 | args = append(args, " user='", c.Username, "'") 533 | } 534 | 535 | if c.Database != "" { 536 | args = append(args, " dbname='", c.Database, "'") 537 | } 538 | 539 | if c.Password != "" { 540 | args = append(args, " password='", c.Password, "'") 541 | } 542 | 543 | args = append(args, 544 | " host='", c.Host, "'", 545 | " port=", c.Port, 546 | " sslmode=", c.SslMode) 547 | 548 | return fmt.Sprint(args...) 549 | } 550 | 551 | func migrations(c *Config, direction string) ([]Migration, error) { 552 | re := regexp.MustCompile("^[0-9]+") 553 | 554 | migrations := []Migration{} 555 | files, err := ioutil.ReadDir(c.MigrationFolder) 556 | if err != nil { 557 | return migrations, err 558 | } 559 | 560 | for _, file := range files { 561 | if match, _ := regexp.MatchString("^[0-9]+_.+\\."+direction+"\\.sql$", file.Name()); match { 562 | version, _ := strconv.ParseInt(re.FindString(file.Name()), 10, 64) 563 | migrations = append(migrations, Migration{Filename: file.Name(), Version: version}) 564 | } 565 | } 566 | 567 | return migrations, nil 568 | } 569 | 570 | func sh(command string, args []string) error { 571 | c := exec.Command(command, args...) 572 | output, err := c.CombinedOutput() 573 | fmt.Println(string(output)) 574 | if err != nil { 575 | return err 576 | } 577 | 578 | return nil 579 | } 580 | 581 | func shRead(command string, args []string) (*[]byte, error) { 582 | c := exec.Command(command, args...) 583 | output, err := c.CombinedOutput() 584 | return &output, err 585 | } 586 | 587 | func printFailedMigrationMessage(err error, migrationType string) { 588 | fmt.Fprintf(os.Stderr, err.Error()) 589 | fmt.Fprintf(os.Stderr, "\n\n") 590 | fmt.Fprintf(os.Stderr, "ERROR! Aborting the "+migrationType+" process.\n") 591 | } 592 | -------------------------------------------------------------------------------- /pgmgr/pgmgr_test.go: -------------------------------------------------------------------------------- 1 | package pgmgr 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os/exec" 7 | "path" 8 | "path/filepath" 9 | "strings" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | const ( 15 | testDBName = "pgmgr_testdb" 16 | migrationFolder = "/tmp/migrations/" 17 | dumpFile = "/tmp/pgmgr_dump.sql" 18 | ) 19 | 20 | func globalConfig() *Config { 21 | return &Config{ 22 | Username: "pgmgr", 23 | Database: testDBName, 24 | Host: "localhost", 25 | Port: 5432, 26 | MigrationFolder: migrationFolder, 27 | MigrationTable: "schema_migrations", 28 | SslMode: "disable", 29 | DumpConfig: DumpConfig{DumpFile: dumpFile, NoCompress: true}, 30 | } 31 | } 32 | 33 | func TestCreate(t *testing.T) { 34 | dropDB(t) 35 | 36 | if err := Create(globalConfig()); err != nil { 37 | t.Log(err) 38 | t.Fatal("Could not create database") 39 | } 40 | 41 | // if we can't remove that db, it couldn't have been created by us above. 42 | if err := dropDB(t); err != nil { 43 | t.Fatal("database doesn't seem to have been created!") 44 | } 45 | } 46 | 47 | func TestDrop(t *testing.T) { 48 | if err := createDB(t); err != nil { 49 | t.Fatal("createdb failed: ", err) 50 | } 51 | 52 | if err := Drop(globalConfig()); err != nil { 53 | t.Log(err) 54 | t.Fatal("Could not drop database") 55 | } 56 | 57 | if err := createDB(t); err != nil { 58 | t.Fatal("database doesn't seem to have been dropped!") 59 | } 60 | } 61 | 62 | func TestDump(t *testing.T) { 63 | resetDB(t) 64 | psqlMustExec(t, `CREATE TABLE bars (bar_id INTEGER);`) 65 | psqlMustExec(t, `INSERT INTO bars (bar_id) VALUES (123), (456);`) 66 | psqlMustExec(t, `CREATE TABLE foos (foo_id INTEGER);`) 67 | psqlMustExec(t, `INSERT INTO foos (foo_id) VALUES (789);`) 68 | 69 | c := globalConfig() 70 | err := Dump(c) 71 | 72 | if err != nil { 73 | t.Log(err) 74 | t.Fatal("Could not dump database to file") 75 | } 76 | 77 | file, err := ioutil.ReadFile(dumpFile) 78 | if err != nil { 79 | t.Log(err) 80 | t.Fatal("Could not read dump") 81 | } 82 | 83 | if !(strings.Contains(string(file), "CREATE TABLE bars") || strings.Contains(string(file), "CREATE TABLE public.bars")) { 84 | t.Log(string(file)) 85 | t.Fatal("dump does not contain the table definition") 86 | } 87 | 88 | if !strings.Contains(string(file), "123") { 89 | t.Fatal("dump does not contain the table data when --seed-tables is not specified") 90 | } 91 | 92 | c.DumpConfig.IncludeTables = append(c.DumpConfig.IncludeTables, "foos") 93 | err = Dump(c) 94 | 95 | if err != nil { 96 | t.Log(err) 97 | t.Fatal("Could not dump database to file") 98 | } 99 | 100 | file, err = ioutil.ReadFile(dumpFile) 101 | if err != nil { 102 | t.Log(err) 103 | t.Fatal("Could not read dump") 104 | } 105 | 106 | if strings.Contains(string(file), "123") { 107 | t.Fatal("dump contains table data for non-seed tables, when --seed-tables was given") 108 | } 109 | 110 | if !strings.Contains(string(file), "789") { 111 | t.Fatal("dump does not contain table data for seed tables, when --seed-tables was given") 112 | } 113 | } 114 | 115 | func TestLoad(t *testing.T) { 116 | resetDB(t) 117 | 118 | ioutil.WriteFile(dumpFile, []byte(` 119 | CREATE TABLE foos (foo_id INTEGER); 120 | INSERT INTO foos (foo_id) VALUES (1), (2), (3); 121 | `), 0644) 122 | 123 | err := Load(globalConfig()) 124 | 125 | if err != nil { 126 | t.Log(err) 127 | t.Fatal("Could not load database from file") 128 | } 129 | 130 | psqlMustExec(t, `SELECT * FROM foos;`) 131 | } 132 | 133 | func TestInitialize(t *testing.T) { 134 | config := globalConfig() 135 | 136 | // Default config should create public.schema_migrations 137 | resetDB(t) 138 | 139 | if err := Initialize(config); err != nil { 140 | t.Fatal("Initialize failed:", err) 141 | } 142 | 143 | if err := Initialize(config); err != nil { 144 | t.Fatal("Initialize was not safe to run twice:", err) 145 | } 146 | 147 | psqlMustExec(t, `SELECT * FROM public.schema_migrations;`) 148 | 149 | // If we specify a table, it should create public. 150 | resetDB(t) 151 | config.MigrationTable = "applied_migrations" 152 | 153 | if err := Initialize(config); err != nil { 154 | t.Fatal("Initialize failed: ", err) 155 | } 156 | 157 | psqlMustExec(t, `SELECT * FROM public.applied_migrations;`) 158 | 159 | // If we specify a schema-qualified table, the schema should be 160 | // created if it does not yet exist. 161 | resetDB(t) 162 | config.MigrationTable = "pgmgr.applied_migrations" 163 | if err := Initialize(config); err != nil { 164 | t.Fatal("Initialize failed: ", err) 165 | } 166 | 167 | psqlMustExec(t, `SELECT * FROM pgmgr.applied_migrations`) 168 | 169 | // If we specify a schema-qualified table, and the schema already existed, 170 | // that's fine too. 171 | resetDB(t) 172 | psqlMustExec(t, `CREATE SCHEMA pgmgr;`) 173 | if err := Initialize(config); err != nil { 174 | t.Fatal("Initialize failed: ", err) 175 | } 176 | 177 | psqlMustExec(t, `SELECT * FROM pgmgr.applied_migrations`) 178 | } 179 | 180 | func TestVersion(t *testing.T) { 181 | resetDB(t) 182 | 183 | version, err := Version(globalConfig()) 184 | 185 | if err != nil { 186 | t.Log(err) 187 | t.Fatal("Could not fetch version info") 188 | } 189 | 190 | if version != -1 { 191 | t.Fatal("expected version to be -1 before table exists, got", version) 192 | } 193 | 194 | Initialize(globalConfig()) 195 | psqlMustExec(t, `INSERT INTO schema_migrations (version) VALUES (1);`) 196 | 197 | version, err = Version(globalConfig()) 198 | if version != 1 { 199 | t.Fatal("expected version to be 1, got", version) 200 | } 201 | } 202 | 203 | func TestColumnTypeString(t *testing.T) { 204 | resetDB(t) 205 | 206 | config := globalConfig() 207 | config.ColumnType = "string" 208 | Initialize(config) 209 | 210 | psqlMustExec(t, `INSERT INTO schema_migrations (version) VALUES ('20150910120933');`) 211 | version, err := Version(config) 212 | if err != nil { 213 | t.Fatal(err) 214 | } 215 | 216 | if version != 20150910120933 { 217 | t.Fatal("expected version to be 20150910120933, got", version) 218 | } 219 | } 220 | 221 | func TestMigrate(t *testing.T) { 222 | resetDB(t) 223 | clearMigrationFolder(t) 224 | 225 | // add our first migration 226 | writeMigration(t, "002_this_is_a_migration.up.sql", ` 227 | CREATE TABLE foos (foo_id INTEGER); 228 | INSERT INTO foos (foo_id) VALUES (1), (2), (3); 229 | `) 230 | 231 | writeMigration(t, "002_this_is_a_migration.down.sql", `DROP TABLE foos;`) 232 | 233 | err := Migrate(globalConfig()) 234 | 235 | if err != nil { 236 | t.Log(err) 237 | t.Fatal("Migrations failed to run.") 238 | } 239 | 240 | // test simple idempotency 241 | err = Migrate(globalConfig()) 242 | if err != nil { 243 | t.Log(err) 244 | t.Fatal("Running migrations again was not idempotent!") 245 | } 246 | 247 | psqlMustExec(t, `SELECT * FROM foos;`) 248 | 249 | // add a new migration with an older version, as if another dev's branch was merged in 250 | writeMigration(t, "001_this_is_an_older_migration.up.sql", ` 251 | CREATE TABLE bars (bar_id INTEGER); 252 | INSERT INTO bars (bar_id) VALUES (4), (5), (6); 253 | `) 254 | 255 | err = Migrate(globalConfig()) 256 | if err != nil { 257 | t.Log(err) 258 | t.Fatal("Could not apply second migration!") 259 | } 260 | 261 | psqlMustExec(t, `SELECT * FROM bars;`) 262 | 263 | // Make a filename that would match a vim .swp file 264 | writeMigration(t, ".003_this_is_an_older_migration.up.sql.swp", ` 265 | CREATE TABLE baz (baz_id INTEGER); 266 | INSERT INTO baz (baz_id) VALUES (4), (5), (6); 267 | `) 268 | 269 | err = Migrate(globalConfig()) 270 | if err != nil { 271 | t.Log(err) 272 | t.Fatal("Migration returned an error instead of being skipped!") 273 | } 274 | 275 | psqlMustNotExec(t, `SELECT * FROM baz;`) 276 | 277 | // rollback the initial migration, since it has the latest version 278 | err = Rollback(globalConfig()) 279 | 280 | if err := psqlExec(t, `SELECT * FROM foos;`); err == nil { 281 | t.Fatal("Should not have been able to select from foos table") 282 | } 283 | 284 | v, err := Version(globalConfig()) 285 | if err != nil || v != 1 { 286 | t.Log(err) 287 | t.Fatal("Rollback did not reset version! Still on version ", v) 288 | } 289 | } 290 | 291 | func TestMigrateColumnTypeString(t *testing.T) { 292 | resetDB(t) 293 | clearMigrationFolder(t) 294 | 295 | config := globalConfig() 296 | config.ColumnType = "string" 297 | 298 | // migrate up 299 | writeMigration(t, "20150910120933_some_migration.up.sql", ` 300 | CREATE TABLE foos (foo_id INTEGER); 301 | INSERT INTO foos (foo_id) VALUES (1), (2), (3); 302 | `) 303 | 304 | err := Migrate(config) 305 | if err != nil { 306 | t.Fatal(err) 307 | } 308 | 309 | v, err := Version(config) 310 | if err != nil { 311 | t.Fatal(err) 312 | } 313 | 314 | if v != 20150910120933 { 315 | t.Fatal("Expected version 20150910120933 after migration, got", v) 316 | } 317 | 318 | // migrate down 319 | writeMigration(t, "20150910120933_some_migration.down.sql", `DROP TABLE foos;`) 320 | 321 | err = Rollback(config) 322 | if err != nil { 323 | t.Fatal(err) 324 | } 325 | 326 | v, err = Version(config) 327 | if err != nil { 328 | t.Fatal(err) 329 | } 330 | 331 | if v != -1 { 332 | t.Fatal("Expected version -1 after rollback, got", v) 333 | } 334 | } 335 | 336 | func TestMigrateNoTransaction(t *testing.T) { 337 | resetDB(t) 338 | clearMigrationFolder(t) 339 | 340 | // CREATE INDEX CONCURRENTLY can not run inside a transaction, so we can assert 341 | // that no transaction was used by verifying it ran successfully. 342 | writeMigration(t, "001_create_foos.up.sql", `CREATE TABLE foos (foo_id INTEGER);`) 343 | writeMigration(t, "002_index_foos.no_txn.up.sql", `CREATE INDEX CONCURRENTLY idx_foo_id ON foos(foo_id);`) 344 | 345 | err := Migrate(globalConfig()) 346 | if err != nil { 347 | t.Fatal(err) 348 | } 349 | } 350 | 351 | func TestMigrateCustomMigrationTable(t *testing.T) { 352 | resetDB(t) 353 | clearMigrationFolder(t) 354 | writeMigration(t, "001_create_foos.up.sql", `CREATE TABLE foos (foo_id INTEGER);`) 355 | 356 | config := globalConfig() 357 | config.MigrationTable = "pgmgr.migrations" 358 | if err := Migrate(config); err != nil { 359 | t.Fatal(err) 360 | } 361 | 362 | v, err := Version(config) 363 | if err != nil { 364 | t.Fatal(err) 365 | } 366 | 367 | if v != 1 { 368 | t.Fatal("Expected version 1, got ", v) 369 | } 370 | } 371 | 372 | func TestMigratePsqlDriver(t *testing.T) { 373 | resetDB(t) 374 | clearMigrationFolder(t) 375 | 376 | writeMigration(t, "001_create_foos.up.sql", `CREATE TABLE foos (foo_id INTEGER, val BOOLEAN);`) 377 | writeMigration(t, "001_create_foos.down.sql", `DROP TABLE foos;`) 378 | writeMigration(t, "002_index_foos.no_txn.up.sql", `CREATE INDEX CONCURRENTLY ON foos (foo_id); CREATE INDEX CONCURRENTLY ON foos (val);`) 379 | writeMigration(t, "002_index_foos.no_txn.down.sql", ``) 380 | 381 | config := globalConfig() 382 | config.MigrationDriver = "psql" 383 | 384 | if err := Migrate(config); err != nil { 385 | t.Fatal(err) 386 | } 387 | 388 | v, err := Version(config) 389 | if err != nil { 390 | t.Fatal(err) 391 | } 392 | 393 | if v != 2 { 394 | t.Fatal("expected version 2, got", v) 395 | } 396 | 397 | if err := psqlExec(t, `SELECT * FROM foos;`); err != nil { 398 | t.Fatal("foos table is not queryable -- does it exist?") 399 | } 400 | 401 | if err := Rollback(config); err != nil { 402 | t.Fatal("rollback of 2 failed") 403 | } 404 | 405 | v, _ = Version(config) 406 | if v != 1 { 407 | t.Fatal("expected version 1, got", v) 408 | } 409 | 410 | if err := Rollback(config); err != nil { 411 | t.Fatal("rollback of 1 failed") 412 | } 413 | 414 | v, _ = Version(config) 415 | if v != -1 { 416 | t.Fatal("expected version -1, got", v) 417 | } 418 | 419 | if err := psqlExec(t, `SELECT * FROM foos;`); err == nil { 420 | t.Fatal("foos table is queryable but should not be") 421 | } 422 | } 423 | 424 | func TestMigratePsqlAddsSemicolon(t *testing.T) { 425 | resetDB(t) 426 | clearMigrationFolder(t) 427 | 428 | writeMigration(t, "001_create_foos.up.sql", `CREATE TABLE foos (foo_id INTEGER, val BOOLEAN)`) 429 | 430 | config := globalConfig() 431 | config.MigrationDriver = "psql" 432 | 433 | if err := Migrate(config); err != nil { 434 | t.Fatal(err) 435 | } 436 | 437 | v, err := Version(config) 438 | if err != nil { 439 | t.Fatal(err) 440 | } 441 | 442 | if v != 1 { 443 | t.Fatal("Expected migration to apply; did not -- version is still: ", v) 444 | } 445 | } 446 | 447 | func TestMigratePsqlAddsNewline(t *testing.T) { 448 | resetDB(t) 449 | clearMigrationFolder(t) 450 | 451 | writeMigration(t, "001_create_foos.up.sql", `CREATE TABLE foos (foo_id INTEGER, val BOOLEAN); --okay, all done!`) 452 | 453 | config := globalConfig() 454 | config.MigrationDriver = "psql" 455 | 456 | if err := Migrate(config); err != nil { 457 | t.Fatal(err) 458 | } 459 | 460 | v, err := Version(config) 461 | if err != nil { 462 | t.Fatal(err) 463 | } 464 | 465 | if v != 1 { 466 | t.Fatal("Expected migration to apply; did not -- version is still: ", v) 467 | } 468 | } 469 | 470 | func TestMigratePsqlFailedMigration(t *testing.T) { 471 | resetDB(t) 472 | clearMigrationFolder(t) 473 | 474 | writeMigration(t, "001_oops.up.sql", ` 475 | CREATE TABLE bars (bar_id INTEGER); 476 | CREATE TABLE foos (foo_id INTEGER, val BOOLEAN; 477 | `) // syntax error! 478 | 479 | config := globalConfig() 480 | config.MigrationDriver = "psql" 481 | 482 | if err := Migrate(config); err == nil { 483 | t.Fatal("Expected malformed migration to raise error, but got none") 484 | } 485 | 486 | if err := psqlExec(t, `SELECT * FROM bars;`); err == nil { 487 | t.Fatal("Migration partially applied when it should've rolled back") 488 | } 489 | } 490 | 491 | func TestCreateMigration(t *testing.T) { 492 | clearMigrationFolder(t) 493 | 494 | assertFileExists := func(filename string) { 495 | err := testSh(t, "stat", []string{filepath.Join(migrationFolder, filename)}) 496 | if err != nil { 497 | t.Fatal(err) 498 | } 499 | } 500 | 501 | expectedVersion := time.Now().Unix() 502 | err := CreateMigration(globalConfig(), "new_migration", false) 503 | if err != nil { 504 | t.Fatal(err) 505 | } 506 | 507 | assertFileExists(fmt.Sprint(expectedVersion, "_new_migration.up.sql")) 508 | assertFileExists(fmt.Sprint(expectedVersion, "_new_migration.down.sql")) 509 | 510 | expectedStringVersion := time.Now().Format(datetimeFormat) 511 | config := globalConfig() 512 | config.Format = "datetime" 513 | err = CreateMigration(config, "rails_style", false) 514 | if err != nil { 515 | t.Fatal(err) 516 | } 517 | 518 | assertFileExists(fmt.Sprint(expectedStringVersion, "_rails_style.up.sql")) 519 | assertFileExists(fmt.Sprint(expectedStringVersion, "_rails_style.down.sql")) 520 | 521 | err = CreateMigration(config, "create_index", true) 522 | if err != nil { 523 | t.Fatal(err) 524 | } 525 | 526 | assertFileExists(fmt.Sprint(expectedStringVersion, "_create_index.no_txn.up.sql")) 527 | assertFileExists(fmt.Sprint(expectedStringVersion, "_create_index.no_txn.down.sql")) 528 | } 529 | 530 | func TestRollback(t *testing.T) { 531 | resetDB(t) 532 | clearMigrationFolder(t) 533 | 534 | writeMigration(t, "001_a_migration.up.sql", ` 535 | CREATE TABLE foos (foo_id INTEGER); 536 | INSERT INTO foos (foo_id) VALUES (1), (2), (3); 537 | `) 538 | 539 | writeMigration(t, "001_a_migration.down.sql", `DROP TABLE foos;`) 540 | 541 | err := Migrate(globalConfig()) 542 | if err != nil { 543 | t.Log(err) 544 | t.Fatal("Migrations failed to run.") 545 | } 546 | 547 | err = Rollback(globalConfig()) 548 | if err != nil { 549 | t.Log(err) 550 | t.Fatal("Failed to rollback.") 551 | } 552 | 553 | psqlMustNotExec(t, "SELECT * FROM bars;") 554 | } 555 | 556 | func TestRollbackFailed(t *testing.T) { 557 | resetDB(t) 558 | clearMigrationFolder(t) 559 | 560 | writeMigration(t, "001_a_migration.up.sql", ` 561 | CREATE TABLE foos (foo_id INTEGER); 562 | INSERT INTO foos (foo_id) VALUES (1), (2), (3); 563 | `) 564 | 565 | // Note the syntax error in the SQL 566 | writeMigration(t, "001_a_migration.down.sql", `DRO TABLE foos;`) 567 | 568 | err := Migrate(globalConfig()) 569 | if err != nil { 570 | t.Log(err) 571 | t.Fatal("Migrations failed to run.") 572 | } 573 | 574 | err = Rollback(globalConfig()) 575 | if err == nil { 576 | t.Fatal("Rollback succeeded when it shouldn't have.") 577 | } 578 | } 579 | 580 | // redundant, but I'm also lazy 581 | func testSh(t *testing.T, command string, args []string) error { 582 | c := exec.Command(command, args...) 583 | output, err := c.CombinedOutput() 584 | t.Log(string(output)) 585 | if err != nil { 586 | return err 587 | } 588 | 589 | return nil 590 | } 591 | 592 | func psqlExec(t *testing.T, statement string) error { 593 | return testSh(t, "psql", []string{"-d", testDBName, "-c", statement}) 594 | } 595 | 596 | func psqlMustExec(t *testing.T, statement string) { 597 | err := psqlExec(t, statement) 598 | if err != nil { 599 | t.Fatalf("Failed to execute statement: '%s': %s", statement, err) 600 | } 601 | } 602 | 603 | func psqlMustNotExec(t *testing.T, statement string) { 604 | err := psqlExec(t, statement) 605 | 606 | // If there is no error, the statement successfully executed. 607 | // We don't want that to happen. 608 | if err == nil { 609 | t.Fatalf("SQL statement executed when it should not have: '%s'", statement) 610 | } 611 | } 612 | 613 | func resetDB(t *testing.T) { 614 | if err := dropDB(t); err != nil { 615 | t.Fatal("dropdb failed: ", err) 616 | } 617 | 618 | if err := createDB(t); err != nil { 619 | t.Fatal("createdb failed: ", err) 620 | } 621 | } 622 | 623 | func dropDB(t *testing.T) error { 624 | return testSh(t, "dropdb", []string{testDBName}) 625 | } 626 | 627 | func createDB(t *testing.T) error { 628 | return testSh(t, "createdb", []string{testDBName}) 629 | } 630 | 631 | func clearMigrationFolder(t *testing.T) { 632 | testSh(t, "rm", []string{"-r", migrationFolder}) 633 | testSh(t, "mkdir", []string{migrationFolder}) 634 | } 635 | 636 | func writeMigration(t *testing.T, name, contents string) { 637 | filename := path.Join(migrationFolder, name) 638 | err := ioutil.WriteFile(filename, []byte(contents), 0644) 639 | if err != nil { 640 | t.Fatalf("Failed to write %s: %s", filename, err) 641 | } 642 | } 643 | --------------------------------------------------------------------------------