├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── LICENSE ├── README.md ├── coll_stack.go ├── config.go ├── config_builder.go ├── config_db.go ├── config_db_test.go ├── config_mapper.go ├── config_mapper_test.go ├── db.yml ├── examples ├── examples.go ├── mapper │ └── userMapper.xml └── sql │ └── user.sql ├── executor.go ├── expr.go ├── expr_test.go ├── go.mod ├── go.sum ├── gobatis.dtd ├── gobatis.go ├── gobatis_db.go ├── gobatis_test.go ├── logger.go ├── mapper.go ├── nilable_structs.go ├── option.go ├── option_db.go ├── option_ds.go ├── option_file.go ├── params_test.go ├── parser_xml.go ├── parser_xml_test.go ├── proc_params.go ├── proc_params_test.go ├── proc_res.go ├── sql_source.go ├── sql_source_test.go ├── util.go ├── util_builder.go ├── val.go ├── xmltag.go └── xmltag_test.go /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '45 6 * * 3' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Go template 3 | # Binaries for programs and plugins 4 | *.exe 5 | *.exe~ 6 | *.dll 7 | *.so 8 | *.dylib 9 | 10 | # Test binary, build with `go test -c` 11 | *.test 12 | 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | .DS_Store 17 | .idea/ 18 | pkg/ 19 | /src/ 20 | /.vscode 21 | /*.code-workspace 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gobatis 2 | 3 | [![CodeQL](https://github.com/wenj91/gobatis/actions/workflows/codeql.yml/badge.svg?branch=master)](https://github.com/wenj91/gobatis/actions/workflows/codeql.yml) 4 | 5 | 目前代码都是基于类mysql数据库编写测试的,其他数据库暂时还未做兼容处理 6 | 7 | - [x] 支持数据库 8 | - [x] mysql 9 | - [x] tidb 10 | - [x] mariadb 11 | - [ ] postgres 12 | - [ ] sqlite 13 | - [x] 基础操作 14 | - [x] query 15 | - [x] insert 16 | - [x] update 17 | - [x] delete 18 | 19 | ## ToDo 20 | 21 | - 增加更多易用表达式指令,目前已有`$blank`指令用于判别字符串是否为空的指令,比如判断name为空串: test="$blank(name)" 22 | 23 | ## 模板代码生成 24 | 25 | 提供了简单的增删改查代码自动生成 26 | 27 | 具体操作看仓库: [https://github.com/wenj91/mctl.git](https://github.com/wenj91/mctl.git) 28 | 29 | ## gobatis接口 30 | 31 | ```go 32 | type GoBatis interface { 33 | // Select 查询数据 34 | Select(stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) 35 | // SelectContext 查询数据with context 36 | SelectContext(ctx context.Context, stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) 37 | // Insert 插入数据 38 | Insert(stmt string, param interface{}) (lastInsertId int64, affected int64, err error) 39 | // InsertContext 插入数据with context 40 | InsertContext(ctx context.Context, stmt string, param interface{}) (lastInsertId int64, affected int64, err error) 41 | // Update 更新数据 42 | Update(stmt string, param interface{}) (affected int64, err error) 43 | // UpdateContext 更新数据with context 44 | UpdateContext(ctx context.Context, stmt string, param interface{}) (affected int64, err error) 45 | // Delete 刪除数据 46 | Delete(stmt string, param interface{}) (affected int64, err error) 47 | // DeleteContext 刪除数据with context 48 | DeleteContext(ctx context.Context, stmt string, param interface{}) (affected int64, err error) 49 | } 50 | ``` 51 | 52 | ## db数据源配置 53 | - 支持多数据源配置 54 | - db子级配置为一个map,map的key即为数据源名称标识 55 | - map的value为数据源具体配置,具体配置项如下表 56 | 57 | | 配置 | 是否必填配置 | 默认值 | 说明 | 58 | |:---|:----:|:----:|----| 59 | | driverName | 是 | | 数据源驱动名,必填配置项 60 | | dataSourceName | 是 | | 数据源名称,必填配置项,例如: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 61 | | maxLifeTime | 否 | 120(单位: s)| 连接最大存活时间,默认值为: 120 单位为: s 62 | | maxOpenConns | 否 | 10 | 最大打开连接数,默认值为: 10 63 | | maxIdleConns | 否 | 5 | 最大挂起连接数,默认值为: 5 64 | 65 | ### 示例 66 | * db配置示例(配置较之前的有所调整) 67 | 以下为多数据源配置示例: db.yml 68 | ```yaml 69 | # 数据库配置 70 | db: 71 | # 数据源名称1 72 | - datasource: ds1 73 | # 驱动名 74 | driverName: mysql 75 | # 数据源 76 | dataSourceName: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 77 | # 连接最大存活时间(单位: s) 78 | maxLifeTime: 120 79 | # 最大open连接数 80 | maxOpenConns: 10 81 | # 最大挂起连接数 82 | maxIdleConns: 5 83 | # 数据源名称2 84 | - datasource: ds2 85 | # 驱动名 86 | driverName: mysql 87 | # 数据源 88 | dataSourceName: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 89 | # 连接最大存活时间(单位: s) 90 | maxLifeTime: 120 91 | # 最大open连接数 92 | maxOpenConns: 10 93 | # 最大挂起连接数 94 | maxIdleConns: 5 95 | # 是否显示SQL语句 96 | showSql: true 97 | # 数据表映射文件路径配置 98 | mappers: 99 | # 映射文件路径, 可以为绝对路径,如: /usr/local/mapper/userMapper.xml 100 | - mapper/userMapper.xml 101 | ``` 102 | 103 | * mapper配置 104 | 1. mapper可以配置namespace属性 105 | 1. mapper可以包含: select, insert, update, delete标签 106 | 1. mapper子标签id属性则为标签唯一标识, 必须配置属性 107 | 1. 其中select标签必须包含resultType属性,resultType可以是: map, maps, array, arrays, struct, structs, value 108 | 109 | * 标签说明 110 | select: 用于查询操作 111 | insert: 用于插入sql操作 112 | update: 用于更新sql操作 113 | delete: 用于删除sql操作 114 | 115 | * resultType说明 116 | map: 则数据库查询结果为map 117 | maps: 则数据库查询结果为map数组 118 | array: 则数据库查询结果为值数组 119 | arrays: 则数据库查询结果为多个值数组 120 | struct: 则数据库查询结果为单个结构体 121 | structs: 则数据库查询结果为结构体数组 122 | value: 则数据库查询结果为单个数值 123 | 124 | 以下是mapper配置示例: mapper/userMapper.xml 125 | ```xml 126 | 127 | 129 | 130 | 131 | id, name, crtTm, pwd, email 132 | 133 | 139 | 142 | 145 | 148 | 151 | 154 | 155 | insert into user (name, email, crtTm) 156 | values (#{Name}, #{Email}, #{CrtTm}) 157 | 158 | 159 | delete from user where id=#{id} 160 | 161 | 168 | 175 | 176 | update user 177 | 178 | name = #{Name}, 179 | 180 | where id = #{Id} 181 | 182 | 183 | ``` 184 | 185 | ## 使用方法 186 | 187 | ### 使用配置文件配置 188 | example1.go 189 | ```go 190 | package main 191 | 192 | import ( 193 | "fmt" 194 | _ "github.com/go-sql-driver/mysql" // 引入驱动 195 | "github.com/wenj91/gobatis" // 引入gobatis 196 | ) 197 | 198 | // 实体结构示例, tag:field为数据库对应字段名称 199 | type User struct { 200 | Id gobatis.NullInt64 `field:"id"` 201 | Name gobatis.NullString `field:"name"` 202 | Email gobatis.NullString `field:"email"` 203 | CrtTm gobatis.NullTime `field:"crtTm"` 204 | } 205 | 206 | 207 | // User to string 208 | func (u *User) String() string { 209 | bs, _ := json.Marshal(u) 210 | return string(bs) 211 | } 212 | 213 | func main() { 214 | // 初始化db,参数为db.yml路径,如:db.yml 215 | gobatis.Init(gobatis.NewFileOption("db.yml")) 216 | 217 | // 获取数据源,参数为数据源名称,如:datasource1 218 | gb := gobatis.Get("ds1") 219 | 220 | //传入id查询Map 221 | mapRes := make(map[string]interface{}) 222 | // stmt标识为:namespace + '.' + id, 如:userMapper.findMapById 223 | // 查询参数可以是map,也可以是数组,也可以是实体结构 224 | _, err := gb.Select("userMapper.findMapById", map[string]interface{}{"id": 1})(mapRes) 225 | fmt.Println("userMapper.findMapById-->", mapRes, err) 226 | 227 | // 根据传入实体查询对象 228 | param := User{Id: gobatis.NullInt64{Int64: 1, Valid: true}} 229 | var structRes *User 230 | _, err = gb.Select("userMapper.findStructByStruct", param)(&structRes) 231 | fmt.Println("userMapper.findStructByStruct-->", structRes, err) 232 | 233 | // 查询实体列表 234 | structsRes := make([]*User, 0) 235 | _, err = gb.Select("userMapper.queryStructs", map[string]interface{}{})(&structsRes) 236 | fmt.Println("userMapper.queryStructs-->", structsRes, err) 237 | 238 | param = User{ 239 | Id: gobatis.NullInt64{Int64: 1, Valid: true}, 240 | Name: gobatis.NullString{String: "wenj1993", Valid: true}, 241 | } 242 | 243 | // set tag 244 | affected, err := gb.Update("userMapper.updateByCond", param) 245 | fmt.Println("updateByCond:", affected, err) 246 | 247 | param = User{Name: gobatis.NullString{String: "wenj1993", Valid: true}} 248 | // where tag 249 | res := make([]*User, 0) 250 | _, err = gb.Select("userMapper.queryStructsByCond", param)(&res) 251 | fmt.Println("queryStructsByCond", res, err) 252 | 253 | // trim tag 254 | res = make([]*User, 0) 255 | _, err = gb.Select("userMapper.queryStructsByCond2", param)(&res) 256 | fmt.Println("queryStructsByCond2", res, err) 257 | 258 | // include tag 259 | ms := make([]map[string]interface{}, 0) 260 | _, err = gb.Select("userMapper.findIncludeMaps", nil)(&ms) 261 | fmt.Println("userMapper.findIncludeMaps-->", ms, err) 262 | 263 | // ${id} 264 | res = make([]*User, 0) 265 | _, err = gb.Select("userMapper.queryStructsByOrder", map[string]interface{}{ 266 | "id":"id", 267 | })(&res) 268 | fmt.Println("queryStructsByCond", res, err) 269 | 270 | // ${id} with count, 传入RowBounds(0, 100)即可返回count总数 271 | res = make([]*User, 0) 272 | cnt, err = gb.Select("userMapper.queryStructsByOrder", map[string]interface{}{ 273 | "id":"id", 274 | }, RowBounds(0, 100))(&res) 275 | fmt.Println("queryStructsByCond", cnt, res, err) 276 | 277 | 278 | // 开启事务示例 279 | tx, _ := gb.Begin() 280 | defer tx.Rollback() 281 | _, tx.Select("userMapper.findMapById", map[string]interface{}{"id": 1,})(mapRes) 282 | fmt.Println("tx userMapper.findMapById-->", mapRes, err) 283 | tx.Commit() 284 | } 285 | ``` 286 | 287 | ### 代码配置方式 288 | 289 | example2.go 290 | 291 | ```go 292 | package main 293 | 294 | import ( 295 | "fmt" 296 | _ "github.com/go-sql-driver/mysql" // 引入驱动 297 | "github.com/wenj91/gobatis" // 引入gobatis 298 | ) 299 | 300 | // 实体结构示例, tag:field为数据库对应字段名称 301 | type User struct { 302 | Id gobatis.NullInt64 `field:"id"` 303 | Name gobatis.NullString `field:"name"` 304 | Email gobatis.NullString `field:"email"` 305 | CrtTm gobatis.NullTime `field:"crtTm"` 306 | } 307 | 308 | func main() { 309 | // 初始化db 310 | ds1 := gobatis.NewDataSourceBuilder(). 311 | DataSource("ds1"). 312 | DriverName("mysql"). 313 | DataSourceName("root:123456@tcp(127.0.0.1:3306)/test?charset=utf8"). 314 | MaxLifeTime(120). 315 | MaxOpenConns(10). 316 | MaxIdleConns(5). 317 | Build() 318 | 319 | option := gobatis.NewDSOption(). 320 | DS([]*gobatis.DataSource{ds1}). 321 | Mappers([]string{"examples/mapper/userMapper.xml"}). 322 | ShowSQL(true) 323 | 324 | gobatis.Init(option) 325 | 326 | // 获取数据源,参数为数据源名称,如:ds1 327 | gb := gobatis.Get("ds1") 328 | 329 | //传入id查询Map 330 | mapRes := make(map[string]interface{}) 331 | // stmt标识为:namespace + '.' + id, 如:userMapper.findMapById 332 | // 查询参数可以是map,也可以是数组,也可以是实体结构 333 | _, err := gb.Select("userMapper.findMapById", map[string]interface{}{"id": 1})(mapRes) 334 | fmt.Println("userMapper.findMapById-->", mapRes, err) 335 | } 336 | ``` 337 | 338 | example3.go 339 | 340 | ```go 341 | package main 342 | 343 | import ( 344 | "database/sql" 345 | "fmt" 346 | _ "github.com/go-sql-driver/mysql" // 引入驱动 347 | "github.com/wenj91/gobatis" // 引入gobatis 348 | ) 349 | 350 | // 实体结构示例, tag:field为数据库对应字段名称 351 | type User struct { 352 | Id gobatis.NullInt64 `field:"id"` 353 | Name gobatis.NullString `field:"name"` 354 | Email gobatis.NullString `field:"email"` 355 | CrtTm gobatis.NullTime `field:"crtTm"` 356 | } 357 | 358 | func main() { 359 | // 初始化db 360 | db, _ := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8") 361 | dbs := make(map[string]*gobatis.GoBatisDB) 362 | dbs["ds1"] = gobatis.NewGoBatisDB(gobatis.DBTypeMySQL, db) 363 | 364 | option := gobatis.NewDBOption(). 365 | DB(dbs). 366 | ShowSQL(true). 367 | Mappers([]string{"examples/mapper/userMapper.xml"}) 368 | 369 | gobatis.Init(option) 370 | 371 | // 获取数据源,参数为数据源名称,如:ds1 372 | gb := gobatis.Get("ds1") 373 | 374 | //传入id查询Map 375 | mapRes := make(map[string]interface{}) 376 | // stmt标识为:namespace + '.' + id, 如:userMapper.findMapById 377 | // 查询参数可以是map,也可以是数组,也可以是实体结构 378 | _, err := gb.Select("userMapper.findMapById", map[string]interface{}{"id": 1})(mapRes) 379 | fmt.Println("userMapper.findMapById-->", mapRes, err) 380 | } 381 | ``` 382 | 383 | ## 致谢 384 | 385 | 感谢jetbrains提供的goland! 386 | -------------------------------------------------------------------------------- /coll_stack.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | l "container/list" 5 | "sync" 6 | ) 7 | 8 | type stack struct { 9 | list *l.List 10 | mu sync.Mutex 11 | } 12 | 13 | func NewStack() *stack { 14 | list := l.New() 15 | return &stack{list: list,} 16 | } 17 | 18 | func (s *stack) Push(t interface{}){ 19 | s.mu.Lock() 20 | defer s.mu.Unlock() 21 | s.list.PushFront(t) 22 | } 23 | 24 | func (s *stack) Pop() interface{} { 25 | s.mu.Lock() 26 | defer s.mu.Unlock() 27 | ele := s.list.Front() 28 | if nil != ele { 29 | s.list.Remove(ele) 30 | return ele.Value 31 | } 32 | 33 | return nil 34 | } 35 | 36 | func (s *stack) Peak() interface{} { 37 | s.mu.Lock() 38 | defer s.mu.Unlock() 39 | ele := s.list.Front() 40 | return ele.Value 41 | } 42 | 43 | func (s *stack) Len() int { 44 | return s.list.Len() 45 | } 46 | 47 | func (s *stack) IsEmpty() bool { 48 | return s.list.Len() == 0 49 | } 50 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | type config struct { 11 | dbConf *DBConfig 12 | mapperConf *mapperConfig 13 | } 14 | 15 | var conf *config 16 | var db map[string]*GoBatisDB 17 | 18 | func Init(option IOption) { 19 | configInit(option.ToDBConf()) 20 | } 21 | 22 | func configInit(dbConf *DBConfig) { 23 | if nil == dbConf { 24 | panic("Build db config err: dbConf == nil") 25 | } 26 | 27 | if len(dbConf.DB) <= 0 && dbConf.db == nil { 28 | panic("No datasource config") 29 | } 30 | 31 | mapperConf := &mapperConfig{ 32 | mappedStmts: make(map[string]*node), 33 | mappedSql: make(map[string]*node), 34 | } 35 | 36 | for _, item := range dbConf.Mappers { 37 | f, err := os.Open(item) 38 | if nil != err { 39 | panic("Open mapper config: " + item + " err:" + err.Error()) 40 | } 41 | 42 | LOG.Info("mapper config:%s %s", item, "init...") 43 | mc := buildMapperConfig(f) 44 | for k, ms := range mc.mappedStmts { 45 | mapperConf.put(k, ms) 46 | } 47 | 48 | // sql tag cache 49 | for k, ms := range mc.mappedSql { 50 | mapperConf.putSql(k, ms) 51 | } 52 | } 53 | 54 | conf = &config{ 55 | dbConf: dbConf, 56 | mapperConf: mapperConf, 57 | } 58 | 59 | // init db 60 | dbInit(dbConf) 61 | } 62 | 63 | func dbInit(dbConf *DBConfig) { 64 | db = make(map[string]*GoBatisDB) 65 | if len(dbConf.DB) <= 0 && dbConf.db == nil { 66 | panic("No config for datasource") 67 | } 68 | 69 | for _, item := range dbConf.DB { 70 | if item.DataSource == "" { 71 | panic("DB config err: datasource must not be nil") 72 | } 73 | 74 | item.DataSource = strings.TrimSpace(item.DataSource) 75 | 76 | _, ok := db[item.DataSource] 77 | if ok { 78 | panic("DB config datasource name repeat:" + item.DataSource) 79 | } 80 | 81 | if item.DriverName == "" { 82 | panic("DB config err: driverName must not be nil") 83 | } 84 | 85 | if item.DataSourceName == "" { 86 | panic("DB config err: dataSourceName must not be nil") 87 | } 88 | 89 | dbConn, err := sql.Open(item.DriverName, item.DataSourceName) 90 | if nil != err { 91 | panic(err) 92 | } 93 | 94 | if err := dbConn.Ping(); err != nil { 95 | panic(err) 96 | } 97 | 98 | if item.MaxLifeTime == 0 { 99 | dbConn.SetConnMaxLifetime(120 * time.Second) 100 | } else { 101 | dbConn.SetConnMaxLifetime(time.Duration(item.MaxLifeTime) * time.Second) 102 | } 103 | 104 | if item.MaxOpenConns == 0 { 105 | dbConn.SetMaxOpenConns(10) 106 | } else { 107 | dbConn.SetMaxOpenConns(item.MaxOpenConns) 108 | } 109 | 110 | if item.MaxOpenConns == 0 { 111 | dbConn.SetMaxIdleConns(5) 112 | } else { 113 | dbConn.SetMaxIdleConns(item.MaxIdleConns) 114 | } 115 | 116 | d := NewGoBatisDB(DBType(item.DriverName), dbConn) 117 | db[item.DataSource] = d 118 | } 119 | 120 | if dbConf.db != nil { 121 | for k, v := range dbConf.db { 122 | _, ok := db[k] 123 | if ok { 124 | panic("DB config datasource name repeat:" + k) 125 | } 126 | db[k] = v 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /config_builder.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type DataSourceBuilder struct { 4 | ds *DataSource 5 | } 6 | 7 | func NewDataSourceBuilder() *DataSourceBuilder { 8 | return &DataSourceBuilder{ 9 | ds: &DataSource{}, 10 | } 11 | } 12 | 13 | // DataSource 14 | func (d *DataSourceBuilder) DataSource(ds string) *DataSourceBuilder { 15 | d.ds.DataSource = ds 16 | return d 17 | } 18 | 19 | // DriverName 20 | func (d *DataSourceBuilder) DriverName(dn string) *DataSourceBuilder { 21 | d.ds.DriverName = dn 22 | return d 23 | } 24 | 25 | // DataSourceName 26 | func (d *DataSourceBuilder) DataSourceName(dsn string) *DataSourceBuilder { 27 | d.ds.DataSourceName = dsn 28 | return d 29 | } 30 | 31 | // MaxLifeTime 32 | func (d *DataSourceBuilder) MaxLifeTime(mlt int) *DataSourceBuilder { 33 | d.ds.MaxLifeTime = mlt 34 | return d 35 | } 36 | 37 | // MaxOpenConns 38 | func (d *DataSourceBuilder) MaxOpenConns(moc int) *DataSourceBuilder { 39 | d.ds.MaxOpenConns = moc 40 | return d 41 | } 42 | 43 | // MaxIdleConns 44 | func (d *DataSourceBuilder) MaxIdleConns(mic int) *DataSourceBuilder { 45 | d.ds.MaxIdleConns = mic 46 | return d 47 | } 48 | 49 | func (d *DataSourceBuilder) Build() *DataSource { 50 | if d.ds.DataSource == "" { 51 | panic("DataSource is nil") 52 | } 53 | 54 | if d.ds.DataSourceName == "" { 55 | panic("DataSourceName is nil") 56 | } 57 | 58 | if d.ds.DriverName == "" { 59 | panic("DriverName is nil") 60 | } 61 | 62 | return d.ds 63 | } 64 | 65 | type DBConfigBuilder struct { 66 | d *DBConfig 67 | } 68 | 69 | func NewDBConfigBuilder() *DBConfigBuilder { 70 | return &DBConfigBuilder{ 71 | d: &DBConfig{ 72 | DB: make([]*DataSource, 0), 73 | }, 74 | } 75 | } 76 | 77 | func (d *DBConfigBuilder) Mappers(mappers []string) *DBConfigBuilder { 78 | d.d.Mappers = mappers 79 | return d 80 | } 81 | 82 | func (d *DBConfigBuilder) DS(dss []*DataSource) *DBConfigBuilder { 83 | d.d.DB = dss 84 | return d 85 | } 86 | 87 | func (d *DBConfigBuilder) DB(db map[string]*GoBatisDB) *DBConfigBuilder { 88 | d.d.db = db 89 | return d 90 | } 91 | 92 | func (d *DBConfigBuilder) ShowSQL(showSQL bool) *DBConfigBuilder { 93 | d.d.ShowSQL = showSQL 94 | return d 95 | } 96 | 97 | func (d *DBConfigBuilder) Build() *DBConfig { 98 | if len(d.d.DB) <= 0 && d.d.db == nil { 99 | panic("No config for datasource") 100 | } 101 | 102 | return d.d 103 | } 104 | -------------------------------------------------------------------------------- /config_db.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type DataSource struct { 4 | DataSource string `yaml:"datasource"` 5 | DriverName string `yaml:"driverName"` 6 | DataSourceName string `yaml:"dataSourceName"` 7 | MaxLifeTime int `yaml:"maxLifeTime"` 8 | MaxOpenConns int `yaml:"maxOpenConns"` 9 | MaxIdleConns int `yaml:"maxIdleConns"` 10 | } 11 | 12 | // NewDataSource new data source 13 | func NewDataSource(dataSource string, driverName string, dataSourceName string) *DataSource { 14 | return &DataSource{ 15 | DataSource: dataSource, 16 | DriverName: driverName, 17 | DataSourceName: dataSourceName, 18 | } 19 | } 20 | 21 | // NewDataSource_ new data source 22 | func NewDataSource_(dataSource string, driverName string, dataSourceName string, 23 | maxLifeTime int, maxOpenConns int, maxIdleConns int) *DataSource { 24 | return &DataSource{ 25 | DataSource: dataSource, 26 | DriverName: driverName, 27 | DataSourceName: dataSourceName, 28 | MaxLifeTime: maxLifeTime, 29 | MaxOpenConns: maxOpenConns, 30 | MaxIdleConns: maxIdleConns, 31 | } 32 | } 33 | 34 | type DBConfig struct { 35 | DB []*DataSource `yaml:"db"` 36 | ShowSQL bool `yaml:"showSQL"` 37 | Mappers []string `yaml:"mappers"` 38 | db map[string]*GoBatisDB 39 | dbMap map[string]*DataSource 40 | } 41 | 42 | func (this *DBConfig) getDataSourceByName(datasource string) *DataSource { 43 | if this.dbMap == nil { 44 | this.dbMap = make(map[string]*DataSource) 45 | } 46 | 47 | if v, ok := this.dbMap[datasource]; ok { 48 | return v 49 | } 50 | 51 | for _, v := range this.DB { 52 | if v.DataSource == datasource { 53 | this.dbMap[datasource] = v 54 | return v 55 | } 56 | } 57 | 58 | return nil 59 | } 60 | -------------------------------------------------------------------------------- /config_db_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestDbConfig(t *testing.T) { 10 | ymlStr := ` 11 | db: 12 | - datasource: ds1 13 | driverName: mysql 14 | dataSourceName: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 15 | maxLifeTime: 120 16 | maxOpenConns: 10 17 | maxIdleConns: 5 18 | - datasource: ds2 19 | driverName: mysql 20 | dataSourceName: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 21 | maxLifeTime: 120 22 | maxOpenConns: 10 23 | maxIdleConns: 5 24 | showSQL: true 25 | mappers: 26 | - userMapper.xml 27 | - orderMapper.xml 28 | ` 29 | dbconf := buildDbConfig(ymlStr) 30 | 31 | dbc := dbconf.getDataSourceByName("ds1") 32 | assert.True(t, dbc != nil, "test fail: No datasource1") 33 | assert.True(t, dbconf.ShowSQL, "test fail: showSql == false") 34 | assert.Equal(t, dbc.DriverName, "mysql", "test fail, actual:"+dbc.DriverName) 35 | assert.Equal(t, dbc.DataSourceName, "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8", "test fail, actual:"+dbc.DataSourceName) 36 | assert.Equal(t, dbc.MaxLifeTime, 120, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxLifeTime)) 37 | assert.Equal(t, dbc.MaxOpenConns, 10, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxOpenConns)) 38 | assert.Equal(t, dbc.MaxIdleConns, 5, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxIdleConns)) 39 | assert.True(t, len(dbconf.Mappers) == 2, "len(dbconf.Mappers) != 2") 40 | assert.Equal(t, dbconf.Mappers[0], "userMapper.xml", "test fail, actual:"+dbconf.Mappers[0]) 41 | assert.Equal(t, dbconf.Mappers[1], "orderMapper.xml", "test fail, actual:"+dbconf.Mappers[1]) 42 | } 43 | 44 | func TestDbConfigCodeInit(t *testing.T) { 45 | ds1 := NewDataSourceBuilder(). 46 | DataSource("ds1"). 47 | DriverName("mysql"). 48 | DataSourceName("root:123456@tcp(127.0.0.1:3306)/test?charset=utf8"). 49 | MaxLifeTime(120). 50 | MaxOpenConns(10). 51 | MaxIdleConns(5). 52 | Build() 53 | 54 | dbconf := NewDBConfigBuilder(). 55 | DS([]*DataSource{ds1}). 56 | ShowSQL(true). 57 | Mappers([]string{"userMapper.xml", "orderMapper.xml"}). 58 | Build() 59 | 60 | dbc := dbconf.getDataSourceByName("ds1") 61 | assert.True(t, dbc != nil, "test fail: No datasource1") 62 | assert.True(t, dbconf.ShowSQL, "test fail: showSql == false") 63 | assert.Equal(t, dbc.DriverName, "mysql", "test fail, actual:"+dbc.DriverName) 64 | assert.Equal(t, dbc.DataSourceName, "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8", "test fail, actual:"+dbc.DataSourceName) 65 | assert.Equal(t, dbc.MaxLifeTime, 120, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxLifeTime)) 66 | assert.Equal(t, dbc.MaxOpenConns, 10, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxOpenConns)) 67 | assert.Equal(t, dbc.MaxIdleConns, 5, "test fail, actual:"+fmt.Sprintf("%d", dbc.MaxIdleConns)) 68 | assert.True(t, len(dbconf.Mappers) == 2, "len(dbconf.Mappers) != 2") 69 | assert.Equal(t, dbconf.Mappers[0], "userMapper.xml", "test fail, actual:"+dbconf.Mappers[0]) 70 | assert.Equal(t, dbconf.Mappers[1], "orderMapper.xml", "test fail, actual:"+dbconf.Mappers[1]) 71 | } 72 | -------------------------------------------------------------------------------- /config_mapper.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type mapperConfig struct { 8 | mappedStmts map[string]*node 9 | mappedSql map[string]*node 10 | cache map[string]*mappedStmt 11 | mu sync.Mutex 12 | } 13 | 14 | func (this *mapperConfig) put(id string, n *node) bool { 15 | this.mu.Lock() 16 | defer this.mu.Unlock() 17 | 18 | if _, ok := this.mappedStmts[id]; ok { 19 | return false 20 | } 21 | 22 | this.mappedStmts[id] = n 23 | return true 24 | } 25 | 26 | func (this *mapperConfig) putSql(id string, n *node) bool { 27 | this.mu.Lock() 28 | defer this.mu.Unlock() 29 | 30 | if _, ok := this.mappedSql[id]; ok { 31 | return false 32 | } 33 | 34 | this.mappedSql[id] = n 35 | return true 36 | } 37 | 38 | func (this *mapperConfig) getXmlNode(id string) (rootNode *node, resultType string) { 39 | rootNode, ok := this.mappedStmts[id] 40 | if !ok { 41 | panic("Can not find id:" + id + "mapped stmt") 42 | } 43 | 44 | resultType = "" 45 | if rootNode.Name == "select" { 46 | resultTypeAttr, ok := rootNode.Attrs["resultType"] 47 | if !ok { 48 | panic("Tag ` 14 | SELECT id, name FROM user where id=#{id} order by id 15 | 16 | 17 | insert into user (name, email, create_time) 18 | values 19 | 20 | #{Name}, #{Email}, #{CrtTm} 21 | 22 | 23 | 24 | update user set name = #{Name}, email = #{Email} 25 | where id = #{Id} 26 | 27 | 28 | update user 29 | 30 | name = #{Name}, 31 | email = #{Email}, 32 | 33 | where id = #{Id} 34 | 35 | 36 | delete from user where id=#{id} 37 | 38 | 39 | ` 40 | r := strings.NewReader(xmlStr) 41 | conf := buildMapperConfig(r) 42 | assert.NotNil(t, conf.getMappedStmt("Mapper.findMapById"), "Mapper.findMapById mapped stmt is nil") 43 | assert.NotNil(t, conf.getMappedStmt("Mapper.insertStructsBatch"), "Mapper.insertStructsBatch mapped stmt is nil") 44 | assert.NotNil(t, conf.getMappedStmt("Mapper.updateByStruct"), "Mapper.updateByStruct mapped stmt is nil") 45 | assert.NotNil(t, conf.getMappedStmt("Mapper.deleteById"), "Mapper.deleteById mapped stmt is nil") 46 | assert.NotNil(t, conf.getMappedStmt("Mapper.updateByCond"), "Mapper.deleteById mapped stmt is nil") 47 | } 48 | -------------------------------------------------------------------------------- /db.yml: -------------------------------------------------------------------------------- 1 | db: 2 | - datasource: ds 3 | driverName: mysql 4 | dataSourceName: root:123456@tcp(127.0.0.1:3306)/test?charset=utf8 5 | maxLifeTime: 120 6 | maxOpenConns: 10 7 | maxIdleConns: 5 8 | showSQL: true 9 | mappers: 10 | - examples/mapper/userMapper.xml 11 | -------------------------------------------------------------------------------- /examples/examples.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | 8 | _ "github.com/go-sql-driver/mysql" // 引入驱动 9 | "github.com/wenj91/gobatis" // 引入gobatis 10 | ) 11 | 12 | // 实体结构示例, tag:field为数据库对应字段名称 13 | type User struct { 14 | Id gobatis.NullInt64 `field:"id"` 15 | Name gobatis.NullString `field:"name"` 16 | Email gobatis.NullString `field:"email"` 17 | CrtTm gobatis.NullTime `field:"crtTm"` 18 | } 19 | 20 | func main() { 21 | // 初始化db 22 | db, _ := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8") 23 | dbs := make(map[string]*gobatis.GoBatisDB) 24 | dbs["ds1"] = gobatis.NewGoBatisDB(gobatis.DBTypeMySQL, db) 25 | 26 | option := gobatis.NewDBOption(). 27 | DB(dbs). 28 | ShowSQL(true). 29 | Mappers([]string{"mapper/userMapper.xml"}) 30 | 31 | gobatis.Init(option) 32 | 33 | // 获取数据源,参数为数据源名称,如:ds1 34 | gb := gobatis.Get("ds1") 35 | 36 | //传入id查询Map 37 | mapRes := make(map[string]interface{}) 38 | // stmt标识为:namespace + '.' + id, 如:userMapper.findMapById 39 | // 查询参数可以是map,也可以是数组,也可以是实体结构 40 | _, err := gb.Select("userMapper.findMapById", map[string]interface{}{"id": 1})(mapRes) 41 | fmt.Println("userMapper.findMapById-->", mapRes, err) 42 | 43 | mapRes2 := make(map[string]interface{}) 44 | _, err = gb.SelectContext(context.TODO(), "userMapper.findMapById", map[string]interface{}{"id": 4})(mapRes2) 45 | fmt.Println("userMapper.findMapById-->", mapRes2, err) 46 | } 47 | -------------------------------------------------------------------------------- /examples/mapper/userMapper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | id, name, crtTm, pwd, email 6 | 7 | 13 | 16 | 19 | 22 | 25 | 28 | 35 | 43 | 63 | 64 | 65 | update user 66 | 67 | name = #{Name}, 68 | pwd = #{Password}, 69 | 70 | where id = #{Id} 71 | 72 | 73 | insert into user (name, email, crtTm) 74 | values (#{Name}, #{Email}, #{CrtTm}) 75 | 76 | 77 | delete from user where id=#{id} 78 | 79 | -------------------------------------------------------------------------------- /examples/sql/user.sql: -------------------------------------------------------------------------------- 1 | -- MariaDB dump 10.17 Distrib 10.4.11-MariaDB, for Win64 (AMD64) 2 | -- 3 | -- Host: 127.0.0.1 Database: test 4 | -- ------------------------------------------------------ 5 | -- Server version 10.4.11-MariaDB 6 | 7 | /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; 8 | /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; 9 | /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; 10 | /*!40101 SET NAMES utf8mb4 */; 11 | /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; 12 | /*!40103 SET TIME_ZONE='+00:00' */; 13 | /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; 14 | /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; 15 | /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; 16 | /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; 17 | 18 | -- 19 | -- Table structure for table `user` 20 | -- 21 | 22 | DROP TABLE IF EXISTS `user`; 23 | /*!40101 SET @saved_cs_client = @@character_set_client */; 24 | /*!40101 SET character_set_client = utf8 */; 25 | CREATE TABLE `user` ( 26 | `id` int(11) unsigned NOT NULL AUTO_INCREMENT, 27 | `name` varchar(50) CHARACTER SET latin1 DEFAULT NULL, 28 | `email` varchar(50) CHARACTER SET latin1 DEFAULT NULL, 29 | `pwd` varchar(50) CHARACTER SET latin1 DEFAULT NULL, 30 | `crtTm` datetime DEFAULT NULL, 31 | PRIMARY KEY (`id`) 32 | ) ENGINE=InnoDB AUTO_INCREMENT=9 DEFAULT CHARSET=utf8mb4; 33 | /*!40101 SET character_set_client = @saved_cs_client */; 34 | 35 | -- 36 | -- Dumping data for table `user` 37 | -- 38 | 39 | LOCK TABLES `user` WRITE; 40 | /*!40000 ALTER TABLE `user` DISABLE KEYS */; 41 | INSERT INTO `user` VALUES (1,'wenj1991',NULL,'654321',NULL),(2,'wenj1991','sss@qq.com','12345678','2019-08-26 17:41:04'),(4,'wenj1991',NULL,NULL,NULL),(5,'wenj1991',NULL,NULL,NULL),(6,'wenj1991',NULL,NULL,NULL),(7,'wenj1991',NULL,NULL,NULL),(8,'wenj1991',NULL,NULL,NULL); 42 | /*!40000 ALTER TABLE `user` ENABLE KEYS */; 43 | UNLOCK TABLES; 44 | /*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; 45 | 46 | /*!40101 SET SQL_MODE=@OLD_SQL_MODE */; 47 | /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; 48 | /*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; 49 | /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; 50 | /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; 51 | /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; 52 | /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; 53 | 54 | -- Dump completed on 2020-02-05 16:46:02 55 | -------------------------------------------------------------------------------- /executor.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | type executor struct { 10 | gb *gbBase 11 | } 12 | 13 | func (exec *executor) updateContext(ctx context.Context, ms *mappedStmt, params map[string]interface{}) (lastInsertId int64, affected int64, err error) { 14 | boundSql, paramArr, err := paramProc(ms, params) 15 | if nil != err { 16 | return 0, 0, err 17 | } 18 | 19 | if conf.dbConf.ShowSQL { 20 | LOG.Info("SQL:%s\nParamMappings:%s\nParams:%v", boundSql.sqlStr, boundSql.paramMappings, paramArr) 21 | } 22 | 23 | stmt, err := exec.gb.db.PrepareContext(ctx, boundSql.sqlStr) 24 | if nil != err { 25 | return 0, 0, err 26 | } 27 | defer stmt.Close() 28 | 29 | result, err := stmt.ExecContext(ctx, paramArr...) 30 | if nil != err { 31 | return 0, 0, err 32 | } 33 | 34 | lastInsertId, err = result.LastInsertId() 35 | if nil != err { 36 | return 0, 0, err 37 | } 38 | affected, err = result.RowsAffected() 39 | if nil != err { 40 | return 0, 0, err 41 | } 42 | 43 | return lastInsertId, affected, nil 44 | } 45 | 46 | func (exec *executor) update(ms *mappedStmt, params map[string]interface{}) (lastInsertId int64, affected int64, err error) { 47 | return exec.updateContext(context.Background(), ms, params) 48 | } 49 | 50 | func (exec *executor) queryContext(ctx context.Context, ms *mappedStmt, params map[string]interface{}, res interface{}, rowBound ...*rowBounds) (int64, error) { 51 | boundSql, paramArr, err := paramProc(ms, params) 52 | if nil != err { 53 | return 0, err 54 | } 55 | 56 | sqlStr := boundSql.sqlStr 57 | 58 | if conf.dbConf.ShowSQL { 59 | LOG.Info("SQL:%s\nParamMappings:%s\nParams:%v", sqlStr, boundSql.paramMappings, paramArr) 60 | } 61 | 62 | count := int64(0) 63 | if len(rowBound) > 0 { 64 | countSql := "SELECT COUNT(1) cnt FROM (" + sqlStr + ") AS t" 65 | rows, err := exec.gb.db.QueryContext(ctx, countSql, paramArr...) 66 | if nil != err { 67 | return 0, err 68 | } 69 | 70 | resProc, err := rowsToMaps(rows) 71 | if nil != err { 72 | return 0, err 73 | } 74 | 75 | c, err := valToInt64(resProc[0].(map[string]interface{})["cnt"]) 76 | if nil != err { 77 | return 0, err 78 | } 79 | 80 | count = c 81 | 82 | if count <= 0 { 83 | return 0, nil 84 | } 85 | 86 | sqlStr += " LIMIT " + fmt.Sprint(rowBound[0].offset) + "," + fmt.Sprint(rowBound[0].limit) 87 | } 88 | 89 | rows, err := exec.gb.db.QueryContext(ctx, sqlStr, paramArr...) 90 | if nil != err { 91 | return 0, err 92 | } 93 | defer rows.Close() 94 | 95 | resProc, ok := resSetProcMap[ms.resultType] 96 | if !ok { 97 | return 0, errors.New("No exec result type proc, result type:" + string(ms.resultType)) 98 | } 99 | 100 | // func(rows *sql.Rows, res interface{}) error 101 | err = resProc(rows, res) 102 | if nil != err { 103 | return 0, err 104 | } 105 | 106 | return count, nil 107 | } 108 | 109 | func (exec *executor) query(ms *mappedStmt, params map[string]interface{}, res interface{}, rowBound ...*rowBounds) (int64, error) { 110 | return exec.queryContext(context.Background(), ms, params, res, rowBound...) 111 | } 112 | 113 | func paramProc(ms *mappedStmt, params map[string]interface{}) (boundSql *boundSql, paramArr []interface{}, err error) { 114 | boundSql = ms.sqlSource.getBoundSql(params) 115 | if nil == boundSql { 116 | err = errors.New("get boundSql err: boundSql == nil") 117 | return 118 | } 119 | 120 | paramArr = make([]interface{}, 0) 121 | for i := 0; i < len(boundSql.paramMappings); i++ { 122 | paramName := boundSql.paramMappings[i] 123 | param, ok := boundSql.extParams[paramName] 124 | if !ok { 125 | err = errors.New("param:" + paramName + " not exists") 126 | return 127 | } 128 | 129 | paramArr = append(paramArr, param) 130 | } 131 | 132 | return 133 | } 134 | -------------------------------------------------------------------------------- /expr.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/antonmedv/expr" 8 | ) 9 | 10 | func blank(arg interface{}) bool { 11 | if nil == arg { 12 | return true 13 | } 14 | 15 | res := fmt.Sprint(arg) 16 | if res == "" { 17 | return true 18 | } 19 | 20 | if strings.TrimSpace(res) == "" { 21 | return true 22 | } 23 | 24 | return false 25 | } 26 | 27 | func eval(expression string, mapper map[string]interface{}) bool { 28 | env := map[string]interface{}{ 29 | "$blank": blank, 30 | } 31 | 32 | for k, v := range mapper { 33 | env[k] = v 34 | } 35 | 36 | program, err := expr.Compile(expression, expr.Env(env)) 37 | if err != nil { 38 | LOG.Debug("[WARN]", "Expression:", expression, ">>> Compile result err:", err) 39 | return false 40 | } 41 | 42 | ok, err := expr.Run(program, env) 43 | if err != nil { 44 | LOG.Debug("[WARN]", "Expression:", expression, ">>> eval result err:", err) 45 | return false 46 | } 47 | 48 | return ok.(bool) 49 | } 50 | -------------------------------------------------------------------------------- /expr_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type TestUser struct { 11 | Name string 12 | } 13 | 14 | func TestExpr_eval(t *testing.T) { 15 | params := map[string]interface{}{ 16 | "name": "wenj91", 17 | "val": "", 18 | "user": &TestUser{Name: "wenj91"}, 19 | "m": map[string]interface{}{"user": &TestUser{Name: "wenj91"}}, 20 | "m1": map[string]interface{}{"name": "wenj91"}, 21 | "arr": []string{"1", "2"}, 22 | "arr2": []string{}, 23 | } 24 | expression := []string{ 25 | "1 != 1", 26 | "1 == 1", 27 | "name == 'wenj91'", 28 | "name != 'wenj91'", 29 | "user.Name1 == 'wenj91'", 30 | "user.Name == 'wenj91'", 31 | "user.Name != 'wenj91'", 32 | "user.Name != nil", 33 | "user.Name == nil", 34 | "m.user.Name != 'wenj91'", 35 | "m.user.Name == 'wenj91'", 36 | "m1.name == 'wenj91'", 37 | "m1.name != 'wenj91'", 38 | "m.user.Name == 'wenj91' && 1 == 1", 39 | "m.user.Name == 'wenj91' && 1 != 1", 40 | "m.user.Name == 'wenj91' || 1 != 1", 41 | "val != nil", 42 | "val != ''", 43 | "val == ''", 44 | "val != nil && val == ''", 45 | "val != nil and val == ''", 46 | "arr != nil and len(arr) > 0", 47 | "arr2 != nil and len(arr2) > 0", 48 | "$blank(val)", 49 | "!$blank(val)", 50 | } 51 | 52 | for i, ex := range expression { 53 | ok := eval(ex, params) 54 | fmt.Printf("Index:%v Expr:%v >>>> Result:%v \n", i, ex, ok) 55 | assertExpr(t, i, ok, ex) 56 | } 57 | } 58 | 59 | func assertExpr(t *testing.T, i int, ok bool, expr string) { 60 | switch i { 61 | case 0, 3, 4, 6, 8, 9, 12, 14, 17, 22, 24: // false 62 | assert.True(t, !ok, "Expr:"+expr+" Result:true") 63 | case 1, 2, 5, 7, 10, 11, 13, 15, 16, 18, 19, 20, 21, 23: // true 64 | assert.True(t, ok, "Expr:"+expr+" Result:false") 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/wenj91/gobatis 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/antonmedv/expr v1.9.0 7 | github.com/go-sql-driver/mysql v1.4.1 8 | github.com/json-iterator/go v1.1.12 9 | github.com/stretchr/testify v1.7.0 10 | google.golang.org/appengine v1.6.7 // indirect 11 | gopkg.in/yaml.v2 v2.4.0 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 2 | github.com/antonmedv/expr v1.9.0 h1:j4HI3NHEdgDnN9p6oI6Ndr0G5QryMY0FNxT4ONrFDGU= 3 | github.com/antonmedv/expr v1.9.0/go.mod h1:5qsM3oLGDND7sDmQGDXHkYfkjYMUX14qsgqmHhwGEk8= 4 | github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= 9 | github.com/gdamore/tcell v1.3.0/go.mod h1:Hjvr+Ofd+gLglo7RYKxxnzCBmev3BzsS67MebKS4zMM= 10 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 11 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 12 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 13 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 14 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 15 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 16 | github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s= 17 | github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 18 | github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= 19 | github.com/mattn/go-runewidth v0.0.8/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 20 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= 21 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 22 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 23 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 24 | github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 25 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 26 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 27 | github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84= 28 | github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 29 | github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4= 30 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 31 | github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 32 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 33 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 34 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 35 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 36 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 37 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= 38 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 39 | golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 40 | golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 41 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 42 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 43 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 44 | google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= 45 | google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= 46 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 47 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 48 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 49 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 50 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 51 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 52 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 53 | -------------------------------------------------------------------------------- /gobatis.dtd: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 15 | 16 | 17 | 18 | 21 | 22 | 23 | 26 | 27 | 28 | 32 | 33 | 34 | 37 | 38 | 39 | 43 | 44 | 45 | 50 | 51 | 52 | 58 | 59 | 60 | 61 | 62 | 63 | 71 | 72 | 73 | 74 | 77 | 78 | 79 | 80 | 83 | -------------------------------------------------------------------------------- /gobatis.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | ) 8 | 9 | type ResultType string 10 | 11 | var LOG ILogger = defLog 12 | 13 | const ( 14 | resultTypeMap ResultType = "map" // result set is a map: map[string]interface{} 15 | resultTypeMaps ResultType = "maps" // result set is a slice, item is map: []map[string]interface{} 16 | resultTypeStruct ResultType = "struct" // result set is a struct 17 | resultTypeStructs ResultType = "structs" // result set is a slice, item is struct 18 | resultTypeSlice ResultType = "slice" // result set is a value slice, []interface{} 19 | resultTypeSlices ResultType = "slices" // result set is a value slice, item is value slice, []interface{} 20 | resultTypeArray ResultType = "array" // 21 | resultTypeArrays ResultType = "arrays" // result set is a value slice, item is value slice, []interface{} 22 | resultTypeValue ResultType = "value" // result set is single value 23 | ) 24 | 25 | type GoBatis interface { 26 | // Select 查询数据 27 | Select(stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) 28 | // SelectContext 查询数据with context 29 | SelectContext(ctx context.Context, stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) 30 | // Insert 插入数据 31 | Insert(stmt string, param interface{}) (lastInsertId int64, affected int64, err error) 32 | // InsertContext 插入数据with context 33 | InsertContext(ctx context.Context, stmt string, param interface{}) (lastInsertId int64, affected int64, err error) 34 | // Update 更新数据 35 | Update(stmt string, param interface{}) (affected int64, err error) 36 | // UpdateContext 更新数据with context 37 | UpdateContext(ctx context.Context, stmt string, param interface{}) (affected int64, err error) 38 | // Delete 刪除数据 39 | Delete(stmt string, param interface{}) (affected int64, err error) 40 | // DeleteContext 刪除数据with context 41 | DeleteContext(ctx context.Context, stmt string, param interface{}) (affected int64, err error) 42 | } 43 | 44 | // reference from https://github.com/yinshuwei/osm/blob/master/osm.go start 45 | type dbRunner interface { 46 | Prepare(query string) (*sql.Stmt, error) 47 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 48 | Exec(query string, args ...interface{}) (sql.Result, error) 49 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 50 | Query(query string, args ...interface{}) (*sql.Rows, error) 51 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 52 | QueryRow(query string, args ...interface{}) *sql.Row 53 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 54 | } 55 | 56 | func Get(datasource string) *DB { 57 | if nil == conf { 58 | panic(errors.New("DB config no init, please invoke DB.ConfInit() to init db config!")) 59 | } 60 | 61 | if nil == db { 62 | panic(errors.New("DB init err, db == nil!")) 63 | } 64 | 65 | ds, ok := db[datasource] 66 | if !ok { 67 | panic(errors.New("Datasource:" + datasource + " not exists!")) 68 | } 69 | 70 | dbType := ds.dbType 71 | if dbType != DBTypeMySQL { 72 | panic(errors.New("No support to this driver name!")) 73 | } 74 | 75 | gb := &DB{ 76 | gbBase{ 77 | db: ds.db, 78 | dbType: ds.dbType, 79 | config: conf, 80 | }, 81 | } 82 | 83 | return gb 84 | } 85 | 86 | func SetLogger(log ILogger) { 87 | LOG = log 88 | } 89 | 90 | type rowBounds struct { 91 | offset int 92 | limit int 93 | } 94 | 95 | func RowBounds(offset int, limit int) *rowBounds { 96 | if offset < 0 { 97 | offset = 0 98 | } 99 | 100 | if limit < 0 { 101 | limit = 0 102 | } 103 | 104 | return &rowBounds{ 105 | offset: offset, 106 | limit: limit, 107 | } 108 | } 109 | 110 | type gbBase struct { 111 | db dbRunner 112 | dbType DBType 113 | config *config 114 | } 115 | 116 | // DB 117 | type DB struct { 118 | gbBase 119 | } 120 | 121 | var _ GoBatis = &DB{} 122 | 123 | // TX 124 | type TX struct { 125 | gbBase 126 | } 127 | 128 | var _ GoBatis = &TX{} 129 | 130 | // Begin TX 131 | // 132 | // ps: 133 | // TX, err := this.Begin() 134 | func (d *DB) Begin() (*TX, error) { 135 | if nil == d.db { 136 | return nil, errors.New("db no opened") 137 | } 138 | 139 | sqlDB, ok := d.db.(*sql.DB) 140 | if !ok { 141 | return nil, errors.New("db no opened") 142 | } 143 | 144 | db, err := sqlDB.Begin() 145 | if nil != err { 146 | return nil, err 147 | } 148 | 149 | t := &TX{ 150 | gbBase{ 151 | dbType: d.dbType, 152 | config: d.config, 153 | db: db, 154 | }, 155 | } 156 | return t, nil 157 | } 158 | 159 | // Begin TX with ctx & opts 160 | // 161 | // ps: 162 | // TX, err := this.BeginTx(ctx, ops) 163 | func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TX, error) { 164 | if nil == d.db { 165 | return nil, errors.New("db no opened") 166 | } 167 | 168 | sqlDb, ok := d.db.(*sql.DB) 169 | if !ok { 170 | return nil, errors.New("db no opened") 171 | } 172 | 173 | db, err := sqlDb.BeginTx(ctx, opts) 174 | if nil != err { 175 | return nil, err 176 | } 177 | 178 | t := &TX{ 179 | gbBase{ 180 | dbType: d.dbType, 181 | config: d.config, 182 | db: db, 183 | }, 184 | } 185 | return t, nil 186 | } 187 | 188 | // Transaction tx 189 | func (d *DB) Transaction(fn func(tx *TX) error) error { 190 | 191 | tx, err := d.Begin() 192 | if nil != err { 193 | return err 194 | } 195 | 196 | defer func() { 197 | if err != nil { 198 | err = tx.Rollback() 199 | if nil != err { 200 | LOG.Error("tx rollback err:#v", err) 201 | } 202 | } 203 | }() 204 | 205 | err = fn(tx) 206 | if nil != err { 207 | return err 208 | } 209 | 210 | err = tx.Commit() 211 | if nil != err { 212 | return err 213 | } 214 | 215 | return nil 216 | } 217 | 218 | // Transaction tx 219 | func (d *DB) TransactionTX(ctx context.Context, opts *sql.TxOptions, fn func(tx *TX) error) error { 220 | 221 | tx, err := d.BeginTx(ctx, opts) 222 | if nil != err { 223 | return err 224 | } 225 | 226 | defer func() { 227 | if err != nil { 228 | err = tx.Rollback() 229 | if nil != err { 230 | LOG.Error("tx rollback err:#v", err) 231 | } 232 | } 233 | }() 234 | 235 | err = fn(tx) 236 | if nil != err { 237 | return err 238 | } 239 | 240 | err = tx.Commit() 241 | if nil != err { 242 | return err 243 | } 244 | 245 | return nil 246 | } 247 | 248 | // Close db 249 | // 250 | // ps: 251 | // err := this.Close() 252 | func (g *gbBase) Close() error { 253 | if nil == g.db { 254 | return errors.New("db no opened") 255 | } 256 | 257 | sqlDb, ok := g.db.(*sql.DB) 258 | if !ok { 259 | return errors.New("db no opened") 260 | } 261 | 262 | err := sqlDb.Close() 263 | g.db = nil 264 | return err 265 | } 266 | 267 | // Commit TX 268 | // 269 | // ps: 270 | // err := TX.Commit() 271 | func (t *TX) Commit() error { 272 | if nil == t.db { 273 | return errors.New("TX no running") 274 | } 275 | 276 | sqlTx, ok := t.db.(*sql.Tx) 277 | if !ok { 278 | return errors.New("TX no running") 279 | 280 | } 281 | 282 | return sqlTx.Commit() 283 | } 284 | 285 | // Rollback TX 286 | // 287 | // ps: 288 | // err := TX.Rollback() 289 | func (t *TX) Rollback() error { 290 | if nil == t.db { 291 | return errors.New("TX no running") 292 | } 293 | 294 | sqlTx, ok := t.db.(*sql.Tx) 295 | if !ok { 296 | return errors.New("TX no running") 297 | } 298 | 299 | return sqlTx.Rollback() 300 | } 301 | 302 | // reference from https://github.com/yinshuwei/osm/blob/master/osm.go end 303 | func (g *gbBase) Select(stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) { 304 | ms := g.config.mapperConf.getMappedStmt(stmt) 305 | if nil == ms { 306 | return func(res interface{}) (int64, error) { 307 | return 0, errors.New("Mapped statement not found:" + stmt) 308 | } 309 | } 310 | ms.dbType = g.dbType 311 | 312 | params := paramProcess(param) 313 | 314 | return func(res interface{}) (int64, error) { 315 | executor := &executor{ 316 | gb: g, 317 | } 318 | count, err := executor.query(ms, params, res, rowBound...) 319 | return count, err 320 | } 321 | } 322 | 323 | func (g *gbBase) SelectContext(ctx context.Context, stmt string, param interface{}, rowBound ...*rowBounds) func(res interface{}) (int64, error) { 324 | ms := g.config.mapperConf.getMappedStmt(stmt) 325 | if nil == ms { 326 | return func(res interface{}) (int64, error) { 327 | return 0, errors.New("Mapped statement not found:" + stmt) 328 | } 329 | } 330 | ms.dbType = g.dbType 331 | 332 | params := paramProcess(param) 333 | 334 | return func(res interface{}) (int64, error) { 335 | executor := &executor{ 336 | gb: g, 337 | } 338 | count, err := executor.queryContext(ctx, ms, params, res) 339 | return count, err 340 | } 341 | } 342 | 343 | // insert(stmt string, param interface{}) 344 | func (g *gbBase) Insert(stmt string, param interface{}) (int64, int64, error) { 345 | ms := g.config.mapperConf.getMappedStmt(stmt) 346 | if nil == ms { 347 | return 0, 0, errors.New("Mapped statement not found:" + stmt) 348 | } 349 | ms.dbType = g.dbType 350 | 351 | params := paramProcess(param) 352 | 353 | executor := &executor{ 354 | gb: g, 355 | } 356 | 357 | lastInsertId, affected, err := executor.update(ms, params) 358 | if nil != err { 359 | return 0, 0, err 360 | } 361 | 362 | return lastInsertId, affected, nil 363 | } 364 | 365 | func (g *gbBase) InsertContext(ctx context.Context, stmt string, param interface{}) (int64, int64, error) { 366 | ms := g.config.mapperConf.getMappedStmt(stmt) 367 | if nil == ms { 368 | return 0, 0, errors.New("Mapped statement not found:" + stmt) 369 | } 370 | ms.dbType = g.dbType 371 | 372 | params := paramProcess(param) 373 | 374 | executor := &executor{ 375 | gb: g, 376 | } 377 | 378 | lastInsertId, affected, err := executor.updateContext(ctx, ms, params) 379 | if nil != err { 380 | return 0, 0, err 381 | } 382 | 383 | return lastInsertId, affected, nil 384 | } 385 | 386 | // update(stmt string, param interface{}) 387 | func (g *gbBase) Update(stmt string, param interface{}) (int64, error) { 388 | ms := g.config.mapperConf.getMappedStmt(stmt) 389 | if nil == ms { 390 | return 0, errors.New("Mapped statement not found:" + stmt) 391 | } 392 | ms.dbType = g.dbType 393 | 394 | params := paramProcess(param) 395 | 396 | executor := &executor{ 397 | gb: g, 398 | } 399 | 400 | _, affected, err := executor.update(ms, params) 401 | if nil != err { 402 | return 0, err 403 | } 404 | 405 | return affected, nil 406 | } 407 | 408 | func (g *gbBase) UpdateContext(ctx context.Context, stmt string, param interface{}) (int64, error) { 409 | ms := g.config.mapperConf.getMappedStmt(stmt) 410 | if nil == ms { 411 | return 0, errors.New("Mapped statement not found:" + stmt) 412 | } 413 | ms.dbType = g.dbType 414 | 415 | params := paramProcess(param) 416 | 417 | executor := &executor{ 418 | gb: g, 419 | } 420 | 421 | _, affected, err := executor.updateContext(ctx, ms, params) 422 | if nil != err { 423 | return 0, err 424 | } 425 | 426 | return affected, nil 427 | } 428 | 429 | // delete(stmt string, param interface{}) 430 | func (g *gbBase) Delete(stmt string, param interface{}) (int64, error) { 431 | return g.Update(stmt, param) 432 | } 433 | 434 | func (g *gbBase) DeleteContext(ctx context.Context, stmt string, param interface{}) (int64, error) { 435 | return g.UpdateContext(ctx, stmt, param) 436 | } 437 | -------------------------------------------------------------------------------- /gobatis_db.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import "database/sql" 4 | 5 | type DBType string 6 | 7 | const ( 8 | DBTypeMySQL DBType = "mysql" 9 | DBTypePostgres DBType = "postgres" 10 | ) 11 | 12 | type GoBatisDB struct { 13 | db *sql.DB 14 | dbType DBType 15 | } 16 | 17 | func NewGoBatisDB(dbType DBType, db *sql.DB) *GoBatisDB { 18 | return &GoBatisDB{ 19 | db: db, 20 | dbType: dbType, 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /gobatis_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/json" 7 | "fmt" 8 | "testing" 9 | 10 | _ "github.com/go-sql-driver/mysql" 11 | ) 12 | 13 | type TUser struct { 14 | Id int64 `field:"id"` 15 | Name string `field:"name"` 16 | Password NullString `field:"pwd"` 17 | Email NullString `field:"email"` 18 | CrtTm NullTime `field:"crtTm"` 19 | } 20 | 21 | // TUser to string 22 | func (u *TUser) String() string { 23 | bs, _ := json.Marshal(u) 24 | return string(bs) 25 | } 26 | 27 | func TestGoBatis(t *testing.T) { 28 | Init(NewFileOption()) 29 | if nil == conf { 30 | LOG.Error("db config == nil") 31 | return 32 | } 33 | 34 | gb := Get("ds") 35 | 36 | //result := make(map[string]interface{}) 37 | //result := make([]interface{}, 0) 38 | //var result interface{} 39 | //result := make([]TUser, 0) 40 | var result *TUser 41 | _, err := gb.Select("userMapper.findById", map[string]interface{}{ 42 | "id": 2, 43 | })(&result) 44 | 45 | fmt.Println("result:", result, "err:", err) 46 | 47 | u := &TUser{ 48 | Name: "wenj1991", 49 | Password: NullString{ 50 | String: "654321", 51 | Valid: true, 52 | }, 53 | } 54 | 55 | id, a, err := gb.Insert("userMapper.saveUser", u) 56 | fmt.Println("id:", id, "affected:", a, "err:", err) 57 | 58 | uu := &TUser{ 59 | Id: 1, 60 | Name: "wenj1993", 61 | Password: NullString{ 62 | String: "654321", 63 | Valid: true, 64 | }, 65 | } 66 | 67 | // test set 68 | affected, err := gb.Update("userMapper.updateByCond", uu) 69 | fmt.Println("updateByCond:", affected, err) 70 | 71 | param := &TUser{ 72 | Name: "wenj1993", 73 | } 74 | 75 | // test where 76 | res := make([]*TUser, 0) 77 | _, err = gb.Select("userMapper.queryStructsByCond", param)(&res) 78 | fmt.Println("queryStructsByCond", res, err) 79 | 80 | // test trim 81 | res2 := make([]*TUser, 0) 82 | _, err = gb.Select("userMapper.queryStructsByCond2", param)(&res2) 83 | fmt.Println("queryStructsByCond", res2, err) 84 | 85 | affected, err = gb.Delete("userMapper.deleteById", map[string]interface{}{ 86 | "id": 3, 87 | }) 88 | fmt.Println("delete affected:", affected, "err:", err) 89 | } 90 | 91 | func TestGoBatisWithDB(t *testing.T) { 92 | db, _ := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8") 93 | dbs := make(map[string]*GoBatisDB) 94 | dbs["ds"] = NewGoBatisDB(DBTypeMySQL, db) 95 | 96 | option := NewDBOption(). 97 | DB(dbs). 98 | ShowSQL(true). 99 | Mappers([]string{"examples/mapper/userMapper.xml"}) 100 | Init(option) 101 | 102 | if nil == conf { 103 | LOG.Info("db config == nil") 104 | return 105 | } 106 | 107 | gb := Get("ds") 108 | 109 | var result *TUser 110 | _, err := gb.Select("userMapper.findById", map[string]interface{}{ 111 | "id": 2, 112 | })(&result) 113 | 114 | fmt.Println("result:", result, "err:", err) 115 | 116 | var result2 *TUser 117 | _, err = gb.SelectContext(context.Background(), "userMapper.findById", map[string]interface{}{ 118 | "id": 4, 119 | })(&result2) 120 | fmt.Println("result:", result2, "err:", err) 121 | 122 | var result3 *TUser 123 | cnt, err := gb.Select("userMapper.findById", map[string]interface{}{ 124 | "id": 2, 125 | }, RowBounds(0, 10))(&result3) 126 | fmt.Println("result:", result3, "cnt:", cnt, "err:", err) 127 | 128 | // queryStructsByCond with count 129 | param := &TUser{ 130 | // Name: "", 131 | } 132 | 133 | res := make([]*TUser, 0) 134 | cnt2, err := gb.Select("userMapper.queryStructsByCond", param, RowBounds(0, 100))(&res) 135 | fmt.Println("queryStructsByCond", cnt2, res, err) 136 | } 137 | 138 | func TestGoBatisWithCodeConf(t *testing.T) { 139 | ds1 := NewDataSourceBuilder(). 140 | DataSource("ds1"). 141 | DriverName("mysql"). 142 | DataSourceName("root:123456@tcp(127.0.0.1:3306)/test?charset=utf8"). 143 | MaxLifeTime(120). 144 | MaxOpenConns(10). 145 | MaxIdleConns(5). 146 | Build() 147 | 148 | option := NewDSOption(). 149 | DS([]*DataSource{ds1}). 150 | Mappers([]string{"examples/mapper/userMapper.xml"}). 151 | ShowSQL(true) 152 | Init(option) 153 | 154 | if nil == conf { 155 | LOG.Error("db config == nil") 156 | return 157 | } 158 | 159 | gb := Get("ds1") 160 | 161 | //result := make(map[string]interface{}) 162 | //result := make([]interface{}, 0) 163 | //var result interface{} 164 | //result := make([]TUser, 0) 165 | var result *TUser 166 | _, err := gb.Select("userMapper.findById", map[string]interface{}{ 167 | "id": 2, 168 | })(&result) 169 | 170 | fmt.Println("result:", result, "err:", err) 171 | 172 | u := &TUser{ 173 | Name: "wenj1991", 174 | Password: NullString{ 175 | String: "654321", 176 | Valid: true, 177 | }, 178 | } 179 | 180 | id, a, err := gb.Insert("userMapper.saveUser", u) 181 | fmt.Println("id:", id, "affected:", a, "err:", err) 182 | 183 | uu := &TUser{ 184 | Id: 1, 185 | Name: "wenj1993", 186 | Password: NullString{ 187 | String: "654321", 188 | Valid: true, 189 | }, 190 | } 191 | 192 | // test set 193 | affected, err := gb.Update("userMapper.updateByCond", uu) 194 | fmt.Println("updateByCond:", affected, err) 195 | 196 | param := &TUser{ 197 | Name: "wenj1993", 198 | } 199 | 200 | // test where 201 | res := make([]*TUser, 0) 202 | _, err = gb.Select("userMapper.queryStructsByCond", param)(&res) 203 | fmt.Println("queryStructsByCond", res, err) 204 | 205 | // test trim 206 | res2 := make([]*TUser, 0) 207 | _, err = gb.Select("userMapper.queryStructsByCond2", param)(&res2) 208 | fmt.Println("queryStructsByCond", res2, err) 209 | 210 | affected, err = gb.Delete("userMapper.deleteById", map[string]interface{}{ 211 | "id": 3, 212 | }) 213 | fmt.Println("delete affected:", affected, "err:", err) 214 | } 215 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strings" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | func now() string { 12 | date := time.Now().Format("2006-01-02 15:04:06") 13 | return date 14 | } 15 | 16 | func getCallers() []string { 17 | callers := make([]string, 0) 18 | for i := 0; true; i++ { 19 | _, file, line, ok := runtime.Caller(i) 20 | if !ok { 21 | break 22 | } 23 | 24 | id := strings.LastIndex(file, "/") + 1 25 | caller := fmt.Sprintf("%s:%d", (string)(([]byte(file))[id:]), line) 26 | callers = append(callers, caller) 27 | } 28 | 29 | return callers 30 | } 31 | 32 | // 如果想定制logger可以实现此接口,否则日志将使用默认打印 33 | type ILogger interface { 34 | SetLevel(level LogLevel) 35 | Info(format string, v ...interface{}) 36 | Debug(format string, v ...interface{}) 37 | Warn(format string, v ...interface{}) 38 | Error(format string, v ...interface{}) 39 | Fatal(format string, v ...interface{}) 40 | } 41 | 42 | type LogLevel int 43 | 44 | // ALL < DEBUG < INFO < WARN < ERROR < FATAL < OFF 45 | const ( 46 | LOG_LEVEL_DEBUG LogLevel = iota 47 | LOG_LEVEL_INFO 48 | LOG_LEVEL_WARN 49 | LOG_LEVEL_ERROR 50 | LOG_LEVEL_FATAL 51 | LOG_LEVEL_OFF 52 | ) 53 | 54 | type OutType int 55 | 56 | const ( 57 | OutTypeFile OutType = iota 58 | OutTypeStd 59 | ) 60 | 61 | type iOut interface { 62 | getOutType() OutType 63 | println(msg string) 64 | Close() 65 | } 66 | 67 | type logger struct { 68 | out iOut 69 | logLevel LogLevel 70 | mu sync.Mutex 71 | callStepDepth int 72 | } 73 | 74 | var defaultLogLevel = LOG_LEVEL_DEBUG 75 | 76 | type stdLogger struct{ mu sync.Mutex } 77 | 78 | func (sl *stdLogger) println(v string) { 79 | sl.mu.Lock() 80 | defer sl.mu.Unlock() 81 | 82 | fmt.Println(v) 83 | } 84 | 85 | func (sl *stdLogger) getOutType() OutType { 86 | return OutTypeStd 87 | } 88 | 89 | func (sl *stdLogger) Close() {} 90 | 91 | var defLog = &logger{logLevel: defaultLogLevel, out: &stdLogger{}, callStepDepth: 0} 92 | 93 | func (l *logger) getPrefix(flag string) string { 94 | prefix := fmt.Sprintf("%s [%5s] - ", now(), flag) 95 | callers := getCallers() 96 | if len(callers) >= 6 { 97 | prefix = fmt.Sprintf("%s [%5s] [%s] - ", now(), flag, callers[3+l.callStepDepth]) 98 | } 99 | 100 | return prefix 101 | } 102 | 103 | func (l *logger) SetCallStepDepth(stepDepth int) { 104 | l.mu.Lock() 105 | defer l.mu.Unlock() 106 | 107 | l.callStepDepth = stepDepth 108 | } 109 | 110 | func (l *logger) SetLevel(level LogLevel) { 111 | l.mu.Lock() 112 | defer l.mu.Unlock() 113 | 114 | l.logLevel = level 115 | } 116 | 117 | func (l *logger) Info(format string, v ...interface{}) { 118 | if l.logLevel <= LOG_LEVEL_INFO { 119 | logStr := fmt.Sprintf(l.getPrefix("INFO")+format, v...) 120 | 121 | l.out.println(logStr) 122 | } 123 | } 124 | 125 | func (l *logger) Debug(format string, v ...interface{}) { 126 | if l.logLevel <= LOG_LEVEL_DEBUG { 127 | logStr := fmt.Sprintf(l.getPrefix("DEBUG")+format, v...) 128 | l.out.println(logStr) 129 | } 130 | } 131 | 132 | func (l *logger) Warn(format string, v ...interface{}) { 133 | if l.logLevel <= LOG_LEVEL_WARN { 134 | logStr := fmt.Sprintf(l.getPrefix("WARN")+format, v...) 135 | l.out.println(logStr) 136 | } 137 | } 138 | 139 | func (l *logger) Error(format string, v ...interface{}) { 140 | if l.logLevel <= LOG_LEVEL_ERROR { 141 | logStr := fmt.Sprintf(l.getPrefix("ERROR")+format, v...) 142 | l.out.println(logStr) 143 | } 144 | } 145 | 146 | func (l *logger) Fatal(format string, v ...interface{}) { 147 | if l.logLevel <= LOG_LEVEL_FATAL { 148 | logStr := fmt.Sprintf(l.getPrefix("FATAL")+format, v...) 149 | l.out.println(logStr) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /mapper.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type mappedStmt struct { 4 | dbType DBType 5 | sqlSource iSqlSource 6 | resultType ResultType 7 | } 8 | -------------------------------------------------------------------------------- /nilable_structs.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | type NullBool = sql.NullBool 8 | type NullFloat64 = sql.NullFloat64 9 | type NullInt64 = sql.NullInt64 10 | type NullString = sql.NullString 11 | type NullTime = sql.NullTime 12 | -------------------------------------------------------------------------------- /option.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type OptionType int 4 | 5 | const ( 6 | OptionTypeFile OptionType = 1 7 | OptionTypeDS OptionType = 2 8 | OptionTypeDB OptionType = 3 9 | ) 10 | 11 | type IOption interface { 12 | Type() OptionType 13 | ToDBConf() *DBConfig 14 | } 15 | -------------------------------------------------------------------------------- /option_db.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type DBOption struct { 4 | dbs map[string]*GoBatisDB 5 | showSQL bool 6 | mappers []string 7 | } 8 | 9 | var _ IOption = &DBOption{} 10 | 11 | func NewDBOption() *DBOption { 12 | return &DBOption{} 13 | } 14 | 15 | func NewDBOption_(dbs map[string]*GoBatisDB, showSQL bool, mappers []string) *DBOption { 16 | return &DBOption{ 17 | dbs: dbs, 18 | showSQL: showSQL, 19 | mappers: mappers, 20 | } 21 | } 22 | 23 | func (ds *DBOption) DB(dbs map[string]*GoBatisDB) *DBOption { 24 | ds.dbs = dbs 25 | return ds 26 | } 27 | 28 | func (ds *DBOption) ShowSQL(showSQL bool) *DBOption { 29 | ds.showSQL = showSQL 30 | return ds 31 | } 32 | 33 | func (ds *DBOption) Mappers(mappers []string) *DBOption { 34 | ds.mappers = mappers 35 | return ds 36 | } 37 | 38 | func (ds *DBOption) Type() OptionType { 39 | return OptionTypeDB 40 | } 41 | 42 | func (ds *DBOption) ToDBConf() *DBConfig { 43 | dbconf := NewDBConfigBuilder(). 44 | DB(ds.dbs). 45 | ShowSQL(ds.showSQL). 46 | Mappers(ds.mappers). 47 | Build() 48 | return dbconf 49 | } 50 | -------------------------------------------------------------------------------- /option_ds.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | type DSOption struct { 4 | dss []*DataSource 5 | showSQL bool 6 | mappers []string 7 | } 8 | 9 | var _ IOption = &DSOption{} 10 | 11 | func NewDSOption() *DSOption { 12 | return &DSOption{} 13 | } 14 | 15 | func NewDSOption_(dss []*DataSource, showSQL bool, mappers []string) *DSOption { 16 | return &DSOption{ 17 | dss: dss, 18 | showSQL: showSQL, 19 | mappers: mappers, 20 | } 21 | } 22 | 23 | func (ds *DSOption) DS(dss []*DataSource) *DSOption { 24 | ds.dss = dss 25 | return ds 26 | } 27 | 28 | func (ds *DSOption) ShowSQL(showSQL bool) *DSOption { 29 | ds.showSQL = showSQL 30 | return ds 31 | } 32 | 33 | func (ds *DSOption) Mappers(mappers []string) *DSOption { 34 | ds.mappers = mappers 35 | return ds 36 | } 37 | 38 | func (ds *DSOption) Type() OptionType { 39 | return OptionTypeDS 40 | } 41 | 42 | func (ds *DSOption) ToDBConf() *DBConfig { 43 | dbconf := NewDBConfigBuilder(). 44 | DS(ds.dss). 45 | ShowSQL(ds.showSQL). 46 | Mappers(ds.mappers). 47 | Build() 48 | return dbconf 49 | } 50 | -------------------------------------------------------------------------------- /option_file.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | ) 7 | 8 | type FileOption struct { 9 | path string 10 | } 11 | 12 | var _ IOption = &FileOption{} 13 | 14 | // NewFileOption db config file path, default: db.yml 15 | func NewFileOption(pt ...string) *FileOption { 16 | path := "db.yml" 17 | if len(pt) > 0 { 18 | path = pt[0] 19 | } 20 | return &FileOption{ 21 | path: path, 22 | } 23 | } 24 | 25 | func (f *FileOption) Type() OptionType { 26 | return OptionTypeFile 27 | } 28 | 29 | func (f *FileOption) ToDBConf() *DBConfig { 30 | file, err := os.Open(f.path) 31 | if nil != err { 32 | panic("Open db conf err:" + err.Error()) 33 | } 34 | 35 | r, err := ioutil.ReadAll(file) 36 | if nil != err { 37 | panic("Read db conf err:" + err.Error()) 38 | } 39 | 40 | dbConf := buildDbConfig(string(r)) 41 | return dbConf 42 | } 43 | -------------------------------------------------------------------------------- /params_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | type TestStruct struct { 10 | } 11 | 12 | func TestVal(t *testing.T) { 13 | paramsInt := 1 14 | v := reflect.ValueOf(paramsInt) 15 | assert.True(t, v.Kind() == reflect.Int, "test fail: params is not int") 16 | 17 | paramsInt64 := int64(1) 18 | v = reflect.ValueOf(paramsInt64) 19 | assert.True(t, v.Kind() == reflect.Int64, "test fail: params is not int64") 20 | 21 | paramsString := "" 22 | v = reflect.ValueOf(paramsString) 23 | assert.True(t, v.Kind() == reflect.String, "test fail: params is not string") 24 | 25 | paramsSlice := []int{1, 2, 3} 26 | v = reflect.ValueOf(paramsSlice) 27 | assert.True(t, v.Kind() == reflect.Slice, "test fail: params is not slice") 28 | 29 | paramsStruct := TestStruct{} 30 | v = reflect.ValueOf(paramsStruct) 31 | assert.True(t, v.Kind() == reflect.Struct, "test fail: params is not struct") 32 | 33 | paramsPtr := &TestStruct{} 34 | v = reflect.ValueOf(paramsPtr) 35 | assert.True(t, v.Kind() == reflect.Ptr, "test fail: params is not ptr") 36 | v = v.Elem() 37 | assert.True(t, v.Kind() == reflect.Struct, "test fail: params is not struct") 38 | 39 | paramsStructs := []*TestStruct{{}, {}, {}} 40 | v = reflect.ValueOf(paramsStructs) 41 | assert.True(t, v.Kind() == reflect.Slice, "test fail: params is not slice") 42 | assert.True(t, v.Len() == 3, "test fail: params len != 3") 43 | v0 := v.Index(0) 44 | assert.True(t, v0.Kind() == reflect.Ptr, "test fail: ele is not ptr") 45 | v0 = v0.Elem() 46 | assert.True(t, v0.Kind() == reflect.Struct, "test fail: ele is not struct") 47 | } 48 | -------------------------------------------------------------------------------- /parser_xml.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "encoding/xml" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | type ElemType string 10 | 11 | const ( 12 | eleTpText ElemType = "text" // 静态文本节点 13 | eleTpNode ElemType = "node" // 节点子节点 14 | ) 15 | 16 | type node struct { 17 | Id string 18 | Namespace string 19 | Name string 20 | Attrs map[string]xml.Attr 21 | Elements []element 22 | } 23 | 24 | func (n *node) getAttr(attr string) string { 25 | res := "" 26 | at, ok := n.Attrs[attr] 27 | if ok { 28 | res = at.Value 29 | } 30 | 31 | return res 32 | } 33 | 34 | type element struct { 35 | ElementType ElemType 36 | Val interface{} 37 | } 38 | 39 | func parse(r io.Reader) *node { 40 | parser := xml.NewDecoder(r) 41 | var root node 42 | namespace := "" 43 | 44 | st := NewStack() 45 | for { 46 | token, err := parser.Token() 47 | if err != nil { 48 | break 49 | } 50 | switch t := token.(type) { 51 | case xml.StartElement: //tag start 52 | elmt := xml.StartElement(t) 53 | name := elmt.Name.Local 54 | attr := elmt.Attr 55 | attrMap := make(map[string]xml.Attr) 56 | for _, val := range attr { 57 | attrMap[val.Name.Local] = val 58 | } 59 | node := node{ 60 | Name: name, 61 | Attrs: attrMap, 62 | Elements: make([]element, 0), 63 | } 64 | 65 | id := node.getAttr("id") 66 | node.Id = id 67 | 68 | if namespace == "" { 69 | namespace = node.getAttr("namespace") 70 | } 71 | 72 | st.Push(node) 73 | 74 | case xml.EndElement: //tag end 75 | if st.Len() > 0 { 76 | //cur node 77 | n := st.Pop().(node) 78 | 79 | // set namespace 80 | if namespace != "" { 81 | n.Namespace = namespace + "." 82 | } 83 | 84 | if st.Len() > 0 { //if the root node then append to element 85 | e := element{ 86 | ElementType: eleTpNode, 87 | Val: n, 88 | } 89 | 90 | pn := st.Pop().(node) 91 | els := pn.Elements 92 | els = append(els, e) 93 | pn.Elements = els 94 | st.Push(pn) 95 | } else { //else root = n 96 | root = n 97 | } 98 | } 99 | case xml.CharData: //tag content 100 | if st.Len() > 0 { 101 | n := st.Pop().(node) 102 | 103 | bytes := xml.CharData(t) 104 | content := strings.TrimSpace(string(bytes)) 105 | if content != "" { 106 | e := element{ 107 | ElementType: eleTpText, 108 | Val: content, 109 | } 110 | els := n.Elements 111 | els = append(els, e) 112 | n.Elements = els 113 | } 114 | 115 | st.Push(n) 116 | } 117 | 118 | case xml.Comment: 119 | case xml.ProcInst: 120 | case xml.Directive: 121 | default: 122 | } 123 | } 124 | 125 | if st.Len() != 0 { 126 | panic("Parse xml error, there is tag no close, please check your xml config!") 127 | } 128 | 129 | return &root 130 | } 131 | -------------------------------------------------------------------------------- /parser_xml_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestXmlNode_parse(t *testing.T) { 10 | xmlStr := ` 11 | 12 | 13 | 23 | 24 | UPDATE t_gap SET gap = #{gap} WHERE id = #{id} 25 | 26 | 27 | ` 28 | r := strings.NewReader(xmlStr) 29 | rn := parse(r) 30 | assert.NotNil(t, rn, "Parse xml result is nil") 31 | } 32 | -------------------------------------------------------------------------------- /proc_params.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "reflect" 5 | "strconv" 6 | "time" 7 | ) 8 | 9 | // parameters process util 10 | // @params 11 | // param interface{} : sql query params 12 | // @return 13 | // map[string]interface{} : return the convert map 14 | func paramProcess(param interface{}) map[string]interface{} { 15 | v := reflect.ValueOf(param) 16 | if v.Kind() == reflect.Ptr { 17 | v = v.Elem() 18 | } 19 | 20 | res := make(map[string]interface{}) 21 | switch v.Kind() { 22 | case reflect.Array, reflect.Slice: 23 | LOG.Warn("Foreach tag collection element must not be slice or array") 24 | res = listToMap(param) 25 | case reflect.Struct: 26 | res = structToMap(param) 27 | case reflect.Map: 28 | res = param.(map[string]interface{}) 29 | default: 30 | res["0"] = param 31 | } 32 | 33 | return res 34 | } 35 | 36 | // convert list to map 37 | // @params 38 | // arr interface{} : list param 39 | // @return 40 | // map[string]interface{} : return the convert map 41 | func listToMap(arr interface{}) map[string]interface{} { 42 | res := make(map[string]interface{}) 43 | objVal := reflect.ValueOf(arr) 44 | if objVal.Kind() != reflect.Array && objVal.Kind() != reflect.Slice { 45 | return res 46 | } 47 | 48 | res["list"] = arr 49 | 50 | for i := 0; i < objVal.Len(); i++ { 51 | res[strconv.Itoa(i)] = objVal.Index(i).Interface() 52 | } 53 | 54 | return res 55 | } 56 | 57 | // convert struct to map 58 | // @params 59 | // s interface{} : struct param 60 | // @return 61 | // map[string]interface{} : return the convert map 62 | func structToMap(s interface{}) map[string]interface{} { 63 | objVal := reflect.ValueOf(s) 64 | if objVal.Kind() == reflect.Ptr { 65 | objVal = objVal.Elem() 66 | } 67 | 68 | res := make(map[string]interface{}) 69 | 70 | tp := objVal.Type() 71 | switch tp.Name() { 72 | case "Time": 73 | res["0"] = nil 74 | if nil != s { 75 | res["0"] = s.(time.Time).Format("2006-01-02 15:04:05") 76 | } 77 | case "NullString": 78 | res["0"] = nil 79 | if nil != s { 80 | ns := s.(NullString) 81 | if ns.Valid { 82 | str, _ := ns.Value() 83 | res["0"] = str 84 | } 85 | } 86 | case "NullInt64": 87 | res["0"] = nil 88 | if nil != s { 89 | ns := s.(NullInt64) 90 | if ns.Valid { 91 | str, _ := ns.Value() 92 | res["0"] = str 93 | } 94 | } 95 | case "NullBool": 96 | res["0"] = nil 97 | if nil != s { 98 | ns := s.(NullBool) 99 | if ns.Valid { 100 | str, _ := ns.Value() 101 | res["0"] = str 102 | } 103 | } 104 | case "NullFloat64": 105 | res["0"] = nil 106 | if nil != s { 107 | ns := s.(NullFloat64) 108 | if ns.Valid { 109 | str, _ := ns.Value() 110 | res["0"] = str 111 | } 112 | } 113 | case "NullTime": 114 | res["0"] = nil 115 | if nil != s { 116 | ns := s.(NullTime) 117 | if ns.Valid { 118 | str, _ := ns.Value() 119 | res["0"] = str 120 | } 121 | } 122 | default: 123 | objType := objVal.Type() 124 | for i := 0; i < objVal.NumField(); i++ { 125 | fieldVal := objVal.Field(i) 126 | if fieldVal.CanInterface() { 127 | field := objType.Field(i) 128 | 129 | data, ok := fieldToVal(fieldVal.Interface()) 130 | if ok { 131 | res[field.Name] = data 132 | // 同时可以使用tag做参数名 https://github.com/wenj91/gobatis/issues/43 133 | tag := field.Tag.Get("field") 134 | if tag != "" && tag != "-" { 135 | res[tag] = data 136 | } 137 | } 138 | } 139 | } 140 | } 141 | 142 | return res 143 | } 144 | 145 | func fieldToVal(field interface{}) (interface{}, bool) { 146 | objVal := reflect.ValueOf(field) 147 | 148 | k := objVal.Kind() 149 | switch k { 150 | case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, 151 | reflect.Interface, reflect.Slice: 152 | if objVal.IsNil() { 153 | return nil, false 154 | } 155 | } 156 | 157 | if objVal.Kind() == reflect.Ptr { 158 | objVal = objVal.Elem() 159 | } 160 | 161 | tp := objVal.Type() 162 | switch tp.Name() { 163 | case "Time": 164 | return field.(time.Time).Format("2006-01-02 15:04:05"), true 165 | case "NullString": 166 | ns := field.(NullString) 167 | if ns.Valid { 168 | str, _ := ns.Value() 169 | return str, true 170 | } 171 | case "NullInt64": 172 | ni64 := field.(NullInt64) 173 | if ni64.Valid { 174 | i, _ := ni64.Value() 175 | return i, true 176 | } 177 | case "NullBool": 178 | nb := field.(NullBool) 179 | if nb.Valid { 180 | b, _ := nb.Value() 181 | return b, true 182 | } 183 | case "NullFloat64": 184 | nf := field.(NullFloat64) 185 | if nf.Valid { 186 | f, _ := nf.Value() 187 | return f, true 188 | } 189 | case "NullTime": 190 | nt := field.(NullTime) 191 | if nt.Valid { 192 | t, _ := nt.Value() 193 | return t.(time.Time).Format("2006-01-02 15:04:05"), true 194 | } 195 | default: 196 | return field, true 197 | } 198 | 199 | return nil, false 200 | } 201 | -------------------------------------------------------------------------------- /proc_params_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | type TestStruct2 struct { 10 | T NullTime 11 | S NullString 12 | Id int64 13 | } 14 | 15 | func TestParams(t *testing.T) { 16 | paramB := "" 17 | res := paramProcess(paramB) 18 | assert.NotNil(t, res["0"], "test fail: res[0] == nil") 19 | assert.Equal(t, res["0"], "", "test fail: res[0] != ''") 20 | 21 | paramM := map[string]interface{}{ 22 | "id": nil, 23 | "name": "wenj91", 24 | } 25 | res = paramProcess(paramM) 26 | assert.Nil(t, res["id"], "test fail: res['id'] != nil") 27 | assert.NotNil(t, res["name"], "test fail: res['name'] == nil") 28 | assert.Equal(t, res["name"], "wenj91", "test fail: res['name'] != 'wenj91'") 29 | 30 | paramNil := NullString{"str", true} 31 | res = paramProcess(paramNil) 32 | assert.NotNil(t, res["0"], "test fail: res['0'] == nil") 33 | assert.Equal(t, res["0"], "str", "test fail: res['0'] != 'str'") 34 | 35 | tt, _ := time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05") 36 | paramS := &TestStruct2{ 37 | T: NullTime{tt, true}, 38 | } 39 | res = paramProcess(paramS) 40 | assert.NotNil(t, res["T"], "test fail: res['T'] == nil") 41 | assert.Equal(t, res["T"], "2006-01-02 15:04:05", "test fail: res['T'] != '2006-01-02 15:04:05'") 42 | assert.Nil(t, res["S"], "test fail: res['S'] != nil") 43 | assert.Equal(t, res["Id"], int64(0), "test fail: res['Id'] != 0") 44 | } 45 | -------------------------------------------------------------------------------- /proc_res.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "reflect" 7 | ) 8 | 9 | type resultTypeProc = func(rows *sql.Rows, res interface{}) error 10 | 11 | var resSetProcMap = map[ResultType]resultTypeProc{ 12 | resultTypeMap: resMapProc, 13 | resultTypeMaps: resMapsProc, 14 | resultTypeSlice: resSliceProc, 15 | resultTypeArray: resSliceProc, 16 | resultTypeSlices: resSlicesProc, 17 | resultTypeArrays: resSlicesProc, 18 | resultTypeValue: resValueProc, 19 | resultTypeStructs: resStructsProc, 20 | resultTypeStruct: resStructProc, 21 | } 22 | 23 | func resStructProc(rows *sql.Rows, res interface{}) error { 24 | resVal := reflect.ValueOf(res) 25 | if resVal.Kind() != reflect.Ptr { 26 | return errors.New("struct query result must be ptr") 27 | } 28 | 29 | if resVal.Elem().Kind() != reflect.Ptr || 30 | !resVal.Elem().IsValid() || 31 | resVal.Elem().Elem().Kind() != reflect.Invalid { 32 | tips := ` 33 | var res *XXX 34 | queryParams := make(map[string]interface{}) 35 | queryParams["id"] = id 36 | gb.Select("selectXXXById", queryParams)(&res) 37 | 38 | Tips: "(&res)" --> don't forget "&" 39 | ` 40 | return errors.New("Struct query result must be a struct ptr, " + 41 | "and params res is the address of ptr, e.g. " + tips) 42 | } 43 | 44 | finalVal := reflect.New(resVal.Elem().Type().Elem()) 45 | finalStructPtr := finalVal.Interface() 46 | arr, err := rowsToStructs(rows, reflect.TypeOf(finalStructPtr).Elem()) 47 | if nil != err { 48 | return err 49 | } 50 | 51 | // fixme: 查询结果是返回错误呢, 觉得如果返回错误就会造成错误的困惑, 52 | // 因为这里的错误定义是用于参数以及异常校验, 53 | // 如果用户结果校验, 那么如果用户单单用err来判断是否存在查询对象而忽略了其它一些类似sql语句错误, 传参错误等, 54 | // 还是不处理好呢??? 如果有人看到这里可以提下意见|・ω・`) 55 | if len(arr) > 1 { 56 | //return errors.New("Struct query result more than one row") 57 | LOG.Warn("Struct query result more than one row") 58 | resVal.Elem().Set(reflect.ValueOf(arr[0])) 59 | } 60 | 61 | // fixme: 查询结果是返回错误呢, 觉得如果返回错误就会造成错误的困惑, 62 | // 因为这里的错误定义是用于参数校验以及异常, 63 | // 如果用户结果校验, 那么如果用户单单用err来判断是否存在查询对象而忽略了其它一些类似sql语句错误, 传参错误等, 64 | // 还是不处理好呢??? 如果有人看到这里可以提下意见|・ω・`) 65 | if len(arr) == 0 { 66 | //return errors.New("No result") 67 | LOG.Warn("Struct query result is nil") 68 | } 69 | 70 | if len(arr) == 1 { 71 | resVal.Elem().Set(reflect.ValueOf(arr[0])) 72 | } 73 | 74 | return nil 75 | } 76 | 77 | func resStructsProc(rows *sql.Rows, res interface{}) error { 78 | sliceVal := reflect.ValueOf(res) 79 | if sliceVal.Kind() != reflect.Ptr { 80 | return errors.New("structs query result must be ptr") 81 | } 82 | 83 | slicePtr := reflect.Indirect(sliceVal) 84 | if slicePtr.Kind() != reflect.Slice && slicePtr.Kind() != reflect.Array { 85 | return errors.New("structs query result must be slice") 86 | } 87 | 88 | //get elem type 89 | elem := slicePtr.Type().Elem() 90 | resultType := elem 91 | isPtr := elem.Kind() == reflect.Ptr 92 | if isPtr { 93 | resultType = elem.Elem() 94 | } 95 | 96 | if resultType.Kind() != reflect.Struct { 97 | return errors.New("structs query results item must be struct") 98 | } 99 | 100 | arr, err := rowsToStructs(rows, resultType) 101 | if nil != err { 102 | return err 103 | } 104 | 105 | for i := 0; i < len(arr); i++ { 106 | if isPtr { 107 | slicePtr.Set(reflect.Append(slicePtr, reflect.ValueOf(arr[i]))) 108 | } else { 109 | slicePtr.Set(reflect.Append(slicePtr, reflect.Indirect(reflect.ValueOf(arr[i])))) 110 | } 111 | } 112 | 113 | return nil 114 | } 115 | 116 | func resValueProc(rows *sql.Rows, res interface{}) error { 117 | resPtr := reflect.ValueOf(res) 118 | if resPtr.Kind() != reflect.Ptr { 119 | return errors.New("value query result must be ptr") 120 | } 121 | 122 | arr, err := rowsToSlices(rows) 123 | if nil != err { 124 | return err 125 | } 126 | 127 | if len(arr) > 1 { 128 | return errors.New("value query result more than one row") 129 | } 130 | 131 | tempResSlice := arr[0].([]interface{}) 132 | if len(tempResSlice) > 1 { 133 | return errors.New("value query result more than one col") 134 | } 135 | 136 | if len(tempResSlice) > 0 { 137 | if nil != tempResSlice[0] { 138 | value := reflect.Indirect(resPtr) 139 | val := dataToFieldVal(tempResSlice[0], value.Type(), "val") 140 | value.Set(reflect.ValueOf(val)) 141 | } 142 | } 143 | 144 | return nil 145 | } 146 | 147 | func resSlicesProc(rows *sql.Rows, res interface{}) error { 148 | resPtr := reflect.ValueOf(res) 149 | if resPtr.Kind() != reflect.Ptr { 150 | return errors.New("slices query result must be ptr") 151 | } 152 | 153 | value := reflect.Indirect(resPtr) 154 | if value.Kind() != reflect.Slice { 155 | return errors.New("slices query result must be slice ptr") 156 | } 157 | 158 | arr, err := rowsToSlices(rows) 159 | if nil != err { 160 | return err 161 | } 162 | 163 | if len(arr) > 0 { 164 | for _, item := range arr { 165 | // get sub arr type 166 | subVal := reflect.Indirect(reflect.New(reflect.TypeOf(res).Elem().Elem())) 167 | for _, val := range item.([]interface{}) { 168 | // set val to sub arr 169 | subVal.Set(reflect.Append(subVal, reflect.ValueOf(val))) 170 | } 171 | 172 | // set sub arr to arr 173 | value.Set(reflect.Append(value, subVal)) 174 | } 175 | } 176 | 177 | return nil 178 | } 179 | 180 | func resSliceProc(rows *sql.Rows, res interface{}) error { 181 | resPtr := reflect.ValueOf(res) 182 | if resPtr.Kind() != reflect.Ptr { 183 | return errors.New("slice query result must be ptr") 184 | } 185 | 186 | value := reflect.Indirect(resPtr) 187 | if value.Kind() != reflect.Slice { 188 | return errors.New("slice query result must be slice ptr") 189 | } 190 | 191 | arr, err := rowsToSlices(rows) 192 | if nil != err { 193 | return err 194 | } 195 | 196 | if len(arr) > 1 { 197 | return errors.New("slice query result more than one row") 198 | } 199 | 200 | if len(arr) > 0 { 201 | tempResSlice := arr[0].([]interface{}) 202 | for _, v := range tempResSlice { 203 | value.Set(reflect.Append(value, reflect.ValueOf(v))) 204 | } 205 | } 206 | 207 | return nil 208 | } 209 | 210 | func resMapProc(rows *sql.Rows, res interface{}) error { 211 | resBean := reflect.ValueOf(res) 212 | if resBean.Kind() == reflect.Ptr { 213 | return errors.New("map query result can not be ptr") 214 | } 215 | 216 | if resBean.Kind() != reflect.Map { 217 | return errors.New("map query result must be map") 218 | } 219 | 220 | arr, err := rowsToMaps(rows) 221 | if nil != err { 222 | return err 223 | } 224 | 225 | if len(arr) > 1 { 226 | return errors.New("map query result more than one row") 227 | } 228 | 229 | if len(arr) > 0 { 230 | resMap := res.(map[string]interface{}) 231 | tempResMap := arr[0].(map[string]interface{}) 232 | for k, v := range tempResMap { 233 | resMap[k] = v 234 | } 235 | } 236 | 237 | return nil 238 | } 239 | 240 | func resMapsProc(rows *sql.Rows, res interface{}) error { 241 | resPtr := reflect.ValueOf(res) 242 | if resPtr.Kind() != reflect.Ptr { 243 | return errors.New("maps query result must be ptr") 244 | } 245 | 246 | value := reflect.Indirect(resPtr) 247 | if value.Kind() != reflect.Slice { 248 | return errors.New("maps query result must be slice ptr") 249 | } 250 | arr, err := rowsToMaps(rows) 251 | if nil != err { 252 | return err 253 | } 254 | 255 | for i := 0; i < len(arr); i++ { 256 | value.Set(reflect.Append(value, reflect.ValueOf(arr[i]))) 257 | } 258 | 259 | return nil 260 | } 261 | 262 | func rowsToMaps(rows *sql.Rows) ([]interface{}, error) { 263 | res := make([]interface{}, 0) 264 | for rows.Next() { 265 | resMap := make(map[string]interface{}) 266 | cols, err := rows.Columns() 267 | if nil != err { 268 | LOG.Error("rows to maps err:%v", err) 269 | return res, err 270 | } 271 | 272 | vals := make([]interface{}, len(cols)) 273 | scanArgs := make([]interface{}, len(cols)) 274 | for i := range vals { 275 | scanArgs[i] = &vals[i] 276 | } 277 | 278 | err = rows.Scan(scanArgs...) 279 | if nil != err { 280 | LOG.Error("rows scan err:%v", err) 281 | return nil, err 282 | } 283 | 284 | for i := 0; i < len(cols); i++ { 285 | val := vals[i] 286 | if nil != val { 287 | v := reflect.ValueOf(val) 288 | if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { 289 | val = string(val.([]uint8)) 290 | } 291 | } 292 | resMap[cols[i]] = val 293 | } 294 | 295 | res = append(res, resMap) 296 | } 297 | 298 | return res, nil 299 | } 300 | 301 | func rowsToSlices(rows *sql.Rows) ([]interface{}, error) { 302 | res := make([]interface{}, 0) 303 | for rows.Next() { 304 | resSlice := make([]interface{}, 0) 305 | cols, err := rows.Columns() 306 | if nil != err { 307 | LOG.Error("rows to slices err:%v", err) 308 | return nil, err 309 | } 310 | 311 | vals := make([]interface{}, len(cols)) 312 | scanArgs := make([]interface{}, len(cols)) 313 | for i := range vals { 314 | scanArgs[i] = &vals[i] 315 | } 316 | 317 | err = rows.Scan(scanArgs...) 318 | if nil != err { 319 | LOG.Error("rows scan err:%v", err) 320 | return nil, err 321 | } 322 | 323 | for i := 0; i < len(cols); i++ { 324 | val := vals[i] 325 | if nil != val { 326 | v := reflect.ValueOf(val) 327 | if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { 328 | val = string(val.([]uint8)) 329 | } 330 | } 331 | resSlice = append(resSlice, val) 332 | } 333 | 334 | res = append(res, resSlice) 335 | } 336 | 337 | return res, nil 338 | } 339 | 340 | func rowsToStructs(rows *sql.Rows, resultType reflect.Type) ([]interface{}, error) { 341 | fieldsMapper := make(map[string]string) 342 | fields := resultType.NumField() 343 | for i := 0; i < fields; i++ { 344 | field := resultType.Field(i) 345 | fieldsMapper[field.Name] = field.Name 346 | tag := field.Tag.Get("field") 347 | if tag != "" { 348 | fieldsMapper[tag] = field.Name 349 | } 350 | } 351 | 352 | res := make([]interface{}, 0) 353 | for rows.Next() { 354 | cols, err := rows.Columns() 355 | if nil != err { 356 | LOG.Error("rows.Columns() err:%v", err) 357 | return nil, err 358 | } 359 | 360 | vals := make([]interface{}, len(cols)) 361 | scanArgs := make([]interface{}, len(cols)) 362 | for i := range vals { 363 | scanArgs[i] = &vals[i] 364 | } 365 | 366 | err = rows.Scan(scanArgs...) 367 | if nil != err { 368 | return nil, err 369 | } 370 | 371 | obj := reflect.New(resultType).Elem() 372 | objPtr := reflect.Indirect(obj) 373 | for i := 0; i < len(cols); i++ { 374 | colName := cols[i] 375 | fieldName := fieldsMapper[colName] 376 | field := objPtr.FieldByName(fieldName) 377 | // 设置相关字段的值,并判断是否可设值 378 | if field.CanSet() && vals[i] != nil { 379 | //获取字段类型并设值 380 | ft := field.Type() 381 | isPtr := false 382 | if ft.Kind() == reflect.Ptr { 383 | isPtr = true 384 | ft = ft.Elem() 385 | } 386 | 387 | data := dataToFieldVal(vals[i], ft, fieldName) 388 | if nil != data { 389 | // 数据库返回类型与字段类型不符合的情况下提醒用户 390 | dt := reflect.TypeOf(data) 391 | if dt.Name() != ft.Name() { 392 | warnInfo := "[WARN] fieldType != dataType, filedName:" + fieldName + 393 | " fieldType:" + ft.Name() + 394 | " dataType:" + dt.Name() 395 | LOG.Warn(warnInfo) 396 | } 397 | 398 | if isPtr { 399 | data = dataToPtr(data, ft, fieldName) 400 | val := reflect.ValueOf(data) 401 | field.Set(val) 402 | } else { 403 | val := reflect.ValueOf(data) 404 | field.Set(val) 405 | } 406 | 407 | } 408 | } 409 | } 410 | 411 | if objPtr.CanInterface() { 412 | res = append(res, objPtr.Addr().Interface()) 413 | } 414 | } 415 | 416 | return res, nil 417 | } 418 | -------------------------------------------------------------------------------- /sql_source.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // gobatis的核心, 从配置到sql, 参数映射...... 8 | type boundSql struct { 9 | sqlStr string 10 | paramMappings []string 11 | params map[string]interface{} 12 | extParams map[string]interface{} 13 | } 14 | 15 | type iSqlSource interface { 16 | getBoundSql(params map[string]interface{}) *boundSql 17 | } 18 | 19 | type dynamicSqlSource struct { 20 | sqlNode iSqlNode 21 | } 22 | 23 | func (d *dynamicSqlSource) getBoundSql(params map[string]interface{}) *boundSql { 24 | ctx := newDynamicContext(params) 25 | d.sqlNode.build(ctx) 26 | 27 | sss := staticSqlSource{ 28 | sqlStr: ctx.toSql(), 29 | } 30 | 31 | bs := sss.getBoundSql(params) 32 | bs.extParams = ctx.params 33 | 34 | return bs 35 | } 36 | 37 | type staticSqlSource struct { 38 | sqlStr string 39 | paramMappings []string 40 | } 41 | 42 | func (ss *staticSqlSource) getBoundSql(params map[string]interface{}) *boundSql { 43 | ss.dollarTokenHandler(params) 44 | ss.tokenHandler(params) 45 | return &boundSql{ 46 | sqlStr: ss.sqlStr, 47 | paramMappings: ss.paramMappings, 48 | params: params, 49 | } 50 | } 51 | 52 | // ${xx}处理 53 | func (ss *staticSqlSource) dollarTokenHandler(params map[string]interface{}) { 54 | sqlStr := ss.sqlStr 55 | if strings.Index(sqlStr, "$") == -1 { 56 | return 57 | } 58 | 59 | finalSqlStr := "" 60 | itemStr := "" 61 | start := 0 62 | for i := 0; i < len(sqlStr); i++ { 63 | if start > 0 { 64 | itemStr += string(sqlStr[i]) 65 | } 66 | 67 | if i != 0 && i < len(sqlStr) { 68 | if string([]byte{sqlStr[i-1], sqlStr[i]}) == "${" { 69 | start = i 70 | } 71 | } 72 | 73 | if start != 0 && i < len(sqlStr)-1 && sqlStr[i+1] == '}' { 74 | finalSqlStr += sqlStr[:start-1] 75 | sqlStr = sqlStr[i+2:] 76 | 77 | itemStr = strings.TrimSpace(itemStr) 78 | //ss.paramMappings = append(ss.paramMappings, itemStr) 79 | 80 | item, ok := params[itemStr] 81 | if !ok { 82 | LOG.Error("param %s, not found", itemStr) 83 | panic("params:" + itemStr + " not found") 84 | } 85 | 86 | finalSqlStr += item.(string) 87 | 88 | i = 0 89 | start = 0 90 | itemStr = "" 91 | } 92 | } 93 | 94 | if start != 0 { 95 | LOG.Warn("token not close") 96 | } 97 | 98 | finalSqlStr += sqlStr 99 | finalSqlStr = strings.TrimSpace(finalSqlStr) 100 | ss.sqlStr = finalSqlStr 101 | } 102 | 103 | // 静态token处理, 将#{xx}预处理为数据库预编译语句 104 | func (ss *staticSqlSource) tokenHandler(params map[string]interface{}) { 105 | sqlStr := ss.sqlStr 106 | 107 | finalSqlStr := "" 108 | itemStr := "" 109 | start := 0 110 | for i := 0; i < len(sqlStr); i++ { 111 | if start > 0 { 112 | itemStr += string(sqlStr[i]) 113 | } 114 | 115 | if i != 0 && i < len(sqlStr) { 116 | if string([]byte{sqlStr[i-1], sqlStr[i]}) == "#{" { 117 | start = i 118 | } 119 | } 120 | 121 | if start != 0 && i < len(sqlStr)-1 && sqlStr[i+1] == '}' { 122 | finalSqlStr += sqlStr[:start-1] 123 | sqlStr = sqlStr[i+2:] 124 | 125 | itemStr = strings.Trim(itemStr, " ") 126 | itemStr = strings.TrimSpace(itemStr) 127 | ss.paramMappings = append(ss.paramMappings, itemStr) 128 | 129 | finalSqlStr += "?" 130 | 131 | i = 0 132 | start = 0 133 | itemStr = "" 134 | } 135 | } 136 | 137 | if start != 0 { 138 | LOG.Warn("token not close") 139 | } 140 | 141 | finalSqlStr += sqlStr 142 | finalSqlStr = strings.TrimSpace(finalSqlStr) 143 | ss.sqlStr = finalSqlStr 144 | } 145 | -------------------------------------------------------------------------------- /sql_source_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestStaticSqlSource_getBoundSql(t *testing.T) { 10 | sss := &staticSqlSource{ 11 | sqlStr: "select * from t_gap where id = #{id} and gap = #{gap}", 12 | paramMappings: make([]string, 0), 13 | } 14 | 15 | bs := sss.getBoundSql(map[string]interface{}{ 16 | "id": 1, 17 | "gap": 10, 18 | }) 19 | 20 | expc := "select * from t_gap where id = ? and gap = ?" 21 | assert.Equal(t, bs.sqlStr, expc, "test failed, actual:"+bs.sqlStr) 22 | assert.Equal(t, bs.params["id"], 1, "test failed, actual:"+fmt.Sprintf("%d", bs.params["id"])) 23 | assert.Equal(t, bs.params["gap"], 10, "test failed, actual:"+fmt.Sprintf("%d", bs.params["gap"])) 24 | } 25 | 26 | func TestDynamicSqlSource_getBoundSql(t *testing.T) { 27 | params := map[string]interface{}{ 28 | "name": "Sean", 29 | "age": 18, 30 | "code": 18, 31 | "array": []map[string]interface{}{{"idea": "11"}, {"idea": "22"}, {"idea": "33"}}, 32 | "array1": []string{"11", "22", "33"}, 33 | "array2": []s{{A: "aa"}, {A: "bb"}, {A: "cc"}}, 34 | } 35 | 36 | msn := &mixedSqlNode{ 37 | sqlNodes: []iSqlNode{ 38 | &textSqlNode{ 39 | content: "select 1 from t_gap ", 40 | }, 41 | 42 | &whereSqlNode{ 43 | sqlNodes: []iSqlNode{ 44 | &trimSqlNode{ 45 | prefixOverrides: "and", 46 | sqlNodes: []iSqlNode{ 47 | &ifSqlNode{ 48 | test: "age == 18", 49 | sqlNode: &textSqlNode{ 50 | content: "and age = #{age}", 51 | }, 52 | }, 53 | &ifSqlNode{ 54 | test: "name == 'Sean'", 55 | sqlNode: &textSqlNode{ 56 | content: "and name = #{name}", 57 | }, 58 | }, 59 | }, 60 | }, 61 | &chooseNode{ 62 | sqlNodes: []iSqlNode{ 63 | &ifSqlNode{ 64 | test: "code == 18", 65 | sqlNode: &textSqlNode{ 66 | content: "and code = 'cctv'", 67 | }, 68 | }, 69 | }, 70 | }, 71 | }, 72 | }, 73 | 74 | &foreachSqlNode{ 75 | sqlNode: &mixedSqlNode{ 76 | sqlNodes: []iSqlNode{ 77 | &textSqlNode{ 78 | content: "#{ item.A }", 79 | }, 80 | }, 81 | }, 82 | item: "item", 83 | open: "and id in (", 84 | close: ")", 85 | separator: ",", 86 | collection: "array2", 87 | }, 88 | }, 89 | } 90 | 91 | ds := dynamicSqlSource{ 92 | sqlNode: msn, 93 | } 94 | 95 | bs := ds.getBoundSql(params) 96 | 97 | expc := "select 1 from t_gap where age = ? and name = ? and code = 'cctv' and id in ( ? , ? , ? )" 98 | assert.Equal(t, bs.sqlStr, expc, "test failed, actual:"+bs.sqlStr) 99 | assert.Equal(t, bs.params["name"], "Sean", "test failed, actual:"+fmt.Sprintf("%d", bs.params["id"])) 100 | assert.Equal(t, bs.extParams["_ls_item_p_item0.A"], "aa", "test failed, actual:"+fmt.Sprintf("%s", bs.extParams["_ls_item_p_item0.A"])) 101 | assert.Equal(t, bs.extParams["_ls_item_p_item1.A"], "bb", "test failed, actual:"+fmt.Sprintf("%s", bs.extParams["_ls_item_p_item1.A"])) 102 | assert.Equal(t, bs.extParams["_ls_item_p_item2.A"], "cc", "test failed, actual:"+fmt.Sprintf("%s", bs.extParams["_ls_item_p_item2.A"])) 103 | } 104 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import "time" 4 | 5 | // PI64 to NullInt64 6 | func PI64(i int64) *int64 { 7 | return &i 8 | } 9 | 10 | // PS to NullString 11 | func PS(s string) *string { 12 | return &s 13 | } 14 | 15 | // PF64 to NullFloat64 16 | func PF64(f float64) *float64 { 17 | return &f 18 | } 19 | 20 | // PT to NullTime 21 | func PT(t time.Time) *time.Time { 22 | return &t 23 | } 24 | 25 | // NB to NullBool 26 | func PB(b bool) *bool { 27 | return &b 28 | } 29 | 30 | // NI64 to NullInt64 31 | func NI64(i int64) NullInt64 { 32 | return NullInt64{Int64: i, Valid: true} 33 | } 34 | 35 | // NS to NullString 36 | func NS(s string) NullString { 37 | return NullString{String: s, Valid: true} 38 | } 39 | 40 | // NF64 to NullFloat64 41 | func NF64(f float64) NullFloat64 { 42 | return NullFloat64{Float64: f, Valid: true} 43 | } 44 | 45 | // NT to NullTime 46 | func NT(t time.Time) NullTime { 47 | return NullTime{Time: t, Valid: true} 48 | } 49 | 50 | // NB to NullBool 51 | func NB(b bool) NullBool { 52 | return NullBool{Bool: b, Valid: true} 53 | } 54 | -------------------------------------------------------------------------------- /util_builder.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | 7 | "gopkg.in/yaml.v2" 8 | ) 9 | 10 | func buildMapperConfig(r io.Reader) *mapperConfig { 11 | rootNode := parse(r) 12 | 13 | conf := &mapperConfig{ 14 | mappedStmts: make(map[string]*node), 15 | mappedSql: make(map[string]*node), 16 | cache: make(map[string]*mappedStmt), 17 | } 18 | 19 | if rootNode.Name != "mapper" { 20 | LOG.Error("Mapper xml must start with `mapper` tag, please check your xml mapperConfig!") 21 | panic("Mapper xml must start with `mapper` tag, please check your xml mapperConfig!") 22 | } 23 | 24 | namespace := "" 25 | if val, ok := rootNode.Attrs["namespace"]; ok { 26 | nStr := strings.TrimSpace(val.Value) 27 | if nStr != "" { 28 | nStr += "." 29 | } 30 | namespace = nStr 31 | } 32 | 33 | for _, elem := range rootNode.Elements { 34 | if elem.ElementType == eleTpNode { 35 | childNode := elem.Val.(node) 36 | switch childNode.Name { 37 | case "select", "update", "insert", "delete": 38 | if childNode.Id == "" { 39 | LOG.Error("No id for:" + childNode.Name + "Id must be not null, please check your xml mapperConfig!") 40 | panic("No id for:" + childNode.Name + "Id must be not null, please check your xml mapperConfig!") 41 | } 42 | 43 | fid := namespace + childNode.Id 44 | if ok := conf.put(fid, &childNode); !ok { 45 | LOG.Error("Repeat id for:" + fid + "Please check your xml mapperConfig!") 46 | panic("Repeat id for:" + fid + "Please check your xml mapperConfig!") 47 | } 48 | 49 | case "sql": 50 | if childNode.Id == "" { 51 | LOG.Error("No id for:" + childNode.Name + "Id must be not null, please check your xml mapperConfig!") 52 | panic("No id for:" + childNode.Name + "Id must be not null, please check your xml mapperConfig!") 53 | } 54 | 55 | fid := namespace + childNode.Id 56 | if ok := conf.putSql(fid, &childNode); !ok { 57 | LOG.Error("Repeat id for:" + fid + "Please check your xml mapperConfig!") 58 | panic("Repeat id for:" + fid + "Please check your xml mapperConfig!") 59 | } 60 | } 61 | } 62 | } 63 | 64 | return conf 65 | } 66 | 67 | func buildDbConfig(ymlStr string) *DBConfig { 68 | dbconf := &DBConfig{} 69 | err := yaml.Unmarshal([]byte(ymlStr), &dbconf) 70 | if err != nil { 71 | panic("error: " + err.Error()) 72 | } 73 | 74 | return dbconf 75 | } 76 | -------------------------------------------------------------------------------- /val.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "reflect" 9 | "strconv" 10 | "time" 11 | ) 12 | 13 | func stringToVal(data interface{}, tp reflect.Type) interface{} { 14 | str := data.(string) 15 | switch tp.Kind() { 16 | case reflect.Bool: 17 | data = false 18 | if str == "1" { 19 | data = true 20 | } 21 | case reflect.Int: 22 | i, _ := strconv.ParseInt(str, 10, 64) 23 | data = int(i) 24 | case reflect.Int8: 25 | i, _ := strconv.ParseInt(str, 10, 64) 26 | data = int8(i) 27 | case reflect.Int16: 28 | i, _ := strconv.ParseInt(str, 10, 64) 29 | data = int16(i) 30 | case reflect.Int32: 31 | i, _ := strconv.ParseInt(str, 10, 64) 32 | data = int32(i) 33 | case reflect.Int64: 34 | i, _ := strconv.ParseInt(str, 10, 64) 35 | data = int64(i) 36 | case reflect.Uint: 37 | i, _ := strconv.ParseInt(str, 10, 64) 38 | data = int32(i) 39 | case reflect.Uint8: 40 | ui, _ := strconv.ParseUint(str, 0, 64) 41 | data = uint8(ui) 42 | case reflect.Uint16: 43 | ui, _ := strconv.ParseUint(str, 0, 64) 44 | data = uint16(ui) 45 | case reflect.Uint32: 46 | ui, _ := strconv.ParseUint(str, 0, 64) 47 | data = uint32(ui) 48 | case reflect.Uint64: 49 | ui, _ := strconv.ParseUint(str, 0, 64) 50 | data = uint64(ui) 51 | case reflect.Uintptr: 52 | ui, _ := strconv.ParseUint(str, 0, 64) 53 | data = uintptr(ui) 54 | case reflect.Float32: 55 | f64, _ := strconv.ParseFloat(str, 64) 56 | data = float32(f64) 57 | case reflect.Float64: 58 | f64, _ := strconv.ParseFloat(str, 64) 59 | data = f64 60 | case reflect.Complex64: 61 | binBuf := bytes.NewBuffer(data.([]uint8)) 62 | var x complex64 63 | _ = binary.Read(binBuf, binary.BigEndian, &x) 64 | data = x 65 | case reflect.Complex128: 66 | binBuf := bytes.NewBuffer(data.([]uint8)) 67 | var x complex128 68 | _ = binary.Read(binBuf, binary.BigEndian, &x) 69 | data = x 70 | } 71 | 72 | return data 73 | } 74 | 75 | func bytesToVal(data interface{}, tp reflect.Type) interface{} { 76 | str := string(data.([]uint8)) 77 | switch tp.Kind() { 78 | case reflect.Bool: 79 | data = false 80 | if str == "1" { 81 | data = true 82 | } 83 | case reflect.Int: 84 | i, _ := strconv.ParseInt(str, 10, 64) 85 | data = int(i) 86 | case reflect.Int8: 87 | i, _ := strconv.ParseInt(str, 10, 64) 88 | data = int8(i) 89 | case reflect.Int16: 90 | i, _ := strconv.ParseInt(str, 10, 64) 91 | data = int16(i) 92 | case reflect.Int32: 93 | i, _ := strconv.ParseInt(str, 10, 64) 94 | data = int32(i) 95 | case reflect.Int64: 96 | i, _ := strconv.ParseInt(str, 10, 64) 97 | data = int64(i) 98 | case reflect.Uint: 99 | i, _ := strconv.ParseInt(str, 10, 64) 100 | data = int32(i) 101 | case reflect.Uint8: 102 | ui, _ := strconv.ParseUint(str, 0, 64) 103 | data = uint8(ui) 104 | case reflect.Uint16: 105 | ui, _ := strconv.ParseUint(str, 0, 64) 106 | data = uint16(ui) 107 | case reflect.Uint32: 108 | ui, _ := strconv.ParseUint(str, 0, 64) 109 | data = uint32(ui) 110 | case reflect.Uint64: 111 | ui, _ := strconv.ParseUint(str, 0, 64) 112 | data = uint64(ui) 113 | case reflect.Uintptr: 114 | ui, _ := strconv.ParseUint(str, 0, 64) 115 | data = uintptr(ui) 116 | case reflect.Float32: 117 | f64, _ := strconv.ParseFloat(str, 64) 118 | data = float32(f64) 119 | case reflect.Float64: 120 | f64, _ := strconv.ParseFloat(str, 64) 121 | data = f64 122 | case reflect.Complex64: 123 | binBuf := bytes.NewBuffer(data.([]uint8)) 124 | var x complex64 125 | _ = binary.Read(binBuf, binary.BigEndian, &x) 126 | data = x 127 | case reflect.Complex128: 128 | binBuf := bytes.NewBuffer(data.([]uint8)) 129 | var x complex128 130 | _ = binary.Read(binBuf, binary.BigEndian, &x) 131 | data = x 132 | } 133 | 134 | return data 135 | } 136 | 137 | func valToInt64(data interface{}) (int64, error) { 138 | tp := reflect.TypeOf(data) 139 | switch tp.Kind() { 140 | case reflect.Bool: 141 | if data.(bool) { 142 | return 1, nil 143 | } 144 | return 0, nil 145 | case reflect.Int: 146 | return int64(data.(int)), nil 147 | case reflect.Int8: 148 | return int64(data.(int8)), nil 149 | case reflect.Int16: 150 | return int64(data.(int16)), nil 151 | case reflect.Int32: 152 | return int64(data.(int32)), nil 153 | case reflect.Int64: 154 | return data.(int64), nil 155 | case reflect.Uint: 156 | return int64(data.(uint)), nil 157 | case reflect.Uint8: 158 | return int64(data.(uint8)), nil 159 | case reflect.Uint16: 160 | return int64(data.(uint16)), nil 161 | case reflect.Uint32: 162 | return int64(data.(uint32)), nil 163 | case reflect.Uint64: 164 | return int64(data.(uint64)), nil 165 | case reflect.Uintptr: 166 | return int64(data.(uintptr)), nil 167 | case reflect.Float32: 168 | return int64(data.(float32)), nil 169 | case reflect.Float64: 170 | return int64(data.(float64)), nil 171 | case reflect.String: 172 | return strconv.ParseInt(data.(string), 10, 64) 173 | default: 174 | return 0, errors.New("unsupported type") 175 | } 176 | } 177 | 178 | func valToString(data interface{}) string { 179 | tp := reflect.TypeOf(data) 180 | s := "" 181 | switch tp.Kind() { 182 | case reflect.Bool: 183 | s = strconv.FormatBool(data.(bool)) 184 | case reflect.Int: 185 | s = strconv.FormatInt(int64(data.(int)), 10) 186 | case reflect.Int8: 187 | s = strconv.FormatInt(int64(data.(int8)), 10) 188 | case reflect.Int16: 189 | s = strconv.FormatInt(int64(data.(int16)), 10) 190 | case reflect.Int32: 191 | s = strconv.FormatInt(int64(data.(int32)), 10) 192 | case reflect.Int64: 193 | s = strconv.FormatInt(int64(data.(int64)), 10) 194 | case reflect.Uint: 195 | s = strconv.FormatUint(uint64(data.(uint)), 10) 196 | case reflect.Uint8: 197 | s = strconv.FormatUint(uint64(data.(uint8)), 10) 198 | case reflect.Uint16: 199 | s = strconv.FormatUint(uint64(data.(uint16)), 10) 200 | case reflect.Uint32: 201 | s = strconv.FormatUint(uint64(data.(uint32)), 10) 202 | case reflect.Uint64: 203 | s = strconv.FormatUint(uint64(data.(uint64)), 10) 204 | case reflect.Uintptr: 205 | s = fmt.Sprint(data.(uintptr)) 206 | case reflect.Float32: 207 | s = strconv.FormatFloat(float64(data.(float32)), 'f', -1, 64) 208 | case reflect.Float64: 209 | s = strconv.FormatFloat(data.(float64), 'f', -1, 64) 210 | case reflect.Complex64: 211 | s = fmt.Sprint(data.(complex64)) 212 | case reflect.Complex128: 213 | s = fmt.Sprint(data.(complex128)) 214 | default: 215 | LOG.Warn("[WARN]no process for type:" + tp.Name()) 216 | } 217 | return s 218 | } 219 | 220 | func valUpcast(data interface{}, typeName string) interface{} { 221 | d := data 222 | switch typeName { 223 | case "bool": 224 | 225 | case "int": 226 | switch data.(type) { 227 | case bool: 228 | if data.(bool) { 229 | d = int(1) 230 | } else { 231 | d = int(0) 232 | } 233 | case int8: 234 | d = int(data.(int8)) 235 | case int16: 236 | d = int(data.(int16)) 237 | case int32: 238 | d = int(data.(int32)) 239 | case int64: 240 | d = int(data.(int64)) 241 | case uint: 242 | d = int(data.(uint)) 243 | case uint8: 244 | d = int(data.(uint8)) 245 | case uint16: 246 | d = int(data.(uint16)) 247 | case uint32: 248 | d = int(data.(uint32)) 249 | case uint64: 250 | d = int(data.(uint64)) 251 | case uintptr: 252 | d = int(data.(uintptr)) 253 | case float32: 254 | d = int(data.(float32)) 255 | case float64: 256 | d = int(data.(float64)) 257 | } 258 | case "int8": 259 | switch data.(type) { 260 | case bool: 261 | if data.(bool) { 262 | d = int8(1) 263 | } else { 264 | d = int8(0) 265 | } 266 | case int: 267 | d = int8(data.(int)) 268 | case int16: 269 | d = int8(data.(int16)) 270 | case int32: 271 | d = int8(data.(int32)) 272 | case int64: 273 | d = int8(data.(int64)) 274 | case uint: 275 | d = int8(data.(uint)) 276 | case uint8: 277 | d = int8(data.(uint8)) 278 | case uint16: 279 | d = int8(data.(uint16)) 280 | case uint32: 281 | d = int8(data.(uint32)) 282 | case uint64: 283 | d = int8(data.(uint64)) 284 | case uintptr: 285 | d = int8(data.(uintptr)) 286 | case float32: 287 | d = int8(data.(float32)) 288 | case float64: 289 | d = int8(data.(float64)) 290 | } 291 | case "int16": 292 | switch data.(type) { 293 | case bool: 294 | if data.(bool) { 295 | d = int16(1) 296 | } else { 297 | d = int16(0) 298 | } 299 | case int: 300 | d = int16(data.(int)) 301 | case int8: 302 | d = int16(data.(int8)) 303 | case int32: 304 | d = int16(data.(int32)) 305 | case int64: 306 | d = int16(data.(int64)) 307 | case uint: 308 | d = int16(data.(uint)) 309 | case uint8: 310 | d = int16(data.(uint8)) 311 | case uint16: 312 | d = int16(data.(uint16)) 313 | case uint32: 314 | d = int16(data.(uint32)) 315 | case uint64: 316 | d = int16(data.(uint64)) 317 | case uintptr: 318 | d = int16(data.(uintptr)) 319 | case float32: 320 | d = int16(data.(float32)) 321 | case float64: 322 | d = int16(data.(float64)) 323 | } 324 | case "int32": 325 | switch data.(type) { 326 | case bool: 327 | if data.(bool) { 328 | d = int32(1) 329 | } else { 330 | d = int32(0) 331 | } 332 | case int: 333 | d = int32(data.(int)) 334 | case int8: 335 | d = int32(data.(int8)) 336 | case int16: 337 | d = int32(data.(int16)) 338 | case int64: 339 | d = int32(data.(int64)) 340 | case uint: 341 | d = int32(data.(uint)) 342 | case uint8: 343 | d = int32(data.(uint8)) 344 | case uint16: 345 | d = int32(data.(uint16)) 346 | case uint32: 347 | d = int32(data.(uint32)) 348 | case uint64: 349 | d = int32(data.(uint64)) 350 | case uintptr: 351 | d = int32(data.(uintptr)) 352 | case float32: 353 | d = int32(data.(float32)) 354 | case float64: 355 | d = int32(data.(float64)) 356 | } 357 | case "int64": 358 | switch data.(type) { 359 | case bool: 360 | if data.(bool) { 361 | d = int64(1) 362 | } else { 363 | d = int64(0) 364 | } 365 | case int: 366 | d = int64(data.(int)) 367 | case int8: 368 | d = int64(data.(int8)) 369 | case int16: 370 | d = int64(data.(int16)) 371 | case int32: 372 | d = int64(data.(int32)) 373 | case uint: 374 | d = int64(data.(uint)) 375 | case uint8: 376 | d = int64(data.(uint8)) 377 | case uint16: 378 | d = int64(data.(uint16)) 379 | case uint32: 380 | d = int64(data.(uint32)) 381 | case uint64: 382 | d = int64(data.(uint64)) 383 | case uintptr: 384 | d = int64(data.(uintptr)) 385 | case float32: 386 | d = int64(data.(float32)) 387 | case float64: 388 | d = int64(data.(float64)) 389 | } 390 | case "uint": 391 | switch data.(type) { 392 | case bool: 393 | if data.(bool) { 394 | d = uint(1) 395 | } else { 396 | d = uint(0) 397 | } 398 | case int: 399 | d = uint(data.(int)) 400 | case int8: 401 | d = uint(data.(int8)) 402 | case int16: 403 | d = uint(data.(int16)) 404 | case int32: 405 | d = uint(data.(int32)) 406 | case int64: 407 | d = uint(data.(int64)) 408 | case uint8: 409 | d = uint(data.(uint8)) 410 | case uint16: 411 | d = uint(data.(uint16)) 412 | case uint32: 413 | d = uint(data.(uint32)) 414 | case uint64: 415 | d = uint(data.(uint64)) 416 | case uintptr: 417 | d = uint(data.(uintptr)) 418 | case float32: 419 | d = uint(data.(float32)) 420 | case float64: 421 | d = uint(data.(float64)) 422 | } 423 | case "uint8": 424 | switch data.(type) { 425 | case bool: 426 | if data.(bool) { 427 | d = uint8(1) 428 | } else { 429 | d = uint8(0) 430 | } 431 | case int: 432 | d = uint8(data.(int)) 433 | case int8: 434 | d = uint8(data.(int8)) 435 | case int16: 436 | d = uint8(data.(int16)) 437 | case int32: 438 | d = uint8(data.(int32)) 439 | case int64: 440 | d = uint8(data.(int64)) 441 | case uint: 442 | d = uint8(data.(uint)) 443 | case uint16: 444 | d = uint8(data.(uint16)) 445 | case uint32: 446 | d = uint8(data.(uint32)) 447 | case uint64: 448 | d = uint8(data.(uint64)) 449 | case uintptr: 450 | d = uint8(data.(uintptr)) 451 | case float32: 452 | d = uint8(data.(float32)) 453 | case float64: 454 | d = uint8(data.(float64)) 455 | } 456 | case "uint16": 457 | switch data.(type) { 458 | case bool: 459 | if data.(bool) { 460 | d = uint16(1) 461 | } else { 462 | d = uint16(0) 463 | } 464 | case int: 465 | d = uint16(data.(int)) 466 | case int8: 467 | d = uint16(data.(int8)) 468 | case int16: 469 | d = uint16(data.(int16)) 470 | case int32: 471 | d = uint16(data.(int32)) 472 | case int64: 473 | d = uint16(data.(int64)) 474 | case uint: 475 | d = uint16(data.(uint)) 476 | case uint8: 477 | d = uint16(data.(uint8)) 478 | case uint32: 479 | d = uint16(data.(uint32)) 480 | case uint64: 481 | d = uint16(data.(uint64)) 482 | case uintptr: 483 | d = uint16(data.(uintptr)) 484 | case float32: 485 | d = uint16(data.(float32)) 486 | case float64: 487 | d = uint16(data.(float64)) 488 | } 489 | case "uint32": 490 | switch data.(type) { 491 | case bool: 492 | if data.(bool) { 493 | d = uint32(1) 494 | } else { 495 | d = uint32(0) 496 | } 497 | case int: 498 | d = uint32(data.(int)) 499 | case int8: 500 | d = uint32(data.(int8)) 501 | case int16: 502 | d = uint32(data.(int16)) 503 | case int32: 504 | d = uint32(data.(int32)) 505 | case int64: 506 | d = uint32(data.(int64)) 507 | case uint: 508 | d = uint32(data.(uint)) 509 | case uint8: 510 | d = uint32(data.(uint8)) 511 | case uint16: 512 | d = uint32(data.(uint16)) 513 | case uint64: 514 | d = uint32(data.(uint64)) 515 | case uintptr: 516 | d = uint32(data.(uintptr)) 517 | case float32: 518 | d = uint32(data.(float32)) 519 | case float64: 520 | d = uint32(data.(float64)) 521 | } 522 | case "uint64": 523 | switch data.(type) { 524 | case bool: 525 | if data.(bool) { 526 | d = uint64(1) 527 | } else { 528 | d = uint64(0) 529 | } 530 | case int: 531 | d = uint64(data.(int)) 532 | case int8: 533 | d = uint64(data.(int8)) 534 | case int16: 535 | d = uint64(data.(int16)) 536 | case int32: 537 | d = uint64(data.(int32)) 538 | case int64: 539 | d = uint64(data.(int64)) 540 | case uint: 541 | d = uint64(data.(uint)) 542 | case uint8: 543 | d = uint64(data.(uint8)) 544 | case uint16: 545 | d = uint64(data.(uint16)) 546 | case uint32: 547 | d = uint64(data.(uint32)) 548 | case uintptr: 549 | d = uint64(data.(uintptr)) 550 | case float32: 551 | d = uint64(data.(float32)) 552 | case float64: 553 | d = uint64(data.(float64)) 554 | } 555 | case "uintptr": 556 | switch data.(type) { 557 | case bool: 558 | if data.(bool) { 559 | d = uintptr(1) 560 | } else { 561 | d = uintptr(0) 562 | } 563 | case int: 564 | d = uintptr(data.(int)) 565 | case int8: 566 | d = uintptr(data.(int8)) 567 | case int16: 568 | d = uintptr(data.(int16)) 569 | case int32: 570 | d = uintptr(data.(int32)) 571 | case int64: 572 | d = uintptr(data.(int64)) 573 | case uint: 574 | d = uintptr(data.(uint)) 575 | case uint8: 576 | d = uintptr(data.(uint8)) 577 | case uint16: 578 | d = uintptr(data.(uint16)) 579 | case uint32: 580 | d = uintptr(data.(uint32)) 581 | case uint64: 582 | d = uintptr(data.(uint64)) 583 | case float32: 584 | d = uintptr(data.(float32)) 585 | case float64: 586 | d = uintptr(data.(float64)) 587 | } 588 | case "float32": 589 | switch data.(type) { 590 | case bool: 591 | if data.(bool) { 592 | d = float32(1) 593 | } else { 594 | d = float32(0) 595 | } 596 | case int: 597 | d = float32(data.(int)) 598 | case int8: 599 | d = float32(data.(int8)) 600 | case int16: 601 | d = float32(data.(int16)) 602 | case int32: 603 | d = float32(data.(int32)) 604 | case int64: 605 | d = float32(data.(int64)) 606 | case uint: 607 | d = float32(data.(uint)) 608 | case uint8: 609 | d = float32(data.(uint8)) 610 | case uint16: 611 | d = float32(data.(uint16)) 612 | case uint32: 613 | d = float32(data.(uint32)) 614 | case uint64: 615 | d = float32(data.(uint64)) 616 | case uintptr: 617 | d = float32(data.(uintptr)) 618 | case float64: 619 | d = float32(data.(float64)) 620 | } 621 | case "float64": 622 | switch data.(type) { 623 | case bool: 624 | if data.(bool) { 625 | d = float64(1) 626 | } else { 627 | d = float64(0) 628 | } 629 | case int: 630 | d = float64(data.(int)) 631 | case int8: 632 | d = float64(data.(int8)) 633 | case int16: 634 | d = float64(data.(int16)) 635 | case int32: 636 | d = float64(data.(int32)) 637 | case int64: 638 | d = float64(data.(int64)) 639 | case uint: 640 | d = float64(data.(uint)) 641 | case uint8: 642 | d = float64(data.(uint8)) 643 | case uint16: 644 | d = float64(data.(uint16)) 645 | case uint32: 646 | d = float64(data.(uint32)) 647 | case uint64: 648 | d = float64(data.(uint64)) 649 | case uintptr: 650 | d = float64(data.(uintptr)) 651 | case float32: 652 | d = float64(data.(float32)) 653 | } 654 | case "complex64": 655 | case "complex128": 656 | 657 | } 658 | 659 | return d 660 | } 661 | 662 | func dataToPtr(data interface{}, tp reflect.Type, fieldName string) interface{} { 663 | defer func() { 664 | if err := recover(); nil != err { 665 | LOG.Warn("[WARN] data to field val panic, fieldName:", fieldName, " err:", err) 666 | } 667 | }() 668 | 669 | typeName := tp.Name() 670 | switch { 671 | case typeName == "bool": 672 | d := data.(bool) 673 | data = &d 674 | case typeName == "int": 675 | d := data.(int) 676 | data = &d 677 | case typeName == "int8": 678 | d := data.(int8) 679 | data = &d 680 | case typeName == "int16": 681 | d := data.(int16) 682 | data = &d 683 | case typeName == "int32": 684 | d := data.(int32) 685 | data = &d 686 | case typeName == "int64": 687 | d := data.(int64) 688 | data = &d 689 | case typeName == "uint": 690 | d := data.(uint) 691 | data = &d 692 | case typeName == "uint8": 693 | d := data.(uint8) 694 | data = &d 695 | case typeName == "uint16": 696 | d := data.(uint16) 697 | data = &d 698 | case typeName == "uint32": 699 | d := data.(uint32) 700 | data = &d 701 | case typeName == "uint64": 702 | d := data.(uint64) 703 | data = &d 704 | case typeName == "uintptr": 705 | d := data.(uintptr) 706 | data = &d 707 | case typeName == "float32": 708 | d := data.(float32) 709 | data = &d 710 | case typeName == "float64": 711 | d := data.(float64) 712 | data = &d 713 | case typeName == "complex64": 714 | d := data.(complex64) 715 | data = &d 716 | case typeName == "complex128": 717 | d := data.(complex128) 718 | data = &d 719 | case typeName == "string": 720 | d := data.(string) 721 | data = &d 722 | case typeName == "Time": 723 | d := data.(time.Time) 724 | data = &d 725 | case typeName == "NullString": 726 | if nil != data { 727 | if reflect.TypeOf(data).Kind() == reflect.Slice || 728 | reflect.TypeOf(data).Kind() == reflect.Array { 729 | data = string(data.([]byte)) 730 | } else { 731 | data = valToString(data) 732 | } 733 | data = &NullString{String: data.(string), Valid: true} 734 | } 735 | case typeName == "NullInt64": 736 | if nil != data { 737 | if reflect.TypeOf(data).Kind() == reflect.Slice || 738 | reflect.TypeOf(data).Kind() == reflect.Array { 739 | data = string(data.([]byte)) 740 | } else { 741 | data = valToString(data) 742 | } 743 | 744 | i, err := strconv.ParseInt(data.(string), 10, 64) 745 | if err != nil { 746 | panic("ParseInt err:" + err.Error()) 747 | } 748 | data = &NullInt64{Int64: i, Valid: true} 749 | } 750 | case typeName == "NullBool": 751 | if nil != data { 752 | if reflect.TypeOf(data).Kind() == reflect.Slice || 753 | reflect.TypeOf(data).Kind() == reflect.Array { 754 | data = string(data.([]byte)) 755 | } else { 756 | data = valToString(data) 757 | } 758 | if data.(string) == "true" { 759 | return NullBool{Bool: true, Valid: true} 760 | } 761 | data = &NullBool{Bool: false, Valid: true} 762 | } 763 | case typeName == "NullFloat64": 764 | if nil != data { 765 | if reflect.TypeOf(data).Kind() == reflect.Slice || 766 | reflect.TypeOf(data).Kind() == reflect.Array { 767 | data = string(data.([]byte)) 768 | } else { 769 | data = valToString(data) 770 | } 771 | 772 | f64, err := strconv.ParseFloat(data.(string), 64) 773 | if err != nil { 774 | panic("ParseFloat err:" + err.Error()) 775 | } 776 | 777 | data = &NullFloat64{Float64: f64, Valid: true} 778 | } 779 | case typeName == "NullTime": 780 | if nil != data { 781 | var t time.Time 782 | dt, ok := data.(time.Time) 783 | if !ok { 784 | if reflect.TypeOf(data).Kind() == reflect.Slice || 785 | reflect.TypeOf(data).Kind() == reflect.Array { 786 | data = string(data.([]byte)) 787 | } else { 788 | data = valToString(data) 789 | } 790 | 791 | tt, err := time.Parse("2006-01-02 15:04:05", data.(string)) 792 | if err != nil { 793 | panic("time.Parse err:" + err.Error()) 794 | } 795 | 796 | t = tt 797 | } else { 798 | t = dt 799 | } 800 | 801 | data = &NullTime{Time: t, Valid: true} 802 | } 803 | } 804 | 805 | return data 806 | } 807 | 808 | func dataToFieldVal(data interface{}, tp reflect.Type, fieldName string) interface{} { 809 | defer func() { 810 | if err := recover(); nil != err { 811 | LOG.Warn("[WARN] data to field val panic, fieldName:", fieldName, " err:", err) 812 | } 813 | }() 814 | 815 | typeName := tp.Name() 816 | switch { 817 | case typeName == "bool" || 818 | typeName == "int" || 819 | typeName == "int8" || 820 | typeName == "int16" || 821 | typeName == "int32" || 822 | typeName == "int64" || 823 | typeName == "uint" || 824 | typeName == "uint8" || 825 | typeName == "uint16" || 826 | typeName == "uint32" || 827 | typeName == "uint64" || 828 | typeName == "uintptr" || 829 | typeName == "float32" || 830 | typeName == "float64" || 831 | typeName == "complex64" || 832 | typeName == "complex128": 833 | if nil != data { 834 | dataTp := reflect.TypeOf(data) 835 | if dataTp.Kind() == reflect.Slice || 836 | dataTp.Kind() == reflect.Array { 837 | data = bytesToVal(data, tp) 838 | } 839 | 840 | dataTp = reflect.TypeOf(data) 841 | if dataTp.Kind() == reflect.String { 842 | data = stringToVal(data, tp) 843 | } 844 | 845 | data = valUpcast(data, typeName) 846 | 847 | return data 848 | } 849 | case typeName == "string": 850 | if nil != data { 851 | if reflect.TypeOf(data).Kind() == reflect.Slice || 852 | reflect.TypeOf(data).Kind() == reflect.Array { 853 | return string(data.([]byte)) 854 | } 855 | 856 | data = valToString(data) 857 | return string(data.(string)) 858 | } 859 | case typeName == "Time": 860 | if nil != data { 861 | if reflect.TypeOf(data).Kind() == reflect.Slice || 862 | reflect.TypeOf(data).Kind() == reflect.Array { 863 | data = string(data.([]byte)) 864 | } else { 865 | data = valToString(data) 866 | } 867 | 868 | tm, err := time.Parse("2006-01-02 15:04:05", data.(string)) 869 | if err != nil { 870 | panic("time.Parse err:" + err.Error()) 871 | } 872 | return tm 873 | } 874 | case typeName == "NullString": 875 | if nil != data { 876 | if reflect.TypeOf(data).Kind() == reflect.Slice || 877 | reflect.TypeOf(data).Kind() == reflect.Array { 878 | data = string(data.([]byte)) 879 | } else { 880 | data = valToString(data) 881 | } 882 | return NullString{String: data.(string), Valid: true} 883 | } 884 | case typeName == "NullInt64": 885 | if nil != data { 886 | if reflect.TypeOf(data).Kind() == reflect.Slice || 887 | reflect.TypeOf(data).Kind() == reflect.Array { 888 | data = string(data.([]byte)) 889 | } else { 890 | data = valToString(data) 891 | } 892 | 893 | i, err := strconv.ParseInt(data.(string), 10, 64) 894 | if err != nil { 895 | panic("ParseInt err:" + err.Error()) 896 | } 897 | return NullInt64{Int64: i, Valid: true} 898 | } 899 | case typeName == "NullBool": 900 | if nil != data { 901 | if reflect.TypeOf(data).Kind() == reflect.Slice || 902 | reflect.TypeOf(data).Kind() == reflect.Array { 903 | data = string(data.([]byte)) 904 | } else { 905 | data = valToString(data) 906 | } 907 | if data.(string) == "true" { 908 | return NullBool{Bool: true, Valid: true} 909 | } 910 | return NullBool{Bool: false, Valid: true} 911 | } 912 | case typeName == "NullFloat64": 913 | if nil != data { 914 | if reflect.TypeOf(data).Kind() == reflect.Slice || 915 | reflect.TypeOf(data).Kind() == reflect.Array { 916 | data = string(data.([]byte)) 917 | } else { 918 | data = valToString(data) 919 | } 920 | 921 | f64, err := strconv.ParseFloat(data.(string), 64) 922 | if err != nil { 923 | panic("ParseFloat err:" + err.Error()) 924 | } 925 | 926 | return NullFloat64{Float64: f64, Valid: true} 927 | } 928 | case typeName == "NullTime": 929 | if nil != data { 930 | var t time.Time 931 | dt, ok := data.(time.Time) 932 | if !ok { 933 | if reflect.TypeOf(data).Kind() == reflect.Slice || 934 | reflect.TypeOf(data).Kind() == reflect.Array { 935 | data = string(data.([]byte)) 936 | } else { 937 | data = valToString(data) 938 | } 939 | 940 | tt, err := time.Parse("2006-01-02 15:04:05", data.(string)) 941 | if err != nil { 942 | panic("time.Parse err:" + err.Error()) 943 | } 944 | 945 | t = tt 946 | } else { 947 | t = dt 948 | } 949 | 950 | return NullTime{Time: t, Valid: true} 951 | } 952 | } 953 | 954 | return nil 955 | } 956 | -------------------------------------------------------------------------------- /xmltag.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | type dynamicContext struct { 11 | sqlStr string 12 | params map[string]interface{} 13 | } 14 | 15 | func newDynamicContext(params map[string]interface{}) *dynamicContext { 16 | return &dynamicContext{ 17 | params: params, 18 | } 19 | } 20 | 21 | func (d *dynamicContext) appendSql(sqlStr string) { 22 | d.sqlStr += sqlStr + " " 23 | } 24 | 25 | func (d *dynamicContext) toSql() string { 26 | return strings.TrimSpace(d.sqlStr) 27 | } 28 | 29 | // [ref](http://www.mybatis.org/mybatis-3/dynamic-sql.html) 30 | type iSqlNode interface { 31 | build(ctx *dynamicContext) bool 32 | } 33 | 34 | // mixed node 35 | type mixedSqlNode struct { 36 | sqlNodes []iSqlNode 37 | } 38 | 39 | var _ iSqlNode = &mixedSqlNode{} 40 | 41 | func (m *mixedSqlNode) build(ctx *dynamicContext) bool { 42 | for i := 0; i < len(m.sqlNodes); i++ { 43 | sqlNode := m.sqlNodes[i] 44 | sqlNode.build(ctx) 45 | } 46 | 47 | return true 48 | } 49 | 50 | // if node 51 | type ifSqlNode struct { 52 | test string 53 | sqlNode iSqlNode 54 | } 55 | 56 | var _ iSqlNode = &ifSqlNode{} 57 | 58 | func (i *ifSqlNode) build(ctx *dynamicContext) bool { 59 | if ok := eval(i.test, ctx.params); ok { 60 | i.sqlNode.build(ctx) 61 | return true 62 | } 63 | 64 | return false 65 | } 66 | 67 | // text node 68 | type textSqlNode struct { 69 | content string 70 | } 71 | 72 | var _ iSqlNode = &textSqlNode{} 73 | 74 | func (t *textSqlNode) build(ctx *dynamicContext) bool { 75 | ctx.appendSql(t.content) 76 | return true 77 | } 78 | 79 | // for node 80 | const listItemPrefix = "_ls_item_p_" 81 | 82 | type foreachSqlNode struct { 83 | sqlNode iSqlNode 84 | collection string 85 | open string 86 | close string 87 | separator string 88 | item string 89 | index string 90 | } 91 | 92 | var _ iSqlNode = &foreachSqlNode{} 93 | 94 | func (f *foreachSqlNode) build(ctx *dynamicContext) bool { 95 | collection, ok := ctx.params[f.collection] 96 | if !ok { 97 | LOG.Warn("No collection for foreach tag:%s", f.collection) 98 | return false 99 | } 100 | 101 | ctx.appendSql(f.open) 102 | 103 | val := reflect.ValueOf(collection) 104 | 105 | if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { 106 | LOG.Info("Foreach tag collection must be slice or array") 107 | return false 108 | } 109 | 110 | for i := 0; i < val.Len(); i++ { 111 | v := val.Index(i) 112 | if v.Kind() == reflect.Ptr { 113 | v = v.Elem() 114 | } 115 | 116 | // convert struct map val to params 117 | keys := make([]string, 0) 118 | params := make(map[string]interface{}) 119 | switch v.Kind() { 120 | case reflect.Array, reflect.Slice: 121 | LOG.Info("Foreach tag collection element must not be slice or array") 122 | return false 123 | case reflect.Struct: 124 | m := f.structToMap(v.Interface()) 125 | for k, v := range m { 126 | key := f.item + "." + k 127 | keys = append(keys, key) 128 | params[key] = v 129 | } 130 | case reflect.Map: 131 | m := v.Interface().(map[string]interface{}) 132 | for k, v := range m { 133 | key := f.item + "." + k 134 | keys = append(keys, key) 135 | params[key] = v 136 | } 137 | default: 138 | keys = append(keys, f.item) 139 | params[f.item] = v.Interface() 140 | } 141 | 142 | params[f.item] = v.Interface() 143 | 144 | tempCtx := &dynamicContext{ 145 | params: params, 146 | } 147 | 148 | f.sqlNode.build(tempCtx) 149 | f.tokenHandler(tempCtx, i) 150 | 151 | if i != 0 { 152 | ctx.appendSql(f.separator) 153 | } 154 | 155 | ctx.appendSql(tempCtx.sqlStr) 156 | 157 | // del temp param 158 | for _, k := range keys { 159 | delete(tempCtx.params, k) 160 | } 161 | 162 | // sync tempCtx params to ctx 163 | for k, v := range tempCtx.params { 164 | ctx.params[k] = v 165 | } 166 | } 167 | ctx.appendSql(f.close) 168 | 169 | return true 170 | } 171 | 172 | func (f *foreachSqlNode) tokenHandler(ctx *dynamicContext, index int) { 173 | sqlStr := ctx.sqlStr 174 | 175 | finalSqlStr := "" 176 | itemStr := "" 177 | start := 0 178 | for i := 0; i < len(sqlStr); i++ { 179 | if start > 0 { 180 | itemStr += string(sqlStr[i]) 181 | } 182 | 183 | if i != 0 && i < len(sqlStr) { 184 | if string([]byte{sqlStr[i-1], sqlStr[i]}) == "#{" { 185 | start = i 186 | } 187 | } 188 | 189 | if start != 0 && i < len(sqlStr)-1 && sqlStr[i+1] == '}' { 190 | finalSqlStr += sqlStr[:start+1] 191 | sqlStr = sqlStr[i+2:] 192 | 193 | var re = regexp.MustCompile("^\\s*" + f.item + "\\s*") 194 | itemPrefix := listItemPrefix + f.item + fmt.Sprintf("%d", index) 195 | s := re.ReplaceAllString(itemStr, itemPrefix) 196 | s = strings.TrimSpace(s) 197 | if strings.Contains(s, itemPrefix) { 198 | itemKey := strings.TrimSpace(itemStr) 199 | if v, ok := ctx.params[itemKey]; ok { 200 | ctx.params[s] = v 201 | } 202 | } 203 | 204 | finalSqlStr += s + "}" 205 | 206 | i = 0 207 | start = 0 208 | itemStr = "" 209 | } 210 | } 211 | 212 | if start != 0 { 213 | LOG.Warn("WARN: token not close, SqlStr:" + ctx.sqlStr + " At:" + fmt.Sprintf("%d", start)) 214 | } 215 | 216 | finalSqlStr += sqlStr 217 | ctx.sqlStr = finalSqlStr 218 | } 219 | 220 | func (f *foreachSqlNode) structToMap(s interface{}) map[string]interface{} { 221 | return structToMap(s) 222 | } 223 | 224 | // set node 225 | type setSqlNode struct { 226 | sqlNodes []iSqlNode 227 | } 228 | 229 | func (s *setSqlNode) build(ctx *dynamicContext) bool { 230 | 231 | sqlStr := "" 232 | for _, sqlNode := range s.sqlNodes { 233 | tempCtx := &dynamicContext{ 234 | params: ctx.params, 235 | } 236 | sqlNode.build(tempCtx) 237 | if sqlStr != "" && tempCtx.sqlStr != "" { 238 | sqlStr += " , " 239 | } 240 | 241 | sqlStr += tempCtx.sqlStr 242 | 243 | for k, v := range tempCtx.params { 244 | ctx.params[k] = v 245 | } 246 | } 247 | 248 | if sqlStr != "" { 249 | ctx.appendSql(" set ") 250 | sqlStr = strings.TrimSpace(sqlStr) 251 | sqlStr = strings.TrimSuffix(sqlStr, ",") 252 | ctx.appendSql(sqlStr) 253 | } 254 | 255 | return true 256 | } 257 | 258 | // trim node 259 | type trimSqlNode struct { 260 | prefix string // prefix:前缀 261 | prefixOverrides string // prefixOverride:去掉第一个出现prefixOverrides字符串 262 | suffixOverrides string // suffixOverride:去掉最后一个字符串 263 | suffix string // suffix:后缀 264 | sqlNodes []iSqlNode 265 | } 266 | 267 | func (t *trimSqlNode) build(ctx *dynamicContext) bool { 268 | tempCtx := &dynamicContext{ 269 | params: ctx.params, 270 | } 271 | 272 | for _, sqlNode := range t.sqlNodes { 273 | if tempCtx.sqlStr != "" { 274 | tempCtx.sqlStr += " " 275 | } 276 | sqlNode.build(tempCtx) 277 | } 278 | 279 | if tempCtx.sqlStr != "" { 280 | sqlStr := strings.TrimSpace(tempCtx.sqlStr) 281 | 282 | preOv := strings.TrimSpace(t.prefixOverrides) 283 | if preOv != "" { 284 | sqlStr = strings.TrimPrefix(sqlStr, preOv) 285 | } 286 | 287 | suffOv := strings.TrimSpace(t.suffixOverrides) 288 | if suffOv != "" { 289 | sqlStr = strings.TrimSuffix(sqlStr, suffOv) 290 | } 291 | 292 | pre := strings.TrimSpace(t.prefix) 293 | if pre != "" { 294 | sqlStr = pre + " " + sqlStr 295 | } 296 | 297 | suff := strings.TrimSpace(t.suffix) 298 | if suff != "" { 299 | sqlStr += " " + suff 300 | } 301 | 302 | ctx.appendSql(sqlStr) 303 | } 304 | 305 | for k, v := range tempCtx.params { 306 | ctx.params[k] = v 307 | } 308 | 309 | return true 310 | } 311 | 312 | // where node 313 | type whereSqlNode struct { 314 | sqlNodes []iSqlNode 315 | } 316 | 317 | func (w *whereSqlNode) build(ctx *dynamicContext) bool { 318 | tempCtx := &dynamicContext{ 319 | params: ctx.params, 320 | } 321 | 322 | for _, sqlNode := range w.sqlNodes { 323 | if tempCtx.sqlStr != "" { 324 | tempCtx.sqlStr += " " 325 | } 326 | sqlNode.build(tempCtx) 327 | } 328 | 329 | if tempCtx.sqlStr != "" { 330 | sqlStr := strings.TrimSpace(tempCtx.sqlStr) 331 | sqlStr = strings.TrimPrefix(sqlStr, "and ") 332 | sqlStr = strings.TrimPrefix(sqlStr, "AND ") 333 | sqlStr = strings.TrimPrefix(sqlStr, "or ") 334 | sqlStr = strings.TrimPrefix(sqlStr, "OR ") 335 | 336 | ctx.appendSql("where") 337 | ctx.appendSql(sqlStr) 338 | } 339 | 340 | for k, v := range tempCtx.params { 341 | ctx.params[k] = v 342 | } 343 | 344 | return true 345 | } 346 | 347 | // choose node 348 | type chooseNode struct { 349 | sqlNodes []iSqlNode 350 | otherwise iSqlNode 351 | } 352 | 353 | func (c *chooseNode) build(ctx *dynamicContext) bool { 354 | for _, n := range c.sqlNodes { 355 | if n.build(ctx) { 356 | return true 357 | } 358 | } 359 | if nil != c.otherwise { 360 | c.otherwise.build(ctx) 361 | return true 362 | } 363 | return false 364 | } 365 | 366 | // include 367 | -------------------------------------------------------------------------------- /xmltag_test.go: -------------------------------------------------------------------------------- 1 | package gobatis 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | type s struct { 11 | A string 12 | B string 13 | } 14 | 15 | func TestTextSqlNode_build(t *testing.T) { 16 | 17 | ctx := &dynamicContext{ 18 | params: map[string]interface{}{}, 19 | } 20 | 21 | textSqlNode := &textSqlNode{ 22 | content: "select 1 from t_gap", 23 | } 24 | 25 | textSqlNode.build(ctx) 26 | 27 | expc := "select 1 from t_gap" 28 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 29 | } 30 | 31 | func TestIfSqlNode_True_build(t *testing.T) { 32 | ctx := &dynamicContext{ 33 | params: map[string]interface{}{ 34 | "name": "wenj91", 35 | }, 36 | } 37 | 38 | ifSqlNode := &ifSqlNode{ 39 | test: "name == 'wenj91'", 40 | sqlNode: &textSqlNode{ 41 | content: "select 1 from t_gap", 42 | }, 43 | } 44 | 45 | ifSqlNode.build(ctx) 46 | 47 | expc := "select 1 from t_gap" 48 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 49 | } 50 | 51 | func TestIfSqlNode_False_build(t *testing.T) { 52 | ctx := &dynamicContext{ 53 | params: map[string]interface{}{ 54 | "name": "wenj91", 55 | }, 56 | } 57 | 58 | ifSqlNode := &ifSqlNode{ 59 | test: "name != 'wenj91'", 60 | sqlNode: &textSqlNode{ 61 | content: "select 1 from t_gap", 62 | }, 63 | } 64 | 65 | ifSqlNode.build(ctx) 66 | 67 | expc := "" 68 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 69 | } 70 | 71 | func TestForeachSqlNode_build(t *testing.T) { 72 | ctx := newDynamicContext(map[string]interface{}{ 73 | "array": []int{1, 2, 3}, 74 | }) 75 | 76 | f := &foreachSqlNode{ 77 | sqlNode: &mixedSqlNode{ 78 | sqlNodes: []iSqlNode{ 79 | &textSqlNode{ 80 | content: "#{ item }", 81 | }, 82 | }, 83 | }, 84 | item: "item", 85 | open: "select 1 from t_gap where id in (", 86 | close: ")", 87 | separator: ",", 88 | collection: "array", 89 | } 90 | 91 | f.build(ctx) 92 | 93 | expc := "select 1 from t_gap where id in ( #{_ls_item_p_item0} , #{_ls_item_p_item1} , #{_ls_item_p_item2} )" 94 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 95 | assert.Equal(t, ctx.params["_ls_item_p_item0"], 1, "test failed, actual:"+fmt.Sprintf("%d", ctx.params["_ls_item_p_item0"])) 96 | assert.Equal(t, ctx.params["_ls_item_p_item1"], 2, "test failed, actual:"+fmt.Sprintf("%d", ctx.params["_ls_item_p_item1"])) 97 | assert.Equal(t, ctx.params["_ls_item_p_item2"], 3, "test failed, actual:"+fmt.Sprintf("%d", ctx.params["_ls_item_p_item2"])) 98 | } 99 | 100 | func TestMixedSqlNode_build(t *testing.T) { 101 | params := map[string]interface{}{ 102 | "name": "wenj91", 103 | "array": []map[string]interface{}{{"idea": "11"}, {"idea": "22"}, {"idea": "33"}}, 104 | "array1": []string{"11", "22", "33"}, 105 | "array2": []s{{A: "aa"}, {A: "bb"}, {A: "cc"}}, 106 | } 107 | 108 | mixedSqlNode := &mixedSqlNode{ 109 | sqlNodes: []iSqlNode{ 110 | &textSqlNode{ 111 | content: "select 1 from t_gap where 1 = 1", 112 | }, 113 | &ifSqlNode{ 114 | test: "name == 'wenj91'", 115 | sqlNode: &textSqlNode{ 116 | content: "and name = #{name}", 117 | }, 118 | }, 119 | &foreachSqlNode{ 120 | sqlNode: &mixedSqlNode{ 121 | sqlNodes: []iSqlNode{ 122 | &ifSqlNode{ 123 | test: "item.B == nil", 124 | sqlNode: &textSqlNode{ 125 | content: "1, ", 126 | }, 127 | }, 128 | &textSqlNode{ 129 | content: "#{ item.A }", 130 | }, 131 | }, 132 | }, 133 | item: "item", 134 | open: "and id in (", 135 | close: ")", 136 | separator: ",", 137 | collection: "array2", 138 | }, 139 | }, 140 | } 141 | 142 | ctx := newDynamicContext(params) 143 | 144 | mixedSqlNode.build(ctx) 145 | 146 | expc := "select 1 from t_gap where 1 = 1 and name = #{name} and id in ( #{_ls_item_p_item0.A} , #{_ls_item_p_item1.A} , #{_ls_item_p_item2.A} )" 147 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 148 | assert.Equal(t, ctx.params["_ls_item_p_item0.A"], "aa", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["_ls_item_p_item0.A"])) 149 | assert.Equal(t, ctx.params["_ls_item_p_item1.A"], "bb", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["_ls_item_p_item1.A"])) 150 | assert.Equal(t, ctx.params["_ls_item_p_item2.A"], "cc", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["_ls_item_p_item2.A"])) 151 | } 152 | 153 | func TestSetSqlNode_build(t *testing.T) { 154 | params := map[string]interface{}{ 155 | "name": "wenj91", 156 | "name2": "wenj91", 157 | } 158 | 159 | setSqlNode := &setSqlNode{ 160 | sqlNodes: []iSqlNode{ 161 | &ifSqlNode{ 162 | test: "name == 'wenj91'", 163 | sqlNode: &textSqlNode{ 164 | content: "name = #{name}", 165 | }, 166 | }, 167 | &ifSqlNode{ 168 | test: "name2 == 'wenj91'", 169 | sqlNode: &textSqlNode{ 170 | content: "name2 = #{name2}", 171 | }, 172 | }, 173 | }, 174 | } 175 | 176 | ctx := newDynamicContext(params) 177 | 178 | setSqlNode.build(ctx) 179 | 180 | expc := "set name = #{name} , name2 = #{name2}" 181 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 182 | assert.Equal(t, ctx.params["name"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name"])) 183 | assert.Equal(t, ctx.params["name2"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name2"])) 184 | } 185 | 186 | func TestTrimSqlNode_build(t *testing.T) { 187 | params := map[string]interface{}{ 188 | "name": "wenj91", 189 | "name2": "wenj91", 190 | } 191 | 192 | trimSqlNode := &trimSqlNode{ 193 | prefixOverrides: "and", 194 | suffixOverrides: ",", 195 | sqlNodes: []iSqlNode{ 196 | &ifSqlNode{ 197 | test: "name == 'wenj91'", 198 | sqlNode: &textSqlNode{ 199 | content: "and name = #{name}", 200 | }, 201 | }, 202 | &ifSqlNode{ 203 | test: "name2 == 'wenj91'", 204 | sqlNode: &textSqlNode{ 205 | content: "and name2 = #{name2}", 206 | }, 207 | }, 208 | }, 209 | } 210 | 211 | ctx := newDynamicContext(params) 212 | 213 | trimSqlNode.build(ctx) 214 | 215 | expc := "name = #{name} and name2 = #{name2}" 216 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 217 | assert.Equal(t, ctx.params["name"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name"])) 218 | assert.Equal(t, ctx.params["name2"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name2"])) 219 | } 220 | 221 | func TestWhereSqlNode_build(t *testing.T) { 222 | params := map[string]interface{}{ 223 | "name": "wenj91", 224 | "name2": "wenj91", 225 | } 226 | 227 | whereSqlNode := &whereSqlNode{ 228 | sqlNodes: []iSqlNode{ 229 | &ifSqlNode{ 230 | test: "name == 'wenj91'", 231 | sqlNode: &textSqlNode{ 232 | content: "and name = #{name}", 233 | }, 234 | }, 235 | &ifSqlNode{ 236 | test: "name2 == 'wenj91'", 237 | sqlNode: &textSqlNode{ 238 | content: "and name2 = #{name2}", 239 | }, 240 | }, 241 | }, 242 | } 243 | 244 | ctx := newDynamicContext(params) 245 | 246 | whereSqlNode.build(ctx) 247 | 248 | expc := "where name = #{name} and name2 = #{name2}" 249 | assert.Equal(t, ctx.toSql(), expc, "test failed, actual:"+ctx.toSql()) 250 | assert.Equal(t, ctx.params["name"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name"])) 251 | assert.Equal(t, ctx.params["name2"], "wenj91", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name2"])) 252 | } 253 | 254 | func TestChooseSqlNode_build(t *testing.T) { 255 | params := map[string]interface{}{ 256 | "name": "aa", 257 | } 258 | 259 | choose := chooseNode{ 260 | sqlNodes: []iSqlNode{ 261 | &ifSqlNode{ 262 | test: "name == 'sean'", 263 | sqlNode: &textSqlNode{ 264 | content: "and name = 'sean' ", 265 | }, 266 | }, 267 | &ifSqlNode{ 268 | test: "name == 'Sean'", 269 | sqlNode: &textSqlNode{ 270 | content: "and name = #{name} ", 271 | }, 272 | }, 273 | }, 274 | otherwise: &mixedSqlNode{ 275 | sqlNodes: []iSqlNode{ 276 | &ifSqlNode{ 277 | test: "name == 'aa'", 278 | sqlNode: &textSqlNode{ 279 | content: "and name = 'aa' ", 280 | }, 281 | }, 282 | }, 283 | }, 284 | } 285 | 286 | ctx := newDynamicContext(params) 287 | choose.build(ctx) 288 | expc := "and name = 'aa'" 289 | assert.Equal(t, strings.Trim(ctx.toSql(), " "), expc, "test failed, actual:"+ctx.toSql()) 290 | assert.Equal(t, ctx.params["name"], "aa", "test failed, actual:"+fmt.Sprintf("%s", ctx.params["name"])) 291 | } 292 | --------------------------------------------------------------------------------