├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── dependabot.yml └── workflows │ ├── reviewdog.yml │ └── tests.yml ├── .gitignore ├── .golangci.yml ├── License ├── README.md ├── condition.go ├── config.go ├── do.go ├── do_options.go ├── do_test.go ├── errors.go ├── examples ├── README.md ├── biz │ └── query.go ├── cmd │ ├── from_object │ │ ├── generate.go │ │ └── prepare.go │ ├── gen │ │ ├── generate.go │ │ └── prepare.go │ ├── only_model │ │ ├── generate.go │ │ └── prepare.go │ ├── sync_table │ │ ├── generate.go │ │ └── prepare.go │ ├── ultimate │ │ ├── generate.go │ │ └── prepare.go │ └── without_db │ │ └── generate.go ├── conf │ └── mysql.go ├── dal │ ├── model │ │ ├── model.go │ │ └── mytables.gen.go │ ├── mysql.go │ └── query │ │ ├── gen.go │ │ └── mytables.gen.go ├── generate.sh ├── go.mod ├── go.sum └── main.go ├── field ├── assign_attr.go ├── association.go ├── asterisk.go ├── bool.go ├── doc.go ├── example_test.go ├── export.go ├── export_test.go ├── expr.go ├── field.go ├── field_test.go ├── float.go ├── function.go ├── int.go ├── serializer.go ├── string.go ├── tag.go └── time.go ├── field_options.go ├── generator.go ├── generator_test.go ├── go.mod ├── go.sum ├── helper ├── clause.go └── object.go ├── import.go ├── interface.go ├── internal ├── generate │ ├── clause.go │ ├── clause_test.go │ ├── export.go │ ├── generate.go │ ├── interface.go │ ├── query.go │ ├── section.go │ ├── table.go │ ├── test.go │ └── utils.go ├── model │ ├── base.go │ ├── config.go │ ├── options.go │ ├── tbl_column.go │ └── tbl_index.go ├── parser │ ├── export.go │ ├── method.go │ ├── parser.go │ └── utils.go ├── template │ ├── base.go │ ├── method.go │ ├── model.go │ ├── query.go │ └── struct.go └── utils │ ├── common.go │ └── pools │ ├── export.go │ └── pool.go ├── sec_check.go ├── tests ├── .expect │ ├── dal_1 │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ ├── dal_2 │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── banks.gen.go │ │ │ ├── banks.gen_test.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── credit_cards.gen_test.go │ │ │ ├── customers.gen.go │ │ │ ├── customers.gen_test.go │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── people.gen.go │ │ │ ├── people.gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_3 │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── banks.gen.go │ │ │ ├── banks.gen_test.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── credit_cards.gen_test.go │ │ │ ├── customers.gen.go │ │ │ ├── customers.gen_test.go │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── people.gen.go │ │ │ ├── people.gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_4 │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── banks.gen.go │ │ │ ├── banks.gen_test.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── credit_cards.gen_test.go │ │ │ ├── customers.gen.go │ │ │ ├── customers.gen_test.go │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── people.gen.go │ │ │ ├── people.gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_5 │ │ ├── model │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_6 │ │ ├── model │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_7 │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ └── customers.gen.go │ │ └── query │ │ │ ├── customers.gen.go │ │ │ ├── customers.gen_test.go │ │ │ ├── gen.go │ │ │ └── gen_test.go │ ├── dal_8 │ │ └── query │ │ │ ├── comments.gen.go │ │ │ ├── comments.gen_test.go │ │ │ ├── gen.go │ │ │ ├── gen_test.go │ │ │ ├── posts.gen.go │ │ │ ├── posts.gen_test.go │ │ │ ├── users.gen.go │ │ │ └── users.gen_test.go │ ├── dal_test │ │ ├── model │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ │ └── query │ │ │ ├── banks.gen.go │ │ │ ├── credit_cards.gen.go │ │ │ ├── customers.gen.go │ │ │ ├── gen.go │ │ │ ├── people.gen.go │ │ │ └── users.gen.go │ └── dal_test_relation │ │ ├── model │ │ ├── banks.gen.go │ │ ├── credit_cards.gen.go │ │ └── customers.gen.go │ │ └── query │ │ ├── banks.gen.go │ │ ├── credit_cards.gen.go │ │ ├── customers.gen.go │ │ └── gen.go ├── .gitignore ├── README.md ├── create_test.go ├── ddl_test.go ├── diy_method │ └── method.go ├── docker-compose.yml ├── gen_test.go ├── generate_test.go ├── go.mod ├── query_test.go ├── tables.sql ├── test.sh ├── tests_test.go └── transaction_test.go └── tools └── gentool ├── README.ZH_CN.md ├── README.md ├── gen.yml ├── gentool.go ├── go.mod └── go.sum /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Report a bug 3 | about: Your issue may already be reported! please search on the https://github.com/go-gorm/gorm/issues before creating one 🥳 4 | labels: type:bug 5 | assignees: riverchu 6 | 7 | --- 8 | 9 | ## GORM Playground Link 10 | 11 | 18 | 19 | https://github.com/go-gorm/playground/pull/1 20 | 21 | ## Description 22 | 23 | 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🍻 Feature request 3 | about: Suggest an idea, Pull Request welcome 🚀 4 | labels: type:feature 5 | assignees: riverchu 6 | 7 | --- 8 | 9 | 10 | 11 | ## Describe the feature 12 | 13 | 14 | 15 | ## Motivation 16 | 17 | 18 | 19 | ## Related Issues 20 | 21 | 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 💬 Question 3 | about: The resources of the GORM team are limited, please search documents/google/issues/test cases before ask 🙏 4 | labels: type:question 5 | assignees: riverchu 6 | 7 | --- 8 | 9 | 10 | 11 | ## Your Question 12 | 13 | 14 | 15 | 16 | 17 | ## The document you expected this should be explained 18 | 19 | 20 | 21 | ## Expected answer 22 | 23 | 24 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: gomod 5 | directory: / 6 | schedule: 7 | interval: weekly 8 | - package-ecosystem: github-actions 9 | directory: / 10 | schedule: 11 | interval: weekly 12 | - package-ecosystem: gomod 13 | directory: /tests 14 | schedule: 15 | interval: weekly 16 | - package-ecosystem: gomod 17 | directory: /tools/gentool 18 | schedule: 19 | interval: weekly 20 | -------------------------------------------------------------------------------- /.github/workflows/reviewdog.yml: -------------------------------------------------------------------------------- 1 | name: reviewdog 2 | on: [pull_request] 3 | jobs: 4 | golangci-lint: 5 | name: runner / golangci-lint 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Check out code into the Go module directory 9 | uses: actions/checkout@v3 10 | 11 | - name: golangci-lint 12 | uses: reviewdog/action-golangci-lint@v2 13 | with: 14 | golangci_lint_flags: '--timeout 5m' 15 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches-ignore: 8 | - "gh-pages" 9 | 10 | jobs: 11 | # Label of the container job 12 | tests: 13 | strategy: 14 | matrix: 15 | go: ["1.21", "1.20", "1.19", "1.18"] 16 | platform: [ubuntu-latest] # can not run in windows OS 17 | runs-on: ${{ matrix.platform }} 18 | 19 | steps: 20 | - name: Set up Go 1.x 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: ${{ matrix.go }} 24 | 25 | - name: Check out code into the Go module directory 26 | uses: actions/checkout@v3 27 | 28 | - name: go mod package cache 29 | uses: actions/cache@v3 30 | with: 31 | path: ~/go/pkg/mod 32 | key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('go.sum') }} 33 | 34 | - name: Tests 35 | run: go test ./... 36 | 37 | mysql: 38 | strategy: 39 | matrix: 40 | dbversion: ["mysql:5.7", "mysql:latest"] 41 | go: ["1.21", "1.20", "1.19", "1.18"] 42 | platform: [ubuntu-latest] 43 | runs-on: ${{ matrix.platform }} 44 | 45 | services: 46 | mysql: 47 | image: ${{ matrix.dbversion }} 48 | env: 49 | MYSQL_DATABASE: gen 50 | MYSQL_USER: gen 51 | MYSQL_PASSWORD: gen 52 | MYSQL_ROOT_PASSWORD: 123456 53 | ports: 54 | - 9910:3306 55 | options: >- 56 | --health-cmd "mysqladmin ping -ugen -pgen" 57 | --health-interval 10s 58 | --health-start-period 10s 59 | --health-timeout 5s 60 | --health-retries 10 61 | --name mysql_server 62 | 63 | steps: 64 | - name: Change MySQL sql_mode 65 | run: docker exec mysql_server mysql -uroot -p123456 -e "SET GLOBAL sql_mode = 'NO_ENGINE_SUBSTITUTION';" 66 | 67 | - name: Set up Go 1.x 68 | uses: actions/setup-go@v5 69 | with: 70 | go-version: ${{ matrix.go }} 71 | 72 | - name: Check out code into the Go module directory 73 | uses: actions/checkout@v3 74 | 75 | - name: go mod package cache 76 | uses: actions/cache@v3 77 | with: 78 | path: ~/go/pkg/mod 79 | key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} 80 | 81 | - name: Tests 82 | run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gen:gen@tcp(localhost:9910)/gen?charset=utf8&parseTime=True" ./tests/test.sh 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.a 3 | *.so 4 | _obj 5 | _test 6 | *.[568vq] 7 | [568vq].out 8 | *.cgo1.go 9 | *.cgo2.c 10 | _cgo_defun.c 11 | _cgo_gotypes.go 12 | _cgo_export.* 13 | _testmain.go 14 | *.exe 15 | *.exe~ 16 | *.test 17 | *.prof 18 | *.rar 19 | *.zip 20 | *.gz 21 | *.psd 22 | *.bmd 23 | *.cfg 24 | *.pptx 25 | *.log 26 | *nohup.out 27 | *settings.pyc 28 | *.sublime-project 29 | *.sublime-workspace 30 | !.gitkeep 31 | .DS_Store 32 | /.idea 33 | /.vscode 34 | __debug_bin 35 | /test/ 36 | **/go.work 37 | **/go.work.sum 38 | 39 | *.db 40 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | run: 3 | tests: false 4 | 5 | linters: 6 | enable: 7 | - bodyclose 8 | - revive 9 | - unparam 10 | exclusions: 11 | generated: lax 12 | paths: 13 | - third_party$ 14 | - builtin$ 15 | - examples$ 16 | 17 | formatters: 18 | enable: 19 | - goimports 20 | settings: 21 | gofmt: 22 | simplify: true 23 | goimports: 24 | local-prefixes: 25 | - gorm.io 26 | - gorm.io/gen 27 | exclusions: 28 | generated: lax 29 | paths: 30 | - third_party$ 31 | - builtin$ 32 | - examples$ 33 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021-NOW Jinzhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM Gen 2 | 3 | Friendly & Safer GORM powered by Code Generation. 4 | 5 | [![Release](https://img.shields.io/github/v/release/go-gorm/gen)](https://github.com/go-gorm/gen/releases) 6 | [![Go Report Card](https://goreportcard.com/badge/github.com/go-gorm/gen)](https://goreportcard.com/report/github.com/go-gorm/gen) 7 | [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 8 | [![OpenIssue](https://img.shields.io/github/issues/go-gorm/gen)](https://github.com/go-gorm/gen/issues?q=is%3Aopen+is%3Aissue) 9 | [![ClosedIssue](https://img.shields.io/github/issues-closed/go-gorm/gen)](https://github.com/go-gorm/gen/issues?q=is%3Aissue+is%3Aclosed) 10 | [![TODOs](https://badgen.net/https/api.tickgit.com/badgen/github.com/go-gorm/gen)](https://www.tickgit.com/browse?repo=github.com/go-gorm/gen) 11 | [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gen?tab=doc) 12 | 13 | ## Overview 14 | 15 | - Idiomatic & Reusable API from Dynamic Raw SQL 16 | - 100% Type-safe DAO API without `interface{}` 17 | - Database To Struct follows GORM conventions 18 | - GORM under the hood, supports all features, plugins, DBMS that GORM supports 19 | 20 | ## Getting Started 21 | 22 | * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) 23 | * GORM Guides [http://gorm.io/docs](http://gorm.io/docs) 24 | 25 | ## Maintainers 26 | 27 | [@riverchu](https://github.com/riverchu) [@iDer](https://github.com/idersec) [@qqxhb](https://github.com/qqxhb) [@dino-ma](https://github.com/dino-ma) 28 | 29 | [@jinzhu](https://github.com/jinzhu) 30 | 31 | ## Contributing 32 | 33 | [You can help to deliver a better GORM/Gen, check out things you can do](https://gorm.io/contribute.html) 34 | 35 | ## License 36 | 37 | Released under the [MIT License](https://github.com/go-gorm/gen/blob/master/License) 38 | -------------------------------------------------------------------------------- /condition.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/datatypes" 7 | "gorm.io/gen/field" 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | // Cond convert expression array to condition array 12 | func Cond(exprs ...clause.Expression) []Condition { 13 | return exprToCondition(exprs...) 14 | } 15 | 16 | var _ Condition = &condContainer{} 17 | 18 | type condContainer struct { 19 | value interface{} 20 | err error 21 | } 22 | 23 | func (c *condContainer) BeCond() interface{} { return c.value } 24 | func (c *condContainer) CondError() error { return c.err } 25 | 26 | func exprToCondition(exprs ...clause.Expression) []Condition { 27 | conds := make([]Condition, 0, len(exprs)) 28 | for _, e := range exprs { 29 | switch e := e.(type) { 30 | case *datatypes.JSONQueryExpression, *datatypes.JSONOverlapsExpression, *datatypes.JSONArrayExpression: 31 | conds = append(conds, &condContainer{value: e}) 32 | default: 33 | conds = append(conds, &condContainer{err: fmt.Errorf("unsupported Expression %T to converted to Condition", e)}) 34 | } 35 | } 36 | return conds 37 | } 38 | 39 | func condToExpression(conds []Condition) ([]clause.Expression, error) { 40 | if len(conds) == 0 { 41 | return nil, nil 42 | } 43 | exprs := make([]clause.Expression, 0, len(conds)) 44 | for _, cond := range conds { 45 | if cond == nil { 46 | continue 47 | } 48 | if err := cond.CondError(); err != nil { 49 | return nil, err 50 | } 51 | 52 | switch cond.(type) { 53 | case *condContainer, field.Expr, SubQuery: 54 | default: 55 | return nil, fmt.Errorf("unsupported condition: %+v", cond) 56 | } 57 | 58 | switch e := cond.BeCond().(type) { 59 | case []clause.Expression: 60 | exprs = append(exprs, e...) 61 | case clause.Expression: 62 | exprs = append(exprs, e) 63 | } 64 | } 65 | return exprs, nil 66 | } 67 | -------------------------------------------------------------------------------- /do_options.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | // DOOption gorm option interface 4 | type DOOption interface { 5 | Apply(*DOConfig) error 6 | AfterInitialize(*DO) error 7 | } 8 | 9 | type DOConfig struct { 10 | } 11 | 12 | // Apply update config to new config 13 | func (c *DOConfig) Apply(config *DOConfig) error { 14 | if config != c { 15 | *config = *c 16 | } 17 | return nil 18 | } 19 | 20 | // AfterInitialize initialize plugins after db connected 21 | func (c *DOConfig) AfterInitialize(db *DO) error { 22 | return nil 23 | } 24 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrEmptyCondition empty condition 7 | ErrEmptyCondition = errors.New("empty condition") 8 | ) 9 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # examples 2 | 3 | ***Run generate.sh*** 4 | 5 | 一个简单的`GEN`最佳实践。你可以通过配置`generate.sh`中的`TARGET_DIR`值指定执行不同的代码生成命令。 6 | 7 | A simple best practice of `GEN`. You can configure `TARGET_DIR` value in `generate.sh` to generate different code. 8 | 9 | ## TARGET_DIR 10 | 11 | - `gen` 12 | a slim quick start. 13 | 14 | - `ultimate` 15 | a ultimate quick start 16 | 17 | - `sync_table` 18 | a quick start to show how to sync table from database. 19 | 20 | - `without_db` 21 | a quick start to show how to generate code without database connection. 22 | -------------------------------------------------------------------------------- /examples/biz/query.go: -------------------------------------------------------------------------------- 1 | package biz 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "gorm.io/gen/examples/dal/query" 8 | ) 9 | 10 | var q = query.Q 11 | 12 | func Query(ctx context.Context) { 13 | t := q.Mytable 14 | do := t.WithContext(context.Background()) 15 | 16 | data, err := do.Take() 17 | catchError("Take", err) 18 | fmt.Printf("got %+v\n", data) 19 | 20 | dataArray, err := do.Find() 21 | catchError("Find", err) 22 | fmt.Printf("got %+v\n", dataArray) 23 | 24 | data, err = do.Where(t.ID.Eq(1)).Take() 25 | catchError("Take", err) 26 | fmt.Printf("got %+v\n", data) 27 | 28 | dataArray, err = do.Where(t.Age.Gt(18)).Order(t.Username).Find() 29 | catchError("Find", err) 30 | fmt.Printf("got %+v\n", dataArray) 31 | 32 | dataArray, err = do.Select(t.ID, t.Username).Order(t.Age.Desc()).Find() 33 | catchError("Find", err) 34 | fmt.Printf("got %+v\n", dataArray) 35 | 36 | info, err := do.Where(t.ID.Eq(1)).UpdateSimple(t.Age.Add(1)) 37 | catchError("Update", err) 38 | fmt.Printf("got %+v\n", info) 39 | } 40 | 41 | func catchError(detail string, err error) { 42 | if err != nil { 43 | fmt.Printf("%s: %v\n", detail, err) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /examples/cmd/from_object/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gen" 5 | "gorm.io/gen/helper" 6 | ) 7 | 8 | var detail, data helper.Object 9 | 10 | func init() { 11 | detail = &Demo{ 12 | structName: "Detail", 13 | fileName: "diy_data_detail", 14 | fields: []helper.Field{ 15 | &DemoField{ 16 | name: "Username", 17 | typ: "string", 18 | jsonTag: "username", 19 | comment: "用户名", 20 | }, 21 | &DemoField{ 22 | name: "Age", 23 | typ: "uint", 24 | jsonTag: "age", 25 | comment: "用户年龄", 26 | }, 27 | &DemoField{ 28 | name: "Phone", 29 | typ: "string", 30 | jsonTag: "phone", 31 | comment: "手机号", 32 | }, 33 | }, 34 | } 35 | 36 | data = &Demo{ 37 | structName: "Data", 38 | tableName: "data", 39 | fileName: "diy_data", 40 | fields: []helper.Field{ 41 | &DemoField{ 42 | name: "ID", 43 | typ: "uint", 44 | gormTag: "column:id;type:bigint unsigned;primaryKey;autoIncrement:true", 45 | jsonTag: "id", 46 | tag: `kms:"enc:aes"`, 47 | comment: "主键", 48 | }, 49 | &DemoField{ 50 | name: "UserInfo", 51 | typ: "[]Detail", 52 | jsonTag: "user_info", 53 | comment: "用户信息", 54 | }, 55 | &DemoField{ 56 | name: "Remark", 57 | typ: "json.RawMessage", 58 | gormTag: "column:detail", 59 | jsonTag: "remark", 60 | tag: `kms:"enc:aes"`, 61 | comment: "备注\n详细信息", 62 | }, 63 | }, 64 | } 65 | } 66 | 67 | func main() { 68 | g := gen.NewGenerator(gen.Config{ 69 | OutPath: "/tmp/gentest/query", 70 | ModelPkgPath: "/tmp/gentest/demo", 71 | }) 72 | 73 | g.GenerateModelFrom(detail) 74 | 75 | g.ApplyBasic(g.GenerateModelFrom(data)) 76 | 77 | g.Execute() 78 | } 79 | -------------------------------------------------------------------------------- /examples/cmd/from_object/prepare.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | 6 | "gorm.io/gen/field" 7 | "gorm.io/gen/helper" 8 | ) 9 | 10 | var _ helper.Object = new(Demo) 11 | 12 | // Demo demo structure 13 | type Demo struct { 14 | structName string 15 | tableName string 16 | fileName string 17 | fields []helper.Field 18 | } 19 | 20 | // TableName return table name 21 | func (d *Demo) TableName() string { return d.tableName } 22 | 23 | // StructName return struct name 24 | func (d *Demo) StructName() string { return d.structName } 25 | 26 | // FileName return file name 27 | func (d *Demo) FileName() string { return d.fileName } 28 | 29 | // ImportPkgPaths return import package paths 30 | func (d *Demo) ImportPkgPaths() []string { return nil } 31 | 32 | // Fields return fields 33 | func (d *Demo) Fields() []helper.Field { return d.fields } 34 | 35 | // DemoField demo field 36 | type DemoField struct { 37 | name string 38 | typ string 39 | gormTag string 40 | jsonTag string 41 | tag string 42 | comment string 43 | } 44 | 45 | // Name return name 46 | func (f *DemoField) Name() string { return f.name } 47 | 48 | // Type return field type 49 | func (f *DemoField) Type() string { return f.typ } 50 | 51 | // ColumnName return column name 52 | func (f *DemoField) ColumnName() string { return strings.ToLower(f.name) } 53 | 54 | // GORMTag return gorm tag 55 | func (f *DemoField) GORMTag() string { return f.gormTag } 56 | 57 | // JSONTag return json tag 58 | func (f *DemoField) JSONTag() string { return f.jsonTag } 59 | 60 | // Tag return new tag 61 | func (f *DemoField) Tag() field.Tag { return field.Tag{} } 62 | 63 | // Comment return comment 64 | func (f *DemoField) Comment() string { return f.comment } 65 | -------------------------------------------------------------------------------- /examples/cmd/gen/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gen" 5 | "gorm.io/gen/examples/conf" 6 | "gorm.io/gen/examples/dal" 7 | ) 8 | 9 | func init() { 10 | dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug() 11 | 12 | prepare(dal.DB) // prepare table for generate 13 | } 14 | 15 | func main() { 16 | g := gen.NewGenerator(gen.Config{ 17 | OutPath: "../../dal/query", 18 | }) 19 | 20 | g.UseDB(dal.DB) 21 | 22 | // generate all table from database 23 | g.ApplyBasic(g.GenerateAllTable()...) 24 | 25 | g.Execute() 26 | } 27 | -------------------------------------------------------------------------------- /examples/cmd/gen/prepare.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | // prepare table for test 8 | 9 | const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" + 10 | " `ID` int(11) NOT NULL," + 11 | " `username` varchar(16) DEFAULT NULL," + 12 | " `age` int(8) NOT NULL," + 13 | " `phone` varchar(11) NOT NULL," + 14 | " INDEX `idx_username` (`username`)" + 15 | ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;" 16 | 17 | func prepare(db *gorm.DB) { 18 | db.Exec(mytableSQL) 19 | } 20 | -------------------------------------------------------------------------------- /examples/cmd/only_model/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gen" 5 | "gorm.io/gen/examples/conf" 6 | "gorm.io/gen/examples/dal" 7 | ) 8 | 9 | func init() { 10 | dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug() 11 | 12 | prepare(dal.DB) // prepare table for generate 13 | } 14 | 15 | func main() { 16 | g := gen.NewGenerator(gen.Config{ 17 | OutPath: "/tmp/gentest/query", 18 | }) 19 | 20 | g.UseDB(dal.DB) 21 | 22 | g.GenerateAllTable() 23 | 24 | g.Execute() 25 | } 26 | -------------------------------------------------------------------------------- /examples/cmd/only_model/prepare.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | // prepare table for test 8 | 9 | const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" + 10 | " `ID` int(11) NOT NULL," + 11 | " `username` varchar(16) DEFAULT NULL," + 12 | " `age` int(8) NOT NULL," + 13 | " `phone` varchar(11) NOT NULL," + 14 | " INDEX `idx_username` (`username`)" + 15 | ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;" 16 | 17 | func prepare(db *gorm.DB) { 18 | db.Exec(mytableSQL) 19 | } 20 | -------------------------------------------------------------------------------- /examples/cmd/sync_table/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | 6 | "gorm.io/gen" 7 | "gorm.io/gen/examples/conf" 8 | "gorm.io/gen/examples/dal" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | func init() { 13 | dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug() 14 | 15 | prepare(dal.DB) // prepare table for generate 16 | } 17 | 18 | // dataMap mapping relationship 19 | var dataMap = map[string]func(gorm.ColumnType) (dataType string){ 20 | // int mapping 21 | "int": func(columnType gorm.ColumnType) (dataType string) { return "int32" }, 22 | 23 | // bool mapping 24 | "tinyint": func(columnType gorm.ColumnType) (dataType string) { 25 | ct, _ := columnType.ColumnType() 26 | if strings.HasPrefix(ct, "tinyint(1)") { 27 | return "bool" 28 | } 29 | return "byte" 30 | }, 31 | } 32 | 33 | func main() { 34 | g := gen.NewGenerator(gen.Config{ 35 | OutPath: "../../dal/query", 36 | ModelPkgPath: "../../dal/model", 37 | 38 | // generate model global configuration 39 | FieldNullable: true, // generate pointer when field is nullable 40 | FieldCoverable: true, // generate pointer when field has default value 41 | FieldWithIndexTag: true, // generate with gorm index tag 42 | FieldWithTypeTag: true, // generate with gorm column type tag 43 | }) 44 | 45 | g.UseDB(dal.DB) 46 | 47 | // specify diy mapping relationship 48 | g.WithDataTypeMap(dataMap) 49 | 50 | // generate all field with json tag end with "_example" 51 | g.WithJSONTagNameStrategy(func(c string) string { return c + "_example" }) 52 | 53 | mytable := g.GenerateModel("mytables") 54 | g.ApplyBasic(mytable) 55 | // g.ApplyBasic(g.GenerateAllTable()...) // generate all table in db server 56 | 57 | g.Execute() 58 | } 59 | -------------------------------------------------------------------------------- /examples/cmd/sync_table/prepare.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | // prepare table for test 8 | 9 | const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" + 10 | " `ID` int(11) NOT NULL," + 11 | " `username` varchar(16) DEFAULT NULL," + 12 | " `age` int(8) NOT NULL," + 13 | " `phone` varchar(11) NOT NULL," + 14 | " INDEX `idx_username` (`username`)" + 15 | ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;" 16 | 17 | func prepare(db *gorm.DB) { 18 | db.Exec(mytableSQL) 19 | } 20 | -------------------------------------------------------------------------------- /examples/cmd/ultimate/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gen" 5 | "gorm.io/gen/examples/conf" 6 | "gorm.io/gen/examples/dal" 7 | "gorm.io/gen/examples/dal/model" 8 | "gorm.io/gorm" 9 | ) 10 | 11 | func init() { 12 | dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug() 13 | 14 | prepare(dal.DB) // prepare table for generate 15 | } 16 | 17 | var dataMap = map[string]func(gorm.ColumnType) (dataType string){ 18 | "int": func(columnType gorm.ColumnType) (dataType string) { return "int64" }, 19 | "json": func(columnType gorm.ColumnType) string { return "json.RawMessage" }, 20 | } 21 | 22 | func main() { 23 | g := gen.NewGenerator(gen.Config{ 24 | OutPath: "../../dal/query", 25 | Mode: gen.WithDefaultQuery, 26 | 27 | WithUnitTest: true, 28 | 29 | FieldNullable: true, 30 | FieldCoverable: true, 31 | FieldWithIndexTag: true, 32 | }) 33 | 34 | g.UseDB(dal.DB) 35 | 36 | g.WithDataTypeMap(dataMap) 37 | g.WithJSONTagNameStrategy(func(c string) string { return "-" }) 38 | 39 | g.ApplyBasic(model.Customer{}) 40 | g.ApplyBasic(g.GenerateAllTable()...) 41 | 42 | g.Execute() 43 | } 44 | -------------------------------------------------------------------------------- /examples/cmd/ultimate/prepare.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | // prepare table for test 8 | 9 | const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" + 10 | " `ID` int(11) NOT NULL," + 11 | " `username` varchar(16) DEFAULT NULL," + 12 | " `age` int(8) NOT NULL," + 13 | " `phone` varchar(11) NOT NULL," + 14 | " INDEX `idx_username` (`username`)" + 15 | ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;" 16 | 17 | func prepare(db *gorm.DB) { 18 | db.Exec(mytableSQL) 19 | } 20 | -------------------------------------------------------------------------------- /examples/cmd/without_db/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "gorm.io/gen" 5 | "gorm.io/gen/examples/dal/model" 6 | ) 7 | 8 | func main() { 9 | g := gen.NewGenerator(gen.Config{ 10 | OutPath: "../../dal/query", 11 | Mode: gen.WithDefaultQuery, 12 | }) 13 | 14 | // generate from struct in project 15 | g.ApplyBasic(model.Customer{}) 16 | 17 | g.Execute() 18 | } 19 | -------------------------------------------------------------------------------- /examples/conf/mysql.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | const MySQLDSN = "root:local_mysql_test@tcp(localhost:3306)/test?charset=utf8mb4&parseTime=True" 4 | 5 | const SQLiteDBName = "gen_sqlite.db" 6 | -------------------------------------------------------------------------------- /examples/dal/model/model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "gorm.io/gorm" 4 | 5 | // some struct implement manually 6 | 7 | // Customer a struct mapping to table customers 8 | type Customer struct { 9 | gorm.Model 10 | 11 | Name string `gorm:"type:varchar(100);not null"` 12 | Age int `gorm:"type:int"` 13 | Phone string `gorm:"type:varchar(11)"` 14 | Address string `gorm:"type:text"` 15 | Amount float64 `gorm:"type:float"` 16 | } 17 | 18 | func (Customer) TableName() string { 19 | return "customers" 20 | } 21 | -------------------------------------------------------------------------------- /examples/dal/model/mytables.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameMytable = "mytables" 8 | 9 | // Mytable mapped from table 10 | type Mytable struct { 11 | ID int32 `gorm:"column:ID;type:int(11);not null" json:"ID_example"` 12 | Username *string `gorm:"column:username;type:varchar(16);index:idx_username,priority:1;default:NULL" json:"username_example"` 13 | Age int32 `gorm:"column:age;type:int(8);not null" json:"age_example"` 14 | Phone string `gorm:"column:phone;type:varchar(11);not null" json:"phone_example"` 15 | } 16 | 17 | // TableName Mytable's table name 18 | func (*Mytable) TableName() string { 19 | return TableNameMytable 20 | } 21 | -------------------------------------------------------------------------------- /examples/dal/mysql.go: -------------------------------------------------------------------------------- 1 | package dal 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "gorm.io/driver/mysql" 8 | "gorm.io/driver/sqlite" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | var DB *gorm.DB 13 | 14 | func ConnectDB(dsn string) (db *gorm.DB) { 15 | var err error 16 | 17 | if strings.HasSuffix(dsn, "sqlite.db") { 18 | db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) 19 | } else { 20 | db, err = gorm.Open(mysql.Open(dsn)) 21 | } 22 | 23 | if err != nil { 24 | panic(fmt.Errorf("connect db fail: %w", err)) 25 | } 26 | 27 | return db 28 | } 29 | -------------------------------------------------------------------------------- /examples/dal/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | ) 13 | 14 | var ( 15 | Q = new(Query) 16 | Mytable *mytable 17 | ) 18 | 19 | func SetDefault(db *gorm.DB) { 20 | *Q = *Use(db) 21 | Mytable = &Q.Mytable 22 | } 23 | 24 | func Use(db *gorm.DB) *Query { 25 | return &Query{ 26 | db: db, 27 | Mytable: newMytable(db), 28 | } 29 | } 30 | 31 | type Query struct { 32 | db *gorm.DB 33 | 34 | Mytable mytable 35 | } 36 | 37 | func (q *Query) Available() bool { return q.db != nil } 38 | 39 | func (q *Query) clone(db *gorm.DB) *Query { 40 | return &Query{ 41 | db: db, 42 | Mytable: q.Mytable.clone(db), 43 | } 44 | } 45 | 46 | type queryCtx struct { 47 | Mytable mytableDo 48 | } 49 | 50 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 51 | return &queryCtx{ 52 | Mytable: *q.Mytable.WithContext(ctx), 53 | } 54 | } 55 | 56 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 57 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 58 | } 59 | 60 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 61 | return &QueryTx{q.clone(q.db.Begin(opts...))} 62 | } 63 | 64 | type QueryTx struct{ *Query } 65 | 66 | func (q *QueryTx) Commit() error { 67 | return q.db.Commit().Error 68 | } 69 | 70 | func (q *QueryTx) Rollback() error { 71 | return q.db.Rollback().Error 72 | } 73 | 74 | func (q *QueryTx) SavePoint(name string) error { 75 | return q.db.SavePoint(name).Error 76 | } 77 | 78 | func (q *QueryTx) RollbackTo(name string) error { 79 | return q.db.RollbackTo(name).Error 80 | } 81 | -------------------------------------------------------------------------------- /examples/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TARGET_DIR="gen" 4 | # TARGET_DIR="ultimate" 5 | # TARGET_DIR="sync_table" 6 | TARGET_DIR="without_db" 7 | 8 | PROJECT_DIR=$(dirname "$0") 9 | GENERATE_DIR="$PROJECT_DIR/cmd/$TARGET_DIR" 10 | 11 | cd "$GENERATE_DIR" || exit 12 | 13 | echo "Start Generating" 14 | go run . 15 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module examples 2 | 3 | go 1.19 4 | 5 | require ( 6 | gorm.io/driver/mysql v1.5.6 7 | gorm.io/driver/sqlite v1.5.5 8 | gorm.io/gen v0.3.25 9 | gorm.io/gorm v1.25.9 10 | ) 11 | 12 | require ( 13 | github.com/go-sql-driver/mysql v1.7.0 // indirect 14 | github.com/jinzhu/inflection v1.0.0 // indirect 15 | github.com/jinzhu/now v1.1.5 // indirect 16 | github.com/mattn/go-sqlite3 v1.14.17 // indirect 17 | golang.org/x/mod v0.14.0 // indirect 18 | golang.org/x/sys v0.14.0 // indirect 19 | golang.org/x/tools v0.15.0 // indirect 20 | gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect 21 | gorm.io/hints v1.1.0 // indirect 22 | gorm.io/plugin/dbresolver v1.5.0 // indirect 23 | ) 24 | 25 | replace gorm.io/gen => ../ 26 | -------------------------------------------------------------------------------- /examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "gorm.io/gen/examples/biz" 8 | "gorm.io/gen/examples/conf" 9 | "gorm.io/gen/examples/dal" 10 | "gorm.io/gen/examples/dal/query" 11 | ) 12 | 13 | func init() { 14 | dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug() 15 | } 16 | 17 | func main() { 18 | // start your project here 19 | fmt.Println("hello world") 20 | defer fmt.Println("bye~") 21 | 22 | query.SetDefault(dal.DB) 23 | biz.Query(context.Background()) 24 | } 25 | -------------------------------------------------------------------------------- /field/assign_attr.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/utils/tests" 9 | ) 10 | 11 | var testDB, _ = gorm.Open(tests.DummyDialector{}, nil) 12 | 13 | type IValues interface { 14 | Values() interface{} 15 | } 16 | 17 | type attrs struct { 18 | expr 19 | value interface{} 20 | db *gorm.DB 21 | selectFields []IColumnName 22 | omitFields []IColumnName 23 | } 24 | 25 | func (att *attrs) AssignExpr() expression { 26 | return att 27 | } 28 | 29 | func (att *attrs) BeCond() interface{} { 30 | return att.db.Statement.BuildCondition(att.Values()) 31 | } 32 | 33 | func (att *attrs) Values() interface{} { 34 | if att == nil || att.value == nil { 35 | return nil 36 | } 37 | if len(att.selectFields) == 0 && len(att.omitFields) == 0 { 38 | return att.value 39 | } 40 | values := make(map[string]interface{}) 41 | if value, ok := att.value.(map[string]interface{}); ok { 42 | values = value 43 | } else if value, ok := att.value.(*map[string]interface{}); ok { 44 | values = *value 45 | } else { 46 | reflectValue := reflect.Indirect(reflect.ValueOf(att.value)) 47 | for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { 48 | reflectValue = reflect.Indirect(reflectValue) 49 | } 50 | switch reflectValue.Kind() { 51 | case reflect.Struct: 52 | if err := att.db.Statement.Parse(att.value); err == nil { 53 | ignoreZero := len(att.selectFields) == 0 54 | for _, f := range att.db.Statement.Schema.Fields { 55 | if f.Readable { 56 | if v, isZero := f.ValueOf(att.db.Statement.Context, reflectValue); !isZero || !ignoreZero { 57 | values[f.DBName] = v 58 | } 59 | } 60 | } 61 | } 62 | } 63 | } 64 | if len(att.selectFields) > 0 { 65 | fm, all := toFieldMap(att.selectFields) 66 | if all { 67 | return values 68 | } 69 | tvs := make(map[string]interface{}, len(fm)) 70 | for fn, vl := range values { 71 | if fm[fn] { 72 | tvs[fn] = vl 73 | } 74 | } 75 | return tvs 76 | } 77 | fm, all := toFieldMap(att.omitFields) 78 | if all { 79 | return map[string]interface{}{} 80 | } 81 | for fn := range fm { 82 | delete(values, fn) 83 | } 84 | return values 85 | } 86 | 87 | func toFieldMap(fields []IColumnName) (fieldsMap map[string]bool, all bool) { 88 | fieldsMap = make(map[string]bool, len(fields)) 89 | for _, f := range fields { 90 | if strings.HasSuffix(string(f.ColumnName()), "*") { 91 | all = true 92 | return 93 | } 94 | fieldsMap[string(f.ColumnName())] = true 95 | } 96 | return 97 | } 98 | 99 | func (att *attrs) Select(fields ...IColumnName) *attrs { 100 | if att == nil || att.db == nil { 101 | return att 102 | } 103 | att.selectFields = fields 104 | return att 105 | } 106 | 107 | func (att *attrs) Omit(fields ...IColumnName) *attrs { 108 | if att == nil || att.db == nil { 109 | return att 110 | } 111 | att.omitFields = fields 112 | return att 113 | } 114 | 115 | func Attrs(attr interface{}) *attrs { 116 | res := &attrs{db: testDB.Debug()} 117 | if attr != nil { 118 | res.value = attr 119 | } 120 | return res 121 | } 122 | -------------------------------------------------------------------------------- /field/asterisk.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | "gorm.io/gorm/clause" 6 | ) 7 | 8 | // Asterisk a type of xxx.* 9 | type Asterisk struct{ asteriskExpr } 10 | 11 | // Count count 12 | func (a Asterisk) Count() Asterisk { 13 | var expr *clause.Expr 14 | switch { 15 | case a.e != nil: 16 | expr = &clause.Expr{ 17 | SQL: "COUNT(?)", 18 | Vars: []interface{}{a.e}, 19 | } 20 | case a.col.Table == "": 21 | expr = &clause.Expr{SQL: "COUNT(*)"} 22 | default: 23 | expr = &clause.Expr{ 24 | SQL: "COUNT(?.*)", 25 | Vars: []interface{}{clause.Table{Name: a.col.Table}}, 26 | } 27 | } 28 | return Asterisk{asteriskExpr{expr: a.setE(expr)}} 29 | } 30 | 31 | // Distinct distinct 32 | func (a Asterisk) Distinct() Asterisk { 33 | var expr *clause.Expr 34 | if a.col.Table == "" { 35 | expr = &clause.Expr{SQL: "DISTINCT *"} 36 | } else { 37 | expr = &clause.Expr{ 38 | SQL: "DISTINCT ?.*", 39 | Vars: []interface{}{clause.Table{Name: a.col.Table}}, 40 | } 41 | } 42 | return Asterisk{asteriskExpr{expr: a.setE(expr)}} 43 | } 44 | 45 | type asteriskExpr struct{ expr } 46 | 47 | func (e asteriskExpr) BuildWithArgs(*gorm.Statement) (query sql, args []interface{}) { 48 | // if e.expr has no expression it must be directly calling for "*" or "xxx.*" 49 | if e.e != nil { 50 | return "?", []interface{}{e.e} 51 | } 52 | if e.col.Table == "" { 53 | return "*", nil 54 | } 55 | return "?.*", []interface{}{clause.Table{Name: e.col.Table}} 56 | } 57 | -------------------------------------------------------------------------------- /field/bool.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | // Bool boolean type field 4 | type Bool Field 5 | 6 | // Not ... 7 | func (field Bool) Not() Bool { 8 | return Bool{field.not()} 9 | } 10 | 11 | // Is ... 12 | func (field Bool) Is(value bool) Expr { 13 | return field.is(value) 14 | } 15 | 16 | // And boolean and 17 | func (field Bool) And(value bool) Expr { 18 | return Bool{field.and(value)} 19 | } 20 | 21 | // Or boolean or 22 | func (field Bool) Or(value bool) Expr { 23 | return Bool{field.or(value)} 24 | } 25 | 26 | // Xor ... 27 | func (field Bool) Xor(value bool) Expr { 28 | return Bool{field.xor(value)} 29 | } 30 | 31 | // BitXor ... 32 | func (field Bool) BitXor(value bool) Expr { 33 | return Bool{field.bitXor(value)} 34 | } 35 | 36 | // BitAnd ... 37 | func (field Bool) BitAnd(value bool) Expr { 38 | return Bool{field.bitAnd(value)} 39 | } 40 | 41 | // BitOr ... 42 | func (field Bool) BitOr(value bool) Expr { 43 | return Bool{field.bitOr(value)} 44 | } 45 | 46 | // Value ... 47 | func (field Bool) Value(value bool) AssignExpr { 48 | return field.value(value) 49 | } 50 | 51 | // Zero ... 52 | func (field Bool) Zero() AssignExpr { 53 | return field.value(false) 54 | } 55 | -------------------------------------------------------------------------------- /field/doc.go: -------------------------------------------------------------------------------- 1 | // Package field implement all type field and method 2 | package field 3 | -------------------------------------------------------------------------------- /field/example_test.go: -------------------------------------------------------------------------------- 1 | package field_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/gen/field" 7 | ) 8 | 9 | func ExampleFunc() { 10 | expr := field.Func.UnixTimestamp() 11 | 12 | sql, vars := field.BuildToString(expr) 13 | fmt.Println(sql, vars) 14 | 15 | sql, vars = field.BuildToString(expr.Mul(100)) 16 | fmt.Println(sql, vars) 17 | 18 | // Output: 19 | // UNIX_TIMESTAMP() [] 20 | // (UNIX_TIMESTAMP())*? [100] 21 | } 22 | -------------------------------------------------------------------------------- /field/field.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "database/sql/driver" 5 | 6 | "gorm.io/gorm/clause" 7 | ) 8 | 9 | // ScanValuer interface for Field 10 | type ScanValuer interface { 11 | Scan(src interface{}) error // sql.Scanner 12 | Value() (driver.Value, error) // driver.Valuer 13 | } 14 | 15 | // Field a standard field struct 16 | type Field struct{ expr } 17 | 18 | // Eq judge equal 19 | func (field Field) Eq(value driver.Valuer) Expr { 20 | return expr{e: clause.Eq{Column: field.RawExpr(), Value: value}} 21 | } 22 | 23 | // Neq judge not equal 24 | func (field Field) Neq(value driver.Valuer) Expr { 25 | return expr{e: clause.Neq{Column: field.RawExpr(), Value: value}} 26 | } 27 | 28 | // In ... 29 | func (field Field) In(values ...driver.Valuer) Expr { 30 | return expr{e: clause.IN{Column: field.RawExpr(), Values: field.toSlice(values...)}} 31 | } 32 | 33 | // NotIn ... 34 | func (field Field) NotIn(values ...driver.Valuer) Expr { 35 | return expr{e: clause.Not(field.In(values...).expression())} 36 | } 37 | 38 | // Gt ... 39 | func (field Field) Gt(value driver.Valuer) Expr { 40 | return expr{e: clause.Gt{Column: field.RawExpr(), Value: value}} 41 | } 42 | 43 | // Gte ... 44 | func (field Field) Gte(value driver.Valuer) Expr { 45 | return expr{e: clause.Gte{Column: field.RawExpr(), Value: value}} 46 | } 47 | 48 | // Lt ... 49 | func (field Field) Lt(value driver.Valuer) Expr { 50 | return expr{e: clause.Lt{Column: field.RawExpr(), Value: value}} 51 | } 52 | 53 | // Lte ... 54 | func (field Field) Lte(value driver.Valuer) Expr { 55 | return expr{e: clause.Lte{Column: field.RawExpr(), Value: value}} 56 | } 57 | 58 | // Like ... 59 | func (field Field) Like(value driver.Valuer) Expr { 60 | return expr{e: clause.Like{Column: field.RawExpr(), Value: value}} 61 | } 62 | 63 | // Value ... 64 | func (field Field) Value(value driver.Valuer) AssignExpr { 65 | return field.value(value) 66 | } 67 | 68 | // Sum ... 69 | func (field Field) Sum() Field { 70 | return Field{field.sum()} 71 | } 72 | 73 | // IfNull ... 74 | func (field Field) IfNull(value driver.Valuer) Expr { 75 | return field.ifNull(value) 76 | } 77 | 78 | // Field ... 79 | func (field Field) Field(value []interface{}) Expr { 80 | return field.field(value) 81 | } 82 | 83 | func (field Field) toSlice(values ...driver.Valuer) []interface{} { 84 | slice := make([]interface{}, len(values)) 85 | for i, v := range values { 86 | slice[i] = v 87 | } 88 | return slice 89 | } 90 | -------------------------------------------------------------------------------- /field/field_test.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "sync" 7 | "testing" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/schema" 12 | "gorm.io/gorm/utils/tests" 13 | ) 14 | 15 | var db, _ = gorm.Open(tests.DummyDialector{}, nil) 16 | 17 | func GetStatement() *gorm.Statement { 18 | user, _ := schema.Parse(&User{}, &sync.Map{}, db.NamingStrategy) 19 | return &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} 20 | } 21 | 22 | func CheckBuildExpr(t *testing.T, e Expr, result string, vars []interface{}) { 23 | stmt := GetStatement() 24 | 25 | e.expression().Build(stmt) 26 | 27 | sql := strings.TrimSpace(stmt.SQL.String()) 28 | if sql != result { 29 | t.Errorf("SQL expects %v got %v", result, sql) 30 | } 31 | 32 | if !reflect.DeepEqual(stmt.Vars, vars) { 33 | t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) 34 | } 35 | } 36 | 37 | func BuildToString(e Expr) (string, []interface{}) { 38 | stmt := GetStatement() 39 | 40 | e.expression().Build(stmt) 41 | 42 | return stmt.SQL.String(), stmt.Vars 43 | } 44 | 45 | type User struct { 46 | gorm.Model 47 | Name string 48 | Age uint 49 | // Birthday *time.Time 50 | // Account Account 51 | // Pets []*Pet 52 | // Toys []Toy `gorm:"polymorphic:Owner"` 53 | // CompanyID *int 54 | // Company Company 55 | // ManagerID *uint 56 | // Manager *User 57 | // Team []User `gorm:"foreignkey:ManagerID"` 58 | // Languages []Language `gorm:"many2many:UserSpeak;"` 59 | // Friends []*User `gorm:"many2many:user_friends;"` 60 | // Active bool 61 | } 62 | -------------------------------------------------------------------------------- /field/function.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "strings" 5 | 6 | "gorm.io/gorm/clause" 7 | ) 8 | 9 | // Func sql functions 10 | var Func = new(function) 11 | 12 | type function struct{} 13 | 14 | // UnixTimestamp same as UNIX_TIMESTAMP([date]) 15 | func (f *function) UnixTimestamp(date ...string) Uint64 { 16 | if len(date) > 0 { 17 | return Uint64{expr{e: clause.Expr{SQL: "UNIX_TIMESTAMP(?)", Vars: []interface{}{date[0]}}}} 18 | } 19 | return Uint64{expr{e: clause.Expr{SQL: "UNIX_TIMESTAMP()"}}} 20 | } 21 | 22 | // FromUnixTime FROM_UNIXTIME(unix_timestamp[,format]) 23 | func (f *function) FromUnixTime(date uint64, format string) String { 24 | if strings.TrimSpace(format) != "" { 25 | return String{expr{e: clause.Expr{SQL: "FROM_UNIXTIME(?, ?)", Vars: []interface{}{date, format}}}} 26 | } 27 | return String{expr{e: clause.Expr{SQL: "FROM_UNIXTIME(?)", Vars: []interface{}{date}}}} 28 | } 29 | 30 | func (f *function) Rand() String { 31 | return String{expr{e: clause.Expr{SQL: "RAND()"}}} 32 | } 33 | 34 | func (f *function) Random() String { 35 | return String{expr{e: clause.Expr{SQL: "RANDOM()"}}} 36 | } 37 | -------------------------------------------------------------------------------- /field/serializer.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "context" 5 | "gorm.io/gorm/clause" 6 | "gorm.io/gorm/schema" 7 | "reflect" 8 | 9 | "gorm.io/gorm" 10 | ) 11 | 12 | type ValuerType struct { 13 | Column string 14 | Value schema.SerializerValuerInterface 15 | } 16 | 17 | func (v ValuerType) GormValue(ctx context.Context, db *gorm.DB) (expr clause.Expr) { 18 | stmt := db.Statement.Schema 19 | field := stmt.LookUpField(v.Column) 20 | newValue, err := v.Value.Value(ctx, field, reflect.ValueOf(v.Value), v.Value) 21 | _ = db.AddError(err) 22 | return clause.Expr{SQL: "?", Vars: []interface{}{newValue}} 23 | } 24 | 25 | // Field2 a standard field struct 26 | type Serializer struct{ expr } 27 | 28 | // Eq judge equal 29 | func (field Serializer) Eq(value schema.SerializerValuerInterface) Expr { 30 | return expr{e: clause.Eq{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 31 | } 32 | 33 | // Neq judge not equal 34 | func (field Serializer) Neq(value schema.SerializerValuerInterface) Expr { 35 | return expr{e: clause.Neq{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 36 | } 37 | 38 | // In ... 39 | func (field Serializer) In(values ...schema.SerializerValuerInterface) Expr { 40 | return expr{e: clause.IN{Column: field.RawExpr(), Values: field.toSlice(values...)}} 41 | } 42 | 43 | // Gt ... 44 | func (field Serializer) Gt(value schema.SerializerValuerInterface) Expr { 45 | return expr{e: clause.Gt{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 46 | } 47 | 48 | // Gte ... 49 | func (field Serializer) Gte(value schema.SerializerValuerInterface) Expr { 50 | return expr{e: clause.Gte{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 51 | } 52 | 53 | // Lt ... 54 | func (field Serializer) Lt(value schema.SerializerValuerInterface) Expr { 55 | return expr{e: clause.Lt{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 56 | } 57 | 58 | // Lte ... 59 | func (field Serializer) Lte(value schema.SerializerValuerInterface) Expr { 60 | return expr{e: clause.Lte{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 61 | } 62 | 63 | // Like ... 64 | func (field Serializer) Like(value schema.SerializerValuerInterface) Expr { 65 | return expr{e: clause.Like{Column: field.RawExpr(), Value: ValuerType{Column: field.ColumnName().String(), Value: value}}} 66 | } 67 | 68 | // Value ... 69 | func (field Serializer) Value(value schema.SerializerValuerInterface) AssignExpr { 70 | return field.value(ValuerType{Column: field.ColumnName().String(), Value: value}) 71 | } 72 | 73 | // Sum ... 74 | func (field Serializer) Sum() Field { 75 | return Field{field.sum()} 76 | } 77 | 78 | // IfNull ... 79 | func (field Serializer) IfNull(value schema.SerializerValuerInterface) Expr { 80 | return field.ifNull(ValuerType{Column: field.ColumnName().String(), Value: value}) 81 | } 82 | 83 | func (field Serializer) toSlice(values ...schema.SerializerValuerInterface) []interface{} { 84 | slice := make([]interface{}, len(values)) 85 | for i, v := range values { 86 | slice[i] = ValuerType{Column: field.ColumnName().String(), Value: v} 87 | } 88 | return slice 89 | } 90 | -------------------------------------------------------------------------------- /field/tag.go: -------------------------------------------------------------------------------- 1 | package field 2 | 3 | import ( 4 | "sort" 5 | "strings" 6 | ) 7 | 8 | const ( 9 | TagKeyGorm = "gorm" 10 | TagKeyJson = "json" 11 | 12 | //gorm tag 13 | TagKeyGormColumn = "column" 14 | TagKeyGormType = "type" 15 | TagKeyGormPrimaryKey = "primaryKey" 16 | TagKeyGormAutoIncrement = "autoIncrement" 17 | TagKeyGormNotNull = "not null" 18 | TagKeyGormUniqueIndex = "uniqueIndex" 19 | TagKeyGormIndex = "index" 20 | TagKeyGormDefault = "default" 21 | TagKeyGormComment = "comment" 22 | ) 23 | 24 | var ( 25 | tagKeyPriorities = map[string]int16{ 26 | TagKeyGorm: 100, 27 | TagKeyJson: 99, 28 | 29 | TagKeyGormColumn: 10, 30 | TagKeyGormType: 9, 31 | TagKeyGormPrimaryKey: 8, 32 | TagKeyGormAutoIncrement: 7, 33 | TagKeyGormNotNull: 6, 34 | TagKeyGormUniqueIndex: 5, 35 | TagKeyGormIndex: 4, 36 | TagKeyGormDefault: 3, 37 | TagKeyGormComment: 0, 38 | } 39 | ) 40 | 41 | type TagBuilder interface { 42 | Build() string 43 | } 44 | 45 | type Tag map[string]string 46 | 47 | func (tag Tag) Set(key, value string) Tag { 48 | tag[key] = value 49 | return tag 50 | } 51 | 52 | func (tag Tag) Remove(key string) Tag { 53 | delete(tag, key) 54 | return tag 55 | } 56 | 57 | func (tag Tag) Build() string { 58 | if len(tag) == 0 { 59 | return "" 60 | } 61 | 62 | tags := make([]string, 0, len(tag)) 63 | for _, k := range tagKeys(tag) { 64 | v := tag[k] 65 | if k == "" { 66 | continue 67 | } 68 | tags = append(tags, k+":\""+v+"\"") 69 | } 70 | return strings.Join(tags, " ") 71 | } 72 | 73 | type GormTag map[string][]string 74 | 75 | func (tag GormTag) Append(key string, values ...string) GormTag { 76 | if _, ok := tag[key]; ok { 77 | tag[key] = append(tag[key], values...) 78 | } else { 79 | tag[key] = values 80 | } 81 | return tag 82 | } 83 | 84 | func (tag GormTag) Set(key string, values ...string) GormTag { 85 | tag[key] = values 86 | return tag 87 | } 88 | 89 | func (tag GormTag) Remove(key string) GormTag { 90 | delete(tag, key) 91 | return tag 92 | } 93 | 94 | func (tag GormTag) Build() string { 95 | if len(tag) == 0 { 96 | return "" 97 | } 98 | tags := make([]string, 0, len(tag)) 99 | for _, k := range gormKeys(tag) { 100 | vs := tag[k] 101 | if len(vs) == 0 && k == "" { 102 | continue 103 | } 104 | if len(vs) == 0 { 105 | tags = append(tags, k) 106 | continue 107 | } 108 | for _, v := range vs { 109 | if k == "" && v == "" { 110 | continue 111 | } 112 | tv := make([]string, 0, 2) 113 | if k != "" { 114 | tv = append(tv, k) 115 | } 116 | if v != "" { 117 | tv = append(tv, v) 118 | } 119 | tags = append(tags, strings.Join(tv, ":")) 120 | } 121 | } 122 | 123 | return strings.Join(tags, ";") 124 | } 125 | 126 | func tagKeys(tag Tag) []string { 127 | keys := make([]string, 0, len(tag)) 128 | if len(tag) == 0 { 129 | return keys 130 | } 131 | for k := range tag { 132 | keys = append(keys, k) 133 | } 134 | return keySort(keys) 135 | } 136 | 137 | func gormKeys(tag GormTag) []string { 138 | keys := make([]string, 0, len(tag)) 139 | if len(tag) == 0 { 140 | return keys 141 | } 142 | for k := range tag { 143 | keys = append(keys, k) 144 | } 145 | return keySort(keys) 146 | } 147 | 148 | func keySort(keys []string) []string { 149 | if len(keys) == 0 { 150 | return keys 151 | } 152 | sort.Slice(keys, func(i, j int) bool { 153 | if tagKeyPriorities[keys[i]] == tagKeyPriorities[keys[j]] { 154 | return keys[i] <= keys[j] 155 | } 156 | return tagKeyPriorities[keys[i]] > tagKeyPriorities[keys[j]] 157 | }) 158 | return keys 159 | } 160 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gen 2 | 3 | go 1.18 4 | 5 | require ( 6 | golang.org/x/tools v0.17.0 7 | gorm.io/datatypes v1.2.4 8 | gorm.io/gorm v1.25.12 9 | gorm.io/hints v1.1.0 10 | gorm.io/plugin/dbresolver v1.5.3 11 | ) 12 | 13 | require ( 14 | filippo.io/edwards25519 v1.1.0 // indirect 15 | github.com/go-sql-driver/mysql v1.8.1 // indirect 16 | github.com/google/uuid v1.3.0 // indirect 17 | github.com/jinzhu/inflection v1.0.0 // indirect 18 | github.com/jinzhu/now v1.1.5 // indirect 19 | golang.org/x/mod v0.14.0 // indirect 20 | golang.org/x/text v0.14.0 // indirect 21 | gorm.io/driver/mysql v1.5.7 // indirect 22 | ) 23 | -------------------------------------------------------------------------------- /helper/object.go: -------------------------------------------------------------------------------- 1 | package helper 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "gorm.io/gen/field" 8 | ) 9 | 10 | // Object an object interface 11 | type Object interface { 12 | // TableName return table name 13 | TableName() string 14 | // StructName return struct name 15 | StructName() string 16 | // FileName return field name 17 | FileName() string 18 | // ImportPkgPaths return need import package path 19 | ImportPkgPaths() []string 20 | 21 | // Fields return field array 22 | Fields() []Field 23 | } 24 | 25 | // Field a field interface 26 | type Field interface { 27 | // Name return field name 28 | Name() string 29 | // Type return field type 30 | Type() string 31 | 32 | // ColumnName return column name 33 | ColumnName() string 34 | // GORMTag return gorm tag 35 | GORMTag() string 36 | // JSONTag return json tag 37 | JSONTag() string 38 | // Tag return field tag 39 | Tag() field.Tag 40 | 41 | // Comment return comment 42 | Comment() string 43 | } 44 | 45 | // CheckObject check ojbect 46 | func CheckObject(obj Object) error { 47 | if obj.StructName() == "" { 48 | return errors.New("Object's StructName() cannot be empty") 49 | } 50 | 51 | for _, field := range obj.Fields() { 52 | switch "" { 53 | case field.Name(): 54 | return fmt.Errorf("Object %s's Field.Name() cannot be empty", obj.StructName()) 55 | case field.Type(): 56 | return fmt.Errorf("Object %s's Field.Type() cannot be empty", obj.StructName()) 57 | } 58 | } 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /import.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import "strings" 4 | 5 | var ( 6 | importList = new(importPkgS).Add( 7 | "context", 8 | "database/sql", 9 | "strings", 10 | "", 11 | "gorm.io/gorm", 12 | "gorm.io/gorm/schema", 13 | "gorm.io/gorm/clause", 14 | "", 15 | "gorm.io/gen", 16 | "gorm.io/gen/field", 17 | "gorm.io/gen/helper", 18 | "", 19 | "gorm.io/plugin/dbresolver", 20 | ) 21 | unitTestImportList = new(importPkgS).Add( 22 | "context", 23 | "fmt", 24 | "strconv", 25 | "testing", 26 | "", 27 | "gorm.io/driver/sqlite", 28 | "gorm.io/gorm", 29 | ) 30 | ) 31 | 32 | type importPkgS struct { 33 | paths []string 34 | } 35 | 36 | func (ip importPkgS) Add(paths ...string) *importPkgS { 37 | purePaths := make([]string, 0, len(paths)+1) 38 | for _, p := range paths { 39 | p = strings.TrimSpace(p) 40 | if p == "" { 41 | purePaths = append(purePaths, p) 42 | continue 43 | } 44 | 45 | if p[len(p)-1] != '"' { 46 | p = `"` + p + `"` 47 | } 48 | 49 | var exists bool 50 | for _, existsP := range ip.paths { 51 | if p == existsP { 52 | exists = true 53 | break 54 | } 55 | } 56 | if !exists { 57 | purePaths = append(purePaths, p) 58 | } 59 | } 60 | purePaths = append(purePaths, "") 61 | 62 | ip.paths = append(ip.paths, purePaths...) 63 | 64 | return &ip 65 | } 66 | 67 | func (ip importPkgS) Paths() []string { return ip.paths } 68 | -------------------------------------------------------------------------------- /interface.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "gorm.io/gorm" 7 | "gorm.io/gorm/clause" 8 | "gorm.io/gorm/schema" 9 | 10 | "gorm.io/gen/field" 11 | ) 12 | 13 | type ( 14 | // Condition query condition 15 | // field.Expr and subquery are expect value 16 | Condition interface { 17 | BeCond() interface{} 18 | CondError() error 19 | } 20 | ) 21 | 22 | var ( 23 | _ Condition = (field.Expr)(nil) 24 | _ Condition = (field.Value)(nil) 25 | _ Condition = (SubQuery)(nil) 26 | _ Condition = (Dao)(nil) 27 | ) 28 | 29 | // SubQuery sub query interface 30 | type SubQuery interface { 31 | underlyingDB() *gorm.DB 32 | underlyingDO() *DO 33 | 34 | Condition 35 | } 36 | 37 | // Dao CRUD methods 38 | type Dao interface { 39 | SubQuery 40 | schema.Tabler 41 | As(alias string) Dao 42 | 43 | Not(conds ...Condition) Dao 44 | Or(conds ...Condition) Dao 45 | 46 | Select(columns ...field.Expr) Dao 47 | Where(conds ...Condition) Dao 48 | Order(columns ...field.Expr) Dao 49 | Distinct(columns ...field.Expr) Dao 50 | Omit(columns ...field.Expr) Dao 51 | Join(table schema.Tabler, conds ...field.Expr) Dao 52 | LeftJoin(table schema.Tabler, conds ...field.Expr) Dao 53 | RightJoin(table schema.Tabler, conds ...field.Expr) Dao 54 | Group(columns ...field.Expr) Dao 55 | Having(conds ...Condition) Dao 56 | Limit(limit int) Dao 57 | Offset(offset int) Dao 58 | Scopes(funcs ...func(Dao) Dao) Dao 59 | Unscoped() Dao 60 | Attrs(attrs ...field.AssignExpr) Dao 61 | Assign(attrs ...field.AssignExpr) Dao 62 | Joins(field field.RelationField) Dao 63 | Preload(field field.RelationField) Dao 64 | Clauses(conds ...clause.Expression) Dao 65 | 66 | Create(value interface{}) error 67 | CreateInBatches(value interface{}, batchSize int) error 68 | Save(value interface{}) error 69 | First() (result interface{}, err error) 70 | Take() (result interface{}, err error) 71 | Last() (result interface{}, err error) 72 | Find() (results interface{}, err error) 73 | FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error 74 | FirstOrInit() (result interface{}, err error) 75 | FirstOrCreate() (result interface{}, err error) 76 | Update(column field.Expr, value interface{}) (info ResultInfo, err error) 77 | UpdateSimple(columns ...field.AssignExpr) (info ResultInfo, err error) 78 | Updates(values interface{}) (info ResultInfo, err error) 79 | UpdateColumn(column field.Expr, value interface{}) (info ResultInfo, err error) 80 | UpdateColumns(values interface{}) (info ResultInfo, err error) 81 | UpdateColumnSimple(columns ...field.AssignExpr) (info ResultInfo, err error) 82 | Delete(...interface{}) (info ResultInfo, err error) 83 | Count() (int64, error) 84 | Row() *sql.Row 85 | Rows() (*sql.Rows, error) 86 | Scan(dest interface{}) error 87 | Pluck(column field.Expr, dest interface{}) error 88 | ScanRows(rows *sql.Rows, dest interface{}) error 89 | 90 | AddError(err error) error 91 | } 92 | -------------------------------------------------------------------------------- /internal/generate/generate.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/schema" 10 | 11 | "gorm.io/gen/internal/model" 12 | ) 13 | 14 | /* 15 | ** The feature of mapping table from database server to Golang struct 16 | ** Provided by @qqxhb 17 | */ 18 | 19 | func getFields(db *gorm.DB, conf *model.Config, columns []*model.Column) (fields []*model.Field) { 20 | for _, col := range columns { 21 | col.SetDataTypeMap(conf.DataTypeMap) 22 | col.WithNS(conf.FieldJSONTagNS) 23 | 24 | m := col.ToField(conf.FieldNullable, conf.FieldCoverable, conf.FieldSignable) 25 | 26 | if filterField(m, conf.FilterOpts) == nil { 27 | continue 28 | } 29 | if _, ok := col.ColumnType.ColumnType(); ok && !conf.FieldWithTypeTag { // remove type tag if FieldWithTypeTag == false 30 | m.GORMTag.Remove("type") 31 | } 32 | 33 | m = modifyField(m, conf.ModifyOpts) 34 | if ns, ok := db.NamingStrategy.(schema.NamingStrategy); ok { 35 | ns.SingularTable = true 36 | m.Name = ns.SchemaName(ns.TablePrefix + m.Name) 37 | } else if db.NamingStrategy != nil { 38 | m.Name = db.NamingStrategy.SchemaName(m.Name) 39 | } 40 | 41 | fields = append(fields, m) 42 | } 43 | for _, create := range conf.CreateOpts { 44 | m := create.Operator()(nil) 45 | if m.Relation != nil { 46 | if m.Relation.Model() != nil { 47 | stmt := gorm.Statement{DB: db} 48 | _ = stmt.Parse(m.Relation.Model()) 49 | if stmt.Schema != nil { 50 | m.Relation.AppendChildRelation(ParseStructRelationShip(&stmt.Schema.Relationships)...) 51 | } 52 | } 53 | m.Type = strings.ReplaceAll(m.Type, conf.ModelPkg+".", "") // remove modelPkg in field's Type, avoid import error 54 | } 55 | 56 | fields = append(fields, m) 57 | } 58 | return fields 59 | } 60 | 61 | func filterField(m *model.Field, opts []model.FieldOption) *model.Field { 62 | for _, opt := range opts { 63 | if opt.Operator()(m) == nil { 64 | return nil 65 | } 66 | } 67 | return m 68 | } 69 | 70 | func modifyField(m *model.Field, opts []model.FieldOption) *model.Field { 71 | for _, opt := range opts { 72 | m = opt.Operator()(m) 73 | } 74 | return m 75 | } 76 | 77 | // get mysql db' name 78 | var modelNameReg = regexp.MustCompile(`^\w+$`) 79 | 80 | func checkStructName(name string) error { 81 | if name == "" { 82 | return nil 83 | } 84 | if !modelNameReg.MatchString(name) { 85 | return fmt.Errorf("model name cannot contains invalid character") 86 | } 87 | if name[0] < 'A' || name[0] > 'Z' { 88 | return fmt.Errorf("model name must be initial capital") 89 | } 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /internal/generate/table.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "gorm.io/gorm" 8 | 9 | "gorm.io/gen/internal/model" 10 | ) 11 | 12 | // ITableInfo table info interface 13 | type ITableInfo interface { 14 | GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error) 15 | 16 | GetTableIndex(schemaName string, tableName string) (indexes []gorm.Index, err error) 17 | } 18 | 19 | func getTableInfo(db *gorm.DB) ITableInfo { 20 | return &tableInfo{db} 21 | } 22 | 23 | func getTableComment(db *gorm.DB, tableName string) string { 24 | table, err := getTableType(db, tableName) 25 | if err != nil || table == nil { 26 | return "" 27 | } 28 | if comment, ok := table.Comment(); ok { 29 | return comment 30 | } 31 | return "" 32 | } 33 | 34 | func getTableType(db *gorm.DB, tableName string) (result gorm.TableType, err error) { 35 | if db == nil || db.Migrator() == nil { 36 | return 37 | } 38 | return db.Migrator().TableType(tableName) 39 | } 40 | 41 | func getTableColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) { 42 | if db == nil { 43 | return nil, errors.New("gorm db is nil") 44 | } 45 | 46 | mt := getTableInfo(db) 47 | result, err = mt.GetTableColumns(schemaName, tableName) 48 | if err != nil { 49 | return nil, err 50 | } 51 | if !indexTag || len(result) == 0 { 52 | return result, nil 53 | } 54 | 55 | index, err := mt.GetTableIndex(schemaName, tableName) 56 | if err != nil { //ignore find index err 57 | db.Logger.Warn(context.Background(), "GetTableIndex for %s,err=%s", tableName, err.Error()) 58 | return result, nil 59 | } 60 | if len(index) == 0 { 61 | return result, nil 62 | } 63 | 64 | im := model.GroupByColumn(index) 65 | for _, c := range result { 66 | c.Indexes = im[c.Name()] 67 | } 68 | return result, nil 69 | } 70 | 71 | type tableInfo struct{ *gorm.DB } 72 | 73 | // GetTableColumns struct 74 | func (t *tableInfo) GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error) { 75 | types, err := t.Migrator().ColumnTypes(tableName) 76 | if err != nil { 77 | return nil, err 78 | } 79 | for _, column := range types { 80 | result = append(result, &model.Column{ColumnType: column, TableName: tableName, UseScanType: t.Dialector.Name() != "mysql" && t.Dialector.Name() != "sqlite"}) 81 | } 82 | return result, nil 83 | } 84 | 85 | // GetTableIndex index 86 | func (t *tableInfo) GetTableIndex(schemaName string, tableName string) (indexes []gorm.Index, err error) { 87 | return t.Migrator().GetIndexes(tableName) 88 | } 89 | -------------------------------------------------------------------------------- /internal/generate/test.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "gorm.io/gen/internal/parser" 9 | ) 10 | 11 | //GetTestParamInTmpl return param list 12 | func (m *InterfaceMethod) GetTestParamInTmpl() string { 13 | return testParamToString(m.Params) 14 | } 15 | 16 | // GetTestResultParamInTmpl return result list 17 | func (m *InterfaceMethod) GetTestResultParamInTmpl() string { 18 | var res []string 19 | for i := range m.Result { 20 | tmplString := fmt.Sprintf("res%d", i+1) 21 | res = append(res, tmplString) 22 | } 23 | return strings.Join(res, ",") 24 | } 25 | 26 | // testParamToString param list to string used in tmpl 27 | func testParamToString(params []parser.Param) string { 28 | var res []string 29 | for i, param := range params { 30 | // TODO manage array and pointer 31 | typ := param.Type 32 | if param.Package != "" { 33 | typ = param.Package + "." + typ 34 | } 35 | if param.IsArray { 36 | typ = "[]" + typ 37 | } 38 | if param.IsPointer { 39 | typ = "*" + typ 40 | } 41 | res = append(res, fmt.Sprintf("tt.Input.Args[%d].(%s)", i, typ)) 42 | } 43 | return strings.Join(res, ",") 44 | } 45 | 46 | // GetAssertInTmpl assert in diy test 47 | func (m *InterfaceMethod) GetAssertInTmpl() string { 48 | var res []string 49 | for i := range m.Result { 50 | tmplString := fmt.Sprintf("assert(t, %v, res%d, tt.Expectation.Ret[%d])", strconv.Quote(m.MethodName), i+1, i) 51 | res = append(res, tmplString) 52 | } 53 | return strings.Join(res, "\n") 54 | } 55 | -------------------------------------------------------------------------------- /internal/generate/utils.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | func isCapitalize(s string) bool { 8 | if len(s) < 1 { 9 | return false 10 | } 11 | b := s[0] 12 | if b >= 'A' && b <= 'Z' { 13 | return true 14 | } 15 | return false 16 | } 17 | 18 | func isEnd(b byte) bool { 19 | switch { 20 | case b >= 'a' && b <= 'z': 21 | return false 22 | case b >= 'A' && b <= 'Z': 23 | return false 24 | case b >= '0' && b <= '9': 25 | return false 26 | case b == '-' || b == '_' || b == '.': 27 | return false 28 | default: 29 | return true 30 | } 31 | } 32 | 33 | func getPackageName(fullName string) string { 34 | return strings.Split(delPointerSym(fullName), ".")[0] 35 | } 36 | 37 | func strOutRange(index int, str string) bool { 38 | return index >= len(str) 39 | } 40 | 41 | func delPointerSym(name string) string { 42 | return strings.TrimLeft(name, "*") 43 | } 44 | 45 | func getPureName(s string) string { 46 | return string(strings.ToLower(delPointerSym(s))[0]) 47 | } 48 | 49 | // not need capitalize 50 | func getStructName(t string) string { 51 | list := strings.Split(t, ".") 52 | return list[len(list)-1] 53 | } 54 | 55 | func uncaptialize(s string) string { 56 | if s == "" { 57 | return "" 58 | } 59 | 60 | return strings.ToLower(s[:1]) + s[1:] 61 | } 62 | -------------------------------------------------------------------------------- /internal/model/config.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "path/filepath" 5 | "strings" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | // Config model configuration 11 | type Config struct { 12 | ModelPkg string 13 | TablePrefix string 14 | TableName string 15 | ModelName string 16 | 17 | ImportPkgPaths []string 18 | ModelOpts []Option 19 | 20 | NameStrategy 21 | FieldConfig 22 | MethodConfig 23 | } 24 | 25 | // NameStrategy name strategy 26 | type NameStrategy struct { 27 | SchemaNameOpts []SchemaNameOpt 28 | 29 | TableNameNS func(tableName string) string 30 | ModelNameNS func(tableName string) string 31 | FileNameNS func(tableName string) string 32 | } 33 | 34 | // FieldConfig field configuration 35 | type FieldConfig struct { 36 | DataTypeMap map[string]func(columnType gorm.ColumnType) (dataType string) 37 | 38 | FieldNullable bool // generate pointer when field is nullable 39 | FieldCoverable bool // generate pointer when field has default value 40 | FieldSignable bool // detect integer field's unsigned type, adjust generated data type 41 | FieldWithIndexTag bool // generate with gorm index tag 42 | FieldWithTypeTag bool // generate with gorm column type tag 43 | 44 | FieldJSONTagNS func(columnName string) string 45 | 46 | ModifyOpts []FieldOption 47 | FilterOpts []FieldOption 48 | CreateOpts []FieldOption 49 | } 50 | 51 | // MethodConfig method configuration 52 | type MethodConfig struct { 53 | MethodOpts []MethodOption 54 | } 55 | 56 | // Preprocess revise invalid field 57 | func (cfg *Config) Preprocess() *Config { 58 | if cfg.ModelPkg == "" { 59 | cfg.ModelPkg = DefaultModelPkg 60 | } 61 | cfg.ModelPkg = filepath.Base(cfg.ModelPkg) 62 | 63 | cfg.ModifyOpts, cfg.FilterOpts, cfg.CreateOpts, cfg.MethodOpts = sortOptions(cfg.ModelOpts) 64 | 65 | return cfg 66 | } 67 | 68 | // GetNames get names 69 | func (cfg *Config) GetNames() (tableName, structName, fileName string) { 70 | tableName, structName = cfg.TableName, cfg.ModelName 71 | 72 | if cfg.ModelNameNS != nil { 73 | structName = cfg.ModelNameNS(tableName) 74 | } 75 | 76 | if cfg.TableNameNS != nil { 77 | tableName = cfg.TableNameNS(tableName) 78 | } 79 | if tableName != "" && !strings.HasPrefix(tableName, cfg.TablePrefix) { 80 | tableName = cfg.TablePrefix + tableName 81 | } 82 | 83 | fileName = strings.ToLower(tableName) 84 | if cfg.FileNameNS != nil { 85 | fileName = cfg.FileNameNS(cfg.TableName) 86 | } 87 | 88 | return 89 | } 90 | 91 | // GetModelMethods get diy method from option 92 | func (cfg *Config) GetModelMethods() (methods []interface{}) { 93 | if cfg == nil { 94 | return 95 | } 96 | 97 | for _, opt := range cfg.MethodOpts { 98 | methods = append(methods, opt.Methods()...) 99 | } 100 | return 101 | } 102 | 103 | // GetSchemaName get schema name 104 | func (cfg *Config) GetSchemaName(db *gorm.DB) string { 105 | if cfg == nil { 106 | return "" 107 | } 108 | 109 | for _, opt := range cfg.SchemaNameOpts { 110 | if name := opt(db); name != "" { 111 | return name 112 | } 113 | } 114 | return "" 115 | } 116 | -------------------------------------------------------------------------------- /internal/model/options.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | // SchemaNameOpt schema name option 8 | type SchemaNameOpt func(*gorm.DB) string 9 | 10 | var defaultSchemaNameOpt = SchemaNameOpt(func(db *gorm.DB) string { return db.Migrator().CurrentDatabase() }) 11 | 12 | // Option field option 13 | type Option interface{ OptionType() string } 14 | 15 | const fieldType = "field" 16 | 17 | // FieldOption ... 18 | type FieldOption interface { 19 | Option 20 | Operator() func(*Field) *Field 21 | } 22 | 23 | const methodType = "method" 24 | 25 | // MethodOption ... 26 | type MethodOption interface { 27 | Option 28 | Methods() (methods []interface{}) 29 | } 30 | 31 | var ( 32 | _ Option = ModifyFieldOpt(nil) 33 | _ Option = FilterFieldOpt(nil) 34 | _ Option = CreateFieldOpt(nil) 35 | 36 | _ Option = AddMethodOpt(nil) 37 | ) 38 | 39 | // ModifyFieldOpt modify field option 40 | type ModifyFieldOpt func(*Field) *Field 41 | 42 | // OptionType implement for interface Option 43 | func (ModifyFieldOpt) OptionType() string { return fieldType } 44 | 45 | // Operator implement for FieldOpt 46 | func (o ModifyFieldOpt) Operator() func(*Field) *Field { return o } 47 | 48 | // FilterFieldOpt filter field option 49 | type FilterFieldOpt ModifyFieldOpt 50 | 51 | // OptionType implement for interface Option 52 | func (FilterFieldOpt) OptionType() string { return fieldType } 53 | 54 | // Operator implement for FieldOpt 55 | func (o FilterFieldOpt) Operator() func(*Field) *Field { return o } 56 | 57 | // CreateFieldOpt create field option 58 | type CreateFieldOpt ModifyFieldOpt 59 | 60 | // OptionType implement for interface Option 61 | func (CreateFieldOpt) OptionType() string { return fieldType } 62 | 63 | // Operator implement for FieldOpt 64 | func (o CreateFieldOpt) Operator() func(*Field) *Field { return o } 65 | 66 | // AddMethodOpt diy method option 67 | type AddMethodOpt func() (methods []interface{}) 68 | 69 | // OptionType implement for interface Option 70 | func (AddMethodOpt) OptionType() string { return methodType } 71 | 72 | // Methods ... 73 | func (o AddMethodOpt) Methods() []interface{} { return o() } 74 | 75 | func sortOptions(opts []Option) (modifyOpts []FieldOption, filterOpts []FieldOption, createOpts []FieldOption, methodOpt []MethodOption) { 76 | for _, opt := range opts { 77 | switch opt := opt.(type) { 78 | case ModifyFieldOpt: 79 | modifyOpts = append(modifyOpts, opt) 80 | case FilterFieldOpt: 81 | filterOpts = append(filterOpts, opt) 82 | case CreateFieldOpt: 83 | createOpts = append(createOpts, opt) 84 | case AddMethodOpt: 85 | methodOpt = append(methodOpt, opt) 86 | } 87 | } 88 | return 89 | } 90 | -------------------------------------------------------------------------------- /internal/model/tbl_index.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "gorm.io/gorm" 4 | 5 | // Index table index info 6 | type Index struct { 7 | gorm.Index 8 | Priority int32 `gorm:"column:SEQ_IN_INDEX"` 9 | } 10 | 11 | // GroupByColumn group columns 12 | func GroupByColumn(indexList []gorm.Index) map[string][]*Index { 13 | columnIndexMap := make(map[string][]*Index, len(indexList)) 14 | if len(indexList) == 0 { 15 | return columnIndexMap 16 | } 17 | 18 | for _, idx := range indexList { 19 | if idx == nil { 20 | continue 21 | } 22 | for i, col := range idx.Columns() { 23 | columnIndexMap[col] = append(columnIndexMap[col], &Index{ 24 | Index: idx, 25 | Priority: int32(i + 1), 26 | }) 27 | } 28 | } 29 | return columnIndexMap 30 | } 31 | -------------------------------------------------------------------------------- /internal/parser/export.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "go/build" 6 | "os" 7 | "path/filepath" 8 | "reflect" 9 | "runtime" 10 | "strings" 11 | ) 12 | 13 | // InterfacePath interface path 14 | type InterfacePath struct { 15 | Name string 16 | FullName string 17 | Files []string 18 | Package string 19 | } 20 | 21 | // GetInterfacePath get interface's directory path and all files it contains 22 | func GetInterfacePath(v interface{}) (paths []*InterfacePath, err error) { 23 | value := reflect.ValueOf(v) 24 | if value.Kind() != reflect.Func { 25 | err = fmt.Errorf("model param is not function:%s", value.String()) 26 | return 27 | } 28 | 29 | for i := 0; i < value.Type().NumIn(); i++ { 30 | var path InterfacePath 31 | arg := value.Type().In(i) 32 | path.FullName = arg.String() 33 | 34 | // keep the last model 35 | for _, n := range strings.Split(arg.String(), ".") { 36 | path.Name = n 37 | } 38 | 39 | ctx := build.Default 40 | var p *build.Package 41 | 42 | if strings.Split(arg.String(), ".")[0] == "main" { 43 | _, file, _, _ := runtime.Caller(3) 44 | p, err = ctx.ImportDir(filepath.Dir(file), build.ImportComment) 45 | } else { 46 | p, err = ctx.Import(arg.PkgPath(), "", build.ImportComment) 47 | } 48 | 49 | if err != nil { 50 | return 51 | } 52 | 53 | for _, file := range p.GoFiles { 54 | goFile := fmt.Sprintf("%s/%s", p.Dir, file) 55 | if fileExists(goFile) { 56 | path.Files = append(path.Files, goFile) 57 | } 58 | } 59 | 60 | if len(path.Files) == 0 { 61 | err = fmt.Errorf("interface file not found:%s", value.String()) 62 | return 63 | } 64 | 65 | paths = append(paths, &path) 66 | } 67 | 68 | return 69 | } 70 | 71 | func fileExists(path string) bool { 72 | _, err := os.Stat(path) 73 | return err == nil 74 | } 75 | 76 | // GetModelMethod get diy methods 77 | func GetModelMethod(v interface{}) (method *DIYMethods, err error) { 78 | method = new(DIYMethods) 79 | 80 | // get diy method info by input value, must input a function or a struct 81 | value := reflect.ValueOf(v) 82 | switch value.Kind() { 83 | case reflect.Func: 84 | fullPath := runtime.FuncForPC(value.Pointer()).Name() 85 | err = method.parserPath(fullPath) 86 | if err != nil { 87 | return nil, err 88 | } 89 | case reflect.Struct: 90 | method.pkgPath = value.Type().PkgPath() 91 | method.BaseStructType = value.Type().Name() 92 | default: 93 | return nil, fmt.Errorf("method param must be a function or struct") 94 | } 95 | 96 | var p *build.Package 97 | 98 | // if struct in main file 99 | ctx := build.Default 100 | if method.pkgPath == "main" { 101 | var skip int 102 | var file string 103 | for { 104 | _, file, _, _ = runtime.Caller(skip) 105 | if !(strings.Contains(file, "gorm/gen/generator.go") || strings.Contains(file, "gorm/gen/internal")) || file == "" { 106 | break 107 | } 108 | skip++ 109 | } 110 | p, err = ctx.ImportDir(filepath.Dir(file), build.ImportComment) 111 | } else { 112 | p, err = ctx.Import(method.pkgPath, "", build.ImportComment) 113 | } 114 | if err != nil { 115 | return nil, fmt.Errorf("diy method dir not found:%s.%s %w", method.pkgPath, method.MethodName, err) 116 | } 117 | 118 | for _, file := range p.GoFiles { 119 | goFile := p.Dir + "/" + file 120 | if fileExists(goFile) { 121 | method.pkgFiles = append(method.pkgFiles, goFile) 122 | } 123 | } 124 | if len(method.pkgFiles) == 0 { 125 | return nil, fmt.Errorf("diy method file not found:%s.%s", method.pkgPath, method.MethodName) 126 | } 127 | 128 | // read files got methods 129 | return method, method.LoadMethods() 130 | } 131 | -------------------------------------------------------------------------------- /internal/parser/method.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "io/ioutil" 9 | "strings" 10 | ) 11 | 12 | func DefaultMethodTableName(structName string) *Method { 13 | return &Method{ 14 | Receiver: Param{IsPointer: true, Type: structName}, 15 | MethodName: "TableName", 16 | Doc: fmt.Sprint("TableName ", structName, "'s table name "), 17 | Result: []Param{{Type: "string"}}, 18 | Body: fmt.Sprintf("{\n\treturn TableName%s\n} ", structName), 19 | } 20 | } 21 | 22 | // Method Apply to query struct and base struct custom method 23 | type Method struct { 24 | Receiver Param 25 | MethodName string 26 | Doc string 27 | Params []Param 28 | Result []Param 29 | Body string 30 | } 31 | 32 | // FuncSign function signature 33 | func (m Method) FuncSign() string { 34 | return fmt.Sprintf("%s(%s) (%s)", m.MethodName, m.GetParamInTmpl(), m.GetResultParamInTmpl()) 35 | } 36 | 37 | // GetBaseStructTmpl return method bind info string 38 | func (m *Method) GetBaseStructTmpl() string { 39 | return m.Receiver.TmplString() 40 | } 41 | 42 | // GetParamInTmpl return param list 43 | func (m *Method) GetParamInTmpl() string { 44 | return paramToString(m.Params) 45 | } 46 | 47 | // GetResultParamInTmpl return result list 48 | func (m *Method) GetResultParamInTmpl() string { 49 | return paramToString(m.Result) 50 | } 51 | 52 | // paramToString param list to string used in tmpl 53 | func paramToString(params []Param) string { 54 | res := make([]string, len(params)) 55 | for i, param := range params { 56 | res[i] = param.TmplString() 57 | } 58 | return strings.Join(res, ",") 59 | } 60 | 61 | // DocComment return comment sql add "//" every line 62 | func (m *Method) DocComment() string { 63 | return strings.Replace(strings.TrimSpace(m.Doc), "\n", "\n//", -1) 64 | } 65 | 66 | // DIYMethods user Custom methods bind to db base struct 67 | type DIYMethods struct { 68 | BaseStructType string 69 | MethodName string 70 | pkgPath string 71 | currentFile string 72 | pkgFiles []string 73 | Methods []*Method 74 | } 75 | 76 | func (m *DIYMethods) parserPath(path string) error { 77 | pathList := strings.Split(path, ".") 78 | if len(pathList) < 3 { 79 | return fmt.Errorf("parser diy method error") 80 | } 81 | 82 | m.pkgPath = strings.Join(pathList[:len(pathList)-2], ".") 83 | methodName := pathList[len(pathList)-1] 84 | m.MethodName = methodName[:len(methodName)-3] 85 | 86 | structName := pathList[len(pathList)-2] 87 | m.BaseStructType = strings.Trim(structName, "()*") 88 | return nil 89 | } 90 | 91 | // Visit ast visit function 92 | func (m *DIYMethods) Visit(n ast.Node) (w ast.Visitor) { 93 | switch t := n.(type) { 94 | case *ast.FuncDecl: 95 | // check base struct and method name is expect 96 | structMeta := getParamList(t.Recv) 97 | if len(structMeta) != 1 { 98 | return 99 | } 100 | if structMeta[0].Type != m.BaseStructType { 101 | return 102 | } 103 | // if m.MethodName is null will generate all methods 104 | if m.MethodName != "" && m.MethodName != t.Name.Name { 105 | return 106 | } 107 | 108 | // use ast read bind start package is UNDEFINED ,set it null string 109 | structMeta[0].Package = "" 110 | m.Methods = append(m.Methods, &Method{ 111 | Receiver: structMeta[0], 112 | MethodName: t.Name.String(), 113 | Doc: t.Doc.Text(), 114 | Body: getBody(m.currentFile, int(t.Body.Pos()), int(t.Body.End())), 115 | Params: getParamList(t.Type.Params), 116 | Result: getParamList(t.Type.Results), 117 | }) 118 | } 119 | 120 | return m 121 | } 122 | 123 | // read old file get method body 124 | func getBody(fileName string, start, end int) string { 125 | f1, err := ioutil.ReadFile(fileName) 126 | if err != nil { 127 | return "{}" 128 | } 129 | 130 | return string(f1[start-1 : end-1]) 131 | } 132 | 133 | // LoadMethods ast read file get diy method 134 | func (m *DIYMethods) LoadMethods() error { 135 | for _, filename := range m.pkgFiles { 136 | f, err := parser.ParseFile(token.NewFileSet(), filename, nil, parser.ParseComments) 137 | if err != nil { 138 | return fmt.Errorf("can't parse file %q: %s", filename, err) 139 | } 140 | m.currentFile = filename 141 | ast.Walk(m, f) 142 | } 143 | 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /internal/parser/utils.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import "go/ast" 4 | 5 | func getParamList(fields *ast.FieldList) []Param { 6 | if fields == nil { 7 | return nil 8 | } 9 | var pars []Param 10 | if len(fields.List) < 1 { 11 | return nil 12 | } 13 | for _, field := range fields.List { 14 | if field.Names == nil { 15 | par := Param{} 16 | par.astGetParamType(field) 17 | pars = append(pars, par) 18 | continue 19 | } 20 | 21 | for _, name := range field.Names { 22 | par := Param{ 23 | Name: name.Name, 24 | } 25 | par.astGetParamType(field) 26 | pars = append(pars, par) 27 | continue 28 | } 29 | } 30 | return pars 31 | } 32 | 33 | func fixParamPackagePath(imports map[string]string, params []Param) { 34 | for i := range params { 35 | if importPath, exist := imports[params[i].Package]; exist { 36 | params[i].PkgPath = importPath 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /internal/template/base.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | const NotEditMark = ` 4 | // Code generated by gorm.io/gen. DO NOT EDIT. 5 | // Code generated by gorm.io/gen. DO NOT EDIT. 6 | // Code generated by gorm.io/gen. DO NOT EDIT. 7 | ` 8 | 9 | const Header = NotEditMark + ` 10 | package {{.Package}} 11 | 12 | import( 13 | {{range .ImportPkgPaths}}{{.}}` + "\n" + `{{end}} 14 | ) 15 | ` 16 | -------------------------------------------------------------------------------- /internal/template/model.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | // Model used as a variable because it cannot load template file after packed, params still can pass file 4 | const Model = NotEditMark + ` 5 | package {{.StructInfo.Package}} 6 | 7 | import ( 8 | "encoding/json" 9 | "time" 10 | 11 | "gorm.io/datatypes" 12 | "gorm.io/gorm" 13 | "gorm.io/gorm/schema" 14 | {{range .ImportPkgPaths}}{{.}} ` + "\n" + `{{end}} 15 | ) 16 | 17 | {{if .TableName -}}const TableName{{.ModelStructName}} = "{{.TableName}}"{{- end}} 18 | 19 | // {{.ModelStructName}} {{.StructComment}} 20 | type {{.ModelStructName}} struct { 21 | {{range .Fields}} 22 | {{if .MultilineComment -}} 23 | /* 24 | {{.ColumnComment}} 25 | */ 26 | {{end -}} 27 | {{.Name}} {{.Type}} ` + "`{{.Tags}}` " + 28 | "{{if not .MultilineComment}}{{if .ColumnComment}}// {{.ColumnComment}}{{end}}{{end}}" + 29 | `{{end}} 30 | } 31 | 32 | ` 33 | 34 | // ModelMethod model struct DIY method 35 | const ModelMethod = ` 36 | 37 | {{if .Doc -}}// {{.DocComment -}}{{end}} 38 | func ({{.GetBaseStructTmpl}}){{.MethodName}}({{.GetParamInTmpl}})({{.GetResultParamInTmpl}}){{.Body}} 39 | ` 40 | -------------------------------------------------------------------------------- /internal/utils/common.go: -------------------------------------------------------------------------------- 1 | package utils 2 | -------------------------------------------------------------------------------- /internal/utils/pools/export.go: -------------------------------------------------------------------------------- 1 | // Package pools : goroutine pools 2 | package pools 3 | 4 | // NewPool return a new pool 5 | func NewPool(size int) Pool { 6 | var p pool 7 | p.Init(size) 8 | return &p 9 | } 10 | -------------------------------------------------------------------------------- /internal/utils/pools/pool.go: -------------------------------------------------------------------------------- 1 | package pools 2 | 3 | import "sync" 4 | 5 | // Pool goroutine pool 6 | type Pool interface { 7 | // Wait 等待令牌 8 | Wait() 9 | // Done 归还令牌 10 | Done() 11 | // Num 当前发放的令牌书 12 | Num() int 13 | // Size 总令牌数 14 | Size() int 15 | 16 | // WaitAll 同步等待令牌全部归还 17 | WaitAll() 18 | // AsyncWaitAll 异步等待令牌全部归还 19 | AsyncWaitAll() <-chan struct{} 20 | } 21 | 22 | type pool struct { 23 | pool chan struct{} 24 | 25 | wg sync.WaitGroup 26 | } 27 | 28 | func (p *pool) Init(size int) { 29 | if size >= 0 { 30 | p.pool = make(chan struct{}, size) 31 | } 32 | } 33 | 34 | func (p *pool) Wait() { 35 | if p.pool != nil { 36 | p.wg.Add(1) 37 | p.pool <- struct{}{} 38 | } 39 | } 40 | 41 | func (p *pool) Done() { 42 | if p.pool != nil { 43 | <-p.pool 44 | p.wg.Done() 45 | } 46 | } 47 | 48 | func (p *pool) Num() int { 49 | if p.pool != nil { 50 | return len(p.pool) 51 | } 52 | return 0 53 | } 54 | 55 | func (p *pool) Size() int { 56 | if p.pool != nil { 57 | return cap(p.pool) 58 | } 59 | return 0 60 | } 61 | 62 | func (p *pool) WaitAll() { 63 | p.wg.Wait() 64 | } 65 | 66 | func (p *pool) AsyncWaitAll() <-chan struct{} { 67 | sig := make(chan struct{}) 68 | go func() { 69 | p.WaitAll() 70 | sig <- struct{}{} 71 | }() 72 | return sig 73 | } 74 | -------------------------------------------------------------------------------- /sec_check.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "gorm.io/gorm/clause" 9 | "gorm.io/hints" 10 | "gorm.io/plugin/dbresolver" 11 | ) 12 | 13 | func checkConds(conds []clause.Expression) error { 14 | for _, cond := range conds { 15 | if err := CheckClause(cond); err != nil { 16 | return err 17 | } 18 | } 19 | return nil 20 | } 21 | 22 | var banClauses = map[string]bool{ 23 | // "INSERT": true, 24 | "VALUES": true, 25 | // "ON CONFLICT": true, 26 | "SELECT": true, 27 | "FROM": true, 28 | "WHERE": true, 29 | "GROUP BY": true, 30 | "ORDER BY": true, 31 | "LIMIT": true, 32 | // "FOR": true, 33 | "UPDATE": true, 34 | "SET": true, 35 | "DELETE": true, 36 | } 37 | 38 | // CheckClause check security of Expression 39 | func CheckClause(cond clause.Expression) error { 40 | switch cond := cond.(type) { 41 | case hints.Hints, hints.IndexHint, dbresolver.Operation: 42 | return nil 43 | case clause.OnConflict: 44 | return checkOnConflict(cond) 45 | case clause.Locking: 46 | return checkLocking(cond) 47 | case clause.Insert: 48 | return checkInsert(cond) 49 | case clause.Interface: 50 | if banClauses[cond.Name()] { 51 | return fmt.Errorf("clause %s is banned", cond.Name()) 52 | } 53 | return nil 54 | } 55 | return fmt.Errorf("unknown clause %v", cond) 56 | } 57 | 58 | func checkOnConflict(c clause.OnConflict) error { 59 | for _, item := range c.DoUpdates { 60 | switch item.Value.(type) { 61 | case clause.Expr, *clause.Expr: 62 | return errors.New("OnConflict clause assignment with gorm.Expr is banned for security reasons for now") 63 | } 64 | } 65 | return nil 66 | } 67 | 68 | func checkLocking(c clause.Locking) error { 69 | if strength := strings.ToUpper(strings.TrimSpace(c.Strength)); strength != "UPDATE" && strength != "SHARE" { 70 | return errors.New("Locking clause's Strength only allow assignments of UPDATE/SHARE") 71 | } 72 | if c.Table.Raw { 73 | return errors.New("Locking clause's Table cannot be set Raw==true") 74 | } 75 | if options := strings.ToUpper(strings.TrimSpace(c.Options)); options != "" && options != "NOWAIT" && options != "SKIP LOCKED" { 76 | return errors.New("Locking clause's Options only allow assignments of NOWAIT/SKIP LOCKED for now") 77 | } 78 | return nil 79 | } 80 | 81 | // checkInsert check if clause.Insert is safe 82 | // https://dev.mysql.com/doc/refman/8.0/en/sql-statements.html#insert 83 | func checkInsert(c clause.Insert) error { 84 | if c.Table.Raw == true { 85 | return errors.New("Table Raw cannot be true") 86 | } 87 | 88 | if c.Modifier == "" { 89 | return nil 90 | } 91 | 92 | var priority, ignore string 93 | if modifiers := strings.SplitN(strings.ToUpper(strings.TrimSpace(c.Modifier)), " ", 2); len(modifiers) == 2 { 94 | priority, ignore = strings.TrimSpace(modifiers[0]), strings.TrimSpace(modifiers[1]) 95 | } else { 96 | ignore = strings.TrimSpace(modifiers[0]) 97 | } 98 | if priority != "" && !in(priority, "LOW_PRIORITY", "DELAYED", "HIGH_PRIORITY") { 99 | return errors.New("invalid priority value") 100 | } 101 | if ignore != "" && ignore != "IGNORE" { 102 | return errors.New("invalid modifiers value, should be IGNORE") 103 | } 104 | return nil 105 | } 106 | 107 | func in(s string, v ...string) bool { 108 | for _, vv := range v { 109 | if vv == s { 110 | return true 111 | } 112 | } 113 | return false 114 | } 115 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 12 | Name string `gorm:"column:name" json:"name"` 13 | Address string `gorm:"column:address" json:"address"` 14 | Scale int64 `gorm:"column:scale" json:"scale"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | Number string `gorm:"column:number" json:"number"` 22 | CustomerRefer int64 `gorm:"column:customer_refer" json:"customer_refer"` 23 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 22 | } 23 | 24 | // TableName Customer's table name 25 | func (*Customer) TableName() string { 26 | return TableNameCustomer 27 | } 28 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/model/people.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNamePerson = "people" 14 | 15 | // Person mapped from table 16 | type Person struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | Name string `gorm:"column:name" json:"name"` 19 | Alias_ string `gorm:"column:alias" json:"alias"` 20 | Age int32 `gorm:"column:age" json:"age"` 21 | Flag bool `gorm:"column:flag" json:"flag"` 22 | AnotherFlag int32 `gorm:"column:another_flag" json:"another_flag"` 23 | Commit string `gorm:"column:commit" json:"commit"` 24 | First bool `gorm:"column:First" json:"First"` 25 | Bit []uint8 `gorm:"column:bit" json:"bit"` 26 | Small int32 `gorm:"column:small" json:"small"` 27 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 28 | Score float64 `gorm:"column:score" json:"score"` 29 | Number int32 `gorm:"column:number" json:"number"` 30 | Birth time.Time `gorm:"column:birth;default:CURRENT_TIMESTAMP" json:"birth"` 31 | XMLHTTPRequest string `gorm:"column:xmlHTTPRequest;default:' '" json:"xmlHTTPRequest"` 32 | JStr string `gorm:"column:jStr" json:"jStr"` 33 | Geo string `gorm:"column:geo" json:"geo"` 34 | Mint int32 `gorm:"column:mint" json:"mint"` 35 | Blank string `gorm:"column:blank;default:' '" json:"blank"` 36 | Remark string `gorm:"column:remark" json:"remark"` 37 | LongRemark string `gorm:"column:long_remark" json:"long_remark"` 38 | } 39 | 40 | // TableName Person's table name 41 | func (*Person) TableName() string { 42 | return TableNamePerson 43 | } 44 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 16 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 17 | Name string `gorm:"column:name;comment:oneline" json:"name"` // oneline 18 | Address string `gorm:"column:address" json:"address"` 19 | RegisterTime time.Time `gorm:"column:register_time" json:"register_time"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"alive"` 26 | CompanyID int64 `gorm:"column:company_id;default:666" json:"company_id"` 27 | PrivateURL string `gorm:"column:private_url;default:https://a.b.c" json:"private_url"` 28 | } 29 | 30 | // TableName User's table name 31 | func (*User) TableName() string { 32 | return TableNameUser 33 | } 34 | -------------------------------------------------------------------------------- /tests/.expect/dal_1/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | Person *person 24 | User *user 25 | ) 26 | 27 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 28 | *Q = *Use(db, opts...) 29 | Bank = &Q.Bank 30 | CreditCard = &Q.CreditCard 31 | Customer = &Q.Customer 32 | Person = &Q.Person 33 | User = &Q.User 34 | } 35 | 36 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 37 | return &Query{ 38 | db: db, 39 | Bank: newBank(db, opts...), 40 | CreditCard: newCreditCard(db, opts...), 41 | Customer: newCustomer(db, opts...), 42 | Person: newPerson(db, opts...), 43 | User: newUser(db, opts...), 44 | } 45 | } 46 | 47 | type Query struct { 48 | db *gorm.DB 49 | 50 | Bank bank 51 | CreditCard creditCard 52 | Customer customer 53 | Person person 54 | User user 55 | } 56 | 57 | func (q *Query) Available() bool { return q.db != nil } 58 | 59 | func (q *Query) clone(db *gorm.DB) *Query { 60 | return &Query{ 61 | db: db, 62 | Bank: q.Bank.clone(db), 63 | CreditCard: q.CreditCard.clone(db), 64 | Customer: q.Customer.clone(db), 65 | Person: q.Person.clone(db), 66 | User: q.User.clone(db), 67 | } 68 | } 69 | 70 | func (q *Query) ReadDB() *Query { 71 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 72 | } 73 | 74 | func (q *Query) WriteDB() *Query { 75 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 76 | } 77 | 78 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 79 | return &Query{ 80 | db: db, 81 | Bank: q.Bank.replaceDB(db), 82 | CreditCard: q.CreditCard.replaceDB(db), 83 | Customer: q.Customer.replaceDB(db), 84 | Person: q.Person.replaceDB(db), 85 | User: q.User.replaceDB(db), 86 | } 87 | } 88 | 89 | type queryCtx struct { 90 | Bank *bankDo 91 | CreditCard *creditCardDo 92 | Customer *customerDo 93 | Person *personDo 94 | User *userDo 95 | } 96 | 97 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 98 | return &queryCtx{ 99 | Bank: q.Bank.WithContext(ctx), 100 | CreditCard: q.CreditCard.WithContext(ctx), 101 | Customer: q.Customer.WithContext(ctx), 102 | Person: q.Person.WithContext(ctx), 103 | User: q.User.WithContext(ctx), 104 | } 105 | } 106 | 107 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 108 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 109 | } 110 | 111 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 112 | tx := q.db.Begin(opts...) 113 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 114 | } 115 | 116 | type QueryTx struct { 117 | *Query 118 | Error error 119 | } 120 | 121 | func (q *QueryTx) Commit() error { 122 | return q.db.Commit().Error 123 | } 124 | 125 | func (q *QueryTx) Rollback() error { 126 | return q.db.Rollback().Error 127 | } 128 | 129 | func (q *QueryTx) SavePoint(name string) error { 130 | return q.db.SavePoint(name).Error 131 | } 132 | 133 | func (q *QueryTx) RollbackTo(name string) error { 134 | return q.db.RollbackTo(name).Error 135 | } 136 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 12 | Name *string `gorm:"column:name" json:"-"` 13 | Address *string `gorm:"column:address" json:"-"` 14 | Scale *int64 `gorm:"column:scale" json:"-"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_credit_cards_deleted_at,priority:1" json:"-"` 21 | Number *string `gorm:"column:number" json:"-"` 22 | CustomerRefer *int64 `gorm:"column:customer_refer" json:"-"` 23 | BankID *int64 `gorm:"column:bank_id" json:"-"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_customers_deleted_at,priority:1" json:"-"` 21 | BankID *int64 `gorm:"column:bank_id" json:"-"` 22 | } 23 | 24 | // TableName Customer's table name 25 | func (*Customer) TableName() string { 26 | return TableNameCustomer 27 | } 28 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/model/people.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNamePerson = "people" 14 | 15 | // Person mapped from table 16 | type Person struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | Name *string `gorm:"column:name" json:"-"` 19 | Alias_ *string `gorm:"column:alias" json:"-"` 20 | Age *int32 `gorm:"column:age" json:"-"` 21 | Flag *bool `gorm:"column:flag" json:"-"` 22 | AnotherFlag *int32 `gorm:"column:another_flag" json:"-"` 23 | Commit *string `gorm:"column:commit" json:"-"` 24 | First *bool `gorm:"column:First" json:"-"` 25 | Bit *[]uint8 `gorm:"column:bit" json:"-"` 26 | Small *int32 `gorm:"column:small" json:"-"` 27 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"-"` 28 | Score *float64 `gorm:"column:score" json:"-"` 29 | Number *int32 `gorm:"column:number" json:"-"` 30 | Birth *time.Time `gorm:"column:birth;default:CURRENT_TIMESTAMP" json:"-"` 31 | XMLHTTPRequest *string `gorm:"column:xmlHTTPRequest;default:' '" json:"-"` 32 | JStr *string `gorm:"column:jStr" json:"-"` 33 | Geo *string `gorm:"column:geo" json:"-"` 34 | Mint *int32 `gorm:"column:mint" json:"-"` 35 | Blank *string `gorm:"column:blank;default:' '" json:"-"` 36 | Remark *string `gorm:"column:remark" json:"-"` 37 | LongRemark *string `gorm:"column:long_remark" json:"-"` 38 | } 39 | 40 | // TableName Person's table name 41 | func (*Person) TableName() string { 42 | return TableNamePerson 43 | } 44 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 16 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 17 | Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline 18 | Address *string `gorm:"column:address" json:"-"` 19 | RegisterTime *time.Time `gorm:"column:register_time" json:"-"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` 26 | CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` 27 | PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` 28 | } 29 | 30 | // TableName User's table name 31 | func (*User) TableName() string { 32 | return TableNameUser 33 | } 34 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/query/banks.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "testing" 11 | 12 | "gorm.io/gen" 13 | "gorm.io/gen/field" 14 | "gorm.io/gen/tests/.gen/dal_2/model" 15 | "gorm.io/gorm/clause" 16 | ) 17 | 18 | func init() { 19 | InitializeDB() 20 | err := _gen_test_db.AutoMigrate(&model.Bank{}) 21 | if err != nil { 22 | fmt.Printf("Error: AutoMigrate(&model.Bank{}) fail: %s", err) 23 | } 24 | } 25 | 26 | func Test_bankQuery(t *testing.T) { 27 | bank := newBank(_gen_test_db) 28 | bank = *bank.As(bank.TableName()) 29 | _do := bank.WithContext(context.Background()).Debug() 30 | 31 | primaryKey := field.NewString(bank.TableName(), clause.PrimaryKey) 32 | _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() 33 | if err != nil { 34 | t.Error("clean table fail:", err) 35 | return 36 | } 37 | 38 | _, ok := bank.GetFieldByName("") 39 | if ok { 40 | t.Error("GetFieldByName(\"\") from bank success") 41 | } 42 | 43 | err = _do.Create(&model.Bank{}) 44 | if err != nil { 45 | t.Error("create item in table fail:", err) 46 | } 47 | 48 | err = _do.Save(&model.Bank{}) 49 | if err != nil { 50 | t.Error("create item in table fail:", err) 51 | } 52 | 53 | err = _do.CreateInBatches([]*model.Bank{{}, {}}, 10) 54 | if err != nil { 55 | t.Error("create item in table fail:", err) 56 | } 57 | 58 | _, err = _do.Select(bank.ALL).Take() 59 | if err != nil { 60 | t.Error("Take() on table fail:", err) 61 | } 62 | 63 | _, err = _do.First() 64 | if err != nil { 65 | t.Error("First() on table fail:", err) 66 | } 67 | 68 | _, err = _do.Last() 69 | if err != nil { 70 | t.Error("First() on table fail:", err) 71 | } 72 | 73 | _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) 74 | if err != nil { 75 | t.Error("FindInBatch() on table fail:", err) 76 | } 77 | 78 | err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.Bank{}, 10, func(tx gen.Dao, batch int) error { return nil }) 79 | if err != nil { 80 | t.Error("FindInBatches() on table fail:", err) 81 | } 82 | 83 | _, err = _do.Select(bank.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() 84 | if err != nil { 85 | t.Error("Find() on table fail:", err) 86 | } 87 | 88 | _, err = _do.Distinct(primaryKey).Take() 89 | if err != nil { 90 | t.Error("select Distinct() on table fail:", err) 91 | } 92 | 93 | _, err = _do.Select(bank.ALL).Omit(primaryKey).Take() 94 | if err != nil { 95 | t.Error("Omit() on table fail:", err) 96 | } 97 | 98 | _, err = _do.Group(primaryKey).Find() 99 | if err != nil { 100 | t.Error("Group() on table fail:", err) 101 | } 102 | 103 | _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() 104 | if err != nil { 105 | t.Error("Scopes() on table fail:", err) 106 | } 107 | 108 | _, _, err = _do.FindByPage(0, 1) 109 | if err != nil { 110 | t.Error("FindByPage() on table fail:", err) 111 | } 112 | 113 | _, err = _do.ScanByPage(&model.Bank{}, 0, 1) 114 | if err != nil { 115 | t.Error("ScanByPage() on table fail:", err) 116 | } 117 | 118 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() 119 | if err != nil { 120 | t.Error("FirstOrInit() on table fail:", err) 121 | } 122 | 123 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() 124 | if err != nil { 125 | t.Error("FirstOrCreate() on table fail:", err) 126 | } 127 | 128 | var _a _another 129 | var _aPK = field.NewString(_a.TableName(), "id") 130 | 131 | err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 132 | if err != nil { 133 | t.Error("Join() on table fail:", err) 134 | } 135 | 136 | err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 137 | if err != nil { 138 | t.Error("LeftJoin() on table fail:", err) 139 | } 140 | 141 | _, err = _do.Not().Or().Clauses().Take() 142 | if err != nil { 143 | t.Error("Not/Or/Clauses on table fail:", err) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | Person *person 24 | User *user 25 | ) 26 | 27 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 28 | *Q = *Use(db, opts...) 29 | Bank = &Q.Bank 30 | CreditCard = &Q.CreditCard 31 | Customer = &Q.Customer 32 | Person = &Q.Person 33 | User = &Q.User 34 | } 35 | 36 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 37 | return &Query{ 38 | db: db, 39 | Bank: newBank(db, opts...), 40 | CreditCard: newCreditCard(db, opts...), 41 | Customer: newCustomer(db, opts...), 42 | Person: newPerson(db, opts...), 43 | User: newUser(db, opts...), 44 | } 45 | } 46 | 47 | type Query struct { 48 | db *gorm.DB 49 | 50 | Bank bank 51 | CreditCard creditCard 52 | Customer customer 53 | Person person 54 | User user 55 | } 56 | 57 | func (q *Query) Available() bool { return q.db != nil } 58 | 59 | func (q *Query) clone(db *gorm.DB) *Query { 60 | return &Query{ 61 | db: db, 62 | Bank: q.Bank.clone(db), 63 | CreditCard: q.CreditCard.clone(db), 64 | Customer: q.Customer.clone(db), 65 | Person: q.Person.clone(db), 66 | User: q.User.clone(db), 67 | } 68 | } 69 | 70 | func (q *Query) ReadDB() *Query { 71 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 72 | } 73 | 74 | func (q *Query) WriteDB() *Query { 75 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 76 | } 77 | 78 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 79 | return &Query{ 80 | db: db, 81 | Bank: q.Bank.replaceDB(db), 82 | CreditCard: q.CreditCard.replaceDB(db), 83 | Customer: q.Customer.replaceDB(db), 84 | Person: q.Person.replaceDB(db), 85 | User: q.User.replaceDB(db), 86 | } 87 | } 88 | 89 | type queryCtx struct { 90 | Bank *bankDo 91 | CreditCard *creditCardDo 92 | Customer *customerDo 93 | Person *personDo 94 | User *userDo 95 | } 96 | 97 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 98 | return &queryCtx{ 99 | Bank: q.Bank.WithContext(ctx), 100 | CreditCard: q.CreditCard.WithContext(ctx), 101 | Customer: q.Customer.WithContext(ctx), 102 | Person: q.Person.WithContext(ctx), 103 | User: q.User.WithContext(ctx), 104 | } 105 | } 106 | 107 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 108 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 109 | } 110 | 111 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 112 | tx := q.db.Begin(opts...) 113 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 114 | } 115 | 116 | type QueryTx struct { 117 | *Query 118 | Error error 119 | } 120 | 121 | func (q *QueryTx) Commit() error { 122 | return q.db.Commit().Error 123 | } 124 | 125 | func (q *QueryTx) Rollback() error { 126 | return q.db.Rollback().Error 127 | } 128 | 129 | func (q *QueryTx) SavePoint(name string) error { 130 | return q.db.SavePoint(name).Error 131 | } 132 | 133 | func (q *QueryTx) RollbackTo(name string) error { 134 | return q.db.RollbackTo(name).Error 135 | } 136 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.Bank.UnderlyingDB().Statement.Context, 81 | qCtx.CreditCard.UnderlyingDB().Statement.Context, 82 | qCtx.Customer.UnderlyingDB().Statement.Context, 83 | qCtx.Person.UnderlyingDB().Statement.Context, 84 | qCtx.User.UnderlyingDB().Statement.Context, 85 | } { 86 | if v := ctx.Value(key); v != value { 87 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 88 | } 89 | } 90 | } 91 | 92 | func Test_Transaction(t *testing.T) { 93 | query := Use(_gen_test_db) 94 | if !query.Available() { 95 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 96 | } 97 | 98 | err := query.Transaction(func(tx *Query) error { return nil }) 99 | if err != nil { 100 | t.Errorf("query.Transaction execute fail: %s", err) 101 | } 102 | 103 | tx := query.Begin() 104 | 105 | err = tx.SavePoint("point") 106 | if err != nil { 107 | t.Errorf("query tx SavePoint fail: %s", err) 108 | } 109 | err = tx.RollbackTo("point") 110 | if err != nil { 111 | t.Errorf("query tx RollbackTo fail: %s", err) 112 | } 113 | err = tx.Commit() 114 | if err != nil { 115 | t.Errorf("query tx Commit fail: %s", err) 116 | } 117 | 118 | err = query.Begin().Rollback() 119 | if err != nil { 120 | t.Errorf("query tx Rollback fail: %s", err) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /tests/.expect/dal_2/query/users.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "testing" 11 | 12 | "gorm.io/gen" 13 | "gorm.io/gen/field" 14 | "gorm.io/gen/tests/.gen/dal_2/model" 15 | "gorm.io/gorm/clause" 16 | ) 17 | 18 | func init() { 19 | InitializeDB() 20 | err := _gen_test_db.AutoMigrate(&model.User{}) 21 | if err != nil { 22 | fmt.Printf("Error: AutoMigrate(&model.User{}) fail: %s", err) 23 | } 24 | } 25 | 26 | func Test_userQuery(t *testing.T) { 27 | user := newUser(_gen_test_db) 28 | user = *user.As(user.TableName()) 29 | _do := user.WithContext(context.Background()).Debug() 30 | 31 | primaryKey := field.NewString(user.TableName(), clause.PrimaryKey) 32 | _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() 33 | if err != nil { 34 | t.Error("clean table fail:", err) 35 | return 36 | } 37 | 38 | _, ok := user.GetFieldByName("") 39 | if ok { 40 | t.Error("GetFieldByName(\"\") from user success") 41 | } 42 | 43 | err = _do.Create(&model.User{}) 44 | if err != nil { 45 | t.Error("create item in table fail:", err) 46 | } 47 | 48 | err = _do.Save(&model.User{}) 49 | if err != nil { 50 | t.Error("create item in table fail:", err) 51 | } 52 | 53 | err = _do.CreateInBatches([]*model.User{{}, {}}, 10) 54 | if err != nil { 55 | t.Error("create item in table fail:", err) 56 | } 57 | 58 | _, err = _do.Select(user.ALL).Take() 59 | if err != nil { 60 | t.Error("Take() on table fail:", err) 61 | } 62 | 63 | _, err = _do.First() 64 | if err != nil { 65 | t.Error("First() on table fail:", err) 66 | } 67 | 68 | _, err = _do.Last() 69 | if err != nil { 70 | t.Error("First() on table fail:", err) 71 | } 72 | 73 | _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) 74 | if err != nil { 75 | t.Error("FindInBatch() on table fail:", err) 76 | } 77 | 78 | err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.User{}, 10, func(tx gen.Dao, batch int) error { return nil }) 79 | if err != nil { 80 | t.Error("FindInBatches() on table fail:", err) 81 | } 82 | 83 | _, err = _do.Select(user.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() 84 | if err != nil { 85 | t.Error("Find() on table fail:", err) 86 | } 87 | 88 | _, err = _do.Distinct(primaryKey).Take() 89 | if err != nil { 90 | t.Error("select Distinct() on table fail:", err) 91 | } 92 | 93 | _, err = _do.Select(user.ALL).Omit(primaryKey).Take() 94 | if err != nil { 95 | t.Error("Omit() on table fail:", err) 96 | } 97 | 98 | _, err = _do.Group(primaryKey).Find() 99 | if err != nil { 100 | t.Error("Group() on table fail:", err) 101 | } 102 | 103 | _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() 104 | if err != nil { 105 | t.Error("Scopes() on table fail:", err) 106 | } 107 | 108 | _, _, err = _do.FindByPage(0, 1) 109 | if err != nil { 110 | t.Error("FindByPage() on table fail:", err) 111 | } 112 | 113 | _, err = _do.ScanByPage(&model.User{}, 0, 1) 114 | if err != nil { 115 | t.Error("ScanByPage() on table fail:", err) 116 | } 117 | 118 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() 119 | if err != nil { 120 | t.Error("FirstOrInit() on table fail:", err) 121 | } 122 | 123 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() 124 | if err != nil { 125 | t.Error("FirstOrCreate() on table fail:", err) 126 | } 127 | 128 | var _a _another 129 | var _aPK = field.NewString(_a.TableName(), "id") 130 | 131 | err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 132 | if err != nil { 133 | t.Error("Join() on table fail:", err) 134 | } 135 | 136 | err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 137 | if err != nil { 138 | t.Error("LeftJoin() on table fail:", err) 139 | } 140 | 141 | _, err = _do.Not().Or().Clauses().Take() 142 | if err != nil { 143 | t.Error("Not/Or/Clauses on table fail:", err) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 12 | Name *string `gorm:"column:name" json:"-"` 13 | Address *string `gorm:"column:address" json:"-"` 14 | Scale *int64 `gorm:"column:scale" json:"-"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_credit_cards_deleted_at,priority:1" json:"-"` 21 | Number *string `gorm:"column:number" json:"-"` 22 | CustomerRefer *int64 `gorm:"column:customer_refer" json:"-"` 23 | BankID *int64 `gorm:"column:bank_id" json:"-"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_customers_deleted_at,priority:1" json:"-"` 21 | BankID *int64 `gorm:"column:bank_id" json:"-"` 22 | } 23 | 24 | // TableName Customer's table name 25 | func (*Customer) TableName() string { 26 | return TableNameCustomer 27 | } 28 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/model/people.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNamePerson = "people" 14 | 15 | // Person mapped from table 16 | type Person struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | Name *string `gorm:"column:name" json:"-"` 19 | Alias_ *string `gorm:"column:alias" json:"-"` 20 | Age *int32 `gorm:"column:age" json:"-"` 21 | Flag *bool `gorm:"column:flag" json:"-"` 22 | AnotherFlag *int32 `gorm:"column:another_flag" json:"-"` 23 | Commit *string `gorm:"column:commit" json:"-"` 24 | First *bool `gorm:"column:First" json:"-"` 25 | Bit *[]uint8 `gorm:"column:bit" json:"-"` 26 | Small *int32 `gorm:"column:small" json:"-"` 27 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"-"` 28 | Score *float64 `gorm:"column:score" json:"-"` 29 | Number *int32 `gorm:"column:number" json:"-"` 30 | Birth *time.Time `gorm:"column:birth;default:CURRENT_TIMESTAMP" json:"-"` 31 | XMLHTTPRequest *string `gorm:"column:xmlHTTPRequest;default:' '" json:"-"` 32 | JStr *string `gorm:"column:jStr" json:"-"` 33 | Geo *string `gorm:"column:geo" json:"-"` 34 | Mint *int32 `gorm:"column:mint" json:"-"` 35 | Blank *string `gorm:"column:blank;default:' '" json:"-"` 36 | Remark *string `gorm:"column:remark" json:"-"` 37 | LongRemark *string `gorm:"column:long_remark" json:"-"` 38 | } 39 | 40 | // TableName Person's table name 41 | func (*Person) TableName() string { 42 | return TableNamePerson 43 | } 44 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 16 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 17 | Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1" json:"-"` // oneline 18 | Address *string `gorm:"column:address" json:"-"` 19 | RegisterTime *time.Time `gorm:"column:register_time" json:"-"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive *bool `gorm:"column:alive" json:"-"` 26 | CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` 27 | PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` 28 | } 29 | 30 | // TableName User's table name 31 | func (*User) TableName() string { 32 | return TableNameUser 33 | } 34 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/query/banks.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "testing" 11 | 12 | "gorm.io/gen" 13 | "gorm.io/gen/field" 14 | "gorm.io/gen/tests/.gen/dal_3/model" 15 | "gorm.io/gorm/clause" 16 | ) 17 | 18 | func init() { 19 | InitializeDB() 20 | err := _gen_test_db.AutoMigrate(&model.Bank{}) 21 | if err != nil { 22 | fmt.Printf("Error: AutoMigrate(&model.Bank{}) fail: %s", err) 23 | } 24 | } 25 | 26 | func Test_bankQuery(t *testing.T) { 27 | bank := newBank(_gen_test_db) 28 | bank = *bank.As(bank.TableName()) 29 | _do := bank.WithContext(context.Background()).Debug() 30 | 31 | primaryKey := field.NewString(bank.TableName(), clause.PrimaryKey) 32 | _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() 33 | if err != nil { 34 | t.Error("clean table fail:", err) 35 | return 36 | } 37 | 38 | _, ok := bank.GetFieldByName("") 39 | if ok { 40 | t.Error("GetFieldByName(\"\") from bank success") 41 | } 42 | 43 | err = _do.Create(&model.Bank{}) 44 | if err != nil { 45 | t.Error("create item in table fail:", err) 46 | } 47 | 48 | err = _do.Save(&model.Bank{}) 49 | if err != nil { 50 | t.Error("create item in table fail:", err) 51 | } 52 | 53 | err = _do.CreateInBatches([]*model.Bank{{}, {}}, 10) 54 | if err != nil { 55 | t.Error("create item in table fail:", err) 56 | } 57 | 58 | _, err = _do.Select(bank.ALL).Take() 59 | if err != nil { 60 | t.Error("Take() on table fail:", err) 61 | } 62 | 63 | _, err = _do.First() 64 | if err != nil { 65 | t.Error("First() on table fail:", err) 66 | } 67 | 68 | _, err = _do.Last() 69 | if err != nil { 70 | t.Error("First() on table fail:", err) 71 | } 72 | 73 | _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) 74 | if err != nil { 75 | t.Error("FindInBatch() on table fail:", err) 76 | } 77 | 78 | err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.Bank{}, 10, func(tx gen.Dao, batch int) error { return nil }) 79 | if err != nil { 80 | t.Error("FindInBatches() on table fail:", err) 81 | } 82 | 83 | _, err = _do.Select(bank.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() 84 | if err != nil { 85 | t.Error("Find() on table fail:", err) 86 | } 87 | 88 | _, err = _do.Distinct(primaryKey).Take() 89 | if err != nil { 90 | t.Error("select Distinct() on table fail:", err) 91 | } 92 | 93 | _, err = _do.Select(bank.ALL).Omit(primaryKey).Take() 94 | if err != nil { 95 | t.Error("Omit() on table fail:", err) 96 | } 97 | 98 | _, err = _do.Group(primaryKey).Find() 99 | if err != nil { 100 | t.Error("Group() on table fail:", err) 101 | } 102 | 103 | _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() 104 | if err != nil { 105 | t.Error("Scopes() on table fail:", err) 106 | } 107 | 108 | _, _, err = _do.FindByPage(0, 1) 109 | if err != nil { 110 | t.Error("FindByPage() on table fail:", err) 111 | } 112 | 113 | _, err = _do.ScanByPage(&model.Bank{}, 0, 1) 114 | if err != nil { 115 | t.Error("ScanByPage() on table fail:", err) 116 | } 117 | 118 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() 119 | if err != nil { 120 | t.Error("FirstOrInit() on table fail:", err) 121 | } 122 | 123 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() 124 | if err != nil { 125 | t.Error("FirstOrCreate() on table fail:", err) 126 | } 127 | 128 | var _a _another 129 | var _aPK = field.NewString(_a.TableName(), "id") 130 | 131 | err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 132 | if err != nil { 133 | t.Error("Join() on table fail:", err) 134 | } 135 | 136 | err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 137 | if err != nil { 138 | t.Error("LeftJoin() on table fail:", err) 139 | } 140 | 141 | _, err = _do.Not().Or().Clauses().Take() 142 | if err != nil { 143 | t.Error("Not/Or/Clauses on table fail:", err) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | Person *person 24 | User *user 25 | ) 26 | 27 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 28 | *Q = *Use(db, opts...) 29 | Bank = &Q.Bank 30 | CreditCard = &Q.CreditCard 31 | Customer = &Q.Customer 32 | Person = &Q.Person 33 | User = &Q.User 34 | } 35 | 36 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 37 | return &Query{ 38 | db: db, 39 | Bank: newBank(db, opts...), 40 | CreditCard: newCreditCard(db, opts...), 41 | Customer: newCustomer(db, opts...), 42 | Person: newPerson(db, opts...), 43 | User: newUser(db, opts...), 44 | } 45 | } 46 | 47 | type Query struct { 48 | db *gorm.DB 49 | 50 | Bank bank 51 | CreditCard creditCard 52 | Customer customer 53 | Person person 54 | User user 55 | } 56 | 57 | func (q *Query) Available() bool { return q.db != nil } 58 | 59 | func (q *Query) clone(db *gorm.DB) *Query { 60 | return &Query{ 61 | db: db, 62 | Bank: q.Bank.clone(db), 63 | CreditCard: q.CreditCard.clone(db), 64 | Customer: q.Customer.clone(db), 65 | Person: q.Person.clone(db), 66 | User: q.User.clone(db), 67 | } 68 | } 69 | 70 | func (q *Query) ReadDB() *Query { 71 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 72 | } 73 | 74 | func (q *Query) WriteDB() *Query { 75 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 76 | } 77 | 78 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 79 | return &Query{ 80 | db: db, 81 | Bank: q.Bank.replaceDB(db), 82 | CreditCard: q.CreditCard.replaceDB(db), 83 | Customer: q.Customer.replaceDB(db), 84 | Person: q.Person.replaceDB(db), 85 | User: q.User.replaceDB(db), 86 | } 87 | } 88 | 89 | type queryCtx struct { 90 | Bank IBankDo 91 | CreditCard ICreditCardDo 92 | Customer ICustomerDo 93 | Person IPersonDo 94 | User IUserDo 95 | } 96 | 97 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 98 | return &queryCtx{ 99 | Bank: q.Bank.WithContext(ctx), 100 | CreditCard: q.CreditCard.WithContext(ctx), 101 | Customer: q.Customer.WithContext(ctx), 102 | Person: q.Person.WithContext(ctx), 103 | User: q.User.WithContext(ctx), 104 | } 105 | } 106 | 107 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 108 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 109 | } 110 | 111 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 112 | tx := q.db.Begin(opts...) 113 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 114 | } 115 | 116 | type QueryTx struct { 117 | *Query 118 | Error error 119 | } 120 | 121 | func (q *QueryTx) Commit() error { 122 | return q.db.Commit().Error 123 | } 124 | 125 | func (q *QueryTx) Rollback() error { 126 | return q.db.Rollback().Error 127 | } 128 | 129 | func (q *QueryTx) SavePoint(name string) error { 130 | return q.db.SavePoint(name).Error 131 | } 132 | 133 | func (q *QueryTx) RollbackTo(name string) error { 134 | return q.db.RollbackTo(name).Error 135 | } 136 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.Bank.UnderlyingDB().Statement.Context, 81 | qCtx.CreditCard.UnderlyingDB().Statement.Context, 82 | qCtx.Customer.UnderlyingDB().Statement.Context, 83 | qCtx.Person.UnderlyingDB().Statement.Context, 84 | qCtx.User.UnderlyingDB().Statement.Context, 85 | } { 86 | if v := ctx.Value(key); v != value { 87 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 88 | } 89 | } 90 | } 91 | 92 | func Test_Transaction(t *testing.T) { 93 | query := Use(_gen_test_db) 94 | if !query.Available() { 95 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 96 | } 97 | 98 | err := query.Transaction(func(tx *Query) error { return nil }) 99 | if err != nil { 100 | t.Errorf("query.Transaction execute fail: %s", err) 101 | } 102 | 103 | tx := query.Begin() 104 | 105 | err = tx.SavePoint("point") 106 | if err != nil { 107 | t.Errorf("query tx SavePoint fail: %s", err) 108 | } 109 | err = tx.RollbackTo("point") 110 | if err != nil { 111 | t.Errorf("query tx RollbackTo fail: %s", err) 112 | } 113 | err = tx.Commit() 114 | if err != nil { 115 | t.Errorf("query tx Commit fail: %s", err) 116 | } 117 | 118 | err = query.Begin().Rollback() 119 | if err != nil { 120 | t.Errorf("query tx Rollback fail: %s", err) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /tests/.expect/dal_3/query/users.gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "testing" 11 | 12 | "gorm.io/gen" 13 | "gorm.io/gen/field" 14 | "gorm.io/gen/tests/.gen/dal_3/model" 15 | "gorm.io/gorm/clause" 16 | ) 17 | 18 | func init() { 19 | InitializeDB() 20 | err := _gen_test_db.AutoMigrate(&model.User{}) 21 | if err != nil { 22 | fmt.Printf("Error: AutoMigrate(&model.User{}) fail: %s", err) 23 | } 24 | } 25 | 26 | func Test_userQuery(t *testing.T) { 27 | user := newUser(_gen_test_db) 28 | user = *user.As(user.TableName()) 29 | _do := user.WithContext(context.Background()).Debug() 30 | 31 | primaryKey := field.NewString(user.TableName(), clause.PrimaryKey) 32 | _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() 33 | if err != nil { 34 | t.Error("clean table fail:", err) 35 | return 36 | } 37 | 38 | _, ok := user.GetFieldByName("") 39 | if ok { 40 | t.Error("GetFieldByName(\"\") from user success") 41 | } 42 | 43 | err = _do.Create(&model.User{}) 44 | if err != nil { 45 | t.Error("create item in table fail:", err) 46 | } 47 | 48 | err = _do.Save(&model.User{}) 49 | if err != nil { 50 | t.Error("create item in table fail:", err) 51 | } 52 | 53 | err = _do.CreateInBatches([]*model.User{{}, {}}, 10) 54 | if err != nil { 55 | t.Error("create item in table fail:", err) 56 | } 57 | 58 | _, err = _do.Select(user.ALL).Take() 59 | if err != nil { 60 | t.Error("Take() on table fail:", err) 61 | } 62 | 63 | _, err = _do.First() 64 | if err != nil { 65 | t.Error("First() on table fail:", err) 66 | } 67 | 68 | _, err = _do.Last() 69 | if err != nil { 70 | t.Error("First() on table fail:", err) 71 | } 72 | 73 | _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) 74 | if err != nil { 75 | t.Error("FindInBatch() on table fail:", err) 76 | } 77 | 78 | err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.User{}, 10, func(tx gen.Dao, batch int) error { return nil }) 79 | if err != nil { 80 | t.Error("FindInBatches() on table fail:", err) 81 | } 82 | 83 | _, err = _do.Select(user.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() 84 | if err != nil { 85 | t.Error("Find() on table fail:", err) 86 | } 87 | 88 | _, err = _do.Distinct(primaryKey).Take() 89 | if err != nil { 90 | t.Error("select Distinct() on table fail:", err) 91 | } 92 | 93 | _, err = _do.Select(user.ALL).Omit(primaryKey).Take() 94 | if err != nil { 95 | t.Error("Omit() on table fail:", err) 96 | } 97 | 98 | _, err = _do.Group(primaryKey).Find() 99 | if err != nil { 100 | t.Error("Group() on table fail:", err) 101 | } 102 | 103 | _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() 104 | if err != nil { 105 | t.Error("Scopes() on table fail:", err) 106 | } 107 | 108 | _, _, err = _do.FindByPage(0, 1) 109 | if err != nil { 110 | t.Error("FindByPage() on table fail:", err) 111 | } 112 | 113 | _, err = _do.ScanByPage(&model.User{}, 0, 1) 114 | if err != nil { 115 | t.Error("ScanByPage() on table fail:", err) 116 | } 117 | 118 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() 119 | if err != nil { 120 | t.Error("FirstOrInit() on table fail:", err) 121 | } 122 | 123 | _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() 124 | if err != nil { 125 | t.Error("FirstOrCreate() on table fail:", err) 126 | } 127 | 128 | var _a _another 129 | var _aPK = field.NewString(_a.TableName(), "id") 130 | 131 | err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 132 | if err != nil { 133 | t.Error("Join() on table fail:", err) 134 | } 135 | 136 | err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) 137 | if err != nil { 138 | t.Error("LeftJoin() on table fail:", err) 139 | } 140 | 141 | _, err = _do.Not().Or().Clauses().Take() 142 | if err != nil { 143 | t.Error("Not/Or/Clauses on table fail:", err) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 12 | Name *string `gorm:"column:name" json:"-"` 13 | Address *string `gorm:"column:address" json:"-"` 14 | Scale *int64 `gorm:"column:scale" json:"-"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_credit_cards_deleted_at,priority:1" json:"-"` 21 | Number *string `gorm:"column:number" json:"-"` 22 | CustomerRefer *int64 `gorm:"column:customer_refer" json:"-"` 23 | BankID *int64 `gorm:"column:bank_id" json:"-"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_customers_deleted_at,priority:1" json:"-"` 21 | BankID *int64 `gorm:"column:bank_id" json:"-"` 22 | } 23 | 24 | // TableName Customer's table name 25 | func (*Customer) TableName() string { 26 | return TableNameCustomer 27 | } 28 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/model/people.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNamePerson = "people" 14 | 15 | // Person mapped from table 16 | type Person struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | Name *string `gorm:"column:name" json:"-"` 19 | Alias_ *string `gorm:"column:alias" json:"-"` 20 | Age *int32 `gorm:"column:age" json:"-"` 21 | Flag *bool `gorm:"column:flag" json:"-"` 22 | AnotherFlag *int32 `gorm:"column:another_flag" json:"-"` 23 | Commit *string `gorm:"column:commit" json:"-"` 24 | First *bool `gorm:"column:First" json:"-"` 25 | Bit *[]uint8 `gorm:"column:bit" json:"-"` 26 | Small *int32 `gorm:"column:small" json:"-"` 27 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"-"` 28 | Score *float64 `gorm:"column:score" json:"-"` 29 | Number *int32 `gorm:"column:number" json:"-"` 30 | Birth *time.Time `gorm:"column:birth;default:CURRENT_TIMESTAMP" json:"-"` 31 | XMLHTTPRequest *string `gorm:"column:xmlHTTPRequest;default:' '" json:"-"` 32 | JStr *string `gorm:"column:jStr" json:"-"` 33 | Geo *string `gorm:"column:geo" json:"-"` 34 | Mint *int32 `gorm:"column:mint" json:"-"` 35 | Blank *string `gorm:"column:blank;default:' '" json:"-"` 36 | Remark *string `gorm:"column:remark" json:"-"` 37 | LongRemark *string `gorm:"column:long_remark" json:"-"` 38 | } 39 | 40 | // TableName Person's table name 41 | func (*Person) TableName() string { 42 | return TableNamePerson 43 | } 44 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 16 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 17 | Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline 18 | Address *string `gorm:"column:address" json:"-"` 19 | RegisterTime *time.Time `gorm:"column:register_time" json:"-"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` 26 | CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` 27 | PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` 28 | } 29 | 30 | // TableName User's table name 31 | func (*User) TableName() string { 32 | return TableNameUser 33 | } 34 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | Person *person 24 | User *user 25 | ) 26 | 27 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 28 | *Q = *Use(db, opts...) 29 | Bank = &Q.Bank 30 | CreditCard = &Q.CreditCard 31 | Customer = &Q.Customer 32 | Person = &Q.Person 33 | User = &Q.User 34 | } 35 | 36 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 37 | return &Query{ 38 | db: db, 39 | Bank: newBank(db, opts...), 40 | CreditCard: newCreditCard(db, opts...), 41 | Customer: newCustomer(db, opts...), 42 | Person: newPerson(db, opts...), 43 | User: newUser(db, opts...), 44 | } 45 | } 46 | 47 | type Query struct { 48 | db *gorm.DB 49 | 50 | Bank bank 51 | CreditCard creditCard 52 | Customer customer 53 | Person person 54 | User user 55 | } 56 | 57 | func (q *Query) Available() bool { return q.db != nil } 58 | 59 | func (q *Query) clone(db *gorm.DB) *Query { 60 | return &Query{ 61 | db: db, 62 | Bank: q.Bank.clone(db), 63 | CreditCard: q.CreditCard.clone(db), 64 | Customer: q.Customer.clone(db), 65 | Person: q.Person.clone(db), 66 | User: q.User.clone(db), 67 | } 68 | } 69 | 70 | func (q *Query) ReadDB() *Query { 71 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 72 | } 73 | 74 | func (q *Query) WriteDB() *Query { 75 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 76 | } 77 | 78 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 79 | return &Query{ 80 | db: db, 81 | Bank: q.Bank.replaceDB(db), 82 | CreditCard: q.CreditCard.replaceDB(db), 83 | Customer: q.Customer.replaceDB(db), 84 | Person: q.Person.replaceDB(db), 85 | User: q.User.replaceDB(db), 86 | } 87 | } 88 | 89 | type queryCtx struct { 90 | Bank IBankDo 91 | CreditCard ICreditCardDo 92 | Customer ICustomerDo 93 | Person IPersonDo 94 | User IUserDo 95 | } 96 | 97 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 98 | return &queryCtx{ 99 | Bank: q.Bank.WithContext(ctx), 100 | CreditCard: q.CreditCard.WithContext(ctx), 101 | Customer: q.Customer.WithContext(ctx), 102 | Person: q.Person.WithContext(ctx), 103 | User: q.User.WithContext(ctx), 104 | } 105 | } 106 | 107 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 108 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 109 | } 110 | 111 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 112 | tx := q.db.Begin(opts...) 113 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 114 | } 115 | 116 | type QueryTx struct { 117 | *Query 118 | Error error 119 | } 120 | 121 | func (q *QueryTx) Commit() error { 122 | return q.db.Commit().Error 123 | } 124 | 125 | func (q *QueryTx) Rollback() error { 126 | return q.db.Rollback().Error 127 | } 128 | 129 | func (q *QueryTx) SavePoint(name string) error { 130 | return q.db.SavePoint(name).Error 131 | } 132 | 133 | func (q *QueryTx) RollbackTo(name string) error { 134 | return q.db.RollbackTo(name).Error 135 | } 136 | -------------------------------------------------------------------------------- /tests/.expect/dal_4/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.Bank.UnderlyingDB().Statement.Context, 81 | qCtx.CreditCard.UnderlyingDB().Statement.Context, 82 | qCtx.Customer.UnderlyingDB().Statement.Context, 83 | qCtx.Person.UnderlyingDB().Statement.Context, 84 | qCtx.User.UnderlyingDB().Statement.Context, 85 | } { 86 | if v := ctx.Value(key); v != value { 87 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 88 | } 89 | } 90 | } 91 | 92 | func Test_Transaction(t *testing.T) { 93 | query := Use(_gen_test_db) 94 | if !query.Available() { 95 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 96 | } 97 | 98 | err := query.Transaction(func(tx *Query) error { return nil }) 99 | if err != nil { 100 | t.Errorf("query.Transaction execute fail: %s", err) 101 | } 102 | 103 | tx := query.Begin() 104 | 105 | err = tx.SavePoint("point") 106 | if err != nil { 107 | t.Errorf("query tx SavePoint fail: %s", err) 108 | } 109 | err = tx.RollbackTo("point") 110 | if err != nil { 111 | t.Errorf("query tx RollbackTo fail: %s", err) 112 | } 113 | err = tx.Commit() 114 | if err != nil { 115 | t.Errorf("query tx Commit fail: %s", err) 116 | } 117 | 118 | err = query.Begin().Rollback() 119 | if err != nil { 120 | t.Errorf("query tx Rollback fail: %s", err) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /tests/.expect/dal_5/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 16 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 17 | Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline 18 | Address *string `gorm:"column:address" json:"-"` 19 | RegisterTime *time.Time `gorm:"column:register_time" json:"-"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` 26 | CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` 27 | PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` 28 | } 29 | 30 | func (m *User) IsEmpty() bool { 31 | if m == nil { 32 | return true 33 | } 34 | return m.ID == 0 35 | } 36 | 37 | func (m *User) GetID() int { 38 | return int(m.ID) 39 | } 40 | 41 | // TableName User's table name 42 | func (*User) TableName() string { 43 | return TableNameUser 44 | } 45 | -------------------------------------------------------------------------------- /tests/.expect/dal_5/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | User *user 21 | ) 22 | 23 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 24 | *Q = *Use(db, opts...) 25 | User = &Q.User 26 | } 27 | 28 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 29 | return &Query{ 30 | db: db, 31 | User: newUser(db, opts...), 32 | } 33 | } 34 | 35 | type Query struct { 36 | db *gorm.DB 37 | 38 | User user 39 | } 40 | 41 | func (q *Query) Available() bool { return q.db != nil } 42 | 43 | func (q *Query) clone(db *gorm.DB) *Query { 44 | return &Query{ 45 | db: db, 46 | User: q.User.clone(db), 47 | } 48 | } 49 | 50 | func (q *Query) ReadDB() *Query { 51 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 52 | } 53 | 54 | func (q *Query) WriteDB() *Query { 55 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 56 | } 57 | 58 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 59 | return &Query{ 60 | db: db, 61 | User: q.User.replaceDB(db), 62 | } 63 | } 64 | 65 | type queryCtx struct { 66 | User IUserDo 67 | } 68 | 69 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 70 | return &queryCtx{ 71 | User: q.User.WithContext(ctx), 72 | } 73 | } 74 | 75 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 76 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 77 | } 78 | 79 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 80 | tx := q.db.Begin(opts...) 81 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 82 | } 83 | 84 | type QueryTx struct { 85 | *Query 86 | Error error 87 | } 88 | 89 | func (q *QueryTx) Commit() error { 90 | return q.db.Commit().Error 91 | } 92 | 93 | func (q *QueryTx) Rollback() error { 94 | return q.db.Rollback().Error 95 | } 96 | 97 | func (q *QueryTx) SavePoint(name string) error { 98 | return q.db.SavePoint(name).Error 99 | } 100 | 101 | func (q *QueryTx) RollbackTo(name string) error { 102 | return q.db.RollbackTo(name).Error 103 | } 104 | -------------------------------------------------------------------------------- /tests/.expect/dal_5/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.User.UnderlyingDB().Statement.Context, 81 | } { 82 | if v := ctx.Value(key); v != value { 83 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 84 | } 85 | } 86 | } 87 | 88 | func Test_Transaction(t *testing.T) { 89 | query := Use(_gen_test_db) 90 | if !query.Available() { 91 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 92 | } 93 | 94 | err := query.Transaction(func(tx *Query) error { return nil }) 95 | if err != nil { 96 | t.Errorf("query.Transaction execute fail: %s", err) 97 | } 98 | 99 | tx := query.Begin() 100 | 101 | err = tx.SavePoint("point") 102 | if err != nil { 103 | t.Errorf("query tx SavePoint fail: %s", err) 104 | } 105 | err = tx.RollbackTo("point") 106 | if err != nil { 107 | t.Errorf("query tx RollbackTo fail: %s", err) 108 | } 109 | err = tx.Commit() 110 | if err != nil { 111 | t.Errorf("query tx Commit fail: %s", err) 112 | } 113 | 114 | err = query.Begin().Rollback() 115 | if err != nil { 116 | t.Errorf("query tx Rollback fail: %s", err) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /tests/.expect/dal_6/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 16 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 17 | Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline 18 | Address *string `gorm:"column:address" json:"-"` 19 | RegisterTime *time.Time `gorm:"column:register_time" json:"-"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` 26 | CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` 27 | PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` 28 | } 29 | 30 | func (m *User) IsEmpty() bool { 31 | if m == nil { 32 | return true 33 | } 34 | return m.ID == 0 35 | } 36 | 37 | func (m *User) GetID() int { 38 | return int(m.ID) 39 | } 40 | 41 | // TableName User's table name 42 | func (*User) TableName() string { 43 | return TableNameUser 44 | } 45 | -------------------------------------------------------------------------------- /tests/.expect/dal_6/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | User *user 21 | ) 22 | 23 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 24 | *Q = *Use(db, opts...) 25 | User = &Q.User 26 | } 27 | 28 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 29 | return &Query{ 30 | db: db, 31 | User: newUser(db, opts...), 32 | } 33 | } 34 | 35 | type Query struct { 36 | db *gorm.DB 37 | 38 | User user 39 | } 40 | 41 | func (q *Query) Available() bool { return q.db != nil } 42 | 43 | func (q *Query) clone(db *gorm.DB) *Query { 44 | return &Query{ 45 | db: db, 46 | User: q.User.clone(db), 47 | } 48 | } 49 | 50 | func (q *Query) ReadDB() *Query { 51 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 52 | } 53 | 54 | func (q *Query) WriteDB() *Query { 55 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 56 | } 57 | 58 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 59 | return &Query{ 60 | db: db, 61 | User: q.User.replaceDB(db), 62 | } 63 | } 64 | 65 | type queryCtx struct { 66 | User IUserDo 67 | } 68 | 69 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 70 | return &queryCtx{ 71 | User: q.User.WithContext(ctx), 72 | } 73 | } 74 | 75 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 76 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 77 | } 78 | 79 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 80 | tx := q.db.Begin(opts...) 81 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 82 | } 83 | 84 | type QueryTx struct { 85 | *Query 86 | Error error 87 | } 88 | 89 | func (q *QueryTx) Commit() error { 90 | return q.db.Commit().Error 91 | } 92 | 93 | func (q *QueryTx) Rollback() error { 94 | return q.db.Rollback().Error 95 | } 96 | 97 | func (q *QueryTx) SavePoint(name string) error { 98 | return q.db.SavePoint(name).Error 99 | } 100 | 101 | func (q *QueryTx) RollbackTo(name string) error { 102 | return q.db.RollbackTo(name).Error 103 | } 104 | -------------------------------------------------------------------------------- /tests/.expect/dal_6/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.User.UnderlyingDB().Statement.Context, 81 | } { 82 | if v := ctx.Value(key); v != value { 83 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 84 | } 85 | } 86 | } 87 | 88 | func Test_Transaction(t *testing.T) { 89 | query := Use(_gen_test_db) 90 | if !query.Available() { 91 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 92 | } 93 | 94 | err := query.Transaction(func(tx *Query) error { return nil }) 95 | if err != nil { 96 | t.Errorf("query.Transaction execute fail: %s", err) 97 | } 98 | 99 | tx := query.Begin() 100 | 101 | err = tx.SavePoint("point") 102 | if err != nil { 103 | t.Errorf("query tx SavePoint fail: %s", err) 104 | } 105 | err = tx.RollbackTo("point") 106 | if err != nil { 107 | t.Errorf("query tx RollbackTo fail: %s", err) 108 | } 109 | err = tx.Commit() 110 | if err != nil { 111 | t.Errorf("query tx Commit fail: %s", err) 112 | } 113 | 114 | err = query.Begin().Rollback() 115 | if err != nil { 116 | t.Errorf("query tx Rollback fail: %s", err) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /tests/.expect/dal_7/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 12 | Name *string `gorm:"column:name" json:"-"` 13 | Address *string `gorm:"column:address" json:"-"` 14 | Scale *int64 `gorm:"column:scale" json:"-"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_7/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_credit_cards_deleted_at,priority:1" json:"-"` 21 | Number *string `gorm:"column:number" json:"-"` 22 | CustomerRefer *int64 `gorm:"column:customer_refer" json:"-"` 23 | BankID *int64 `gorm:"column:bank_id" json:"-"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_7/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` 18 | CreatedAt *time.Time `gorm:"column:created_at" json:"-"` 19 | UpdatedAt *time.Time `gorm:"column:updated_at" json:"-"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index:idx_customers_deleted_at,priority:1" json:"-"` 21 | BankID *int64 `gorm:"column:bank_id" json:"-"` 22 | Bank Bank `gorm:"foreignKey:BankID;references:ID" json:"bank"` 23 | CreditCards []CreditCard `gorm:"foreignKey:CustomerRefer;references:ID" json:"credit_cards"` 24 | } 25 | 26 | // TableName Customer's table name 27 | func (*Customer) TableName() string { 28 | return TableNameCustomer 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_7/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Customer *customer 21 | ) 22 | 23 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 24 | *Q = *Use(db, opts...) 25 | Customer = &Q.Customer 26 | } 27 | 28 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 29 | return &Query{ 30 | db: db, 31 | Customer: newCustomer(db, opts...), 32 | } 33 | } 34 | 35 | type Query struct { 36 | db *gorm.DB 37 | 38 | Customer customer 39 | } 40 | 41 | func (q *Query) Available() bool { return q.db != nil } 42 | 43 | func (q *Query) clone(db *gorm.DB) *Query { 44 | return &Query{ 45 | db: db, 46 | Customer: q.Customer.clone(db), 47 | } 48 | } 49 | 50 | func (q *Query) ReadDB() *Query { 51 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 52 | } 53 | 54 | func (q *Query) WriteDB() *Query { 55 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 56 | } 57 | 58 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 59 | return &Query{ 60 | db: db, 61 | Customer: q.Customer.replaceDB(db), 62 | } 63 | } 64 | 65 | type queryCtx struct { 66 | Customer ICustomerDo 67 | } 68 | 69 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 70 | return &queryCtx{ 71 | Customer: q.Customer.WithContext(ctx), 72 | } 73 | } 74 | 75 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 76 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 77 | } 78 | 79 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 80 | tx := q.db.Begin(opts...) 81 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 82 | } 83 | 84 | type QueryTx struct { 85 | *Query 86 | Error error 87 | } 88 | 89 | func (q *QueryTx) Commit() error { 90 | return q.db.Commit().Error 91 | } 92 | 93 | func (q *QueryTx) Rollback() error { 94 | return q.db.Rollback().Error 95 | } 96 | 97 | func (q *QueryTx) SavePoint(name string) error { 98 | return q.db.SavePoint(name).Error 99 | } 100 | 101 | func (q *QueryTx) RollbackTo(name string) error { 102 | return q.db.RollbackTo(name).Error 103 | } 104 | -------------------------------------------------------------------------------- /tests/.expect/dal_7/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.Customer.UnderlyingDB().Statement.Context, 81 | } { 82 | if v := ctx.Value(key); v != value { 83 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 84 | } 85 | } 86 | } 87 | 88 | func Test_Transaction(t *testing.T) { 89 | query := Use(_gen_test_db) 90 | if !query.Available() { 91 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 92 | } 93 | 94 | err := query.Transaction(func(tx *Query) error { return nil }) 95 | if err != nil { 96 | t.Errorf("query.Transaction execute fail: %s", err) 97 | } 98 | 99 | tx := query.Begin() 100 | 101 | err = tx.SavePoint("point") 102 | if err != nil { 103 | t.Errorf("query tx SavePoint fail: %s", err) 104 | } 105 | err = tx.RollbackTo("point") 106 | if err != nil { 107 | t.Errorf("query tx RollbackTo fail: %s", err) 108 | } 109 | err = tx.Commit() 110 | if err != nil { 111 | t.Errorf("query tx Commit fail: %s", err) 112 | } 113 | 114 | err = query.Begin().Rollback() 115 | if err != nil { 116 | t.Errorf("query tx Rollback fail: %s", err) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /tests/.expect/dal_8/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Comment *comment 21 | Post *post 22 | User *user 23 | ) 24 | 25 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 26 | *Q = *Use(db, opts...) 27 | Comment = &Q.Comment 28 | Post = &Q.Post 29 | User = &Q.User 30 | } 31 | 32 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 33 | return &Query{ 34 | db: db, 35 | Comment: newComment(db, opts...), 36 | Post: newPost(db, opts...), 37 | User: newUser(db, opts...), 38 | } 39 | } 40 | 41 | type Query struct { 42 | db *gorm.DB 43 | 44 | Comment comment 45 | Post post 46 | User user 47 | } 48 | 49 | func (q *Query) Available() bool { return q.db != nil } 50 | 51 | func (q *Query) clone(db *gorm.DB) *Query { 52 | return &Query{ 53 | db: db, 54 | Comment: q.Comment.clone(db), 55 | Post: q.Post.clone(db), 56 | User: q.User.clone(db), 57 | } 58 | } 59 | 60 | func (q *Query) ReadDB() *Query { 61 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 62 | } 63 | 64 | func (q *Query) WriteDB() *Query { 65 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 66 | } 67 | 68 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 69 | return &Query{ 70 | db: db, 71 | Comment: q.Comment.replaceDB(db), 72 | Post: q.Post.replaceDB(db), 73 | User: q.User.replaceDB(db), 74 | } 75 | } 76 | 77 | type queryCtx struct { 78 | Comment ICommentDo 79 | Post IPostDo 80 | User IUserDo 81 | } 82 | 83 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 84 | return &queryCtx{ 85 | Comment: q.Comment.WithContext(ctx), 86 | Post: q.Post.WithContext(ctx), 87 | User: q.User.WithContext(ctx), 88 | } 89 | } 90 | 91 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 92 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 93 | } 94 | 95 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 96 | tx := q.db.Begin(opts...) 97 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 98 | } 99 | 100 | type QueryTx struct { 101 | *Query 102 | Error error 103 | } 104 | 105 | func (q *QueryTx) Commit() error { 106 | return q.db.Commit().Error 107 | } 108 | 109 | func (q *QueryTx) Rollback() error { 110 | return q.db.Rollback().Error 111 | } 112 | 113 | func (q *QueryTx) SavePoint(name string) error { 114 | return q.db.SavePoint(name).Error 115 | } 116 | 117 | func (q *QueryTx) RollbackTo(name string) error { 118 | return q.db.RollbackTo(name).Error 119 | } 120 | -------------------------------------------------------------------------------- /tests/.expect/dal_8/query/gen_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | 14 | "gorm.io/driver/sqlite" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type Input struct { 19 | Args []interface{} 20 | } 21 | 22 | type Expectation struct { 23 | Ret []interface{} 24 | } 25 | 26 | type TestCase struct { 27 | Input 28 | Expectation 29 | } 30 | 31 | const _gen_test_db_name = "gen_test.db" 32 | 33 | var _gen_test_db *gorm.DB 34 | var _gen_test_once sync.Once 35 | 36 | func init() { 37 | InitializeDB() 38 | _gen_test_db.AutoMigrate(&_another{}) 39 | } 40 | 41 | func InitializeDB() { 42 | _gen_test_once.Do(func() { 43 | var err error 44 | _gen_test_db, err = gorm.Open(sqlite.Open(_gen_test_db_name), &gorm.Config{}) 45 | if err != nil { 46 | panic(fmt.Errorf("open sqlite %q fail: %w", _gen_test_db_name, err)) 47 | } 48 | }) 49 | } 50 | 51 | func assert(t *testing.T, methodName string, res, exp interface{}) { 52 | if !reflect.DeepEqual(res, exp) { 53 | t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) 54 | } 55 | } 56 | 57 | type _another struct { 58 | ID uint64 `gorm:"primaryKey"` 59 | } 60 | 61 | func (*_another) TableName() string { return "another_for_unit_test" } 62 | 63 | func Test_Available(t *testing.T) { 64 | if !Use(_gen_test_db).Available() { 65 | t.Errorf("query.Available() == false") 66 | } 67 | } 68 | 69 | func Test_WithContext(t *testing.T) { 70 | query := Use(_gen_test_db) 71 | if !query.Available() { 72 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 73 | } 74 | 75 | type Content string 76 | var key, value Content = "gen_tag", "unit_test" 77 | qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) 78 | 79 | for _, ctx := range []context.Context{ 80 | qCtx.Comment.UnderlyingDB().Statement.Context, 81 | qCtx.Post.UnderlyingDB().Statement.Context, 82 | qCtx.User.UnderlyingDB().Statement.Context, 83 | } { 84 | if v := ctx.Value(key); v != value { 85 | t.Errorf("get value from context fail, expect %q, got %q", value, v) 86 | } 87 | } 88 | } 89 | 90 | func Test_Transaction(t *testing.T) { 91 | query := Use(_gen_test_db) 92 | if !query.Available() { 93 | t.Errorf("query Use(_gen_test_db) fail: query.Available() == false") 94 | } 95 | 96 | err := query.Transaction(func(tx *Query) error { return nil }) 97 | if err != nil { 98 | t.Errorf("query.Transaction execute fail: %s", err) 99 | } 100 | 101 | tx := query.Begin() 102 | 103 | err = tx.SavePoint("point") 104 | if err != nil { 105 | t.Errorf("query tx SavePoint fail: %s", err) 106 | } 107 | err = tx.RollbackTo("point") 108 | if err != nil { 109 | t.Errorf("query tx RollbackTo fail: %s", err) 110 | } 111 | err = tx.Commit() 112 | if err != nil { 113 | t.Errorf("query tx Commit fail: %s", err) 114 | } 115 | 116 | err = query.Begin().Rollback() 117 | if err != nil { 118 | t.Errorf("query tx Rollback fail: %s", err) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 12 | Name string `gorm:"column:name" json:"name"` 13 | Address string `gorm:"column:address" json:"address"` 14 | Scale int64 `gorm:"column:scale" json:"scale"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | Number string `gorm:"column:number" json:"number"` 22 | CustomerRefer int64 `gorm:"column:customer_refer" json:"customer_refer"` 23 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 22 | } 23 | 24 | // TableName Customer's table name 25 | func (*Customer) TableName() string { 26 | return TableNameCustomer 27 | } 28 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/model/people.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNamePerson = "people" 14 | 15 | // Person mapped from table 16 | type Person struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | Name string `gorm:"column:name" json:"name"` 19 | Alias_ string `gorm:"column:alias" json:"alias"` 20 | Age int32 `gorm:"column:age" json:"age"` 21 | Flag bool `gorm:"column:flag" json:"flag"` 22 | AnotherFlag int32 `gorm:"column:another_flag" json:"another_flag"` 23 | Commit string `gorm:"column:commit" json:"commit"` 24 | First bool `gorm:"column:First" json:"First"` 25 | Bit []uint8 `gorm:"column:bit" json:"bit"` 26 | Small int32 `gorm:"column:small" json:"small"` 27 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 28 | Score float64 `gorm:"column:score" json:"score"` 29 | Number int32 `gorm:"column:number" json:"number"` 30 | Birth time.Time `gorm:"column:birth;default:CURRENT_TIMESTAMP" json:"birth"` 31 | XMLHTTPRequest string `gorm:"column:xmlHTTPRequest;default:' '" json:"xmlHTTPRequest"` 32 | JStr string `gorm:"column:jStr" json:"jStr"` 33 | Geo string `gorm:"column:geo" json:"geo"` 34 | Mint int32 `gorm:"column:mint" json:"mint"` 35 | Blank string `gorm:"column:blank;default:' '" json:"blank"` 36 | Remark string `gorm:"column:remark" json:"remark"` 37 | LongRemark string `gorm:"column:long_remark" json:"long_remark"` 38 | } 39 | 40 | // TableName Person's table name 41 | func (*Person) TableName() string { 42 | return TableNamePerson 43 | } 44 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/model/users.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | ) 10 | 11 | const TableNameUser = "users" 12 | 13 | // User mapped from table 14 | type User struct { 15 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 16 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 17 | Name string `gorm:"column:name;comment:oneline" json:"name"` // oneline 18 | Address string `gorm:"column:address" json:"address"` 19 | RegisterTime time.Time `gorm:"column:register_time" json:"register_time"` 20 | /* 21 | multiline 22 | line1 23 | line2 24 | */ 25 | Alive bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"alive"` 26 | CompanyID int64 `gorm:"column:company_id;default:666" json:"company_id"` 27 | PrivateURL string `gorm:"column:private_url;default:https://a.b.c" json:"private_url"` 28 | } 29 | 30 | // TableName User's table name 31 | func (*User) TableName() string { 32 | return TableNameUser 33 | } 34 | -------------------------------------------------------------------------------- /tests/.expect/dal_test/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | Person *person 24 | User *user 25 | ) 26 | 27 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 28 | *Q = *Use(db, opts...) 29 | Bank = &Q.Bank 30 | CreditCard = &Q.CreditCard 31 | Customer = &Q.Customer 32 | Person = &Q.Person 33 | User = &Q.User 34 | } 35 | 36 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 37 | return &Query{ 38 | db: db, 39 | Bank: newBank(db, opts...), 40 | CreditCard: newCreditCard(db, opts...), 41 | Customer: newCustomer(db, opts...), 42 | Person: newPerson(db, opts...), 43 | User: newUser(db, opts...), 44 | } 45 | } 46 | 47 | type Query struct { 48 | db *gorm.DB 49 | 50 | Bank bank 51 | CreditCard creditCard 52 | Customer customer 53 | Person person 54 | User user 55 | } 56 | 57 | func (q *Query) Available() bool { return q.db != nil } 58 | 59 | func (q *Query) clone(db *gorm.DB) *Query { 60 | return &Query{ 61 | db: db, 62 | Bank: q.Bank.clone(db), 63 | CreditCard: q.CreditCard.clone(db), 64 | Customer: q.Customer.clone(db), 65 | Person: q.Person.clone(db), 66 | User: q.User.clone(db), 67 | } 68 | } 69 | 70 | func (q *Query) ReadDB() *Query { 71 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 72 | } 73 | 74 | func (q *Query) WriteDB() *Query { 75 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 76 | } 77 | 78 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 79 | return &Query{ 80 | db: db, 81 | Bank: q.Bank.replaceDB(db), 82 | CreditCard: q.CreditCard.replaceDB(db), 83 | Customer: q.Customer.replaceDB(db), 84 | Person: q.Person.replaceDB(db), 85 | User: q.User.replaceDB(db), 86 | } 87 | } 88 | 89 | type queryCtx struct { 90 | Bank *bankDo 91 | CreditCard *creditCardDo 92 | Customer *customerDo 93 | Person *personDo 94 | User *userDo 95 | } 96 | 97 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 98 | return &queryCtx{ 99 | Bank: q.Bank.WithContext(ctx), 100 | CreditCard: q.CreditCard.WithContext(ctx), 101 | Customer: q.Customer.WithContext(ctx), 102 | Person: q.Person.WithContext(ctx), 103 | User: q.User.WithContext(ctx), 104 | } 105 | } 106 | 107 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 108 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 109 | } 110 | 111 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 112 | tx := q.db.Begin(opts...) 113 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 114 | } 115 | 116 | type QueryTx struct { 117 | *Query 118 | Error error 119 | } 120 | 121 | func (q *QueryTx) Commit() error { 122 | return q.db.Commit().Error 123 | } 124 | 125 | func (q *QueryTx) Rollback() error { 126 | return q.db.Rollback().Error 127 | } 128 | 129 | func (q *QueryTx) SavePoint(name string) error { 130 | return q.db.SavePoint(name).Error 131 | } 132 | 133 | func (q *QueryTx) RollbackTo(name string) error { 134 | return q.db.RollbackTo(name).Error 135 | } 136 | -------------------------------------------------------------------------------- /tests/.expect/dal_test_relation/model/banks.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | const TableNameBank = "banks" 8 | 9 | // Bank mapped from table 10 | type Bank struct { 11 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 12 | Name string `gorm:"column:name" json:"name"` 13 | Address string `gorm:"column:address" json:"address"` 14 | Scale int64 `gorm:"column:scale" json:"scale"` 15 | } 16 | 17 | // TableName Bank's table name 18 | func (*Bank) TableName() string { 19 | return TableNameBank 20 | } 21 | -------------------------------------------------------------------------------- /tests/.expect/dal_test_relation/model/credit_cards.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCreditCard = "credit_cards" 14 | 15 | // CreditCard mapped from table 16 | type CreditCard struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | Number string `gorm:"column:number" json:"number"` 22 | CustomerRefer int64 `gorm:"column:customer_refer" json:"customer_refer"` 23 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 24 | } 25 | 26 | // TableName CreditCard's table name 27 | func (*CreditCard) TableName() string { 28 | return TableNameCreditCard 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_test_relation/model/customers.gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package model 6 | 7 | import ( 8 | "time" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const TableNameCustomer = "customers" 14 | 15 | // Customer mapped from table 16 | type Customer struct { 17 | ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` 18 | CreatedAt time.Time `gorm:"column:created_at" json:"created_at"` 19 | UpdatedAt time.Time `gorm:"column:updated_at" json:"updated_at"` 20 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` 21 | BankID int64 `gorm:"column:bank_id" json:"bank_id"` 22 | Bank Bank `gorm:"foreignKey:BankID;references:ID" json:"bank"` 23 | CreditCards []CreditCard `gorm:"foreignKey:CustomerRefer;references:ID" json:"credit_cards"` 24 | } 25 | 26 | // TableName Customer's table name 27 | func (*Customer) TableName() string { 28 | return TableNameCustomer 29 | } 30 | -------------------------------------------------------------------------------- /tests/.expect/dal_test_relation/query/gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by gorm.io/gen. DO NOT EDIT. 2 | // Code generated by gorm.io/gen. DO NOT EDIT. 3 | // Code generated by gorm.io/gen. DO NOT EDIT. 4 | 5 | package query 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "gorm.io/gorm" 12 | 13 | "gorm.io/gen" 14 | 15 | "gorm.io/plugin/dbresolver" 16 | ) 17 | 18 | var ( 19 | Q = new(Query) 20 | Bank *bank 21 | CreditCard *creditCard 22 | Customer *customer 23 | ) 24 | 25 | func SetDefault(db *gorm.DB, opts ...gen.DOOption) { 26 | *Q = *Use(db, opts...) 27 | Bank = &Q.Bank 28 | CreditCard = &Q.CreditCard 29 | Customer = &Q.Customer 30 | } 31 | 32 | func Use(db *gorm.DB, opts ...gen.DOOption) *Query { 33 | return &Query{ 34 | db: db, 35 | Bank: newBank(db, opts...), 36 | CreditCard: newCreditCard(db, opts...), 37 | Customer: newCustomer(db, opts...), 38 | } 39 | } 40 | 41 | type Query struct { 42 | db *gorm.DB 43 | 44 | Bank bank 45 | CreditCard creditCard 46 | Customer customer 47 | } 48 | 49 | func (q *Query) Available() bool { return q.db != nil } 50 | 51 | func (q *Query) clone(db *gorm.DB) *Query { 52 | return &Query{ 53 | db: db, 54 | Bank: q.Bank.clone(db), 55 | CreditCard: q.CreditCard.clone(db), 56 | Customer: q.Customer.clone(db), 57 | } 58 | } 59 | 60 | func (q *Query) ReadDB() *Query { 61 | return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) 62 | } 63 | 64 | func (q *Query) WriteDB() *Query { 65 | return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) 66 | } 67 | 68 | func (q *Query) ReplaceDB(db *gorm.DB) *Query { 69 | return &Query{ 70 | db: db, 71 | Bank: q.Bank.replaceDB(db), 72 | CreditCard: q.CreditCard.replaceDB(db), 73 | Customer: q.Customer.replaceDB(db), 74 | } 75 | } 76 | 77 | type queryCtx struct { 78 | Bank *bankDo 79 | CreditCard *creditCardDo 80 | Customer *customerDo 81 | } 82 | 83 | func (q *Query) WithContext(ctx context.Context) *queryCtx { 84 | return &queryCtx{ 85 | Bank: q.Bank.WithContext(ctx), 86 | CreditCard: q.CreditCard.WithContext(ctx), 87 | Customer: q.Customer.WithContext(ctx), 88 | } 89 | } 90 | 91 | func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { 92 | return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) 93 | } 94 | 95 | func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { 96 | tx := q.db.Begin(opts...) 97 | return &QueryTx{Query: q.clone(tx), Error: tx.Error} 98 | } 99 | 100 | type QueryTx struct { 101 | *Query 102 | Error error 103 | } 104 | 105 | func (q *QueryTx) Commit() error { 106 | return q.db.Commit().Error 107 | } 108 | 109 | func (q *QueryTx) Rollback() error { 110 | return q.db.Rollback().Error 111 | } 112 | 113 | func (q *QueryTx) SavePoint(name string) error { 114 | return q.db.SavePoint(name).Error 115 | } 116 | 117 | func (q *QueryTx) RollbackTo(name string) error { 118 | return q.db.RollbackTo(name).Error 119 | } 120 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | go.sum 2 | 3 | .gen/ 4 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Test Guide 2 | 3 | ```bash 4 | cd tests 5 | # prepare test databases 6 | docker-compose up 7 | 8 | # run all tests 9 | ./tests_all.sh 10 | ``` 11 | -------------------------------------------------------------------------------- /tests/create_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/gen/tests/.expect/dal_test/model" 7 | "gorm.io/gen/tests/.expect/dal_test/query" 8 | ) 9 | 10 | func TestCreate(t *testing.T) { 11 | useOnce.Do(CRUDInit) 12 | 13 | u := query.User 14 | 15 | err := u.WithContext(ctx).Create(&model.User{ID: 1}) 16 | if err != nil { 17 | t.Errorf("create model fail: %s", err) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /tests/ddl_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "regexp" 7 | ) 8 | 9 | const ddlPath = "tables.sql" 10 | 11 | var reg, _ = regexp.Compile(`(DROP TABLE IF EXISTS \x60.*?\x60;)\s(CREATE TABLE [\s\S][^;]*;)`) 12 | 13 | func GetDDL() (tableMetas [][2]string) { 14 | data, err := ioutil.ReadFile(ddlPath) 15 | if err != nil { 16 | log.Fatalf("read ddl fail: %s", err) 17 | return nil 18 | } 19 | 20 | results := reg.FindAllStringSubmatch(string(data), -1) 21 | for _, res := range results { 22 | tableMetas = append(tableMetas, [2]string{res[1], res[2]}) 23 | } 24 | return tableMetas 25 | } 26 | -------------------------------------------------------------------------------- /tests/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | mysql: 5 | image: 'mysql/mysql-server:latest' 6 | ports: 7 | - 9910:3306 8 | environment: 9 | - MYSQL_DATABASE=gen 10 | - MYSQL_USER=gen 11 | - MYSQL_PASSWORD=gen 12 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 13 | command: mysqld --sql_mode="NO_ENGINE_SUBSTITUTION" 14 | -------------------------------------------------------------------------------- /tests/gen_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "gorm.io/gen/tests/.expect/dal_test/query" 8 | relquery "gorm.io/gen/tests/.expect/dal_test_relation/query" 9 | ) 10 | 11 | var useOnce sync.Once 12 | var ctx = context.Background() 13 | 14 | func CRUDInit() { 15 | query.Use(DB) 16 | query.SetDefault(DB) 17 | relquery.Use(DB) 18 | relquery.SetDefault(DB) 19 | } 20 | -------------------------------------------------------------------------------- /tests/go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gen/tests 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/mattn/go-sqlite3 v1.14.16 // indirect 7 | gorm.io/driver/mysql v1.5.7 8 | gorm.io/driver/sqlite v1.4.4 9 | gorm.io/gen v0.3.19 10 | gorm.io/gorm v1.25.12 11 | gorm.io/hints v1.1.1 // indirect 12 | gorm.io/plugin/dbresolver v1.5.3 13 | ) 14 | 15 | replace gorm.io/gen => ../ 16 | -------------------------------------------------------------------------------- /tests/query_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/gen/tests/.expect/dal_test/model" 7 | "gorm.io/gen/tests/.expect/dal_test/query" 8 | ) 9 | 10 | func TestQuery_Find(t *testing.T) { 11 | useOnce.Do(CRUDInit) 12 | 13 | u := query.User 14 | 15 | err := u.WithContext(ctx).Create(&model.User{ID: 100}) 16 | if err != nil { 17 | t.Errorf("create model fail: %s", err) 18 | } 19 | 20 | user, err := u.WithContext(ctx).Where(u.ID.Eq(100)).Take() 21 | if err != nil { 22 | t.Errorf("take model fail: %s", err) 23 | } 24 | if user.ID != 100 { 25 | t.Errorf("take model fail: %+v", user) 26 | } 27 | t.Logf("got model: %+v", user) 28 | } 29 | -------------------------------------------------------------------------------- /tests/tables.sql: -------------------------------------------------------------------------------- 1 | -- ------------------------------------------------------------- 2 | -- Database: gen 3 | -- Generation Time: 2022-08-29 11:37:29.2770 4 | -- ------------------------------------------------------------- 5 | 6 | 7 | /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; 8 | /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; 9 | /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; 10 | /*!40101 SET NAMES utf8mb4 */; 11 | /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; 12 | /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; 13 | /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; 14 | /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; 15 | 16 | 17 | DROP TABLE IF EXISTS `banks`; 18 | CREATE TABLE `banks` ( 19 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 20 | `name` longtext, 21 | `address` longtext, 22 | `scale` bigint(20) DEFAULT NULL, 23 | PRIMARY KEY (`id`) 24 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 25 | 26 | DROP TABLE IF EXISTS `credit_cards`; 27 | CREATE TABLE `credit_cards` ( 28 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 29 | `created_at` datetime(3) DEFAULT NULL, 30 | `updated_at` datetime(3) DEFAULT NULL, 31 | `deleted_at` datetime(3) DEFAULT NULL, 32 | `number` longtext, 33 | `customer_refer` bigint(20) unsigned DEFAULT NULL, 34 | `bank_id` bigint(20) unsigned DEFAULT NULL, 35 | PRIMARY KEY (`id`), 36 | KEY `idx_credit_cards_deleted_at` (`deleted_at`) 37 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 38 | 39 | DROP TABLE IF EXISTS `customers`; 40 | CREATE TABLE `customers` ( 41 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 42 | `created_at` datetime(3) DEFAULT NULL, 43 | `updated_at` datetime(3) DEFAULT NULL, 44 | `deleted_at` datetime(3) DEFAULT NULL, 45 | `bank_id` bigint(20) unsigned DEFAULT NULL, 46 | PRIMARY KEY (`id`), 47 | KEY `idx_customers_deleted_at` (`deleted_at`) 48 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 49 | 50 | DROP TABLE IF EXISTS `people`; 51 | CREATE TABLE `people` ( 52 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 53 | `name` varchar(255) DEFAULT NULL, 54 | `alias` varchar(255) DEFAULT NULL, 55 | `age` int(11) unsigned DEFAULT NULL, 56 | `flag` tinyint(1) DEFAULT NULL, 57 | `another_flag` tinyint(4) DEFAULT NULL, 58 | `commit` varchar(255) DEFAULT NULL, 59 | `First` tinyint(1) DEFAULT NULL, 60 | `bit` bit(1) DEFAULT NULL, 61 | `small` smallint(5) unsigned DEFAULT NULL, 62 | `deleted_at` datetime(3) DEFAULT NULL, 63 | `score` decimal(19,0) DEFAULT NULL, 64 | `number` int(11) DEFAULT NULL, 65 | `birth` datetime DEFAULT CURRENT_TIMESTAMP, 66 | `xmlHTTPRequest` varchar(255) DEFAULT ' ', 67 | `jStr` json DEFAULT NULL, 68 | `geo` geometry DEFAULT NULL, 69 | `mint` mediumint(9) DEFAULT NULL, 70 | `blank` varchar(64) DEFAULT ' ', 71 | `remark` text, 72 | `long_remark` longtext, 73 | PRIMARY KEY (`id`) 74 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 75 | 76 | DROP TABLE IF EXISTS `users`; 77 | CREATE TABLE `users` ( 78 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 79 | `created_at` datetime(3) DEFAULT NULL, 80 | `name` varchar(255) DEFAULT NULL COMMENT 'oneline', 81 | `address` varchar(255) DEFAULT '', 82 | `register_time` datetime(3) DEFAULT NULL, 83 | `alive` tinyint(1) DEFAULT NULL COMMENT 'multiline\nline1\nline2', 84 | `company_id` bigint(20) unsigned DEFAULT '666', 85 | `private_url` varchar(255) DEFAULT 'https://a.b.c ', 86 | PRIMARY KEY (`id`), 87 | KEY `idx_name` (`name`) USING BTREE, 88 | KEY `idx_name_company_id` (`name`,`company_id`) 89 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 90 | 91 | 92 | 93 | /*!40101 SET SQL_MODE=@OLD_SQL_MODE */; 94 | /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; 95 | /*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; 96 | /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; 97 | /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; 98 | /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; 99 | /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; 100 | -------------------------------------------------------------------------------- /tests/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | # dialects=("sqlite" "mysql" "postgres" "sqlserver") 4 | dialects=("mysql") 5 | 6 | if [[ $(pwd) == *"gen/tests"* ]]; then 7 | cd .. 8 | fi 9 | 10 | if [ -d tests ] 11 | then 12 | cd tests 13 | go get -t ./... 14 | go mod download 15 | go mod tidy 16 | cd .. 17 | fi 18 | 19 | # SqlServer for Mac M1 20 | if [[ -z $GITHUB_ACTION ]]; then 21 | if [ -d tests ] 22 | then 23 | cd tests 24 | if [[ $(uname -a) == *" arm64" ]]; then 25 | MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true 26 | go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true 27 | SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gen') IS NULL CREATE DATABASE gen" > /dev/null || true 28 | SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gen') IS NULL CREATE LOGIN gen WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true 29 | SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gen') IS NULL CREATE USER gen FROM LOGIN gen; ALTER SERVER ROLE sysadmin ADD MEMBER [gen];" > /dev/null || true 30 | else 31 | docker-compose start 32 | fi 33 | cd .. 34 | fi 35 | fi 36 | 37 | for dialect in "${dialects[@]}" ; do 38 | if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] 39 | then 40 | echo "testing ${dialect}..." 41 | 42 | if [ "$GEN_VERBOSE" = "" ] 43 | then 44 | if [ -d tests ] 45 | then 46 | cd tests 47 | GORM_DIALECT=${dialect} go test -race -count=1 ./... 48 | cd .. 49 | fi 50 | else 51 | if [ -d tests ] 52 | then 53 | cd tests 54 | GORM_DIALECT=${dialect} go test -race -count=1 -v ./... 55 | cd .. 56 | fi 57 | fi 58 | fi 59 | done 60 | -------------------------------------------------------------------------------- /tests/tests_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "path/filepath" 7 | 8 | "gorm.io/driver/mysql" 9 | "gorm.io/driver/sqlite" 10 | "gorm.io/gen" 11 | "gorm.io/gorm" 12 | "gorm.io/gorm/logger" 13 | ) 14 | 15 | const ( 16 | mysqlDSN = "gen:gen@tcp(localhost:9910)/gen?charset=utf8&parseTime=True&loc=Local" 17 | postgresDSN = "user=gen password=gen dbname=gen host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" 18 | sqlserverDSN = "sqlserver://gen:LoremIpsum86@localhost:9930?database=gen" 19 | ) 20 | 21 | var DB *gorm.DB 22 | 23 | func init() { 24 | log.Print("initing...") 25 | var err error 26 | if DB, err = OpenTestConnection(); err != nil { 27 | log.Printf("failed to connect database, got error %v", err) 28 | os.Exit(1) 29 | } else { 30 | sqlDB, err := DB.DB() 31 | if err != nil { 32 | log.Printf("failed to connect database, got error %v", err) 33 | os.Exit(1) 34 | } 35 | 36 | err = sqlDB.Ping() 37 | if err != nil { 38 | log.Printf("failed to ping sqlDB, got error %v", err) 39 | os.Exit(1) 40 | } 41 | 42 | // RunMigrations() 43 | if DB.Dialector.Name() == "sqlite" { 44 | DB.Exec("PRAGMA foreign_keys = ON") 45 | } 46 | } 47 | RunMigrations() 48 | 49 | var generators []*gen.Generator 50 | for dir, build := range generateCase { 51 | generators = append(generators, build(dir)) 52 | } 53 | RunGenerate(generators...) 54 | } 55 | 56 | func OpenTestConnection() (db *gorm.DB, err error) { 57 | dbDSN := os.Getenv("GEN_DSN") 58 | switch os.Getenv("GORM_DIALECT") { 59 | case "mysql": 60 | log.Println("testing mysql...") 61 | if dbDSN == "" { 62 | dbDSN = mysqlDSN 63 | } 64 | db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) 65 | default: 66 | log.Println("testing sqlite3...") 67 | db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) 68 | } 69 | 70 | if err != nil { 71 | return 72 | } 73 | 74 | if debug := os.Getenv("DEBUG"); debug == "true" { 75 | db.Logger = db.Logger.LogMode(logger.Info) 76 | } else if debug == "false" { 77 | db.Logger = db.Logger.LogMode(logger.Silent) 78 | } 79 | 80 | return 81 | } 82 | 83 | func RunMigrations() { 84 | db := DB.Session(&gorm.Session{}) 85 | for _, meta := range GetDDL() { 86 | dropTable, createTable := meta[0], meta[1] 87 | if err := db.Exec(dropTable).Error; err != nil { 88 | log.Printf("drop table fail: %s", err) 89 | } 90 | if err := db.Exec(createTable).Error; err != nil { 91 | log.Printf("create table fail: %s", err) 92 | } 93 | } 94 | } 95 | 96 | func RunGenerate(gs ...*gen.Generator) { 97 | for _, g := range gs { 98 | g.Execute() 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /tests/transaction_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gen/tests/.expect/dal_test_relation/model" 8 | "gorm.io/gen/tests/.expect/dal_test_relation/query" 9 | ) 10 | 11 | func TestQuery_Transaction_Relation(t *testing.T) { 12 | useOnce.Do(CRUDInit) 13 | 14 | t.Run("transaction has many", func(t *testing.T) { 15 | if err := query.Q.Transaction(func(tx *query.Query) error { 16 | c := tx.Customer 17 | customer := &model.Customer{ 18 | Bank: model.Bank{ 19 | Name: "bank1", 20 | Address: "bank-address1", 21 | Scale: 1, 22 | }, 23 | CreditCards: []model.CreditCard{ 24 | {Number: "num1"}, 25 | {Number: "num2"}, 26 | }, 27 | } 28 | if err := c.WithContext(ctx).Create(customer); err != nil { 29 | return fmt.Errorf("create model fail: %s", err) 30 | } 31 | 32 | got, err := c.WithContext(ctx).Where(c.ID.Eq(customer.ID)). 33 | Preload(c.CreditCards). 34 | Preload(c.Bank). 35 | First() 36 | if err != nil { 37 | return fmt.Errorf("find model fail: %s", err) 38 | } 39 | if len(got.CreditCards) != 2 { 40 | return fmt.Errorf("replace model fail, expect %d, got %d", 1, len(got.CreditCards)) 41 | } 42 | 43 | if err := c.CreditCards.WithContext(ctx).Model(customer).Replace(&model.CreditCard{ 44 | Number: "num_replace", 45 | }); err != nil { 46 | return fmt.Errorf("replace model fail: %s", err) 47 | } 48 | 49 | got, err = c.WithContext(ctx).Where(c.ID.Eq(customer.ID)). 50 | Preload(c.CreditCards). 51 | Preload(c.Bank). 52 | First() 53 | if err != nil { 54 | return fmt.Errorf("find model fail: %s", err) 55 | } 56 | if len(got.CreditCards) != 1 { 57 | return fmt.Errorf("replace model fail, expect %d, got %d", 1, len(got.CreditCards)) 58 | } 59 | if got.CreditCards[0].Number != "num_replace" { 60 | return fmt.Errorf("replace model fail, expect %q, got %q", "num_replace", got.CreditCards[0].Number) 61 | } 62 | 63 | return nil 64 | }); err != nil { 65 | t.Errorf("transaction execute fail: %s", err) 66 | } 67 | }) 68 | 69 | t.Run("transaction has one", func(t *testing.T) { 70 | if err := query.Q.Transaction(func(tx *query.Query) error { 71 | c := tx.Customer 72 | customer := &model.Customer{ 73 | Bank: model.Bank{ 74 | Name: "bank1", 75 | Address: "bank-address1", 76 | Scale: 1, 77 | }, 78 | CreditCards: []model.CreditCard{ 79 | {Number: "num1"}, 80 | {Number: "num2"}, 81 | }, 82 | } 83 | 84 | if err := c.WithContext(ctx).Create(customer); err != nil { 85 | return fmt.Errorf("create model fail: %s", err) 86 | } 87 | if err := c.Bank.WithContext(ctx).Model(customer).Replace(&model.Bank{ 88 | Name: "bank-replace", 89 | Address: "bank-replace-address", 90 | Scale: 2, 91 | }); err != nil { 92 | return fmt.Errorf("replace model fail: %s", err) 93 | } 94 | 95 | got, err := c.WithContext(ctx).Where(c.ID.Eq(customer.ID)). 96 | Preload(c.CreditCards). 97 | Preload(c.Bank). 98 | First() 99 | if err != nil { 100 | return fmt.Errorf("find model fail: %s", err) 101 | } 102 | if got.Bank.Name != "bank-replace" { 103 | return fmt.Errorf("replace model fail, expect %q, got %q", "bank-replace", got.Bank.Name) 104 | } 105 | 106 | return nil 107 | }); err != nil { 108 | t.Errorf("transaction execute fail: %s", err) 109 | } 110 | }) 111 | } 112 | -------------------------------------------------------------------------------- /tools/gentool/README.ZH_CN.md: -------------------------------------------------------------------------------- 1 | # GenTool 2 | 3 | 将Gen作为二进制的方式进行安装 4 | 5 | 6 | 7 | ## 安装 8 | 9 | ```shell 10 | go install gorm.io/gen/tools/gentool@latest 11 | ``` 12 | 13 | ## 使用方式 14 | 15 | ```shell 16 | 17 | gentool -h 18 | 19 | Usage of gentool: 20 | -db string 21 | input mysql or postgres or sqlite or sqlserver. consult[https://gorm.io/docs/connecting_to_the_database.html] (default "mysql") 22 | -dsn string 23 | consult[https://gorm.io/docs/connecting_to_the_database.html] 24 | -fieldNullable 25 | generate with pointer when field is nullable 26 | -fieldCoverable 27 | generate with pointer when field has default value 28 | -fieldWithIndexTag 29 | generate field with gorm index tag 30 | -fieldWithTypeTag 31 | generate field with gorm column type tag 32 | -modelPkgName string 33 | generated model code's package name 34 | -outFile string 35 | query code file name, default: gen.go 36 | -outPath string 37 | specify a directory for output (default "./dao/query") 38 | -tables string 39 | enter the required data table or leave it blank 40 | -onlyModel 41 | only generate models (without query file) 42 | -withUnitTest 43 | generate unit test for query code 44 | -fieldSignable 45 | detect integer field's unsigned type, adjust generated data type 46 | 47 | ``` 48 | 49 | #### c 50 | default "" 51 | 可以指定配置文件gen.yml的路径。 52 | 用配置文件来代替命令行。 53 | 命令行是最高优先级。 54 | 55 | #### db 56 | 57 | 默认值:mysql 58 | 59 | 可以输入: mysql、 postgres、 sqlite 、 sqlserve 60 | 61 | 参考:https://gorm.io/docs/connecting_to_the_database.html 62 | 63 | #### dsn 64 | 65 | 你可以使用GORM所有的连接。 66 | 67 | 参考:https://gorm.io/docs/connecting_to_the_database.html 68 | 69 | #### fieldNullable 70 | 71 | 字段可为空时使用指针生成 72 | 73 | #### fieldCoverable 74 | 75 | 字段有默认值时使用指针生成 76 | 77 | #### fieldWithIndexTag 78 | 79 | 使用GROM索引标记生成字段 80 | 81 | #### fieldWithTypeTag 82 | 83 | 使用gorm列类型标记生成字段 84 | 85 | #### modelPkgName 86 | 87 | 默认值是数据表名称。 88 | 89 | 生成的model代码的包名称。 90 | 91 | #### outFile 92 | 93 | 默认为:gen.go 94 | 95 | 查询代码文件名。 96 | 97 | #### outPath 98 | 99 | 默认为:/dao/query 100 | 101 | 指定输出目录 102 | 103 | #### tables 104 | 105 | 值为 : 输入所需的数据表或将其留空 106 | 107 | eg : 108 | 109 | ​ --tables="orders" #orders 数据表 110 | 111 | ​ --tables="orders,users" #orders 数据表和 users数据表 112 | 113 | ​ --tables="" # 数据库中所有的数据表 114 | 115 | 基于数据表生成对应的代码。 116 | 117 | #### withUnitTest 118 | 119 | 值为 : False / True 120 | 121 | 生成单元测试。 122 | 123 | #### fieldSignable 124 | 125 | Value : False / True 126 | 127 | 基于数据表定义的数据类型,生成对应的数据类型 128 | 129 | 130 | ### 使用示例 131 | 132 | ```shell 133 | gentool -dsn "user:pwd@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local" -tables "orders,doctor" 134 | ``` -------------------------------------------------------------------------------- /tools/gentool/README.md: -------------------------------------------------------------------------------- 1 | # GenTool 2 | 3 | Install GEN as a binary tool 4 | 5 | ## install 6 | 7 | ```shell 8 | go install gorm.io/gen/tools/gentool@latest 9 | ``` 10 | 11 | ## usage 12 | 13 | ```shell 14 | 15 | gentool -h 16 | 17 | Usage of gentool: 18 | -db string 19 | input mysql|postgres|sqlite|sqlserver|clickhouse. consult[https://gorm.io/docs/connecting_to_the_database.html] (default "mysql") 20 | -dsn string 21 | consult[https://gorm.io/docs/connecting_to_the_database.html] 22 | -fieldNullable 23 | generate with pointer when field is nullable 24 | -fieldCoverable 25 | generate with pointer when field has default value 26 | -fieldWithIndexTag 27 | generate field with gorm index tag 28 | -fieldWithTypeTag 29 | generate field with gorm column type tag 30 | -modelPkgName string 31 | generated model code's package name 32 | -outFile string 33 | query code file name, default: gen.go 34 | -outPath string 35 | specify a directory for output (default "./dao/query") 36 | -tables string 37 | enter the required data table or leave it blank 38 | -onlyModel 39 | only generate models (without query file) 40 | -withUnitTest 41 | generate unit test for query code 42 | -fieldSignable 43 | detect integer field's unsigned type, adjust generated data type 44 | 45 | ``` 46 | #### c 47 | default "" 48 | Is path for gen.yml 49 | Replace the command line with a configuration file 50 | The command line is the highest priority 51 | 52 | 53 | #### db 54 | 55 | default:mysql 56 | 57 | input mysql or postgres or sqlite or sqlserver. 58 | 59 | consult : https://gorm.io/docs/connecting_to_the_database.html 60 | 61 | #### dsn 62 | 63 | You can use all gorm's dsn. 64 | 65 | consult : https://gorm.io/docs/connecting_to_the_database.html 66 | 67 | #### fieldNullable 68 | 69 | generate with pointer when field is nullable 70 | 71 | #### fieldCoverable 72 | 73 | generate with pointer when field has default value 74 | 75 | #### fieldWithIndexTag 76 | 77 | generate field with gorm index tag 78 | 79 | #### fieldWithTypeTag 80 | 81 | generate field with gorm column type tag 82 | 83 | #### modelPkgName 84 | 85 | defalut table name. 86 | 87 | generated model code's package name. 88 | 89 | #### outFile 90 | 91 | query code file name, default: gen.go 92 | 93 | #### outPath 94 | 95 | specify a directory for output (default "./dao/query") 96 | 97 | #### tables 98 | 99 | Value : enter the required data table or leave it blank. 100 | 101 | eg : 102 | 103 | ​ --tables="orders" #orders table 104 | 105 | ​ --tables="orders,users" #orders table and users table 106 | 107 | ​ --tables="" # All data tables in the database. 108 | 109 | Generate some tables code. 110 | 111 | #### withUnitTest 112 | 113 | Value : False / True 114 | 115 | Generate unit test. 116 | 117 | #### fieldSignable 118 | 119 | Value : False / True 120 | 121 | detect integer field's unsigned type, adjust generated data type 122 | 123 | 124 | 125 | ### example 126 | 127 | ```shell 128 | gentool -dsn "user:pwd@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local" -tables "orders,doctor" 129 | ``` -------------------------------------------------------------------------------- /tools/gentool/gen.yml: -------------------------------------------------------------------------------- 1 | version: "0.1" 2 | database: 3 | # consult[https://gorm.io/docs/connecting_to_the_database.html]" 4 | dsn : "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local" 5 | # input mysql or postgres or sqlite or sqlserver. consult[https://gorm.io/docs/connecting_to_the_database.html] 6 | db : "mysql" 7 | # enter the required data table or leave it blank.You can input : 8 | # tables : 9 | # - orders 10 | # - users 11 | # - goods 12 | tables : 13 | # only generate models (without query file) 14 | onlyModel : false 15 | # specify a directory for output 16 | outPath : "./dao/query" 17 | # query code file name, default: gen.go 18 | outFile : "" 19 | # generate unit test for query code 20 | withUnitTest : false 21 | # generated model code's package name 22 | modelPkgName : "" 23 | # generate with pointer when field is nullable 24 | fieldNullable : false 25 | # generate with pointer when field has default value 26 | fieldCoverable : false 27 | # generate field with gorm index tag 28 | fieldWithIndexTag : false 29 | # generate field with gorm column type tag 30 | fieldWithTypeTag : false 31 | # detect integer field's unsigned type, adjust generated data type 32 | fieldSignable : false 33 | -------------------------------------------------------------------------------- /tools/gentool/go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gen/tools/gentool 2 | 3 | go 1.19 4 | 5 | require ( 6 | gopkg.in/yaml.v3 v3.0.1 7 | gorm.io/driver/clickhouse v0.6.0 8 | gorm.io/driver/mysql v1.5.6 9 | gorm.io/driver/postgres v1.5.7 10 | gorm.io/driver/sqlite v1.5.5 11 | gorm.io/driver/sqlserver v1.5.3 12 | gorm.io/gen v0.3.26 13 | gorm.io/gorm v1.25.9 14 | ) 15 | 16 | require ( 17 | github.com/ClickHouse/ch-go v0.58.2 // indirect 18 | github.com/ClickHouse/clickhouse-go/v2 v2.15.0 // indirect 19 | github.com/andybalholm/brotli v1.0.6 // indirect 20 | github.com/go-faster/city v1.0.1 // indirect 21 | github.com/go-faster/errors v0.6.1 // indirect 22 | github.com/go-sql-driver/mysql v1.7.0 // indirect 23 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 24 | github.com/golang-sql/sqlexp v0.1.0 // indirect 25 | github.com/google/uuid v1.3.1 // indirect 26 | github.com/hashicorp/go-version v1.6.0 // indirect 27 | github.com/jackc/pgpassfile v1.0.0 // indirect 28 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 29 | github.com/jackc/pgx/v5 v5.4.3 // indirect 30 | github.com/jinzhu/inflection v1.0.0 // indirect 31 | github.com/jinzhu/now v1.1.5 // indirect 32 | github.com/klauspost/compress v1.16.7 // indirect 33 | github.com/mattn/go-sqlite3 v1.14.17 // indirect 34 | github.com/microsoft/go-mssqldb v1.6.0 // indirect 35 | github.com/paulmach/orb v0.10.0 // indirect 36 | github.com/pierrec/lz4/v4 v4.1.18 // indirect 37 | github.com/pkg/errors v0.9.1 // indirect 38 | github.com/segmentio/asm v1.2.0 // indirect 39 | github.com/shopspring/decimal v1.3.1 // indirect 40 | go.opentelemetry.io/otel v1.19.0 // indirect 41 | go.opentelemetry.io/otel/trace v1.19.0 // indirect 42 | golang.org/x/crypto v0.14.0 // indirect 43 | golang.org/x/mod v0.14.0 // indirect 44 | golang.org/x/sys v0.13.0 // indirect 45 | golang.org/x/text v0.13.0 // indirect 46 | golang.org/x/tools v0.17.0 // indirect 47 | gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect 48 | gorm.io/hints v1.1.0 // indirect 49 | gorm.io/plugin/dbresolver v1.5.0 // indirect 50 | ) 51 | --------------------------------------------------------------------------------