├── .github ├── dependabot.yml ├── local_start_mysql_ci.sh ├── pre_install.sh ├── test.sh └── workflows │ ├── reviewdog.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── _examples ├── query │ └── main.go └── update │ └── main.go ├── api.go ├── clause.go ├── clause_test.go ├── go.mod ├── go.sum ├── sql_type.go ├── sql_type_test.go ├── tests_test.go ├── update.go ├── update_test.go ├── util.go ├── util_test.go ├── where.go └── where_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: gomod 5 | directory: / 6 | schedule: 7 | interval: weekly 8 | - package-ecosystem: github-actions 9 | directory: / 10 | schedule: 11 | interval: weekly 12 | - package-ecosystem: gomod 13 | directory: /tests 14 | schedule: 15 | interval: weekly -------------------------------------------------------------------------------- /.github/local_start_mysql_ci.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | docker run \ 6 | -d \ 7 | -p 9910:3306 \ 8 | -e MYSQL_DATABASE=gorm \ 9 | -e MYSQL_USER=gorm \ 10 | -e MYSQL_PASSWORD=gorm \ 11 | -e MYSQL_ALLOW_EMPTY_PASSWORD=yes \ 12 | mysql:latest -------------------------------------------------------------------------------- /.github/pre_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | installed_deps=$(ls `go env GOPATH`/bin/) 6 | 7 | deps=( 8 | "goimports;golang.org/x/tools/cmd/goimports" # go 官方的 goimports 工具 9 | "gofumpt;mvdan.cc/gofumpt" # 格式化工具 10 | ) 11 | 12 | for i in "${deps[@]}"; do 13 | dep_name=$(echo "$i" | cut -d ';' -f 1-1) 14 | dep_path=$(echo "$i" | cut -d ';' -f 2-2) 15 | if echo $installed_deps | grep -w "$dep_name" > /dev/null; then 16 | printf $(tput setaf 2)"\"$dep_name\""$(tput sgr0)" installed, skip.\n" 17 | else 18 | echo "go install $dep_path@latest" && go install "$dep_path"@latest; 19 | fi 20 | done 21 | 22 | printf "\n" 23 | 24 | -------------------------------------------------------------------------------- /.github/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | module_name=$(cat go.mod | grep module | cut -d ' ' -f 2-2) 6 | echo "module_name is $module_name" 7 | 8 | GO111MODULE=on echo 'mode: atomic' > c.out && \ 9 | go list ./... | grep -v 'frontend' | xargs -n1 -I{} sh -c 'LOCAL_TEST=true go test -covermode=atomic -coverprofile=coverage.tmp -coverpkg=./... -parallel 1 -p 1 -count=1 -gcflags="all=-l -N" {} && tail -n +2 coverage.tmp >> c.out' && \ 10 | rm coverage.tmp 11 | 12 | # go tool cover -func=c.out -o coverage.txt 13 | -------------------------------------------------------------------------------- /.github/workflows/reviewdog.yml: -------------------------------------------------------------------------------- 1 | name: reviewdog 2 | on: [pull_request] 3 | jobs: 4 | golangci-lint: 5 | name: runner / golangci-lint 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Check out code into the Go module directory 9 | uses: actions/checkout@v3 10 | - name: golangci-lint 11 | uses: reviewdog/action-golangci-lint@v2 12 | 13 | - name: Setup reviewdog 14 | uses: reviewdog/action-setup@v1 15 | 16 | - name: gofumpt -s with reviewdog 17 | env: 18 | REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | run: | 20 | go install mvdan.cc/gofumpt@v0.2.0 21 | gofumpt -e -d . | \ 22 | reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | # Label of the container job 14 | mysql: 15 | strategy: 16 | matrix: 17 | go: ['1.20', '1.19', '1.18'] 18 | platform: [ubuntu-latest] 19 | runs-on: ${{ matrix.platform }} 20 | 21 | services: 22 | mysql: 23 | image: mysql:latest 24 | env: 25 | MYSQL_DATABASE: gorm 26 | MYSQL_USER: gorm 27 | MYSQL_PASSWORD: gorm 28 | MYSQL_RANDOM_ROOT_PASSWORD: "yes" 29 | ports: 30 | - 9910:3306 31 | options: >- 32 | --health-cmd "mysqladmin ping -ugorm -pgorm" 33 | --health-interval 10s 34 | --health-start-period 10s 35 | --health-timeout 5s 36 | --health-retries 10 37 | 38 | steps: 39 | - name: Set up Go 1.x 40 | uses: actions/setup-go@v4 41 | with: 42 | go-version: ${{ matrix.go }} 43 | 44 | - name: Check out code into the Go module directory 45 | uses: actions/checkout@v3 46 | 47 | - name: go mod package cache 48 | uses: actions/cache@v3 49 | with: 50 | path: ~/go/pkg/mod 51 | key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('go.mod') }} 52 | 53 | - name: Build 54 | run: make check && make lint && make build 55 | 56 | - name: Tests 57 | run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" make test 58 | 59 | - name: Upload Cov 60 | run: | 61 | curl -Os https://uploader.codecov.io/latest/linux/codecov 62 | chmod +x codecov 63 | ./codecov -f c.out -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | c.out 3 | coverage.tmp 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 chyroc(chyroc@qq.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | local_package=gorm.io/gormx 2 | all: lint build test 3 | 4 | check: 5 | @$(set_env) ./.github/pre_install.sh 6 | @$(set_env) test -z "$$(goimports -local $(local_package) -d .)" 7 | @$(set_env) test -z "$$(gofumpt -d -e . | tee /dev/stderr)" 8 | 9 | lint: 10 | @$(set_env) ./.github/pre_install.sh 11 | @$(set_env) go fmt ./... 12 | @$(set_env) goimports -local $(local_package) -w . 13 | @$(set_env) gofumpt -l -w . 14 | 15 | build: 16 | go build ./... 17 | 18 | test: 19 | ./.github/test.sh 20 | 21 | html: 22 | go tool cover -html=c.out 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM ToolBox 2 | 3 | [![codecov](https://codecov.io/github/go-gorm/gormx/branch/master/graph/badge.svg?token=C7F3NAJH6U)](https://codecov.io/github/go-gorm/gormx) 4 | [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gormx "go report card")](https://goreportcard.com/report/github.com/go-gorm/gormx) 5 | [![test status](https://github.com/go-gorm/gormx/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gormx/actions) 6 | [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 7 | [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gormx?tab=doc) 8 | 9 | 10 | ## Overview 11 | 12 | - Query supports gorm's where method 13 | - Update supports gorm's update method 14 | - 95% test coverage 15 | 16 | ## Getting Started 17 | 18 | - GORMX Guides https://gorm.io/gormx/index.html 19 | - GORM Guides http://gorm.io/docs 20 | 21 | ## Maintainers 22 | 23 | [@chyroc](https://github.com/chyroc) 24 | [@jinzhu](https://github.com/jinzhu) 25 | 26 | ## Contributing 27 | 28 | [You can help to deliver a better GORM/GORMX, check out things you can do](https://gorm.io/contribute.html) 29 | 30 | ## License 31 | 32 | Released under the [MIT License](https://github.com/go-gorm/gormx/blob/master/LICENSE) 33 | -------------------------------------------------------------------------------- /_examples/query/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/gorm" 7 | 8 | "gorm.io/gormx" 9 | ) 10 | 11 | type User struct { 12 | Name string `gorm:"column:name"` 13 | Age uint `gorm:"column:age"` 14 | } 15 | 16 | type Where struct { 17 | Name *string `gorm:"column:name;"` 18 | AgeGT *uint `gorm:"column:age; query_expr:>"` 19 | Or *WhereOr `gorm:"query_expr:or"` 20 | } 21 | 22 | type WhereOr struct { 23 | Name *string `gorm:"column:name;"` 24 | AgeGT *uint `gorm:"column:age; query_expr:>"` 25 | } 26 | 27 | type Update struct { 28 | Age uint `gorm:"column:age; update_expr:+"` 29 | Extra *ExtraInfo `gorm:"column:extra; update_expr:merge_json"` 30 | } 31 | 32 | type ExtraInfo struct { 33 | City string `json:"city"` 34 | } 35 | 36 | func queryExample(db *gorm.DB) { 37 | fmt.Println(db.ToSQL(func(tx *gorm.DB) *gorm.DB { 38 | res := tx.Table("users").Where(gormx.Query(Where{ 39 | Name: ptr("a"), 40 | AgeGT: ptr(uint(10)), 41 | Or: &WhereOr{ 42 | Name: ptr("or-name"), 43 | AgeGT: ptr(uint(20)), 44 | }, 45 | })).Find(&[]User{}) 46 | if res.Error != nil { 47 | panic(res.Error) 48 | } 49 | return res 50 | })) 51 | } 52 | 53 | func ptr[T any](v T) *T { 54 | return &v 55 | } 56 | -------------------------------------------------------------------------------- /_examples/update/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/gorm" 7 | 8 | "gorm.io/gormx" 9 | ) 10 | 11 | type Where struct { 12 | Name *string `gorm:"column:name;"` 13 | AgeGT *uint `gorm:"column:age; query_expr:>"` 14 | Or *WhereOr `gorm:"query_expr:or"` 15 | } 16 | 17 | type WhereOr struct { 18 | Name *string `gorm:"column:name;"` 19 | AgeGT *uint `gorm:"column:age; query_expr:>"` 20 | } 21 | 22 | type Update struct { 23 | Age uint `gorm:"column:age; update_expr:+"` 24 | Extra *ExtraInfo `gorm:"column:extra; update_expr:merge_json"` 25 | } 26 | 27 | type ExtraInfo struct { 28 | City string `json:"city"` 29 | } 30 | 31 | func updateExample(db *gorm.DB) { 32 | where := Where{ 33 | Name: ptr("a"), 34 | AgeGT: ptr(uint(10)), 35 | Or: &WhereOr{ 36 | Name: ptr("or-name"), 37 | AgeGT: ptr(uint(20)), 38 | }, 39 | } 40 | update := Update{ 41 | Age: 10, 42 | Extra: &ExtraInfo{ 43 | City: "beijing", 44 | }, 45 | } 46 | fmt.Println(db.ToSQL(func(tx *gorm.DB) *gorm.DB { 47 | res := tx.Table("users").Where(gormx.Query(where)).Updates(gormx.Update(update)) 48 | if res.Error != nil { 49 | panic(res.Error) 50 | } 51 | return res 52 | })) 53 | } 54 | 55 | func ptr[T any](v T) *T { 56 | return &v 57 | } 58 | -------------------------------------------------------------------------------- /api.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | "gorm.io/gorm/clause" 6 | ) 7 | 8 | func Query(where any) clause.Expression { 9 | expression, err := buildSQLWhere(where) 10 | if err != nil { 11 | return errExpression{err} 12 | } 13 | return expression 14 | } 15 | 16 | func Update(update any) gorm.StatementModifier { 17 | return &updateModifyStatement{update} 18 | } 19 | 20 | type updateModifyStatement struct { 21 | update any 22 | } 23 | 24 | var _ gorm.StatementModifier = (*updateModifyStatement)(nil) 25 | 26 | func (u updateModifyStatement) ModifyStatement(stmt *gorm.Statement) { 27 | m, err := buildSQLUpdate(u.update) 28 | if err != nil { 29 | _ = stmt.AddError(err) 30 | return 31 | } 32 | stmt.Dest = m 33 | } 34 | -------------------------------------------------------------------------------- /clause.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import "gorm.io/gorm/clause" 4 | 5 | type notIn struct { 6 | in clause.IN 7 | } 8 | 9 | func (in notIn) Build(builder clause.Builder) { 10 | in.in.NegationBuild(builder) 11 | } 12 | 13 | type errExpression struct { 14 | err error 15 | } 16 | 17 | func (e errExpression) Build(builder clause.Builder) { 18 | _ = builder.AddError(e.err) 19 | } 20 | -------------------------------------------------------------------------------- /clause_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func Test_Clause(t *testing.T) { 11 | as := assert.New(t) 12 | db := newDB() 13 | 14 | t.Run("no in", func(t *testing.T) { 15 | res := notIn{clause.IN{Column: "name", Values: []any{1, 2}}} 16 | res.Build(db.Statement) 17 | as.Equal("`name` NOT IN (?,?)", db.Statement.SQL.String()) 18 | }) 19 | } 20 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gormx 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/stretchr/testify v1.8.1 7 | gorm.io/driver/mysql v1.4.7 8 | gorm.io/gorm v1.24.5 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/go-sql-driver/mysql v1.7.0 // indirect 14 | github.com/jinzhu/inflection v1.0.0 // indirect 15 | github.com/jinzhu/now v1.1.5 // indirect 16 | github.com/kr/pretty v0.3.0 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= 6 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 7 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 8 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 9 | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 10 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 11 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 12 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 13 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 14 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 15 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 16 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 17 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 18 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 19 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= 23 | github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= 24 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 25 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 26 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 27 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 28 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 29 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 30 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 32 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 34 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 35 | gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= 36 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 37 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 38 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y= 40 | gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= 41 | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= 42 | gorm.io/gorm v1.24.5 h1:g6OPREKqqlWq4kh/3MCQbZKImeB9e6Xgc4zD+JgNZGE= 43 | gorm.io/gorm v1.24.5/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= 44 | -------------------------------------------------------------------------------- /sql_type.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sync" 7 | 8 | "gorm.io/gorm/schema" 9 | ) 10 | 11 | const ( 12 | tagColumn = "COLUMN" 13 | tagQuery = "QUERY_EXPR" 14 | tagUpdate = "UPDATE_EXPR" 15 | ) 16 | 17 | var structTypeCacheMap sync.Map 18 | 19 | type structType struct { 20 | Names []string 21 | Fields map[string]*fieldType 22 | } 23 | 24 | type fieldType struct { 25 | Name string // field name 26 | Column string // tag sql_field 27 | QueryExpr string // tag query_expr 28 | UpdateExpr string // tag update_expr 29 | IsAnonymous bool // field 是否是匿名字段 30 | Kind reflect.Kind // field Kind 31 | OrType reflect.Type // field OrType 32 | Tag map[string]string // key: COLUMN etc. 33 | } 34 | 35 | func parseStructType(t reflect.Type) (*structType, error) { 36 | structType := loadStructTypeFromCache(t) 37 | if structType != nil { 38 | return structType, nil 39 | } 40 | parsedType, err := parseStructTypeNoCache(t) 41 | if err != nil { 42 | return nil, err 43 | } 44 | structTypeCacheMap.Store(t, parsedType) 45 | return parsedType, nil 46 | } 47 | 48 | func parseStructTypeNoCache(t reflect.Type) (_ *structType, err error) { 49 | return parseStructTypeRev(t, nil, false) 50 | } 51 | 52 | func parseStructTypeRev(t reflect.Type, sType *structType, isField bool) (_ *structType, err error) { 53 | if t.Kind() == reflect.Ptr { 54 | t = t.Elem() 55 | } 56 | if isField && (t.Kind() != reflect.Struct && t.Kind() != reflect.Slice) { 57 | return nil, fmt.Errorf("field's type must be struct/slice, but got %s", t.Kind()) 58 | } 59 | 60 | if sType == nil { 61 | sType = &structType{ 62 | Names: []string{}, 63 | Fields: map[string]*fieldType{}, 64 | } 65 | } 66 | for i := 0; i < t.NumField(); i++ { 67 | structField := t.Field(i) 68 | tag := schema.ParseTagSetting(structField.Tag.Get("gorm"), ";") 69 | columnName := tag[tagColumn] 70 | queryExprString := tag[tagQuery] 71 | updateExprString := tag[tagUpdate] 72 | ft := structField.Type 73 | if ft.Kind() == reflect.Ptr { 74 | ft = ft.Elem() 75 | } 76 | isOr := isColumnEmpty(columnName) && queryExprString == operatorOr 77 | if isColumnEmpty(columnName) { 78 | if structField.Anonymous { 79 | // 匿名字段,不跳过 80 | } else if queryExprString == "" && updateExprString == "" { 81 | // 没有 query_expr 和 update_expr,跳过 82 | continue 83 | } 84 | } 85 | if err := checkField(structField, columnName, queryExprString, ft); err != nil { 86 | return nil, err 87 | } 88 | 89 | // 匿名字段, 将匿名字段的字段加入到当前结构体中 90 | if structField.Anonymous { 91 | if err := parseAnonymousStructField(structField, sType); err != nil { 92 | return nil, err 93 | } 94 | } else { 95 | var fieldStructType reflect.Type 96 | if isOr { 97 | if ft.Kind() == reflect.Slice { 98 | ft = ft.Elem() 99 | } 100 | fieldStructType = ft 101 | } 102 | if err := parseNormalStructField(structField, queryExprString, updateExprString, columnName, tag, fieldStructType, sType); err != nil { 103 | return nil, err 104 | } 105 | } 106 | } 107 | return sType, nil 108 | } 109 | 110 | func reOrderNames(sType *structType, name string) { 111 | for idx, n := range sType.Names { 112 | if n == name { 113 | copy(sType.Names[idx:], sType.Names[idx+1:]) 114 | sType.Names[len(sType.Names)-1] = name 115 | break 116 | } 117 | } 118 | } 119 | 120 | func checkField(structField reflect.StructField, columnName, queryExprString string, ft reflect.Type) error { 121 | // 匿名 122 | if structField.Anonymous { 123 | if !isColumnEmpty(columnName) { 124 | return fmt.Errorf("field %s is anonymous that can not have column tag", structField.Name) 125 | } 126 | } 127 | 128 | // or 129 | if queryExprString == operatorOr { 130 | if !isColumnEmpty(columnName) { 131 | return fmt.Errorf("struct field(%s) with query_expr(%s) cannot set column tag", structField.Name, queryExprString) 132 | } 133 | if ft.Kind() == reflect.Struct { 134 | } else if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array { 135 | if ft.Elem().Kind() != reflect.Struct { 136 | return fmt.Errorf("struct field(%s) with query_expr(%s) must be struct or it's list", structField.Name, queryExprString) 137 | } 138 | } else { 139 | return fmt.Errorf("struct field(%s) with query_expr(%s) must be struct or it's list", structField.Name, queryExprString) 140 | } 141 | } 142 | 143 | // 非匿名 144 | if !structField.Anonymous { 145 | if isColumnEmpty(columnName) && queryExprString != operatorOr { 146 | return fmt.Errorf("struct field(%s) need column tag", structField.Name) 147 | } 148 | } 149 | 150 | // op 和 类型对应 151 | rt := structField.Type 152 | if rt.Kind() == reflect.Ptr { 153 | rt = rt.Elem() 154 | } 155 | switch queryExprString { 156 | case operatorIn: 157 | if rt.Kind() != reflect.Slice && rt.Kind() != reflect.Array { 158 | return fmt.Errorf("struct field(%s) with in query_expr must be slice/array", structField.Name) 159 | } 160 | case operatorEq: 161 | if rt.Kind() == reflect.Slice || rt.Kind() == reflect.Array { 162 | return fmt.Errorf("struct field(%s) with eq query_expr can not be slice/array", structField.Name) 163 | } 164 | } 165 | 166 | return nil 167 | } 168 | 169 | func checkQueryExpr(field reflect.StructField, q string) error { 170 | if q != "" { 171 | if _, ok := queryExprMap[q]; !ok { 172 | return fmt.Errorf("field(%s) query_expr(%s) invalid", field.Name, q) 173 | } 174 | } 175 | 176 | return nil 177 | } 178 | 179 | func checkUpdateExpr(field reflect.StructField, q string) error { 180 | if q != "" { 181 | if _, ok := updaterMap[q]; !ok { 182 | return fmt.Errorf("field(%s) update_expr(%s) invalid", field.Name, q) 183 | } 184 | } 185 | return nil 186 | } 187 | 188 | func loadStructTypeFromCache(t reflect.Type) *structType { 189 | v, ok := structTypeCacheMap.Load(t) 190 | if ok { 191 | sqlType := v.(*structType) 192 | return sqlType 193 | } 194 | return nil 195 | } 196 | 197 | func isColumnEmpty(column string) bool { 198 | return column == "" || column == "-" 199 | } 200 | 201 | func parseAnonymousStructField(structField reflect.StructField, sType *structType) error { 202 | t := structField.Type 203 | if t.Kind() == reflect.Ptr { 204 | t = t.Elem() 205 | } 206 | childSType, err := parseStructTypeRev(t, sType, true) 207 | if err != nil { 208 | return err 209 | } 210 | for kk, vv := range childSType.Fields { 211 | if _, ok := sType.Fields[kk]; !ok { 212 | sType.Fields[kk] = vv 213 | sType.Names = append(sType.Names, kk) 214 | } 215 | } 216 | return nil 217 | } 218 | 219 | func parseNormalStructField(structField reflect.StructField, queryExprString string, updateExprString string, columnName string, tag map[string]string, fieldStructType reflect.Type, sType *structType) error { 220 | if err := checkQueryExpr(structField, queryExprString); err != nil { 221 | return err 222 | } 223 | if err := checkUpdateExpr(structField, updateExprString); err != nil { 224 | return err 225 | } 226 | column := &fieldType{ 227 | Name: structField.Name, 228 | Column: columnName, 229 | QueryExpr: queryExprString, 230 | UpdateExpr: updateExprString, 231 | IsAnonymous: structField.Anonymous, 232 | Kind: structField.Type.Kind(), 233 | OrType: fieldStructType, 234 | Tag: tag, 235 | } 236 | // 已经有这个 name 的 field 这说明需要覆盖 237 | if _, ok := sType.Fields[structField.Name]; ok { 238 | reOrderNames(sType, structField.Name) 239 | } else { 240 | sType.Names = append(sType.Names, structField.Name) 241 | } 242 | sType.Fields[structField.Name] = column 243 | 244 | return nil 245 | } 246 | -------------------------------------------------------------------------------- /sql_type_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "reflect" 7 | "sort" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_ParseType(t *testing.T) { 14 | as := assert.New(t) 15 | 16 | type AStruct struct { 17 | A *int `gorm:"column:a; query_expr:>"` 18 | } 19 | type BStruct struct { 20 | B *string `gorm:"column:b; query_expr:!="` 21 | } 22 | type CStruct struct { 23 | B *int `gorm:"column:b; query_expr:!="` 24 | BStruct 25 | } 26 | 27 | tests := []struct { 28 | name string 29 | args reflect.Type 30 | want *structType 31 | wantErr assert.ErrorAssertionFunc 32 | }{ 33 | {"err - invalid update_expr", reflect.TypeOf(struct { 34 | A *int `gorm:"column:a; update_expr:x"` 35 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 36 | as.NotNil(err) 37 | as.Equal("field(A) update_expr(x) invalid", err.Error()) 38 | return false 39 | }}, 40 | {"err - invalid query_expr", reflect.TypeOf(struct { 41 | A *int `gorm:"column:a; query_expr:x"` 42 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 43 | as.NotNil(err) 44 | as.Equal("field(A) query_expr(x) invalid", err.Error()) 45 | return false 46 | }}, 47 | {"err - invalid anonymous type", reflect.TypeOf(struct { 48 | io.Reader 49 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 50 | as.NotNil(err) 51 | as.Equal("field's type must be struct/slice, but got interface", err.Error()) 52 | return false 53 | }}, 54 | {"err - invalid anonymous column name not empty", reflect.TypeOf(struct { 55 | BStruct `gorm:"column:b"` 56 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 57 | as.NotNil(err) 58 | as.Equal("field BStruct is anonymous that can not have column tag", err.Error()) 59 | return false 60 | }}, 61 | {"err - invalid or column name not empty", reflect.TypeOf(struct { 62 | Or BStruct `gorm:"column:b; query_expr:or"` 63 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 64 | as.NotNil(err) 65 | as.Equal("struct field(Or) with query_expr(or) cannot set column tag", err.Error()) 66 | return false 67 | }}, 68 | {"err - invalid or must be struct/slice", reflect.TypeOf(struct { 69 | Or int `gorm:"column:b; query_expr:or"` 70 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 71 | as.NotNil(err) 72 | as.Equal("struct field(Or) with query_expr(or) cannot set column tag", err.Error()) 73 | return false 74 | }}, 75 | {"err - invalid or must be struct/slice", reflect.TypeOf(struct { 76 | Or []int `gorm:"query_expr:or"` 77 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 78 | as.NotNil(err) 79 | as.Equal("struct field(Or) with query_expr(or) must be struct or it's list", err.Error()) 80 | return false 81 | }}, 82 | {"err - invalid or must be struct/slice", reflect.TypeOf(struct { 83 | Or int `gorm:"query_expr:or"` 84 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 85 | as.NotNil(err) 86 | as.Equal("struct field(Or) with query_expr(or) must be struct or it's list", err.Error()) 87 | return false 88 | }}, 89 | {"err - invalid must set column name", reflect.TypeOf(struct { 90 | A string `gorm:"query_expr:="` 91 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 92 | as.NotNil(err) 93 | as.Equal("struct field(A) need column tag", err.Error()) 94 | return false 95 | }}, 96 | {"err - in with not slice type", reflect.TypeOf(struct { 97 | A string `gorm:"column:a; query_expr:in"` 98 | }{}), &structType{}, func(t assert.TestingT, err error, i ...interface{}) bool { 99 | as.NotNil(err) 100 | as.Equal("struct field(A) with in query_expr must be slice/array", err.Error()) 101 | return false 102 | }}, 103 | 104 | {"empty - basetype", reflect.TypeOf(struct { 105 | A int `json:"a"` 106 | }{}), &structType{ 107 | Names: []string{}, 108 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 109 | as.Nil(err) 110 | return true 111 | }}, 112 | {"empty - pointer", reflect.TypeOf(struct { 113 | A *int `json:"a"` 114 | }{}), &structType{ 115 | Names: []string{}, 116 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 117 | as.Nil(err) 118 | return true 119 | }}, 120 | 121 | {"ok - one field - basetype", reflect.TypeOf(struct { 122 | A int `gorm:"column:a"` 123 | }{}), &structType{ 124 | Names: []string{"A"}, 125 | Fields: map[string]*fieldType{ 126 | "A": {Name: "A", Column: "a", Kind: reflect.Int, Tag: map[string]string{"COLUMN": "a"}}, 127 | }, 128 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 129 | as.Nil(err) 130 | return true 131 | }}, 132 | 133 | {"ok - one field - pointer", reflect.TypeOf(struct { 134 | A *int `gorm:"column:a"` 135 | }{}), &structType{ 136 | Names: []string{"A"}, 137 | Fields: map[string]*fieldType{ 138 | "A": {Name: "A", Column: "a", Kind: reflect.Pointer, Tag: map[string]string{"COLUMN": "a"}}, 139 | }, 140 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 141 | as.Nil(err) 142 | return true 143 | }}, 144 | {"ok - one field - struct is pointer", reflect.TypeOf(&struct { 145 | A int `gorm:"column:a"` 146 | }{}), &structType{ 147 | Names: []string{"A"}, 148 | Fields: map[string]*fieldType{ 149 | "A": {Name: "A", Column: "a", Kind: reflect.Int, Tag: map[string]string{"COLUMN": "a"}}, 150 | }, 151 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 152 | as.Nil(err) 153 | return true 154 | }}, 155 | {"ok - two field", reflect.TypeOf(struct { 156 | A int `gorm:"column:a"` 157 | B *string `gorm:"column:b"` 158 | }{}), &structType{ 159 | Names: []string{"A", "B"}, 160 | Fields: map[string]*fieldType{ 161 | "A": {Name: "A", Column: "a", Kind: reflect.Int, Tag: map[string]string{"COLUMN": "a"}}, 162 | "B": {Name: "B", Column: "b", Kind: reflect.Pointer, Tag: map[string]string{"COLUMN": "b"}}, 163 | }, 164 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 165 | as.Nil(err) 166 | return true 167 | }}, 168 | {"ok - rewrite field", reflect.TypeOf(struct { 169 | A *int `gorm:"column:a"` 170 | AStruct 171 | }{}), &structType{ 172 | Names: []string{"A"}, 173 | Fields: map[string]*fieldType{ 174 | "A": { 175 | Name: "A", Column: "a", Kind: reflect.Pointer, 176 | QueryExpr: ">", Tag: map[string]string{ 177 | "COLUMN": "a", "QUERY_EXPR": ">", 178 | }, 179 | }, 180 | }, 181 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 182 | as.Nil(err) 183 | return true 184 | }}, 185 | {"ok - anonymous", reflect.TypeOf(struct { 186 | BStruct 187 | }{}), &structType{ 188 | Names: []string{"B"}, 189 | Fields: map[string]*fieldType{ 190 | "B": { 191 | Name: "B", Column: "b", 192 | Kind: reflect.Ptr, 193 | IsAnonymous: false, 194 | QueryExpr: "!=", Tag: map[string]string{"COLUMN": "b", "QUERY_EXPR": "!="}, 195 | }, 196 | }, 197 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 198 | as.Nil(err) 199 | return true 200 | }}, 201 | {"ok - anonymous - pointer", reflect.TypeOf(struct { 202 | B int `gorm:"column:b; query_expr:!="` 203 | *BStruct 204 | }{}), &structType{ 205 | Names: []string{"B"}, 206 | Fields: map[string]*fieldType{ 207 | "B": { 208 | Name: "B", Column: "b", 209 | Kind: reflect.Ptr, 210 | IsAnonymous: false, 211 | QueryExpr: "!=", Tag: map[string]string{"COLUMN": "b", "QUERY_EXPR": "!="}, 212 | }, 213 | }, 214 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 215 | as.Nil(err) 216 | return true 217 | }}, 218 | {"ok - anonymous'2s", reflect.TypeOf(struct { 219 | B int `gorm:"column:b; query_expr:!="` 220 | *CStruct 221 | }{}), &structType{ 222 | Names: []string{"B"}, 223 | Fields: map[string]*fieldType{ 224 | "B": { 225 | Name: "B", Column: "b", 226 | Kind: reflect.Ptr, 227 | IsAnonymous: false, 228 | QueryExpr: "!=", Tag: map[string]string{"COLUMN": "b", "QUERY_EXPR": "!="}, 229 | }, 230 | }, 231 | }, func(t assert.TestingT, err error, i ...interface{}) bool { 232 | as.Nil(err) 233 | return true 234 | }}, 235 | } 236 | for _, tt := range tests { 237 | t.Run(tt.name+"_parseTypeNoCache", func(t *testing.T) { 238 | got, err := parseStructTypeNoCache(tt.args) 239 | if !tt.wantErr(t, err, fmt.Sprintf("parseStructTypeNoCache(%v)", tt.args)) { 240 | return 241 | } 242 | assertStructTypeEqual(t, got, tt.want, fmt.Sprintf("parseStructTypeNoCache(%v)", tt.args)) 243 | }) 244 | t.Run(tt.name+"_parseType", func(t *testing.T) { 245 | got, err := parseStructType(tt.args) 246 | if !tt.wantErr(t, err, fmt.Sprintf("parseStructTypeNoCache(%v)", tt.args)) { 247 | return 248 | } 249 | assertStructTypeEqual(t, got, tt.want, fmt.Sprintf("parseStructTypeNoCache(%v)", tt.args)) 250 | }) 251 | } 252 | } 253 | 254 | func assertStructTypeEqual(t assert.TestingT, a, b *structType, msg string) { 255 | as := assert.New(t) 256 | 257 | if a == nil { 258 | as.Nil(b, msg) 259 | return 260 | } 261 | 262 | as.Equal(len(a.Names), len(b.Names), msg) 263 | 264 | sort.Strings(a.Names) 265 | sort.Strings(b.Names) 266 | as.Equal(a.Names, b.Names, msg) 267 | 268 | as.Equal(len(a.Fields), len(b.Fields), msg) 269 | for _, name := range a.Names { 270 | as.NotNil(a.Fields[name], msg) 271 | as.NotNil(b.Fields[name], msg) 272 | assertFieldTypeEqual(t, a.Fields[name], b.Fields[name], fmt.Sprintf("name:%s, val:%v; name:%s, val:%v; %s", name, a.Fields[name], name, b.Fields[name], msg)) 273 | } 274 | } 275 | 276 | func assertFieldTypeEqual(t assert.TestingT, a, b *fieldType, msg string) { 277 | as := assert.New(t) 278 | 279 | as.NotNil(a) 280 | as.NotNil(b) 281 | 282 | as.Equal(a.Name, b.Name, msg) 283 | as.Equal(a.Column, b.Column, msg) 284 | as.Equal(a.QueryExpr, b.QueryExpr, msg) 285 | as.Equal(a.UpdateExpr, b.UpdateExpr, msg) 286 | as.Equal(a.IsAnonymous, b.IsAnonymous, msg) 287 | as.Equal(a.Kind, b.Kind, msg) 288 | if a.OrType == nil { 289 | as.Nil(b.OrType) 290 | } else { 291 | as.NotNil(b.OrType) 292 | as.Equal(a.OrType.String(), b.OrType.String()) 293 | } 294 | 295 | as.Equal(len(a.Tag), len(b.Tag), msg) 296 | for k, v := range a.Tag { 297 | as.Equal(v, b.Tag[k], msg) 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /tests_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "log" 5 | "os" 6 | 7 | "gorm.io/driver/mysql" 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/logger" 10 | ) 11 | 12 | func init() { 13 | db := newDB() 14 | testDB(db) 15 | } 16 | 17 | func ptr[T any](v T) *T { 18 | return &v 19 | } 20 | 21 | func newDB() *gorm.DB { 22 | dbDSN := os.Getenv("GORM_DSN") 23 | if dbDSN == "" { 24 | dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" 25 | } 26 | db, err := gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) 27 | if err != nil { 28 | log.Printf("connect to mysql fail: %s\n", err) 29 | os.Exit(1) 30 | } 31 | 32 | if debug := os.Getenv("DEBUG"); debug == "true" { 33 | db.Logger = db.Logger.LogMode(logger.Info) 34 | } else if debug == "false" { 35 | db.Logger = db.Logger.LogMode(logger.Silent) 36 | } 37 | 38 | return db 39 | } 40 | 41 | func testDB(db *gorm.DB) { 42 | sqlDB, err := db.DB() 43 | if err != nil { 44 | log.Printf("failed to connect database, got error %v", err) 45 | os.Exit(1) 46 | } 47 | 48 | err = sqlDB.Ping() 49 | if err != nil { 50 | log.Printf("failed to ping sqlDB, got error %v", err) 51 | os.Exit(1) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "reflect" 7 | 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/clause" 10 | ) 11 | 12 | func buildSQLUpdate(opt interface{}) (result map[string]interface{}, err error) { 13 | defer func() { 14 | if r := recover(); r != nil { 15 | err = packPanicError(r) 16 | } 17 | }() 18 | 19 | rv, rt, err := getValueAndType(opt) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | // 针对类型的检查 解析的时候有做 25 | sqlType, err := parseStructType(rt) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | // 遍历 field,将非 nil 的值拼到 map 中 31 | return buildUpdateMap(rv, sqlType) 32 | } 33 | 34 | // 遍历 field,将非 nil 的值拼到 map 中 35 | func buildUpdateMap(rv reflect.Value, structType *structType) (result map[string]interface{}, err error) { 36 | result = make(map[string]interface{}) 37 | for _, name := range structType.Names { 38 | column := structType.Fields[name] // 前置函数已经检查过,一定存在 39 | data := rv.FieldByName(column.Name) 40 | // 字段的值是空值, 直接忽略, 不做处理 41 | if isEmptyValue(data) { 42 | continue 43 | } 44 | if data.Kind() == reflect.Ptr { 45 | data = data.Elem() 46 | } 47 | if column.UpdateExpr != "" { 48 | updateExprBuilder := updaterMap[column.UpdateExpr] // 前置函数已经检查过,一定存在 49 | if updaterResult := updateExprBuilder(column.Column, data.Interface()); updaterResult.SQL != "" { 50 | result[column.Column] = updaterResult 51 | } 52 | } else { 53 | result[column.Column] = data.Interface() 54 | } 55 | } 56 | 57 | return result, nil 58 | } 59 | 60 | const ( 61 | updateExprAdd = "+" 62 | updateExprSub = "-" 63 | updateExprMergeJSON = "merge_json" 64 | ) 65 | 66 | type buildUpdateExpr func(field string, data interface{}) clause.Expr 67 | 68 | var updaterMap = map[string]buildUpdateExpr{ 69 | updateExprAdd: func(field string, data interface{}) clause.Expr { 70 | return gorm.Expr(field+" + ?", data) 71 | }, 72 | updateExprSub: func(field string, data interface{}) clause.Expr { 73 | return gorm.Expr(field+" - ?", data) 74 | }, 75 | updateExprMergeJSON: func(field string, data interface{}) clause.Expr { 76 | var bs []byte 77 | if isMergeJSONStruct(data) { 78 | dataMap, _ := mergeJSONStructToJSONMap(data) 79 | bs, _ = json.Marshal(dataMap) 80 | } else { 81 | bs, _ = json.Marshal(data) 82 | } 83 | s := string(bs) 84 | if s == "" { 85 | return clause.Expr{} 86 | } 87 | 88 | return gorm.Expr("CASE WHEN (`"+field+"` IS NULL OR `"+field+"` = '') THEN CAST(? AS JSON) ELSE JSON_MERGE_PATCH(`"+field+"`, CAST(? AS JSON)) END", s, s) 89 | }, 90 | } 91 | 92 | func isMergeJSONStruct(v interface{}) bool { 93 | vt := reflect.TypeOf(v) 94 | if vt.Kind() == reflect.Ptr { 95 | vt = vt.Elem() 96 | } 97 | return vt.Kind() == reflect.Struct 98 | } 99 | 100 | func mergeJSONStructToJSONMap(v interface{}) (map[string]interface{}, error) { 101 | vt := reflect.TypeOf(v) 102 | vv := reflect.ValueOf(v) 103 | 104 | if vt.Kind() == reflect.Ptr { 105 | vt = vt.Elem() 106 | vv = vv.Elem() 107 | } 108 | if vt.Kind() != reflect.Struct { 109 | return nil, fmt.Errorf("update(JSON_MERGE_PATCH) need struct type") 110 | } 111 | 112 | m := map[string]interface{}{} 113 | for i := 0; i < vt.NumField(); i++ { 114 | vtField := vt.Field(i) 115 | vvField := vv.Field(i) 116 | 117 | if !vvField.IsValid() { 118 | continue 119 | } 120 | 121 | jsonField := vtField.Tag.Get("json") 122 | if jsonField == "" || jsonField == "-" { 123 | continue 124 | } 125 | 126 | // ptr 127 | if vtField.Type.Kind() == reflect.Ptr { 128 | if vvField.IsNil() { 129 | continue 130 | } 131 | m[jsonField] = vvField.Elem().Interface() 132 | } else { 133 | m[jsonField] = vvField.Interface() 134 | } 135 | } 136 | 137 | return m, nil 138 | } 139 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | func TestBuildSQLUpdate(t *testing.T) { 12 | as := assert.New(t) 13 | db := newDB() 14 | as.Nil(db.Migrator().AutoMigrate(&User{})) 15 | 16 | testBuildSQLUpdate := func(opt interface{}, check func(result map[string]interface{}, sql string, err error)) { 17 | result, err := buildSQLUpdate(opt) 18 | sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { 19 | return tx.Table("user").Where(Query(struct { 20 | ID int `gorm:"column:id"` 21 | }{ID: 1})).Updates(Update(opt)) 22 | }) 23 | check(result, sql, err) 24 | } 25 | 26 | t.Run("invalid", func(t *testing.T) { 27 | testBuildSQLUpdate(struct { 28 | Age *int `gorm:"column:age; update_expr:invalid"` 29 | }{ 30 | Age: ptr[int](1), 31 | }, func(m map[string]interface{}, sql string, err error) { 32 | as.NotNil(err) 33 | as.Equal("", sql) 34 | 35 | as.Equal("field(Age) update_expr(invalid) invalid", err.Error()) 36 | }) 37 | }) 38 | 39 | t.Run("empty", func(t *testing.T) { 40 | testBuildSQLUpdate(struct { 41 | Name *string `gorm:"column:name"` 42 | }{}, func(m map[string]interface{}, sql string, err error) { 43 | as.Nil(err) 44 | as.Equal("", sql) 45 | 46 | as.Len(m, 0) 47 | }) 48 | }) 49 | 50 | t.Run("direct set", func(t *testing.T) { 51 | t.Run("one field", func(t *testing.T) { 52 | name := "bob" 53 | 54 | testBuildSQLUpdate(struct { 55 | Name *string `gorm:"column:name"` 56 | Age *int `gorm:"column:age"` 57 | }{ 58 | Name: &name, 59 | }, func(m map[string]interface{}, sql string, err error) { 60 | as.Nil(err) 61 | as.Equal("UPDATE `user` SET `name`='bob' WHERE `id` = 1", sql) 62 | 63 | as.Len(m, 1) 64 | as.Equal(name, m["name"]) 65 | }) 66 | }) 67 | 68 | t.Run("field 名称不同", func(t *testing.T) { 69 | name := "bob" 70 | age := 20 71 | testBuildSQLUpdate(struct { 72 | Name *string `gorm:"column:name_jjj"` 73 | Age *int `gorm:"column:age_hhh"` 74 | }{ 75 | Name: &name, 76 | Age: &age, 77 | }, func(m map[string]interface{}, sql string, err error) { 78 | as.Nil(err) 79 | as.Equal("UPDATE `user` SET `age_hhh`=20,`name_jjj`='bob' WHERE `id` = 1", sql) 80 | 81 | as.Len(m, 2) 82 | as.Equal(name, m["name_jjj"]) 83 | as.Equal(age, m["age_hhh"]) 84 | }) 85 | }) 86 | }) 87 | 88 | t.Run("+", func(t *testing.T) { 89 | testBuildSQLUpdate(struct { 90 | Age *int `gorm:"column:age; update_expr:+"` 91 | }{ 92 | Age: ptr[int](1), 93 | }, func(m map[string]interface{}, sql string, err error) { 94 | as.Nil(err) 95 | as.Equal("UPDATE `user` SET `age`=age + 1 WHERE `id` = 1", sql) 96 | 97 | as.Len(m, 1) 98 | as.Equal(clause.Expr{SQL: "age + ?", Vars: []interface{}{1}}, m["age"]) 99 | }) 100 | }) 101 | 102 | t.Run("-", func(t *testing.T) { 103 | testBuildSQLUpdate(struct { 104 | Age *int `gorm:"column:age; update_expr:-"` 105 | }{ 106 | Age: ptr[int](1), 107 | }, func(m map[string]interface{}, sql string, err error) { 108 | as.Nil(err) 109 | as.Equal("UPDATE `user` SET `age`=age - 1 WHERE `id` = 1", sql) 110 | 111 | as.Len(m, 1) 112 | as.Equal(clause.Expr{SQL: "age - ?", Vars: []interface{}{1}}, m["age"]) 113 | }) 114 | }) 115 | 116 | t.Run("merge_json", func(t *testing.T) { 117 | t.Run("nil", func(t *testing.T) { 118 | testBuildSQLUpdate(struct { 119 | Data *map[string]interface{} `gorm:"column:data; update_expr:merge_json"` 120 | }{ 121 | Data: nil, 122 | }, func(m map[string]interface{}, sql string, err error) { 123 | as.Nil(err) 124 | as.Equal("", sql) 125 | 126 | as.Len(m, 0) 127 | }) 128 | }) 129 | 130 | t.Run("empty", func(t *testing.T) { 131 | testBuildSQLUpdate(struct { 132 | Data string `gorm:"column:data; update_expr:merge_json"` 133 | }{ 134 | Data: "", 135 | }, func(m map[string]interface{}, sql string, err error) { 136 | as.Nil(err) 137 | as.Equal("", sql) 138 | 139 | as.Len(m, 0) 140 | }) 141 | }) 142 | 143 | t.Run("map", func(t *testing.T) { 144 | m := map[string]interface{}{"a": "a", "b": 2, "c": false} 145 | testBuildSQLUpdate(struct { 146 | Data *map[string]interface{} `gorm:"column:data; update_expr:merge_json"` 147 | }{ 148 | Data: &m, 149 | }, func(m map[string]interface{}, sql string, err error) { 150 | as.Nil(err) 151 | as.Equal("UPDATE `user` SET `data`=CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON)) END WHERE `id` = 1", sql) 152 | 153 | as.Len(m, 1) 154 | as.Equal(clause.Expr{ 155 | SQL: "CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST(? AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST(? AS JSON)) END", 156 | Vars: []interface{}{"{\"a\":\"a\",\"b\":2,\"c\":false}", "{\"a\":\"a\",\"b\":2,\"c\":false}"}, 157 | }, m["data"]) 158 | }) 159 | }) 160 | 161 | t.Run("struct-no-nil", func(t *testing.T) { 162 | type data struct { 163 | A string `json:"a"` 164 | B int `json:"b"` 165 | C bool `json:"c"` 166 | D string `json:"-"` 167 | } 168 | testBuildSQLUpdate(struct { 169 | Data *data `gorm:"column:data; update_expr:merge_json"` 170 | }{ 171 | Data: &data{ 172 | A: "a", 173 | B: 2, 174 | C: false, 175 | }, 176 | }, func(m map[string]interface{}, sql string, err error) { 177 | as.Nil(err) 178 | as.Equal("UPDATE `user` SET `data`=CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON)) END WHERE `id` = 1", sql) 179 | 180 | as.Len(m, 1) 181 | as.Equal(clause.Expr{ 182 | SQL: "CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST(? AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST(? AS JSON)) END", 183 | Vars: []interface{}{"{\"a\":\"a\",\"b\":2,\"c\":false}", "{\"a\":\"a\",\"b\":2,\"c\":false}"}, 184 | }, m["data"]) 185 | }) 186 | }) 187 | 188 | t.Run("struct-no-nil", func(t *testing.T) { 189 | type data struct { 190 | A interface{} `json:"a"` 191 | } 192 | testBuildSQLUpdate(struct { 193 | Data *data `gorm:"column:data; update_expr:merge_json"` 194 | }{ 195 | Data: &data{ 196 | A: nil, 197 | }, 198 | }, func(m map[string]interface{}, sql string, err error) { 199 | as.Nil(err) 200 | as.Equal("UPDATE `user` SET `data`=CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST('{\"a\":null}' AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST('{\"a\":null}' AS JSON)) END WHERE `id` = 1", sql) 201 | 202 | as.Len(m, 1) 203 | as.Equal(clause.Expr{ 204 | SQL: "CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST(? AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST(? AS JSON)) END", 205 | Vars: []interface{}{"{\"a\":null}", "{\"a\":null}"}, 206 | }, m["data"]) 207 | }) 208 | }) 209 | 210 | t.Run("struct-no-nil - pointer of pointer", func(t *testing.T) { 211 | type data struct { 212 | A string `json:"a"` 213 | B int `json:"b"` 214 | C bool `json:"c"` 215 | } 216 | dataX := &data{ 217 | A: "a", 218 | B: 2, 219 | C: false, 220 | } 221 | testBuildSQLUpdate(struct { 222 | Data **data `gorm:"column:data; update_expr:merge_json"` 223 | }{ 224 | Data: &dataX, 225 | }, func(m map[string]interface{}, sql string, err error) { 226 | as.Nil(err) 227 | as.Equal("UPDATE `user` SET `data`=CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST('{\"a\":\"a\",\"b\":2,\"c\":false}' AS JSON)) END WHERE `id` = 1", sql) 228 | 229 | as.Len(m, 1) 230 | as.Equal(clause.Expr{ 231 | SQL: "CASE WHEN (`data` IS NULL OR `data` = '') THEN CAST(? AS JSON) ELSE JSON_MERGE_PATCH(`data`, CAST(? AS JSON)) END", 232 | Vars: []interface{}{"{\"a\":\"a\",\"b\":2,\"c\":false}", "{\"a\":\"a\",\"b\":2,\"c\":false}"}, 233 | }, m["data"]) 234 | }) 235 | }) 236 | 237 | t.Run("struct-nil", func(t *testing.T) { 238 | type data struct { 239 | A string `json:"a"` 240 | B int `json:"b"` 241 | C bool `json:"c"` 242 | } 243 | testBuildSQLUpdate(struct { 244 | Data *data `gorm:"column:data; update_expr:merge_json"` 245 | }{ 246 | Data: nil, 247 | }, func(m map[string]interface{}, sql string, err error) { 248 | as.Nil(err) 249 | as.Equal("", sql) 250 | 251 | as.Len(m, 0) 252 | }) 253 | }) 254 | }) 255 | } 256 | 257 | func Test_StructHelper(t *testing.T) { 258 | as := assert.New(t) 259 | 260 | { 261 | _, err := mergeJSONStructToJSONMap(0) 262 | as.NotNil(err) 263 | as.Equal("update(JSON_MERGE_PATCH) need struct type", err.Error()) 264 | } 265 | 266 | { 267 | _, err := mergeJSONStructToJSONMap(ptr[int64](1)) 268 | as.NotNil(err) 269 | as.Equal("update(JSON_MERGE_PATCH) need struct type", err.Error()) 270 | } 271 | 272 | { 273 | 274 | m, err := mergeJSONStructToJSONMap(struct { 275 | Name *string `json:"name"` 276 | }{}) 277 | as.Nil(err) 278 | as.Equal(map[string]interface{}(map[string]interface{}{}), m) 279 | } 280 | 281 | { 282 | 283 | m, err := mergeJSONStructToJSONMap(struct { 284 | Name *string `json:"name"` 285 | }{ 286 | Name: ptr("name1"), 287 | }) 288 | as.Nil(err) 289 | as.Equal(map[string]interface{}(map[string]interface{}{"name": "name1"}), m) 290 | } 291 | 292 | { 293 | 294 | m, err := mergeJSONStructToJSONMap(struct { 295 | Name *string `json:"name"` 296 | Age *int `json:"age"` 297 | }{ 298 | Name: ptr("name1"), 299 | Age: nil, 300 | }) 301 | as.Nil(err) 302 | as.Equal(map[string]interface{}(map[string]interface{}{"name": "name1"}), m) 303 | } 304 | 305 | { 306 | 307 | m, err := mergeJSONStructToJSONMap(struct { 308 | Name *string `json:"name"` 309 | Age int32 `json:"age"` 310 | }{ 311 | Name: ptr("name1"), 312 | Age: 0, 313 | }) 314 | as.Nil(err) 315 | as.Equal(map[string]interface{}(map[string]interface{}{"name": "name1", "age": int32(0)}), m) 316 | } 317 | } 318 | 319 | func Test_buildSQLUpdate(t *testing.T) { 320 | as := assert.New(t) 321 | 322 | t.Run("", func(t *testing.T) { 323 | _, err := buildSQLUpdate(nil) 324 | as.NotNil(err) 325 | as.Equal("gormx's data is invalid", err.Error()) 326 | }) 327 | } 328 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | func getValueAndType(structData interface{}) (reflect.Value, reflect.Type, error) { 9 | rv := reflect.ValueOf(structData) 10 | 11 | if !rv.IsValid() { 12 | return reflect.Value{}, nil, fmt.Errorf("gormx's data is invalid") 13 | } 14 | 15 | if rv.Kind() == reflect.Ptr { 16 | rv = rv.Elem() 17 | } 18 | 19 | if rv.Kind() != reflect.Struct { 20 | return reflect.Value{}, nil, fmt.Errorf("data's kind must be struct, but got '%s'", rv.Kind()) 21 | } 22 | 23 | rt := rv.Type() 24 | return rv, rt, nil 25 | } 26 | 27 | func packPanicError(r interface{}) (err error) { 28 | switch je := r.(type) { 29 | case error: 30 | return je 31 | default: 32 | return fmt.Errorf("gormx panic: %s", r) 33 | } 34 | } 35 | 36 | func interfaceToSlice(v any) []any { 37 | rv := reflect.ValueOf(v) 38 | if rv.Kind() != reflect.Slice && rv.Kind() != reflect.Array { 39 | panic("interfaceToSlice: v must be slice") 40 | } 41 | 42 | sliceType := rv.Type().Elem() 43 | slice := make([]any, rv.Len()) 44 | for i := 0; i < rv.Len(); i++ { 45 | x := reflect.New(sliceType).Elem() 46 | x.Set(rv.Index(i)) 47 | slice[i] = x.Interface() 48 | } 49 | return slice 50 | } 51 | 52 | func isEmptyValue(rv reflect.Value) bool { 53 | // data may be string, int, *string, slice, check data is empty 54 | // example: "", 0, nil, []string{}, []int{}, []*string{}, []*int{} 55 | 56 | switch rv.Kind() { 57 | case reflect.String: 58 | return rv.String() == "" 59 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 60 | return rv.Int() == 0 61 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 62 | return rv.Uint() == 0 63 | case reflect.Float32, reflect.Float64: 64 | return rv.Float() == 0 65 | case reflect.Bool: 66 | return !rv.Bool() 67 | case reflect.Interface, reflect.Ptr: 68 | return rv.IsNil() 69 | case reflect.Invalid: 70 | return true 71 | case reflect.Complex64, reflect.Complex128: 72 | return rv.Complex() == 0 73 | case reflect.Slice, reflect.Array, reflect.Map: 74 | return rv.IsNil() || rv.Len() == 0 75 | default: 76 | return false 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_packPanicError(t *testing.T) { 12 | as := assert.New(t) 13 | 14 | t.Run("panic-string", func(t *testing.T) { 15 | defer func() { 16 | err := packPanicError(recover()) 17 | as.NotNil(err) 18 | as.Equal("gormx panic: test", err.Error()) 19 | }() 20 | 21 | panic("test") 22 | }) 23 | 24 | t.Run("panic-error", func(t *testing.T) { 25 | defer func() { 26 | err := packPanicError(recover()) 27 | as.NotNil(err) 28 | as.Equal("error", err.Error()) 29 | }() 30 | 31 | panic(fmt.Errorf("error")) 32 | }) 33 | } 34 | 35 | func Test_interfaceToSlice(t *testing.T) { 36 | tests := []struct { 37 | name string 38 | args any 39 | want []any 40 | err error 41 | }{ 42 | {"1", []int{1, 2, 3}, []any{1, 2, 3}, nil}, 43 | {"2", [...]int{1, 2, 3}, []any{1, 2, 3}, nil}, 44 | {"3", "string", nil, fmt.Errorf("gormx panic: interfaceToSlice: v must be slice")}, 45 | } 46 | for _, tt := range tests { 47 | t.Run(tt.name, func(t *testing.T) { 48 | if tt.err == nil { 49 | assert.Equalf(t, tt.want, interfaceToSlice(tt.args), "interfaceToSlice(%v)", tt.args) 50 | } else { 51 | defer func() { 52 | err := packPanicError(recover()) 53 | assert.NotNil(t, err) 54 | assert.Equal(t, tt.err.Error(), err.Error()) 55 | }() 56 | interfaceToSlice(tt.args) 57 | } 58 | }) 59 | } 60 | } 61 | 62 | func Test_isEmptyValue(t *testing.T) { 63 | tests := []struct { 64 | name string 65 | args any 66 | want bool 67 | }{ 68 | {"1", nil, true}, 69 | 70 | {"int-0", 0, true}, 71 | {"int-1", 1, false}, 72 | 73 | {"uint-0", uint(0), true}, 74 | {"uint-1", uint(1), false}, 75 | 76 | {"float-0.0", 0.0, true}, 77 | {"float-1.0", 1.0, false}, 78 | 79 | {"complex-0+0i", 0 + 0i, true}, 80 | {"complex-1+0i", 1 + 0i, false}, 81 | 82 | {"bool-false", false, true}, 83 | {"bool-true", true, false}, 84 | 85 | {"string-''", "", true}, 86 | {"string-str", "str", false}, 87 | 88 | {"slice-[]", []int{}, true}, 89 | {"slice-[1]", []int{1}, false}, 90 | 91 | {"map-{}", map[string]int{}, true}, 92 | {"map-{1}", map[string]int{"1": 1}, false}, 93 | 94 | {"struct-{}", struct{}{}, false}, 95 | } 96 | for _, tt := range tests { 97 | t.Run(tt.name, func(t *testing.T) { 98 | v := reflect.ValueOf(tt.args) 99 | assert.Equalf(t, tt.want, isEmptyValue(v), "isEmptyValue(%v)", tt.args) 100 | }) 101 | } 102 | } 103 | 104 | func Test_getValueAndType(t *testing.T) { 105 | as := assert.New(t) 106 | 107 | t.Run("invalid", func(t *testing.T) { 108 | _, _, err := getValueAndType(nil) 109 | as.NotNil(err) 110 | as.Equal("gormx's data is invalid", err.Error()) 111 | }) 112 | 113 | t.Run("pointer", func(t *testing.T) { 114 | _, _, _ = getValueAndType(&struct{}{}) 115 | }) 116 | 117 | t.Run("not struct", func(t *testing.T) { 118 | _, _, err := getValueAndType(1) 119 | as.NotNil(err) 120 | as.Equal("data's kind must be struct, but got 'int'", err.Error()) 121 | }) 122 | } 123 | -------------------------------------------------------------------------------- /where.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func buildSQLWhere(where interface{}) (expression clause.Expression, err error) { 11 | defer func() { 12 | if r := recover(); r != nil { 13 | err = packPanicError(r) 14 | } 15 | }() 16 | 17 | rv, rt, err := getValueAndType(where) 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | sqlType, err := parseStructType(rt) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | return buildClauseExpression(rv, sqlType, true) 28 | } 29 | 30 | func buildClauseExpression(rv reflect.Value, sqlType *structType, joinAnd bool) (result clause.Expression, err error) { 31 | expressions := []clause.Expression{} 32 | for _, name := range sqlType.Names { 33 | column := sqlType.Fields[name] // 前置步骤检查过,一定存在 34 | 35 | // 计算字段的值 36 | data := rv.FieldByName(column.Name) 37 | // 字段的值是 nil 直接忽略 不做处理 38 | if isEmptyValue(data) { 39 | continue 40 | } 41 | if data.Kind() == reflect.Ptr { 42 | data = data.Elem() 43 | } 44 | inter := data.Interface() 45 | 46 | queryExprBuilder, err := getQueryExpr(column.QueryExpr) 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | if column.OrType != nil { 52 | orType, err := parseStructType(column.OrType) 53 | if err != nil { 54 | return nil, err 55 | } 56 | if data.Kind() == reflect.Slice { 57 | list := []clause.Expression{} 58 | for i := 0; i < data.Len(); i++ { 59 | or, err := buildClauseExpression(data.Index(i), orType, false) 60 | if err != nil { 61 | return nil, err 62 | } else if or != nil { 63 | list = append(list, or) 64 | } 65 | } 66 | if len(list) > 0 { 67 | expressions = append(expressions, joinExpression(list, false)) 68 | } 69 | } else { 70 | or, err := buildClauseExpression(data, orType, false) 71 | if err != nil { 72 | return nil, err 73 | } else if or != nil { 74 | expressions = append(expressions, or) 75 | } 76 | } 77 | } else { 78 | and := queryExprBuilder(column.Column, inter) 79 | if and != nil { 80 | expressions = append(expressions, and) 81 | } 82 | } 83 | } 84 | 85 | return joinExpression(expressions, joinAnd), nil 86 | } 87 | 88 | func joinExpression(exprs []clause.Expression, joinAnd bool) clause.Expression { 89 | if len(exprs) == 1 { 90 | return exprs[0] 91 | } 92 | if joinAnd { 93 | return clause.And(exprs...) 94 | } 95 | return clause.Or(exprs...) 96 | } 97 | 98 | const ( 99 | operatorOr = "or" // clause.OrConditions 100 | operatorIn = "in" // clause.IN 101 | operatorNin = "not in" // notIn // 无 clause.NIN 102 | operatorGt = ">" // clause.Gt 103 | operatorGte = ">=" // clause.Gte 104 | operatorLt = "<" // clause.Lt 105 | operatorLte = "<=" // clause.Lte 106 | operatorEq = "=" // clause.Eq 107 | operatorNeq = "!=" // clause.Neq 108 | operatorLike = "like" // clause.Like 109 | operatorNull = "null" // clause.Null 110 | ) 111 | 112 | type ( 113 | buildExpression func(field string, data interface{}) clause.Expression 114 | ) 115 | 116 | var queryExprMap = map[string]struct { 117 | build buildExpression 118 | }{ 119 | operatorLt: { 120 | build: func(field string, data interface{}) clause.Expression { 121 | return clause.Lt{ 122 | Column: clause.Column{Name: field}, 123 | Value: data, 124 | } 125 | }, 126 | }, 127 | operatorLte: { 128 | build: func(field string, data interface{}) clause.Expression { 129 | return clause.Lte{ 130 | Column: clause.Column{Name: field}, 131 | Value: data, 132 | } 133 | }, 134 | }, 135 | operatorEq: { 136 | build: func(field string, data interface{}) clause.Expression { 137 | return clause.Eq{ 138 | Column: clause.Column{Name: field}, 139 | Value: data, 140 | } 141 | }, 142 | }, 143 | "": { 144 | build: func(field string, data interface{}) clause.Expression { 145 | return clause.Eq{ 146 | Column: clause.Column{Name: field}, 147 | Value: data, 148 | } 149 | }, 150 | }, 151 | operatorNeq: { 152 | build: func(field string, data interface{}) clause.Expression { 153 | return clause.Neq{ 154 | Column: clause.Column{Name: field}, 155 | Value: data, 156 | } 157 | }, 158 | }, 159 | operatorGt: { 160 | build: func(field string, data interface{}) clause.Expression { 161 | return clause.Gt{ 162 | Column: clause.Column{Name: field}, 163 | Value: data, 164 | } 165 | }, 166 | }, 167 | operatorGte: { 168 | build: func(field string, data interface{}) clause.Expression { 169 | return clause.Gte{ 170 | Column: clause.Column{Name: field}, 171 | Value: data, 172 | } 173 | }, 174 | }, 175 | operatorNull: { 176 | build: func(field string, data interface{}) clause.Expression { 177 | switch v := data.(type) { 178 | case bool: 179 | if v { 180 | return clause.Eq{ 181 | Column: clause.Column{Name: field}, 182 | Value: nil, 183 | } 184 | } else { 185 | return clause.Neq{ 186 | Column: clause.Column{Name: field}, 187 | Value: nil, 188 | } 189 | } 190 | } 191 | return nil 192 | }, 193 | }, 194 | operatorIn: { 195 | build: func(field string, data interface{}) clause.Expression { 196 | return clause.IN{ 197 | Column: clause.Column{Name: field}, 198 | Values: interfaceToSlice(data), 199 | } 200 | }, 201 | }, 202 | operatorNin: { 203 | build: func(field string, data interface{}) clause.Expression { 204 | return notIn{clause.IN{ 205 | Column: clause.Column{Name: field}, 206 | Values: interfaceToSlice(data), 207 | }} 208 | }, 209 | }, 210 | operatorLike: { 211 | build: func(field string, data interface{}) clause.Expression { 212 | return clause.Like{ 213 | Column: clause.Column{Name: field}, 214 | Value: data.(string), 215 | } 216 | }, 217 | }, 218 | operatorOr: {}, 219 | } 220 | 221 | func getQueryExpr(queryExprString string) (buildExpression, error) { 222 | queryExpr, ok := queryExprMap[queryExprString] 223 | if !ok { 224 | return nil, fmt.Errorf("query_expr '%s' invalid", queryExprString) 225 | } 226 | return queryExpr.build, nil 227 | } 228 | -------------------------------------------------------------------------------- /where_test.go: -------------------------------------------------------------------------------- 1 | package gormx 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/clause" 10 | ) 11 | 12 | type User struct { 13 | Name string `gorm:"column:name"` 14 | } 15 | 16 | func (User) TableName() string { 17 | return "user" 18 | } 19 | 20 | func TestBuildSQLWhere(t *testing.T) { 21 | as := assert.New(t) 22 | db := newDB() 23 | 24 | type AOrInvalid struct { 25 | In string `gorm:"column:in; query_expr:in"` 26 | } 27 | type BOrOfOrInvalid struct { 28 | Or struct { 29 | Or AOrInvalid `gorm:"query_expr:or"` 30 | } `gorm:"query_expr:or"` 31 | } 32 | 33 | testBuildSQLWhere := func(opt interface{}, check func(expression clause.Expression, sql string, err error)) { 34 | expression, err := buildSQLWhere(opt) 35 | sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Where(Query(opt)).Find(&[]User{}) }) 36 | check(expression, sql, err) 37 | } 38 | 39 | t.Run("invalid", func(t *testing.T) { 40 | t.Run("basetype", func(t *testing.T) { 41 | testBuildSQLWhere(struct { 42 | Name string `gorm:"column:name; query_expr:invalid"` 43 | }{}, func(expression clause.Expression, sql string, err error) { 44 | as.NotNil(err) 45 | as.Contains(err.Error(), `field(Name) query_expr(invalid) invalid`) 46 | }) 47 | }) 48 | 49 | t.Run("pointer", func(t *testing.T) { 50 | testBuildSQLWhere(struct { 51 | Name *string `gorm:"column:name; query_expr:invalid"` 52 | }{}, func(expression clause.Expression, sql string, err error) { 53 | as.NotNil(err) 54 | as.Contains(err.Error(), `field(Name) query_expr(invalid) invalid`) 55 | }) 56 | }) 57 | 58 | t.Run("invalid data", func(t *testing.T) { 59 | _, err := buildSQLWhere(nil) 60 | as.NotNil(err) 61 | as.Equal("gormx's data is invalid", err.Error()) 62 | }) 63 | 64 | t.Run("in[or] with invalid datatype", func(t *testing.T) { 65 | testBuildSQLWhere(struct { 66 | Or AOrInvalid `gorm:"query_expr:or"` 67 | }{}, func(expression clause.Expression, sql string, err error) { 68 | as.NotNil(err) 69 | as.Contains(err.Error(), `struct field(In) with in query_expr must be slice/array`) 70 | }) 71 | }) 72 | 73 | t.Run("in[or of or] with invalid datatype", func(t *testing.T) { 74 | testBuildSQLWhere(struct { 75 | Or BOrOfOrInvalid `gorm:"query_expr:or"` 76 | }{}, func(expression clause.Expression, sql string, err error) { 77 | as.NotNil(err) 78 | as.Contains(err.Error(), `struct field(In) with in query_expr must be slice/array`) 79 | }) 80 | }) 81 | }) 82 | 83 | t.Run("empty-expression", func(t *testing.T) { 84 | t.Run("basetype", func(t *testing.T) { 85 | testBuildSQLWhere(struct { 86 | Name string `gorm:"column:name"` 87 | }{}, func(expression clause.Expression, sql string, err error) { 88 | as.Nil(err) 89 | as.Equal("SELECT * FROM `user`", sql) 90 | as.Nil(expression) 91 | }) 92 | }) 93 | 94 | t.Run("pointer", func(t *testing.T) { 95 | testBuildSQLWhere(struct { 96 | Name *string `gorm:"column:name"` 97 | }{}, func(expression clause.Expression, sql string, err error) { 98 | as.Nil(err) 99 | as.Equal("SELECT * FROM `user`", sql) 100 | as.Nil(expression) 101 | }) 102 | }) 103 | }) 104 | 105 | t.Run("one field", func(t *testing.T) { 106 | t.Run("basetype", func(t *testing.T) { 107 | name := "bob" 108 | testBuildSQLWhere(struct { 109 | Name string `gorm:"column:name"` 110 | }{ 111 | Name: name, 112 | }, func(expression clause.Expression, sql string, err error) { 113 | as.Nil(err) 114 | as.Equal("SELECT * FROM `user` WHERE `name` = 'bob'", sql) 115 | assertExprEq[clause.Eq](t, expression, "name", name) 116 | }) 117 | }) 118 | 119 | t.Run("pointer", func(t *testing.T) { 120 | name := "bob" 121 | testBuildSQLWhere(struct { 122 | Name *string `gorm:"column:name"` 123 | }{ 124 | Name: &name, 125 | }, func(expression clause.Expression, sql string, err error) { 126 | as.Nil(err) 127 | as.Equal("SELECT * FROM `user` WHERE `name` = 'bob'", sql) 128 | assertExprEq[clause.Eq](t, expression, "name", name) 129 | }) 130 | }) 131 | }) 132 | 133 | t.Run("two field", func(t *testing.T) { 134 | t.Run("not empty", func(t *testing.T) { 135 | name := "bob" 136 | age := 0 137 | 138 | testBuildSQLWhere(struct { 139 | Name string `gorm:"column:name_jjj"` 140 | Age *int `gorm:"column:age_hhh"` 141 | }{ 142 | Name: name, 143 | Age: &age, 144 | }, func(expression clause.Expression, sql string, err error) { 145 | as.Nil(err) 146 | as.Equal("SELECT * FROM `user` WHERE (`name_jjj` = 'bob' AND `age_hhh` = 0)", sql) 147 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 148 | assertExprEq[clause.Eq](t, exprs[0], "name_jjj", name) 149 | assertExprEq[clause.Eq](t, exprs[1], "age_hhh", age) 150 | }) 151 | }) 152 | 153 | t.Run("empty", func(t *testing.T) { 154 | testBuildSQLWhere(struct { 155 | Name string `gorm:"column:name_jjj"` 156 | Age *int `gorm:"column:age_hhh"` 157 | }{ 158 | Name: "", 159 | Age: nil, 160 | }, func(expression clause.Expression, sql string, err error) { 161 | as.Nil(err) 162 | as.Equal("SELECT * FROM `user`", sql) 163 | as.Nil(expression) 164 | }) 165 | }) 166 | }) 167 | 168 | t.Run("in", func(t *testing.T) { 169 | ids := []int64{1, 2, 3} 170 | names := []string{"a", "b"} 171 | 172 | idsEmpty := []int64{} 173 | namesEmpty := []string{} 174 | 175 | t.Run("in && not int", func(t *testing.T) { 176 | testBuildSQLWhere(struct { 177 | IDs *[]int64 `gorm:"column:id; query_expr:in"` 178 | Names *[]string `gorm:"column:name; query_expr:not in"` 179 | }{ 180 | IDs: &ids, 181 | Names: &names, 182 | }, func(expression clause.Expression, sql string, err error) { 183 | as.Nil(err) 184 | as.Equal("SELECT * FROM `user` WHERE (`id` IN (1,2,3) AND `name` NOT IN ('a','b'))", sql) 185 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 186 | assertExprIn(t, exprs[0], "id", toAnySlice(ids)) 187 | assertExprNotIn(t, exprs[1], "name", toAnySlice(names)) 188 | }) 189 | }) 190 | 191 | t.Run("empty slice", func(t *testing.T) { 192 | t.Run("basetype", func(t *testing.T) { 193 | testBuildSQLWhere(struct { 194 | IDs []int64 `gorm:"column:id; query_expr:in"` 195 | Names []string `gorm:"column:name; query_expr:not in"` 196 | }{ 197 | IDs: []int64{}, 198 | Names: []string{}, 199 | }, func(expression clause.Expression, sql string, err error) { 200 | as.Nil(err) 201 | as.Equal("SELECT * FROM `user`", sql) 202 | assertExprList[clause.AndConditions](t, expression, 0) 203 | }) 204 | }) 205 | 206 | t.Run("pointer", func(t *testing.T) { 207 | testBuildSQLWhere(struct { 208 | IDs *[]int64 `gorm:"column:id; query_expr:in"` 209 | Names *[]string `gorm:"column:name; query_expr:not in"` 210 | }{ 211 | IDs: &idsEmpty, 212 | Names: &namesEmpty, 213 | }, func(expression clause.Expression, sql string, err error) { 214 | as.Nil(err) 215 | as.Equal("SELECT * FROM `user` WHERE (`id` IN (NULL) AND `name` IS NOT NULL)", sql) 216 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 217 | assertExprIn(t, exprs[0], "id", []any{}) 218 | assertExprNotIn(t, exprs[1], "name", []any{}) 219 | }) 220 | }) 221 | }) 222 | 223 | t.Run("slice", func(t *testing.T) { 224 | testBuildSQLWhere(struct { 225 | IDs []int64 `gorm:"column:id; query_expr:in"` 226 | Names []string `gorm:"column:name; query_expr:not in"` 227 | }{ 228 | IDs: ids, 229 | Names: names, 230 | }, func(expression clause.Expression, sql string, err error) { 231 | as.Nil(err) 232 | as.Equal("SELECT * FROM `user` WHERE (`id` IN (1,2,3) AND `name` NOT IN ('a','b'))", sql) 233 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 234 | assertExprIn(t, exprs[0], "id", toAnySlice(ids)) 235 | assertExprNotIn(t, exprs[1], "name", toAnySlice(names)) 236 | }) 237 | }) 238 | 239 | t.Run("invalid", func(t *testing.T) { 240 | testBuildSQLWhere(struct { 241 | IDs []int64 `gorm:"column:id; query_expr:="` 242 | }{ 243 | IDs: ids, 244 | }, func(expression clause.Expression, sql string, err error) { 245 | as.NotNil(err) 246 | as.Equal("struct field(IDs) with eq query_expr can not be slice/array", err.Error()) 247 | }) 248 | }) 249 | }) 250 | 251 | t.Run("like", func(t *testing.T) { 252 | t.Run("like", func(t *testing.T) { 253 | name := "%name%" 254 | testBuildSQLWhere(struct { 255 | NameLike *string `gorm:"column:name; query_expr:like"` 256 | }{ 257 | NameLike: &name, 258 | }, func(expression clause.Expression, sql string, err error) { 259 | as.Nil(err) 260 | as.Equal("SELECT * FROM `user` WHERE `name` LIKE '%name%'", sql) 261 | assertExprEq[clause.Like](t, expression, "name", name) 262 | }) 263 | }) 264 | }) 265 | 266 | t.Run("compare", func(t *testing.T) { 267 | t.Run(">", func(t *testing.T) { 268 | testBuildSQLWhere(struct { 269 | ID *int `gorm:"column:id; query_expr:>"` 270 | Name *bool `gorm:"column:name; query_expr:null"` 271 | }{ 272 | ID: ptr[int](1), 273 | }, func(expression clause.Expression, sql string, err error) { 274 | as.Nil(err) 275 | as.Equal("SELECT * FROM `user` WHERE `id` > 1", sql) 276 | assertExprEq[clause.Gt](t, expression, "id", 1) 277 | }) 278 | }) 279 | 280 | t.Run(">=", func(t *testing.T) { 281 | testBuildSQLWhere(struct { 282 | ID *int `gorm:"column:id; query_expr:>="` 283 | Name *bool `gorm:"column:name; query_expr:null"` 284 | }{ 285 | ID: ptr[int](1), 286 | }, func(expression clause.Expression, sql string, err error) { 287 | as.Nil(err) 288 | as.Equal("SELECT * FROM `user` WHERE `id` >= 1", sql) 289 | assertExprEq[clause.Gte](t, expression, "id", 1) 290 | }) 291 | }) 292 | 293 | t.Run("=", func(t *testing.T) { 294 | testBuildSQLWhere(struct { 295 | ID *int `gorm:"column:id; query_expr:="` 296 | Name *bool `gorm:"column:name; query_expr:null"` 297 | }{ 298 | ID: ptr[int](1), 299 | }, func(expression clause.Expression, sql string, err error) { 300 | as.Nil(err) 301 | as.Equal("SELECT * FROM `user` WHERE `id` = 1", sql) 302 | assertExprEq[clause.Eq](t, expression, "id", 1) 303 | }) 304 | }) 305 | 306 | t.Run("<", func(t *testing.T) { 307 | testBuildSQLWhere(struct { 308 | ID *int `gorm:"column:id; query_expr:<"` 309 | Name *bool `gorm:"column:name; query_expr:null"` 310 | }{ 311 | ID: ptr[int](1), 312 | }, func(expression clause.Expression, sql string, err error) { 313 | as.Nil(err) 314 | as.Equal("SELECT * FROM `user` WHERE `id` < 1", sql) 315 | assertExprEq[clause.Lt](t, expression, "id", 1) 316 | }) 317 | }) 318 | 319 | t.Run("<=", func(t *testing.T) { 320 | testBuildSQLWhere(struct { 321 | ID *int `gorm:"column:id; query_expr:<="` 322 | Name *bool `gorm:"column:name; query_expr:null"` 323 | }{ 324 | ID: ptr[int](1), 325 | }, func(expression clause.Expression, sql string, err error) { 326 | as.Nil(err) 327 | as.Equal("SELECT * FROM `user` WHERE `id` <= 1", sql) 328 | assertExprEq[clause.Lte](t, expression, "id", 1) 329 | }) 330 | }) 331 | 332 | t.Run("!=", func(t *testing.T) { 333 | testBuildSQLWhere(struct { 334 | ID *int `gorm:"column:id; query_expr:!="` 335 | Name *bool `gorm:"column:name; query_expr:null"` 336 | }{ 337 | ID: ptr[int](1), 338 | }, func(expression clause.Expression, sql string, err error) { 339 | as.Nil(err) 340 | as.Equal("SELECT * FROM `user` WHERE `id` <> 1", sql) 341 | assertExprEq[clause.Neq](t, expression, "id", 1) 342 | }) 343 | }) 344 | }) 345 | 346 | t.Run("null", func(t *testing.T) { 347 | t.Run("is null - empty", func(t *testing.T) { 348 | testBuildSQLWhere(struct { 349 | Name *bool `gorm:"column:name; query_expr:null"` 350 | }{ 351 | Name: nil, 352 | }, func(expression clause.Expression, sql string, err error) { 353 | as.Nil(err) 354 | as.Equal("SELECT * FROM `user`", sql) 355 | as.Nil(expression) 356 | }) 357 | }) 358 | 359 | t.Run("is null - is null", func(t *testing.T) { 360 | testBuildSQLWhere(struct { 361 | Name *bool `gorm:"column:name; query_expr:null"` 362 | }{ 363 | Name: ptr(true), 364 | }, func(expression clause.Expression, sql string, err error) { 365 | as.Nil(err) 366 | as.Equal("SELECT * FROM `user` WHERE `name` IS NULL", sql) 367 | assertExprEq[clause.Eq](t, expression, "name", nil) 368 | }) 369 | }) 370 | 371 | t.Run("is null - is not null", func(t *testing.T) { 372 | testBuildSQLWhere(struct { 373 | Name *bool `gorm:"column:name; query_expr:null"` 374 | }{ 375 | Name: ptr(false), 376 | }, func(expression clause.Expression, sql string, err error) { 377 | as.Nil(err) 378 | as.Equal("SELECT * FROM `user` WHERE `name` IS NOT NULL", sql) 379 | assertExprEq[clause.Neq](t, expression, "name", nil) 380 | }) 381 | }) 382 | 383 | t.Run("is null - not bool", func(t *testing.T) { 384 | testBuildSQLWhere(struct { 385 | Name string `gorm:"column:name; query_expr:null"` 386 | }{ 387 | Name: "x", 388 | }, func(expression clause.Expression, sql string, err error) { 389 | as.Nil(err) 390 | as.Equal("SELECT * FROM `user`", sql) 391 | as.Nil(expression) 392 | }) 393 | }) 394 | }) 395 | 396 | t.Run("anonymous struct", func(t *testing.T) { 397 | type WhereUser struct { 398 | UserID *int64 `gorm:"column:user_id"` 399 | } 400 | testBuildSQLWhere(struct { 401 | WhereUser 402 | ParentID *int64 `gorm:"column:parent_id"` 403 | }{ 404 | ParentID: ptr[int64](1), 405 | WhereUser: WhereUser{ 406 | UserID: ptr[int64](2), 407 | }, 408 | }, func(expression clause.Expression, sql string, err error) { 409 | as.Nil(err) 410 | as.Equal("SELECT * FROM `user` WHERE (`user_id` = 2 AND `parent_id` = 1)", sql) 411 | 412 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 413 | assertExprEq[clause.Eq](t, exprs[0], "user_id", int64(2)) 414 | assertExprEq[clause.Eq](t, exprs[1], "parent_id", int64(1)) 415 | }) 416 | }) 417 | 418 | t.Run("or", func(t *testing.T) { 419 | type WhereUser struct { 420 | UserID *int64 `gorm:"column:user_id"` 421 | UserName *string `gorm:"column:user_name"` 422 | UserAge *int64 `gorm:"column:user_age"` 423 | OrClauses1 []WhereUser `gorm:"query_expr:or"` 424 | OrClauses2 *WhereUser `gorm:"query_expr:or"` 425 | } 426 | 427 | t.Run("or with slice", func(t *testing.T) { 428 | testBuildSQLWhere(WhereUser{ 429 | OrClauses1: []WhereUser{ 430 | { 431 | UserName: ptr("dirac"), 432 | }, 433 | { 434 | UserAge: ptr[int64](18), 435 | }, 436 | }, 437 | }, func(expression clause.Expression, sql string, err error) { 438 | as.Nil(err) 439 | as.Equal("SELECT * FROM `user` WHERE (`user_name` = 'dirac' OR `user_age` = 18)", sql) 440 | 441 | exprs := assertExprList[clause.OrConditions](t, expression, 2) 442 | assertExprEq[clause.Eq](t, exprs[0], "user_name", "dirac") 443 | assertExprEq[clause.Eq](t, exprs[1], "user_age", int64(18)) 444 | }) 445 | }) 446 | 447 | t.Run("and + or", func(t *testing.T) { 448 | testBuildSQLWhere(WhereUser{ 449 | UserID: ptr[int64](1), 450 | OrClauses1: []WhereUser{ 451 | { 452 | UserName: ptr("dirac"), 453 | }, 454 | { 455 | UserAge: ptr[int64](18), 456 | }, 457 | }, 458 | }, func(expression clause.Expression, sql string, err error) { 459 | as.Nil(err) 460 | as.Equal("SELECT * FROM `user` WHERE (`user_id` = 1 AND (`user_name` = 'dirac' OR `user_age` = 18))", sql) 461 | 462 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 463 | assertExprEq[clause.Eq](t, exprs[0], "user_id", int64(1)) 464 | exprs2 := assertExprList[clause.OrConditions](t, exprs[1], 2) 465 | assertExprEq[clause.Eq](t, exprs2[0], "user_name", "dirac") 466 | assertExprEq[clause.Eq](t, exprs2[1], "user_age", int64(18)) 467 | }) 468 | }) 469 | 470 | t.Run("empty or", func(t *testing.T) { 471 | testBuildSQLWhere(WhereUser{ 472 | UserID: ptr[int64](1), 473 | }, func(expression clause.Expression, sql string, err error) { 474 | as.Nil(err) 475 | as.Equal("SELECT * FROM `user` WHERE `user_id` = 1", sql) 476 | assertExprEq[clause.Eq](t, expression, "user_id", int64(1)) 477 | }) 478 | }) 479 | 480 | t.Run("empty or", func(t *testing.T) { 481 | testBuildSQLWhere(WhereUser{ 482 | UserID: ptr[int64](1), 483 | OrClauses1: []WhereUser{ 484 | {}, 485 | {}, 486 | }, 487 | }, func(expression clause.Expression, sql string, err error) { 488 | as.Nil(err) 489 | as.Equal("SELECT * FROM `user` WHERE `user_id` = 1", sql) 490 | assertExprEq[clause.Eq](t, expression, "user_id", int64(1)) 491 | }) 492 | }) 493 | 494 | t.Run("or of or", func(t *testing.T) { 495 | testBuildSQLWhere(WhereUser{ 496 | UserID: ptr[int64](1), 497 | OrClauses1: []WhereUser{ 498 | { 499 | UserAge: ptr[int64](18), 500 | OrClauses1: []WhereUser{ 501 | { 502 | UserName: ptr("bob"), 503 | }, 504 | { 505 | UserName: ptr("dirac"), 506 | }, 507 | }, 508 | }, 509 | { 510 | UserAge: ptr[int64](19), 511 | }, 512 | }, 513 | }, func(expression clause.Expression, sql string, err error) { 514 | as.Nil(err) 515 | as.Equal("SELECT * FROM `user` WHERE (`user_id` = 1 AND ((`user_age` = 18 OR (`user_name` = 'bob' OR `user_name` = 'dirac')) OR `user_age` = 19))", sql) 516 | 517 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 518 | assertExprEq[clause.Eq](t, exprs[0], "user_id", int64(1)) 519 | exprs2 := assertExprList[clause.OrConditions](t, exprs[1], 2) 520 | 521 | exprs3 := assertExprList[clause.OrConditions](t, exprs2[0], 2) 522 | assertExprEq[clause.Eq](t, exprs2[1], "user_age", int64(19)) 523 | 524 | assertExprEq[clause.Eq](t, exprs3[0], "user_age", int64(18)) 525 | exprs4 := assertExprList[clause.OrConditions](t, exprs3[1], 2) 526 | 527 | assertExprEq[clause.Eq](t, exprs4[0], "user_name", "bob") 528 | assertExprEq[clause.Eq](t, exprs4[1], "user_name", "dirac") 529 | }) 530 | }) 531 | 532 | t.Run("multi or", func(t *testing.T) { 533 | testBuildSQLWhere(WhereUser{ 534 | UserID: ptr[int64](1), 535 | OrClauses1: []WhereUser{ 536 | { 537 | UserAge: ptr[int64](18), 538 | }, 539 | { 540 | UserAge: ptr[int64](19), 541 | }, 542 | }, 543 | OrClauses2: &WhereUser{ 544 | UserName: ptr("dirac"), 545 | }, 546 | }, func(expression clause.Expression, sql string, err error) { 547 | as.Nil(err) 548 | as.Equal("SELECT * FROM `user` WHERE (`user_id` = 1 AND (`user_age` = 18 OR `user_age` = 19) AND `user_name` = 'dirac')", sql) 549 | 550 | exprs := assertExprList[clause.AndConditions](t, expression, 3) 551 | assertExprEq[clause.Eq](t, exprs[0], "user_id", int64(1)) 552 | exprs2 := assertExprList[clause.OrConditions](t, exprs[1], 2) 553 | assertExprEq[clause.Eq](t, exprs[2], "user_name", "dirac") 554 | 555 | assertExprEq[clause.Eq](t, exprs2[0], "user_age", int64(18)) 556 | assertExprEq[clause.Eq](t, exprs2[1], "user_age", int64(19)) 557 | }) 558 | }) 559 | 560 | t.Run("column=-, query_expr=or", func(t *testing.T) { 561 | type WhereUser struct { 562 | UserID *int64 `gorm:"column:user_id"` 563 | UserName *string `gorm:"column:user_name"` 564 | UserAge *int64 `gorm:"column:user_age"` 565 | OrClauses []WhereUser `gorm:"column:-; query_expr:or"` 566 | } 567 | 568 | testBuildSQLWhere(WhereUser{ 569 | UserAge: ptr[int64](18), 570 | OrClauses: []WhereUser{ 571 | { 572 | UserID: ptr[int64](123), 573 | UserName: ptr("bob"), 574 | }, 575 | { 576 | UserID: ptr[int64](234), 577 | UserName: ptr("dirac"), 578 | }, 579 | }, 580 | }, func(expression clause.Expression, sql string, err error) { 581 | as.Nil(err) 582 | as.Equal("SELECT * FROM `user` WHERE (`user_age` = 18 AND ((`user_id` = 123 OR `user_name` = 'bob') OR (`user_id` = 234 OR `user_name` = 'dirac')))", sql) 583 | 584 | exprs := assertExprList[clause.AndConditions](t, expression, 2) 585 | assertExprEq[clause.Eq](t, exprs[0], "user_age", int64(18)) 586 | exprs2 := assertExprList[clause.OrConditions](t, exprs[1], 2) 587 | 588 | expr3 := assertExprList[clause.OrConditions](t, exprs2[0], 2) 589 | expr4 := assertExprList[clause.OrConditions](t, exprs2[1], 2) 590 | 591 | assertExprEq[clause.Eq](t, expr3[0], "user_id", int64(123)) 592 | assertExprEq[clause.Eq](t, expr3[1], "user_name", "bob") 593 | 594 | assertExprEq[clause.Eq](t, expr4[0], "user_id", int64(234)) 595 | assertExprEq[clause.Eq](t, expr4[1], "user_name", "dirac") 596 | }) 597 | }) 598 | }) 599 | 600 | t.Run("json_contains", func(t *testing.T) { 601 | testBuildSQLWhere(struct { 602 | JSONField *string `gorm:"column:json_field; json_contains:json_key"` 603 | }{}, func(expression clause.Expression, sql string, err error) { 604 | }) 605 | }) 606 | } 607 | 608 | func assertExprIn(t *testing.T, expression clause.Expression, column string, value any) { 609 | as := assert.New(t) 610 | 611 | eq, ok := expression.(clause.IN) 612 | as.True(ok) 613 | 614 | _, ok = eq.Column.(clause.Column) 615 | as.True(ok) 616 | 617 | as.Equal(column, eq.Column.(clause.Column).Name) 618 | as.Equal(value, eq.Values) 619 | } 620 | 621 | func assertExprEq[T any](t *testing.T, expression clause.Expression, column string, value any) { 622 | as := assert.New(t) 623 | 624 | _, ok := expression.(T) 625 | if !ok { 626 | t.Errorf("expression(%T) is not %T", expression, new(T)) 627 | return 628 | } 629 | 630 | eqType := reflect.TypeOf(clause.Eq{}) 631 | ev := reflect.ValueOf(expression) 632 | if !ev.CanConvert(eqType) { 633 | t.Errorf("expression can not convert to clause.Eq") 634 | return 635 | } 636 | eq := ev.Convert(eqType).Interface().(clause.Eq) 637 | 638 | _, ok = eq.Column.(clause.Column) 639 | as.True(ok) 640 | 641 | as.Equal(column, eq.Column.(clause.Column).Name) 642 | as.Equal(value, eq.Value) 643 | } 644 | 645 | func assertExprNotIn(t *testing.T, expression clause.Expression, column string, value any) { 646 | as := assert.New(t) 647 | 648 | eq, ok := expression.(notIn) 649 | as.True(ok) 650 | 651 | as.Equal(column, eq.in.Column.(clause.Column).Name) 652 | as.Equal(value, eq.in.Values) 653 | } 654 | 655 | func assertExprList[T any](t *testing.T, expression clause.Expression, length int) []clause.Expression { 656 | as := assert.New(t) 657 | 658 | if length == 0 { 659 | as.Nil(expression) 660 | return nil 661 | } 662 | 663 | if _, ok := expression.(T); !ok { 664 | t.Errorf("expression(%T) is not %T", expression, new(T)) 665 | return nil 666 | } 667 | 668 | rv := reflect.ValueOf(expression) 669 | if rv.Kind() == reflect.Ptr { 670 | rv = rv.Elem() 671 | } 672 | exprs := rv.FieldByName("Exprs").Interface().([]clause.Expression) 673 | 674 | // eq, ok := expression.(clause.AndConditions) 675 | // as.True(ok) 676 | 677 | as.Equal(length, len(exprs)) 678 | 679 | return exprs 680 | } 681 | 682 | func toAnySlice[T any](data []T) []any { 683 | var result []any 684 | for _, v := range data { 685 | result = append(result, v) 686 | } 687 | return result 688 | } 689 | 690 | func Test_getQueryExpr(t *testing.T) { 691 | as := assert.New(t) 692 | 693 | t.Run("not found", func(t *testing.T) { 694 | _, err := getQueryExpr("not found") 695 | as.NotNil(err) 696 | as.Equal("query_expr 'not found' invalid", err.Error()) 697 | }) 698 | } 699 | 700 | func Test_buildClauseExpression(t *testing.T) { 701 | as := assert.New(t) 702 | 703 | t.Run("invalid query_expr", func(t *testing.T) { 704 | _, err := buildClauseExpression(reflect.ValueOf(struct { 705 | Name string `query_expr:"invalid"` 706 | }{Name: "str"}), &structType{ 707 | Names: []string{"Name"}, 708 | Fields: map[string]*fieldType{ 709 | "Name": { 710 | Name: "Name", 711 | QueryExpr: "invalid", 712 | }, 713 | }, 714 | }, true) 715 | as.NotNil(err) 716 | as.Equal("query_expr 'invalid' invalid", err.Error()) 717 | }) 718 | } 719 | --------------------------------------------------------------------------------