├── .chglog ├── CHANGELOG.tpl.md └── config.yml ├── .github ├── dependabot.yml └── workflows │ └── main.yml ├── .gitignore ├── APIDESIGN.md ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── config.go ├── connection.go ├── connection_test.go ├── db.go ├── db_test.go ├── dialect.go ├── dialect_mysql.go ├── dialect_mysql_test.go ├── dialect_postgres.go ├── dialect_postgres_test.go ├── dialect_sqlite3.go ├── dialect_sqlite3_test.go ├── dialect_test.go ├── expr.go ├── expr_test.go ├── go.mod ├── go.sum ├── hook.go ├── hook_test.go ├── internal └── example │ └── models │ ├── common.go │ ├── moment.go │ ├── photo.go │ └── user.go ├── json.go ├── json_test.go ├── logger.go ├── mapper.go ├── mapper_test.go ├── model.go ├── model_test.go ├── model_wrapper.go ├── reflect.go ├── reflect_test.go ├── relation.go ├── relation_test.go ├── sql_builder.go ├── sql_builder_test.go ├── testdata └── test.sql ├── util.go └── util_test.go /.chglog/CHANGELOG.tpl.md: -------------------------------------------------------------------------------- 1 | {{ if .Versions -}} 2 | 3 | ## [Unreleased] 4 | 5 | {{ if .Unreleased.CommitGroups -}} 6 | {{ range .Unreleased.CommitGroups -}} 7 | ### {{ .Title }} 8 | {{ range .Commits -}} 9 | - {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }} 10 | {{ end }} 11 | {{ end -}} 12 | {{ end -}} 13 | {{ end -}} 14 | 15 | {{ range .Versions }} 16 | 17 | ## {{ if .Tag.Previous }}[{{ .Tag.Name }}]{{ else }}{{ .Tag.Name }}{{ end }} - {{ datetime "2006-01-02" .Tag.Date }} 18 | {{ range .CommitGroups -}} 19 | ### {{ .Title }} 20 | {{ range .Commits -}} 21 | - {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }} 22 | {{ end }} 23 | {{ end -}} 24 | 25 | {{- if .MergeCommits -}} 26 | ### Pull Requests 27 | {{ range .MergeCommits -}} 28 | - {{ .Header }} 29 | {{ end }} 30 | {{ end -}} 31 | 32 | {{- if .NoteGroups -}} 33 | {{ range .NoteGroups -}} 34 | ### {{ .Title }} 35 | {{ range .Notes }} 36 | {{ .Body }} 37 | {{ end }} 38 | {{ end -}} 39 | {{ end -}} 40 | {{ end -}} 41 | 42 | {{- if .Versions }} 43 | [Unreleased]: {{ .Info.RepositoryURL }}/compare/{{ $latest := index .Versions 0 }}{{ $latest.Tag.Name }}...HEAD 44 | {{ range .Versions -}} 45 | {{ if .Tag.Previous -}} 46 | [{{ .Tag.Name }}]: {{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }} 47 | {{ end -}} 48 | {{ end -}} 49 | {{ end -}} -------------------------------------------------------------------------------- /.chglog/config.yml: -------------------------------------------------------------------------------- 1 | style: github 2 | template: CHANGELOG.tpl.md 3 | info: 4 | title: CHANGELOG 5 | repository_url: https://github.com/ilibs/gosql 6 | options: 7 | commits: 8 | # filters: 9 | # Type: 10 | # - feat 11 | # - fix 12 | # - perf 13 | # - refactor 14 | commit_groups: 15 | # title_maps: 16 | # feat: Features 17 | # fix: Bug Fixes 18 | # perf: Performance Improvements 19 | # refactor: Code Refactoring 20 | header: 21 | pattern: "^(\\w*)(?:\\(([\\w\\$\\.\\-\\*\\s]*)\\))?\\:\\s(.*)$" 22 | pattern_maps: 23 | - Type 24 | - Scope 25 | - Subject 26 | notes: 27 | keywords: 28 | - BREAKING CHANGE -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: gosql 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | strategy: 6 | matrix: 7 | go-version: [1.13, 1.14, 1.15, 1.16] 8 | runs-on: ubuntu-latest 9 | services: 10 | mysql: 11 | image: mysql:5.7 12 | env: 13 | MYSQL_ROOT_PASSWORD: root 14 | ports: 15 | - 32574:3306 16 | options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@master 20 | 21 | - name: Set up Golang ${{ matrix.go-version }} 22 | uses: actions/setup-go@v1 23 | with: 24 | go-version: ${{ matrix.go-version }} 25 | id: go 26 | 27 | - name: Init Database 28 | run: | 29 | mysql -h127.0.0.1 --port 32574 -uroot -proot -e 'CREATE DATABASE IF NOT EXISTS db1;' 30 | mysql -h127.0.0.1 --port 32574 -uroot -proot -e 'CREATE DATABASE IF NOT EXISTS db2;' 31 | 32 | - name: Test 33 | env: 34 | MYSQL_TEST_DSN1: "root:root@tcp(127.0.0.1:32574)/db1?parseTime=true" 35 | MYSQL_TEST_DSN2: "root:root@tcp(127.0.0.1:32574)/db2?parseTime=true" 36 | run: | 37 | go get && make test 38 | bash <(curl -s https://codecov.io/bash) -t ${{ secrets.CODECOV_TOKEN}} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /vendor 2 | .idea/ 3 | Gopkg.lock 4 | cover.out 5 | cover_total.out 6 | cover.html -------------------------------------------------------------------------------- /APIDESIGN.md: -------------------------------------------------------------------------------- 1 | # API Design 2 | 3 | ## Model interface 4 | ```go 5 | type IModel interface { 6 | TableName() string 7 | PK() string 8 | } 9 | ``` 10 | 11 | > Remove the V1 version DbName(),Use the Use() function instead 12 | 13 | ## Use sqlx 14 | ```go 15 | gosql.Sqlx() //return native sqlx 16 | ``` 17 | ## Change database 18 | ```go 19 | gosql.Use(name string) 20 | gosql.Use(db).Table("xxxx").Where("id = ?",1).Update(map[string]interface{}{"name":"test"}) 21 | gosql.Use(db).Model(&Users{}}).Get() 22 | ``` 23 | 24 | ## Transaction context switching 25 | ```go 26 | gosql.Use(db).Tx(func(tx *gosql.DB){ 27 | tx.Table("xxxx").Where("id = ?",1).Get(&user) 28 | tx.Model(&Users{}).Get() 29 | }) 30 | ``` 31 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | ## [Unreleased] 3 | 4 | ### Fix 5 | - ./dialect.go:45:2: unreachable code 6 | 7 | 8 | 9 | ## [v2.0.5] - 2020-08-07 10 | 11 | 12 | ## [v2.0.4] - 2020-07-14 13 | 14 | 15 | ## [v2.0.3] - 2020-04-02 16 | 17 | 18 | ## [v2.0.2] - 2020-01-09 19 | ### Pull Requests 20 | - Merge pull request [#20](https://github.com/ilibs/gosql/issues/20) from ilibs/dependabot/go_modules/github.com/go-sql-driver/mysql-1.5.0 21 | - Merge pull request [#19](https://github.com/ilibs/gosql/issues/19) from xzyaoi/patch-1 22 | 23 | 24 | 25 | ## [v2.0.1] - 2019-10-16 26 | 27 | 28 | ## [v2.0.0] - 2019-09-12 29 | ### Pull Requests 30 | - Merge pull request [#13](https://github.com/ilibs/gosql/issues/13) from brucehuang2/master 31 | 32 | 33 | 34 | ## [v1.1.9] - 2019-07-09 35 | ### Pull Requests 36 | - Merge pull request [#12](https://github.com/ilibs/gosql/issues/12) from brucehuang2/master 37 | 38 | 39 | 40 | ## [v1.1.8] - 2019-05-17 41 | 42 | 43 | ## [v1.1.7] - 2019-05-17 44 | 45 | 46 | ## [v1.1.6] - 2019-05-14 47 | 48 | 49 | ## [v1.1.5] - 2019-04-26 50 | 51 | 52 | ## [v1.1.4] - 2019-01-15 53 | 54 | 55 | ## [v1.1.3] - 2018-12-27 56 | 57 | 58 | ## [v1.1.2] - 2018-12-26 59 | 60 | 61 | ## [v1.1.1] - 2018-12-07 62 | ### Pull Requests 63 | - Merge pull request [#11](https://github.com/ilibs/gosql/issues/11) from HQ6968/master 64 | 65 | 66 | 67 | ## [v1.1.0] - 2018-12-06 68 | 69 | 70 | ## [v1.0.12] - 2018-12-03 71 | 72 | 73 | ## [v1.0.10] - 2018-12-03 74 | 75 | 76 | ## [v1.0.9] - 2018-11-28 77 | 78 | 79 | ## [v1.0.8] - 2018-11-20 80 | ### Pull Requests 81 | - Merge pull request [#10](https://github.com/ilibs/gosql/issues/10) from Wanchaochao/master 82 | 83 | 84 | 85 | ## [v1.0.7] - 2018-11-17 86 | 87 | 88 | ## [v1.0.6] - 2018-11-16 89 | 90 | 91 | ## [v1.0.5] - 2018-11-06 92 | 93 | 94 | ## [v1.0.4] - 2018-09-26 95 | 96 | 97 | ## [v1.0.3] - 2018-09-18 98 | 99 | 100 | ## [v1.0.2] - 2018-09-18 101 | 102 | 103 | ## [v1.0.1] - 2018-09-18 104 | 105 | 106 | ## v1.0.0 - 2018-09-05 107 | ### Pull Requests 108 | - Merge pull request [#5](https://github.com/ilibs/gosql/issues/5) from ilibs/sqlnul 109 | 110 | 111 | [Unreleased]: https://github.com/ilibs/gosql/compare/v2.0.5...HEAD 112 | [v2.0.5]: https://github.com/ilibs/gosql/compare/v2.0.4...v2.0.5 113 | [v2.0.4]: https://github.com/ilibs/gosql/compare/v2.0.3...v2.0.4 114 | [v2.0.3]: https://github.com/ilibs/gosql/compare/v2.0.2...v2.0.3 115 | [v2.0.2]: https://github.com/ilibs/gosql/compare/v2.0.1...v2.0.2 116 | [v2.0.1]: https://github.com/ilibs/gosql/compare/v2.0.0...v2.0.1 117 | [v2.0.0]: https://github.com/ilibs/gosql/compare/v1.1.9...v2.0.0 118 | [v1.1.9]: https://github.com/ilibs/gosql/compare/v1.1.8...v1.1.9 119 | [v1.1.8]: https://github.com/ilibs/gosql/compare/v1.1.7...v1.1.8 120 | [v1.1.7]: https://github.com/ilibs/gosql/compare/v1.1.6...v1.1.7 121 | [v1.1.6]: https://github.com/ilibs/gosql/compare/v1.1.5...v1.1.6 122 | [v1.1.5]: https://github.com/ilibs/gosql/compare/v1.1.4...v1.1.5 123 | [v1.1.4]: https://github.com/ilibs/gosql/compare/v1.1.3...v1.1.4 124 | [v1.1.3]: https://github.com/ilibs/gosql/compare/v1.1.2...v1.1.3 125 | [v1.1.2]: https://github.com/ilibs/gosql/compare/v1.1.1...v1.1.2 126 | [v1.1.1]: https://github.com/ilibs/gosql/compare/v1.1.0...v1.1.1 127 | [v1.1.0]: https://github.com/ilibs/gosql/compare/v1.0.12...v1.1.0 128 | [v1.0.12]: https://github.com/ilibs/gosql/compare/v1.0.10...v1.0.12 129 | [v1.0.10]: https://github.com/ilibs/gosql/compare/v1.0.9...v1.0.10 130 | [v1.0.9]: https://github.com/ilibs/gosql/compare/v1.0.8...v1.0.9 131 | [v1.0.8]: https://github.com/ilibs/gosql/compare/v1.0.7...v1.0.8 132 | [v1.0.7]: https://github.com/ilibs/gosql/compare/v1.0.6...v1.0.7 133 | [v1.0.6]: https://github.com/ilibs/gosql/compare/v1.0.5...v1.0.6 134 | [v1.0.5]: https://github.com/ilibs/gosql/compare/v1.0.4...v1.0.5 135 | [v1.0.4]: https://github.com/ilibs/gosql/compare/v1.0.3...v1.0.4 136 | [v1.0.3]: https://github.com/ilibs/gosql/compare/v1.0.2...v1.0.3 137 | [v1.0.2]: https://github.com/ilibs/gosql/compare/v1.0.1...v1.0.2 138 | [v1.0.1]: https://github.com/ilibs/gosql/compare/v1.0.0...v1.0.1 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 iLibs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO ?= go 2 | 3 | .PHONY: test 4 | test: 5 | $(GO) test -short -v -coverprofile=cover.out ./... 6 | 7 | .PHONY: cover 8 | cover: 9 | $(GO) tool cover -func=cover.out -o cover_total.out 10 | $(GO) tool cover -html=cover.out -o cover.html -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gosql 2 | The package based on [sqlx](https://github.com/jmoiron/sqlx), It's simple and keep simple 3 | 4 | Build Status 5 | codecov 6 | Go Report Card
  7 | 8 | GoDoc 9 | 10 | 11 | ⚠️ Because of some disruptive changes, The current major version is upgraded to V2,If you continue with V1, you can check out the v1 branches [https://github.com/ilibs/gosql/tree/v1](https://github.com/ilibs/gosql/tree/v1) 12 | 13 | ## V2 ChangeLog 14 | - Remove the second argument to the Model() and Table() functions and replace it with WithTx(tx) 15 | - Remove Model interface DbName() function,use the Use() function 16 | - Uniform API design specification, see [APIDESIGN](APIDESIGN.md) 17 | - Relation add `connection:"db2"` struct tag, Solve the cross-library connection problem caused by deleting DbName() 18 | - Discard the WithTx function 19 | 20 | ## Usage 21 | 22 | Connection database and use sqlx original function,See the https://github.com/jmoiron/sqlx 23 | 24 | ```go 25 | import ( 26 | _ "github.com/go-sql-driver/mysql" //mysql driver 27 | "github.com/ilibs/gosql/v2" 28 | ) 29 | 30 | func main(){ 31 | configs := make(map[string]*gosql.Config) 32 | 33 | configs["default"] = &gosql.Config{ 34 | Enable: true, 35 | Driver: "mysql", 36 | Dsn: "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8&parseTime=True&loc=Asia%2FShanghai", 37 | ShowSql: true, 38 | } 39 | 40 | //connection database 41 | gosql.Connect(configs) 42 | gosql.QueryRowx("select * from users where id = 1") 43 | } 44 | 45 | ``` 46 | 47 | Use `default` database, So you can use wrapper function 48 | 49 | ```go 50 | //Exec 51 | gosql.Exec("insert into users(name,email,created_at,updated_at) value(?,?,?,?)","test","test@gmail.com",time.Now(),time.Now()) 52 | 53 | //Queryx 54 | rows,err := gosql.Queryx("select * from users") 55 | for rows.Next() { 56 | user := &Users{} 57 | err = rows.StructScan(user) 58 | } 59 | rows.Close() 60 | 61 | //QueryRowx 62 | user := &Users{} 63 | err := gosql.QueryRowx("select * from users where id = ?",1).StructScan(user) 64 | 65 | //Get 66 | user := &Users{} 67 | err := gosql.Get(user,"select * from users where id = ?",1) 68 | 69 | //Select 70 | users := make([]Users) 71 | err := gosql.Select(&users,"select * from users") 72 | 73 | //Change database 74 | db := gosql.Use("test") 75 | db.Queryx("select * from tests") 76 | ``` 77 | 78 | You can also set the default database connection name 79 | 80 | ```go 81 | gosql.SetDefaultLink("log") 82 | gosql.Connect(configs) 83 | ``` 84 | 85 | > `gosql.Get` etc., will use the configuration with the connection name `log` 86 | 87 | ## Using struct 88 | 89 | ```go 90 | type Users struct { 91 | Id int `db:"id"` 92 | Name string `db:"name"` 93 | Email string `db:"email"` 94 | Status int `db:"status"` 95 | CreatedAt time.Time `db:"created_at"` 96 | UpdatedAt time.Time `db:"updated_at"` 97 | } 98 | 99 | func (u *Users) TableName() string { 100 | return "users" 101 | } 102 | 103 | func (u *Users) PK() string { 104 | return "id" 105 | } 106 | 107 | //Get 108 | user := &Users{} 109 | gosql.Model(user).Where("id=?",1).Get() 110 | 111 | //All 112 | user := make([]Users,0) 113 | gosql.Model(&user).All() 114 | 115 | //Create and auto set CreatedAt 116 | gosql.Model(&User{Name:"test",Email:"test@gmail.com"}).Create() 117 | 118 | //Update 119 | gosql.Model(&User{Name:"test2",Email:"test@gmail.com"}).Where("id=?",1).Update() 120 | //If you need to update the zero value, you can do so 121 | gosql.Model(&User{Status:0}).Where("id=?",1).Update("status") 122 | 123 | //Delete 124 | gosql.Model(&User{}).Where("id=?",1).Delete() 125 | 126 | ``` 127 | 128 | If you use struct to generate where conditions 129 | 130 | ```go 131 | //Get where id = 1 and name = "test1" 132 | user := &Users{Id:1,Name:"test1"} 133 | gosql.Model(&user).Get() 134 | 135 | //Update default use primary key as the condition 136 | gosql.Model(&User{Id:1,Name:"test2"}).Update() 137 | //Use custom conditions 138 | //Builder => UPDATE users SET `id`=?,`name`=?,`updated_at`=? WHERE (status = ?) 139 | gosql.Model(&User{Id:1,Name:"test2"}).Where("status = ?",1).Update() 140 | 141 | //Delete 142 | gosql.Model(&User{Id:1}).Delete() 143 | ``` 144 | 145 | But the zero value is filtered by default, you can specify fields that are not filtered. For example 146 | 147 | ```go 148 | user := &Users{Id:1,Status:0} 149 | gosql.Model(&user).Get("status") 150 | ``` 151 | 152 | > You can use the [genstruct](https://github.com/fifsky/genstruct) tool to quickly generate database structs 153 | 154 | ## Transaction 155 | The `Tx` function has a callback function, if an error is returned, the transaction rollback 156 | 157 | ```go 158 | gosql.Tx(func(tx *gosql.DB) error { 159 | for id := 1; id < 10; id++ { 160 | user := &Users{ 161 | Id: id, 162 | Name: "test" + strconv.Itoa(id), 163 | Email: "test" + strconv.Itoa(id) + "@test.com", 164 | } 165 | 166 | //v2 support, do some database operations in the transaction (use 'tx' from this point, not 'gosql') 167 | tx.Model(user).Create() 168 | 169 | if id == 8 { 170 | return errors.New("interrupt the transaction") 171 | } 172 | } 173 | 174 | //query with transaction 175 | var num int 176 | err := tx.QueryRowx("select count(*) from user_id = 1").Scan(&num) 177 | 178 | if err != nil { 179 | return err 180 | } 181 | 182 | return nil 183 | }) 184 | ``` 185 | 186 | > If you need to invoke context, you can use `gosql.Txx` 187 | 188 | Now support gosql.Begin() or gosql.Use("other").Begin() for example: 189 | ```go 190 | tx, err := gosql.Begin() 191 | if err != nil { 192 | return err 193 | } 194 | 195 | for id := 1; id < 10; id++ { 196 | _, err := tx.Exec("INSERT INTO users(id,name,status,created_at,updated_at) VALUES(?,?,?,?,?)", id, "test"+strconv.Itoa(id), 1, time.Now(), time.Now()) 197 | if err != nil { 198 | return tx.Rollback() 199 | } 200 | } 201 | 202 | return tx.Commit() 203 | ``` 204 | 205 | ## Automatic time 206 | If your fields contain the following field names, they will be updated automatically 207 | 208 | ``` 209 | AUTO_CREATE_TIME_FIELDS = []string{ 210 | "create_time", 211 | "create_at", 212 | "created_at", 213 | "update_time", 214 | "update_at", 215 | "updated_at", 216 | } 217 | AUTO_UPDATE_TIME_FIELDS = []string{ 218 | "update_time", 219 | "update_at", 220 | "updated_at", 221 | } 222 | ``` 223 | 224 | 225 | ## Using Map 226 | `Create` `Update` `Delete` `Count` support `map[string]interface`,For example: 227 | 228 | ```go 229 | //Create 230 | gosql.Table("users").Create(map[string]interface{}{ 231 | "id": 1, 232 | "name": "test", 233 | "email": "test@test.com", 234 | "created_at": "2018-07-11 11:58:21", 235 | "updated_at": "2018-07-11 11:58:21", 236 | }) 237 | 238 | //Update 239 | gosql.Table("users").Where("id = ?", 1).Update(map[string]interface{}{ 240 | "name": "fifsky", 241 | "email": "fifsky@test.com", 242 | }) 243 | 244 | //Delete 245 | gosql.Table("users").Where("id = ?", 1).Delete() 246 | 247 | //Count 248 | gosql.Table("users").Where("id = ?", 1).Count() 249 | 250 | //Change database 251 | gosql.Use("db2").Table("users").Where("id = ?", 1).Count() 252 | 253 | //Transaction `tx` 254 | tx.Table("users").Where("id = ?", 1}).Count() 255 | ``` 256 | 257 | 258 | ## sql.Null* 259 | Now Model support sql.Null* field's, Note, however, that if sql.Null* is also filtered by zero values,For example 260 | 261 | ```go 262 | type Users struct { 263 | Id int `db:"id"` 264 | Name string `db:"name"` 265 | Email string `db:"email"` 266 | Status int `db:"status"` 267 | SuccessTime sql.NullString `db:"success_time" json:"success_time"` 268 | CreatedAt time.Time `db:"created_at" json:"created_at"` 269 | UpdatedAt time.Time `db:"updated_at" json:"updated_at"` 270 | } 271 | 272 | user := &Users{ 273 | Id: 1, 274 | SuccessTime: sql.NullString{ 275 | String: "2018-09-03 00:00:00", 276 | Valid: false, 277 | } 278 | } 279 | 280 | err := gosql.Model(user).Get() 281 | ``` 282 | 283 | Builder SQL: 284 | 285 | ``` 286 | Query: SELECT * FROM users WHERE (id=?); 287 | Args: []interface {}{1} 288 | Time: 0.00082s 289 | ``` 290 | 291 | If `sql.NullString` of `Valid` attribute is false, SQL builder will ignore this zero value 292 | 293 | 294 | ## gosql.Expr 295 | Reference GORM Expr, Resolve update field self-update problem 296 | ```go 297 | gosql.Table("users").Update(map[string]interface{}{ 298 | "id":2, 299 | "count":gosql.Expr("count+?",1) 300 | }) 301 | //Builder SQL 302 | //UPDATE `users` SET `count`=count + ?,`id`=?; [1 2] 303 | ``` 304 | 305 | 306 | ## "In" Queries 307 | 308 | Because database/sql does not inspect your query and it passes your arguments directly to the driver, it makes dealing with queries with IN clauses difficult: 309 | 310 | ```go 311 | SELECT * FROM users WHERE level IN (?); 312 | ``` 313 | 314 | `sqlx.In` is encapsulated In `gosql` and can be queried using the following schema 315 | 316 | ```go 317 | var levels = []int{4, 6, 7} 318 | rows, err := gosql.Queryx("SELECT * FROM users WHERE level IN (?);", levels) 319 | 320 | //or 321 | 322 | user := make([]Users, 0) 323 | err := gosql.Select(&user, "select * from users where id in(?)",[]int{1,2,3}) 324 | ``` 325 | 326 | ## Relation 327 | gosql used the golang structure to express the relationships between tables,You only need to use the `relation` Tag to specify the associated field, see example 328 | 329 | ⚠️ Since version v2, the relation query across library connections needs to be specified using `connection` tag 330 | 331 | 332 | ```go 333 | type MomentList struct { 334 | models.Moments 335 | User *models.Users `json:"user" db:"-" relation:"user_id,id"` //one-to-one 336 | Photos []models.Photos `json:"photos" db:"-" relation:"id,moment_id" connection:"db2"` //one-to-many 337 | } 338 | ``` 339 | 340 | Get single result 341 | 342 | ```go 343 | moment := &MomentList{} 344 | err := gosql.Model(moment).Where("status = 1 and id = ?",14).Get() 345 | //output User and Photos and you get the result 346 | ``` 347 | 348 | SQL: 349 | 350 | ```sql 351 | 2018/12/06 13:27:54 352 | Query: SELECT * FROM `moments` WHERE (status = 1 and id = ?); 353 | Args: []interface {}{14} 354 | Time: 0.00300s 355 | 356 | 2018/12/06 13:27:54 357 | Query: SELECT * FROM `moment_users` WHERE (id=?); 358 | Args: []interface {}{5} 359 | Time: 0.00081s 360 | 361 | 2018/12/06 13:27:54 362 | Query: SELECT * FROM `photos` WHERE (moment_id=?); 363 | Args: []interface {}{14} 364 | Time: 0.00093s 365 | ``` 366 | 367 | Get list result, many-to-many 368 | 369 | ```go 370 | var moments = make([]MomentList, 0) 371 | err := gosql.Model(&moments).Where("status = 1").Limit(10).All() 372 | //You get the total result for *UserMoment slice 373 | ``` 374 | 375 | SQL: 376 | 377 | ```sql 378 | 2018/12/06 13:50:59 379 | Query: SELECT * FROM `moments` WHERE (status = 1) LIMIT 10; 380 | Time: 0.00319s 381 | 382 | 2018/12/06 13:50:59 383 | Query: SELECT * FROM `moment_users` WHERE (id in(?)); 384 | Args: []interface {}{[]interface {}{5}} 385 | Time: 0.00094s 386 | 387 | 2018/12/06 13:50:59 388 | Query: SELECT * FROM `photos` WHERE (moment_id in(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)); 389 | Args: []interface {}{[]interface {}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}} 390 | Time: 0.00087s 391 | ``` 392 | 393 | 394 | Relation Where: 395 | 396 | ```go 397 | moment := &MomentList{} 398 | err := gosql.Relation("User" , func(b *gosql.Builder) { 399 | //this is builder instance, 400 | b.Where("gender = 0") 401 | }).Get(moment , "select * from moments") 402 | ``` 403 | 404 | ## Hooks 405 | Hooks are functions that are called before or after creation/querying/updating/deletion. 406 | 407 | If you have defiend specified methods for a model, it will be called automatically when creating, updating, querying, deleting, and if any callback returns an error, `gosql` will stop future operations and rollback current transaction. 408 | 409 | ``` 410 | // begin transaction 411 | BeforeChange 412 | BeforeCreate 413 | // update timestamp `CreatedAt`, `UpdatedAt` 414 | // save 415 | AfterCreate 416 | AfterChange 417 | // commit or rollback transaction 418 | ``` 419 | Example: 420 | 421 | ```go 422 | func (u *Users) BeforeCreate(ctx context.Context) (err error) { 423 | if u.IsValid() { 424 | err = errors.New("can't save invalid data") 425 | } 426 | return 427 | } 428 | 429 | func (u *Users) AfterCreate(ctx context.Context, tx *gosql.DB) (err error) { 430 | if u.Id == 1 { 431 | u.Email = ctx.Value("email") 432 | tx.Model(u).Update() 433 | } 434 | return 435 | } 436 | ``` 437 | 438 | > BeforeChange and AfterChange only used in create/update/delete 439 | 440 | All Hooks: 441 | 442 | ``` 443 | BeforeChange 444 | AfterChange 445 | BeforeCreate 446 | AfterCreate 447 | BeforeUpdate 448 | AfterUpdate 449 | BeforeDelete 450 | AfterDelete 451 | BeforeFind 452 | AfterFind 453 | ``` 454 | 455 | Hook func type supports multiple ways: 456 | 457 | ``` 458 | func (u *Users) BeforeCreate() 459 | func (u *Users) BeforeCreate() (err error) 460 | func (u *Users) BeforeCreate(tx *gosql.DB) 461 | func (u *Users) BeforeCreate(tx *gosql.DB) (err error) 462 | func (u *Users) BeforeCreate(ctx context.Context) 463 | func (u *Users) BeforeCreate(ctx context.Context) (err error) 464 | func (u *Users) BeforeCreate(ctx context.Context, tx *rsql.DB) 465 | func (u *Users) BeforeCreate(ctx context.Context, tx *rsql.DB) (err error) 466 | 467 | ``` 468 | 469 | If you want to use `context` feature, you need to use below function while start a sql, or the context in callback will be nil: 470 | 471 | 1. ` gosql.WithContext(ctx).Model(...)` 472 | 1. ` gosql.Use("xxx").WithContext(ctx).Model(...)` 473 | 474 | 475 | ## Thanks 476 | 477 | sqlx https://github.com/jmoiron/sqlx 478 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | // Config is database connection configuration 4 | type Config struct { 5 | Enable bool `yml:"enable" toml:"enable" json:"enable"` 6 | Driver string `yml:"driver" toml:"driver" json:"driver"` 7 | Dsn string `yml:"dsn" toml:"dsn" json:"dsn"` 8 | MaxOpenConns int `yml:"max_open_conns" toml:"max_open_conns" json:"max_open_conns"` 9 | MaxIdleConns int `yml:"max_idle_conns" toml:"max_idle_conns" json:"max_idle_conns"` 10 | MaxLifetime int `yml:"max_lifetime" toml:"max_lifetime" json:"max_lifetime"` 11 | ShowSql bool `yml:"show_sql" toml:"show_sql" json:"show_sql"` 12 | } 13 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strings" 9 | "time" 10 | 11 | "github.com/jmoiron/sqlx" 12 | ) 13 | 14 | // defaultLink set database default link name 15 | var defaultLink = "default" 16 | 17 | // If database fatal exit 18 | var FatalExit = true 19 | var dbService = make(map[string]*sqlx.DB, 0) 20 | 21 | // DB gets the specified database engine, 22 | // or the default DB if no name is specified. 23 | func Sqlx(name ...string) *sqlx.DB { 24 | dbName := defaultLink 25 | if name != nil { 26 | dbName = name[0] 27 | } 28 | 29 | engine, ok := dbService[dbName] 30 | if !ok { 31 | panic(fmt.Sprintf("[db] the database link `%s` is not configured", dbName)) 32 | } 33 | return engine 34 | } 35 | 36 | // List gets the list of database engines 37 | func List() map[string]*sqlx.DB { 38 | return dbService 39 | } 40 | 41 | type Options struct { 42 | maxOpenConns int 43 | maxIdleConns int 44 | maxLifetime int 45 | } 46 | 47 | type Option func(*Options) 48 | 49 | func WithMaxOpenConns(i int) Option { 50 | return func(options *Options) { 51 | options.maxOpenConns = i 52 | } 53 | } 54 | 55 | func WithMaxIdleConns(i int) Option { 56 | return func(options *Options) { 57 | options.maxIdleConns = i 58 | } 59 | } 60 | 61 | func WithMaxLifetimes(i int) Option { 62 | return func(options *Options) { 63 | options.maxLifetime = i 64 | } 65 | } 66 | 67 | // Open gosql.DB with sqlx 68 | func Open(driver, dbSource string, opts ...Option) (*DB, error) { 69 | 70 | var options Options 71 | for _, opt := range opts { 72 | opt(&options) 73 | } 74 | 75 | db, err := sqlx.Connect(driver, dbSource) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | if options.maxOpenConns > 0 { 81 | db.SetMaxOpenConns(options.maxOpenConns) 82 | } 83 | 84 | if options.maxIdleConns > 0 { 85 | db.SetMaxIdleConns(options.maxIdleConns) 86 | } 87 | 88 | if options.maxLifetime > 0 { 89 | db.SetConnMaxLifetime(time.Duration(options.maxLifetime) * time.Second) 90 | } 91 | 92 | return &DB{database: db}, nil 93 | } 94 | 95 | // OpenWithDB open gosql.DB with sql.DB 96 | func OpenWithDB(driver string, db *sql.DB) *DB { 97 | return &DB{database: sqlx.NewDb(db, driver)} 98 | } 99 | 100 | // Connect database 101 | func Connect(configs map[string]*Config) (err error) { 102 | 103 | var errs []string 104 | defer func() { 105 | if len(errs) > 0 { 106 | err = errors.New("[db] " + strings.Join(errs, "\n")) 107 | if FatalExit { 108 | log.Fatal(err) 109 | } 110 | } 111 | }() 112 | 113 | for key, conf := range configs { 114 | if !conf.Enable { 115 | continue 116 | } 117 | 118 | sess, err := sqlx.Connect(conf.Driver, conf.Dsn) 119 | 120 | if err != nil { 121 | errs = append(errs, err.Error()) 122 | continue 123 | } 124 | log.Println("[db] connect:" + key) 125 | 126 | if conf.ShowSql { 127 | logger.SetLogging(true) 128 | } 129 | 130 | sess.SetMaxOpenConns(conf.MaxOpenConns) 131 | sess.SetMaxIdleConns(conf.MaxIdleConns) 132 | if conf.MaxLifetime > 0 { 133 | sess.SetConnMaxLifetime(time.Duration(conf.MaxLifetime) * time.Second) 134 | } 135 | 136 | if db, ok := dbService[key]; ok { 137 | _ = db.Close() 138 | } 139 | 140 | dbService[key] = sess 141 | } 142 | return 143 | } 144 | -------------------------------------------------------------------------------- /connection_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | _ "github.com/go-sql-driver/mysql" 8 | ) 9 | 10 | func TestMain(m *testing.M) { 11 | configs := make(map[string]*Config) 12 | 13 | dsn1 := os.Getenv("MYSQL_TEST_DSN1") 14 | 15 | if dsn1 == "" { 16 | dsn1 = "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8&parseTime=True&loc=Asia%2FShanghai" 17 | } 18 | 19 | dsn2 := os.Getenv("MYSQL_TEST_DSN2") 20 | 21 | if dsn2 == "" { 22 | dsn2 = "root:123456@tcp(127.0.0.1:3306)/test2?charset=utf8&parseTime=True&loc=Asia%2FShanghai" 23 | } 24 | 25 | configs["default"] = &Config{ 26 | Enable: true, 27 | Driver: "mysql", 28 | Dsn: dsn1, 29 | ShowSql: true, 30 | } 31 | 32 | configs["db2"] = &Config{ 33 | Enable: true, 34 | Driver: "mysql", 35 | Dsn: dsn2, 36 | ShowSql: true, 37 | } 38 | 39 | _ = Connect(configs) 40 | 41 | m.Run() 42 | } 43 | 44 | func TestWithOptions(t *testing.T) { 45 | dsn := os.Getenv("MYSQL_TEST_DSN2") 46 | 47 | if dsn == "" { 48 | dsn = "root:123456@tcp(127.0.0.1:3306)/test2?charset=utf8&parseTime=True&loc=Asia%2FShanghai" 49 | } 50 | 51 | _, err := Open("mysql", dsn, WithMaxOpenConns(10), WithMaxIdleConns(100), WithMaxLifetimes(100)) 52 | if err != nil { 53 | t.Error(err) 54 | } 55 | } 56 | 57 | func TestConnect(t *testing.T) { 58 | db := Sqlx() 59 | 60 | if db.DriverName() != "mysql" { 61 | t.Fatalf("sqlx database connection error") 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "log" 7 | "reflect" 8 | "time" 9 | 10 | "github.com/jmoiron/sqlx" 11 | ) 12 | 13 | type ISqlx interface { 14 | Queryx(query string, args ...interface{}) (*sqlx.Rows, error) 15 | QueryRowx(query string, args ...interface{}) *sqlx.Row 16 | Get(dest interface{}, query string, args ...interface{}) error 17 | Select(dest interface{}, query string, args ...interface{}) error 18 | Exec(query string, args ...interface{}) (sql.Result, error) 19 | NamedExec(query string, arg interface{}) (sql.Result, error) 20 | Preparex(query string) (*sqlx.Stmt, error) 21 | Rebind(query string) string 22 | DriverName() string 23 | } 24 | 25 | type BuilderChainFunc func(b *Builder) 26 | 27 | type DB struct { 28 | database *sqlx.DB 29 | tx *sqlx.Tx 30 | logging bool 31 | RelationMap map[string]BuilderChainFunc 32 | } 33 | 34 | // return database instance, if it is a transaction, the transaction priority is higher 35 | func (w *DB) db() ISqlx { 36 | if w.tx != nil { 37 | return w.tx.Unsafe() 38 | } 39 | 40 | return w.database.Unsafe() 41 | } 42 | 43 | // ShowSql single show sql log 44 | func ShowSql() *DB { 45 | w := Use(defaultLink) 46 | w.logging = true 47 | return w 48 | } 49 | 50 | func (w *DB) argsIn(query string, args []interface{}) (string, []interface{}, error) { 51 | newArgs := make([]interface{}, 0) 52 | newQuery, newArgs, err := sqlx.In(query, args...) 53 | 54 | if err != nil { 55 | return query, args, err 56 | } 57 | 58 | return newQuery, newArgs, nil 59 | } 60 | 61 | // DriverName wrapper sqlx.DriverName 62 | func (w *DB) DriverName() string { 63 | if w.tx != nil { 64 | return w.tx.DriverName() 65 | } 66 | 67 | return w.database.DriverName() 68 | } 69 | 70 | func (w *DB) ShowSql() *DB { 71 | w.logging = true 72 | return w 73 | } 74 | 75 | // Beginx begins a transaction and returns an *gosql.DB instead of an *sql.Tx. 76 | func (w *DB) Begin() (*DB, error) { 77 | tx, err := w.database.Beginx() 78 | if err != nil { 79 | return nil, err 80 | } 81 | return &DB{tx: tx}, nil 82 | } 83 | 84 | // Commit commits the transaction. 85 | func (w *DB) Commit() error { 86 | return w.tx.Commit() 87 | } 88 | 89 | // Rollback aborts the transaction. 90 | func (w *DB) Rollback() error { 91 | return w.tx.Rollback() 92 | } 93 | 94 | // Rebind wrapper sqlx.Rebind 95 | func (w *DB) Rebind(query string) string { 96 | return w.db().Rebind(query) 97 | } 98 | 99 | // Preparex wrapper sqlx.Preparex 100 | func (w *DB) Preparex(query string) (*sqlx.Stmt, error) { 101 | return w.db().Preparex(query) 102 | } 103 | 104 | // Exec wrapper sqlx.Exec 105 | func (w *DB) Exec(query string, args ...interface{}) (result sql.Result, err error) { 106 | defer func(start time.Time) { 107 | logger.Log(&QueryStatus{ 108 | Query: query, 109 | Args: args, 110 | Err: err, 111 | Start: start, 112 | End: time.Now(), 113 | }, w.logging) 114 | 115 | }(time.Now()) 116 | 117 | return w.db().Exec(query, args...) 118 | } 119 | 120 | // NamedExec wrapper sqlx.Exec 121 | func (w *DB) NamedExec(query string, args interface{}) (result sql.Result, err error) { 122 | defer func(start time.Time) { 123 | logger.Log(&QueryStatus{ 124 | Query: query, 125 | Args: args, 126 | Err: err, 127 | Start: start, 128 | End: time.Now(), 129 | }, w.logging) 130 | 131 | }(time.Now()) 132 | 133 | return w.db().NamedExec(query, args) 134 | } 135 | 136 | // Queryx wrapper sqlx.Queryx 137 | func (w *DB) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, err error) { 138 | defer func(start time.Time) { 139 | logger.Log(&QueryStatus{ 140 | Query: query, 141 | Args: args, 142 | Err: err, 143 | Start: start, 144 | End: time.Now(), 145 | }, w.logging) 146 | }(time.Now()) 147 | 148 | query, newArgs, err := w.argsIn(query, args) 149 | if err != nil { 150 | return nil, err 151 | } 152 | 153 | return w.db().Queryx(query, newArgs...) 154 | } 155 | 156 | // QueryRowx wrapper sqlx.QueryRowx 157 | func (w *DB) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) { 158 | defer func(start time.Time) { 159 | logger.Log(&QueryStatus{ 160 | Query: query, 161 | Args: args, 162 | Err: rows.Err(), 163 | Start: start, 164 | End: time.Now(), 165 | }, w.logging) 166 | }(time.Now()) 167 | 168 | query, newArgs, _ := w.argsIn(query, args) 169 | 170 | return w.db().QueryRowx(query, newArgs...) 171 | } 172 | 173 | // Get wrapper sqlx.Get 174 | func (w *DB) Get(dest interface{}, query string, args ...interface{}) (err error) { 175 | defer func(start time.Time) { 176 | logger.Log(&QueryStatus{ 177 | Query: query, 178 | Args: args, 179 | Err: err, 180 | Start: start, 181 | End: time.Now(), 182 | }, w.logging) 183 | }(time.Now()) 184 | 185 | wrapper, ok := dest.(*ModelWrapper) 186 | if ok { 187 | dest = wrapper.model 188 | } 189 | 190 | hook := NewHook(nil, w) 191 | refVal := reflect.ValueOf(dest) 192 | hook.callMethod("BeforeFind", refVal) 193 | 194 | query, newArgs, err := w.argsIn(query, args) 195 | if err != nil { 196 | return err 197 | } 198 | 199 | err = w.db().Get(dest, query, newArgs...) 200 | if err != nil { 201 | return err 202 | } 203 | 204 | if reflect.Indirect(refVal).Kind() == reflect.Struct { 205 | // relation data fill 206 | err = RelationOne(wrapper, w, dest) 207 | } 208 | 209 | if err != nil { 210 | return err 211 | } 212 | 213 | hook.callMethod("AfterFind", refVal) 214 | if hook.HasError() { 215 | return hook.Error() 216 | } 217 | 218 | return nil 219 | } 220 | 221 | func indirectType(v reflect.Type) reflect.Type { 222 | if v.Kind() != reflect.Ptr { 223 | return v 224 | } 225 | return v.Elem() 226 | } 227 | 228 | // Select wrapper sqlx.Select 229 | func (w *DB) Select(dest interface{}, query string, args ...interface{}) (err error) { 230 | defer func(start time.Time) { 231 | logger.Log(&QueryStatus{ 232 | Query: query, 233 | Args: args, 234 | Err: err, 235 | Start: start, 236 | End: time.Now(), 237 | }, w.logging) 238 | }(time.Now()) 239 | 240 | query, newArgs, err := w.argsIn(query, args) 241 | if err != nil { 242 | return err 243 | } 244 | 245 | wrapper, ok := dest.(*ModelWrapper) 246 | if ok { 247 | dest = wrapper.model 248 | } 249 | 250 | err = w.db().Select(dest, query, newArgs...) 251 | if err != nil { 252 | return err 253 | } 254 | 255 | t := indirectType(reflect.TypeOf(dest)) 256 | if t.Kind() == reflect.Slice { 257 | if indirectType(t.Elem()).Kind() == reflect.Struct { 258 | // relation data fill 259 | err = RelationAll(wrapper, w, dest) 260 | } 261 | } 262 | 263 | if err != nil { 264 | return err 265 | } 266 | 267 | return nil 268 | } 269 | 270 | // Txx the transaction with context 271 | func (w *DB) Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) (err error) { 272 | tx, err := w.database.BeginTxx(ctx, nil) 273 | 274 | if err != nil { 275 | return err 276 | } 277 | defer func() { 278 | if err != nil { 279 | err := tx.Rollback() 280 | if err != nil { 281 | log.Printf("gosql rollback error:%s", err) 282 | } 283 | } 284 | }() 285 | 286 | err = fn(ctx, &DB{tx: tx}) 287 | if err == nil { 288 | err = tx.Commit() 289 | } 290 | return 291 | } 292 | 293 | // Tx the transaction 294 | func (w *DB) Tx(fn func(w *DB) error) (err error) { 295 | tx, err := w.database.Beginx() 296 | if err != nil { 297 | return err 298 | } 299 | defer func() { 300 | if err != nil { 301 | err := tx.Rollback() 302 | if err != nil { 303 | log.Printf("gosql rollback error:%s", err) 304 | } 305 | } 306 | }() 307 | err = fn(&DB{tx: tx}) 308 | if err == nil { 309 | err = tx.Commit() 310 | } 311 | return 312 | } 313 | 314 | // Table database handler from to table name 315 | // for example: 316 | // gosql.Use("db2").Table("users") 317 | func (w *DB) Table(t string) *Mapper { 318 | return &Mapper{db: w, SQLBuilder: SQLBuilder{table: t, dialect: newDialect(w.DriverName())}} 319 | } 320 | 321 | // Model database handler from to struct 322 | // for example: 323 | // gosql.Use("db2").Model(&users{}) 324 | func (w *DB) Model(m interface{}) *Builder { 325 | if v1, ok := m.(*ModelWrapper); ok { 326 | return &Builder{modelWrapper: v1, model: v1.model, db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}} 327 | } else { 328 | return &Builder{model: m, db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}} 329 | } 330 | } 331 | 332 | // Model database handler from to struct with context 333 | // for example: 334 | // gosql.Use("db2").WithContext(ctx).Model(&users{}) 335 | func (w *DB) WithContext(ctx context.Context) *Builder { 336 | return &Builder{db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}, ctx: ctx} 337 | } 338 | 339 | // Relation association table builder handle 340 | func (w *DB) Relation(name string, fn BuilderChainFunc) *DB { 341 | if w.RelationMap == nil { 342 | w.RelationMap = make(map[string]BuilderChainFunc) 343 | } 344 | w.RelationMap[name] = fn 345 | return w 346 | } 347 | 348 | // Beginx begins a transaction for default database and returns an *gosql.DB instead of an *sql.Tx. 349 | func Begin() (*DB, error) { 350 | return Use(defaultLink).Begin() 351 | } 352 | 353 | // Use is change database 354 | func Use(db string) *DB { 355 | return &DB{database: Sqlx(db)} 356 | } 357 | 358 | // Exec default database 359 | func Exec(query string, args ...interface{}) (sql.Result, error) { 360 | return Use(defaultLink).Exec(query, args...) 361 | } 362 | 363 | // Exec default database 364 | func NamedExec(query string, args interface{}) (sql.Result, error) { 365 | return Use(defaultLink).NamedExec(query, args) 366 | } 367 | 368 | // Queryx default database 369 | func Queryx(query string, args ...interface{}) (*sqlx.Rows, error) { 370 | return Use(defaultLink).Queryx(query, args...) 371 | } 372 | 373 | // QueryRowx default database 374 | func QueryRowx(query string, args ...interface{}) *sqlx.Row { 375 | return Use(defaultLink).QueryRowx(query, args...) 376 | } 377 | 378 | // Txx default database the transaction with context 379 | func Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) error { 380 | return Use(defaultLink).Txx(ctx, fn) 381 | } 382 | 383 | // Tx default database the transaction 384 | func Tx(fn func(tx *DB) error) error { 385 | return Use(defaultLink).Tx(fn) 386 | } 387 | 388 | // Get default database 389 | func Get(dest interface{}, query string, args ...interface{}) error { 390 | return Use(defaultLink).Get(dest, query, args...) 391 | } 392 | 393 | // Select default database 394 | func Select(dest interface{}, query string, args ...interface{}) error { 395 | return Use(defaultLink).Select(dest, query, args...) 396 | } 397 | 398 | // Relation association table builder handle 399 | func Relation(name string, fn BuilderChainFunc) *DB { 400 | w := Use(defaultLink) 401 | w.RelationMap = make(map[string]BuilderChainFunc) 402 | w.RelationMap[name] = fn 403 | return w 404 | } 405 | 406 | // SetDefaultLink set default link name 407 | func SetDefaultLink(db string) { 408 | defaultLink = db 409 | } 410 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | "testing" 9 | "time" 10 | 11 | "github.com/ilibs/gosql/v2/internal/example/models" 12 | ) 13 | 14 | func TestShowSql(t *testing.T) { 15 | logger.logging = false 16 | RunWithSchema(t, func(t *testing.T) { 17 | insert(1) 18 | 19 | ShowSql().Queryx("select * from users") 20 | user := &models.Users{} 21 | Model(user).ShowSQL().Where("id = ?", 1).Get() 22 | Table("users").ShowSQL().Where("id = ?", 1).Update(map[string]interface{}{ 23 | "name": "test2", 24 | }) 25 | }) 26 | } 27 | 28 | func TestExec(t *testing.T) { 29 | RunWithSchema(t, func(t *testing.T) { 30 | result, err := Exec("insert into users(name,status,created_at,updated_at) value(?,?,?,?)", "test", 1, time.Now(), time.Now()) 31 | 32 | if err != nil { 33 | t.Error(err) 34 | } 35 | 36 | id, err := result.LastInsertId() 37 | 38 | if err != nil { 39 | t.Error(err) 40 | } 41 | 42 | if id != 1 { 43 | t.Error("lastInsertId error") 44 | } 45 | 46 | Exec("update users set status = status + 1 where id = ?", 1) 47 | 48 | result, err = Exec("update users set status = status + 1 where id = ?", 1) 49 | if err != nil { 50 | t.Error("update user error", err) 51 | } 52 | 53 | if aff, _ := result.RowsAffected(); aff == 0 { 54 | t.Error("update set error") 55 | } 56 | }) 57 | } 58 | 59 | func TestBatchExec(t *testing.T) { 60 | RunWithSchema(t, func(t *testing.T) { 61 | { 62 | // batch insert with structs 63 | users := []models.Users{ 64 | {Name: "test1", Status: 1, CreatedAt: time.Now(), UpdatedAt: time.Now()}, 65 | {Name: "test2", Status: 1, CreatedAt: time.Now(), UpdatedAt: time.Now()}, 66 | {Name: "test3", Status: 1, CreatedAt: time.Now(), UpdatedAt: time.Now()}, 67 | {Name: "test4", Status: 1, CreatedAt: time.Now(), UpdatedAt: time.Now()}, 68 | } 69 | 70 | result, err := NamedExec("insert into users(name,status,created_at,updated_at) values(:name,:status,:created_at,:updated_at)", users) 71 | 72 | if err != nil { 73 | t.Error(err) 74 | } 75 | 76 | if aff, _ := result.RowsAffected(); aff != 4 { 77 | t.Error("update set error") 78 | } 79 | } 80 | 81 | { 82 | // batch insert with maps 83 | users := []map[string]interface{}{ 84 | {"name": "test5", "status": 1, "created_at": "2021-01-25 12:22:22", "updated_at": "2021-01-25 12:22:22"}, 85 | {"name": "test6", "status": 1, "created_at": "2021-01-25 12:22:22", "updated_at": "2021-01-25 12:22:22"}, 86 | {"name": "test7", "status": 1, "created_at": "2021-01-25 12:22:22", "updated_at": "2021-01-25 12:22:22"}, 87 | {"name": "test8", "status": 1, "created_at": "2021-01-25 12:22:22", "updated_at": "2021-01-25 12:22:22"}, 88 | } 89 | 90 | result, err := NamedExec("insert into users(name,status,created_at,updated_at) values(:name,:status,:created_at,:updated_at)", users) 91 | 92 | if err != nil { 93 | t.Error(err) 94 | } 95 | 96 | if aff, _ := result.RowsAffected(); aff != 4 { 97 | t.Error("update set error") 98 | } 99 | } 100 | }) 101 | } 102 | 103 | func TestQueryx(t *testing.T) { 104 | RunWithSchema(t, func(t *testing.T) { 105 | insert(1) 106 | insert(2) 107 | 108 | rows, err := Queryx("select * from users") 109 | 110 | if err != nil { 111 | t.Error(err) 112 | } 113 | defer rows.Close() 114 | 115 | for rows.Next() { 116 | user := &models.Users{} 117 | err = rows.StructScan(user) 118 | if err != nil { 119 | t.Error(err) 120 | } 121 | } 122 | 123 | rows, err = Queryx("select name from users") 124 | 125 | if err != nil { 126 | t.Error(err) 127 | } 128 | defer rows.Close() 129 | 130 | for rows.Next() { 131 | // results := make(map[string]interface{}) 132 | // err = rows.MapScan(results) 133 | var name string 134 | err = rows.Scan(&name) 135 | if err != nil { 136 | t.Error(err) 137 | } 138 | fmt.Println(name) 139 | } 140 | }) 141 | } 142 | 143 | func TestQueryRowx(t *testing.T) { 144 | RunWithSchema(t, func(t *testing.T) { 145 | insert(1) 146 | user := &models.Users{} 147 | err := QueryRowx("select * from users where id = 1").StructScan(user) 148 | 149 | if err != nil { 150 | t.Error(err) 151 | } 152 | 153 | if user.Id != 1 { 154 | t.Error("wraper QueryRowx error") 155 | } 156 | }) 157 | } 158 | 159 | func TestUse(t *testing.T) { 160 | RunWithSchema(t, func(t *testing.T) { 161 | db := Use("db2") 162 | _, err := db.Exec("insert into photos(moment_id,url,created_at,updated_at) value(?,?,?,?)", 1, "http://test.com", time.Now(), time.Now()) 163 | 164 | if err != nil { 165 | t.Error(err) 166 | } 167 | }) 168 | } 169 | 170 | func TestUseTable(t *testing.T) { 171 | RunWithSchema(t, func(t *testing.T) { 172 | post := &models.Photos{ 173 | MomentId: 1, 174 | Url: "http://test.com", 175 | } 176 | 177 | _, err := Use("db2").Model(post).Create() 178 | 179 | if err != nil { 180 | t.Error(err) 181 | } 182 | 183 | post.Url = "http://test.com/2" 184 | _, err = Use("db2").WithContext(context.Background()).Model(post).Update() 185 | if err != nil { 186 | t.Error(err) 187 | } 188 | 189 | _, err = Use("db2").Table("photos").Where("id = ?", 1).Update(map[string]interface{}{ 190 | "url": "http://test2.com", 191 | }) 192 | 193 | if err != nil { 194 | t.Error(err) 195 | } 196 | }) 197 | } 198 | 199 | func TestGet(t *testing.T) { 200 | RunWithSchema(t, func(t *testing.T) { 201 | insert(1) 202 | db := Use("default") 203 | { 204 | user := &models.Users{} 205 | err := db.Get(user, "select * from users where id = ?", 1) 206 | 207 | if err != nil { 208 | t.Error(err) 209 | } 210 | 211 | fmt.Println(jsonEncode(user)) 212 | } 213 | }) 214 | } 215 | 216 | func TestGetSingle(t *testing.T) { 217 | RunWithSchema(t, func(t *testing.T) { 218 | insert(1) 219 | db := Use("default") 220 | { 221 | var name string 222 | err := db.Get(&name, "select name from users where id = ?", 1) 223 | 224 | if err != nil { 225 | t.Error(err) 226 | } 227 | 228 | fmt.Println(name) 229 | } 230 | }) 231 | } 232 | 233 | func TestSelectSlice(t *testing.T) { 234 | RunWithSchema(t, func(t *testing.T) { 235 | insert(1) 236 | insert(2) 237 | var users []string 238 | err := Select(&users, "select name from users") 239 | 240 | if err != nil { 241 | t.Error(err) 242 | } 243 | 244 | fmt.Println(jsonEncode(users)) 245 | }) 246 | } 247 | 248 | func TestSelect(t *testing.T) { 249 | RunWithSchema(t, func(t *testing.T) { 250 | insert(1) 251 | insert(2) 252 | db := Use("default") 253 | user := make([]*models.Users, 0) 254 | err := db.Select(&user, "select * from users") 255 | 256 | if err != nil { 257 | t.Error(err) 258 | } 259 | 260 | fmt.Println(jsonEncode(user)) 261 | }) 262 | } 263 | 264 | func TestQueryxIn(t *testing.T) { 265 | RunWithSchema(t, func(t *testing.T) { 266 | insert(1) 267 | insert(2) 268 | insert(4) 269 | insert(5) 270 | insert(6) 271 | 272 | rows, err := Queryx("select * from users where status = ? and id in (?)", 0, []int{1, 2, 3}) 273 | 274 | if err != nil { 275 | t.Error(err) 276 | } 277 | defer rows.Close() 278 | 279 | for rows.Next() { 280 | user := &models.Users{} 281 | err = rows.StructScan(user) 282 | if err != nil { 283 | t.Error(err) 284 | } 285 | } 286 | }) 287 | } 288 | 289 | func TestSelectIn(t *testing.T) { 290 | RunWithSchema(t, func(t *testing.T) { 291 | insert(1) 292 | insert(2) 293 | insert(4) 294 | insert(5) 295 | insert(6) 296 | db := Use("default") 297 | user := make([]*models.Users, 0) 298 | err := db.Select(&user, "select * from users where id in(?)", []int{1, 2, 3}) 299 | 300 | if err != nil { 301 | t.Error(err) 302 | } 303 | 304 | fmt.Println(jsonEncode(user)) 305 | }) 306 | } 307 | 308 | func TestTx(t *testing.T) { 309 | RunWithSchema(t, func(t *testing.T) { 310 | // 1 311 | { 312 | Tx(func(tx *DB) error { 313 | for id := 1; id < 10; id++ { 314 | user := &models.Users{ 315 | Id: id, 316 | Name: "test" + strconv.Itoa(id), 317 | } 318 | 319 | tx.Model(user).Create() 320 | 321 | if id == 8 { 322 | return errors.New("simulation terminated") 323 | } 324 | } 325 | 326 | return nil 327 | }) 328 | 329 | num, err := Model(&models.Users{}).Count() 330 | 331 | if err != nil { 332 | t.Error(err) 333 | } 334 | 335 | if num != 0 { 336 | t.Error("transaction abort failed") 337 | } 338 | } 339 | 340 | // 2 341 | { 342 | Tx(func(tx *DB) error { 343 | for id := 1; id < 10; id++ { 344 | user := &models.Users{ 345 | Id: id, 346 | Name: "test" + strconv.Itoa(id), 347 | } 348 | 349 | tx.Model(user).Create() 350 | } 351 | 352 | return nil 353 | }) 354 | 355 | num, err := Model(&models.Users{}).Count() 356 | 357 | if err != nil { 358 | t.Error(err) 359 | } 360 | 361 | if num != 9 { 362 | t.Error("transaction create failed") 363 | } 364 | } 365 | }) 366 | } 367 | 368 | func TestWithTx(t *testing.T) { 369 | RunWithSchema(t, func(t *testing.T) { 370 | { 371 | err := Tx(func(tx *DB) error { 372 | for id := 1; id < 10; id++ { 373 | _, err := tx.Exec("INSERT INTO users(id,name,status,created_at,updated_at) VALUES(?,?,?,?,?)", id, "test"+strconv.Itoa(id), 1, time.Now(), time.Now()) 374 | if err != nil { 375 | return err 376 | } 377 | } 378 | 379 | var num int 380 | err := tx.QueryRowx("select count(*) from users").Scan(&num) 381 | 382 | if err != nil { 383 | return err 384 | } 385 | 386 | if num != 9 { 387 | t.Error("with transaction create failed") 388 | } 389 | 390 | return nil 391 | }) 392 | 393 | if err != nil { 394 | t.Fatalf("with transaction failed %s", err) 395 | } 396 | } 397 | }) 398 | } 399 | 400 | func TestTxx(t *testing.T) { 401 | RunWithSchema(t, func(t *testing.T) { 402 | 403 | ctx, cancel := context.WithCancel(context.Background()) 404 | 405 | err := Txx(ctx, func(ctx context.Context, tx *DB) error { 406 | for id := 1; id < 10; id++ { 407 | user := &models.Users{ 408 | Id: id, 409 | Name: "test" + strconv.Itoa(id), 410 | } 411 | 412 | tx.Model(user).Create() 413 | 414 | if id == 8 { 415 | cancel() 416 | break 417 | } 418 | } 419 | 420 | return nil 421 | }) 422 | 423 | if err == nil { 424 | t.Fatalf("with transaction must be cancel error") 425 | } 426 | 427 | num, err := Model(&models.Users{}).Count() 428 | 429 | if err != nil { 430 | t.Error(err) 431 | } 432 | 433 | if num != 0 { 434 | t.Error("transaction abort failed") 435 | } 436 | }) 437 | } 438 | 439 | func TestWrapper_Relation(t *testing.T) { 440 | RunWithSchema(t, func(t *testing.T) { 441 | initDatas(t) 442 | moment := &MomentList{} 443 | err := Relation("User", func(b *Builder) { 444 | b.Where("status = 0") 445 | }).Get(moment, "select * from moments") 446 | 447 | // b, _ := json.MarshalIndent(moment, "", " ") 448 | // fmt.Println(string(b), err) 449 | 450 | if err != nil { 451 | t.Fatal(err) 452 | } 453 | }) 454 | } 455 | 456 | func TestWrapper_Relation2(t *testing.T) { 457 | RunWithSchema(t, func(t *testing.T) { 458 | initDatas(t) 459 | var moments = make([]*MomentList, 0) 460 | err := Relation("User", func(b *Builder) { 461 | b.Where("status = 1") 462 | }).Select(&moments, "select * from moments") 463 | 464 | if err != nil { 465 | t.Fatal(err) 466 | } 467 | }) 468 | } 469 | 470 | func TestDB_Begin(t *testing.T) { 471 | RunWithSchema(t, func(t *testing.T) { 472 | tx, err := Begin() 473 | if err != nil { 474 | t.Fatalf("with transaction begin error %s", err) 475 | } 476 | 477 | var fn = func() error { 478 | for id := 1; id < 10; id++ { 479 | _, err := tx.Exec("INSERT INTO users(id,name,status,created_at,updated_at) VALUES(?,?,?,?,?)", id, "test"+strconv.Itoa(id), 1, time.Now(), time.Now()) 480 | if err != nil { 481 | return err 482 | } 483 | } 484 | 485 | var num int 486 | err = tx.QueryRowx("select count(*) from users").Scan(&num) 487 | 488 | if err != nil { 489 | return err 490 | } 491 | 492 | if num != 9 { 493 | return errors.New("with transaction create failed") 494 | } 495 | return nil 496 | } 497 | 498 | err = fn() 499 | 500 | if err != nil { 501 | err := tx.Rollback() 502 | if err != nil { 503 | t.Fatalf("with transaction rollback error %s", err) 504 | } 505 | } 506 | 507 | err = tx.Commit() 508 | if err != nil { 509 | t.Fatalf("with transaction commit error %s", err) 510 | } 511 | }) 512 | } 513 | 514 | func TestRelation(t *testing.T) { 515 | RunWithSchema(t, func(t *testing.T) { 516 | initDatas(t) 517 | 518 | moment := &MomentList{} 519 | err := Relation("User", func(b *Builder) { 520 | // this is builder instance 521 | b.Where("status = 1") 522 | }).Get(moment, "select * from moments where id = 1") 523 | 524 | if err != nil { 525 | t.Fatalf("relation query error %s", err) 526 | } 527 | }) 528 | } 529 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Dialect interface contains behaviors that differ across SQL database 8 | type Dialect interface { 9 | // GetName get dialect's name 10 | GetName() string 11 | 12 | // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name 13 | Quote(key string) string 14 | 15 | // Placeholder is where value holder default "?" 16 | Placeholder() string 17 | } 18 | 19 | type commonDialect struct { 20 | } 21 | 22 | func (commonDialect) GetName() string { 23 | return "common" 24 | } 25 | 26 | func (commonDialect) Quote(key string) string { 27 | return fmt.Sprintf(`"%s"`, key) 28 | } 29 | 30 | func (*commonDialect) Placeholder() string { 31 | return "?" 32 | } 33 | 34 | var dialectsMap = map[string]Dialect{} 35 | 36 | // RegisterDialect register new dialect 37 | func RegisterDialect(name string, dialect Dialect) { 38 | dialectsMap[name] = dialect 39 | } 40 | 41 | // GetDialect gets the dialect for the specified dialect name 42 | func GetDialect(name string) (dialect Dialect, ok bool) { 43 | dialect, ok = dialectsMap[name] 44 | return 45 | } 46 | 47 | func mustGetDialect(name string) Dialect { 48 | if dialect, ok := dialectsMap[name]; ok { 49 | return dialect 50 | } 51 | panic(fmt.Sprintf("`%v` is not officially supported", name)) 52 | } 53 | 54 | func newDialect(name string) Dialect { 55 | if value, ok := GetDialect(name); ok { 56 | return value 57 | } 58 | 59 | fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) 60 | return &commonDialect{} 61 | } 62 | -------------------------------------------------------------------------------- /dialect_mysql.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type mysqlDialect struct { 8 | commonDialect 9 | } 10 | 11 | func init() { 12 | RegisterDialect("mysql", &mysqlDialect{}) 13 | } 14 | 15 | func (mysqlDialect) GetName() string { 16 | return "mysql" 17 | } 18 | 19 | func (mysqlDialect) Quote(key string) string { 20 | return fmt.Sprintf("`%s`", key) 21 | } 22 | -------------------------------------------------------------------------------- /dialect_mysql_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "testing" 4 | 5 | func Test_mysqlDialect_GetName(t *testing.T) { 6 | type fields struct { 7 | commonDialect commonDialect 8 | } 9 | tests := []struct { 10 | name string 11 | fields fields 12 | want string 13 | }{ 14 | { 15 | name: "test", 16 | fields: fields{}, 17 | want: "mysql", 18 | }, 19 | } 20 | for _, tt := range tests { 21 | t.Run(tt.name, func(t *testing.T) { 22 | my := mysqlDialect{ 23 | commonDialect: tt.fields.commonDialect, 24 | } 25 | if got := my.GetName(); got != tt.want { 26 | t.Errorf("GetName() = %v, want %v", got, tt.want) 27 | } 28 | }) 29 | } 30 | } 31 | 32 | func Test_mysqlDialect_Quote(t *testing.T) { 33 | type fields struct { 34 | commonDialect commonDialect 35 | } 36 | type args struct { 37 | key string 38 | } 39 | tests := []struct { 40 | name string 41 | fields fields 42 | args args 43 | want string 44 | }{ 45 | { 46 | name: "test", 47 | fields: fields{}, 48 | args: args{"status"}, 49 | want: "`status`", 50 | }, 51 | } 52 | for _, tt := range tests { 53 | t.Run(tt.name, func(t *testing.T) { 54 | my := mysqlDialect{ 55 | commonDialect: tt.fields.commonDialect, 56 | } 57 | if got := my.Quote(tt.args.key); got != tt.want { 58 | t.Errorf("Quote() = %v, want %v", got, tt.want) 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /dialect_postgres.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "strconv" 4 | 5 | type postgresDialect struct { 6 | commonDialect 7 | count int 8 | } 9 | 10 | func init() { 11 | RegisterDialect("postgres", &postgresDialect{}) 12 | } 13 | 14 | func (postgresDialect) GetName() string { 15 | return "postgres" 16 | } 17 | 18 | func (p *postgresDialect) Placeholder() string { 19 | p.count++ 20 | return "$" + strconv.Itoa(p.count) 21 | } 22 | -------------------------------------------------------------------------------- /dialect_postgres_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "testing" 4 | 5 | func Test_postgresDialect_GetName(t *testing.T) { 6 | type fields struct { 7 | commonDialect commonDialect 8 | } 9 | tests := []struct { 10 | name string 11 | fields fields 12 | want string 13 | }{ 14 | { 15 | name: "test", 16 | fields: fields{}, 17 | want: "postgres", 18 | }, 19 | } 20 | for _, tt := range tests { 21 | t.Run(tt.name, func(t *testing.T) { 22 | po := postgresDialect{ 23 | commonDialect: tt.fields.commonDialect, 24 | } 25 | if got := po.GetName(); got != tt.want { 26 | t.Errorf("GetName() = %v, want %v", got, tt.want) 27 | } 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /dialect_sqlite3.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | type sqlite3Dialect struct { 4 | commonDialect 5 | } 6 | 7 | func init() { 8 | RegisterDialect("sqlite3", &sqlite3Dialect{}) 9 | } 10 | 11 | func (sqlite3Dialect) GetName() string { 12 | return "sqlite3" 13 | } 14 | -------------------------------------------------------------------------------- /dialect_sqlite3_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "testing" 4 | 5 | func Test_sqlite3Dialect_GetName(t *testing.T) { 6 | type fields struct { 7 | commonDialect commonDialect 8 | } 9 | tests := []struct { 10 | name string 11 | fields fields 12 | want string 13 | }{ 14 | { 15 | name: "test", 16 | fields: fields{}, 17 | want: "sqlite3", 18 | }, 19 | } 20 | for _, tt := range tests { 21 | t.Run(tt.name, func(t *testing.T) { 22 | sq := sqlite3Dialect{ 23 | commonDialect: tt.fields.commonDialect, 24 | } 25 | if got := sq.GetName(); got != tt.want { 26 | t.Errorf("GetName() = %v, want %v", got, tt.want) 27 | } 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /dialect_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestGetDialect(t *testing.T) { 8 | RegisterDialect("mysql", &mysqlDialect{}) 9 | d := newDialect("mysql") 10 | if d.GetName() != "mysql" { 11 | t.Fatal("get dialect not mysql") 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /expr.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | // SQL expression 4 | type expr struct { 5 | expr string 6 | args []interface{} 7 | } 8 | 9 | // Expr generate raw SQL expression, for example: 10 | // gosql.Table("user").Update(map[string]interface{}{"price", gorm.Expr("price * ? + ?", 2, 100)}) 11 | func Expr(expression string, args ...interface{}) *expr { 12 | return &expr{expr: expression, args: args} 13 | } 14 | -------------------------------------------------------------------------------- /expr_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestExpr(t *testing.T) { 8 | b := &SQLBuilder{ 9 | table: "users", 10 | dialect: mustGetDialect("mysql"), 11 | } 12 | 13 | q := b.updateString(map[string]interface{}{ 14 | "id": 2, 15 | "count": Expr("count + ?", 1), 16 | }) 17 | 18 | //fmt.Println(q, b.args) 19 | 20 | if q != "UPDATE `users` SET `count`=count + ?,`id`=?;" { 21 | t.Error("Expr error,get:", q) 22 | } 23 | 24 | if len(b.args) != 2 { 25 | t.Error("Expr args count error") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ilibs/gosql/v2 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.7.1 7 | github.com/jmoiron/sqlx v1.3.5 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 2 | github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= 3 | github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 4 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= 5 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= 6 | github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= 7 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 8 | github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= 9 | github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 10 | -------------------------------------------------------------------------------- /hook.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "reflect" 8 | "strings" 9 | ) 10 | 11 | type Hook struct { 12 | db *DB 13 | Errs []error 14 | ctx context.Context 15 | } 16 | 17 | func NewHook(ctx context.Context, db *DB) *Hook { 18 | return &Hook{ 19 | db: db, 20 | ctx: ctx, 21 | } 22 | } 23 | 24 | func (h *Hook) callMethod(methodName string, reflectValue reflect.Value) { 25 | // Only get address from non-pointer 26 | if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { 27 | reflectValue = reflectValue.Addr() 28 | } 29 | 30 | if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { 31 | switch method := methodValue.Interface().(type) { 32 | case func(): 33 | method() 34 | case func() error: 35 | h.Err(method()) 36 | case func(db *DB): 37 | method(h.db) 38 | case func(db *DB) error: 39 | h.Err(method(h.db)) 40 | case func(ctx context.Context): 41 | method(h.ctx) 42 | case func(ctx context.Context) error: 43 | h.Err(method(h.ctx)) 44 | case func(ctx context.Context, db *DB): 45 | method(h.ctx, h.db) 46 | case func(ctx context.Context, db *DB) error: 47 | h.Err(method(h.ctx, h.db)) 48 | default: 49 | log.Panicf("unsupported function %v", methodName) 50 | } 51 | } 52 | } 53 | 54 | // Err add error 55 | func (h *Hook) Err(err error) { 56 | if err != nil { 57 | h.Errs = append(h.Errs, err) 58 | } 59 | } 60 | 61 | // HasError has errors 62 | func (h *Hook) HasError() bool { 63 | return len(h.Errs) > 0 64 | } 65 | 66 | // Error format happened errors 67 | func (h *Hook) Error() error { 68 | var errs = make([]string, 0) 69 | for _, e := range h.Errs { 70 | errs = append(errs, e.Error()) 71 | } 72 | return errors.New(strings.Join(errs, "; ")) 73 | } 74 | -------------------------------------------------------------------------------- /hook_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/ilibs/gosql/v2/internal/example/models" 12 | ) 13 | 14 | type hookUser struct { 15 | models.Users 16 | } 17 | 18 | func (u *hookUser) BeforeCreate() error { 19 | fmt.Println("BeforCreate run") 20 | if u.Id == 1 { 21 | return errors.New("before error") 22 | } 23 | 24 | return nil 25 | } 26 | 27 | func (u *hookUser) AfterCreate() { 28 | fmt.Println("AfterCreate run") 29 | } 30 | 31 | func (u *hookUser) BeforeUpdate() { 32 | fmt.Println("BeforeUpdate run") 33 | } 34 | 35 | func (u *hookUser) AfterUpdate() error { 36 | fmt.Println("AfterUpdate run") 37 | user := &models.Users{ 38 | Id: 999, 39 | } 40 | 41 | err := WithContext(nil).Model(user).Get() 42 | return err 43 | } 44 | 45 | func (u *hookUser) BeforeDelete() { 46 | fmt.Println("BeforeDelete run") 47 | } 48 | 49 | func (u *hookUser) AfterDelete() { 50 | fmt.Println("AfterDelete run") 51 | } 52 | 53 | func (u *hookUser) AfterFind() { 54 | u.Name = "AfterUserName" 55 | fmt.Println("AfterFind run") 56 | } 57 | 58 | func TestNewHook(t *testing.T) { 59 | RunWithSchema(t, func(t *testing.T) { 60 | { 61 | user := &hookUser{models.Users{ 62 | Id: 1, 63 | Name: "test", 64 | Status: 1, 65 | }} 66 | _, err := WithContext(nil).Model(user).Create() 67 | if err == nil { 68 | t.Error("before create must error") 69 | } 70 | } 71 | 72 | { 73 | insert(2) 74 | user := &hookUser{models.Users{ 75 | Id: 2, 76 | }, 77 | } 78 | _, err := Model(user).Update() 79 | if err == nil { 80 | t.Error("after update must error") 81 | } 82 | } 83 | 84 | { 85 | user := &hookUser{models.Users{ 86 | Id: 3, 87 | Name: "test", 88 | Status: 1, 89 | }, 90 | } 91 | _, err := Model(user).Create() 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | 96 | user.Name = "test2" 97 | Model(user).Update() 98 | user2 := &hookUser{} 99 | Model(user2).Where("id=3").Get() 100 | if user2.Name != "AfterUserName" { 101 | t.Error("AfterFind change username error") 102 | } 103 | 104 | Model(user).Delete() 105 | } 106 | }) 107 | } 108 | 109 | func TestHook_Err(t *testing.T) { 110 | hook := NewHook(nil, nil) 111 | hook.Err(errors.New("test")) 112 | if !hook.HasError() { 113 | t.Error("hook err") 114 | } 115 | } 116 | 117 | func TestHook_HasError(t *testing.T) { 118 | hook := NewHook(nil, nil) 119 | hook.Err(errors.New("test")) 120 | hook.Err(errors.New("test")) 121 | hook.Err(errors.New("test")) 122 | hook.Err(errors.New("test")) 123 | hook.Err(errors.New("test")) 124 | if !hook.HasError() { 125 | t.Error("has error err") 126 | } 127 | } 128 | 129 | func TestHook_Error(t *testing.T) { 130 | hook := NewHook(nil, nil) 131 | hook.Err(errors.New("test")) 132 | hook.Err(errors.New("test")) 133 | hook.Err(errors.New("test")) 134 | hook.Err(errors.New("test")) 135 | hook.Err(errors.New("test")) 136 | if strings.Count(hook.Error().Error(), "test") != 5 { 137 | t.Error("get error err") 138 | } 139 | } 140 | 141 | type testModelCallBack struct { 142 | } 143 | 144 | func (m *testModelCallBack) BeforeCreate() { 145 | } 146 | 147 | func (m *testModelCallBack) AfterCreate() error { 148 | return nil 149 | } 150 | 151 | func (m *testModelCallBack) BeforeChange(tx *DB) { 152 | } 153 | 154 | func (m *testModelCallBack) AfterChange(tx *DB) error { 155 | return nil 156 | } 157 | 158 | func (m *testModelCallBack) BeforeUpdate(ctx context.Context) { 159 | } 160 | 161 | func (m *testModelCallBack) AfterUpdate(ctx context.Context) error { 162 | return nil 163 | } 164 | 165 | func (m *testModelCallBack) BeforeDelete(ctx context.Context, tx *DB) { 166 | } 167 | 168 | func (m *testModelCallBack) AfterDelete(ctx context.Context, tx *DB) error { 169 | return nil 170 | } 171 | 172 | func TestHook_callMethod(t *testing.T) { 173 | 174 | hook := NewHook(nil, nil) 175 | 176 | m := &testModelCallBack{} 177 | 178 | refVal := reflect.ValueOf(m) 179 | hook.callMethod("BeforeCreate", refVal) 180 | hook.callMethod("BeforeChange", refVal) 181 | hook.callMethod("BeforeDelete", refVal) 182 | hook.callMethod("BeforeUpdate", refVal) 183 | hook.callMethod("AfterCreate", refVal) 184 | hook.callMethod("AfterChange", refVal) 185 | hook.callMethod("AfterDelete", refVal) 186 | hook.callMethod("AfterUpdate", refVal) 187 | } 188 | -------------------------------------------------------------------------------- /internal/example/models/common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type ModelTime struct { 8 | CreatedAt time.Time `form:"-" json:"created_at" db:"created_at" time_format:"2006-01-02 15:04:05"` 9 | UpdatedAt time.Time `form:"-" json:"updated_at" db:"updated_at" time_format:"2006-01-02 15:04:05"` 10 | } 11 | -------------------------------------------------------------------------------- /internal/example/models/moment.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type Moments struct { 4 | Id int `form:"id" json:"id" db:"id"` 5 | UserId int `form:"user_id" json:"user_id" db:"user_id"` 6 | Content string `form:"content" json:"content" db:"content"` 7 | CommentTotal int `form:"comment_total" json:"comment_total" db:"comment_total"` 8 | LikeTotal int `form:"like_total" json:"like_total" db:"like_total"` 9 | Status int `form:"status" json:"status" db:"status"` 10 | ModelTime 11 | } 12 | 13 | func (p *Moments) TableName() string { 14 | return "moments" 15 | } 16 | 17 | func (p *Moments) PK() string { 18 | return "id" 19 | } 20 | -------------------------------------------------------------------------------- /internal/example/models/photo.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type Photos struct { 4 | Id int `form:"id" json:"id" db:"id"` 5 | MomentId int `form:"moment_id" json:"moment_id" db:"moment_id"` 6 | Url string `form:"url" json:"url" db:"url"` 7 | ModelTime 8 | } 9 | 10 | func (p *Photos) TableName() string { 11 | return "photos" 12 | } 13 | 14 | func (p *Photos) PK() string { 15 | return "id" 16 | } 17 | -------------------------------------------------------------------------------- /internal/example/models/user.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | ) 7 | 8 | type Users struct { 9 | Id int `form:"id" json:"id" db:"id"` 10 | Name string `form:"name" json:"name" db:"name"` 11 | Status int `form:"status" json:"status" db:"status"` 12 | SuccessTime sql.NullString `form:"-" json:"success_time" db:"success_time"` 13 | CreatedAt time.Time `form:"-" json:"created_at" db:"created_at" time_format:"2006-01-02 15:04:05"` 14 | UpdatedAt time.Time `form:"-" json:"updated_at" db:"updated_at" time_format:"2006-01-02 15:04:05"` 15 | } 16 | 17 | func (u *Users) TableName() string { 18 | return "users" 19 | } 20 | 21 | func (u *Users) PK() string { 22 | return "id" 23 | } 24 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "errors" 7 | ) 8 | 9 | var emptyJSON = json.RawMessage("{}") 10 | 11 | func JsonObject(value interface{}) (json.RawMessage, error) { 12 | var source []byte 13 | switch t := value.(type) { 14 | case string: 15 | source = []byte(t) 16 | case []byte: 17 | source = t 18 | case nil: 19 | source = emptyJSON 20 | default: 21 | return nil, errors.New("incompatible type for json.RawMessage") 22 | } 23 | 24 | if len(source) == 0 { 25 | source = emptyJSON 26 | } 27 | 28 | return source, nil 29 | } 30 | 31 | // JSONText is a json.RawMessage, which is a []byte underneath. 32 | // Value() validates the json format in the source, and returns an error if 33 | // the json is not valid. Scan does no validation. JSONText additionally 34 | // implements `Unmarshal`, which unmarshals the json within to an interface{} 35 | type JSONText json.RawMessage 36 | 37 | // MarshalJSON returns the *j as the JSON encoding of j. 38 | func (j JSONText) MarshalJSON() ([]byte, error) { 39 | if len(j) == 0 { 40 | return emptyJSON, nil 41 | } 42 | return j, nil 43 | } 44 | 45 | // UnmarshalJSON sets *j to a copy of data 46 | func (j *JSONText) UnmarshalJSON(data []byte) error { 47 | if j == nil { 48 | return errors.New("JSONText: UnmarshalJSON on nil pointer") 49 | } 50 | *j = append((*j)[0:0], data...) 51 | return nil 52 | } 53 | 54 | // Value returns j as a value. This does a validating unmarshal into another 55 | // RawMessage. If j is invalid json, it returns an error. 56 | func (j JSONText) Value() (driver.Value, error) { 57 | var m json.RawMessage 58 | var err = j.Unmarshal(&m) 59 | if err != nil { 60 | return []byte{}, err 61 | } 62 | return []byte(j), nil 63 | } 64 | 65 | // Scan stores the src in *j. No validation is done. 66 | func (j *JSONText) Scan(src interface{}) error { 67 | var source []byte 68 | switch t := src.(type) { 69 | case string: 70 | source = []byte(t) 71 | case []byte: 72 | if len(t) == 0 { 73 | source = emptyJSON 74 | } else { 75 | source = t 76 | } 77 | case nil: 78 | *j = JSONText(emptyJSON) 79 | default: 80 | return errors.New("incompatible type for JSONText") 81 | } 82 | *j = append((*j)[0:0], source...) 83 | return nil 84 | } 85 | 86 | // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. 87 | func (j *JSONText) Unmarshal(v interface{}) error { 88 | if len(*j) == 0 { 89 | *j = JSONText(emptyJSON) 90 | } 91 | return json.Unmarshal([]byte(*j), v) 92 | } 93 | 94 | // String supports pretty printing for JSONText types. 95 | func (j JSONText) String() string { 96 | return string(j) 97 | } 98 | 99 | func (j *JSONText) UnmarshalBinary(data []byte) error { 100 | return j.UnmarshalJSON(data) 101 | } 102 | 103 | func (j JSONText) MarshalBinary() ([]byte, error) { 104 | return j.MarshalJSON() 105 | } 106 | -------------------------------------------------------------------------------- /json_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestJsonObject(t *testing.T) { 10 | type args struct { 11 | value interface{} 12 | } 13 | tests := []struct { 14 | name string 15 | args args 16 | want json.RawMessage 17 | wantErr bool 18 | }{ 19 | { 20 | name: "test1", 21 | args: args{ 22 | value: nil, 23 | }, 24 | want: emptyJSON, 25 | wantErr: false, 26 | }, 27 | { 28 | name: "test2", 29 | args: args{ 30 | value: []byte(""), 31 | }, 32 | want: emptyJSON, 33 | wantErr: false, 34 | }, 35 | { 36 | name: "test3", 37 | args: args{ 38 | value: "", 39 | }, 40 | want: emptyJSON, 41 | wantErr: false, 42 | }, 43 | { 44 | name: "test4", 45 | args: args{ 46 | value: `{"other":1}`, 47 | }, 48 | want: []byte(`{"other":1}`), 49 | wantErr: false, 50 | }, 51 | } 52 | for _, tt := range tests { 53 | t.Run(tt.name, func(t *testing.T) { 54 | got, err := JsonObject(tt.args.value) 55 | if (err != nil) != tt.wantErr { 56 | t.Errorf("JsonValue() error = %v, wantErr %v", err, tt.wantErr) 57 | return 58 | } 59 | if !reflect.DeepEqual(got, tt.want) { 60 | t.Errorf("JsonValue() got = %v, want %v", string(got), string(tt.want)) 61 | } 62 | }) 63 | } 64 | } 65 | 66 | func TestJSONText(t *testing.T) { 67 | j := JSONText(`{"foo": 1, "bar": 2}`) 68 | v, err := j.Value() 69 | if err != nil { 70 | t.Errorf("Was not expecting an error") 71 | } 72 | err = (&j).Scan(v) 73 | if err != nil { 74 | t.Errorf("Was not expecting an error") 75 | } 76 | m := map[string]interface{}{} 77 | j.Unmarshal(&m) 78 | 79 | if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { 80 | t.Errorf("Expected valid json but got some garbage instead? %#v", m) 81 | } 82 | 83 | j = JSONText(`{"foo": 1, invalid, false}`) 84 | v, err = j.Value() 85 | if err == nil { 86 | t.Errorf("Was expecting invalid json to fail!") 87 | } 88 | 89 | j = JSONText("") 90 | v, err = j.Value() 91 | if err != nil { 92 | t.Errorf("Was not expecting an error") 93 | } 94 | 95 | err = (&j).Scan(v) 96 | if err != nil { 97 | t.Errorf("Was not expecting an error") 98 | } 99 | 100 | j = JSONText(nil) 101 | v, err = j.Value() 102 | if err != nil { 103 | t.Errorf("Was not expecting an error") 104 | } 105 | 106 | err = (&j).Scan(v) 107 | if err != nil { 108 | t.Errorf("Was not expecting an error") 109 | } 110 | 111 | t.Run("Binary", func(t *testing.T) { 112 | j := JSONText(`{"foo": 1, "bar": 2}`) 113 | v, err := j.MarshalBinary() 114 | if err != nil { 115 | t.Errorf("Was not expecting an error") 116 | } 117 | if string(v) != `{"foo": 1, "bar": 2}` { 118 | t.Errorf("MarshalBinary result error") 119 | } 120 | 121 | err = (&j).UnmarshalBinary(v) 122 | if err != nil { 123 | t.Errorf("Was not expecting an error") 124 | } 125 | }) 126 | } 127 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "regexp" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | const ( 13 | fmtLogQuery = `Query: %s` 14 | fmtLogArgs = `Args: %#v` 15 | fmtLogError = `Error: %v` 16 | fmtLogTimeTaken = `Time: %0.5fs` 17 | ) 18 | 19 | var ( 20 | reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) 21 | ) 22 | 23 | // QueryStatus represents the status of a query after being executed. 24 | type QueryStatus struct { 25 | Query string 26 | Args interface{} 27 | 28 | Start time.Time 29 | End time.Time 30 | 31 | Err error 32 | } 33 | 34 | // String returns a formatted log message. 35 | func (q *QueryStatus) String() string { 36 | lines := make([]string, 0, 8) 37 | 38 | if query := q.Query; query != "" { 39 | query = reInvisibleChars.ReplaceAllString(query, ` `) 40 | query = strings.TrimSpace(query) 41 | lines = append(lines, fmt.Sprintf(fmtLogQuery, query)) 42 | } 43 | 44 | if args, ok := q.Args.([]interface{}); ok && len(args) == 0 { 45 | q.Args = nil 46 | } 47 | 48 | if q.Args != nil { 49 | lines = append(lines, fmt.Sprintf(fmtLogArgs, q.Args)) 50 | } 51 | 52 | if q.Err != nil { 53 | lines = append(lines, fmt.Sprintf(fmtLogError, q.Err)) 54 | } 55 | 56 | lines = append(lines, fmt.Sprintf(fmtLogTimeTaken, float64(q.End.UnixNano()-q.Start.UnixNano())/float64(1e9))) 57 | 58 | return strings.Join(lines, "\n") 59 | } 60 | 61 | // Logger represents a logging collector. You can pass a logging collector to 62 | // gosql.SetLogger(myCollector) to make it collect QueryStatus messages 63 | // after executing a query. 64 | type Logger interface { 65 | Printf(format string, v ...interface{}) 66 | } 67 | 68 | type defaultLogger struct { 69 | logging bool 70 | log Logger 71 | } 72 | 73 | func (d *defaultLogger) Log(m *QueryStatus, show bool) { 74 | if d.logging || show { 75 | d.log.Printf("\n\t%s\n\n", strings.Replace(m.String(), "\n", "\n\t", -1)) 76 | } 77 | } 78 | 79 | func (d *defaultLogger) SetLogging(logging bool) { 80 | d.logging = logging 81 | } 82 | 83 | var logger = &defaultLogger{log: log.New(os.Stderr, "", log.LstdFlags)} 84 | 85 | func SetLogger(l Logger) { 86 | logger.log = l 87 | } 88 | 89 | //SetLogging set default logger 90 | func SetLogging(logging bool) { 91 | logger.logging = logging 92 | } 93 | -------------------------------------------------------------------------------- /mapper.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | type Mapper struct { 4 | db *DB 5 | SQLBuilder 6 | } 7 | 8 | // Table select table name 9 | func Table(t string) *Mapper { 10 | db := &DB{database: Sqlx(defaultLink)} 11 | return &Mapper{db: db, SQLBuilder: SQLBuilder{table: t, dialect: newDialect(db.DriverName())}} 12 | } 13 | 14 | func (m *Mapper) ShowSQL() *Mapper { 15 | m.db.logging = true 16 | return m 17 | } 18 | 19 | //Where 20 | func (m *Mapper) Where(str string, args ...interface{}) *Mapper { 21 | m.SQLBuilder.Where(str, args...) 22 | return m 23 | } 24 | 25 | //Update data from to map[string]interface 26 | func (m *Mapper) Update(data map[string]interface{}) (affected int64, err error) { 27 | result, err := m.db.Exec(m.updateString(data), m.args...) 28 | if err != nil { 29 | return 0, err 30 | } 31 | 32 | return result.RowsAffected() 33 | } 34 | 35 | //Create data from to map[string]interface 36 | func (m *Mapper) Create(data map[string]interface{}) (lastInsertId int64, err error) { 37 | result, err := m.db.Exec(m.insertString(data), m.args...) 38 | if err != nil { 39 | return 0, err 40 | } 41 | 42 | return result.LastInsertId() 43 | } 44 | 45 | //Delete data from to map[string]interface 46 | func (m *Mapper) Delete() (affected int64, err error) { 47 | result, err := m.db.Exec(m.deleteString(), m.args...) 48 | if err != nil { 49 | return 0, err 50 | } 51 | 52 | return result.RowsAffected() 53 | } 54 | 55 | //Count data from to map[string]interface 56 | func (m *Mapper) Count() (num int64, err error) { 57 | err = m.db.Get(&num, m.countString(), m.args...) 58 | return num, err 59 | } 60 | -------------------------------------------------------------------------------- /mapper_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | ) 7 | 8 | func mapInsert(t *testing.T, id int64) int64 { 9 | id, err := Table("users").Create(map[string]interface{}{ 10 | "id": id, 11 | "name": "test" + strconv.Itoa(int(id)), 12 | "status": 1, 13 | "created_at": "2018-07-11 11:58:21", 14 | "updated_at": "2018-07-11 11:58:21", 15 | }) 16 | 17 | if err != nil { 18 | t.Error(err) 19 | } 20 | 21 | if id <= 0 { 22 | t.Error("map insert error") 23 | } 24 | 25 | return id 26 | } 27 | 28 | func TestMapper_Create(t *testing.T) { 29 | RunWithSchema(t, func(t *testing.T) { 30 | mapInsert(t, 1) 31 | }) 32 | } 33 | 34 | func TestMapper_Update(t *testing.T) { 35 | RunWithSchema(t, func(t *testing.T) { 36 | id := mapInsert(t, 1) 37 | 38 | affected, err := Table("users").Where("id = ?", id).Update(map[string]interface{}{ 39 | "name": "fifsky", 40 | }) 41 | 42 | if err != nil { 43 | t.Error(err) 44 | } 45 | 46 | if affected <= 0 { 47 | t.Error("map update error") 48 | } 49 | }) 50 | } 51 | 52 | func TestMapper_Delete(t *testing.T) { 53 | RunWithSchema(t, func(t *testing.T) { 54 | { 55 | id := mapInsert(t, 1) 56 | affected, err := Table("users").Where("id = ?", id).Delete() 57 | 58 | if err != nil { 59 | t.Error(err) 60 | } 61 | 62 | if affected <= 0 { 63 | t.Error("map delete error") 64 | } 65 | } 66 | 67 | { 68 | mapInsert(t, 2) 69 | affected, err := Table("users").Delete() 70 | if err != nil { 71 | t.Error(err) 72 | } 73 | 74 | if affected <= 0 { 75 | t.Error("map delete error") 76 | } 77 | } 78 | }) 79 | } 80 | 81 | func TestMapper_Count(t *testing.T) { 82 | RunWithSchema(t, func(t *testing.T) { 83 | { 84 | id := mapInsert(t, 1) 85 | num, err := Table("users").Where("id = ?", id).Count() 86 | 87 | if err != nil { 88 | t.Error(err) 89 | } 90 | 91 | if num != 1 { 92 | t.Error("map count error") 93 | } 94 | } 95 | 96 | { 97 | mapInsert(t, 2) 98 | mapInsert(t, 3) 99 | num, err := Table("users").Count() 100 | if err != nil { 101 | t.Error(err) 102 | } 103 | 104 | if num <= 0 { 105 | t.Error("map count error") 106 | } 107 | } 108 | }) 109 | } 110 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "reflect" 8 | "strconv" 9 | ) 10 | 11 | var ( 12 | mapper = NewReflectMapper("db") 13 | // Insert database automatically updates fields 14 | AUTO_CREATE_TIME_FIELDS = []string{ 15 | "create_time", 16 | "create_at", 17 | "created_at", 18 | "update_time", 19 | "update_at", 20 | "updated_at", 21 | } 22 | // Update database automatically updates fields 23 | AUTO_UPDATE_TIME_FIELDS = []string{ 24 | "update_time", 25 | "update_at", 26 | "updated_at", 27 | } 28 | ) 29 | 30 | // Model interface 31 | type IModel interface { 32 | TableName() string 33 | PK() string 34 | } 35 | 36 | type Builder struct { 37 | model interface{} 38 | modelReflectValue reflect.Value 39 | modelEntity IModel 40 | db *DB 41 | ctx context.Context 42 | SQLBuilder 43 | modelWrapper *ModelWrapper 44 | } 45 | 46 | // Model construct SQL from Struct 47 | func Model(model interface{}) *Builder { 48 | return &Builder{ 49 | model: model, 50 | db: &DB{database: Sqlx(defaultLink)}, 51 | } 52 | } 53 | 54 | // Model construct SQL from Struct with context 55 | func (b *Builder) Model(model interface{}) *Builder { 56 | b.model = model 57 | return b 58 | } 59 | 60 | // Model construct SQL from Struct with context 61 | func WithContext(ctx context.Context) *Builder { 62 | w := Use(defaultLink) 63 | return &Builder{db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}, ctx: ctx} 64 | } 65 | 66 | // ShowSQL output single sql 67 | func (b *Builder) ShowSQL() *Builder { 68 | b.db.logging = true 69 | return b 70 | } 71 | 72 | func (b *Builder) initModel() { 73 | if b.model == nil { 74 | log.Panicf("model argument must not nil") 75 | } else if m, ok := b.model.(IModel); ok { 76 | b.modelEntity = m 77 | b.table = m.TableName() 78 | b.modelReflectValue = reflect.ValueOf(m) 79 | b.dialect = newDialect(b.db.DriverName()) 80 | } else { 81 | value := reflect.ValueOf(b.model) 82 | if value.Kind() != reflect.Ptr { 83 | log.Panicf("model argument must pass a pointer, not a value %#v", b.model) 84 | } 85 | 86 | if value.IsNil() { 87 | log.Panicf("model argument cannot be nil pointer passed") 88 | } 89 | 90 | tp := reflect.Indirect(value).Type() 91 | 92 | // If b.model is *interface{} have to do a second Elem 93 | // 94 | // For example, 95 | // var m interface{} 96 | // mm := make([]*Model,0) 97 | // mm = append(mm, &Model{Id:1}) 98 | // m = mm 99 | // reflect.Indirect(reflect.ValueOf(&m)).Elem().Type().Kind() == reflect.Slice 100 | 101 | if tp.Kind() == reflect.Interface { 102 | tp = reflect.Indirect(value).Elem().Type() 103 | } 104 | 105 | if tp.Kind() != reflect.Slice { 106 | log.Panicf("model argument must slice, but get %#v", b.model) 107 | } 108 | 109 | tpEl := tp.Elem() 110 | 111 | // Compatible with []*Struct or []Struct 112 | if tpEl.Kind() == reflect.Ptr { 113 | tpEl = tpEl.Elem() 114 | } 115 | 116 | if m, ok := reflect.New(tpEl).Interface().(IModel); ok { 117 | b.modelEntity = m 118 | b.table = m.TableName() 119 | b.modelReflectValue = reflect.ValueOf(m) 120 | b.dialect = newDialect(b.db.DriverName()) 121 | } else { 122 | log.Panicf("model argument must implementation IModel interface or slice []IModel and pointer,but get %#v", b.model) 123 | } 124 | } 125 | } 126 | 127 | // Hint is set TDDL "/*+TDDL:slave()*/" 128 | func (b *Builder) Hint(hint string) *Builder { 129 | b.hint = hint 130 | return b 131 | } 132 | 133 | // ForceIndex 134 | func (b *Builder) ForceIndex(i string) *Builder { 135 | b.forceIndex = i 136 | return b 137 | } 138 | 139 | // Where for example Where("id = ? and name = ?",1,"test") 140 | func (b *Builder) Where(str string, args ...interface{}) *Builder { 141 | b.SQLBuilder.Where(str, args...) 142 | return b 143 | } 144 | 145 | // Select filter column 146 | func (b *Builder) Select(fields string) *Builder { 147 | b.fields = fields 148 | return b 149 | } 150 | 151 | // Limit 152 | func (b *Builder) Limit(i int) *Builder { 153 | b.limit = strconv.Itoa(i) 154 | return b 155 | } 156 | 157 | // Offset 158 | func (b *Builder) Offset(i int) *Builder { 159 | b.offset = strconv.Itoa(i) 160 | return b 161 | } 162 | 163 | // OrderBy for example "id desc" 164 | func (b *Builder) OrderBy(str string) *Builder { 165 | b.order = str 166 | return b 167 | } 168 | 169 | func (b *Builder) reflectModel(autoTime []string) map[string]reflect.Value { 170 | fields := mapper.FieldMap(b.modelReflectValue) 171 | if autoTime != nil { 172 | structAutoTime(fields, autoTime) 173 | } 174 | return fields 175 | } 176 | 177 | // Relation association table builder handle 178 | func (b *Builder) Relation(fieldName string, fn BuilderChainFunc) *Builder { 179 | if b.db.RelationMap == nil { 180 | b.db.RelationMap = make(map[string]BuilderChainFunc) 181 | } 182 | b.db.RelationMap[fieldName] = fn 183 | return b 184 | } 185 | 186 | // All get data row from to Struct 187 | func (b *Builder) Get(zeroValues ...string) (err error) { 188 | b.initModel() 189 | m := zeroValueFilter(b.reflectModel(nil), zeroValues) 190 | // If where is empty, the primary key where condition is generated automatically 191 | b.generateWhere(m) 192 | 193 | if b.modelWrapper != nil { 194 | return b.db.Get(b.modelWrapper, b.queryString(), b.args...) 195 | } 196 | return b.db.Get(b.model, b.queryString(), b.args...) 197 | } 198 | 199 | // All get data rows from to Struct 200 | func (b *Builder) All() (err error) { 201 | b.initModel() 202 | 203 | if b.modelWrapper != nil { 204 | return b.db.Select(b.modelWrapper, b.queryString(), b.args...) 205 | } 206 | return b.db.Select(b.model, b.queryString(), b.args...) 207 | } 208 | 209 | // Create data from to Struct 210 | func (b *Builder) Create() (lastInsertId int64, err error) { 211 | b.initModel() 212 | hook := NewHook(b.ctx, b.db) 213 | hook.callMethod("BeforeChange", b.modelReflectValue) 214 | hook.callMethod("BeforeCreate", b.modelReflectValue) 215 | if hook.HasError() { 216 | return 0, hook.Error() 217 | } 218 | 219 | fields := b.reflectModel(AUTO_CREATE_TIME_FIELDS) 220 | m := structToMap(fields) 221 | 222 | result, err := b.db.Exec(b.insertString(m), b.args...) 223 | if err != nil { 224 | return 0, err 225 | } 226 | 227 | hook.callMethod("AfterCreate", b.modelReflectValue) 228 | hook.callMethod("AfterChange", b.modelReflectValue) 229 | 230 | if hook.HasError() { 231 | return 0, hook.Error() 232 | } 233 | 234 | lastId, err := result.LastInsertId() 235 | 236 | if err != nil { 237 | return 0, err 238 | } 239 | 240 | if v, ok := fields[b.modelEntity.PK()]; ok { 241 | fillPrimaryKey(v, lastId) 242 | } 243 | 244 | return lastId, err 245 | } 246 | 247 | func (b *Builder) generateWhere(m map[string]interface{}) { 248 | for k, v := range m { 249 | b.Where(fmt.Sprintf("%s=%s", k, b.dialect.Placeholder()), v) 250 | } 251 | } 252 | 253 | func (b *Builder) generateWhereForPK(m map[string]interface{}) { 254 | pk := b.modelEntity.PK() 255 | pval, has := m[pk] 256 | if b.where == "" && has { 257 | b.Where(fmt.Sprintf("%s=%s", pk, b.dialect.Placeholder()), pval) 258 | delete(m, pk) 259 | } 260 | } 261 | 262 | // gosql.Model(&User{Id:1,Status:0}).Update("status") 263 | func (b *Builder) Update(zeroValues ...string) (affected int64, err error) { 264 | b.initModel() 265 | hook := NewHook(b.ctx, b.db) 266 | hook.callMethod("BeforeChange", b.modelReflectValue) 267 | hook.callMethod("BeforeUpdate", b.modelReflectValue) 268 | if hook.HasError() { 269 | return 0, hook.Error() 270 | } 271 | 272 | fields := b.reflectModel(AUTO_UPDATE_TIME_FIELDS) 273 | m := zeroValueFilter(fields, zeroValues) 274 | 275 | // If where is empty, the primary key where condition is generated automatically 276 | b.generateWhereForPK(m) 277 | 278 | result, err := b.db.Exec(b.updateString(m), b.args...) 279 | if err != nil { 280 | return 0, err 281 | } 282 | 283 | hook.callMethod("AfterUpdate", b.modelReflectValue) 284 | hook.callMethod("AfterChange", b.modelReflectValue) 285 | 286 | if hook.HasError() { 287 | return 0, hook.Error() 288 | } 289 | 290 | return result.RowsAffected() 291 | } 292 | 293 | // gosql.Model(&User{Id:1}).Delete() 294 | func (b *Builder) Delete(zeroValues ...string) (affected int64, err error) { 295 | b.initModel() 296 | hook := NewHook(b.ctx, b.db) 297 | hook.callMethod("BeforeChange", b.modelReflectValue) 298 | hook.callMethod("BeforeDelete", b.modelReflectValue) 299 | if hook.HasError() { 300 | return 0, hook.Error() 301 | } 302 | 303 | m := zeroValueFilter(b.reflectModel(nil), zeroValues) 304 | // If where is empty, the primary key where condition is generated automatically 305 | b.generateWhere(m) 306 | 307 | result, err := b.db.Exec(b.deleteString(), b.args...) 308 | if err != nil { 309 | return 0, err 310 | } 311 | 312 | hook.callMethod("AfterDelete", b.modelReflectValue) 313 | hook.callMethod("AfterChange", b.modelReflectValue) 314 | 315 | if hook.HasError() { 316 | return 0, hook.Error() 317 | } 318 | 319 | return result.RowsAffected() 320 | } 321 | 322 | // gosql.Model(&User{}).Where("status = 0").Count() 323 | func (b *Builder) Count(zeroValues ...string) (num int64, err error) { 324 | b.initModel() 325 | 326 | m := zeroValueFilter(b.reflectModel(nil), zeroValues) 327 | // If where is empty, the primary key where condition is generated automatically 328 | b.generateWhere(m) 329 | 330 | err = b.db.Get(&num, b.countString(), b.args...) 331 | return num, err 332 | } 333 | -------------------------------------------------------------------------------- /model_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "strconv" 9 | "testing" 10 | "time" 11 | 12 | "github.com/jmoiron/sqlx" 13 | 14 | "github.com/ilibs/gosql/v2/internal/example/models" 15 | ) 16 | 17 | var ( 18 | createSchemas = map[string]string{ 19 | "moments": ` 20 | CREATE TABLE moments ( 21 | id int(11) unsigned NOT NULL AUTO_INCREMENT, 22 | user_id int(11) NOT NULL COMMENT '成员ID', 23 | content text NOT NULL COMMENT '日记内容', 24 | comment_total int(11) NOT NULL DEFAULT '0' COMMENT '评论总数', 25 | like_total int(11) NOT NULL DEFAULT '0' COMMENT '点赞数', 26 | status int(11) NOT NULL DEFAULT '1' COMMENT '1 正常 2删除', 27 | created_at datetime NOT NULL COMMENT '创建时间', 28 | updated_at datetime NOT NULL COMMENT '更新时间', 29 | PRIMARY KEY (id) 30 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`, 31 | "users": ` 32 | CREATE TABLE users ( 33 | id int(11) unsigned NOT NULL AUTO_INCREMENT, 34 | name varchar(50) NOT NULL DEFAULT '', 35 | status int(11) NOT NULL, 36 | success_time datetime, 37 | created_at datetime NOT NULL, 38 | updated_at datetime NOT NULL, 39 | PRIMARY KEY (id) 40 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 41 | `, 42 | "photos": ` 43 | CREATE TABLE photos ( 44 | id int(11) unsigned NOT NULL AUTO_INCREMENT, 45 | url varchar(255) NOT NULL DEFAULT '' COMMENT '照片路径', 46 | moment_id int(11) NOT NULL COMMENT '日记ID', 47 | created_at datetime NOT NULL COMMENT '创建时间', 48 | updated_at datetime NOT NULL COMMENT '更新时间', 49 | PRIMARY KEY (id) 50 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 51 | `, 52 | } 53 | 54 | datas = map[string]string{ 55 | "users": ` 56 | INSERT INTO users (id,name, status, created_at, updated_at) VALUES 57 | (5,'豆爸&玥爸',1,'2018-11-28 10:29:55','2018-11-28 10:29:55'), 58 | (6,'呵呵',1,'2018-11-28 10:29:55','2018-11-28 10:29:55');`, 59 | "photos": ` 60 | INSERT INTO photos (id, url, moment_id, created_at, updated_at) 61 | VALUES 62 | (1,'https://static.fifsky.com/kids/upload/20181128/5febe3b6e23623168cb70eac39d26412.png!blog',1,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 63 | (2,'https://static.fifsky.com/kids/upload/20181128/9c60f42f07d7a0e13293c91fc5740c9d.png!blog',10,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 64 | (3,'https://static.fifsky.com/kids/upload/20181128/458762098fb20128996c9cb21309aa9a.png!blog',1,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 65 | (4,'https://static.fifsky.com/kids/upload/20181128/5b90c5af1bc35375a08cbc990ed662d1.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 66 | (5,'https://static.fifsky.com/kids/upload/20181128/db190be4184774d88abb31521123b14c.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 67 | (6,'https://static.fifsky.com/kids/upload/20181128/e1bd15706a79edd1f92f54538622600e.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 68 | (7,'https://static.fifsky.com/kids/upload/20181128/6bf495726054fa12ae7e6f5d0d4560a4.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 69 | (8,'https://static.fifsky.com/kids/upload/20181128/e1bd15706a79edd1f92f54538622600e.png!blog',15,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 70 | (9,'https://static.fifsky.com/kids/upload/20181128/c6cc28b912f805b6ef402603e0f67852.png!blog',9,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 71 | (10,'https://static.fifsky.com/kids/upload/20181128/a63a0798d098272a39e76c88f39f2f29.png!blog',16,'2018-11-28 18:35:24','2018-11-28 18:35:24'), 72 | (11,'https://static.fifsky.com/kids/upload/20181128/381de1930d970183ab083fe08e2677ac.png!blog',16,'2018-11-28 18:35:24','2018-11-28 18:35:24'); 73 | `, 74 | "moments": ` 75 | INSERT INTO moments (id, user_id, content, comment_total, like_total, status, created_at, updated_at) 76 | VALUES 77 | (1,5,'sdfsdfsdfsdfsdf',0,0,1,'2018-11-28 14:04:02','2018-11-28 14:04:02'), 78 | (2,5,'sdfsdfsdfsdfsdf',0,0,1,'2018-11-28 17:14:23','2018-11-28 17:14:23'), 79 | (3,6,'123123123',0,0,1,'2018-11-28 17:19:38','2018-11-28 17:19:38'), 80 | (4,5,'13212312313',0,0,1,'2018-11-28 17:22:25','2018-11-28 17:22:25'), 81 | (5,5,'123123123123',0,0,1,'2018-11-28 17:24:21','2018-11-28 17:24:21'), 82 | (6,6,'131231232345tasvdf',0,0,1,'2018-11-28 17:24:27','2018-11-28 17:24:27'), 83 | (7,5,'1231231231231231',0,0,1,'2018-11-28 18:07:48','2018-11-28 18:07:48'), 84 | (8,5,'1231231231231231',0,0,1,'2018-11-28 18:09:20','2018-11-28 18:09:20'), 85 | (9,6,'1231231231231231',0,0,1,'2018-11-28 18:11:19','2018-11-28 18:11:19'), 86 | (10,5,'1231231231231231',0,0,1,'2018-11-28 18:13:52','2018-11-28 18:13:52'), 87 | (11,5,'1231231231231231',0,0,1,'2018-11-28 18:15:02','2018-11-28 18:15:02'), 88 | (12,5,'1231231231231231',0,0,1,'2018-11-28 18:15:13','2018-11-28 18:15:13'), 89 | (13,5,'1231231231231231',0,0,1,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 90 | (14,6,'开开信息想你',0,0,1,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 91 | (15,5,'网友们已经开始争相给宝宝取名字了',0,0,1,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 92 | (16,6,' B2B事业部的对外报价显示',0,0,1,'2018-11-28 18:35:24','2018-11-28 18:35:24'); 93 | `, 94 | } 95 | ) 96 | 97 | func RunWithSchema(t *testing.T, test func(t *testing.T)) { 98 | var dbs = []*sqlx.DB{ 99 | Sqlx(), 100 | Sqlx("db2"), 101 | } 102 | defer func() { 103 | // for k := range createSchemas { 104 | // _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", k)) 105 | // if err != nil { 106 | // t.Error(err) 107 | // } 108 | // } 109 | }() 110 | 111 | for _, db := range dbs { 112 | for k, v := range createSchemas { 113 | _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", k)) 114 | if err != nil { 115 | t.Error(err) 116 | } 117 | 118 | _, err = db.Exec(v) 119 | if err != nil { 120 | t.Fatalf("create schema %s error:%s", k, err) 121 | } 122 | } 123 | } 124 | 125 | test(t) 126 | } 127 | 128 | func initDatas(t *testing.T) { 129 | db := Sqlx() 130 | db2 := Sqlx("db2") 131 | for k, v := range datas { 132 | udb := db 133 | if k == "photos" { 134 | udb = db2 135 | } 136 | _, err := udb.Exec(v) 137 | if err != nil { 138 | t.Fatalf("init %s data error:%s", k, err) 139 | } 140 | } 141 | } 142 | 143 | func insert(id int) { 144 | user := &models.Users{ 145 | Id: id, 146 | Name: "test" + strconv.Itoa(id), 147 | Status: 1, 148 | } 149 | _, err := Model(user).Create() 150 | if err != nil { 151 | log.Fatal(err) 152 | } 153 | } 154 | 155 | func insertStatus(id int, status int) { 156 | user := &models.Users{ 157 | Id: id, 158 | Name: "test" + strconv.Itoa(id), 159 | Status: status, 160 | } 161 | Model(user).Create() 162 | } 163 | 164 | func TestBuilder_Get(t *testing.T) { 165 | RunWithSchema(t, func(t *testing.T) { 166 | insert(1) 167 | { 168 | user := &models.Users{} 169 | err := Model(user).Where("id = ?", 1).Get() 170 | 171 | if err != nil { 172 | t.Error(err) 173 | } 174 | //fmt.Println(user) 175 | 176 | } 177 | 178 | { 179 | user := &models.Users{ 180 | Name: "test1", 181 | Status: 1, 182 | } 183 | err := Model(user).Get() 184 | 185 | if err != nil { 186 | t.Error(err) 187 | } 188 | fmt.Println(user) 189 | } 190 | 191 | { 192 | insertStatus(2, 0) 193 | user := &models.Users{ 194 | Status: 0, 195 | } 196 | 197 | err := Model(user).Where("id = ?", 2).Get("status") 198 | 199 | if err != nil { 200 | t.Error(err) 201 | } 202 | fmt.Println(user) 203 | } 204 | }) 205 | } 206 | 207 | func TestBuilder_Hint(t *testing.T) { 208 | RunWithSchema(t, func(t *testing.T) { 209 | insert(1) 210 | insert(2) 211 | 212 | user := make([]*models.Users, 0) 213 | err := Model(&user).Hint("/*+TDDL:slave()*/").All() 214 | 215 | if err != nil { 216 | t.Error(err) 217 | } 218 | 219 | fmt.Println(jsonEncode(user)) 220 | }) 221 | } 222 | 223 | func jsonEncode(i interface{}) string { 224 | ret, _ := json.Marshal(i) 225 | return string(ret) 226 | } 227 | 228 | func TestBuilder_All(t *testing.T) { 229 | RunWithSchema(t, func(t *testing.T) { 230 | insert(1) 231 | insert(2) 232 | 233 | user := make([]*models.Users, 0) 234 | err := Model(&user).All() 235 | 236 | if err != nil { 237 | t.Error(err) 238 | } 239 | 240 | fmt.Println(jsonEncode(user)) 241 | }) 242 | } 243 | 244 | func TestBuilder_Select(t *testing.T) { 245 | RunWithSchema(t, func(t *testing.T) { 246 | insert(1) 247 | insert(2) 248 | 249 | user := make([]*models.Users, 0) 250 | err := Model(&user).Select("id,name").All() 251 | 252 | if err != nil { 253 | t.Error(err) 254 | } 255 | 256 | fmt.Println(jsonEncode(user)) 257 | }) 258 | } 259 | 260 | func TestBuilder_InAll(t *testing.T) { 261 | RunWithSchema(t, func(t *testing.T) { 262 | insert(1) 263 | insert(2) 264 | insert(3) 265 | insert(4) 266 | insert(5) 267 | 268 | user := make([]*models.Users, 0) 269 | err := Model(&user).Where("status = ? and id in(?)", 1, []int{1, 3, 4}).All() 270 | 271 | if err != nil { 272 | t.Error(err) 273 | } 274 | 275 | fmt.Println(jsonEncode(user)) 276 | }) 277 | } 278 | 279 | func TestBuilder_Update(t *testing.T) { 280 | RunWithSchema(t, func(t *testing.T) { 281 | insert(1) 282 | 283 | { 284 | user := &models.Users{ 285 | Name: "test2", 286 | } 287 | 288 | affected, err := Model(user).Where("id=?", 1).Update() 289 | 290 | if err != nil { 291 | t.Error("update user error", err) 292 | } 293 | 294 | if affected == 0 { 295 | t.Error("update user affected error", err) 296 | } 297 | } 298 | 299 | { 300 | user := &models.Users{ 301 | Id: 1, 302 | Name: "test3", 303 | } 304 | 305 | affected, err := Model(user).Update() 306 | 307 | if err != nil { 308 | t.Error("update user error", err) 309 | } 310 | 311 | if affected == 0 { 312 | t.Error("update user affected error", err) 313 | } 314 | } 315 | }) 316 | } 317 | 318 | func TestBuilder_Delete(t *testing.T) { 319 | RunWithSchema(t, func(t *testing.T) { 320 | { 321 | insert(1) 322 | affected, err := Model(&models.Users{}).Where("id=?", 1).Delete() 323 | 324 | if err != nil { 325 | t.Error("delete user error", err) 326 | } 327 | 328 | if affected == 0 { 329 | t.Error("delete user affected error", err) 330 | } 331 | } 332 | { 333 | insert(1) 334 | affected, err := Model(&models.Users{Id: 1}).Delete() 335 | 336 | if err != nil { 337 | t.Error("delete user error", err) 338 | } 339 | 340 | if affected == 0 { 341 | t.Error("delete user affected error", err) 342 | } 343 | } 344 | 345 | { 346 | insertStatus(1, 0) 347 | insertStatus(2, 0) 348 | insertStatus(3, 0) 349 | 350 | affected, err := Model(&models.Users{Status: 0}).Delete("status") 351 | 352 | if err != nil { 353 | t.Error("delete user error", err) 354 | } 355 | 356 | if affected != 3 { 357 | t.Error("delete user affected error", err) 358 | } 359 | } 360 | }) 361 | } 362 | 363 | func TestBuilder_Count(t *testing.T) { 364 | RunWithSchema(t, func(t *testing.T) { 365 | insert(1) 366 | { 367 | num, err := Model(&models.Users{}).Count() 368 | 369 | if err != nil { 370 | t.Error(err) 371 | } 372 | 373 | if num != 1 { 374 | t.Error("count user error") 375 | } 376 | } 377 | 378 | { 379 | insertStatus(2, 0) 380 | insertStatus(3, 0) 381 | 382 | num, err := Model(&models.Users{Status: 0}).Count("status") 383 | 384 | if err != nil { 385 | t.Error(err) 386 | } 387 | 388 | if num != 2 { 389 | t.Error("count user error") 390 | } 391 | } 392 | }) 393 | } 394 | 395 | func TestBuilder_Create(t *testing.T) { 396 | RunWithSchema(t, func(t *testing.T) { 397 | user := &models.Users{ 398 | //Id: 1, 399 | Name: "test", 400 | } 401 | id, err := Model(user).Create() 402 | 403 | if err != nil { 404 | t.Error(err) 405 | } 406 | 407 | if id != 1 { 408 | t.Error("lastInsertId error", id) 409 | } 410 | 411 | if int(id) != user.Id { 412 | t.Error("fill primaryKey error", id) 413 | } 414 | }) 415 | } 416 | 417 | func TestBuilder_Limit(t *testing.T) { 418 | RunWithSchema(t, func(t *testing.T) { 419 | insert(1) 420 | insert(2) 421 | insert(3) 422 | user := &models.Users{} 423 | err := Model(user).Limit(1).Get() 424 | 425 | if err != nil { 426 | t.Error(err) 427 | } 428 | }) 429 | } 430 | 431 | func TestBuilder_Offset(t *testing.T) { 432 | RunWithSchema(t, func(t *testing.T) { 433 | insert(1) 434 | insert(2) 435 | insert(3) 436 | user := &models.Users{} 437 | err := Model(user).Limit(1).Offset(1).Get() 438 | 439 | if err != nil { 440 | t.Error(err) 441 | } 442 | }) 443 | } 444 | 445 | func TestBuilder_OrderBy(t *testing.T) { 446 | RunWithSchema(t, func(t *testing.T) { 447 | insert(1) 448 | insert(2) 449 | insert(3) 450 | user := &models.Users{} 451 | err := Model(user).OrderBy("id desc").Limit(1).Offset(1).Get() 452 | 453 | if err != nil { 454 | t.Error(err) 455 | } 456 | 457 | if user.Id != 2 { 458 | t.Error("order by error") 459 | } 460 | 461 | //fmt.Println(user) 462 | }) 463 | } 464 | 465 | func TestBuilder_Where(t *testing.T) { 466 | RunWithSchema(t, func(t *testing.T) { 467 | insert(1) 468 | insert(2) 469 | insert(3) 470 | user := make([]*models.Users, 0) 471 | err := Model(&user).Where("id in(?,?)", 2, 3).OrderBy("id desc").All() 472 | 473 | if err != nil { 474 | t.Error(err) 475 | } 476 | 477 | if len(user) != 2 { 478 | t.Error("where error") 479 | } 480 | 481 | //fmt.Println(user) 482 | }) 483 | } 484 | 485 | func TestBuilder_NullString(t *testing.T) { 486 | RunWithSchema(t, func(t *testing.T) { 487 | ct, _ := time.Parse("2006-01-02 15:04:05", "2018-09-02 00:00:00") 488 | { 489 | user := &models.Users{ 490 | Id: 1, 491 | Name: "test", 492 | Status: 1, 493 | SuccessTime: sql.NullString{ 494 | String: "2018-09-03 00:00:00", 495 | Valid: true, 496 | }, 497 | CreatedAt: ct, 498 | } 499 | _, err := Model(user).Create() 500 | if err != nil { 501 | log.Fatal(err) 502 | } 503 | } 504 | 505 | { 506 | user := &models.Users{} 507 | err := Model(user).Where("id=1").Get() 508 | 509 | if err != nil { 510 | t.Error(err) 511 | } 512 | 513 | fmt.Println(jsonEncode(user)) 514 | } 515 | 516 | { 517 | user := &models.Users{ 518 | Id: 1, 519 | SuccessTime: sql.NullString{ 520 | String: "2018-09-03 00:00:00", 521 | Valid: true, 522 | }, 523 | CreatedAt: ct, 524 | } 525 | 526 | err := Model(user).Get() 527 | 528 | if err != nil { 529 | t.Error(err) 530 | } 531 | 532 | fmt.Println(jsonEncode(user)) 533 | } 534 | }) 535 | } 536 | 537 | func TestBuilder_Relation1(t *testing.T) { 538 | RunWithSchema(t, func(t *testing.T) { 539 | initDatas(t) 540 | moment := &MomentList{} 541 | err := Model(moment).Relation("User", func(b *Builder) { 542 | b.Where("status = 1") 543 | }).Where("status = 1 and id = ?", 14).Get() 544 | 545 | b, _ := json.MarshalIndent(moment, "", " ") 546 | fmt.Println(string(b), err) 547 | 548 | if err != nil { 549 | t.Fatal(err) 550 | } 551 | }) 552 | } 553 | 554 | func TestBuilder_Relation2(t *testing.T) { 555 | RunWithSchema(t, func(t *testing.T) { 556 | var moments = make([]*MomentList, 0) 557 | err := Model(&moments).Relation("User", func(b *Builder) { 558 | b.Where("status = 0") 559 | }).Where("status = 1").Limit(10).All() 560 | 561 | b, _ := json.MarshalIndent(moments, "", " ") 562 | fmt.Println(string(b), err) 563 | 564 | if err != nil { 565 | t.Fatal(err) 566 | } 567 | }) 568 | } 569 | -------------------------------------------------------------------------------- /model_wrapper.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | type ModelWrapper struct { 4 | dbList map[string]*DB 5 | model interface{} 6 | } 7 | 8 | type ModelWrapperFactory func(m interface{}) *ModelWrapper 9 | 10 | func NewModelWrapper(dbList map[string]*DB, model interface{}) *ModelWrapper { 11 | return &ModelWrapper{dbList: dbList, model: model} 12 | } 13 | 14 | func (m *ModelWrapper) GetRelationDB(connect string) *DB { 15 | return m.dbList[connect] 16 | } 17 | 18 | func (m *ModelWrapper) UnWrap() interface{} { 19 | return m.model 20 | } 21 | -------------------------------------------------------------------------------- /reflect.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/jmoiron/sqlx/reflectx" 7 | ) 8 | 9 | type ReflectMapper struct { 10 | mapper *reflectx.Mapper 11 | } 12 | 13 | func NewReflectMapper(tagName string) *ReflectMapper { 14 | return &ReflectMapper{ 15 | mapper: reflectx.NewMapper(tagName), 16 | } 17 | } 18 | 19 | // FieldByName returns a field by its mapped name as a reflect.Value. 20 | // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. 21 | // Returns zero Value if the name is not found. 22 | func (r *ReflectMapper) FieldByName(v reflect.Value, name string) reflect.Value { 23 | return r.mapper.FieldByName(v, name) 24 | } 25 | 26 | // FieldMap returns the mapper's mapping of field names to reflect values. Panics 27 | // if v's Kind is not Struct, or v is not Indirectable to a struct kind. 28 | func (r *ReflectMapper) FieldMap(v reflect.Value) map[string]reflect.Value { 29 | v = reflect.Indirect(v) 30 | 31 | ret := map[string]reflect.Value{} 32 | tm := r.mapper.TypeMap(v.Type()) 33 | for tagName, fi := range tm.Names { 34 | //fmt.Println(tagName,fi.Parent.Zero.Kind(),fi.Parent.Field.Anonymous) 35 | if (fi.Parent.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct)) && !fi.Parent.Field.Anonymous { 36 | continue 37 | } 38 | ret[tagName] = reflectx.FieldByIndexes(v, fi.Index) 39 | } 40 | 41 | return ret 42 | } 43 | -------------------------------------------------------------------------------- /reflect_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/ilibs/gosql/v2/internal/example/models" 10 | ) 11 | 12 | func TestReflectMapper_FieldMap(t *testing.T) { 13 | mapper := NewReflectMapper("db") 14 | 15 | { 16 | user := &models.Users{ 17 | Id: 1, 18 | Name: "test", 19 | SuccessTime: sql.NullString{ 20 | String: "2018-09-03 00:00:00", 21 | Valid: false, 22 | }, 23 | CreatedAt: time.Now(), 24 | UpdatedAt: time.Now(), 25 | } 26 | 27 | fields := mapper.FieldMap(reflect.ValueOf(user)) 28 | if len(fields) != 6 { 29 | t.Error("FieldMap length error") 30 | } 31 | 32 | if v := fields["name"].Interface().(string); v != user.Name { 33 | t.Errorf("Expecting %s, got %s", user.Name, v) 34 | } 35 | 36 | if v := fields["success_time"].Interface().(sql.NullString).String; v != user.SuccessTime.String { 37 | t.Errorf("Expecting %s, got %s", user.Name, v) 38 | } 39 | } 40 | 41 | { 42 | photos := &models.Photos{} 43 | fields := mapper.FieldMap(reflect.ValueOf(photos)) 44 | 45 | if len(fields) != 5 { 46 | t.Error("FieldMap length error") 47 | } 48 | 49 | if v := fields["url"].Interface().(string); v != photos.Url { 50 | t.Errorf("Expecting %s, got %s", photos.Url, v) 51 | } 52 | 53 | if _, ok := fields["created_at"].Interface().(time.Time); !ok { 54 | t.Error("Expecting true, got false") 55 | } 56 | 57 | //fmt.Println(fields) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /relation.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "strings" 9 | ) 10 | 11 | func eachField(t reflect.Type, fn func(field reflect.StructField, val string, name string, relations []string, connection string) error) error { 12 | for i := 0; i < t.NumField(); i++ { 13 | val := t.Field(i).Tag.Get("relation") 14 | connection := t.Field(i).Tag.Get("connection") 15 | name := t.Field(i).Name 16 | field := t.Field(i) 17 | 18 | if val != "" && val != "-" { 19 | relations := strings.Split(val, ",") 20 | if len(relations) != 2 { 21 | return errors.New(fmt.Sprintf("relation tag error, length must 2,but get %v", relations)) 22 | } 23 | 24 | err := fn(field, val, name, relations, connection) 25 | if err != nil { 26 | return err 27 | } 28 | } 29 | } 30 | return nil 31 | } 32 | 33 | func newModel(value reflect.Value, connection string) *Builder { 34 | var m *Builder 35 | if connection != "" { 36 | m = Use(connection).Model(value.Interface()) 37 | } else { 38 | m = Model(value.Interface()) 39 | } 40 | 41 | return m 42 | } 43 | 44 | func newModelWithWrapper(wrapper *ModelWrapper, defaultDb *DB, value reflect.Value, connection string) *Builder { 45 | var m *Builder 46 | if connection != "" { 47 | var relationDb *DB 48 | if wrapper != nil { // if wrapper model 49 | relationDb = wrapper.GetRelationDB(connection) 50 | } else { 51 | // if not wrapper, so use 52 | relationDb = Use(connection) 53 | } 54 | m = relationDb.Model(value.Interface()) 55 | } else { 56 | // if connection is null,so db is default link 57 | m = defaultDb.Model(value.Interface()) 58 | } 59 | return m 60 | } 61 | 62 | // RelationOne is get the associated relational data for a single piece of data 63 | func RelationOne(wrapper *ModelWrapper, db *DB, data interface{}) error { 64 | refVal := reflect.Indirect(reflect.ValueOf(data)) 65 | t := refVal.Type() 66 | 67 | return eachField(t, func(field reflect.StructField, val string, name string, relations []string, connection string) error { 68 | var foreignModel reflect.Value 69 | // if field type is slice then one-to-many ,eg: []*Struct 70 | if field.Type.Kind() == reflect.Slice { 71 | foreignModel = reflect.New(field.Type) 72 | // m := newModel(foreignModel, connection) 73 | m := newModelWithWrapper(wrapper, db, foreignModel, connection) 74 | 75 | if chainFn, ok := db.RelationMap[name]; ok { 76 | chainFn(m) 77 | } 78 | 79 | // batch get field values 80 | // Since the structure is slice, there is no need to new Value 81 | err := m.Where(fmt.Sprintf("%s=%s", relations[1], m.dialect.Placeholder()), mapper.FieldByName(refVal, relations[0]).Interface()).All() 82 | if err != nil { 83 | return err 84 | } 85 | 86 | if reflect.Indirect(foreignModel).Len() == 0 { 87 | // If relation data is empty, must set empty slice 88 | // Otherwise, the JSON result will be null instead of [] 89 | refVal.FieldByName(name).Set(reflect.MakeSlice(field.Type, 0, 0)) 90 | } else { 91 | refVal.FieldByName(name).Set(foreignModel.Elem()) 92 | } 93 | 94 | } else { 95 | // If field type is struct the one-to-one,eg: *Struct 96 | foreignModel = reflect.New(field.Type.Elem()) 97 | // m := newModel(foreignModel, connection) 98 | m := newModelWithWrapper(wrapper, db, foreignModel, connection) 99 | if chainFn, ok := db.RelationMap[name]; ok { 100 | chainFn(m) 101 | } 102 | 103 | err := m.Where(fmt.Sprintf("%s=%s", relations[1], m.dialect.Placeholder()), mapper.FieldByName(refVal, relations[0]).Interface()).Get() 104 | // If one-to-one NoRows is not an error that needs to be terminated 105 | if err != nil && err != sql.ErrNoRows { 106 | return err 107 | } 108 | 109 | if err == nil { 110 | refVal.FieldByName(name).Set(foreignModel) 111 | } 112 | } 113 | return nil 114 | }) 115 | } 116 | 117 | // RelationAll is gets the associated relational data for multiple pieces of data 118 | func RelationAll(wrapper *ModelWrapper, db *DB, data interface{}) error { 119 | refVal := reflect.Indirect(reflect.ValueOf(data)) 120 | 121 | l := refVal.Len() 122 | 123 | if l == 0 { 124 | return nil 125 | } 126 | 127 | // get the struct field in slice 128 | t := reflect.Indirect(refVal.Index(0)).Type() 129 | 130 | return eachField(t, func(field reflect.StructField, val string, name string, relations []string, connection string) error { 131 | relVals := make([]interface{}, 0) 132 | relValsMap := make(map[interface{}]interface{}, 0) 133 | 134 | // get relation field values and unique 135 | for j := 0; j < l; j++ { 136 | v := mapper.FieldByName(refVal.Index(j), relations[0]).Interface() 137 | relValsMap[v] = nil 138 | } 139 | 140 | for k, _ := range relValsMap { 141 | relVals = append(relVals, k) 142 | } 143 | 144 | var foreignModel reflect.Value 145 | // if field type is slice then one to many ,eg: []*Struct 146 | if field.Type.Kind() == reflect.Slice { 147 | foreignModel = reflect.New(field.Type) 148 | // m := newModel(foreignModel, connection) 149 | m := newModelWithWrapper(wrapper, db, foreignModel, connection) 150 | if chainFn, ok := db.RelationMap[name]; ok { 151 | chainFn(m) 152 | } 153 | 154 | // batch get field values 155 | // Since the structure is slice, there is no need to new Value 156 | err := m.Where(fmt.Sprintf("%s in(%s)", relations[1], m.dialect.Placeholder()), relVals).All() 157 | if err != nil { 158 | return err 159 | } 160 | 161 | fmap := make(map[interface{}]reflect.Value) 162 | 163 | // Combine relation data as a one-to-many relation 164 | // For example, if there are multiple images under an article 165 | // we use the article ID to associate the images, map[1][]*Images 166 | for n := 0; n < reflect.Indirect(foreignModel).Len(); n++ { 167 | val := reflect.Indirect(foreignModel).Index(n) 168 | fid := mapper.FieldByName(val, relations[1]) 169 | if _, has := fmap[fid.Interface()]; !has { 170 | fmap[fid.Interface()] = reflect.New(reflect.SliceOf(field.Type.Elem())).Elem() 171 | } 172 | fmap[fid.Interface()] = reflect.Append(fmap[fid.Interface()], val) 173 | } 174 | 175 | // Set the result to the model 176 | for j := 0; j < l; j++ { 177 | fid := mapper.FieldByName(refVal.Index(j), relations[0]) 178 | if value, has := fmap[fid.Interface()]; has { 179 | reflect.Indirect(refVal.Index(j)).FieldByName(name).Set(value) 180 | } else { 181 | // If relation data is empty, must set empty slice 182 | // Otherwise, the JSON result will be null instead of [] 183 | reflect.Indirect(refVal.Index(j)).FieldByName(name).Set(reflect.MakeSlice(field.Type, 0, 0)) 184 | } 185 | } 186 | } else { 187 | // If field type is struct the one to one,eg: *Struct 188 | foreignModel = reflect.New(field.Type.Elem()) 189 | 190 | // Batch get field values, but must new slice []*Struct 191 | fi := reflect.New(reflect.SliceOf(foreignModel.Type())) 192 | // m := newModel(fi, connection) 193 | m := newModelWithWrapper(wrapper, db, fi, connection) 194 | 195 | if chainFn, ok := db.RelationMap[name]; ok { 196 | chainFn(m) 197 | } 198 | 199 | // TODO sqlx.In maybe not support postgres 200 | err := m.Where(fmt.Sprintf("%s in(%s)", relations[1], m.dialect.Placeholder()), relVals).All() 201 | if err != nil { 202 | return err 203 | } 204 | 205 | // Combine relation data as a one-to-one relation 206 | fmap := make(map[interface{}]reflect.Value) 207 | for n := 0; n < reflect.Indirect(fi).Len(); n++ { 208 | val := reflect.Indirect(fi).Index(n) 209 | fid := mapper.FieldByName(val, relations[1]) 210 | fmap[fid.Interface()] = val 211 | } 212 | 213 | // Set the result to the model 214 | for j := 0; j < l; j++ { 215 | fid := mapper.FieldByName(refVal.Index(j), relations[0]) 216 | if value, has := fmap[fid.Interface()]; has { 217 | reflect.Indirect(refVal.Index(j)).FieldByName(name).Set(value) 218 | } 219 | } 220 | } 221 | 222 | return nil 223 | }) 224 | } 225 | -------------------------------------------------------------------------------- /relation_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/ilibs/gosql/v2/internal/example/models" 7 | ) 8 | 9 | type UserMomentRelation struct { 10 | models.Users 11 | Moments []*models.Moments `json:"moments" db:"-" relation:"id,user_id" connection:"db2"` 12 | } 13 | 14 | func TestRelationOneWithRelationDB(t *testing.T) { 15 | RunWithSchema(t, func(t *testing.T) { 16 | initDatas(t) 17 | moment := &UserMomentRelation{} 18 | // Use("default").Model() 19 | err := Use("default").Model(NewModelWrapper(map[string]*DB{ 20 | "default": Use("default"), 21 | "db2": Use("db2"), 22 | }, moment)).Relation("Moments", func(b *Builder) { 23 | b.Limit(2) 24 | }).Where("id = ?", 5).Get() 25 | 26 | // b, _ := json.MarshalIndent(moment, "", " ") 27 | // fmt.Println(string(b), err) 28 | 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | }) 33 | } 34 | 35 | type UserMoment struct { 36 | models.Users 37 | Moments []*models.Moments `json:"moments" db:"-" relation:"id,user_id" connection:"db2"` 38 | } 39 | 40 | func TestRelationOne2(t *testing.T) { 41 | RunWithSchema(t, func(t *testing.T) { 42 | initDatas(t) 43 | moment := &UserMoment{} 44 | err := Model(moment).Relation("Moments", func(b *Builder) { 45 | b.Limit(2) 46 | }).Where("id = ?", 5).Get() 47 | 48 | // b, _ := json.MarshalIndent(moment, "", " ") 49 | // fmt.Println(string(b), err) 50 | 51 | if err != nil { 52 | t.Fatal(err) 53 | } 54 | }) 55 | } 56 | 57 | type MomentList struct { 58 | models.Moments 59 | User *models.Users `json:"user" db:"-" relation:"user_id,id"` 60 | Photos []*models.Photos `json:"photos" db:"-" relation:"id,moment_id" connection:"db2"` 61 | } 62 | 63 | func TestRelationOne(t *testing.T) { 64 | RunWithSchema(t, func(t *testing.T) { 65 | initDatas(t) 66 | 67 | moment := &MomentList{} 68 | err := Model(moment).Where("status = 1 and id = ?", 14).Get() 69 | 70 | // b, _ := json.MarshalIndent(moment, "", " ") 71 | // fmt.Println(string(b), err) 72 | 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | 77 | if moment.User.Name == "" { 78 | t.Fatal("relation one-to-one data error[user]") 79 | } 80 | 81 | if len(moment.Photos) == 0 { 82 | t.Fatal("relation get one-to-many data error[photos]") 83 | } 84 | }) 85 | } 86 | 87 | func TestRelationAll(t *testing.T) { 88 | RunWithSchema(t, func(t *testing.T) { 89 | initDatas(t) 90 | 91 | var moments = make([]*MomentList, 0) 92 | err := Model(&moments).Where("status = 1").Limit(10).All() 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | // b, _ := json.MarshalIndent(moments, "", " ") 98 | // fmt.Println(string(b), err) 99 | 100 | if len(moments) == 0 { 101 | t.Fatal("relation get many-to-many data error[moments]") 102 | } 103 | 104 | if moments[0].User.Name == "" { 105 | t.Fatal("relation get many-to-many data error[user]") 106 | } 107 | 108 | if len(moments[0].Photos) == 0 { 109 | t.Fatal("relation get many-to-many data error[photos]") 110 | } 111 | }) 112 | } 113 | 114 | type MomentListWrapper struct { 115 | models.Moments 116 | User *models.Users `json:"user" db:"-" relation:"user_id,id"` 117 | Photos []*models.Photos `json:"photos" db:"-" relation:"id,moment_id" connection:"db2"` 118 | } 119 | 120 | func TestRelationModelWrapper(t *testing.T) { 121 | RunWithSchema(t, func(t *testing.T) { 122 | initDatas(t) 123 | var moments = make([]*MomentListWrapper, 0) 124 | err := Use("default").Model(NewModelWrapper(map[string]*DB{ 125 | "default": Use("default"), 126 | "db2": Use("db2"), 127 | }, &moments)).Where("status = 1").Limit(10).All() 128 | if err != nil { 129 | t.Fatal(err) 130 | } 131 | 132 | // b, _ := json.MarshalIndent(moments, "", " ") 133 | // fmt.Println(string(b), err) 134 | 135 | if len(moments) == 0 { 136 | t.Fatal("relation get many-to-many data error[moments]") 137 | } 138 | 139 | if moments[0].User.Name == "" { 140 | t.Fatal("relation get many-to-many data error[user]") 141 | } 142 | 143 | if len(moments[0].Photos) == 0 { 144 | t.Fatal("relation get many-to-many data error[photos]") 145 | } 146 | }) 147 | } 148 | -------------------------------------------------------------------------------- /sql_builder.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type SQLBuilder struct { 9 | dialect Dialect 10 | fields string 11 | table string 12 | forceIndex string 13 | where string 14 | order string 15 | limit string 16 | offset string 17 | hint string 18 | // Extra args to be substituted in the *where* clause 19 | args []interface{} 20 | } 21 | 22 | func (s *SQLBuilder) limitFormat() string { 23 | if s.limit != "" { 24 | return fmt.Sprintf("LIMIT %s", s.limit) 25 | } 26 | return "" 27 | } 28 | 29 | func (s *SQLBuilder) offsetFormat() string { 30 | if s.offset != "" { 31 | return fmt.Sprintf("OFFSET %s", s.offset) 32 | } 33 | return "" 34 | } 35 | 36 | func (s *SQLBuilder) orderFormat() string { 37 | if s.order != "" { 38 | return fmt.Sprintf("ORDER BY %s", s.order) 39 | } 40 | return "" 41 | } 42 | 43 | // queryString Assemble the query statement 44 | func (s *SQLBuilder) queryString() string { 45 | if s.fields == "" { 46 | s.fields = "*" 47 | } 48 | 49 | table := s.dialect.Quote(s.table) 50 | if s.forceIndex != "" { 51 | table += fmt.Sprintf(" force index(%s)", s.forceIndex) 52 | } 53 | 54 | query := fmt.Sprintf("%sSELECT %s FROM %s %s %s %s %s", s.hint, s.fields, table, s.where, s.orderFormat(), s.limitFormat(), s.offsetFormat()) 55 | query = strings.TrimRight(query, " ") 56 | query = query + ";" 57 | 58 | return query 59 | } 60 | 61 | // countString Assemble the count statement 62 | func (s *SQLBuilder) countString() string { 63 | query := fmt.Sprintf("%sSELECT count(*) FROM %s %s", s.hint, s.dialect.Quote(s.table), s.where) 64 | query = strings.TrimRight(query, " ") 65 | query = query + ";" 66 | 67 | return query 68 | } 69 | 70 | // insertString Assemble the insert statement 71 | func (s *SQLBuilder) insertString(params map[string]interface{}) string { 72 | var cols, vals []string 73 | for _, k := range sortedParamKeys(params) { 74 | cols = append(cols, s.dialect.Quote(k)) 75 | vals = append(vals, s.dialect.Placeholder()) 76 | s.args = append(s.args, params[k]) 77 | } 78 | 79 | return fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s);", s.dialect.Quote(s.table), strings.Join(cols, ","), strings.Join(vals, ",")) 80 | } 81 | 82 | // updateString Assemble the update statement 83 | func (s *SQLBuilder) updateString(params map[string]interface{}) string { 84 | var updateFields []string 85 | args := make([]interface{}, 0) 86 | 87 | for _, k := range sortedParamKeys(params) { 88 | if e, ok := params[k].(*expr); ok { 89 | updateFields = append(updateFields, fmt.Sprintf("%s=%s", s.dialect.Quote(k), e.expr)) 90 | args = append(args, e.args...) 91 | } else { 92 | updateFields = append(updateFields, fmt.Sprintf("%s=%s", s.dialect.Quote(k), s.dialect.Placeholder())) 93 | args = append(args, params[k]) 94 | } 95 | } 96 | args = append(args, s.args...) 97 | s.args = args 98 | 99 | query := fmt.Sprintf("UPDATE %s SET %s %s", s.dialect.Quote(s.table), strings.Join(updateFields, ","), s.where) 100 | query = strings.TrimRight(query, " ") 101 | query = query + ";" 102 | return query 103 | } 104 | 105 | // deleteString Assemble the delete statement 106 | func (s *SQLBuilder) deleteString() string { 107 | query := fmt.Sprintf("DELETE FROM %s %s", s.dialect.Quote(s.table), s.where) 108 | query = strings.TrimRight(query, " ") 109 | query = query + ";" 110 | return query 111 | } 112 | 113 | func (s *SQLBuilder) Where(str string, args ...interface{}) { 114 | if s.where != "" { 115 | s.where = fmt.Sprintf("%s AND (%s)", s.where, str) 116 | } else { 117 | s.where = fmt.Sprintf("WHERE (%s)", str) 118 | } 119 | 120 | // NB this assumes that args are only supplied for where clauses 121 | // this may be an incorrect assumption! 122 | if args != nil { 123 | if s.args == nil { 124 | s.args = args 125 | } else { 126 | s.args = append(s.args, args...) 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /sql_builder_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestSQLBuilder_queryString(t *testing.T) { 9 | 10 | b := &SQLBuilder{ 11 | dialect: mustGetDialect("mysql"), 12 | table: "users", 13 | order: "id desc", 14 | limit: "0", 15 | offset: "10", 16 | } 17 | 18 | b.Where("id = ?", 1) 19 | 20 | if b.queryString() != "SELECT * FROM `users` WHERE (id = ?) ORDER BY id desc LIMIT 0 OFFSET 10;" { 21 | t.Error("sql builder query error", b.queryString()) 22 | } 23 | fmt.Println(b.queryString()) 24 | } 25 | 26 | func TestSQLBuilder_queryForceIndexString(t *testing.T) { 27 | 28 | b := &SQLBuilder{ 29 | dialect: mustGetDialect("mysql"), 30 | table: "users", 31 | order: "id desc", 32 | forceIndex: "idx_user", 33 | limit: "0", 34 | offset: "10", 35 | } 36 | 37 | b.Where("id = ?", 1) 38 | 39 | if b.queryString() != "SELECT * FROM `users` force index(idx_user) WHERE (id = ?) ORDER BY id desc LIMIT 0 OFFSET 10;" { 40 | t.Error("sql builder query error", b.queryString()) 41 | } 42 | fmt.Println(b.queryString()) 43 | } 44 | 45 | func TestSQLBuilder_insertString(t *testing.T) { 46 | 47 | b := &SQLBuilder{ 48 | dialect: mustGetDialect("mysql"), 49 | table: "users", 50 | } 51 | 52 | query := b.insertString(map[string]interface{}{ 53 | "id": 1, 54 | "name": "test", 55 | "email": "test@test.com", 56 | "created_at": "2018-07-11 11:58:21", 57 | "updated_at": "2018-07-11 11:58:21", 58 | }) 59 | 60 | if query != "INSERT INTO `users` (`created_at`,`email`,`id`,`name`,`updated_at`) VALUES(?,?,?,?,?);" { 61 | t.Error("sql builder insert error", query) 62 | } 63 | } 64 | 65 | func TestSQLBuilder_updateString(t *testing.T) { 66 | 67 | b := &SQLBuilder{ 68 | dialect: mustGetDialect("mysql"), 69 | table: "users", 70 | } 71 | 72 | b.Where("id = ?", 1) 73 | query := b.updateString(map[string]interface{}{ 74 | "name": "test", 75 | "email": "test@test.com", 76 | }) 77 | 78 | if query != "UPDATE `users` SET `email`=?,`name`=? WHERE (id = ?);" { 79 | t.Error("sql builder update error", query) 80 | } 81 | 82 | fmt.Println(query, b.args) 83 | } 84 | 85 | func TestSQLBuilder_deleteString(t *testing.T) { 86 | 87 | b := &SQLBuilder{ 88 | dialect: mustGetDialect("mysql"), 89 | table: "users", 90 | } 91 | 92 | b.Where("id = ?", 1) 93 | 94 | query := b.deleteString() 95 | 96 | if query != "DELETE FROM `users` WHERE (id = ?);" { 97 | t.Error("sql builder delete error", query) 98 | } 99 | } 100 | 101 | func TestSQLBuilder_countString(t *testing.T) { 102 | 103 | b := &SQLBuilder{ 104 | dialect: mustGetDialect("mysql"), 105 | table: "users", 106 | } 107 | b.Where("id = ?", 1) 108 | 109 | query := b.countString() 110 | 111 | if query != "SELECT count(*) FROM `users` WHERE (id = ?);" { 112 | t.Error("sql builder count error", query) 113 | } 114 | } 115 | 116 | func TestSQLBuilder_Dialect(t *testing.T) { 117 | testData := map[string]string{ 118 | "mysql": "INSERT INTO `users` (`created_at`,`email`,`id`,`name`,`updated_at`) VALUES(?,?,?,?,?);", 119 | "postgres": `INSERT INTO "users" ("created_at","email","id","name","updated_at") VALUES($1,$2,$3,$4,$5);`, 120 | "sqlite3": `INSERT INTO "users" ("created_at","email","id","name","updated_at") VALUES(?,?,?,?,?);`, 121 | } 122 | 123 | for k, v := range testData { 124 | b := &SQLBuilder{ 125 | dialect: mustGetDialect(k), 126 | table: "users", 127 | } 128 | 129 | query := b.insertString(map[string]interface{}{ 130 | "id": 1, 131 | "name": "test", 132 | "email": "test@test.com", 133 | "created_at": "2018-07-11 11:58:21", 134 | "updated_at": "2018-07-11 11:58:21", 135 | }) 136 | 137 | if query != v { 138 | t.Error(fmt.Sprintf("sql builder %s dialect insert error", k), query) 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /testdata/test.sql: -------------------------------------------------------------------------------- 1 | # Dump of table moments 2 | # ------------------------------------------------------------ 3 | 4 | DROP TABLE IF EXISTS `moments`; 5 | 6 | CREATE TABLE `moments` ( 7 | `id` int(11) unsigned NOT NULL AUTO_INCREMENT, 8 | `user_id` int(11) NOT NULL COMMENT '成员ID', 9 | `content` text NOT NULL COMMENT '日记内容', 10 | `comment_total` int(11) NOT NULL DEFAULT '0' COMMENT '评论总数', 11 | `like_total` int(11) NOT NULL DEFAULT '0' COMMENT '点赞数', 12 | `status` int(11) NOT NULL DEFAULT '1' COMMENT '1 正常 2删除', 13 | `created_at` datetime NOT NULL COMMENT '创建时间', 14 | `updated_at` datetime NOT NULL COMMENT '更新时间', 15 | PRIMARY KEY (`id`) 16 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 17 | 18 | LOCK TABLES `moments` WRITE; 19 | /*!40000 ALTER TABLE `moments` DISABLE KEYS */; 20 | 21 | INSERT INTO `moments` (`id`, `user_id`, `content`, `comment_total`, `like_total`, `status`, `created_at`, `updated_at`) 22 | VALUES 23 | (1,5,'sdfsdfsdfsdfsdf',0,0,1,'2018-11-28 14:04:02','2018-11-28 14:04:02'), 24 | (2,5,'sdfsdfsdfsdfsdf',0,0,1,'2018-11-28 17:14:23','2018-11-28 17:14:23'), 25 | (3,6,'123123123',0,0,1,'2018-11-28 17:19:38','2018-11-28 17:19:38'), 26 | (4,5,'13212312313',0,0,1,'2018-11-28 17:22:25','2018-11-28 17:22:25'), 27 | (5,5,'123123123123',0,0,1,'2018-11-28 17:24:21','2018-11-28 17:24:21'), 28 | (6,6,'131231232345tasvdf',0,0,1,'2018-11-28 17:24:27','2018-11-28 17:24:27'), 29 | (7,5,'1231231231231231',0,0,1,'2018-11-28 18:07:48','2018-11-28 18:07:48'), 30 | (8,5,'1231231231231231',0,0,1,'2018-11-28 18:09:20','2018-11-28 18:09:20'), 31 | (9,6,'1231231231231231',0,0,1,'2018-11-28 18:11:19','2018-11-28 18:11:19'), 32 | (10,5,'1231231231231231',0,0,1,'2018-11-28 18:13:52','2018-11-28 18:13:52'), 33 | (11,5,'1231231231231231',0,0,1,'2018-11-28 18:15:02','2018-11-28 18:15:02'), 34 | (12,5,'1231231231231231',0,0,1,'2018-11-28 18:15:13','2018-11-28 18:15:13'), 35 | (13,5,'1231231231231231',0,0,1,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 36 | (14,6,'开开信息想你',0,0,1,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 37 | (15,5,'网友们已经开始争相给宝宝取名字了',0,0,1,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 38 | (16,6,' B2B事业部的对外报价显示',0,0,1,'2018-11-28 18:35:24','2018-11-28 18:35:24'); 39 | 40 | /*!40000 ALTER TABLE `moments` ENABLE KEYS */; 41 | UNLOCK TABLES; 42 | 43 | 44 | # Dump of table photos 45 | # ------------------------------------------------------------ 46 | 47 | DROP TABLE IF EXISTS `photos`; 48 | 49 | CREATE TABLE `photos` ( 50 | `id` int(11) unsigned NOT NULL AUTO_INCREMENT, 51 | `url` varchar(255) NOT NULL DEFAULT '' COMMENT '照片路径', 52 | `moment_id` int(11) NOT NULL COMMENT '日记ID', 53 | `created_at` datetime NOT NULL COMMENT '创建时间', 54 | `updated_at` datetime NOT NULL COMMENT '更新时间', 55 | PRIMARY KEY (`id`) 56 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 57 | 58 | LOCK TABLES `photos` WRITE; 59 | /*!40000 ALTER TABLE `photos` DISABLE KEYS */; 60 | 61 | INSERT INTO `photos` (`id`, `url`, `moment_id`, `created_at`, `updated_at`) 62 | VALUES 63 | (1,'https://static.fifsky.com/kids/upload/20181128/5febe3b6e23623168cb70eac39d26412.png!blog',1,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 64 | (2,'https://static.fifsky.com/kids/upload/20181128/9c60f42f07d7a0e13293c91fc5740c9d.png!blog',10,'2018-11-28 18:15:39','2018-11-28 18:15:39'), 65 | (3,'https://static.fifsky.com/kids/upload/20181128/458762098fb20128996c9cb21309aa9a.png!blog',1,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 66 | (4,'https://static.fifsky.com/kids/upload/20181128/5b90c5af1bc35375a08cbc990ed662d1.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 67 | (5,'https://static.fifsky.com/kids/upload/20181128/db190be4184774d88abb31521123b14c.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 68 | (6,'https://static.fifsky.com/kids/upload/20181128/e1bd15706a79edd1f92f54538622600e.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 69 | (7,'https://static.fifsky.com/kids/upload/20181128/6bf495726054fa12ae7e6f5d0d4560a4.png!blog',14,'2018-11-28 18:31:37','2018-11-28 18:31:37'), 70 | (8,'https://static.fifsky.com/kids/upload/20181128/e1bd15706a79edd1f92f54538622600e.png!blog',15,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 71 | (9,'https://static.fifsky.com/kids/upload/20181128/c6cc28b912f805b6ef402603e0f67852.png!blog',9,'2018-11-28 18:34:45','2018-11-28 18:34:45'), 72 | (10,'https://static.fifsky.com/kids/upload/20181128/a63a0798d098272a39e76c88f39f2f29.png!blog',16,'2018-11-28 18:35:24','2018-11-28 18:35:24'), 73 | (11,'https://static.fifsky.com/kids/upload/20181128/381de1930d970183ab083fe08e2677ac.png!blog',16,'2018-11-28 18:35:24','2018-11-28 18:35:24'); 74 | 75 | /*!40000 ALTER TABLE `photos` ENABLE KEYS */; 76 | UNLOCK TABLES; 77 | 78 | 79 | # Dump of table moment_users 80 | # ------------------------------------------------------------ 81 | 82 | DROP TABLE IF EXISTS `moment_users`; 83 | 84 | CREATE TABLE `users` ( 85 | `id` int(11) unsigned NOT NULL AUTO_INCREMENT, 86 | `name` varchar(50) NOT NULL DEFAULT '', 87 | `status` int(11) NOT NULL, 88 | `success_time` datetime, 89 | `created_at` datetime NOT NULL, 90 | `updated_at` datetime NOT NULL, 91 | PRIMARY KEY (`id`) 92 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 93 | 94 | LOCK TABLES `users` WRITE; 95 | /*!40000 ALTER TABLE `moment_users` DISABLE KEYS */; 96 | 97 | INSERT INTO `users` (`id`,`name`, `status`, `created_at`, `updated_at`) 98 | VALUES 99 | (5,'豆爸&玥爸',1,'2018-11-28 10:29:55','2018-11-28 10:29:55'), 100 | (6,'呵呵',1,'2018-11-28 10:29:55','2018-11-28 10:29:55'); 101 | 102 | /*!40000 ALTER TABLE `moment_users` ENABLE KEYS */; 103 | UNLOCK TABLES; 104 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | "time" 7 | ) 8 | 9 | //inSlice 10 | func inSlice(k string, s []string) bool { 11 | for _, v := range s { 12 | if k == v { 13 | return true 14 | } 15 | } 16 | return false 17 | } 18 | 19 | //IsZero assert value is zero value 20 | func IsZero(val reflect.Value) bool { 21 | if !val.IsValid() { 22 | return true 23 | } 24 | 25 | kind := val.Kind() 26 | switch kind { 27 | case reflect.String: 28 | return val.Len() == 0 29 | case reflect.Bool: 30 | return val.Bool() == false 31 | case reflect.Float32, reflect.Float64: 32 | return val.Float() == 0 33 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 34 | return val.Int() == 0 35 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 36 | return val.Uint() == 0 37 | case reflect.Ptr, reflect.Chan, reflect.Func, reflect.Interface, reflect.Slice, reflect.Map: 38 | return val.IsNil() 39 | case reflect.Array: 40 | for i := 0; i < val.Len(); i++ { 41 | if !IsZero(val.Index(i)) { 42 | return false 43 | } 44 | } 45 | return true 46 | case reflect.Struct: 47 | if t, ok := val.Interface().(time.Time); ok { 48 | return t.IsZero() 49 | } else { 50 | valid := val.FieldByName("Valid") 51 | if valid.IsValid() { 52 | va, ok := valid.Interface().(bool) 53 | return ok && !va 54 | } 55 | 56 | return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface()) 57 | } 58 | default: 59 | return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface()) 60 | } 61 | } 62 | 63 | //zeroValueFilter filter zero value and keep the specified zero value 64 | func zeroValueFilter(fields map[string]reflect.Value, zv []string) map[string]interface{} { 65 | m := make(map[string]interface{}) 66 | 67 | for k, v := range fields { 68 | v = reflect.Indirect(v) 69 | if inSlice(k, zv) || !IsZero(v) { 70 | m[k] = v.Interface() 71 | } 72 | } 73 | 74 | return m 75 | } 76 | 77 | // structAutoTime auto set created_at updated_at 78 | func structAutoTime(fields map[string]reflect.Value, f []string) { 79 | for k, v := range fields { 80 | v = reflect.Indirect(v) 81 | if v.IsValid() && inSlice(k, f) && IsZero(v) { 82 | switch v.Kind() { 83 | case reflect.String: 84 | v.SetString(time.Now().Format("2006-01-02 15:04:05")) 85 | case reflect.Struct: 86 | // truncate 1 sec, Otherwise the data you create and the data you get will never be compared 87 | v.Set(reflect.ValueOf(time.Now().Truncate(1 * time.Second))) 88 | case reflect.Int, reflect.Int32, reflect.Int64: 89 | v.SetInt(time.Now().Unix()) 90 | case reflect.Uint, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 91 | v.SetUint(uint64(time.Now().Unix())) 92 | } 93 | } 94 | } 95 | } 96 | 97 | // structToMap 98 | func structToMap(fields map[string]reflect.Value) map[string]interface{} { 99 | m := make(map[string]interface{}) 100 | for k, v := range fields { 101 | v = reflect.Indirect(v) 102 | m[k] = v.Interface() 103 | } 104 | return m 105 | } 106 | 107 | // fillPrimaryKey is created fill primary key 108 | func fillPrimaryKey(v reflect.Value, value int64) { 109 | v = reflect.Indirect(v) 110 | if v.IsValid() { 111 | switch v.Kind() { 112 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 113 | v.SetInt(value) 114 | 115 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 116 | v.SetUint(uint64(value)) 117 | } 118 | } 119 | } 120 | 121 | // sortedParamKeys Sorts the param names given - map iteration order is explicitly random in Go 122 | // but we need params in a defined order to avoid unexpected results. 123 | func sortedParamKeys(params map[string]interface{}) []string { 124 | sortedKeys := make([]string, len(params)) 125 | i := 0 126 | for k := range params { 127 | sortedKeys[i] = k 128 | i++ 129 | } 130 | sort.Strings(sortedKeys) 131 | 132 | return sortedKeys 133 | } 134 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "testing" 8 | "time" 9 | 10 | "github.com/ilibs/gosql/v2/internal/example/models" 11 | ) 12 | 13 | type MyString string 14 | 15 | type MyStruct struct { 16 | num int 17 | text MyString 18 | } 19 | 20 | var ( 21 | zeroPtr *string 22 | zeroSlice []int 23 | zeroFunc func() string 24 | zeroMap map[string]string 25 | emptyIface interface{} 26 | zeroIface fmt.Stringer 27 | zeroValues = []interface{}{ 28 | nil, 29 | 30 | // bool 31 | false, 32 | 33 | // int 34 | 0, 35 | int8(0), 36 | int16(0), 37 | int32(0), 38 | int64(0), 39 | uint(0), 40 | uint8(0), 41 | uint16(0), 42 | uint32(0), 43 | uint64(0), 44 | 45 | // float 46 | 0.0, 47 | float32(0.0), 48 | float64(0.0), 49 | 50 | // string 51 | "", 52 | 53 | // alias 54 | MyString(""), 55 | 56 | // func 57 | zeroFunc, 58 | 59 | // array / slice 60 | [0]int{}, 61 | zeroSlice, 62 | 63 | // map 64 | zeroMap, 65 | 66 | // interface 67 | emptyIface, 68 | zeroIface, 69 | 70 | // pointer 71 | zeroPtr, 72 | 73 | // struct 74 | MyStruct{}, 75 | time.Time{}, 76 | MyStruct{num: 0}, 77 | MyStruct{text: MyString("")}, 78 | sql.NullString{String: "", Valid: false}, 79 | sql.NullInt64{Int64: 0, Valid: false}, 80 | } 81 | nonZeroIface fmt.Stringer = time.Now() 82 | nonZeroValues = []interface{}{ 83 | // bool 84 | true, 85 | 86 | // int 87 | 1, 88 | int8(1), 89 | int16(1), 90 | int32(1), 91 | int64(1), 92 | uint8(1), 93 | uint16(1), 94 | uint32(1), 95 | uint64(1), 96 | 97 | // float 98 | 1.0, 99 | float32(1.0), 100 | float64(1.0), 101 | 102 | // string 103 | "test", 104 | 105 | // alias 106 | MyString("test"), 107 | 108 | // func 109 | time.Now, 110 | 111 | // array / slice 112 | []int{}, 113 | []int{42}, 114 | [1]int{42}, 115 | 116 | // map 117 | make(map[string]string, 1), 118 | 119 | // interface 120 | nonZeroIface, 121 | 122 | // pointer 123 | &nonZeroIface, 124 | 125 | // struct 126 | MyStruct{num: 1}, 127 | time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), 128 | sql.NullString{String: "", Valid: true}, 129 | sql.NullInt64{Int64: 0, Valid: true}, 130 | } 131 | ) 132 | 133 | func TestIsZero(t *testing.T) { 134 | for _, value := range zeroValues { 135 | if !IsZero(reflect.ValueOf(value)) { 136 | t.Errorf("expected '%v' (%T) to be recognized as zero value", value, value) 137 | } 138 | } 139 | 140 | for _, value := range nonZeroValues { 141 | if IsZero(reflect.ValueOf(value)) { 142 | t.Errorf("did not expect '%v' (%T) to be recognized as zero value", value, value) 143 | } 144 | } 145 | } 146 | 147 | func TestUtil_inSlice(t *testing.T) { 148 | s := []string{"a", "b", "c"} 149 | 150 | if !inSlice("a", s) { 151 | t.Error("in slice find error") 152 | } 153 | 154 | if inSlice("d", s) { 155 | t.Error("in slice exist error") 156 | } 157 | } 158 | 159 | func TestUtil_zeroValueFilter(t *testing.T) { 160 | user := &models.Users{ 161 | Id: 1, 162 | Name: "test", 163 | } 164 | 165 | rv := reflect.Indirect(reflect.ValueOf(user)) 166 | fields := mapper.FieldMap(rv) 167 | 168 | m := zeroValueFilter(fields, nil) 169 | 170 | if _, ok := m["status"]; ok { 171 | t.Error("status value not filter") 172 | } 173 | 174 | if _, ok := m["creatd_at"]; ok { 175 | t.Error("creatd_at zero value not filter") 176 | } 177 | 178 | if _, ok := m["updated_at"]; ok { 179 | t.Error("updated_at zero value not filter") 180 | } 181 | 182 | m2 := zeroValueFilter(fields, []string{"status"}) 183 | if _, ok := m2["status"]; !ok { 184 | t.Error("status shouldn't be filtered") 185 | } 186 | } 187 | 188 | type testWithIntCreatedTime struct { 189 | models.Users 190 | CreateAt int `db:"create_at"` 191 | CreateTime uint `db:"create_time"` 192 | } 193 | 194 | func TestUtil_structAutoTime(t *testing.T) { 195 | user := &testWithIntCreatedTime{ 196 | models.Users{ 197 | Id: 1, 198 | Name: "test", 199 | }, 200 | 0, 201 | 0, 202 | } 203 | rv := reflect.Indirect(reflect.ValueOf(user)) 204 | fields := mapper.FieldMap(rv) 205 | 206 | structAutoTime(fields, AUTO_CREATE_TIME_FIELDS) 207 | 208 | if user.CreatedAt.IsZero() { 209 | t.Error("auto time fail") 210 | } 211 | if user.CreateAt == 0 { 212 | t.Error("auto time fail") 213 | } 214 | if user.CreateTime == 0 { 215 | t.Error("auto time fail") 216 | } 217 | 218 | } 219 | 220 | func TestUtil_sortedParamKeys(t *testing.T) { 221 | m := map[string]interface{}{ 222 | "id": 1, 223 | "name": "test", 224 | "created_at": "2018-07-11 11:58:21", 225 | "updated_at": "2018-07-11 11:58:21", 226 | } 227 | 228 | keySort := []string{"created_at", "id", "name", "updated_at"} 229 | 230 | s := sortedParamKeys(m) 231 | 232 | for i, k := range s { 233 | if k != keySort[i] { 234 | t.Error("sort error", k) 235 | } 236 | } 237 | } 238 | 239 | func Test_fillPrimaryKey(t *testing.T) { 240 | var a int 241 | var b uint 242 | 243 | { 244 | v := reflect.ValueOf(&a) 245 | fillPrimaryKey(v, 123) 246 | if a != 123 { 247 | t.Errorf("value want %d,but get %d", 123, a) 248 | } 249 | } 250 | 251 | { 252 | v := reflect.ValueOf(&b) 253 | fillPrimaryKey(v, 465) 254 | if b != 465 { 255 | t.Errorf("value want %d,but get %d", 465, b) 256 | } 257 | } 258 | } 259 | --------------------------------------------------------------------------------