├── gopher.png ├── platform-test ├── go.mod ├── go.sum └── platform_test.go ├── test_config.go ├── go.mod ├── examples ├── migrations │ └── 20191121120000_init.sql ├── go.mod ├── server_test.go ├── examples_test.go └── go.sum ├── rename_test.go ├── rename.go ├── cache_locator_test.go ├── CONTRIBUTING.md ├── LICENSE ├── cache_locator.go ├── .circleci └── config.yml ├── .golangci.yml ├── go.sum ├── logging.go ├── version_strategy.go ├── logging_test.go ├── test_util_test.go ├── .github └── workflows │ └── build.yml ├── .gitignore ├── decompression.go ├── version_strategy_test.go ├── prepare_database.go ├── remote_fetch.go ├── config.go ├── decompression_test.go ├── README.md ├── prepare_database_test.go ├── embedded_postgres.go ├── remote_fetch_test.go └── embedded_postgres_test.go /gopher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fergusstrange/embedded-postgres/HEAD/gopher.png -------------------------------------------------------------------------------- /platform-test/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres/platform-test 2 | 3 | replace github.com/fergusstrange/embedded-postgres => ../ 4 | 5 | go 1.18 6 | 7 | require github.com/fergusstrange/embedded-postgres v0.0.0 8 | 9 | require ( 10 | github.com/lib/pq v1.10.9 // indirect 11 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /test_config.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import "testing" 4 | 5 | func TestGetConnectionURL(t *testing.T) { 6 | config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") 7 | expect := "postgresql://myuser:mypass@localhost:5432/mydb" 8 | 9 | if got := config.GetConnectionURL(); got != expect { 10 | t.Errorf("expected \"%s\" got \"%s\"", expect, got) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/lib/pq v1.10.9 7 | github.com/stretchr/testify v1.10.0 8 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 9 | go.uber.org/goleak v1.3.0 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/kr/text v0.2.0 // indirect 15 | github.com/pmezard/go-difflib v1.0.0 // indirect 16 | gopkg.in/yaml.v3 v3.0.1 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /examples/migrations/20191121120000_init.sql: -------------------------------------------------------------------------------- 1 | -- +goose Up 2 | -- SQL in this section is executed when the migration is applied. 3 | CREATE TABLE beer_catalogue 4 | ( 5 | id SERIAL PRIMARY KEY, 6 | name TEXT, 7 | consumed BOOL DEFAULT TRUE, 8 | rating DOUBLE PRECISION 9 | ); 10 | 11 | INSERT INTO beer_catalogue (name, consumed, rating) 12 | VALUES ('Punk IPA', true, 68.29); 13 | 14 | -- +goose Down 15 | -- SQL in this section is executed when the migration is rolled back. -------------------------------------------------------------------------------- /rename_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func Test_renameOrIgnore_NoErrorOnEEXIST(t *testing.T) { 12 | tmpDir, err := os.MkdirTemp("", "test_dir") 13 | require.NoError(t, err) 14 | 15 | tmpFil, err := os.CreateTemp("", "test_file") 16 | require.NoError(t, err) 17 | 18 | // os.Rename would return an error here, ensure that the error is handled and returned as nil 19 | err = renameOrIgnore(tmpFil.Name(), tmpDir) 20 | assert.NoError(t, err) 21 | } 22 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres/examples 2 | 3 | go 1.18 4 | 5 | replace github.com/fergusstrange/embedded-postgres => ../ 6 | 7 | require ( 8 | github.com/fergusstrange/embedded-postgres v0.0.0 9 | github.com/jmoiron/sqlx v1.3.5 10 | github.com/lib/pq v1.10.9 11 | github.com/pressly/goose/v3 v3.0.1 12 | go.uber.org/zap v1.21.0 13 | ) 14 | 15 | require ( 16 | github.com/pkg/errors v0.9.1 // indirect 17 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect 18 | go.uber.org/atomic v1.7.0 // indirect 19 | go.uber.org/multierr v1.6.0 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /platform-test/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 3 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 6 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= 7 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= 8 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 9 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 10 | -------------------------------------------------------------------------------- /rename.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "syscall" 7 | ) 8 | 9 | // renameOrIgnore will rename the oldpath to the newpath. 10 | // 11 | // On Unix this will be a safe atomic operation. 12 | // On Windows this will do nothing if the new path already exists. 13 | // 14 | // This is only safe to use if you can be sure that the newpath is either missing, or contains the same data as the 15 | // old path. 16 | func renameOrIgnore(oldpath, newpath string) error { 17 | err := os.Rename(oldpath, newpath) 18 | 19 | // if the error is due to syscall.EEXIST then this is most likely windows, and a race condition with 20 | // multiple downloads of the file. We can assume that the existing file is the correct one and ignore 21 | // the error 22 | if errors.Is(err, syscall.EEXIST) { 23 | return nil 24 | } 25 | 26 | return err 27 | } 28 | -------------------------------------------------------------------------------- /cache_locator_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_defaultCacheLocator_NotExists(t *testing.T) { 10 | locator := defaultCacheLocator("", func() (string, string, PostgresVersion) { 11 | return "a", "b", "1.2.3" 12 | }) 13 | 14 | cacheLocation, exists := locator() 15 | 16 | assert.Contains(t, cacheLocation, ".embedded-postgres-go/embedded-postgres-binaries-a-b-1.2.3.txz") 17 | assert.False(t, exists) 18 | } 19 | 20 | func Test_defaultCacheLocator_CustomPath(t *testing.T) { 21 | locator := defaultCacheLocator("/custom/path", func() (string, string, PostgresVersion) { 22 | return "a", "b", "1.2.3" 23 | }) 24 | 25 | cacheLocation, exists := locator() 26 | 27 | assert.Equal(t, cacheLocation, "/custom/path/embedded-postgres-binaries-a-b-1.2.3.txz") 28 | assert.False(t, exists) 29 | } 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to embedded-postgres 2 | 3 | Thank you for taking the time to contribute. These are mostly guidelines, not rules. Use your best judgment, and feel 4 | free to propose changes to this document in a pull request. 5 | 6 | # Working with forked go repos 7 | 8 | If you haven't worked with forked go repos before, take a look at this blog post for some excellent advice 9 | about [contributing to go open source git repositories](https://splice.com/blog/contributing-open-source-git-repositories-go/) 10 | . 11 | 12 | # PRs 13 | 14 | - Please open PRs against master. 15 | - We prefer single commit PRs, but sometimes for multiple commits are justified - use your best judgement. 16 | - Please add/modify tests to cover the proposed code changes. 17 | - If the PR contains a new feature, please document it in the README. 18 | 19 | # Documentation 20 | 21 | For simple typo fixes and documentation improvements feel free to raise a PR without raising an issue in github. For 22 | anything more complicated please file an issue. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Fergus Strange 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cache_locator.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | // CacheLocator retrieves the location of the Postgres binary cache returning it to location. 10 | // The result of whether this cache is present will be returned to exists. 11 | type CacheLocator func() (location string, exists bool) 12 | 13 | func defaultCacheLocator(cacheDirectory string, versionStrategy VersionStrategy) CacheLocator { 14 | return func() (string, bool) { 15 | if cacheDirectory == "" { 16 | cacheDirectory = ".embedded-postgres-go" 17 | if userHome, err := os.UserHomeDir(); err == nil { 18 | cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go") 19 | } 20 | } 21 | 22 | operatingSystem, architecture, version := versionStrategy() 23 | cacheLocation := filepath.Join(cacheDirectory, 24 | fmt.Sprintf("embedded-postgres-binaries-%s-%s-%s.txz", 25 | operatingSystem, 26 | architecture, 27 | version)) 28 | 29 | info, err := os.Stat(cacheLocation) 30 | 31 | if err != nil { 32 | return cacheLocation, os.IsExist(err) && !info.IsDir() 33 | } 34 | 35 | return cacheLocation, !info.IsDir() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | executors: 3 | linux-arm64: 4 | machine: 5 | image: ubuntu-2204:2024.01.2 6 | resource_class: arm.medium 7 | working_directory: /home/circleci/go/src/github.com/fergusstrange/embedded-postgres 8 | apple-m4: &macos-executor 9 | resource_class: m4pro.medium 10 | macos: 11 | xcode: "26.1.0" 12 | orbs: 13 | go: circleci/go@3.0.3 14 | jobs: 15 | platform_test: 16 | parameters: 17 | executor: 18 | type: executor 19 | executor: << parameters.executor >> 20 | steps: 21 | - checkout 22 | - when: 23 | condition: 24 | equal: [ *macos-executor, << parameters.executor >> ] 25 | steps: 26 | - go/install: 27 | version: "1.18" 28 | - run: /usr/sbin/softwareupdate --install-rosetta --agree-to-license 29 | - go/load-mod-cache 30 | - go/mod-download 31 | - go/save-mod-cache 32 | - run: cd platform-test && go mod download && go test -v -race ./... 33 | 34 | workflows: 35 | version: 2 36 | test: 37 | jobs: 38 | - platform_test: 39 | matrix: 40 | parameters: 41 | executor: 42 | - linux-arm64 43 | - apple-m4 44 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - deadcode 5 | - errcheck 6 | - gosimple 7 | - govet 8 | - ineffassign 9 | - staticcheck 10 | - structcheck 11 | - typecheck 12 | - unused 13 | - varcheck 14 | - bodyclose 15 | - depguard 16 | - dogsled 17 | - dupl 18 | - funlen 19 | - gochecknoglobals 20 | - gochecknoinits 21 | - gocognit 22 | - goconst 23 | - gocritic 24 | - gocyclo 25 | - godox 26 | - gofmt 27 | - goimports 28 | - revive 29 | - misspell 30 | - nakedret 31 | - prealloc 32 | - exportloopref 33 | - stylecheck 34 | - unconvert 35 | - unparam 36 | - whitespace 37 | - wsl 38 | issues: 39 | exclude-use-default: false 40 | exclude: 41 | - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv). is not checked 42 | - ST1000 43 | - func name will be used as test\.Test.* by other packages, and that stutters; consider calling this 44 | - (possible misuse of unsafe.Pointer|should have signature) 45 | - ineffective break statement. Did you mean to break out of the outer loop 46 | - Use of unsafe calls should be audited 47 | - Subprocess launch(ed with variable|ing should be audited) 48 | - G104 49 | - (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less) 50 | - Potential file inclusion via variable 51 | exclude-rules: 52 | - path: '(.+)_test\.go' 53 | linters: 54 | - funlen 55 | - goconst 56 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 5 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 6 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 7 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 8 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 12 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 13 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= 14 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= 15 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 16 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 18 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 19 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 20 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 21 | -------------------------------------------------------------------------------- /examples/server_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/jmoiron/sqlx" 9 | "github.com/pressly/goose/v3" 10 | ) 11 | 12 | type BeerCatalogue struct { 13 | ID int64 `json:"id"` 14 | Name string `json:"name"` 15 | Consumed bool `json:"consumed"` 16 | Rating float64 `json:"rating"` 17 | } 18 | 19 | type App struct { 20 | router *http.ServeMux 21 | } 22 | 23 | func (a *App) Start() error { 24 | return http.ListenAndServe("localhost:8080", a.router) 25 | } 26 | 27 | func NewApp() *App { 28 | db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | 33 | if err := goose.Up(db.DB, "./migrations"); err != nil { 34 | log.Fatal(err) 35 | } 36 | 37 | router := http.NewServeMux() 38 | router.HandleFunc("/beer-catalogue", GetBeer(db)) 39 | 40 | return &App{router: router} 41 | } 42 | 43 | func GetBeer(db *sqlx.DB) func(w http.ResponseWriter, r *http.Request) { 44 | return func(w http.ResponseWriter, r *http.Request) { 45 | if beerName := r.URL.Query().Get("name"); beerName != "" { 46 | beers := make([]BeerCatalogue, 0) 47 | if err := db.Select(&beers, "SELECT * FROM beer_catalogue WHERE UPPER(name) = UPPER($1)", beerName); err != nil { 48 | w.WriteHeader(http.StatusInternalServerError) 49 | return 50 | } 51 | 52 | jsonPayload, err := json.Marshal(beers) 53 | if err != nil { 54 | w.WriteHeader(http.StatusInternalServerError) 55 | return 56 | } 57 | 58 | if _, err := w.Write(jsonPayload); err != nil { 59 | w.WriteHeader(http.StatusInternalServerError) 60 | return 61 | } 62 | } 63 | 64 | w.WriteHeader(http.StatusBadRequest) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/ioutil" 7 | "os" 8 | "time" 9 | ) 10 | 11 | type syncedLogger struct { 12 | offset int64 13 | logger io.Writer 14 | file *os.File 15 | } 16 | 17 | func newSyncedLogger(dir string, logger io.Writer) (*syncedLogger, error) { 18 | file, err := os.CreateTemp(dir, "embedded_postgres_log") 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | s := syncedLogger{ 24 | logger: logger, 25 | file: file, 26 | } 27 | 28 | return &s, nil 29 | } 30 | 31 | func (s *syncedLogger) flush() error { 32 | if s.logger != nil { 33 | file, err := os.Open(s.file.Name()) 34 | if err != nil { 35 | return fmt.Errorf("unable to process postgres logs: %s", err) 36 | } 37 | 38 | defer func() { 39 | if err := file.Close(); err != nil { 40 | panic(err) 41 | } 42 | }() 43 | 44 | if _, err = file.Seek(s.offset, io.SeekStart); err != nil { 45 | return fmt.Errorf("unable to process postgres logs: %s", err) 46 | } 47 | 48 | readBytes, err := io.Copy(s.logger, file) 49 | if err != nil { 50 | return fmt.Errorf("unable to process postgres logs: %s", err) 51 | } 52 | 53 | s.offset += readBytes 54 | } 55 | 56 | return nil 57 | } 58 | 59 | func readLogsOrTimeout(logger *os.File) (logContent []byte, err error) { 60 | logContent = []byte("logs could not be read") 61 | 62 | logContentChan := make(chan []byte, 1) 63 | errChan := make(chan error, 1) 64 | 65 | go func() { 66 | if actualLogContent, err := ioutil.ReadFile(logger.Name()); err == nil { 67 | logContentChan <- actualLogContent 68 | } else { 69 | errChan <- err 70 | } 71 | }() 72 | 73 | select { 74 | case logContent = <-logContentChan: 75 | case err = <-errChan: 76 | case <-time.After(10 * time.Second): 77 | err = fmt.Errorf("timed out waiting for logs") 78 | } 79 | 80 | return logContent, err 81 | } 82 | -------------------------------------------------------------------------------- /version_strategy.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "strings" 8 | ) 9 | 10 | // VersionStrategy provides a strategy that can be used to determine which version of Postgres should be used based on 11 | // the operating system, architecture and desired Postgres version. 12 | type VersionStrategy func() (operatingSystem string, architecture string, postgresVersion PostgresVersion) 13 | 14 | func defaultVersionStrategy(config Config, goos, arch string, linuxMachineName func() string, shouldUseAlpineLinuxBuild func() bool) VersionStrategy { 15 | return func() (string, string, PostgresVersion) { 16 | goos := goos 17 | arch := arch 18 | 19 | if goos == "linux" { 20 | // the zonkyio/embedded-postgres-binaries project produces 21 | // arm binaries with the following name schema: 22 | // 32bit: arm32v6 / arm32v7 23 | // 64bit (aarch64): arm64v8 24 | if arch == "arm64" { 25 | arch += "v8" 26 | } else if arch == "arm" { 27 | machineName := linuxMachineName() 28 | if strings.HasPrefix(machineName, "armv7") { 29 | arch += "32v7" 30 | } else if strings.HasPrefix(machineName, "armv6") { 31 | arch += "32v6" 32 | } 33 | } 34 | 35 | if shouldUseAlpineLinuxBuild() { 36 | arch += "-alpine" 37 | } 38 | } 39 | 40 | // postgres below version 14.2 is not available for macos on arm 41 | if goos == "darwin" && arch == "arm64" { 42 | var majorVer, minorVer int 43 | if _, err := fmt.Sscanf(string(config.version), "%d.%d", &majorVer, &minorVer); err == nil && 44 | (majorVer < 14 || (majorVer == 14 && minorVer < 2)) { 45 | arch = "amd64" 46 | } else { 47 | arch += "v8" 48 | } 49 | } 50 | 51 | return goos, arch, config.version 52 | } 53 | } 54 | 55 | func linuxMachineName() string { 56 | var uname string 57 | 58 | if output, err := exec.Command("uname", "-m").Output(); err == nil { 59 | uname = string(output) 60 | } 61 | 62 | return uname 63 | } 64 | 65 | func shouldUseAlpineLinuxBuild() bool { 66 | _, err := os.Stat("/etc/alpine-release") 67 | return err == nil 68 | } 69 | -------------------------------------------------------------------------------- /logging_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | type customLogger struct { 14 | logLines []byte 15 | } 16 | 17 | func (cl *customLogger) Write(p []byte) (n int, err error) { 18 | cl.logLines = append(cl.logLines, p...) 19 | return len(p), nil 20 | } 21 | 22 | func Test_SyncedLogger_CreateError(t *testing.T) { 23 | logger := customLogger{} 24 | _, err := newSyncedLogger("/not-exists-anywhere", &logger) 25 | 26 | assert.Error(t, err) 27 | } 28 | 29 | func Test_SyncedLogger_ErrorDuringFlush(t *testing.T) { 30 | logger := customLogger{} 31 | 32 | sl, slErr := newSyncedLogger("", &logger) 33 | 34 | assert.NoError(t, slErr) 35 | 36 | rmFileErr := os.Remove(sl.file.Name()) 37 | 38 | assert.NoError(t, rmFileErr) 39 | 40 | err := sl.flush() 41 | 42 | assert.Error(t, err) 43 | } 44 | 45 | func Test_SyncedLogger_NoErrorDuringFlush(t *testing.T) { 46 | logger := customLogger{} 47 | 48 | sl, slErr := newSyncedLogger("", &logger) 49 | 50 | assert.NoError(t, slErr) 51 | 52 | err := os.WriteFile(sl.file.Name(), []byte("some logs\non a new line"), os.ModeAppend) 53 | 54 | assert.NoError(t, err) 55 | 56 | err = sl.flush() 57 | 58 | assert.NoError(t, err) 59 | 60 | assert.Equal(t, "some logs\non a new line", string(logger.logLines)) 61 | } 62 | 63 | func Test_readLogsOrTimeout(t *testing.T) { 64 | logFile, err := ioutil.TempFile("", "prepare_database_test_log") 65 | if err != nil { 66 | panic(err) 67 | } 68 | 69 | logContent, err := readLogsOrTimeout(logFile) 70 | assert.NoError(t, err) 71 | assert.Equal(t, []byte(""), logContent) 72 | 73 | _, _ = logFile.Write([]byte("and here are the logs!")) 74 | 75 | logContent, err = readLogsOrTimeout(logFile) 76 | assert.NoError(t, err) 77 | assert.Equal(t, []byte("and here are the logs!"), logContent) 78 | 79 | require.NoError(t, os.Remove(logFile.Name())) 80 | logContent, err = readLogsOrTimeout(logFile) 81 | assert.Equal(t, []byte("logs could not be read"), logContent) 82 | assert.EqualError(t, err, fmt.Sprintf("open %s: no such file or directory", logFile.Name())) 83 | } 84 | -------------------------------------------------------------------------------- /test_util_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "encoding/base64" 5 | "os" 6 | "testing" 7 | 8 | "go.uber.org/goleak" 9 | ) 10 | 11 | func createTempXzArchive() (string, func()) { 12 | return writeFileWithBase64Content("remote_fetch_test*.txz", "/Td6WFoAAATm1rRGAgAhARYAAAB0L+Wj4Av/AKZdADIaSqdFdWDG5Dyin7tszujmfm9YJn6/1REVUfqW8HwXvgwbrrcDDc4Q2ql+L+ybLTxJ+QNhhaKnawviRjKhUOT3syXi2Ye8k4QMkeurnnCu4a8eoCV+hqNFWkk8/w8MzyMzQZ2D3wtvoaZV/KqJ8jyLbNVj+vsKrzqg5vbSGz5/h7F37nqN1V8ZsdCnKnDMZPzovM8RwtelDd0g3fPC0dG/W9PH4wAAAAC2dqs1k9ZA0QABwgGAGAAAIQZ5XbHEZ/sCAAAAAARZWg==") 13 | } 14 | 15 | func createTempZipArchive() (string, func()) { 16 | return writeFileWithBase64Content("remote_fetch_test*.zip", "UEsDBBQACAAIAExBSlMAAAAAAAAAAAAAAAAaAAkAcmVtb3RlX2ZldGNoX3Rlc3Q4MDA0NjE5MDVVVAUAAfCfYmEBAAD//1BLBwgAAAAABQAAAAAAAABQSwMEFAAIAAAATEFKUwAAAAAAAAAAAAAAABUACQByZW1vdGVfZmV0Y2hfdGVzdC50eHpVVAUAAfCfYmH9N3pYWgAABObWtEYCACEBFgAAAHQv5aPgBf8Abl0AORlJ/tq+A8rMBye1kCuXLnw2aeeO0gdfXeVHCWpF8/VeZU/MTVkdLzI+XgKLEMlHJukIdxP7iSAuKts+v7aDrJu68RHNgIsXGrGouAjf780FXjTUjX4vXDh08vNY1yOBayt9z9dKHdoG9AeAIgAAAAAOKMpgA1Mm3wABigGADAAAjIVdpbHEZ/sCAAAAAARZWlBLBwhkmQgRsAAAALAAAABQSwECFAMUAAgACABMQUpTAAAAAAUAAAAAAAAAGgAJAAAAAAAAAAAAgIEAAAAAcmVtb3RlX2ZldGNoX3Rlc3Q4MDA0NjE5MDVVVAUAAfCfYmFQSwECFAMUAAgAAABMQUpTZJkIEbAAAACwAAAAFQAJAAAAAAAAAAAApIFWAAAAcmVtb3RlX2ZldGNoX3Rlc3QudHh6VVQFAAHwn2JhUEsFBgAAAAACAAIAnQAAAFIBAAAAAA==") 17 | } 18 | 19 | func writeFileWithBase64Content(filename, base64Content string) (string, func()) { 20 | tempFile, err := os.CreateTemp("", filename) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | byteContent, err := base64.StdEncoding.DecodeString(base64Content) 26 | if err != nil { 27 | panic(err) 28 | } 29 | 30 | if err := os.WriteFile(tempFile.Name(), byteContent, 0666); err != nil { 31 | panic(err) 32 | } 33 | 34 | return tempFile.Name(), func() { 35 | if err := os.RemoveAll(tempFile.Name()); err != nil { 36 | panic(err) 37 | } 38 | } 39 | } 40 | 41 | func shutdownDBAndFail(t *testing.T, err error, db *EmbeddedPostgres) { 42 | if db.started { 43 | if stopErr := db.Stop(); stopErr != nil { 44 | t.Errorf("Failed to shutdown server with error %s", stopErr) 45 | } 46 | } 47 | 48 | t.Errorf("Failed for version %s with error %s", db.config.version, err) 49 | } 50 | 51 | func testVersionStrategy() VersionStrategy { 52 | return func() (string, string, PostgresVersion) { 53 | return "darwin", "amd64", "1.2.3" 54 | } 55 | } 56 | 57 | func testCacheLocator() CacheLocator { 58 | return func() (s string, b bool) { 59 | return "", false 60 | } 61 | } 62 | 63 | func verifyLeak(t *testing.T) { 64 | // Ideally, there should be no exceptions here. 65 | goleak.VerifyNone(t, goleak.IgnoreTopFunction("internal/poll.runtime_pollWait")) 66 | } 67 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Embedded Postgres 2 | on: 3 | push: 4 | branches: [ "master" ] 5 | pull_request: 6 | branches: [ "*" ] 7 | jobs: 8 | tests: 9 | name: Tests 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | id: go 14 | uses: actions/checkout@v4 15 | - name: Set Up Golang 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: 1.22 19 | - name: Check Dependencies 20 | run: | 21 | go list -json -deps > go.list 22 | for d in "." "examples" "platform-test"; do 23 | pushd $d 24 | go mod tidy 25 | if [ ! -z "$(git status --porcelain go.mod)" ]; then 26 | printf "go.mod has modifications\n" 27 | git diff go.mod 28 | exit 1 29 | fi 30 | if [ ! -z "$(git status --porcelain go.sum)" ]; then 31 | printf "go.sum has modifications\n" 32 | git diff go.sum 33 | exit 1 34 | fi 35 | popd 36 | done; 37 | - name: Nancy Vulnerability 38 | uses: sonatype-nexus-community/nancy-github-action@main 39 | with: 40 | nancyVersion: v1.0.52 41 | nancyCommand: sleuth 42 | env: 43 | OSSI_TOKEN: ${{ secrets.OSSI_TOKEN }} 44 | OSSI_USERNAME: ${{ secrets.OSSI_USERNAME }} 45 | - name: GolangCI Lint 46 | run: | 47 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.42.1 48 | /home/runner/go/bin/golangci-lint run 49 | - name: Test 50 | run: go test -v -test.timeout 0 -race -cover -covermode=atomic -coverprofile=coverage.out ./... 51 | - name: Test Examples 52 | run: | 53 | pushd examples && \ 54 | go test -v ./... && \ 55 | popd 56 | - name: Upload Coverage Report 57 | env: 58 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 59 | run: go install github.com/mattn/goveralls@latest && $(go env GOPATH)/bin/goveralls -v -coverprofile=coverage.out -service=github 60 | alpine_tests: 61 | name: Alpine Linux Platform Tests 62 | runs-on: ubuntu-latest 63 | container: 64 | image: golang:1.22-alpine 65 | steps: 66 | - uses: actions/checkout@v4 67 | - name: Set Up 68 | run: | 69 | apk add --upgrade gcc g++ && \ 70 | adduser testuser -D 71 | - name: All Tests 72 | run: su - testuser -c 'export PATH=$PATH:/usr/local/go/bin; cd /__w/embedded-postgres/embedded-postgres && go test -v ./... && cd platform-test && go test -v ./...' 73 | platform_tests: 74 | name: Platform tests 75 | strategy: 76 | matrix: 77 | os: [ ubuntu-latest, windows-latest, macos-14 ] 78 | runs-on: ${{ matrix.os }} 79 | steps: 80 | - name: Checkout 81 | uses: actions/checkout@v4 82 | - name: Set Up Golang 83 | uses: actions/setup-go@v5 84 | with: 85 | go-version: 1.22 86 | - name: Platform Tests 87 | run: | 88 | cd platform-test 89 | go test -v -race ./... 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/go,intellij+all,visualstudiocode 3 | # Edit at https://www.gitignore.io/?templates=go,intellij+all,visualstudiocode 4 | 5 | ### Go ### 6 | # Binaries for programs and plugins 7 | *.exe 8 | *.exe~ 9 | *.dll 10 | *.so 11 | *.dylib 12 | 13 | # Test binary, built with `go test -c` 14 | *.test 15 | 16 | # Output of the go coverage tool, specifically when used with LiteIDE 17 | *.out 18 | 19 | # Dependency directories (remove the comment below to include it) 20 | # vendor/ 21 | 22 | ### Go Patch ### 23 | /vendor/ 24 | /Godeps/ 25 | 26 | ### Intellij+all ### 27 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 28 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 29 | 30 | # User-specific stuff 31 | .idea/**/workspace.xml 32 | .idea/**/tasks.xml 33 | .idea/**/usage.statistics.xml 34 | .idea/**/dictionaries 35 | .idea/**/shelf 36 | 37 | # Generated files 38 | .idea/**/contentModel.xml 39 | 40 | # Sensitive or high-churn files 41 | .idea/**/dataSources/ 42 | .idea/**/dataSources.ids 43 | .idea/**/dataSources.local.xml 44 | .idea/**/sqlDataSources.xml 45 | .idea/**/dynamic.xml 46 | .idea/**/uiDesigner.xml 47 | .idea/**/dbnavigator.xml 48 | 49 | # Gradle 50 | .idea/**/gradle.xml 51 | .idea/**/libraries 52 | 53 | # Gradle and Maven with auto-import 54 | # When using Gradle or Maven with auto-import, you should exclude module files, 55 | # since they will be recreated, and may cause churn. Uncomment if using 56 | # auto-import. 57 | # .idea/modules.xml 58 | # .idea/*.iml 59 | # .idea/modules 60 | # *.iml 61 | # *.ipr 62 | 63 | # CMake 64 | cmake-build-*/ 65 | 66 | # Mongo Explorer plugin 67 | .idea/**/mongoSettings.xml 68 | 69 | # File-based project format 70 | *.iws 71 | 72 | # IntelliJ 73 | out/ 74 | 75 | # mpeltonen/sbt-idea plugin 76 | .idea_modules/ 77 | 78 | # JIRA plugin 79 | atlassian-ide-plugin.xml 80 | 81 | # Cursive Clojure plugin 82 | .idea/replstate.xml 83 | 84 | # Crashlytics plugin (for Android Studio and IntelliJ) 85 | com_crashlytics_export_strings.xml 86 | crashlytics.properties 87 | crashlytics-build.properties 88 | fabric.properties 89 | 90 | # Editor-based Rest Client 91 | .idea/httpRequests 92 | 93 | # Android studio 3.1+ serialized cache file 94 | .idea/caches/build_file_checksums.ser 95 | 96 | ### Intellij+all Patch ### 97 | # Ignores the whole .idea folder and all .iml files 98 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 99 | 100 | .idea/ 101 | 102 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 103 | 104 | *.iml 105 | modules.xml 106 | .idea/misc.xml 107 | *.ipr 108 | 109 | # Sonarlint plugin 110 | .idea/sonarlint 111 | 112 | ### VisualStudioCode ### 113 | .vscode/* 114 | !.vscode/settings.json 115 | !.vscode/tasks.json 116 | !.vscode/launch.json 117 | !.vscode/extensions.json 118 | 119 | ### VisualStudioCode Patch ### 120 | # Ignore all local history of files 121 | .history 122 | 123 | # End of https://www.gitignore.io/api/go,intellij+all,visualstudiocode 124 | 125 | main -------------------------------------------------------------------------------- /decompression.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/tar" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/xi2/xz" 11 | ) 12 | 13 | func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 14 | tarReader := tar.NewReader(xzReader) 15 | 16 | return func() (*tar.Header, error) { 17 | return tarReader.Next() 18 | }, func() io.Reader { 19 | return tarReader 20 | } 21 | } 22 | 23 | func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error { 24 | extractDirectory := filepath.Dir(extractPath) 25 | 26 | if err := os.MkdirAll(extractDirectory, os.ModePerm); err != nil { 27 | return errorUnableToExtract(path, extractPath, err) 28 | } 29 | 30 | tempExtractPath, err := os.MkdirTemp(extractDirectory, "temp_") 31 | if err != nil { 32 | return errorUnableToExtract(path, extractPath, err) 33 | } 34 | defer func() { 35 | if err := os.RemoveAll(tempExtractPath); err != nil { 36 | panic(err) 37 | } 38 | }() 39 | 40 | tarFile, err := os.Open(path) 41 | if err != nil { 42 | return errorUnableToExtract(path, extractPath, err) 43 | } 44 | 45 | defer func() { 46 | if err := tarFile.Close(); err != nil { 47 | panic(err) 48 | } 49 | }() 50 | 51 | xzReader, err := xz.NewReader(tarFile, 0) 52 | if err != nil { 53 | return errorUnableToExtract(path, extractPath, err) 54 | } 55 | 56 | readNext, reader := tarReader(xzReader) 57 | 58 | for { 59 | header, err := readNext() 60 | 61 | if err == io.EOF { 62 | break 63 | } 64 | 65 | if err != nil { 66 | return errorExtractingPostgres(err) 67 | } 68 | 69 | targetPath := filepath.Join(tempExtractPath, header.Name) 70 | finalPath := filepath.Join(extractPath, header.Name) 71 | 72 | if err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm); err != nil { 73 | return errorExtractingPostgres(err) 74 | } 75 | 76 | if err := os.MkdirAll(filepath.Dir(finalPath), os.ModePerm); err != nil { 77 | return errorExtractingPostgres(err) 78 | } 79 | 80 | switch header.Typeflag { 81 | case tar.TypeReg: 82 | outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) 83 | if err != nil { 84 | return errorExtractingPostgres(err) 85 | } 86 | 87 | if _, err := io.Copy(outFile, reader()); err != nil { 88 | return errorExtractingPostgres(err) 89 | } 90 | 91 | if err := outFile.Close(); err != nil { 92 | return errorExtractingPostgres(err) 93 | } 94 | case tar.TypeSymlink: 95 | if err := os.RemoveAll(targetPath); err != nil { 96 | return errorExtractingPostgres(err) 97 | } 98 | 99 | if err := os.Symlink(header.Linkname, targetPath); err != nil { 100 | return errorExtractingPostgres(err) 101 | } 102 | 103 | case tar.TypeDir: 104 | if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil { 105 | return errorExtractingPostgres(err) 106 | } 107 | continue 108 | } 109 | 110 | if err := renameOrIgnore(targetPath, finalPath); err != nil { 111 | return errorExtractingPostgres(err) 112 | } 113 | } 114 | 115 | return nil 116 | } 117 | 118 | func errorUnableToExtract(cacheLocation, binariesPath string, err error) error { 119 | return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w", 120 | cacheLocation, 121 | binariesPath, 122 | err, 123 | ) 124 | } 125 | -------------------------------------------------------------------------------- /platform-test/platform_test.go: -------------------------------------------------------------------------------- 1 | package platform_test 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "runtime" 9 | "strings" 10 | "testing" 11 | 12 | embeddedpostgres "github.com/fergusstrange/embedded-postgres" 13 | ) 14 | 15 | func Test_AllMajorVersions(t *testing.T) { 16 | allVersions := []embeddedpostgres.PostgresVersion{ 17 | embeddedpostgres.V18, 18 | embeddedpostgres.V17, 19 | embeddedpostgres.V16, 20 | embeddedpostgres.V15, 21 | embeddedpostgres.V14, 22 | } 23 | 24 | isLikelyAppleSilicon := runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" 25 | 26 | if !isLikelyAppleSilicon { 27 | allVersions = append(allVersions, 28 | embeddedpostgres.V13, 29 | embeddedpostgres.V12, 30 | embeddedpostgres.V11, 31 | embeddedpostgres.V10, 32 | embeddedpostgres.V9) 33 | } 34 | 35 | tempExtractLocation, err := os.MkdirTemp("", "embedded_postgres_tests") 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | for i, v := range allVersions { 41 | testNumber := i 42 | version := v 43 | t.Run(fmt.Sprintf("MajorVersion_%s", version), func(t *testing.T) { 44 | port := uint32(5555 + testNumber) 45 | runtimePath := filepath.Join(tempExtractLocation, string(version)) 46 | 47 | maxConnections := 150 48 | database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). 49 | Version(version). 50 | Port(port). 51 | RuntimePath(runtimePath). 52 | StartParameters(map[string]string{ 53 | "max_connections": fmt.Sprintf("%d", maxConnections), 54 | })) 55 | 56 | if err := database.Start(); err != nil { 57 | shutdownDBAndFail(t, err, database, version) 58 | } 59 | 60 | db, err := connect(port) 61 | if err != nil { 62 | shutdownDBAndFail(t, err, database, version) 63 | } 64 | 65 | rows, err := db.Query("SELECT 1") 66 | if err != nil { 67 | shutdownDBAndFail(t, err, database, version) 68 | } 69 | if err := rows.Close(); err != nil { 70 | shutdownDBAndFail(t, err, database, version) 71 | } 72 | 73 | rows, err = db.Query(`SELECT setting::int max_conn FROM pg_settings WHERE name = 'max_connections';`) 74 | if err != nil { 75 | shutdownDBAndFail(t, err, database, version) 76 | } 77 | if !rows.Next() { 78 | shutdownDBAndFail(t, fmt.Errorf("no rows returned for max_connections"), database, version) 79 | } 80 | var maxConnReturned int 81 | if err := rows.Scan(&maxConnReturned); err != nil { 82 | shutdownDBAndFail(t, err, database, version) 83 | } 84 | if maxConnReturned != maxConnections { 85 | shutdownDBAndFail(t, fmt.Errorf("max_connections is %d, not %d as expected", maxConnReturned, maxConnections), database, version) 86 | } 87 | if err := rows.Close(); err != nil { 88 | shutdownDBAndFail(t, err, database, version) 89 | } 90 | 91 | if err := db.Close(); err != nil { 92 | shutdownDBAndFail(t, err, database, version) 93 | } 94 | 95 | if err := database.Stop(); err != nil { 96 | t.Fatal(err) 97 | } 98 | 99 | if err := checkPgVersionFile(filepath.Join(runtimePath, "data"), version); err != nil { 100 | t.Fatal(err) 101 | } 102 | }) 103 | } 104 | } 105 | 106 | func shutdownDBAndFail(t *testing.T, err error, db *embeddedpostgres.EmbeddedPostgres, version embeddedpostgres.PostgresVersion) { 107 | if err2 := db.Stop(); err2 != nil { 108 | t.Fatalf("Failed for version %s with error %s, original error %s", version, err2, err) 109 | } 110 | 111 | t.Fatalf("Failed for version %s with error %s", version, err) 112 | } 113 | 114 | func connect(port uint32) (*sql.DB, error) { 115 | db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port)) 116 | return db, err 117 | } 118 | 119 | func checkPgVersionFile(dataDir string, version embeddedpostgres.PostgresVersion) error { 120 | pgVersion := filepath.Join(dataDir, "PG_VERSION") 121 | 122 | d, err := os.ReadFile(pgVersion) 123 | if err != nil { 124 | return fmt.Errorf("could not read file %v", pgVersion) 125 | } 126 | 127 | v := strings.TrimSuffix(string(d), "\n") 128 | 129 | if strings.HasPrefix(string(version), v) { 130 | return nil 131 | } 132 | 133 | return fmt.Errorf("version missmatch in PG_VERSION: %v <> %v", string(version), v) 134 | } 135 | -------------------------------------------------------------------------------- /version_strategy_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_DefaultVersionStrategy_AllGolangDistributions(t *testing.T) { 12 | allGolangDistributions := map[string][]string{ 13 | "aix/ppc64": {"aix", "ppc64"}, 14 | "android/386": {"android", "386"}, 15 | "android/amd64": {"android", "amd64"}, 16 | "android/arm": {"android", "arm"}, 17 | "android/arm64": {"android", "arm64"}, 18 | "darwin/amd64": {"darwin", "amd64"}, 19 | "darwin/arm64": {"darwin", "amd64"}, 20 | "dragonfly/amd64": {"dragonfly", "amd64"}, 21 | "freebsd/386": {"freebsd", "386"}, 22 | "freebsd/amd64": {"freebsd", "amd64"}, 23 | "freebsd/arm": {"freebsd", "arm"}, 24 | "freebsd/arm64": {"freebsd", "arm64"}, 25 | "illumos/amd64": {"illumos", "amd64"}, 26 | "js/wasm": {"js", "wasm"}, 27 | "linux/386": {"linux", "386"}, 28 | "linux/amd64": {"linux", "amd64"}, 29 | "linux/arm": {"linux", "arm"}, 30 | "linux/arm64": {"linux", "arm64v8"}, 31 | "linux/mips": {"linux", "mips"}, 32 | "linux/mips64": {"linux", "mips64"}, 33 | "linux/mips64le": {"linux", "mips64le"}, 34 | "linux/mipsle": {"linux", "mipsle"}, 35 | "linux/ppc64": {"linux", "ppc64"}, 36 | "linux/ppc64le": {"linux", "ppc64le"}, 37 | "linux/riscv64": {"linux", "riscv64"}, 38 | "linux/s390x": {"linux", "s390x"}, 39 | "netbsd/386": {"netbsd", "386"}, 40 | "netbsd/amd64": {"netbsd", "amd64"}, 41 | "netbsd/arm": {"netbsd", "arm"}, 42 | "netbsd/arm64": {"netbsd", "arm64"}, 43 | "openbsd/386": {"openbsd", "386"}, 44 | "openbsd/amd64": {"openbsd", "amd64"}, 45 | "openbsd/arm": {"openbsd", "arm"}, 46 | "openbsd/arm64": {"openbsd", "arm64"}, 47 | "plan9/386": {"plan9", "386"}, 48 | "plan9/amd64": {"plan9", "amd64"}, 49 | "plan9/arm": {"plan9", "arm"}, 50 | "solaris/amd64": {"solaris", "amd64"}, 51 | "windows/386": {"windows", "386"}, 52 | "windows/amd64": {"windows", "amd64"}, 53 | "windows/arm": {"windows", "arm"}, 54 | } 55 | 56 | versionDifferences := map[PostgresVersion]map[string][]string{ 57 | PostgresVersion("14.0.0"): {}, 58 | PostgresVersion("14.1.0"): {}, 59 | PostgresVersion("14.2.0"): {"darwin/arm64": {"darwin", "arm64v8"}}, 60 | V15: {"darwin/arm64": {"darwin", "arm64v8"}}, 61 | } 62 | defaultConfig := DefaultConfig() 63 | 64 | for version, differences := range versionDifferences { 65 | defaultConfig.version = version 66 | 67 | for dist, expected := range allGolangDistributions { 68 | dist := dist 69 | expected := expected 70 | 71 | if override, ok := differences[dist]; ok { 72 | expected = override 73 | } 74 | 75 | t.Run(fmt.Sprintf("DefaultVersionStrategy_%s", dist), func(t *testing.T) { 76 | osArch := strings.Split(dist, "/") 77 | 78 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 79 | defaultConfig, 80 | osArch[0], 81 | osArch[1], 82 | linuxMachineName, 83 | func() bool { 84 | return false 85 | })() 86 | 87 | assert.Equal(t, expected[0], operatingSystem) 88 | assert.Equal(t, expected[1], architecture) 89 | assert.Equal(t, version, postgresVersion) 90 | }) 91 | } 92 | } 93 | } 94 | 95 | func Test_DefaultVersionStrategy_Linux_ARM32V6(t *testing.T) { 96 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 97 | DefaultConfig(), 98 | "linux", 99 | "arm", 100 | func() string { 101 | return "armv6l" 102 | }, func() bool { 103 | return false 104 | })() 105 | 106 | assert.Equal(t, "linux", operatingSystem) 107 | assert.Equal(t, "arm32v6", architecture) 108 | assert.Equal(t, V18, postgresVersion) 109 | } 110 | 111 | func Test_DefaultVersionStrategy_Linux_ARM32V7(t *testing.T) { 112 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 113 | DefaultConfig(), 114 | "linux", 115 | "arm", 116 | func() string { 117 | return "armv7l" 118 | }, func() bool { 119 | return false 120 | })() 121 | 122 | assert.Equal(t, "linux", operatingSystem) 123 | assert.Equal(t, "arm32v7", architecture) 124 | assert.Equal(t, V18, postgresVersion) 125 | } 126 | 127 | func Test_DefaultVersionStrategy_Linux_Alpine(t *testing.T) { 128 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 129 | DefaultConfig(), 130 | "linux", 131 | "amd64", 132 | func() string { 133 | return "" 134 | }, 135 | func() bool { 136 | return true 137 | }, 138 | )() 139 | 140 | assert.Equal(t, "linux", operatingSystem) 141 | assert.Equal(t, "amd64-alpine", architecture) 142 | assert.Equal(t, V18, postgresVersion) 143 | } 144 | 145 | func Test_DefaultVersionStrategy_shouldUseAlpineLinuxBuild(t *testing.T) { 146 | assert.NotPanics(t, func() { 147 | shouldUseAlpineLinuxBuild() 148 | }) 149 | } 150 | -------------------------------------------------------------------------------- /examples/examples_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "reflect" 8 | "testing" 9 | 10 | embeddedpostgres "github.com/fergusstrange/embedded-postgres" 11 | "github.com/jmoiron/sqlx" 12 | _ "github.com/lib/pq" 13 | "github.com/pressly/goose/v3" 14 | "go.uber.org/zap" 15 | "go.uber.org/zap/zapio" 16 | ) 17 | 18 | func Test_GooseMigrations(t *testing.T) { 19 | database := embeddedpostgres.NewDatabase() 20 | if err := database.Start(); err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | defer func() { 25 | if err := database.Stop(); err != nil { 26 | t.Fatal(err) 27 | } 28 | }() 29 | 30 | db, err := connect() 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | 35 | if err := goose.Up(db.DB, "./migrations"); err != nil { 36 | t.Fatal(err) 37 | } 38 | } 39 | 40 | func Test_ZapioLogger(t *testing.T) { 41 | logger, err := zap.NewProduction() 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | w := &zapio.Writer{Log: logger} 47 | 48 | database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). 49 | Logger(w)) 50 | if err := database.Start(); err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | defer func() { 55 | if err := database.Stop(); err != nil { 56 | t.Fatal(err) 57 | } 58 | }() 59 | 60 | db, err := connect() 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | if err := goose.Up(db.DB, "./migrations"); err != nil { 66 | t.Fatal(err) 67 | } 68 | } 69 | 70 | func Test_Sqlx_SelectOne(t *testing.T) { 71 | database := embeddedpostgres.NewDatabase() 72 | if err := database.Start(); err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | defer func() { 77 | if err := database.Stop(); err != nil { 78 | t.Fatal(err) 79 | } 80 | }() 81 | 82 | db, err := connect() 83 | if err != nil { 84 | t.Fatal(err) 85 | } 86 | 87 | rows := make([]int32, 0) 88 | 89 | err = db.Select(&rows, "SELECT 1") 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | 94 | if len(rows) != 1 { 95 | t.Fatal("Expected one row returned") 96 | } 97 | } 98 | 99 | func Test_ManyTestsAgainstOneDatabase(t *testing.T) { 100 | database := embeddedpostgres.NewDatabase() 101 | if err := database.Start(); err != nil { 102 | t.Fatal(err) 103 | } 104 | 105 | defer func() { 106 | if err := database.Stop(); err != nil { 107 | t.Fatal(err) 108 | } 109 | }() 110 | 111 | db, err := connect() 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | 116 | if err := goose.Up(db.DB, "./migrations"); err != nil { 117 | t.Fatal(err) 118 | } 119 | 120 | tests := []func(t *testing.T){ 121 | func(t *testing.T) { 122 | rows := make([]BeerCatalogue, 0) 123 | if err := db.Select(&rows, "SELECT * FROM beer_catalogue WHERE UPPER(name) = UPPER('Elvis Juice')"); err != nil { 124 | t.Fatal(err) 125 | } 126 | 127 | if len(rows) != 0 { 128 | t.Fatalf("expected 0 rows but got %d", len(rows)) 129 | } 130 | }, 131 | func(t *testing.T) { 132 | _, err := db.Exec(`INSERT INTO beer_catalogue (name, consumed, rating) VALUES ($1, $2, $3)`, 133 | "Kernal", 134 | true, 135 | 99.32) 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | 140 | actualBeerCatalogue := make([]BeerCatalogue, 0) 141 | if err := db.Select(&actualBeerCatalogue, "SELECT * FROM beer_catalogue WHERE id = 2"); err != nil { 142 | t.Fatal(err) 143 | } 144 | 145 | expectedBeerCatalogue := BeerCatalogue{ 146 | ID: 2, 147 | Name: "Kernal", 148 | Consumed: true, 149 | Rating: 99.32, 150 | } 151 | if !reflect.DeepEqual(expectedBeerCatalogue, actualBeerCatalogue[0]) { 152 | t.Fatalf("expected %+v did not match actual %+v", expectedBeerCatalogue, actualBeerCatalogue) 153 | } 154 | }, 155 | } 156 | 157 | for testNumber, test := range tests { 158 | t.Run(fmt.Sprintf("%d", testNumber), test) 159 | } 160 | } 161 | 162 | func Test_SimpleHttpWebApp(t *testing.T) { 163 | database := embeddedpostgres.NewDatabase() 164 | if err := database.Start(); err != nil { 165 | t.Fatal(err) 166 | } 167 | 168 | defer func() { 169 | if err := database.Stop(); err != nil { 170 | t.Fatal(err) 171 | } 172 | }() 173 | 174 | request := httptest.NewRequest("GET", "/beer-catalogue?name=Punk%20IPA", nil) 175 | recorder := httptest.NewRecorder() 176 | 177 | NewApp().router.ServeHTTP(recorder, request) 178 | 179 | if recorder.Code != http.StatusOK { 180 | t.Fatalf("expected 200 but receieved %d", recorder.Code) 181 | } 182 | 183 | expectedPayload := `[{"id":1,"name":"Punk IPA","consumed":true,"rating":68.29}]` 184 | actualPayload := recorder.Body.String() 185 | 186 | if actualPayload != expectedPayload { 187 | t.Fatalf("expected %+v but receieved %+v", expectedPayload, actualPayload) 188 | } 189 | } 190 | 191 | func connect() (*sqlx.DB, error) { 192 | db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 193 | return db, err 194 | } 195 | -------------------------------------------------------------------------------- /prepare_database.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | 13 | "github.com/lib/pq" 14 | ) 15 | 16 | const ( 17 | fmtCloseDBConn = "unable to close database connection: %w" 18 | fmtAfterError = "%v happened after error: %w" 19 | ) 20 | 21 | type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error 22 | type createDatabase func(port uint32, username, password, database string) error 23 | 24 | func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error { 25 | passwordFile, err := createPasswordFile(runtimePath, password) 26 | if err != nil { 27 | return err 28 | } 29 | 30 | args := []string{ 31 | "-A", "password", 32 | "-U", username, 33 | "-D", pgDataDir, 34 | fmt.Sprintf("--pwfile=%s", passwordFile), 35 | } 36 | 37 | if locale != "" { 38 | args = append(args, fmt.Sprintf("--locale=%s", locale)) 39 | } 40 | 41 | if encoding != "" { 42 | args = append(args, fmt.Sprintf("--encoding=%s", encoding)) 43 | } 44 | 45 | postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb") 46 | postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...) 47 | postgresInitDBProcess.Stderr = logger 48 | postgresInitDBProcess.Stdout = logger 49 | 50 | if err = postgresInitDBProcess.Run(); err != nil { 51 | logContent, readLogsErr := readLogsOrTimeout(logger) // we want to preserve the original error 52 | if readLogsErr != nil { 53 | logContent = []byte(string(logContent) + " - " + readLogsErr.Error()) 54 | } 55 | return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent)) 56 | } 57 | 58 | if err = os.Remove(passwordFile); err != nil { 59 | return fmt.Errorf("unable to remove password file '%v': %w", passwordFile, err) 60 | } 61 | 62 | return nil 63 | } 64 | 65 | func createPasswordFile(runtimePath, password string) (string, error) { 66 | passwordFileLocation := filepath.Join(runtimePath, "pwfile") 67 | if err := os.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil { 68 | return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation) 69 | } 70 | 71 | return passwordFileLocation, nil 72 | } 73 | 74 | func defaultCreateDatabase(port uint32, username, password, database string) (err error) { 75 | if database == "postgres" { 76 | return nil 77 | } 78 | 79 | conn, err := openDatabaseConnection(port, username, password, "postgres") 80 | if err != nil { 81 | return errorCustomDatabase(database, err) 82 | } 83 | 84 | db := sql.OpenDB(conn) 85 | defer func() { 86 | err = connectionClose(db, err) 87 | }() 88 | 89 | if _, err := db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", database)); err != nil { 90 | return errorCustomDatabase(database, err) 91 | } 92 | 93 | return nil 94 | } 95 | 96 | // connectionClose closes the database connection and handles the error of the function that used the database connection 97 | func connectionClose(db io.Closer, err error) error { 98 | closeErr := db.Close() 99 | if closeErr != nil { 100 | closeErr = fmt.Errorf(fmtCloseDBConn, closeErr) 101 | 102 | if err != nil { 103 | err = fmt.Errorf(fmtAfterError, closeErr, err) 104 | } else { 105 | err = closeErr 106 | } 107 | } 108 | 109 | return err 110 | } 111 | 112 | func healthCheckDatabaseOrTimeout(config Config) error { 113 | healthCheckSignal := make(chan bool) 114 | 115 | defer close(healthCheckSignal) 116 | 117 | timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout) 118 | 119 | defer cancelFunc() 120 | 121 | go func() { 122 | for timeout.Err() == nil { 123 | if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil { 124 | continue 125 | } 126 | healthCheckSignal <- true 127 | 128 | break 129 | } 130 | }() 131 | 132 | select { 133 | case <-healthCheckSignal: 134 | return nil 135 | case <-timeout.Done(): 136 | return errors.New("timed out waiting for database to become available") 137 | } 138 | } 139 | 140 | func healthCheckDatabase(port uint32, database, username, password string) (err error) { 141 | conn, err := openDatabaseConnection(port, username, password, database) 142 | if err != nil { 143 | return err 144 | } 145 | 146 | db := sql.OpenDB(conn) 147 | defer func() { 148 | err = connectionClose(db, err) 149 | }() 150 | 151 | if _, err := db.Query("SELECT 1"); err != nil { 152 | return err 153 | } 154 | 155 | return nil 156 | } 157 | 158 | func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) { 159 | conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable", 160 | port, 161 | username, 162 | password, 163 | database)) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | return conn, nil 169 | } 170 | 171 | func errorCustomDatabase(database string, err error) error { 172 | return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err) 173 | } 174 | -------------------------------------------------------------------------------- /remote_fetch.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/zip" 5 | "bytes" 6 | "crypto/sha256" 7 | "encoding/hex" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "log" 12 | "net/http" 13 | "os" 14 | "path/filepath" 15 | "strings" 16 | ) 17 | 18 | // RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use. 19 | type RemoteFetchStrategy func() error 20 | 21 | //nolint:funlen 22 | func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy { 23 | return func() error { 24 | operatingSystem, architecture, version := versionStrategy() 25 | 26 | jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar", 27 | remoteFetchHost, 28 | operatingSystem, 29 | architecture, 30 | version, 31 | operatingSystem, 32 | architecture, 33 | version) 34 | 35 | jarDownloadResponse, err := http.Get(jarDownloadURL) 36 | if err != nil { 37 | return fmt.Errorf("unable to connect to %s", remoteFetchHost) 38 | } 39 | 40 | defer closeBody(jarDownloadResponse)() 41 | 42 | if jarDownloadResponse.StatusCode != http.StatusOK { 43 | return fmt.Errorf("no version found matching %s", version) 44 | } 45 | 46 | jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body) 47 | if err != nil { 48 | return errorFetchingPostgres(err) 49 | } 50 | 51 | shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL) 52 | shaDownloadResponse, err := http.Get(shaDownloadURL) 53 | if err != nil { 54 | return fmt.Errorf("download sha256 from %s failed: %w", shaDownloadURL, err) 55 | } 56 | defer closeBody(shaDownloadResponse)() 57 | 58 | if err == nil && shaDownloadResponse.StatusCode == http.StatusOK { 59 | if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil { 60 | jarChecksum := sha256.Sum256(jarBodyBytes) 61 | if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) { 62 | return errors.New("downloaded checksums do not match") 63 | } 64 | } 65 | } 66 | 67 | return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL) 68 | } 69 | } 70 | 71 | func closeBody(resp *http.Response) func() { 72 | return func() { 73 | if resp == nil || resp.Body == nil { 74 | return 75 | } 76 | if err := resp.Body.Close(); err != nil { 77 | log.Fatal(err) 78 | } 79 | } 80 | } 81 | 82 | func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error { 83 | size := contentLength 84 | // if the content length is not set (i.e. chunked encoding), 85 | // we need to use the length of the bodyBytes otherwise 86 | // the unzip operation will fail 87 | if contentLength < 0 { 88 | size = int64(len(bodyBytes)) 89 | } 90 | zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size) 91 | if err != nil { 92 | return errorFetchingPostgres(err) 93 | } 94 | 95 | cacheLocation, _ := cacheLocator() 96 | 97 | if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil { 98 | return errorExtractingPostgres(err) 99 | } 100 | 101 | for _, file := range zipReader.File { 102 | if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") { 103 | if err := decompressSingleFile(file, cacheLocation); err != nil { 104 | return err 105 | } 106 | 107 | // we have successfully found the file, return early 108 | return nil 109 | } 110 | } 111 | 112 | return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL) 113 | } 114 | 115 | func decompressSingleFile(file *zip.File, cacheLocation string) error { 116 | renamed := false 117 | 118 | archiveReader, err := file.Open() 119 | if err != nil { 120 | return errorExtractingPostgres(err) 121 | } 122 | 123 | archiveBytes, err := io.ReadAll(archiveReader) 124 | if err != nil { 125 | return errorExtractingPostgres(err) 126 | } 127 | 128 | // if multiple processes attempt to extract 129 | // to prevent file corruption when multiple processes attempt to extract at the same time 130 | // first to a cache location, and then move the file into place. 131 | tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_") 132 | if err != nil { 133 | return errorExtractingPostgres(err) 134 | } 135 | defer func() { 136 | // if anything failed before the rename then the temporary file should be cleaned up. 137 | // if the rename was successful then there is no temporary file to remove. 138 | if !renamed { 139 | if err := os.Remove(tmp.Name()); err != nil { 140 | panic(err) 141 | } 142 | } 143 | }() 144 | 145 | if _, err := tmp.Write(archiveBytes); err != nil { 146 | return errorExtractingPostgres(err) 147 | } 148 | 149 | // Windows cannot rename a file if is it still open. 150 | // The file needs to be manually closed to allow the rename to happen 151 | if err := tmp.Close(); err != nil { 152 | return errorExtractingPostgres(err) 153 | } 154 | 155 | if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil { 156 | return errorExtractingPostgres(err) 157 | } 158 | renamed = true 159 | 160 | return nil 161 | } 162 | 163 | func errorExtractingPostgres(err error) error { 164 | return fmt.Errorf("unable to extract postgres archive: %s", err) 165 | } 166 | 167 | func errorFetchingPostgres(err error) error { 168 | return fmt.Errorf("error fetching postgres: %s", err) 169 | } 170 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // Config maintains the runtime configuration for the Postgres process to be created. 11 | type Config struct { 12 | version PostgresVersion 13 | port uint32 14 | database string 15 | username string 16 | password string 17 | cachePath string 18 | runtimePath string 19 | dataPath string 20 | binariesPath string 21 | locale string 22 | encoding string 23 | startParameters map[string]string 24 | binaryRepositoryURL string 25 | startTimeout time.Duration 26 | logger io.Writer 27 | } 28 | 29 | // DefaultConfig provides a default set of configuration to be used "as is" or modified using the provided builders. 30 | // The following can be assumed as defaults: 31 | // Version: 16 32 | // Port: 5432 33 | // Database: postgres 34 | // Username: postgres 35 | // Password: postgres 36 | // StartTimeout: 15 Seconds 37 | func DefaultConfig() Config { 38 | return Config{ 39 | version: V18, 40 | port: 5432, 41 | database: "postgres", 42 | username: "postgres", 43 | password: "postgres", 44 | startTimeout: 15 * time.Second, 45 | logger: os.Stdout, 46 | binaryRepositoryURL: "https://repo1.maven.org/maven2", 47 | } 48 | } 49 | 50 | // Version will set the Postgres binary version. 51 | func (c Config) Version(version PostgresVersion) Config { 52 | c.version = version 53 | return c 54 | } 55 | 56 | // Port sets the runtime port that Postgres can be accessed on. 57 | func (c Config) Port(port uint32) Config { 58 | c.port = port 59 | return c 60 | } 61 | 62 | // Database sets the database name that will be created. 63 | func (c Config) Database(database string) Config { 64 | c.database = database 65 | return c 66 | } 67 | 68 | // Username sets the username that will be used to connect. 69 | func (c Config) Username(username string) Config { 70 | c.username = username 71 | return c 72 | } 73 | 74 | // Password sets the password that will be used to connect. 75 | func (c Config) Password(password string) Config { 76 | c.password = password 77 | return c 78 | } 79 | 80 | // RuntimePath sets the path that will be used for the extracted Postgres runtime directory. 81 | // If Postgres data directory is not set with DataPath(), this directory is also used as data directory. 82 | func (c Config) RuntimePath(path string) Config { 83 | c.runtimePath = path 84 | return c 85 | } 86 | 87 | // CachePath sets the path that will be used for storing Postgres binaries archive. 88 | // If this option is not set, ~/.go-embedded-postgres will be used. 89 | func (c Config) CachePath(path string) Config { 90 | c.cachePath = path 91 | return c 92 | } 93 | 94 | // DataPath sets the path that will be used for the Postgres data directory. 95 | // If this option is set, a previously initialized data directory will be reused if possible. 96 | func (c Config) DataPath(path string) Config { 97 | c.dataPath = path 98 | return c 99 | } 100 | 101 | // BinariesPath sets the path of the pre-downloaded postgres binaries. 102 | // If this option is left unset, the binaries will be downloaded. 103 | func (c Config) BinariesPath(path string) Config { 104 | c.binariesPath = path 105 | return c 106 | } 107 | 108 | // Locale sets the default locale for initdb 109 | func (c Config) Locale(locale string) Config { 110 | c.locale = locale 111 | return c 112 | } 113 | 114 | // Encoding sets the default character set for initdb 115 | func (c Config) Encoding(encoding string) Config { 116 | c.encoding = encoding 117 | return c 118 | } 119 | 120 | // StartParameters sets run-time parameters when starting Postgres (passed to Postgres via "-c"). 121 | // 122 | // These parameters can be used to override the default configuration values in postgres.conf such 123 | // as max_connections=100. See https://www.postgresql.org/docs/current/runtime-config.html 124 | func (c Config) StartParameters(parameters map[string]string) Config { 125 | c.startParameters = parameters 126 | return c 127 | } 128 | 129 | // StartTimeout sets the max timeout that will be used when starting the Postgres process and creating the initial database. 130 | func (c Config) StartTimeout(timeout time.Duration) Config { 131 | c.startTimeout = timeout 132 | return c 133 | } 134 | 135 | // Logger sets the logger for postgres output 136 | func (c Config) Logger(logger io.Writer) Config { 137 | c.logger = logger 138 | return c 139 | } 140 | 141 | // BinaryRepositoryURL set BinaryRepositoryURL to fetch PG Binary in case of Maven proxy 142 | func (c Config) BinaryRepositoryURL(binaryRepositoryURL string) Config { 143 | c.binaryRepositoryURL = binaryRepositoryURL 144 | return c 145 | } 146 | 147 | func (c Config) GetConnectionURL() string { 148 | return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.username, c.password, "localhost", c.port, c.database) 149 | } 150 | 151 | // PostgresVersion represents the semantic version used to fetch and run the Postgres process. 152 | type PostgresVersion string 153 | 154 | // Predefined supported Postgres versions. 155 | const ( 156 | V18 = PostgresVersion("18.0.0") 157 | V17 = PostgresVersion("17.5.0") 158 | V16 = PostgresVersion("16.9.0") 159 | V15 = PostgresVersion("15.13.0") 160 | V14 = PostgresVersion("14.18.0") 161 | V13 = PostgresVersion("13.21.0") 162 | V12 = PostgresVersion("12.22.0") 163 | V11 = PostgresVersion("11.22.0") 164 | V10 = PostgresVersion("10.23.0") 165 | V9 = PostgresVersion("9.6.24") 166 | ) 167 | -------------------------------------------------------------------------------- /decompression_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/tar" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "os" 9 | "path" 10 | "path/filepath" 11 | "syscall" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | "github.com/xi2/xz" 17 | ) 18 | 19 | func Test_decompressTarXz(t *testing.T) { 20 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 21 | if err != nil { 22 | panic(err) 23 | } 24 | if err := syscall.Rmdir(tempDir); err != nil { 25 | panic(err) 26 | } 27 | 28 | archive, cleanUp := createTempXzArchive() 29 | defer cleanUp() 30 | 31 | err = decompressTarXz(defaultTarReader, archive, tempDir) 32 | 33 | assert.NoError(t, err) 34 | 35 | expectedExtractedFileLocation := filepath.Join(tempDir, "dir1", "dir2", "some_content") 36 | assert.FileExists(t, expectedExtractedFileLocation) 37 | 38 | fileContentBytes, err := os.ReadFile(expectedExtractedFileLocation) 39 | assert.NoError(t, err) 40 | 41 | assert.Equal(t, "b33r is g00d", string(fileContentBytes)) 42 | } 43 | 44 | func Test_decompressTarXz_ErrorWhenFileNotExists(t *testing.T) { 45 | err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake") 46 | 47 | assert.Error(t, err) 48 | assert.Contains( 49 | t, 50 | err.Error(), 51 | "unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories", 52 | ) 53 | } 54 | 55 | func Test_decompressTarXz_ErrorWhenErrorDuringRead(t *testing.T) { 56 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 57 | if err != nil { 58 | panic(err) 59 | } 60 | if err := syscall.Rmdir(tempDir); err != nil { 61 | panic(err) 62 | } 63 | 64 | archive, cleanUp := createTempXzArchive() 65 | defer cleanUp() 66 | 67 | err = decompressTarXz(func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 68 | return func() (*tar.Header, error) { 69 | return nil, errors.New("oh noes") 70 | }, nil 71 | }, archive, tempDir) 72 | 73 | assert.EqualError(t, err, "unable to extract postgres archive: oh noes") 74 | } 75 | 76 | func Test_decompressTarXz_ErrorWhenFailedToReadFileToCopy(t *testing.T) { 77 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | archive, cleanUp := createTempXzArchive() 83 | defer cleanUp() 84 | 85 | blockingFile := filepath.Join(tempDir, "blocking") 86 | 87 | if err = os.WriteFile(blockingFile, []byte("wazz"), 0000); err != nil { 88 | panic(err) 89 | } 90 | 91 | fileBlockingExtractTarReader := func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 92 | shouldReadFile := true 93 | 94 | return func() (*tar.Header, error) { 95 | if shouldReadFile { 96 | shouldReadFile = false 97 | 98 | return &tar.Header{ 99 | Typeflag: tar.TypeReg, 100 | Name: "blocking", 101 | }, nil 102 | } 103 | 104 | return nil, io.EOF 105 | }, func() io.Reader { 106 | open, _ := os.Open("file_not_exists") 107 | return open 108 | } 109 | } 110 | 111 | err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) 112 | 113 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 114 | } 115 | 116 | func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) { 117 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 118 | if err != nil { 119 | panic(err) 120 | } 121 | if err := syscall.Rmdir(tempDir); err != nil { 122 | panic(err) 123 | } 124 | 125 | archive, cleanUp := createTempXzArchive() 126 | defer cleanUp() 127 | 128 | fileBlockingExtractTarReader := func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 129 | shouldReadFile := true 130 | 131 | return func() (*tar.Header, error) { 132 | if shouldReadFile { 133 | shouldReadFile = false 134 | 135 | return &tar.Header{ 136 | Typeflag: tar.TypeReg, 137 | Name: "some_dir/wazz/dazz/fazz", 138 | }, nil 139 | } 140 | 141 | return nil, io.EOF 142 | }, func() io.Reader { 143 | open, _ := os.Open("file_not_exists") 144 | return open 145 | } 146 | } 147 | 148 | err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) 149 | 150 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 151 | } 152 | 153 | func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) { 154 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 155 | if err != nil { 156 | panic(err) 157 | } 158 | if err := syscall.Rmdir(tempDir); err != nil { 159 | panic(err) 160 | } 161 | 162 | archive, cleanup := createTempXzArchive() 163 | 164 | defer cleanup() 165 | 166 | file, err := os.OpenFile(archive, os.O_WRONLY, 0664) 167 | if err != nil { 168 | panic(err) 169 | } 170 | 171 | if _, err := file.Seek(35, 0); err != nil { 172 | panic(err) 173 | } 174 | 175 | if _, err := file.WriteString("someJunk"); err != nil { 176 | panic(err) 177 | } 178 | 179 | if err := file.Close(); err != nil { 180 | panic(err) 181 | } 182 | 183 | err = decompressTarXz(defaultTarReader, archive, tempDir) 184 | 185 | assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt") 186 | } 187 | 188 | func Test_decompressTarXz_ErrorWithInvalidDestination(t *testing.T) { 189 | archive, cleanUp := createTempXzArchive() 190 | defer cleanUp() 191 | 192 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 193 | require.NoError(t, err) 194 | defer func() { 195 | os.RemoveAll(tempDir) 196 | }() 197 | 198 | op := fmt.Sprintf(path.Join(tempDir, "%c"), rune(0)) 199 | 200 | err = decompressTarXz(defaultTarReader, archive, op) 201 | assert.EqualError( 202 | t, 203 | err, 204 | fmt.Sprintf("unable to extract postgres archive: mkdir %s: invalid argument", op), 205 | ) 206 | } 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 | Godoc 7 | Coverage Status 8 | Build Status 9 | Build Status 10 | Go Report Card 11 |

12 | 13 | # embedded-postgres 14 | 15 | Run a real Postgres database locally on Linux, OSX or Windows as part of another Go application or test. 16 | 17 | When testing this provides a higher level of confidence than using any in memory alternative. It also requires no other 18 | external dependencies outside of the Go build ecosystem. 19 | 20 | Heavily inspired by Java projects [zonkyio/embedded-postgres](https://github.com/zonkyio/embedded-postgres) 21 | and [opentable/otj-pg-embedded](https://github.com/opentable/otj-pg-embedded) and reliant on the great work being done 22 | by [zonkyio/embedded-postgres-binaries](https://github.com/zonkyio/embedded-postgres-binaries) in order to fetch 23 | precompiled binaries 24 | from [Maven](https://mvnrepository.com/artifact/io.zonky.test.postgres/embedded-postgres-binaries-bom). 25 | 26 | ## Installation 27 | 28 | embedded-postgres uses Go modules and as such can be referenced by release version for use as a library. Use the 29 | following to add the latest release to your project. 30 | 31 | ```bash 32 | go get -u github.com/fergusstrange/embedded-postgres 33 | ``` 34 | 35 | Please note that Postgres 18 & Mac/Darwin builds require [Rosetta 2](https://github.com/fergusstrange/embedded-postgres/blob/cf5b3570ca7fc727fae6e4874ec08b4818b705b1/.circleci/config.yml#L28). 36 | 37 | ## How to use 38 | 39 | This library aims to require as little configuration as possible, favouring overridable defaults 40 | 41 | | Configuration | Default Value | 42 | |---------------------|-------------------------------------------------| 43 | | Username | postgres | 44 | | Password | postgres | 45 | | Database | postgres | 46 | | Version | 18.0.0 | 47 | | Encoding | UTF8 | 48 | | Locale | C | 49 | | CachePath | $USER_HOME/.embedded-postgres-go/ | 50 | | RuntimePath | $USER_HOME/.embedded-postgres-go/extracted | 51 | | DataPath | $USER_HOME/.embedded-postgres-go/extracted/data | 52 | | BinariesPath | $USER_HOME/.embedded-postgres-go/extracted | 53 | | BinaryRepositoryURL | https://repo1.maven.org/maven2 | 54 | | Port | 5432 | 55 | | StartTimeout | 15 Seconds | 56 | | StartParameters | map[string]string{"max_connections": "101"} | 57 | 58 | The *RuntimePath* directory is erased and recreated at each `Start()` and therefore not suitable for persistent data. 59 | 60 | If a persistent data location is required, set *DataPath* to a directory outside *RuntimePath*. 61 | 62 | If the *RuntimePath* directory is empty or already initialized but with an incompatible postgres version, it will be 63 | removed and Postgres reinitialized. 64 | 65 | Postgres binaries will be downloaded and placed in *BinaryPath* if `BinaryPath/bin` doesn't exist. 66 | *BinaryRepositoryURL* parameter allow overriding maven repository url for Postgres binaries. 67 | If the directory does exist, whatever binary version is placed there will be used (no version check 68 | is done). 69 | If your test need to run multiple different versions of Postgres for different tests, make sure 70 | *BinaryPath* is a subdirectory of *RuntimePath*. 71 | 72 | A single Postgres instance can be created, started and stopped as follows 73 | 74 | ```go 75 | postgres := embeddedpostgres.NewDatabase() 76 | err := postgres.Start() 77 | 78 | // Do test logic 79 | 80 | err := postgres.Stop() 81 | ``` 82 | 83 | or created with custom configuration 84 | 85 | ```go 86 | logger := &bytes.Buffer{} 87 | postgres := NewDatabase(DefaultConfig(). 88 | Username("beer"). 89 | Password("wine"). 90 | Database("gin"). 91 | Version(V12). 92 | RuntimePath("/tmp"). 93 | BinaryRepositoryURL("https://repo.local/central.proxy"). 94 | Port(9876). 95 | StartTimeout(45 * time.Second). 96 | StartParameters(map[string]string{"max_connections": "200"}). 97 | Logger(logger)) 98 | err := postgres.Start() 99 | 100 | // Do test logic 101 | 102 | err := postgres.Stop() 103 | ``` 104 | 105 | It should be noted that if `postgres.Stop()` is not called then the child Postgres process will not be released and the 106 | caller will block. 107 | 108 | ## Examples 109 | 110 | There are a number of realistic representations of how to use this library 111 | in [examples](https://github.com/fergusstrange/embedded-postgres/tree/master/examples). 112 | 113 | ## Credits 114 | 115 | - [Gopherize Me](https://gopherize.me) Thanks for the awesome logo template. 116 | - [zonkyio/embedded-postgres-binaries](https://github.com/zonkyio/embedded-postgres-binaries) Without which the 117 | precompiled Postgres binaries would not exist for this to work. 118 | 119 | ## Contributing 120 | 121 | View the [contributing guide](CONTRIBUTING.md). 122 | 123 | -------------------------------------------------------------------------------- /prepare_database_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_defaultInitDatabase_ErrorWhenCannotCreatePasswordFile(t *testing.T) { 15 | err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", "", os.Stderr) 16 | 17 | assert.EqualError(t, err, "unable to write password file to path_not_exists/pwfile") 18 | } 19 | 20 | func Test_defaultInitDatabase_ErrorWhenCannotStartInitDBProcess(t *testing.T) { 21 | binTempDir, err := os.MkdirTemp("", "prepare_database_test_bin") 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | runtimeTempDir, err := os.MkdirTemp("", "prepare_database_test_runtime") 27 | if err != nil { 28 | panic(err) 29 | } 30 | 31 | logFile, err := ioutil.TempFile("", "prepare_database_test_log") 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | defer func() { 37 | if err := os.RemoveAll(binTempDir); err != nil { 38 | panic(err) 39 | } 40 | 41 | if err := os.RemoveAll(runtimeTempDir); err != nil { 42 | panic(err) 43 | } 44 | 45 | if err := os.Remove(logFile.Name()); err != nil { 46 | panic(err) 47 | } 48 | }() 49 | 50 | _, _ = logFile.Write([]byte("and here are the logs!")) 51 | 52 | err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", "", logFile) 53 | 54 | assert.NotNil(t, err) 55 | assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U Tom -D %s/data --pwfile=%s/pwfile'", 56 | binTempDir, 57 | runtimeTempDir, 58 | runtimeTempDir)) 59 | assert.Contains(t, err.Error(), "and here are the logs!") 60 | assert.FileExists(t, filepath.Join(runtimeTempDir, "pwfile")) 61 | } 62 | 63 | func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) { 64 | tempDir, err := os.MkdirTemp("", "prepare_database_test") 65 | if err != nil { 66 | panic(err) 67 | } 68 | 69 | defer func() { 70 | if err := os.RemoveAll(tempDir); err != nil { 71 | panic(err) 72 | } 73 | }() 74 | 75 | err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", "", os.Stderr) 76 | 77 | assert.NotNil(t, err) 78 | assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --locale=en_XY'", 79 | tempDir, 80 | tempDir, 81 | tempDir)) 82 | } 83 | 84 | func Test_defaultInitDatabase_ErrorInvalidEncodingSetting(t *testing.T) { 85 | tempDir, err := os.MkdirTemp("", "prepare_database_test") 86 | if err != nil { 87 | panic(err) 88 | } 89 | 90 | defer func() { 91 | if err := os.RemoveAll(tempDir); err != nil { 92 | panic(err) 93 | } 94 | }() 95 | 96 | err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "", "invalid", os.Stderr) 97 | 98 | assert.NotNil(t, err) 99 | assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --encoding=invalid'", 100 | tempDir, 101 | tempDir, 102 | tempDir)) 103 | } 104 | 105 | func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) { 106 | tempDir, err := os.MkdirTemp("", "prepare_database_test") 107 | if err != nil { 108 | panic(err) 109 | } 110 | 111 | defer func() { 112 | if err := os.RemoveAll(tempDir); err != nil { 113 | panic(err) 114 | } 115 | }() 116 | 117 | database := NewDatabase(DefaultConfig().RuntimePath(tempDir)) 118 | if err := database.Start(); err != nil { 119 | t.Fatal(err) 120 | } 121 | 122 | defer func() { 123 | if err := database.Stop(); err != nil { 124 | t.Fatal(err) 125 | } 126 | }() 127 | 128 | pwFile := filepath.Join(tempDir, "pwfile") 129 | _, err = os.Stat(pwFile) 130 | 131 | assert.True(t, os.IsNotExist(err), "pwfile (%v) still exists after starting the db", pwFile) 132 | } 133 | 134 | func Test_defaultCreateDatabase_ErrorWhenSQLOpenError(t *testing.T) { 135 | err := defaultCreateDatabase(1234, "user client_encoding=lol", "password", "database") 136 | 137 | assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'") 138 | } 139 | 140 | func Test_defaultCreateDatabase_DashesInName(t *testing.T) { 141 | database := NewDatabase(DefaultConfig(). 142 | Port(9832). 143 | Database("my-cool-database")) 144 | 145 | if err := database.Start(); err != nil { 146 | t.Fatal(err) 147 | } 148 | 149 | if err := database.Stop(); err != nil { 150 | t.Fatal(err) 151 | } 152 | } 153 | 154 | func Test_defaultCreateDatabase_ErrorWhenQueryError(t *testing.T) { 155 | database := NewDatabase(DefaultConfig(). 156 | Port(9831). 157 | Database("b33r")) 158 | if err := database.Start(); err != nil { 159 | t.Fatal(err) 160 | } 161 | 162 | defer func() { 163 | if err := database.Stop(); err != nil { 164 | t.Fatal(err) 165 | } 166 | }() 167 | 168 | err := defaultCreateDatabase(9831, "postgres", "postgres", "b33r") 169 | 170 | assert.EqualError(t, err, `unable to connect to create database with custom name b33r with the following error: pq: database "b33r" already exists`) 171 | } 172 | 173 | func Test_healthCheckDatabase_ErrorWhenSQLConnectingError(t *testing.T) { 174 | err := healthCheckDatabase(1234, "tom client_encoding=lol", "more", "b33r") 175 | 176 | assert.EqualError(t, err, "client_encoding must be absent or 'UTF8'") 177 | } 178 | 179 | type CloserWithoutErr struct{} 180 | 181 | func (c *CloserWithoutErr) Close() error { 182 | return nil 183 | } 184 | 185 | func TestConnCloserWithoutErr(t *testing.T) { 186 | originalErr := errors.New("OriginalError") 187 | 188 | tests := []struct { 189 | name string 190 | err error 191 | expectedErrTxt string 192 | }{ 193 | { 194 | "No original error, no error from closer", 195 | nil, 196 | "", 197 | }, 198 | { 199 | "original error, no error from closer", 200 | originalErr, 201 | originalErr.Error(), 202 | }, 203 | } 204 | 205 | for _, tt := range tests { 206 | t.Run(tt.name, func(t *testing.T) { 207 | resultErr := connectionClose(&CloserWithoutErr{}, tt.err) 208 | 209 | if len(tt.expectedErrTxt) == 0 { 210 | if resultErr != nil { 211 | t.Fatalf("Expected nil error, got error: %v", resultErr) 212 | } 213 | 214 | return 215 | } 216 | 217 | if resultErr.Error() != tt.expectedErrTxt { 218 | t.Fatalf("Expected error: %v, got error: %v", tt.expectedErrTxt, resultErr) 219 | } 220 | }) 221 | } 222 | } 223 | 224 | type CloserWithErr struct{} 225 | 226 | const testError = "TestError" 227 | 228 | func (c *CloserWithErr) Close() error { 229 | return errors.New(testError) 230 | } 231 | 232 | func TestConnCloserWithErr(t *testing.T) { 233 | originalErr := errors.New("OriginalError") 234 | 235 | closeDBConnErr := fmt.Errorf(fmtCloseDBConn, errors.New(testError)) 236 | 237 | tests := []struct { 238 | name string 239 | err error 240 | expectedErrTxt string 241 | }{ 242 | { 243 | "No original error, error from closer", 244 | nil, 245 | closeDBConnErr.Error(), 246 | }, 247 | { 248 | "original error, error from closer", 249 | originalErr, 250 | fmt.Errorf(fmtAfterError, closeDBConnErr, originalErr).Error(), 251 | }, 252 | } 253 | 254 | for _, tt := range tests { 255 | t.Run(tt.name, func(t *testing.T) { 256 | resultErr := connectionClose(&CloserWithErr{}, tt.err) 257 | 258 | if len(tt.expectedErrTxt) == 0 { 259 | if resultErr != nil { 260 | t.Fatalf("Expected nil error, got error: %v", resultErr) 261 | } 262 | 263 | return 264 | } 265 | 266 | if resultErr.Error() != tt.expectedErrTxt { 267 | t.Fatalf("Expected error: %v, got error: %v", tt.expectedErrTxt, resultErr) 268 | } 269 | }) 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /examples/go.sum: -------------------------------------------------------------------------------- 1 | github.com/ClickHouse/clickhouse-go v1.4.5/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI= 2 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 3 | github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 4 | github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= 5 | github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= 10 | github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 11 | github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= 12 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 13 | github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= 14 | github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= 15 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= 16 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= 17 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 18 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 19 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 20 | github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 21 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 22 | github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 23 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 24 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 25 | github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 26 | github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 27 | github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= 28 | github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 29 | github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= 30 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 31 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 32 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 33 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 34 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 35 | github.com/pressly/goose/v3 v3.0.1 h1:XdndErg0gNhnWGhQYAurNWw2oYuirQTaAN/Dbw+arUM= 36 | github.com/pressly/goose/v3 v3.0.1/go.mod h1:1L3t2XSf5sGj6OkiCD61z2DJARRWr2sqbJt3JKVnbwo= 37 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 38 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 39 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 40 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 41 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= 42 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= 43 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 44 | github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= 45 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 46 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 47 | go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= 48 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 49 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 50 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 51 | go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= 52 | go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= 53 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 54 | golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 55 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 56 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 57 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 58 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 59 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 60 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 61 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 62 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 63 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 64 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 65 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 66 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 67 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 68 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 69 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 70 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 71 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 72 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 73 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 74 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 75 | golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 76 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 77 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 78 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 79 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 80 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 81 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 82 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 83 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 84 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 85 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 86 | -------------------------------------------------------------------------------- /embedded_postgres.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "os" 8 | "os/exec" 9 | "path/filepath" 10 | "runtime" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | var mu sync.Mutex 16 | 17 | var ( 18 | ErrServerNotStarted = errors.New("server has not been started") 19 | ErrServerAlreadyStarted = errors.New("server is already started") 20 | ) 21 | 22 | // EmbeddedPostgres maintains all configuration and runtime functions for maintaining the lifecycle of one Postgres process. 23 | type EmbeddedPostgres struct { 24 | config Config 25 | cacheLocator CacheLocator 26 | remoteFetchStrategy RemoteFetchStrategy 27 | initDatabase initDatabase 28 | createDatabase createDatabase 29 | started bool 30 | syncedLogger *syncedLogger 31 | } 32 | 33 | // NewDatabase creates a new EmbeddedPostgres struct that can be used to start and stop a Postgres process. 34 | // When called with no parameters it will assume a default configuration state provided by the DefaultConfig method. 35 | // When called with parameters the first Config parameter will be used for configuration. 36 | func NewDatabase(config ...Config) *EmbeddedPostgres { 37 | if len(config) < 1 { 38 | return newDatabaseWithConfig(DefaultConfig()) 39 | } 40 | 41 | return newDatabaseWithConfig(config[0]) 42 | } 43 | 44 | func newDatabaseWithConfig(config Config) *EmbeddedPostgres { 45 | versionStrategy := defaultVersionStrategy( 46 | config, 47 | runtime.GOOS, 48 | runtime.GOARCH, 49 | linuxMachineName, 50 | shouldUseAlpineLinuxBuild, 51 | ) 52 | cacheLocator := defaultCacheLocator(config.cachePath, versionStrategy) 53 | remoteFetchStrategy := defaultRemoteFetchStrategy(config.binaryRepositoryURL, versionStrategy, cacheLocator) 54 | 55 | return &EmbeddedPostgres{ 56 | config: config, 57 | cacheLocator: cacheLocator, 58 | remoteFetchStrategy: remoteFetchStrategy, 59 | initDatabase: defaultInitDatabase, 60 | createDatabase: defaultCreateDatabase, 61 | started: false, 62 | } 63 | } 64 | 65 | // Start will try to start the configured Postgres process returning an error when there were any problems with invocation. 66 | // If any error occurs Start will try to also Stop the Postgres process in order to not leave any sub-process running. 67 | // 68 | //nolint:funlen 69 | func (ep *EmbeddedPostgres) Start() error { 70 | if ep.started { 71 | return ErrServerAlreadyStarted 72 | } 73 | 74 | if err := ensurePortAvailable(ep.config.port); err != nil { 75 | return err 76 | } 77 | 78 | logger, err := newSyncedLogger("", ep.config.logger) 79 | if err != nil { 80 | return errors.New("unable to create logger") 81 | } 82 | 83 | ep.syncedLogger = logger 84 | 85 | cacheLocation, cacheExists := ep.cacheLocator() 86 | 87 | if ep.config.runtimePath == "" { 88 | ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted") 89 | } 90 | 91 | if ep.config.dataPath == "" { 92 | ep.config.dataPath = filepath.Join(ep.config.runtimePath, "data") 93 | } 94 | 95 | if err := os.RemoveAll(ep.config.runtimePath); err != nil { 96 | return fmt.Errorf("unable to clean up runtime directory %s with error: %s", ep.config.runtimePath, err) 97 | } 98 | 99 | if ep.config.binariesPath == "" { 100 | ep.config.binariesPath = ep.config.runtimePath 101 | } 102 | 103 | if err := ep.downloadAndExtractBinary(cacheExists, cacheLocation); err != nil { 104 | return err 105 | } 106 | 107 | if err := os.MkdirAll(ep.config.runtimePath, os.ModePerm); err != nil { 108 | return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err) 109 | } 110 | 111 | reuseData := dataDirIsValid(ep.config.dataPath, ep.config.version) 112 | 113 | if !reuseData { 114 | if err := ep.cleanDataDirectoryAndInit(); err != nil { 115 | return err 116 | } 117 | } 118 | 119 | if err := startPostgres(ep); err != nil { 120 | return err 121 | } 122 | 123 | if err := ep.syncedLogger.flush(); err != nil { 124 | return err 125 | } 126 | 127 | ep.started = true 128 | 129 | if !reuseData { 130 | if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { 131 | if stopErr := stopPostgres(ep); stopErr != nil { 132 | return fmt.Errorf("unable to stop database caused by error %s", err) 133 | } 134 | 135 | return err 136 | } 137 | } 138 | 139 | if err := healthCheckDatabaseOrTimeout(ep.config); err != nil { 140 | if stopErr := stopPostgres(ep); stopErr != nil { 141 | return fmt.Errorf("unable to stop database caused by error %s", err) 142 | } 143 | 144 | return err 145 | } 146 | 147 | return nil 148 | } 149 | 150 | func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error { 151 | // lock to prevent collisions with duplicate downloads 152 | mu.Lock() 153 | defer mu.Unlock() 154 | 155 | _, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin", "pg_ctl")) 156 | if os.IsNotExist(binDirErr) { 157 | if !cacheExists { 158 | if err := ep.remoteFetchStrategy(); err != nil { 159 | return err 160 | } 161 | } 162 | 163 | if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil { 164 | return err 165 | } 166 | } 167 | return nil 168 | } 169 | 170 | func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error { 171 | if err := os.RemoveAll(ep.config.dataPath); err != nil { 172 | return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err) 173 | } 174 | 175 | if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.encoding, ep.syncedLogger.file); err != nil { 176 | return err 177 | } 178 | 179 | return nil 180 | } 181 | 182 | // Stop will try to stop the Postgres process gracefully returning an error when there were any problems. 183 | func (ep *EmbeddedPostgres) Stop() error { 184 | if !ep.started { 185 | return ErrServerNotStarted 186 | } 187 | 188 | if err := stopPostgres(ep); err != nil { 189 | return err 190 | } 191 | 192 | ep.started = false 193 | 194 | if err := ep.syncedLogger.flush(); err != nil { 195 | return err 196 | } 197 | 198 | return nil 199 | } 200 | 201 | func encodeOptions(port uint32, parameters map[string]string) string { 202 | options := []string{fmt.Sprintf("-p %d", port)} 203 | for k, v := range parameters { 204 | // Double-quote parameter values - they may have spaces. 205 | // Careful: CMD on Windows uses only double quotes to delimit strings. 206 | // It treats single quotes as regular characters. 207 | options = append(options, fmt.Sprintf("-c %s=\"%s\"", k, v)) 208 | } 209 | return strings.Join(options, " ") 210 | } 211 | 212 | func startPostgres(ep *EmbeddedPostgres) error { 213 | postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl") 214 | postgresProcess := exec.Command(postgresBinary, "start", "-w", 215 | "-D", ep.config.dataPath, 216 | "-o", encodeOptions(ep.config.port, ep.config.startParameters)) 217 | postgresProcess.Stdout = ep.syncedLogger.file 218 | postgresProcess.Stderr = ep.syncedLogger.file 219 | 220 | if err := postgresProcess.Run(); err != nil { 221 | _ = ep.syncedLogger.flush() 222 | logContent, _ := readLogsOrTimeout(ep.syncedLogger.file) 223 | 224 | return fmt.Errorf("could not start postgres using %s:\n%s", postgresProcess.String(), string(logContent)) 225 | } 226 | 227 | return nil 228 | } 229 | 230 | func stopPostgres(ep *EmbeddedPostgres) error { 231 | postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl") 232 | postgresProcess := exec.Command(postgresBinary, "stop", "-w", 233 | "-D", ep.config.dataPath) 234 | postgresProcess.Stderr = ep.syncedLogger.file 235 | postgresProcess.Stdout = ep.syncedLogger.file 236 | 237 | if err := postgresProcess.Run(); err != nil { 238 | return err 239 | } 240 | 241 | return nil 242 | } 243 | 244 | func ensurePortAvailable(port uint32) error { 245 | conn, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) 246 | if err != nil { 247 | return fmt.Errorf("process already listening on port %d", port) 248 | } 249 | 250 | if err := conn.Close(); err != nil { 251 | return err 252 | } 253 | 254 | return nil 255 | } 256 | 257 | func dataDirIsValid(dataDir string, version PostgresVersion) bool { 258 | pgVersion := filepath.Join(dataDir, "PG_VERSION") 259 | 260 | d, err := os.ReadFile(pgVersion) 261 | if err != nil { 262 | return false 263 | } 264 | 265 | v := strings.TrimSuffix(string(d), "\n") 266 | 267 | return strings.HasPrefix(string(version), v) 268 | } 269 | -------------------------------------------------------------------------------- /remote_fetch_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/zip" 5 | "crypto/sha256" 6 | "encoding/hex" 7 | "github.com/stretchr/testify/require" 8 | "io" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "path" 13 | "path/filepath" 14 | "strings" 15 | "testing" 16 | 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | func Test_defaultRemoteFetchStrategy_ErrorWhenHttpGet(t *testing.T) { 21 | remoteFetchStrategy := defaultRemoteFetchStrategy("http://localhost:1234/maven2", 22 | testVersionStrategy(), 23 | testCacheLocator()) 24 | 25 | err := remoteFetchStrategy() 26 | 27 | assert.EqualError(t, err, "unable to connect to http://localhost:1234/maven2") 28 | } 29 | 30 | func Test_defaultRemoteFetchStrategy_ErrorWhenHttpStatusNot200(t *testing.T) { 31 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | w.WriteHeader(http.StatusNotFound) 33 | })) 34 | defer server.Close() 35 | 36 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL, 37 | testVersionStrategy(), 38 | testCacheLocator()) 39 | 40 | err := remoteFetchStrategy() 41 | 42 | assert.EqualError(t, err, "no version found matching 1.2.3") 43 | } 44 | 45 | func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) { 46 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 47 | w.Header().Set("Content-Length", "1") 48 | })) 49 | defer server.Close() 50 | 51 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 52 | testVersionStrategy(), 53 | testCacheLocator()) 54 | 55 | err := remoteFetchStrategy() 56 | 57 | assert.EqualError(t, err, "error fetching postgres: unexpected EOF") 58 | } 59 | 60 | func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) { 61 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 62 | if strings.HasSuffix(r.RequestURI, ".sha256") { 63 | w.WriteHeader(http.StatusNotFound) 64 | return 65 | } 66 | })) 67 | defer server.Close() 68 | 69 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 70 | testVersionStrategy(), 71 | testCacheLocator()) 72 | 73 | err := remoteFetchStrategy() 74 | 75 | assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file") 76 | } 77 | 78 | func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) { 79 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 | if strings.HasSuffix(r.RequestURI, ".sha256") { 81 | w.WriteHeader(404) 82 | return 83 | } 84 | 85 | if _, err := w.Write([]byte("lolz")); err != nil { 86 | panic(err) 87 | } 88 | })) 89 | defer server.Close() 90 | 91 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 92 | testVersionStrategy(), 93 | testCacheLocator()) 94 | 95 | err := remoteFetchStrategy() 96 | 97 | assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file") 98 | } 99 | 100 | func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) { 101 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 102 | if strings.HasSuffix(r.RequestURI, ".sha256") { 103 | w.WriteHeader(http.StatusNotFound) 104 | return 105 | } 106 | 107 | MyZipWriter := zip.NewWriter(w) 108 | 109 | if err := MyZipWriter.Close(); err != nil { 110 | t.Error(err) 111 | } 112 | })) 113 | defer server.Close() 114 | 115 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 116 | testVersionStrategy(), 117 | testCacheLocator()) 118 | 119 | err := remoteFetchStrategy() 120 | 121 | assert.EqualError(t, err, "error fetching postgres: cannot find binary in archive retrieved from "+server.URL+"/maven2/io/zonky/test/postgres/embedded-postgres-binaries-darwin-amd64/1.2.3/embedded-postgres-binaries-darwin-amd64-1.2.3.jar") 122 | } 123 | 124 | func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing.T) { 125 | jarFile, cleanUp := createTempZipArchive() 126 | defer cleanUp() 127 | 128 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 | if strings.HasSuffix(r.RequestURI, ".sha256") { 130 | w.WriteHeader(http.StatusNotFound) 131 | return 132 | } 133 | 134 | bytes, err := os.ReadFile(jarFile) 135 | if err != nil { 136 | panic(err) 137 | } 138 | if _, err := w.Write(bytes); err != nil { 139 | panic(err) 140 | } 141 | })) 142 | defer server.Close() 143 | 144 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 145 | testVersionStrategy(), 146 | func() (s string, b bool) { 147 | return filepath.FromSlash("/invalid"), false 148 | }) 149 | 150 | err := remoteFetchStrategy() 151 | 152 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 153 | } 154 | 155 | func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *testing.T) { 156 | jarFile, cleanUp := createTempZipArchive() 157 | defer cleanUp() 158 | 159 | fileBlockingExtractDirectory := filepath.Join(filepath.Dir(jarFile), "a_file_blocking_extract") 160 | 161 | if _, err := os.Create(fileBlockingExtractDirectory); err != nil { 162 | panic(err) 163 | } 164 | 165 | cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar") 166 | 167 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 168 | if strings.HasSuffix(r.RequestURI, ".sha256") { 169 | w.WriteHeader(http.StatusNotFound) 170 | return 171 | } 172 | 173 | bytes, err := os.ReadFile(jarFile) 174 | if err != nil { 175 | panic(err) 176 | } 177 | if _, err := w.Write(bytes); err != nil { 178 | panic(err) 179 | } 180 | })) 181 | 182 | defer server.Close() 183 | 184 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 185 | testVersionStrategy(), 186 | func() (s string, b bool) { 187 | return cacheLocation, false 188 | }) 189 | 190 | err := remoteFetchStrategy() 191 | 192 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 193 | } 194 | 195 | func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *testing.T) { 196 | jarFile, cleanUp := createTempZipArchive() 197 | defer cleanUp() 198 | 199 | cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_directory", "cache_file.jar") 200 | 201 | if err := os.MkdirAll(cacheLocation, os.ModePerm); err != nil { 202 | panic(err) 203 | } 204 | 205 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 206 | if strings.HasSuffix(r.RequestURI, ".sha256") { 207 | w.WriteHeader(http.StatusNotFound) 208 | return 209 | } 210 | 211 | bytes, err := os.ReadFile(jarFile) 212 | if err != nil { 213 | panic(err) 214 | } 215 | if _, err := w.Write(bytes); err != nil { 216 | panic(err) 217 | } 218 | })) 219 | defer server.Close() 220 | 221 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 222 | testVersionStrategy(), 223 | func() (s string, b bool) { 224 | return "/\\000", false 225 | }) 226 | 227 | err := remoteFetchStrategy() 228 | 229 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 230 | } 231 | 232 | func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) { 233 | jarFile, cleanUp := createTempZipArchive() 234 | defer cleanUp() 235 | 236 | cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar") 237 | 238 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 239 | bytes, err := os.ReadFile(jarFile) 240 | if err != nil { 241 | panic(err) 242 | } 243 | 244 | if strings.HasSuffix(r.RequestURI, ".sha256") { 245 | w.WriteHeader(200) 246 | if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil { 247 | panic(err) 248 | } 249 | 250 | return 251 | } 252 | 253 | if _, err := w.Write(bytes); err != nil { 254 | panic(err) 255 | } 256 | })) 257 | defer server.Close() 258 | 259 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 260 | testVersionStrategy(), 261 | func() (s string, b bool) { 262 | return cacheLocation, false 263 | }) 264 | 265 | err := remoteFetchStrategy() 266 | 267 | assert.EqualError(t, err, "downloaded checksums do not match") 268 | } 269 | 270 | func Test_defaultRemoteFetchStrategy(t *testing.T) { 271 | jarFile, cleanUp := createTempZipArchive() 272 | defer cleanUp() 273 | 274 | cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar") 275 | 276 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 277 | bytes, err := os.ReadFile(jarFile) 278 | if err != nil { 279 | panic(err) 280 | } 281 | 282 | if strings.HasSuffix(r.RequestURI, ".sha256") { 283 | w.WriteHeader(200) 284 | contentHash := sha256.Sum256(bytes) 285 | if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil { 286 | panic(err) 287 | } 288 | 289 | return 290 | } 291 | 292 | if _, err := w.Write(bytes); err != nil { 293 | panic(err) 294 | } 295 | })) 296 | defer server.Close() 297 | 298 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 299 | testVersionStrategy(), 300 | func() (s string, b bool) { 301 | return cacheLocation, false 302 | }) 303 | 304 | err := remoteFetchStrategy() 305 | 306 | assert.NoError(t, err) 307 | assert.FileExists(t, cacheLocation) 308 | } 309 | 310 | func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) { 311 | jarFile, cleanUp := createTempZipArchive() 312 | defer cleanUp() 313 | 314 | // create a temp directory for testing 315 | tempFile, err := os.MkdirTemp("", "cache_output") 316 | if err != nil { 317 | panic(err) 318 | } 319 | // clean up once the test is finished. 320 | defer func() { 321 | if err := os.RemoveAll(tempFile); err != nil { 322 | panic(err) 323 | } 324 | }() 325 | 326 | cacheLocation := path.Join(tempFile, "temp.jar") 327 | 328 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 329 | bytes, err := os.ReadFile(jarFile) 330 | if err != nil { 331 | panic(err) 332 | } 333 | 334 | if strings.HasSuffix(r.RequestURI, ".sha256") { 335 | w.WriteHeader(200) 336 | contentHash := sha256.Sum256(bytes) 337 | if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil { 338 | panic(err) 339 | } 340 | 341 | return 342 | } 343 | 344 | if _, err := w.Write(bytes); err != nil { 345 | panic(err) 346 | } 347 | })) 348 | defer server.Close() 349 | 350 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 351 | testVersionStrategy(), 352 | func() (s string, b bool) { 353 | return cacheLocation, false 354 | }) 355 | 356 | // call it the remoteFetchStrategy(). The output location should be empty and a new file created 357 | err = remoteFetchStrategy() 358 | assert.NoError(t, err) 359 | assert.FileExists(t, cacheLocation) 360 | out1, err := os.ReadFile(cacheLocation) 361 | 362 | // write some bad data to the file, this helps us test that the file is overwritten 363 | err = os.WriteFile(cacheLocation, []byte("invalid"), 0600) 364 | assert.NoError(t, err) 365 | 366 | // call the remoteFetchStrategy() again, this time the file should be overwritten 367 | err = remoteFetchStrategy() 368 | assert.NoError(t, err) 369 | assert.FileExists(t, cacheLocation) 370 | 371 | // ensure that the file contents are the same from both downloads, and that it doesn't contain the `invalid` data. 372 | out2, err := os.ReadFile(cacheLocation) 373 | assert.Equal(t, out1, out2) 374 | } 375 | 376 | func Test_defaultRemoteFetchStrategy_whenContentLengthNotSet(t *testing.T) { 377 | jarFile, cleanUp := createTempZipArchive() 378 | defer cleanUp() 379 | 380 | cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar") 381 | 382 | bytes, err := os.ReadFile(jarFile) 383 | if err != nil { 384 | require.NoError(t, err) 385 | } 386 | contentHash := sha256.Sum256(bytes) 387 | 388 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 389 | if strings.HasSuffix(r.RequestURI, ".sha256") { 390 | w.WriteHeader(200) 391 | if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil { 392 | panic(err) 393 | } 394 | 395 | return 396 | } 397 | 398 | f, err := os.Open(jarFile) 399 | if err != nil { 400 | panic(err) 401 | } 402 | 403 | // stream the file back so that Go uses 404 | // chunked encoding and never sets Content-Length 405 | _, _ = io.Copy(w, f) 406 | })) 407 | defer server.Close() 408 | 409 | remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", 410 | testVersionStrategy(), 411 | func() (s string, b bool) { 412 | return cacheLocation, false 413 | }) 414 | 415 | err = remoteFetchStrategy() 416 | 417 | assert.NoError(t, err) 418 | assert.FileExists(t, cacheLocation) 419 | } 420 | 421 | func Test_closeBody_NilResponse(t *testing.T) { 422 | assert.NotPanics(t, func() { 423 | closeBody(nil)() 424 | }) 425 | } 426 | -------------------------------------------------------------------------------- /embedded_postgres_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "os" 9 | "os/user" 10 | "path" 11 | "path/filepath" 12 | "strings" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | ) 20 | 21 | func Test_DefaultConfig(t *testing.T) { 22 | defer verifyLeak(t) 23 | 24 | database := NewDatabase() 25 | if err := database.Start(); err != nil { 26 | shutdownDBAndFail(t, err, database) 27 | } 28 | 29 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 30 | if err != nil { 31 | shutdownDBAndFail(t, err, database) 32 | } 33 | 34 | if err = db.Ping(); err != nil { 35 | shutdownDBAndFail(t, err, database) 36 | } 37 | 38 | if err := db.Close(); err != nil { 39 | shutdownDBAndFail(t, err, database) 40 | } 41 | 42 | if err := database.Stop(); err != nil { 43 | shutdownDBAndFail(t, err, database) 44 | } 45 | } 46 | 47 | func Test_ErrorWhenPortAlreadyTaken(t *testing.T) { 48 | listener, err := net.Listen("tcp", "localhost:9887") 49 | if err != nil { 50 | panic(err) 51 | } 52 | 53 | defer func() { 54 | if err := listener.Close(); err != nil { 55 | panic(err) 56 | } 57 | }() 58 | 59 | database := NewDatabase(DefaultConfig(). 60 | Port(9887)) 61 | 62 | err = database.Start() 63 | 64 | assert.EqualError(t, err, "process already listening on port 9887") 65 | } 66 | 67 | func Test_ErrorWhenRemoteFetchError(t *testing.T) { 68 | database := NewDatabase() 69 | database.cacheLocator = func() (string, bool) { 70 | return "", false 71 | } 72 | database.remoteFetchStrategy = func() error { 73 | return errors.New("did not work") 74 | } 75 | 76 | err := database.Start() 77 | 78 | assert.EqualError(t, err, "did not work") 79 | } 80 | 81 | func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) { 82 | jarFile, cleanUp := createTempZipArchive() 83 | defer cleanUp() 84 | 85 | database := NewDatabase(DefaultConfig(). 86 | Username("gin"). 87 | Password("wine"). 88 | Database("beer"). 89 | StartTimeout(10 * time.Second)) 90 | 91 | database.cacheLocator = func() (string, bool) { 92 | return jarFile, true 93 | } 94 | 95 | err := database.Start() 96 | 97 | if err == nil { 98 | if err := database.Stop(); err != nil { 99 | panic(err) 100 | } 101 | } 102 | 103 | assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, xz: file format not recognized`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted"))) 104 | } 105 | 106 | func Test_ErrorWhenUnableToInitDatabase(t *testing.T) { 107 | jarFile, cleanUp := createTempXzArchive() 108 | defer cleanUp() 109 | 110 | extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract") 111 | if err != nil { 112 | panic(err) 113 | } 114 | 115 | database := NewDatabase(DefaultConfig(). 116 | Username("gin"). 117 | Password("wine"). 118 | Database("beer"). 119 | RuntimePath(extractPath). 120 | StartTimeout(10 * time.Second)) 121 | 122 | database.cacheLocator = func() (string, bool) { 123 | return jarFile, true 124 | } 125 | 126 | database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error { 127 | return errors.New("ah it did not work") 128 | } 129 | 130 | err = database.Start() 131 | 132 | if err == nil { 133 | if err := database.Stop(); err != nil { 134 | panic(err) 135 | } 136 | } 137 | 138 | assert.EqualError(t, err, "ah it did not work") 139 | } 140 | 141 | func Test_ErrorWhenUnableToCreateDatabase(t *testing.T) { 142 | jarFile, cleanUp := createTempXzArchive() 143 | 144 | defer cleanUp() 145 | 146 | extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract") 147 | 148 | if err != nil { 149 | panic(err) 150 | } 151 | 152 | database := NewDatabase(DefaultConfig(). 153 | Username("gin"). 154 | Password("wine"). 155 | Database("beer"). 156 | RuntimePath(extractPath). 157 | StartTimeout(10 * time.Second)) 158 | 159 | database.createDatabase = func(port uint32, username, password, database string) error { 160 | return errors.New("ah noes") 161 | } 162 | 163 | err = database.Start() 164 | 165 | if err == nil { 166 | if err := database.Stop(); err != nil { 167 | panic(err) 168 | } 169 | } 170 | 171 | assert.EqualError(t, err, "ah noes") 172 | } 173 | 174 | func Test_TimesOutWhenCannotStart(t *testing.T) { 175 | database := NewDatabase(DefaultConfig(). 176 | Database("something-fancy"). 177 | StartTimeout(500 * time.Millisecond)) 178 | 179 | database.createDatabase = func(port uint32, username, password, database string) error { 180 | return nil 181 | } 182 | 183 | err := database.Start() 184 | 185 | assert.EqualError(t, err, "timed out waiting for database to become available") 186 | } 187 | 188 | func Test_ErrorWhenStopCalledBeforeStart(t *testing.T) { 189 | database := NewDatabase() 190 | 191 | err := database.Stop() 192 | 193 | assert.ErrorIs(t, err, ErrServerNotStarted) 194 | } 195 | 196 | func Test_ErrorWhenStartCalledWhenAlreadyStarted(t *testing.T) { 197 | database := NewDatabase() 198 | 199 | defer func() { 200 | if err := database.Stop(); err != nil { 201 | t.Fatal(err) 202 | } 203 | }() 204 | 205 | err := database.Start() 206 | assert.NoError(t, err) 207 | 208 | err = database.Start() 209 | assert.ErrorIs(t, err, ErrServerAlreadyStarted) 210 | } 211 | 212 | func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) { 213 | jarFile, cleanUp := createTempXzArchive() 214 | 215 | defer cleanUp() 216 | 217 | extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract") 218 | if err != nil { 219 | panic(err) 220 | } 221 | 222 | database := NewDatabase(DefaultConfig(). 223 | RuntimePath(extractPath)) 224 | 225 | database.cacheLocator = func() (string, bool) { 226 | return jarFile, true 227 | } 228 | 229 | database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error { 230 | _, _ = logger.Write([]byte("ah it did not work")) 231 | return nil 232 | } 233 | 234 | err = database.Start() 235 | 236 | assert.EqualError(t, err, fmt.Sprintf("could not start postgres using %s/bin/pg_ctl start -w -D %s/data -o -p 5432:\nah it did not work", extractPath, extractPath)) 237 | } 238 | 239 | func Test_CustomConfig(t *testing.T) { 240 | tempDir, err := os.MkdirTemp("", "embedded_postgres_test") 241 | if err != nil { 242 | panic(err) 243 | } 244 | 245 | defer func() { 246 | if err := os.RemoveAll(tempDir); err != nil { 247 | panic(err) 248 | } 249 | }() 250 | 251 | database := NewDatabase(DefaultConfig(). 252 | Username("gin"). 253 | Password("wine"). 254 | Database("beer"). 255 | Version(V15). 256 | RuntimePath(tempDir). 257 | Port(9876). 258 | StartTimeout(10 * time.Second). 259 | Locale("C"). 260 | Encoding("UTF8"). 261 | Logger(nil)) 262 | 263 | if err := database.Start(); err != nil { 264 | shutdownDBAndFail(t, err, database) 265 | } 266 | 267 | db, err := sql.Open("postgres", "host=localhost port=9876 user=gin password=wine dbname=beer sslmode=disable") 268 | if err != nil { 269 | shutdownDBAndFail(t, err, database) 270 | } 271 | 272 | if err = db.Ping(); err != nil { 273 | shutdownDBAndFail(t, err, database) 274 | } 275 | 276 | if err := db.Close(); err != nil { 277 | shutdownDBAndFail(t, err, database) 278 | } 279 | 280 | if err := database.Stop(); err != nil { 281 | shutdownDBAndFail(t, err, database) 282 | } 283 | } 284 | 285 | func Test_CustomLog(t *testing.T) { 286 | tempDir, err := os.MkdirTemp("", "embedded_postgres_test") 287 | if err != nil { 288 | panic(err) 289 | } 290 | 291 | defer func() { 292 | if err := os.RemoveAll(tempDir); err != nil { 293 | panic(err) 294 | } 295 | }() 296 | 297 | logger := customLogger{} 298 | 299 | database := NewDatabase(DefaultConfig(). 300 | Logger(&logger)) 301 | 302 | if err := database.Start(); err != nil { 303 | shutdownDBAndFail(t, err, database) 304 | } 305 | 306 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 307 | if err != nil { 308 | shutdownDBAndFail(t, err, database) 309 | } 310 | 311 | if err = db.Ping(); err != nil { 312 | shutdownDBAndFail(t, err, database) 313 | } 314 | 315 | if err := db.Close(); err != nil { 316 | shutdownDBAndFail(t, err, database) 317 | } 318 | 319 | if err := database.Stop(); err != nil { 320 | shutdownDBAndFail(t, err, database) 321 | } 322 | 323 | current, err := user.Current() 324 | 325 | lines := strings.Split(string(logger.logLines), "\n") 326 | 327 | assert.NoError(t, err) 328 | assert.Contains(t, lines, fmt.Sprintf("The files belonging to this database system will be owned by user \"%s\".", current.Username)) 329 | assert.Contains(t, lines, "syncing data to disk ... ok") 330 | assert.Contains(t, lines, "server stopped") 331 | assert.Less(t, len(lines), 55) 332 | assert.Greater(t, len(lines), 40) 333 | } 334 | 335 | func Test_CustomLocaleConfig(t *testing.T) { 336 | // C is the only locale we can guarantee to always work 337 | database := NewDatabase(DefaultConfig().Locale("C")) 338 | if err := database.Start(); err != nil { 339 | shutdownDBAndFail(t, err, database) 340 | } 341 | 342 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 343 | if err != nil { 344 | shutdownDBAndFail(t, err, database) 345 | } 346 | 347 | if err = db.Ping(); err != nil { 348 | shutdownDBAndFail(t, err, database) 349 | } 350 | 351 | if err := db.Close(); err != nil { 352 | shutdownDBAndFail(t, err, database) 353 | } 354 | 355 | if err := database.Stop(); err != nil { 356 | shutdownDBAndFail(t, err, database) 357 | } 358 | } 359 | 360 | func Test_CustomEncodingConfig(t *testing.T) { 361 | database := NewDatabase(DefaultConfig().Encoding("UTF8")) 362 | if err := database.Start(); err != nil { 363 | shutdownDBAndFail(t, err, database) 364 | } 365 | 366 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 367 | if err != nil { 368 | shutdownDBAndFail(t, err, database) 369 | } 370 | 371 | rows := db.QueryRow("SHOW SERVER_ENCODING;") 372 | 373 | var ( 374 | value string 375 | ) 376 | if err := rows.Scan(&value); err != nil { 377 | shutdownDBAndFail(t, err, database) 378 | } 379 | assert.Equal(t, "UTF8", value) 380 | 381 | if err := db.Close(); err != nil { 382 | shutdownDBAndFail(t, err, database) 383 | } 384 | 385 | if err := database.Stop(); err != nil { 386 | shutdownDBAndFail(t, err, database) 387 | } 388 | } 389 | 390 | func Test_ConcurrentStart(t *testing.T) { 391 | var wg sync.WaitGroup 392 | 393 | database := NewDatabase() 394 | cacheLocation, _ := database.cacheLocator() 395 | err := os.RemoveAll(cacheLocation) 396 | require.NoError(t, err) 397 | 398 | port := 5432 399 | for i := 1; i <= 3; i++ { 400 | port = port + 1 401 | wg.Add(1) 402 | 403 | go func(p int) { 404 | defer wg.Done() 405 | tempDir, err := os.MkdirTemp("", "embedded_postgres_test") 406 | if err != nil { 407 | panic(err) 408 | } 409 | 410 | defer func() { 411 | if err := os.RemoveAll(tempDir); err != nil { 412 | panic(err) 413 | } 414 | }() 415 | 416 | database := NewDatabase(DefaultConfig(). 417 | RuntimePath(tempDir). 418 | Port(uint32(p))) 419 | 420 | if err := database.Start(); err != nil { 421 | shutdownDBAndFail(t, err, database) 422 | } 423 | 424 | db, err := sql.Open( 425 | "postgres", 426 | fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", p), 427 | ) 428 | if err != nil { 429 | shutdownDBAndFail(t, err, database) 430 | } 431 | 432 | if err = db.Ping(); err != nil { 433 | shutdownDBAndFail(t, err, database) 434 | } 435 | 436 | if err := db.Close(); err != nil { 437 | shutdownDBAndFail(t, err, database) 438 | } 439 | 440 | if err := database.Stop(); err != nil { 441 | shutdownDBAndFail(t, err, database) 442 | } 443 | 444 | }(port) 445 | } 446 | 447 | wg.Wait() 448 | } 449 | 450 | func Test_CustomStartParameters(t *testing.T) { 451 | database := NewDatabase(DefaultConfig().StartParameters(map[string]string{"max_connections": "101"})) 452 | if err := database.Start(); err != nil { 453 | shutdownDBAndFail(t, err, database) 454 | } 455 | 456 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 457 | if err != nil { 458 | shutdownDBAndFail(t, err, database) 459 | } 460 | 461 | if err := db.Ping(); err != nil { 462 | shutdownDBAndFail(t, err, database) 463 | } 464 | 465 | row := db.QueryRow("SHOW max_connections") 466 | var res string 467 | if err := row.Scan(&res); err != nil { 468 | shutdownDBAndFail(t, err, database) 469 | } 470 | assert.Equal(t, "101", res) 471 | 472 | if err := db.Close(); err != nil { 473 | shutdownDBAndFail(t, err, database) 474 | } 475 | 476 | if err := database.Stop(); err != nil { 477 | shutdownDBAndFail(t, err, database) 478 | } 479 | } 480 | 481 | func Test_CanStartAndStopTwice(t *testing.T) { 482 | database := NewDatabase() 483 | 484 | if err := database.Start(); err != nil { 485 | shutdownDBAndFail(t, err, database) 486 | } 487 | 488 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 489 | if err != nil { 490 | shutdownDBAndFail(t, err, database) 491 | } 492 | 493 | if err = db.Ping(); err != nil { 494 | shutdownDBAndFail(t, err, database) 495 | } 496 | 497 | if err := db.Close(); err != nil { 498 | shutdownDBAndFail(t, err, database) 499 | } 500 | 501 | if err := database.Stop(); err != nil { 502 | shutdownDBAndFail(t, err, database) 503 | } 504 | 505 | if err := database.Start(); err != nil { 506 | shutdownDBAndFail(t, err, database) 507 | } 508 | 509 | db, err = sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 510 | if err != nil { 511 | shutdownDBAndFail(t, err, database) 512 | } 513 | 514 | if err = db.Ping(); err != nil { 515 | shutdownDBAndFail(t, err, database) 516 | } 517 | 518 | if err := db.Close(); err != nil { 519 | shutdownDBAndFail(t, err, database) 520 | } 521 | 522 | if err := database.Stop(); err != nil { 523 | shutdownDBAndFail(t, err, database) 524 | } 525 | } 526 | 527 | func Test_ReuseData(t *testing.T) { 528 | tempDir, err := os.MkdirTemp("", "embedded_postgres_test") 529 | if err != nil { 530 | panic(err) 531 | } 532 | 533 | defer func() { 534 | if err := os.RemoveAll(tempDir); err != nil { 535 | panic(err) 536 | } 537 | }() 538 | 539 | database := NewDatabase(DefaultConfig().DataPath(tempDir)) 540 | 541 | if err := database.Start(); err != nil { 542 | shutdownDBAndFail(t, err, database) 543 | } 544 | 545 | db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 546 | if err != nil { 547 | shutdownDBAndFail(t, err, database) 548 | } 549 | 550 | if _, err = db.Exec("CREATE TABLE test(id serial, value text, PRIMARY KEY(id))"); err != nil { 551 | shutdownDBAndFail(t, err, database) 552 | } 553 | 554 | if _, err = db.Exec("INSERT INTO test (value) VALUES ('foobar')"); err != nil { 555 | shutdownDBAndFail(t, err, database) 556 | } 557 | 558 | if err := db.Close(); err != nil { 559 | shutdownDBAndFail(t, err, database) 560 | } 561 | 562 | if err := database.Stop(); err != nil { 563 | shutdownDBAndFail(t, err, database) 564 | } 565 | 566 | database = NewDatabase(DefaultConfig().DataPath(tempDir)) 567 | 568 | if err := database.Start(); err != nil { 569 | shutdownDBAndFail(t, err, database) 570 | } 571 | 572 | db, err = sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 573 | if err != nil { 574 | shutdownDBAndFail(t, err, database) 575 | } 576 | 577 | if rows, err := db.Query("SELECT * FROM test"); err != nil { 578 | shutdownDBAndFail(t, err, database) 579 | } else { 580 | if !rows.Next() { 581 | shutdownDBAndFail(t, errors.New("no row from db"), database) 582 | } 583 | 584 | var ( 585 | id int64 586 | value string 587 | ) 588 | if err := rows.Scan(&id, &value); err != nil { 589 | shutdownDBAndFail(t, err, database) 590 | } 591 | if value != "foobar" { 592 | shutdownDBAndFail(t, errors.New("wrong value from db"), database) 593 | } 594 | } 595 | 596 | if err := db.Close(); err != nil { 597 | shutdownDBAndFail(t, err, database) 598 | } 599 | 600 | if err := database.Stop(); err != nil { 601 | shutdownDBAndFail(t, err, database) 602 | } 603 | } 604 | 605 | func Test_CustomBinariesRepo(t *testing.T) { 606 | tempDir, err := os.MkdirTemp("", "embedded_postgres_test") 607 | if err != nil { 608 | panic(err) 609 | } 610 | 611 | defer func() { 612 | if err := os.RemoveAll(tempDir); err != nil { 613 | panic(err) 614 | } 615 | }() 616 | 617 | database := NewDatabase(DefaultConfig(). 618 | Username("gin"). 619 | Password("wine"). 620 | Database("beer"). 621 | Version(V15). 622 | RuntimePath(tempDir). 623 | BinaryRepositoryURL("https://repo.maven.apache.org/maven2"). 624 | Port(9876). 625 | StartTimeout(10 * time.Second). 626 | Locale("C"). 627 | Logger(nil)) 628 | 629 | if err := database.Start(); err != nil { 630 | shutdownDBAndFail(t, err, database) 631 | } 632 | 633 | db, err := sql.Open("postgres", "host=localhost port=9876 user=gin password=wine dbname=beer sslmode=disable") 634 | if err != nil { 635 | shutdownDBAndFail(t, err, database) 636 | } 637 | 638 | if err = db.Ping(); err != nil { 639 | shutdownDBAndFail(t, err, database) 640 | } 641 | 642 | if err := db.Close(); err != nil { 643 | shutdownDBAndFail(t, err, database) 644 | } 645 | 646 | if err := database.Stop(); err != nil { 647 | shutdownDBAndFail(t, err, database) 648 | } 649 | } 650 | 651 | func Test_CachePath(t *testing.T) { 652 | cacheTempDir, err := os.MkdirTemp("", "prepare_database_test_cache") 653 | if err != nil { 654 | panic(err) 655 | } 656 | 657 | defer func() { 658 | if err := os.RemoveAll(cacheTempDir); err != nil { 659 | panic(err) 660 | } 661 | }() 662 | 663 | database := NewDatabase(DefaultConfig(). 664 | CachePath(cacheTempDir)) 665 | 666 | if err := database.Start(); err != nil { 667 | shutdownDBAndFail(t, err, database) 668 | } 669 | 670 | if err := database.Stop(); err != nil { 671 | shutdownDBAndFail(t, err, database) 672 | } 673 | } 674 | 675 | func Test_CustomRuntimePathCreatedWhenNotPresent(t *testing.T) { 676 | runtimeTempDir, err := os.MkdirTemp("", "non_existent_runtime_path") 677 | if err != nil { 678 | panic(err) 679 | } 680 | 681 | defer func() { 682 | if err := os.RemoveAll(runtimeTempDir); err != nil { 683 | panic(err) 684 | } 685 | }() 686 | 687 | postgresDataPath := filepath.Join(runtimeTempDir, 688 | fmt.Sprintf(".embedded-postgres-go-%d", 4444), 689 | "extracted") 690 | 691 | database := NewDatabase(DefaultConfig(). 692 | RuntimePath(postgresDataPath). 693 | BinariesPath(postgresDataPath). 694 | DataPath(filepath.Join(postgresDataPath, "data")). 695 | Database("hoh")) 696 | 697 | if err := database.Start(); err != nil { 698 | shutdownDBAndFail(t, err, database) 699 | } 700 | 701 | if err := database.Stop(); err != nil { 702 | shutdownDBAndFail(t, err, database) 703 | } 704 | } 705 | 706 | func Test_CustomBinariesLocation(t *testing.T) { 707 | tempDir, err := os.MkdirTemp("", "prepare_database_test") 708 | if err != nil { 709 | panic(err) 710 | } 711 | 712 | defer func() { 713 | if err := os.RemoveAll(tempDir); err != nil { 714 | panic(err) 715 | } 716 | }() 717 | 718 | database := NewDatabase(DefaultConfig(). 719 | BinariesPath(tempDir)) 720 | 721 | if err := database.Start(); err != nil { 722 | shutdownDBAndFail(t, err, database) 723 | } 724 | 725 | if err := database.Stop(); err != nil { 726 | shutdownDBAndFail(t, err, database) 727 | } 728 | 729 | // Delete cache to make sure unarchive doesn't happen again. 730 | cacheLocation, _ := database.cacheLocator() 731 | if err := os.RemoveAll(cacheLocation); err != nil { 732 | panic(err) 733 | } 734 | 735 | if err := database.Start(); err != nil { 736 | shutdownDBAndFail(t, err, database) 737 | } 738 | 739 | if err := database.Stop(); err != nil { 740 | shutdownDBAndFail(t, err, database) 741 | } 742 | } 743 | 744 | func Test_PrefetchedBinaries(t *testing.T) { 745 | binTempDir, err := os.MkdirTemp("", "prepare_database_test_bin") 746 | if err != nil { 747 | panic(err) 748 | } 749 | 750 | runtimeTempDir, err := os.MkdirTemp("", "prepare_database_test_runtime") 751 | if err != nil { 752 | panic(err) 753 | } 754 | 755 | defer func() { 756 | if err := os.RemoveAll(binTempDir); err != nil { 757 | panic(err) 758 | } 759 | 760 | if err := os.RemoveAll(runtimeTempDir); err != nil { 761 | panic(err) 762 | } 763 | }() 764 | 765 | database := NewDatabase(DefaultConfig(). 766 | BinariesPath(binTempDir). 767 | RuntimePath(runtimeTempDir)) 768 | 769 | // Download and unarchive postgres into the bindir. 770 | if err := database.remoteFetchStrategy(); err != nil { 771 | panic(err) 772 | } 773 | 774 | cacheLocation, _ := database.cacheLocator() 775 | if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir); err != nil { 776 | panic(err) 777 | } 778 | 779 | // Expect everything to work without cacheLocator and/or remoteFetch abilities. 780 | database.cacheLocator = func() (string, bool) { 781 | return "", false 782 | } 783 | database.remoteFetchStrategy = func() error { 784 | return errors.New("did not work") 785 | } 786 | 787 | if err := database.Start(); err != nil { 788 | shutdownDBAndFail(t, err, database) 789 | } 790 | 791 | if err := database.Stop(); err != nil { 792 | shutdownDBAndFail(t, err, database) 793 | } 794 | } 795 | 796 | func Test_RunningInParallel(t *testing.T) { 797 | tempPath, err := os.MkdirTemp("", "parallel_tests_path") 798 | if err != nil { 799 | panic(err) 800 | } 801 | 802 | waitGroup := sync.WaitGroup{} 803 | waitGroup.Add(2) 804 | 805 | runTestWithPortAndPath := func(port uint32, path string) { 806 | defer waitGroup.Done() 807 | 808 | database := NewDatabase(DefaultConfig().Port(port).RuntimePath(path)) 809 | if err := database.Start(); err != nil { 810 | shutdownDBAndFail(t, err, database) 811 | } 812 | 813 | db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port)) 814 | if err != nil { 815 | shutdownDBAndFail(t, err, database) 816 | } 817 | 818 | if err = db.Ping(); err != nil { 819 | shutdownDBAndFail(t, err, database) 820 | } 821 | 822 | if err := db.Close(); err != nil { 823 | shutdownDBAndFail(t, err, database) 824 | } 825 | 826 | if err := database.Stop(); err != nil { 827 | shutdownDBAndFail(t, err, database) 828 | } 829 | } 830 | 831 | go runTestWithPortAndPath(8765, path.Join(tempPath, "1")) 832 | go runTestWithPortAndPath(8766, path.Join(tempPath, "2")) 833 | 834 | waitGroup.Wait() 835 | } 836 | --------------------------------------------------------------------------------