├── .travis.yml ├── LICENSE ├── README.md ├── conversions.go ├── conversions_test.go ├── doc.go ├── go.mod ├── go.sum ├── int32_array.go ├── int32_array_test.go ├── int64_array.go ├── int64_array_test.go ├── json_text.go ├── json_text_test.go ├── postgis.go ├── postgis_test.go ├── sqltypes_test.go ├── string_array.go └── string_array_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | sudo: required 4 | dist: trusty 5 | 6 | env: 7 | - PGVERSION=9.1 8 | - PGVERSION=9.2 9 | - PGVERSION=9.3 10 | - PGVERSION=9.4 11 | 12 | go: 13 | - 1.5.4 14 | - 1.6.3 15 | - 1.7 16 | - tip 17 | 18 | before_install: 19 | - sudo /etc/init.d/postgresql stop 20 | - sudo /etc/init.d/postgresql start $PGVERSION 21 | 22 | before_script: 23 | - createdb -V 24 | - createdb pq_types 25 | 26 | script: go test -v -check.v 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 mc² software 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pq-types [![Build Status](https://travis-ci.org/mc2soft/pq-types.svg?branch=master)](https://travis-ci.org/mc2soft/pq-types) [![GoDoc](https://godoc.org/github.com/mc2soft/pq-types?status.svg)](http://godoc.org/github.com/mc2soft/pq-types) 2 | 3 | This Go package provides additional types for PostgreSQL: 4 | 5 | * `Int32Array` for `int[]` (compatible with [`intarray`](http://www.postgresql.org/docs/current/static/intarray.html) module); 6 | * `Int64Array` for `bigint[]`; 7 | * `StringArray` for `varchar[]`; 8 | * `JSONText` for `varchar`, `text`, `json` and `jsonb`; 9 | * `PostGISPoint`, `PostGISBox2D` and `PostGISPolygon`. 10 | 11 | Install it: `go get github.com/mc2soft/pq-types` 12 | -------------------------------------------------------------------------------- /conversions.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | ) 7 | 8 | // NullString covers trivial case of string to sql.NullString conversion assuming empty string to be NULL 9 | func NullString(src string) sql.NullString { 10 | return sql.NullString{ 11 | String: src, 12 | Valid: src != "", 13 | } 14 | } 15 | 16 | // NullInt32 covers trivial case of int32 to sql.NullInt32 conversion assuming 0 to be NULL 17 | func NullInt32(src int32) sql.NullInt32 { 18 | return sql.NullInt32{ 19 | Int32: src, 20 | Valid: src != 0, 21 | } 22 | } 23 | 24 | // NullInt64 covers trivial case of int64 to sql.NullInt64 conversion assuming 0 to be NULL 25 | func NullInt64(src int64) sql.NullInt64 { 26 | return sql.NullInt64{ 27 | Int64: src, 28 | Valid: src != 0, 29 | } 30 | } 31 | 32 | // NullTimestampP converts *time.Time to a sql.NullTime 33 | func NullTimestampP(src *time.Time) sql.NullTime { 34 | if src == nil { 35 | return sql.NullTime{Valid: false} 36 | } 37 | return sql.NullTime{ 38 | Time: *src, 39 | Valid: true, 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /conversions_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "time" 7 | 8 | "gopkg.in/check.v1" 9 | ) 10 | 11 | func insertQuery(col string) string { 12 | return fmt.Sprintf("INSERT INTO pq_types (%s) VALUES($1)", col) 13 | } 14 | 15 | func selectQuery(col string) string { 16 | return fmt.Sprintf("SELECT %s FROM pq_types LIMIT 1;", col) 17 | } 18 | 19 | // For each type there need to be a pair of tests: with empty -> nil value and with non-empty value -> value. 20 | 21 | func (s *TypesSuite) TestConversionNullString(c *check.C) { 22 | cases := []sql.NullString{ 23 | {String: ""}, 24 | {String: "truly random string", Valid: true}, 25 | } 26 | for _, expected := range cases { 27 | val := NullString(expected.String) 28 | _, err := s.db.Exec(insertQuery("null_str"), val) 29 | c.Assert(err, check.IsNil) 30 | 31 | var actual sql.NullString 32 | err = s.db.QueryRow(selectQuery("null_str")).Scan(actual) 33 | c.Check(err, check.IsNil) 34 | c.Check(actual, check.DeepEquals, expected) 35 | } 36 | } 37 | 38 | func (s *TypesSuite) TestConversionNullInt32(c *check.C) { 39 | cases := []sql.NullInt32{ 40 | {Int32: 0}, 41 | {Int32: 0xabc, Valid: true}, 42 | } 43 | for _, expected := range cases { 44 | val := NullInt32(expected.Int32) 45 | _, err := s.db.Exec(insertQuery("null_int32"), val) 46 | c.Assert(err, check.IsNil) 47 | 48 | var actual sql.NullInt32 49 | err = s.db.QueryRow(selectQuery("null_int32")).Scan(actual) 50 | c.Check(err, check.IsNil) 51 | c.Check(actual, check.DeepEquals, expected) 52 | } 53 | } 54 | 55 | func (s *TypesSuite) TestConversionNullInt64(c *check.C) { 56 | cases := []sql.NullInt64{ 57 | {Int64: 0}, 58 | {Int64: 0xabcdef, Valid: true}, 59 | } 60 | for _, expected := range cases { 61 | val := NullInt64(expected.Int64) 62 | _, err := s.db.Exec(insertQuery("null_int64"), val) 63 | c.Assert(err, check.IsNil) 64 | 65 | var actual sql.NullInt64 66 | err = s.db.QueryRow(selectQuery("null_int64")).Scan(actual) 67 | c.Check(err, check.IsNil) 68 | c.Check(actual, check.DeepEquals, expected) 69 | } 70 | } 71 | 72 | func (s *TypesSuite) TestConversionNullTimestamp(c *check.C) { 73 | // here we use another approach, as input is a pointer 74 | now := time.Now() 75 | cases := []*time.Time{nil, &now} 76 | 77 | for _, expected := range cases { 78 | s.SetUpTest(c) 79 | 80 | val := NullTimestampP(expected) 81 | 82 | _, err := s.db.Exec(insertQuery("null_timestamp"), val) 83 | c.Assert(err, check.IsNil) 84 | 85 | var actual *time.Time 86 | err = s.db.QueryRow(selectQuery("null_timestamp")).Scan(&actual) 87 | c.Check(err, check.IsNil) 88 | c.Check(actual, check.Equals, expected) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package pq_types provides additional types for PostgreSQL: int, bigint and string arrays 2 | // (former is compatible with intarray module), json and jsonb values, few PostGIS types. 3 | package pq_types 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mc2soft/pq-types 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/lib/pq v1.10.4 7 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c 8 | ) 9 | 10 | require ( 11 | github.com/kr/pretty v0.2.1 // indirect 12 | github.com/kr/text v0.1.0 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= 2 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 3 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 4 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 5 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 6 | github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk= 7 | github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 8 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 9 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 10 | -------------------------------------------------------------------------------- /int32_array.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "fmt" 7 | "sort" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // Int32Array is a slice of int32 values, compatible with PostgreSQL's int[] and intarray module. 13 | type Int32Array []int32 14 | 15 | func (a Int32Array) Len() int { return len(a) } 16 | func (a Int32Array) Less(i, j int) bool { return a[i] < a[j] } 17 | func (a Int32Array) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 18 | 19 | // Value implements database/sql/driver Valuer interface. 20 | func (a Int32Array) Value() (driver.Value, error) { 21 | if a == nil { 22 | return nil, nil 23 | } 24 | 25 | s := make([]string, len(a)) 26 | for i, v := range a { 27 | s[i] = strconv.Itoa(int(v)) 28 | } 29 | return []byte("{" + strings.Join(s, ",") + "}"), nil 30 | } 31 | 32 | // Scan implements database/sql Scanner interface. 33 | func (a *Int32Array) Scan(value interface{}) error { 34 | if value == nil { 35 | *a = nil 36 | return nil 37 | } 38 | 39 | var b []byte 40 | switch v := value.(type) { 41 | case []byte: 42 | b = v 43 | case string: 44 | b = []byte(v) 45 | default: 46 | return fmt.Errorf("Int32Array.Scan: expected []byte or string, got %T (%q)", value, value) 47 | } 48 | 49 | if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { 50 | return fmt.Errorf("Int32Array.Scan: unexpected data %q", b) 51 | } 52 | 53 | p := strings.Split(string(b[1:len(b)-1]), ",") 54 | 55 | // reuse underlying array if present 56 | if *a == nil { 57 | *a = make(Int32Array, 0, len(p)) 58 | } 59 | *a = (*a)[:0] 60 | 61 | for _, s := range p { 62 | if s == "" { 63 | continue 64 | } 65 | i, err := strconv.Atoi(s) 66 | if err != nil { 67 | return err 68 | } 69 | *a = append(*a, int32(i)) 70 | } 71 | 72 | return nil 73 | } 74 | 75 | // EqualWithoutOrder returns true if two int32 arrays are equal without order, false otherwise. 76 | // It may sort both arrays in-place to do so. 77 | func (a Int32Array) EqualWithoutOrder(b Int32Array) bool { 78 | if len(a) != len(b) { 79 | return false 80 | } 81 | 82 | sort.Sort(a) 83 | sort.Sort(b) 84 | 85 | for i := range a { 86 | if a[i] != b[i] { 87 | return false 88 | } 89 | } 90 | 91 | return true 92 | } 93 | 94 | // check interfaces 95 | var ( 96 | _ sort.Interface = Int32Array{} 97 | _ driver.Valuer = Int32Array{} 98 | _ sql.Scanner = &Int32Array{} 99 | ) 100 | -------------------------------------------------------------------------------- /int32_array_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | func (s *TypesSuite) TestInt32Array(c *C) { 11 | type testData struct { 12 | a Int32Array 13 | b []byte 14 | } 15 | for _, d := range []testData{ 16 | {Int32Array(nil), []byte(nil)}, 17 | {Int32Array{}, []byte(`{}`)}, 18 | {Int32Array{1}, []byte(`{1}`)}, 19 | {Int32Array{1, 0, -3}, []byte(`{1,0,-3}`)}, 20 | {Int32Array{-3, 0, 1}, []byte(`{-3,0,1}`)}, 21 | } { 22 | s.SetUpTest(c) 23 | 24 | _, err := s.db.Exec("INSERT INTO pq_types (int32_array) VALUES($1)", d.a) 25 | c.Assert(err, IsNil) 26 | 27 | b1 := []byte("42") 28 | a1 := Int32Array{42} 29 | err = s.db.QueryRow("SELECT int32_array, int32_array FROM pq_types").Scan(&b1, &a1) 30 | c.Check(err, IsNil) 31 | c.Check(b1, DeepEquals, d.b, Commentf("\nb1 = %#q\nd.b = %#q", b1, d.b)) 32 | c.Check(a1, DeepEquals, d.a) 33 | 34 | // check db array length 35 | var length sql.NullInt64 36 | err = s.db.QueryRow("SELECT array_length(int32_array, 1) FROM pq_types").Scan(&length) 37 | c.Check(err, IsNil) 38 | c.Check(length.Valid, Equals, len(d.a) > 0) 39 | c.Check(length.Int64, Equals, int64(len(d.a))) 40 | 41 | // check db array elements 42 | for i := 0; i < len(d.a); i++ { 43 | q := fmt.Sprintf("SELECT int32_array[%d] FROM pq_types", i+1) 44 | var el sql.NullInt64 45 | err = s.db.QueryRow(q).Scan(&el) 46 | c.Check(err, IsNil) 47 | c.Check(el.Valid, Equals, true) 48 | c.Check(el.Int64, Equals, int64(d.a[i])) 49 | } 50 | } 51 | } 52 | 53 | func (s *TypesSuite) TestInt32ArrayEqualWithoutOrder(c *C) { 54 | c.Check(Int32Array{1, 0, -3}.EqualWithoutOrder(Int32Array{-3, 0, 1}), Equals, true) 55 | c.Check(Int32Array{1, 0, -3}.EqualWithoutOrder(Int32Array{1}), Equals, false) 56 | c.Check(Int32Array{1, 0, -3}.EqualWithoutOrder(Int32Array{1, 0, 42}), Equals, false) 57 | c.Check(Int32Array{}.EqualWithoutOrder(Int32Array{}), Equals, true) 58 | c.Check(Int32Array{}.EqualWithoutOrder(Int32Array{1}), Equals, false) 59 | } 60 | -------------------------------------------------------------------------------- /int64_array.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "fmt" 7 | "sort" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // Int64Array is a slice of int64 values, compatible with PostgreSQL's bigint[]. 13 | type Int64Array []int64 14 | 15 | func (a Int64Array) Len() int { return len(a) } 16 | func (a Int64Array) Less(i, j int) bool { return a[i] < a[j] } 17 | func (a Int64Array) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 18 | 19 | // Value implements database/sql/driver Valuer interface. 20 | func (a Int64Array) Value() (driver.Value, error) { 21 | if a == nil { 22 | return nil, nil 23 | } 24 | 25 | s := make([]string, len(a)) 26 | for i, v := range a { 27 | s[i] = strconv.Itoa(int(v)) 28 | } 29 | return []byte("{" + strings.Join(s, ",") + "}"), nil 30 | } 31 | 32 | // Scan implements database/sql Scanner interface. 33 | func (a *Int64Array) Scan(value interface{}) error { 34 | if value == nil { 35 | *a = nil 36 | return nil 37 | } 38 | 39 | var b []byte 40 | switch v := value.(type) { 41 | case []byte: 42 | b = v 43 | case string: 44 | b = []byte(v) 45 | default: 46 | return fmt.Errorf("Int64Array.Scan: expected []byte or string, got %T (%q)", value, value) 47 | } 48 | 49 | if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { 50 | return fmt.Errorf("Int64Array.Scan: unexpected data %q", b) 51 | } 52 | 53 | p := strings.Split(string(b[1:len(b)-1]), ",") 54 | 55 | // reuse underlying array if present 56 | if *a == nil { 57 | *a = make(Int64Array, 0, len(p)) 58 | } 59 | *a = (*a)[:0] 60 | 61 | for _, s := range p { 62 | if s == "" { 63 | continue 64 | } 65 | i, err := strconv.Atoi(s) 66 | if err != nil { 67 | return err 68 | } 69 | *a = append(*a, int64(i)) 70 | } 71 | 72 | return nil 73 | } 74 | 75 | // EqualWithoutOrder returns true if two int64 arrays are equal without order, false otherwise. 76 | // It may sort both arrays in-place to do so. 77 | func (a Int64Array) EqualWithoutOrder(b Int64Array) bool { 78 | if len(a) != len(b) { 79 | return false 80 | } 81 | 82 | sort.Sort(a) 83 | sort.Sort(b) 84 | 85 | for i := range a { 86 | if a[i] != b[i] { 87 | return false 88 | } 89 | } 90 | 91 | return true 92 | } 93 | 94 | // check interfaces 95 | var ( 96 | _ sort.Interface = Int64Array{} 97 | _ driver.Valuer = Int64Array{} 98 | _ sql.Scanner = &Int64Array{} 99 | ) 100 | -------------------------------------------------------------------------------- /int64_array_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | func (s *TypesSuite) TestInt64Array(c *C) { 11 | type testData struct { 12 | a Int64Array 13 | b []byte 14 | } 15 | for _, d := range []testData{ 16 | {Int64Array(nil), []byte(nil)}, 17 | {Int64Array{}, []byte(`{}`)}, 18 | {Int64Array{1}, []byte(`{1}`)}, 19 | {Int64Array{1, 0, -3}, []byte(`{1,0,-3}`)}, 20 | {Int64Array{-3, 0, 1}, []byte(`{-3,0,1}`)}, 21 | } { 22 | s.SetUpTest(c) 23 | 24 | _, err := s.db.Exec("INSERT INTO pq_types (int64_array) VALUES($1)", d.a) 25 | c.Assert(err, IsNil) 26 | 27 | b1 := []byte("42") 28 | a1 := Int64Array{42} 29 | err = s.db.QueryRow("SELECT int64_array, int64_array FROM pq_types").Scan(&b1, &a1) 30 | c.Check(err, IsNil) 31 | c.Check(b1, DeepEquals, d.b, Commentf("\nb1 = %#q\nd.b = %#q", b1, d.b)) 32 | c.Check(a1, DeepEquals, d.a) 33 | 34 | // check db array length 35 | var length sql.NullInt64 36 | err = s.db.QueryRow("SELECT array_length(int64_array, 1) FROM pq_types").Scan(&length) 37 | c.Check(err, IsNil) 38 | c.Check(length.Valid, Equals, len(d.a) > 0) 39 | c.Check(length.Int64, Equals, int64(len(d.a))) 40 | 41 | // check db array elements 42 | for i := 0; i < len(d.a); i++ { 43 | q := fmt.Sprintf("SELECT int64_array[%d] FROM pq_types", i+1) 44 | var el sql.NullInt64 45 | err = s.db.QueryRow(q).Scan(&el) 46 | c.Check(err, IsNil) 47 | c.Check(el.Valid, Equals, true) 48 | c.Check(el.Int64, Equals, int64(d.a[i])) 49 | } 50 | } 51 | } 52 | 53 | func (s *TypesSuite) TestInt64ArrayEqualWithoutOrder(c *C) { 54 | c.Check(Int64Array{1, 0, -3}.EqualWithoutOrder(Int64Array{-3, 0, 1}), Equals, true) 55 | c.Check(Int64Array{1, 0, -3}.EqualWithoutOrder(Int64Array{1}), Equals, false) 56 | c.Check(Int64Array{1, 0, -3}.EqualWithoutOrder(Int64Array{1, 0, 42}), Equals, false) 57 | c.Check(Int64Array{}.EqualWithoutOrder(Int64Array{}), Equals, true) 58 | c.Check(Int64Array{}.EqualWithoutOrder(Int64Array{1}), Equals, false) 59 | } 60 | -------------------------------------------------------------------------------- /json_text.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | ) 10 | 11 | // JSONText is a raw encoded JSON value, compatible with PostgreSQL's varchar, text, json and jsonb. 12 | // It behaves like json.RawMessage by implementing json.Marshaler and json.Unmarshaler 13 | // and can be used to delay JSON decoding or precompute a JSON encoding. 14 | type JSONText []byte 15 | 16 | // String implements fmt.Stringer for better output and logging. 17 | func (j JSONText) String() string { 18 | return string(j) 19 | } 20 | 21 | // MarshalJSON returns j as the JSON encoding of j. 22 | func (j JSONText) MarshalJSON() ([]byte, error) { 23 | if j == nil { 24 | return []byte(`null`), nil 25 | } 26 | return j, nil 27 | } 28 | 29 | // UnmarshalJSON sets *j to a copy of data. 30 | func (j *JSONText) UnmarshalJSON(data []byte) error { 31 | if j == nil { 32 | return errors.New("JSONText.UnmarshalJSON: on nil pointer") 33 | } 34 | *j = append((*j)[0:0], data...) 35 | return nil 36 | 37 | } 38 | 39 | // Value implements database/sql/driver Valuer interface. 40 | // It performs basic validation by unmarshaling itself into json.RawMessage. 41 | // If j is not valid JSON, it returns and error. 42 | func (j JSONText) Value() (driver.Value, error) { 43 | if j == nil { 44 | return nil, nil 45 | } 46 | 47 | var m json.RawMessage 48 | var err = json.Unmarshal(j, &m) 49 | if err != nil { 50 | return []byte{}, err 51 | } 52 | return []byte(j), nil 53 | } 54 | 55 | // Scan implements database/sql Scanner interface. 56 | // It store value in *j. No validation is done. 57 | func (j *JSONText) Scan(value interface{}) error { 58 | if value == nil { 59 | *j = nil 60 | return nil 61 | } 62 | 63 | var b []byte 64 | switch v := value.(type) { 65 | case []byte: 66 | b = v 67 | case string: 68 | b = []byte(v) 69 | default: 70 | return fmt.Errorf("JSONText.Scan: expected []byte or string, got %T (%q)", value, value) 71 | } 72 | 73 | *j = JSONText(append((*j)[0:0], b...)) 74 | return nil 75 | } 76 | 77 | // check interfaces 78 | var ( 79 | _ json.Marshaler = JSONText{} 80 | _ json.Unmarshaler = &JSONText{} 81 | _ driver.Valuer = JSONText{} 82 | _ sql.Scanner = &JSONText{} 83 | _ fmt.Stringer = JSONText{} 84 | _ fmt.Stringer = &JSONText{} 85 | ) 86 | -------------------------------------------------------------------------------- /json_text_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | 10 | . "gopkg.in/check.v1" 11 | ) 12 | 13 | func (s *TypesSuite) TestJSONText(c *C) { 14 | type testData struct { 15 | j JSONText 16 | b []byte 17 | } 18 | 19 | for _, d := range []testData{ 20 | {JSONText(nil), []byte(nil)}, 21 | {JSONText(`null`), []byte(`null`)}, 22 | {JSONText(`{}`), []byte(`{}`)}, 23 | {JSONText(`[]`), []byte(`[]`)}, 24 | {JSONText(`[{"b": true, "n": 123}, {"s": "foo", "obj": {"f1": 456, "f2": false}}, [null]]`), 25 | []byte(`[{"b": true, "n": 123}, {"s": "foo", "obj": {"f1": 456, "f2": false}}, [null]]`)}, 26 | } { 27 | b1, err := json.Marshal(d.j) 28 | c.Check(err, IsNil) 29 | b := bytes.Replace(d.b, []byte(` `), nil, -1) 30 | if d.j == nil { 31 | // special case 32 | c.Check(b1, DeepEquals, []byte(`null`)) 33 | } else { 34 | c.Check(b1, DeepEquals, b, Commentf("\nb1 = %#q\nb = %#q", b1, b)) 35 | } 36 | 37 | for _, col := range []string{"jsontext_varchar", "jsontext_json", "jsontext_jsonb"} { 38 | if strings.HasSuffix(col, "json") && s.skipJSON { 39 | continue 40 | } 41 | if strings.HasSuffix(col, "jsonb") && s.skipJSONB { 42 | continue 43 | } 44 | 45 | s.SetUpTest(c) 46 | 47 | _, err = s.db.Exec(fmt.Sprintf("INSERT INTO pq_types (%s) VALUES($1)", col), d.j) 48 | c.Assert(err, IsNil) 49 | 50 | b1 := []byte(`{"foo": "bar"}`) 51 | j1 := JSONText(`{"foo": "bar"}`) 52 | err = s.db.QueryRow(fmt.Sprintf("SELECT %s, %s FROM pq_types", col, col)).Scan(&b1, &j1) 53 | c.Check(err, IsNil) 54 | c.Check(b1, DeepEquals, d.b, Commentf("\nb1 = %#q\nd.b = %#q", b1, d.b)) 55 | c.Check(j1, DeepEquals, d.j) 56 | } 57 | } 58 | 59 | for _, j := range []JSONText{ 60 | JSONText{}, 61 | } { 62 | for _, col := range []string{"jsontext_varchar", "jsontext_json", "jsontext_jsonb"} { 63 | if strings.HasSuffix(col, "json") && s.skipJSON { 64 | continue 65 | } 66 | if strings.HasSuffix(col, "jsonb") && s.skipJSONB { 67 | continue 68 | } 69 | 70 | s.SetUpTest(c) 71 | 72 | _, err := s.db.Exec(fmt.Sprintf("INSERT INTO pq_types (%s) VALUES($1)", col), j) 73 | c.Check(err, DeepEquals, errors.New(`sql: converting Exec argument #0's type: unexpected end of JSON input`)) 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /postgis.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "encoding/binary" 8 | "encoding/hex" 9 | "fmt" 10 | "strings" 11 | ) 12 | 13 | // PostGISPoint is wrapper for PostGIS POINT type. 14 | type PostGISPoint struct { 15 | Lon, Lat float64 16 | } 17 | 18 | // Value implements database/sql/driver Valuer interface. 19 | // It returns point as WKT with SRID 4326 (WGS 84). 20 | func (p PostGISPoint) Value() (driver.Value, error) { 21 | return []byte(fmt.Sprintf("SRID=4326;POINT(%.8f %.8f)", p.Lon, p.Lat)), nil 22 | } 23 | 24 | type ewkbPoint struct { 25 | ByteOrder byte // 1 (LittleEndian) 26 | WkbType uint32 // 0x20000001 (PointS) 27 | SRID uint32 // 4326 28 | Point PostGISPoint 29 | } 30 | 31 | // Scan implements database/sql Scanner interface. 32 | // It expectes EWKB with SRID 4326 (WGS 84). 33 | func (p *PostGISPoint) Scan(value interface{}) error { 34 | if value == nil { 35 | *p = PostGISPoint{} 36 | return nil 37 | } 38 | 39 | v, ok := value.([]byte) 40 | if !ok { 41 | return fmt.Errorf("PostGISPoint.Scan: expected []byte, got %T (%v)", value, value) 42 | } 43 | 44 | ewkb := make([]byte, hex.DecodedLen(len(v))) 45 | n, err := hex.Decode(ewkb, v) 46 | if err != nil { 47 | return err 48 | } 49 | 50 | var ewkbP ewkbPoint 51 | err = binary.Read(bytes.NewReader(ewkb[:n]), binary.LittleEndian, &ewkbP) 52 | if err != nil { 53 | return err 54 | } 55 | 56 | if ewkbP.ByteOrder != 1 || ewkbP.WkbType != 0x20000001 || ewkbP.SRID != 4326 { 57 | return fmt.Errorf("PostGISPoint.Scan: unexpected ewkb %#v", ewkbP) 58 | } 59 | *p = ewkbP.Point 60 | return nil 61 | } 62 | 63 | // check interfaces 64 | var ( 65 | _ driver.Valuer = PostGISPoint{} 66 | _ sql.Scanner = &PostGISPoint{} 67 | ) 68 | 69 | // PostGISBox2D is wrapper for PostGIS Box2D type. 70 | type PostGISBox2D struct { 71 | Min, Max PostGISPoint 72 | } 73 | 74 | // Value implements database/sql/driver Valuer interface. 75 | // It returns box as WKT. 76 | func (b PostGISBox2D) Value() (driver.Value, error) { 77 | return []byte(fmt.Sprintf("BOX(%.8f %.8f,%.8f %.8f)", b.Min.Lon, b.Min.Lat, b.Max.Lon, b.Max.Lat)), nil 78 | } 79 | 80 | // Scan implements database/sql Scanner interface. 81 | // It expectes WKT. 82 | func (b *PostGISBox2D) Scan(value interface{}) error { 83 | if value == nil { 84 | *b = PostGISBox2D{} 85 | return nil 86 | } 87 | 88 | v, ok := value.([]byte) 89 | if !ok { 90 | return fmt.Errorf("PostGISBox2D.Scan: expected []byte, got %T (%v)", value, value) 91 | } 92 | 93 | n, err := fmt.Sscanf(string(v), "BOX(%f %f,%f %f)", &b.Min.Lon, &b.Min.Lat, &b.Max.Lon, &b.Max.Lat) 94 | if err != nil { 95 | return err 96 | } 97 | if n != 4 { 98 | return fmt.Errorf("PostGISBox2D.Scan: not enough params in the string: %v, %v != 4", v, n) 99 | } 100 | 101 | return nil 102 | } 103 | 104 | // check interfaces 105 | var ( 106 | _ driver.Valuer = PostGISBox2D{} 107 | _ sql.Scanner = &PostGISBox2D{} 108 | ) 109 | 110 | // PostGISPolygon is wrapper for PostGIS Polygon type. 111 | type PostGISPolygon struct { 112 | Points []PostGISPoint 113 | } 114 | 115 | // MakeEnvelope returns rectangular (min, max) polygon 116 | func MakeEnvelope(min, max PostGISPoint) PostGISPolygon { 117 | return PostGISPolygon{ 118 | Points: []PostGISPoint{min, {Lon: min.Lon, Lat: max.Lat}, max, {Lon: max.Lon, Lat: min.Lat}, min}, 119 | } 120 | } 121 | 122 | // Min returns min side of rectangular polygon 123 | func (p *PostGISPolygon) Min() PostGISPoint { 124 | if len(p.Points) != 5 || p.Points[0] != p.Points[4] || 125 | p.Points[0].Lon != p.Points[1].Lon || p.Points[0].Lat != p.Points[3].Lat || 126 | p.Points[1].Lat != p.Points[2].Lat || p.Points[2].Lon != p.Points[3].Lon { 127 | panic("Not an envelope polygon") 128 | } 129 | 130 | return p.Points[0] 131 | } 132 | 133 | // Max returns max side of rectangular polygon 134 | func (p *PostGISPolygon) Max() PostGISPoint { 135 | if len(p.Points) != 5 || p.Points[0] != p.Points[4] || 136 | p.Points[0].Lon != p.Points[1].Lon || p.Points[0].Lat != p.Points[3].Lat || 137 | p.Points[1].Lat != p.Points[2].Lat || p.Points[2].Lon != p.Points[3].Lon { 138 | panic("Not an envelope polygon") 139 | } 140 | 141 | return p.Points[2] 142 | } 143 | 144 | // Value implements database/sql/driver Valuer interface. 145 | // It returns polygon as WKT with SRID 4326 (WGS 84). 146 | func (p PostGISPolygon) Value() (driver.Value, error) { 147 | parts := make([]string, len(p.Points)) 148 | for i, pt := range p.Points { 149 | parts[i] = fmt.Sprintf("%.8f %.8f", pt.Lon, pt.Lat) 150 | } 151 | return []byte(fmt.Sprintf("SRID=4326;POLYGON((%s))", strings.Join(parts, ","))), nil 152 | } 153 | 154 | type ewkbPolygon struct { 155 | ByteOrder byte // 1 (LittleEndian) 156 | WkbType uint32 // 0x20000003 (PolygonS) 157 | SRID uint32 // 4326 158 | Rings uint32 159 | Count uint32 160 | } 161 | 162 | // Scan implements database/sql Scanner interface. 163 | // It expectes EWKB with SRID 4326 (WGS 84). 164 | func (p *PostGISPolygon) Scan(value interface{}) error { 165 | if value == nil { 166 | *p = PostGISPolygon{} 167 | return nil 168 | } 169 | 170 | v, ok := value.([]byte) 171 | if !ok { 172 | return fmt.Errorf("PostGISPolygon.Scan: expected []byte, got %T (%v)", value, value) 173 | } 174 | 175 | ewkb := make([]byte, hex.DecodedLen(len(v))) 176 | _, err := hex.Decode(ewkb, v) 177 | if err != nil { 178 | return err 179 | } 180 | 181 | r := bytes.NewReader(ewkb) 182 | 183 | var ewkbP ewkbPolygon 184 | err = binary.Read(r, binary.LittleEndian, &ewkbP) 185 | if err != nil { 186 | return err 187 | } 188 | 189 | if ewkbP.ByteOrder != 1 || ewkbP.WkbType != 0x20000003 || ewkbP.SRID != 4326 || ewkbP.Rings != 1 { 190 | return fmt.Errorf("PostGISPolygon.Scan: unexpected ewkb %#v", ewkbP) 191 | } 192 | p.Points = make([]PostGISPoint, ewkbP.Count) 193 | 194 | err = binary.Read(r, binary.LittleEndian, p.Points) 195 | if err != nil { 196 | return err 197 | } 198 | 199 | return nil 200 | } 201 | 202 | // check interfaces 203 | var ( 204 | _ driver.Valuer = PostGISPolygon{} 205 | _ sql.Scanner = &PostGISPolygon{} 206 | ) 207 | -------------------------------------------------------------------------------- /postgis_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | . "gopkg.in/check.v1" 5 | ) 6 | 7 | func (s *TypesSuite) TestPostGISPointScanValue(c *C) { 8 | var a PostGISPoint 9 | b := []byte{ 10 | 0x30, 0x31, 0x30, 0x31, 0x30, 0x30, 0x30, 0x30, 0x32, 0x30, 0x45, 0x36, 11 | 0x31, 0x30, 0x30, 0x30, 0x30, 0x30, 0x34, 0x34, 0x36, 0x45, 0x38, 0x36, 12 | 0x31, 0x42, 0x46, 0x30, 0x43, 0x44, 0x34, 0x32, 0x34, 0x30, 0x32, 0x31, 13 | 0x39, 0x34, 0x46, 0x37, 0x37, 0x31, 0x33, 0x34, 0x45, 0x39, 0x34, 0x42, 14 | 0x34, 0x30, 15 | } 16 | c.Check(a.Scan(b), IsNil) 17 | c.Check(a, DeepEquals, PostGISPoint{Lon: 37.6088900, Lat: 55.8219130}) 18 | v, err := a.Value() 19 | c.Check(err, IsNil) 20 | c.Check(v, DeepEquals, []byte(`SRID=4326;POINT(37.60889000 55.82191300)`), Commentf("%s", v)) 21 | } 22 | 23 | func (s *TypesSuite) TestPostGISPoint(c *C) { 24 | if s.skipPostGIS { 25 | c.Skip("PostGIS not available") 26 | } 27 | 28 | for _, p := range []PostGISPoint{ 29 | {Lon: 37.6088900, Lat: 55.8219130}, 30 | {Lon: -37.6088900, Lat: -55.8219130}, 31 | {Lon: 0, Lat: 0}, 32 | {Lon: 0.00, Lat: 0.0}, 33 | } { 34 | s.SetUpTest(c) 35 | 36 | _, err := s.db.Exec("INSERT INTO pq_types (point) VALUES($1)", p) 37 | c.Assert(err, IsNil) 38 | 39 | p1 := PostGISPoint{Lon: -1, Lat: -1} 40 | err = s.db.QueryRow("SELECT point FROM pq_types").Scan(&p1) 41 | c.Check(err, IsNil) 42 | 43 | c.Check(p1, DeepEquals, p) 44 | } 45 | } 46 | 47 | func (s *TypesSuite) TestPostGISBox2DScanValue(c *C) { 48 | var a PostGISBox2D 49 | b := []byte{ 50 | 0x42, 0x4f, 0x58, 0x28, 0x30, 0x2e, 0x31, 0x32, 0x35, 0x20, 0x30, 0x2e, 51 | 0x32, 0x35, 0x2c, 0x30, 0x2e, 0x35, 0x20, 0x31, 0x29, 52 | } 53 | c.Check(a.Scan(b), IsNil) 54 | c.Check(a, DeepEquals, PostGISBox2D{Min: PostGISPoint{Lon: 0.125, Lat: 0.25}, Max: PostGISPoint{Lon: 0.5, Lat: 1}}) 55 | v, err := a.Value() 56 | c.Check(err, IsNil) 57 | c.Check(v, DeepEquals, []byte(`BOX(0.12500000 0.25000000,0.50000000 1.00000000)`), Commentf("%s", v)) 58 | } 59 | 60 | func (s *TypesSuite) TestPostGISBox2D(c *C) { 61 | if s.skipPostGIS { 62 | c.Skip("PostGIS not available") 63 | } 64 | 65 | for _, b := range []PostGISBox2D{ 66 | {Min: PostGISPoint{Lon: 0.125, Lat: 0.25}, Max: PostGISPoint{Lon: 0.5, Lat: 1}}, 67 | {Min: PostGISPoint{Lon: -0.125, Lat: -0.25}, Max: PostGISPoint{Lon: 0.5, Lat: 1}}, 68 | {Min: PostGISPoint{Lon: -0.55, Lat: -0.55}, Max: PostGISPoint{Lon: 0.5, Lat: 1}}, 69 | } { 70 | s.SetUpTest(c) 71 | 72 | _, err := s.db.Exec("INSERT INTO pq_types (box) VALUES($1)", b) 73 | c.Assert(err, IsNil) 74 | 75 | var b1 PostGISBox2D 76 | err = s.db.QueryRow("SELECT box FROM pq_types").Scan(&b1) 77 | c.Check(err, IsNil) 78 | 79 | c.Check(b1, DeepEquals, b) 80 | } 81 | } 82 | 83 | func (s *TypesSuite) TestPostGISPolygonScanValue(c *C) { 84 | var a PostGISPolygon 85 | b := []byte{ 86 | 0x30, 0x31, 0x30, 0x33, 0x30, 0x30, 0x30, 0x30, 0x32, 0x30, 0x45, 0x36, 87 | 0x31, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, 0x30, 0x30, 0x30, 0x30, 88 | 0x30, 0x30, 0x30, 0x35, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 89 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x43, 0x30, 90 | 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 91 | 0x30, 0x30, 0x44, 0x30, 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 92 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x43, 0x30, 0x33, 0x46, 0x30, 0x30, 93 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x46, 0x30, 94 | 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 95 | 0x30, 0x30, 0x45, 0x30, 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 96 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x46, 0x30, 0x33, 0x46, 0x30, 0x30, 97 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x45, 0x30, 98 | 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 99 | 0x30, 0x30, 0x44, 0x30, 0x33, 0x46, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 100 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x43, 0x30, 0x33, 0x46, 0x30, 0x30, 101 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x44, 0x30, 102 | 0x33, 0x46, 103 | } 104 | c.Check(a.Scan(b), IsNil) 105 | c.Check(a, DeepEquals, PostGISPolygon{ 106 | Points: []PostGISPoint{ 107 | {Lon: 0.125, Lat: 0.25}, 108 | {Lon: 0.125, Lat: 1}, 109 | {Lon: 0.5, Lat: 1}, 110 | {Lon: 0.5, Lat: 0.25}, 111 | {Lon: 0.125, Lat: 0.25}, 112 | }, 113 | }) 114 | v, err := a.Value() 115 | c.Check(err, IsNil) 116 | c.Check(v, DeepEquals, []byte(`SRID=4326;POLYGON((0.12500000 0.25000000,0.12500000 1.00000000,0.50000000 1.00000000,0.50000000 0.25000000,0.12500000 0.25000000))`), Commentf("%s", v)) 117 | } 118 | 119 | func (s *TypesSuite) TestPostGISPolygon(c *C) { 120 | if s.skipPostGIS { 121 | c.Skip("PostGIS not available") 122 | } 123 | 124 | for _, p := range []PostGISPolygon{ 125 | { 126 | Points: []PostGISPoint{ 127 | {Lon: 0.125, Lat: 0.25}, 128 | {Lon: 0.125, Lat: 1}, 129 | {Lon: 0.5, Lat: 1}, 130 | {Lon: 0.5, Lat: 0.25}, 131 | {Lon: 0.125, Lat: 0.25}}}, 132 | { 133 | Points: []PostGISPoint{ 134 | {Lon: 0.0, Lat: 0.0}, 135 | {Lon: -50.555, Lat: -50.555}, 136 | {Lon: -50, Lat: 0}, 137 | {Lon: 0, Lat: 0}}}, 138 | } { 139 | s.SetUpTest(c) 140 | 141 | _, err := s.db.Exec("INSERT INTO pq_types (polygon) VALUES($1)", p) 142 | c.Assert(err, IsNil) 143 | 144 | var p1 PostGISPolygon 145 | err = s.db.QueryRow("SELECT polygon FROM pq_types").Scan(&p1) 146 | c.Check(err, IsNil) 147 | 148 | c.Check(p1, DeepEquals, p) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /sqltypes_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | "strconv" 7 | "strings" 8 | "testing" 9 | 10 | _ "github.com/lib/pq" 11 | . "gopkg.in/check.v1" 12 | ) 13 | 14 | type Logger interface { 15 | Logf(format string, args ...interface{}) 16 | } 17 | 18 | type DB struct { 19 | *sql.DB 20 | l Logger 21 | } 22 | 23 | func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { 24 | if db.l != nil { 25 | db.l.Logf("%s (args = %#v)", query, args) 26 | } 27 | return db.DB.Query(query, args...) 28 | } 29 | 30 | func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { 31 | if db.l != nil { 32 | db.l.Logf("%s (args = %#v)", query, args) 33 | } 34 | return db.DB.QueryRow(query, args...) 35 | } 36 | 37 | func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { 38 | if db.l != nil { 39 | db.l.Logf("%s (args = %#v)", query, args) 40 | } 41 | return db.DB.Exec(query, args...) 42 | } 43 | 44 | func Test(t *testing.T) { TestingT(t) } 45 | 46 | type TypesSuite struct { 47 | db *DB 48 | skipJSON bool 49 | skipJSONB bool 50 | skipPostGIS bool 51 | } 52 | 53 | var _ = Suite(&TypesSuite{}) 54 | 55 | func (s *TypesSuite) SetUpSuite(c *C) { 56 | db, err := sql.Open("postgres", "dbname=pq_types sslmode=disable") 57 | c.Assert(err, IsNil) 58 | s.db = &DB{ 59 | DB: db, 60 | l: c, 61 | } 62 | 63 | // log full version 64 | var version string 65 | row := db.QueryRow("SELECT version()") 66 | err = row.Scan(&version) 67 | c.Assert(err, IsNil) 68 | log.Print(version) 69 | 70 | // check minor version 71 | row = db.QueryRow("SHOW server_version") 72 | err = row.Scan(&version) 73 | c.Assert(err, IsNil) 74 | minor, err := strconv.Atoi(strings.Split(version, ".")[1]) 75 | c.Assert(err, IsNil) 76 | 77 | // check json and jsonb support 78 | if minor <= 1 { 79 | log.Print("json not available") 80 | s.skipJSON = true 81 | } 82 | if minor <= 3 { 83 | log.Print("jsonb not available") 84 | s.skipJSONB = true 85 | } 86 | 87 | s.db.Exec("DROP TABLE IF EXISTS pq_types") 88 | _, err = s.db.Exec(`CREATE TABLE pq_types( 89 | string_array varchar[], 90 | int32_array int[], 91 | int64_array bigint[], 92 | jsontext_varchar varchar, 93 | null_str varchar, 94 | null_int32 int4, 95 | null_int64 int8, 96 | null_timestamp timestamptz 97 | )`) 98 | c.Assert(err, IsNil) 99 | 100 | if !s.skipJSON { 101 | _, err = s.db.Exec(`ALTER TABLE pq_types ADD COLUMN jsontext_json json`) 102 | c.Assert(err, IsNil) 103 | } 104 | 105 | if !s.skipJSONB { 106 | _, err = s.db.Exec(`ALTER TABLE pq_types ADD COLUMN jsontext_jsonb jsonb`) 107 | c.Assert(err, IsNil) 108 | } 109 | 110 | // check PostGIS 111 | db.Exec("CREATE EXTENSION postgis") 112 | row = db.QueryRow("SELECT PostGIS_full_version()") 113 | err = row.Scan(&version) 114 | if err == nil { 115 | log.Print(version) 116 | 117 | _, err = s.db.Exec(`ALTER TABLE pq_types 118 | ADD COLUMN point geography(POINT, 4326), 119 | ADD COLUMN box box2d, 120 | ADD COLUMN polygon geography(POLYGON, 4326) 121 | `) 122 | c.Assert(err, IsNil) 123 | } else { 124 | log.Printf("PostGIS not available: %s", err) 125 | s.skipPostGIS = true 126 | } 127 | } 128 | 129 | func (s *TypesSuite) SetUpTest(c *C) { 130 | s.db.l = c 131 | _, err := s.db.Exec("TRUNCATE TABLE pq_types") 132 | c.Check(err, IsNil) 133 | } 134 | 135 | func (s *TypesSuite) TearDownSuite(c *C) { 136 | s.db.l = c 137 | s.db.Close() 138 | } 139 | 140 | func (s *TypesSuite) TestEmpty(c *C) { 141 | type record struct { 142 | i32a Int32Array 143 | i64a Int64Array 144 | sa StringArray 145 | } 146 | 147 | for _, r := range []record{ 148 | {}, 149 | {i32a: Int32Array{}, i64a: Int64Array{}, sa: StringArray{}}, 150 | } { 151 | s.SetUpTest(c) 152 | 153 | _, err := s.db.Exec( 154 | "INSERT INTO pq_types (int32_array, int64_array, string_array) VALUES($1, $2, $3)", 155 | r.i32a, r.i64a, r.sa, 156 | ) 157 | c.Assert(err, IsNil) 158 | 159 | var r1 record 160 | row := s.db.QueryRow("SELECT int32_array, int64_array, string_array FROM pq_types") 161 | err = row.Scan(&r1.i32a, &r1.i64a, &r1.sa) 162 | c.Check(err, IsNil) 163 | c.Check(r1, DeepEquals, r) 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /string_array.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "io" 9 | "sort" 10 | "strings" 11 | "unicode" 12 | ) 13 | 14 | // StringArray is a slice of string values, compatible with PostgreSQL's varchar[]. 15 | type StringArray []string 16 | 17 | func (a StringArray) Len() int { return len(a) } 18 | func (a StringArray) Less(i, j int) bool { return a[i] < a[j] } 19 | func (a StringArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 20 | 21 | // Value implements database/sql/driver Valuer interface. 22 | func (a StringArray) Value() (driver.Value, error) { 23 | if a == nil { 24 | return nil, nil 25 | } 26 | 27 | res := make([]string, len(a)) 28 | for i, e := range a { 29 | r := e 30 | r = strings.Replace(r, `\`, `\\`, -1) 31 | r = strings.Replace(r, `"`, `\"`, -1) 32 | res[i] = `"` + r + `"` 33 | } 34 | return []byte("{" + strings.Join(res, ",") + "}"), nil 35 | } 36 | 37 | // Scan implements database/sql Scanner interface. 38 | func (a *StringArray) Scan(value interface{}) error { 39 | if value == nil { 40 | *a = nil 41 | return nil 42 | } 43 | 44 | var b []byte 45 | switch v := value.(type) { 46 | case []byte: 47 | b = v 48 | case string: 49 | b = []byte(v) 50 | default: 51 | return fmt.Errorf("StringArray.Scan: expected []byte or string, got %T (%q)", value, value) 52 | } 53 | 54 | if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { 55 | return fmt.Errorf("StringArray.Scan: unexpected data %q", b) 56 | } 57 | 58 | // reuse underlying array if present 59 | if *a == nil { 60 | *a = make(StringArray, 0) 61 | } 62 | *a = (*a)[:0] 63 | 64 | if len(b) == 2 { // '{}' 65 | return nil 66 | } 67 | 68 | reader := bytes.NewReader(b[1 : len(b)-1]) // skip '{' and '}' 69 | 70 | // helper function to read next rune and check if it valid 71 | readRune := func() (rune, error) { 72 | r, _, err := reader.ReadRune() 73 | if err != nil { 74 | return 0, err 75 | } 76 | if r == unicode.ReplacementChar { 77 | return 0, fmt.Errorf("StringArray.Scan: invalid rune") 78 | } 79 | return r, nil 80 | } 81 | 82 | var q bool 83 | var e []rune 84 | for { 85 | // read next rune and check if we are done 86 | r, err := readRune() 87 | if err == io.EOF { 88 | break 89 | } 90 | if err != nil { 91 | return err 92 | } 93 | 94 | switch r { 95 | case '"': 96 | // enter or leave quotes 97 | q = !q 98 | continue 99 | case ',': 100 | // end of element unless in we are in quotes 101 | if !q { 102 | *a = append(*a, string(e)) 103 | e = e[:0] 104 | continue 105 | } 106 | case '\\': 107 | // skip to next rune, it should be present 108 | n, err := readRune() 109 | if err != nil { 110 | return err 111 | } 112 | r = n 113 | } 114 | 115 | e = append(e, r) 116 | } 117 | 118 | // we should not be in quotes at this point 119 | if q { 120 | panic("StringArray.Scan bug") 121 | } 122 | 123 | // add last element 124 | *a = append(*a, string(e)) 125 | return nil 126 | } 127 | 128 | // check interfaces 129 | var ( 130 | _ sort.Interface = StringArray{} 131 | _ driver.Valuer = StringArray{} 132 | _ sql.Scanner = &StringArray{} 133 | ) 134 | -------------------------------------------------------------------------------- /string_array_test.go: -------------------------------------------------------------------------------- 1 | package pq_types 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | func (s *TypesSuite) TestStringArray(c *C) { 11 | type testData struct { 12 | a StringArray 13 | b []byte 14 | } 15 | for _, d := range []testData{ 16 | {StringArray(nil), []byte(nil)}, 17 | {StringArray{}, []byte(`{}`)}, 18 | 19 | {StringArray{`1234567`}, []byte(`{1234567}`)}, 20 | {StringArray{`abc123, def456 xyz789`, `абв`, `世界,`}, []byte(`{"abc123, def456 xyz789",абв,"世界,"}`)}, 21 | 22 | {StringArray{"", "`", "``", "```", "````"}, []byte("{\"\",`,``,```,````}")}, 23 | {StringArray{``, `'`, `''`, `'''`, `''''`}, []byte(`{"",','',''',''''}`)}, 24 | {StringArray{``, `"`, `""`, `"""`, `""""`}, []byte(`{"","\"","\"\"","\"\"\"","\"\"\"\""}`)}, 25 | {StringArray{``, `,`, `,,`, `,,,`, `,,,,`}, []byte(`{"",",",",,",",,,",",,,,"}`)}, 26 | {StringArray{``, `\`, `\\`, `\\\`, `\\\\`}, []byte(`{"","\\","\\\\","\\\\\\","\\\\\\\\"}`)}, 27 | {StringArray{``, `{`, `{{`, `}}`, `}`, `{{}}`}, []byte(`{"","{","{{","}}","}","{{}}"}`)}, 28 | 29 | {StringArray{`\{`, `\\{{`, `\}\}`, `\}}`}, []byte(`{"\\{","\\\\{{","\\}\\}","\\}}"}`)}, 30 | {StringArray{`\"'`, `\\"`, `\\\"`, `"\"\\""`}, []byte(`{"\\\"'","\\\\\"","\\\\\\\"","\"\\\"\\\\\"\""}`)}, 31 | } { 32 | s.SetUpTest(c) 33 | 34 | _, err := s.db.Exec("INSERT INTO pq_types (string_array) VALUES($1)", d.a) 35 | c.Assert(err, IsNil) 36 | 37 | b1 := []byte("lalala") 38 | a1 := StringArray{"lalala"} 39 | err = s.db.QueryRow("SELECT string_array, string_array FROM pq_types").Scan(&b1, &a1) 40 | c.Check(err, IsNil) 41 | c.Check(b1, DeepEquals, d.b, Commentf("\nb1 = %#q\nd.b = %#q", b1, d.b)) 42 | c.Check(a1, DeepEquals, d.a) 43 | 44 | // check db array length 45 | var length sql.NullInt64 46 | err = s.db.QueryRow("SELECT array_length(string_array, 1) FROM pq_types").Scan(&length) 47 | c.Check(err, IsNil) 48 | c.Check(length.Valid, Equals, len(d.a) > 0) 49 | c.Check(length.Int64, Equals, int64(len(d.a))) 50 | 51 | // check db array elements 52 | for i := 0; i < len(d.a); i++ { 53 | q := fmt.Sprintf("SELECT string_array[%d] FROM pq_types", i+1) 54 | var el sql.NullString 55 | err = s.db.QueryRow(q).Scan(&el) 56 | c.Check(err, IsNil) 57 | c.Check(el.Valid, Equals, true) 58 | c.Check(el.String, Equals, d.a[i]) 59 | } 60 | } 61 | } 62 | --------------------------------------------------------------------------------