├── .gitignore ├── Makefile ├── go.mod ├── go.sum ├── docker-compose.yml ├── LICENSE ├── README.md ├── main_test.go └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | vendor/ 3 | cover.out 4 | coverage.html 5 | *.swp 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | docker-compose up --detach 3 | go test ./... 4 | docker-compose down -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tristanfisher/ivory/v2 2 | 3 | go 1.20 4 | 5 | require github.com/lib/pq v1.10.9 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 3 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | ivory: 4 | build: 5 | context: . 6 | image: "postgres:14.1-alpine" 7 | environment: 8 | # POSTGRES_USER, POSTGRES_PASSWORD is the setup for the superuser 9 | POSTGRES_USER: postgres 10 | POSTGRES_PASSWORD: rootUserSeriousPassword1 11 | POSTGRES_DB: ivoryPgExisting 12 | ports: 13 | - "5555:5432" 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tristan Fisher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ivory 🐘 2 | 3 | [![Go 4 | Reference](https://pkg.go.dev/badge/github.com/tristanfisher/ivory.svg)](https://pkg.go.dev/github.com/tristanfisher/ivory) 5 | [![Go Report 6 | Card](https://goreportcard.com/badge/github.com/tristanfisher/ivory)](https://goreportcard.com/report/github.com/tristanfisher/ivory) 7 | 8 | ## Overview 9 | 10 | Ivory makes it easy for you to create and manage PostgreSQL databases via a Go program. 11 | 12 | This is particularly useful during tests or any other time that you only want a pg database with a lifecycle that is 13 | managed programmatically. 14 | 15 | ## Usage 16 | 17 | When not told to create a database or run SQL, ivory will simply try to connect to a PostgreSQL database, return 18 | database handles, and an function for dropping an existing database and closing DB handles. 19 | 20 | However, the primary goal of this library is to make it easy to bootstrap and cleanup databases. 21 | 22 | The following is likely all you need to know in order to make use of this library: 23 | 24 | - If a database name is provided in the options, a random database name is not generated. Otherwise, a 25 | collision-resistant database name is generated. 26 | - If told to create a database, Ivory will oblige and bind 2 connections: 27 | 1. Server-scoped (connection without database in connection string -- required to drop a database instance) 28 | 2. Database-scoped (connection with database in connection string) 29 | - A slice of strings may be provided for Ivory to be treated as migrations. These may include any valid SQL, including 30 | transactions. 31 | - A "tear down function" that will drop the created (or specified) database and close database handles is returned. 32 | Calling this is optional. Ivory does not implicitly drop databases or close connections. 33 | 34 | As an example: 35 | 36 | ```go 37 | dbOptions := &ivory.DatabaseOptions{ 38 | Host: "localhost", 39 | Port: 5555, 40 | SslMode: "disable", 41 | User: "postgres", 42 | Password: "rootUserSeriousPassword1", 43 | } 44 | // we discard two db handles here, one to the instance, one scoped to the database in the instance teardown() closes 45 | // database handles and cleans up the created database 46 | _, _, dbName, teardown, err := ivory.New( 47 | context.TODO(), 48 | dbOptions, 49 | []string{"CREATE TABLE grocery_list (item char(25))"}, 50 | true, 51 | "my_app") 52 | defer teardown() 53 | if err != nil { 54 | fmt.Println("error creating: ", err) 55 | return 56 | } 57 | fmt.Printf("created: %s", dbName) 58 | ``` 59 | 60 | After the deferred function is called, the database is dropped. 61 | 62 | As an example of how to create many databases: 63 | 64 | ```go 65 | // set aside database-context connections, say if we want to use them as normal *sql.DB handles 66 | availableDBs := make(map[string]*sql.DB, 0) 67 | for i := 0; i < 10; i++ { 68 | // important: must be inside loop as dbOptions are updated in place! 69 | opts := &ivory.DatabaseOptions{ 70 | Host: "localhost", 71 | Port: 5555, 72 | SslMode: "disable", 73 | User: "postgres", 74 | Password: "rootUserSeriousPassword1", 75 | } 76 | 77 | // we discard two db handles here, one to the instance, one scoped to the database in the instance teardown() closes 78 | // database handles and cleans up the created database 79 | _, dbScopedConn, dbName, teardown, err := ivory.New(ctx, opts, fixtureTable, true, "") 80 | if err != nil { 81 | fmt.Println("=> failed to create a database: ", err) 82 | continue 83 | } 84 | defer teardown() 85 | availableDBs[dbName] = dbScopedConn 86 | } 87 | databasesCreated := "" 88 | for k, _ := range availableDBs { 89 | databasesCreated += fmt.Sprintf("%s ", k) 90 | } 91 | fmt.Printf("created: %s", databasesCreated) 92 | // created: _disp_pg_1664125384_q57qcq8qid6a7k7d _disp_pg_1664125384_1yj7i14v7iqauzln _disp_pg_1664125385_galb4ijb... 93 | ``` 94 | 95 | That's it! No need to act as a database janitor or deal with a slowly growing PostgreSQLinstance. 96 | 97 | That said, if your code panics or the context is cancelled before the deferred functions can run, 98 | `FindLikelyAbandonedDBs()` and `DropDB()` are available to you for finding and dropping databases, respectively. 99 | 100 | ## Development / Contribution 101 | 102 | Pull requests or GitHub issues are welcomed. 103 | 104 | If you are creating a pull request, please include tests as well as a description of the problem being solved. 105 | 106 | If you are opening a GitHub issue, please include the error message and verify that your database is listening to 107 | connections (`connection refused` likely means the wrong database host or port is being specified). 108 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package ivory 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "reflect" 8 | "strconv" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | const createDatabaseInSqlTextTest = "createDatabaseInSqlText" 14 | const sqlTextDbName = "example_created_in_sqlText" 15 | 16 | func Test_mightHaveTransaction(t *testing.T) { 17 | type args struct { 18 | sqlText string 19 | } 20 | tests := []struct { 21 | name string 22 | args args 23 | want bool 24 | }{ 25 | { 26 | name: "obvious", 27 | args: args{ 28 | sqlText: "begin statement", 29 | }, 30 | want: true, 31 | }, 32 | { 33 | name: "noteSpace", 34 | args: args{ 35 | sqlText: "begin ", 36 | }, 37 | want: true, 38 | }, 39 | { 40 | name: "dontBeGreedy", 41 | args: args{ 42 | sqlText: "create table beginnings", 43 | }, 44 | want: false, 45 | }, 46 | { 47 | name: "nothingMatching", 48 | args: args{ 49 | sqlText: "insert into table pub (table_number, party_size, time) values (3, 7, '2021-10-13 03:38:57.914877')", 50 | }, 51 | want: false, 52 | }, 53 | { 54 | name: "empty", 55 | args: args{ 56 | sqlText: "", 57 | }, 58 | want: false, 59 | }, 60 | } 61 | for _, tt := range tests { 62 | t.Run(tt.name, func(t *testing.T) { 63 | if got := mightHaveTransaction(tt.args.sqlText); got != tt.want { 64 | t.Errorf("mightHaveTransaction() = %v, want %v", got, tt.want) 65 | } 66 | }) 67 | } 68 | } 69 | 70 | func TestDatabaseOptions_DSN(t *testing.T) { 71 | type fields struct { 72 | Host string 73 | Port int 74 | Database string 75 | Schema string 76 | User string 77 | Password string 78 | SslMode string 79 | SslCert string 80 | SslKey string 81 | SslRootcert string 82 | SslCertMode string 83 | ConnectTimeoutSeconds int 84 | MaxOpenConns int 85 | MaxIdleConns int 86 | 87 | reflectType reflect.Type 88 | } 89 | tests := []struct { 90 | name string 91 | fields fields 92 | want string 93 | wantErr bool 94 | }{ 95 | { 96 | name: "simpleString", 97 | fields: fields{ 98 | Host: "localhost", 99 | }, 100 | want: "host='localhost'", 101 | }, 102 | { 103 | name: "digit", 104 | fields: fields{ 105 | Port: 1024, 106 | }, 107 | want: "port=1024", 108 | }, 109 | { 110 | name: "hostPortIsSpaced", 111 | fields: fields{ 112 | Host: "localhost", 113 | Port: 1024, 114 | }, 115 | want: "host='localhost' port=1024", 116 | }, 117 | { 118 | name: "emptyStringSkipped", 119 | fields: fields{ 120 | Host: "", 121 | }, 122 | want: "", 123 | }, 124 | { 125 | name: "zeroIntSkipped", 126 | fields: fields{ 127 | Port: 0, 128 | }, 129 | want: "", 130 | }, 131 | { 132 | name: "invalidSslError", 133 | fields: fields{ 134 | SslMode: "post rock", 135 | }, 136 | want: "", 137 | wantErr: true, 138 | }, 139 | { 140 | name: "allFields", 141 | fields: fields{ 142 | Host: "localhost", 143 | Port: 1024, 144 | Database: "exampleDB", 145 | Schema: "exampleSchema", 146 | User: "user", 147 | Password: "password", 148 | SslMode: "verify-full", 149 | SslCert: "./cert.pem", 150 | SslKey: "./key.pem", 151 | SslRootcert: "./rootcert.pem", 152 | SslCertMode: "require", 153 | ConnectTimeoutSeconds: 10, 154 | }, 155 | want: "host='localhost' port=1024 dbname='exampleDB' search_path='exampleSchema' user='user' password='password' sslmode=verify-full sslcert='./cert.pem' sslkey='./key.pem' sslrootcert='./rootcert.pem' sslcertmode=require connect_timeout=10", 156 | }, 157 | } 158 | for _, tt := range tests { 159 | t.Run(tt.name, func(t *testing.T) { 160 | do := &DatabaseOptions{ 161 | Host: tt.fields.Host, 162 | Port: tt.fields.Port, 163 | Database: tt.fields.Database, 164 | Schema: tt.fields.Schema, 165 | User: tt.fields.User, 166 | Password: tt.fields.Password, 167 | SslMode: tt.fields.SslMode, 168 | SslCert: tt.fields.SslCert, 169 | SslKey: tt.fields.SslKey, 170 | SslRootCert: tt.fields.SslRootcert, 171 | SslCertMode: tt.fields.SslCertMode, 172 | ConnectTimeoutSeconds: tt.fields.ConnectTimeoutSeconds, 173 | MaxOpenConns: tt.fields.MaxOpenConns, 174 | MaxIdleConns: tt.fields.MaxIdleConns, 175 | reflectType: tt.fields.reflectType, 176 | } 177 | 178 | got, err := do.DSN() 179 | if (err != nil) != tt.wantErr { 180 | t.Errorf("DSN() error = %v, wantErr %v", err, tt.wantErr) 181 | return 182 | } 183 | if got != tt.want { 184 | t.Errorf("DSN() = %v, want %v", got, tt.want) 185 | } 186 | }) 187 | } 188 | } 189 | 190 | // this convenience func creates a string with a temporal and random portion 191 | // as such, we only test len does not exceed our expectation and that repeated runs don't collide on names 192 | func Test_generateDbName(t *testing.T) { 193 | 194 | // e.g. 1634246356 195 | lenTimeNow := len(strconv.FormatInt(time.Now().Unix(), 10)) 196 | 197 | type args struct { 198 | userProvidedId string 199 | } 200 | tests := []struct { 201 | name string 202 | args args 203 | wantLen int 204 | }{ 205 | { 206 | name: "noIdProvided", 207 | args: args{ 208 | userProvidedId: "", 209 | }, 210 | wantLen: PgMaxIdentifierLen - (remainingNameBudget - lenTimeNow), 211 | }, 212 | { 213 | name: "IDVeryLong", 214 | args: args{ 215 | userProvidedId: "aSBhbSByZWFsbHkgc2ljayBvZiBkb2luZyB3ZWIgZGV2ZWxvcG1lbnQhICBpcyBhbnlvbmUgaW50ZXJlc3RpbmcgaGlyaW5nPz8K", 216 | }, 217 | wantLen: PgMaxIdentifierLen, 218 | }, 219 | } 220 | for _, tt := range tests { 221 | t.Run(tt.name, func(t *testing.T) { 222 | got := generateDbName(tt.args.userProvidedId) 223 | gotLen := len(got) 224 | if gotLen != tt.wantLen { 225 | t.Errorf("generateDbName() got len = %v, want len %v, value: %s", gotLen, tt.wantLen, got) 226 | } 227 | }) 228 | } 229 | } 230 | 231 | // Corresponds to values in docker-compose.yml 232 | func TestNew(t *testing.T) { 233 | type args struct { 234 | ctx context.Context 235 | opts *DatabaseOptions 236 | sqlText []string 237 | createDatabase bool 238 | customIdPortion string 239 | } 240 | tests := []struct { 241 | name string 242 | args args 243 | wantDBName string 244 | wantErr bool 245 | }{ 246 | { 247 | name: "failToConnect", 248 | args: args{ 249 | ctx: context.TODO(), 250 | opts: &DatabaseOptions{ 251 | Host: "test_failToConnect.example.org", 252 | Port: 99999, 253 | Database: "abc123", 254 | }, 255 | sqlText: []string{}, 256 | }, 257 | wantDBName: "", // no db will be created if we can't connect 258 | wantErr: true, 259 | }, 260 | { 261 | name: "createRandomDB", 262 | args: args{ 263 | ctx: context.TODO(), 264 | opts: &DatabaseOptions{ 265 | Host: "localhost", 266 | Port: 5555, 267 | SslMode: "disable", 268 | User: "postgres", 269 | Password: "rootUserSeriousPassword1", 270 | }, 271 | createDatabase: true, 272 | sqlText: []string{}, 273 | }, 274 | wantDBName: "", // we don't know what it will be 275 | wantErr: false, 276 | }, 277 | { 278 | name: "honorDBNameCreateDB", 279 | args: args{ 280 | ctx: context.TODO(), 281 | opts: &DatabaseOptions{ 282 | Host: "localhost", 283 | Port: 5555, 284 | Database: "flannel", 285 | SslMode: "disable", 286 | User: "postgres", 287 | Password: "rootUserSeriousPassword1", 288 | }, 289 | createDatabase: true, 290 | sqlText: []string{"CREATE TABLE foo ( hello CHAR(5));"}, 291 | }, 292 | wantDBName: "flannel", 293 | wantErr: false, 294 | }, 295 | // if the user creates the database in sql text, they must clean up their own database(s) 296 | { 297 | name: "createDatabaseInSqlText", 298 | args: args{ 299 | ctx: context.TODO(), 300 | opts: &DatabaseOptions{ 301 | Host: "localhost", 302 | Port: 5555, 303 | SslMode: "disable", 304 | User: "postgres", 305 | Password: "rootUserSeriousPassword1", 306 | }, 307 | createDatabase: false, 308 | sqlText: []string{fmt.Sprintf("CREATE DATABASE %s;", sqlTextDbName)}, 309 | }, 310 | }, 311 | } 312 | for _, tt := range tests { 313 | t.Run(tt.name, func(t *testing.T) { 314 | dbHandleNoBoundDB, dbHandleBound, dbName, tearDown, err := New(tt.args.ctx, tt.args.opts, tt.args.sqlText, tt.args.createDatabase, tt.args.customIdPortion) 315 | // success of the tear down function should be tested separately 316 | defer func() { 317 | // a bit of a hack, but the "create db in the provided text" is a special, advanced case 318 | // where the caller is responsible. come up with more elegant behavior if time permits. 319 | // one naive approach would be allowing the user to include a snippet to run on tearDown 320 | // that executes before closing connections. 321 | if tt.name == createDatabaseInSqlTextTest { 322 | return 323 | } 324 | 325 | err = tearDown() 326 | // if we got an err on connection, we could not have created the database 327 | if (err != nil) != tt.wantErr { 328 | // stopping before creating more mess 329 | t.Errorf("Failed to clean up created database: %s. error = %v", dbName, err) 330 | t.FailNow() 331 | } 332 | }() 333 | // if we received an error, it can be a failure to connect to the db or a failure running sql 334 | if (err != nil) != tt.wantErr { 335 | t.Errorf("New() did not expect error. error = %v", err) 336 | return 337 | } 338 | if tt.wantErr { 339 | return 340 | } 341 | 342 | if len(tt.wantDBName) > 0 { 343 | if tt.wantDBName != dbName { 344 | t.Errorf("New() did not honor the provided DB name. got name = %v, want name = %v", dbName, tt.wantDBName) 345 | return 346 | } 347 | } 348 | 349 | if tt.name == createDatabaseInSqlTextTest { 350 | // make our own teardown func for this specific test 351 | err = tearDownFunc(tt.args.ctx, dbHandleNoBoundDB, dbHandleBound, sqlTextDbName)() 352 | if (err != nil) != tt.wantErr { 353 | // stopping before creating more mess 354 | t.Errorf("Failed to clean up created database: %s. error = %v", sqlTextDbName, err) 355 | t.FailNow() 356 | } 357 | } 358 | 359 | }) 360 | } 361 | } 362 | 363 | func TestFindLikelyAbandonedDBs(t *testing.T) { 364 | 365 | ctx := context.Background() 366 | expectedToFind := make([]string, 0) 367 | 368 | testLocalTimeString := strconv.FormatInt(time.Now().Unix(), 10) 369 | prefix := fmt.Sprintf("TestFLAD_%s", testLocalTimeString) 370 | 371 | // create databases that we want to find later 372 | // discarding teardown 373 | dbHandleNoBoundDB1, dbHandleBound1, dbName1, _, err := New( 374 | ctx, 375 | &DatabaseOptions{ 376 | Host: "localhost", 377 | Port: 5555, 378 | Database: fmt.Sprintf(prefix + "_1"), 379 | SslMode: "disable", 380 | User: "postgres", 381 | Password: "rootUserSeriousPassword1", 382 | }, 383 | []string{}, 384 | true, 385 | "", 386 | ) 387 | if err != nil { 388 | t.Errorf("Test setup failed while creating DBs to abandon. error = %v ", err) 389 | t.FailNow() 390 | } 391 | expectedToFind = append(expectedToFind, dbName1) 392 | 393 | dbHandleNoBoundDB2, dbHandleBound2, dbName2, _, err := New( 394 | ctx, 395 | &DatabaseOptions{ 396 | Host: "localhost", 397 | Port: 5555, 398 | Database: fmt.Sprintf(prefix + "_2"), 399 | SslMode: "disable", 400 | User: "postgres", 401 | Password: "rootUserSeriousPassword1", 402 | }, 403 | []string{}, 404 | true, 405 | "", 406 | ) 407 | if err != nil { 408 | t.Errorf("Test setup failed while creating fixture DBs. error = %v ", err) 409 | t.FailNow() 410 | } 411 | expectedToFind = append(expectedToFind, dbName2) 412 | 413 | // close without using our a cleanup function to simulate left work behind 414 | for _, closeFunc := range []func() error{ 415 | dbHandleNoBoundDB1.Close, dbHandleBound1.Close, dbHandleNoBoundDB2.Close, dbHandleBound2.Close} { 416 | err = closeFunc() 417 | if err != nil { 418 | t.Errorf("Test setup failed closing handle for fixture DB. error = %v ", err) 419 | t.FailNow() 420 | } 421 | } 422 | 423 | // bind a new connection not specific to a database 424 | noDBOpts := &DatabaseOptions{ 425 | Host: "localhost", 426 | Port: 5555, 427 | ConnectTimeoutSeconds: 10, 428 | SslMode: "disable", 429 | User: "postgres", 430 | Password: "rootUserSeriousPassword1", 431 | } 432 | dbHandle, err := Connect(ctx, noDBOpts) 433 | if err != nil { 434 | t.Errorf("Test setup failed while creating a new DB handle. error = %v", err) 435 | t.FailNow() 436 | } 437 | 438 | defer func(dbHandle *sql.DB) { 439 | err := dbHandle.Close() 440 | if err != nil { 441 | // not really an error in the test, but good to know 442 | // there is no t.Warn or t.Info 443 | t.Errorf("Failed to clean up database handle for test.") 444 | } 445 | }(dbHandle) 446 | 447 | dbsFound, err := FindLikelyAbandonedDBs(ctx, dbHandle, prefix) 448 | if err != nil { 449 | t.Errorf("FindLikelyAbandonedDBs() failed to run. error = %v", err) 450 | t.FailNow() 451 | return 452 | } 453 | 454 | // we could have some databases left over between tests, but it's really unlikely 455 | if len(dbsFound) != len(expectedToFind) { 456 | t.Errorf( 457 | "FindLikelyAbandonedDBs() unexpected number of databases found. got len() = %v, want len = %v ", len(dbsFound), len(expectedToFind)) 458 | if len(dbsFound) == 0 { 459 | // no point continuing with no work to do 460 | t.FailNow() 461 | } 462 | } 463 | 464 | dbsFoundTable := make(map[string]struct{}, len(dbsFound)) 465 | for _, v := range dbsFound { 466 | dbsFoundTable[v] = struct{}{} 467 | } 468 | 469 | for _, expectedDB := range expectedToFind { 470 | _, ok := dbsFoundTable[expectedDB] 471 | if !ok { 472 | t.Errorf("FindLikelyAbandonedDBs() did not find a fixture db. found: %s, missing db: %s", dbsFound, expectedDB) 473 | } 474 | // we don't want to use a teardown func in a loop as we'll close the database handle on the first iteration 475 | } 476 | _, errSlice := DropDB(ctx, dbHandle, expectedToFind) 477 | for i, e := range errSlice { 478 | if e != nil { 479 | t.Errorf("FindLikelyAbandonedDBs() failed to clean up a test-created database: %s. error = %v", expectedToFind[i], e) 480 | } 481 | } 482 | } 483 | 484 | // TestNew_UserProvidedSQL specifically tests the functionality of running user-provided SQL 485 | func TestNew_UserProvidedSQL(t *testing.T) { 486 | 487 | ctx := context.Background() 488 | _, dbHandleBound, dbName, tearDownFunc, err := New( 489 | ctx, 490 | &DatabaseOptions{ 491 | Host: "localhost", 492 | Port: 5555, 493 | SslMode: "disable", 494 | User: "postgres", 495 | Password: "rootUserSeriousPassword1", 496 | }, 497 | []string{ 498 | "create schema ivory;", 499 | "create table ivory.disposable_table ( hello char(5));", 500 | "insert into ivory.disposable_table (hello) values ('world');", 501 | }, 502 | true, 503 | "_test_ups", 504 | ) 505 | 506 | defer func() { 507 | err := tearDownFunc() 508 | if err != nil { 509 | t.Errorf("Failed to clean up created database: %s. error = %v", dbName, err) 510 | } 511 | }() 512 | 513 | // test database creation without error 514 | if err != nil { 515 | t.Errorf("New() failed to create database for test. error = %v", err) 516 | t.FailNow() 517 | } 518 | 519 | // confirm our table exists 520 | rows, err := dbHandleBound.QueryContext(ctx, "SELECT hello FROM ivory.disposable_table LIMIT 1;") 521 | if err != nil { 522 | t.Errorf("Failed to ") 523 | t.FailNow() 524 | } 525 | var r string 526 | 527 | for rows.Next() { 528 | err = rows.Scan(&r) 529 | if err != nil { 530 | t.Errorf("Failed to scan results. error = %v", err) 531 | t.FailNow() 532 | } 533 | } 534 | 535 | if r != "world" { 536 | t.Errorf("New() failed to run sql expressions during setup. got = %v, want value: %v", r, "world") 537 | } 538 | 539 | } 540 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package ivory 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "database/sql" 7 | "errors" 8 | "fmt" 9 | "math/rand" 10 | "reflect" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | _ "github.com/lib/pq" 16 | ) 17 | 18 | // PgMaxIdentifierLen is the longest a postgres identifier may be without truncation. For our purposes, we consider 19 | // this to be a limit to avoid 20 | // 21 | // The system uses no more than NAMEDATALEN-1 bytes of an identifier. Longer names can be written in commands, but they will be truncated. 22 | // By default, NAMEDATALEN is 64 so the maximum identifier length is 63 bytes. 23 | // more detail: https://www.postgresql.org/docs/14/sql-syntax-lexical.html 24 | // 25 | // Its length is currently defined as 64 bytes (63 usable characters plus terminator), but should be referenced 26 | // using the constant NAMEDATALEN in C source code. 27 | // more detail: https://www.postgresql.org/docs/current/datatype-character.html 28 | const PgMaxIdentifierLen = 63 29 | 30 | // we suffix autogenerated db names 31 | const randSuffixLen = 16 32 | 33 | // this makes it easy to find our generated DBs 34 | const AutogenDBPrefix = "_disp_pg" // _ is added as a suffix as part of a join 35 | 36 | // Calculate how many characters we can use to generate our table name 37 | // +1 for autogenDBPrefix "_" join, +1 for customIdPortion "_" join, but not one for the final randSuffix 38 | const remainingNameBudget = PgMaxIdentifierLen - (len(AutogenDBPrefix) + randSuffixLen + 2) 39 | 40 | // these are string replacements, not placeholders that can be substituted 41 | // note %% literal for wildcard in postgres 42 | // only a single database can be dropped at a time. if not iterating in code, the drop statements can be generated via this format: 43 | // 44 | // SELECT 'DROP DATABASE "'||datname||'";' FROM pg_database WHERE datname LIKE '_disp_pg%' AND datistemplate = false; 45 | const findCreatedTemplate = `SELECT datname FROM pg_database WHERE datname LIKE '%s%%' AND datistemplate = false;` 46 | const deleteTemplate = `DROP DATABASE IF EXISTS "%s"` // double quote for any upper case 47 | 48 | /* 49 | todo: having each created dbhandle register at db-creation-time would be a clever way of tracking 50 | (e.g. each New() is a sub instance NewSession?) open and/or not dropped databases. 51 | leaving unimplemented pending need and some thinking about ux 52 | */ 53 | 54 | // mightHaveTransaction looks for the start of a transaction in DDL to try to avoid creating tx inside a tx. 55 | // ideally we want to run all user provided sql in a tx to report if we failed setup 56 | func mightHaveTransaction(sqlText string) bool { 57 | // be naive with simple string matching until we have reason to not be. 58 | // if this affects you, please open a PR or issue 59 | evidence1 := "begin statement" 60 | evidence2 := "begin " // note trailing space -- avoid "beginning" or other innocent usage. prefer string cmp over regex (`begin\ (statement)?`) 61 | 62 | scanner := bufio.NewScanner(strings.NewReader(sqlText)) 63 | for scanner.Scan() { 64 | lowerSQL := strings.ToLower(scanner.Text()) 65 | if strings.Contains(lowerSQL, evidence1) || strings.Contains(lowerSQL, evidence2) { 66 | return true 67 | } 68 | } 69 | return false 70 | } 71 | 72 | // New creates a new database, offering some parameters for adjusting behavior. 73 | // 74 | // opts is a set of database options to pass into the function. an empty database name in opts results in an automatically generated name. 75 | // sqlText are SQL statements that are run in order 76 | // createDatabase does what you expect, but if the database does not exist, an error will return when New() tries to create a handle. 77 | // customIdPortion is an optional string that allows for specifying _disp_pg___bfpmfckppdetem30 78 | // 79 | // todo: customIdPortion maybe doesn't make sense 80 | // 81 | // return is: 82 | // - db handle without the db open 83 | // - db handle with the db open 84 | // - the name of the db 85 | // - a teardown/cleanup function specific to a DB (the connection not associated with a given db is closed in this function) 86 | // 87 | // The teardown function must be called to clean up resources after work is complete. 88 | // Not calling the teardown can also be useful if you want resources to persist for whatever reason 89 | // (such as single threaded tests with up/down migrations or for application instantiation with IF NOT EXISTS). 90 | func New(ctx context.Context, opts *DatabaseOptions, sqlText []string, createDatabase bool, customIdPortion string) (*sql.DB, *sql.DB, string, func() error, error) { 91 | 92 | // bind connection needed for cleanup func as we cannot drop an open database 93 | userOptsDatabaseName := opts.Database 94 | opts.Database = "" 95 | dbHandleNoBoundDB, err := Connect(ctx, opts) 96 | if err != nil { 97 | return nil, nil, opts.Database, dbHandleNoBoundDB.Close, err 98 | } 99 | 100 | // post connection binding without opening the database, store database name for usage 101 | // (necessary post Connect because we can't provide the DB name on connection if it doesn't exist yet) 102 | if len(userOptsDatabaseName) > 0 { 103 | opts.Database = userOptsDatabaseName 104 | } 105 | 106 | // the provided value could still be an empty string, which we take as a sign to generate a db name 107 | // do not generate a DB name if we're not creating a database or update our options 108 | if len(opts.Database) == 0 && createDatabase { 109 | opts.Database = generateDbName(customIdPortion) 110 | } 111 | 112 | tearDown := tearDownFunc(ctx, nil, dbHandleNoBoundDB, opts.Database) 113 | 114 | // pinging the database would be useful here, notwithstanding the temporal issue 115 | // however, Ping() connects to a db that may not exist yet. 116 | // an alternative is looking at open ports or dialing to see if something is potentially listening 117 | // 118 | // similarly, we can't trivially mark errors as classes of execution/connection failures because 119 | // it's not known for certain which Exec will happen first based on arguments to this function 120 | 121 | if createDatabase { 122 | _, err = dbHandleNoBoundDB.ExecContext(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, opts.Database)) // double quote for any upper case 123 | if err != nil { 124 | dsn, errDSN := opts.DSN() 125 | if errDSN != nil { 126 | dsn = fmt.Sprintf(" 0 { 260 | partFmt, err := do.GetDSNPart("Host") 261 | if err != nil { 262 | return "", err 263 | } 264 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.Host)) 265 | } 266 | 267 | if do.Port > 0 { 268 | partFmt, err := do.GetDSNPart("Port") 269 | if err != nil { 270 | return "", err 271 | } 272 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.Port)) 273 | } 274 | 275 | if len(do.Database) > 0 { 276 | partFmt, err := do.GetDSNPart("Database") 277 | if err != nil { 278 | return "", err 279 | } 280 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.Database)) 281 | } 282 | 283 | if len(do.Schema) > 0 { 284 | partFmt, err := do.GetDSNPart("Schema") 285 | if err != nil { 286 | return "", err 287 | } 288 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.Schema)) 289 | } 290 | 291 | if len(do.User) > 0 { 292 | partFmt, err := do.GetDSNPart("User") 293 | if err != nil { 294 | return "", err 295 | } 296 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.User)) 297 | } 298 | 299 | if len(do.Password) > 0 { 300 | partFmt, err := do.GetDSNPart("Password") 301 | if err != nil { 302 | return "", err 303 | } 304 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.Password)) 305 | } 306 | 307 | if len(do.SslMode) > 0 { 308 | if !IsValidSSLString(do.SslMode) { 309 | return "", fmt.Errorf("invalid ssl mode provided: %s", do.SslMode) 310 | } 311 | 312 | partFmt, err := do.GetDSNPart("SslMode") 313 | if err != nil { 314 | return "", err 315 | } 316 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.SslMode)) 317 | } 318 | 319 | if len(do.SslCert) > 0 { 320 | partFmt, err := do.GetDSNPart("SslCert") 321 | if err != nil { 322 | return "", err 323 | } 324 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.SslCert)) 325 | } 326 | if len(do.SslKey) > 0 { 327 | partFmt, err := do.GetDSNPart("SslKey") 328 | if err != nil { 329 | return "", err 330 | } 331 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.SslKey)) 332 | } 333 | if len(do.SslRootCert) > 0 { 334 | partFmt, err := do.GetDSNPart("SslRootCert") 335 | if err != nil { 336 | return "", err 337 | } 338 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.SslRootCert)) 339 | } 340 | 341 | if len(do.SslCertMode) > 0 { 342 | if !IsValidSSLCertModeString(do.SslCertMode) { 343 | return "", fmt.Errorf("invalid sslcertmode provided: %s", do.SslCertMode) 344 | } 345 | partFmt, err := do.GetDSNPart("SslCertMode") 346 | if err != nil { 347 | return "", err 348 | } 349 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.SslCertMode)) 350 | 351 | } 352 | 353 | if do.ConnectTimeoutSeconds > 0 { 354 | partFmt, err := do.GetDSNPart("ConnectTimeoutSeconds") 355 | if err != nil { 356 | return "", err 357 | } 358 | dsnPortions = append(dsnPortions, fmt.Sprintf(partFmt, do.ConnectTimeoutSeconds)) 359 | } 360 | 361 | return strings.Join(dsnPortions, " "), nil 362 | } 363 | 364 | // Connect uses a DSN to create a database handle to a target dbName 365 | // return is the db handle and an error if setup was prevented 366 | func Connect(ctx context.Context, do *DatabaseOptions) (*sql.DB, error) { 367 | 368 | if ctx.Err() != nil { 369 | return nil, ctx.Err() 370 | } 371 | 372 | dsn, err := do.DSN() 373 | if err != nil { 374 | return nil, err 375 | } 376 | 377 | // on postgres, sql.Open does not verify a connection will be successful. 378 | // 379 | // the connector for lib/pq takes a DSN string for construction, so unfortunately, we have to build a DSN 380 | // instead of assigning to struct fields 381 | db, err := sql.Open("postgres", dsn) 382 | if err != nil { 383 | return nil, err 384 | } 385 | 386 | db.SetMaxOpenConns(do.MaxOpenConns) 387 | db.SetMaxIdleConns(do.MaxIdleConns) 388 | return db, nil 389 | } 390 | 391 | // generateDBName creates a database name, using the prefix + suffix of this library 392 | // if an empty string is provided, epoch seconds are used 393 | func generateDbName(customIdPortion string) string { 394 | if len(customIdPortion) == 0 { 395 | customIdPortion = strconv.FormatInt(time.Now().Unix(), 10) 396 | } 397 | 398 | if len(customIdPortion) > remainingNameBudget { 399 | customIdPortion = customIdPortion[:remainingNameBudget] 400 | } 401 | 402 | // generate a random string. if a UUID becomes available in the standard lib, consider that instead 403 | // in go, each rune can be 1-4 bytes long, so we can't make([]byte,32) and rand.Read() into it without creating a 404 | // byte array of size 4*randSuffixLen 405 | const charChoices = "abcdefghijklmnopqrstuvwxyz0123456789" 406 | const charChoiceLen = len(charChoices) 407 | 408 | // a little gross to re-seed each call, but so would be doing this even if not required in New() 409 | // this is not expected to be needed to be crypto-secure 410 | rand.Seed(time.Now().UnixNano()) 411 | 412 | asciiSlice := make([]byte, randSuffixLen) 413 | for i := range asciiSlice { 414 | asciiSlice[i] = charChoices[rand.Intn(charChoiceLen)] 415 | } 416 | return strings.Join([]string{AutogenDBPrefix, customIdPortion, string(asciiSlice)}, "_") 417 | } 418 | 419 | // tearDownFunc returns a function for dropping a database and closing db handles 420 | func tearDownFunc(ctx context.Context, dbHandleDBOpen *sql.DB, dbHandleNoDBOpen *sql.DB, dbName string) func() error { 421 | return func() error { 422 | select { 423 | case <-ctx.Done(): 424 | errs := []string{"context cancelled before database dropped"} 425 | err := dbHandleNoDBOpen.Close() 426 | if err != nil { 427 | errs = append(errs, err.Error()) 428 | } 429 | return errors.New(strings.Join(errs, ", ")) 430 | default: 431 | } 432 | 433 | errs := make([]string, 0) 434 | _, dropDBErrs := DropDB(ctx, dbHandleNoDBOpen, []string{dbName}) 435 | if len(dropDBErrs) > 0 { 436 | for _, e := range dropDBErrs { 437 | errs = append(errs, fmt.Sprintf("dropping database error: %s", e)) 438 | } 439 | } 440 | errClose := dbHandleNoDBOpen.Close() 441 | if errClose != nil { 442 | errs = append(errs, fmt.Sprintf("*sql.DB without db open. %s", errClose)) 443 | } 444 | if dbHandleDBOpen != nil { 445 | errClose = dbHandleDBOpen.Close() 446 | if errClose != nil { 447 | errs = append(errs, fmt.Sprintf("*sql.DB with db open. %s", errClose)) 448 | } 449 | } 450 | 451 | retErr := strings.Join(errs, ", ") 452 | if len(retErr) > 0 { 453 | return errors.New(retErr) 454 | } 455 | 456 | return nil 457 | } 458 | } 459 | 460 | // FindLikelyAbandonedDBs finds databases that were not cleaned up for whatever reason (process crashing, teardown not being called) 461 | // if prefix is an empty string, the AutogenDBPrefix is used 462 | func FindLikelyAbandonedDBs(ctx context.Context, dbHandle *sql.DB, prefix string) ([]string, error) { 463 | if prefix == "" { 464 | prefix = AutogenDBPrefix 465 | } 466 | 467 | selectQuery := fmt.Sprintf(findCreatedTemplate, prefix) 468 | rows, err := dbHandle.QueryContext(ctx, selectQuery) 469 | if err != nil { 470 | return []string{}, err 471 | } 472 | results := make([]string, 0) 473 | for rows.Next() { 474 | var r string 475 | err = rows.Scan(&r) 476 | if err != nil { 477 | return []string{}, err 478 | } 479 | results = append(results, r) 480 | } 481 | return results, nil 482 | } 483 | 484 | // DropDB drops a slice of databases by name. It cannot drop the database that is currently open by the dbHandle 485 | func DropDB(ctx context.Context, dbHandle *sql.DB, dbNames []string) ([]sql.Result, []error) { 486 | results := make([]sql.Result, 0) 487 | if len(dbNames) == 0 { 488 | return results, nil 489 | } 490 | errs := make([]error, 0) 491 | 492 | for _, dbName := range dbNames { 493 | res, err := dbHandle.ExecContext(ctx, fmt.Sprintf(deleteTemplate, dbName)) 494 | if res != nil { 495 | results = append(results, res) 496 | } 497 | if err != nil { 498 | errs = append(errs, err) 499 | } 500 | } 501 | return results, errs 502 | } 503 | --------------------------------------------------------------------------------