├── .editorconfig ├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_cn.md ├── cmd └── cmd.go ├── example ├── build.sh ├── bun │ ├── user_model.gen.go │ └── user_model.go ├── example.sql ├── example_test │ ├── NOTES │ ├── bun │ │ ├── mock.go │ │ └── user_model.gen_test.go │ ├── gorm │ │ ├── mock.go │ │ └── user_model.gen_test.go │ ├── readme.md │ ├── sql │ │ ├── mock.go │ │ ├── scanner.go │ │ └── user_model.gen_test.go │ ├── sqlx │ │ ├── mock.go │ │ └── user_model.gen_test.go │ └── xorm │ │ ├── mock.go │ │ └── user_model.gen_test.go ├── go.mod ├── go.sum ├── gorm │ ├── user_model.gen.go │ └── user_model.go ├── sql │ ├── scanner.go │ ├── user_model.gen.go │ └── user_model.go ├── sqlx │ ├── user_model.gen.go │ └── user_model.go └── xorm │ ├── user_model.gen.go │ └── user_model.go ├── go.mod ├── go.sum ├── internal ├── buffer │ ├── buffer.go │ └── buffer_test.go ├── format │ ├── format.go │ └── format_test.go ├── gen │ ├── bun │ │ ├── bun.go │ │ ├── bun_custom.tpl │ │ ├── bun_gen.tpl │ │ └── bun_test.go │ ├── flags │ │ └── flags.go │ ├── gorm │ │ ├── funcmap.go │ │ ├── gorm.go │ │ ├── gorm_custom.tpl │ │ ├── gorm_gen.tpl │ │ ├── gorm_test.go │ │ └── table.go │ ├── sql │ │ ├── scanner.tpl │ │ ├── sql.go │ │ ├── sql_custom.tpl │ │ ├── sql_gen.tpl │ │ └── sql_test.go │ ├── sqlx │ │ ├── sqlx.go │ │ ├── sqlx_custom.tpl │ │ ├── sqlx_gen.tpl │ │ └── sqlx_test.go │ ├── testdata │ │ ├── test.sql │ │ └── testdata.go │ └── xorm │ │ ├── xorm.go │ │ ├── xorm_custom.tpl │ │ ├── xorm_gen.tpl │ │ └── xorm_test.go ├── infoschema │ ├── LICENSE │ ├── infoschemamodel.go │ ├── infoschemamodel_test.go │ └── mock_infoschemamodel.go ├── log │ ├── log.go │ └── log_test.go ├── parameter │ ├── parameter.go │ └── parameter_test.go ├── parser │ ├── column.go │ ├── column_test.go │ ├── comment.go │ ├── comment_test.go │ ├── ddl.go │ ├── delete.go │ ├── delete_test.go │ ├── dml.go │ ├── dml_test.go │ ├── error.go │ ├── funcmap.go │ ├── infoschema.go │ ├── infoschema_test.go │ ├── init.tpl.sql │ ├── insert.go │ ├── insert_test.go │ ├── parser.go │ ├── parser_test.go │ ├── select.go │ ├── select_test.go │ ├── table.go │ ├── test.sql │ ├── update.go │ └── update_test.go ├── patterns │ ├── patterns.go │ └── patterns_test.go ├── set │ ├── set.go │ └── set_test.go ├── spec │ ├── action.go │ ├── byitem.go │ ├── clause.go │ ├── column.tpl │ ├── comment.go │ ├── constraint.go │ ├── converter.go │ ├── ddl.go │ ├── delete.go │ ├── dml.go │ ├── insert.go │ ├── limit.go │ ├── op.go │ ├── select.go │ ├── spec.go │ ├── sql.go │ ├── stmt.go │ ├── table.go │ ├── transaction.go │ ├── type.go │ └── update.go ├── stringx │ ├── stringx.go │ └── stringx_test.go └── templatex │ ├── funcmap.go │ ├── funcmap_test.go │ ├── templatex.go │ └── templatex_test.go └── main.go /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | indent_size = 4 7 | indent_style = space 8 | insert_final_newline = false 9 | max_line_length = 100 10 | tab_width = 4 11 | ij_continuation_indent_size = 8 12 | ij_formatter_off_tag = @formatter:off 13 | ij_formatter_on_tag = @formatter:on 14 | ij_formatter_tags_enabled = false 15 | ij_smart_tabs = false 16 | ij_visual_guides = 80,100 17 | ij_wrap_on_typing = false 18 | 19 | [{*.go,*.go2}] 20 | indent_style = tab 21 | ij_continuation_indent_size = 4 22 | ij_visual_guides = none 23 | ij_go_GROUP_CURRENT_PROJECT_IMPORTS = true 24 | ij_go_add_leading_space_to_comments = false 25 | ij_go_add_parentheses_for_single_import = false 26 | ij_go_call_parameters_new_line_after_left_paren = true 27 | ij_go_call_parameters_right_paren_on_new_line = true 28 | ij_go_call_parameters_wrap = off 29 | ij_go_fill_paragraph_width = 80 30 | ij_go_group_stdlib_imports = true 31 | ij_go_import_sorting = goimports 32 | ij_go_keep_indents_on_empty_lines = false 33 | ij_go_local_group_mode = project 34 | ij_go_move_all_imports_in_one_declaration = true 35 | ij_go_move_all_stdlib_imports_in_one_group = true 36 | ij_go_remove_redundant_import_aliases = false 37 | ij_go_run_go_fmt_on_reformat = true 38 | ij_go_use_back_quotes_for_imports = false 39 | ij_go_wrap_comp_lit = off 40 | ij_go_wrap_comp_lit_newline_after_lbrace = true 41 | ij_go_wrap_comp_lit_newline_before_rbrace = true 42 | ij_go_wrap_func_params = off 43 | ij_go_wrap_func_params_newline_after_lparen = true 44 | ij_go_wrap_func_params_newline_before_rparen = true 45 | ij_go_wrap_func_result = off 46 | ij_go_wrap_func_result_newline_after_lparen = true 47 | ij_go_wrap_func_result_newline_before_rparen = true 48 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v3 18 | with: 19 | go-version: 1.18 20 | 21 | - name: Build 22 | run: go build -v ./... 23 | 24 | - name: Test 25 | run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... 26 | 27 | - uses: codecov/codecov-action@v2 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .DS_Store 18 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 anqiansong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlgen 2 | 3 | English | [中文](README_cn.md) 4 | 5 | [![Go](https://github.com/anqiansong/sqlgen/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/anqiansong/sqlgen/actions/workflows/go.yml) 6 | [![codecov](https://codecov.io/gh/anqiansong/sqlgen/branch/main/graph/badge.svg?token=8mLCFUqD2l)](https://codecov.io/gh/anqiansong/sqlgen) 7 | [![Go Reference](https://pkg.go.dev/badge/github.com/anqiansong/sqlgen.svg)](https://pkg.go.dev/github.com/anqiansong/sqlgen) 8 | [![Go Report Card](https://goreportcard.com/badge/github.com/anqiansong/sqlgen)](https://goreportcard.com/report/github.com/anqiansong/sqlgen) 9 | [![Release](https://img.shields.io/github/v/release/anqiansong/sqlgen.svg?style=flat-square)](https://github.com/anqiansong/sqlgen) 10 | [![GitHub license](https://img.shields.io/github/license/anqiansong/sqlgen?style=flat-square)](https://github.com/anqiansong/sqlgen/blob/main/LICENSE) 11 | 12 | sqlgen is a tool to generate **bun**, **gorm**, **sql**, **sqlx** and **xorm** sql code from SQL 13 | file which is inspired by 14 | 15 | - [go-zero](https://github.com/zeromicro/go-zero) 16 | - [goctl](https://github.com/zeromicro/go-zero/tree/master/tools/goctl) 17 | - [sqlc](https://github.com/kyleconroy/sqlc). 18 | 19 | # Installation 20 | 21 | ```bash 22 | go install github.com/anqiansong/sqlgen@latest 23 | ``` 24 | 25 | # Example 26 | 27 | See [example](https://github.com/anqiansong/sqlgen/tree/main/example) 28 | 29 | # Queries rule 30 | 31 | ## 1. Function Name 32 | 33 | You can define a function via `fn` keyword in line comment, for example: 34 | 35 | ```sql 36 | -- fn: my_func 37 | SELECT * 38 | FROM user; 39 | ``` 40 | 41 | it will be generated as: 42 | 43 | ```go 44 | func (m *UserModel) my_func (...) { 45 | ... 46 | } 47 | ``` 48 | 49 | ## 2. Get One Record 50 | 51 | The expression `limit 1` must be explicitly defined if you want to get only one record, for example: 52 | 53 | ```sql 54 | -- fn: FindOne 55 | select * 56 | from user 57 | where id = ? limit 1; 58 | ``` 59 | 60 | ## 3. Marker or Values? 61 | 62 | For arguments of SQL, you can use `?` or explicitly values to mark them, in sqlgen, the arguments 63 | will be converted into variables, for example, the following query are equivalent: 64 | 65 | > NOTES: It does not apply to rule 2 66 | 67 | ```sql 68 | -- fn: FindLimit 69 | select * 70 | from user 71 | where id = ?; 72 | 73 | -- fn: FindLimit 74 | select * 75 | from user 76 | where id = 1; 77 | 78 | ``` 79 | 80 | ## 4. SQL Function 81 | 82 | sqlgen supports aggregate function queries in sql, other than that, other functions are not 83 | supported so far. All the aggregate function query expressions must contain AS expression, for 84 | example: 85 | 86 | ```sql 87 | -- fn: CountAll 88 | select count(*) as count 89 | from user; 90 | ``` 91 | 92 | For most query cases, you can 93 | see [example.sql](https://github.com/anqiansong/sqlgen/blob/main/example/example.sql) for details. 94 | 95 | # How it works 96 | 97 | 1. Create a SQL file 98 | 2. Write your SQL code in the SQL file 99 | 3. Run `sqlgen` to generate code 100 | 101 | # Notes 102 | 103 | 1. Only support MYSQL code generation. 104 | 3. Do not support multiple tables in one SQL file. 105 | 4. Do not support join query. -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | # sqlgen 2 | 3 | [English](README.md) | 中文 4 | 5 | [![Go](https://github.com/anqiansong/sqlgen/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/anqiansong/sqlgen/actions/workflows/go.yml) 6 | [![codecov](https://codecov.io/gh/anqiansong/sqlgen/branch/main/graph/badge.svg?token=8mLCFUqD2l)](https://codecov.io/gh/anqiansong/sqlgen) 7 | [![Go Reference](https://pkg.go.dev/badge/github.com/anqiansong/sqlgen.svg)](https://pkg.go.dev/github.com/anqiansong/sqlgen) 8 | [![Go Report Card](https://goreportcard.com/badge/github.com/anqiansong/sqlgen)](https://goreportcard.com/report/github.com/anqiansong/sqlgen) 9 | [![Release](https://img.shields.io/github/v/release/anqiansong/sqlgen.svg?style=flat-square)](https://github.com/anqiansong/sqlgen) 10 | [![GitHub license](https://img.shields.io/github/license/anqiansong/sqlgen?style=flat-square)](https://github.com/anqiansong/sqlgen/blob/main/LICENSE) 11 | 12 | sqlgen 是一个 SQL 代码生成工具,其支持 **bun**, **gorm**, **sql**, **sqlx**, **xorm** 的代码生成,灵感来自于: 13 | 14 | - [go-zero](https://github.com/zeromicro/go-zero) 15 | - [goctl](https://github.com/zeromicro/go-zero/tree/master/tools/goctl) 16 | - [sqlc](https://github.com/kyleconroy/sqlc). 17 | 18 | 19 | # 安装 20 | 21 | ```bash 22 | go install github.com/anqiansong/sqlgen@latest 23 | ``` 24 | 25 | # 示例 26 | 27 | 见 [example](https://github.com/anqiansong/sqlgen/tree/main/example) 28 | 29 | # SQL 查询编写规则 30 | ## 1. 函数名称 31 | 你可以通过在查询语句上方添加一个单行注释,用 `fn` 关键字来声明一个函数名称,例如: 32 | 33 | ```sql 34 | -- fn: my_func 35 | SELECT * FROM user; 36 | ``` 37 | 38 | 其生成后代码格式为: 39 | 40 | ```go 41 | func (m *UserModel) my_func (...) { 42 | ... 43 | } 44 | ``` 45 | 46 | ## 2. 查询一条记录 47 | 当你只想要查询一条记录的需求时,你必须明确地指定 `limit 1`,sqlgen 通过此表达式来判断当前查询是单记录查询还是多记录查询,例如: 48 | 49 | ```sql 50 | -- fn: FindOne 51 | select * from user where id = ? limit 1; 52 | ``` 53 | 54 | ## 3. 使用 '?' 还是具体值? 55 | 在 SQL 查询语句的编写中,你可以用 `?` 来替代一个参数,也可以是具体值,他们最终都会被 sqlgen 转换成一个变量,下列示例中的两个查询是等价的。 56 | 57 | > 注意: 此规则不适用于规则 2 58 | 59 | ```sql 60 | -- fn: FindLimit 61 | select * from user where id = ?; 62 | 63 | -- fn: FindLimit 64 | select * from user where id = 1; 65 | 66 | ``` 67 | 68 | ## 4. SQL 内置函数支持 69 | sqlgen 支持 SQL 内置的聚合函数查询,除此之外的其他函数暂不支持,聚合函数查询的列必须要用 `AS` 来起一个别名,例如: 70 | 71 | ```sql 72 | -- fn: CountAll 73 | select count(*) as count from user; 74 | ``` 75 | 76 | 更多查询示例, 你可以点击 [example.sql](https://github.com/anqiansong/sqlgen/blob/main/example/example.sql) 查看详情. 77 | 78 | # sqlgen 使用步骤 79 | 1. 创建一个 SQL 文件 80 | 2. 编写 SQL 查询语句,如建表语句、查询语句等 81 | 3. 使用 `sqlgen` 工具,生成代码 82 | 83 | # 注意 84 | 1. 目前只支持 MYSQL 代码生成 85 | 3. 不支持多表操作 86 | 4. 不支持联表查询 -------------------------------------------------------------------------------- /cmd/cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/anqiansong/sqlgen/internal/gen/flags" 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | const buildVersion = "0.0.3" 11 | 12 | var arg flags.RunArg 13 | 14 | var rootCmd = &cobra.Command{ 15 | Use: "sqlgen", 16 | Short: "A cli for mysql generator", 17 | } 18 | 19 | var sqlCmd = &cobra.Command{ 20 | Use: "sql", 21 | Short: "Generate sql model", 22 | Run: func(cmd *cobra.Command, args []string) { 23 | arg.Mode = flags.SQL 24 | flags.Run(arg) 25 | }, 26 | } 27 | 28 | var gormCmd = &cobra.Command{ 29 | Use: "gorm", 30 | Short: "Generate gorm model", 31 | Run: func(cmd *cobra.Command, args []string) { 32 | arg.Mode = flags.GORM 33 | flags.Run(arg) 34 | }, 35 | } 36 | 37 | var xormCmd = &cobra.Command{ 38 | Use: "xorm", 39 | Short: "Generate xorm model", 40 | Run: func(cmd *cobra.Command, args []string) { 41 | arg.Mode = flags.XORM 42 | flags.Run(arg) 43 | }, 44 | } 45 | 46 | var sqlxCmd = &cobra.Command{ 47 | Use: "sqlx", 48 | Short: "Generate sqlx model", 49 | Run: func(cmd *cobra.Command, args []string) { 50 | arg.Mode = flags.SQLX 51 | flags.Run(arg) 52 | }, 53 | } 54 | 55 | var bunCmd = &cobra.Command{ 56 | Use: "bun", 57 | Short: "Generate bun model", 58 | Run: func(cmd *cobra.Command, args []string) { 59 | arg.Mode = flags.BUN 60 | flags.Run(arg) 61 | }, 62 | } 63 | 64 | func init() { 65 | // flags init 66 | var persistentFlags = rootCmd.PersistentFlags() 67 | persistentFlags.StringVarP(&arg.DSN, "dsn", "d", "", "Mysql address") 68 | persistentFlags.StringSliceVarP(&arg.Table, "table", "t", []string{"*"}, "Patterns of table name") 69 | persistentFlags.StringSliceVarP(&arg.Filename, "filename", "f", []string{"*.sql"}, "Patterns of SQL filename") 70 | persistentFlags.StringVarP(&arg.Output, "output", "o", ".", "The output directory") 71 | 72 | // sub commands init 73 | rootCmd.AddCommand(bunCmd) 74 | rootCmd.AddCommand(gormCmd) 75 | rootCmd.AddCommand(sqlCmd) 76 | rootCmd.AddCommand(sqlxCmd) 77 | rootCmd.AddCommand(xormCmd) 78 | rootCmd.Version = buildVersion 79 | rootCmd.CompletionOptions.DisableDefaultCmd = true 80 | } 81 | 82 | // Execute executes the sql cmd. 83 | func Execute() { 84 | if err := rootCmd.Execute(); err != nil { 85 | os.Exit(1) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /example/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wd=$(pwd) 4 | 5 | function generate() { 6 | dir=$1 7 | name=$2 8 | output="$dir/$name" 9 | if [ "$output" == "/" ]; then 10 | exit 1 11 | fi 12 | 13 | rm -rf "$output" 14 | mkdir -p "$output" 15 | 16 | cd "$output" 17 | sqlgen $name -f "$dir/example.sql" -o . 18 | } 19 | 20 | 21 | cd "$wd" 22 | 23 | # generate bun code 24 | generate "$wd" "bun" 25 | 26 | # generate gorm code 27 | generate "$wd" "gorm" 28 | 29 | # generate sql code 30 | generate "$wd" "sql" 31 | 32 | # generate sqlx code 33 | generate "$wd" "sqlx" 34 | 35 | # generate xorm code 36 | generate "$wd" "xorm" 37 | 38 | # go mod tidy 39 | go mod tidy 40 | 41 | go test ./... -------------------------------------------------------------------------------- /example/bun/user_model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /example/example_test/NOTES: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 zeromicro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /example/example_test/bun/mock.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "time" 5 | 6 | model "github.com/anqiansong/sqlgen/example/bun" 7 | uuid "github.com/satori/go.uuid" 8 | ) 9 | 10 | func mustMockUser() *model.User { 11 | uid := uuid.NewV4().String() 12 | now := time.Now() 13 | return &model.User{ 14 | Name: uid, 15 | Password: "bar", 16 | Mobile: uid, 17 | Gender: "male", 18 | Nickname: "test", 19 | Type: 1, 20 | CreateAt: now, 21 | UpdateAt: now, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /example/example_test/gorm/mock.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "time" 5 | 6 | model "github.com/anqiansong/sqlgen/example/gorm" 7 | uuid "github.com/satori/go.uuid" 8 | ) 9 | 10 | func mustMockUser() *model.User { 11 | uid := uuid.NewV4().String() 12 | now := time.Now() 13 | return &model.User{ 14 | Name: uid, 15 | Password: "bar", 16 | Mobile: uid, 17 | Gender: "male", 18 | Nickname: "test", 19 | Type: 1, 20 | CreateAt: now, 21 | UpdateAt: now, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /example/example_test/readme.md: -------------------------------------------------------------------------------- 1 | # example_test 2 | 3 | ## before test 4 | 1. started docker 5 | 2. run a mysql container which dsn is `root:mysqlpw@(localhost:55000)` 6 | 3. new a schema `test` 7 | 4. create a table use the following sql 8 | ```sql 9 | CREATE TABLE `user` 10 | ( 11 | `id` bigint(10) unsigned NOT NULL AUTO_INCREMENT primary key, 12 | `name` varchar(255) COLLATE utf8mb4_general_ci NULL COMMENT 'The username', 13 | `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The \n user password', 14 | `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'The mobile phone number', 15 | `gender` char(10) COLLATE utf8mb4_general_ci NOT NULL COMMENT 'gender,male|female|unknown', 16 | `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT 'The nickname', 17 | `type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT 'The user type, 0:normal,1:vip, for test golang keyword', 18 | `create_at` timestamp NULL, 19 | `update_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 20 | UNIQUE KEY `name_index` (`name`), 21 | UNIQUE KEY `mobile_index` (`mobile`) 22 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'user table' COLLATE=utf8mb4_general_ci; 23 | ``` 24 | 5. clean the test data and set auto_increment to `1` 25 | 6. run the test 26 | 27 | 28 | -------------------------------------------------------------------------------- /example/example_test/sql/mock.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "time" 5 | 6 | model "github.com/anqiansong/sqlgen/example/sql" 7 | uuid "github.com/satori/go.uuid" 8 | ) 9 | 10 | func mustMockUser() *model.User { 11 | uid := uuid.NewV4().String() 12 | now := time.Now() 13 | return &model.User{ 14 | Name: uid, 15 | Password: "bar", 16 | Mobile: uid, 17 | Gender: "male", 18 | Nickname: "test", 19 | Type: 1, 20 | CreateAt: now, 21 | UpdateAt: now, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /example/example_test/sql/scanner.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "reflect" 7 | 8 | model "github.com/anqiansong/sqlgen/example/sql" 9 | "github.com/iancoleman/strcase" 10 | ) 11 | 12 | type customScanner struct { 13 | } 14 | 15 | func (c customScanner) ColumnMapper(colName string) string { 16 | return strcase.ToCamel(colName) 17 | } 18 | 19 | func (c customScanner) TagKey() string { 20 | return `db` 21 | } 22 | 23 | func (c customScanner) getRowElem(rows *sql.Rows, v interface{}) ([]interface{}, error) { 24 | var elem reflect.Value 25 | value, ok := v.(reflect.Value) 26 | if !ok { 27 | elem = reflect.ValueOf(v) 28 | } else { 29 | elem = value 30 | } 31 | 32 | switch elem.Kind() { 33 | case reflect.Pointer: 34 | return c.getRowElem(rows, elem.Elem()) 35 | case reflect.Struct: 36 | var list []interface{} 37 | cols, err := rows.Columns() 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | targetField := make(map[string]reflect.Value) 43 | for i := 0; i < elem.NumField(); i++ { 44 | f := elem.Field(i) 45 | t := elem.Type().Field(i) 46 | tag, ok := t.Tag.Lookup(c.TagKey()) 47 | if ok { 48 | targetField[tag] = f 49 | } 50 | } 51 | 52 | for _, name := range cols { 53 | f, ok := targetField[name] 54 | if !ok { 55 | f = elem.FieldByName(c.ColumnMapper(name)) 56 | } 57 | if f.CanAddr() { 58 | list = append(list, f.Addr().Interface()) 59 | } 60 | } 61 | return list, nil 62 | default: 63 | return nil, errors.New("expect a struct") 64 | } 65 | } 66 | 67 | // getRowsElem is inspired by https://github.com/zeromicro/go-zero/blob/8ed22eafdda04c4526164450d7c13c2f4b0f076c/core/stores/sqlx/orm.go#L163 68 | func (c customScanner) getRowsElem(rows *sql.Rows, v interface{}) error { 69 | valueOf := reflect.ValueOf(v) 70 | if valueOf.Kind() != reflect.Ptr { 71 | return errors.New("expect a pointer") 72 | } 73 | 74 | typeOf := reflect.TypeOf(v) 75 | sliceTypeOf := typeOf.Elem() 76 | sliceValueOf := valueOf.Elem() 77 | 78 | if sliceTypeOf.Kind() != reflect.Slice { 79 | return errors.New("expect a slice") 80 | } 81 | if !sliceValueOf.CanSet() { 82 | return errors.New("expect a settable slice") 83 | } 84 | isASlicePointer := sliceTypeOf.Elem().Kind() == reflect.Ptr 85 | 86 | var itemReceiver reflect.Type 87 | itemType := sliceTypeOf.Elem() 88 | if itemType.Kind() == reflect.Ptr { 89 | itemReceiver = itemType.Elem() 90 | } else { 91 | itemReceiver = itemType 92 | } 93 | if itemReceiver.Kind() != reflect.Struct { 94 | return errors.New("expect a struct") 95 | } 96 | 97 | for rows.Next() { 98 | value := reflect.New(itemReceiver) 99 | dest, err := c.getRowElem(rows, value) 100 | if err != nil { 101 | return err 102 | } 103 | 104 | err = rows.Scan(dest...) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | if isASlicePointer { 110 | sliceValueOf.Set(reflect.Append(sliceValueOf, value)) 111 | } else { 112 | sliceValueOf.Set(reflect.Append(sliceValueOf, reflect.Indirect(sliceValueOf))) 113 | } 114 | } 115 | 116 | return nil 117 | } 118 | 119 | func (c customScanner) ScanRow(rows *sql.Rows, v interface{}) error { 120 | if !rows.Next() { 121 | return sql.ErrNoRows 122 | } 123 | 124 | dest, err := c.getRowElem(rows, v) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | return rows.Scan(dest...) 130 | } 131 | 132 | func (c customScanner) ScanRows(rows *sql.Rows, v interface{}) error { 133 | return c.getRowsElem(rows, v) 134 | } 135 | 136 | func getScanner() model.Scanner { 137 | return customScanner{} 138 | } 139 | -------------------------------------------------------------------------------- /example/example_test/sqlx/mock.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "time" 5 | 6 | model "github.com/anqiansong/sqlgen/example/sqlx" 7 | uuid "github.com/satori/go.uuid" 8 | ) 9 | 10 | func mustMockUser() *model.User { 11 | uid := uuid.NewV4().String() 12 | now := time.Now() 13 | return &model.User{ 14 | Name: uid, 15 | Password: "bar", 16 | Mobile: uid, 17 | Gender: "male", 18 | Nickname: "test", 19 | Type: 1, 20 | CreateAt: now, 21 | UpdateAt: now, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /example/example_test/xorm/mock.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "time" 5 | 6 | model "github.com/anqiansong/sqlgen/example/xorm" 7 | uuid "github.com/satori/go.uuid" 8 | ) 9 | 10 | func mustMockUser() *model.User { 11 | uid := uuid.NewV4().String() 12 | now := time.Now() 13 | return &model.User{ 14 | Name: uid, 15 | Password: "bar", 16 | Mobile: uid, 17 | Gender: "male", 18 | Nickname: "test", 19 | Type: 1, 20 | CreateAt: now, 21 | UpdateAt: now, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /example/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/anqiansong/sqlgen/example 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.6.0 7 | github.com/iancoleman/strcase v0.2.0 8 | github.com/jmoiron/sqlx v1.3.5 9 | github.com/satori/go.uuid v1.2.0 10 | github.com/shopspring/decimal v1.2.0 11 | github.com/stretchr/testify v1.7.0 12 | github.com/uptrace/bun v1.1.7 13 | github.com/uptrace/bun/dialect/mysqldialect v1.1.7 14 | github.com/uptrace/bun/extra/bundebug v1.1.7 15 | gorm.io/driver/mysql v1.3.6 16 | gorm.io/gorm v1.23.8 17 | xorm.io/builder v0.3.12 18 | xorm.io/xorm v1.3.1 19 | ) 20 | 21 | require ( 22 | github.com/davecgh/go-spew v1.1.1 // indirect 23 | github.com/fatih/color v1.13.0 // indirect 24 | github.com/goccy/go-json v0.8.1 // indirect 25 | github.com/golang/snappy v0.0.4 // indirect 26 | github.com/jinzhu/inflection v1.0.0 // indirect 27 | github.com/jinzhu/now v1.1.5 // indirect 28 | github.com/json-iterator/go v1.1.12 // indirect 29 | github.com/mattn/go-colorable v0.1.12 // indirect 30 | github.com/mattn/go-isatty v0.0.14 // indirect 31 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 32 | github.com/modern-go/reflect2 v1.0.2 // indirect 33 | github.com/pmezard/go-difflib v1.0.0 // indirect 34 | github.com/syndtr/goleveldb v1.0.0 // indirect 35 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect 36 | github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect 37 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 38 | golang.org/x/mod v0.5.1 // indirect 39 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect 40 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 41 | ) 42 | -------------------------------------------------------------------------------- /example/gorm/user_model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /example/sql/scanner.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "database/sql" 4 | 5 | type Scanner interface { 6 | ScanRow(rows *sql.Rows, v interface{}) error 7 | ScanRows(rows *sql.Rows, v interface{}) error 8 | ColumnMapper(colName string) string 9 | TagKey() string 10 | } 11 | -------------------------------------------------------------------------------- /example/sql/user_model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /example/sqlx/user_model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *UserModel) Customize(ctx context.Context, args ...any) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /example/xorm/user_model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *UserModel) Customize(ctx context.Context, args ...interface{}) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/anqiansong/sqlgen 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.0 7 | github.com/agiledragon/gomonkey/v2 v2.8.0 8 | github.com/go-sql-driver/mysql v1.6.0 9 | github.com/golang/mock v1.6.0 10 | github.com/iancoleman/strcase v0.2.0 11 | github.com/pingcap/parser v0.0.0-20220622031236-3bca03d3057b 12 | github.com/spf13/cobra v1.5.0 13 | github.com/stretchr/testify v1.8.0 14 | github.com/zeromicro/go-zero v1.3.5 15 | golang.org/x/tools v0.1.12 16 | ) 17 | 18 | require ( 19 | github.com/BurntSushi/toml v1.2.0 // indirect 20 | github.com/benbjohnson/clock v1.3.0 // indirect 21 | github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect 22 | github.com/davecgh/go-spew v1.1.1 // indirect 23 | github.com/fatih/color v1.13.0 // indirect 24 | github.com/go-logr/logr v1.2.3 // indirect 25 | github.com/go-logr/stdr v1.2.2 // indirect 26 | github.com/golang/protobuf v1.5.2 // indirect 27 | github.com/inconshreveable/mousetrap v1.0.0 // indirect 28 | github.com/mattn/go-colorable v0.1.12 // indirect 29 | github.com/mattn/go-isatty v0.0.14 // indirect 30 | github.com/openzipkin/zipkin-go v0.4.0 // indirect 31 | github.com/pelletier/go-toml/v2 v2.0.2 // indirect 32 | github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 // indirect 33 | github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 // indirect 34 | github.com/pingcap/log v1.1.0 // indirect 35 | github.com/pmezard/go-difflib v1.0.0 // indirect 36 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect 37 | github.com/spaolacci/murmur3 v1.1.0 // indirect 38 | github.com/spf13/pflag v1.0.5 // indirect 39 | go.opentelemetry.io/otel v1.8.0 // indirect 40 | go.opentelemetry.io/otel/exporters/jaeger v1.8.0 // indirect 41 | go.opentelemetry.io/otel/exporters/zipkin v1.8.0 // indirect 42 | go.opentelemetry.io/otel/sdk v1.8.0 // indirect 43 | go.opentelemetry.io/otel/trace v1.8.0 // indirect 44 | go.uber.org/atomic v1.9.0 // indirect 45 | go.uber.org/automaxprocs v1.5.1 // indirect 46 | go.uber.org/multierr v1.8.0 // indirect 47 | go.uber.org/zap v1.21.0 // indirect 48 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect 49 | golang.org/x/sys v0.0.0-20220727055044-e65921a090b8 // indirect 50 | golang.org/x/text v0.3.7 // indirect 51 | google.golang.org/grpc v1.48.0 // indirect 52 | google.golang.org/protobuf v1.28.0 // indirect 53 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect 54 | gopkg.in/yaml.v2 v2.4.0 // indirect 55 | gopkg.in/yaml.v3 v3.0.1 // indirect 56 | ) 57 | -------------------------------------------------------------------------------- /internal/buffer/buffer.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type b struct { 9 | list []string 10 | } 11 | 12 | func New() *b { 13 | return &b{} 14 | } 15 | 16 | func (b *b) Reset() { 17 | b.list = nil 18 | } 19 | 20 | func (b *b) Write(format string, a ...any) { 21 | b.list = append(b.list, fmt.Sprintf(format, a...)) 22 | } 23 | 24 | func (b *b) String() string { 25 | return strings.Join(b.list, "\n") 26 | } 27 | -------------------------------------------------------------------------------- /internal/buffer/buffer_test.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNew(t *testing.T) { 11 | instance := New() 12 | assert.NotNil(t, instance) 13 | } 14 | 15 | func TestB_Reset(t *testing.T) { 16 | instance := New() 17 | instance.Write("foo") 18 | instance.Reset() 19 | assert.Equal(t, 0, len(instance.list)) 20 | } 21 | 22 | func TestB_Write(t *testing.T) { 23 | testData := map[string]string{ 24 | "": "", 25 | "foo": "foo", 26 | "foo,bar": "foo\nbar", 27 | } 28 | instance := New() 29 | for input, expected := range testData { 30 | instance.Reset() 31 | fields := strings.FieldsFunc(input, func(r rune) bool { 32 | return r == ',' 33 | }) 34 | for _, field := range fields { 35 | instance.Write(field) 36 | } 37 | assert.Equal(t, expected, instance.String()) 38 | } 39 | } 40 | 41 | func TestB_String(t *testing.T) { 42 | instance := New() 43 | instance.Write("foo") 44 | assert.Equal(t, "foo", instance.String()) 45 | } 46 | -------------------------------------------------------------------------------- /internal/format/format.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "go/format" 5 | 6 | "golang.org/x/tools/imports" 7 | ) 8 | 9 | // Source formats go code and imports. 10 | func Source(data []byte) ([]byte, error) { 11 | ret, err := format.Source(data) 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | return imports.Process("", ret, nil) 17 | } 18 | -------------------------------------------------------------------------------- /internal/format/format_test.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSource(t *testing.T) { 10 | testData := []struct { 11 | input []byte 12 | expect []byte 13 | error bool 14 | }{ 15 | { 16 | input: []byte(``), 17 | error: true, 18 | }, 19 | { 20 | input: []byte(`package p`), 21 | expect: []byte(`package p 22 | `), 23 | error: false, 24 | }, 25 | { 26 | input: []byte(`package p;`), 27 | expect: []byte(`package p 28 | `), 29 | error: false, 30 | }, 31 | { 32 | input: []byte(`package p 33 | import "fmt"`), 34 | expect: []byte(`package p 35 | `), 36 | error: false, 37 | }, 38 | { 39 | input: []byte(`package p 40 | import "foo'`), 41 | error: true, 42 | }, 43 | } 44 | for _, data := range testData { 45 | actual, err := Source(data.input) 46 | if data.error { 47 | assert.Error(t, err) 48 | continue 49 | } 50 | assert.NoError(t, err) 51 | assert.Equal(t, string(data.expect), string(actual)) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /internal/gen/bun/bun.go: -------------------------------------------------------------------------------- 1 | package bun 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "text/template" 8 | 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/anqiansong/sqlgen/internal/templatex" 11 | ) 12 | 13 | //go:embed bun_gen.tpl 14 | var bunGenTpl string 15 | 16 | //go:embed bun_custom.tpl 17 | var bunCustomTpl string 18 | 19 | func Run(list []spec.Context, output string) error { 20 | for _, ctx := range list { 21 | var genFilename = filepath.Join(output, fmt.Sprintf("%s_model.gen.go", ctx.Table.Name)) 22 | var customFilename = filepath.Join(output, fmt.Sprintf("%s_model.go", ctx.Table.Name)) 23 | gen := templatex.New() 24 | gen.AppendFuncMap(template.FuncMap{ 25 | "IsPrimary": func(name string) bool { 26 | return ctx.Table.IsPrimary(name) 27 | }, 28 | }) 29 | gen.MustParse(bunGenTpl) 30 | gen.MustExecute(ctx) 31 | gen.MustSaveAs(genFilename, true) 32 | 33 | custom := templatex.New() 34 | custom.MustParse(bunCustomTpl) 35 | custom.MustExecute(ctx) 36 | custom.MustSave(customFilename, true) 37 | } 38 | return nil 39 | } 40 | -------------------------------------------------------------------------------- /internal/gen/bun/bun_custom.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { 7 | 8 | } -------------------------------------------------------------------------------- /internal/gen/bun/bun_gen.tpl: -------------------------------------------------------------------------------- 1 | // Code generated by sqlgen. DO NOT EDIT! 2 | 3 | package model 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "fmt" 9 | "time" 10 | 11 | "github.com/uptrace/bun" 12 | "github.com/shopspring/decimal" 13 | ) 14 | 15 | // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. 16 | type {{UpperCamel $.Table.Name}}Model struct { 17 | db bun.IDB 18 | } 19 | 20 | // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. 21 | type {{UpperCamel $.Table.Name}} struct { 22 | bun.BaseModel `bun:"table:{{$.Table.Name}}"`{{range $.Table.Columns}} 23 | {{UpperCamel .Name}} {{.GoType}} `bun:"{{.Name}}{{if IsPrimary .Name}},pk{{end}}{{if .AutoIncrement}},autoincrement{{end}}" json:"{{LowerCamel .Name}}"`{{if .HasComment}}// {{TrimNewLine .Comment}}{{end}}{{end}} 24 | } 25 | {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 26 | {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} 27 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 28 | {{end}}{{$stmt.ReceiverStructure "bun"}} 29 | {{end}} 30 | 31 | {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 32 | {{end}} 33 | {{end}} 34 | 35 | {{range $stmt := .DeleteStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 36 | {{end}} 37 | {{end}} 38 | 39 | 40 | // New{{UpperCamel $.Table.Name}}Model creates a new {{$.Table.Name}} model. 41 | func New{{UpperCamel $.Table.Name}}Model(db bun.IDB) *{{UpperCamel $.Table.Name}}Model { 42 | return &{{UpperCamel $.Table.Name}}Model{ 43 | db: db, 44 | } 45 | } 46 | 47 | // Create creates {{$.Table.Name}} data. 48 | func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { 49 | if len(data) == 0 { 50 | return fmt.Errorf("data is empty") 51 | } 52 | 53 | list := data[:] 54 | _,err := m.db.NewInsert().Model(&list).Exec(ctx) 55 | return err 56 | } 57 | {{range $stmt := .SelectStmt}} 58 | // {{.FuncName}} is generated from sql: 59 | // {{LineComment $stmt.SQL}} 60 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})({{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} error){ 61 | var result {{if $stmt.Limit.One}} = new({{$stmt.ReceiverName}}){{else}}[]*{{$stmt.ReceiverName}}{{end}} 62 | var db = m.db.NewSelect() 63 | db.Model({{if $stmt.Limit.One}}result{{else}}&result{{end}}) 64 | db.ColumnExpr(`{{$stmt.SelectSQL}}`) 65 | {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 66 | {{end }}{{if $stmt.GroupBy.IsValid}}db.GroupExpr({{$stmt.GroupBy.SQL}}) 67 | {{end}}{{if $stmt.Having.IsValid}}db.Having({{$stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}}) 68 | {{end}}{{if $stmt.OrderBy.IsValid}}db.OrderExpr({{$stmt.OrderBy.SQL}}) 69 | {{end}}{{if $stmt.Limit.IsValid}}db{{if gt $stmt.Limit.Offset 0}}.Offset({{$stmt.Limit.OffsetParameter "limit"}}){{end}}.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}) 70 | {{end}}err := db.Scan(ctx) 71 | return result, err 72 | } 73 | {{end}} 74 | 75 | {{range $stmt := .UpdateStmt}} 76 | // {{.FuncName}} is generated from sql: 77 | // {{LineComment $stmt.SQL}} 78 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}) error { 79 | var db = m.db.NewUpdate() 80 | db.Table("{{$.Table.Name}}") 81 | db.Model(&map[string]interface{}{ 82 | {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, 83 | {{end}} 84 | }) 85 | {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 86 | {{end}}_, err := db.Exec(ctx) 87 | return err 88 | } 89 | {{end}} 90 | 91 | {{range $stmt := .DeleteStmt}} 92 | // {{.FuncName}} is generated from sql: 93 | // {{LineComment $stmt.SQL}} 94 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}) error { 95 | var db = m.db.NewDelete() 96 | db.Model(&{{UpperCamel $.Table.Name}}{}) 97 | {{if $stmt.Where.IsValid}}db.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 98 | {{end}}_, err := db.Exec(ctx) 99 | return err 100 | } 101 | {{end}} -------------------------------------------------------------------------------- /internal/gen/bun/bun_test.go: -------------------------------------------------------------------------------- 1 | package bun 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/gen/testdata" 8 | "github.com/anqiansong/sqlgen/internal/parser" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRun(t *testing.T) { 14 | dxl, err := parser.Parse(testdata.TestSql) 15 | assert.NoError(t, err) 16 | ctx, err := spec.From(dxl) 17 | assert.NoError(t, err) 18 | err = Run(ctx, t.TempDir()) 19 | assert.NoError(t, err) 20 | } 21 | -------------------------------------------------------------------------------- /internal/gen/flags/flags.go: -------------------------------------------------------------------------------- 1 | package flags 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "path/filepath" 7 | 8 | "github.com/anqiansong/sqlgen/internal/gen/bun" 9 | "github.com/anqiansong/sqlgen/internal/gen/gorm" 10 | "github.com/anqiansong/sqlgen/internal/gen/sql" 11 | "github.com/anqiansong/sqlgen/internal/gen/sqlx" 12 | "github.com/anqiansong/sqlgen/internal/gen/xorm" 13 | "github.com/anqiansong/sqlgen/internal/log" 14 | "github.com/anqiansong/sqlgen/internal/parser" 15 | "github.com/anqiansong/sqlgen/internal/patterns" 16 | "github.com/anqiansong/sqlgen/internal/spec" 17 | ) 18 | 19 | const sqlExt = ".sql" 20 | 21 | type Mode int 22 | 23 | const ( 24 | SQL Mode = iota 25 | GORM 26 | XORM 27 | SQLX 28 | BUN 29 | ) 30 | 31 | type RunArg struct { 32 | DSN string 33 | Filename, Table []string 34 | Mode Mode 35 | Output string 36 | } 37 | 38 | func Run(arg RunArg) { 39 | var err error 40 | if len(arg.DSN) > 0 { 41 | err = runFromDSN(arg) 42 | } else if len(arg.Filename) > 0 { 43 | err = runFromSQL(arg) 44 | } else { 45 | err = fmt.Errorf("missing dsn or filename") 46 | } 47 | log.Must(err) 48 | } 49 | 50 | func runFromSQL(arg RunArg) error { 51 | var list []string 52 | for _, item := range arg.Filename { 53 | filename, err := filepath.Abs(item) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | var dir = filepath.Dir(filename) 59 | var base = filepath.Base(filename) 60 | fileInfo, err := ioutil.ReadDir(dir) 61 | if err != nil { 62 | return err 63 | } 64 | var filenames []string 65 | for _, item := range fileInfo { 66 | if item.IsDir() { 67 | continue 68 | } 69 | 70 | ext := filepath.Ext(item.Name()) 71 | if ext != sqlExt { 72 | continue 73 | } 74 | 75 | var f = filepath.Join(dir, item.Name()) 76 | filenames = append(filenames, f) 77 | } 78 | var p = patterns.New(base) 79 | var matchSQLFile = p.Match(filenames...) 80 | 81 | list = append(list, matchSQLFile...) 82 | } 83 | 84 | if len(list) == 0 { 85 | return fmt.Errorf("no sql file found") 86 | } 87 | 88 | var ret spec.DXL 89 | for _, file := range list { 90 | data, err := ioutil.ReadFile(file) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | dxl, err := parser.Parse(string(data)) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | ret.DDL = append(ret.DDL, dxl.DDL...) 101 | ret.DML = append(ret.DML, dxl.DML...) 102 | } 103 | 104 | return run(&ret, arg.Mode, arg.Output) 105 | } 106 | 107 | func runFromDSN(arg RunArg) error { 108 | dxl, err := parser.From(arg.DSN, arg.Table...) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | return run(dxl, arg.Mode, arg.Output) 114 | } 115 | 116 | var funcMap = map[Mode]func(context []spec.Context, output string) error{ 117 | SQL: sql.Run, 118 | GORM: gorm.Run, 119 | XORM: xorm.Run, 120 | SQLX: sqlx.Run, 121 | BUN: bun.Run, 122 | } 123 | 124 | func run(dxl *spec.DXL, mode Mode, output string) error { 125 | ctx, err := spec.From(dxl) 126 | if err != nil { 127 | return err 128 | } 129 | 130 | fn, ok := funcMap[mode] 131 | if !ok { 132 | return nil 133 | } 134 | 135 | return fn(ctx, output) 136 | } 137 | -------------------------------------------------------------------------------- /internal/gen/gorm/funcmap.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "text/template" 5 | ) 6 | 7 | var funcMap = template.FuncMap{} 8 | -------------------------------------------------------------------------------- /internal/gen/gorm/gorm.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "text/template" 8 | 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/anqiansong/sqlgen/internal/templatex" 11 | ) 12 | 13 | //go:embed gorm_gen.tpl 14 | var gormGenTpl string 15 | 16 | //go:embed gorm_custom.tpl 17 | var gormCustomTpl string 18 | 19 | func Run(list []spec.Context, output string) error { 20 | for _, ctx := range list { 21 | var genFilename = filepath.Join(output, fmt.Sprintf("%s_model.gen.go", ctx.Table.Name)) 22 | var customFilename = filepath.Join(output, fmt.Sprintf("%s_model.go", ctx.Table.Name)) 23 | gen := templatex.New() 24 | gen.AppendFuncMap(funcMap) 25 | gen.AppendFuncMap(template.FuncMap{ 26 | "IsPrimary": func(name string) bool { 27 | return ctx.Table.IsPrimary(name) 28 | }, 29 | "IsExtraResult": func(name string) bool { 30 | return name != templatex.UpperCamel(ctx.Table.Name) 31 | }, 32 | }) 33 | gen.MustParse(gormGenTpl) 34 | gen.MustExecute(ctx) 35 | gen.MustSaveAs(genFilename, true) 36 | 37 | custom := templatex.New() 38 | custom.MustParse(gormCustomTpl) 39 | custom.MustExecute(ctx) 40 | custom.MustSave(customFilename, true) 41 | } 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /internal/gen/gorm/gorm_custom.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { 7 | 8 | } -------------------------------------------------------------------------------- /internal/gen/gorm/gorm_test.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/gen/testdata" 8 | "github.com/anqiansong/sqlgen/internal/parser" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRun(t *testing.T) { 14 | dxl, err := parser.Parse(testdata.TestSql) 15 | assert.NoError(t, err) 16 | ctx, err := spec.From(dxl) 17 | assert.NoError(t, err) 18 | err = Run(ctx, t.TempDir()) 19 | assert.NoError(t, err) 20 | } 21 | -------------------------------------------------------------------------------- /internal/gen/gorm/table.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | -------------------------------------------------------------------------------- /internal/gen/sql/scanner.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "database/sql" 4 | 5 | type Scanner interface { 6 | ScanRow(rows *sql.Rows, v interface{}) error 7 | ScanRows(rows *sql.Rows, v interface{}) error 8 | ColumnMapper(colName string) string 9 | TagKey() string 10 | } 11 | -------------------------------------------------------------------------------- /internal/gen/sql/sql.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | "text/template" 9 | 10 | "github.com/iancoleman/strcase" 11 | 12 | "github.com/anqiansong/sqlgen/internal/spec" 13 | "github.com/anqiansong/sqlgen/internal/templatex" 14 | ) 15 | 16 | //go:embed sql_gen.tpl 17 | var sqlGenTpl string 18 | 19 | //go:embed sql_custom.tpl 20 | var sqlCustomTpl string 21 | 22 | //go:embed scanner.tpl 23 | var scannerTpl string 24 | 25 | func Run(list []spec.Context, output string) error { 26 | var scannerFilename = filepath.Join(output, "scanner.go") 27 | scanner := templatex.New() 28 | scanner.MustParse(scannerTpl) 29 | scanner.MustExecute(nil) 30 | scanner.MustSave(scannerFilename, true) 31 | 32 | for _, ctx := range list { 33 | var genFilename = filepath.Join(output, fmt.Sprintf("%s_model.gen.go", ctx.Table.Name)) 34 | var customFilename = filepath.Join(output, fmt.Sprintf("%s_model.go", ctx.Table.Name)) 35 | gen := templatex.New() 36 | var insertQuery, insertQuotes []string 37 | for _, v := range ctx.Table.Columns { 38 | if v.AutoIncrement { 39 | continue 40 | } 41 | insertQuery = append(insertQuery, fmt.Sprintf("`%s`", v.Name)) 42 | insertQuotes = append(insertQuotes, "?") 43 | } 44 | gen.AppendFuncMap(template.FuncMap{ 45 | "IsPrimary": func(name string) bool { 46 | return ctx.Table.IsPrimary(name) 47 | }, 48 | "InsertSQL": func() string { 49 | return strings.Join(insertQuery, ", ") 50 | }, 51 | "InsertQuotes": func() string { 52 | return strings.Join(insertQuotes, ", ") 53 | }, 54 | "InsertValues": func(pkg string) string { 55 | var values []string 56 | for _, v := range ctx.Table.Columns { 57 | if v.AutoIncrement { 58 | continue 59 | } 60 | values = append(values, fmt.Sprintf("%s.%s", pkg, strcase.ToCamel(v.Name))) 61 | } 62 | return strings.Join(values, ", ") 63 | }, 64 | "HavingSprintf": func(format string) string { 65 | format = strings.ReplaceAll(format, "?", "'%v'") 66 | return format 67 | }, 68 | }) 69 | gen.MustParse(sqlGenTpl) 70 | gen.MustExecute(ctx) 71 | gen.MustSaveAs(genFilename, true) 72 | 73 | custom := templatex.New() 74 | custom.MustParse(sqlCustomTpl) 75 | custom.MustExecute(ctx) 76 | custom.MustSave(customFilename, true) 77 | } 78 | return nil 79 | } 80 | -------------------------------------------------------------------------------- /internal/gen/sql/sql_custom.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { 7 | 8 | } -------------------------------------------------------------------------------- /internal/gen/sql/sql_gen.tpl: -------------------------------------------------------------------------------- 1 | // Code generated by sqlgen. DO NOT EDIT! 2 | 3 | package model 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "fmt" 9 | "time" 10 | 11 | "xorm.io/builder" 12 | "github.com/shopspring/decimal" 13 | ) 14 | 15 | // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. 16 | type {{UpperCamel $.Table.Name}}Model struct { 17 | db *sql.DB 18 | scanner Scanner 19 | } 20 | 21 | // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. 22 | type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} 23 | {{UpperCamel .Name}} {{.GoType}} `json:"{{LowerCamel .Name}}"`{{if .HasComment}}// {{TrimNewLine .Comment}}{{end}}{{end}} 24 | } 25 | {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 26 | {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} 27 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 28 | {{end}}{{$stmt.ReceiverStructure "sql"}} 29 | {{end}} 30 | 31 | {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 32 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 33 | {{end}} 34 | {{end}} 35 | 36 | {{range $stmt := .DeleteStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 37 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 38 | {{end}} 39 | {{end}} 40 | 41 | 42 | // New{{UpperCamel $.Table.Name}}Model creates a new {{$.Table.Name}} model. 43 | func New{{UpperCamel $.Table.Name}}Model(db *sql.DB, scanner Scanner) *{{UpperCamel $.Table.Name}}Model { 44 | return &{{UpperCamel $.Table.Name}}Model{ 45 | db: db, 46 | scanner: scanner, 47 | } 48 | } 49 | 50 | // Create creates {{$.Table.Name}} data. 51 | func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { 52 | if len(data) == 0 { 53 | return fmt.Errorf("data is empty") 54 | } 55 | 56 | var stmt *sql.Stmt 57 | stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") 58 | if err != nil { 59 | return err 60 | } 61 | defer stmt.Close() 62 | for _, v := range data { 63 | result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | id, err := result.LastInsertId() 69 | if err != nil { 70 | return err 71 | } 72 | 73 | {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} 74 | } 75 | return nil 76 | } 77 | {{range $stmt := .SelectStmt}} 78 | // {{.FuncName}} is generated from sql: 79 | // {{LineComment $stmt.SQL}} 80 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} 81 | result = new({{$stmt.ReceiverName}}){{end}} 82 | b := builder.MySQL() 83 | b.Select(`{{$stmt.SelectSQL}}`) 84 | b.From("`{{$.Table.Name}}`") 85 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 86 | {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) 87 | {{end}}{{if $stmt.Having.IsValid}}b.Having(fmt.Sprintf({{HavingSprintf $stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}})) 88 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 89 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 90 | {{end}}query, args, err := b.ToSQL() 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | rows, err := m.db.QueryContext(ctx, query, args...) 96 | if err != nil { 97 | return nil, err 98 | } 99 | defer rows.Close() 100 | 101 | if err = m.scanner. {{if $stmt.Limit.One}}ScanRow{{else}}ScanRows{{end}}(rows, &result); err != nil{ 102 | return nil, err 103 | } 104 | 105 | return result, nil 106 | } 107 | {{end}} 108 | 109 | {{range $stmt := .UpdateStmt}} 110 | // {{.FuncName}} is generated from sql: 111 | // {{LineComment $stmt.SQL}} 112 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 113 | b := builder.MySQL() 114 | b.Update(builder.Eq{ 115 | {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, 116 | {{end}} 117 | }) 118 | b.From("`{{$.Table.Name}}`") 119 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 120 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 121 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 122 | {{end}}query, args, err := b.ToSQL() 123 | if err != nil { 124 | return err 125 | } 126 | _, err = m.db.ExecContext(ctx, query, args...) 127 | return err 128 | } 129 | {{end}} 130 | 131 | {{range $stmt := .DeleteStmt}} 132 | // {{.FuncName}} is generated from sql: 133 | // {{LineComment $stmt.SQL}} 134 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 135 | b := builder.MySQL() 136 | b.Delete() 137 | b.From("`{{$.Table.Name}}`") 138 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 139 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 140 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 141 | {{end}}query, args, err := b.ToSQL() 142 | if err != nil { 143 | return err 144 | } 145 | _, err = m.db.ExecContext(ctx, query, args...) 146 | return err 147 | } 148 | {{end}} -------------------------------------------------------------------------------- /internal/gen/sql/sql_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/gen/testdata" 8 | "github.com/anqiansong/sqlgen/internal/parser" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRun(t *testing.T) { 14 | dxl, err := parser.Parse(testdata.TestSql) 15 | assert.NoError(t, err) 16 | ctx, err := spec.From(dxl) 17 | assert.NoError(t, err) 18 | err = Run(ctx, t.TempDir()) 19 | assert.NoError(t, err) 20 | } 21 | -------------------------------------------------------------------------------- /internal/gen/sqlx/sqlx.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | "text/template" 9 | 10 | "github.com/anqiansong/sqlgen/internal/spec" 11 | "github.com/anqiansong/sqlgen/internal/templatex" 12 | "github.com/iancoleman/strcase" 13 | ) 14 | 15 | //go:embed sqlx_gen.tpl 16 | var sqlxGenTpl string 17 | 18 | //go:embed sqlx_custom.tpl 19 | var sqlxCustomTpl string 20 | 21 | func Run(list []spec.Context, output string) error { 22 | for _, ctx := range list { 23 | var genFilename = filepath.Join(output, fmt.Sprintf("%s_model.gen.go", ctx.Table.Name)) 24 | var customFilename = filepath.Join(output, fmt.Sprintf("%s_model.go", ctx.Table.Name)) 25 | gen := templatex.New() 26 | var insertQuery, insertQuotes []string 27 | for _, v := range ctx.Table.Columns { 28 | if v.AutoIncrement { 29 | continue 30 | } 31 | insertQuery = append(insertQuery, fmt.Sprintf("`%s`", v.Name)) 32 | insertQuotes = append(insertQuotes, "?") 33 | } 34 | gen.AppendFuncMap(template.FuncMap{ 35 | "IsPrimary": func(name string) bool { 36 | return ctx.Table.IsPrimary(name) 37 | }, 38 | "InsertSQL": func() string { 39 | return strings.Join(insertQuery, ", ") 40 | }, 41 | "InsertQuotes": func() string { 42 | return strings.Join(insertQuotes, ", ") 43 | }, 44 | "InsertValues": func(pkg string) string { 45 | var values []string 46 | for _, v := range ctx.Table.Columns { 47 | if v.AutoIncrement { 48 | continue 49 | } 50 | values = append(values, fmt.Sprintf("%s.%s", pkg, strcase.ToCamel(v.Name))) 51 | } 52 | return strings.Join(values, ", ") 53 | }, 54 | "HavingSprintf": func(format string) string { 55 | format = strings.ReplaceAll(format, "?", "'%v'") 56 | return format 57 | }, 58 | }) 59 | gen.MustParse(sqlxGenTpl) 60 | gen.MustExecute(ctx) 61 | gen.MustSaveAs(genFilename, true) 62 | 63 | custom := templatex.New() 64 | custom.MustParse(sqlxCustomTpl) 65 | custom.MustExecute(ctx) 66 | custom.MustSave(customFilename, true) 67 | } 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /internal/gen/sqlx/sqlx_custom.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...any) { 7 | 8 | } -------------------------------------------------------------------------------- /internal/gen/sqlx/sqlx_gen.tpl: -------------------------------------------------------------------------------- 1 | // Code generated by sqlgen. DO NOT EDIT! 2 | 3 | package model 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "fmt" 9 | "time" 10 | 11 | "xorm.io/builder" 12 | "github.com/jmoiron/sqlx" 13 | "github.com/shopspring/decimal" 14 | ) 15 | 16 | // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. 17 | type {{UpperCamel $.Table.Name}}Model struct { 18 | db *sqlx.DB 19 | } 20 | 21 | // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. 22 | type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} 23 | {{UpperCamel .Name}} {{.GoType}} `db:"{{.Name}}" json:"{{LowerCamel .Name}}"`{{if .HasComment}}// {{TrimNewLine .Comment}}{{end}}{{end}} 24 | } 25 | {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 26 | {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} 27 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 28 | {{end}}{{$stmt.ReceiverStructure "sqlx"}} 29 | {{end}} 30 | 31 | {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 32 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 33 | {{end}} 34 | {{end}} 35 | 36 | {{range $stmt := .DeleteStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 37 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 38 | {{end}} 39 | {{end}} 40 | 41 | 42 | // New{{UpperCamel $.Table.Name}}Model creates a new {{$.Table.Name}} model. 43 | func New{{UpperCamel $.Table.Name}}Model(db *sqlx.DB) *{{UpperCamel $.Table.Name}}Model { 44 | return &{{UpperCamel $.Table.Name}}Model{ 45 | db: db, 46 | } 47 | } 48 | 49 | // Create creates {{$.Table.Name}} data. 50 | func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { 51 | if len(data) == 0 { 52 | return fmt.Errorf("data is empty") 53 | } 54 | 55 | var stmt *sql.Stmt 56 | stmt, err := m.db.PrepareContext(ctx, "INSERT INTO {{$.Table.Name}} ({{InsertSQL}}) VALUES ({{InsertQuotes}})") 57 | if err != nil { 58 | return err 59 | } 60 | defer stmt.Close() 61 | for _, v := range data { 62 | result, err := stmt.ExecContext(ctx, {{InsertValues "v"}}) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | id, err := result.LastInsertId() 68 | if err != nil { 69 | return err 70 | } 71 | 72 | {{range $.Table.Columns}}{{if IsPrimary .Name}}{{if .AutoIncrement}}v.{{UpperCamel .Name}} = {{.GoType}}(id){{end}}{{end}}{{end}} 73 | } 74 | return nil 75 | } 76 | {{range $stmt := .SelectStmt}} 77 | // {{.FuncName}} is generated from sql: 78 | // {{LineComment $stmt.SQL}} 79 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})(result {{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} err error){ {{if $stmt.Limit.One}} 80 | result = new({{$stmt.ReceiverName}}){{end}} 81 | b := builder.MySQL() 82 | b.Select(`{{$stmt.SelectSQL}}`) 83 | b.From("`{{$.Table.Name}}`") 84 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 85 | {{end }}{{if $stmt.GroupBy.IsValid}}b.GroupBy({{$stmt.GroupBy.SQL}}) 86 | {{end}}{{if $stmt.Having.IsValid}}b.Having(fmt.Sprintf({{HavingSprintf $stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}})) 87 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 88 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 89 | {{end}}query, args, err := b.ToSQL() 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | var rows *sqlx.Rows 95 | rows, err = m.db.QueryxContext(ctx, query, args...) 96 | if err != nil { 97 | return nil, err 98 | } 99 | defer rows.Close() 100 | 101 | {{if $stmt.Limit.One}}if !rows.Next() { 102 | return nil, sql.ErrNoRows 103 | } 104 | 105 | err = rows.StructScan(result) 106 | if err != nil { 107 | return nil, err 108 | } 109 | 110 | return result, nil 111 | 112 | {{else}} 113 | for rows.Next() { 114 | var v {{$stmt.ReceiverName}} 115 | err = rows.StructScan(&v) 116 | if err != nil { 117 | return nil, err 118 | } 119 | result = append(result, &v) 120 | } 121 | 122 | return result, nil 123 | {{end}} 124 | } 125 | {{end}} 126 | 127 | {{range $stmt := .UpdateStmt}} 128 | // {{.FuncName}} is generated from sql: 129 | // {{LineComment $stmt.SQL}} 130 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 131 | b := builder.MySQL() 132 | b.Update(builder.Eq{ 133 | {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, 134 | {{end}} 135 | }) 136 | b.From("`{{$.Table.Name}}`") 137 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 138 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 139 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 140 | {{end}}query, args, err := b.ToSQL() 141 | if err != nil { 142 | return err 143 | } 144 | _, err = m.db.ExecContext(ctx, query, args...) 145 | return err 146 | } 147 | {{end}} 148 | 149 | {{range $stmt := .DeleteStmt}} 150 | // {{.FuncName}} is generated from sql: 151 | // {{LineComment $stmt.SQL}} 152 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 153 | b := builder.MySQL() 154 | b.Delete() 155 | b.From("`{{$.Table.Name}}`") 156 | {{if $stmt.Where.IsValid}}b.Where(builder.Expr({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}})) 157 | {{end}}{{if $stmt.OrderBy.IsValid}}b.OrderBy({{$stmt.OrderBy.SQL}}) 158 | {{end}}{{if $stmt.Limit.IsValid}}b.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 159 | {{end}}query, args, err := b.ToSQL() 160 | if err != nil { 161 | return err 162 | } 163 | _, err = m.db.ExecContext(ctx, query, args...) 164 | return err 165 | } 166 | {{end}} -------------------------------------------------------------------------------- /internal/gen/sqlx/sqlx_test.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/gen/testdata" 8 | "github.com/anqiansong/sqlgen/internal/parser" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRun(t *testing.T) { 14 | dxl, err := parser.Parse(testdata.TestSql) 15 | assert.NoError(t, err) 16 | ctx, err := spec.From(dxl) 17 | assert.NoError(t, err) 18 | err = Run(ctx, t.TempDir()) 19 | assert.NoError(t, err) 20 | } 21 | -------------------------------------------------------------------------------- /internal/gen/testdata/test.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE `foo` 2 | ( 3 | `id` bigint(10) unsigned NOT NULL AUTO_INCREMENT primary key, 4 | `name` varchar(255) COLLATE utf8mb4_general_ci NULL, 5 | UNIQUE KEY `name_index` (`name`) 6 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; 7 | 8 | -- fn: Count 9 | select count(id) AS count from foo; 10 | 11 | -- fn: FindOne 12 | select name, count(id) AS c from foo where id > ? having c > ? limit 1; -------------------------------------------------------------------------------- /internal/gen/testdata/testdata.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | import _ "embed" 4 | 5 | //go:embed test.sql 6 | var TestSql string 7 | -------------------------------------------------------------------------------- /internal/gen/xorm/xorm.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | "text/template" 9 | 10 | "github.com/anqiansong/sqlgen/internal/spec" 11 | "github.com/anqiansong/sqlgen/internal/templatex" 12 | ) 13 | 14 | //go:embed xorm_gen.tpl 15 | var xormGenTpl string 16 | 17 | //go:embed xorm_custom.tpl 18 | var xormCustomTpl string 19 | 20 | func Run(list []spec.Context, output string) error { 21 | for _, ctx := range list { 22 | var genFilename = filepath.Join(output, fmt.Sprintf("%s_model.gen.go", ctx.Table.Name)) 23 | var customFilename = filepath.Join(output, fmt.Sprintf("%s_model.go", ctx.Table.Name)) 24 | gen := templatex.New() 25 | gen.AppendFuncMap(template.FuncMap{ 26 | "IsPrimary": func(name string) bool { 27 | return ctx.Table.IsPrimary(name) 28 | }, 29 | "HavingSprintf": func(format string) string { 30 | format = strings.ReplaceAll(format, "?", "'%v'") 31 | return format 32 | }, 33 | "IsExtraResult": func(name string) bool { 34 | return name != templatex.UpperCamel(ctx.Table.Name) 35 | }, 36 | }) 37 | gen.MustParse(xormGenTpl) 38 | gen.MustExecute(ctx) 39 | gen.MustSaveAs(genFilename, true) 40 | 41 | custom := templatex.New() 42 | custom.MustParse(xormCustomTpl) 43 | custom.MustExecute(ctx) 44 | custom.MustSave(customFilename, true) 45 | } 46 | 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /internal/gen/xorm/xorm_custom.tpl: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "context" 4 | 5 | // TODO(sqlgen): Add your own customize code here. 6 | func (m *{{UpperCamel $.Table.Name}}Model)Customize(ctx context.Context, args...interface{}) { 7 | 8 | } -------------------------------------------------------------------------------- /internal/gen/xorm/xorm_gen.tpl: -------------------------------------------------------------------------------- 1 | // Code generated by sqlgen. DO NOT EDIT! 2 | 3 | package model 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "time" 9 | 10 | "xorm.io/xorm" 11 | 12 | "github.com/shopspring/decimal" 13 | ) 14 | 15 | // {{UpperCamel $.Table.Name}}Model represents a {{$.Table.Name}} model. 16 | type {{UpperCamel $.Table.Name}}Model struct { 17 | engine xorm.EngineInterface 18 | } 19 | 20 | // {{UpperCamel $.Table.Name}} represents a {{$.Table.Name}} struct data. 21 | type {{UpperCamel $.Table.Name}} struct { {{range $.Table.Columns}} 22 | {{UpperCamel .Name}} {{.GoType}} `xorm:"{{if IsPrimary .Name}}pk {{end}}{{if .AutoIncrement}}autoincr {{end}}'{{.Name}}'" json:"{{LowerCamel .Name}}"`{{if .HasComment}}// {{TrimNewLine .Comment}}{{end}}{{end}} 23 | } 24 | 25 | {{range $stmt := .SelectStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 26 | {{end}}{{if $stmt.Having.IsValid}}{{$stmt.Having.ParameterStructure "Having"}} 27 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 28 | {{end}} 29 | 30 | {{if IsExtraResult $stmt.ReceiverName}} 31 | {{$stmt.ReceiverStructure "xorm"}} 32 | 33 | // TableName returns the table name. it implemented by gorm.Tabler. 34 | func ({{$stmt.ReceiverName}}) TableName() string { 35 | return "{{$.Table.Name}}" 36 | }{{end}} 37 | {{end}} 38 | 39 | {{range $stmt := .UpdateStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 40 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 41 | {{end}} 42 | {{end}} 43 | 44 | {{range $stmt := .DeleteStmt}}{{if $stmt.Where.IsValid}}{{$stmt.Where.ParameterStructure "Where"}} 45 | {{end}}{{if $stmt.Limit.Multiple}}{{$stmt.Limit.ParameterStructure}} 46 | {{end}} 47 | {{end}} 48 | 49 | func ({{UpperCamel $.Table.Name}}) TableName() string{ 50 | return "{{$.Table.Name}}" 51 | } 52 | 53 | // New{{UpperCamel $.Table.Name}}Model returns a new {{$.Table.Name}} model. 54 | func New{{UpperCamel $.Table.Name}}Model (engine xorm.EngineInterface) *{{UpperCamel $.Table.Name}}Model { 55 | return &{{UpperCamel $.Table.Name}}Model{engine: engine} 56 | } 57 | 58 | // Create creates {{$.Table.Name}} data. 59 | func (m *{{UpperCamel $.Table.Name}}Model) Create(ctx context.Context, data ...*{{UpperCamel $.Table.Name}}) error { 60 | if len(data)==0{ 61 | return fmt.Errorf("data is empty") 62 | } 63 | 64 | var session = m.engine.Context(ctx) 65 | var list []interface{} 66 | for _, v := range data { 67 | list = append(list, v) 68 | } 69 | 70 | _,err := session.Insert(list...) 71 | return err 72 | } 73 | {{range $stmt := .SelectStmt}} 74 | // {{.FuncName}} is generated from sql: 75 | // {{LineComment $stmt.SQL}} 76 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Having.IsValid}}, having {{$stmt.Having.ParameterStructureName "Having"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}})({{if $stmt.Limit.One}}*{{$stmt.ReceiverName}}, {{else}}[]*{{$stmt.ReceiverName}}, {{end}} error){ 77 | var result {{if $stmt.Limit.One}} = new({{$stmt.ReceiverName}}){{else}}[]*{{$stmt.ReceiverName}}{{end}} 78 | var session = m.engine.Context(ctx) 79 | session.Select(`{{$stmt.SelectSQL}}`) 80 | {{if $stmt.Where.IsValid}}session.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 81 | {{end }}{{if $stmt.GroupBy.IsValid}}session.GroupBy({{$stmt.GroupBy.SQL}}) 82 | {{end}}{{if $stmt.Having.IsValid}}session.Having(fmt.Sprintf({{HavingSprintf $stmt.Having.SQL}}, {{$stmt.Having.Parameters "having"}})) 83 | {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) 84 | {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 85 | {{end}}{{if $stmt.Limit.One}}has, err := session.Get(result) 86 | if !has{ 87 | return nil, sql.ErrNoRows 88 | } 89 | {{else}}err :=session.Find(&result){{end}} 90 | return result, err 91 | } 92 | {{end}} 93 | 94 | {{range $stmt := .UpdateStmt}} 95 | // {{.FuncName}} is generated from sql: 96 | // {{LineComment $stmt.SQL}} 97 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context, data *{{UpperCamel $.Table.Name}}{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 98 | var session = m.engine.Context(ctx) 99 | session.Table(&{{UpperCamel $.Table.Name}}{}) 100 | {{if $stmt.Where.IsValid}}session.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 101 | {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) 102 | {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 103 | {{end}}_, err := session.Update(map[string]interface{}{ 104 | {{range $name := $stmt.Columns}}"{{$name}}": data.{{UpperCamel $name}}, 105 | {{end}} 106 | }) 107 | return err 108 | } 109 | {{end}} 110 | 111 | {{range $stmt := .DeleteStmt}} 112 | // {{.FuncName}} is generated from sql: 113 | // {{LineComment $stmt.SQL}} 114 | func (m *{{UpperCamel $.Table.Name}}Model){{.FuncName}}(ctx context.Context{{if $stmt.Where.IsValid}}, where {{$stmt.Where.ParameterStructureName "Where"}}{{end}}{{if $stmt.Limit.Multiple}}, limit {{$stmt.Limit.ParameterStructureName}}{{end}}) error { 115 | var session = m.engine.Context(ctx) 116 | {{if $stmt.Where.IsValid}}session.Where({{$stmt.Where.SQL}}, {{$stmt.Where.Parameters "where"}}) 117 | {{end}}{{if $stmt.OrderBy.IsValid}}session.OrderBy({{$stmt.OrderBy.SQL}}) 118 | {{end}}{{if $stmt.Limit.IsValid}}session.Limit({{if $stmt.Limit.One}}1{{else}}{{$stmt.Limit.LimitParameter "limit"}}{{end}}{{if gt $stmt.Limit.Offset 0}}, {{$stmt.Limit.OffsetParameter "limit"}}{{end}}) 119 | {{end}}_, err := session.Delete(&{{UpperCamel $.Table.Name}}{}) 120 | return err 121 | } 122 | {{end}} 123 | -------------------------------------------------------------------------------- /internal/gen/xorm/xorm_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/gen/testdata" 8 | "github.com/anqiansong/sqlgen/internal/parser" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRun(t *testing.T) { 14 | dxl, err := parser.Parse(testdata.TestSql) 15 | assert.NoError(t, err) 16 | ctx, err := spec.From(dxl) 17 | assert.NoError(t, err) 18 | err = Run(ctx, t.TempDir()) 19 | assert.NoError(t, err) 20 | } 21 | -------------------------------------------------------------------------------- /internal/infoschema/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 zeromicro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /internal/infoschema/infoschemamodel.go: -------------------------------------------------------------------------------- 1 | // zeromicro copyright,do not edit. 2 | 3 | // MIT License 4 | // 5 | //Copyright (c) 2022 zeromicro 6 | // 7 | //Permission is hereby granted, free of charge, to any person obtaining a copy 8 | //of this software and associated documentation files (the "Software"), to deal 9 | //in the Software without restriction, including without limitation the rights 10 | //to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | //copies of the Software, and to permit persons to whom the Software is 12 | //furnished to do so, subject to the following conditions: 13 | // 14 | //The above copyright notice and this permission notice shall be included in all 15 | //copies or substantial portions of the Software. 16 | // 17 | //THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | //IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | //FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | //AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | //LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | //SOFTWARE. 24 | 25 | // Filepath: go-zero/tools/goctl/model/sql/model/infoschemamodel.go 26 | 27 | package infoschema 28 | 29 | import ( 30 | "sort" 31 | 32 | "github.com/zeromicro/go-zero/core/stores/sqlx" 33 | ) 34 | 35 | var _ IInformationSchema = (*InformationSchemaModel)(nil) 36 | 37 | type ( 38 | // IInformationSchema defines an interface for schema. 39 | // Just for mock. 40 | IInformationSchema interface { 41 | GetAllTables(database string) ([]string, error) 42 | FindColumns(db, table string) (*Table, error) 43 | FindIndex(db, table, column string) ([]*DbIndex, error) 44 | } 45 | 46 | // InformationSchemaModel defines information schema model 47 | InformationSchemaModel struct { 48 | conn sqlx.SqlConn 49 | } 50 | 51 | // Column defines column in table 52 | Column struct { 53 | *DbColumn 54 | Index *DbIndex 55 | } 56 | 57 | // DbColumn defines column info of columns 58 | DbColumn struct { 59 | Name string `db:"COLUMN_NAME"` 60 | DataType string `db:"DATA_TYPE"` 61 | ColumnType string `db:"COLUMN_TYPE"` 62 | Extra string `db:"EXTRA"` 63 | Comment string `db:"COLUMN_COMMENT"` 64 | ColumnDefault interface{} `db:"COLUMN_DEFAULT"` 65 | IsNullAble string `db:"IS_NULLABLE"` 66 | OrdinalPosition int `db:"ORDINAL_POSITION"` 67 | } 68 | 69 | // DbIndex defines index of columns in information_schema.statistic 70 | DbIndex struct { 71 | IndexName string `db:"INDEX_NAME"` 72 | NonUnique int `db:"NON_UNIQUE"` 73 | SeqInIndex int `db:"SEQ_IN_INDEX"` 74 | } 75 | 76 | // Table defines table data 77 | Table struct { 78 | Db string 79 | Table string 80 | Columns []*Column 81 | } 82 | ) 83 | 84 | // NewInformationSchemaModel creates an instance for InformationSchemaModel 85 | func NewInformationSchemaModel(conn sqlx.SqlConn) IInformationSchema { 86 | return &InformationSchemaModel{conn: conn} 87 | } 88 | 89 | // GetAllTables selects all tables from TABLE_SCHEMA 90 | func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error) { 91 | query := `select TABLE_NAME from TABLES where TABLE_SCHEMA = ?` 92 | var tables []string 93 | err := m.conn.QueryRows(&tables, query, database) 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | return tables, nil 99 | } 100 | 101 | // FindColumns return columns in specified database and table 102 | func (m *InformationSchemaModel) FindColumns(db, table string) (*Table, error) { 103 | querySql := `SELECT c.COLUMN_NAME,c.DATA_TYPE,c.COLUMN_TYPE,EXTRA,c.COLUMN_COMMENT,c.COLUMN_DEFAULT,c.IS_NULLABLE,c.ORDINAL_POSITION from COLUMNS c WHERE c.TABLE_SCHEMA = ? and c.TABLE_NAME = ?` 104 | var reply []*DbColumn 105 | err := m.conn.QueryRowsPartial(&reply, querySql, db, table) 106 | if err != nil { 107 | return nil, err 108 | } 109 | 110 | var list []*Column 111 | for _, item := range reply { 112 | index, err := m.FindIndex(db, table, item.Name) 113 | if err != nil { 114 | if err != sqlx.ErrNotFound { 115 | return nil, err 116 | } 117 | continue 118 | } 119 | 120 | if len(index) > 0 { 121 | for _, i := range index { 122 | list = append(list, &Column{ 123 | DbColumn: item, 124 | Index: i, 125 | }) 126 | } 127 | } else { 128 | list = append(list, &Column{ 129 | DbColumn: item, 130 | }) 131 | } 132 | } 133 | 134 | sort.Slice(list, func(i, j int) bool { 135 | return list[i].OrdinalPosition < list[j].OrdinalPosition 136 | }) 137 | 138 | var ret Table 139 | ret.Db = db 140 | ret.Table = table 141 | ret.Columns = list 142 | return &ret, nil 143 | } 144 | 145 | // FindIndex finds index with given db, table and column. 146 | func (m *InformationSchemaModel) FindIndex(db, table, column string) ([]*DbIndex, error) { 147 | querySql := `SELECT s.INDEX_NAME,s.NON_UNIQUE,s.SEQ_IN_INDEX from STATISTICS s WHERE s.TABLE_SCHEMA = ? and s.TABLE_NAME = ? and s.COLUMN_NAME = ?` 148 | var reply []*DbIndex 149 | err := m.conn.QueryRowsPartial(&reply, querySql, db, table, column) 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | return reply, nil 155 | } 156 | -------------------------------------------------------------------------------- /internal/infoschema/infoschemamodel_test.go: -------------------------------------------------------------------------------- 1 | package infoschema 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/zeromicro/go-zero/core/logx" 10 | "github.com/zeromicro/go-zero/core/stores/sqlx" 11 | ) 12 | 13 | var dummyError = errors.New("dummy") 14 | 15 | func TestNewInformationSchemaModel(t *testing.T) { 16 | conn := sqlx.NewMysql("foo") 17 | instance := NewInformationSchemaModel(conn) 18 | assert.NotNil(t, instance) 19 | } 20 | 21 | func TestInformationSchemaModel_GetAllTables(t *testing.T) { 22 | logx.Disable() 23 | var query = `select TABLE_NAME from TABLES where TABLE_SCHEMA = ?` 24 | var database = "foo" 25 | var mockTableNames = []string{"foo"} 26 | 27 | db, mock, err := sqlmock.New() 28 | assert.NoError(t, err) 29 | mock.ExpectQuery(query).WithArgs(database).WillReturnRows(sqlmock.NewRows(mockTableNames)) 30 | 31 | conn := sqlx.NewSqlConnFromDB(db) 32 | model := NewInformationSchemaModel(conn) 33 | _, err = model.GetAllTables(database) 34 | assert.NoError(t, err) 35 | 36 | mock.ExpectQuery(query).WithArgs(database).WillReturnError(dummyError) 37 | _, err = model.GetAllTables(database) 38 | assert.ErrorIs(t, err, dummyError) 39 | } 40 | 41 | func TestInformationSchemaModel_FindColumns(t *testing.T) { 42 | logx.Disable() 43 | var indexQuery = `SELECT s.INDEX_NAME,s.NON_UNIQUE,s.SEQ_IN_INDEX from STATISTICS s WHERE s.TABLE_SCHEMA = ? and s.TABLE_NAME = ? and s.COLUMN_NAME = ?` 44 | var query = `SELECT c.COLUMN_NAME,c.DATA_TYPE,c.COLUMN_TYPE,EXTRA,c.COLUMN_COMMENT,c.COLUMN_DEFAULT,c.IS_NULLABLE,c.ORDINAL_POSITION from COLUMNS c WHERE c.TABLE_SCHEMA = ? and c.TABLE_NAME = ?` 45 | var database = "foo" 46 | var table = "bar" 47 | var column = "baz" 48 | var mockTableNames = []string{"foo"} 49 | 50 | db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 51 | assert.NoError(t, err) 52 | mock.ExpectQuery(query).WithArgs(database, table).WillReturnRows(sqlmock.NewRows(mockTableNames)) 53 | mock.ExpectQuery(indexQuery).WithArgs(database, table, column).WillReturnRows(sqlmock.NewRows(mockTableNames)) 54 | 55 | conn := sqlx.NewSqlConnFromDB(db) 56 | model := NewInformationSchemaModel(conn) 57 | _, err = model.FindColumns(database, table) 58 | assert.NoError(t, err) 59 | 60 | db, mock, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 61 | assert.NoError(t, err) 62 | conn = sqlx.NewSqlConnFromDB(db) 63 | model = NewInformationSchemaModel(conn) 64 | mock.ExpectQuery(query).WithArgs(database, table).WillReturnError(dummyError) 65 | _, err = model.FindColumns(database, table) 66 | assert.ErrorIs(t, err, dummyError) 67 | } 68 | 69 | func TestInformationSchemaModel_FindIndex(t *testing.T) { 70 | logx.Disable() 71 | var indexQuery = `SELECT s.INDEX_NAME,s.NON_UNIQUE,s.SEQ_IN_INDEX from STATISTICS s WHERE s.TABLE_SCHEMA = ? and s.TABLE_NAME = ? and s.COLUMN_NAME = ?` 72 | var database = "foo" 73 | var table = "bar" 74 | var column = "baz" 75 | var mockTableNames = []string{"foo"} 76 | 77 | db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 78 | assert.NoError(t, err) 79 | mock.ExpectQuery(indexQuery).WithArgs(database, table, column).WillReturnRows(sqlmock.NewRows(mockTableNames)) 80 | 81 | conn := sqlx.NewSqlConnFromDB(db) 82 | model := NewInformationSchemaModel(conn) 83 | _, err = model.FindIndex(database, table, column) 84 | assert.NoError(t, err) 85 | 86 | mock.ExpectQuery(indexQuery).WithArgs(database, table, column).WillReturnError(dummyError) 87 | _, err = model.FindIndex(database, table, column) 88 | assert.ErrorIs(t, err, dummyError) 89 | } 90 | -------------------------------------------------------------------------------- /internal/infoschema/mock_infoschemamodel.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: infoschemamodel.go 3 | 4 | // Package infoschema is a generated GoMock package. 5 | package infoschema 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | ) 12 | 13 | // MockIInformationSchema is a mock of IInformationSchema interface. 14 | type MockIInformationSchema struct { 15 | ctrl *gomock.Controller 16 | recorder *MockIInformationSchemaMockRecorder 17 | } 18 | 19 | // MockIInformationSchemaMockRecorder is the mock recorder for MockIInformationSchema. 20 | type MockIInformationSchemaMockRecorder struct { 21 | mock *MockIInformationSchema 22 | } 23 | 24 | // NewMockIInformationSchema creates a new mock instance. 25 | func NewMockIInformationSchema(ctrl *gomock.Controller) *MockIInformationSchema { 26 | mock := &MockIInformationSchema{ctrl: ctrl} 27 | mock.recorder = &MockIInformationSchemaMockRecorder{mock} 28 | return mock 29 | } 30 | 31 | // EXPECT returns an object that allows the caller to indicate expected use. 32 | func (m *MockIInformationSchema) EXPECT() *MockIInformationSchemaMockRecorder { 33 | return m.recorder 34 | } 35 | 36 | // FindColumns mocks base method. 37 | func (m *MockIInformationSchema) FindColumns(db, table string) (*Table, error) { 38 | m.ctrl.T.Helper() 39 | ret := m.ctrl.Call(m, "FindColumns", db, table) 40 | ret0, _ := ret[0].(*Table) 41 | ret1, _ := ret[1].(error) 42 | return ret0, ret1 43 | } 44 | 45 | // FindColumns indicates an expected call of FindColumns. 46 | func (mr *MockIInformationSchemaMockRecorder) FindColumns(db, table interface{}) *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindColumns", reflect.TypeOf((*MockIInformationSchema)(nil).FindColumns), db, table) 49 | } 50 | 51 | // FindIndex mocks base method. 52 | func (m *MockIInformationSchema) FindIndex(db, table, column string) ([]*DbIndex, error) { 53 | m.ctrl.T.Helper() 54 | ret := m.ctrl.Call(m, "FindIndex", db, table, column) 55 | ret0, _ := ret[0].([]*DbIndex) 56 | ret1, _ := ret[1].(error) 57 | return ret0, ret1 58 | } 59 | 60 | // FindIndex indicates an expected call of FindIndex. 61 | func (mr *MockIInformationSchemaMockRecorder) FindIndex(db, table, column interface{}) *gomock.Call { 62 | mr.mock.ctrl.T.Helper() 63 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIndex", reflect.TypeOf((*MockIInformationSchema)(nil).FindIndex), db, table, column) 64 | } 65 | 66 | // GetAllTables mocks base method. 67 | func (m *MockIInformationSchema) GetAllTables(database string) ([]string, error) { 68 | m.ctrl.T.Helper() 69 | ret := m.ctrl.Call(m, "GetAllTables", database) 70 | ret0, _ := ret[0].([]string) 71 | ret1, _ := ret[1].(error) 72 | return ret0, ret1 73 | } 74 | 75 | // GetAllTables indicates an expected call of GetAllTables. 76 | func (mr *MockIInformationSchemaMockRecorder) GetAllTables(database interface{}) *gomock.Call { 77 | mr.mock.ctrl.T.Helper() 78 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTables", reflect.TypeOf((*MockIInformationSchema)(nil).GetAllTables), database) 79 | } 80 | -------------------------------------------------------------------------------- /internal/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | func Must(err error) { 9 | if err == nil { 10 | return 11 | } 12 | fmt.Println(err.Error()) 13 | os.Exit(1) 14 | } 15 | -------------------------------------------------------------------------------- /internal/log/log_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import "testing" 4 | 5 | func TestMust(t *testing.T) { 6 | Must(nil) 7 | } 8 | -------------------------------------------------------------------------------- /internal/parameter/parameter.go: -------------------------------------------------------------------------------- 1 | package parameter 2 | 3 | import ( 4 | "github.com/anqiansong/sqlgen/internal/set" 5 | "github.com/anqiansong/sqlgen/internal/stringx" 6 | ) 7 | 8 | type p struct { 9 | s *set.ListSet 10 | } 11 | 12 | // Parameter represents an original description data for code generation structure info. 13 | type Parameter struct { 14 | // Column represents a parameter name. 15 | Column string 16 | // Type represents a parameter go type. 17 | Type string 18 | // ThirdPkg represents a go type which is a third package or go built-in package. 19 | ThirdPkg string 20 | } 21 | 22 | // Parameters returns the parameters. 23 | type Parameters []Parameter 24 | 25 | // Empty is a placeholder of Parameters. 26 | var Empty = Parameters{} 27 | 28 | func New() *p { 29 | return &p{s: set.From()} 30 | } 31 | 32 | func (p *p) Add(parameter ...Parameter) { 33 | for _, v := range parameter { 34 | for { 35 | if p.s.Exists(v) { 36 | v.Column = stringx.AutoIncrement(v.Column, 1) 37 | continue 38 | } 39 | p.s.Add(v) 40 | break 41 | } 42 | } 43 | } 44 | 45 | func (p *p) List() Parameters { 46 | var ret Parameters 47 | p.s.Range(func(v interface{}) { 48 | ret = append(ret, v.(Parameter)) 49 | }) 50 | return ret 51 | } 52 | -------------------------------------------------------------------------------- /internal/parameter/parameter_test.go: -------------------------------------------------------------------------------- 1 | package parameter 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNew(t *testing.T) { 10 | instance := New() 11 | assert.NotNil(t, instance) 12 | } 13 | 14 | func TestP_Add(t *testing.T) { 15 | instance := New() 16 | instance.Add(Parameter{ 17 | Column: "foo", 18 | Type: "bar", 19 | ThirdPkg: "baz", 20 | }, Parameter{ 21 | Column: "foo1", 22 | Type: "bar1", 23 | ThirdPkg: "baz1", 24 | }, Parameter{ 25 | Column: "foo", 26 | Type: "bar", 27 | ThirdPkg: "baz", 28 | }) 29 | } 30 | 31 | func TestP_List(t *testing.T) { 32 | instance := New() 33 | expected := Parameters{ 34 | Parameter{ 35 | Column: "foo", 36 | Type: "bar", 37 | ThirdPkg: "baz", 38 | }, 39 | Parameter{ 40 | Column: "foo1", 41 | Type: "bar", 42 | ThirdPkg: "baz", 43 | }, 44 | Parameter{ 45 | Column: "foo1", 46 | Type: "bar1", 47 | ThirdPkg: "baz1", 48 | }, 49 | } 50 | instance.Add( 51 | Parameter{ 52 | Column: "foo", 53 | Type: "bar", 54 | ThirdPkg: "baz", 55 | }, 56 | Parameter{ 57 | Column: "foo", 58 | Type: "bar", 59 | ThirdPkg: "baz", 60 | }, 61 | Parameter{ 62 | Column: "foo1", 63 | Type: "bar1", 64 | ThirdPkg: "baz1", 65 | }) 66 | actual := instance.List() 67 | assert.Equal(t, len(expected), len(actual)) 68 | for idx, expectedItem := range expected { 69 | actualItem := actual[idx] 70 | assert.Equal(t, expectedItem, actualItem) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /internal/parser/column.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | "github.com/pingcap/parser/mysql" 6 | "github.com/pingcap/parser/test_driver" 7 | 8 | "github.com/anqiansong/sqlgen/internal/set" 9 | "github.com/anqiansong/sqlgen/internal/spec" 10 | ) 11 | 12 | func parseColumnDef(col *ast.ColumnDef) (*spec.Column, *spec.Constraint) { 13 | if col == nil || col.Name == nil { 14 | return nil, nil 15 | } 16 | 17 | var column spec.Column 18 | var constraint = spec.NewConstraint() 19 | var tp = col.Tp 20 | if tp != nil { 21 | column.Unsigned = mysql.HasUnsignedFlag(tp.Flag) 22 | column.TP = tp.Tp 23 | } 24 | 25 | column.Name = col.Name.String() 26 | for _, opt := range col.Options { 27 | var tp = opt.Tp 28 | switch tp { 29 | case ast.ColumnOptionNotNull: 30 | column.NotNull = true 31 | case ast.ColumnOptionAutoIncrement: 32 | column.AutoIncrement = true 33 | case ast.ColumnOptionDefaultValue: 34 | column.HasDefaultValue = true 35 | case ast.ColumnOptionComment: 36 | var expr = opt.Expr 37 | if expr != nil { 38 | value, ok := expr.(*test_driver.ValueExpr) 39 | if ok { 40 | column.Comment = value.GetString() 41 | } 42 | } 43 | case ast.ColumnOptionUniqKey: 44 | constraint.AppendUniqueKey(column.Name, column.Name) 45 | case ast.ColumnOptionPrimaryKey: 46 | constraint.AppendPrimaryKey(column.Name, column.Name) 47 | default: 48 | // ignore other options 49 | } 50 | } 51 | 52 | return &column, constraint 53 | } 54 | 55 | func parseConstraint(constraint *ast.Constraint) *spec.Constraint { 56 | if constraint == nil { 57 | return nil 58 | } 59 | 60 | var columns = parseColumnFromKeys(constraint.Keys) 61 | if len(columns) == 0 { 62 | return nil 63 | } 64 | 65 | var ret = spec.NewConstraint() 66 | var key = constraint.Name 67 | switch constraint.Tp { 68 | case ast.ConstraintPrimaryKey: 69 | ret.AppendPrimaryKey(key, columns...) 70 | case ast.ConstraintKey, ast.ConstraintIndex: 71 | ret.AppendIndex(key, columns...) 72 | case ast.ConstraintUniq, ast.ConstraintUniqKey, ast.ConstraintUniqIndex: 73 | ret.AppendUniqueKey(key, columns...) 74 | default: 75 | // ignore other constraints 76 | } 77 | 78 | return ret 79 | } 80 | 81 | func parseColumnFromKeys(keys []*ast.IndexPartSpecification) []string { 82 | var columnSet = set.From() 83 | for _, key := range keys { 84 | if key.Column == nil { 85 | continue 86 | } 87 | 88 | var columnName = key.Column.String() 89 | columnSet.Add(columnName) 90 | } 91 | 92 | return columnSet.String() 93 | } 94 | -------------------------------------------------------------------------------- /internal/parser/column_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | "github.com/pingcap/parser/ast" 8 | "github.com/pingcap/parser/model" 9 | "github.com/pingcap/parser/mysql" 10 | "github.com/pingcap/parser/test_driver" 11 | "github.com/pingcap/parser/types" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_parseColumnDef(t *testing.T) { 16 | t.Run("nil", func(t *testing.T) { 17 | column, constraint := parseColumnDef(nil) 18 | assert.Nil(t, column) 19 | assert.Nil(t, constraint) 20 | 21 | column, constraint = parseColumnDef(&ast.ColumnDef{}) 22 | assert.Nil(t, column) 23 | assert.Nil(t, constraint) 24 | }) 25 | 26 | t.Run("not nil", func(t *testing.T) { 27 | datum := test_driver.Datum{} 28 | datum.SetString("foo") 29 | name := &ast.ColumnName{ 30 | Name: model.CIStr{ 31 | O: "foo", 32 | L: "foo", 33 | }, 34 | } 35 | tp := &types.FieldType{ 36 | Tp: mysql.TypeBit, 37 | Flag: mysql.UnsignedFlag, 38 | } 39 | column, constraint := parseColumnDef(&ast.ColumnDef{ 40 | Name: name, 41 | Tp: tp, 42 | Options: []*ast.ColumnOption{ 43 | {Tp: ast.ColumnOptionNotNull}, 44 | {Tp: ast.ColumnOptionAutoIncrement}, 45 | {Tp: ast.ColumnOptionDefaultValue}, 46 | { 47 | Tp: ast.ColumnOptionComment, 48 | Expr: &test_driver.ValueExpr{Datum: datum}, 49 | }, 50 | {Tp: ast.ColumnOptionUniqKey}, 51 | {Tp: ast.ColumnOptionPrimaryKey}, 52 | {Tp: ast.ColumnOptionFulltext}, 53 | }, 54 | }) 55 | assert.Equal(t, spec.Column{ 56 | ColumnOption: spec.ColumnOption{ 57 | AutoIncrement: true, 58 | Comment: "foo", 59 | HasDefaultValue: true, 60 | NotNull: true, 61 | Unsigned: true, 62 | }, 63 | Name: "foo", 64 | TP: mysql.TypeBit, 65 | AggregateCall: false, 66 | }, *column) 67 | 68 | assertMapEqual(t, map[string][]string{ 69 | "foo": {"foo"}, 70 | }, constraint.UniqueKey) 71 | assertMapEqual(t, map[string][]string{ 72 | "foo": {"foo"}, 73 | }, constraint.PrimaryKey) 74 | }) 75 | } 76 | 77 | func assertMapEqual(t *testing.T, m1, m2 map[string][]string) { 78 | for k, v1 := range m1 { 79 | v2, ok := m2[k] 80 | assert.True(t, ok) 81 | assert.Equal(t, v1, v2) 82 | } 83 | return 84 | } 85 | 86 | func Test_parseConstraint(t *testing.T) { 87 | t.Run("nil", func(t *testing.T) { 88 | c := parseConstraint(nil) 89 | assert.Nil(t, c) 90 | 91 | c = parseConstraint(&ast.Constraint{Keys: nil}) 92 | assert.Nil(t, c) 93 | }) 94 | 95 | t.Run("ConstraintPrimaryKey", func(t *testing.T) { 96 | c := parseConstraint(&ast.Constraint{ 97 | Name: "foo", 98 | Tp: ast.ConstraintPrimaryKey, 99 | Keys: []*ast.IndexPartSpecification{ 100 | { 101 | Column: &ast.ColumnName{ 102 | Name: model.CIStr{ 103 | O: "foo", 104 | L: "foo", 105 | }, 106 | }, 107 | }, 108 | }}) 109 | assertMapEqual(t, map[string][]string{ 110 | "foo": {"foo"}, 111 | }, c.PrimaryKey) 112 | }) 113 | 114 | t.Run("ConstraintKey", func(t *testing.T) { 115 | c := parseConstraint(&ast.Constraint{ 116 | Name: "foo", 117 | Tp: ast.ConstraintKey, 118 | Keys: []*ast.IndexPartSpecification{ 119 | { 120 | Column: &ast.ColumnName{ 121 | Name: model.CIStr{ 122 | O: "foo", 123 | L: "foo", 124 | }, 125 | }, 126 | }, 127 | }}) 128 | assertMapEqual(t, map[string][]string{ 129 | "foo": {"foo"}, 130 | }, c.Index) 131 | }) 132 | 133 | t.Run("ConstraintIndex", func(t *testing.T) { 134 | c := parseConstraint(&ast.Constraint{ 135 | Name: "foo", 136 | Tp: ast.ConstraintIndex, 137 | Keys: []*ast.IndexPartSpecification{ 138 | { 139 | Column: &ast.ColumnName{ 140 | Name: model.CIStr{ 141 | O: "foo", 142 | L: "foo", 143 | }, 144 | }, 145 | }, 146 | }}) 147 | assertMapEqual(t, map[string][]string{ 148 | "foo": {"foo"}, 149 | }, c.Index) 150 | }) 151 | 152 | t.Run("ConstraintUniq", func(t *testing.T) { 153 | c := parseConstraint(&ast.Constraint{ 154 | Name: "foo", 155 | Tp: ast.ConstraintUniq, 156 | Keys: []*ast.IndexPartSpecification{ 157 | { 158 | Column: &ast.ColumnName{ 159 | Name: model.CIStr{ 160 | O: "foo", 161 | L: "foo", 162 | }, 163 | }, 164 | }, 165 | }}) 166 | assertMapEqual(t, map[string][]string{ 167 | "foo": {"foo"}, 168 | }, c.UniqueKey) 169 | }) 170 | t.Run("ConstraintUniqKey", func(t *testing.T) { 171 | c := parseConstraint(&ast.Constraint{ 172 | Name: "foo", 173 | Tp: ast.ConstraintUniqKey, 174 | Keys: []*ast.IndexPartSpecification{ 175 | { 176 | Column: &ast.ColumnName{ 177 | Name: model.CIStr{ 178 | O: "foo", 179 | L: "foo", 180 | }, 181 | }, 182 | }, 183 | }}) 184 | assertMapEqual(t, map[string][]string{ 185 | "foo": {"foo"}, 186 | }, c.UniqueKey) 187 | }) 188 | t.Run("ConstraintUniqIndex", func(t *testing.T) { 189 | c := parseConstraint(&ast.Constraint{ 190 | Name: "foo", 191 | Tp: ast.ConstraintUniqIndex, 192 | Keys: []*ast.IndexPartSpecification{ 193 | { 194 | Column: &ast.ColumnName{ 195 | Name: model.CIStr{ 196 | O: "foo", 197 | L: "foo", 198 | }, 199 | }, 200 | }, 201 | }}) 202 | assertMapEqual(t, map[string][]string{ 203 | "foo": {"foo"}, 204 | }, c.UniqueKey) 205 | }) 206 | 207 | t.Run("ConstraintUniqIndex", func(t *testing.T) { 208 | c := parseConstraint(&ast.Constraint{ 209 | Name: "foo", 210 | Tp: ast.ConstraintFulltext, 211 | Keys: []*ast.IndexPartSpecification{ 212 | { 213 | Column: &ast.ColumnName{ 214 | Name: model.CIStr{ 215 | O: "foo", 216 | L: "foo", 217 | }, 218 | }, 219 | }, 220 | }}) 221 | assert.Equal(t, spec.Constraint{ 222 | Index: map[string][]string{}, 223 | PrimaryKey: map[string][]string{}, 224 | UniqueKey: map[string][]string{}, 225 | }, *c) 226 | }) 227 | } 228 | 229 | func Test_parseColumnFromKeys(t *testing.T) { 230 | var testData = []struct { 231 | input string 232 | expect []string 233 | }{ 234 | { 235 | input: "foo", 236 | expect: []string{"foo"}, 237 | }, 238 | { 239 | expect: []string(nil), 240 | }, 241 | } 242 | getColumn := func(name string) *ast.ColumnName { 243 | if len(name) == 0 { 244 | return nil 245 | } 246 | var column ast.ColumnName 247 | column.Name = model.CIStr{ 248 | O: name, 249 | L: name, 250 | } 251 | return &column 252 | } 253 | for _, v := range testData { 254 | actual := parseColumnFromKeys([]*ast.IndexPartSpecification{ 255 | { 256 | Column: getColumn(v.input), 257 | }, 258 | }) 259 | assert.Equal(t, v.expect, actual) 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /internal/parser/comment.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "regexp" 8 | "strings" 9 | "text/scanner" 10 | 11 | "github.com/anqiansong/sqlgen/internal/spec" 12 | "github.com/anqiansong/sqlgen/internal/stringx" 13 | ) 14 | 15 | const ( 16 | plainTextMode = iota 17 | lineCommentMode 18 | docCommentOpenMode 19 | docCommentCloseMode 20 | ) 21 | 22 | var ( 23 | singleLineComment = []byte("--") 24 | fnPrefix = "fn:" 25 | fnRegex = `(?m)^[a-zA-Z]\w*` 26 | ) 27 | 28 | type segment struct { 29 | start, end int 30 | } 31 | 32 | type sqlScanner struct { 33 | *scanner.Scanner 34 | mode int 35 | source string 36 | startIndex int 37 | segments []segment 38 | trim string 39 | } 40 | 41 | func NewSqlScanner(s string) *sqlScanner { 42 | var sc scanner.Scanner 43 | return &sqlScanner{ 44 | Scanner: sc.Init(strings.NewReader(s)), 45 | mode: plainTextMode, 46 | source: s, 47 | trim: s, 48 | } 49 | } 50 | 51 | // ScanAndTrim ignores non comment text. 52 | func (s *sqlScanner) ScanAndTrim() (string, error) { 53 | for tok := s.Next(); tok != scanner.EOF; tok = s.Next() { 54 | switch tok { 55 | case '-': 56 | s.enterLineCommentMode() 57 | case '/': 58 | err := s.enterDocCommentMode() 59 | if err != nil { 60 | return "", err 61 | } 62 | } 63 | } 64 | 65 | if len(s.segments) == 0 { 66 | return s.source, nil 67 | } 68 | 69 | var list []string 70 | for i := 0; i < len(s.segments); i++ { 71 | var segment string 72 | if i == 0 { 73 | if s.segments[i].start > 0 { 74 | segment = s.source[:s.segments[i].start] 75 | list = append(list, stringx.TrimNewLine(segment)) 76 | } 77 | } 78 | if i != len(s.segments)-1 { 79 | segment = s.source[s.segments[i].end:s.segments[i+1].start] 80 | } else { 81 | segment = s.source[s.segments[i].end:] 82 | } 83 | 84 | s := stringx.TrimWhiteSpace(segment) 85 | if len(s) == 0 { 86 | continue 87 | } 88 | 89 | list = append(list, segment) 90 | } 91 | 92 | return stringx.FormatIdentifiers(strings.Join(list, "")), nil 93 | } 94 | 95 | func (s *sqlScanner) enterDocCommentMode() error { 96 | var position = s.CurrentPosition() 97 | s.startIndex = position - 1 98 | if s.startIndex < 0 { 99 | s.startIndex = 0 100 | } 101 | 102 | tok := s.Next() 103 | if tok != '*' { 104 | s.startIndex = 0 105 | return nil 106 | } 107 | 108 | s.mode = docCommentOpenMode 109 | for { 110 | tok := s.Next() 111 | if tok == scanner.EOF { 112 | s.mode = plainTextMode 113 | return fmt.Errorf("expected close flag '*/'") 114 | } 115 | 116 | if tok == '*' { 117 | s.mode = docCommentCloseMode 118 | } else if tok == '/' { 119 | if s.mode == docCommentCloseMode { 120 | s.segments = append(s.segments, segment{ 121 | start: s.startIndex, 122 | end: s.CurrentPosition(), 123 | }) 124 | } 125 | break 126 | } else { 127 | if s.mode == docCommentCloseMode { 128 | s.mode = docCommentOpenMode 129 | } 130 | } 131 | } 132 | s.mode = plainTextMode 133 | return nil 134 | } 135 | 136 | func (s *sqlScanner) CurrentPosition() int { 137 | offset := s.Pos().Offset 138 | return offset 139 | } 140 | 141 | func (s *sqlScanner) enterLineCommentMode() { 142 | var position = s.CurrentPosition() 143 | s.startIndex = position - 1 144 | if s.startIndex < 0 { 145 | s.startIndex = 0 146 | } 147 | tok := s.Next() 148 | if tok != '-' { 149 | s.startIndex = 0 150 | return 151 | } 152 | 153 | s.mode = lineCommentMode 154 | for { 155 | tok = s.Next() 156 | if tok == scanner.EOF || tok == '\n' { 157 | s.segments = append(s.segments, segment{ 158 | start: s.startIndex, 159 | end: s.CurrentPosition(), 160 | }) 161 | break 162 | } 163 | } 164 | s.mode = plainTextMode 165 | } 166 | 167 | func parseLineComment(sql string) (spec.Comment, error) { 168 | r := bufio.NewReader(strings.NewReader(sql)) 169 | var comment spec.Comment 170 | comment.OriginText = sql 171 | for { 172 | line, _, err := r.ReadLine() 173 | if err != nil { 174 | break 175 | } 176 | 177 | comment.LineText = append(comment.LineText, string(line)) 178 | if bytes.HasPrefix(line, singleLineComment) { 179 | var text = strings.TrimSpace(string(line[2:])) 180 | funcName, err := parseFuncName(text) 181 | if err != nil { 182 | return spec.Comment{}, err 183 | } 184 | 185 | if len(funcName) > 0 { // it will override the previous comment. 186 | comment.FuncName = funcName 187 | } 188 | } 189 | } 190 | 191 | return comment, nil 192 | } 193 | 194 | func parseFuncName(text string) (string, error) { 195 | s := stringx.TrimSpace(text) 196 | if !strings.HasPrefix(s, fnPrefix) { 197 | return "", nil 198 | } 199 | 200 | fn := strings.TrimPrefix(s, fnPrefix) 201 | if len(fn) == 0 { 202 | return "", errorMissingFunction 203 | } 204 | 205 | match, _ := regexp.MatchString(fnRegex, fn) 206 | if match { 207 | return fn, nil 208 | } 209 | 210 | return "", fmt.Errorf("invalid function name: %s", fn) 211 | } 212 | -------------------------------------------------------------------------------- /internal/parser/comment_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/anqiansong/sqlgen/internal/spec" 9 | ) 10 | 11 | func Test_parseLineComment(t *testing.T) { 12 | var test = []struct { 13 | input string 14 | expected spec.Comment 15 | }{ 16 | {input: "", expected: spec.Comment{}}, 17 | {input: "--fn:a", expected: spec.Comment{ 18 | OriginText: "--fn:a", 19 | LineText: []string{"--fn:a"}, 20 | FuncName: "a", 21 | }}, 22 | {input: "-- fn: ab", expected: spec.Comment{ 23 | OriginText: "-- fn: ab", 24 | LineText: []string{"-- fn: ab"}, 25 | FuncName: "ab", 26 | }}, 27 | {input: "-- fn : a1", expected: spec.Comment{ 28 | OriginText: "-- fn : a1", 29 | LineText: []string{"-- fn : a1"}, 30 | FuncName: "a1", 31 | }}, 32 | {input: "-- fn:A", expected: spec.Comment{ 33 | OriginText: "-- fn:A", 34 | LineText: []string{"-- fn:A"}, 35 | FuncName: "A", 36 | }}, 37 | {input: "-- fn:A1", expected: spec.Comment{ 38 | OriginText: "--\t\tfn:A1", 39 | LineText: []string{"--\t\tfn:A1"}, 40 | FuncName: "A1", 41 | }}, 42 | {input: "-- fn : A_ ", expected: spec.Comment{ 43 | OriginText: "-- fn : A_ ", 44 | LineText: []string{"-- fn : A_ "}, 45 | FuncName: "A_", 46 | }}, 47 | {input: "-- fn : A_\n-- plain text", expected: spec.Comment{ 48 | OriginText: "-- fn : A_\n-- plain text", 49 | LineText: []string{"-- fn : A_", "-- plain text"}, 50 | FuncName: "A_", 51 | }}, 52 | {input: `-- fn: Insert 53 | -- name: foo 54 | -- 用户数据插入 55 | insert into user (user, name, password, mobile) 56 | values ('test', 'test', 'test', 'test');`, expected: spec.Comment{ 57 | OriginText: `-- fn: Insert 58 | -- name: foo 59 | -- 用户数据插入 60 | insert into user (user, name, password, mobile) 61 | values ('test', 'test', 'test', 'test');`, 62 | LineText: []string{"-- fn: Insert", "-- name: foo", "-- 用户数据插入", "insert into user (user, name, password, mobile)", "values ('test', 'test', 'test', 'test');"}, 63 | FuncName: "Insert", 64 | }}, 65 | } 66 | for _, c := range test { 67 | actual, _ := parseLineComment(c.input) 68 | assert.Equal(t, c.expected, actual) 69 | } 70 | } 71 | 72 | func Test_parseFuncName(t *testing.T) { 73 | var test = []struct { 74 | input string 75 | expected string 76 | err bool 77 | }{ 78 | {input: "", expected: ""}, 79 | {input: "fn:", err: true}, 80 | {input: "fn:a", expected: "a"}, 81 | {input: "fn: ab", expected: "ab"}, 82 | {input: "fn : a1", expected: "a1"}, 83 | {input: "fn:A", expected: "A"}, 84 | {input: "fn:A1", expected: "A1"}, 85 | {input: " fn : A_ ", expected: "A_"}, 86 | {input: "fn:2", err: true}, 87 | {input: "fn:_", err: true}, 88 | {input: "fn:-", err: true}, 89 | } 90 | for _, c := range test { 91 | actual, err := parseFuncName(c.input) 92 | if c.err { 93 | assert.Error(t, err) 94 | continue 95 | } 96 | assert.Equal(t, c.expected, actual) 97 | } 98 | } 99 | 100 | func Test_trimComment(t *testing.T) { 101 | test := []struct { 102 | input string 103 | expect string 104 | err bool 105 | }{ 106 | {}, 107 | { 108 | input: " ", 109 | expect: " ", 110 | }, 111 | { 112 | input: "-", 113 | expect: "-", 114 | }, 115 | { 116 | input: "--", 117 | expect: "", 118 | }, 119 | { 120 | input: `--foo--bar 121 | foo`, 122 | expect: "foo", 123 | }, 124 | { 125 | input: `--foo--bar 126 | foo 127 | bar`, 128 | expect: "foo bar", 129 | }, 130 | { 131 | input: `/**/foo`, 132 | expect: "foo", 133 | }, 134 | { 135 | input: `foo/**/ bar`, 136 | expect: "foo bar", 137 | }, 138 | { 139 | input: `foo/**/ 140 | --foo 141 | /*--*/ 142 | bar`, 143 | expect: "foo bar", 144 | }, 145 | { 146 | input: "/*", 147 | err: true, 148 | }, 149 | { 150 | input: "/**", 151 | err: true, 152 | }, 153 | } 154 | for _, v := range test { 155 | s := NewSqlScanner(v.input) 156 | actual, err := s.ScanAndTrim() 157 | if v.err { 158 | assert.Error(t, err) 159 | continue 160 | } 161 | assert.Equal(t, v.expect, actual) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /internal/parser/ddl.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | ) 8 | 9 | func parseDDL(node *ast.CreateTableStmt) (*spec.DDL, error) { 10 | var ddl spec.DDL 11 | ddl.Table = parseCreateTableStmt(node) 12 | return &ddl, nil 13 | } 14 | -------------------------------------------------------------------------------- /internal/parser/delete.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | ) 8 | 9 | func parseDelete(stmt *ast.DeleteStmt) (spec.DML, error) { 10 | var ret spec.DeleteStmt 11 | var text = stmt.Text() 12 | comment, err := parseLineComment(text) 13 | if err != nil { 14 | return nil, errorNearBy(err, text) 15 | } 16 | 17 | sql, err := NewSqlScanner(text).ScanAndTrim() 18 | if err != nil { 19 | return nil, errorNearBy(err, text) 20 | } 21 | 22 | if stmt.IsMultiTable { 23 | return nil, errorNearBy(errorMultipleTable, text) 24 | } 25 | 26 | tableName, err := parseTableRefsClause(stmt.TableRefs) 27 | if err != nil { 28 | return nil, errorNearBy(err, text) 29 | } 30 | 31 | if stmt.Where != nil { 32 | where, err := parseExprNode(stmt.Where, tableName, exprTypeWhereClause) 33 | if err != nil { 34 | return nil, errorNearBy(err, text) 35 | } 36 | 37 | ret.Where = where 38 | } 39 | 40 | if stmt.Order != nil { 41 | orderBy, err := parseOrderBy(stmt.Order, tableName) 42 | if err != nil { 43 | return nil, errorNearBy(err, text) 44 | } 45 | 46 | ret.OrderBy = orderBy 47 | } 48 | 49 | if stmt.Limit != nil { 50 | limit, err := parseLimit(stmt.Limit) 51 | if err != nil { 52 | return nil, errorNearBy(err, text) 53 | } 54 | 55 | ret.Limit = limit 56 | } 57 | 58 | ret.Comment = comment 59 | ret.SQL = sql 60 | ret.Action = spec.ActionDelete 61 | ret.From = tableName 62 | return &ret, nil 63 | } 64 | -------------------------------------------------------------------------------- /internal/parser/delete_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser" 7 | "github.com/pingcap/parser/ast" 8 | "github.com/pingcap/parser/model" 9 | "github.com/pingcap/parser/opcode" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | var testParser *parser.Parser 14 | 15 | func TestMain(m *testing.M) { 16 | testParser = parser.New() 17 | m.Run() 18 | } 19 | 20 | func Test_parseDelete(t *testing.T) { 21 | t.Run("success", func(t *testing.T) { 22 | stmt, _, err := testParser.Parse(`-- fn: foo 23 | delete from foo where id = ? order by name limit 1`, "", "") 24 | assert.NoError(t, err) 25 | for _, v := range stmt { 26 | deleteStmt, ok := v.(*ast.DeleteStmt) 27 | if !ok { 28 | continue 29 | } 30 | _, err := parseDelete(deleteStmt) 31 | assert.NoError(t, err) 32 | } 33 | }) 34 | t.Run("parseLineComment", func(t *testing.T) { 35 | stmt, _, err := testParser.Parse(`-- fn: 36 | delete from foo where id = ?`, "", "") 37 | assert.NoError(t, err) 38 | for _, v := range stmt { 39 | deleteStmt, ok := v.(*ast.DeleteStmt) 40 | if !ok { 41 | continue 42 | } 43 | _, err := parseDelete(deleteStmt) 44 | assert.Contains(t, err.Error(), errorMissingFunction.Error()) 45 | } 46 | }) 47 | t.Run("IsMultiTable", func(t *testing.T) { 48 | stmt, _, err := testParser.Parse(`-- fn: foo 49 | delete from foo where id = ?`, "", "") 50 | assert.NoError(t, err) 51 | for _, v := range stmt { 52 | deleteStmt, ok := v.(*ast.DeleteStmt) 53 | if !ok { 54 | continue 55 | } 56 | // mock 57 | deleteStmt.IsMultiTable = true 58 | _, err := parseDelete(deleteStmt) 59 | assert.Contains(t, err.Error(), errorMultipleTable.Error()) 60 | } 61 | }) 62 | 63 | t.Run("parseTableRefsClause", func(t *testing.T) { 64 | stmt := &ast.DeleteStmt{} 65 | _, err := parseDelete(stmt) 66 | assert.Contains(t, err.Error(), errorMissingTable.Error()) 67 | }) 68 | 69 | t.Run("whereExpr", func(t *testing.T) { 70 | stmt := &ast.DeleteStmt{ 71 | Where: &ast.BinaryOperationExpr{ 72 | Op: opcode.Plus, 73 | }, 74 | TableRefs: &ast.TableRefsClause{ 75 | TableRefs: &ast.Join{ 76 | Left: &ast.TableSource{ 77 | Source: &ast.TableName{ 78 | Name: model.CIStr{ 79 | O: "foo", 80 | L: "foo", 81 | }, 82 | }, 83 | }, 84 | }, 85 | }, 86 | } 87 | _, err := parseDelete(stmt) 88 | assert.Contains(t, err.Error(), "unsupported opcode") 89 | }) 90 | 91 | t.Run("orderExpr", func(t *testing.T) { 92 | stmt := &ast.DeleteStmt{ 93 | Order: &ast.OrderByClause{ 94 | Items: []*ast.ByItem{ 95 | {}, 96 | }, 97 | }, 98 | TableRefs: &ast.TableRefsClause{ 99 | TableRefs: &ast.Join{ 100 | Left: &ast.TableSource{ 101 | Source: &ast.TableName{ 102 | Name: model.CIStr{ 103 | O: "foo", 104 | L: "foo", 105 | }, 106 | }, 107 | }, 108 | }, 109 | }, 110 | } 111 | _, err := parseDelete(stmt) 112 | assert.Contains(t, err.Error(), errorInvalidExprNode.Error()) 113 | }) 114 | 115 | t.Run("limitExpr", func(t *testing.T) { 116 | stmt := &ast.DeleteStmt{ 117 | Limit: &ast.Limit{ 118 | Count: &ast.BetweenExpr{}, 119 | }, 120 | TableRefs: &ast.TableRefsClause{ 121 | TableRefs: &ast.Join{ 122 | Left: &ast.TableSource{ 123 | Source: &ast.TableName{ 124 | Name: model.CIStr{ 125 | O: "foo", 126 | L: "foo", 127 | }, 128 | }, 129 | }, 130 | }, 131 | }, 132 | } 133 | _, err := parseDelete(stmt) 134 | assert.Contains(t, err.Error(), errorUnsupportedLimitExpr.Error()) 135 | }) 136 | } 137 | -------------------------------------------------------------------------------- /internal/parser/dml.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/anqiansong/sqlgen/internal/buffer" 5 | "github.com/pingcap/parser/ast" 6 | 7 | "github.com/anqiansong/sqlgen/internal/spec" 8 | ) 9 | 10 | func parseDML(node ast.StmtNode) (spec.DML, error) { 11 | switch v := node.(type) { 12 | case *ast.InsertStmt: 13 | return parseInsert(v) 14 | case *ast.SelectStmt: 15 | return parseSelect(v) 16 | case *ast.DeleteStmt: 17 | return parseDelete(v) 18 | case *ast.UpdateStmt: 19 | return parseUpdate(v) 20 | default: 21 | return nil, errorUnsupportedStmt 22 | } 23 | } 24 | 25 | func parseTableRefsClause(clause *ast.TableRefsClause) (string, error) { 26 | if clause == nil { 27 | return "", errorMissingTable 28 | } 29 | 30 | var join = clause.TableRefs 31 | if join == nil { 32 | return "", errorMissingTable 33 | } 34 | 35 | if join.Left == nil { 36 | return "", errorMissingTable 37 | } 38 | 39 | if join.Right != nil { 40 | return "", errorMultipleTable 41 | } 42 | 43 | tableName, err := parseResultSetNode(join.Left) 44 | if err != nil { 45 | return "", err 46 | } 47 | 48 | return tableName, nil 49 | } 50 | 51 | func parseTransaction(node *transactionStmt) (spec.DML, error) { 52 | if node == nil { 53 | return nil, errorMissingTransaction 54 | } 55 | var sqlBuilder = buffer.New() 56 | var beginText = node.startTransactionStmt.Text() 57 | var commitText = node.commitStmt.Text() 58 | beginSQL, err := NewSqlScanner(beginText).ScanAndTrim() 59 | if err != nil { 60 | return nil, errorNearBy(err, beginText) 61 | } 62 | commitSQL, err := NewSqlScanner(commitText).ScanAndTrim() 63 | if err != nil { 64 | return nil, errorNearBy(err, commitText) 65 | } 66 | 67 | comment, err := parseLineComment(beginText) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | sqlBuilder.Write(beginSQL) 73 | var ret spec.Transaction 74 | ret.Action = spec.ActionTransaction 75 | for _, v := range node.nodes() { 76 | dml, err := parseDML(v) 77 | if err != nil { 78 | return nil, err 79 | } 80 | sqlBuilder.Write(dml.SQLText()) 81 | ret.Statements = append(ret.Statements, dml) 82 | } 83 | sqlBuilder.Write(commitSQL) 84 | ret.SQL = sqlBuilder.String() 85 | ret.Comment = comment 86 | return &ret, nil 87 | } 88 | -------------------------------------------------------------------------------- /internal/parser/dml_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser/ast" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_parseDML(t *testing.T) { 11 | t.Run("InsertStmt", func(t *testing.T) { 12 | stmt, _, err := testParser.Parse(` 13 | -- fn: 14 | insert into foo (name) values(?); 15 | -- fn: 16 | select * from foo; 17 | -- fn: 18 | delete from foo where id = ?; 19 | -- fn: 20 | update foo set name = ? where id = ?; 21 | -- fn: 22 | alter table foo add column bar varchar(255); 23 | `, "", "") 24 | assert.NoError(t, err) 25 | for _, v := range stmt { 26 | _, err := parseDML(v) 27 | assert.NotNil(t, err) 28 | } 29 | }) 30 | } 31 | 32 | func Test_parseTableRefsClause(t *testing.T) { 33 | t.Run("nil", func(t *testing.T) { 34 | _, err := parseTableRefsClause(nil) 35 | assert.ErrorIs(t, err, errorMissingTable) 36 | }) 37 | 38 | t.Run("joinNil", func(t *testing.T) { 39 | _, err := parseTableRefsClause(&ast.TableRefsClause{}) 40 | assert.ErrorIs(t, err, errorMissingTable) 41 | }) 42 | 43 | t.Run("joinLeftNil", func(t *testing.T) { 44 | _, err := parseTableRefsClause(&ast.TableRefsClause{ 45 | TableRefs: &ast.Join{ 46 | Left: &ast.TableName{}, 47 | Right: &ast.TableName{}, 48 | }, 49 | }) 50 | assert.ErrorIs(t, err, errorMultipleTable) 51 | }) 52 | 53 | t.Run("parseResultSetNode", func(t *testing.T) { 54 | _, err := parseTableRefsClause(&ast.TableRefsClause{ 55 | TableRefs: &ast.Join{ 56 | Left: &ast.SelectStmt{}, 57 | }, 58 | }) 59 | assert.ErrorIs(t, err, errorUnsupportedNestedQuery) 60 | }) 61 | } 62 | 63 | func Test_parseTransaction(t *testing.T) { 64 | t.Run("nil", func(t *testing.T) { 65 | _, err := parseTransaction(nil) 66 | assert.ErrorIs(t, err, errorMissingTransaction) 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /internal/parser/error.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import "errors" 4 | 5 | var ( 6 | errorMissingTable = errors.New("missing table") 7 | errorUnsupportedStmt = errors.New("unsupported statement") 8 | errorMultipleTable = errors.New("unsupported multiple tables") 9 | errorUnsupportedTableStyle = errors.New("unsupported table style") 10 | errorUnsupportedNestedQuery = errors.New("unsupported nested query") 11 | errorUnsupportedUnionQuery = errors.New("unsupported union query") 12 | errorUnsupportedSubQuery = errors.New("unsupported sub-query query") 13 | errorInvalidExprNode = errors.New("invalid expr node") 14 | errorMissingHaving = errors.New("missing having expr") 15 | errorUnsupportedLimitExpr = errors.New("unsupported limit expr") 16 | errorParamMaker = errors.New("marker expr") 17 | errorUnsupportedNestedTransaction = errors.New("unsupported nested transaction") 18 | errorMissingCommit = errors.New("missing commit statement") 19 | errorMissingTransaction = errors.New("missing transaction statement") 20 | errorMissingFunction = errors.New("missing function name") 21 | ) 22 | 23 | func errorNearBy(err error, text string) error { 24 | return errors.New(err.Error() + " near by " + text) 25 | } 26 | -------------------------------------------------------------------------------- /internal/parser/funcmap.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/anqiansong/sqlgen/internal/spec" 5 | "github.com/pingcap/parser/mysql" 6 | ) 7 | 8 | var funcMap = map[string]byte{ 9 | // spec.TypeNullLongLong 10 | "count": spec.TypeNullLongLong, 11 | "length": spec.TypeNullLongLong, 12 | "char_length": spec.TypeNullLongLong, 13 | "locate": spec.TypeNullLongLong, 14 | "position": spec.TypeNullLongLong, 15 | "instr": spec.TypeNullLongLong, 16 | "field": spec.TypeNullLongLong, 17 | "sign": spec.TypeNullLongLong, 18 | "mod": spec.TypeNullLongLong, 19 | "unix_timestamp": spec.TypeNullLongLong, 20 | "sysdate": spec.TypeNullLongLong, 21 | "utc_date": spec.TypeNullLongLong, 22 | "utc_time": spec.TypeNullLongLong, 23 | "month": spec.TypeNullLongLong, 24 | "day": spec.TypeNullLongLong, 25 | "dayofmonth": spec.TypeNullLongLong, 26 | "dayofweek": spec.TypeNullLongLong, 27 | "dayofyear": spec.TypeNullLongLong, 28 | "week": spec.TypeNullLongLong, 29 | "weekday": spec.TypeNullLongLong, 30 | "weekofyear": spec.TypeNullLongLong, 31 | "quarter": spec.TypeNullLongLong, 32 | "hour": spec.TypeNullLongLong, 33 | "minute": spec.TypeNullLongLong, 34 | "second": spec.TypeNullLongLong, 35 | "extract": spec.TypeNullLongLong, 36 | "time_to_sec": spec.TypeNullLongLong, 37 | "to_days": spec.TypeNullLongLong, 38 | "to_seconds": spec.TypeNullLongLong, 39 | "datadiff": spec.TypeNullLongLong, 40 | 41 | // spec.TypeNullDecimal 42 | "avg": spec.TypeNullDecimal, 43 | "abs": spec.TypeNullDecimal, 44 | "ceil": spec.TypeNullDecimal, 45 | "floor": spec.TypeNullDecimal, 46 | "round": spec.TypeNullDecimal, 47 | "rand": spec.TypeNullDecimal, 48 | "pi": spec.TypeNullDecimal, 49 | "truncate": spec.TypeNullDecimal, 50 | "pow": spec.TypeNullDecimal, 51 | "sqrt": spec.TypeNullDecimal, 52 | "exp": spec.TypeNullDecimal, 53 | "log": spec.TypeNullDecimal, 54 | "log10": spec.TypeNullDecimal, 55 | "radians": spec.TypeNullDecimal, 56 | "degrees": spec.TypeNullDecimal, 57 | "sin": spec.TypeNullDecimal, 58 | "cos": spec.TypeNullDecimal, 59 | "tan": spec.TypeNullDecimal, 60 | "cot": spec.TypeNullDecimal, 61 | "asin": spec.TypeNullDecimal, 62 | "acos": spec.TypeNullDecimal, 63 | "atan": spec.TypeNullDecimal, 64 | 65 | // spec.TypeNullString 66 | "concat_ws": spec.TypeNullString, 67 | "concat": spec.TypeNullString, 68 | "insert": spec.TypeNullString, 69 | "upper": spec.TypeNullString, 70 | "ucaase": spec.TypeNullString, 71 | "lower": spec.TypeNullString, 72 | "lcase": spec.TypeNullString, 73 | "left": spec.TypeNullString, 74 | "right": spec.TypeNullString, 75 | "lpad": spec.TypeNullString, 76 | "rpad": spec.TypeNullString, 77 | "replace": spec.TypeNullString, 78 | "substring": spec.TypeNullString, 79 | "substr": spec.TypeNullString, 80 | "trim": spec.TypeNullString, 81 | "ltrim": spec.TypeNullString, 82 | "rtrim": spec.TypeNullString, 83 | "reverse": spec.TypeNullString, 84 | "repeat": spec.TypeNullString, 85 | "space": spec.TypeNullString, 86 | "strcmp": spec.TypeNullString, 87 | "mid": spec.TypeNullString, 88 | "from_unixtime": spec.TypeNullString, 89 | "month_name": spec.TypeNullString, 90 | "day_name": spec.TypeNullString, 91 | "date_format": spec.TypeNullString, 92 | "time_format": spec.TypeNullString, 93 | 94 | // mysql.TypeDate 95 | "curdate": mysql.TypeDate, 96 | "current_date": mysql.TypeDate, 97 | "adddate": mysql.TypeDate, 98 | "subdate": mysql.TypeDate, 99 | 100 | // mysql.TypeDuration 101 | "curtime": mysql.TypeDuration, 102 | "current_tim": mysql.TypeDuration, 103 | "now": mysql.TypeDuration, 104 | "localtime": mysql.TypeDuration, 105 | "sec_to_time": mysql.TypeDuration, 106 | "addtime": mysql.TypeDuration, 107 | 108 | // mysql.TypeTimestamp 109 | "current_timestamp": mysql.TypeTimestamp, 110 | "localtimestamp": mysql.TypeTimestamp, 111 | } 112 | -------------------------------------------------------------------------------- /internal/parser/infoschema.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bytes" 5 | _ "embed" 6 | "errors" 7 | "fmt" 8 | "strings" 9 | "text/template" 10 | 11 | sql "github.com/go-sql-driver/mysql" 12 | "github.com/pingcap/parser/mysql" 13 | "github.com/zeromicro/go-zero/core/stores/sqlx" 14 | 15 | "github.com/anqiansong/sqlgen/internal/infoschema" 16 | "github.com/anqiansong/sqlgen/internal/patterns" 17 | "github.com/anqiansong/sqlgen/internal/spec" 18 | "github.com/anqiansong/sqlgen/internal/stringx" 19 | "github.com/anqiansong/sqlgen/internal/templatex" 20 | ) 21 | 22 | var errMissingSchema = errors.New("missing schema") 23 | 24 | func From(dsn string, pattern ...string) (*spec.DXL, error) { 25 | schema, url, err := parseDSN(dsn) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | var conn = sqlx.NewMysql(url) 31 | var model = infoschema.NewInformationSchemaModel(conn) 32 | tables, err := model.GetAllTables(schema) 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | var p = patterns.New(pattern...) 38 | var matchTables = p.Match(tables...) 39 | var dxl spec.DXL 40 | for _, table := range matchTables { 41 | modelTable, err := model.FindColumns(schema, table) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | ddl, err := convertDDL(modelTable) 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | dml, err := convertDML(ddl.Table) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | dxl.DDL = append(dxl.DDL, ddl) 57 | dxl.DML = append(dxl.DML, dml...) 58 | } 59 | 60 | return &dxl, nil 61 | } 62 | 63 | func parseDSN(dsn string) (db, url string, err error) { 64 | cfg, err := sql.ParseDSN(dsn) 65 | if err != nil { 66 | return "", "", err 67 | } 68 | 69 | if cfg.DBName == "" { 70 | return "", "", errMissingSchema 71 | } 72 | 73 | url = fmt.Sprintf("%s:%s@tcp(%s)/%s", cfg.User, cfg.Passwd, cfg.Addr, "information_schema") 74 | db = cfg.DBName 75 | return 76 | } 77 | 78 | //go:embed init.tpl.sql 79 | var initSql string 80 | 81 | func convertDML(in *spec.Table) ([]spec.DML, error) { 82 | t, err := template.New("sql").Parse(initSql) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | var sqlBuffer bytes.Buffer 88 | if err = t.Execute(&sqlBuffer, map[string]interface{}{ 89 | "insert_columns": strings.Join(in.ColumnList(), ", "), 90 | "insert_table": in.Name, 91 | "insert_values": stringx.RepeatJoin("?", ", ", len(in.ColumnList())), 92 | "unique_indexes": getUniques(in), 93 | }); err != nil { 94 | return nil, err 95 | } 96 | 97 | dxl, err := Parse(sqlBuffer.String()) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | return dxl.DML, nil 103 | } 104 | 105 | // Unique is a unique index info. 106 | type Unique struct { 107 | SelectColumns string 108 | Table string 109 | UpdateSet string 110 | WhereClause string 111 | UniqueNameJoin string 112 | } 113 | 114 | func getUniques(in *spec.Table) []Unique { 115 | var list []Unique 116 | var columns = strings.Join(in.ColumnList(), ", ") 117 | var updateSet = strings.Join(in.ColumnList(), " = ?,") + " = ?" 118 | var m = map[Unique]struct{}{} 119 | for _, c := range in.Constraint.PrimaryKey { 120 | var item = Unique{ 121 | SelectColumns: columns, 122 | Table: in.Name, 123 | UpdateSet: updateSet, 124 | WhereClause: strings.Join(c, " = ? AND") + " = ?", 125 | UniqueNameJoin: templatex.UpperCamel(strings.Join(c, "")), 126 | } 127 | if _, ok := m[item]; ok { 128 | continue 129 | } 130 | m[item] = struct{}{} 131 | list = append(list, item) 132 | } 133 | for _, c := range in.Constraint.UniqueKey { 134 | var item = Unique{ 135 | SelectColumns: columns, 136 | Table: in.Name, 137 | UpdateSet: updateSet, 138 | WhereClause: strings.Join(c, " = ? AND") + " = ?", 139 | UniqueNameJoin: templatex.UpperCamel(strings.Join(c, "")), 140 | } 141 | if _, ok := m[item]; ok { 142 | continue 143 | } 144 | m[item] = struct{}{} 145 | list = append(list, item) 146 | } 147 | 148 | return list 149 | } 150 | 151 | func convertDDL(in *infoschema.Table) (*spec.DDL, error) { 152 | var ddl spec.DDL 153 | var constraint = spec.NewConstraint() 154 | getConstraint(in.Columns, constraint) 155 | var table spec.Table 156 | table.Name = in.Table 157 | table.Schema = in.Db 158 | if !constraint.IsEmpty() { 159 | table.Constraint = *constraint 160 | } 161 | 162 | for _, c := range in.Columns { 163 | var extra = c.Extra 164 | var autoIncrement = strings.Contains(extra, "auto_increment") 165 | var unsigned = strings.Contains(c.DataType, "unsigned") 166 | tp, err := dbTypeMapper(c.DataType) 167 | if err != nil { 168 | return nil, err 169 | } 170 | 171 | table.Columns = append(table.Columns, spec.Column{ 172 | ColumnOption: spec.ColumnOption{ 173 | AutoIncrement: autoIncrement, 174 | Comment: stringx.TrimNewLine(c.Comment), 175 | HasDefaultValue: c.ColumnDefault != nil, 176 | NotNull: !strings.EqualFold(c.IsNullAble, "yes"), 177 | Unsigned: unsigned, 178 | }, 179 | Name: c.Name, 180 | TP: tp, 181 | }) 182 | } 183 | 184 | ddl.Table = &table 185 | return &ddl, nil 186 | } 187 | 188 | func getConstraint(columns []*infoschema.Column, constraint *spec.Constraint) { 189 | for _, c := range columns { 190 | index := c.Index 191 | if index == nil { 192 | continue 193 | } 194 | indexName := index.IndexName 195 | if strings.EqualFold(indexName, "primary") { 196 | constraint.AppendPrimaryKey(indexName, c.Name) 197 | } 198 | if index.NonUnique == 0 { 199 | constraint.AppendUniqueKey(indexName, c.Name) 200 | } else { 201 | constraint.AppendIndex(indexName, c.Name) 202 | } 203 | } 204 | } 205 | 206 | var str2Type = map[string]byte{ 207 | "bit": mysql.TypeBit, 208 | "text": mysql.TypeBlob, 209 | "date": mysql.TypeDate, 210 | "datetime": mysql.TypeDatetime, 211 | "unspecified": mysql.TypeUnspecified, 212 | "decimal": mysql.TypeNewDecimal, 213 | "double": mysql.TypeDouble, 214 | "enum": mysql.TypeEnum, 215 | "float": mysql.TypeFloat, 216 | "geometry": mysql.TypeGeometry, 217 | "mediumint": mysql.TypeInt24, 218 | "json": mysql.TypeJSON, 219 | "int": mysql.TypeLong, 220 | "bigint": mysql.TypeLonglong, 221 | "longtext": mysql.TypeLongBlob, 222 | "mediumtext": mysql.TypeMediumBlob, 223 | "null": mysql.TypeNull, 224 | "set": mysql.TypeSet, 225 | "smallint": mysql.TypeShort, 226 | "char": mysql.TypeString, 227 | "time": mysql.TypeDuration, 228 | "timestamp": mysql.TypeTimestamp, 229 | "tinyint": mysql.TypeTiny, 230 | "tinytext": mysql.TypeTinyBlob, 231 | "varchar": mysql.TypeVarchar, 232 | "var_string": mysql.TypeVarString, 233 | "year": mysql.TypeYear, 234 | } 235 | 236 | func dbTypeMapper(tp string) (byte, error) { 237 | var l = strings.ToLower(tp) 238 | ret, ok := str2Type[l] 239 | if !ok { 240 | return 0, fmt.Errorf("unsupported type:%s", tp) 241 | } 242 | return ret, nil 243 | } 244 | -------------------------------------------------------------------------------- /internal/parser/init.tpl.sql: -------------------------------------------------------------------------------- 1 | -- fn: Insert 2 | insert into {{.insert_table}} ({{.insert_columns}}) values ({{.insert_values}}); 3 | {{range .unique_indexes}} 4 | -- fn: FindOneBy{{.UniqueNameJoin}} 5 | select {{.SelectColumns}} from {{.Table}} where {{.WhereClause}} limit 1; 6 | {{end}} 7 | {{range .unique_indexes}} 8 | -- fn: UpdateBy{{.UniqueNameJoin}} 9 | update {{.Table}} set {{.UpdateSet}} where {{.WhereClause}}; 10 | {{end}} 11 | {{range .unique_indexes}} 12 | -- fn: DeleteBy{{.UniqueNameJoin}} 13 | delete from {{.Table}} where {{.WhereClause}}; 14 | {{end}} -------------------------------------------------------------------------------- /internal/parser/insert.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | ) 8 | 9 | func parseInsert(stmt *ast.InsertStmt) (*spec.InsertStmt, error) { 10 | var text = stmt.Text() 11 | comment, err := parseLineComment(text) 12 | if err != nil { 13 | return nil, errorNearBy(err, text) 14 | } 15 | 16 | sql, err := NewSqlScanner(text).ScanAndTrim() 17 | if err != nil { 18 | return nil, errorNearBy(err, text) 19 | } 20 | 21 | var ret spec.InsertStmt 22 | ret.Comment = comment 23 | tableName, err := parseTableRefsClause(stmt.Table) 24 | if err != nil { 25 | return nil, errorNearBy(err, text) 26 | } 27 | 28 | columns, err := parseColumns(stmt.Columns, tableName) 29 | if err != nil { 30 | return nil, errorNearBy(err, text) 31 | } 32 | 33 | ret.Table = tableName 34 | ret.Action = spec.ActionCreate 35 | ret.SQL = sql 36 | ret.Columns = columns 37 | 38 | return &ret, nil 39 | } 40 | -------------------------------------------------------------------------------- /internal/parser/insert_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser/ast" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_parseInsert(t *testing.T) { 11 | t.Run("missingFunction", func(t *testing.T) { 12 | stmt, _, err := testParser.Parse(`-- fn: 13 | delete from foo where id = ?`, "", "") 14 | assert.NoError(t, err) 15 | for _, v := range stmt { 16 | insertStmt, ok := v.(*ast.InsertStmt) 17 | if !ok { 18 | continue 19 | } 20 | _, err := parseInsert(insertStmt) 21 | assert.Contains(t, err.Error(), errorMissingFunction.Error()) 22 | } 23 | }) 24 | 25 | t.Run("parseTableRefsClause", func(t *testing.T) { 26 | stmt := &ast.InsertStmt{} 27 | _, err := parseInsert(stmt) 28 | assert.Contains(t, err.Error(), errorMissingTable.Error()) 29 | }) 30 | } 31 | -------------------------------------------------------------------------------- /internal/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pingcap/parser" 7 | "github.com/pingcap/parser/ast" 8 | _ "github.com/pingcap/parser/test_driver" 9 | 10 | "github.com/anqiansong/sqlgen/internal/spec" 11 | ) 12 | 13 | var p *parser.Parser 14 | 15 | type stmts []stmt 16 | 17 | type stmt interface { 18 | nodes() []ast.StmtNode 19 | } 20 | 21 | type createTableStmt struct { 22 | stmt *ast.CreateTableStmt 23 | } 24 | 25 | func (c createTableStmt) nodes() []ast.StmtNode { 26 | return []ast.StmtNode{c.stmt} 27 | } 28 | 29 | type queryStmt struct { 30 | stmt ast.StmtNode 31 | } 32 | 33 | func (q queryStmt) nodes() []ast.StmtNode { 34 | return []ast.StmtNode{q.stmt} 35 | } 36 | 37 | type transactionStmt struct { 38 | startTransactionStmt ast.StmtNode 39 | queryList stmts 40 | commitStmt ast.StmtNode 41 | } 42 | 43 | func (t transactionStmt) nodes() []ast.StmtNode { 44 | var list []ast.StmtNode 45 | for _, v := range t.queryList { 46 | stmt, ok := v.(*queryStmt) 47 | if ok { 48 | list = append(list, stmt.stmt) 49 | } 50 | } 51 | return list 52 | } 53 | 54 | func init() { 55 | p = parser.New() 56 | } 57 | 58 | // Parse parses a SQL statement string and returns a spec.DXL. 59 | func Parse(sql string) (*spec.DXL, error) { 60 | stmtNodes, _, err := p.Parse(sql, "", "") 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | stmt, err := splits(stmtNodes) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | var ret spec.DXL 71 | for _, stmtNode := range stmt { 72 | switch node := stmtNode.(type) { 73 | case *createTableStmt: 74 | ddl, err := parseDDL(node.stmt) 75 | if err != nil { 76 | return nil, err 77 | } 78 | ret.DDL = append(ret.DDL, ddl) 79 | case *queryStmt: 80 | dml, err := parseDML(node.stmt) 81 | if err != nil { 82 | return nil, err 83 | } 84 | ret.DML = append(ret.DML, dml) 85 | case *transactionStmt: 86 | if node.queryList.hasTransactionStmt() { 87 | return nil, errorUnsupportedNestedTransaction 88 | } 89 | if len(node.queryList) == 0 { 90 | continue 91 | } 92 | dml, err := parseTransaction(node) 93 | if err != nil { 94 | return nil, err 95 | } 96 | ret.DML = append(ret.DML, dml) 97 | default: 98 | // ignores other statements 99 | } 100 | } 101 | 102 | if err = ret.Validate(); err != nil { 103 | return nil, err 104 | } 105 | 106 | return &ret, nil 107 | } 108 | 109 | func splits(stmtNodes []ast.StmtNode) ([]stmt, error) { 110 | var list stmts 111 | var transactionMode bool 112 | for _, v := range stmtNodes { 113 | switch node := v.(type) { 114 | case *ast.CreateTableStmt: 115 | if transactionMode { 116 | return nil, fmt.Errorf("missing begin stmt near by '%s'", v.Text()) 117 | } 118 | list = append(list, &createTableStmt{stmt: node}) 119 | case *ast.InsertStmt, *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt: 120 | if transactionMode { 121 | transactionNode := list[len(list)-1].(*transactionStmt) 122 | transactionNode.queryList = append(transactionNode.queryList, &queryStmt{stmt: node}) 123 | } else { 124 | list = append(list, &queryStmt{stmt: node}) 125 | } 126 | case *ast.BeginStmt: 127 | if transactionMode { 128 | transactionNode := list[len(list)-1].(*transactionStmt) 129 | transactionNode.queryList = append(transactionNode.queryList, &transactionStmt{startTransactionStmt: node}) 130 | } else { 131 | transactionMode = true 132 | list = append(list, &transactionStmt{startTransactionStmt: v}) 133 | } 134 | case *ast.CommitStmt: 135 | if transactionMode { 136 | transactionNode := list[len(list)-1].(*transactionStmt) 137 | transactionNode.commitStmt = v 138 | transactionMode = false 139 | } else { 140 | return nil, fmt.Errorf("missing begin stmt near by '%s'", v.Text()) 141 | } 142 | default: 143 | return nil, errorUnsupportedStmt 144 | } 145 | } 146 | if transactionMode { 147 | return nil, errorMissingCommit 148 | } 149 | return list, nil 150 | } 151 | 152 | func (s stmts) hasTransactionStmt() bool { 153 | for _, v := range s { 154 | if _, ok := v.(*transactionStmt); ok { 155 | return true 156 | } 157 | } 158 | return false 159 | } 160 | -------------------------------------------------------------------------------- /internal/parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | _ "embed" 5 | "testing" 6 | 7 | "github.com/anqiansong/sqlgen/internal/spec" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | //go:embed test.sql 12 | var testSql string 13 | 14 | func TestParse(t *testing.T) { 15 | t.Run("ParseError", func(t *testing.T) { 16 | _, err := Parse("delete from where id = ?") 17 | assert.NotNil(t, err) 18 | }) 19 | 20 | t.Run("ParseError", func(t *testing.T) { 21 | _, err := Parse("alter table foo add column name varchar(255);") 22 | assert.ErrorIs(t, err, errorUnsupportedStmt) 23 | }) 24 | 25 | t.Run("success", func(t *testing.T) { 26 | dxl, err := Parse(testSql) 27 | assert.Nil(t, err) 28 | 29 | _, err = spec.From(dxl) 30 | assert.Nil(t, err) 31 | }) 32 | } 33 | -------------------------------------------------------------------------------- /internal/parser/select_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser/ast" 7 | "github.com/pingcap/parser/model" 8 | "github.com/pingcap/parser/opcode" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_parseSelect(t *testing.T) { 13 | t.Run("missingFunction", func(t *testing.T) { 14 | stmt, _, err := testParser.Parse(`-- fn: 15 | delete from foo where id = ?`, "", "") 16 | assert.NoError(t, err) 17 | for _, v := range stmt { 18 | selectStmt, ok := v.(*ast.SelectStmt) 19 | if !ok { 20 | continue 21 | } 22 | _, err := parseSelect(selectStmt) 23 | assert.Contains(t, err.Error(), errorMissingFunction.Error()) 24 | } 25 | }) 26 | 27 | t.Run("parseTableRefsClause", func(t *testing.T) { 28 | stmt := &ast.SelectStmt{} 29 | _, err := parseSelect(stmt) 30 | assert.Contains(t, err.Error(), errorMissingTable.Error()) 31 | }) 32 | 33 | t.Run("whereExpr", func(t *testing.T) { 34 | stmt := &ast.SelectStmt{ 35 | Where: &ast.BinaryOperationExpr{ 36 | Op: opcode.Plus, 37 | }, 38 | From: &ast.TableRefsClause{ 39 | TableRefs: &ast.Join{ 40 | Left: &ast.TableSource{ 41 | Source: &ast.TableName{ 42 | Name: model.CIStr{ 43 | O: "foo", 44 | L: "foo", 45 | }, 46 | }, 47 | }, 48 | }, 49 | }, 50 | } 51 | _, err := parseSelect(stmt) 52 | assert.Contains(t, err.Error(), "unsupported opcode") 53 | }) 54 | 55 | t.Run("groupBy", func(t *testing.T) { 56 | stmt := &ast.SelectStmt{ 57 | GroupBy: &ast.GroupByClause{ 58 | Items: []*ast.ByItem{ 59 | {}, 60 | }, 61 | }, 62 | From: &ast.TableRefsClause{ 63 | TableRefs: &ast.Join{ 64 | Left: &ast.TableSource{ 65 | Source: &ast.TableName{ 66 | Name: model.CIStr{ 67 | O: "foo", 68 | L: "foo", 69 | }, 70 | }, 71 | }, 72 | }, 73 | }, 74 | } 75 | _, err := parseSelect(stmt) 76 | assert.Contains(t, err.Error(), errorInvalidExprNode.Error()) 77 | }) 78 | 79 | t.Run("having", func(t *testing.T) { 80 | stmt := &ast.SelectStmt{ 81 | Having: &ast.HavingClause{ 82 | Expr: &ast.BinaryOperationExpr{ 83 | Op: opcode.Plus, 84 | }, 85 | }, 86 | From: &ast.TableRefsClause{ 87 | TableRefs: &ast.Join{ 88 | Left: &ast.TableSource{ 89 | Source: &ast.TableName{ 90 | Name: model.CIStr{ 91 | O: "foo", 92 | L: "foo", 93 | }, 94 | }, 95 | }, 96 | }, 97 | }, 98 | } 99 | _, err := parseSelect(stmt) 100 | assert.Contains(t, err.Error(), "unsupported opcode") 101 | }) 102 | 103 | t.Run("orderExpr", func(t *testing.T) { 104 | stmt := &ast.SelectStmt{ 105 | OrderBy: &ast.OrderByClause{ 106 | Items: []*ast.ByItem{ 107 | {}, 108 | }, 109 | }, 110 | From: &ast.TableRefsClause{ 111 | TableRefs: &ast.Join{ 112 | Left: &ast.TableSource{ 113 | Source: &ast.TableName{ 114 | Name: model.CIStr{ 115 | O: "foo", 116 | L: "foo", 117 | }, 118 | }, 119 | }, 120 | }, 121 | }, 122 | } 123 | _, err := parseSelect(stmt) 124 | assert.Contains(t, err.Error(), errorInvalidExprNode.Error()) 125 | }) 126 | 127 | t.Run("limitExpr", func(t *testing.T) { 128 | stmt := &ast.SelectStmt{ 129 | Limit: &ast.Limit{ 130 | Count: &ast.BetweenExpr{}, 131 | }, 132 | From: &ast.TableRefsClause{ 133 | TableRefs: &ast.Join{ 134 | Left: &ast.TableSource{ 135 | Source: &ast.TableName{ 136 | Name: model.CIStr{ 137 | O: "foo", 138 | L: "foo", 139 | }, 140 | }, 141 | }, 142 | }, 143 | }, 144 | } 145 | _, err := parseSelect(stmt) 146 | assert.Contains(t, err.Error(), errorUnsupportedLimitExpr.Error()) 147 | }) 148 | 149 | t.Run("fields", func(t *testing.T) { 150 | stmt := &ast.SelectStmt{ 151 | Fields: &ast.FieldList{ 152 | Fields: []*ast.SelectField{ 153 | { 154 | Offset: 0, 155 | WildCard: &ast.WildCardField{ 156 | Table: model.CIStr{ 157 | O: "bar", 158 | L: "bar", 159 | }, 160 | }, 161 | }, 162 | }, 163 | }, 164 | From: &ast.TableRefsClause{ 165 | TableRefs: &ast.Join{ 166 | Left: &ast.TableSource{ 167 | Source: &ast.TableName{ 168 | Name: model.CIStr{ 169 | O: "foo", 170 | L: "foo", 171 | }, 172 | }, 173 | }, 174 | }, 175 | }, 176 | } 177 | _, err := parseSelect(stmt) 178 | assert.Contains(t, err.Error(), "wildcard table") 179 | }) 180 | } 181 | 182 | func Test_convertOP(t *testing.T) { 183 | _, err := convertOP(opcode.In) 184 | assert.NoError(t, err) 185 | 186 | _, err = convertOP(opcode.Regexp) 187 | assert.NotNil(t, err) 188 | } 189 | -------------------------------------------------------------------------------- /internal/parser/table.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | ) 8 | 9 | func parseCreateTableStmt(stmt *ast.CreateTableStmt) *spec.Table { 10 | var table spec.Table 11 | if stmt.Table != nil { 12 | table.Name = stmt.Table.Name.String() 13 | } 14 | 15 | var constraint = spec.NewConstraint() 16 | for _, col := range stmt.Cols { 17 | column, con := parseColumnDef(col) 18 | if column != nil { 19 | table.Columns = append(table.Columns, *column) 20 | } 21 | constraint.Merge(con) 22 | } 23 | 24 | for _, c := range stmt.Constraints { 25 | constraint.Merge(parseConstraint(c)) 26 | } 27 | 28 | table.Constraint = *constraint 29 | return &table 30 | } 31 | -------------------------------------------------------------------------------- /internal/parser/test.sql: -------------------------------------------------------------------------------- 1 | -- 用户表 -- 2 | CREATE TABLE `user` 3 | ( 4 | `id` bigint(10) unsigned NOT NULL AUTO_INCREMENT primary key, 5 | `name` varchar(255) COLLATE utf8mb4_general_ci NULL COMMENT '用户\t名称', 6 | `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户\n密码', 7 | `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号', 8 | `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公\r开', 9 | `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称', 10 | `type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型', 11 | `create_time` timestamp NULL, 12 | `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 13 | UNIQUE KEY `name_index` (`name`), 14 | UNIQUE KEY `type_index` (`type`), 15 | UNIQUE KEY `mobile_index` (`mobile`) 16 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT '用户表' COLLATE=utf8mb4_general_ci; 17 | 18 | -- fn: findLimit 19 | select * from user where id > ? group by name having id > ? order by id desc limit ?,?; 20 | 21 | -- fn: case1 22 | select count(mobile) as mobileCount, count(1) as count, id,name from user; 23 | 24 | -- fn: case2 25 | select * from user where name in (?) and id between ? and ? or (mobile = ?) and nickname like ?; 26 | 27 | -- fn: count 28 | select count(id) as count from user; 29 | 30 | -- fn: test 31 | start transaction; 32 | -- fn: foo1 33 | select * from user where id = 1 ; 34 | -- fn: foo2 35 | select * from user where id = 2; 36 | commit; 37 | 38 | -- fn: test2 39 | start transaction; 40 | -- fn: foo3 41 | update user set name = ? where id = ?; 42 | -- fn: foo4 43 | update user set nickname = ? where id = ?; 44 | commit ; 45 | 46 | -- fn: deleteUser 47 | delete from user; 48 | 49 | 50 | -------------------------------------------------------------------------------- /internal/parser/update.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | 6 | "github.com/anqiansong/sqlgen/internal/spec" 7 | ) 8 | 9 | func parseUpdate(stmt *ast.UpdateStmt) (spec.DML, error) { 10 | var ret spec.UpdateStmt 11 | var text = stmt.Text() 12 | comment, err := parseLineComment(text) 13 | if err != nil { 14 | return nil, errorNearBy(err, text) 15 | } 16 | 17 | sql, err := NewSqlScanner(text).ScanAndTrim() 18 | if err != nil { 19 | return nil, errorNearBy(err, text) 20 | } 21 | 22 | if stmt.MultipleTable { 23 | return nil, errorNearBy(errorMultipleTable, text) 24 | } 25 | 26 | tableName, err := parseTableRefsClause(stmt.TableRefs) 27 | if err != nil { 28 | return nil, errorNearBy(err, text) 29 | } 30 | 31 | if stmt.Where != nil { 32 | where, err := parseExprNode(stmt.Where, tableName, exprTypeWhereClause) 33 | if err != nil { 34 | return nil, errorNearBy(err, text) 35 | } 36 | 37 | ret.Where = where 38 | } 39 | 40 | if stmt.Order != nil { 41 | orderBy, err := parseOrderBy(stmt.Order, tableName) 42 | if err != nil { 43 | return nil, errorNearBy(err, text) 44 | } 45 | 46 | ret.OrderBy = orderBy 47 | } 48 | 49 | if stmt.Limit != nil { 50 | limit, err := parseLimit(stmt.Limit) 51 | if err != nil { 52 | return nil, errorNearBy(err, text) 53 | } 54 | 55 | ret.Limit = limit 56 | } 57 | 58 | for _, a := range stmt.List { 59 | colName, err := parseColumn(a.Column, tableName) 60 | if err != nil { 61 | return nil, errorNearBy(err, text) 62 | } 63 | 64 | if len(colName) > 0 { 65 | ret.Columns = append(ret.Columns, colName) 66 | } 67 | } 68 | 69 | ret.Comment = comment 70 | ret.SQL = sql 71 | ret.Action = spec.ActionUpdate 72 | ret.Table = tableName 73 | 74 | return &ret, nil 75 | } 76 | -------------------------------------------------------------------------------- /internal/parser/update_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser/ast" 7 | "github.com/pingcap/parser/model" 8 | "github.com/pingcap/parser/opcode" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_parseUpdate(t *testing.T) { 13 | t.Run("missingFunction", func(t *testing.T) { 14 | stmt, _, err := testParser.Parse(`-- fn: 15 | delete from foo where id = ?`, "", "") 16 | assert.NoError(t, err) 17 | for _, v := range stmt { 18 | updateStmt, ok := v.(*ast.UpdateStmt) 19 | if !ok { 20 | continue 21 | } 22 | _, err := parseUpdate(updateStmt) 23 | assert.Contains(t, err.Error(), errorMissingFunction.Error()) 24 | } 25 | }) 26 | 27 | t.Run("parseTableRefsClause", func(t *testing.T) { 28 | stmt := &ast.UpdateStmt{} 29 | _, err := parseUpdate(stmt) 30 | assert.Contains(t, err.Error(), errorMissingTable.Error()) 31 | }) 32 | 33 | t.Run("whereExpr", func(t *testing.T) { 34 | stmt := &ast.UpdateStmt{ 35 | Where: &ast.BinaryOperationExpr{ 36 | Op: opcode.Plus, 37 | }, 38 | TableRefs: &ast.TableRefsClause{ 39 | TableRefs: &ast.Join{ 40 | Left: &ast.TableSource{ 41 | Source: &ast.TableName{ 42 | Name: model.CIStr{ 43 | O: "foo", 44 | L: "foo", 45 | }, 46 | }, 47 | }, 48 | }, 49 | }, 50 | } 51 | _, err := parseUpdate(stmt) 52 | assert.Contains(t, err.Error(), "unsupported opcode") 53 | }) 54 | 55 | t.Run("orderExpr", func(t *testing.T) { 56 | stmt := &ast.UpdateStmt{ 57 | Order: &ast.OrderByClause{ 58 | Items: []*ast.ByItem{ 59 | {}, 60 | }, 61 | }, 62 | TableRefs: &ast.TableRefsClause{ 63 | TableRefs: &ast.Join{ 64 | Left: &ast.TableSource{ 65 | Source: &ast.TableName{ 66 | Name: model.CIStr{ 67 | O: "foo", 68 | L: "foo", 69 | }, 70 | }, 71 | }, 72 | }, 73 | }, 74 | } 75 | _, err := parseUpdate(stmt) 76 | assert.Contains(t, err.Error(), errorInvalidExprNode.Error()) 77 | }) 78 | 79 | t.Run("limitExpr", func(t *testing.T) { 80 | stmt := &ast.UpdateStmt{ 81 | Limit: &ast.Limit{ 82 | Count: &ast.BetweenExpr{}, 83 | }, 84 | TableRefs: &ast.TableRefsClause{ 85 | TableRefs: &ast.Join{ 86 | Left: &ast.TableSource{ 87 | Source: &ast.TableName{ 88 | Name: model.CIStr{ 89 | O: "foo", 90 | L: "foo", 91 | }, 92 | }, 93 | }, 94 | }, 95 | }, 96 | } 97 | _, err := parseUpdate(stmt) 98 | assert.Contains(t, err.Error(), errorUnsupportedLimitExpr.Error()) 99 | }) 100 | 101 | } 102 | -------------------------------------------------------------------------------- /internal/patterns/patterns.go: -------------------------------------------------------------------------------- 1 | package patterns 2 | 3 | import ( 4 | "path/filepath" 5 | "strings" 6 | 7 | "github.com/anqiansong/sqlgen/internal/set" 8 | ) 9 | 10 | // Pattern is a set of patterns. 11 | type Pattern []string 12 | 13 | func (p Pattern) Match(list ...string) []string { 14 | var matchTableSet = set.From() 15 | for _, s := range list { 16 | for _, v := range p { 17 | match, _ := filepath.Match(v, filepath.Base(s)) 18 | if match { 19 | matchTableSet.Add(s) 20 | } 21 | } 22 | } 23 | return matchTableSet.String() 24 | 25 | } 26 | func New(patterns ...string) Pattern { 27 | var patternSet = set.From() 28 | if len(patterns) == 0 { 29 | patternSet.Add("*") 30 | return patternSet.String() 31 | } 32 | 33 | for _, v := range patterns { 34 | fields := strings.FieldsFunc(v, func(r rune) bool { 35 | return r == ',' 36 | }) 37 | for _, f := range fields { 38 | patternSet.Add(f) 39 | } 40 | } 41 | 42 | return patternSet.String() 43 | } 44 | -------------------------------------------------------------------------------- /internal/patterns/patterns_test.go: -------------------------------------------------------------------------------- 1 | package patterns 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestPattern_Match(t *testing.T) { 11 | list := Pattern{} 12 | matched := list.Match() 13 | assert.Equal(t, []string(nil), matched) 14 | matched = list.Match("foo") 15 | assert.Equal(t, []string(nil), matched) 16 | 17 | list = Pattern{"*"} 18 | matched = list.Match() 19 | assert.Equal(t, []string(nil), matched) 20 | matched = list.Match("foo") 21 | assert.Equal(t, "foo", matched[0]) 22 | } 23 | 24 | func TestNew(t *testing.T) { 25 | p := New() 26 | assert.Equal(t, "*", p[0]) 27 | 28 | p = New("foo") 29 | assert.Equal(t, "foo", p[0]) 30 | 31 | p = New("foo,bar,baz,baz") 32 | assert.Equal(t, "foo,bar,baz", strings.Join(p, ",")) 33 | } 34 | -------------------------------------------------------------------------------- /internal/set/set.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import "container/list" 4 | 5 | type ListSet struct { 6 | list *list.List 7 | m map[interface{}]*list.Element 8 | } 9 | 10 | func From(data ...any) *ListSet { 11 | set := &ListSet{ 12 | list: list.New(), 13 | m: make(map[interface{}]*list.Element), 14 | } 15 | set.Add(data...) 16 | return set 17 | } 18 | 19 | func FromString(data ...string) *ListSet { 20 | return From(string2Any(data...)...) 21 | } 22 | 23 | func (s *ListSet) Add(data ...any) { 24 | for _, v := range data { 25 | if _, ok := s.m[v]; ok { 26 | continue 27 | } 28 | 29 | var el = s.list.PushBack(v) 30 | s.m[v] = el 31 | } 32 | } 33 | 34 | func string2Any(data ...string) []any { 35 | var ret []any 36 | for _, v := range data { 37 | ret = append(ret, v) 38 | } 39 | 40 | return ret 41 | } 42 | func (s *ListSet) AddStringList(data []string) { 43 | s.Add(string2Any(data...)...) 44 | } 45 | 46 | func (s *ListSet) Remove(data any) { 47 | el, ok := s.m[data] 48 | if !ok { 49 | return 50 | } 51 | 52 | delete(s.m, data) 53 | s.list.Remove(el) 54 | } 55 | 56 | func (s *ListSet) Exists(data any) bool { 57 | _, ok := s.m[data] 58 | return ok 59 | } 60 | 61 | func (s *ListSet) String() []string { 62 | var ret []string 63 | s.Range(func(v interface{}) { 64 | s, ok := v.(string) 65 | if ok { 66 | ret = append(ret, s) 67 | } 68 | }) 69 | 70 | return ret 71 | } 72 | 73 | func (s *ListSet) Int() []int { 74 | var ret []int 75 | s.Range(func(v interface{}) { 76 | i, ok := v.(int) 77 | if ok { 78 | ret = append(ret, i) 79 | } 80 | }) 81 | 82 | return ret 83 | } 84 | 85 | func (s *ListSet) Int32() []int32 { 86 | var ret []int32 87 | s.Range(func(v interface{}) { 88 | i, ok := v.(int32) 89 | if ok { 90 | ret = append(ret, i) 91 | } 92 | }) 93 | 94 | return ret 95 | } 96 | 97 | func (s *ListSet) Int64() []int64 { 98 | var ret []int64 99 | s.Range(func(v interface{}) { 100 | i, ok := v.(int64) 101 | if ok { 102 | ret = append(ret, i) 103 | } 104 | }) 105 | 106 | return ret 107 | } 108 | 109 | func (s *ListSet) Init() { 110 | s.list = list.New() 111 | s.m = make(map[interface{}]*list.Element) 112 | } 113 | 114 | func (s *ListSet) Range(fn func(v interface{})) { 115 | var next = s.list.Front() 116 | if next == nil { 117 | return 118 | } 119 | for { 120 | if next == nil { 121 | return 122 | } 123 | fn(next.Value) 124 | next = next.Next() 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /internal/set/set_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestListSet(t *testing.T) { 11 | listSet := From("a", "b", "c") 12 | var result = listSet.String() 13 | assert.Equal(t, []string{"a", "b", "c"}, result) 14 | listSet.Add("b", "c", "d") 15 | result = listSet.String() 16 | assert.Equal(t, []string{"a", "b", "c", "d"}, result) 17 | listSet.Remove("a") 18 | result = listSet.String() 19 | assert.Equal(t, []string{"b", "c", "d"}, result) 20 | exists := listSet.Exists("b") 21 | assert.True(t, exists) 22 | listSet.Add(1, 2, 3) 23 | intResult := listSet.Int() 24 | assert.Equal(t, []int{1, 2, 3}, intResult) 25 | listSet.Add(int32(4), int32(5), int32(6)) 26 | int32Result := listSet.Int32() 27 | assert.Equal(t, []int32{4, 5, 6}, int32Result) 28 | listSet.Add(int64(7), int64(8), int64(9)) 29 | int64Result := listSet.Int64() 30 | assert.Equal(t, []int64{7, 8, 9}, int64Result) 31 | } 32 | 33 | func TestFromString(t *testing.T) { 34 | s := FromString() 35 | assert.NotNil(t, s) 36 | s = FromString("foo", "bar", "baz") 37 | var result = s.String() 38 | assert.Equal(t, []string{"foo", "bar", "baz"}, result) 39 | } 40 | 41 | func TestListSet_AddStringList(t *testing.T) { 42 | s := FromString() 43 | s.AddStringList([]string{"foo", "bar", "baz"}) 44 | var result = s.String() 45 | assert.Equal(t, []string{"foo", "bar", "baz"}, result) 46 | } 47 | 48 | func TestListSet_Remove(t *testing.T) { 49 | s := From("foo", "bar") 50 | s.Remove("baz") 51 | s.Remove("bar") 52 | var result = s.String() 53 | assert.Equal(t, []string{"foo"}, result) 54 | } 55 | 56 | func TestListSet_Init(t *testing.T) { 57 | s := From("foo", "bar", "baz") 58 | s.Init() 59 | var result = s.String() 60 | assert.Equal(t, []string(nil), result) 61 | } 62 | 63 | func TestListSet_Range(t *testing.T) { 64 | var str []string 65 | s := From() 66 | s.Range(func(v interface{}) { 67 | str = append(str, fmt.Sprint(v)) 68 | }) 69 | s = From("foo") 70 | s.Range(func(v interface{}) { 71 | str = append(str, fmt.Sprint(v)) 72 | }) 73 | 74 | assert.Equal(t, []string{"foo"}, str) 75 | } 76 | -------------------------------------------------------------------------------- /internal/spec/action.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | const ( 4 | _ = iota 5 | ActionCreate // ActionCreate represents a create action. 6 | ActionRead // ActionRead represents a read action. 7 | ActionUpdate // ActionUpdate represents an update action. 8 | ActionDelete // ActionDelete represents a delete action. 9 | ActionTransaction // ActionTransaction represents a transaction action. 10 | ) 11 | 12 | // Action represents an action. 13 | type Action int 14 | -------------------------------------------------------------------------------- /internal/spec/byitem.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/iancoleman/strcase" 8 | 9 | "github.com/anqiansong/sqlgen/internal/buffer" 10 | "github.com/anqiansong/sqlgen/internal/parameter" 11 | "github.com/anqiansong/sqlgen/internal/set" 12 | ) 13 | 14 | // ByItems returns the by items. 15 | type ByItems []*ByItem 16 | 17 | // ByItem represents an order-by or group-by item. 18 | type ByItem struct { 19 | // Column represents the column name. 20 | Column string 21 | // Desc returns true if order by Column desc. 22 | Desc bool 23 | 24 | // the below data are from table 25 | // ColumnInfo are the column info which are convert from Column. 26 | ColumnInfo Column 27 | // TableInfo is the table info. 28 | TableInfo *Table 29 | 30 | // the below data are from stmt 31 | // Comment represents a sql comment. 32 | Comment Comment 33 | } 34 | 35 | func (b *ByItem) IsValid() bool { 36 | if b == nil { 37 | return false 38 | } 39 | 40 | return len(b.Column) > 0 41 | } 42 | 43 | func (b ByItems) IsValid() bool { 44 | if len(b) == 0 { 45 | return false 46 | } 47 | 48 | for _, v := range b { 49 | if v.IsValid() { 50 | return true 51 | } 52 | } 53 | 54 | return false 55 | } 56 | 57 | // SQL returns the clause condition strings. 58 | func (b ByItems) SQL() (string, error) { 59 | sql, _, err := b.marshal() 60 | return fmt.Sprintf("`%s`", sql), err 61 | } 62 | 63 | // ParameterStructure returns the parameter type structure. 64 | func (b ByItems) ParameterStructure(identifier string) (string, error) { 65 | _, parameters, err := b.marshal() 66 | if err != nil { 67 | return "", err 68 | } 69 | 70 | var writer = buffer.New() 71 | writer.Write(`// %s is a %s parameter structure.`, b.ParameterStructureName(identifier), strcase.ToDelimited(identifier, ' ')) 72 | writer.Write(`type %s struct {`, b.ParameterStructureName(identifier)) 73 | for _, v := range parameters { 74 | writer.Write("%s %s", v.Column, v.Type) 75 | } 76 | 77 | writer.Write(`}`) 78 | 79 | return writer.String(), nil 80 | } 81 | 82 | // ParameterThirdImports returns the third package imports. 83 | func (b ByItems) ParameterThirdImports() (string, error) { 84 | _, parameters, err := b.marshal() 85 | if err != nil { 86 | return "", err 87 | } 88 | var thirdPkgSet = set.From() 89 | for _, v := range parameters { 90 | if len(v.ThirdPkg) == 0 { 91 | continue 92 | } 93 | thirdPkgSet.Add(v.ThirdPkg) 94 | } 95 | 96 | return strings.Join(thirdPkgSet.String(), "\n"), nil 97 | } 98 | 99 | // Parameters returns the parameter variables. 100 | func (b ByItems) Parameters(pkg string) (string, error) { 101 | _, parameters, err := b.marshal() 102 | if err != nil { 103 | return "", err 104 | } 105 | var list []string 106 | for _, v := range parameters { 107 | list = append(list, fmt.Sprintf("%s.%s", pkg, v.Column)) 108 | } 109 | 110 | return strings.Join(list, ", "), nil 111 | } 112 | 113 | // ParameterStructureName returns the parameter structure name. 114 | func (b ByItems) ParameterStructureName(identifier string) string { 115 | if !b.IsValid() { 116 | return "" 117 | } 118 | 119 | one := b[0] 120 | return strcase.ToCamel(fmt.Sprintf("%s%sParameter", one.Comment.FuncName, identifier)) 121 | } 122 | 123 | func (b ByItems) marshal() (sql string, parameters parameter.Parameters, err error) { 124 | parameters = parameter.Empty 125 | if len(b) == 0 { 126 | return 127 | } 128 | 129 | var sqlJoin []string 130 | var ps = parameter.New() 131 | for _, v := range b { 132 | if v.Desc { 133 | sqlJoin = append(sqlJoin, fmt.Sprintf("%s desc", v.Column)) 134 | } else { 135 | sqlJoin = append(sqlJoin, v.Column) 136 | } 137 | 138 | p, err := v.ColumnInfo.DataType() 139 | if err != nil { 140 | return "", nil, err 141 | } 142 | 143 | ps.Add(p) 144 | } 145 | 146 | sql = strings.Join(sqlJoin, ", ") 147 | parameters = ps.List() 148 | return 149 | } 150 | -------------------------------------------------------------------------------- /internal/spec/clause.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/iancoleman/strcase" 8 | 9 | "github.com/anqiansong/sqlgen/internal/buffer" 10 | "github.com/anqiansong/sqlgen/internal/parameter" 11 | "github.com/anqiansong/sqlgen/internal/set" 12 | ) 13 | 14 | // Clause represents a where clause, having clause. 15 | type Clause struct { 16 | // Column represents the column name. 17 | Column string 18 | // Left represents the left expr. 19 | Left *Clause 20 | // Right represents the right expr. 21 | Right *Clause 22 | // OP represents the operator. 23 | OP OP 24 | 25 | // the below data are from table 26 | // ColumnInfo are the column info which are convert from Column. 27 | ColumnInfo Column 28 | // TableInfo is the table info. 29 | TableInfo *Table 30 | 31 | // the below data are from stmt 32 | // Comment represents a sql comment. 33 | Comment Comment 34 | } 35 | 36 | // NewParameter returns a new parameter. 37 | func NewParameter(column string, tp string, thirdPkg string) parameter.Parameter { 38 | return parameter.Parameter{Column: strcase.ToCamel(column), Type: tp, ThirdPkg: thirdPkg} 39 | } 40 | 41 | // IsValid returns true if the statement is valid. 42 | func (c *Clause) IsValid() bool { 43 | if c == nil { 44 | return false 45 | } 46 | 47 | return c.Column != "" || c.OP != 0 || c.Left != nil || c.Right != nil 48 | } 49 | 50 | // SQL returns the clause condition strings. 51 | func (c *Clause) SQL() (string, error) { 52 | if !c.IsValid() { 53 | return "", nil 54 | } 55 | 56 | sql, _, err := c.marshal() 57 | return fmt.Sprintf("`%s`", sql), err 58 | } 59 | 60 | // ParameterStructure returns the parameter type structure. 61 | func (c *Clause) ParameterStructure(identifier string) (string, error) { 62 | if !c.IsValid() { 63 | return "", nil 64 | } 65 | 66 | _, parameters, err := c.marshal() 67 | if err != nil { 68 | return "", err 69 | } 70 | 71 | var writer = buffer.New() 72 | writer.Write(`// %s is a %s parameter structure.`, c.ParameterStructureName(identifier), strcase.ToDelimited(identifier, ' ')) 73 | writer.Write(`type %s struct {`, c.ParameterStructureName(identifier)) 74 | for _, v := range parameters { 75 | writer.Write("%s %s", v.Column, v.Type) 76 | } 77 | 78 | writer.Write(`}`) 79 | 80 | return writer.String(), nil 81 | } 82 | 83 | // ParameterStructureName returns the parameter structure name. 84 | func (c *Clause) ParameterStructureName(identifier string) string { 85 | if !c.IsValid() { 86 | return "" 87 | } 88 | return strcase.ToCamel(fmt.Sprintf("%s%sParameter", c.Comment.FuncName, identifier)) 89 | } 90 | 91 | // ParameterThirdImports returns the third package imports. 92 | func (c *Clause) ParameterThirdImports() (string, error) { 93 | if !c.IsValid() { 94 | return "", nil 95 | } 96 | 97 | _, parameters, err := c.marshal() 98 | if err != nil { 99 | return "", err 100 | } 101 | var thirdPkgSet = set.From() 102 | for _, v := range parameters { 103 | if len(v.ThirdPkg) == 0 { 104 | continue 105 | } 106 | thirdPkgSet.Add(v.ThirdPkg) 107 | } 108 | 109 | return strings.Join(thirdPkgSet.String(), "\n"), nil 110 | } 111 | 112 | // Parameters returns the parameter variables. 113 | func (c *Clause) Parameters(pkg string) (string, error) { 114 | if !c.IsValid() { 115 | return "", nil 116 | } 117 | 118 | _, parameters, err := c.marshal() 119 | if err != nil { 120 | return "", err 121 | } 122 | var list []string 123 | for _, v := range parameters { 124 | list = append(list, fmt.Sprintf("%s.%s", pkg, v.Column)) 125 | } 126 | 127 | return strings.Join(list, ", "), nil 128 | } 129 | 130 | func (c *Clause) marshal() (sql string, parameters parameter.Parameters, err error) { 131 | if !c.IsValid() { 132 | return 133 | } 134 | 135 | parameters = parameter.Empty 136 | var ps = parameter.New() 137 | switch c.OP { 138 | case And, Or: 139 | leftSQL, leftParameter, err := c.Left.marshal() 140 | if err != nil { 141 | return "", nil, err 142 | } 143 | 144 | rightSQL, rightParameter, err := c.Right.marshal() 145 | if err != nil { 146 | return "", nil, err 147 | } 148 | 149 | ps.Add(leftParameter...) 150 | ps.Add(rightParameter...) 151 | var sqlList []string 152 | if len(leftSQL) > 0 { 153 | sqlList = append(sqlList, leftSQL) 154 | } 155 | if len(rightSQL) > 0 { 156 | sqlList = append(sqlList, rightSQL) 157 | } 158 | 159 | sql = strings.Join(sqlList, " "+Operator[c.OP]+" ") 160 | case EQ, GE, GT, LE, LT, Like, NE, Not, NotLike: 161 | sql = fmt.Sprintf("%s %s ?", c.Column, Operator[c.OP]) 162 | p, err := c.ColumnInfo.DataType() 163 | if err != nil { 164 | return "", nil, err 165 | } 166 | 167 | ps.Add(parameter.Parameter{ 168 | Column: p.Column + OpName[c.OP], 169 | Type: p.Type, 170 | ThirdPkg: p.ThirdPkg, 171 | }) 172 | case In, NotIn: 173 | sql = fmt.Sprintf("%s %s (?)", c.Column, Operator[c.OP]) 174 | p, err := c.ColumnInfo.DataType() 175 | if err != nil { 176 | return "", nil, err 177 | } 178 | 179 | p.Type = fmt.Sprintf("[]%s", p.Type) 180 | ps.Add(parameter.Parameter{ 181 | Column: p.Column + OpName[c.OP], 182 | Type: p.Type, 183 | ThirdPkg: p.ThirdPkg, 184 | }) 185 | case Between, NotBetween: 186 | sql = fmt.Sprintf("%s %s ? AND ?", c.Column, Operator[c.OP]) 187 | p, err := c.ColumnInfo.DataType() 188 | if err != nil { 189 | return "", nil, err 190 | } 191 | 192 | ps.Add( 193 | NewParameter(fmt.Sprintf("%s%sStart", c.Column, OpName[c.OP]), p.Type, p.ThirdPkg), 194 | NewParameter(fmt.Sprintf("%s%sEnd", c.Column, OpName[c.OP]), p.Type, p.ThirdPkg)) 195 | case Parentheses: 196 | leftSQL, leftParameter, err := c.Left.marshal() 197 | if err != nil { 198 | return "", nil, err 199 | } 200 | 201 | // assert right clause is nil 202 | //rightSQL, rightParameter, err := c.Right.marshal() 203 | //if err != nil { 204 | // return "", nil, err 205 | //} 206 | 207 | ps.Add(leftParameter...) 208 | //ps.Add(rightParameter...) 209 | 210 | if len(leftSQL) > 0 { 211 | sql = fmt.Sprintf("( %s )", leftSQL) 212 | } 213 | default: 214 | // ignores 'case' 215 | } 216 | parameters = ps.List() 217 | return 218 | } 219 | -------------------------------------------------------------------------------- /internal/spec/column.tpl: -------------------------------------------------------------------------------- 1 | {{UpperCamel .Name}} {{.GoType}} `{{ColumnTag}}json:"{{LowerCamel .Name}}"` -------------------------------------------------------------------------------- /internal/spec/comment.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import "fmt" 4 | 5 | // Comment represents a sql comment. 6 | type Comment struct { 7 | // OriginText represents the original sql text. 8 | OriginText string 9 | // LineText is the text of the line comment. 10 | LineText []string 11 | // FuncNames represents the generated function names. 12 | FuncName string 13 | } 14 | 15 | func (c Comment) validate() error { 16 | if len(c.FuncName) == 0 { 17 | return fmt.Errorf("missing func name near %s", c.OriginText) 18 | } 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /internal/spec/constraint.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "github.com/anqiansong/sqlgen/internal/set" 5 | ) 6 | 7 | // NewConstraint returns a new Constraint. 8 | func NewConstraint() *Constraint { 9 | return &Constraint{ 10 | PrimaryKey: map[string][]string{}, 11 | UniqueKey: map[string][]string{}, 12 | Index: map[string][]string{}, 13 | } 14 | } 15 | 16 | // AppendPrimaryKey appends a column to the primary key. 17 | func (c *Constraint) AppendPrimaryKey(key string, columns ...string) { 18 | c.append(func(key string) ([]string, bool) { 19 | list, ok := c.PrimaryKey[key] 20 | return list, ok 21 | }, func(columns []string) { 22 | c.PrimaryKey[key] = columns 23 | }, key, columns...) 24 | } 25 | 26 | // AppendUniqueKey appends a column to the unique key. 27 | func (c *Constraint) AppendUniqueKey(key string, columns ...string) { 28 | c.append(func(key string) ([]string, bool) { 29 | list, ok := c.UniqueKey[key] 30 | return list, ok 31 | }, func(columns []string) { 32 | c.UniqueKey[key] = columns 33 | }, key, columns...) 34 | } 35 | 36 | // AppendIndex appends a column to the unique key. 37 | func (c *Constraint) AppendIndex(key string, columns ...string) { 38 | c.append(func(key string) ([]string, bool) { 39 | list, ok := c.Index[key] 40 | return list, ok 41 | }, func(columns []string) { 42 | c.Index[key] = columns 43 | }, key, columns...) 44 | } 45 | 46 | // IsEmpty returns true if the constraint is empty. 47 | func (c *Constraint) IsEmpty() bool { 48 | return len(c.PrimaryKey) == 0 && len(c.UniqueKey) == 0 && len(c.Index) == 0 49 | } 50 | 51 | // Merge merges the constraint with another constraint. 52 | func (c *Constraint) Merge(constraint *Constraint) { 53 | if constraint == nil { 54 | return 55 | } 56 | if constraint.IsEmpty() { 57 | return 58 | } 59 | 60 | for key, columns := range constraint.PrimaryKey { 61 | c.AppendPrimaryKey(key, columns...) 62 | } 63 | 64 | for key, columns := range constraint.UniqueKey { 65 | c.AppendUniqueKey(key, columns...) 66 | } 67 | 68 | for key, columns := range constraint.Index { 69 | c.AppendIndex(key, columns...) 70 | } 71 | } 72 | 73 | func (c *Constraint) append(existFn func(key string) ([]string, bool), result func(columns []string), key string, columns ...string) { 74 | var columnSet = set.FromString(columns...) 75 | if len(columns) == 0 { 76 | columns = []string{key} 77 | } 78 | 79 | list, ok := existFn(key) 80 | if !ok { 81 | result(columnSet.String()) 82 | return 83 | } 84 | 85 | columnSet.AddStringList(list) 86 | for _, column := range columns { 87 | columnSet.Add(column) 88 | } 89 | 90 | result(columnSet.String()) 91 | } 92 | -------------------------------------------------------------------------------- /internal/spec/converter.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pingcap/parser/mysql" 7 | ) 8 | 9 | func convertLimit(limit *Limit, table *Table, comment Comment) *Limit { 10 | if !limit.IsValid() { 11 | return limit 12 | } 13 | 14 | limit.TableInfo = table 15 | limit.Comment = comment 16 | return limit 17 | } 18 | 19 | func convertByItems(byItems ByItems, table *Table, comment Comment) (ByItems, error) { 20 | var list ByItems 21 | for _, v := range byItems { 22 | byItem, err := convertByItem(v, table, comment) 23 | if err != nil { 24 | return nil, err 25 | } 26 | list = append(list, byItem) 27 | } 28 | return list, nil 29 | } 30 | 31 | func convertByItem(byItem *ByItem, table *Table, comment Comment) (*ByItem, error) { 32 | if !byItem.IsValid() { 33 | return byItem, nil 34 | } 35 | 36 | byItem.TableInfo = table 37 | byItem.Comment = comment 38 | if byItem.Column == WildCard { 39 | return nil, fmt.Errorf("wildcard is not allowed in by item") 40 | } 41 | column, ok := table.GetColumnByName(byItem.Column) 42 | if !ok { 43 | return nil, fmt.Errorf("column %q no found in table %q", byItem.Column, table.Name) 44 | } 45 | byItem.ColumnInfo = column 46 | return byItem, nil 47 | } 48 | 49 | func convertClause(clause *Clause, table *Table, comment Comment, rows Columns) (*Clause, error) { 50 | if !clause.IsValid() { 51 | return clause, nil 52 | } 53 | 54 | clause.Comment = comment 55 | clause.TableInfo = table 56 | if clause.Column == WildCard { 57 | return nil, fmt.Errorf("wildcard is not allowed in by item") 58 | } 59 | if len(clause.Column) > 0 { 60 | column, ok := table.GetColumnByName(clause.Column) 61 | if ok { 62 | clause.ColumnInfo = column 63 | } else { 64 | // for case: select max(id) AS maxID from t having maxID > 0; 65 | column, ok = rows.GetColumn(clause.Column) 66 | if !ok { 67 | return nil, fmt.Errorf("column %q no found in table %q", clause.Column, table.Name) 68 | } 69 | } 70 | 71 | clause.ColumnInfo = column 72 | } 73 | 74 | leftClause, err := convertClause(clause.Left, table, comment, rows) 75 | if err != nil { 76 | return nil, err 77 | } 78 | rightClause, err := convertClause(clause.Right, table, comment, rows) 79 | if err != nil { 80 | return nil, err 81 | } 82 | 83 | clause.Left = leftClause 84 | clause.Right = rightClause 85 | return clause, nil 86 | } 87 | 88 | func convertColumn(table *Table, columns []string) Columns { 89 | var list Columns 90 | var m = map[string]struct{}{} 91 | for _, c := range columns { 92 | if _, ok := m[c]; ok { 93 | continue 94 | } 95 | if c == WildCard { 96 | list = append(list, table.Columns...) 97 | continue 98 | } 99 | 100 | column, ok := table.GetColumnByName(c) 101 | if ok { 102 | list = append(list, column) 103 | } 104 | 105 | } 106 | return list 107 | } 108 | 109 | func convertField(table *Table, fields []Field) (Columns, error) { 110 | var list Columns 111 | var m = map[string]struct{}{} 112 | for _, f := range fields { 113 | name := f.ColumnName 114 | if len(f.ASName) > 0 { 115 | name = f.ASName 116 | } 117 | if _, ok := m[name]; ok { 118 | continue 119 | } 120 | m[name] = struct{}{} 121 | if name == WildCard { 122 | list = append(list, table.Columns...) 123 | continue 124 | } 125 | 126 | if len(f.ColumnName) > 0 { 127 | column, ok := table.GetColumnByName(f.ColumnName) 128 | if ok { 129 | column.Name = name 130 | column.AggregateCall = f.AggregateCall 131 | if f.TP != mysql.TypeUnspecified { 132 | column.TP = f.TP 133 | } 134 | list = append(list, column) 135 | } else { 136 | return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) 137 | } 138 | } else { 139 | if f.TP == mysql.TypeUnspecified { 140 | return nil, fmt.Errorf("column %q no found in table %q", f.ColumnName, table.Name) 141 | } 142 | list = append(list, Column{ 143 | Name: name, 144 | TP: f.TP, 145 | AggregateCall: f.AggregateCall, 146 | }) 147 | } 148 | } 149 | return list, nil 150 | } 151 | -------------------------------------------------------------------------------- /internal/spec/ddl.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // DDL represents a DDL statement. 4 | type DDL struct { 5 | // Table represents a table in the database. 6 | Table *Table 7 | } 8 | 9 | // IsEmpty returns true if the DDL is empty. 10 | func (d *DDL) IsEmpty() bool { 11 | return d.Table == nil 12 | } 13 | 14 | func (d *DDL) validate() error { 15 | if d.Table == nil { 16 | return nil 17 | } 18 | return d.Table.validate() 19 | } 20 | -------------------------------------------------------------------------------- /internal/spec/delete.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // DeleteStmt represents a delete statement. 4 | type DeleteStmt struct { 5 | // Action represents the db action. 6 | Action Action 7 | // Comment represents a sql comment. 8 | Comment 9 | // From represents the operation table name, do not support multiple tables. 10 | From string 11 | // Limit represents the limit clause. 12 | Limit *Limit 13 | // OrderBy represents the order by clause. 14 | OrderBy ByItems 15 | // SQL represents the original sql text. 16 | SQL string 17 | // Where represents the where clause. 18 | Where *Clause 19 | 20 | // the below data are from table 21 | // FromInfo is the table info which is convert from From. 22 | FromInfo *Table 23 | } 24 | 25 | func (d *DeleteStmt) SQLText() string { 26 | return d.SQL 27 | } 28 | 29 | func (d *DeleteStmt) TableName() string { 30 | return d.From 31 | } 32 | 33 | func (d *DeleteStmt) validate() (map[string]string, error) { 34 | return map[string]string{ 35 | d.FuncName: d.OriginText, 36 | }, d.Comment.validate() 37 | } 38 | 39 | func (d *DeleteStmt) HasArg() bool { 40 | if d.Limit.IsValid() { 41 | return true 42 | } 43 | if d.OrderBy.IsValid() { 44 | return true 45 | } 46 | if d.Where.IsValid() { 47 | return true 48 | } 49 | return false 50 | } 51 | -------------------------------------------------------------------------------- /internal/spec/dml.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // DML represents a DML statement. 4 | type DML interface { 5 | // SQLText returns the SQL text of the DML statement. 6 | SQLText() string 7 | // TableName returns the table of the DML statement. 8 | TableName() string 9 | validate() (map[string]string, error) 10 | } 11 | -------------------------------------------------------------------------------- /internal/spec/insert.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // InsertStmt represents a insert statement. 4 | type InsertStmt struct { 5 | // Action represents the db action. 6 | Action Action 7 | // Columns represents the operation columns. 8 | Columns []string 9 | // Comment represents a sql comment. 10 | Comment 11 | // SQL represents the original sql text. 12 | SQL string 13 | // Table represents the operation table name, do not support multiple tables. 14 | Table string 15 | 16 | // the below data are from table 17 | // ColumnInfo are the column info which are convert from Columns. 18 | ColumnInfo Columns 19 | // TableInfo is the table info which is convert from Table. 20 | TableInfo *Table 21 | } 22 | 23 | func (i *InsertStmt) SQLText() string { 24 | return i.SQL 25 | } 26 | 27 | func (i *InsertStmt) TableName() string { 28 | return i.Table 29 | } 30 | 31 | func (i *InsertStmt) validate() (map[string]string, error) { 32 | return map[string]string{ 33 | i.FuncName: i.OriginText, 34 | }, i.Comment.validate() 35 | } 36 | 37 | func (i *InsertStmt) HasArg() bool { 38 | return false 39 | } 40 | -------------------------------------------------------------------------------- /internal/spec/limit.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/iancoleman/strcase" 8 | 9 | "github.com/anqiansong/sqlgen/internal/buffer" 10 | "github.com/anqiansong/sqlgen/internal/parameter" 11 | "github.com/anqiansong/sqlgen/internal/set" 12 | ) 13 | 14 | const ( 15 | countField = "count" 16 | offsetFiled = "offset" 17 | ) 18 | 19 | func (l *Limit) IsValid() bool { 20 | if l == nil { 21 | return false 22 | } 23 | return l.Count > 0 24 | } 25 | 26 | // SQL returns the clause condition strings. 27 | func (l *Limit) SQL() (string, error) { 28 | sql, _, err := l.marshal() 29 | return fmt.Sprintf("`%s`", sql), err 30 | } 31 | 32 | // ParameterStructure returns the parameter type structure. 33 | func (l *Limit) ParameterStructure() (string, error) { 34 | _, parameters, err := l.marshal() 35 | if err != nil { 36 | return "", err 37 | } 38 | 39 | var writer = buffer.New() 40 | writer.Write(`// %s is a limit parameter structure.`, l.ParameterStructureName()) 41 | writer.Write(`type %s struct {`, l.ParameterStructureName()) 42 | for _, v := range parameters { 43 | writer.Write("%s %s", v.Column, v.Type) 44 | } 45 | 46 | writer.Write(`}`) 47 | 48 | return writer.String(), nil 49 | } 50 | 51 | // ParameterThirdImports returns the third package imports. 52 | func (l *Limit) ParameterThirdImports() (string, error) { 53 | _, parameters, err := l.marshal() 54 | if err != nil { 55 | return "", err 56 | } 57 | var thirdPkgSet = set.From() 58 | for _, v := range parameters { 59 | if len(v.ThirdPkg) == 0 { 60 | continue 61 | } 62 | thirdPkgSet.Add(v.ThirdPkg) 63 | } 64 | 65 | return strings.Join(thirdPkgSet.String(), "\n"), nil 66 | } 67 | 68 | // Parameters returns the parameter variables. 69 | func (l *Limit) Parameters(pkg string) (string, error) { 70 | _, parameters, err := l.marshal() 71 | if err != nil { 72 | return "", err 73 | } 74 | var list []string 75 | for _, v := range parameters { 76 | list = append(list, fmt.Sprintf("%s.%s", pkg, v.Column)) 77 | } 78 | 79 | return strings.Join(list, ", "), nil 80 | } 81 | 82 | // LimitParameter returns the parameter variables. 83 | func (l *Limit) LimitParameter(pkg string) string { 84 | return fmt.Sprintf("%s.%s", pkg, strcase.ToCamel(countField)) 85 | } 86 | 87 | // OffsetParameter returns the parameter variables. 88 | func (l *Limit) OffsetParameter(pkg string) string { 89 | return fmt.Sprintf("%s.%s", pkg, strcase.ToCamel(offsetFiled)) 90 | } 91 | 92 | // ParameterStructureName returns the parameter structure name. 93 | func (l *Limit) ParameterStructureName() string { 94 | if !l.IsValid() { 95 | return "" 96 | } 97 | 98 | return strcase.ToCamel(fmt.Sprintf("%sLimitParameter", l.Comment.FuncName)) 99 | } 100 | 101 | func (l *Limit) One() bool { 102 | if l == nil { 103 | return false 104 | } 105 | return l.Count == 1 106 | } 107 | 108 | func (l *Limit) Multiple() bool { 109 | if l == nil { 110 | return false 111 | } 112 | return l.Count > 1 113 | } 114 | 115 | func (l *Limit) marshal() (sql string, parameters parameter.Parameters, err error) { 116 | parameters = parameter.Empty 117 | if l == nil { 118 | return 119 | } 120 | 121 | sql = fmt.Sprintf("limit %d", l.Count) 122 | parameters = append(parameters, NewParameter(countField, "int", "")) 123 | if l.Offset > 0 { 124 | sql = fmt.Sprintf("limit %d, %d", l.Offset, l.Count) 125 | parameters = append(parameters, NewParameter(offsetFiled, "int", "")) 126 | } 127 | 128 | return 129 | } 130 | -------------------------------------------------------------------------------- /internal/spec/op.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | const ( 4 | _ OP = iota 5 | ColumnValue 6 | And 7 | Between 8 | Case 9 | EQ 10 | Or 11 | GE 12 | GT 13 | In 14 | LE 15 | LT 16 | Like 17 | NE 18 | Not 19 | NotBetween 20 | NotIn 21 | NotLike 22 | Parentheses 23 | ) 24 | 25 | // OP is opcode type. 26 | type OP int 27 | 28 | var Operator = []string{ 29 | And: "AND", 30 | Between: "BETWEEN", 31 | Case: "CASE", 32 | EQ: "=", 33 | Or: "OR", 34 | GE: ">=", 35 | GT: ">", 36 | In: "IN", 37 | LE: "<=", 38 | LT: "<", 39 | Like: "LIKE", 40 | NE: "!=", 41 | Not: "NOT", 42 | NotBetween: "NOT BETWEEN", 43 | NotIn: "NOT IN", 44 | NotLike: "NOT LIKE", 45 | } 46 | 47 | var OpName = []string{ 48 | And: "And", 49 | Between: "Between", 50 | Case: "Case", 51 | EQ: "Equal", 52 | Or: "Or", 53 | GE: "GE", 54 | GT: "GT", 55 | In: "In", 56 | LE: "LE", 57 | LT: "LT", 58 | Like: "Like", 59 | NE: "NE", 60 | Not: "Not", 61 | NotBetween: "NotBetween", 62 | NotIn: "NotIn", 63 | NotLike: "NotLike", 64 | } 65 | -------------------------------------------------------------------------------- /internal/spec/select.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "bytes" 5 | _ "embed" 6 | "fmt" 7 | "strings" 8 | "text/template" 9 | 10 | "github.com/iancoleman/strcase" 11 | 12 | "github.com/anqiansong/sqlgen/internal/buffer" 13 | "github.com/anqiansong/sqlgen/internal/templatex" 14 | ) 15 | 16 | const ( 17 | ormBun = "bun" 18 | ormGorm = "gorm" 19 | ormSQL = "sql" 20 | ormSQLX = "sqlx" 21 | ormXorm = "xorm" 22 | ) 23 | 24 | // SelectStmt represents a select statement. 25 | type SelectStmt struct { 26 | // Action represents the db action. 27 | Action Action 28 | // SelectSQL represents the select filed sql. 29 | SelectSQL string 30 | // Columns represents the operation columns. 31 | Columns Fields 32 | // Comment represents a sql comment. 33 | Comment 34 | // Distinct represents the select distinct flag. 35 | Distinct bool 36 | // From represents the operation table name, do not support multiple tables. 37 | From string 38 | // GroupBy represents the group by clause. 39 | GroupBy ByItems 40 | // Having represents the having clause. 41 | Having *Clause 42 | // Limit represents the limit clause. 43 | Limit *Limit 44 | // OrderBy represents the order by clause. 45 | OrderBy ByItems 46 | // SQL represents the original sql text. 47 | SQL string 48 | // Where represents the where clause. 49 | Where *Clause 50 | 51 | // the below data are from table 52 | // ColumnInfo are the column info which are convert from Columns. 53 | ColumnInfo Columns 54 | // FromInfo is the table info which is convert from From. 55 | FromInfo *Table 56 | } 57 | 58 | func (s *SelectStmt) SQLText() string { 59 | return s.SQL 60 | } 61 | 62 | func (s *SelectStmt) TableName() string { 63 | return s.From 64 | } 65 | 66 | func (s *SelectStmt) ReceiverName() string { 67 | if s.ContainsExtraColumns() { 68 | return strcase.ToCamel(fmt.Sprintf("%sResult", s.FuncName)) 69 | } 70 | return strcase.ToCamel(s.TableName()) 71 | } 72 | 73 | //go:embed column.tpl 74 | var fieldTpl string 75 | 76 | func (s *SelectStmt) ReceiverStructure(orm string) string { 77 | receiverName := s.ReceiverName() 78 | if strings.EqualFold(receiverName, s.TableName()) { 79 | // Use table struct 80 | return "" 81 | } 82 | var buf = buffer.New() 83 | buf.Write("\n") 84 | buf.Write("// %s is a %s.", receiverName, strcase.ToDelimited(receiverName, ' ')) 85 | buf.Write(`type %s struct {`, receiverName) 86 | if orm == ormBun { 87 | buf.Write("bun.BaseModel `bun:\"table:%s\"`", s.FromInfo.Name) 88 | } 89 | for _, v := range s.ColumnInfo { 90 | t := templatex.New() 91 | t.AppendFuncMap(template.FuncMap{ 92 | "ColumnTag": func() string { 93 | switch orm { 94 | case ormBun: 95 | return fmt.Sprintf(`bun:"%s" `, v.Name) 96 | case ormGorm: 97 | return fmt.Sprintf(`gorm:"column:%s" `, v.Name) 98 | case ormSQL: 99 | return "" // placeholder 100 | case ormSQLX: 101 | return fmt.Sprintf(`db:"%s" `, v.Name) 102 | case ormXorm: 103 | return fmt.Sprintf(`xorm:"'%s'" `, v.Name) 104 | default: 105 | return "" 106 | } 107 | }, 108 | }) 109 | t.MustParse(fieldTpl) 110 | t.MustExecute(v) 111 | var columnBuf bytes.Buffer 112 | t.Write(&columnBuf, false) 113 | buf.Write(columnBuf.String()) 114 | } 115 | buf.Write("}") 116 | return buf.String() 117 | } 118 | 119 | // ContainsExtraColumns returns true if the select statement contains extra columns. 120 | func (s *SelectStmt) ContainsExtraColumns() bool { 121 | for _, f := range s.Columns { 122 | name := f.Name() 123 | if name == WildCard { 124 | continue 125 | } 126 | if !s.FromInfo.Columns.Has(name) { 127 | return true 128 | } 129 | } 130 | return false 131 | } 132 | 133 | func (s *SelectStmt) validate() (map[string]string, error) { 134 | return map[string]string{ 135 | s.FuncName: s.OriginText, 136 | }, s.Comment.validate() 137 | } 138 | 139 | func (s *SelectStmt) HasArg() bool { 140 | if s.GroupBy.IsValid() { 141 | return true 142 | } 143 | if s.Having.IsValid() { 144 | return true 145 | } 146 | if s.Limit.IsValid() { 147 | return true 148 | } 149 | if s.OrderBy.IsValid() { 150 | return true 151 | } 152 | if s.Where.IsValid() { 153 | return true 154 | } 155 | return false 156 | } 157 | -------------------------------------------------------------------------------- /internal/spec/spec.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // DXL describes ddl and dml. 4 | type DXL struct { 5 | // DDL represents a DDL statement. 6 | DDL []*DDL 7 | // DML represents a DML statement. 8 | DML []DML 9 | } 10 | 11 | // Validate validates the ddl and dml. 12 | func (dxl *DXL) Validate() error { 13 | for _, ddl := range dxl.DDL { 14 | if err := ddl.validate(); err != nil { 15 | return err 16 | } 17 | } 18 | 19 | for _, dml := range dxl.DML { 20 | if _, err := dml.validate(); err != nil { 21 | return err 22 | } 23 | } 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /internal/spec/sql.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import "fmt" 4 | 5 | // Context is sql table and query context. 6 | type Context struct { 7 | Table *Table 8 | InsertStmt []*InsertStmt 9 | SelectStmt []*SelectStmt 10 | UpdateStmt []*UpdateStmt 11 | DeleteStmt []*DeleteStmt 12 | Transaction []*Transaction 13 | } 14 | 15 | // From creates context from table and dml. 16 | func From(dxl *DXL) ([]Context, error) { 17 | var list []Context 18 | for _, d := range dxl.DDL { 19 | ctx, err := from(d.Table, dxl.DML) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | list = append(list, ctx) 25 | } 26 | 27 | return list, nil 28 | } 29 | 30 | func from(table *Table, dml []DML) (Context, error) { 31 | var ctx Context 32 | ctx.Table = table 33 | for _, d := range dml { 34 | transaction, isTransaction := d.(*Transaction) 35 | if !isTransaction { 36 | if d.TableName() == table.Name { 37 | switch v := d.(type) { 38 | case *InsertStmt: 39 | columns := convertColumn(table, v.Columns) 40 | v.ColumnInfo = columns 41 | v.TableInfo = table 42 | ctx.InsertStmt = append(ctx.InsertStmt, v) 43 | case *SelectStmt: 44 | columns, err := convertField(table, v.Columns) 45 | if err != nil { 46 | return ctx, err 47 | } 48 | 49 | v.GroupBy, err = convertByItems(v.GroupBy, table, v.Comment) 50 | if err != nil { 51 | return ctx, err 52 | } 53 | 54 | v.Having, err = convertClause(v.Having, table, v.Comment, columns) 55 | if err != nil { 56 | return ctx, err 57 | } 58 | 59 | v.Where, err = convertClause(v.Where, table, v.Comment, columns) 60 | if err != nil { 61 | return ctx, err 62 | } 63 | 64 | v.OrderBy, err = convertByItems(v.OrderBy, table, v.Comment) 65 | if err != nil { 66 | return ctx, err 67 | } 68 | 69 | v.Limit = convertLimit(v.Limit, table, v.Comment) 70 | v.ColumnInfo = columns 71 | v.FromInfo = table 72 | ctx.SelectStmt = append(ctx.SelectStmt, v) 73 | case *UpdateStmt: 74 | columns := convertColumn(table, v.Columns) 75 | var err error 76 | 77 | v.Where, err = convertClause(v.Where, table, v.Comment, columns) 78 | if err != nil { 79 | return ctx, err 80 | } 81 | 82 | v.OrderBy, err = convertByItems(v.OrderBy, table, v.Comment) 83 | if err != nil { 84 | return ctx, err 85 | } 86 | 87 | v.TableInfo = table 88 | v.ColumnInfo = columns 89 | v.Limit = convertLimit(v.Limit, table, v.Comment) 90 | ctx.UpdateStmt = append(ctx.UpdateStmt, v) 91 | case *DeleteStmt: 92 | var err error 93 | v.Where, err = convertClause(v.Where, table, v.Comment, nil) 94 | if err != nil { 95 | return ctx, err 96 | } 97 | 98 | v.OrderBy, err = convertByItems(v.OrderBy, table, v.Comment) 99 | if err != nil { 100 | return ctx, err 101 | } 102 | 103 | v.FromInfo = table 104 | v.Limit = convertLimit(v.Limit, table, v.Comment) 105 | ctx.DeleteStmt = append(ctx.DeleteStmt, v) 106 | } 107 | } 108 | } else { 109 | childCtx, err := from(table, transaction.Statements) 110 | if err != nil { 111 | 112 | return Context{}, err 113 | } 114 | transaction.Context = childCtx 115 | ctx.Transaction = append(ctx.Transaction, transaction) 116 | } 117 | } 118 | if _, err := ctx.validate(); err != nil { 119 | return Context{}, err 120 | } 121 | return ctx, nil 122 | } 123 | 124 | func (ctx Context) validate() (map[string]string, error) { 125 | funcName := map[string]string{} 126 | for _, v := range ctx.InsertStmt { 127 | if _, err := v.validate(); err != nil { 128 | return nil, err 129 | } 130 | if _, ok := funcName[v.FuncName]; ok { 131 | return nil, fmt.Errorf("duplicate function %q near by %q", v.FuncName, v.OriginText) 132 | } 133 | funcName[v.FuncName] = v.OriginText 134 | } 135 | for _, v := range ctx.SelectStmt { 136 | if _, err := v.validate(); err != nil { 137 | return nil, err 138 | } 139 | if _, ok := funcName[v.FuncName]; ok { 140 | return nil, fmt.Errorf("duplicate function %q near by %q", v.FuncName, v.OriginText) 141 | } 142 | funcName[v.FuncName] = v.OriginText 143 | } 144 | for _, v := range ctx.UpdateStmt { 145 | if _, err := v.validate(); err != nil { 146 | return nil, err 147 | } 148 | if _, ok := funcName[v.FuncName]; ok { 149 | return nil, fmt.Errorf("duplicate function %q near by %q", v.FuncName, v.OriginText) 150 | } 151 | funcName[v.FuncName] = v.OriginText 152 | } 153 | for _, v := range ctx.DeleteStmt { 154 | if _, err := v.validate(); err != nil { 155 | return nil, err 156 | } 157 | if _, ok := funcName[v.FuncName]; ok { 158 | return nil, fmt.Errorf("duplicate function %q near by %q", v.FuncName, v.OriginText) 159 | } 160 | funcName[v.FuncName] = v.OriginText 161 | } 162 | for _, v := range ctx.Transaction { 163 | if funcM, err := v.validate(); err != nil { 164 | return nil, err 165 | } else { 166 | for fn, originText := range funcM { 167 | if _, ok := funcName[fn]; ok { 168 | return nil, fmt.Errorf("duplicate function %q near by %q", fn, originText) 169 | } 170 | funcName[fn] = originText 171 | } 172 | } 173 | } 174 | return funcName, nil 175 | } 176 | -------------------------------------------------------------------------------- /internal/spec/stmt.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | _ "embed" 5 | ) 6 | 7 | // WildCard is a wildcard column. 8 | const WildCard = "*" 9 | 10 | var _ DML = (*InsertStmt)(nil) 11 | var _ DML = (*UpdateStmt)(nil) 12 | var _ DML = (*SelectStmt)(nil) 13 | var _ DML = (*DeleteStmt)(nil) 14 | 15 | type Fields []Field 16 | 17 | // Field represents a select filed. 18 | type Field struct { 19 | ASName string 20 | ColumnName string 21 | TP byte 22 | AggregateCall bool 23 | } 24 | 25 | // Limit represents a limit clause. 26 | type Limit struct { 27 | // Count represents the limit count. 28 | Count int64 29 | // Offset represents the limit offset. 30 | Offset int64 31 | 32 | // the below data are from table 33 | // TableInfo is the table info. 34 | TableInfo *Table 35 | 36 | // the below data are from stmt 37 | // Comment represents a sql comment. 38 | Comment Comment 39 | } 40 | 41 | func (f Field) Name() string { 42 | if len(f.ASName) > 0 { 43 | return f.ASName 44 | } 45 | return f.ColumnName 46 | } 47 | -------------------------------------------------------------------------------- /internal/spec/table.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/zeromicro/go-zero/core/stringx" 7 | ) 8 | 9 | // Table represents a table in the database. 10 | type Table struct { 11 | // Columns is the list of columns in the table. 12 | Columns Columns 13 | // Constraint is a struct that contains the constraints of a table. 14 | // ConstraintForeignKey,ConstraintFulltext,ConstraintCheck are ignored. 15 | Constraint Constraint 16 | // Schema is the name of the schema that the table belongs to. 17 | Schema string 18 | // Name is the name of the table. 19 | Name string 20 | } 21 | 22 | type Columns []Column 23 | 24 | // Column represents a column in a table. 25 | type Column struct { 26 | // ColumnOption is a column option. 27 | ColumnOption 28 | // Name is the name of the column. 29 | Name string 30 | // TP is the type of the column. 31 | TP byte 32 | AggregateCall bool 33 | } 34 | 35 | // ColumnOption is a column option. 36 | type ColumnOption struct { 37 | // AutoIncrement is true if the column allows auto increment. 38 | AutoIncrement bool 39 | // Comment is the comment of the column. 40 | Comment string 41 | // HasDefault is true if the column has default value. 42 | HasDefaultValue bool 43 | // TODO: Add default value 44 | // NotNull is true if the column is not null, false represents the column is null. 45 | NotNull bool 46 | // Unsigned is true if the column is unsigned. 47 | Unsigned bool 48 | } 49 | 50 | // Constraint is a struct that contains the constraints of a table. 51 | // ConstraintForeignKey,ConstraintFulltext,ConstraintCheck are ignored. 52 | type Constraint struct { 53 | // Index is a list of column names that are part of an index, the key of map 54 | // // is the key name, the values are the column list. 55 | Index map[string][]string 56 | // PrimaryKey is a list of column names that are part of the primary key, the key of map 57 | // is the key name, the values are the column list. 58 | PrimaryKey map[string][]string 59 | // UniqueKey is a list of column names that are part of a unique ke, the key of map 60 | // // is the key name, the values are the column list. 61 | UniqueKey map[string][]string 62 | } 63 | 64 | // Has returns true if Columns has specified column. 65 | func (cs Columns) Has(name string) bool { 66 | if cs == nil { 67 | return false 68 | } 69 | _, ok := cs.GetColumn(name) 70 | return ok 71 | } 72 | 73 | func (cs Columns) GetColumn(name string) (Column, bool) { 74 | for _, c := range cs { 75 | if c.Name == name { 76 | return c, true 77 | } 78 | } 79 | return Column{}, false 80 | } 81 | 82 | // IsPrimary returns true if the column is part of the primary key. 83 | func (t *Table) IsPrimary(name string) bool { 84 | for _, c := range t.Constraint.PrimaryKey { 85 | if stringx.Contains(c, name) { 86 | return true 87 | } 88 | } 89 | return false 90 | } 91 | 92 | // ColumnList is a list of column names. 93 | func (t *Table) ColumnList() []string { 94 | var list []string 95 | for _, c := range t.Columns { 96 | list = append(list, c.Name) 97 | } 98 | return list 99 | } 100 | 101 | // PrimaryColumnList is a list of column names that are part of the primary key. 102 | func (t *Table) PrimaryColumnList() Columns { 103 | var ret Columns 104 | for _, list := range t.Constraint.PrimaryKey { 105 | for _, name := range list { 106 | c, ok := t.GetColumnByName(name) 107 | if !ok { 108 | continue 109 | } 110 | ret = append(ret, c) 111 | } 112 | } 113 | return ret 114 | } 115 | 116 | // PrimaryColumn returns the primary column. 117 | func (t *Table) PrimaryColumn() Column { 118 | list := t.PrimaryColumnList() 119 | if len(list) == 0 { 120 | return Column{} 121 | } 122 | return list[0] 123 | } 124 | 125 | // HasOnePrimaryKey returns true if the table has one primary key. 126 | func (t *Table) HasOnePrimaryKey() bool { 127 | return len(t.PrimaryColumnList()) == 1 128 | } 129 | 130 | // GetColumnByName returns the column with the given name. 131 | func (t *Table) GetColumnByName(name string) (Column, bool) { 132 | for _, c := range t.Columns { 133 | if c.Name == name { 134 | return c, true 135 | } 136 | } 137 | return Column{}, false 138 | } 139 | 140 | func (t *Table) validate() error { 141 | if len(t.Name) == 0 { 142 | return fmt.Errorf("missing table name") 143 | } 144 | if len(t.Columns) == 0 { 145 | return fmt.Errorf("missing table columns") 146 | } 147 | if len(t.Constraint.PrimaryKey) == 0 { 148 | return fmt.Errorf("missing table primary key") 149 | } 150 | if len(t.Constraint.PrimaryKey) > 1 { 151 | return fmt.Errorf("unsupported multiple primary key") 152 | } 153 | return nil 154 | } 155 | -------------------------------------------------------------------------------- /internal/spec/transaction.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | type Transaction struct { 4 | // Action represents the db action. 5 | Action Action 6 | // Comment represents a sql comment. 7 | Comment 8 | // SQL represents the original sql text. 9 | SQL string 10 | // Statements represents the list of statement. 11 | Statements []DML 12 | 13 | // the below fields is convert from Statements 14 | Context 15 | } 16 | 17 | func (t Transaction) SQLText() string { 18 | return t.SQL 19 | } 20 | 21 | func (t Transaction) TableName() string { 22 | return "" 23 | } 24 | 25 | func (t Transaction) validate() (map[string]string, error) { 26 | return t.Context.validate() 27 | } 28 | 29 | func (t Transaction) HasArg() bool { 30 | for _, v := range t.InsertStmt { 31 | if v.HasArg() { 32 | return true 33 | } 34 | } 35 | for _, v := range t.SelectStmt { 36 | if v.HasArg() { 37 | return true 38 | } 39 | } 40 | for _, v := range t.UpdateStmt { 41 | if v.HasArg() { 42 | return true 43 | } 44 | } 45 | for _, v := range t.DeleteStmt { 46 | if v.HasArg() { 47 | return true 48 | } 49 | } 50 | return false 51 | } 52 | -------------------------------------------------------------------------------- /internal/spec/type.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | "github.com/anqiansong/sqlgen/internal/parameter" 8 | "github.com/pingcap/parser/mysql" 9 | ) 10 | 11 | const ( 12 | // TypeNullLongLong is a type extension for mysql.TypeLongLong. 13 | TypeNullLongLong byte = 0xf0 14 | // TypeNullDecimal is a type extension for mysql.TypeDecimal. 15 | TypeNullDecimal byte = 0xf1 16 | // TypeNullString is a type extension for mysql.TypeString. 17 | TypeNullString byte = 0xf2 18 | ) 19 | 20 | const defaultThirdDecimalPkg = "github.com/shopspring/decimal" 21 | 22 | type typeKey struct { 23 | tp byte 24 | signed bool 25 | thirdPkg string 26 | aggregateCall bool 27 | sql.NullFloat64 28 | } 29 | 30 | var typeMapper = map[typeKey]string{ 31 | typeKey{tp: mysql.TypeTiny}: "int8", 32 | typeKey{tp: mysql.TypeTiny, signed: true}: "uint8", 33 | typeKey{tp: mysql.TypeShort}: "int16", 34 | typeKey{tp: mysql.TypeShort, signed: true}: "uint16", 35 | typeKey{tp: mysql.TypeLong}: "int32", 36 | typeKey{tp: mysql.TypeLong, signed: true}: "uint32", 37 | typeKey{tp: mysql.TypeFloat}: "float64", 38 | typeKey{tp: mysql.TypeDouble}: "float64", 39 | typeKey{tp: mysql.TypeTimestamp}: "time.Time", 40 | typeKey{tp: mysql.TypeLonglong}: "int64", 41 | typeKey{tp: mysql.TypeLonglong, signed: true}: "uint64", 42 | typeKey{tp: mysql.TypeInt24}: "int32", 43 | typeKey{tp: mysql.TypeInt24, signed: true}: "uint32", 44 | typeKey{tp: mysql.TypeDate}: "time.Time", 45 | typeKey{tp: mysql.TypeDuration}: "time.Time", 46 | typeKey{tp: mysql.TypeDatetime}: "time.Time", 47 | typeKey{tp: mysql.TypeYear}: "string", 48 | typeKey{tp: mysql.TypeVarchar}: "string", 49 | typeKey{tp: mysql.TypeBit}: "byte", 50 | typeKey{tp: mysql.TypeJSON}: "string", 51 | typeKey{ 52 | tp: mysql.TypeNewDecimal, 53 | thirdPkg: defaultThirdDecimalPkg, 54 | }: "decimal.Decimal", 55 | typeKey{ 56 | tp: TypeNullDecimal, 57 | thirdPkg: defaultThirdDecimalPkg, 58 | }: "decimal.NullDecimal", 59 | typeKey{tp: mysql.TypeEnum}: "string", 60 | typeKey{tp: mysql.TypeSet}: "string", 61 | typeKey{tp: mysql.TypeTinyBlob}: "string", 62 | typeKey{tp: mysql.TypeMediumBlob}: "string", 63 | typeKey{tp: mysql.TypeLongBlob}: "string", 64 | typeKey{tp: mysql.TypeBlob}: "string", 65 | typeKey{tp: mysql.TypeVarString}: "string", 66 | typeKey{tp: mysql.TypeString}: "string", 67 | typeKey{tp: TypeNullString}: "sql.NullString", 68 | 69 | // aggregate functions 70 | typeKey{tp: mysql.TypeTiny, aggregateCall: true}: "sql.NullInt16", 71 | typeKey{tp: mysql.TypeShort, aggregateCall: true}: "sql.NullInt16", 72 | typeKey{tp: mysql.TypeLong, aggregateCall: true}: "sql.NullInt32", 73 | typeKey{tp: mysql.TypeFloat, aggregateCall: true}: "sql.NullInt32", 74 | typeKey{tp: mysql.TypeDouble, aggregateCall: true}: "sql.NullFloat64", 75 | typeKey{tp: mysql.TypeLonglong, aggregateCall: true}: "sql.NullInt64", 76 | typeKey{tp: mysql.TypeInt24, aggregateCall: true}: "sql.NullInt32", 77 | typeKey{tp: mysql.TypeYear, aggregateCall: true}: "sql.NullString", 78 | typeKey{tp: mysql.TypeVarchar, aggregateCall: true}: "sql.NullString", 79 | typeKey{tp: mysql.TypeBit, aggregateCall: true}: "sql.NullInt16", 80 | typeKey{tp: mysql.TypeJSON, aggregateCall: true}: "sql.NullString", 81 | typeKey{ 82 | tp: mysql.TypeNewDecimal, 83 | thirdPkg: defaultThirdDecimalPkg, 84 | aggregateCall: true, 85 | }: "decimal.NullDecimal", 86 | typeKey{ 87 | tp: TypeNullDecimal, 88 | thirdPkg: defaultThirdDecimalPkg, 89 | aggregateCall: true, 90 | }: "decimal.NullDecimal", 91 | typeKey{tp: mysql.TypeEnum, aggregateCall: true}: "sql.NullString", 92 | typeKey{tp: mysql.TypeSet, aggregateCall: true}: "sql.NullString", 93 | typeKey{tp: mysql.TypeTinyBlob, aggregateCall: true}: "sql.NullString", 94 | typeKey{tp: mysql.TypeMediumBlob, aggregateCall: true}: "sql.NullString", 95 | typeKey{tp: mysql.TypeLongBlob, aggregateCall: true}: "sql.NullString", 96 | typeKey{tp: mysql.TypeBlob, aggregateCall: true}: "sql.NullString", 97 | typeKey{tp: mysql.TypeVarString, aggregateCall: true}: "sql.NullString", 98 | typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", 99 | typeKey{tp: mysql.TypeString, aggregateCall: true}: "sql.NullString", 100 | typeKey{tp: TypeNullLongLong}: "sql.NullInt64", 101 | typeKey{tp: TypeNullDecimal}: "decimal.NullDecimal", 102 | typeKey{tp: TypeNullString}: "sql.NullString", 103 | typeKey{tp: TypeNullLongLong, aggregateCall: true}: "sql.NullInt64", 104 | typeKey{tp: TypeNullDecimal, aggregateCall: true}: "decimal.NullDecimal", 105 | typeKey{tp: TypeNullString, aggregateCall: true}: "sql.NullString", 106 | } 107 | 108 | // Type is the type of the column. 109 | type Type byte 110 | 111 | // DataType returns the Go type, third-package of the column. 112 | func (c Column) DataType() (parameter.Parameter, error) { 113 | var key = typeKey{tp: c.TP, signed: c.Unsigned, aggregateCall: c.AggregateCall} 114 | if c.AggregateCall { 115 | key = typeKey{tp: c.TP, aggregateCall: c.AggregateCall} 116 | } 117 | if c.TP == mysql.TypeNewDecimal { 118 | key.thirdPkg = defaultThirdDecimalPkg 119 | } 120 | 121 | goType, ok := typeMapper[key] 122 | if !ok { 123 | return parameter.Parameter{}, fmt.Errorf("unsupported type: %v", c.TP) 124 | } 125 | 126 | return NewParameter(c.Name, goType, key.thirdPkg), nil 127 | } 128 | 129 | // GoType returns the Go type of the column. 130 | func (c Column) GoType() (string, error) { 131 | p, err := c.DataType() 132 | return p.Type, err 133 | } 134 | 135 | func (c Column) HasComment() bool { 136 | return len(c.Comment) > 0 137 | } 138 | 139 | func isNullType(tp byte) bool { 140 | return tp >= TypeNullLongLong && tp <= TypeNullString 141 | } 142 | -------------------------------------------------------------------------------- /internal/spec/update.go: -------------------------------------------------------------------------------- 1 | package spec 2 | 3 | // UpdateStmt represents a update statement. 4 | type UpdateStmt struct { 5 | // Action represents the db action. 6 | Action Action 7 | // Columns represents the operation columns. 8 | Columns []string 9 | // Comment represents a sql comment. 10 | Comment 11 | // Limit represents the limit clause. 12 | Limit *Limit 13 | // OrderBy represents the order by clause. 14 | OrderBy ByItems 15 | // SQL represents the original sql text. 16 | SQL string 17 | // Table represents the operation table name, do not support multiple tables. 18 | Table string 19 | // Where represents the where clause. 20 | Where *Clause 21 | 22 | // the below data are from table 23 | // ColumnInfo are the column info which are convert from Columns. 24 | ColumnInfo Columns 25 | // TableInfo is the table info which is convert from Table. 26 | TableInfo *Table 27 | } 28 | 29 | func (u *UpdateStmt) SQLText() string { 30 | return u.SQL 31 | } 32 | 33 | func (u *UpdateStmt) TableName() string { 34 | return u.Table 35 | } 36 | 37 | func (u *UpdateStmt) validate() (map[string]string, error) { 38 | return map[string]string{ 39 | u.FuncName: u.OriginText, 40 | }, u.Comment.validate() 41 | } 42 | 43 | func (u *UpdateStmt) HasArg() bool { 44 | if u.Limit.IsValid() { 45 | return true 46 | } 47 | if u.OrderBy.IsValid() { 48 | return true 49 | } 50 | if u.Where.IsValid() { 51 | return true 52 | } 53 | return false 54 | } 55 | -------------------------------------------------------------------------------- /internal/stringx/stringx.go: -------------------------------------------------------------------------------- 1 | package stringx 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | func TrimWhiteSpace(s string) string { 10 | ret := TrimNewLine(s) 11 | return TrimSpace(ret) 12 | } 13 | 14 | func TrimNewLine(s string) string { 15 | var replacer = strings.NewReplacer("\r", "", "\n", "") 16 | return replacer.Replace(s) 17 | } 18 | 19 | func TrimSpace(s string) string { 20 | var r = strings.NewReplacer(" ", "", "\t", "") 21 | return r.Replace(s) 22 | } 23 | 24 | func RepeatJoin(s, sep string, count int) string { 25 | if len(s) == 0 { 26 | return "" 27 | } 28 | 29 | var list []string 30 | for i := 0; i < count; i++ { 31 | list = append(list, s) 32 | } 33 | 34 | return strings.Join(list, sep) 35 | } 36 | 37 | func AutoIncrement(s string, step int) string { 38 | length := len(s) 39 | if length == 0 { 40 | return "" 41 | } 42 | 43 | for i := 0; i < length; i++ { 44 | r := s[i] 45 | if r >= '0' && r <= '9' { 46 | if num, ok := IsNumber(s[i:]); ok { 47 | return fmt.Sprintf("%s%d", s[:i], num+uint64(step)) 48 | } 49 | } 50 | } 51 | return fmt.Sprintf("%s%d", s, step) 52 | } 53 | 54 | func IsNumber(s string) (uint64, bool) { 55 | num, err := strconv.ParseUint(s, 10, 64) 56 | if err != nil { 57 | return 0, false 58 | } 59 | return num, true 60 | } 61 | 62 | func FormatIdentifiers(s string) string { 63 | var list = strings.FieldsFunc(s, func(r rune) bool { 64 | return r == ' ' || r == '\t' || r == '\r' || r == '\n' || r == '\f' 65 | }) 66 | var target []string 67 | for _, v := range list { 68 | if len(v) > 0 { 69 | target = append(target, v) 70 | } 71 | } 72 | 73 | return strings.Join(target, " ") 74 | } 75 | -------------------------------------------------------------------------------- /internal/stringx/stringx_test.go: -------------------------------------------------------------------------------- 1 | package stringx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAutoIncrement(t *testing.T) { 10 | test := map[string]string{ 11 | "": "", 12 | "1": "2", 13 | "a": "a1", 14 | "1a": "1a1", 15 | "a1": "a2", 16 | "a10": "a11", 17 | "a1b1": "a1b2", 18 | } 19 | for input, expect := range test { 20 | actual := AutoIncrement(input, 1) 21 | assert.Equal(t, expect, actual) 22 | } 23 | } 24 | 25 | func TestTrimWhiteSpace(t *testing.T) { 26 | var testData = []struct { 27 | input string 28 | expect string 29 | }{ 30 | {input: "", expect: ""}, 31 | {input: "foo", expect: "foo"}, 32 | {input: "foo bar", expect: "foobar"}, 33 | {input: "foo\nbar", expect: "foobar"}, 34 | {input: "foo\rbar", expect: "foobar"}, 35 | {input: "foo\tbar", expect: "foobar"}, 36 | {input: "foo\n\r\tbar", expect: "foobar"}, 37 | } 38 | for _, v := range testData { 39 | actual := TrimWhiteSpace(v.input) 40 | assert.Equal(t, v.expect, actual) 41 | } 42 | } 43 | 44 | func TestTrimNewLine(t *testing.T) { 45 | var testData = []struct { 46 | input string 47 | expect string 48 | }{ 49 | {input: "", expect: ""}, 50 | {input: "foo", expect: "foo"}, 51 | {input: "foo bar", expect: "foo bar"}, 52 | {input: "foo\nbar", expect: "foobar"}, 53 | {input: "foo\rbar", expect: "foobar"}, 54 | {input: "foo\n\r\tbar", expect: "foo\tbar"}, 55 | } 56 | for _, v := range testData { 57 | actual := TrimNewLine(v.input) 58 | assert.Equal(t, v.expect, actual) 59 | } 60 | } 61 | 62 | func TestTrimSpace(t *testing.T) { 63 | var testData = []struct { 64 | input string 65 | expect string 66 | }{ 67 | {input: "", expect: ""}, 68 | {input: "foo", expect: "foo"}, 69 | {input: "foo bar", expect: "foobar"}, 70 | {input: "foo\nbar", expect: "foo\nbar"}, 71 | {input: "foo\rbar", expect: "foo\rbar"}, 72 | {input: "foo\n\r\tbar", expect: "foo\n\rbar"}, 73 | } 74 | for _, v := range testData { 75 | actual := TrimSpace(v.input) 76 | assert.Equal(t, v.expect, actual) 77 | } 78 | } 79 | 80 | func TestRepeatJoin(t *testing.T) { 81 | var testData = []struct { 82 | input string 83 | expect string 84 | }{ 85 | {input: "", expect: ""}, 86 | {input: "foo", expect: "foo,foo"}, 87 | } 88 | for _, v := range testData { 89 | actual := RepeatJoin(v.input, ",", 2) 90 | assert.Equal(t, v.expect, actual) 91 | } 92 | } 93 | 94 | func TestFormatIdentifiers(t *testing.T) { 95 | var testData = []struct { 96 | input string 97 | expect string 98 | }{ 99 | {input: "", expect: ""}, 100 | {input: "foo", expect: "foo"}, 101 | {input: "foo bar", expect: "foo bar"}, 102 | {input: "foo\nbar", expect: "foo bar"}, 103 | {input: "foo\tbar", expect: "foo bar"}, 104 | {input: "foo\rbar", expect: "foo bar"}, 105 | {input: "foo\fbar", expect: "foo bar"}, 106 | {input: "foo\n\t\r\fbar", expect: "foo bar"}, 107 | } 108 | for _, v := range testData { 109 | actual := FormatIdentifiers(v.input) 110 | assert.Equal(t, v.expect, actual) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /internal/templatex/funcmap.go: -------------------------------------------------------------------------------- 1 | package templatex 2 | 3 | import ( 4 | "strings" 5 | "text/template" 6 | 7 | "github.com/anqiansong/sqlgen/internal/stringx" 8 | "github.com/iancoleman/strcase" 9 | ) 10 | 11 | func UpperCamel(s string) string { 12 | return strcase.ToCamel(s) 13 | } 14 | 15 | func LowerCamel(s string) string { 16 | return strcase.ToLowerCamel(s) 17 | } 18 | 19 | func Join(list []string, sep string) string { 20 | return strings.Join(list, sep) 21 | } 22 | 23 | func LineComment(s string) string { 24 | fields := strings.FieldsFunc(s, func(r rune) bool { 25 | return r == '\n' 26 | }) 27 | return strings.Join(fields, "\n// ") 28 | } 29 | 30 | var funcMap = template.FuncMap{ 31 | "UpperCamel": UpperCamel, 32 | "LowerCamel": LowerCamel, 33 | "Join": Join, 34 | "TrimNewLine": stringx.TrimNewLine, 35 | "LineComment": LineComment, 36 | } 37 | -------------------------------------------------------------------------------- /internal/templatex/funcmap_test.go: -------------------------------------------------------------------------------- 1 | package templatex 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestUpperCamel(t *testing.T) { 10 | var testData = []struct { 11 | input string 12 | expect string 13 | }{ 14 | {input: "", expect: ""}, 15 | {input: "foo", expect: "Foo"}, 16 | {input: "foo bar", expect: "FooBar"}, 17 | {input: "foo_bar", expect: "FooBar"}, 18 | {input: "foo-bar", expect: "FooBar"}, 19 | {input: "_foobar", expect: "Foobar"}, 20 | {input: "_foobar_", expect: "Foobar"}, 21 | } 22 | for _, v := range testData { 23 | actual := UpperCamel(v.input) 24 | assert.Equal(t, v.expect, actual) 25 | } 26 | } 27 | 28 | func TestLowerCamel(t *testing.T) { 29 | var testData = []struct { 30 | input string 31 | expect string 32 | }{ 33 | {input: "", expect: ""}, 34 | {input: "foo", expect: "foo"}, 35 | {input: "Foo bar", expect: "fooBar"}, 36 | {input: "Foo_bar", expect: "fooBar"}, 37 | {input: "Foo-bar", expect: "fooBar"}, 38 | {input: "_foobar", expect: "Foobar"}, 39 | {input: "_foobar_", expect: "Foobar"}, 40 | {input: "FooBar", expect: "fooBar"}, 41 | {input: "Foo_Bar", expect: "fooBar"}, 42 | } 43 | for _, v := range testData { 44 | actual := LowerCamel(v.input) 45 | assert.Equal(t, v.expect, actual) 46 | } 47 | } 48 | 49 | func TestJoin(t *testing.T) { 50 | var testData = []struct { 51 | input []string 52 | expect string 53 | }{ 54 | {input: []string{}, expect: ""}, 55 | {input: []string{"foo"}, expect: "foo"}, 56 | {input: []string{"foo", "bar"}, expect: "foo,bar"}, 57 | } 58 | for _, v := range testData { 59 | actual := Join(v.input, ",") 60 | assert.Equal(t, v.expect, actual) 61 | } 62 | } 63 | 64 | func TestLineComment(t *testing.T) { 65 | var testData = []struct { 66 | input string 67 | expect string 68 | }{ 69 | {input: "", expect: ""}, 70 | {input: "foo", expect: "foo"}, 71 | {input: "foo\nbar", expect: "foo\n// bar"}, 72 | {input: "foo\nbar\n", expect: "foo\n// bar"}, 73 | {input: "\nfoo\nbar\n", expect: "foo\n// bar"}, 74 | } 75 | for _, v := range testData { 76 | actual := LineComment(v.input) 77 | assert.Equal(t, v.expect, actual) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /internal/templatex/templatex.go: -------------------------------------------------------------------------------- 1 | package templatex 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "text/template" 12 | 13 | "github.com/anqiansong/sqlgen/internal/format" 14 | "github.com/anqiansong/sqlgen/internal/log" 15 | ) 16 | 17 | const name = "_" 18 | 19 | // T is a template helper. 20 | type T struct { 21 | t *template.Template 22 | buffer *bytes.Buffer 23 | fm template.FuncMap 24 | } 25 | 26 | // New creates a new template helper. 27 | func New() *T { 28 | var t = template.New(name) 29 | return &T{ 30 | t: t, 31 | buffer: bytes.NewBuffer(nil), 32 | fm: funcMap, 33 | } 34 | } 35 | 36 | func (t *T) AppendFuncMap(fm template.FuncMap) { 37 | for k, v := range fm { 38 | t.fm[k] = v 39 | } 40 | } 41 | 42 | // MustParse parses the template. 43 | func (t *T) MustParse(text string) *T { 44 | t.t.Funcs(t.fm) 45 | _, err := t.t.Parse(text) 46 | log.Must(err) 47 | return t 48 | } 49 | 50 | // MustExecute executes the template. 51 | func (t *T) MustExecute(data interface{}) *T { 52 | t.buffer.Reset() 53 | log.Must(t.t.Execute(t.buffer, data)) 54 | return t 55 | } 56 | 57 | // MustSaveAs saves the template to the given filename, it will overwrite the file if it exists. 58 | func (t *T) MustSaveAs(filename string, formatCode bool) { 59 | var data []byte 60 | var err error 61 | if formatCode { 62 | data, err = format.Source(t.buffer.Bytes()) 63 | if err != nil { 64 | extension := filepath.Ext(filename) 65 | errorFilename := strings.TrimSuffix(filename, extension) + ".error" + extension 66 | ioutil.WriteFile(errorFilename, t.buffer.Bytes(), 0644) 67 | log.Must(err) 68 | } 69 | } else { 70 | data = t.buffer.Bytes() 71 | } 72 | log.Must(ioutil.WriteFile(filename, data, 0666)) 73 | } 74 | 75 | // MustSave saves the template to the given filename, it will do nothing if it exists. 76 | func (t *T) MustSave(filename string, format bool) { 77 | _, err := os.Stat(filename) 78 | if err != nil { 79 | t.MustSaveAs(filename, format) 80 | } 81 | } 82 | 83 | func (t *T) Write(writer io.Writer, formatCode bool) { 84 | var data []byte 85 | var err error 86 | if formatCode { 87 | data, err = format.Source(t.buffer.Bytes()) 88 | if err != nil { 89 | fmt.Printf("%+v\n", string(t.buffer.Bytes())) 90 | log.Must(err) 91 | } 92 | } else { 93 | data = t.buffer.Bytes() 94 | } 95 | writer.Write(data) 96 | } 97 | -------------------------------------------------------------------------------- /internal/templatex/templatex_test.go: -------------------------------------------------------------------------------- 1 | package templatex 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | "text/template" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNew(t *testing.T) { 14 | instance := New() 15 | assert.NotNil(t, instance) 16 | } 17 | 18 | func TestT_AppendFuncMap(t *testing.T) { 19 | instance := New() 20 | fn := func() string { 21 | return "any" 22 | } 23 | list := template.FuncMap{ 24 | "foo": fn, 25 | "bar": fn, 26 | "baz": fn, 27 | } 28 | instance.AppendFuncMap(list) 29 | for k := range list { 30 | _, ok := instance.fm[k] 31 | assert.True(t, ok) 32 | } 33 | } 34 | 35 | func TestT_MustParse(t *testing.T) { 36 | instance := New() 37 | instance.AppendFuncMap(template.FuncMap{"foo": func() string { return "bar" }}) 38 | ret := instance.MustParse("{{foo}}") 39 | assert.Equal(t, instance, ret) 40 | } 41 | 42 | func TestT_MustExecute(t *testing.T) { 43 | instance := New() 44 | instance.MustParse("{{.foo}}") 45 | ret := instance.MustExecute(map[string]string{ 46 | "foo": "bar", 47 | }) 48 | assert.Equal(t, instance, ret) 49 | assert.Equal(t, "bar", instance.buffer.String()) 50 | } 51 | 52 | func TestT_MustSaveAs(t *testing.T) { 53 | t.Run("format_false", func(t *testing.T) { 54 | instance := New() 55 | instance.MustParse("{{.foo}}") 56 | instance.MustExecute(map[string]string{ 57 | "foo": "bar", 58 | }) 59 | tempFile := filepath.Join(t.TempDir(), "foo") 60 | instance.MustSaveAs(tempFile, false) 61 | data, err := ioutil.ReadFile(tempFile) 62 | assert.NoError(t, err) 63 | assert.Equal(t, "bar", string(data)) 64 | }) 65 | 66 | t.Run("format_true", func(t *testing.T) { 67 | instance := New() 68 | instance.MustParse(" package {{.foo}}") 69 | instance.MustExecute(map[string]string{ 70 | "foo": "bar", 71 | }) 72 | tempFile := filepath.Join(t.TempDir(), "foo") 73 | instance.MustSaveAs(tempFile, true) 74 | data, err := ioutil.ReadFile(tempFile) 75 | assert.NoError(t, err) 76 | assert.Equal(t, "package bar\n", string(data)) 77 | }) 78 | } 79 | 80 | func TestT_MustSave(t *testing.T) { 81 | instance := New() 82 | instance.MustParse("{{.foo}}") 83 | instance.MustExecute(map[string]string{ 84 | "foo": "bar", 85 | }) 86 | tempFile := filepath.Join(t.TempDir(), "foo") 87 | instance.MustSave(tempFile, false) 88 | data, err := ioutil.ReadFile(tempFile) 89 | assert.NoError(t, err) 90 | assert.Equal(t, "bar", string(data)) 91 | 92 | instance.MustExecute(map[string]string{ 93 | "foo": "baz", 94 | }) 95 | instance.MustSave(tempFile, false) 96 | data, err = ioutil.ReadFile(tempFile) 97 | assert.NoError(t, err) 98 | assert.Equal(t, "bar", string(data)) 99 | } 100 | 101 | func TestT_Write(t *testing.T) { 102 | t.Run("format_false", func(t *testing.T) { 103 | instance := New() 104 | instance.MustParse("{{.foo}}") 105 | instance.MustExecute(map[string]string{ 106 | "foo": "bar", 107 | }) 108 | tempFile := filepath.Join(t.TempDir(), "foo") 109 | file, err := os.OpenFile(tempFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) 110 | assert.NoError(t, err) 111 | t.Cleanup(func() { 112 | file.Close() 113 | }) 114 | 115 | instance.Write(file, false) 116 | data, err := ioutil.ReadFile(tempFile) 117 | assert.NoError(t, err) 118 | assert.Equal(t, "bar", string(data)) 119 | }) 120 | 121 | t.Run("format_true", func(t *testing.T) { 122 | instance := New() 123 | instance.MustParse(" package {{.foo}}") 124 | instance.MustExecute(map[string]string{ 125 | "foo": "bar", 126 | }) 127 | tempFile := filepath.Join(t.TempDir(), "foo") 128 | file, err := os.OpenFile(tempFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) 129 | assert.NoError(t, err) 130 | t.Cleanup(func() { 131 | file.Close() 132 | }) 133 | 134 | instance.Write(file, true) 135 | data, err := ioutil.ReadFile(tempFile) 136 | assert.NoError(t, err) 137 | assert.Equal(t, "package bar\n", string(data)) 138 | }) 139 | } 140 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/anqiansong/sqlgen/cmd" 4 | 5 | func main() { 6 | cmd.Execute() 7 | } 8 | --------------------------------------------------------------------------------