├── gopher.png ├── platform-test ├── go.mod ├── go.sum └── platform_test.go ├── test_config.go ├── go.mod ├── examples ├── migrations │ └── 20191121120000_init.sql ├── go.mod ├── server_test.go ├── examples_test.go └── go.sum ├── rename_test.go ├── rename.go ├── cache_locator_test.go ├── CONTRIBUTING.md ├── LICENSE ├── cache_locator.go ├── .circleci └── config.yml ├── .golangci.yml ├── go.sum ├── logging.go ├── version_strategy.go ├── logging_test.go ├── test_util_test.go ├── .github └── workflows │ └── build.yml ├── .gitignore ├── decompression.go ├── version_strategy_test.go ├── prepare_database.go ├── remote_fetch.go ├── config.go ├── decompression_test.go ├── README.md ├── prepare_database_test.go ├── embedded_postgres.go ├── remote_fetch_test.go └── embedded_postgres_test.go /gopher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fergusstrange/embedded-postgres/HEAD/gopher.png -------------------------------------------------------------------------------- /platform-test/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres/platform-test 2 | 3 | replace github.com/fergusstrange/embedded-postgres => ../ 4 | 5 | go 1.18 6 | 7 | require github.com/fergusstrange/embedded-postgres v0.0.0 8 | 9 | require ( 10 | github.com/lib/pq v1.10.9 // indirect 11 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /test_config.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import "testing" 4 | 5 | func TestGetConnectionURL(t *testing.T) { 6 | config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") 7 | expect := "postgresql://myuser:mypass@localhost:5432/mydb" 8 | 9 | if got := config.GetConnectionURL(); got != expect { 10 | t.Errorf("expected \"%s\" got \"%s\"", expect, got) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/lib/pq v1.10.9 7 | github.com/stretchr/testify v1.10.0 8 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 9 | go.uber.org/goleak v1.3.0 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/kr/text v0.2.0 // indirect 15 | github.com/pmezard/go-difflib v1.0.0 // indirect 16 | gopkg.in/yaml.v3 v3.0.1 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /examples/migrations/20191121120000_init.sql: -------------------------------------------------------------------------------- 1 | -- +goose Up 2 | -- SQL in this section is executed when the migration is applied. 3 | CREATE TABLE beer_catalogue 4 | ( 5 | id SERIAL PRIMARY KEY, 6 | name TEXT, 7 | consumed BOOL DEFAULT TRUE, 8 | rating DOUBLE PRECISION 9 | ); 10 | 11 | INSERT INTO beer_catalogue (name, consumed, rating) 12 | VALUES ('Punk IPA', true, 68.29); 13 | 14 | -- +goose Down 15 | -- SQL in this section is executed when the migration is rolled back. -------------------------------------------------------------------------------- /rename_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func Test_renameOrIgnore_NoErrorOnEEXIST(t *testing.T) { 12 | tmpDir, err := os.MkdirTemp("", "test_dir") 13 | require.NoError(t, err) 14 | 15 | tmpFil, err := os.CreateTemp("", "test_file") 16 | require.NoError(t, err) 17 | 18 | // os.Rename would return an error here, ensure that the error is handled and returned as nil 19 | err = renameOrIgnore(tmpFil.Name(), tmpDir) 20 | assert.NoError(t, err) 21 | } 22 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fergusstrange/embedded-postgres/examples 2 | 3 | go 1.18 4 | 5 | replace github.com/fergusstrange/embedded-postgres => ../ 6 | 7 | require ( 8 | github.com/fergusstrange/embedded-postgres v0.0.0 9 | github.com/jmoiron/sqlx v1.3.5 10 | github.com/lib/pq v1.10.9 11 | github.com/pressly/goose/v3 v3.0.1 12 | go.uber.org/zap v1.21.0 13 | ) 14 | 15 | require ( 16 | github.com/pkg/errors v0.9.1 // indirect 17 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect 18 | go.uber.org/atomic v1.7.0 // indirect 19 | go.uber.org/multierr v1.6.0 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /platform-test/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 3 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 6 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= 7 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= 8 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 9 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 10 | -------------------------------------------------------------------------------- /rename.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "syscall" 7 | ) 8 | 9 | // renameOrIgnore will rename the oldpath to the newpath. 10 | // 11 | // On Unix this will be a safe atomic operation. 12 | // On Windows this will do nothing if the new path already exists. 13 | // 14 | // This is only safe to use if you can be sure that the newpath is either missing, or contains the same data as the 15 | // old path. 16 | func renameOrIgnore(oldpath, newpath string) error { 17 | err := os.Rename(oldpath, newpath) 18 | 19 | // if the error is due to syscall.EEXIST then this is most likely windows, and a race condition with 20 | // multiple downloads of the file. We can assume that the existing file is the correct one and ignore 21 | // the error 22 | if errors.Is(err, syscall.EEXIST) { 23 | return nil 24 | } 25 | 26 | return err 27 | } 28 | -------------------------------------------------------------------------------- /cache_locator_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_defaultCacheLocator_NotExists(t *testing.T) { 10 | locator := defaultCacheLocator("", func() (string, string, PostgresVersion) { 11 | return "a", "b", "1.2.3" 12 | }) 13 | 14 | cacheLocation, exists := locator() 15 | 16 | assert.Contains(t, cacheLocation, ".embedded-postgres-go/embedded-postgres-binaries-a-b-1.2.3.txz") 17 | assert.False(t, exists) 18 | } 19 | 20 | func Test_defaultCacheLocator_CustomPath(t *testing.T) { 21 | locator := defaultCacheLocator("/custom/path", func() (string, string, PostgresVersion) { 22 | return "a", "b", "1.2.3" 23 | }) 24 | 25 | cacheLocation, exists := locator() 26 | 27 | assert.Equal(t, cacheLocation, "/custom/path/embedded-postgres-binaries-a-b-1.2.3.txz") 28 | assert.False(t, exists) 29 | } 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to embedded-postgres 2 | 3 | Thank you for taking the time to contribute. These are mostly guidelines, not rules. Use your best judgment, and feel 4 | free to propose changes to this document in a pull request. 5 | 6 | # Working with forked go repos 7 | 8 | If you haven't worked with forked go repos before, take a look at this blog post for some excellent advice 9 | about [contributing to go open source git repositories](https://splice.com/blog/contributing-open-source-git-repositories-go/) 10 | . 11 | 12 | # PRs 13 | 14 | - Please open PRs against master. 15 | - We prefer single commit PRs, but sometimes for multiple commits are justified - use your best judgement. 16 | - Please add/modify tests to cover the proposed code changes. 17 | - If the PR contains a new feature, please document it in the README. 18 | 19 | # Documentation 20 | 21 | For simple typo fixes and documentation improvements feel free to raise a PR without raising an issue in github. For 22 | anything more complicated please file an issue. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Fergus Strange 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cache_locator.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | // CacheLocator retrieves the location of the Postgres binary cache returning it to location. 10 | // The result of whether this cache is present will be returned to exists. 11 | type CacheLocator func() (location string, exists bool) 12 | 13 | func defaultCacheLocator(cacheDirectory string, versionStrategy VersionStrategy) CacheLocator { 14 | return func() (string, bool) { 15 | if cacheDirectory == "" { 16 | cacheDirectory = ".embedded-postgres-go" 17 | if userHome, err := os.UserHomeDir(); err == nil { 18 | cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go") 19 | } 20 | } 21 | 22 | operatingSystem, architecture, version := versionStrategy() 23 | cacheLocation := filepath.Join(cacheDirectory, 24 | fmt.Sprintf("embedded-postgres-binaries-%s-%s-%s.txz", 25 | operatingSystem, 26 | architecture, 27 | version)) 28 | 29 | info, err := os.Stat(cacheLocation) 30 | 31 | if err != nil { 32 | return cacheLocation, os.IsExist(err) && !info.IsDir() 33 | } 34 | 35 | return cacheLocation, !info.IsDir() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | executors: 3 | linux-arm64: 4 | machine: 5 | image: ubuntu-2204:2024.01.2 6 | resource_class: arm.medium 7 | working_directory: /home/circleci/go/src/github.com/fergusstrange/embedded-postgres 8 | apple-m4: &macos-executor 9 | resource_class: m4pro.medium 10 | macos: 11 | xcode: "26.1.0" 12 | orbs: 13 | go: circleci/go@3.0.3 14 | jobs: 15 | platform_test: 16 | parameters: 17 | executor: 18 | type: executor 19 | executor: << parameters.executor >> 20 | steps: 21 | - checkout 22 | - when: 23 | condition: 24 | equal: [ *macos-executor, << parameters.executor >> ] 25 | steps: 26 | - go/install: 27 | version: "1.18" 28 | - run: /usr/sbin/softwareupdate --install-rosetta --agree-to-license 29 | - go/load-mod-cache 30 | - go/mod-download 31 | - go/save-mod-cache 32 | - run: cd platform-test && go mod download && go test -v -race ./... 33 | 34 | workflows: 35 | version: 2 36 | test: 37 | jobs: 38 | - platform_test: 39 | matrix: 40 | parameters: 41 | executor: 42 | - linux-arm64 43 | - apple-m4 44 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - deadcode 5 | - errcheck 6 | - gosimple 7 | - govet 8 | - ineffassign 9 | - staticcheck 10 | - structcheck 11 | - typecheck 12 | - unused 13 | - varcheck 14 | - bodyclose 15 | - depguard 16 | - dogsled 17 | - dupl 18 | - funlen 19 | - gochecknoglobals 20 | - gochecknoinits 21 | - gocognit 22 | - goconst 23 | - gocritic 24 | - gocyclo 25 | - godox 26 | - gofmt 27 | - goimports 28 | - revive 29 | - misspell 30 | - nakedret 31 | - prealloc 32 | - exportloopref 33 | - stylecheck 34 | - unconvert 35 | - unparam 36 | - whitespace 37 | - wsl 38 | issues: 39 | exclude-use-default: false 40 | exclude: 41 | - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv). is not checked 42 | - ST1000 43 | - func name will be used as test\.Test.* by other packages, and that stutters; consider calling this 44 | - (possible misuse of unsafe.Pointer|should have signature) 45 | - ineffective break statement. Did you mean to break out of the outer loop 46 | - Use of unsafe calls should be audited 47 | - Subprocess launch(ed with variable|ing should be audited) 48 | - G104 49 | - (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less) 50 | - Potential file inclusion via variable 51 | exclude-rules: 52 | - path: '(.+)_test\.go' 53 | linters: 54 | - funlen 55 | - goconst 56 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 5 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 6 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 7 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 8 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 12 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 13 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= 14 | github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= 15 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 16 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 18 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 19 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 20 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 21 | -------------------------------------------------------------------------------- /examples/server_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/jmoiron/sqlx" 9 | "github.com/pressly/goose/v3" 10 | ) 11 | 12 | type BeerCatalogue struct { 13 | ID int64 `json:"id"` 14 | Name string `json:"name"` 15 | Consumed bool `json:"consumed"` 16 | Rating float64 `json:"rating"` 17 | } 18 | 19 | type App struct { 20 | router *http.ServeMux 21 | } 22 | 23 | func (a *App) Start() error { 24 | return http.ListenAndServe("localhost:8080", a.router) 25 | } 26 | 27 | func NewApp() *App { 28 | db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | 33 | if err := goose.Up(db.DB, "./migrations"); err != nil { 34 | log.Fatal(err) 35 | } 36 | 37 | router := http.NewServeMux() 38 | router.HandleFunc("/beer-catalogue", GetBeer(db)) 39 | 40 | return &App{router: router} 41 | } 42 | 43 | func GetBeer(db *sqlx.DB) func(w http.ResponseWriter, r *http.Request) { 44 | return func(w http.ResponseWriter, r *http.Request) { 45 | if beerName := r.URL.Query().Get("name"); beerName != "" { 46 | beers := make([]BeerCatalogue, 0) 47 | if err := db.Select(&beers, "SELECT * FROM beer_catalogue WHERE UPPER(name) = UPPER($1)", beerName); err != nil { 48 | w.WriteHeader(http.StatusInternalServerError) 49 | return 50 | } 51 | 52 | jsonPayload, err := json.Marshal(beers) 53 | if err != nil { 54 | w.WriteHeader(http.StatusInternalServerError) 55 | return 56 | } 57 | 58 | if _, err := w.Write(jsonPayload); err != nil { 59 | w.WriteHeader(http.StatusInternalServerError) 60 | return 61 | } 62 | } 63 | 64 | w.WriteHeader(http.StatusBadRequest) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/ioutil" 7 | "os" 8 | "time" 9 | ) 10 | 11 | type syncedLogger struct { 12 | offset int64 13 | logger io.Writer 14 | file *os.File 15 | } 16 | 17 | func newSyncedLogger(dir string, logger io.Writer) (*syncedLogger, error) { 18 | file, err := os.CreateTemp(dir, "embedded_postgres_log") 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | s := syncedLogger{ 24 | logger: logger, 25 | file: file, 26 | } 27 | 28 | return &s, nil 29 | } 30 | 31 | func (s *syncedLogger) flush() error { 32 | if s.logger != nil { 33 | file, err := os.Open(s.file.Name()) 34 | if err != nil { 35 | return fmt.Errorf("unable to process postgres logs: %s", err) 36 | } 37 | 38 | defer func() { 39 | if err := file.Close(); err != nil { 40 | panic(err) 41 | } 42 | }() 43 | 44 | if _, err = file.Seek(s.offset, io.SeekStart); err != nil { 45 | return fmt.Errorf("unable to process postgres logs: %s", err) 46 | } 47 | 48 | readBytes, err := io.Copy(s.logger, file) 49 | if err != nil { 50 | return fmt.Errorf("unable to process postgres logs: %s", err) 51 | } 52 | 53 | s.offset += readBytes 54 | } 55 | 56 | return nil 57 | } 58 | 59 | func readLogsOrTimeout(logger *os.File) (logContent []byte, err error) { 60 | logContent = []byte("logs could not be read") 61 | 62 | logContentChan := make(chan []byte, 1) 63 | errChan := make(chan error, 1) 64 | 65 | go func() { 66 | if actualLogContent, err := ioutil.ReadFile(logger.Name()); err == nil { 67 | logContentChan <- actualLogContent 68 | } else { 69 | errChan <- err 70 | } 71 | }() 72 | 73 | select { 74 | case logContent = <-logContentChan: 75 | case err = <-errChan: 76 | case <-time.After(10 * time.Second): 77 | err = fmt.Errorf("timed out waiting for logs") 78 | } 79 | 80 | return logContent, err 81 | } 82 | -------------------------------------------------------------------------------- /version_strategy.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "strings" 8 | ) 9 | 10 | // VersionStrategy provides a strategy that can be used to determine which version of Postgres should be used based on 11 | // the operating system, architecture and desired Postgres version. 12 | type VersionStrategy func() (operatingSystem string, architecture string, postgresVersion PostgresVersion) 13 | 14 | func defaultVersionStrategy(config Config, goos, arch string, linuxMachineName func() string, shouldUseAlpineLinuxBuild func() bool) VersionStrategy { 15 | return func() (string, string, PostgresVersion) { 16 | goos := goos 17 | arch := arch 18 | 19 | if goos == "linux" { 20 | // the zonkyio/embedded-postgres-binaries project produces 21 | // arm binaries with the following name schema: 22 | // 32bit: arm32v6 / arm32v7 23 | // 64bit (aarch64): arm64v8 24 | if arch == "arm64" { 25 | arch += "v8" 26 | } else if arch == "arm" { 27 | machineName := linuxMachineName() 28 | if strings.HasPrefix(machineName, "armv7") { 29 | arch += "32v7" 30 | } else if strings.HasPrefix(machineName, "armv6") { 31 | arch += "32v6" 32 | } 33 | } 34 | 35 | if shouldUseAlpineLinuxBuild() { 36 | arch += "-alpine" 37 | } 38 | } 39 | 40 | // postgres below version 14.2 is not available for macos on arm 41 | if goos == "darwin" && arch == "arm64" { 42 | var majorVer, minorVer int 43 | if _, err := fmt.Sscanf(string(config.version), "%d.%d", &majorVer, &minorVer); err == nil && 44 | (majorVer < 14 || (majorVer == 14 && minorVer < 2)) { 45 | arch = "amd64" 46 | } else { 47 | arch += "v8" 48 | } 49 | } 50 | 51 | return goos, arch, config.version 52 | } 53 | } 54 | 55 | func linuxMachineName() string { 56 | var uname string 57 | 58 | if output, err := exec.Command("uname", "-m").Output(); err == nil { 59 | uname = string(output) 60 | } 61 | 62 | return uname 63 | } 64 | 65 | func shouldUseAlpineLinuxBuild() bool { 66 | _, err := os.Stat("/etc/alpine-release") 67 | return err == nil 68 | } 69 | -------------------------------------------------------------------------------- /logging_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | type customLogger struct { 14 | logLines []byte 15 | } 16 | 17 | func (cl *customLogger) Write(p []byte) (n int, err error) { 18 | cl.logLines = append(cl.logLines, p...) 19 | return len(p), nil 20 | } 21 | 22 | func Test_SyncedLogger_CreateError(t *testing.T) { 23 | logger := customLogger{} 24 | _, err := newSyncedLogger("/not-exists-anywhere", &logger) 25 | 26 | assert.Error(t, err) 27 | } 28 | 29 | func Test_SyncedLogger_ErrorDuringFlush(t *testing.T) { 30 | logger := customLogger{} 31 | 32 | sl, slErr := newSyncedLogger("", &logger) 33 | 34 | assert.NoError(t, slErr) 35 | 36 | rmFileErr := os.Remove(sl.file.Name()) 37 | 38 | assert.NoError(t, rmFileErr) 39 | 40 | err := sl.flush() 41 | 42 | assert.Error(t, err) 43 | } 44 | 45 | func Test_SyncedLogger_NoErrorDuringFlush(t *testing.T) { 46 | logger := customLogger{} 47 | 48 | sl, slErr := newSyncedLogger("", &logger) 49 | 50 | assert.NoError(t, slErr) 51 | 52 | err := os.WriteFile(sl.file.Name(), []byte("some logs\non a new line"), os.ModeAppend) 53 | 54 | assert.NoError(t, err) 55 | 56 | err = sl.flush() 57 | 58 | assert.NoError(t, err) 59 | 60 | assert.Equal(t, "some logs\non a new line", string(logger.logLines)) 61 | } 62 | 63 | func Test_readLogsOrTimeout(t *testing.T) { 64 | logFile, err := ioutil.TempFile("", "prepare_database_test_log") 65 | if err != nil { 66 | panic(err) 67 | } 68 | 69 | logContent, err := readLogsOrTimeout(logFile) 70 | assert.NoError(t, err) 71 | assert.Equal(t, []byte(""), logContent) 72 | 73 | _, _ = logFile.Write([]byte("and here are the logs!")) 74 | 75 | logContent, err = readLogsOrTimeout(logFile) 76 | assert.NoError(t, err) 77 | assert.Equal(t, []byte("and here are the logs!"), logContent) 78 | 79 | require.NoError(t, os.Remove(logFile.Name())) 80 | logContent, err = readLogsOrTimeout(logFile) 81 | assert.Equal(t, []byte("logs could not be read"), logContent) 82 | assert.EqualError(t, err, fmt.Sprintf("open %s: no such file or directory", logFile.Name())) 83 | } 84 | -------------------------------------------------------------------------------- /test_util_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "encoding/base64" 5 | "os" 6 | "testing" 7 | 8 | "go.uber.org/goleak" 9 | ) 10 | 11 | func createTempXzArchive() (string, func()) { 12 | return writeFileWithBase64Content("remote_fetch_test*.txz", "/Td6WFoAAATm1rRGAgAhARYAAAB0L+Wj4Av/AKZdADIaSqdFdWDG5Dyin7tszujmfm9YJn6/1REVUfqW8HwXvgwbrrcDDc4Q2ql+L+ybLTxJ+QNhhaKnawviRjKhUOT3syXi2Ye8k4QMkeurnnCu4a8eoCV+hqNFWkk8/w8MzyMzQZ2D3wtvoaZV/KqJ8jyLbNVj+vsKrzqg5vbSGz5/h7F37nqN1V8ZsdCnKnDMZPzovM8RwtelDd0g3fPC0dG/W9PH4wAAAAC2dqs1k9ZA0QABwgGAGAAAIQZ5XbHEZ/sCAAAAAARZWg==") 13 | } 14 | 15 | func createTempZipArchive() (string, func()) { 16 | return writeFileWithBase64Content("remote_fetch_test*.zip", "UEsDBBQACAAIAExBSlMAAAAAAAAAAAAAAAAaAAkAcmVtb3RlX2ZldGNoX3Rlc3Q4MDA0NjE5MDVVVAUAAfCfYmEBAAD//1BLBwgAAAAABQAAAAAAAABQSwMEFAAIAAAATEFKUwAAAAAAAAAAAAAAABUACQByZW1vdGVfZmV0Y2hfdGVzdC50eHpVVAUAAfCfYmH9N3pYWgAABObWtEYCACEBFgAAAHQv5aPgBf8Abl0AORlJ/tq+A8rMBye1kCuXLnw2aeeO0gdfXeVHCWpF8/VeZU/MTVkdLzI+XgKLEMlHJukIdxP7iSAuKts+v7aDrJu68RHNgIsXGrGouAjf780FXjTUjX4vXDh08vNY1yOBayt9z9dKHdoG9AeAIgAAAAAOKMpgA1Mm3wABigGADAAAjIVdpbHEZ/sCAAAAAARZWlBLBwhkmQgRsAAAALAAAABQSwECFAMUAAgACABMQUpTAAAAAAUAAAAAAAAAGgAJAAAAAAAAAAAAgIEAAAAAcmVtb3RlX2ZldGNoX3Rlc3Q4MDA0NjE5MDVVVAUAAfCfYmFQSwECFAMUAAgAAABMQUpTZJkIEbAAAACwAAAAFQAJAAAAAAAAAAAApIFWAAAAcmVtb3RlX2ZldGNoX3Rlc3QudHh6VVQFAAHwn2JhUEsFBgAAAAACAAIAnQAAAFIBAAAAAA==") 17 | } 18 | 19 | func writeFileWithBase64Content(filename, base64Content string) (string, func()) { 20 | tempFile, err := os.CreateTemp("", filename) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | byteContent, err := base64.StdEncoding.DecodeString(base64Content) 26 | if err != nil { 27 | panic(err) 28 | } 29 | 30 | if err := os.WriteFile(tempFile.Name(), byteContent, 0666); err != nil { 31 | panic(err) 32 | } 33 | 34 | return tempFile.Name(), func() { 35 | if err := os.RemoveAll(tempFile.Name()); err != nil { 36 | panic(err) 37 | } 38 | } 39 | } 40 | 41 | func shutdownDBAndFail(t *testing.T, err error, db *EmbeddedPostgres) { 42 | if db.started { 43 | if stopErr := db.Stop(); stopErr != nil { 44 | t.Errorf("Failed to shutdown server with error %s", stopErr) 45 | } 46 | } 47 | 48 | t.Errorf("Failed for version %s with error %s", db.config.version, err) 49 | } 50 | 51 | func testVersionStrategy() VersionStrategy { 52 | return func() (string, string, PostgresVersion) { 53 | return "darwin", "amd64", "1.2.3" 54 | } 55 | } 56 | 57 | func testCacheLocator() CacheLocator { 58 | return func() (s string, b bool) { 59 | return "", false 60 | } 61 | } 62 | 63 | func verifyLeak(t *testing.T) { 64 | // Ideally, there should be no exceptions here. 65 | goleak.VerifyNone(t, goleak.IgnoreTopFunction("internal/poll.runtime_pollWait")) 66 | } 67 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Embedded Postgres 2 | on: 3 | push: 4 | branches: [ "master" ] 5 | pull_request: 6 | branches: [ "*" ] 7 | jobs: 8 | tests: 9 | name: Tests 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | id: go 14 | uses: actions/checkout@v4 15 | - name: Set Up Golang 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: 1.22 19 | - name: Check Dependencies 20 | run: | 21 | go list -json -deps > go.list 22 | for d in "." "examples" "platform-test"; do 23 | pushd $d 24 | go mod tidy 25 | if [ ! -z "$(git status --porcelain go.mod)" ]; then 26 | printf "go.mod has modifications\n" 27 | git diff go.mod 28 | exit 1 29 | fi 30 | if [ ! -z "$(git status --porcelain go.sum)" ]; then 31 | printf "go.sum has modifications\n" 32 | git diff go.sum 33 | exit 1 34 | fi 35 | popd 36 | done; 37 | - name: Nancy Vulnerability 38 | uses: sonatype-nexus-community/nancy-github-action@main 39 | with: 40 | nancyVersion: v1.0.52 41 | nancyCommand: sleuth 42 | env: 43 | OSSI_TOKEN: ${{ secrets.OSSI_TOKEN }} 44 | OSSI_USERNAME: ${{ secrets.OSSI_USERNAME }} 45 | - name: GolangCI Lint 46 | run: | 47 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.42.1 48 | /home/runner/go/bin/golangci-lint run 49 | - name: Test 50 | run: go test -v -test.timeout 0 -race -cover -covermode=atomic -coverprofile=coverage.out ./... 51 | - name: Test Examples 52 | run: | 53 | pushd examples && \ 54 | go test -v ./... && \ 55 | popd 56 | - name: Upload Coverage Report 57 | env: 58 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 59 | run: go install github.com/mattn/goveralls@latest && $(go env GOPATH)/bin/goveralls -v -coverprofile=coverage.out -service=github 60 | alpine_tests: 61 | name: Alpine Linux Platform Tests 62 | runs-on: ubuntu-latest 63 | container: 64 | image: golang:1.22-alpine 65 | steps: 66 | - uses: actions/checkout@v4 67 | - name: Set Up 68 | run: | 69 | apk add --upgrade gcc g++ && \ 70 | adduser testuser -D 71 | - name: All Tests 72 | run: su - testuser -c 'export PATH=$PATH:/usr/local/go/bin; cd /__w/embedded-postgres/embedded-postgres && go test -v ./... && cd platform-test && go test -v ./...' 73 | platform_tests: 74 | name: Platform tests 75 | strategy: 76 | matrix: 77 | os: [ ubuntu-latest, windows-latest, macos-14 ] 78 | runs-on: ${{ matrix.os }} 79 | steps: 80 | - name: Checkout 81 | uses: actions/checkout@v4 82 | - name: Set Up Golang 83 | uses: actions/setup-go@v5 84 | with: 85 | go-version: 1.22 86 | - name: Platform Tests 87 | run: | 88 | cd platform-test 89 | go test -v -race ./... 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/go,intellij+all,visualstudiocode 3 | # Edit at https://www.gitignore.io/?templates=go,intellij+all,visualstudiocode 4 | 5 | ### Go ### 6 | # Binaries for programs and plugins 7 | *.exe 8 | *.exe~ 9 | *.dll 10 | *.so 11 | *.dylib 12 | 13 | # Test binary, built with `go test -c` 14 | *.test 15 | 16 | # Output of the go coverage tool, specifically when used with LiteIDE 17 | *.out 18 | 19 | # Dependency directories (remove the comment below to include it) 20 | # vendor/ 21 | 22 | ### Go Patch ### 23 | /vendor/ 24 | /Godeps/ 25 | 26 | ### Intellij+all ### 27 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 28 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 29 | 30 | # User-specific stuff 31 | .idea/**/workspace.xml 32 | .idea/**/tasks.xml 33 | .idea/**/usage.statistics.xml 34 | .idea/**/dictionaries 35 | .idea/**/shelf 36 | 37 | # Generated files 38 | .idea/**/contentModel.xml 39 | 40 | # Sensitive or high-churn files 41 | .idea/**/dataSources/ 42 | .idea/**/dataSources.ids 43 | .idea/**/dataSources.local.xml 44 | .idea/**/sqlDataSources.xml 45 | .idea/**/dynamic.xml 46 | .idea/**/uiDesigner.xml 47 | .idea/**/dbnavigator.xml 48 | 49 | # Gradle 50 | .idea/**/gradle.xml 51 | .idea/**/libraries 52 | 53 | # Gradle and Maven with auto-import 54 | # When using Gradle or Maven with auto-import, you should exclude module files, 55 | # since they will be recreated, and may cause churn. Uncomment if using 56 | # auto-import. 57 | # .idea/modules.xml 58 | # .idea/*.iml 59 | # .idea/modules 60 | # *.iml 61 | # *.ipr 62 | 63 | # CMake 64 | cmake-build-*/ 65 | 66 | # Mongo Explorer plugin 67 | .idea/**/mongoSettings.xml 68 | 69 | # File-based project format 70 | *.iws 71 | 72 | # IntelliJ 73 | out/ 74 | 75 | # mpeltonen/sbt-idea plugin 76 | .idea_modules/ 77 | 78 | # JIRA plugin 79 | atlassian-ide-plugin.xml 80 | 81 | # Cursive Clojure plugin 82 | .idea/replstate.xml 83 | 84 | # Crashlytics plugin (for Android Studio and IntelliJ) 85 | com_crashlytics_export_strings.xml 86 | crashlytics.properties 87 | crashlytics-build.properties 88 | fabric.properties 89 | 90 | # Editor-based Rest Client 91 | .idea/httpRequests 92 | 93 | # Android studio 3.1+ serialized cache file 94 | .idea/caches/build_file_checksums.ser 95 | 96 | ### Intellij+all Patch ### 97 | # Ignores the whole .idea folder and all .iml files 98 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 99 | 100 | .idea/ 101 | 102 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 103 | 104 | *.iml 105 | modules.xml 106 | .idea/misc.xml 107 | *.ipr 108 | 109 | # Sonarlint plugin 110 | .idea/sonarlint 111 | 112 | ### VisualStudioCode ### 113 | .vscode/* 114 | !.vscode/settings.json 115 | !.vscode/tasks.json 116 | !.vscode/launch.json 117 | !.vscode/extensions.json 118 | 119 | ### VisualStudioCode Patch ### 120 | # Ignore all local history of files 121 | .history 122 | 123 | # End of https://www.gitignore.io/api/go,intellij+all,visualstudiocode 124 | 125 | main -------------------------------------------------------------------------------- /decompression.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/tar" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/xi2/xz" 11 | ) 12 | 13 | func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 14 | tarReader := tar.NewReader(xzReader) 15 | 16 | return func() (*tar.Header, error) { 17 | return tarReader.Next() 18 | }, func() io.Reader { 19 | return tarReader 20 | } 21 | } 22 | 23 | func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error { 24 | extractDirectory := filepath.Dir(extractPath) 25 | 26 | if err := os.MkdirAll(extractDirectory, os.ModePerm); err != nil { 27 | return errorUnableToExtract(path, extractPath, err) 28 | } 29 | 30 | tempExtractPath, err := os.MkdirTemp(extractDirectory, "temp_") 31 | if err != nil { 32 | return errorUnableToExtract(path, extractPath, err) 33 | } 34 | defer func() { 35 | if err := os.RemoveAll(tempExtractPath); err != nil { 36 | panic(err) 37 | } 38 | }() 39 | 40 | tarFile, err := os.Open(path) 41 | if err != nil { 42 | return errorUnableToExtract(path, extractPath, err) 43 | } 44 | 45 | defer func() { 46 | if err := tarFile.Close(); err != nil { 47 | panic(err) 48 | } 49 | }() 50 | 51 | xzReader, err := xz.NewReader(tarFile, 0) 52 | if err != nil { 53 | return errorUnableToExtract(path, extractPath, err) 54 | } 55 | 56 | readNext, reader := tarReader(xzReader) 57 | 58 | for { 59 | header, err := readNext() 60 | 61 | if err == io.EOF { 62 | break 63 | } 64 | 65 | if err != nil { 66 | return errorExtractingPostgres(err) 67 | } 68 | 69 | targetPath := filepath.Join(tempExtractPath, header.Name) 70 | finalPath := filepath.Join(extractPath, header.Name) 71 | 72 | if err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm); err != nil { 73 | return errorExtractingPostgres(err) 74 | } 75 | 76 | if err := os.MkdirAll(filepath.Dir(finalPath), os.ModePerm); err != nil { 77 | return errorExtractingPostgres(err) 78 | } 79 | 80 | switch header.Typeflag { 81 | case tar.TypeReg: 82 | outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) 83 | if err != nil { 84 | return errorExtractingPostgres(err) 85 | } 86 | 87 | if _, err := io.Copy(outFile, reader()); err != nil { 88 | return errorExtractingPostgres(err) 89 | } 90 | 91 | if err := outFile.Close(); err != nil { 92 | return errorExtractingPostgres(err) 93 | } 94 | case tar.TypeSymlink: 95 | if err := os.RemoveAll(targetPath); err != nil { 96 | return errorExtractingPostgres(err) 97 | } 98 | 99 | if err := os.Symlink(header.Linkname, targetPath); err != nil { 100 | return errorExtractingPostgres(err) 101 | } 102 | 103 | case tar.TypeDir: 104 | if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil { 105 | return errorExtractingPostgres(err) 106 | } 107 | continue 108 | } 109 | 110 | if err := renameOrIgnore(targetPath, finalPath); err != nil { 111 | return errorExtractingPostgres(err) 112 | } 113 | } 114 | 115 | return nil 116 | } 117 | 118 | func errorUnableToExtract(cacheLocation, binariesPath string, err error) error { 119 | return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w", 120 | cacheLocation, 121 | binariesPath, 122 | err, 123 | ) 124 | } 125 | -------------------------------------------------------------------------------- /platform-test/platform_test.go: -------------------------------------------------------------------------------- 1 | package platform_test 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "runtime" 9 | "strings" 10 | "testing" 11 | 12 | embeddedpostgres "github.com/fergusstrange/embedded-postgres" 13 | ) 14 | 15 | func Test_AllMajorVersions(t *testing.T) { 16 | allVersions := []embeddedpostgres.PostgresVersion{ 17 | embeddedpostgres.V18, 18 | embeddedpostgres.V17, 19 | embeddedpostgres.V16, 20 | embeddedpostgres.V15, 21 | embeddedpostgres.V14, 22 | } 23 | 24 | isLikelyAppleSilicon := runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" 25 | 26 | if !isLikelyAppleSilicon { 27 | allVersions = append(allVersions, 28 | embeddedpostgres.V13, 29 | embeddedpostgres.V12, 30 | embeddedpostgres.V11, 31 | embeddedpostgres.V10, 32 | embeddedpostgres.V9) 33 | } 34 | 35 | tempExtractLocation, err := os.MkdirTemp("", "embedded_postgres_tests") 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | for i, v := range allVersions { 41 | testNumber := i 42 | version := v 43 | t.Run(fmt.Sprintf("MajorVersion_%s", version), func(t *testing.T) { 44 | port := uint32(5555 + testNumber) 45 | runtimePath := filepath.Join(tempExtractLocation, string(version)) 46 | 47 | maxConnections := 150 48 | database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). 49 | Version(version). 50 | Port(port). 51 | RuntimePath(runtimePath). 52 | StartParameters(map[string]string{ 53 | "max_connections": fmt.Sprintf("%d", maxConnections), 54 | })) 55 | 56 | if err := database.Start(); err != nil { 57 | shutdownDBAndFail(t, err, database, version) 58 | } 59 | 60 | db, err := connect(port) 61 | if err != nil { 62 | shutdownDBAndFail(t, err, database, version) 63 | } 64 | 65 | rows, err := db.Query("SELECT 1") 66 | if err != nil { 67 | shutdownDBAndFail(t, err, database, version) 68 | } 69 | if err := rows.Close(); err != nil { 70 | shutdownDBAndFail(t, err, database, version) 71 | } 72 | 73 | rows, err = db.Query(`SELECT setting::int max_conn FROM pg_settings WHERE name = 'max_connections';`) 74 | if err != nil { 75 | shutdownDBAndFail(t, err, database, version) 76 | } 77 | if !rows.Next() { 78 | shutdownDBAndFail(t, fmt.Errorf("no rows returned for max_connections"), database, version) 79 | } 80 | var maxConnReturned int 81 | if err := rows.Scan(&maxConnReturned); err != nil { 82 | shutdownDBAndFail(t, err, database, version) 83 | } 84 | if maxConnReturned != maxConnections { 85 | shutdownDBAndFail(t, fmt.Errorf("max_connections is %d, not %d as expected", maxConnReturned, maxConnections), database, version) 86 | } 87 | if err := rows.Close(); err != nil { 88 | shutdownDBAndFail(t, err, database, version) 89 | } 90 | 91 | if err := db.Close(); err != nil { 92 | shutdownDBAndFail(t, err, database, version) 93 | } 94 | 95 | if err := database.Stop(); err != nil { 96 | t.Fatal(err) 97 | } 98 | 99 | if err := checkPgVersionFile(filepath.Join(runtimePath, "data"), version); err != nil { 100 | t.Fatal(err) 101 | } 102 | }) 103 | } 104 | } 105 | 106 | func shutdownDBAndFail(t *testing.T, err error, db *embeddedpostgres.EmbeddedPostgres, version embeddedpostgres.PostgresVersion) { 107 | if err2 := db.Stop(); err2 != nil { 108 | t.Fatalf("Failed for version %s with error %s, original error %s", version, err2, err) 109 | } 110 | 111 | t.Fatalf("Failed for version %s with error %s", version, err) 112 | } 113 | 114 | func connect(port uint32) (*sql.DB, error) { 115 | db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port)) 116 | return db, err 117 | } 118 | 119 | func checkPgVersionFile(dataDir string, version embeddedpostgres.PostgresVersion) error { 120 | pgVersion := filepath.Join(dataDir, "PG_VERSION") 121 | 122 | d, err := os.ReadFile(pgVersion) 123 | if err != nil { 124 | return fmt.Errorf("could not read file %v", pgVersion) 125 | } 126 | 127 | v := strings.TrimSuffix(string(d), "\n") 128 | 129 | if strings.HasPrefix(string(version), v) { 130 | return nil 131 | } 132 | 133 | return fmt.Errorf("version missmatch in PG_VERSION: %v <> %v", string(version), v) 134 | } 135 | -------------------------------------------------------------------------------- /version_strategy_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_DefaultVersionStrategy_AllGolangDistributions(t *testing.T) { 12 | allGolangDistributions := map[string][]string{ 13 | "aix/ppc64": {"aix", "ppc64"}, 14 | "android/386": {"android", "386"}, 15 | "android/amd64": {"android", "amd64"}, 16 | "android/arm": {"android", "arm"}, 17 | "android/arm64": {"android", "arm64"}, 18 | "darwin/amd64": {"darwin", "amd64"}, 19 | "darwin/arm64": {"darwin", "amd64"}, 20 | "dragonfly/amd64": {"dragonfly", "amd64"}, 21 | "freebsd/386": {"freebsd", "386"}, 22 | "freebsd/amd64": {"freebsd", "amd64"}, 23 | "freebsd/arm": {"freebsd", "arm"}, 24 | "freebsd/arm64": {"freebsd", "arm64"}, 25 | "illumos/amd64": {"illumos", "amd64"}, 26 | "js/wasm": {"js", "wasm"}, 27 | "linux/386": {"linux", "386"}, 28 | "linux/amd64": {"linux", "amd64"}, 29 | "linux/arm": {"linux", "arm"}, 30 | "linux/arm64": {"linux", "arm64v8"}, 31 | "linux/mips": {"linux", "mips"}, 32 | "linux/mips64": {"linux", "mips64"}, 33 | "linux/mips64le": {"linux", "mips64le"}, 34 | "linux/mipsle": {"linux", "mipsle"}, 35 | "linux/ppc64": {"linux", "ppc64"}, 36 | "linux/ppc64le": {"linux", "ppc64le"}, 37 | "linux/riscv64": {"linux", "riscv64"}, 38 | "linux/s390x": {"linux", "s390x"}, 39 | "netbsd/386": {"netbsd", "386"}, 40 | "netbsd/amd64": {"netbsd", "amd64"}, 41 | "netbsd/arm": {"netbsd", "arm"}, 42 | "netbsd/arm64": {"netbsd", "arm64"}, 43 | "openbsd/386": {"openbsd", "386"}, 44 | "openbsd/amd64": {"openbsd", "amd64"}, 45 | "openbsd/arm": {"openbsd", "arm"}, 46 | "openbsd/arm64": {"openbsd", "arm64"}, 47 | "plan9/386": {"plan9", "386"}, 48 | "plan9/amd64": {"plan9", "amd64"}, 49 | "plan9/arm": {"plan9", "arm"}, 50 | "solaris/amd64": {"solaris", "amd64"}, 51 | "windows/386": {"windows", "386"}, 52 | "windows/amd64": {"windows", "amd64"}, 53 | "windows/arm": {"windows", "arm"}, 54 | } 55 | 56 | versionDifferences := map[PostgresVersion]map[string][]string{ 57 | PostgresVersion("14.0.0"): {}, 58 | PostgresVersion("14.1.0"): {}, 59 | PostgresVersion("14.2.0"): {"darwin/arm64": {"darwin", "arm64v8"}}, 60 | V15: {"darwin/arm64": {"darwin", "arm64v8"}}, 61 | } 62 | defaultConfig := DefaultConfig() 63 | 64 | for version, differences := range versionDifferences { 65 | defaultConfig.version = version 66 | 67 | for dist, expected := range allGolangDistributions { 68 | dist := dist 69 | expected := expected 70 | 71 | if override, ok := differences[dist]; ok { 72 | expected = override 73 | } 74 | 75 | t.Run(fmt.Sprintf("DefaultVersionStrategy_%s", dist), func(t *testing.T) { 76 | osArch := strings.Split(dist, "/") 77 | 78 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 79 | defaultConfig, 80 | osArch[0], 81 | osArch[1], 82 | linuxMachineName, 83 | func() bool { 84 | return false 85 | })() 86 | 87 | assert.Equal(t, expected[0], operatingSystem) 88 | assert.Equal(t, expected[1], architecture) 89 | assert.Equal(t, version, postgresVersion) 90 | }) 91 | } 92 | } 93 | } 94 | 95 | func Test_DefaultVersionStrategy_Linux_ARM32V6(t *testing.T) { 96 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 97 | DefaultConfig(), 98 | "linux", 99 | "arm", 100 | func() string { 101 | return "armv6l" 102 | }, func() bool { 103 | return false 104 | })() 105 | 106 | assert.Equal(t, "linux", operatingSystem) 107 | assert.Equal(t, "arm32v6", architecture) 108 | assert.Equal(t, V18, postgresVersion) 109 | } 110 | 111 | func Test_DefaultVersionStrategy_Linux_ARM32V7(t *testing.T) { 112 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 113 | DefaultConfig(), 114 | "linux", 115 | "arm", 116 | func() string { 117 | return "armv7l" 118 | }, func() bool { 119 | return false 120 | })() 121 | 122 | assert.Equal(t, "linux", operatingSystem) 123 | assert.Equal(t, "arm32v7", architecture) 124 | assert.Equal(t, V18, postgresVersion) 125 | } 126 | 127 | func Test_DefaultVersionStrategy_Linux_Alpine(t *testing.T) { 128 | operatingSystem, architecture, postgresVersion := defaultVersionStrategy( 129 | DefaultConfig(), 130 | "linux", 131 | "amd64", 132 | func() string { 133 | return "" 134 | }, 135 | func() bool { 136 | return true 137 | }, 138 | )() 139 | 140 | assert.Equal(t, "linux", operatingSystem) 141 | assert.Equal(t, "amd64-alpine", architecture) 142 | assert.Equal(t, V18, postgresVersion) 143 | } 144 | 145 | func Test_DefaultVersionStrategy_shouldUseAlpineLinuxBuild(t *testing.T) { 146 | assert.NotPanics(t, func() { 147 | shouldUseAlpineLinuxBuild() 148 | }) 149 | } 150 | -------------------------------------------------------------------------------- /examples/examples_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "reflect" 8 | "testing" 9 | 10 | embeddedpostgres "github.com/fergusstrange/embedded-postgres" 11 | "github.com/jmoiron/sqlx" 12 | _ "github.com/lib/pq" 13 | "github.com/pressly/goose/v3" 14 | "go.uber.org/zap" 15 | "go.uber.org/zap/zapio" 16 | ) 17 | 18 | func Test_GooseMigrations(t *testing.T) { 19 | database := embeddedpostgres.NewDatabase() 20 | if err := database.Start(); err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | defer func() { 25 | if err := database.Stop(); err != nil { 26 | t.Fatal(err) 27 | } 28 | }() 29 | 30 | db, err := connect() 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | 35 | if err := goose.Up(db.DB, "./migrations"); err != nil { 36 | t.Fatal(err) 37 | } 38 | } 39 | 40 | func Test_ZapioLogger(t *testing.T) { 41 | logger, err := zap.NewProduction() 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | w := &zapio.Writer{Log: logger} 47 | 48 | database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). 49 | Logger(w)) 50 | if err := database.Start(); err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | defer func() { 55 | if err := database.Stop(); err != nil { 56 | t.Fatal(err) 57 | } 58 | }() 59 | 60 | db, err := connect() 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | if err := goose.Up(db.DB, "./migrations"); err != nil { 66 | t.Fatal(err) 67 | } 68 | } 69 | 70 | func Test_Sqlx_SelectOne(t *testing.T) { 71 | database := embeddedpostgres.NewDatabase() 72 | if err := database.Start(); err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | defer func() { 77 | if err := database.Stop(); err != nil { 78 | t.Fatal(err) 79 | } 80 | }() 81 | 82 | db, err := connect() 83 | if err != nil { 84 | t.Fatal(err) 85 | } 86 | 87 | rows := make([]int32, 0) 88 | 89 | err = db.Select(&rows, "SELECT 1") 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | 94 | if len(rows) != 1 { 95 | t.Fatal("Expected one row returned") 96 | } 97 | } 98 | 99 | func Test_ManyTestsAgainstOneDatabase(t *testing.T) { 100 | database := embeddedpostgres.NewDatabase() 101 | if err := database.Start(); err != nil { 102 | t.Fatal(err) 103 | } 104 | 105 | defer func() { 106 | if err := database.Stop(); err != nil { 107 | t.Fatal(err) 108 | } 109 | }() 110 | 111 | db, err := connect() 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | 116 | if err := goose.Up(db.DB, "./migrations"); err != nil { 117 | t.Fatal(err) 118 | } 119 | 120 | tests := []func(t *testing.T){ 121 | func(t *testing.T) { 122 | rows := make([]BeerCatalogue, 0) 123 | if err := db.Select(&rows, "SELECT * FROM beer_catalogue WHERE UPPER(name) = UPPER('Elvis Juice')"); err != nil { 124 | t.Fatal(err) 125 | } 126 | 127 | if len(rows) != 0 { 128 | t.Fatalf("expected 0 rows but got %d", len(rows)) 129 | } 130 | }, 131 | func(t *testing.T) { 132 | _, err := db.Exec(`INSERT INTO beer_catalogue (name, consumed, rating) VALUES ($1, $2, $3)`, 133 | "Kernal", 134 | true, 135 | 99.32) 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | 140 | actualBeerCatalogue := make([]BeerCatalogue, 0) 141 | if err := db.Select(&actualBeerCatalogue, "SELECT * FROM beer_catalogue WHERE id = 2"); err != nil { 142 | t.Fatal(err) 143 | } 144 | 145 | expectedBeerCatalogue := BeerCatalogue{ 146 | ID: 2, 147 | Name: "Kernal", 148 | Consumed: true, 149 | Rating: 99.32, 150 | } 151 | if !reflect.DeepEqual(expectedBeerCatalogue, actualBeerCatalogue[0]) { 152 | t.Fatalf("expected %+v did not match actual %+v", expectedBeerCatalogue, actualBeerCatalogue) 153 | } 154 | }, 155 | } 156 | 157 | for testNumber, test := range tests { 158 | t.Run(fmt.Sprintf("%d", testNumber), test) 159 | } 160 | } 161 | 162 | func Test_SimpleHttpWebApp(t *testing.T) { 163 | database := embeddedpostgres.NewDatabase() 164 | if err := database.Start(); err != nil { 165 | t.Fatal(err) 166 | } 167 | 168 | defer func() { 169 | if err := database.Stop(); err != nil { 170 | t.Fatal(err) 171 | } 172 | }() 173 | 174 | request := httptest.NewRequest("GET", "/beer-catalogue?name=Punk%20IPA", nil) 175 | recorder := httptest.NewRecorder() 176 | 177 | NewApp().router.ServeHTTP(recorder, request) 178 | 179 | if recorder.Code != http.StatusOK { 180 | t.Fatalf("expected 200 but receieved %d", recorder.Code) 181 | } 182 | 183 | expectedPayload := `[{"id":1,"name":"Punk IPA","consumed":true,"rating":68.29}]` 184 | actualPayload := recorder.Body.String() 185 | 186 | if actualPayload != expectedPayload { 187 | t.Fatalf("expected %+v but receieved %+v", expectedPayload, actualPayload) 188 | } 189 | } 190 | 191 | func connect() (*sqlx.DB, error) { 192 | db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") 193 | return db, err 194 | } 195 | -------------------------------------------------------------------------------- /prepare_database.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | 13 | "github.com/lib/pq" 14 | ) 15 | 16 | const ( 17 | fmtCloseDBConn = "unable to close database connection: %w" 18 | fmtAfterError = "%v happened after error: %w" 19 | ) 20 | 21 | type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error 22 | type createDatabase func(port uint32, username, password, database string) error 23 | 24 | func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error { 25 | passwordFile, err := createPasswordFile(runtimePath, password) 26 | if err != nil { 27 | return err 28 | } 29 | 30 | args := []string{ 31 | "-A", "password", 32 | "-U", username, 33 | "-D", pgDataDir, 34 | fmt.Sprintf("--pwfile=%s", passwordFile), 35 | } 36 | 37 | if locale != "" { 38 | args = append(args, fmt.Sprintf("--locale=%s", locale)) 39 | } 40 | 41 | if encoding != "" { 42 | args = append(args, fmt.Sprintf("--encoding=%s", encoding)) 43 | } 44 | 45 | postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb") 46 | postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...) 47 | postgresInitDBProcess.Stderr = logger 48 | postgresInitDBProcess.Stdout = logger 49 | 50 | if err = postgresInitDBProcess.Run(); err != nil { 51 | logContent, readLogsErr := readLogsOrTimeout(logger) // we want to preserve the original error 52 | if readLogsErr != nil { 53 | logContent = []byte(string(logContent) + " - " + readLogsErr.Error()) 54 | } 55 | return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent)) 56 | } 57 | 58 | if err = os.Remove(passwordFile); err != nil { 59 | return fmt.Errorf("unable to remove password file '%v': %w", passwordFile, err) 60 | } 61 | 62 | return nil 63 | } 64 | 65 | func createPasswordFile(runtimePath, password string) (string, error) { 66 | passwordFileLocation := filepath.Join(runtimePath, "pwfile") 67 | if err := os.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil { 68 | return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation) 69 | } 70 | 71 | return passwordFileLocation, nil 72 | } 73 | 74 | func defaultCreateDatabase(port uint32, username, password, database string) (err error) { 75 | if database == "postgres" { 76 | return nil 77 | } 78 | 79 | conn, err := openDatabaseConnection(port, username, password, "postgres") 80 | if err != nil { 81 | return errorCustomDatabase(database, err) 82 | } 83 | 84 | db := sql.OpenDB(conn) 85 | defer func() { 86 | err = connectionClose(db, err) 87 | }() 88 | 89 | if _, err := db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", database)); err != nil { 90 | return errorCustomDatabase(database, err) 91 | } 92 | 93 | return nil 94 | } 95 | 96 | // connectionClose closes the database connection and handles the error of the function that used the database connection 97 | func connectionClose(db io.Closer, err error) error { 98 | closeErr := db.Close() 99 | if closeErr != nil { 100 | closeErr = fmt.Errorf(fmtCloseDBConn, closeErr) 101 | 102 | if err != nil { 103 | err = fmt.Errorf(fmtAfterError, closeErr, err) 104 | } else { 105 | err = closeErr 106 | } 107 | } 108 | 109 | return err 110 | } 111 | 112 | func healthCheckDatabaseOrTimeout(config Config) error { 113 | healthCheckSignal := make(chan bool) 114 | 115 | defer close(healthCheckSignal) 116 | 117 | timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout) 118 | 119 | defer cancelFunc() 120 | 121 | go func() { 122 | for timeout.Err() == nil { 123 | if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil { 124 | continue 125 | } 126 | healthCheckSignal <- true 127 | 128 | break 129 | } 130 | }() 131 | 132 | select { 133 | case <-healthCheckSignal: 134 | return nil 135 | case <-timeout.Done(): 136 | return errors.New("timed out waiting for database to become available") 137 | } 138 | } 139 | 140 | func healthCheckDatabase(port uint32, database, username, password string) (err error) { 141 | conn, err := openDatabaseConnection(port, username, password, database) 142 | if err != nil { 143 | return err 144 | } 145 | 146 | db := sql.OpenDB(conn) 147 | defer func() { 148 | err = connectionClose(db, err) 149 | }() 150 | 151 | if _, err := db.Query("SELECT 1"); err != nil { 152 | return err 153 | } 154 | 155 | return nil 156 | } 157 | 158 | func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) { 159 | conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable", 160 | port, 161 | username, 162 | password, 163 | database)) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | return conn, nil 169 | } 170 | 171 | func errorCustomDatabase(database string, err error) error { 172 | return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err) 173 | } 174 | -------------------------------------------------------------------------------- /remote_fetch.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/zip" 5 | "bytes" 6 | "crypto/sha256" 7 | "encoding/hex" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "log" 12 | "net/http" 13 | "os" 14 | "path/filepath" 15 | "strings" 16 | ) 17 | 18 | // RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use. 19 | type RemoteFetchStrategy func() error 20 | 21 | //nolint:funlen 22 | func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy { 23 | return func() error { 24 | operatingSystem, architecture, version := versionStrategy() 25 | 26 | jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar", 27 | remoteFetchHost, 28 | operatingSystem, 29 | architecture, 30 | version, 31 | operatingSystem, 32 | architecture, 33 | version) 34 | 35 | jarDownloadResponse, err := http.Get(jarDownloadURL) 36 | if err != nil { 37 | return fmt.Errorf("unable to connect to %s", remoteFetchHost) 38 | } 39 | 40 | defer closeBody(jarDownloadResponse)() 41 | 42 | if jarDownloadResponse.StatusCode != http.StatusOK { 43 | return fmt.Errorf("no version found matching %s", version) 44 | } 45 | 46 | jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body) 47 | if err != nil { 48 | return errorFetchingPostgres(err) 49 | } 50 | 51 | shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL) 52 | shaDownloadResponse, err := http.Get(shaDownloadURL) 53 | if err != nil { 54 | return fmt.Errorf("download sha256 from %s failed: %w", shaDownloadURL, err) 55 | } 56 | defer closeBody(shaDownloadResponse)() 57 | 58 | if err == nil && shaDownloadResponse.StatusCode == http.StatusOK { 59 | if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil { 60 | jarChecksum := sha256.Sum256(jarBodyBytes) 61 | if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) { 62 | return errors.New("downloaded checksums do not match") 63 | } 64 | } 65 | } 66 | 67 | return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL) 68 | } 69 | } 70 | 71 | func closeBody(resp *http.Response) func() { 72 | return func() { 73 | if resp == nil || resp.Body == nil { 74 | return 75 | } 76 | if err := resp.Body.Close(); err != nil { 77 | log.Fatal(err) 78 | } 79 | } 80 | } 81 | 82 | func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error { 83 | size := contentLength 84 | // if the content length is not set (i.e. chunked encoding), 85 | // we need to use the length of the bodyBytes otherwise 86 | // the unzip operation will fail 87 | if contentLength < 0 { 88 | size = int64(len(bodyBytes)) 89 | } 90 | zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size) 91 | if err != nil { 92 | return errorFetchingPostgres(err) 93 | } 94 | 95 | cacheLocation, _ := cacheLocator() 96 | 97 | if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil { 98 | return errorExtractingPostgres(err) 99 | } 100 | 101 | for _, file := range zipReader.File { 102 | if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") { 103 | if err := decompressSingleFile(file, cacheLocation); err != nil { 104 | return err 105 | } 106 | 107 | // we have successfully found the file, return early 108 | return nil 109 | } 110 | } 111 | 112 | return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL) 113 | } 114 | 115 | func decompressSingleFile(file *zip.File, cacheLocation string) error { 116 | renamed := false 117 | 118 | archiveReader, err := file.Open() 119 | if err != nil { 120 | return errorExtractingPostgres(err) 121 | } 122 | 123 | archiveBytes, err := io.ReadAll(archiveReader) 124 | if err != nil { 125 | return errorExtractingPostgres(err) 126 | } 127 | 128 | // if multiple processes attempt to extract 129 | // to prevent file corruption when multiple processes attempt to extract at the same time 130 | // first to a cache location, and then move the file into place. 131 | tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_") 132 | if err != nil { 133 | return errorExtractingPostgres(err) 134 | } 135 | defer func() { 136 | // if anything failed before the rename then the temporary file should be cleaned up. 137 | // if the rename was successful then there is no temporary file to remove. 138 | if !renamed { 139 | if err := os.Remove(tmp.Name()); err != nil { 140 | panic(err) 141 | } 142 | } 143 | }() 144 | 145 | if _, err := tmp.Write(archiveBytes); err != nil { 146 | return errorExtractingPostgres(err) 147 | } 148 | 149 | // Windows cannot rename a file if is it still open. 150 | // The file needs to be manually closed to allow the rename to happen 151 | if err := tmp.Close(); err != nil { 152 | return errorExtractingPostgres(err) 153 | } 154 | 155 | if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil { 156 | return errorExtractingPostgres(err) 157 | } 158 | renamed = true 159 | 160 | return nil 161 | } 162 | 163 | func errorExtractingPostgres(err error) error { 164 | return fmt.Errorf("unable to extract postgres archive: %s", err) 165 | } 166 | 167 | func errorFetchingPostgres(err error) error { 168 | return fmt.Errorf("error fetching postgres: %s", err) 169 | } 170 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // Config maintains the runtime configuration for the Postgres process to be created. 11 | type Config struct { 12 | version PostgresVersion 13 | port uint32 14 | database string 15 | username string 16 | password string 17 | cachePath string 18 | runtimePath string 19 | dataPath string 20 | binariesPath string 21 | locale string 22 | encoding string 23 | startParameters map[string]string 24 | binaryRepositoryURL string 25 | startTimeout time.Duration 26 | logger io.Writer 27 | } 28 | 29 | // DefaultConfig provides a default set of configuration to be used "as is" or modified using the provided builders. 30 | // The following can be assumed as defaults: 31 | // Version: 16 32 | // Port: 5432 33 | // Database: postgres 34 | // Username: postgres 35 | // Password: postgres 36 | // StartTimeout: 15 Seconds 37 | func DefaultConfig() Config { 38 | return Config{ 39 | version: V18, 40 | port: 5432, 41 | database: "postgres", 42 | username: "postgres", 43 | password: "postgres", 44 | startTimeout: 15 * time.Second, 45 | logger: os.Stdout, 46 | binaryRepositoryURL: "https://repo1.maven.org/maven2", 47 | } 48 | } 49 | 50 | // Version will set the Postgres binary version. 51 | func (c Config) Version(version PostgresVersion) Config { 52 | c.version = version 53 | return c 54 | } 55 | 56 | // Port sets the runtime port that Postgres can be accessed on. 57 | func (c Config) Port(port uint32) Config { 58 | c.port = port 59 | return c 60 | } 61 | 62 | // Database sets the database name that will be created. 63 | func (c Config) Database(database string) Config { 64 | c.database = database 65 | return c 66 | } 67 | 68 | // Username sets the username that will be used to connect. 69 | func (c Config) Username(username string) Config { 70 | c.username = username 71 | return c 72 | } 73 | 74 | // Password sets the password that will be used to connect. 75 | func (c Config) Password(password string) Config { 76 | c.password = password 77 | return c 78 | } 79 | 80 | // RuntimePath sets the path that will be used for the extracted Postgres runtime directory. 81 | // If Postgres data directory is not set with DataPath(), this directory is also used as data directory. 82 | func (c Config) RuntimePath(path string) Config { 83 | c.runtimePath = path 84 | return c 85 | } 86 | 87 | // CachePath sets the path that will be used for storing Postgres binaries archive. 88 | // If this option is not set, ~/.go-embedded-postgres will be used. 89 | func (c Config) CachePath(path string) Config { 90 | c.cachePath = path 91 | return c 92 | } 93 | 94 | // DataPath sets the path that will be used for the Postgres data directory. 95 | // If this option is set, a previously initialized data directory will be reused if possible. 96 | func (c Config) DataPath(path string) Config { 97 | c.dataPath = path 98 | return c 99 | } 100 | 101 | // BinariesPath sets the path of the pre-downloaded postgres binaries. 102 | // If this option is left unset, the binaries will be downloaded. 103 | func (c Config) BinariesPath(path string) Config { 104 | c.binariesPath = path 105 | return c 106 | } 107 | 108 | // Locale sets the default locale for initdb 109 | func (c Config) Locale(locale string) Config { 110 | c.locale = locale 111 | return c 112 | } 113 | 114 | // Encoding sets the default character set for initdb 115 | func (c Config) Encoding(encoding string) Config { 116 | c.encoding = encoding 117 | return c 118 | } 119 | 120 | // StartParameters sets run-time parameters when starting Postgres (passed to Postgres via "-c"). 121 | // 122 | // These parameters can be used to override the default configuration values in postgres.conf such 123 | // as max_connections=100. See https://www.postgresql.org/docs/current/runtime-config.html 124 | func (c Config) StartParameters(parameters map[string]string) Config { 125 | c.startParameters = parameters 126 | return c 127 | } 128 | 129 | // StartTimeout sets the max timeout that will be used when starting the Postgres process and creating the initial database. 130 | func (c Config) StartTimeout(timeout time.Duration) Config { 131 | c.startTimeout = timeout 132 | return c 133 | } 134 | 135 | // Logger sets the logger for postgres output 136 | func (c Config) Logger(logger io.Writer) Config { 137 | c.logger = logger 138 | return c 139 | } 140 | 141 | // BinaryRepositoryURL set BinaryRepositoryURL to fetch PG Binary in case of Maven proxy 142 | func (c Config) BinaryRepositoryURL(binaryRepositoryURL string) Config { 143 | c.binaryRepositoryURL = binaryRepositoryURL 144 | return c 145 | } 146 | 147 | func (c Config) GetConnectionURL() string { 148 | return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.username, c.password, "localhost", c.port, c.database) 149 | } 150 | 151 | // PostgresVersion represents the semantic version used to fetch and run the Postgres process. 152 | type PostgresVersion string 153 | 154 | // Predefined supported Postgres versions. 155 | const ( 156 | V18 = PostgresVersion("18.0.0") 157 | V17 = PostgresVersion("17.5.0") 158 | V16 = PostgresVersion("16.9.0") 159 | V15 = PostgresVersion("15.13.0") 160 | V14 = PostgresVersion("14.18.0") 161 | V13 = PostgresVersion("13.21.0") 162 | V12 = PostgresVersion("12.22.0") 163 | V11 = PostgresVersion("11.22.0") 164 | V10 = PostgresVersion("10.23.0") 165 | V9 = PostgresVersion("9.6.24") 166 | ) 167 | -------------------------------------------------------------------------------- /decompression_test.go: -------------------------------------------------------------------------------- 1 | package embeddedpostgres 2 | 3 | import ( 4 | "archive/tar" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "os" 9 | "path" 10 | "path/filepath" 11 | "syscall" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | "github.com/xi2/xz" 17 | ) 18 | 19 | func Test_decompressTarXz(t *testing.T) { 20 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 21 | if err != nil { 22 | panic(err) 23 | } 24 | if err := syscall.Rmdir(tempDir); err != nil { 25 | panic(err) 26 | } 27 | 28 | archive, cleanUp := createTempXzArchive() 29 | defer cleanUp() 30 | 31 | err = decompressTarXz(defaultTarReader, archive, tempDir) 32 | 33 | assert.NoError(t, err) 34 | 35 | expectedExtractedFileLocation := filepath.Join(tempDir, "dir1", "dir2", "some_content") 36 | assert.FileExists(t, expectedExtractedFileLocation) 37 | 38 | fileContentBytes, err := os.ReadFile(expectedExtractedFileLocation) 39 | assert.NoError(t, err) 40 | 41 | assert.Equal(t, "b33r is g00d", string(fileContentBytes)) 42 | } 43 | 44 | func Test_decompressTarXz_ErrorWhenFileNotExists(t *testing.T) { 45 | err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake") 46 | 47 | assert.Error(t, err) 48 | assert.Contains( 49 | t, 50 | err.Error(), 51 | "unable to extract postgres archive /does-not-exist to /also-fake, if running parallel tests, configure RuntimePath to isolate testing directories", 52 | ) 53 | } 54 | 55 | func Test_decompressTarXz_ErrorWhenErrorDuringRead(t *testing.T) { 56 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 57 | if err != nil { 58 | panic(err) 59 | } 60 | if err := syscall.Rmdir(tempDir); err != nil { 61 | panic(err) 62 | } 63 | 64 | archive, cleanUp := createTempXzArchive() 65 | defer cleanUp() 66 | 67 | err = decompressTarXz(func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 68 | return func() (*tar.Header, error) { 69 | return nil, errors.New("oh noes") 70 | }, nil 71 | }, archive, tempDir) 72 | 73 | assert.EqualError(t, err, "unable to extract postgres archive: oh noes") 74 | } 75 | 76 | func Test_decompressTarXz_ErrorWhenFailedToReadFileToCopy(t *testing.T) { 77 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | archive, cleanUp := createTempXzArchive() 83 | defer cleanUp() 84 | 85 | blockingFile := filepath.Join(tempDir, "blocking") 86 | 87 | if err = os.WriteFile(blockingFile, []byte("wazz"), 0000); err != nil { 88 | panic(err) 89 | } 90 | 91 | fileBlockingExtractTarReader := func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 92 | shouldReadFile := true 93 | 94 | return func() (*tar.Header, error) { 95 | if shouldReadFile { 96 | shouldReadFile = false 97 | 98 | return &tar.Header{ 99 | Typeflag: tar.TypeReg, 100 | Name: "blocking", 101 | }, nil 102 | } 103 | 104 | return nil, io.EOF 105 | }, func() io.Reader { 106 | open, _ := os.Open("file_not_exists") 107 | return open 108 | } 109 | } 110 | 111 | err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) 112 | 113 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 114 | } 115 | 116 | func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) { 117 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 118 | if err != nil { 119 | panic(err) 120 | } 121 | if err := syscall.Rmdir(tempDir); err != nil { 122 | panic(err) 123 | } 124 | 125 | archive, cleanUp := createTempXzArchive() 126 | defer cleanUp() 127 | 128 | fileBlockingExtractTarReader := func(reader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { 129 | shouldReadFile := true 130 | 131 | return func() (*tar.Header, error) { 132 | if shouldReadFile { 133 | shouldReadFile = false 134 | 135 | return &tar.Header{ 136 | Typeflag: tar.TypeReg, 137 | Name: "some_dir/wazz/dazz/fazz", 138 | }, nil 139 | } 140 | 141 | return nil, io.EOF 142 | }, func() io.Reader { 143 | open, _ := os.Open("file_not_exists") 144 | return open 145 | } 146 | } 147 | 148 | err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) 149 | 150 | assert.Regexp(t, "^unable to extract postgres archive:.+$", err) 151 | } 152 | 153 | func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) { 154 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 155 | if err != nil { 156 | panic(err) 157 | } 158 | if err := syscall.Rmdir(tempDir); err != nil { 159 | panic(err) 160 | } 161 | 162 | archive, cleanup := createTempXzArchive() 163 | 164 | defer cleanup() 165 | 166 | file, err := os.OpenFile(archive, os.O_WRONLY, 0664) 167 | if err != nil { 168 | panic(err) 169 | } 170 | 171 | if _, err := file.Seek(35, 0); err != nil { 172 | panic(err) 173 | } 174 | 175 | if _, err := file.WriteString("someJunk"); err != nil { 176 | panic(err) 177 | } 178 | 179 | if err := file.Close(); err != nil { 180 | panic(err) 181 | } 182 | 183 | err = decompressTarXz(defaultTarReader, archive, tempDir) 184 | 185 | assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt") 186 | } 187 | 188 | func Test_decompressTarXz_ErrorWithInvalidDestination(t *testing.T) { 189 | archive, cleanUp := createTempXzArchive() 190 | defer cleanUp() 191 | 192 | tempDir, err := os.MkdirTemp("", "temp_tar_test") 193 | require.NoError(t, err) 194 | defer func() { 195 | os.RemoveAll(tempDir) 196 | }() 197 | 198 | op := fmt.Sprintf(path.Join(tempDir, "%c"), rune(0)) 199 | 200 | err = decompressTarXz(defaultTarReader, archive, op) 201 | assert.EqualError( 202 | t, 203 | err, 204 | fmt.Sprintf("unable to extract postgres archive: mkdir %s: invalid argument", op), 205 | ) 206 | } 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |