├── go.mod ├── README.md ├── .github └── workflows │ └── ci.yml ├── LICENSE ├── go.sum ├── validation.go ├── pgxrecord_test.go └── pgxrecord.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jackc/pgxrecord 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/jackc/pgx/v5 v5.4.2 7 | github.com/stretchr/testify v1.8.4 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/jackc/pgpassfile v1.0.0 // indirect 13 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 14 | github.com/kr/text v0.2.0 // indirect 15 | github.com/pmezard/go-difflib v1.0.0 // indirect 16 | github.com/rogpeppe/go-internal v1.11.0 // indirect 17 | golang.org/x/crypto v0.11.0 // indirect 18 | golang.org/x/text v0.11.0 // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgxrecord.svg)](https://pkg.go.dev/github.com/jackc/pgxrecord) 2 | ![Build Status](https://github.com/jackc/pgxrecord/actions/workflows/ci.yml/badge.svg) 3 | 4 | # pgxrecord 5 | 6 | Package pgxrecord is a tiny library for CRUD operations. 7 | 8 | It does not and most likely will not have traditional ORM features such as associations. It's sole purpose is a simple way to read and write records. 9 | 10 | ## Package Status 11 | 12 | pgxrecord is highly experimental. The API may change at any time or the package may be abandoned. 13 | 14 | ## Testing 15 | 16 | The pgxrecord tests require a PostgreSQL database. It will use the standard PG* environment variables (PGHOST, PGDATABASE, etc.) for its connection settings. Each test is run inside of a transaction which is rolled back at the end of the test. No permanent changes will be made to the test database. 17 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | test: 12 | name: Test 13 | runs-on: ubuntu-22.04 14 | 15 | steps: 16 | 17 | - name: Start and set up PostgreSQL 18 | run: | 19 | sudo systemctl start postgresql.service 20 | pg_isready 21 | sudo -u postgres createuser -s runner 22 | createdb runner 23 | 24 | - name: Set up Go 1.x 25 | uses: actions/setup-go@v3 26 | with: 27 | go-version: 1.19 28 | 29 | - name: Check out code into the Go module directory 30 | uses: actions/checkout@v3 31 | 32 | # - name: Setup upterm session 33 | # uses: lhotari/action-upterm@v1 34 | # with: 35 | ## limits ssh access and adds the ssh public key for the user which triggered the workflow 36 | # limit-access-to-actor: true 37 | 38 | - name: Test 39 | run: go test -race -v ./... 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Jack Christensen 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 6 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 7 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 8 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 9 | github.com/jackc/pgx/v5 v5.4.2 h1:u1gmGDwbdRUZiwisBm/Ky2M14uQyUP65bG8+20nnyrg= 10 | github.com/jackc/pgx/v5 v5.4.2/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= 11 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 12 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 13 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 17 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 18 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 19 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 20 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 22 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 23 | golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= 24 | golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= 25 | golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= 26 | golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 29 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 30 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 31 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 32 | -------------------------------------------------------------------------------- /validation.go: -------------------------------------------------------------------------------- 1 | package pgxrecord 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type ValidationError struct { 9 | field string 10 | err error 11 | } 12 | 13 | func (ve *ValidationError) Field() string { 14 | return ve.field 15 | } 16 | 17 | func (ve *ValidationError) Unwrap() error { 18 | return ve.err 19 | } 20 | 21 | func (ve *ValidationError) Error() string { 22 | return fmt.Sprintf("%s: %s", ve.field, ve.err) 23 | } 24 | 25 | type ValidationErrors struct { 26 | errors []*ValidationError 27 | } 28 | 29 | // Add adds a new error to the validation errors for the given field. By convention, an empty string for field indicates 30 | // a record-level error. 31 | func (ve *ValidationErrors) Add(field string, err error) { 32 | ve.errors = append(ve.errors, &ValidationError{field: field, err: err}) 33 | } 34 | 35 | // Len returns the number of errors in the ValidationErrors. 36 | func (ve *ValidationErrors) Len() int { 37 | if ve == nil { 38 | return 0 39 | } 40 | 41 | return len(ve.errors) 42 | } 43 | 44 | // On returns a []*ValidationError for the given field. 45 | func (ve *ValidationErrors) On(field string) []*ValidationError { 46 | if ve == nil { 47 | return nil 48 | } 49 | 50 | var errs []*ValidationError 51 | for _, e := range ve.errors { 52 | if e.field == field { 53 | errs = append(errs, e) 54 | } 55 | } 56 | return errs 57 | } 58 | 59 | // All returns all errors. 60 | func (ve *ValidationErrors) All() []*ValidationError { 61 | if ve == nil { 62 | return nil 63 | } 64 | 65 | return ve.errors 66 | } 67 | 68 | // Unwrap unwraps all errors. 69 | func (ve *ValidationErrors) Unwrap() []error { 70 | var errs []error 71 | for _, e := range ve.errors { 72 | errs = append(errs, e) 73 | } 74 | 75 | return errs 76 | } 77 | 78 | // Error satisfies the error interface. 79 | func (ve *ValidationErrors) Error() string { 80 | if len(ve.errors) == 0 { 81 | return "BUG: ValidationErrors.Error() called with no errors" 82 | } 83 | 84 | sb := strings.Builder{} 85 | for i, e := range ve.errors { 86 | if i > 0 { 87 | sb.WriteString(", ") 88 | } 89 | 90 | if e.field == "" { 91 | sb.WriteString(e.err.Error()) 92 | } else { 93 | sb.WriteString(e.field) 94 | sb.WriteString(": ") 95 | sb.WriteString(e.err.Error()) 96 | } 97 | } 98 | 99 | return sb.String() 100 | } 101 | 102 | type GetterSetter interface { 103 | Get(attribute string) any 104 | Set(attribute string, value any) 105 | } 106 | 107 | type RecordValidator struct { 108 | record GetterSetter 109 | errors *ValidationErrors 110 | } 111 | 112 | func (v *RecordValidator) Validate(field string, validators ...ValueValidator) { 113 | value := v.record.Get(field) 114 | for _, validator := range validators { 115 | var err error 116 | value, err = validator.Validate(value) 117 | if err != nil { 118 | v.errors.Add(field, err) 119 | return 120 | } 121 | } 122 | v.record.Set(field, value) 123 | } 124 | 125 | type ValueValidator interface { 126 | Validate(any) (any, error) 127 | } 128 | 129 | // type RecordValidator interface { 130 | // ValidateRecord(ctx context.Context, db DB, table *Table, record *Record) error 131 | // } 132 | 133 | // type RecordValidatorBuilder struct { 134 | // ctx context.Context 135 | // db DB 136 | // table *Table 137 | // record *Record 138 | // } 139 | 140 | // func NewRecordValidatorBuilder(fn func(rvb *RecordValidatorBuilder)) RecordValidator { 141 | 142 | // } 143 | -------------------------------------------------------------------------------- /pgxrecord_test.go: -------------------------------------------------------------------------------- 1 | package pgxrecord_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/jackc/pgx/v5" 10 | "github.com/jackc/pgx/v5/pgtype" 11 | "github.com/jackc/pgx/v5/pgxtest" 12 | "github.com/jackc/pgxrecord" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | var defaultConnTestRunner pgxtest.ConnTestRunner 18 | 19 | func init() { 20 | defaultConnTestRunner = pgxtest.DefaultConnTestRunner() 21 | defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { 22 | config, err := pgx.ParseConfig(os.Getenv("PGXRECORD_TEST_DATABASE")) 23 | require.NoError(t, err) 24 | return config 25 | } 26 | } 27 | 28 | func TestTableLoadAllColumns(t *testing.T) { 29 | t.Parallel() 30 | 31 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 32 | _, err := conn.Exec(ctx, `create temporary table t ( 33 | id int primary key generated by default as identity, 34 | name text not null, 35 | age int 36 | )`) 37 | require.NoError(t, err) 38 | 39 | table := &pgxrecord.Table{ 40 | Name: pgx.Identifier{"t"}, 41 | } 42 | err = table.LoadAllColumns(ctx, conn) 43 | require.NoError(t, err) 44 | 45 | require.Len(t, table.Columns, 3) 46 | expectedColumns := []pgxrecord.Column{ 47 | {Name: "id", OID: pgtype.Int4OID, NotNull: true, PrimaryKey: true}, 48 | {Name: "name", OID: pgtype.TextOID, NotNull: true, PrimaryKey: false}, 49 | {Name: "age", OID: pgtype.Int4OID, NotNull: false, PrimaryKey: false}, 50 | } 51 | for i := range expectedColumns { 52 | assert.Equalf(t, expectedColumns[i].Name, table.Columns[i].Name, "Column %d name", i+1) 53 | assert.Equalf(t, expectedColumns[i].OID, table.Columns[i].OID, "Column %d OID", i+1) 54 | assert.Equalf(t, expectedColumns[i].NotNull, table.Columns[i].NotNull, "Column %d not null", i+1) 55 | assert.Equalf(t, expectedColumns[i].PrimaryKey, table.Columns[i].PrimaryKey, "Column %d primary key", i+1) 56 | } 57 | }) 58 | } 59 | 60 | func TestTableSelectQuery(t *testing.T) { 61 | t.Parallel() 62 | 63 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 64 | _, err := conn.Exec(ctx, `create temporary table t ( 65 | id int primary key generated by default as identity, 66 | name text not null, 67 | age int 68 | )`) 69 | require.NoError(t, err) 70 | 71 | table := &pgxrecord.Table{ 72 | Name: pgx.Identifier{"t"}, 73 | } 74 | err = table.LoadAllColumns(ctx, conn) 75 | require.NoError(t, err) 76 | 77 | require.Equal(t, `select "t"."id", "t"."name", "t"."age" from "t"`, table.SelectQuery()) 78 | }) 79 | } 80 | 81 | func TestTableNewRecord(t *testing.T) { 82 | t.Parallel() 83 | 84 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 85 | _, err := conn.Exec(ctx, `create temporary table t ( 86 | id int primary key generated by default as identity, 87 | name text not null, 88 | age int 89 | )`) 90 | require.NoError(t, err) 91 | 92 | table := &pgxrecord.Table{ 93 | Name: pgx.Identifier{"t"}, 94 | } 95 | err = table.LoadAllColumns(ctx, conn) 96 | require.NoError(t, err) 97 | 98 | record := table.NewRecord() 99 | require.Equal(t, map[string]any{"id": nil, "name": nil, "age": nil}, record.Attributes()) 100 | 101 | err = record.SetAttributesStrict(map[string]any{"name": "John", "age": 42}) 102 | require.NoError(t, err) 103 | require.Equal(t, map[string]any{"id": nil, "name": "John", "age": 42}, record.Attributes()) 104 | }) 105 | } 106 | 107 | func TestTableFindByPK(t *testing.T) { 108 | t.Parallel() 109 | 110 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 111 | _, err := conn.Exec(ctx, `create temporary table t ( 112 | id int primary key generated by default as identity, 113 | name text not null, 114 | age int 115 | )`) 116 | require.NoError(t, err) 117 | 118 | var id int32 119 | err = conn.QueryRow(ctx, `insert into t (name, age) values ('John', 42) returning id`).Scan(&id) 120 | require.NoError(t, err) 121 | 122 | table := &pgxrecord.Table{ 123 | Name: pgx.Identifier{"t"}, 124 | } 125 | err = table.LoadAllColumns(ctx, conn) 126 | require.NoError(t, err) 127 | 128 | record, err := table.FindByPK(ctx, conn, id) 129 | require.NoError(t, err) 130 | require.Equal(t, map[string]any{"id": int32(1), "name": "John", "age": int32(42)}, record.Attributes()) 131 | }) 132 | } 133 | 134 | func TestRecordSetAndGet(t *testing.T) { 135 | t.Parallel() 136 | 137 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 138 | _, err := conn.Exec(ctx, `create temporary table t ( 139 | id int primary key generated by default as identity, 140 | name text not null, 141 | age int 142 | )`) 143 | require.NoError(t, err) 144 | 145 | table := &pgxrecord.Table{ 146 | Name: pgx.Identifier{"t"}, 147 | } 148 | err = table.LoadAllColumns(ctx, conn) 149 | require.NoError(t, err) 150 | 151 | record := table.NewRecord() 152 | 153 | record.Set("name", "John") 154 | name := record.Get("name") 155 | require.Equal(t, "John", name) 156 | }) 157 | } 158 | 159 | func TestRecordSaveInsert(t *testing.T) { 160 | t.Parallel() 161 | 162 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 163 | _, err := conn.Exec(ctx, `create temporary table t ( 164 | id int primary key generated by default as identity, 165 | name text not null, 166 | age int 167 | )`) 168 | require.NoError(t, err) 169 | 170 | table := &pgxrecord.Table{ 171 | Name: pgx.Identifier{"t"}, 172 | } 173 | err = table.LoadAllColumns(ctx, conn) 174 | require.NoError(t, err) 175 | 176 | record := table.NewRecord() 177 | err = record.SetAttributesStrict(map[string]any{"name": "John", "age": 42}) 178 | require.NoError(t, err) 179 | err = record.Save(ctx, conn) 180 | require.NoError(t, err) 181 | 182 | require.Equal(t, map[string]any{"id": int32(1), "name": "John", "age": int32(42)}, record.Attributes()) 183 | }) 184 | } 185 | 186 | func TestRecordSaveUpdate(t *testing.T) { 187 | t.Parallel() 188 | 189 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 190 | _, err := conn.Exec(ctx, `create temporary table t ( 191 | id int primary key generated by default as identity, 192 | name text not null, 193 | age int 194 | )`) 195 | require.NoError(t, err) 196 | 197 | var id int32 198 | err = conn.QueryRow(ctx, `insert into t (name, age) values ('John', 42) returning id`).Scan(&id) 199 | require.NoError(t, err) 200 | 201 | table := &pgxrecord.Table{ 202 | Name: pgx.Identifier{"t"}, 203 | } 204 | err = table.LoadAllColumns(ctx, conn) 205 | require.NoError(t, err) 206 | 207 | record, err := table.FindByPK(ctx, conn, id) 208 | require.NoError(t, err) 209 | require.Equal(t, map[string]any{"id": int32(1), "name": "John", "age": int32(42)}, record.Attributes()) 210 | 211 | record.Set("name", "Bill") 212 | err = record.Save(ctx, conn) 213 | require.NoError(t, err) 214 | 215 | record, err = table.FindByPK(ctx, conn, id) 216 | require.NoError(t, err) 217 | require.Equal(t, map[string]any{"id": int32(1), "name": "Bill", "age": int32(42)}, record.Attributes()) 218 | }) 219 | } 220 | 221 | func TestRecordSaveNormalize(t *testing.T) { 222 | t.Parallel() 223 | 224 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 225 | _, err := conn.Exec(ctx, `create temporary table t ( 226 | id int primary key generated by default as identity, 227 | name text not null, 228 | age int 229 | )`) 230 | require.NoError(t, err) 231 | 232 | normalizeCallCount := 0 233 | 234 | table := &pgxrecord.Table{ 235 | Name: pgx.Identifier{"t"}, 236 | Normalize: func(ctx context.Context, db pgxrecord.DB, table *pgxrecord.Table, record *pgxrecord.Record) error { 237 | record.Set("name", "Bill") 238 | normalizeCallCount++ 239 | return nil 240 | }, 241 | } 242 | err = table.LoadAllColumns(ctx, conn) 243 | require.NoError(t, err) 244 | 245 | // Insert calls normalize 246 | record := table.NewRecord() 247 | err = record.SetAttributesStrict(map[string]any{"name": "John", "age": 42}) 248 | require.NoError(t, err) 249 | err = record.Save(ctx, conn) 250 | require.NoError(t, err) 251 | require.EqualValues(t, 1, normalizeCallCount) 252 | require.Equal(t, map[string]any{"id": int32(1), "name": "Bill", "age": int32(42)}, record.Attributes()) 253 | 254 | // Update calls normalize 255 | err = record.SetAttributesStrict(map[string]any{"name": "George", "age": int32(43)}) 256 | require.NoError(t, err) 257 | err = record.Save(ctx, conn) 258 | require.NoError(t, err) 259 | require.Equal(t, map[string]any{"id": int32(1), "name": "Bill", "age": int32(43)}, record.Attributes()) 260 | require.EqualValues(t, 2, normalizeCallCount) 261 | }) 262 | } 263 | 264 | func TestRecordSaveValidate(t *testing.T) { 265 | t.Parallel() 266 | 267 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 268 | _, err := conn.Exec(ctx, `create temporary table t ( 269 | id int primary key generated by default as identity, 270 | name text not null, 271 | age int 272 | )`) 273 | require.NoError(t, err) 274 | 275 | validateCallCount := 0 276 | 277 | table := &pgxrecord.Table{ 278 | Name: pgx.Identifier{"t"}, 279 | Validate: func(ctx context.Context, db pgxrecord.DB, table *pgxrecord.Table, record *pgxrecord.Record) error { 280 | validateCallCount++ 281 | 282 | ve := &pgxrecord.ValidationErrors{} 283 | if name := record.Get("name"); name == "" || name == nil { 284 | ve.Add("name", fmt.Errorf("cannot be blank")) 285 | } 286 | if age := record.Get("age"); age == nil { 287 | ve.Add("age", fmt.Errorf("cannot be blank")) 288 | } 289 | 290 | if ve.Len() > 0 { 291 | return ve 292 | } 293 | 294 | return nil 295 | }, 296 | } 297 | err = table.LoadAllColumns(ctx, conn) 298 | require.NoError(t, err) 299 | 300 | // Insert calls Validate and blocks Save 301 | record := table.NewRecord() 302 | err = record.Save(ctx, conn) 303 | require.Error(t, err) 304 | require.EqualValues(t, 1, validateCallCount) 305 | require.EqualValues(t, 2, record.Errors().Len()) 306 | nameErrors := record.Errors().On("name") 307 | require.Len(t, nameErrors, 1) 308 | require.Equal(t, "name: cannot be blank", nameErrors[0].Error()) 309 | ageErrors := record.Errors().On("age") 310 | require.Len(t, ageErrors, 1) 311 | require.Equal(t, "age: cannot be blank", ageErrors[0].Error()) 312 | require.Equal(t, map[string]any{"id": nil, "name": nil, "age": nil}, record.Attributes()) 313 | 314 | // Fix bad attributes and save 315 | err = record.SetAttributesStrict(map[string]any{"name": "John", "age": 42}) 316 | require.NoError(t, err) 317 | err = record.Save(ctx, conn) 318 | require.NoError(t, err) 319 | require.EqualValues(t, 2, validateCallCount) 320 | require.Nil(t, record.Errors()) 321 | 322 | // Update calls Validate 323 | err = record.SetAttributesStrict(map[string]any{"name": nil}) 324 | require.NoError(t, err) 325 | err = record.Save(ctx, conn) 326 | require.Error(t, err) 327 | require.EqualValues(t, 3, validateCallCount) 328 | require.Equal(t, map[string]any{"id": int32(1), "name": nil, "age": int32(42)}, record.Attributes()) 329 | }) 330 | } 331 | 332 | func TestRecordUpdateAttributes(t *testing.T) { 333 | t.Parallel() 334 | 335 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 336 | table := &pgxrecord.Table{ 337 | Name: pgx.Identifier{"t"}, 338 | Columns: []*pgxrecord.Column{ 339 | {Name: "id", OID: pgtype.Int4OID, NotNull: true, PrimaryKey: true}, 340 | {Name: "name", OID: pgtype.TextOID, NotNull: true, PrimaryKey: false}, 341 | {Name: "age", OID: pgtype.Int4OID, NotNull: false, PrimaryKey: false}, 342 | }, 343 | } 344 | 345 | record := table.NewRecord() 346 | require.Equal(t, map[string]any{"id": nil, "name": nil, "age": nil}, record.Attributes()) 347 | 348 | record.SetAttributes(map[string]any{"name": "John", "age": 42, "ignore": "me"}) 349 | require.Equal(t, map[string]any{"id": nil, "name": "John", "age": 42}, record.Attributes()) 350 | }) 351 | } 352 | -------------------------------------------------------------------------------- /pgxrecord.go: -------------------------------------------------------------------------------- 1 | // Package pgxrecord is a tiny library for CRUD operations. 2 | package pgxrecord 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "fmt" 8 | "strconv" 9 | "strings" 10 | 11 | "github.com/jackc/pgx/v5" 12 | ) 13 | 14 | var errTooManyRows = fmt.Errorf("too many rows") 15 | 16 | // DB is the interface pgxrecord uses to access the database. It is satisfied by *pgx.Conn, pgx.Tx, *pgxpool.Pool, etc. 17 | type DB interface { 18 | Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) 19 | } 20 | 21 | // Column represents a column in a table. 22 | type Column struct { 23 | Name string 24 | quotedName string 25 | OID uint32 26 | NotNull bool 27 | PrimaryKey bool 28 | } 29 | 30 | // Table represents a table in a database. It must not be mutated after any method other than LoadAllColumns is called. 31 | type Table struct { 32 | Name pgx.Identifier 33 | Columns []*Column 34 | 35 | // Normalize is called before a record is saved. It is useful for normalizing data before it is saved. For example, 36 | // it can be used to trim strings. If Normalize returns an error then the save is aborted. 37 | Normalize func(ctx context.Context, db DB, table *Table, record *Record) error 38 | 39 | // Validate is called before a record is saved. If Validate returns an error then the save is aborted. A 40 | // *ValidationErrors should be returned if validation fails. Any other error indicates an error occurred while 41 | // validating. For example, a database query for a uniqueness check failed because of a broken database connection. 42 | Validate func(ctx context.Context, db DB, table *Table, record *Record) error 43 | 44 | finalized bool 45 | quotedQualifiedName string 46 | quotedName string 47 | selectQuery string 48 | selectByPKQuery string 49 | pkWhereClause string 50 | returningClause string 51 | pkIndexes []int 52 | nameToColumnIndex map[string]int 53 | validationErrors *ValidationErrors 54 | } 55 | 56 | // Record represents a row from a table in the database. 57 | type Record struct { 58 | table *Table 59 | originalAttributes []any 60 | attributes []any 61 | assigned []bool 62 | } 63 | 64 | // LoadAllColumns queries the database for the table columns. It must not be called after any other method has been 65 | // called. 66 | func (t *Table) LoadAllColumns(ctx context.Context, db DB) error { 67 | if t.finalized { 68 | return fmt.Errorf("cannot call after table finalized") 69 | } 70 | 71 | var tableOID uint32 72 | 73 | { 74 | var rows pgx.Rows 75 | 76 | if len(t.Name) == 1 { 77 | rows, _ = db.Query(ctx, `select c.oid 78 | from pg_catalog.pg_class c 79 | where c.relname=$1 80 | and pg_catalog.pg_table_is_visible(c.oid) 81 | limit 1`, 82 | t.Name[0], 83 | ) 84 | } else if len(t.Name) == 2 { 85 | rows, _ = db.Query(ctx, `select c.oid 86 | from pg_catalog.pg_class c 87 | join pg_catalog.pg_namespace n on n.oid=c.relnamespace 88 | where c.relname=$1 89 | and n.nspname=$2 90 | and pg_catalog.pg_table_is_visible(c.oid) 91 | limit 1`, 92 | t.Name[1], t.Name[0], 93 | ) 94 | } 95 | 96 | var err error 97 | tableOID, err = pgx.CollectOneRow(rows, pgx.RowTo[uint32]) 98 | if err != nil { 99 | return fmt.Errorf("pgxrecord.Table (%s): LoadAllColumns: failed to find table OID: %v", t.Name.Sanitize(), err) 100 | } 101 | } 102 | 103 | rows, _ := db.Query(ctx, `select attname, atttypid, attnotnull, 104 | coalesce(( 105 | select true 106 | from pg_catalog.pg_index 107 | where pg_index.indrelid=pg_attribute.attrelid 108 | and pg_index.indisprimary 109 | and pg_attribute.attnum = any(pg_index.indkey) 110 | ), false) as isprimary 111 | from pg_catalog.pg_attribute 112 | where attrelid=$1 113 | and attnum > 0 114 | and not attisdropped 115 | order by attnum`, tableOID) 116 | var err error 117 | t.Columns, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[Column]) 118 | if err != nil { 119 | return fmt.Errorf("pgxrecord.Table (%s): LoadAllColumns: failed to find columns: %v", t.Name.Sanitize(), err) 120 | } 121 | 122 | return nil 123 | } 124 | 125 | // finalize finishes the table initialization. 126 | func (t *Table) finalize() { 127 | if t.finalized { 128 | panic("BUG: cannot call after table finalized") 129 | } 130 | 131 | t.finalized = true 132 | 133 | t.quotedQualifiedName = t.Name.Sanitize() 134 | t.quotedName = pgx.Identifier{t.Name[len(t.Name)-1]}.Sanitize() 135 | for i, c := range t.Columns { 136 | c.quotedName = pgx.Identifier{c.Name}.Sanitize() 137 | if c.PrimaryKey { 138 | t.pkIndexes = append(t.pkIndexes, i) 139 | } 140 | } 141 | 142 | t.pkWhereClause = t.buildPKWhereClause() 143 | t.selectQuery = t.buildSelectQuery() 144 | t.selectByPKQuery = t.selectQuery + " " + t.pkWhereClause 145 | t.returningClause = t.buildReturningClause() 146 | t.nameToColumnIndex = buildNameToColumnIndex(t.Columns) 147 | } 148 | 149 | func (t *Table) buildSelectQuery() string { 150 | b := &strings.Builder{} 151 | b.WriteString("select ") 152 | for i := range t.Columns { 153 | if i > 0 { 154 | b.WriteString(", ") 155 | } 156 | b.WriteString(t.quotedName) 157 | b.WriteByte('.') 158 | b.WriteString(t.Columns[i].quotedName) 159 | } 160 | b.WriteString(" from ") 161 | b.WriteString(t.quotedQualifiedName) 162 | 163 | return b.String() 164 | } 165 | 166 | func (t *Table) buildPKWhereClause() string { 167 | b := &strings.Builder{} 168 | b.WriteString("where ") 169 | for i := range t.pkIndexes { 170 | if i > 0 { 171 | b.WriteString(" and ") 172 | } 173 | c := t.Columns[t.pkIndexes[i]] 174 | b.WriteString(c.quotedName) 175 | b.WriteString(" = $") 176 | b.WriteString(strconv.FormatInt(int64(i+1), 10)) 177 | } 178 | 179 | return b.String() 180 | } 181 | 182 | func (t *Table) buildReturningClause() string { 183 | b := &strings.Builder{} 184 | b.WriteString("returning ") 185 | for i, c := range t.Columns { 186 | if i > 0 { 187 | b.WriteString(", ") 188 | } 189 | b.WriteString(c.quotedName) 190 | } 191 | 192 | return b.String() 193 | } 194 | 195 | func (t *Table) buildSelectByPKQuery() string { 196 | b := &strings.Builder{} 197 | b.WriteString(t.selectQuery) 198 | 199 | for i := range t.Columns { 200 | if i > 0 { 201 | b.WriteString(", ") 202 | } 203 | b.WriteString(t.quotedName) 204 | b.WriteByte('.') 205 | b.WriteString(t.Columns[i].quotedName) 206 | } 207 | b.WriteString(" from ") 208 | b.WriteString(t.quotedQualifiedName) 209 | 210 | return b.String() 211 | } 212 | 213 | func buildNameToColumnIndex(columns []*Column) map[string]int { 214 | m := make(map[string]int, len(columns)) 215 | for i := range columns { 216 | m[columns[i].Name] = i 217 | } 218 | return m 219 | } 220 | 221 | // NewRecord creates an empty Record. 222 | func (t *Table) NewRecord() *Record { 223 | if !t.finalized { 224 | t.finalize() 225 | } 226 | 227 | record := &Record{ 228 | table: t, 229 | attributes: make([]any, len(t.Columns)), 230 | assigned: make([]bool, len(t.Columns)), 231 | } 232 | 233 | return record 234 | } 235 | 236 | // SelectQuery returns the SQL query to select all rows from the table. 237 | func (t *Table) SelectQuery() string { 238 | if !t.finalized { 239 | t.finalize() 240 | } 241 | 242 | return t.selectQuery 243 | } 244 | 245 | // FindByPK finds a record by primary key. 246 | func (t *Table) FindByPK(ctx context.Context, db DB, pk ...any) (*Record, error) { 247 | if !t.finalized { 248 | t.finalize() 249 | } 250 | 251 | rows, _ := db.Query(ctx, t.selectByPKQuery, pk...) 252 | record, err := pgx.CollectOneRow(rows, t.RowToRecord) 253 | if err != nil { 254 | return nil, fmt.Errorf("pgxrecord.Table (%s): FindByPK (%v): %w", t.quotedQualifiedName, pk, err) 255 | } 256 | 257 | return record, nil 258 | } 259 | 260 | // RowToRecord is a pgx.RowToFunc that returns a *Record. 261 | func (t *Table) RowToRecord(row pgx.CollectableRow) (*Record, error) { 262 | if !t.finalized { 263 | t.finalize() 264 | } 265 | 266 | record := t.NewRecord() 267 | 268 | ptrsToAttributes := make([]any, len(record.attributes)) 269 | for i := range record.attributes { 270 | ptrsToAttributes[i] = &record.attributes[i] 271 | } 272 | 273 | err := row.Scan(ptrsToAttributes...) 274 | if err != nil { 275 | return nil, fmt.Errorf("pgxrecord.Table (%s): RowToRecord: %w", t.quotedQualifiedName, err) 276 | } 277 | 278 | record.originalAttributes = make([]any, len(record.attributes)) 279 | copy(record.originalAttributes, record.attributes) 280 | 281 | return record, nil 282 | } 283 | 284 | // Set sets an attribute to a value. It panics if attribute does not exist. 285 | func (r *Record) Set(attribute string, value any) { 286 | idx, ok := r.table.nameToColumnIndex[attribute] 287 | if !ok { 288 | panic(fmt.Sprintf("pgxrecord.Record (%s): Set: attribute %q is not found", r.table.quotedQualifiedName, attribute)) 289 | } 290 | 291 | r.attributes[idx] = value 292 | r.assigned[idx] = true 293 | } 294 | 295 | // Get returns the value of attribute. It panics if attribute does not exist. 296 | func (r *Record) Get(attribute string) any { 297 | idx, ok := r.table.nameToColumnIndex[attribute] 298 | if !ok { 299 | panic(fmt.Sprintf("pgxrecord.Record (%s): Get: attribute %q is not found", r.table.quotedQualifiedName, attribute)) 300 | } 301 | 302 | return r.attributes[idx] 303 | } 304 | 305 | // SetAttributes sets attributes. Ignores attributes that do not exist. 306 | func (r *Record) SetAttributes(attributes map[string]any) { 307 | for k, v := range attributes { 308 | idx, ok := r.table.nameToColumnIndex[k] 309 | if ok { 310 | r.attributes[idx] = v 311 | r.assigned[idx] = true 312 | } 313 | } 314 | } 315 | 316 | // SetAttributesStrict sets attributes. Returns an error if any attributes do not exist. 317 | func (r *Record) SetAttributesStrict(attributes map[string]any) error { 318 | for k, v := range attributes { 319 | idx, ok := r.table.nameToColumnIndex[k] 320 | if !ok { 321 | return fmt.Errorf("pgxrecord.Record (%s): Set: attribute %q is not found", r.table.quotedQualifiedName, k) 322 | } 323 | 324 | r.attributes[idx] = v 325 | r.assigned[idx] = true 326 | } 327 | 328 | return nil 329 | } 330 | 331 | // Attributes returns all attributes. 332 | func (r *Record) Attributes() map[string]any { 333 | m := make(map[string]any, len(r.attributes)) 334 | for i := range r.table.Columns { 335 | m[r.table.Columns[i].Name] = r.attributes[i] 336 | } 337 | 338 | return m 339 | } 340 | 341 | // Save saves the record using db. 342 | func (r *Record) Save(ctx context.Context, db DB) error { 343 | r.table.validationErrors = nil 344 | 345 | if fn := r.table.Normalize; fn != nil { 346 | err := fn(ctx, db, r.table, r) 347 | if err != nil { 348 | return fmt.Errorf("pgxrecord.Record (%s): Save: %w", r.table.quotedQualifiedName, err) 349 | } 350 | } 351 | 352 | if fn := r.table.Validate; fn != nil { 353 | err := fn(ctx, db, r.table, r) 354 | if err != nil { 355 | var ve *ValidationErrors 356 | if errors.As(err, &ve) { 357 | r.table.validationErrors = ve 358 | } 359 | return fmt.Errorf("pgxrecord.Record (%s): Save: %w", r.table.quotedQualifiedName, err) 360 | } 361 | } 362 | 363 | var sql string 364 | var args []any 365 | 366 | if r.originalAttributes == nil { 367 | sql, args = r.insert(ctx, db) 368 | } else { 369 | sql, args = r.update(ctx, db) 370 | } 371 | 372 | ptrsToAttributes := make([]any, len(r.attributes)) 373 | for i := range r.attributes { 374 | ptrsToAttributes[i] = &r.attributes[i] 375 | } 376 | 377 | err := queryRow(ctx, db, sql, args, ptrsToAttributes) 378 | if err != nil { 379 | return fmt.Errorf("pgxrecord.Record (%s): Save: %w", r.table.quotedQualifiedName, err) 380 | } 381 | 382 | r.originalAttributes = make([]any, len(r.attributes)) 383 | copy(r.originalAttributes, r.attributes) 384 | for i := range r.assigned { 385 | r.assigned[i] = false 386 | } 387 | 388 | return nil 389 | } 390 | 391 | func (r *Record) insert(ctx context.Context, db DB) (string, []any) { 392 | b := &strings.Builder{} 393 | b.WriteString("insert into ") 394 | b.WriteString(r.table.quotedQualifiedName) 395 | b.WriteString(" (") 396 | 397 | assignedCount := 0 398 | for i := range r.assigned { 399 | if r.assigned[i] { 400 | if assignedCount > 0 { 401 | b.WriteString(", ") 402 | } 403 | assignedCount++ 404 | b.WriteString(r.table.Columns[i].quotedName) 405 | } 406 | } 407 | 408 | b.WriteString(") values (") 409 | args := make([]any, assignedCount) 410 | assignedCount = 0 411 | for i := range r.assigned { 412 | if r.assigned[i] { 413 | if assignedCount > 0 { 414 | b.WriteString(", ") 415 | } 416 | args[assignedCount] = r.attributes[i] 417 | assignedCount++ 418 | b.WriteByte('$') 419 | b.WriteString(strconv.FormatInt(int64(assignedCount), 10)) 420 | } 421 | } 422 | 423 | b.WriteString(") ") 424 | b.WriteString(r.table.returningClause) 425 | 426 | return b.String(), args 427 | } 428 | 429 | func (r *Record) update(ctx context.Context, db DB) (string, []any) { 430 | b := &strings.Builder{} 431 | b.WriteString("update ") 432 | b.WriteString(r.table.quotedQualifiedName) 433 | b.WriteString(" set ") 434 | 435 | args := make([]any, 0, len(r.attributes)) 436 | for _, pkIdx := range r.table.pkIndexes { 437 | args = append(args, r.attributes[pkIdx]) 438 | } 439 | 440 | assignedCount := 0 441 | for i := range r.assigned { 442 | if r.assigned[i] { 443 | if assignedCount > 0 { 444 | b.WriteString(", ") 445 | } 446 | args = append(args, r.attributes[i]) 447 | assignedCount++ 448 | b.WriteString(r.table.Columns[i].quotedName) 449 | b.WriteString(" = $") 450 | b.WriteString(strconv.FormatInt(int64(len(args)), 10)) 451 | } 452 | } 453 | 454 | b.WriteByte(' ') 455 | b.WriteString(r.table.pkWhereClause) 456 | 457 | b.WriteByte(' ') 458 | b.WriteString(r.table.returningClause) 459 | 460 | return b.String(), args 461 | } 462 | 463 | func (r *Record) Errors() *ValidationErrors { 464 | return r.table.validationErrors 465 | } 466 | 467 | // queryRow builds QueryRow-like functionality on top of DB. This allows pgxutil to have the convenience of QueryRow 468 | // without needing it as part of the DB interface. 469 | func queryRow(ctx context.Context, db DB, sql string, args []any, scanTargets []any) error { 470 | rows, err := db.Query(ctx, sql, args...) 471 | if err != nil { 472 | return err 473 | } 474 | defer rows.Close() 475 | 476 | if rows.Next() { 477 | rows.Scan(scanTargets...) 478 | } else { 479 | return pgx.ErrNoRows 480 | } 481 | 482 | if rows.Next() { 483 | return errTooManyRows 484 | } 485 | 486 | err = rows.Err() 487 | if err != nil { 488 | return err 489 | } 490 | 491 | return nil 492 | } 493 | --------------------------------------------------------------------------------