├── .github ├── dependabot.yaml └── workflows │ └── ci.yaml ├── .gitignore ├── .golangci.yaml ├── LICENSE ├── README.md ├── clause.go ├── clause_exporter_test.go ├── clause_test.go ├── cmd └── genorm │ ├── generator │ ├── codegen │ │ ├── codegen.go │ │ ├── column.go │ │ ├── column_test.go │ │ ├── directory.go │ │ ├── directory_test.go │ │ ├── joined_table.go │ │ ├── lib_type.go │ │ └── table.go │ ├── convert │ │ ├── convert.go │ │ └── convert_test.go │ ├── generator.go │ ├── parser │ │ ├── parser.go │ │ └── parser_test.go │ └── types │ │ └── types.go │ ├── main.go │ └── main_test.go ├── column.go ├── column_mock_test.go ├── config.go ├── context.go ├── db.go ├── delete.go ├── delete_exporter_test.go ├── delete_test.go ├── errors.go ├── expr.go ├── expr_exporter_test.go ├── expr_mock_test.go ├── find.go ├── find_expoter_test.go ├── find_test.go ├── function.go ├── function_test.go ├── go.mod ├── go.sum ├── insert.go ├── insert_exporter_test.go ├── insert_test.go ├── mock ├── column.go └── expr.go ├── operator.go ├── operator_test.go ├── pluck.go ├── pluck_expoter_test.go ├── pluck_test.go ├── relation ├── relation.go ├── relation_test.go └── table.go ├── select.go ├── select_exporter_test.go ├── select_test.go ├── table.go ├── table_mock_test.go ├── tuple.go ├── tuple_mock_test.go ├── type.go ├── type_mock_test.go ├── update.go ├── update_exprter_test.go └── update_test.go /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: "/" 5 | schedule: 6 | interval: weekly 7 | day: saturday 8 | time: "00:00" 9 | timezone: Asia/Tokyo 10 | - package-ecosystem: github-actions 11 | directory: "/" 12 | schedule: 13 | interval: weekly 14 | day: saturday 15 | time: "00:00" 16 | timezone: Asia/Tokyo 17 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-go@v5 16 | with: 17 | go-version-file: go.mod 18 | - run: go build -o genorm ./cmd/genorm/ 19 | - uses: actions/upload-artifact@v4 20 | with: 21 | name: genorm 22 | path: ./genorm 23 | lint: 24 | name: Lint 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | # go generate用に、golangci-lintの前にGoのinstallをする 29 | - uses: actions/setup-go@v5 30 | with: 31 | go-version-file: go.mod 32 | - run: go generate ./... 33 | - name: golangci-lint 34 | uses: reviewdog/action-golangci-lint@v2.7 35 | with: 36 | go_version_file: go.mod 37 | reporter: github-pr-review 38 | github_token: ${{ secrets.GITHUB_TOKEN }} 39 | fail_on_error: true 40 | test: 41 | name: Test 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: actions/checkout@v4 45 | - uses: actions/setup-go@v5 46 | with: 47 | go-version-file: go.mod 48 | - run: go generate ./... 49 | - run: go test ./... -v -coverprofile=./coverage.txt -race -vet=off 50 | - name: Upload coverage data 51 | uses: codecov/codecov-action@v5.4.0 52 | with: 53 | file: ./coverage.txt 54 | fail_ci_if_error: true 55 | token: ${{ secrets.CODECOV_TOKEN }} 56 | - uses: actions/upload-artifact@v4 57 | with: 58 | name: coverage.txt 59 | path: coverage.txt 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Go workspace file 18 | go.work 19 | 20 | # mock file 21 | mock/*_mock.go 22 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | skip-dirs: 3 | - mock 4 | linters: 5 | enable: 6 | - govet 7 | - errcheck 8 | - staticcheck 9 | - unused 10 | - gosimple 11 | - structcheck 12 | - varcheck 13 | - ineffassign 14 | - typecheck 15 | - revive 16 | - gofmt 17 | - asasalint 18 | - asciicheck 19 | - bidichk 20 | - errname 21 | - errorlint 22 | - exhaustive 23 | - exportloopref 24 | - forcetypeassert 25 | - gocheckcompilerdirectives 26 | - gocritic 27 | - goheader 28 | - goimports 29 | - gosec 30 | - misspell 31 | - nakedret 32 | - nilerr 33 | - nosprintfhostport 34 | - sqlclosecheck 35 | - testpackage 36 | - unconvert 37 | - unparam 38 | - whitespace 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shunsuke Wakamatsu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenORM 2 | 3 | [![MIT License](http://img.shields.io/badge/license-MIT-blue.svg?style=flat)](LICENSE) 4 | [![](https://pkg.go.dev/badge/github.com/mazrean/genorm)](https://pkg.go.dev/github.com/mazrean/genorm) 5 | [![](https://github.com/mazrean/genorm/workflows/CI/badge.svg)](https://github.com/mazrean/genorm/actions) 6 | 7 | SQL Builder to prevent SQL mistakes using the Golang generics 8 | 9 | #### document 10 | - [English](https://mazrean.github.io/genorm-docs/en/) 11 | - [日本語](https://mazrean.github.io/genorm-docs/ja/) 12 | 13 | ## Feature 14 | 15 | By mapping SQL expressions to appropriate golang types using generics, you can discover many SQL mistakes at the time of compilation that are not prevented by traditional Golang ORMs or query builders. 16 | 17 | For example: 18 | 19 | * Compilation error occurs when using values of different Go types in SQL for = etc. comparisons or updating values in UPDATE statements 20 | * Compile error occurs when using column names of unavailable tables 21 | 22 | It also supports many CRUD syntaxes in SQL. 23 | 24 | ## Example 25 | #### Example 1 26 | 27 | String column `users.name` can be compared to a `string` value, but comparing it to an `int` value will result in a compile error. 28 | 29 | ```go 30 | // correct 31 | userValues, err := genorm. 32 | Select(orm.User()). 33 | Where(genorm.EqLit(user.NameExpr, genorm.Wrap("name"))). 34 | GetAll(db) 35 | 36 | // compile error 37 | userValues, err := genorm. 38 | Select(orm.User()). 39 | Where(genorm.EqLit(user.NameExpr, genorm.Wrap(1))). 40 | GetAll(db) 41 | ``` 42 | 43 | #### Example 2 44 | 45 | You can use an `id` column from the `users` table in a `SELECT` statement that retrieves data from the `users` table, but using an `id` column from the `messages` table will result in a compile error. 46 | 47 | ```go 48 | // correct 49 | userValues, err := genorm. 50 | Select(orm.User()). 51 | Where(genorm.EqLit(user.IDExpr, uuid.New())). 52 | GetAll(db) 53 | 54 | // compile error 55 | userValues, err := genorm. 56 | Select(orm.User()). 57 | Where(genorm.EqLit(message.IDExpr, uuid.New())). 58 | GetAll(db) 59 | ``` 60 | 61 | ## Install 62 | 63 | GenORM uses the CLI to generate code. The `genorm`package is used to invoke queries. For this reason, both the CLI and Package must be install. 64 | 65 | #### CLI 66 | 67 | ``` 68 | go install github.com/mazrean/genorm/cmd/genorm@v1.0.0 69 | ``` 70 | 71 | #### Package 72 | 73 | ``` 74 | go get -u github.com/mazrean/genorm 75 | ``` 76 | 77 | ### Configuration 78 | 79 | #### Example 80 | 81 | The `users` table can join the `messages` table. 82 | 83 | ```go 84 | import "github.com/mazrean/genorm" 85 | 86 | type User struct { 87 | // Column Information 88 | Message genorm.Ref[Message] 89 | } 90 | 91 | func (*User) TableName() string { 92 | return "users" 93 | } 94 | 95 | type Message struct { 96 | // Column Information 97 | } 98 | 99 | func (*Message) TableName() string { 100 | return "messages" 101 | } 102 | ``` 103 | 104 | ## Usage 105 | ### Connecting to a Database 106 | ```go 107 | import ( 108 | "database/sql" 109 | _ "github.com/go-sql-driver/mysql" 110 | ) 111 | 112 | db, err := sql.Open("mysql", "user:pass@tcp(host:port)/database?parseTime=true&loc=Asia%2FTokyo&charset=utf8mb4") 113 | ``` 114 | 115 | ### Insert 116 | ```go 117 | // INSERT INTO users (id, name, created_at) VALUES ({{uuid.New()}}, "name1", {{time.Now()}}), ({{uuid.New()}}, "name2", {{time.Now()}}) 118 | affectedRows, err := genorm. 119 | Insert(orm.User()). 120 | Values(&orm.UserTable{ 121 | ID: uuid.New(), 122 | Name: genorm.Wrap("name1"), 123 | CreatedAt: genorm.Wrap(time.Now()), 124 | }, &orm.UserTable{ 125 | ID: uuid.New(), 126 | Name: genorm.Wrap("name2"), 127 | CreatedAt: genorm.Wrap(time.Now()), 128 | }). 129 | Do(db) 130 | ``` 131 | 132 | ### Select 133 | 134 | ```go 135 | // SELECT id, name, created_at FROM users 136 | // userValues: []orm.UserTable 137 | userValues, err := genorm. 138 | Select(orm.User()). 139 | GetAll(db) 140 | 141 | // SELECT id, name, created_at FROM users LIMIT 1 142 | // userValue: orm.UserTable 143 | userValue, err := genorm. 144 | Select(orm.User()). 145 | Get(db) 146 | 147 | // SELECT id FROM users 148 | // userIDs: []uuid.UUID 149 | userIDs, err := genorm. 150 | Pluck(orm.User(), user.IDExpr). 151 | GetAll(db) 152 | 153 | // SELECT COUNT(id) AS result FROM users LIMIT 1 154 | // userNum: int64 155 | userNum, err := genorm. 156 | Pluck(orm.User(), genorm.Count(user.IDExpr, false)). 157 | Get(db) 158 | ``` 159 | 160 | ### Update 161 | ```go 162 | // UPDATE users SET name="name" 163 | affectedRows, err = genorm. 164 | Update(orm.User()). 165 | Set( 166 | genorm.AssignLit(user.Name, genorm.Wrap("name")), 167 | ). 168 | Do(db) 169 | ``` 170 | 171 | 172 | ### Delete 173 | ```go 174 | // DELETE FROM users 175 | affectedRows, err = genorm. 176 | Delete(orm.User()). 177 | Do(db) 178 | ``` 179 | 180 | ### Join 181 | #### Select 182 | ```go 183 | // SELECT users.name, messages.content FROM users INNER JOIN messages ON users.id = messages.user_id 184 | // messageUserValues: []orm.MessageUserTable 185 | userID := orm.MessageUserParseExpr(user.ID) 186 | userName := orm.MessageUserParse(user.Name) 187 | messageUserID := orm.MessageUserParseExpr(message.UserID) 188 | messageContent := orm.MessageUserParse(message.Content) 189 | messageUserValues, err := genorm. 190 | Select(orm.User(). 191 | Message().Join(genorm.Eq(userID, messageUserID))). 192 | Fields(userName, messageContent). 193 | GetAll(db) 194 | ``` 195 | 196 | #### Update 197 | ```go 198 | // UPDATE users INNER JOIN messages ON users.id = messages.id SET content="hello world" 199 | userIDColumn := orm.MessageUserParseExpr(user.ID) 200 | messageUserIDColumn := orm.MessageUserParseExpr(message.UserID) 201 | messageContent := orm.MessageUserParse(message.Content) 202 | affectedRows, err := genorm. 203 | Update(orm.User(). 204 | Message().Join(genorm.Eq(userID, messageUserID))). 205 | Set(genorm.AssignLit(messageContent, genorm.Wrap("hello world"))). 206 | Do(db) 207 | ``` 208 | 209 | ### Transaction 210 | ```go 211 | tx, err := db.Begin() 212 | if err != nil { 213 | log.Fatal(err) 214 | } 215 | 216 | _, err = genorm. 217 | Insert(orm.User()). 218 | Values(&orm.UserTable{ 219 | ID: uuid.New(), 220 | Name: genorm.Wrap("name1"), 221 | CreatedAt: genorm.Wrap(time.Now()), 222 | }, &orm.UserTable{ 223 | ID: uuid.New(), 224 | Name: genorm.Wrap("name2"), 225 | CreatedAt: genorm.Wrap(time.Now()), 226 | }). 227 | Do(db) 228 | if err != nil { 229 | _ = tx.Rollback() 230 | log.Fatal(err) 231 | } 232 | 233 | err = tx.Commit() 234 | if err != nil { 235 | log.Fatal(err) 236 | } 237 | ``` 238 | 239 | ### Context 240 | 241 | ```go 242 | // SELECT id, name, created_at FROM users 243 | // userValues: []orm.UserTable 244 | userValues, err := genorm. 245 | Select(orm.User()). 246 | GetAllCtx(context.Background(), db) 247 | ``` 248 | 249 | ```go 250 | // INSERT INTO users (id, name, created_at) VALUES ({{uuid.New()}}, "name", {{time.Now()}}) 251 | affectedRows, err := genorm. 252 | Insert(orm.User()). 253 | Values(&orm.UserTable{ 254 | ID: uuid.New(), 255 | Name: genorm.Wrap("name"), 256 | CreatedAt: genorm.Wrap(time.Now()), 257 | }). 258 | DoCtx(context.Background(), db) 259 | ``` 260 | 261 | -------------------------------------------------------------------------------- /clause.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type whereConditionClause[T Table] struct { 10 | condition TypedTableExpr[T, WrappedPrimitive[bool]] 11 | } 12 | 13 | func (c *whereConditionClause[T]) set(condition TypedTableExpr[T, WrappedPrimitive[bool]]) error { 14 | if c.condition != nil { 15 | return errors.New("where conditions already set") 16 | } 17 | if condition == nil { 18 | return errors.New("empty where condition") 19 | } 20 | 21 | c.condition = condition 22 | 23 | return nil 24 | } 25 | 26 | func (c *whereConditionClause[T]) exists() bool { 27 | return c.condition != nil 28 | } 29 | 30 | func (c *whereConditionClause[T]) getExpr() (string, []ExprType, error) { 31 | if c.condition == nil { 32 | return "", nil, errors.New("empty where condition") 33 | } 34 | 35 | query, args, errs := c.condition.Expr() 36 | if len(errs) != 0 { 37 | return "", nil, errs[0] 38 | } 39 | 40 | return query, args, nil 41 | } 42 | 43 | type groupClause[T Table] struct { 44 | exprs []TableExpr[T] 45 | } 46 | 47 | func (c *groupClause[T]) set(exprs []TableExpr[T]) error { 48 | if len(c.exprs) != 0 { 49 | return errors.New("group by already set") 50 | } 51 | if len(exprs) == 0 { 52 | return errors.New("empty group by") 53 | } 54 | 55 | c.exprs = exprs 56 | 57 | return nil 58 | } 59 | 60 | func (c *groupClause[_]) exists() bool { 61 | return len(c.exprs) != 0 62 | } 63 | 64 | func (c *groupClause[T]) getExpr() (string, []ExprType, error) { 65 | if len(c.exprs) == 0 { 66 | return "", nil, errors.New("empty group by") 67 | } 68 | 69 | queries := make([]string, 0, len(c.exprs)) 70 | args := []ExprType{} 71 | for _, expr := range c.exprs { 72 | groupQuery, groupArgs, errs := expr.Expr() 73 | if len(errs) != 0 { 74 | return "", nil, errs[0] 75 | } 76 | 77 | queries = append(queries, groupQuery) 78 | args = append(args, groupArgs...) 79 | } 80 | 81 | return "GROUP BY " + strings.Join(queries, ", "), args, nil 82 | } 83 | 84 | type orderClause[T Table] struct { 85 | orderExprs []orderItem[T] 86 | } 87 | 88 | type orderItem[T Table] struct { 89 | expr TableExpr[T] 90 | direction OrderDirection 91 | } 92 | 93 | type OrderDirection uint8 94 | 95 | const ( 96 | Asc OrderDirection = iota + 1 97 | Desc 98 | ) 99 | 100 | func (c *orderClause[T]) add(item orderItem[T]) error { 101 | if item.expr == nil { 102 | return errors.New("empty order expr") 103 | } 104 | 105 | if item.direction != Asc && item.direction != Desc { 106 | return errors.New("invalid order direction") 107 | } 108 | 109 | c.orderExprs = append(c.orderExprs, item) 110 | 111 | return nil 112 | } 113 | 114 | func (c *orderClause[T]) exists() bool { 115 | return len(c.orderExprs) != 0 116 | } 117 | 118 | func (c *orderClause[T]) getExpr() (string, []ExprType, error) { 119 | if len(c.orderExprs) == 0 { 120 | return "", nil, errors.New("empty order by") 121 | } 122 | 123 | args := []ExprType{} 124 | orderQueries := make([]string, 0, len(c.orderExprs)) 125 | for _, orderItem := range c.orderExprs { 126 | orderQuery, orderArgs, errs := orderItem.expr.Expr() 127 | if len(errs) != 0 { 128 | return "", nil, errs[0] 129 | } 130 | 131 | var directionQuery string 132 | switch orderItem.direction { 133 | case Asc: 134 | directionQuery = "ASC" 135 | case Desc: 136 | directionQuery = "DESC" 137 | default: 138 | return "", nil, fmt.Errorf("invalid order direction: %d", orderItem.direction) 139 | } 140 | 141 | orderQueries = append(orderQueries, fmt.Sprintf("%s %s", orderQuery, directionQuery)) 142 | args = append(args, orderArgs...) 143 | } 144 | 145 | return fmt.Sprintf("ORDER BY %s", strings.Join(orderQueries, ", ")), args, nil 146 | } 147 | 148 | type limitClause struct { 149 | limit uint64 150 | } 151 | 152 | func (l *limitClause) set(limit uint64) error { 153 | if l.limit != 0 { 154 | return errors.New("limit already set") 155 | } 156 | if limit == 0 { 157 | return errors.New("invalid limit") 158 | } 159 | 160 | l.limit = limit 161 | 162 | return nil 163 | } 164 | 165 | func (l *limitClause) exists() bool { 166 | return l.limit != 0 167 | } 168 | 169 | func (l *limitClause) getExpr() (string, []ExprType, error) { 170 | if l.limit == 0 { 171 | return "", nil, errors.New("empty limit") 172 | } 173 | 174 | return fmt.Sprintf("LIMIT %d", l.limit), nil, nil 175 | } 176 | 177 | type offsetClause struct { 178 | offset uint64 179 | } 180 | 181 | func (o *offsetClause) set(offset uint64) error { 182 | if o.offset != 0 { 183 | return errors.New("offset already set") 184 | } 185 | if offset == 0 { 186 | return errors.New("invalid offset") 187 | } 188 | 189 | o.offset = offset 190 | 191 | return nil 192 | } 193 | 194 | func (o *offsetClause) exists() bool { 195 | return o.offset != 0 196 | } 197 | 198 | func (o *offsetClause) getExpr() (string, []ExprType, error) { 199 | if o.offset == 0 { 200 | return "", nil, errors.New("empty offset") 201 | } 202 | 203 | return fmt.Sprintf("OFFSET %d", o.offset), nil, nil 204 | } 205 | 206 | type lockClause struct { 207 | lockType LockType 208 | } 209 | 210 | type LockType uint8 211 | 212 | const ( 213 | none LockType = iota 214 | ForUpdate 215 | ForShare 216 | ) 217 | 218 | func (l *lockClause) set(lockType LockType) error { 219 | if l.lockType != none { 220 | return errors.New("lock type already set") 221 | } 222 | if lockType != ForUpdate && lockType != ForShare { 223 | return errors.New("invalid lock type") 224 | } 225 | 226 | l.lockType = lockType 227 | 228 | return nil 229 | } 230 | 231 | func (l *lockClause) exists() bool { 232 | return l.lockType != none 233 | } 234 | 235 | func (l *lockClause) getExpr() (string, []ExprType, error) { 236 | switch l.lockType { 237 | case ForUpdate: 238 | return "FOR UPDATE", nil, nil 239 | case ForShare: 240 | return "FOR SHARE", nil, nil 241 | case none: 242 | return "", nil, nil 243 | } 244 | 245 | return "", nil, errors.New("invalid lock type") 246 | } 247 | -------------------------------------------------------------------------------- /clause_exporter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | //nolint:revive 4 | func NewWhereConditionClause[T Table](condition TypedTableExpr[T, WrappedPrimitive[bool]]) *whereConditionClause[T] { 5 | return &whereConditionClause[T]{ 6 | condition: condition, 7 | } 8 | } 9 | 10 | func (c *whereConditionClause[T]) GetCondition() TypedTableExpr[T, WrappedPrimitive[bool]] { 11 | return c.condition 12 | } 13 | 14 | func (c *whereConditionClause[T]) Set(condition TypedTableExpr[T, WrappedPrimitive[bool]]) error { 15 | return c.set(condition) 16 | } 17 | 18 | func (c *whereConditionClause[T]) Exists() bool { 19 | return c.exists() 20 | } 21 | 22 | func (c *whereConditionClause[T]) GetExpr() (string, []ExprType, error) { 23 | return c.getExpr() 24 | } 25 | 26 | //nolint:revive 27 | func NewGroupClause[T Table](exprs []TableExpr[T]) *groupClause[T] { 28 | return &groupClause[T]{ 29 | exprs: exprs, 30 | } 31 | } 32 | 33 | func (c *groupClause[T]) GetCondition() []TableExpr[T] { 34 | return c.exprs 35 | } 36 | 37 | func (c *groupClause[T]) Set(exprs []TableExpr[T]) error { 38 | return c.set(exprs) 39 | } 40 | 41 | func (c *groupClause[T]) Exists() bool { 42 | return c.exists() 43 | } 44 | 45 | func (c *groupClause[T]) GetExpr() (string, []ExprType, error) { 46 | return c.getExpr() 47 | } 48 | 49 | type OrderItem[T Table] orderItem[T] 50 | 51 | func NewOrderItem[T Table](expr TableExpr[T], direction OrderDirection) OrderItem[T] { 52 | return OrderItem[T]{ 53 | expr: expr, 54 | direction: direction, 55 | } 56 | } 57 | 58 | func (c *OrderItem[T]) Value() (TableExpr[T], OrderDirection) { 59 | return c.expr, c.direction 60 | } 61 | 62 | //nolint:revive 63 | func NewOrderClause[T Table](items []OrderItem[T]) *orderClause[T] { 64 | orderItems := make([]orderItem[T], 0, len(items)) 65 | for _, item := range items { 66 | orderItems = append(orderItems, orderItem[T](item)) 67 | } 68 | 69 | return &orderClause[T]{ 70 | orderExprs: orderItems, 71 | } 72 | } 73 | 74 | func (c *orderClause[T]) GetItems() []OrderItem[T] { 75 | items := make([]OrderItem[T], 0, len(c.orderExprs)) 76 | for _, item := range c.orderExprs { 77 | items = append(items, OrderItem[T](item)) 78 | } 79 | 80 | return items 81 | } 82 | 83 | func (c *orderClause[T]) Add(item OrderItem[T]) error { 84 | return c.add(orderItem[T](item)) 85 | } 86 | 87 | func (c *orderClause[T]) Exists() bool { 88 | return c.exists() 89 | } 90 | 91 | func (c *orderClause[T]) GetExpr() (string, []ExprType, error) { 92 | return c.getExpr() 93 | } 94 | 95 | //nolint:revive 96 | func NewLimitClause(limit uint64) *limitClause { 97 | return &limitClause{ 98 | limit: limit, 99 | } 100 | } 101 | 102 | func (c *limitClause) GetLimit() uint64 { 103 | return c.limit 104 | } 105 | 106 | func (c *limitClause) Set(limit uint64) error { 107 | return c.set(limit) 108 | } 109 | 110 | func (c *limitClause) Exists() bool { 111 | return c.exists() 112 | } 113 | 114 | func (c *limitClause) GetExpr() (string, []ExprType, error) { 115 | return c.getExpr() 116 | } 117 | 118 | //nolint:revive 119 | func NewOffsetClause(offset uint64) *offsetClause { 120 | return &offsetClause{ 121 | offset: offset, 122 | } 123 | } 124 | 125 | func (c *offsetClause) GetOffset() uint64 { 126 | return c.offset 127 | } 128 | 129 | func (c *offsetClause) Set(offset uint64) error { 130 | return c.set(offset) 131 | } 132 | 133 | func (c *offsetClause) Exists() bool { 134 | return c.exists() 135 | } 136 | 137 | func (c *offsetClause) GetExpr() (string, []ExprType, error) { 138 | return c.getExpr() 139 | } 140 | 141 | //nolint:revive 142 | func NewLockClause(lockType LockType) *lockClause { 143 | return &lockClause{ 144 | lockType: lockType, 145 | } 146 | } 147 | 148 | func (c *lockClause) GetLockType() LockType { 149 | return c.lockType 150 | } 151 | 152 | func (c *lockClause) Set(lockType LockType) error { 153 | return c.set(lockType) 154 | } 155 | 156 | func (c *lockClause) Exists() bool { 157 | return c.exists() 158 | } 159 | 160 | func (c *lockClause) GetExpr() (string, []ExprType, error) { 161 | return c.getExpr() 162 | } 163 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/codegen.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | 8 | "github.com/mazrean/genorm/cmd/genorm/generator/types" 9 | ) 10 | 11 | const ( 12 | genormImport = `"github.com/mazrean/genorm"` 13 | genormRelationImport = `"github.com/mazrean/genorm/relation"` 14 | fmtImport = `"fmt"` 15 | ) 16 | 17 | var ( 18 | genormIdent = ast.NewIdent("genorm") 19 | genormRelationIdent = ast.NewIdent("relation") 20 | fmtIdent = ast.NewIdent("fmt") 21 | 22 | rootPackageIdent *ast.Ident 23 | ) 24 | 25 | func Codegen( 26 | packageName string, 27 | modulePath string, 28 | destinationDir string, 29 | baseAst *ast.File, 30 | tables []*types.Table, 31 | joinedTables []*types.JoinedTable, 32 | ) error { 33 | rootPackageIdent = ast.NewIdent(packageName) 34 | 35 | dir, err := newDirectory(destinationDir, packageName, modulePath) 36 | if err != nil { 37 | return fmt.Errorf("failed to create directory: %w", err) 38 | } 39 | 40 | importDecls := codegenImportDecls(baseAst) 41 | 42 | codegenTables, codegenJoinedTables, err := convert(tables, joinedTables) 43 | if err != nil { 44 | return fmt.Errorf("failed to convert tables: %w", err) 45 | } 46 | 47 | err = codegenMain(dir, importDecls, codegenTables, codegenJoinedTables) 48 | if err != nil { 49 | return fmt.Errorf("failed to codegen main: %w", err) 50 | } 51 | 52 | for _, table := range codegenTables { 53 | err = codegenTable(dir, importDecls, table) 54 | if err != nil { 55 | return fmt.Errorf("failed to codegen table(%s): %w", table.name, err) 56 | } 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func codegenImportDecls(baseAst *ast.File) []ast.Decl { 63 | importDecls := []ast.Decl{} 64 | 65 | haveImport := false 66 | haveGenorm := false 67 | haveGenormRelation := false 68 | haveFmt := false 69 | for _, decl := range baseAst.Decls { 70 | genDecl, ok := decl.(*ast.GenDecl) 71 | if !ok || genDecl == nil || genDecl.Tok != token.IMPORT || len(genDecl.Specs) == 0 { 72 | continue 73 | } 74 | 75 | haveImport = true 76 | 77 | for _, spec := range genDecl.Specs { 78 | importSpec, ok := spec.(*ast.ImportSpec) 79 | if !ok || importSpec == nil { 80 | continue 81 | } 82 | 83 | switch importSpec.Path.Value { 84 | case genormImport: 85 | if importSpec.Name != nil { 86 | genormIdent = importSpec.Name 87 | } 88 | haveGenorm = true 89 | case genormRelationImport: 90 | if importSpec.Name != nil { 91 | genormRelationIdent = importSpec.Name 92 | } 93 | haveGenormRelation = true 94 | case fmtImport: 95 | if importSpec.Name != nil { 96 | fmtIdent = importSpec.Name 97 | } 98 | haveFmt = true 99 | } 100 | } 101 | 102 | if !haveGenorm { 103 | genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{ 104 | Path: &ast.BasicLit{ 105 | Kind: token.STRING, 106 | Value: genormImport, 107 | }, 108 | }) 109 | 110 | haveGenorm = true 111 | } 112 | 113 | if !haveGenormRelation { 114 | genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{ 115 | Path: &ast.BasicLit{ 116 | Kind: token.STRING, 117 | Value: genormRelationImport, 118 | }, 119 | }) 120 | } 121 | 122 | if !haveFmt { 123 | genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{ 124 | Path: &ast.BasicLit{ 125 | Kind: token.STRING, 126 | Value: fmtImport, 127 | }, 128 | }) 129 | 130 | haveFmt = true 131 | } 132 | 133 | importDecls = append(importDecls, genDecl) 134 | } 135 | 136 | if !haveImport { 137 | importDecls = append(importDecls, &ast.GenDecl{ 138 | Tok: token.IMPORT, 139 | Specs: []ast.Spec{ 140 | &ast.ImportSpec{ 141 | Name: genormIdent, 142 | Path: &ast.BasicLit{ 143 | Kind: token.STRING, 144 | Value: genormImport, 145 | }, 146 | }, 147 | &ast.ImportSpec{ 148 | Name: genormRelationIdent, 149 | Path: &ast.BasicLit{ 150 | Kind: token.STRING, 151 | Value: genormRelationImport, 152 | }, 153 | }, 154 | &ast.ImportSpec{ 155 | Name: fmtIdent, 156 | Path: &ast.BasicLit{ 157 | Kind: token.STRING, 158 | Value: fmtImport, 159 | }, 160 | }, 161 | }, 162 | }) 163 | } 164 | 165 | return importDecls 166 | } 167 | 168 | type refTable struct { 169 | refTable *table 170 | joinedTable *joinedTable 171 | } 172 | 173 | type refJoinedTable struct { 174 | refTable *joinedTable 175 | joinedTable *joinedTable 176 | } 177 | 178 | func convert(tables []*types.Table, joinedTables []*types.JoinedTable) ([]*table, []*joinedTable, error) { 179 | tableMap := make(map[string]*table, len(tables)) 180 | codegenTables := make([]*table, 0, len(tables)) 181 | for _, table := range tables { 182 | codegenTable, err := newTable(table) 183 | if err != nil { 184 | return nil, nil, fmt.Errorf("failed to create table(%s): %w", table.StructName, err) 185 | } 186 | 187 | tableMap[table.StructName] = codegenTable 188 | codegenTables = append(codegenTables, codegenTable) 189 | } 190 | 191 | joinedTableMap := make(map[string]*joinedTable, len(joinedTables)) 192 | codegenJoinedTables := make([]*joinedTable, 0, len(joinedTables)) 193 | for _, joinedTable := range joinedTables { 194 | codegenJoinedTable := newJoinedTable(joinedTable) 195 | 196 | joinedTableMap[joinedTableName(joinedTable)] = codegenJoinedTable 197 | codegenJoinedTables = append(codegenJoinedTables, codegenJoinedTable) 198 | } 199 | 200 | for _, typesTable := range tables { 201 | codegenTable := tableMap[typesTable.StructName] 202 | 203 | refTables := make([]*refTable, 0, len(typesTable.RefTables)) 204 | for _, typeRefTable := range typesTable.RefTables { 205 | refTables = append(refTables, &refTable{ 206 | refTable: tableMap[typeRefTable.Table.StructName], 207 | joinedTable: joinedTableMap[joinedTableName(typeRefTable.JoinedTable)], 208 | }) 209 | } 210 | codegenTable.refTables = refTables 211 | 212 | refJoinedTables := make([]*refJoinedTable, 0, len(typesTable.RefJoinedTables)) 213 | for _, typeRefJoinedTable := range typesTable.RefJoinedTables { 214 | refJoinedTables = append(refJoinedTables, &refJoinedTable{ 215 | refTable: joinedTableMap[joinedTableName(typeRefJoinedTable.Table)], 216 | joinedTable: joinedTableMap[joinedTableName(typeRefJoinedTable.JoinedTable)], 217 | }) 218 | } 219 | codegenTable.refJoinedTables = refJoinedTables 220 | } 221 | 222 | for _, typesJoinedTable := range joinedTables { 223 | codegenJoinedTable := joinedTableMap[joinedTableName(typesJoinedTable)] 224 | 225 | tables := make([]*table, 0, len(typesJoinedTable.Tables)) 226 | for _, typeTable := range typesJoinedTable.Tables { 227 | tables = append(tables, tableMap[typeTable.StructName]) 228 | } 229 | codegenJoinedTable.tables = tables 230 | 231 | refTables := make([]*refTable, 0, len(typesJoinedTable.RefTables)) 232 | for _, typeRefTable := range typesJoinedTable.RefTables { 233 | refTables = append(refTables, &refTable{ 234 | refTable: tableMap[typeRefTable.Table.StructName], 235 | joinedTable: joinedTableMap[joinedTableName(typeRefTable.JoinedTable)], 236 | }) 237 | } 238 | codegenJoinedTable.refTables = refTables 239 | 240 | refJoinedTables := make([]*refJoinedTable, 0, len(typesJoinedTable.RefJoinedTables)) 241 | for _, typeRefJoinedTable := range typesJoinedTable.RefJoinedTables { 242 | refJoinedTables = append(refJoinedTables, &refJoinedTable{ 243 | refTable: joinedTableMap[joinedTableName(typeRefJoinedTable.Table)], 244 | joinedTable: joinedTableMap[joinedTableName(typeRefJoinedTable.JoinedTable)], 245 | }) 246 | } 247 | codegenJoinedTable.refJoinedTables = refJoinedTables 248 | } 249 | 250 | return codegenTables, codegenJoinedTables, nil 251 | } 252 | 253 | func codegenMain(dir *directory, importDecls []ast.Decl, tables []*table, joinedTables []*joinedTable) error { 254 | f, err := dir.addFile("genorm.go") 255 | if err != nil { 256 | return fmt.Errorf("failed to create file: %w", err) 257 | } 258 | defer f.Close() 259 | 260 | astFile := f.ast() 261 | 262 | astFile.Decls = append(astFile.Decls, importDecls...) 263 | 264 | for _, table := range tables { 265 | astFile.Decls = append(astFile.Decls, table.decl()...) 266 | } 267 | 268 | for _, joinedTable := range joinedTables { 269 | astFile.Decls = append(astFile.Decls, joinedTable.decl()...) 270 | } 271 | 272 | return nil 273 | } 274 | 275 | func codegenTable(dir *directory, importDecls []ast.Decl, table *table) error { 276 | rootModulePath := dir.modulePath 277 | 278 | dir, err := dir.addDirectory(table.snakeName(), table.lowerName()) 279 | if err != nil { 280 | return fmt.Errorf("failed to create directory: %w", err) 281 | } 282 | 283 | f, err := dir.addFile(fmt.Sprintf("%s.go", table.snakeName())) 284 | if err != nil { 285 | return fmt.Errorf("failed to create file: %w", err) 286 | } 287 | defer f.Close() 288 | 289 | astFile := f.ast() 290 | 291 | astFile.Decls = append(astFile.Decls, importDecls...) 292 | 293 | astFile.Decls = append(astFile.Decls, &ast.GenDecl{ 294 | Tok: token.IMPORT, 295 | Specs: []ast.Spec{ 296 | &ast.ImportSpec{ 297 | Path: &ast.BasicLit{ 298 | Kind: token.STRING, 299 | Value: fmt.Sprintf(`"%s"`, rootModulePath), 300 | }, 301 | }, 302 | &ast.ImportSpec{ 303 | Path: &ast.BasicLit{ 304 | Kind: token.STRING, 305 | Value: genormImport, 306 | }, 307 | }, 308 | }, 309 | }) 310 | 311 | astFile.Decls = append(astFile.Decls, table.tablePackageDecls()...) 312 | 313 | return nil 314 | } 315 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/column.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "strings" 8 | 9 | "github.com/mazrean/genorm/cmd/genorm/generator/types" 10 | ) 11 | 12 | type column struct { 13 | table *table 14 | columnName string 15 | fieldIdent *ast.Ident 16 | fieldType ast.Expr 17 | typeIdent *ast.Ident 18 | varIdent *ast.Ident 19 | tablePackageVarIdent *ast.Ident 20 | tablePackageExprIdent *ast.Ident 21 | recvIdent *ast.Ident 22 | } 23 | 24 | func newColumn(tbl *table, clmn *types.Column) *column { 25 | return &column{ 26 | table: tbl, 27 | columnName: clmn.Name, 28 | fieldIdent: ast.NewIdent(clmn.FieldName), 29 | fieldType: fieldTypeExpr(clmn.Type), 30 | typeIdent: ast.NewIdent(tbl.lowerName() + clmn.FieldName), 31 | varIdent: ast.NewIdent(tbl.name + clmn.FieldName), 32 | tablePackageVarIdent: ast.NewIdent(clmn.FieldName), 33 | tablePackageExprIdent: ast.NewIdent(clmn.FieldName + "Expr"), 34 | recvIdent: ast.NewIdent("c"), 35 | } 36 | } 37 | 38 | func (clmn *column) field() *ast.Field { 39 | return &ast.Field{ 40 | Names: []*ast.Ident{clmn.fieldIdent}, 41 | Type: clmn.fieldType, 42 | } 43 | } 44 | 45 | func fieldTypeExpr(columnTypeExpr ast.Expr) ast.Expr { 46 | columnTypeIdentExpr, ok := columnTypeExpr.(*ast.Ident) 47 | if ok { 48 | switch columnTypeIdentExpr.Name { 49 | case "bool", 50 | "int", "int8", "int16", "int32", "int64", 51 | "uint", "uint8", "uint16", "uint32", "uint64", 52 | "float32", "float64", 53 | "string": 54 | return wrappedPrimitive(columnTypeExpr) 55 | default: 56 | return columnTypeExpr 57 | } 58 | } 59 | 60 | columnTypeSelectorExpr, ok := columnTypeExpr.(*ast.SelectorExpr) 61 | if ok && columnTypeSelectorExpr != nil { 62 | identExpr, ok := columnTypeSelectorExpr.X.(*ast.Ident) 63 | if ok && 64 | identExpr != nil && 65 | identExpr.Name == "time" && 66 | columnTypeSelectorExpr.Sel.Name == "Time" { 67 | return wrappedPrimitive(columnTypeExpr) 68 | } 69 | } 70 | 71 | return columnTypeExpr 72 | } 73 | 74 | func (clmn *column) decls() []ast.Decl { 75 | return []ast.Decl{ 76 | clmn.structDecl(), 77 | clmn.varDecl(), 78 | clmn.exprDecl(), 79 | clmn.sqlColumnsDecl(), 80 | clmn.tableNameDecl(), 81 | clmn.columnNameDecl(), 82 | clmn.tableExprDecl(), 83 | clmn.typeExprDecl(), 84 | } 85 | } 86 | 87 | func (clmn *column) structDecl() ast.Decl { 88 | return &ast.GenDecl{ 89 | Tok: token.TYPE, 90 | Specs: []ast.Spec{ 91 | &ast.TypeSpec{ 92 | Name: clmn.typeIdent, 93 | Type: &ast.StructType{ 94 | Fields: &ast.FieldList{}, 95 | }, 96 | }, 97 | }, 98 | } 99 | } 100 | 101 | func (clmn *column) varDecl() ast.Decl { 102 | return &ast.GenDecl{ 103 | Tok: token.VAR, 104 | Specs: []ast.Spec{ 105 | &ast.ValueSpec{ 106 | Names: []*ast.Ident{clmn.varIdent}, 107 | Type: typedTableColumn(&ast.StarExpr{ 108 | X: clmn.table.structIdent, 109 | }, clmn.fieldType), 110 | Values: []ast.Expr{ 111 | &ast.CompositeLit{ 112 | Type: clmn.typeIdent, 113 | }, 114 | }, 115 | }, 116 | }, 117 | } 118 | } 119 | 120 | func (clmn *column) exprDecl() ast.Decl { 121 | return &ast.FuncDecl{ 122 | Recv: &ast.FieldList{ 123 | List: []*ast.Field{ 124 | { 125 | Names: []*ast.Ident{clmn.recvIdent}, 126 | Type: clmn.typeIdent, 127 | }, 128 | }, 129 | }, 130 | Name: exprExprIdent, 131 | Type: &ast.FuncType{ 132 | Results: &ast.FieldList{ 133 | List: []*ast.Field{ 134 | { 135 | Type: ast.NewIdent("string"), 136 | }, 137 | { 138 | Type: &ast.ArrayType{ 139 | Elt: exprTypeInterfaceTypeExpr, 140 | }, 141 | }, 142 | { 143 | Type: &ast.ArrayType{ 144 | Elt: ast.NewIdent("error"), 145 | }, 146 | }, 147 | }, 148 | }, 149 | }, 150 | Body: &ast.BlockStmt{ 151 | List: []ast.Stmt{ 152 | &ast.ReturnStmt{ 153 | Results: []ast.Expr{ 154 | &ast.CallExpr{ 155 | Fun: &ast.SelectorExpr{ 156 | X: clmn.recvIdent, 157 | Sel: columnSQLColumnsIdent, 158 | }, 159 | }, 160 | ast.NewIdent("nil"), 161 | ast.NewIdent("nil"), 162 | }, 163 | }, 164 | }, 165 | }, 166 | } 167 | } 168 | 169 | func (clmn *column) sqlColumnsDecl() ast.Decl { 170 | return &ast.FuncDecl{ 171 | Recv: &ast.FieldList{ 172 | List: []*ast.Field{ 173 | { 174 | Names: []*ast.Ident{clmn.recvIdent}, 175 | Type: clmn.typeIdent, 176 | }, 177 | }, 178 | }, 179 | Name: columnSQLColumnsIdent, 180 | Type: &ast.FuncType{ 181 | Results: &ast.FieldList{ 182 | List: []*ast.Field{ 183 | { 184 | Type: ast.NewIdent("string"), 185 | }, 186 | }, 187 | }, 188 | }, 189 | Body: &ast.BlockStmt{ 190 | List: []ast.Stmt{ 191 | &ast.ReturnStmt{ 192 | Results: []ast.Expr{ 193 | &ast.CallExpr{ 194 | Fun: &ast.SelectorExpr{ 195 | X: fmtIdent, 196 | Sel: ast.NewIdent("Sprintf"), 197 | }, 198 | Args: []ast.Expr{ 199 | &ast.BasicLit{ 200 | Kind: token.STRING, 201 | Value: `"%s.%s"`, 202 | }, 203 | &ast.CallExpr{ 204 | Fun: &ast.SelectorExpr{ 205 | X: clmn.recvIdent, 206 | Sel: columnTableNameIdent, 207 | }, 208 | }, 209 | &ast.CallExpr{ 210 | Fun: &ast.SelectorExpr{ 211 | X: clmn.recvIdent, 212 | Sel: columnColumnNameIdent, 213 | }, 214 | }, 215 | }, 216 | }, 217 | }, 218 | }, 219 | }, 220 | }, 221 | } 222 | } 223 | 224 | func (clmn *column) tableNameDecl() ast.Decl { 225 | return &ast.FuncDecl{ 226 | Recv: &ast.FieldList{ 227 | List: []*ast.Field{ 228 | { 229 | Names: []*ast.Ident{clmn.recvIdent}, 230 | Type: clmn.typeIdent, 231 | }, 232 | }, 233 | }, 234 | Name: columnTableNameIdent, 235 | Type: &ast.FuncType{ 236 | Results: &ast.FieldList{ 237 | List: []*ast.Field{ 238 | { 239 | Type: ast.NewIdent("string"), 240 | }, 241 | }, 242 | }, 243 | }, 244 | Body: &ast.BlockStmt{ 245 | List: []ast.Stmt{ 246 | &ast.ReturnStmt{ 247 | Results: []ast.Expr{ 248 | &ast.CallExpr{ 249 | Fun: &ast.SelectorExpr{ 250 | X: &ast.UnaryExpr{ 251 | Op: token.AND, 252 | X: &ast.CompositeLit{ 253 | Type: clmn.table.structIdent, 254 | Elts: []ast.Expr{}, 255 | }, 256 | }, 257 | Sel: columnTableNameIdent, 258 | }, 259 | }, 260 | }, 261 | }, 262 | }, 263 | }, 264 | } 265 | } 266 | 267 | func (clmn *column) columnNameDecl() ast.Decl { 268 | return &ast.FuncDecl{ 269 | Recv: &ast.FieldList{ 270 | List: []*ast.Field{ 271 | { 272 | Names: []*ast.Ident{clmn.recvIdent}, 273 | Type: clmn.typeIdent, 274 | }, 275 | }, 276 | }, 277 | Name: columnColumnNameIdent, 278 | Type: &ast.FuncType{ 279 | Results: &ast.FieldList{ 280 | List: []*ast.Field{ 281 | { 282 | Type: ast.NewIdent("string"), 283 | }, 284 | }, 285 | }, 286 | }, 287 | Body: &ast.BlockStmt{ 288 | List: []ast.Stmt{ 289 | &ast.ReturnStmt{ 290 | Results: []ast.Expr{ 291 | &ast.BasicLit{ 292 | Kind: token.STRING, 293 | Value: fmt.Sprintf(`"%s"`, escapeTag(clmn.columnName)), 294 | }, 295 | }, 296 | }, 297 | }, 298 | }, 299 | } 300 | } 301 | 302 | func escapeTag(tag string) string { 303 | return strings.ReplaceAll(tag, `"`, `\"`) 304 | } 305 | 306 | func (clmn *column) tableExprDecl() ast.Decl { 307 | return &ast.FuncDecl{ 308 | Recv: &ast.FieldList{ 309 | List: []*ast.Field{ 310 | { 311 | Names: []*ast.Ident{clmn.recvIdent}, 312 | Type: clmn.typeIdent, 313 | }, 314 | }, 315 | }, 316 | Name: tableExprTableExprIdent, 317 | Type: &ast.FuncType{ 318 | Params: &ast.FieldList{ 319 | List: []*ast.Field{ 320 | { 321 | Type: &ast.StarExpr{ 322 | X: clmn.table.structIdent, 323 | }, 324 | }, 325 | }, 326 | }, 327 | Results: &ast.FieldList{ 328 | List: []*ast.Field{ 329 | { 330 | Type: ast.NewIdent("string"), 331 | }, 332 | { 333 | Type: &ast.ArrayType{ 334 | Elt: exprTypeInterfaceTypeExpr, 335 | }, 336 | }, 337 | { 338 | Type: &ast.ArrayType{ 339 | Elt: ast.NewIdent("error"), 340 | }, 341 | }, 342 | }, 343 | }, 344 | }, 345 | Body: &ast.BlockStmt{ 346 | List: []ast.Stmt{ 347 | &ast.ReturnStmt{ 348 | Results: []ast.Expr{ 349 | &ast.CallExpr{ 350 | Fun: &ast.SelectorExpr{ 351 | X: clmn.recvIdent, 352 | Sel: exprExprIdent, 353 | }, 354 | }, 355 | }, 356 | }, 357 | }, 358 | }, 359 | } 360 | } 361 | 362 | func (clmn *column) typeExprDecl() ast.Decl { 363 | return &ast.FuncDecl{ 364 | Recv: &ast.FieldList{ 365 | List: []*ast.Field{ 366 | { 367 | Names: []*ast.Ident{clmn.recvIdent}, 368 | Type: clmn.typeIdent, 369 | }, 370 | }, 371 | }, 372 | Name: typedExprTypedExprIdent, 373 | Type: &ast.FuncType{ 374 | Params: &ast.FieldList{ 375 | List: []*ast.Field{ 376 | { 377 | Type: clmn.fieldType, 378 | }, 379 | }, 380 | }, 381 | Results: &ast.FieldList{ 382 | List: []*ast.Field{ 383 | { 384 | Type: ast.NewIdent("string"), 385 | }, 386 | { 387 | Type: &ast.ArrayType{ 388 | Elt: exprTypeInterfaceTypeExpr, 389 | }, 390 | }, 391 | { 392 | Type: &ast.ArrayType{ 393 | Elt: ast.NewIdent("error"), 394 | }, 395 | }, 396 | }, 397 | }, 398 | }, 399 | Body: &ast.BlockStmt{ 400 | List: []ast.Stmt{ 401 | &ast.ReturnStmt{ 402 | Results: []ast.Expr{ 403 | &ast.CallExpr{ 404 | Fun: &ast.SelectorExpr{ 405 | X: clmn.recvIdent, 406 | Sel: exprExprIdent, 407 | }, 408 | }, 409 | }, 410 | }, 411 | }, 412 | }, 413 | } 414 | } 415 | 416 | func (clmn *column) tablePackageDecls() []ast.Decl { 417 | return []ast.Decl{ 418 | clmn.tablePackageVarDecl(), 419 | clmn.tablePackageExprDecl(), 420 | } 421 | } 422 | 423 | func (clmn *column) tablePackageVarDecl() ast.Decl { 424 | return &ast.GenDecl{ 425 | Tok: token.VAR, 426 | Specs: []ast.Spec{ 427 | &ast.ValueSpec{ 428 | Names: []*ast.Ident{clmn.tablePackageVarIdent}, 429 | Type: typedTableColumn(&ast.StarExpr{ 430 | X: &ast.SelectorExpr{ 431 | X: rootPackageIdent, 432 | Sel: clmn.table.structIdent, 433 | }, 434 | }, clmn.fieldType), 435 | Values: []ast.Expr{ 436 | &ast.SelectorExpr{ 437 | X: rootPackageIdent, 438 | Sel: clmn.varIdent, 439 | }, 440 | }, 441 | }, 442 | }, 443 | } 444 | } 445 | 446 | func (clmn *column) tablePackageExprDecl() ast.Decl { 447 | return &ast.GenDecl{ 448 | Tok: token.VAR, 449 | Specs: []ast.Spec{ 450 | &ast.ValueSpec{ 451 | Names: []*ast.Ident{clmn.tablePackageExprIdent}, 452 | Type: typedTableExpr(&ast.StarExpr{ 453 | X: &ast.SelectorExpr{ 454 | X: rootPackageIdent, 455 | Sel: clmn.table.structIdent, 456 | }, 457 | }, clmn.fieldType), 458 | Values: []ast.Expr{ 459 | &ast.SelectorExpr{ 460 | X: rootPackageIdent, 461 | Sel: clmn.varIdent, 462 | }, 463 | }, 464 | }, 465 | }, 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/column_test.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import "testing" 4 | 5 | func TestEscapeTag(t *testing.T) { 6 | tests := []struct { 7 | description string 8 | tag string 9 | expected string 10 | }{ 11 | { 12 | description: "normal", 13 | tag: "hoge", 14 | expected: "hoge", 15 | }, 16 | { 17 | description: "with double quote", 18 | tag: `"piyo"`, 19 | expected: `\"piyo\"`, 20 | }, 21 | { 22 | description: "with single quote", 23 | tag: `'piyo'`, 24 | expected: `'piyo'`, 25 | }, 26 | { 27 | description: "with back quote", 28 | tag: "`piyo`", 29 | expected: "`piyo`", 30 | }, 31 | } 32 | 33 | for _, test := range tests { 34 | t.Run(test.description, func(t *testing.T) { 35 | actual := escapeTag(test.tag) 36 | if actual != test.expected { 37 | t.Errorf("expected: %s, actual: %s", test.expected, actual) 38 | } 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/directory.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/format" 8 | "go/token" 9 | "io" 10 | "os" 11 | "path/filepath" 12 | 13 | "golang.org/x/tools/imports" 14 | ) 15 | 16 | type directory struct { 17 | path string 18 | packageName string 19 | modulePath string 20 | } 21 | 22 | func newDirectory(path string, packageName string, modulePath string) (*directory, error) { 23 | err := os.MkdirAll(path, os.ModePerm) 24 | if err != nil { 25 | return nil, fmt.Errorf("failed to create destination directory: %w", err) 26 | } 27 | 28 | return &directory{ 29 | path: path, 30 | packageName: packageName, 31 | modulePath: modulePath, 32 | }, nil 33 | } 34 | 35 | func (d *directory) addDirectory(name string, packageName string) (*directory, error) { 36 | return newDirectory(filepath.Join(d.path, name), packageName, filepath.Join(d.modulePath, name)) 37 | } 38 | 39 | func (d *directory) addFile(name string) (*file, error) { 40 | return newFile(filepath.Join(d.path, name), d.packageName) 41 | } 42 | 43 | type file struct { 44 | writer io.WriteCloser 45 | file *ast.File 46 | } 47 | 48 | func newFile(path string, packageName string) (*file, error) { 49 | f, err := os.Create(path) 50 | if err != nil { 51 | return nil, fmt.Errorf("failed to create file: %w", err) 52 | } 53 | 54 | astFile := &ast.File{ 55 | Name: ast.NewIdent(packageName), 56 | } 57 | 58 | return &file{ 59 | writer: f, 60 | file: astFile, 61 | }, nil 62 | } 63 | 64 | func (f *file) ast() *ast.File { 65 | return f.file 66 | } 67 | 68 | func (f *file) Close() (err error) { 69 | defer f.writer.Close() 70 | 71 | buf := bytes.NewBuffer(nil) 72 | 73 | err = format.Node(buf, token.NewFileSet(), f.file) 74 | if err != nil { 75 | return fmt.Errorf("failed to format file: %w", err) 76 | } 77 | 78 | codeBytes, err := imports.Process("", buf.Bytes(), nil) 79 | if err != nil { 80 | return fmt.Errorf("failed to process file: %w", err) 81 | } 82 | 83 | _, err = f.writer.Write(codeBytes) 84 | if err != nil { 85 | return fmt.Errorf("failed to write file: %w", err) 86 | } 87 | 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/directory_test.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewDestination(t *testing.T) { 12 | err := os.MkdirAll("./test", os.ModePerm) 13 | if err != nil { 14 | t.Fatalf("create test directory: %s", err) 15 | } 16 | defer func() { 17 | err := os.RemoveAll("./test") 18 | if err != nil { 19 | t.Errorf("failed to remove test directory: %s", err) 20 | } 21 | }() 22 | 23 | absPath, err := filepath.Abs("./test2") 24 | if err != nil { 25 | t.Fatalf("failed to get absolute path: %s", err) 26 | } 27 | 28 | tests := []struct { 29 | description string 30 | path string 31 | packageName string 32 | modulePath string 33 | exists bool 34 | err bool 35 | }{ 36 | { 37 | description: "normal destination -> success", 38 | path: "./test2", 39 | packageName: "test2", 40 | modulePath: "github/mazrean/genorm/test2", 41 | exists: true, 42 | }, 43 | { 44 | description: "non-existent destination -> success", 45 | path: "./test2", 46 | packageName: "test2", 47 | modulePath: "github/mazrean/genorm/test2", 48 | }, 49 | { 50 | description: "destination in directory -> success", 51 | path: "./test/test", 52 | packageName: "test", 53 | modulePath: "github/mazrean/genorm/test", 54 | exists: true, 55 | }, 56 | { 57 | description: "non-existent destination in directory -> success", 58 | path: "./test/test", 59 | packageName: "test", 60 | modulePath: "github/mazrean/genorm/test", 61 | }, 62 | { 63 | description: "destination(absolute path) -> success", 64 | path: absPath, 65 | packageName: "test2", 66 | modulePath: "github/mazrean/genorm/test2", 67 | exists: true, 68 | }, 69 | { 70 | description: "non-existent destination(absolute path) -> success", 71 | path: absPath, 72 | packageName: "test2", 73 | modulePath: "github/mazrean/genorm/test2", 74 | exists: true, 75 | }, 76 | } 77 | 78 | for _, test := range tests { 79 | t.Run(test.description, func(t *testing.T) { 80 | if test.exists { 81 | err := os.MkdirAll(test.path, os.ModePerm) 82 | if err != nil { 83 | t.Fatalf("failed to create test file: %s", err) 84 | } 85 | 86 | defer func() { 87 | err := os.RemoveAll(test.path) 88 | if err != nil { 89 | t.Errorf("failed to remove test file: %s", err) 90 | } 91 | }() 92 | } 93 | 94 | dir, err := newDirectory(test.path, test.packageName, test.modulePath) 95 | if err != nil { 96 | if !test.err { 97 | t.Fatalf("unexpected error: %s", err) 98 | } 99 | return 100 | } 101 | 102 | if test.err { 103 | t.Fatalf("expected error but got none") 104 | } 105 | 106 | assert.Equal(t, test.path, dir.path) 107 | assert.Equal(t, test.packageName, dir.packageName) 108 | assert.Equal(t, test.modulePath, dir.modulePath) 109 | assert.DirExists(t, dir.path) 110 | }) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/lib_type.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "go/ast" 5 | ) 6 | 7 | var ( 8 | columnInterfaceTypeExpr = &ast.SelectorExpr{ 9 | X: genormIdent, 10 | Sel: ast.NewIdent("Column"), 11 | } 12 | exprTypeInterfaceTypeExpr = &ast.SelectorExpr{ 13 | X: genormIdent, 14 | Sel: ast.NewIdent("ExprType"), 15 | } 16 | columnFieldExprTypeExpr = &ast.SelectorExpr{ 17 | X: genormIdent, 18 | Sel: ast.NewIdent("ColumnFieldExprType"), 19 | } 20 | tableTypeExpr = &ast.SelectorExpr{ 21 | X: genormIdent, 22 | Sel: ast.NewIdent("Table"), 23 | } 24 | basicTableTypeExpr = &ast.SelectorExpr{ 25 | X: genormIdent, 26 | Sel: ast.NewIdent("BasicTable"), 27 | } 28 | relationTypeExpr = &ast.SelectorExpr{ 29 | X: genormRelationIdent, 30 | Sel: ast.NewIdent("Relation"), 31 | } 32 | 33 | exprExprIdent = ast.NewIdent("Expr") 34 | tableExprTableExprIdent = ast.NewIdent("TableExpr") 35 | typedExprTypedExprIdent = ast.NewIdent("TypedExpr") 36 | 37 | tableColumnsIdent = ast.NewIdent("Columns") 38 | tableGetErrorsIdent = ast.NewIdent("GetErrors") 39 | tableAddErrorIdent = ast.NewIdent("AddError") 40 | tableColumnMapIdent = ast.NewIdent("ColumnMap") 41 | basicTableTableNameIdent = ast.NewIdent("TableName") 42 | joinedTableBaseTablesIdent = ast.NewIdent("BaseTables") 43 | joinedTableSetRelationIdent = ast.NewIdent("SetRelation") 44 | 45 | columnSQLColumnsIdent = ast.NewIdent("SQLColumnName") 46 | columnTableNameIdent = ast.NewIdent("TableName") 47 | columnColumnNameIdent = ast.NewIdent("ColumnName") 48 | 49 | relationJoinedTableNameIdent = ast.NewIdent("JoinedTableName") 50 | ) 51 | 52 | func wrappedPrimitive(primitive ast.Expr) ast.Expr { 53 | return &ast.IndexExpr{ 54 | X: &ast.SelectorExpr{ 55 | X: genormIdent, 56 | Sel: ast.NewIdent("WrappedPrimitive"), 57 | }, 58 | Index: primitive, 59 | } 60 | } 61 | 62 | func typedTableExpr(tableType ast.Expr, exprType ast.Expr) ast.Expr { 63 | return &ast.IndexListExpr{ 64 | X: &ast.SelectorExpr{ 65 | X: genormIdent, 66 | Sel: ast.NewIdent("TypedTableExpr"), 67 | }, 68 | Indices: []ast.Expr{ 69 | tableType, 70 | exprType, 71 | }, 72 | } 73 | } 74 | 75 | func typedTableColumn(tableType ast.Expr, exprType ast.Expr) ast.Expr { 76 | return &ast.IndexListExpr{ 77 | X: &ast.SelectorExpr{ 78 | X: genormIdent, 79 | Sel: ast.NewIdent("TypedTableColumns"), 80 | }, 81 | Indices: []ast.Expr{ 82 | tableType, 83 | exprType, 84 | }, 85 | } 86 | } 87 | 88 | func relationContext(baseTable ast.Expr, refTable ast.Expr, joinedTable ast.Expr) ast.Expr { 89 | return &ast.StarExpr{ 90 | X: &ast.IndexListExpr{ 91 | X: &ast.SelectorExpr{ 92 | X: genormRelationIdent, 93 | Sel: ast.NewIdent("RelationContext"), 94 | }, 95 | Indices: []ast.Expr{ 96 | baseTable, 97 | refTable, 98 | &ast.StarExpr{ 99 | X: joinedTable, 100 | }, 101 | joinedTable, 102 | }, 103 | }, 104 | } 105 | } 106 | 107 | func newRelationContext(baseTable ast.Expr, refTable ast.Expr, joinedTable ast.Expr) ast.Expr { 108 | return &ast.IndexListExpr{ 109 | X: &ast.SelectorExpr{ 110 | X: genormRelationIdent, 111 | Sel: ast.NewIdent("NewRelationContext"), 112 | }, 113 | Indices: []ast.Expr{ 114 | baseTable, 115 | refTable, 116 | &ast.StarExpr{ 117 | X: joinedTable, 118 | }, 119 | joinedTable, 120 | }, 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /cmd/genorm/generator/codegen/table.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "go/ast" 7 | "go/token" 8 | "strings" 9 | 10 | "github.com/mazrean/genorm/cmd/genorm/generator/types" 11 | ) 12 | 13 | type table struct { 14 | table *types.Table 15 | name string 16 | funcIdent *ast.Ident 17 | structIdent *ast.Ident 18 | recvIdent *ast.Ident 19 | methods []*method 20 | columns []*column 21 | refTables []*refTable 22 | refJoinedTables []*refJoinedTable 23 | } 24 | 25 | func newTable(tbl *types.Table) (*table, error) { 26 | codegenTable := &table{ 27 | table: tbl, 28 | name: tbl.StructName, 29 | funcIdent: ast.NewIdent(tbl.StructName), 30 | structIdent: ast.NewIdent(tbl.StructName + "Table"), 31 | recvIdent: ast.NewIdent("t"), 32 | } 33 | 34 | methods := make([]*method, 0, len(tbl.Methods)) 35 | for _, m := range tbl.Methods { 36 | mthd, err := newMethod(codegenTable, m) 37 | if err != nil { 38 | return nil, fmt.Errorf("failed to create method: %w", err) 39 | } 40 | 41 | methods = append(methods, mthd) 42 | } 43 | codegenTable.methods = methods 44 | 45 | columns := make([]*column, 0, len(tbl.Columns)) 46 | for _, c := range tbl.Columns { 47 | col := newColumn(codegenTable, c) 48 | 49 | columns = append(columns, col) 50 | } 51 | codegenTable.columns = columns 52 | 53 | return codegenTable, nil 54 | } 55 | 56 | func (tbl *table) lowerName() string { 57 | return strings.ToLower(tbl.name[0:1]) + tbl.name[1:] 58 | } 59 | 60 | func (tbl *table) snakeName() string { 61 | snakeName := "" 62 | for i, c := range tbl.name { 63 | if i != 0 && c >= 'A' && c <= 'Z' { 64 | snakeName += "_" 65 | } 66 | 67 | snakeName += strings.ToLower(string(c)) 68 | } 69 | 70 | return snakeName 71 | } 72 | 73 | func (tbl *table) decl() []ast.Decl { 74 | tableDecls := []ast.Decl{} 75 | 76 | tableDecls = append(tableDecls, tbl.structDecl(), tbl.funcDecl()) 77 | 78 | for _, ref := range tbl.refTables { 79 | tableDecls = append(tableDecls, tbl.tableJoinDecl(ref)) 80 | } 81 | 82 | for _, ref := range tbl.refJoinedTables { 83 | tableDecls = append(tableDecls, tbl.joinedTableJoinDecl(ref)) 84 | } 85 | 86 | for _, method := range tbl.methods { 87 | tableDecls = append(tableDecls, method.Decl) 88 | } 89 | 90 | tableDecls = append(tableDecls, tbl.exprDecl(), tbl.columnsDecl(), tbl.columnMapDecl(), tbl.getErrorsDecl()) 91 | 92 | for _, column := range tbl.columns { 93 | tableDecls = append(tableDecls, column.decls()...) 94 | } 95 | 96 | return tableDecls 97 | } 98 | 99 | func (tbl *table) structDecl() ast.Decl { 100 | fields := make([]*ast.Field, 0, len(tbl.columns)) 101 | for _, column := range tbl.columns { 102 | fields = append(fields, column.field()) 103 | } 104 | 105 | return &ast.GenDecl{ 106 | Tok: token.TYPE, 107 | Specs: []ast.Spec{ 108 | &ast.TypeSpec{ 109 | Name: tbl.structIdent, 110 | Type: &ast.StructType{ 111 | Fields: &ast.FieldList{ 112 | List: fields, 113 | }, 114 | }, 115 | }, 116 | }, 117 | } 118 | } 119 | 120 | func (tbl *table) funcDecl() ast.Decl { 121 | return &ast.FuncDecl{ 122 | Name: tbl.funcIdent, 123 | Type: &ast.FuncType{ 124 | Results: &ast.FieldList{ 125 | List: []*ast.Field{ 126 | { 127 | Type: &ast.StarExpr{ 128 | X: tbl.structIdent, 129 | }, 130 | }, 131 | }, 132 | }, 133 | }, 134 | Body: &ast.BlockStmt{ 135 | List: []ast.Stmt{ 136 | &ast.ReturnStmt{ 137 | Results: []ast.Expr{ 138 | &ast.UnaryExpr{ 139 | Op: token.AND, 140 | X: &ast.CompositeLit{ 141 | Type: tbl.structIdent, 142 | Elts: []ast.Expr{}, 143 | }, 144 | }, 145 | }, 146 | }, 147 | }, 148 | }, 149 | } 150 | } 151 | 152 | func (tbl *table) exprDecl() ast.Decl { 153 | return &ast.FuncDecl{ 154 | Recv: &ast.FieldList{ 155 | List: []*ast.Field{ 156 | { 157 | Names: []*ast.Ident{tbl.recvIdent}, 158 | Type: &ast.StarExpr{ 159 | X: tbl.structIdent, 160 | }, 161 | }, 162 | }, 163 | }, 164 | Name: exprExprIdent, 165 | Type: &ast.FuncType{ 166 | Results: &ast.FieldList{ 167 | List: []*ast.Field{ 168 | { 169 | Type: ast.NewIdent("string"), 170 | }, 171 | { 172 | Type: &ast.ArrayType{ 173 | Elt: exprTypeInterfaceTypeExpr, 174 | }, 175 | }, 176 | { 177 | Type: &ast.ArrayType{ 178 | Elt: ast.NewIdent("error"), 179 | }, 180 | }, 181 | }, 182 | }, 183 | }, 184 | Body: &ast.BlockStmt{ 185 | List: []ast.Stmt{ 186 | &ast.ReturnStmt{ 187 | Results: []ast.Expr{ 188 | &ast.CallExpr{ 189 | Fun: &ast.SelectorExpr{ 190 | X: tbl.recvIdent, 191 | Sel: basicTableTableNameIdent, 192 | }, 193 | }, 194 | ast.NewIdent("nil"), 195 | ast.NewIdent("nil"), 196 | }, 197 | }, 198 | }, 199 | }, 200 | } 201 | } 202 | 203 | func (tbl *table) columnsDecl() ast.Decl { 204 | columnExprs := make([]ast.Expr, 0, len(tbl.columns)) 205 | for _, column := range tbl.columns { 206 | columnExprs = append(columnExprs, column.varIdent) 207 | } 208 | 209 | return &ast.FuncDecl{ 210 | Recv: &ast.FieldList{ 211 | List: []*ast.Field{ 212 | { 213 | Names: []*ast.Ident{tbl.recvIdent}, 214 | Type: &ast.StarExpr{ 215 | X: tbl.structIdent, 216 | }, 217 | }, 218 | }, 219 | }, 220 | Name: tableColumnsIdent, 221 | Type: &ast.FuncType{ 222 | Results: &ast.FieldList{ 223 | List: []*ast.Field{ 224 | { 225 | Type: &ast.ArrayType{ 226 | Elt: columnInterfaceTypeExpr, 227 | }, 228 | }, 229 | }, 230 | }, 231 | }, 232 | Body: &ast.BlockStmt{ 233 | List: []ast.Stmt{ 234 | &ast.ReturnStmt{ 235 | Results: []ast.Expr{ 236 | &ast.CompositeLit{ 237 | Type: &ast.ArrayType{ 238 | Elt: columnInterfaceTypeExpr, 239 | }, 240 | Elts: columnExprs, 241 | }, 242 | }, 243 | }, 244 | }, 245 | }, 246 | } 247 | } 248 | 249 | func (tbl *table) columnMapDecl() ast.Decl { 250 | columnMapKeyValueExprs := make([]ast.Expr, 0, len(tbl.columns)) 251 | for _, column := range tbl.columns { 252 | columnMapKeyValueExprs = append(columnMapKeyValueExprs, &ast.KeyValueExpr{ 253 | Key: &ast.CallExpr{ 254 | Fun: &ast.SelectorExpr{ 255 | X: column.varIdent, 256 | Sel: ast.NewIdent("SQLColumnName"), 257 | }, 258 | }, 259 | Value: &ast.UnaryExpr{ 260 | Op: token.AND, 261 | X: &ast.SelectorExpr{ 262 | X: tbl.recvIdent, 263 | Sel: column.fieldIdent, 264 | }, 265 | }, 266 | }) 267 | } 268 | 269 | return &ast.FuncDecl{ 270 | Recv: &ast.FieldList{ 271 | List: []*ast.Field{ 272 | { 273 | Names: []*ast.Ident{tbl.recvIdent}, 274 | Type: &ast.StarExpr{ 275 | X: tbl.structIdent, 276 | }, 277 | }, 278 | }, 279 | }, 280 | Name: tableColumnMapIdent, 281 | Type: &ast.FuncType{ 282 | Results: &ast.FieldList{ 283 | List: []*ast.Field{ 284 | { 285 | Type: &ast.MapType{ 286 | Key: ast.NewIdent("string"), 287 | Value: columnFieldExprTypeExpr, 288 | }, 289 | }, 290 | }, 291 | }, 292 | }, 293 | Body: &ast.BlockStmt{ 294 | List: []ast.Stmt{ 295 | &ast.ReturnStmt{ 296 | Results: []ast.Expr{ 297 | &ast.CompositeLit{ 298 | Type: &ast.MapType{ 299 | Key: ast.NewIdent("string"), 300 | Value: columnFieldExprTypeExpr, 301 | }, 302 | Elts: columnMapKeyValueExprs, 303 | }, 304 | }, 305 | }, 306 | }, 307 | }, 308 | } 309 | } 310 | 311 | func (tbl *table) getErrorsDecl() ast.Decl { 312 | return &ast.FuncDecl{ 313 | Recv: &ast.FieldList{ 314 | List: []*ast.Field{ 315 | { 316 | Names: []*ast.Ident{tbl.recvIdent}, 317 | Type: &ast.StarExpr{ 318 | X: tbl.structIdent, 319 | }, 320 | }, 321 | }, 322 | }, 323 | Name: tableGetErrorsIdent, 324 | Type: &ast.FuncType{ 325 | Results: &ast.FieldList{ 326 | List: []*ast.Field{ 327 | { 328 | Type: &ast.ArrayType{ 329 | Elt: ast.NewIdent("error"), 330 | }, 331 | }, 332 | }, 333 | }, 334 | }, 335 | Body: &ast.BlockStmt{ 336 | List: []ast.Stmt{ 337 | &ast.ReturnStmt{ 338 | Results: []ast.Expr{ast.NewIdent("nil")}, 339 | }, 340 | }, 341 | }, 342 | } 343 | } 344 | 345 | func (tbl *table) tableJoinDecl(ref *refTable) ast.Decl { 346 | joinIdent := ast.NewIdent(ref.refTable.name) 347 | refIdent := ast.NewIdent("ref") 348 | 349 | return &ast.FuncDecl{ 350 | Recv: &ast.FieldList{ 351 | List: []*ast.Field{ 352 | { 353 | Names: []*ast.Ident{tbl.recvIdent}, 354 | Type: &ast.StarExpr{ 355 | X: tbl.structIdent, 356 | }, 357 | }, 358 | }, 359 | }, 360 | Name: joinIdent, 361 | Type: &ast.FuncType{ 362 | Results: &ast.FieldList{ 363 | List: []*ast.Field{ 364 | { 365 | Type: relationContext(&ast.StarExpr{ 366 | X: tbl.structIdent, 367 | }, &ast.StarExpr{ 368 | X: ref.refTable.structIdent, 369 | }, 370 | ref.joinedTable.structIdent, 371 | ), 372 | }, 373 | }, 374 | }, 375 | }, 376 | Body: &ast.BlockStmt{ 377 | List: []ast.Stmt{ 378 | &ast.AssignStmt{ 379 | Lhs: []ast.Expr{refIdent}, 380 | Tok: token.DEFINE, 381 | Rhs: []ast.Expr{ 382 | &ast.CompositeLit{ 383 | Type: ref.refTable.structIdent, 384 | Elts: []ast.Expr{}, 385 | }, 386 | }, 387 | }, 388 | &ast.ReturnStmt{ 389 | Results: []ast.Expr{ 390 | &ast.CallExpr{ 391 | Fun: newRelationContext(&ast.StarExpr{ 392 | X: tbl.structIdent, 393 | }, &ast.StarExpr{ 394 | X: ref.refTable.structIdent, 395 | }, 396 | ref.joinedTable.structIdent, 397 | ), 398 | Args: []ast.Expr{ 399 | tbl.recvIdent, 400 | &ast.UnaryExpr{ 401 | Op: token.AND, 402 | X: refIdent, 403 | }, 404 | }, 405 | }, 406 | }, 407 | }, 408 | }, 409 | }, 410 | } 411 | } 412 | 413 | func (tbl *table) joinedTableJoinDecl(ref *refJoinedTable) ast.Decl { 414 | joinIdent := ast.NewIdent(ref.refTable.name) 415 | refIdent := ast.NewIdent("ref") 416 | 417 | return &ast.FuncDecl{ 418 | Recv: &ast.FieldList{ 419 | List: []*ast.Field{ 420 | { 421 | Names: []*ast.Ident{tbl.recvIdent}, 422 | Type: &ast.StarExpr{ 423 | X: tbl.structIdent, 424 | }, 425 | }, 426 | }, 427 | }, 428 | Name: joinIdent, 429 | Type: &ast.FuncType{ 430 | Params: &ast.FieldList{ 431 | List: []*ast.Field{ 432 | { 433 | Names: []*ast.Ident{refIdent}, 434 | Type: &ast.StarExpr{ 435 | X: ref.refTable.structIdent, 436 | }, 437 | }, 438 | }, 439 | }, 440 | Results: &ast.FieldList{ 441 | List: []*ast.Field{ 442 | { 443 | Type: relationContext(&ast.StarExpr{ 444 | X: tbl.structIdent, 445 | }, &ast.StarExpr{ 446 | X: ref.refTable.structIdent, 447 | }, 448 | ref.joinedTable.structIdent, 449 | ), 450 | }, 451 | }, 452 | }, 453 | }, 454 | Body: &ast.BlockStmt{ 455 | List: []ast.Stmt{ 456 | &ast.ReturnStmt{ 457 | Results: []ast.Expr{ 458 | &ast.CallExpr{ 459 | Fun: newRelationContext(&ast.StarExpr{ 460 | X: tbl.structIdent, 461 | }, &ast.StarExpr{ 462 | X: ref.refTable.structIdent, 463 | }, 464 | ref.joinedTable.structIdent, 465 | ), 466 | Args: []ast.Expr{ 467 | tbl.recvIdent, 468 | refIdent, 469 | }, 470 | }, 471 | }, 472 | }, 473 | }, 474 | }, 475 | } 476 | } 477 | 478 | func (tbl *table) tablePackageDecls() []ast.Decl { 479 | decls := []ast.Decl{} 480 | for _, column := range tbl.columns { 481 | decls = append(decls, column.tablePackageDecls()...) 482 | } 483 | 484 | return decls 485 | } 486 | 487 | type method struct { 488 | Type types.MethodType 489 | Decl *ast.FuncDecl 490 | } 491 | 492 | func newMethod(tbl *table, m *types.Method) (*method, error) { 493 | mthd := &method{ 494 | Type: m.Type, 495 | Decl: m.Decl, 496 | } 497 | 498 | if mthd.Decl == nil || 499 | mthd.Decl.Recv == nil || 500 | len(mthd.Decl.Recv.List) == 0 || 501 | mthd.Decl.Recv.List[0] == nil || 502 | mthd.Decl.Recv.List[0].Type == nil { 503 | return nil, errors.New("invalid method") 504 | } 505 | switch mthd.Type { 506 | case types.MethodTypeIdentifier: 507 | mthd.Decl.Recv.List[0].Names = []*ast.Ident{ 508 | ast.NewIdent(tbl.structIdent.Name), 509 | } 510 | mthd.Decl.Recv.List[0].Type = tbl.structIdent 511 | case types.MethodTypeStar: 512 | mthd.Decl.Recv.List[0].Names = []*ast.Ident{ 513 | ast.NewIdent(tbl.structIdent.Name), 514 | } 515 | 516 | star, ok := mthd.Decl.Recv.List[0].Type.(*ast.StarExpr) 517 | if !ok || star == nil { 518 | return nil, errors.New("invalid method") 519 | } 520 | 521 | star.X = tbl.structIdent 522 | default: 523 | return nil, errors.New("unknown method type") 524 | } 525 | 526 | return mthd, nil 527 | } 528 | -------------------------------------------------------------------------------- /cmd/genorm/generator/convert/convert_test.go: -------------------------------------------------------------------------------- 1 | package convert 2 | 3 | import ( 4 | "go/ast" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/mazrean/genorm/cmd/genorm/generator/types" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestTablesHash(t *testing.T) { 13 | t.Parallel() 14 | 15 | tests := []struct { 16 | description string 17 | joinedTable *converterJoinedTable 18 | tableLength int 19 | hash int64 20 | }{ 21 | { 22 | description: "simple", 23 | joinedTable: &converterJoinedTable{ 24 | hash: -1, 25 | tables: map[int]*converterTable{ 26 | 1: {}, 27 | }, 28 | }, 29 | tableLength: 2, 30 | hash: 1, 31 | }, 32 | { 33 | description: "id: 0", 34 | joinedTable: &converterJoinedTable{ 35 | hash: -1, 36 | tables: map[int]*converterTable{ 37 | 0: {}, 38 | }, 39 | }, 40 | tableLength: 2, 41 | hash: 0, 42 | }, 43 | { 44 | description: "multiple", 45 | joinedTable: &converterJoinedTable{ 46 | hash: -1, 47 | tables: map[int]*converterTable{ 48 | 1: {}, 49 | 3: {}, 50 | }, 51 | }, 52 | tableLength: 4, 53 | hash: 13, 54 | }, 55 | { 56 | description: "length: 0", 57 | joinedTable: &converterJoinedTable{ 58 | hash: -1, 59 | tables: map[int]*converterTable{}, 60 | }, 61 | tableLength: 0, 62 | hash: 0, 63 | }, 64 | { 65 | description: "use cache", 66 | joinedTable: &converterJoinedTable{ 67 | hash: 50, 68 | }, 69 | tableLength: 10, 70 | hash: 50, 71 | }, 72 | } 73 | 74 | for _, test := range tests { 75 | t.Run(test.description, func(t *testing.T) { 76 | assert.Equal(t, test.hash, test.joinedTable.tablesHash(test.tableLength)) 77 | }) 78 | } 79 | } 80 | 81 | func TestConvertJoinedTables(t *testing.T) { 82 | t.Parallel() 83 | 84 | typeIdent1 := ast.NewIdent("int64") 85 | funcDecl := &ast.FuncDecl{ 86 | Name: ast.NewIdent("GetID"), 87 | } 88 | 89 | messageTable := &types.Table{ 90 | StructName: "Message", 91 | Columns: []*types.Column{ 92 | { 93 | Name: "id", 94 | FieldName: "ID", 95 | Type: typeIdent1, 96 | }, 97 | }, 98 | Methods: []*types.Method{}, 99 | RefTables: []*types.RefTable{}, 100 | RefJoinedTables: []*types.RefJoinedTable{}, 101 | } 102 | userTable := &types.Table{ 103 | StructName: "User", 104 | Columns: []*types.Column{ 105 | { 106 | Name: "id", 107 | FieldName: "ID", 108 | Type: typeIdent1, 109 | }, 110 | }, 111 | Methods: []*types.Method{ 112 | { 113 | Type: types.MethodTypeIdentifier, 114 | Decl: funcDecl, 115 | }, 116 | }, 117 | RefTables: []*types.RefTable{}, 118 | RefJoinedTables: []*types.RefJoinedTable{}, 119 | } 120 | userMessageJoinedTable := &types.JoinedTable{ 121 | Tables: []*types.Table{messageTable, userTable}, 122 | RefTables: []*types.RefTable{}, 123 | RefJoinedTables: []*types.RefJoinedTable{}, 124 | } 125 | userTable.RefTables = []*types.RefTable{ 126 | { 127 | Table: messageTable, 128 | JoinedTable: userMessageJoinedTable, 129 | }, 130 | } 131 | 132 | messageOptionTable2 := &types.Table{ 133 | StructName: "MessageOption", 134 | Columns: []*types.Column{ 135 | { 136 | Name: "id", 137 | FieldName: "ID", 138 | Type: typeIdent1, 139 | }, 140 | }, 141 | Methods: []*types.Method{}, 142 | RefTables: []*types.RefTable{}, 143 | RefJoinedTables: []*types.RefJoinedTable{}, 144 | } 145 | messageTable2 := &types.Table{ 146 | StructName: "Message", 147 | Columns: []*types.Column{ 148 | { 149 | Name: "id", 150 | FieldName: "ID", 151 | Type: typeIdent1, 152 | }, 153 | }, 154 | Methods: []*types.Method{}, 155 | RefTables: []*types.RefTable{}, 156 | RefJoinedTables: []*types.RefJoinedTable{}, 157 | } 158 | userTable2 := &types.Table{ 159 | StructName: "User", 160 | Columns: []*types.Column{ 161 | { 162 | Name: "id", 163 | FieldName: "ID", 164 | Type: typeIdent1, 165 | }, 166 | }, 167 | Methods: []*types.Method{ 168 | { 169 | Type: types.MethodTypeIdentifier, 170 | Decl: funcDecl, 171 | }, 172 | }, 173 | RefTables: []*types.RefTable{}, 174 | RefJoinedTables: []*types.RefJoinedTable{}, 175 | } 176 | userMessageMessageOptionTable2 := &types.JoinedTable{ 177 | Tables: []*types.Table{messageTable2, messageOptionTable2, userTable2}, 178 | RefTables: []*types.RefTable{}, 179 | RefJoinedTables: []*types.RefJoinedTable{}, 180 | } 181 | userMessageJoinedTable2 := &types.JoinedTable{ 182 | Tables: []*types.Table{messageTable2, userTable2}, 183 | RefTables: []*types.RefTable{ 184 | { 185 | Table: messageOptionTable2, 186 | JoinedTable: userMessageMessageOptionTable2, 187 | }, 188 | }, 189 | RefJoinedTables: []*types.RefJoinedTable{}, 190 | } 191 | messageMessageOptionTable2 := &types.JoinedTable{ 192 | Tables: []*types.Table{messageTable2, messageOptionTable2}, 193 | RefTables: []*types.RefTable{}, 194 | RefJoinedTables: []*types.RefJoinedTable{}, 195 | } 196 | userTable2.RefTables = []*types.RefTable{ 197 | { 198 | Table: messageTable2, 199 | JoinedTable: userMessageJoinedTable2, 200 | }, 201 | } 202 | userTable2.RefJoinedTables = []*types.RefJoinedTable{ 203 | { 204 | Table: messageMessageOptionTable2, 205 | JoinedTable: userMessageMessageOptionTable2, 206 | }, 207 | } 208 | messageTable2.RefTables = []*types.RefTable{ 209 | { 210 | Table: messageOptionTable2, 211 | JoinedTable: messageMessageOptionTable2, 212 | }, 213 | } 214 | 215 | tests := []struct { 216 | description string 217 | tables []*types.Table 218 | expectTables []*types.Table 219 | expectJoinedTables []*types.JoinedTable 220 | joinNum int 221 | err bool 222 | }{ 223 | { 224 | description: "simple", 225 | joinNum: 5, 226 | tables: []*types.Table{ 227 | { 228 | StructName: "User", 229 | Columns: []*types.Column{ 230 | { 231 | Name: "id", 232 | FieldName: "ID", 233 | Type: typeIdent1, 234 | }, 235 | }, 236 | Methods: []*types.Method{ 237 | { 238 | Type: types.MethodTypeIdentifier, 239 | Decl: funcDecl, 240 | }, 241 | }, 242 | }, 243 | }, 244 | expectTables: []*types.Table{ 245 | { 246 | StructName: "User", 247 | Columns: []*types.Column{ 248 | { 249 | Name: "id", 250 | FieldName: "ID", 251 | Type: typeIdent1, 252 | }, 253 | }, 254 | Methods: []*types.Method{ 255 | { 256 | Type: types.MethodTypeIdentifier, 257 | Decl: funcDecl, 258 | }, 259 | }, 260 | RefTables: []*types.RefTable{}, 261 | RefJoinedTables: []*types.RefJoinedTable{}, 262 | }, 263 | }, 264 | expectJoinedTables: []*types.JoinedTable{}, 265 | }, 266 | { 267 | description: "join", 268 | joinNum: 5, 269 | tables: []*types.Table{ 270 | { 271 | StructName: "User", 272 | Columns: []*types.Column{ 273 | { 274 | Name: "id", 275 | FieldName: "ID", 276 | Type: typeIdent1, 277 | }, 278 | }, 279 | Methods: []*types.Method{ 280 | { 281 | Type: types.MethodTypeIdentifier, 282 | Decl: funcDecl, 283 | }, 284 | }, 285 | RefTables: []*types.RefTable{ 286 | { 287 | Table: messageTable, 288 | }, 289 | }, 290 | }, 291 | messageTable, 292 | }, 293 | expectTables: []*types.Table{ 294 | userTable, 295 | messageTable, 296 | }, 297 | expectJoinedTables: []*types.JoinedTable{userMessageJoinedTable}, 298 | }, 299 | { 300 | description: "no join", 301 | joinNum: 1, 302 | tables: []*types.Table{ 303 | { 304 | StructName: "User", 305 | Columns: []*types.Column{ 306 | { 307 | Name: "id", 308 | FieldName: "ID", 309 | Type: typeIdent1, 310 | }, 311 | }, 312 | Methods: []*types.Method{ 313 | { 314 | Type: types.MethodTypeIdentifier, 315 | Decl: funcDecl, 316 | }, 317 | }, 318 | RefTables: []*types.RefTable{}, 319 | }, 320 | messageTable, 321 | }, 322 | expectTables: []*types.Table{ 323 | { 324 | StructName: "User", 325 | Columns: []*types.Column{ 326 | { 327 | Name: "id", 328 | FieldName: "ID", 329 | Type: typeIdent1, 330 | }, 331 | }, 332 | Methods: []*types.Method{ 333 | { 334 | Type: types.MethodTypeIdentifier, 335 | Decl: funcDecl, 336 | }, 337 | }, 338 | RefTables: []*types.RefTable{}, 339 | RefJoinedTables: []*types.RefJoinedTable{}, 340 | }, 341 | messageTable, 342 | }, 343 | expectJoinedTables: []*types.JoinedTable{}, 344 | }, 345 | { 346 | description: "multiple join", 347 | joinNum: 5, 348 | tables: []*types.Table{ 349 | { 350 | StructName: "User", 351 | Columns: []*types.Column{ 352 | { 353 | Name: "id", 354 | FieldName: "ID", 355 | Type: typeIdent1, 356 | }, 357 | }, 358 | Methods: []*types.Method{ 359 | { 360 | Type: types.MethodTypeIdentifier, 361 | Decl: funcDecl, 362 | }, 363 | }, 364 | RefTables: []*types.RefTable{ 365 | { 366 | Table: messageTable2, 367 | }, 368 | }, 369 | }, 370 | messageTable2, 371 | messageOptionTable2, 372 | }, 373 | expectTables: []*types.Table{ 374 | userTable2, 375 | messageTable2, 376 | messageOptionTable2, 377 | }, 378 | expectJoinedTables: []*types.JoinedTable{ 379 | userMessageJoinedTable2, 380 | messageMessageOptionTable2, 381 | userMessageMessageOptionTable2, 382 | }, 383 | }, 384 | } 385 | 386 | for _, test := range tests { 387 | t.Run(test.description, func(t *testing.T) { 388 | tables, joinedTables, err := convertJoinedTables(test.tables, test.joinNum) 389 | if err != nil { 390 | assert.Error(t, err) 391 | return 392 | } 393 | 394 | assert.NoError(t, err) 395 | 396 | for _, table := range tables { 397 | sort.Slice(table.RefTables, func(i, j int) bool { 398 | return table.RefTables[i].Table.StructName < table.RefTables[j].Table.StructName 399 | }) 400 | } 401 | for _, joinedTable := range joinedTables { 402 | sort.Slice(joinedTable.Tables, func(i, j int) bool { 403 | return joinedTable.Tables[i].StructName < joinedTable.Tables[j].StructName 404 | }) 405 | } 406 | 407 | assert.ElementsMatch(t, test.expectTables, tables) 408 | assert.ElementsMatch(t, test.expectJoinedTables, joinedTables) 409 | }) 410 | } 411 | } 412 | -------------------------------------------------------------------------------- /cmd/genorm/generator/generator.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "fmt" 5 | p "go/parser" 6 | "go/token" 7 | "io" 8 | 9 | "github.com/mazrean/genorm/cmd/genorm/generator/codegen" 10 | "github.com/mazrean/genorm/cmd/genorm/generator/convert" 11 | "github.com/mazrean/genorm/cmd/genorm/generator/parser" 12 | ) 13 | 14 | type Config struct { 15 | JoinNum int 16 | } 17 | 18 | func Generate(packageName string, moduleName string, destinationDir string, src io.Reader, config Config) error { 19 | fset := token.NewFileSet() 20 | f, err := p.ParseFile(fset, "", src, p.Mode(0)) 21 | if err != nil { 22 | return fmt.Errorf("parse source: %w", err) 23 | } 24 | 25 | parserTables, err := parser.Parse(f) 26 | if err != nil { 27 | return fmt.Errorf("parse: %w", err) 28 | } 29 | 30 | tables, joinedTables, err := convert.Convert(parserTables, config.JoinNum) 31 | if err != nil { 32 | return fmt.Errorf("convert: %w", err) 33 | } 34 | 35 | err = codegen.Codegen(packageName, moduleName, destinationDir, f, tables, joinedTables) 36 | if err != nil { 37 | return fmt.Errorf("codegen: %w", err) 38 | } 39 | 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /cmd/genorm/generator/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "reflect" 7 | "strconv" 8 | 9 | "github.com/mazrean/genorm/cmd/genorm/generator/types" 10 | ) 11 | 12 | type parserTable struct { 13 | StructName string 14 | Columns []*parserColumn 15 | Methods []*parserMethod 16 | RefTables []*parserRefTable 17 | } 18 | 19 | type parserMethod struct { 20 | StructName string 21 | Type types.MethodType 22 | Decl *ast.FuncDecl 23 | } 24 | 25 | type parserRefTable struct { 26 | FieldName string 27 | StructName string 28 | } 29 | 30 | type parserColumn struct { 31 | Name string 32 | FieldName string 33 | Type ast.Expr 34 | } 35 | 36 | func Parse(f *ast.File) ([]*types.Table, error) { 37 | parserTables := []*parserTable{} 38 | methodMap := make(map[string][]*parserMethod) 39 | 40 | for _, decl := range f.Decls { 41 | genDecl, ok := decl.(*ast.GenDecl) 42 | if !ok { 43 | funcDecl, ok := decl.(*ast.FuncDecl) 44 | if !ok { 45 | continue 46 | } 47 | 48 | method, isMethod := parseFuncDecl(funcDecl) 49 | 50 | if isMethod { 51 | methodMap[method.StructName] = append(methodMap[method.StructName], method) 52 | } 53 | 54 | continue 55 | } 56 | 57 | newTables, err := parseGenDecl(genDecl) 58 | if err != nil { 59 | return nil, fmt.Errorf("parse gen: %w", err) 60 | } 61 | 62 | parserTables = append(parserTables, newTables...) 63 | } 64 | 65 | for _, parserTable := range parserTables { 66 | parserTable.Methods = methodMap[parserTable.StructName] 67 | } 68 | 69 | tables, err := convertTables(parserTables) 70 | if err != nil { 71 | return nil, fmt.Errorf("convert tables: %w", err) 72 | } 73 | 74 | return tables, nil 75 | } 76 | 77 | func convertTables(tables []*parserTable) ([]*types.Table, error) { 78 | type tablePair struct { 79 | parser *parserTable 80 | converted *types.Table 81 | } 82 | 83 | pairMap := make(map[string]*tablePair, len(tables)) 84 | for _, table := range tables { 85 | pairMap[table.StructName] = &tablePair{ 86 | parser: table, 87 | converted: convertTable(table), 88 | } 89 | } 90 | 91 | convertedTables := make([]*types.Table, 0, len(tables)) 92 | for _, pair := range pairMap { 93 | refTables := make([]*types.RefTable, 0, len(pair.parser.RefTables)) 94 | for _, refParserTable := range pair.parser.RefTables { 95 | refTable, ok := pairMap[refParserTable.StructName] 96 | if !ok { 97 | return nil, fmt.Errorf("ref table not found: %s", refParserTable.StructName) 98 | } 99 | 100 | refTables = append(refTables, &types.RefTable{ 101 | Table: refTable.converted, 102 | }) 103 | } 104 | 105 | pair.converted.RefTables = refTables 106 | 107 | convertedTables = append(convertedTables, pair.converted) 108 | } 109 | 110 | return convertedTables, nil 111 | } 112 | 113 | func convertTable(table *parserTable) *types.Table { 114 | columns := make([]*types.Column, 0, len(table.Columns)) 115 | for _, column := range table.Columns { 116 | columns = append(columns, &types.Column{ 117 | Name: column.Name, 118 | FieldName: column.FieldName, 119 | Type: column.Type, 120 | }) 121 | } 122 | 123 | methods := make([]*types.Method, 0, len(table.Methods)) 124 | for _, method := range table.Methods { 125 | methods = append(methods, &types.Method{ 126 | Type: method.Type, 127 | Decl: method.Decl, 128 | }) 129 | } 130 | 131 | return &types.Table{ 132 | StructName: table.StructName, 133 | Columns: columns, 134 | Methods: methods, 135 | } 136 | } 137 | 138 | func parseFuncDecl(f *ast.FuncDecl) (*parserMethod, bool) { 139 | recv := f.Recv 140 | if recv == nil { 141 | return nil, false 142 | } 143 | 144 | if len(recv.List) == 0 { 145 | return nil, false 146 | } 147 | 148 | recvType := recv.List[0].Type 149 | identType, ok := recvType.(*ast.Ident) 150 | if !ok { 151 | starType, ok := recvType.(*ast.StarExpr) 152 | if !ok { 153 | return nil, false 154 | } 155 | 156 | identType, ok = starType.X.(*ast.Ident) 157 | if !ok { 158 | return nil, false 159 | } 160 | 161 | return &parserMethod{ 162 | StructName: identType.Name, 163 | Type: types.MethodTypeStar, 164 | Decl: f, 165 | }, true 166 | } 167 | 168 | return &parserMethod{ 169 | StructName: identType.Name, 170 | Type: types.MethodTypeIdentifier, 171 | Decl: f, 172 | }, true 173 | } 174 | 175 | func parseGenDecl(g *ast.GenDecl) ([]*parserTable, error) { 176 | tables := []*parserTable{} 177 | 178 | for _, spec := range g.Specs { 179 | typeSpec, ok := spec.(*ast.TypeSpec) 180 | if !ok || typeSpec == nil { 181 | continue 182 | } 183 | 184 | structType, ok := typeSpec.Type.(*ast.StructType) 185 | if !ok || structType == nil { 186 | continue 187 | } 188 | 189 | table, err := parseStructType(typeSpec.Name.Name, structType) 190 | if err != nil { 191 | return nil, fmt.Errorf("parse struct: %w", err) 192 | } 193 | 194 | if table != nil { 195 | tables = append(tables, table) 196 | } 197 | } 198 | 199 | return tables, nil 200 | } 201 | 202 | func parseStructType(name string, s *ast.StructType) (*parserTable, error) { 203 | fieldList := s.Fields 204 | if fieldList == nil { 205 | return nil, nil 206 | } 207 | 208 | fields := fieldList.List 209 | if len(fields) == 0 { 210 | return nil, nil 211 | } 212 | 213 | columns := []*parserColumn{} 214 | refTables := []*parserRefTable{} 215 | for _, field := range fields { 216 | if tableName, isRef := checkRefType(field.Type); isRef { 217 | for _, name := range field.Names { 218 | if name == nil { 219 | continue 220 | } 221 | 222 | refTables = append(refTables, &parserRefTable{ 223 | StructName: tableName, 224 | FieldName: name.Name, 225 | }) 226 | } 227 | 228 | continue 229 | } 230 | 231 | tagLit := field.Tag 232 | 233 | var tag string 234 | if tagLit != nil { 235 | tagValue, err := strconv.Unquote(tagLit.Value) 236 | if err != nil { 237 | return nil, fmt.Errorf("unquote tag: %w", err) 238 | } 239 | 240 | structTag := reflect.StructTag(tagValue) 241 | tag = structTag.Get("genorm") 242 | } 243 | 244 | for _, name := range field.Names { 245 | var columnName string 246 | if len(tag) != 0 { 247 | columnName = tag 248 | } else { 249 | columnName = name.Name 250 | } 251 | 252 | columns = append(columns, &parserColumn{ 253 | Name: columnName, 254 | FieldName: name.Name, 255 | Type: field.Type, 256 | }) 257 | } 258 | } 259 | 260 | return &parserTable{ 261 | StructName: name, 262 | Columns: columns, 263 | RefTables: refTables, 264 | }, nil 265 | } 266 | 267 | func checkRefType(t ast.Expr) (string, bool) { 268 | indexExpr, ok := t.(*ast.IndexExpr) 269 | if !ok || indexExpr == nil { 270 | return "", false 271 | } 272 | 273 | selectorExpr, ok := indexExpr.X.(*ast.SelectorExpr) 274 | if !ok || selectorExpr == nil { 275 | return "", false 276 | } 277 | 278 | if selectorExpr.X == nil || selectorExpr.Sel == nil { 279 | return "", false 280 | } 281 | 282 | selectorIdent, ok := selectorExpr.X.(*ast.Ident) 283 | if !ok { 284 | return "", false 285 | } 286 | 287 | if selectorIdent.Name != "genorm" { 288 | return "", false 289 | } 290 | 291 | if selectorExpr.Sel.Name != "Ref" { 292 | return "", false 293 | } 294 | 295 | ident, ok := indexExpr.Index.(*ast.Ident) 296 | if !ok || ident == nil { 297 | return "", false 298 | } 299 | 300 | return ident.Name, true 301 | } 302 | -------------------------------------------------------------------------------- /cmd/genorm/generator/types/types.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "go/ast" 5 | ) 6 | 7 | type Table struct { 8 | StructName string 9 | Columns []*Column 10 | Methods []*Method 11 | RefTables []*RefTable 12 | RefJoinedTables []*RefJoinedTable 13 | } 14 | 15 | type Method struct { 16 | Type MethodType 17 | Decl *ast.FuncDecl 18 | } 19 | 20 | type JoinedTable struct { 21 | Tables []*Table 22 | RefTables []*RefTable 23 | RefJoinedTables []*RefJoinedTable 24 | } 25 | 26 | type MethodType int8 27 | 28 | const ( 29 | MethodTypeIdentifier MethodType = iota + 1 30 | MethodTypeStar 31 | ) 32 | 33 | type RefTable struct { 34 | Table *Table 35 | JoinedTable *JoinedTable 36 | } 37 | 38 | type RefJoinedTable struct { 39 | Table *JoinedTable 40 | JoinedTable *JoinedTable 41 | } 42 | 43 | type Column struct { 44 | Name string 45 | FieldName string 46 | Type ast.Expr 47 | } 48 | -------------------------------------------------------------------------------- /cmd/genorm/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "os" 9 | "runtime/debug" 10 | 11 | "github.com/mazrean/genorm/cmd/genorm/generator" 12 | ) 13 | 14 | var ( 15 | // flags 16 | showVersionInfo bool 17 | source string 18 | destination string 19 | packageName string 20 | moduleName string 21 | joinNum int 22 | ) 23 | 24 | func init() { 25 | flag.BoolVar(&showVersionInfo, "version", false, "If true, output version information.") 26 | flag.StringVar(&source, "source", "", "The source file to parse.") 27 | flag.StringVar(&destination, "destination", "", "The destination file to write.") 28 | flag.StringVar(&packageName, "package", "", "The root package name to use.") 29 | flag.StringVar(&moduleName, "module", "", "The root module name to use.") 30 | flag.IntVar(&joinNum, "join-num", 5, "The number of joins to generate.") 31 | } 32 | 33 | func main() { 34 | flag.Parse() 35 | 36 | if showVersionInfo { 37 | err := printVersionInfo() 38 | if err != nil { 39 | panic(err) 40 | } 41 | } 42 | 43 | src, err := openSource(source) 44 | if err != nil { 45 | panic(err) 46 | } 47 | defer src.Close() 48 | 49 | if len(destination) == 0 { 50 | panic("Destination directory path is required.") 51 | } 52 | 53 | if len(packageName) == 0 { 54 | panic("package name is required") 55 | } 56 | if len(moduleName) == 0 { 57 | panic("module name is required") 58 | } 59 | 60 | err = generator.Generate(packageName, moduleName, destination, src, generator.Config{ 61 | JoinNum: joinNum, 62 | }) 63 | if err != nil { 64 | panic(err) 65 | } 66 | } 67 | 68 | func printVersionInfo() error { 69 | buildInfo, ok := debug.ReadBuildInfo() 70 | if !ok { 71 | return errors.New("no build info") 72 | } 73 | 74 | _, err := io.WriteString(os.Stderr, fmt.Sprintf("Version: %s\n", buildInfo.Main.Version)) 75 | if err != nil { 76 | return fmt.Errorf("print version info: %w", err) 77 | } 78 | 79 | return nil 80 | } 81 | 82 | func openSource(source string) (io.ReadCloser, error) { 83 | if len(source) == 0 { 84 | return nil, errors.New("empty source file") 85 | } 86 | 87 | file, err := os.Open(source) 88 | if err != nil { 89 | return nil, fmt.Errorf("open source: %w", err) 90 | } 91 | 92 | return file, nil 93 | } 94 | -------------------------------------------------------------------------------- /cmd/genorm/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestOpenSource(t *testing.T) { 12 | err := os.MkdirAll("./test", os.ModePerm) 13 | if err != nil { 14 | t.Fatalf("create test directory: %s", err) 15 | } 16 | defer func() { 17 | err := os.RemoveAll("./test") 18 | if err != nil { 19 | t.Errorf("failed to remove test directory: %s", err) 20 | } 21 | }() 22 | 23 | absPath, err := filepath.Abs("./test.go") 24 | if err != nil { 25 | t.Fatalf("failed to get absolute path: %s", err) 26 | } 27 | 28 | tests := []struct { 29 | description string 30 | path string 31 | exists bool 32 | content string 33 | err bool 34 | }{ 35 | { 36 | description: "normal source -> success", 37 | path: "./test.go", 38 | exists: true, 39 | content: `package main 40 | 41 | func main() { 42 | fmt.Println("Hello, World!") 43 | } 44 | `, 45 | }, 46 | { 47 | description: "source in directory -> success", 48 | path: "./test/main.go", 49 | exists: true, 50 | content: `package main 51 | 52 | func main() { 53 | fmt.Println("Hello, World!") 54 | } 55 | `, 56 | }, 57 | { 58 | description: "source(absolute path) -> success", 59 | path: absPath, 60 | exists: true, 61 | content: `package main 62 | 63 | func main() { 64 | fmt.Println("Hello, World!") 65 | } 66 | `, 67 | }, 68 | { 69 | description: "non-existent source -> error", 70 | path: "./test/main.go", 71 | err: true, 72 | }, 73 | { 74 | description: "empty source -> error", 75 | path: "", 76 | err: true, 77 | }, 78 | } 79 | 80 | for _, test := range tests { 81 | t.Run(test.description, func(t *testing.T) { 82 | if test.exists { 83 | func() { 84 | f, err := os.Create(test.path) 85 | if err != nil { 86 | t.Fatalf("failed to create test file: %s", err) 87 | } 88 | defer f.Close() 89 | 90 | _, err = f.WriteString(test.content) 91 | if err != nil { 92 | t.Fatalf("failed to write test file: %s", err) 93 | } 94 | }() 95 | 96 | defer func() { 97 | err := os.Remove(test.path) 98 | if err != nil { 99 | t.Errorf("failed to remove test file: %s", err) 100 | } 101 | }() 102 | } 103 | 104 | src, err := openSource(test.path) 105 | if err != nil { 106 | if !test.err { 107 | t.Fatalf("unexpected error: %s", err) 108 | } 109 | return 110 | } 111 | defer src.Close() 112 | 113 | if test.err { 114 | t.Fatalf("expected error but got none") 115 | } 116 | 117 | if test.exists { 118 | sb := strings.Builder{} 119 | 120 | _, err := io.Copy(&sb, src) 121 | if err != nil { 122 | t.Fatalf("failed to read source: %s", err) 123 | } 124 | 125 | actualContent := sb.String() 126 | if actualContent != test.content { 127 | t.Fatalf("unexpected content: %s", actualContent) 128 | } 129 | } 130 | }) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /column.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Column interface { 4 | Expr 5 | // SQLColumnName table_name.column_name 6 | SQLColumnName() string 7 | // TableName table_name 8 | TableName() string 9 | // ColumnName column_name 10 | ColumnName() string 11 | } 12 | 13 | type TableColumns[T Table] interface { 14 | Column 15 | TableExpr[T] 16 | } 17 | 18 | type TypedColumns[S ExprType] interface { 19 | Column 20 | TypedExpr[S] 21 | } 22 | 23 | type TypedTableColumns[T Table, S ExprType] interface { 24 | TableColumns[T] 25 | TypedColumns[S] 26 | } 27 | -------------------------------------------------------------------------------- /column_mock_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "github.com/mazrean/genorm" 5 | ) 6 | 7 | //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -source=$GOFILE -destination=mock/column_mock.go -package=mock 8 | 9 | type Column interface { 10 | genorm.Column 11 | } 12 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Ref[T any] struct{} 4 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Context[T Table] struct { 4 | table T 5 | errs []error 6 | } 7 | 8 | func newContext[T Table](table T) *Context[T] { 9 | return &Context[T]{ 10 | table: table, 11 | errs: table.GetErrors(), 12 | } 13 | } 14 | 15 | func (c *Context[T]) Table() T { 16 | return c.table 17 | } 18 | 19 | func (c *Context[T]) addError(err error) { 20 | c.errs = append(c.errs, err) 21 | } 22 | 23 | func (c *Context[T]) Errors() []error { 24 | if len(c.errs) == 0 { 25 | return nil 26 | } 27 | 28 | return c.errs 29 | } 30 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | type DB interface { 9 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 10 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 11 | QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 12 | } 13 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type DeleteContext[T BasicTable] struct { 10 | *Context[T] 11 | whereCondition whereConditionClause[T] 12 | order orderClause[T] 13 | limit limitClause 14 | } 15 | 16 | func Delete[T BasicTable](table T) *DeleteContext[T] { 17 | ctx := newContext(table) 18 | 19 | return &DeleteContext[T]{ 20 | Context: ctx, 21 | } 22 | } 23 | 24 | func (c *DeleteContext[T]) Where( 25 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 26 | ) *DeleteContext[T] { 27 | err := c.whereCondition.set(condition) 28 | if err != nil { 29 | c.addError(fmt.Errorf("where condition: %w", err)) 30 | } 31 | 32 | return c 33 | } 34 | 35 | func (c *DeleteContext[T]) OrderBy(direction OrderDirection, expr TableExpr[T]) *DeleteContext[T] { 36 | err := c.order.add(orderItem[T]{ 37 | expr: expr, 38 | direction: direction, 39 | }) 40 | if err != nil { 41 | c.addError(fmt.Errorf("order by: %w", err)) 42 | } 43 | 44 | return c 45 | } 46 | 47 | func (c *DeleteContext[T]) Limit(limit uint64) *DeleteContext[T] { 48 | err := c.limit.set(limit) 49 | if err != nil { 50 | c.addError(fmt.Errorf("limit: %w", err)) 51 | } 52 | 53 | return c 54 | } 55 | 56 | func (c *DeleteContext[T]) DoCtx(ctx context.Context, db DB) (rowsAffected int64, err error) { 57 | errs := c.Errors() 58 | if len(errs) != 0 { 59 | return 0, errs[0] 60 | } 61 | 62 | query, exprArgs, err := c.buildQuery() 63 | if err != nil { 64 | return 0, fmt.Errorf("build query: %w", err) 65 | } 66 | 67 | args := make([]any, 0, len(exprArgs)) 68 | for _, arg := range exprArgs { 69 | args = append(args, arg) 70 | } 71 | 72 | result, err := db.ExecContext(ctx, query, args...) 73 | if err != nil { 74 | return 0, fmt.Errorf("exec: %w", err) 75 | } 76 | 77 | rowsAffected, err = result.RowsAffected() 78 | if err != nil { 79 | return 0, fmt.Errorf("rows affected: %w", err) 80 | } 81 | 82 | return rowsAffected, nil 83 | } 84 | 85 | func (c *DeleteContext[T]) Do(db DB) (rowsAffected int64, err error) { 86 | return c.DoCtx(context.Background(), db) 87 | } 88 | 89 | func (c *DeleteContext[T]) buildQuery() (string, []ExprType, error) { 90 | args := []ExprType{} 91 | 92 | sb := strings.Builder{} 93 | 94 | str := "DELETE FROM " 95 | _, err := sb.WriteString(str) 96 | if err != nil { 97 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 98 | } 99 | 100 | str = c.table.TableName() 101 | _, err = sb.WriteString(str) 102 | if err != nil { 103 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 104 | } 105 | 106 | if c.whereCondition.exists() { 107 | whereQuery, whereArgs, err := c.whereCondition.getExpr() 108 | if err != nil { 109 | return "", nil, fmt.Errorf("where condition: %w", err) 110 | } 111 | 112 | str = " WHERE " 113 | _, err = sb.WriteString(str) 114 | if err != nil { 115 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 116 | } 117 | 118 | _, err = sb.WriteString(whereQuery) 119 | if err != nil { 120 | return "", nil, fmt.Errorf("write string(%s): %w", whereQuery, err) 121 | } 122 | 123 | args = append(args, whereArgs...) 124 | } 125 | 126 | if c.order.exists() { 127 | orderQuery, orderArgs, err := c.order.getExpr() 128 | if err != nil { 129 | return "", nil, fmt.Errorf("order: %w", err) 130 | } 131 | 132 | str = " " 133 | _, err = sb.WriteString(str) 134 | if err != nil { 135 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 136 | } 137 | 138 | _, err = sb.WriteString(orderQuery) 139 | if err != nil { 140 | return "", nil, fmt.Errorf("write string(%s): %w", orderQuery, err) 141 | } 142 | 143 | args = append(args, orderArgs...) 144 | } 145 | 146 | if c.limit.exists() { 147 | limitQuery, limitArgs, err := c.limit.getExpr() 148 | if err != nil { 149 | return "", nil, fmt.Errorf("limit: %w", err) 150 | } 151 | 152 | str = " " 153 | _, err = sb.WriteString(str) 154 | if err != nil { 155 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 156 | } 157 | 158 | _, err = sb.WriteString(limitQuery) 159 | if err != nil { 160 | return "", nil, fmt.Errorf("write string(%s): %w", limitQuery, err) 161 | } 162 | 163 | args = append(args, limitArgs...) 164 | } 165 | 166 | return sb.String(), args, nil 167 | } 168 | -------------------------------------------------------------------------------- /delete_exporter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *DeleteContext[T]) BuildQuery() (string, []ExprType, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/mazrean/genorm" 9 | "github.com/mazrean/genorm/mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestDeleteBuildQuery(t *testing.T) { 14 | t.Parallel() 15 | 16 | type expr struct { 17 | query string 18 | args []genorm.ExprType 19 | errs []error 20 | } 21 | 22 | type orderItem struct { 23 | direction genorm.OrderDirection 24 | expr expr 25 | } 26 | 27 | tests := []struct { 28 | description string 29 | tableName string 30 | whereCondition *expr 31 | orderItems []orderItem 32 | limit uint64 33 | query string 34 | args []genorm.ExprType 35 | err bool 36 | }{ 37 | { 38 | description: "normal", 39 | tableName: "hoge", 40 | query: "DELETE FROM hoge", 41 | args: []genorm.ExprType{}, 42 | }, 43 | { 44 | description: "where", 45 | tableName: "hoge", 46 | whereCondition: &expr{ 47 | query: "(hoge.huga = ?)", 48 | args: []genorm.ExprType{genorm.Wrap(1)}, 49 | }, 50 | query: "DELETE FROM hoge WHERE (hoge.huga = ?)", 51 | args: []genorm.ExprType{genorm.Wrap(1)}, 52 | }, 53 | { 54 | description: "where error", 55 | tableName: "hoge", 56 | whereCondition: &expr{ 57 | errs: []error{errors.New("where error")}, 58 | }, 59 | err: true, 60 | }, 61 | { 62 | description: "order by", 63 | tableName: "hoge", 64 | orderItems: []orderItem{ 65 | { 66 | direction: genorm.Asc, 67 | expr: expr{ 68 | query: "(hoge.huga = ?)", 69 | args: []genorm.ExprType{genorm.Wrap(1)}, 70 | }, 71 | }, 72 | }, 73 | query: "DELETE FROM hoge ORDER BY (hoge.huga = ?) ASC", 74 | args: []genorm.ExprType{genorm.Wrap(1)}, 75 | }, 76 | { 77 | description: "order by error", 78 | tableName: "hoge", 79 | orderItems: []orderItem{ 80 | { 81 | direction: genorm.Asc, 82 | expr: expr{ 83 | errs: []error{errors.New("order by error")}, 84 | }, 85 | }, 86 | }, 87 | err: true, 88 | }, 89 | { 90 | description: "multi order by", 91 | tableName: "hoge", 92 | orderItems: []orderItem{ 93 | { 94 | direction: genorm.Asc, 95 | expr: expr{ 96 | query: "(hoge.huga = ?)", 97 | args: []genorm.ExprType{genorm.Wrap(1)}, 98 | }, 99 | }, 100 | { 101 | direction: genorm.Desc, 102 | expr: expr{ 103 | query: "(hoge.nya = ?)", 104 | args: []genorm.ExprType{genorm.Wrap(2)}, 105 | }, 106 | }, 107 | }, 108 | query: "DELETE FROM hoge ORDER BY (hoge.huga = ?) ASC, (hoge.nya = ?) DESC", 109 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 110 | }, 111 | { 112 | description: "limit", 113 | tableName: "hoge", 114 | limit: 1, 115 | query: "DELETE FROM hoge LIMIT 1", 116 | args: []genorm.ExprType{}, 117 | }, 118 | } 119 | 120 | for _, test := range tests { 121 | t.Run(test.description, func(t *testing.T) { 122 | ctrl := gomock.NewController(t) 123 | 124 | table := mock.NewMockBasicTable(ctrl) 125 | table. 126 | EXPECT(). 127 | TableName(). 128 | Return(test.tableName) 129 | table. 130 | EXPECT(). 131 | GetErrors(). 132 | Return(nil) 133 | 134 | builder := genorm.Delete(table) 135 | 136 | if test.whereCondition != nil { 137 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockBasicTable, genorm.WrappedPrimitive[bool]](ctrl) 138 | mockExpr. 139 | EXPECT(). 140 | Expr(). 141 | Return(test.whereCondition.query, test.whereCondition.args, test.whereCondition.errs) 142 | 143 | builder = builder.Where(mockExpr) 144 | } 145 | 146 | for _, orderItem := range test.orderItems { 147 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockBasicTable, genorm.WrappedPrimitive[bool]](ctrl) 148 | mockExpr. 149 | EXPECT(). 150 | Expr(). 151 | Return(orderItem.expr.query, orderItem.expr.args, orderItem.expr.errs) 152 | builder = builder.OrderBy(orderItem.direction, mockExpr) 153 | } 154 | 155 | if test.limit > 0 { 156 | builder = builder.Limit(test.limit) 157 | } 158 | 159 | query, args, err := builder.BuildQuery() 160 | 161 | if test.err { 162 | assert.Error(t, err) 163 | return 164 | } else { 165 | if !assert.NoError(t, err) { 166 | return 167 | } 168 | } 169 | 170 | assert.Equal(t, test.query, query) 171 | assert.Equal(t, test.args, args) 172 | }) 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrRecordNotFound = errors.New("record not found") 7 | ErrNullValue = errors.New("null value") 8 | ) 9 | -------------------------------------------------------------------------------- /expr.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Expr interface { 4 | Expr() (string, []ExprType, []error) 5 | } 6 | 7 | type TableExpr[T Table] interface { 8 | Expr 9 | TableExpr(T) (string, []ExprType, []error) 10 | } 11 | 12 | type TypedExpr[T ExprType] interface { 13 | Expr 14 | TypedExpr(T) (string, []ExprType, []error) 15 | } 16 | 17 | type TypedTableExpr[T Table, S ExprType] interface { 18 | Expr 19 | TableExpr[T] 20 | TypedExpr[S] 21 | } 22 | 23 | type TableAssignExpr[T Table] struct { 24 | query string 25 | args []ExprType 26 | errs []error 27 | } 28 | 29 | func (tae *TableAssignExpr[_]) AssignExpr() (string, []ExprType, []error) { 30 | if len(tae.errs) != 0 { 31 | return "", nil, tae.errs 32 | } 33 | 34 | return tae.query, tae.args, nil 35 | } 36 | 37 | type ExprStruct[T Table, S ExprType] struct { 38 | query string 39 | args []ExprType 40 | errs []error 41 | } 42 | 43 | func RawExpr[T Table, S ExprType](query string, args ...ExprType) *ExprStruct[T, S] { 44 | return &ExprStruct[T, S]{ 45 | query: query, 46 | args: args, 47 | } 48 | } 49 | 50 | func (es *ExprStruct[_, _]) Expr() (string, []ExprType, []error) { 51 | if len(es.errs) != 0 { 52 | return "", nil, es.errs 53 | } 54 | 55 | return es.query, es.args, nil 56 | } 57 | 58 | func (es *ExprStruct[T, _]) TableExpr(T) (string, []ExprType, []error) { 59 | return es.Expr() 60 | } 61 | 62 | func (es *ExprStruct[_, S]) TypedExpr(S) (string, []ExprType, []error) { 63 | return es.Expr() 64 | } 65 | -------------------------------------------------------------------------------- /expr_exporter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func NewTableAssignExpr[T Table]( 4 | query string, 5 | args []ExprType, 6 | errs []error, 7 | ) *TableAssignExpr[T] { 8 | return &TableAssignExpr[T]{ 9 | query: query, 10 | args: args, 11 | errs: errs, 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /expr_mock_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "github.com/mazrean/genorm" 5 | ) 6 | 7 | //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -source=$GOFILE -destination=mock/expr_mock.go -package=mock 8 | 9 | type Expr interface { 10 | genorm.Expr 11 | } 12 | -------------------------------------------------------------------------------- /find.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | type FindContext[S Table, T TuplePointer[U], U any] struct { 12 | *Context[S] 13 | distinct bool 14 | tuple T 15 | whereCondition whereConditionClause[S] 16 | groupExpr groupClause[S] 17 | havingCondition whereConditionClause[S] 18 | order orderClause[S] 19 | limit limitClause 20 | offset offsetClause 21 | lockType lockClause 22 | } 23 | 24 | func Find[T Table, S TuplePointer[U], U any](table T, tuple S) *FindContext[T, S, U] { 25 | return &FindContext[T, S, U]{ 26 | Context: newContext(table), 27 | tuple: tuple, 28 | } 29 | } 30 | 31 | func (c *FindContext[S, T, U]) Distinct() *FindContext[S, T, U] { 32 | if c.distinct { 33 | c.addError(errors.New("distinct already set")) 34 | return c 35 | } 36 | 37 | c.distinct = true 38 | 39 | return c 40 | } 41 | 42 | func (c *FindContext[S, T, U]) Where( 43 | condition TypedTableExpr[S, WrappedPrimitive[bool]], 44 | ) *FindContext[S, T, U] { 45 | err := c.whereCondition.set(condition) 46 | if err != nil { 47 | c.addError(fmt.Errorf("where condition: %w", err)) 48 | } 49 | 50 | return c 51 | } 52 | 53 | func (c *FindContext[S, T, U]) GroupBy(exprs ...TableExpr[S]) *FindContext[S, T, U] { 54 | err := c.groupExpr.set(exprs) 55 | if err != nil { 56 | c.addError(fmt.Errorf("group by: %w", err)) 57 | } 58 | 59 | return c 60 | } 61 | 62 | func (c *FindContext[S, T, U]) Having( 63 | condition TypedTableExpr[S, WrappedPrimitive[bool]], 64 | ) *FindContext[S, T, U] { 65 | err := c.havingCondition.set(condition) 66 | if err != nil { 67 | c.addError(fmt.Errorf("having condition: %w", err)) 68 | } 69 | 70 | return c 71 | } 72 | 73 | func (c *FindContext[S, T, U]) OrderBy(direction OrderDirection, expr TableExpr[S]) *FindContext[S, T, U] { 74 | err := c.order.add(orderItem[S]{ 75 | expr: expr, 76 | direction: direction, 77 | }) 78 | if err != nil { 79 | c.addError(fmt.Errorf("order by: %w", err)) 80 | } 81 | 82 | return c 83 | } 84 | 85 | func (c *FindContext[S, T, U]) Limit(limit uint64) *FindContext[S, T, U] { 86 | err := c.limit.set(limit) 87 | if err != nil { 88 | c.addError(fmt.Errorf("limit: %w", err)) 89 | } 90 | 91 | return c 92 | } 93 | 94 | func (c *FindContext[S, T, U]) Offset(offset uint64) *FindContext[S, T, U] { 95 | err := c.offset.set(offset) 96 | if err != nil { 97 | c.addError(fmt.Errorf("offset: %w", err)) 98 | } 99 | 100 | return c 101 | } 102 | 103 | func (c *FindContext[S, T, U]) Lock(lockType LockType) *FindContext[S, T, U] { 104 | err := c.lockType.set(lockType) 105 | if err != nil { 106 | c.addError(fmt.Errorf("lock: %w", err)) 107 | } 108 | 109 | return c 110 | } 111 | 112 | func (c *FindContext[S, T, U]) GetAllCtx(ctx context.Context, db DB) ([]T, error) { 113 | errs := c.Errors() 114 | if len(errs) != 0 { 115 | return nil, errs[0] 116 | } 117 | 118 | query, exprArgs, err := c.buildQuery() 119 | if err != nil { 120 | return nil, fmt.Errorf("build query: %w", err) 121 | } 122 | 123 | args := make([]any, 0, len(exprArgs)) 124 | for _, arg := range exprArgs { 125 | args = append(args, arg) 126 | } 127 | 128 | rows, err := db.QueryContext(ctx, query, args...) 129 | if errors.Is(err, sql.ErrNoRows) { 130 | return []T{}, nil 131 | } 132 | if err != nil { 133 | return nil, fmt.Errorf("query: %w", err) 134 | } 135 | defer rows.Close() 136 | 137 | exprs := []T{} 138 | for rows.Next() { 139 | var tuple U 140 | columns := T(&tuple).Columns() 141 | dests := make([]any, 0, len(columns)) 142 | for _, column := range columns { 143 | dests = append(dests, column) 144 | } 145 | 146 | err = rows.Scan(dests...) 147 | if err != nil { 148 | return nil, fmt.Errorf("query: %w", err) 149 | } 150 | 151 | exprs = append(exprs, &tuple) 152 | } 153 | 154 | return exprs, nil 155 | } 156 | 157 | func (c *FindContext[S, T, U]) GetAll(db DB) ([]T, error) { 158 | return c.GetAllCtx(context.Background(), db) 159 | } 160 | 161 | func (c *FindContext[S, T, U]) GetCtx(ctx context.Context, db DB) (T, error) { 162 | err := c.limit.set(1) 163 | if err != nil { 164 | return nil, fmt.Errorf("set limit 1: %w", err) 165 | } 166 | 167 | errs := c.Errors() 168 | if len(errs) != 0 { 169 | return nil, errs[0] 170 | } 171 | 172 | query, queryArgs, err := c.buildQuery() 173 | if err != nil { 174 | return nil, fmt.Errorf("build query: %w", err) 175 | } 176 | 177 | args := make([]any, 0, len(queryArgs)) 178 | for _, arg := range queryArgs { 179 | args = append(args, arg) 180 | } 181 | 182 | row := db.QueryRowContext(ctx, query, args...) 183 | 184 | var tuple U 185 | columns := T(&tuple).Columns() 186 | dests := make([]any, 0, len(columns)) 187 | for _, column := range columns { 188 | dests = append(dests, column) 189 | } 190 | 191 | err = row.Scan(dests...) 192 | if errors.Is(err, sql.ErrNoRows) { 193 | return nil, ErrRecordNotFound 194 | } 195 | if err != nil { 196 | return nil, fmt.Errorf("query: %w", err) 197 | } 198 | 199 | return &tuple, nil 200 | } 201 | 202 | func (c *FindContext[S, T, U]) Get(db DB) (T, error) { 203 | return c.GetCtx(context.Background(), db) 204 | } 205 | 206 | func (c *FindContext[S, T, U]) buildQuery() (string, []ExprType, error) { 207 | sb := strings.Builder{} 208 | args := []ExprType{} 209 | 210 | str := "SELECT " 211 | _, err := sb.WriteString(str) 212 | if err != nil { 213 | return "", nil, fmt.Errorf("write select(%s): %w", str, err) 214 | } 215 | 216 | if c.distinct { 217 | str = "DISTINCT " 218 | _, err = sb.WriteString(str) 219 | if err != nil { 220 | return "", nil, fmt.Errorf("write distinct(%s): %w", str, err) 221 | } 222 | } 223 | 224 | fields := c.tuple.Exprs() 225 | for i, field := range fields { 226 | if i != 0 { 227 | str = ", " 228 | _, err = sb.WriteString(str) 229 | if err != nil { 230 | return "", nil, fmt.Errorf("write comma(%s): %w", str, err) 231 | } 232 | } 233 | 234 | fieldQuery, fieldArgs, errs := field.Expr() 235 | if len(errs) != 0 { 236 | return "", nil, fmt.Errorf("field: %w", errs[0]) 237 | } 238 | 239 | _, err = sb.WriteString(fieldQuery) 240 | if err != nil { 241 | return "", nil, fmt.Errorf("write field(%s): %w", fieldQuery, err) 242 | } 243 | 244 | str = fmt.Sprintf(" AS value%d", i) 245 | _, err = sb.WriteString(str) 246 | if err != nil { 247 | return "", nil, fmt.Errorf("write as(%s): %w", str, err) 248 | } 249 | 250 | args = append(args, fieldArgs...) 251 | } 252 | 253 | str = " FROM " 254 | _, err = sb.WriteString(str) 255 | if err != nil { 256 | return "", nil, fmt.Errorf("write from(%s): %w", str, err) 257 | } 258 | 259 | tableQuery, tableArgs, errs := c.table.Expr() 260 | if len(errs) != 0 { 261 | return "", nil, fmt.Errorf("table expr: %w", errs[0]) 262 | } 263 | 264 | _, err = sb.WriteString(tableQuery) 265 | if err != nil { 266 | return "", nil, fmt.Errorf("write table(%s): %w", tableQuery, err) 267 | } 268 | 269 | args = append(args, tableArgs...) 270 | 271 | if c.whereCondition.exists() { 272 | whereQuery, whereArgs, err := c.whereCondition.getExpr() 273 | if err != nil { 274 | return "", nil, fmt.Errorf("where condition: %w", err) 275 | } 276 | 277 | str = " WHERE " 278 | _, err = sb.WriteString(str) 279 | if err != nil { 280 | return "", nil, fmt.Errorf("write where(%s): %w", str, err) 281 | } 282 | 283 | _, err = sb.WriteString(whereQuery) 284 | if err != nil { 285 | return "", nil, fmt.Errorf("write where(%s): %w", whereQuery, err) 286 | } 287 | 288 | args = append(args, whereArgs...) 289 | } 290 | 291 | if c.groupExpr.exists() { 292 | groupExpr, groupArgs, err := c.groupExpr.getExpr() 293 | if err != nil { 294 | return "", nil, fmt.Errorf("group expr: %w", err) 295 | } 296 | 297 | str = " " 298 | _, err = sb.WriteString(str) 299 | if err != nil { 300 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 301 | } 302 | 303 | _, err = sb.WriteString(groupExpr) 304 | if err != nil { 305 | return "", nil, fmt.Errorf("write string(%s): %w", groupExpr, err) 306 | } 307 | 308 | args = append(args, groupArgs...) 309 | } 310 | 311 | if c.havingCondition.exists() { 312 | havingQuery, havingArgs, err := c.havingCondition.getExpr() 313 | if err != nil { 314 | return "", nil, fmt.Errorf("having condition: %w", err) 315 | } 316 | 317 | str = " HAVING " 318 | _, err = sb.WriteString(str) 319 | if err != nil { 320 | return "", nil, fmt.Errorf("write having(%s): %w", str, err) 321 | } 322 | 323 | _, err = sb.WriteString(havingQuery) 324 | if err != nil { 325 | return "", nil, fmt.Errorf("write having(%s): %w", havingQuery, err) 326 | } 327 | 328 | args = append(args, havingArgs...) 329 | } 330 | 331 | if c.order.exists() { 332 | orderQuery, orderArgs, err := c.order.getExpr() 333 | if err != nil { 334 | return "", nil, fmt.Errorf("order: %w", err) 335 | } 336 | 337 | str = " " 338 | _, err = sb.WriteString(str) 339 | if err != nil { 340 | return "", nil, fmt.Errorf("write order(%s): %w", str, err) 341 | } 342 | 343 | _, err = sb.WriteString(orderQuery) 344 | if err != nil { 345 | return "", nil, fmt.Errorf("write order(%s): %w", orderQuery, err) 346 | } 347 | 348 | args = append(args, orderArgs...) 349 | } 350 | 351 | if c.limit.exists() { 352 | limitQuery, limitArgs, err := c.limit.getExpr() 353 | if err != nil { 354 | return "", nil, fmt.Errorf("limit: %w", err) 355 | } 356 | 357 | str = " " 358 | _, err = sb.WriteString(str) 359 | if err != nil { 360 | return "", nil, fmt.Errorf("write limit(%s): %w", str, err) 361 | } 362 | 363 | _, err = sb.WriteString(limitQuery) 364 | if err != nil { 365 | return "", nil, fmt.Errorf("write limit(%s): %w", limitQuery, err) 366 | } 367 | 368 | args = append(args, limitArgs...) 369 | } 370 | 371 | if c.offset.exists() { 372 | offsetQuery, offsetArgs, err := c.offset.getExpr() 373 | if err != nil { 374 | return "", nil, fmt.Errorf("offset: %w", err) 375 | } 376 | 377 | str = " " 378 | _, err = sb.WriteString(str) 379 | if err != nil { 380 | return "", nil, fmt.Errorf("write offset(%s): %w", str, err) 381 | } 382 | 383 | _, err = sb.WriteString(offsetQuery) 384 | if err != nil { 385 | return "", nil, fmt.Errorf("write offset(%s): %w", offsetQuery, err) 386 | } 387 | 388 | args = append(args, offsetArgs...) 389 | } 390 | 391 | if c.lockType.exists() { 392 | lockQuery, lockArgs, err := c.lockType.getExpr() 393 | if err != nil { 394 | return "", nil, fmt.Errorf("lock type: %w", err) 395 | } 396 | 397 | str = " " 398 | _, err = sb.WriteString(str) 399 | if err != nil { 400 | return "", nil, fmt.Errorf("write lock(%s): %w", str, err) 401 | } 402 | 403 | _, err = sb.WriteString(lockQuery) 404 | if err != nil { 405 | return "", nil, fmt.Errorf("write lock(%s): %w", lockQuery, err) 406 | } 407 | 408 | args = append(args, lockArgs...) 409 | } 410 | 411 | return sb.String(), args, nil 412 | } 413 | -------------------------------------------------------------------------------- /find_expoter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *FindContext[_, _, _]) BuildQuery() (string, []ExprType, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /function.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // Aggregate Functions 9 | 10 | /* 11 | Avg 12 | if (distinct) {return AVG(DISTINCT expr)} 13 | else {return AVG(expr)} 14 | */ 15 | func Avg[T Table, S ExprType](expr TypedTableExpr[T, S], distinct bool) TypedTableExpr[T, WrappedPrimitive[float64]] { 16 | if expr == nil { 17 | return &ExprStruct[T, WrappedPrimitive[float64]]{ 18 | errs: []error{errors.New("avg expr is nil")}, 19 | } 20 | } 21 | 22 | query, args, errs := expr.Expr() 23 | if len(errs) != 0 { 24 | return &ExprStruct[T, WrappedPrimitive[float64]]{ 25 | errs: errs, 26 | } 27 | } 28 | 29 | if distinct { 30 | query = fmt.Sprintf("AVG(DISTINCT %s)", query) 31 | } else { 32 | query = fmt.Sprintf("AVG(%s)", query) 33 | } 34 | 35 | return &ExprStruct[T, WrappedPrimitive[float64]]{ 36 | query: query, 37 | args: args, 38 | } 39 | } 40 | 41 | /* 42 | Count 43 | if (distinct) {return COUNT(DISTINCT expr)} 44 | else {return COUNT(expr)} 45 | */ 46 | func Count[T Table, S ExprType](expr TypedTableExpr[T, S], distinct bool) TypedTableExpr[T, WrappedPrimitive[int64]] { 47 | if expr == nil { 48 | return &ExprStruct[T, WrappedPrimitive[int64]]{ 49 | errs: []error{errors.New("count expr is nil")}, 50 | } 51 | } 52 | 53 | query, args, errs := expr.Expr() 54 | if len(errs) != 0 { 55 | return &ExprStruct[T, WrappedPrimitive[int64]]{ 56 | errs: errs, 57 | } 58 | } 59 | 60 | if distinct { 61 | query = fmt.Sprintf("COUNT(DISTINCT %s)", query) 62 | } else { 63 | query = fmt.Sprintf("COUNT(%s)", query) 64 | } 65 | 66 | return &ExprStruct[T, WrappedPrimitive[int64]]{ 67 | query: query, 68 | args: args, 69 | } 70 | } 71 | 72 | // Max MAX(expr) 73 | func Max[T Table, S ExprType](expr TypedTableExpr[T, S]) TypedTableExpr[T, S] { 74 | if expr == nil { 75 | return &ExprStruct[T, S]{ 76 | errs: []error{errors.New("max expr is nil")}, 77 | } 78 | } 79 | 80 | query, args, errs := expr.Expr() 81 | if len(errs) != 0 { 82 | return &ExprStruct[T, S]{ 83 | errs: errs, 84 | } 85 | } 86 | 87 | return &ExprStruct[T, S]{ 88 | query: fmt.Sprintf("MAX(%s)", query), 89 | args: args, 90 | } 91 | } 92 | 93 | // Min MIN(expr) 94 | func Min[T Table, S ExprType](expr TypedTableExpr[T, S]) TypedTableExpr[T, S] { 95 | if expr == nil { 96 | return &ExprStruct[T, S]{ 97 | errs: []error{errors.New("max expr is nil")}, 98 | } 99 | } 100 | 101 | query, args, errs := expr.Expr() 102 | if len(errs) != 0 { 103 | return &ExprStruct[T, S]{ 104 | errs: errs, 105 | } 106 | } 107 | 108 | return &ExprStruct[T, S]{ 109 | query: fmt.Sprintf("MIN(%s)", query), 110 | args: args, 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /function_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/mazrean/genorm" 9 | "github.com/mazrean/genorm/mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestAvg(t *testing.T) { 14 | t.Parallel() 15 | 16 | tests := []struct { 17 | description string 18 | exprIsNil bool 19 | exprQuery string 20 | exprArgs []genorm.ExprType 21 | exprErrs []error 22 | distinct bool 23 | expectedQuery string 24 | expectedArgs []genorm.ExprType 25 | isError bool 26 | }{ 27 | { 28 | description: "normal", 29 | exprQuery: "(hoge.huga = ?)", 30 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 31 | expectedQuery: "AVG((hoge.huga = ?))", 32 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 33 | }, 34 | { 35 | description: "distinct", 36 | exprQuery: "(hoge.huga = ?)", 37 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 38 | distinct: true, 39 | expectedQuery: "AVG(DISTINCT (hoge.huga = ?))", 40 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 41 | }, 42 | { 43 | description: "nil expr", 44 | exprIsNil: true, 45 | isError: true, 46 | }, 47 | { 48 | description: "expr error", 49 | exprErrs: []error{errors.New("expr1 error")}, 50 | isError: true, 51 | }, 52 | { 53 | description: "expr1 no args", 54 | exprQuery: "(hoge.huga = hoge.huga)", 55 | exprArgs: nil, 56 | expectedQuery: "AVG((hoge.huga = hoge.huga))", 57 | expectedArgs: nil, 58 | }, 59 | } 60 | 61 | for _, test := range tests { 62 | t.Run(test.description, func(t *testing.T) { 63 | ctrl := gomock.NewController(t) 64 | 65 | var expr genorm.TypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]] 66 | if test.exprIsNil { 67 | expr = nil 68 | } else { 69 | mockExpr := mock.NewMockTypedTableColumn[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 70 | expr = mockExpr 71 | 72 | mockExpr. 73 | EXPECT(). 74 | Expr(). 75 | Return(test.exprQuery, test.exprArgs, test.exprErrs) 76 | } 77 | 78 | res := genorm.Avg(expr, test.distinct) 79 | 80 | assert.NotNil(t, res) 81 | 82 | query, args, errs := res.Expr() 83 | if test.isError { 84 | assert.NotNil(t, errs) 85 | assert.NotEmpty(t, errs) 86 | 87 | return 88 | } 89 | 90 | if !assert.Nil(t, errs) { 91 | return 92 | } 93 | 94 | assert.Equal(t, test.expectedQuery, query) 95 | assert.Equal(t, test.expectedArgs, args) 96 | }) 97 | } 98 | } 99 | 100 | func TestCount(t *testing.T) { 101 | t.Parallel() 102 | 103 | tests := []struct { 104 | description string 105 | exprIsNil bool 106 | exprQuery string 107 | exprArgs []genorm.ExprType 108 | exprErrs []error 109 | distinct bool 110 | expectedQuery string 111 | expectedArgs []genorm.ExprType 112 | isError bool 113 | }{ 114 | { 115 | description: "normal", 116 | exprQuery: "(hoge.huga = ?)", 117 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 118 | expectedQuery: "COUNT((hoge.huga = ?))", 119 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 120 | }, 121 | { 122 | description: "distinct", 123 | exprQuery: "(hoge.huga = ?)", 124 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 125 | distinct: true, 126 | expectedQuery: "COUNT(DISTINCT (hoge.huga = ?))", 127 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 128 | }, 129 | { 130 | description: "nil expr", 131 | exprIsNil: true, 132 | isError: true, 133 | }, 134 | { 135 | description: "expr error", 136 | exprErrs: []error{errors.New("expr1 error")}, 137 | isError: true, 138 | }, 139 | { 140 | description: "expr1 no args", 141 | exprQuery: "(hoge.huga = hoge.huga)", 142 | exprArgs: nil, 143 | expectedQuery: "COUNT((hoge.huga = hoge.huga))", 144 | expectedArgs: nil, 145 | }, 146 | } 147 | 148 | for _, test := range tests { 149 | t.Run(test.description, func(t *testing.T) { 150 | ctrl := gomock.NewController(t) 151 | 152 | var expr genorm.TypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]] 153 | if test.exprIsNil { 154 | expr = nil 155 | } else { 156 | mockExpr := mock.NewMockTypedTableColumn[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 157 | expr = mockExpr 158 | 159 | mockExpr. 160 | EXPECT(). 161 | Expr(). 162 | Return(test.exprQuery, test.exprArgs, test.exprErrs) 163 | } 164 | 165 | res := genorm.Count(expr, test.distinct) 166 | 167 | assert.NotNil(t, res) 168 | 169 | query, args, errs := res.Expr() 170 | if test.isError { 171 | assert.NotNil(t, errs) 172 | assert.NotEmpty(t, errs) 173 | 174 | return 175 | } 176 | 177 | if !assert.Nil(t, errs) { 178 | return 179 | } 180 | 181 | assert.Equal(t, test.expectedQuery, query) 182 | assert.Equal(t, test.expectedArgs, args) 183 | }) 184 | } 185 | } 186 | 187 | func TestMax(t *testing.T) { 188 | t.Parallel() 189 | 190 | tests := []struct { 191 | description string 192 | exprIsNil bool 193 | exprQuery string 194 | exprArgs []genorm.ExprType 195 | exprErrs []error 196 | expectedQuery string 197 | expectedArgs []genorm.ExprType 198 | isError bool 199 | }{ 200 | { 201 | description: "normal", 202 | exprQuery: "(hoge.huga = ?)", 203 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 204 | expectedQuery: "MAX((hoge.huga = ?))", 205 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 206 | }, 207 | { 208 | description: "nil expr", 209 | exprIsNil: true, 210 | isError: true, 211 | }, 212 | { 213 | description: "expr error", 214 | exprErrs: []error{errors.New("expr1 error")}, 215 | isError: true, 216 | }, 217 | { 218 | description: "expr1 no args", 219 | exprQuery: "(hoge.huga = hoge.huga)", 220 | exprArgs: nil, 221 | expectedQuery: "MAX((hoge.huga = hoge.huga))", 222 | expectedArgs: nil, 223 | }, 224 | } 225 | 226 | for _, test := range tests { 227 | t.Run(test.description, func(t *testing.T) { 228 | ctrl := gomock.NewController(t) 229 | 230 | var expr genorm.TypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]] 231 | if test.exprIsNil { 232 | expr = nil 233 | } else { 234 | mockExpr := mock.NewMockTypedTableColumn[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 235 | expr = mockExpr 236 | 237 | mockExpr. 238 | EXPECT(). 239 | Expr(). 240 | Return(test.exprQuery, test.exprArgs, test.exprErrs) 241 | } 242 | 243 | res := genorm.Max(expr) 244 | 245 | assert.NotNil(t, res) 246 | 247 | query, args, errs := res.Expr() 248 | if test.isError { 249 | assert.NotNil(t, errs) 250 | assert.NotEmpty(t, errs) 251 | 252 | return 253 | } 254 | 255 | if !assert.Nil(t, errs) { 256 | return 257 | } 258 | 259 | assert.Equal(t, test.expectedQuery, query) 260 | assert.Equal(t, test.expectedArgs, args) 261 | }) 262 | } 263 | } 264 | 265 | func TestMin(t *testing.T) { 266 | t.Parallel() 267 | 268 | tests := []struct { 269 | description string 270 | exprIsNil bool 271 | exprQuery string 272 | exprArgs []genorm.ExprType 273 | exprErrs []error 274 | expectedQuery string 275 | expectedArgs []genorm.ExprType 276 | isError bool 277 | }{ 278 | { 279 | description: "normal", 280 | exprQuery: "(hoge.huga = ?)", 281 | exprArgs: []genorm.ExprType{genorm.Wrap(1)}, 282 | expectedQuery: "MIN((hoge.huga = ?))", 283 | expectedArgs: []genorm.ExprType{genorm.Wrap(1)}, 284 | }, 285 | { 286 | description: "nil expr", 287 | exprIsNil: true, 288 | isError: true, 289 | }, 290 | { 291 | description: "expr error", 292 | exprErrs: []error{errors.New("expr1 error")}, 293 | isError: true, 294 | }, 295 | { 296 | description: "expr1 no args", 297 | exprQuery: "(hoge.huga = hoge.huga)", 298 | exprArgs: nil, 299 | expectedQuery: "MIN((hoge.huga = hoge.huga))", 300 | expectedArgs: nil, 301 | }, 302 | } 303 | 304 | for _, test := range tests { 305 | t.Run(test.description, func(t *testing.T) { 306 | ctrl := gomock.NewController(t) 307 | 308 | var expr genorm.TypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]] 309 | if test.exprIsNil { 310 | expr = nil 311 | } else { 312 | mockExpr := mock.NewMockTypedTableColumn[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 313 | expr = mockExpr 314 | 315 | mockExpr. 316 | EXPECT(). 317 | Expr(). 318 | Return(test.exprQuery, test.exprArgs, test.exprErrs) 319 | } 320 | 321 | res := genorm.Min(expr) 322 | 323 | assert.NotNil(t, res) 324 | 325 | query, args, errs := res.Expr() 326 | if test.isError { 327 | assert.NotNil(t, errs) 328 | assert.NotEmpty(t, errs) 329 | 330 | return 331 | } 332 | 333 | if !assert.Nil(t, errs) { 334 | return 335 | } 336 | 337 | assert.Equal(t, test.expectedQuery, query) 338 | assert.Equal(t, test.expectedArgs, args) 339 | }) 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mazrean/genorm 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | github.com/golang/mock v1.6.0 7 | github.com/stretchr/testify v1.10.0 8 | ) 9 | 10 | require ( 11 | golang.org/x/mod v0.23.0 // indirect 12 | golang.org/x/sync v0.11.0 // indirect 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | golang.org/x/tools v0.30.0 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= 4 | github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= 5 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 6 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 10 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 11 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 12 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 13 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 14 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 15 | golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= 16 | golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 17 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 18 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 19 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 20 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 21 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 22 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 23 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 24 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 25 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 26 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 28 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 29 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 30 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 31 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 32 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 33 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 34 | golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 35 | golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= 36 | golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= 37 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 38 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 39 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 40 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 42 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 43 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 44 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | type InsertContext[T BasicTable] struct { 11 | *Context[T] 12 | values []T 13 | fields []TableColumns[T] 14 | } 15 | 16 | func Insert[T BasicTable](table T) *InsertContext[T] { 17 | ctx := newContext(table) 18 | 19 | return &InsertContext[T]{ 20 | Context: ctx, 21 | fields: nil, 22 | } 23 | } 24 | 25 | func (c *InsertContext[T]) Values(tableBases ...T) *InsertContext[T] { 26 | if len(tableBases) == 0 { 27 | c.addError(errors.New("no values")) 28 | 29 | return c 30 | } 31 | if len(c.values) != 0 { 32 | c.addError(errors.New("values already set")) 33 | 34 | return c 35 | } 36 | 37 | c.values = append(c.values, tableBases...) 38 | 39 | return c 40 | } 41 | 42 | func (c *InsertContext[T]) Fields(fields ...TableColumns[T]) *InsertContext[T] { 43 | if c.fields != nil { 44 | c.addError(errors.New("fields already set")) 45 | return c 46 | } 47 | if len(fields) == 0 { 48 | c.addError(errors.New("no fields")) 49 | return c 50 | } 51 | 52 | fields = append(c.fields, fields...) 53 | fieldMap := make(map[TableColumns[T]]struct{}, len(fields)) 54 | for _, field := range fields { 55 | if _, ok := fieldMap[field]; ok { 56 | c.addError(errors.New("duplicate field")) 57 | return c 58 | } 59 | 60 | fieldMap[field] = struct{}{} 61 | } 62 | 63 | c.fields = fields 64 | 65 | return c 66 | } 67 | 68 | func (c *InsertContext[T]) DoCtx(ctx context.Context, db DB) (rowsAffected int64, err error) { 69 | errs := c.Errors() 70 | if len(errs) != 0 { 71 | return 0, errs[0] 72 | } 73 | 74 | query, args, err := c.buildQuery() 75 | if err != nil { 76 | return 0, fmt.Errorf("build query: %w", err) 77 | } 78 | 79 | result, err := db.ExecContext(ctx, query, args...) 80 | if err != nil { 81 | return 0, fmt.Errorf("exec: %w", err) 82 | } 83 | 84 | rowsAffected, err = result.RowsAffected() 85 | if err != nil { 86 | return 0, fmt.Errorf("rows affected: %w", err) 87 | } 88 | 89 | return rowsAffected, nil 90 | } 91 | 92 | func (c *InsertContext[T]) Do(db DB) (rowsAffected int64, err error) { 93 | return c.DoCtx(context.Background(), db) 94 | } 95 | 96 | func (c *InsertContext[T]) buildQuery() (string, []any, error) { 97 | args := []any{} 98 | 99 | sb := &strings.Builder{} 100 | 101 | str := "INSERT INTO " 102 | _, err := sb.WriteString(str) 103 | if err != nil { 104 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 105 | } 106 | 107 | str = c.table.TableName() 108 | _, err = sb.WriteString(str) 109 | if err != nil { 110 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 111 | } 112 | 113 | var fields []string 114 | if c.fields == nil { 115 | columns := c.table.Columns() 116 | fields = make([]string, 0, len(columns)) 117 | for _, column := range columns { 118 | fields = append(fields, column.SQLColumnName()) 119 | } 120 | } else { 121 | fields = make([]string, 0, len(c.fields)) 122 | for _, field := range c.fields { 123 | fields = append(fields, field.SQLColumnName()) 124 | } 125 | } 126 | 127 | str = " (" 128 | _, err = sb.WriteString(str) 129 | if err != nil { 130 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 131 | } 132 | 133 | str = strings.Join(fields, ", ") 134 | _, err = sb.WriteString(str) 135 | if err != nil { 136 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 137 | } 138 | 139 | str = ") VALUES " 140 | _, err = sb.WriteString(str) 141 | if err != nil { 142 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 143 | } 144 | 145 | for i, value := range c.values { 146 | if i != 0 { 147 | str = ", " 148 | _, err = sb.WriteString(str) 149 | if err != nil { 150 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 151 | } 152 | } 153 | 154 | var err error 155 | sb, args, err = c.buildValueList(sb, args, fields, value.ColumnMap()) 156 | if err != nil { 157 | return "", nil, fmt.Errorf("build value list: %w", err) 158 | } 159 | } 160 | 161 | return sb.String(), args, nil 162 | } 163 | 164 | func (c *InsertContext[T]) buildValueList(sb *strings.Builder, args []any, fields []string, fieldValueMap map[string]ColumnFieldExprType) (*strings.Builder, []any, error) { 165 | str := "(" 166 | _, err := sb.WriteString(str) 167 | if err != nil { 168 | return sb, args, fmt.Errorf("write string(%s): %w", str, err) 169 | } 170 | 171 | for i, columnName := range fields { 172 | if i != 0 { 173 | str = ", " 174 | _, err = sb.WriteString(str) 175 | if err != nil { 176 | return sb, args, fmt.Errorf("write string(%s): %w", str, err) 177 | } 178 | } 179 | 180 | columnField, ok := fieldValueMap[columnName] 181 | if !ok { 182 | return sb, nil, fmt.Errorf("field(%s) not found", columnName) 183 | } 184 | 185 | _, err := columnField.Value() 186 | if err != nil && !errors.Is(err, ErrNullValue) { 187 | return sb, nil, fmt.Errorf("failed to get field value: %w", err) 188 | } 189 | 190 | if errors.Is(err, ErrNullValue) { 191 | str = "NULL" 192 | _, err = sb.WriteString(str) 193 | if err != nil { 194 | return sb, nil, fmt.Errorf("write string(%s): %w", str, err) 195 | } 196 | } else { 197 | str = "?" 198 | _, err = sb.WriteString(str) 199 | if err != nil { 200 | return sb, nil, fmt.Errorf("write string(%s): %w", str, err) 201 | } 202 | 203 | args = append(args, columnField) 204 | } 205 | } 206 | 207 | str = ")" 208 | _, err = sb.WriteString(str) 209 | if err != nil { 210 | return sb, nil, fmt.Errorf("write string(%s): %w", str, err) 211 | } 212 | 213 | return sb, args, nil 214 | } 215 | -------------------------------------------------------------------------------- /insert_exporter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *InsertContext[T]) BuildQuery() (string, []any, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/golang/mock/gomock" 7 | "github.com/mazrean/genorm" 8 | "github.com/mazrean/genorm/mock" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestInsertBuildQuery(t *testing.T) { 13 | t.Parallel() 14 | 15 | columnFieldExpr1 := genorm.Wrap(1) 16 | columnFieldExpr2 := genorm.Wrap(2) 17 | columnFieldNull := genorm.WrappedPrimitive[int]{} 18 | 19 | tests := []struct { 20 | description string 21 | tableName string 22 | isFieldSet bool 23 | fields []string 24 | values []map[string]genorm.ColumnFieldExprType 25 | query string 26 | args []any 27 | err bool 28 | }{ 29 | { 30 | description: "normal", 31 | tableName: "hoge", 32 | fields: []string{"hoge.huga"}, 33 | values: []map[string]genorm.ColumnFieldExprType{ 34 | { 35 | "hoge.huga": &columnFieldExpr1, 36 | }, 37 | }, 38 | query: "INSERT INTO hoge (hoge.huga) VALUES (?)", 39 | args: []any{&columnFieldExpr1}, 40 | }, 41 | { 42 | description: "multi fields", 43 | tableName: "hoge", 44 | fields: []string{"hoge.huga", "hoge.piyo"}, 45 | values: []map[string]genorm.ColumnFieldExprType{ 46 | { 47 | "hoge.huga": &columnFieldExpr1, 48 | "hoge.piyo": &columnFieldExpr2, 49 | }, 50 | }, 51 | query: "INSERT INTO hoge (hoge.huga, hoge.piyo) VALUES (?, ?)", 52 | args: []any{&columnFieldExpr1, &columnFieldExpr2}, 53 | }, 54 | { 55 | description: "multi values", 56 | tableName: "hoge", 57 | fields: []string{"hoge.huga"}, 58 | values: []map[string]genorm.ColumnFieldExprType{ 59 | { 60 | "hoge.huga": &columnFieldExpr1, 61 | }, 62 | { 63 | "hoge.huga": &columnFieldExpr2, 64 | }, 65 | }, 66 | query: "INSERT INTO hoge (hoge.huga) VALUES (?), (?)", 67 | args: []any{&columnFieldExpr1, &columnFieldExpr2}, 68 | }, 69 | { 70 | description: "null value", 71 | tableName: "hoge", 72 | fields: []string{"hoge.huga"}, 73 | values: []map[string]genorm.ColumnFieldExprType{ 74 | { 75 | "hoge.huga": &columnFieldNull, 76 | }, 77 | }, 78 | query: "INSERT INTO hoge (hoge.huga) VALUES (NULL)", 79 | args: []any{}, 80 | }, 81 | } 82 | 83 | for _, test := range tests { 84 | t.Run(test.description, func(t *testing.T) { 85 | ctrl := gomock.NewController(t) 86 | 87 | table := mock.NewMockBasicTable(ctrl) 88 | table. 89 | EXPECT(). 90 | TableName(). 91 | Return(test.tableName) 92 | table. 93 | EXPECT(). 94 | GetErrors(). 95 | Return(nil) 96 | 97 | builder := genorm.Insert(table) 98 | 99 | fields := make([]genorm.Column, 0, len(test.fields)) 100 | if test.isFieldSet { 101 | tableFields := make([]genorm.TableColumns[*mock.MockBasicTable], 0, len(test.fields)) 102 | for _, field := range test.fields { 103 | mockColumn := mock.NewMockTypedTableColumn[*mock.MockBasicTable, genorm.WrappedPrimitive[bool]](ctrl) 104 | mockColumn. 105 | EXPECT(). 106 | SQLColumnName(). 107 | Return(field) 108 | 109 | tableFields = append(tableFields, mockColumn) 110 | fields = append(fields, mockColumn) 111 | } 112 | 113 | builder.Fields(tableFields...) 114 | } else { 115 | for _, field := range test.fields { 116 | mockColumn := mock.NewMockColumn(ctrl) 117 | mockColumn. 118 | EXPECT(). 119 | SQLColumnName(). 120 | Return(field) 121 | 122 | fields = append(fields, mockColumn) 123 | } 124 | 125 | table. 126 | EXPECT(). 127 | Columns(). 128 | Return(fields) 129 | } 130 | 131 | values := make([]*mock.MockBasicTable, 0, len(test.values)) 132 | for _, value := range test.values { 133 | table := mock.NewMockBasicTable(ctrl) 134 | table. 135 | EXPECT(). 136 | ColumnMap(). 137 | Return(value) 138 | 139 | values = append(values, table) 140 | } 141 | builder = builder.Values(values...) 142 | 143 | query, args, err := builder.BuildQuery() 144 | 145 | if test.err { 146 | assert.Error(t, err) 147 | return 148 | } else { 149 | if !assert.NoError(t, err) { 150 | return 151 | } 152 | } 153 | 154 | assert.Equal(t, test.query, query) 155 | assert.Equal(t, test.args, args) 156 | }) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /mock/column.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | reflect "reflect" 5 | 6 | gomock "github.com/golang/mock/gomock" 7 | genorm "github.com/mazrean/genorm" 8 | ) 9 | 10 | // MockTypedTableColumn is a mock of Column interface. 11 | type MockTypedTableColumn[T genorm.Table, S genorm.ExprType] struct { 12 | ctrl *gomock.Controller 13 | recorder *MockTypedTableColumnMockRecorder[T, S] 14 | } 15 | 16 | // MockTypedTableColumnMockRecorder is the mock recorder for MockTypedTableColumn. 17 | type MockTypedTableColumnMockRecorder[T genorm.Table, S genorm.ExprType] struct { 18 | mock *MockTypedTableColumn[T, S] 19 | } 20 | 21 | // NewMockTypedTableColumn creates a new mock instance. 22 | func NewMockTypedTableColumn[T genorm.Table, S genorm.ExprType](ctrl *gomock.Controller) *MockTypedTableColumn[T, S] { 23 | mock := &MockTypedTableColumn[T, S]{ctrl: ctrl} 24 | mock.recorder = &MockTypedTableColumnMockRecorder[T, S]{mock} 25 | return mock 26 | } 27 | 28 | // EXPECT returns an object that allows the caller to indicate expected use. 29 | func (m *MockTypedTableColumn[T, S]) EXPECT() *MockTypedTableColumnMockRecorder[T, S] { 30 | return m.recorder 31 | } 32 | 33 | // ColumnName mocks base method. 34 | func (m *MockTypedTableColumn[_, _]) ColumnName() string { 35 | m.ctrl.T.Helper() 36 | ret := m.ctrl.Call(m, "ColumnName") 37 | ret0, _ := ret[0].(string) 38 | return ret0 39 | } 40 | 41 | // ColumnName indicates an expected call of ColumnName. 42 | func (mr *MockTypedTableColumnMockRecorder[T, S]) ColumnName() *gomock.Call { 43 | mr.mock.ctrl.T.Helper() 44 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ColumnName", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).ColumnName)) 45 | } 46 | 47 | // Expr mocks base method. 48 | func (m *MockTypedTableColumn[_, _]) Expr() (string, []genorm.ExprType, []error) { 49 | m.ctrl.T.Helper() 50 | ret := m.ctrl.Call(m, "Expr") 51 | ret0, _ := ret[0].(string) 52 | ret1, _ := ret[1].([]genorm.ExprType) 53 | ret2, _ := ret[2].([]error) 54 | return ret0, ret1, ret2 55 | } 56 | 57 | // Expr indicates an expected call of Expr. 58 | func (mr *MockTypedTableColumnMockRecorder[T, S]) Expr() *gomock.Call { 59 | mr.mock.ctrl.T.Helper() 60 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Expr", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).Expr)) 61 | } 62 | 63 | // TableExpr mocks base method. 64 | func (m *MockTypedTableColumn[T, _]) TableExpr(t T) (string, []genorm.ExprType, []error) { 65 | m.ctrl.T.Helper() 66 | ret := m.ctrl.Call(m, "TableExpr", t) 67 | ret0, _ := ret[0].(string) 68 | ret1, _ := ret[1].([]genorm.ExprType) 69 | ret2, _ := ret[2].([]error) 70 | return ret0, ret1, ret2 71 | } 72 | 73 | // TableExpr indicates an expected call of Expr. 74 | func (mr *MockTypedTableColumnMockRecorder[T, S]) TableExpr(t T) *gomock.Call { 75 | mr.mock.ctrl.T.Helper() 76 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TableExpr", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).Expr), t) 77 | } 78 | 79 | // TypedExpr mocks base method. 80 | func (m *MockTypedTableColumn[_, S]) TypedExpr(s S) (string, []genorm.ExprType, []error) { 81 | m.ctrl.T.Helper() 82 | ret := m.ctrl.Call(m, "TypedExpr", s) 83 | ret0, _ := ret[0].(string) 84 | ret1, _ := ret[1].([]genorm.ExprType) 85 | ret2, _ := ret[2].([]error) 86 | return ret0, ret1, ret2 87 | } 88 | 89 | // Expr indicates an expected call of Expr. 90 | func (mr *MockTypedTableColumnMockRecorder[T, S]) TypedExpr(s S) *gomock.Call { 91 | mr.mock.ctrl.T.Helper() 92 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TypedExpr", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).Expr), s) 93 | } 94 | 95 | // SQLColumnName mocks base method. 96 | func (m *MockTypedTableColumn[_, _]) SQLColumnName() string { 97 | m.ctrl.T.Helper() 98 | ret := m.ctrl.Call(m, "SQLColumnName") 99 | ret0, _ := ret[0].(string) 100 | return ret0 101 | } 102 | 103 | // SQLColumnName indicates an expected call of SQLColumnName. 104 | func (mr *MockTypedTableColumnMockRecorder[T, S]) SQLColumnName() *gomock.Call { 105 | mr.mock.ctrl.T.Helper() 106 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SQLColumnName", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).SQLColumnName)) 107 | } 108 | 109 | // TableName mocks base method. 110 | func (m *MockTypedTableColumn[_, _]) TableName() string { 111 | m.ctrl.T.Helper() 112 | ret := m.ctrl.Call(m, "TableName") 113 | ret0, _ := ret[0].(string) 114 | return ret0 115 | } 116 | 117 | // TableName indicates an expected call of TableName. 118 | func (mr *MockTypedTableColumnMockRecorder[T, S]) TableName() *gomock.Call { 119 | mr.mock.ctrl.T.Helper() 120 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TableName", reflect.TypeOf((*MockTypedTableColumn[T, S])(nil).TableName)) 121 | } 122 | -------------------------------------------------------------------------------- /mock/expr.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | reflect "reflect" 5 | 6 | gomock "github.com/golang/mock/gomock" 7 | genorm "github.com/mazrean/genorm" 8 | ) 9 | 10 | // MockTypedTableExpr is a mock of Expr interface. 11 | type MockTypedTableExpr[T genorm.Table, S genorm.ExprType] struct { 12 | ctrl *gomock.Controller 13 | recorder *MockTypedTableExprMockRecorder[T, S] 14 | } 15 | 16 | // MockTypedTableExprMockRecorder is the mock recorder for MockTypedTableExpr. 17 | type MockTypedTableExprMockRecorder[T genorm.Table, S genorm.ExprType] struct { 18 | mock *MockTypedTableExpr[T, S] 19 | } 20 | 21 | // NewMockTypedTableExpr creates a new mock instance. 22 | func NewMockTypedTableExpr[T genorm.Table, S genorm.ExprType](ctrl *gomock.Controller) *MockTypedTableExpr[T, S] { 23 | mock := &MockTypedTableExpr[T, S]{ctrl: ctrl} 24 | mock.recorder = &MockTypedTableExprMockRecorder[T, S]{mock} 25 | return mock 26 | } 27 | 28 | // EXPECT returns an object that allows the caller to indicate expected use. 29 | func (m *MockTypedTableExpr[T, S]) EXPECT() *MockTypedTableExprMockRecorder[T, S] { 30 | return m.recorder 31 | } 32 | 33 | // Expr mocks base method. 34 | func (m *MockTypedTableExpr[_, _]) Expr() (string, []genorm.ExprType, []error) { 35 | m.ctrl.T.Helper() 36 | ret := m.ctrl.Call(m, "Expr") 37 | ret0, _ := ret[0].(string) 38 | ret1, _ := ret[1].([]genorm.ExprType) 39 | ret2, _ := ret[2].([]error) 40 | return ret0, ret1, ret2 41 | } 42 | 43 | // Expr indicates an expected call of Expr. 44 | func (mr *MockTypedTableExprMockRecorder[T, S]) Expr() *gomock.Call { 45 | mr.mock.ctrl.T.Helper() 46 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Expr", reflect.TypeOf((*MockTypedTableExpr[T, S])(nil).Expr)) 47 | } 48 | 49 | // TableExpr mocks base method. 50 | func (m *MockTypedTableExpr[T, _]) TableExpr(t T) (string, []genorm.ExprType, []error) { 51 | m.ctrl.T.Helper() 52 | ret := m.ctrl.Call(m, "TableExpr", t) 53 | ret0, _ := ret[0].(string) 54 | ret1, _ := ret[1].([]genorm.ExprType) 55 | ret2, _ := ret[2].([]error) 56 | return ret0, ret1, ret2 57 | } 58 | 59 | // TableExpr indicates an expected call of Expr. 60 | func (mr *MockTypedTableExprMockRecorder[T, S]) TableExpr(t T) *gomock.Call { 61 | mr.mock.ctrl.T.Helper() 62 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TableExpr", reflect.TypeOf((*MockTypedTableExpr[T, S])(nil).Expr), t) 63 | } 64 | 65 | // TypedExpr mocks base method. 66 | func (m *MockTypedTableExpr[_, S]) TypedExpr(s S) (string, []genorm.ExprType, []error) { 67 | m.ctrl.T.Helper() 68 | ret := m.ctrl.Call(m, "TypedExpr", s) 69 | ret0, _ := ret[0].(string) 70 | ret1, _ := ret[1].([]genorm.ExprType) 71 | ret2, _ := ret[2].([]error) 72 | return ret0, ret1, ret2 73 | } 74 | 75 | // TableExpr indicates an expected call of Expr. 76 | func (mr *MockTypedTableExprMockRecorder[T, S]) TypedExpr(s S) *gomock.Call { 77 | mr.mock.ctrl.T.Helper() 78 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TypedExpr", reflect.TypeOf((*MockTypedTableExpr[T, S])(nil).Expr), s) 79 | } 80 | -------------------------------------------------------------------------------- /pluck.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | type PluckContext[T Table, S ExprType] struct { 12 | *Context[T] 13 | distinct bool 14 | field TypedTableExpr[T, S] 15 | whereCondition whereConditionClause[T] 16 | groupExpr groupClause[T] 17 | havingCondition whereConditionClause[T] 18 | order orderClause[T] 19 | limit limitClause 20 | offset offsetClause 21 | lockType lockClause 22 | } 23 | 24 | func Pluck[T Table, S ExprType](table T, field TypedTableExpr[T, S]) *PluckContext[T, S] { 25 | return &PluckContext[T, S]{ 26 | Context: newContext(table), 27 | field: field, 28 | } 29 | } 30 | 31 | func (c *PluckContext[T, S]) Distinct() *PluckContext[T, S] { 32 | if c.distinct { 33 | c.addError(errors.New("distinct already set")) 34 | return c 35 | } 36 | 37 | c.distinct = true 38 | 39 | return c 40 | } 41 | 42 | func (c *PluckContext[T, S]) Where( 43 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 44 | ) *PluckContext[T, S] { 45 | err := c.whereCondition.set(condition) 46 | if err != nil { 47 | c.addError(fmt.Errorf("where condition: %w", err)) 48 | } 49 | 50 | return c 51 | } 52 | 53 | func (c *PluckContext[T, S]) GroupBy(exprs ...TableExpr[T]) *PluckContext[T, S] { 54 | err := c.groupExpr.set(exprs) 55 | if err != nil { 56 | c.addError(fmt.Errorf("group by: %w", err)) 57 | } 58 | 59 | return c 60 | } 61 | 62 | func (c *PluckContext[T, S]) Having( 63 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 64 | ) *PluckContext[T, S] { 65 | err := c.havingCondition.set(condition) 66 | if err != nil { 67 | c.addError(fmt.Errorf("having condition: %w", err)) 68 | } 69 | 70 | return c 71 | } 72 | 73 | func (c *PluckContext[T, S]) OrderBy(direction OrderDirection, expr TableExpr[T]) *PluckContext[T, S] { 74 | err := c.order.add(orderItem[T]{ 75 | expr: expr, 76 | direction: direction, 77 | }) 78 | if err != nil { 79 | c.addError(fmt.Errorf("order by: %w", err)) 80 | } 81 | 82 | return c 83 | } 84 | 85 | func (c *PluckContext[T, S]) Limit(limit uint64) *PluckContext[T, S] { 86 | err := c.limit.set(limit) 87 | if err != nil { 88 | c.addError(fmt.Errorf("limit: %w", err)) 89 | } 90 | 91 | return c 92 | } 93 | 94 | func (c *PluckContext[T, S]) Offset(offset uint64) *PluckContext[T, S] { 95 | err := c.offset.set(offset) 96 | if err != nil { 97 | c.addError(fmt.Errorf("offset: %w", err)) 98 | } 99 | 100 | return c 101 | } 102 | 103 | func (c *PluckContext[T, S]) Lock(lockType LockType) *PluckContext[T, S] { 104 | err := c.lockType.set(lockType) 105 | if err != nil { 106 | c.addError(fmt.Errorf("lock: %w", err)) 107 | } 108 | 109 | return c 110 | } 111 | 112 | func (c *PluckContext[T, S]) GetAllCtx(ctx context.Context, db DB) ([]S, error) { 113 | errs := c.Errors() 114 | if len(errs) != 0 { 115 | return nil, errs[0] 116 | } 117 | 118 | query, exprArgs, err := c.buildQuery() 119 | if err != nil { 120 | return nil, fmt.Errorf("build query: %w", err) 121 | } 122 | 123 | args := make([]any, 0, len(exprArgs)) 124 | for _, arg := range exprArgs { 125 | args = append(args, arg) 126 | } 127 | 128 | rows, err := db.QueryContext(ctx, query, args...) 129 | if errors.Is(err, sql.ErrNoRows) { 130 | return []S{}, nil 131 | } 132 | if err != nil { 133 | return nil, fmt.Errorf("query: %w", err) 134 | } 135 | defer rows.Close() 136 | 137 | exprs := []S{} 138 | for rows.Next() { 139 | var expr S 140 | 141 | err := rows.Scan(&expr) 142 | if err != nil { 143 | return nil, fmt.Errorf("scan: %w", err) 144 | } 145 | 146 | exprs = append(exprs, expr) 147 | } 148 | 149 | return exprs, nil 150 | } 151 | 152 | func (c *PluckContext[T, S]) GetAll(db DB) ([]S, error) { 153 | return c.GetAllCtx(context.Background(), db) 154 | } 155 | 156 | func (c *PluckContext[T, S]) GetCtx(ctx context.Context, db DB) (S, error) { 157 | var res S 158 | 159 | err := c.limit.set(1) 160 | if err != nil { 161 | return res, fmt.Errorf("set limit 1: %w", err) 162 | } 163 | 164 | errs := c.Errors() 165 | if len(errs) != 0 { 166 | return res, errs[0] 167 | } 168 | 169 | query, queryArgs, err := c.buildQuery() 170 | if err != nil { 171 | return res, fmt.Errorf("build query: %w", err) 172 | } 173 | 174 | args := make([]any, 0, len(queryArgs)) 175 | for _, arg := range queryArgs { 176 | args = append(args, arg) 177 | } 178 | 179 | row := db.QueryRowContext(ctx, query, args...) 180 | 181 | err = row.Scan(&res) 182 | if errors.Is(err, sql.ErrNoRows) { 183 | return res, ErrRecordNotFound 184 | } 185 | if err != nil { 186 | return res, fmt.Errorf("query: %w", err) 187 | } 188 | 189 | return res, nil 190 | } 191 | 192 | func (c *PluckContext[T, S]) Get(db DB) (S, error) { 193 | return c.GetCtx(context.Background(), db) 194 | } 195 | 196 | func (c *PluckContext[T, S]) buildQuery() (string, []ExprType, error) { 197 | sb := strings.Builder{} 198 | args := []ExprType{} 199 | 200 | str := "SELECT " 201 | _, err := sb.WriteString(str) 202 | if err != nil { 203 | return "", nil, fmt.Errorf("write select(%s): %w", str, err) 204 | } 205 | 206 | if c.distinct { 207 | str = "DISTINCT " 208 | _, err = sb.WriteString(str) 209 | if err != nil { 210 | return "", nil, fmt.Errorf("write distinct(%s): %w", str, err) 211 | } 212 | } 213 | 214 | fieldQuery, fieldArgs, errs := c.field.Expr() 215 | if len(errs) != 0 { 216 | return "", nil, fmt.Errorf("field: %w", errs[0]) 217 | } 218 | 219 | _, err = sb.WriteString(fieldQuery) 220 | if err != nil { 221 | return "", nil, fmt.Errorf("write field(%s): %w", fieldQuery, err) 222 | } 223 | 224 | str = " AS res" 225 | _, err = sb.WriteString(str) 226 | if err != nil { 227 | return "", nil, fmt.Errorf("write as(%s): %w", str, err) 228 | } 229 | 230 | args = append(args, fieldArgs...) 231 | 232 | str = " FROM " 233 | _, err = sb.WriteString(str) 234 | if err != nil { 235 | return "", nil, fmt.Errorf("write from(%s): %w", str, err) 236 | } 237 | 238 | tableQuery, tableArgs, errs := c.table.Expr() 239 | if len(errs) != 0 { 240 | return "", nil, fmt.Errorf("table expr: %w", errs[0]) 241 | } 242 | 243 | _, err = sb.WriteString(tableQuery) 244 | if err != nil { 245 | return "", nil, fmt.Errorf("write table(%s): %w", tableQuery, err) 246 | } 247 | 248 | args = append(args, tableArgs...) 249 | 250 | if c.whereCondition.exists() { 251 | whereQuery, whereArgs, err := c.whereCondition.getExpr() 252 | if err != nil { 253 | return "", nil, fmt.Errorf("where condition: %w", err) 254 | } 255 | 256 | str = " WHERE " 257 | _, err = sb.WriteString(str) 258 | if err != nil { 259 | return "", nil, fmt.Errorf("write where(%s): %w", str, err) 260 | } 261 | 262 | _, err = sb.WriteString(whereQuery) 263 | if err != nil { 264 | return "", nil, fmt.Errorf("write where(%s): %w", whereQuery, err) 265 | } 266 | 267 | args = append(args, whereArgs...) 268 | } 269 | 270 | if c.groupExpr.exists() { 271 | groupExpr, groupArgs, err := c.groupExpr.getExpr() 272 | if err != nil { 273 | return "", nil, fmt.Errorf("group expr: %w", err) 274 | } 275 | 276 | str = " " 277 | _, err = sb.WriteString(str) 278 | if err != nil { 279 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 280 | } 281 | 282 | _, err = sb.WriteString(groupExpr) 283 | if err != nil { 284 | return "", nil, fmt.Errorf("write string(%s): %w", groupExpr, err) 285 | } 286 | 287 | args = append(args, groupArgs...) 288 | } 289 | 290 | if c.havingCondition.exists() { 291 | havingQuery, havingArgs, err := c.havingCondition.getExpr() 292 | if err != nil { 293 | return "", nil, fmt.Errorf("having condition: %w", err) 294 | } 295 | 296 | str = " HAVING " 297 | _, err = sb.WriteString(str) 298 | if err != nil { 299 | return "", nil, fmt.Errorf("write having(%s): %w", str, err) 300 | } 301 | 302 | _, err = sb.WriteString(havingQuery) 303 | if err != nil { 304 | return "", nil, fmt.Errorf("write having(%s): %w", havingQuery, err) 305 | } 306 | 307 | args = append(args, havingArgs...) 308 | } 309 | 310 | if c.order.exists() { 311 | orderQuery, orderArgs, err := c.order.getExpr() 312 | if err != nil { 313 | return "", nil, fmt.Errorf("order: %w", err) 314 | } 315 | 316 | str = " " 317 | _, err = sb.WriteString(str) 318 | if err != nil { 319 | return "", nil, fmt.Errorf("write order(%s): %w", str, err) 320 | } 321 | 322 | _, err = sb.WriteString(orderQuery) 323 | if err != nil { 324 | return "", nil, fmt.Errorf("write order(%s): %w", orderQuery, err) 325 | } 326 | 327 | args = append(args, orderArgs...) 328 | } 329 | 330 | if c.limit.exists() { 331 | limitQuery, limitArgs, err := c.limit.getExpr() 332 | if err != nil { 333 | return "", nil, fmt.Errorf("limit: %w", err) 334 | } 335 | 336 | str = " " 337 | _, err = sb.WriteString(str) 338 | if err != nil { 339 | return "", nil, fmt.Errorf("write limit(%s): %w", str, err) 340 | } 341 | 342 | _, err = sb.WriteString(limitQuery) 343 | if err != nil { 344 | return "", nil, fmt.Errorf("write limit(%s): %w", limitQuery, err) 345 | } 346 | 347 | args = append(args, limitArgs...) 348 | } 349 | 350 | if c.offset.exists() { 351 | offsetQuery, offsetArgs, err := c.offset.getExpr() 352 | if err != nil { 353 | return "", nil, fmt.Errorf("offset: %w", err) 354 | } 355 | 356 | str = " " 357 | _, err = sb.WriteString(str) 358 | if err != nil { 359 | return "", nil, fmt.Errorf("write offset(%s): %w", str, err) 360 | } 361 | 362 | _, err = sb.WriteString(offsetQuery) 363 | if err != nil { 364 | return "", nil, fmt.Errorf("write offset(%s): %w", offsetQuery, err) 365 | } 366 | 367 | args = append(args, offsetArgs...) 368 | } 369 | 370 | if c.lockType.exists() { 371 | lockQuery, lockArgs, err := c.lockType.getExpr() 372 | if err != nil { 373 | return "", nil, fmt.Errorf("lock type: %w", err) 374 | } 375 | 376 | str = " " 377 | _, err = sb.WriteString(str) 378 | if err != nil { 379 | return "", nil, fmt.Errorf("write lock(%s): %w", str, err) 380 | } 381 | 382 | _, err = sb.WriteString(lockQuery) 383 | if err != nil { 384 | return "", nil, fmt.Errorf("write lock(%s): %w", lockQuery, err) 385 | } 386 | 387 | args = append(args, lockArgs...) 388 | } 389 | 390 | return sb.String(), args, nil 391 | } 392 | -------------------------------------------------------------------------------- /pluck_expoter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *PluckContext[_, _]) BuildQuery() (string, []ExprType, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /pluck_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/mazrean/genorm" 9 | "github.com/mazrean/genorm/mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestPluckBuildQuery(t *testing.T) { 14 | t.Parallel() 15 | 16 | type expr struct { 17 | query string 18 | args []genorm.ExprType 19 | errs []error 20 | } 21 | 22 | type orderItem struct { 23 | direction genorm.OrderDirection 24 | expr expr 25 | } 26 | 27 | tests := []struct { 28 | description string 29 | tableExpr expr 30 | distinct bool 31 | fieldExpr expr 32 | groupExprs []expr 33 | havingCondition *expr 34 | whereCondition *expr 35 | orderItems []orderItem 36 | limit uint64 37 | offset uint64 38 | lockType genorm.LockType 39 | query string 40 | args []genorm.ExprType 41 | err bool 42 | }{ 43 | { 44 | description: "normal", 45 | tableExpr: expr{ 46 | query: "hoge", 47 | }, 48 | fieldExpr: expr{ 49 | query: "hoge.huga", 50 | }, 51 | query: "SELECT hoge.huga AS res FROM hoge", 52 | args: []genorm.ExprType{}, 53 | }, 54 | { 55 | description: "table error", 56 | tableExpr: expr{ 57 | errs: []error{errors.New("table error")}, 58 | }, 59 | fieldExpr: expr{ 60 | query: "hoge.huga", 61 | }, 62 | err: true, 63 | }, 64 | { 65 | description: "joined table", 66 | tableExpr: expr{ 67 | query: "hoge JOIN fuga ON hoge.id = fuga.id AND hoge.huga = ?", 68 | args: []genorm.ExprType{genorm.Wrap(1)}, 69 | }, 70 | fieldExpr: expr{ 71 | query: "hoge.huga", 72 | }, 73 | query: "SELECT hoge.huga AS res FROM hoge JOIN fuga ON hoge.id = fuga.id AND hoge.huga = ?", 74 | args: []genorm.ExprType{genorm.Wrap(1)}, 75 | }, 76 | { 77 | description: "field error", 78 | tableExpr: expr{ 79 | query: "hoge", 80 | }, 81 | fieldExpr: expr{ 82 | errs: []error{errors.New("field error")}, 83 | }, 84 | err: true, 85 | }, 86 | { 87 | description: "distinct", 88 | tableExpr: expr{ 89 | query: "hoge", 90 | }, 91 | distinct: true, 92 | fieldExpr: expr{ 93 | query: "hoge.huga", 94 | }, 95 | query: "SELECT DISTINCT hoge.huga AS res FROM hoge", 96 | args: []genorm.ExprType{}, 97 | }, 98 | { 99 | description: "group by", 100 | tableExpr: expr{ 101 | query: "hoge", 102 | }, 103 | fieldExpr: expr{ 104 | query: "hoge.huga", 105 | }, 106 | groupExprs: []expr{ 107 | { 108 | query: "hoge.fuga", 109 | }, 110 | }, 111 | query: "SELECT hoge.huga AS res FROM hoge GROUP BY hoge.fuga", 112 | args: []genorm.ExprType{}, 113 | }, 114 | { 115 | description: "group by with args", 116 | tableExpr: expr{ 117 | query: "hoge", 118 | }, 119 | fieldExpr: expr{ 120 | query: "hoge.huga", 121 | }, 122 | groupExprs: []expr{ 123 | { 124 | query: "hoge.fuga = ?", 125 | args: []genorm.ExprType{genorm.Wrap(1)}, 126 | }, 127 | }, 128 | query: "SELECT hoge.huga AS res FROM hoge GROUP BY hoge.fuga = ?", 129 | args: []genorm.ExprType{genorm.Wrap(1)}, 130 | }, 131 | { 132 | description: "group by error", 133 | tableExpr: expr{ 134 | query: "hoge", 135 | }, 136 | fieldExpr: expr{ 137 | query: "hoge.huga", 138 | }, 139 | groupExprs: []expr{ 140 | { 141 | errs: []error{errors.New("group error")}, 142 | }, 143 | }, 144 | err: true, 145 | }, 146 | { 147 | description: "multiple group by", 148 | tableExpr: expr{ 149 | query: "hoge", 150 | }, 151 | fieldExpr: expr{ 152 | query: "hoge.huga", 153 | }, 154 | groupExprs: []expr{ 155 | { 156 | query: "hoge.fuga = ?", 157 | args: []genorm.ExprType{genorm.Wrap(1)}, 158 | }, 159 | { 160 | query: "hoge.piyo = ?", 161 | args: []genorm.ExprType{genorm.Wrap(2)}, 162 | }, 163 | }, 164 | query: "SELECT hoge.huga AS res FROM hoge GROUP BY hoge.fuga = ?, hoge.piyo = ?", 165 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 166 | }, 167 | { 168 | description: "having", 169 | tableExpr: expr{ 170 | query: "hoge", 171 | }, 172 | fieldExpr: expr{ 173 | query: "hoge.huga", 174 | }, 175 | groupExprs: []expr{ 176 | { 177 | query: "hoge.fuga", 178 | }, 179 | }, 180 | havingCondition: &expr{ 181 | query: "hoge.huga = ?", 182 | args: []genorm.ExprType{genorm.Wrap(1)}, 183 | }, 184 | query: "SELECT hoge.huga AS res FROM hoge GROUP BY hoge.fuga HAVING hoge.huga = ?", 185 | args: []genorm.ExprType{genorm.Wrap(1)}, 186 | }, 187 | { 188 | description: "having error", 189 | tableExpr: expr{ 190 | query: "hoge", 191 | }, 192 | fieldExpr: expr{ 193 | query: "hoge.huga", 194 | }, 195 | groupExprs: []expr{ 196 | { 197 | query: "hoge.fuga", 198 | }, 199 | }, 200 | havingCondition: &expr{ 201 | errs: []error{errors.New("having error")}, 202 | }, 203 | err: true, 204 | }, 205 | { 206 | description: "where", 207 | tableExpr: expr{ 208 | query: "hoge", 209 | }, 210 | fieldExpr: expr{ 211 | query: "hoge.huga", 212 | }, 213 | whereCondition: &expr{ 214 | query: "(hoge.huga = ?)", 215 | args: []genorm.ExprType{genorm.Wrap(1)}, 216 | }, 217 | query: "SELECT hoge.huga AS res FROM hoge WHERE (hoge.huga = ?)", 218 | args: []genorm.ExprType{genorm.Wrap(1)}, 219 | }, 220 | { 221 | description: "where error", 222 | tableExpr: expr{ 223 | query: "hoge", 224 | }, 225 | fieldExpr: expr{ 226 | query: "hoge.huga", 227 | }, 228 | whereCondition: &expr{ 229 | errs: []error{errors.New("where error")}, 230 | }, 231 | err: true, 232 | }, 233 | { 234 | description: "order by", 235 | tableExpr: expr{ 236 | query: "hoge", 237 | }, 238 | fieldExpr: expr{ 239 | query: "hoge.huga", 240 | }, 241 | orderItems: []orderItem{ 242 | { 243 | direction: genorm.Asc, 244 | expr: expr{ 245 | query: "(hoge.huga = ?)", 246 | args: []genorm.ExprType{genorm.Wrap(1)}, 247 | }, 248 | }, 249 | }, 250 | query: "SELECT hoge.huga AS res FROM hoge ORDER BY (hoge.huga = ?) ASC", 251 | args: []genorm.ExprType{genorm.Wrap(1)}, 252 | }, 253 | { 254 | description: "order by error", 255 | tableExpr: expr{ 256 | query: "hoge", 257 | }, 258 | fieldExpr: expr{ 259 | query: "hoge.huga", 260 | }, 261 | orderItems: []orderItem{ 262 | { 263 | direction: genorm.Asc, 264 | expr: expr{ 265 | errs: []error{errors.New("order by error")}, 266 | }, 267 | }, 268 | }, 269 | err: true, 270 | }, 271 | { 272 | description: "multi order by", 273 | tableExpr: expr{ 274 | query: "hoge", 275 | }, 276 | fieldExpr: expr{ 277 | query: "hoge.huga", 278 | }, 279 | orderItems: []orderItem{ 280 | { 281 | direction: genorm.Asc, 282 | expr: expr{ 283 | query: "(hoge.huga = ?)", 284 | args: []genorm.ExprType{genorm.Wrap(1)}, 285 | }, 286 | }, 287 | { 288 | direction: genorm.Desc, 289 | expr: expr{ 290 | query: "(hoge.nya = ?)", 291 | args: []genorm.ExprType{genorm.Wrap(2)}, 292 | }, 293 | }, 294 | }, 295 | query: "SELECT hoge.huga AS res FROM hoge ORDER BY (hoge.huga = ?) ASC, (hoge.nya = ?) DESC", 296 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 297 | }, 298 | { 299 | description: "limit", 300 | tableExpr: expr{ 301 | query: "hoge", 302 | }, 303 | fieldExpr: expr{ 304 | query: "hoge.huga", 305 | }, 306 | limit: 1, 307 | query: "SELECT hoge.huga AS res FROM hoge LIMIT 1", 308 | args: []genorm.ExprType{}, 309 | }, 310 | { 311 | description: "offset", 312 | tableExpr: expr{ 313 | query: "hoge", 314 | }, 315 | fieldExpr: expr{ 316 | query: "hoge.huga", 317 | }, 318 | offset: 1, 319 | query: "SELECT hoge.huga AS res FROM hoge OFFSET 1", 320 | args: []genorm.ExprType{}, 321 | }, 322 | { 323 | description: "for update", 324 | tableExpr: expr{ 325 | query: "hoge", 326 | }, 327 | fieldExpr: expr{ 328 | query: "hoge.huga", 329 | }, 330 | lockType: genorm.ForUpdate, 331 | query: "SELECT hoge.huga AS res FROM hoge FOR UPDATE", 332 | args: []genorm.ExprType{}, 333 | }, 334 | { 335 | description: "for share", 336 | tableExpr: expr{ 337 | query: "hoge", 338 | }, 339 | fieldExpr: expr{ 340 | query: "hoge.huga", 341 | }, 342 | lockType: genorm.ForShare, 343 | query: "SELECT hoge.huga AS res FROM hoge FOR SHARE", 344 | args: []genorm.ExprType{}, 345 | }, 346 | } 347 | 348 | for _, test := range tests { 349 | t.Run(test.description, func(t *testing.T) { 350 | ctrl := gomock.NewController(t) 351 | 352 | table := mock.NewMockTable(ctrl) 353 | if len(test.fieldExpr.errs) == 0 { 354 | table. 355 | EXPECT(). 356 | Expr(). 357 | Return(test.tableExpr.query, test.tableExpr.args, test.tableExpr.errs) 358 | } 359 | table. 360 | EXPECT(). 361 | GetErrors(). 362 | Return(nil) 363 | 364 | mockField := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[int]](ctrl) 365 | mockField. 366 | EXPECT(). 367 | Expr(). 368 | Return(test.fieldExpr.query, test.fieldExpr.args, test.fieldExpr.errs) 369 | var field genorm.TypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[int]] = mockField 370 | 371 | builder := genorm.Pluck(table, field) 372 | 373 | if test.distinct { 374 | builder = builder.Distinct() 375 | } 376 | 377 | if test.groupExprs != nil { 378 | groupExprs := make([]genorm.TableExpr[*mock.MockTable], 0, len(test.groupExprs)) 379 | for _, groupExpr := range test.groupExprs { 380 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 381 | mockExpr. 382 | EXPECT(). 383 | Expr(). 384 | Return(groupExpr.query, groupExpr.args, groupExpr.errs) 385 | 386 | groupExprs = append(groupExprs, mockExpr) 387 | } 388 | 389 | builder = builder.GroupBy(groupExprs...) 390 | } 391 | 392 | if test.havingCondition != nil { 393 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 394 | mockExpr. 395 | EXPECT(). 396 | Expr(). 397 | Return(test.havingCondition.query, test.havingCondition.args, test.havingCondition.errs) 398 | 399 | builder = builder.Having(mockExpr) 400 | } 401 | 402 | if test.whereCondition != nil { 403 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 404 | mockExpr. 405 | EXPECT(). 406 | Expr(). 407 | Return(test.whereCondition.query, test.whereCondition.args, test.whereCondition.errs) 408 | 409 | builder = builder.Where(mockExpr) 410 | } 411 | 412 | for _, orderItem := range test.orderItems { 413 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 414 | mockExpr. 415 | EXPECT(). 416 | Expr(). 417 | Return(orderItem.expr.query, orderItem.expr.args, orderItem.expr.errs) 418 | builder = builder.OrderBy(orderItem.direction, mockExpr) 419 | } 420 | 421 | if test.limit > 0 { 422 | builder = builder.Limit(test.limit) 423 | } 424 | 425 | if test.offset > 0 { 426 | builder = builder.Offset(test.offset) 427 | } 428 | 429 | if test.lockType != 0 { 430 | builder = builder.Lock(test.lockType) 431 | } 432 | 433 | query, args, err := builder.BuildQuery() 434 | 435 | if test.err { 436 | assert.Error(t, err) 437 | return 438 | } else { 439 | if !assert.NoError(t, err) { 440 | return 441 | } 442 | } 443 | 444 | assert.Equal(t, test.query, query) 445 | assert.Equal(t, test.args, args) 446 | }) 447 | } 448 | } 449 | -------------------------------------------------------------------------------- /relation/relation.go: -------------------------------------------------------------------------------- 1 | package relation 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/mazrean/genorm" 9 | ) 10 | 11 | //nolint:revive 12 | type RelationContext[S Table, T Table, _ JoinedTablePointer[V], V any] struct { 13 | baseTable S 14 | refTable T 15 | } 16 | 17 | func NewRelationContext[S Table, T Table, U JoinedTablePointer[V], V any](baseTable S, refTable T) *RelationContext[S, T, U, V] { 18 | return &RelationContext[S, T, U, V]{ 19 | baseTable: baseTable, 20 | refTable: refTable, 21 | } 22 | } 23 | 24 | // Join INNER JOIN(CROSS JOIN) 25 | func (r *RelationContext[S, T, U, V]) Join( 26 | expr genorm.TypedTableExpr[U, genorm.WrappedPrimitive[bool]], 27 | ) U { 28 | var joinedTable V 29 | 30 | relation, err := newRelation(join, r.baseTable, r.refTable, expr) 31 | if err != nil { 32 | U(&joinedTable).AddError(err) 33 | return &joinedTable 34 | } 35 | 36 | U(&joinedTable).SetRelation(relation) 37 | 38 | return &joinedTable 39 | } 40 | 41 | // LeftJoin LEFT JOIN 42 | func (r *RelationContext[S, T, U, V]) LeftJoin( 43 | expr genorm.TypedTableExpr[U, genorm.WrappedPrimitive[bool]], 44 | ) U { 45 | var joinedTable V 46 | 47 | relation, err := newRelation(leftJoin, r.baseTable, r.refTable, expr) 48 | if err != nil { 49 | U(&joinedTable).AddError(err) 50 | return &joinedTable 51 | } 52 | 53 | U(&joinedTable).SetRelation(relation) 54 | 55 | return &joinedTable 56 | } 57 | 58 | // RightJoin RIGHT JOIN 59 | func (r *RelationContext[S, T, U, V]) RightJoin( 60 | expr genorm.TypedTableExpr[U, genorm.WrappedPrimitive[bool]], 61 | ) U { 62 | var joinedTable V 63 | 64 | relation, err := newRelation(rightJoin, r.baseTable, r.refTable, expr) 65 | if err != nil { 66 | U(&joinedTable).AddError(err) 67 | return &joinedTable 68 | } 69 | 70 | U(&joinedTable).SetRelation(relation) 71 | 72 | return &joinedTable 73 | } 74 | 75 | type Relation struct { 76 | relationType RelationType 77 | baseTable Table 78 | refTable Table 79 | onExpr genorm.Expr 80 | } 81 | 82 | func newRelation(relationType RelationType, baseTable, refTable Table, expr genorm.Expr) (*Relation, error) { 83 | if err := relationType.validate(); err != nil { 84 | return nil, fmt.Errorf("validate relation type: %w", err) 85 | } 86 | 87 | return &Relation{ 88 | relationType: relationType, 89 | baseTable: baseTable, 90 | refTable: refTable, 91 | onExpr: expr, 92 | }, nil 93 | } 94 | 95 | func (r *Relation) JoinedTableName() (string, []genorm.ExprType, []error) { 96 | sb := strings.Builder{} 97 | args := []genorm.ExprType{} 98 | 99 | str := "(" 100 | _, err := sb.WriteString(str) 101 | if err != nil { 102 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 103 | } 104 | 105 | baseTableQuery, baseTableArgs, errs := r.baseTable.Expr() 106 | if len(errs) != 0 { 107 | return "", nil, errs 108 | } 109 | 110 | _, err = sb.WriteString(baseTableQuery) 111 | if err != nil { 112 | return "", nil, []error{fmt.Errorf("write string(%s): %w", baseTableQuery, err)} 113 | } 114 | 115 | args = append(args, baseTableArgs...) 116 | 117 | switch r.relationType { 118 | case join: 119 | if r.onExpr != nil { 120 | str = " INNER JOIN " 121 | _, err = sb.WriteString(str) 122 | if err != nil { 123 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 124 | } 125 | } else { 126 | str = " CROSS JOIN " 127 | _, err = sb.WriteString(str) 128 | if err != nil { 129 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 130 | } 131 | } 132 | case leftJoin: 133 | str = " LEFT JOIN " 134 | _, err = sb.WriteString(str) 135 | if err != nil { 136 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 137 | } 138 | case rightJoin: 139 | str = " RIGHT JOIN " 140 | _, err = sb.WriteString(str) 141 | if err != nil { 142 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 143 | } 144 | default: 145 | return "", nil, []error{errors.New("unsupported relation type")} 146 | } 147 | 148 | refTableQuery, refTableArgs, errs := r.refTable.Expr() 149 | if len(errs) != 0 { 150 | return "", nil, errs 151 | } 152 | 153 | _, err = sb.WriteString(refTableQuery) 154 | if err != nil { 155 | return "", nil, []error{fmt.Errorf("write string(%s): %w", refTableQuery, err)} 156 | } 157 | 158 | args = append(args, refTableArgs...) 159 | 160 | if r.onExpr != nil { 161 | str = " ON " 162 | _, err = sb.WriteString(str) 163 | if err != nil { 164 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 165 | } 166 | 167 | onExprQuery, onExprArgs, errs := r.onExpr.Expr() 168 | if len(errs) != 0 { 169 | return "", nil, errs 170 | } 171 | 172 | _, err = sb.WriteString(onExprQuery) 173 | if err != nil { 174 | return "", nil, []error{fmt.Errorf("write string(%s): %w", onExprQuery, err)} 175 | } 176 | 177 | args = append(args, onExprArgs...) 178 | } 179 | 180 | str = ")" 181 | _, err = sb.WriteString(str) 182 | if err != nil { 183 | return "", nil, []error{fmt.Errorf("write string(%s): %w", str, err)} 184 | } 185 | 186 | return sb.String(), args, nil 187 | } 188 | 189 | //nolint:revive 190 | type RelationType int8 191 | 192 | const ( 193 | join RelationType = iota + 1 194 | leftJoin 195 | rightJoin 196 | ) 197 | 198 | func (rt RelationType) validate() error { 199 | if rt != join && rt != leftJoin && rt != rightJoin { 200 | return errors.New("unsupported relation type") 201 | } 202 | 203 | return nil 204 | } 205 | -------------------------------------------------------------------------------- /relation/relation_test.go: -------------------------------------------------------------------------------- 1 | package relation 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/mazrean/genorm" 9 | "github.com/mazrean/genorm/mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestJoinedTableName(t *testing.T) { 14 | t.Parallel() 15 | 16 | type expr struct { 17 | query string 18 | args []genorm.ExprType 19 | errs []error 20 | } 21 | 22 | tests := []struct { 23 | description string 24 | relationType RelationType 25 | baseExpr expr 26 | refTableExpr expr 27 | onExpr *expr 28 | query string 29 | args []genorm.ExprType 30 | err bool 31 | }{ 32 | { 33 | description: "normal", 34 | relationType: join, 35 | baseExpr: expr{ 36 | query: "hoge", 37 | }, 38 | refTableExpr: expr{ 39 | query: "fuga", 40 | }, 41 | query: "(hoge CROSS JOIN fuga)", 42 | args: []genorm.ExprType{}, 43 | }, 44 | { 45 | description: "onExpr", 46 | relationType: join, 47 | baseExpr: expr{ 48 | query: "hoge", 49 | }, 50 | refTableExpr: expr{ 51 | query: "fuga", 52 | }, 53 | onExpr: &expr{ 54 | query: "(hoge.id = fuga.id)", 55 | }, 56 | query: "(hoge INNER JOIN fuga ON (hoge.id = fuga.id))", 57 | args: []genorm.ExprType{}, 58 | }, 59 | { 60 | description: "onExpr with args", 61 | relationType: join, 62 | baseExpr: expr{ 63 | query: "hoge", 64 | }, 65 | refTableExpr: expr{ 66 | query: "fuga", 67 | }, 68 | onExpr: &expr{ 69 | query: "(hoge.id = ?)", 70 | args: []genorm.ExprType{genorm.Wrap(1)}, 71 | }, 72 | query: "(hoge INNER JOIN fuga ON (hoge.id = ?))", 73 | args: []genorm.ExprType{genorm.Wrap(1)}, 74 | }, 75 | { 76 | description: "onExpr with error", 77 | relationType: join, 78 | baseExpr: expr{ 79 | query: "hoge", 80 | }, 81 | refTableExpr: expr{ 82 | query: "fuga", 83 | }, 84 | onExpr: &expr{ 85 | errs: []error{errors.New("onExpr error")}, 86 | }, 87 | err: true, 88 | }, 89 | { 90 | description: "baseTable with args", 91 | relationType: join, 92 | baseExpr: expr{ 93 | query: "(hoge INNER JOIN fuga ON (fuga.id = ?))", 94 | args: []genorm.ExprType{genorm.Wrap(1)}, 95 | }, 96 | refTableExpr: expr{ 97 | query: "fuga", 98 | }, 99 | query: "((hoge INNER JOIN fuga ON (fuga.id = ?)) CROSS JOIN fuga)", 100 | args: []genorm.ExprType{genorm.Wrap(1)}, 101 | }, 102 | { 103 | description: "baseTable with error", 104 | relationType: join, 105 | baseExpr: expr{ 106 | errs: []error{errors.New("error")}, 107 | }, 108 | refTableExpr: expr{ 109 | query: "fuga", 110 | }, 111 | err: true, 112 | }, 113 | { 114 | description: "refTable with args", 115 | relationType: join, 116 | baseExpr: expr{ 117 | query: "hoge", 118 | }, 119 | refTableExpr: expr{ 120 | query: "(fuga INNER JOIN piyo ON (piyo.id = ?))", 121 | args: []genorm.ExprType{genorm.Wrap(1)}, 122 | }, 123 | query: "(hoge CROSS JOIN (fuga INNER JOIN piyo ON (piyo.id = ?)))", 124 | args: []genorm.ExprType{genorm.Wrap(1)}, 125 | }, 126 | { 127 | description: "refTable with error", 128 | relationType: join, 129 | baseExpr: expr{ 130 | query: "hoge", 131 | }, 132 | refTableExpr: expr{ 133 | errs: []error{errors.New("error")}, 134 | }, 135 | err: true, 136 | }, 137 | { 138 | description: "left join", 139 | relationType: leftJoin, 140 | baseExpr: expr{ 141 | query: "hoge", 142 | }, 143 | refTableExpr: expr{ 144 | query: "fuga", 145 | }, 146 | onExpr: &expr{ 147 | query: "(hoge.id = fuga.id)", 148 | }, 149 | query: "(hoge LEFT JOIN fuga ON (hoge.id = fuga.id))", 150 | args: []genorm.ExprType{}, 151 | }, 152 | { 153 | description: "right join", 154 | relationType: rightJoin, 155 | baseExpr: expr{ 156 | query: "hoge", 157 | }, 158 | refTableExpr: expr{ 159 | query: "fuga", 160 | }, 161 | onExpr: &expr{ 162 | query: "(hoge.id = fuga.id)", 163 | }, 164 | query: "(hoge RIGHT JOIN fuga ON (hoge.id = fuga.id))", 165 | args: []genorm.ExprType{}, 166 | }, 167 | } 168 | 169 | for _, test := range tests { 170 | t.Run(test.description, func(t *testing.T) { 171 | ctrl := gomock.NewController(t) 172 | 173 | baseTable := mock.NewMockTable(ctrl) 174 | baseTable. 175 | EXPECT(). 176 | Expr(). 177 | Return(test.baseExpr.query, test.baseExpr.args, test.baseExpr.errs) 178 | 179 | refTable := mock.NewMockTable(ctrl) 180 | if len(test.baseExpr.errs) == 0 { 181 | refTable. 182 | EXPECT(). 183 | Expr(). 184 | Return(test.refTableExpr.query, test.refTableExpr.args, test.refTableExpr.errs) 185 | } 186 | 187 | var onExpr genorm.Expr 188 | if test.onExpr != nil { 189 | mockExpr := mock.NewMockExpr(ctrl) 190 | mockExpr. 191 | EXPECT(). 192 | Expr(). 193 | Return(test.onExpr.query, test.onExpr.args, test.onExpr.errs) 194 | 195 | onExpr = mockExpr 196 | } 197 | 198 | relation := &Relation{ 199 | relationType: test.relationType, 200 | baseTable: baseTable, 201 | refTable: refTable, 202 | onExpr: onExpr, 203 | } 204 | 205 | query, args, errs := relation.JoinedTableName() 206 | 207 | if test.err { 208 | assert.Greater(t, len(errs), 0) 209 | return 210 | } else if !assert.Len(t, errs, 0) { 211 | return 212 | } 213 | 214 | assert.Equal(t, test.query, query) 215 | assert.Equal(t, test.args, args) 216 | }) 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /relation/table.go: -------------------------------------------------------------------------------- 1 | package relation 2 | 3 | import ( 4 | "github.com/mazrean/genorm" 5 | ) 6 | 7 | type Table interface { 8 | genorm.Table 9 | } 10 | 11 | type BasicTable interface { 12 | Table 13 | genorm.BasicTable 14 | } 15 | 16 | type JoinedTable interface { 17 | Table 18 | genorm.JoinedTable 19 | SetRelation(*Relation) 20 | } 21 | 22 | type JoinedTablePointer[T any] interface { 23 | JoinedTable 24 | *T 25 | } 26 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | type SelectContext[S any, T TablePointer[S]] struct { 12 | *Context[T] 13 | distinct bool 14 | fields []TableColumns[T] 15 | whereCondition whereConditionClause[T] 16 | groupExpr groupClause[T] 17 | havingCondition whereConditionClause[T] 18 | order orderClause[T] 19 | limit limitClause 20 | offset offsetClause 21 | lockType lockClause 22 | } 23 | 24 | func Select[S any, T TablePointer[S]](table T) *SelectContext[S, T] { 25 | return &SelectContext[S, T]{ 26 | Context: newContext(table), 27 | } 28 | } 29 | 30 | func (c *SelectContext[S, T]) Distinct() *SelectContext[S, T] { 31 | if c.distinct { 32 | c.addError(errors.New("distinct already set")) 33 | return c 34 | } 35 | 36 | c.distinct = true 37 | 38 | return c 39 | } 40 | 41 | func (c *SelectContext[S, T]) Fields(fields ...TableColumns[T]) *SelectContext[S, T] { 42 | if c.fields != nil { 43 | c.addError(errors.New("fields already set")) 44 | return c 45 | } 46 | if len(fields) == 0 { 47 | c.addError(errors.New("no fields")) 48 | return c 49 | } 50 | 51 | fields = append(c.fields, fields...) 52 | fieldMap := make(map[TableColumns[T]]struct{}, len(fields)) 53 | for _, field := range fields { 54 | if _, ok := fieldMap[field]; ok { 55 | c.addError(errors.New("duplicate field")) 56 | return c 57 | } 58 | 59 | fieldMap[field] = struct{}{} 60 | } 61 | 62 | c.fields = fields 63 | 64 | return c 65 | } 66 | 67 | func (c *SelectContext[S, T]) Where( 68 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 69 | ) *SelectContext[S, T] { 70 | err := c.whereCondition.set(condition) 71 | if err != nil { 72 | c.addError(fmt.Errorf("where condition: %w", err)) 73 | } 74 | 75 | return c 76 | } 77 | 78 | func (c *SelectContext[S, T]) GroupBy(exprs ...TableExpr[T]) *SelectContext[S, T] { 79 | err := c.groupExpr.set(exprs) 80 | if err != nil { 81 | c.addError(fmt.Errorf("group by: %w", err)) 82 | } 83 | 84 | return c 85 | } 86 | 87 | func (c *SelectContext[S, T]) Having( 88 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 89 | ) *SelectContext[S, T] { 90 | err := c.havingCondition.set(condition) 91 | if err != nil { 92 | c.addError(fmt.Errorf("having condition: %w", err)) 93 | } 94 | 95 | return c 96 | } 97 | 98 | func (c *SelectContext[S, T]) OrderBy(direction OrderDirection, expr TableExpr[T]) *SelectContext[S, T] { 99 | err := c.order.add(orderItem[T]{ 100 | expr: expr, 101 | direction: direction, 102 | }) 103 | if err != nil { 104 | c.addError(fmt.Errorf("order by: %w", err)) 105 | } 106 | 107 | return c 108 | } 109 | 110 | func (c *SelectContext[S, T]) Limit(limit uint64) *SelectContext[S, T] { 111 | err := c.limit.set(limit) 112 | if err != nil { 113 | c.addError(fmt.Errorf("limit: %w", err)) 114 | } 115 | 116 | return c 117 | } 118 | 119 | func (c *SelectContext[S, T]) Offset(offset uint64) *SelectContext[S, T] { 120 | err := c.offset.set(offset) 121 | if err != nil { 122 | c.addError(fmt.Errorf("offset: %w", err)) 123 | } 124 | 125 | return c 126 | } 127 | 128 | func (c *SelectContext[S, T]) Lock(lockType LockType) *SelectContext[S, T] { 129 | err := c.lockType.set(lockType) 130 | if err != nil { 131 | c.addError(fmt.Errorf("lockType: %w", err)) 132 | } 133 | 134 | return c 135 | } 136 | 137 | func (c *SelectContext[S, T]) GetAllCtx(ctx context.Context, db DB) ([]T, error) { 138 | errs := c.Errors() 139 | if len(errs) != 0 { 140 | return nil, errs[0] 141 | } 142 | 143 | columns, query, exprArgs, err := c.buildQuery() 144 | if err != nil { 145 | return nil, fmt.Errorf("build query: %w", err) 146 | } 147 | 148 | args := make([]any, 0, len(exprArgs)) 149 | for _, arg := range exprArgs { 150 | args = append(args, arg) 151 | } 152 | 153 | rows, err := db.QueryContext(ctx, query, args...) 154 | if errors.Is(err, sql.ErrNoRows) { 155 | return []T{}, nil 156 | } 157 | if err != nil { 158 | return nil, fmt.Errorf("query: %w", err) 159 | } 160 | defer rows.Close() 161 | 162 | tables := []T{} 163 | for rows.Next() { 164 | var table S 165 | columnMap := T(&table).ColumnMap() 166 | 167 | dests := make([]any, 0, len(columns)) 168 | for _, column := range columns { 169 | columnField, ok := columnMap[column.SQLColumnName()] 170 | if !ok { 171 | return nil, fmt.Errorf("column %s not found", column.SQLColumnName()) 172 | } 173 | 174 | dests = append(dests, columnField) 175 | } 176 | 177 | err := rows.Scan(dests...) 178 | if err != nil { 179 | return nil, fmt.Errorf("scan: %w", err) 180 | } 181 | 182 | tables = append(tables, &table) 183 | } 184 | 185 | return tables, nil 186 | } 187 | 188 | func (c *SelectContext[S, T]) GetAll(db DB) ([]T, error) { 189 | return c.GetAllCtx(context.Background(), db) 190 | } 191 | 192 | func (c *SelectContext[S, T]) GetCtx(ctx context.Context, db DB) (T, error) { 193 | err := c.limit.set(1) 194 | if err != nil { 195 | return nil, fmt.Errorf("set limit 1: %w", err) 196 | } 197 | 198 | errs := c.Errors() 199 | if len(errs) != 0 { 200 | return nil, errs[0] 201 | } 202 | 203 | columns, query, exprArgs, err := c.buildQuery() 204 | if err != nil { 205 | return nil, fmt.Errorf("build query: %w", err) 206 | } 207 | 208 | args := make([]any, 0, len(exprArgs)) 209 | for _, arg := range exprArgs { 210 | args = append(args, arg) 211 | } 212 | 213 | row := db.QueryRowContext(ctx, query, args...) 214 | 215 | var table S 216 | columnMap := T(&table).ColumnMap() 217 | 218 | dests := make([]any, 0, len(columns)) 219 | for _, column := range columns { 220 | columnField, ok := columnMap[column.SQLColumnName()] 221 | if !ok { 222 | return nil, fmt.Errorf("column %s not found", column) 223 | } 224 | 225 | dests = append(dests, columnField) 226 | } 227 | 228 | err = row.Scan(dests...) 229 | if errors.Is(err, sql.ErrNoRows) { 230 | return nil, ErrRecordNotFound 231 | } 232 | if err != nil { 233 | return nil, fmt.Errorf("query: %w", err) 234 | } 235 | 236 | return &table, nil 237 | } 238 | 239 | func (c *SelectContext[S, T]) Get(db DB) (T, error) { 240 | return c.GetCtx(context.Background(), db) 241 | } 242 | 243 | func (c *SelectContext[S, T]) buildQuery() ([]Column, string, []ExprType, error) { 244 | sb := strings.Builder{} 245 | args := []ExprType{} 246 | 247 | str := "SELECT " 248 | _, err := sb.WriteString(str) 249 | if err != nil { 250 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 251 | } 252 | 253 | if c.distinct { 254 | str = "DISTINCT " 255 | _, err = sb.WriteString(str) 256 | if err != nil { 257 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 258 | } 259 | } 260 | 261 | var columns []Column 262 | if len(c.fields) == 0 { 263 | columns = c.table.Columns() 264 | } else { 265 | columns = make([]Column, 0, len(c.fields)) 266 | for _, field := range c.fields { 267 | columns = append(columns, field) 268 | } 269 | } 270 | 271 | columnAliasMap := map[string]struct{}{} 272 | selectExprs := make([]string, 0, len(columns)) 273 | for _, column := range columns { 274 | var alias string 275 | i := 0 276 | for ok := true; ok; _, ok = columnAliasMap[alias] { 277 | alias = fmt.Sprintf("%s_%s_%d", column.TableName(), column.ColumnName(), i) 278 | i++ 279 | } 280 | 281 | columnAliasMap[alias] = struct{}{} 282 | selectExprs = append(selectExprs, fmt.Sprintf("%s AS %s", column.SQLColumnName(), alias)) 283 | } 284 | 285 | str = strings.Join(selectExprs, ", ") 286 | _, err = sb.WriteString(str) 287 | if err != nil { 288 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 289 | } 290 | 291 | str = " FROM " 292 | _, err = sb.WriteString(str) 293 | if err != nil { 294 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 295 | } 296 | 297 | tableQuery, tableArgs, errs := c.table.Expr() 298 | if len(errs) != 0 { 299 | return nil, "", nil, fmt.Errorf("table expr: %w", errs[0]) 300 | } 301 | 302 | _, err = sb.WriteString(tableQuery) 303 | if err != nil { 304 | return nil, "", nil, fmt.Errorf("write string(%s): %w", tableQuery, err) 305 | } 306 | 307 | args = append(args, tableArgs...) 308 | 309 | if c.whereCondition.exists() { 310 | whereQuery, whereArgs, err := c.whereCondition.getExpr() 311 | if err != nil { 312 | return nil, "", nil, fmt.Errorf("where condition: %w", err) 313 | } 314 | 315 | str = " WHERE " 316 | _, err = sb.WriteString(str) 317 | if err != nil { 318 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 319 | } 320 | 321 | _, err = sb.WriteString(whereQuery) 322 | if err != nil { 323 | return nil, "", nil, fmt.Errorf("write string(%s): %w", whereQuery, err) 324 | } 325 | 326 | args = append(args, whereArgs...) 327 | } 328 | 329 | if c.groupExpr.exists() { 330 | groupExpr, groupArgs, err := c.groupExpr.getExpr() 331 | if err != nil { 332 | return nil, "", nil, fmt.Errorf("group expr: %w", err) 333 | } 334 | 335 | str = " " 336 | _, err = sb.WriteString(str) 337 | if err != nil { 338 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 339 | } 340 | 341 | _, err = sb.WriteString(groupExpr) 342 | if err != nil { 343 | return nil, "", nil, fmt.Errorf("write string(%s): %w", groupExpr, err) 344 | } 345 | 346 | args = append(args, groupArgs...) 347 | } 348 | 349 | if c.havingCondition.exists() { 350 | havingQuery, havingArgs, err := c.havingCondition.getExpr() 351 | if err != nil { 352 | return nil, "", nil, fmt.Errorf("having condition: %w", err) 353 | } 354 | 355 | str = " HAVING " 356 | _, err = sb.WriteString(str) 357 | if err != nil { 358 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 359 | } 360 | 361 | _, err = sb.WriteString(havingQuery) 362 | if err != nil { 363 | return nil, "", nil, fmt.Errorf("write string(%s): %w", havingQuery, err) 364 | } 365 | 366 | args = append(args, havingArgs...) 367 | } 368 | 369 | if c.order.exists() { 370 | orderQuery, orderArgs, err := c.order.getExpr() 371 | if err != nil { 372 | return nil, "", nil, fmt.Errorf("order: %w", err) 373 | } 374 | 375 | str = " " 376 | _, err = sb.WriteString(str) 377 | if err != nil { 378 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 379 | } 380 | 381 | _, err = sb.WriteString(orderQuery) 382 | if err != nil { 383 | return nil, "", nil, fmt.Errorf("write string(%s): %w", orderQuery, err) 384 | } 385 | 386 | args = append(args, orderArgs...) 387 | } 388 | 389 | if c.limit.exists() { 390 | limitQuery, limitArgs, err := c.limit.getExpr() 391 | if err != nil { 392 | return nil, "", nil, fmt.Errorf("limit: %w", err) 393 | } 394 | 395 | str = " " 396 | _, err = sb.WriteString(str) 397 | if err != nil { 398 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 399 | } 400 | 401 | _, err = sb.WriteString(limitQuery) 402 | if err != nil { 403 | return nil, "", nil, fmt.Errorf("write string(%s): %w", limitQuery, err) 404 | } 405 | 406 | args = append(args, limitArgs...) 407 | } 408 | 409 | if c.offset.exists() { 410 | offsetQuery, offsetArgs, err := c.offset.getExpr() 411 | if err != nil { 412 | return nil, "", nil, fmt.Errorf("offset: %w", err) 413 | } 414 | 415 | str = " " 416 | _, err = sb.WriteString(str) 417 | if err != nil { 418 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 419 | } 420 | 421 | _, err = sb.WriteString(offsetQuery) 422 | if err != nil { 423 | return nil, "", nil, fmt.Errorf("write string(%s): %w", offsetQuery, err) 424 | } 425 | 426 | args = append(args, offsetArgs...) 427 | } 428 | 429 | if c.lockType.exists() { 430 | lockQuery, lockArgs, err := c.lockType.getExpr() 431 | if err != nil { 432 | return nil, "", nil, fmt.Errorf("lock: %w", err) 433 | } 434 | 435 | str = " " 436 | _, err = sb.WriteString(str) 437 | if err != nil { 438 | return nil, "", nil, fmt.Errorf("write string(%s): %w", str, err) 439 | } 440 | 441 | _, err = sb.WriteString(lockQuery) 442 | if err != nil { 443 | return nil, "", nil, fmt.Errorf("write string(%s): %w", lockQuery, err) 444 | } 445 | 446 | args = append(args, lockArgs...) 447 | } 448 | 449 | return columns, sb.String(), args, nil 450 | } 451 | -------------------------------------------------------------------------------- /select_exporter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *SelectContext[_, _]) BuildQuery() ([]Column, string, []ExprType, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /table.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Table interface { 4 | Expr 5 | Columns() []Column 6 | // ColumnMap key: table_name.column_name 7 | ColumnMap() map[string]ColumnFieldExprType 8 | GetErrors() []error 9 | } 10 | 11 | type TablePointer[T any] interface { 12 | Table 13 | *T 14 | } 15 | 16 | type BasicTable interface { 17 | Table 18 | // TableName table_name 19 | TableName() string 20 | } 21 | 22 | type JoinedTable interface { 23 | Table 24 | BaseTables() []BasicTable 25 | AddError(error) 26 | } 27 | -------------------------------------------------------------------------------- /table_mock_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import "github.com/mazrean/genorm" 4 | 5 | //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -source=$GOFILE -destination=mock/table_mock.go -package=mock 6 | 7 | type Table interface { 8 | genorm.Table 9 | } 10 | 11 | type BasicTable interface { 12 | genorm.BasicTable 13 | } 14 | 15 | type JoinedTable interface { 16 | genorm.JoinedTable 17 | } 18 | -------------------------------------------------------------------------------- /tuple.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | type Tuple interface { 4 | Exprs() []Expr 5 | Columns() []ColumnFieldExprType 6 | } 7 | 8 | type TuplePointer[T any] interface { 9 | Exprs() []Expr 10 | Columns() []ColumnFieldExprType 11 | *T 12 | } 13 | 14 | type Tuple2Struct[ 15 | S Table, 16 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 17 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 18 | ] struct { 19 | value1 T1 20 | value2 T2 21 | expr1 TypedTableExpr[S, T1] 22 | expr2 TypedTableExpr[S, T2] 23 | } 24 | 25 | func Tuple2[ 26 | S Table, 27 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 28 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 29 | ](expr1 TypedTableExpr[S, T1], expr2 TypedTableExpr[S, T2]) *Tuple2Struct[S, T1, U1, T2, U2] { 30 | return &Tuple2Struct[S, T1, U1, T2, U2]{ 31 | expr1: expr1, 32 | expr2: expr2, 33 | } 34 | } 35 | 36 | func (t *Tuple2Struct[_, _, _, _, _]) Exprs() []Expr { 37 | return []Expr{t.expr1, t.expr2} 38 | } 39 | 40 | func (t *Tuple2Struct[_, _, U1, _, U2]) Columns() []ColumnFieldExprType { 41 | return []ColumnFieldExprType{ 42 | U1(&t.value1), 43 | U2(&t.value2), 44 | } 45 | } 46 | 47 | func (t *Tuple2Struct[_, T1, _, T2, _]) Values() (T1, T2) { 48 | return t.value1, t.value2 49 | } 50 | 51 | type Tuple3Struct[ 52 | S Table, 53 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 54 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 55 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 56 | ] struct { 57 | value1 T1 58 | value2 T2 59 | value3 T3 60 | expr1 TypedTableExpr[S, T1] 61 | expr2 TypedTableExpr[S, T2] 62 | expr3 TypedTableExpr[S, T3] 63 | } 64 | 65 | func Tuple3[ 66 | S Table, 67 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 68 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 69 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 70 | ](expr1 TypedTableExpr[S, T1], expr2 TypedTableExpr[S, T2], expr3 TypedTableExpr[S, T3]) *Tuple3Struct[S, T1, U1, T2, U2, T3, U3] { 71 | return &Tuple3Struct[S, T1, U1, T2, U2, T3, U3]{ 72 | expr1: expr1, 73 | expr2: expr2, 74 | expr3: expr3, 75 | } 76 | } 77 | 78 | func (t *Tuple3Struct[_, _, _, _, _, _, _]) Exprs() []Expr { 79 | return []Expr{t.expr1, t.expr2, t.expr3} 80 | } 81 | 82 | func (t *Tuple3Struct[_, _, U1, _, U2, _, U3]) Columns() []ColumnFieldExprType { 83 | return []ColumnFieldExprType{ 84 | U1(&t.value1), 85 | U2(&t.value2), 86 | U3(&t.value3), 87 | } 88 | } 89 | 90 | func (t *Tuple3Struct[_, T1, _, T2, _, T3, _]) Values() (T1, T2, T3) { 91 | return t.value1, t.value2, t.value3 92 | } 93 | 94 | type Tuple4Struct[ 95 | S Table, 96 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 97 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 98 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 99 | T4 ExprType, U4 ColumnFieldExprTypePointer[T4], 100 | ] struct { 101 | value1 T1 102 | value2 T2 103 | value3 T3 104 | value4 T4 105 | expr1 TypedTableExpr[S, T1] 106 | expr2 TypedTableExpr[S, T2] 107 | expr3 TypedTableExpr[S, T3] 108 | expr4 TypedTableExpr[S, T4] 109 | } 110 | 111 | func Tuple4[ 112 | S Table, 113 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 114 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 115 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 116 | T4 ExprType, U4 ColumnFieldExprTypePointer[T4], 117 | ](expr1 TypedTableExpr[S, T1], expr2 TypedTableExpr[S, T2], expr3 TypedTableExpr[S, T3], expr4 TypedTableExpr[S, T4]) *Tuple4Struct[S, T1, U1, T2, U2, T3, U3, T4, U4] { 118 | return &Tuple4Struct[S, T1, U1, T2, U2, T3, U3, T4, U4]{ 119 | expr1: expr1, 120 | expr2: expr2, 121 | expr3: expr3, 122 | expr4: expr4, 123 | } 124 | } 125 | 126 | func (t *Tuple4Struct[_, _, _, _, _, _, _, _, _]) Exprs() []Expr { 127 | return []Expr{t.expr1, t.expr2, t.expr3, t.expr4} 128 | } 129 | 130 | func (t *Tuple4Struct[_, _, U1, _, U2, _, U3, _, U4]) Columns() []ColumnFieldExprType { 131 | return []ColumnFieldExprType{ 132 | U1(&t.value1), 133 | U2(&t.value2), 134 | U3(&t.value3), 135 | U4(&t.value4), 136 | } 137 | } 138 | 139 | func (t *Tuple4Struct[_, T1, _, T2, _, T3, _, T4, _]) Values() (T1, T2, T3, T4) { 140 | return t.value1, t.value2, t.value3, t.value4 141 | } 142 | 143 | type Tuple5Struct[ 144 | S Table, 145 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 146 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 147 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 148 | T4 ExprType, U4 ColumnFieldExprTypePointer[T4], 149 | T5 ExprType, U5 ColumnFieldExprTypePointer[T5], 150 | ] struct { 151 | value1 T1 152 | value2 T2 153 | value3 T3 154 | value4 T4 155 | value5 T5 156 | expr1 TypedTableExpr[S, T1] 157 | expr2 TypedTableExpr[S, T2] 158 | expr3 TypedTableExpr[S, T3] 159 | expr4 TypedTableExpr[S, T4] 160 | expr5 TypedTableExpr[S, T5] 161 | } 162 | 163 | func Tuple5[ 164 | S Table, 165 | T1 ExprType, U1 ColumnFieldExprTypePointer[T1], 166 | T2 ExprType, U2 ColumnFieldExprTypePointer[T2], 167 | T3 ExprType, U3 ColumnFieldExprTypePointer[T3], 168 | T4 ExprType, U4 ColumnFieldExprTypePointer[T4], 169 | T5 ExprType, U5 ColumnFieldExprTypePointer[T5], 170 | ](expr1 TypedTableExpr[S, T1], expr2 TypedTableExpr[S, T2], expr3 TypedTableExpr[S, T3], expr4 TypedTableExpr[S, T4], expr5 TypedTableExpr[S, T5]) *Tuple5Struct[S, T1, U1, T2, U2, T3, U3, T4, U4, T5, U5] { 171 | return &Tuple5Struct[S, T1, U1, T2, U2, T3, U3, T4, U4, T5, U5]{ 172 | expr1: expr1, 173 | expr2: expr2, 174 | expr3: expr3, 175 | expr4: expr4, 176 | expr5: expr5, 177 | } 178 | } 179 | 180 | func (t *Tuple5Struct[_, _, _, _, _, _, _, _, _, _, _]) Exprs() []Expr { 181 | return []Expr{t.expr1, t.expr2, t.expr3, t.expr4, t.expr5} 182 | } 183 | 184 | func (t *Tuple5Struct[_, _, U1, _, U2, _, U3, _, U4, _, U5]) Columns() []ColumnFieldExprType { 185 | return []ColumnFieldExprType{ 186 | U1(&t.value1), 187 | U2(&t.value2), 188 | U3(&t.value3), 189 | U4(&t.value4), 190 | U5(&t.value5), 191 | } 192 | } 193 | 194 | func (t *Tuple5Struct[_, T1, _, T2, _, T3, _, T4, _, T5, _]) Values() (T1, T2, T3, T4, T5) { 195 | return t.value1, t.value2, t.value3, t.value4, t.value5 196 | } 197 | -------------------------------------------------------------------------------- /tuple_mock_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import "github.com/mazrean/genorm" 4 | 5 | //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -source=$GOFILE -destination=mock/tuple_mock.go -package=mock 6 | 7 | type Tuple interface { 8 | genorm.Tuple 9 | } 10 | -------------------------------------------------------------------------------- /type.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "time" 8 | ) 9 | 10 | type ExprType interface { 11 | driver.Valuer 12 | } 13 | 14 | type ColumnFieldExprType interface { 15 | sql.Scanner 16 | driver.Valuer 17 | } 18 | 19 | type ColumnFieldExprTypePointer[T ExprType] interface { 20 | ColumnFieldExprType 21 | *T 22 | } 23 | 24 | type ExprPrimitive interface { 25 | bool | 26 | int | int8 | int16 | int32 | int64 | 27 | uint | uint8 | uint16 | uint32 | uint64 | 28 | float32 | float64 | 29 | string | time.Time 30 | } 31 | 32 | type WrappedPrimitive[T ExprPrimitive] struct { 33 | valid bool 34 | val T 35 | } 36 | 37 | func Wrap[T ExprPrimitive](val T) WrappedPrimitive[T] { 38 | return WrappedPrimitive[T]{ 39 | valid: true, 40 | val: val, 41 | } 42 | } 43 | 44 | func (wp *WrappedPrimitive[T]) Scan(src any) error { 45 | var dest any = wp.val 46 | switch dest.(type) { 47 | case bool: 48 | nb := sql.NullBool{} 49 | 50 | err := nb.Scan(src) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | wp.valid = nb.Valid 56 | dest = nb.Bool 57 | case int8: 58 | ni := sql.NullInt16{} 59 | 60 | err := ni.Scan(src) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | wp.valid = ni.Valid 66 | dest = int8(ni.Int16) 67 | case int16: 68 | ns := sql.NullInt16{} 69 | 70 | err := ns.Scan(src) 71 | if err != nil { 72 | return err 73 | } 74 | 75 | wp.valid = ns.Valid 76 | dest = ns.Int16 77 | case int32: 78 | ns := sql.NullInt32{} 79 | 80 | err := ns.Scan(src) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | wp.valid = ns.Valid 86 | dest = ns.Int32 87 | case int, int64: 88 | ni := sql.NullInt64{} 89 | 90 | err := ni.Scan(src) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | wp.valid = ni.Valid 96 | dest = ni.Int64 97 | case byte: // uint8 98 | nb := sql.NullByte{} 99 | 100 | err := nb.Scan(src) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | wp.valid = nb.Valid 106 | dest = nb.Byte 107 | case uint16: 108 | ns := sql.NullInt32{} 109 | 110 | err := ns.Scan(src) 111 | if err != nil { 112 | return err 113 | } 114 | 115 | wp.valid = ns.Valid 116 | dest = uint16(ns.Int32) 117 | case uint32: 118 | ns := sql.NullInt64{} 119 | 120 | err := ns.Scan(src) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | wp.valid = ns.Valid 126 | dest = uint32(ns.Int64) 127 | case uint64: 128 | ni := sql.NullInt64{} 129 | 130 | err := ni.Scan(src) 131 | if err != nil { 132 | return err 133 | } 134 | 135 | wp.valid = ni.Valid 136 | dest = uint64(ni.Int64) 137 | case float32: 138 | nf := sql.NullFloat64{} 139 | 140 | err := nf.Scan(src) 141 | if err != nil { 142 | return err 143 | } 144 | 145 | wp.valid = nf.Valid 146 | dest = float32(nf.Float64) 147 | case float64: 148 | nf := sql.NullFloat64{} 149 | 150 | err := nf.Scan(src) 151 | if err != nil { 152 | return err 153 | } 154 | 155 | wp.valid = nf.Valid 156 | dest = nf.Float64 157 | case string: 158 | ns := sql.NullString{} 159 | 160 | err := ns.Scan(src) 161 | if err != nil { 162 | return err 163 | } 164 | 165 | wp.valid = ns.Valid 166 | dest = ns.String 167 | case time.Time: 168 | nt := sql.NullTime{} 169 | 170 | err := nt.Scan(src) 171 | if err != nil { 172 | return err 173 | } 174 | 175 | wp.valid = nt.Valid 176 | dest = nt.Time 177 | default: 178 | return errors.New("unsupported type") 179 | } 180 | 181 | var ok bool 182 | wp.val, ok = dest.(T) 183 | if !ok { 184 | return errors.New("failed to convert") 185 | } 186 | 187 | return nil 188 | } 189 | 190 | func (wp WrappedPrimitive[_]) Value() (driver.Value, error) { 191 | if !wp.valid { 192 | return nil, ErrNullValue 193 | } 194 | 195 | return wp.val, nil 196 | } 197 | 198 | func (wp WrappedPrimitive[T]) Val() (T, bool) { 199 | return wp.val, wp.valid 200 | } 201 | -------------------------------------------------------------------------------- /type_mock_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "github.com/mazrean/genorm" 5 | ) 6 | 7 | //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -source=$GOFILE -destination=mock/type_mock.go -package=mock 8 | 9 | type ExprType interface { 10 | genorm.ExprType 11 | } 12 | 13 | type ColumnFieldExprType interface { 14 | genorm.ColumnFieldExprType 15 | } 16 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | type UpdateContext[T Table] struct { 11 | *Context[T] 12 | assignExprs []*TableAssignExpr[T] 13 | whereCondition whereConditionClause[T] 14 | order orderClause[T] 15 | limit limitClause 16 | } 17 | 18 | func Update[T Table](table T) *UpdateContext[T] { 19 | ctx := newContext(table) 20 | 21 | return &UpdateContext[T]{ 22 | Context: ctx, 23 | } 24 | } 25 | 26 | func (c *UpdateContext[T]) Set(assignExprs ...*TableAssignExpr[T]) (res *UpdateContext[T]) { 27 | if len(assignExprs) == 0 { 28 | c.addError(errors.New("no assign expressions")) 29 | return c 30 | } 31 | 32 | c.assignExprs = append(c.assignExprs, assignExprs...) 33 | 34 | return c 35 | } 36 | 37 | func (c *UpdateContext[T]) Where( 38 | condition TypedTableExpr[T, WrappedPrimitive[bool]], 39 | ) *UpdateContext[T] { 40 | err := c.whereCondition.set(condition) 41 | if err != nil { 42 | c.addError(fmt.Errorf("where condition: %w", err)) 43 | } 44 | 45 | return c 46 | } 47 | 48 | func (c *UpdateContext[T]) OrderBy(direction OrderDirection, expr TableExpr[T]) *UpdateContext[T] { 49 | err := c.order.add(orderItem[T]{ 50 | expr: expr, 51 | direction: direction, 52 | }) 53 | if err != nil { 54 | c.addError(fmt.Errorf("order by: %w", err)) 55 | } 56 | 57 | return c 58 | } 59 | 60 | func (c *UpdateContext[T]) Limit(limit uint64) *UpdateContext[T] { 61 | err := c.limit.set(limit) 62 | if err != nil { 63 | c.addError(fmt.Errorf("limit: %w", err)) 64 | } 65 | 66 | return c 67 | } 68 | 69 | func (c *UpdateContext[T]) DoCtx(ctx context.Context, db DB) (rowsAffected int64, err error) { 70 | errs := c.Errors() 71 | if len(errs) != 0 { 72 | return 0, errs[0] 73 | } 74 | 75 | query, exprArgs, err := c.buildQuery() 76 | if err != nil { 77 | return 0, fmt.Errorf("build query: %w", err) 78 | } 79 | 80 | args := make([]any, 0, len(exprArgs)) 81 | for _, arg := range exprArgs { 82 | args = append(args, arg) 83 | } 84 | 85 | result, err := db.ExecContext(ctx, query, args...) 86 | if err != nil { 87 | return 0, fmt.Errorf("exec: %w", err) 88 | } 89 | 90 | rowsAffected, err = result.RowsAffected() 91 | if err != nil { 92 | return 0, fmt.Errorf("rows affected: %w", err) 93 | } 94 | 95 | return rowsAffected, nil 96 | } 97 | 98 | func (c *UpdateContext[T]) Do(db DB) (rowsAffected int64, err error) { 99 | return c.DoCtx(context.Background(), db) 100 | } 101 | 102 | func (c *UpdateContext[T]) buildQuery() (string, []ExprType, error) { 103 | args := []ExprType{} 104 | 105 | sb := strings.Builder{} 106 | 107 | str := "UPDATE " 108 | _, err := sb.WriteString(str) 109 | if err != nil { 110 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 111 | } 112 | 113 | tableQuery, tableArgs, errs := c.table.Expr() 114 | if len(errs) != 0 { 115 | return "", nil, fmt.Errorf("table expr: %w", errs[0]) 116 | } 117 | 118 | _, err = sb.WriteString(tableQuery) 119 | if err != nil { 120 | return "", nil, fmt.Errorf("write string(%s): %w", tableQuery, err) 121 | } 122 | 123 | args = append(args, tableArgs...) 124 | 125 | if len(c.assignExprs) == 0 { 126 | return "", nil, errors.New("no assignment") 127 | } 128 | 129 | str = " SET " 130 | _, err = sb.WriteString(str) 131 | if err != nil { 132 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 133 | } 134 | 135 | assignments := make([]string, 0, len(c.assignExprs)) 136 | for _, expr := range c.assignExprs { 137 | assignmentQuery, assignmentArgs, errs := expr.AssignExpr() 138 | if len(errs) != 0 { 139 | return "", nil, errs[0] 140 | } 141 | 142 | assignments = append(assignments, assignmentQuery) 143 | args = append(args, assignmentArgs...) 144 | } 145 | 146 | str = strings.Join(assignments, ", ") 147 | _, err = sb.WriteString(str) 148 | if err != nil { 149 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 150 | } 151 | 152 | if c.whereCondition.exists() { 153 | whereQuery, whereArgs, err := c.whereCondition.getExpr() 154 | if err != nil { 155 | return "", nil, fmt.Errorf("where condition: %w", err) 156 | } 157 | 158 | str = " WHERE " 159 | _, err = sb.WriteString(str) 160 | if err != nil { 161 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 162 | } 163 | 164 | _, err = sb.WriteString(whereQuery) 165 | if err != nil { 166 | return "", nil, fmt.Errorf("write string(%s): %w", whereQuery, err) 167 | } 168 | 169 | args = append(args, whereArgs...) 170 | } 171 | 172 | if c.order.exists() { 173 | orderQuery, orderArgs, err := c.order.getExpr() 174 | if err != nil { 175 | return "", nil, fmt.Errorf("order: %w", err) 176 | } 177 | 178 | str = " " 179 | _, err = sb.WriteString(str) 180 | if err != nil { 181 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 182 | } 183 | 184 | _, err = sb.WriteString(orderQuery) 185 | if err != nil { 186 | return "", nil, fmt.Errorf("write string(%s): %w", orderQuery, err) 187 | } 188 | 189 | args = append(args, orderArgs...) 190 | } 191 | 192 | if c.limit.exists() { 193 | limitQuery, limitArgs, err := c.limit.getExpr() 194 | if err != nil { 195 | return "", nil, fmt.Errorf("limit: %w", err) 196 | } 197 | 198 | str = " " 199 | _, err = sb.WriteString(str) 200 | if err != nil { 201 | return "", nil, fmt.Errorf("write string(%s): %w", str, err) 202 | } 203 | 204 | _, err = sb.WriteString(limitQuery) 205 | if err != nil { 206 | return "", nil, fmt.Errorf("write string(%s): %w", limitQuery, err) 207 | } 208 | 209 | args = append(args, limitArgs...) 210 | } 211 | 212 | return sb.String(), args, nil 213 | } 214 | -------------------------------------------------------------------------------- /update_exprter_test.go: -------------------------------------------------------------------------------- 1 | package genorm 2 | 3 | func (c *UpdateContext[T]) BuildQuery() (string, []ExprType, error) { 4 | return c.buildQuery() 5 | } 6 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package genorm_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/mazrean/genorm" 9 | "github.com/mazrean/genorm/mock" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestUpdateBuildQuery(t *testing.T) { 14 | t.Parallel() 15 | 16 | type expr struct { 17 | query string 18 | args []genorm.ExprType 19 | errs []error 20 | } 21 | 22 | type orderItem struct { 23 | direction genorm.OrderDirection 24 | expr expr 25 | } 26 | 27 | tests := []struct { 28 | description string 29 | tableExpr expr 30 | assignExprs []expr 31 | whereCondition *expr 32 | orderItems []orderItem 33 | limit uint64 34 | query string 35 | args []genorm.ExprType 36 | err bool 37 | }{ 38 | { 39 | description: "normal", 40 | tableExpr: expr{ 41 | query: "hoge", 42 | }, 43 | assignExprs: []expr{ 44 | { 45 | query: "hoge.huga = ?", 46 | args: []genorm.ExprType{genorm.Wrap(1)}, 47 | }, 48 | }, 49 | query: "UPDATE hoge SET hoge.huga = ?", 50 | args: []genorm.ExprType{genorm.Wrap(1)}, 51 | }, 52 | { 53 | description: "joined table", 54 | tableExpr: expr{ 55 | query: "hoge JOIN fuga ON hoge.id = fuga.id AND hoge.huga = ?", 56 | args: []genorm.ExprType{genorm.Wrap(1)}, 57 | }, 58 | assignExprs: []expr{ 59 | { 60 | query: "hoge.huga = ?", 61 | args: []genorm.ExprType{genorm.Wrap(2)}, 62 | }, 63 | }, 64 | query: "UPDATE hoge JOIN fuga ON hoge.id = fuga.id AND hoge.huga = ? SET hoge.huga = ?", 65 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 66 | }, 67 | { 68 | description: "multi assign", 69 | tableExpr: expr{ 70 | query: "hoge", 71 | }, 72 | assignExprs: []expr{ 73 | { 74 | query: "hoge.huga = ?", 75 | args: []genorm.ExprType{genorm.Wrap(1)}, 76 | }, 77 | { 78 | query: "hoge.nya = ?", 79 | args: []genorm.ExprType{genorm.Wrap(2)}, 80 | }, 81 | }, 82 | query: "UPDATE hoge SET hoge.huga = ?, hoge.nya = ?", 83 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 84 | }, 85 | { 86 | description: "where", 87 | tableExpr: expr{ 88 | query: "hoge", 89 | }, 90 | assignExprs: []expr{ 91 | { 92 | query: "hoge.huga = ?", 93 | args: []genorm.ExprType{genorm.Wrap(1)}, 94 | }, 95 | }, 96 | whereCondition: &expr{ 97 | query: "(hoge.huga = ?)", 98 | args: []genorm.ExprType{genorm.Wrap(2)}, 99 | }, 100 | query: "UPDATE hoge SET hoge.huga = ? WHERE (hoge.huga = ?)", 101 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 102 | }, 103 | { 104 | description: "where error", 105 | tableExpr: expr{ 106 | query: "hoge", 107 | }, 108 | assignExprs: []expr{ 109 | { 110 | query: "hoge.huga = ?", 111 | args: []genorm.ExprType{genorm.Wrap(1)}, 112 | }, 113 | }, 114 | whereCondition: &expr{ 115 | errs: []error{errors.New("where error")}, 116 | }, 117 | err: true, 118 | }, 119 | { 120 | description: "order by", 121 | tableExpr: expr{ 122 | query: "hoge", 123 | }, 124 | assignExprs: []expr{ 125 | { 126 | query: "hoge.huga = ?", 127 | args: []genorm.ExprType{genorm.Wrap(1)}, 128 | }, 129 | }, 130 | orderItems: []orderItem{ 131 | { 132 | direction: genorm.Asc, 133 | expr: expr{ 134 | query: "(hoge.huga = ?)", 135 | args: []genorm.ExprType{genorm.Wrap(2)}, 136 | }, 137 | }, 138 | }, 139 | query: "UPDATE hoge SET hoge.huga = ? ORDER BY (hoge.huga = ?) ASC", 140 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2)}, 141 | }, 142 | { 143 | description: "order by error", 144 | tableExpr: expr{ 145 | query: "hoge", 146 | }, 147 | assignExprs: []expr{ 148 | { 149 | query: "hoge.huga = ?", 150 | args: []genorm.ExprType{genorm.Wrap(1)}, 151 | }, 152 | }, 153 | orderItems: []orderItem{ 154 | { 155 | direction: genorm.Asc, 156 | expr: expr{ 157 | errs: []error{errors.New("order by error")}, 158 | }, 159 | }, 160 | }, 161 | err: true, 162 | }, 163 | { 164 | description: "multi order by", 165 | tableExpr: expr{ 166 | query: "hoge", 167 | }, 168 | assignExprs: []expr{ 169 | { 170 | query: "hoge.huga = ?", 171 | args: []genorm.ExprType{genorm.Wrap(1)}, 172 | }, 173 | }, 174 | orderItems: []orderItem{ 175 | { 176 | direction: genorm.Asc, 177 | expr: expr{ 178 | query: "(hoge.huga = ?)", 179 | args: []genorm.ExprType{genorm.Wrap(2)}, 180 | }, 181 | }, 182 | { 183 | direction: genorm.Desc, 184 | expr: expr{ 185 | query: "(hoge.nya = ?)", 186 | args: []genorm.ExprType{genorm.Wrap(3)}, 187 | }, 188 | }, 189 | }, 190 | query: "UPDATE hoge SET hoge.huga = ? ORDER BY (hoge.huga = ?) ASC, (hoge.nya = ?) DESC", 191 | args: []genorm.ExprType{genorm.Wrap(1), genorm.Wrap(2), genorm.Wrap(3)}, 192 | }, 193 | { 194 | description: "limit", 195 | tableExpr: expr{ 196 | query: "hoge", 197 | }, 198 | assignExprs: []expr{ 199 | { 200 | query: "hoge.huga = ?", 201 | args: []genorm.ExprType{genorm.Wrap(1)}, 202 | }, 203 | }, 204 | limit: 1, 205 | query: "UPDATE hoge SET hoge.huga = ? LIMIT 1", 206 | args: []genorm.ExprType{genorm.Wrap(1)}, 207 | }, 208 | } 209 | 210 | for _, test := range tests { 211 | t.Run(test.description, func(t *testing.T) { 212 | ctrl := gomock.NewController(t) 213 | 214 | table := mock.NewMockTable(ctrl) 215 | table. 216 | EXPECT(). 217 | Expr(). 218 | Return(test.tableExpr.query, test.tableExpr.args, test.tableExpr.errs) 219 | table. 220 | EXPECT(). 221 | GetErrors(). 222 | Return(nil) 223 | 224 | builder := genorm.Update(table) 225 | 226 | assignExprs := make([]*genorm.TableAssignExpr[*mock.MockTable], 0, len(test.assignExprs)) 227 | for _, assignExpr := range test.assignExprs { 228 | assignExprs = append(assignExprs, genorm.NewTableAssignExpr[*mock.MockTable](assignExpr.query, assignExpr.args, assignExpr.errs)) 229 | } 230 | builder = builder.Set(assignExprs...) 231 | 232 | if test.whereCondition != nil { 233 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 234 | mockExpr. 235 | EXPECT(). 236 | Expr(). 237 | Return(test.whereCondition.query, test.whereCondition.args, test.whereCondition.errs) 238 | 239 | builder = builder.Where(mockExpr) 240 | } 241 | 242 | for _, orderItem := range test.orderItems { 243 | mockExpr := mock.NewMockTypedTableExpr[*mock.MockTable, genorm.WrappedPrimitive[bool]](ctrl) 244 | mockExpr. 245 | EXPECT(). 246 | Expr(). 247 | Return(orderItem.expr.query, orderItem.expr.args, orderItem.expr.errs) 248 | builder = builder.OrderBy(orderItem.direction, mockExpr) 249 | } 250 | 251 | if test.limit > 0 { 252 | builder = builder.Limit(test.limit) 253 | } 254 | 255 | query, args, err := builder.BuildQuery() 256 | 257 | if test.err { 258 | assert.Error(t, err) 259 | return 260 | } else { 261 | if !assert.NoError(t, err) { 262 | return 263 | } 264 | } 265 | 266 | assert.Equal(t, test.query, query) 267 | assert.Equal(t, test.args, args) 268 | }) 269 | } 270 | } 271 | --------------------------------------------------------------------------------