├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── README.md ├── args.go ├── args_test.go ├── builder.go ├── builder_test.go ├── cond.go ├── cond_test.go ├── createtable.go ├── createtable_test.go ├── cte.go ├── cte_test.go ├── ctequery.go ├── delete.go ├── delete_test.go ├── doc.go ├── fieldmapper.go ├── flavor.go ├── flavor_test.go ├── go.mod ├── go.sum ├── injection.go ├── insert.go ├── insert_test.go ├── interpolate.go ├── interpolate_test.go ├── modifiers.go ├── modifiers_test.go ├── select.go ├── select_test.go ├── stringbuilder.go ├── struct.go ├── struct_test.go ├── structfields.go ├── union.go ├── union_test.go ├── update.go ├── update_test.go ├── whereclause.go └── whereclause_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | name: Build 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Set up Go 1.x 15 | uses: actions/setup-go@v2 16 | with: 17 | go-version: ^1.13 18 | 19 | - name: Check out code into the Go module directory 20 | uses: actions/checkout@v2 21 | 22 | - name: Get dependencies 23 | run: | 24 | go mod download 25 | go get 26 | 27 | - name: Test 28 | run: go test -v -coverprofile=covprofile.cov ./... 29 | 30 | - name: Send coverage 31 | env: 32 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 33 | run: | 34 | go get github.com/mattn/goveralls 35 | go run github.com/mattn/goveralls -coverprofile=covprofile.cov -service=github 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # Intellij 27 | *.iml 28 | .idea/ 29 | 30 | # VS Code 31 | debug 32 | debug_test 33 | .vscode/ 34 | 35 | # Mac 36 | .DS_Store 37 | 38 | # go work 39 | go.work 40 | go.work.sum 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Huan Du 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /args.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | "sort" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | // Args stores arguments associated with a SQL. 15 | type Args struct { 16 | // The default flavor used by `Args#Compile` 17 | Flavor Flavor 18 | 19 | indexBase int 20 | argValues []interface{} 21 | namedArgs map[string]int 22 | sqlNamedArgs map[string]int 23 | onlyNamed bool 24 | } 25 | 26 | func init() { 27 | // Predefine some $n args to avoid additional memory allocation. 28 | predefinedArgs = make([]string, 0, maxPredefinedArgs) 29 | 30 | for i := 0; i < maxPredefinedArgs; i++ { 31 | predefinedArgs = append(predefinedArgs, fmt.Sprintf("$%v", i)) 32 | } 33 | } 34 | 35 | const maxPredefinedArgs = 64 36 | 37 | var predefinedArgs []string 38 | 39 | // Add adds an arg to Args and returns a placeholder. 40 | func (args *Args) Add(arg interface{}) string { 41 | idx := args.add(arg) 42 | 43 | if idx < maxPredefinedArgs { 44 | return predefinedArgs[idx] 45 | } 46 | 47 | return fmt.Sprintf("$%v", idx) 48 | } 49 | 50 | func (args *Args) add(arg interface{}) int { 51 | idx := len(args.argValues) + args.indexBase 52 | 53 | switch a := arg.(type) { 54 | case sql.NamedArg: 55 | if args.sqlNamedArgs == nil { 56 | args.sqlNamedArgs = map[string]int{} 57 | } 58 | 59 | if p, ok := args.sqlNamedArgs[a.Name]; ok { 60 | arg = args.argValues[p] 61 | break 62 | } 63 | 64 | args.sqlNamedArgs[a.Name] = idx 65 | case namedArgs: 66 | if args.namedArgs == nil { 67 | args.namedArgs = map[string]int{} 68 | } 69 | 70 | if p, ok := args.namedArgs[a.name]; ok { 71 | arg = args.argValues[p] 72 | break 73 | } 74 | 75 | // Find out the real arg and add it to args. 76 | idx = args.add(a.arg) 77 | args.namedArgs[a.name] = idx 78 | return idx 79 | } 80 | 81 | args.argValues = append(args.argValues, arg) 82 | return idx 83 | } 84 | 85 | // Compile compiles builder's format to standard sql and returns associated args. 86 | // 87 | // The format string uses a special syntax to represent arguments. 88 | // 89 | // $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`. 90 | // $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1. 91 | // ${name} refers a named argument created by `Named` with `name`. 92 | // $$ is a "$" string. 93 | func (args *Args) Compile(format string, initialValue ...interface{}) (query string, values []interface{}) { 94 | return args.CompileWithFlavor(format, args.Flavor, initialValue...) 95 | } 96 | 97 | // CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args. 98 | // 99 | // See doc for `Compile` to learn details. 100 | func (args *Args) CompileWithFlavor(format string, flavor Flavor, initialValue ...interface{}) (query string, values []interface{}) { 101 | idx := strings.IndexRune(format, '$') 102 | offset := 0 103 | ctx := &argsCompileContext{ 104 | stringBuilder: newStringBuilder(), 105 | Flavor: flavor, 106 | Values: initialValue, 107 | } 108 | 109 | if ctx.Flavor == invalidFlavor { 110 | ctx.Flavor = DefaultFlavor 111 | } 112 | 113 | for idx >= 0 && len(format) > 0 { 114 | if idx > 0 { 115 | ctx.WriteString(format[:idx]) 116 | } 117 | 118 | format = format[idx+1:] 119 | 120 | // Treat the $ at the end of format is a normal $ rune. 121 | if len(format) == 0 { 122 | ctx.WriteRune('$') 123 | break 124 | } 125 | 126 | if r := format[0]; r == '$' { 127 | ctx.WriteRune('$') 128 | format = format[1:] 129 | } else if r == '{' { 130 | format = args.compileNamed(ctx, format) 131 | } else if !args.onlyNamed && '0' <= r && r <= '9' { 132 | format, offset = args.compileDigits(ctx, format, offset) 133 | } else if !args.onlyNamed && r == '?' { 134 | format, offset = args.compileSuccessive(ctx, format[1:], offset) 135 | } else { 136 | // For unknown $ expression format, treat it as a normal $ rune. 137 | ctx.WriteRune('$') 138 | } 139 | 140 | idx = strings.IndexRune(format, '$') 141 | } 142 | 143 | if len(format) > 0 { 144 | ctx.WriteString(format) 145 | } 146 | 147 | query = ctx.String() 148 | values = args.mergeSQLNamedArgs(ctx) 149 | return 150 | } 151 | 152 | // Value returns the value of the arg. 153 | // The arg must be the value returned by `Add`. 154 | func (args *Args) Value(arg string) interface{} { 155 | _, values := args.Compile(arg) 156 | 157 | if len(values) == 0 { 158 | return nil 159 | } 160 | 161 | return values[0] 162 | } 163 | 164 | func (args *Args) compileNamed(ctx *argsCompileContext, format string) string { 165 | i := 1 166 | 167 | for ; i < len(format) && format[i] != '}'; i++ { 168 | // Nothing. 169 | } 170 | 171 | // Invalid $ format. Ignore it. 172 | if i == len(format) { 173 | return format 174 | } 175 | 176 | name := format[1:i] 177 | format = format[i+1:] 178 | 179 | if p, ok := args.namedArgs[name]; ok { 180 | format, _ = args.compileSuccessive(ctx, format, p-args.indexBase) 181 | } 182 | 183 | return format 184 | } 185 | 186 | func (args *Args) compileDigits(ctx *argsCompileContext, format string, offset int) (string, int) { 187 | i := 1 188 | 189 | for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ { 190 | // Nothing. 191 | } 192 | 193 | digits := format[:i] 194 | format = format[i:] 195 | 196 | if pointer, err := strconv.Atoi(digits); err == nil { 197 | return args.compileSuccessive(ctx, format, pointer-args.indexBase) 198 | } 199 | 200 | return format, offset 201 | } 202 | 203 | func (args *Args) compileSuccessive(ctx *argsCompileContext, format string, offset int) (string, int) { 204 | if offset < 0 || offset >= len(args.argValues) { 205 | ctx.WriteString("/* INVALID ARG $") 206 | ctx.WriteString(strconv.Itoa(offset)) 207 | ctx.WriteString(" */") 208 | return format, offset 209 | } 210 | 211 | arg := args.argValues[offset] 212 | ctx.WriteValue(arg) 213 | 214 | return format, offset + 1 215 | } 216 | 217 | func (args *Args) mergeSQLNamedArgs(ctx *argsCompileContext) []interface{} { 218 | if len(args.sqlNamedArgs) == 0 && len(ctx.NamedArgs) == 0 { 219 | return ctx.Values 220 | } 221 | 222 | values := ctx.Values 223 | existingNames := make(map[string]struct{}, len(ctx.NamedArgs)) 224 | 225 | // Add all named args to values. 226 | // Remove duplicated named args in this step. 227 | for _, arg := range ctx.NamedArgs { 228 | if _, ok := existingNames[arg.Name]; !ok { 229 | existingNames[arg.Name] = struct{}{} 230 | values = append(values, arg) 231 | } 232 | } 233 | 234 | // Stabilize the sequence to make it easier to write test cases. 235 | ints := make([]int, 0, len(args.sqlNamedArgs)) 236 | 237 | for n, p := range args.sqlNamedArgs { 238 | if _, ok := existingNames[n]; ok { 239 | continue 240 | } 241 | 242 | ints = append(ints, p) 243 | } 244 | 245 | sort.Ints(ints) 246 | 247 | for _, i := range ints { 248 | values = append(values, args.argValues[i]) 249 | } 250 | 251 | return values 252 | } 253 | 254 | func parseNamedArgs(initialValue []interface{}) (values []interface{}, namedValues []sql.NamedArg) { 255 | if len(initialValue) == 0 { 256 | values = initialValue 257 | return 258 | } 259 | 260 | // sql.NamedArgs must be placed at the end of the initial value. 261 | size := len(initialValue) 262 | i := size 263 | 264 | for ; i > 0; i-- { 265 | switch initialValue[i-1].(type) { 266 | case sql.NamedArg: 267 | continue 268 | } 269 | 270 | break 271 | } 272 | 273 | if i == size { 274 | values = initialValue 275 | return 276 | } 277 | 278 | values = initialValue[:i] 279 | namedValues = make([]sql.NamedArg, 0, size-i) 280 | 281 | for ; i < size; i++ { 282 | namedValues = append(namedValues, initialValue[i].(sql.NamedArg)) 283 | } 284 | 285 | return 286 | } 287 | 288 | type argsCompileContext struct { 289 | *stringBuilder 290 | 291 | Flavor Flavor 292 | Values []interface{} 293 | NamedArgs []sql.NamedArg 294 | } 295 | 296 | func (ctx *argsCompileContext) WriteValue(arg interface{}) { 297 | switch a := arg.(type) { 298 | case Builder: 299 | s, values := a.BuildWithFlavor(ctx.Flavor, ctx.Values...) 300 | ctx.WriteString(s) 301 | 302 | // Add all values to ctx. 303 | // Named args must be located at the end of values. 304 | values, namedArgs := parseNamedArgs(values) 305 | ctx.Values = values 306 | ctx.NamedArgs = append(ctx.NamedArgs, namedArgs...) 307 | 308 | case sql.NamedArg: 309 | ctx.WriteRune('@') 310 | ctx.WriteString(a.Name) 311 | ctx.NamedArgs = append(ctx.NamedArgs, a) 312 | 313 | case rawArgs: 314 | ctx.WriteString(a.expr) 315 | 316 | case listArgs: 317 | if a.isTuple { 318 | ctx.WriteRune('(') 319 | } 320 | 321 | if len(a.args) > 0 { 322 | ctx.WriteValue(a.args[0]) 323 | } 324 | 325 | for i := 1; i < len(a.args); i++ { 326 | ctx.WriteString(", ") 327 | ctx.WriteValue(a.args[i]) 328 | } 329 | 330 | if a.isTuple { 331 | ctx.WriteRune(')') 332 | } 333 | 334 | case condBuilder: 335 | a.Builder(ctx) 336 | 337 | default: 338 | switch ctx.Flavor { 339 | case MySQL, SQLite, CQL, ClickHouse, Presto, Informix, Doris: 340 | ctx.WriteRune('?') 341 | case PostgreSQL: 342 | fmt.Fprintf(ctx, "$%d", len(ctx.Values)+1) 343 | case SQLServer: 344 | fmt.Fprintf(ctx, "@p%d", len(ctx.Values)+1) 345 | case Oracle: 346 | fmt.Fprintf(ctx, ":%d", len(ctx.Values)+1) 347 | default: 348 | panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", ctx.Flavor, int(ctx.Flavor))) 349 | } 350 | 351 | ctx.Values = append(ctx.Values, arg) 352 | } 353 | } 354 | 355 | func (ctx *argsCompileContext) WriteValues(values []interface{}, sep string) { 356 | if len(values) == 0 { 357 | return 358 | } 359 | 360 | ctx.WriteValue(values[0]) 361 | 362 | for _, v := range values[1:] { 363 | ctx.WriteString(sep) 364 | ctx.WriteValue(v) 365 | } 366 | } 367 | -------------------------------------------------------------------------------- /args_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "bytes" 8 | "database/sql" 9 | "fmt" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/huandu/go-assert" 14 | ) 15 | 16 | func TestArgs(t *testing.T) { 17 | a := assert.New(t) 18 | start := sql.Named("start", 1234567890) 19 | end := sql.Named("end", 1234599999) 20 | named1 := Named("named1", "foo") 21 | named2 := Named("named2", "bar") 22 | 23 | cases := map[string][]interface{}{ 24 | "abc ? def\n[123]": {"abc $? def", 123}, 25 | "abc ? def\n[456]": {"abc $0 def", 456}, 26 | "abc /* INVALID ARG $1 */ def\n[]": {"abc $1 def", 123}, 27 | "abc def \n[]": {"abc ${unknown} def ", 123}, 28 | "abc $ def\n[]": {"abc $$ def", 123}, 29 | "abcdef$\n[]": {"abcdef$", 123}, 30 | "abc ? ? ? ? def\n[123 456 123 456]": {"abc $? $? $0 $? def", 123, 456, 789}, 31 | "abc ? raw ? raw def\n[123 123]": {"abc $? $? $0 $? def", 123, Raw("raw"), 789}, 32 | "abc $-1 $a def\n[]": {"abc $-1 $a def", 123}, 33 | 34 | "abc ? def ? ?\n[foo bar foo]": {"abc ${named1} def ${named2} ${named1}", named2, named1, named2}, 35 | "@end @start @end\n[{{} end 1234599999} {{} start 1234567890}]": {"$? $? $?", end, start, end}, 36 | } 37 | 38 | for expected, c := range cases { 39 | args := new(Args) 40 | 41 | for i := 1; i < len(c); i++ { 42 | args.Add(c[i]) 43 | } 44 | 45 | sql, values := args.Compile(c[0].(string)) 46 | actual := fmt.Sprintf("%v\n%v", sql, values) 47 | 48 | a.Equal(actual, expected) 49 | } 50 | 51 | old := DefaultFlavor 52 | defer func() { 53 | DefaultFlavor = old 54 | }() 55 | 56 | DefaultFlavor = PostgreSQL 57 | 58 | // PostgreSQL flavor compiled sql. 59 | for expected, c := range cases { 60 | args := new(Args) 61 | 62 | for i := 1; i < len(c); i++ { 63 | args.Add(c[i]) 64 | } 65 | 66 | sql, values := args.Compile(c[0].(string)) 67 | actual := fmt.Sprintf("%v\n%v", sql, values) 68 | expected = toPostgreSQL(expected) 69 | 70 | a.Equal(actual, expected) 71 | } 72 | 73 | DefaultFlavor = SQLServer 74 | 75 | // SQLServer flavor compiled sql. 76 | for expected, c := range cases { 77 | args := new(Args) 78 | 79 | for i := 1; i < len(c); i++ { 80 | args.Add(c[i]) 81 | } 82 | 83 | sql, values := args.Compile(c[0].(string)) 84 | actual := fmt.Sprintf("%v\n%v", sql, values) 85 | expected = toSQLServerSQL(expected) 86 | 87 | a.Equal(actual, expected) 88 | } 89 | 90 | DefaultFlavor = CQL 91 | 92 | for expected, c := range cases { 93 | args := new(Args) 94 | 95 | for i := 1; i < len(c); i++ { 96 | args.Add(c[i]) 97 | } 98 | 99 | sql, values := args.Compile(c[0].(string)) 100 | actual := fmt.Sprintf("%v\n%v", sql, values) 101 | 102 | a.Equal(actual, expected) 103 | } 104 | } 105 | 106 | func toPostgreSQL(sql string) string { 107 | parts := strings.Split(sql, "?") 108 | buf := &bytes.Buffer{} 109 | buf.WriteString(parts[0]) 110 | 111 | for i, p := range parts[1:] { 112 | fmt.Fprintf(buf, "$%v", i+1) 113 | buf.WriteString(p) 114 | } 115 | 116 | return buf.String() 117 | } 118 | 119 | func toSQLServerSQL(sql string) string { 120 | parts := strings.Split(sql, "?") 121 | buf := &bytes.Buffer{} 122 | buf.WriteString(parts[0]) 123 | 124 | for i, p := range parts[1:] { 125 | fmt.Fprintf(buf, "@p%v", i+1) 126 | buf.WriteString(p) 127 | } 128 | 129 | return buf.String() 130 | } 131 | 132 | func TestArgsAdd(t *testing.T) { 133 | a := assert.New(t) 134 | args := &Args{} 135 | 136 | for i := 0; i < maxPredefinedArgs*2; i++ { 137 | actual := args.Add(i) 138 | a.Equal(actual, fmt.Sprintf("$%v", i)) 139 | } 140 | } 141 | 142 | func TestArgsValue(t *testing.T) { 143 | a := assert.New(t) 144 | args := &Args{} 145 | 146 | v1 := 123 147 | arg1 := args.Add(v1) 148 | argInvalid := "invalid" 149 | argLooselyTyped := arg1 + "something else" 150 | 151 | a.Equal(v1, args.Value(arg1)) 152 | a.Equal(nil, args.Value(argInvalid)) 153 | a.Equal(v1, args.Value(argLooselyTyped)) 154 | } 155 | -------------------------------------------------------------------------------- /builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | ) 9 | 10 | // Builder is a general SQL builder. 11 | // It's used by Args to create nested SQL like the `IN` expression in 12 | // `SELECT * FROM t1 WHERE id IN (SELECT id FROM t2)`. 13 | type Builder interface { 14 | Build() (sql string, args []interface{}) 15 | BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) 16 | Flavor() Flavor 17 | } 18 | 19 | type compiledBuilder struct { 20 | args *Args 21 | format string 22 | } 23 | 24 | var _ Builder = new(compiledBuilder) 25 | 26 | func (cb *compiledBuilder) Build() (sql string, args []interface{}) { 27 | return cb.args.Compile(cb.format) 28 | } 29 | 30 | func (cb *compiledBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 31 | return cb.args.CompileWithFlavor(cb.format, flavor, initialArg...) 32 | } 33 | 34 | // Flavor returns flavor of builder 35 | // Always returns DefaultFlavor 36 | func (cb *compiledBuilder) Flavor() Flavor { 37 | return cb.args.Flavor 38 | } 39 | 40 | type flavoredBuilder struct { 41 | builder Builder 42 | flavor Flavor 43 | } 44 | 45 | func (fb *flavoredBuilder) Build() (sql string, args []interface{}) { 46 | return fb.builder.BuildWithFlavor(fb.flavor) 47 | } 48 | 49 | func (fb *flavoredBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 50 | return fb.builder.BuildWithFlavor(flavor, initialArg...) 51 | } 52 | 53 | // Flavor returns flavor of builder 54 | func (fb *flavoredBuilder) Flavor() Flavor { 55 | return fb.flavor 56 | } 57 | 58 | // WithFlavor creates a new Builder based on builder with a default flavor. 59 | func WithFlavor(builder Builder, flavor Flavor) Builder { 60 | return &flavoredBuilder{ 61 | builder: builder, 62 | flavor: flavor, 63 | } 64 | } 65 | 66 | // Buildf creates a Builder from a format string using `fmt.Sprintf`-like syntax. 67 | // As all arguments will be converted to a string internally, e.g. "$0", 68 | // only `%v` and `%s` are valid. 69 | func Buildf(format string, arg ...interface{}) Builder { 70 | args := &Args{ 71 | Flavor: DefaultFlavor, 72 | } 73 | vars := make([]interface{}, 0, len(arg)) 74 | 75 | for _, a := range arg { 76 | vars = append(vars, args.Add(a)) 77 | } 78 | 79 | return &compiledBuilder{ 80 | args: args, 81 | format: fmt.Sprintf(Escape(format), vars...), 82 | } 83 | } 84 | 85 | // Build creates a Builder from a format string. 86 | // The format string uses special syntax to represent arguments. 87 | // See doc in `Args#Compile` for syntax details. 88 | func Build(format string, arg ...interface{}) Builder { 89 | args := &Args{ 90 | Flavor: DefaultFlavor, 91 | } 92 | 93 | for _, a := range arg { 94 | args.Add(a) 95 | } 96 | 97 | return &compiledBuilder{ 98 | args: args, 99 | format: format, 100 | } 101 | } 102 | 103 | // BuildNamed creates a Builder from a format string. 104 | // The format string uses `${key}` to refer the value of named by key. 105 | func BuildNamed(format string, named map[string]interface{}) Builder { 106 | args := &Args{ 107 | Flavor: DefaultFlavor, 108 | onlyNamed: true, 109 | } 110 | 111 | for n, v := range named { 112 | args.Add(Named(n, v)) 113 | } 114 | 115 | return &compiledBuilder{ 116 | args: args, 117 | format: format, 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/huandu/go-assert" 12 | ) 13 | 14 | func ExampleBuildf() { 15 | sb := NewSelectBuilder() 16 | sb.Select("id").From("user") 17 | 18 | explain := Buildf("EXPLAIN %v LEFT JOIN SELECT * FROM banned WHERE state IN (%v, %v)", sb, 1, 2) 19 | s, args := explain.Build() 20 | fmt.Println(s) 21 | fmt.Println(args) 22 | 23 | // Output: 24 | // EXPLAIN SELECT id FROM user LEFT JOIN SELECT * FROM banned WHERE state IN (?, ?) 25 | // [1 2] 26 | } 27 | 28 | func ExampleBuild() { 29 | sb := NewSelectBuilder() 30 | sb.Select("id").From("user").Where(sb.In("status", 1, 2)) 31 | 32 | b := Build("EXPLAIN $? LEFT JOIN SELECT * FROM $? WHERE created_at > $? AND state IN (${states}) AND modified_at BETWEEN $2 AND $?", 33 | sb, Raw("banned"), 1514458225, 1514544625, Named("states", List([]int{3, 4, 5}))) 34 | s, args := b.Build() 35 | 36 | fmt.Println(s) 37 | fmt.Println(args) 38 | 39 | // Output: 40 | // EXPLAIN SELECT id FROM user WHERE status IN (?, ?) LEFT JOIN SELECT * FROM banned WHERE created_at > ? AND state IN (?, ?, ?) AND modified_at BETWEEN ? AND ? 41 | // [1 2 1514458225 3 4 5 1514458225 1514544625] 42 | } 43 | 44 | func ExampleBuildNamed() { 45 | b := BuildNamed("SELECT * FROM ${table} WHERE status IN (${status}) AND name LIKE ${name} AND created_at > ${time} AND modified_at < ${time} + 86400", 46 | map[string]interface{}{ 47 | "time": sql.Named("start", 1234567890), 48 | "status": List([]int{1, 2, 5}), 49 | "name": "Huan%", 50 | "table": Raw("user"), 51 | }) 52 | s, args := b.Build() 53 | 54 | fmt.Println(s) 55 | fmt.Println(args) 56 | 57 | // Output: 58 | // SELECT * FROM user WHERE status IN (?, ?, ?) AND name LIKE ? AND created_at > @start AND modified_at < @start + 86400 59 | // [1 2 5 Huan% {{} start 1234567890}] 60 | } 61 | 62 | func ExampleWithFlavor() { 63 | sql, args := WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).Build() 64 | 65 | fmt.Println(sql) 66 | fmt.Println(args) 67 | 68 | // Explicitly use MySQL as the flavor. 69 | sql, args = WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).BuildWithFlavor(MySQL) 70 | 71 | fmt.Println(sql) 72 | fmt.Println(args) 73 | 74 | // Explicitly use MySQL as the informix. 75 | sql, args = WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), Informix).Build() 76 | 77 | fmt.Println(sql) 78 | fmt.Println(args) 79 | 80 | // Output: 81 | // SELECT * FROM foo WHERE id = $1 82 | // [1234] 83 | // SELECT * FROM foo WHERE id = ? 84 | // [1234] 85 | // SELECT * FROM foo WHERE id = ? 86 | // [1234] 87 | } 88 | 89 | func TestBuildWithPostgreSQL(t *testing.T) { 90 | a := assert.New(t) 91 | sb1 := PostgreSQL.NewSelectBuilder() 92 | sb1.Select("col1", "col2").From("t1").Where(sb1.E("id", 1234), sb1.G("level", 2)) 93 | 94 | sb2 := PostgreSQL.NewSelectBuilder() 95 | sb2.Select("col3", "col4").From("t2").Where(sb2.E("id", 4567), sb2.LE("level", 5)) 96 | 97 | // Use DefaultFlavor (MySQL) instead of PostgreSQL. 98 | sql, args := Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build() 99 | 100 | a.Equal(sql, "SELECT ? AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = ? AND level > ? LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = ? AND level <= ?") 101 | a.Equal(args, []interface{}{7890, 1234, 2, 4567, 5}) 102 | 103 | old := DefaultFlavor 104 | DefaultFlavor = PostgreSQL 105 | defer func() { 106 | DefaultFlavor = old 107 | }() 108 | 109 | sql, args = Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build() 110 | 111 | a.Equal(sql, "SELECT $1 AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = $2 AND level > $3 LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = $4 AND level <= $5") 112 | a.Equal(args, []interface{}{7890, 1234, 2, 4567, 5}) 113 | } 114 | 115 | func TestBuildWithCQL(t *testing.T) { 116 | a := assert.New(t) 117 | 118 | ib1 := CQL.NewInsertBuilder() 119 | ib1.InsertInto("t1").Cols("col1", "col2").Values(1, 2) 120 | 121 | ib2 := CQL.NewInsertBuilder() 122 | ib2.InsertInto("t2").Cols("col3", "col4").Values(3, 4) 123 | 124 | old := DefaultFlavor 125 | DefaultFlavor = CQL 126 | defer func() { 127 | DefaultFlavor = old 128 | }() 129 | 130 | sql, args := Build("BEGIN BATCH USING TIMESTAMP $0 $1; $2; APPLY BATCH;", 1481124356754405, ib1, ib2).Build() 131 | 132 | a.Equal(sql, "BEGIN BATCH USING TIMESTAMP ? INSERT INTO t1 (col1, col2) VALUES (?, ?); INSERT INTO t2 (col3, col4) VALUES (?, ?); APPLY BATCH;") 133 | a.Equal(args, []interface{}{1481124356754405, 1, 2, 3, 4}) 134 | } 135 | 136 | func TestBuilderGetFlavor(t *testing.T) { 137 | a := assert.New(t) 138 | 139 | defaultBuilder := Build("SELECT * FROM foo WHERE id = $0", 1234) 140 | a.Equal(DefaultFlavor, defaultBuilder.Flavor()) 141 | 142 | buildfBuilder := Buildf("SELECT * FROM foo WHERE id = %v", 1234) 143 | a.Equal(DefaultFlavor, buildfBuilder.Flavor()) 144 | 145 | namedBuilder := Buildf("SELECT * FROM ${table} WHERE id = 1234", map[string]interface{}{ 146 | "table": "foo", 147 | }) 148 | a.Equal(DefaultFlavor, namedBuilder.Flavor()) 149 | 150 | flavoredBuilder := WithFlavor(Build("SELECT * FROM foo WHERE id = $0", 1234), PostgreSQL) 151 | a.Equal(PostgreSQL, flavoredBuilder.Flavor()) 152 | 153 | } 154 | -------------------------------------------------------------------------------- /cond.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | const ( 7 | lparen = "(" 8 | rparen = ")" 9 | opOR = " OR " 10 | opAND = " AND " 11 | opNOT = "NOT " 12 | ) 13 | 14 | const minIndexBase = 256 15 | 16 | // Cond provides several helper methods to build conditions. 17 | type Cond struct { 18 | Args *Args 19 | } 20 | 21 | // NewCond returns a new Cond. 22 | func NewCond() *Cond { 23 | return &Cond{ 24 | Args: &Args{ 25 | // Based on the discussion in #174, users may call this method to create 26 | // `Cond` for building various conditions, which is a misuse, but we 27 | // cannot completely prevent this error. To facilitate users in 28 | // identifying the issue when they make mistakes and to avoid 29 | // unexpected stackoverflows, the base index for `Args` is 30 | // deliberately set to a larger non-zero value here. This can 31 | // significantly reduce the likelihood of issues and allows for 32 | // timely error notification to users. 33 | indexBase: minIndexBase, 34 | }, 35 | } 36 | } 37 | 38 | // Equal is used to construct the expression "field = value". 39 | func (c *Cond) Equal(field string, value interface{}) string { 40 | if len(field) == 0 { 41 | return "" 42 | } 43 | 44 | return c.Var(condBuilder{ 45 | Builder: func(ctx *argsCompileContext) { 46 | ctx.WriteString(field) 47 | ctx.WriteString(" = ") 48 | ctx.WriteValue(value) 49 | }, 50 | }) 51 | } 52 | 53 | // E is an alias of Equal. 54 | func (c *Cond) E(field string, value interface{}) string { 55 | return c.Equal(field, value) 56 | } 57 | 58 | // EQ is an alias of Equal. 59 | func (c *Cond) EQ(field string, value interface{}) string { 60 | return c.Equal(field, value) 61 | } 62 | 63 | // NotEqual is used to construct the expression "field <> value". 64 | func (c *Cond) NotEqual(field string, value interface{}) string { 65 | if len(field) == 0 { 66 | return "" 67 | } 68 | 69 | return c.Var(condBuilder{ 70 | Builder: func(ctx *argsCompileContext) { 71 | ctx.WriteString(field) 72 | ctx.WriteString(" <> ") 73 | ctx.WriteValue(value) 74 | }, 75 | }) 76 | } 77 | 78 | // NE is an alias of NotEqual. 79 | func (c *Cond) NE(field string, value interface{}) string { 80 | return c.NotEqual(field, value) 81 | } 82 | 83 | // NEQ is an alias of NotEqual. 84 | func (c *Cond) NEQ(field string, value interface{}) string { 85 | return c.NotEqual(field, value) 86 | } 87 | 88 | // GreaterThan is used to construct the expression "field > value". 89 | func (c *Cond) GreaterThan(field string, value interface{}) string { 90 | if len(field) == 0 { 91 | return "" 92 | } 93 | 94 | return c.Var(condBuilder{ 95 | Builder: func(ctx *argsCompileContext) { 96 | ctx.WriteString(field) 97 | ctx.WriteString(" > ") 98 | ctx.WriteValue(value) 99 | }, 100 | }) 101 | } 102 | 103 | // G is an alias of GreaterThan. 104 | func (c *Cond) G(field string, value interface{}) string { 105 | return c.GreaterThan(field, value) 106 | } 107 | 108 | // GT is an alias of GreaterThan. 109 | func (c *Cond) GT(field string, value interface{}) string { 110 | return c.GreaterThan(field, value) 111 | } 112 | 113 | // GreaterEqualThan is used to construct the expression "field >= value". 114 | func (c *Cond) GreaterEqualThan(field string, value interface{}) string { 115 | if len(field) == 0 { 116 | return "" 117 | } 118 | 119 | return c.Var(condBuilder{ 120 | Builder: func(ctx *argsCompileContext) { 121 | ctx.WriteString(field) 122 | ctx.WriteString(" >= ") 123 | ctx.WriteValue(value) 124 | }, 125 | }) 126 | } 127 | 128 | // GE is an alias of GreaterEqualThan. 129 | func (c *Cond) GE(field string, value interface{}) string { 130 | return c.GreaterEqualThan(field, value) 131 | } 132 | 133 | // GTE is an alias of GreaterEqualThan. 134 | func (c *Cond) GTE(field string, value interface{}) string { 135 | return c.GreaterEqualThan(field, value) 136 | } 137 | 138 | // LessThan is used to construct the expression "field < value". 139 | func (c *Cond) LessThan(field string, value interface{}) string { 140 | if len(field) == 0 { 141 | return "" 142 | } 143 | 144 | return c.Var(condBuilder{ 145 | Builder: func(ctx *argsCompileContext) { 146 | ctx.WriteString(field) 147 | ctx.WriteString(" < ") 148 | ctx.WriteValue(value) 149 | }, 150 | }) 151 | } 152 | 153 | // L is an alias of LessThan. 154 | func (c *Cond) L(field string, value interface{}) string { 155 | return c.LessThan(field, value) 156 | } 157 | 158 | // LT is an alias of LessThan. 159 | func (c *Cond) LT(field string, value interface{}) string { 160 | return c.LessThan(field, value) 161 | } 162 | 163 | // LessEqualThan is used to construct the expression "field <= value". 164 | func (c *Cond) LessEqualThan(field string, value interface{}) string { 165 | if len(field) == 0 { 166 | return "" 167 | } 168 | return c.Var(condBuilder{ 169 | Builder: func(ctx *argsCompileContext) { 170 | ctx.WriteString(field) 171 | ctx.WriteString(" <= ") 172 | ctx.WriteValue(value) 173 | }, 174 | }) 175 | } 176 | 177 | // LE is an alias of LessEqualThan. 178 | func (c *Cond) LE(field string, value interface{}) string { 179 | return c.LessEqualThan(field, value) 180 | } 181 | 182 | // LTE is an alias of LessEqualThan. 183 | func (c *Cond) LTE(field string, value interface{}) string { 184 | return c.LessEqualThan(field, value) 185 | } 186 | 187 | // In is used to construct the expression "field IN (value...)". 188 | func (c *Cond) In(field string, values ...interface{}) string { 189 | if len(field) == 0 { 190 | return "" 191 | } 192 | 193 | // Empty values means "false". 194 | if len(values) == 0 { 195 | return "0 = 1" 196 | } 197 | 198 | return c.Var(condBuilder{ 199 | Builder: func(ctx *argsCompileContext) { 200 | ctx.WriteString(field) 201 | ctx.WriteString(" IN (") 202 | ctx.WriteValues(values, ", ") 203 | ctx.WriteString(")") 204 | }, 205 | }) 206 | } 207 | 208 | // NotIn is used to construct the expression "field NOT IN (value...)". 209 | func (c *Cond) NotIn(field string, values ...interface{}) string { 210 | if len(field) == 0 || len(values) == 0 { 211 | return "" 212 | } 213 | 214 | return c.Var(condBuilder{ 215 | Builder: func(ctx *argsCompileContext) { 216 | ctx.WriteString(field) 217 | ctx.WriteString(" NOT IN (") 218 | ctx.WriteValues(values, ", ") 219 | ctx.WriteString(")") 220 | }, 221 | }) 222 | } 223 | 224 | // Like is used to construct the expression "field LIKE value". 225 | func (c *Cond) Like(field string, value interface{}) string { 226 | if len(field) == 0 { 227 | return "" 228 | } 229 | 230 | return c.Var(condBuilder{ 231 | Builder: func(ctx *argsCompileContext) { 232 | ctx.WriteString(field) 233 | ctx.WriteString(" LIKE ") 234 | ctx.WriteValue(value) 235 | }, 236 | }) 237 | } 238 | 239 | // ILike is used to construct the expression "field ILIKE value". 240 | // 241 | // When the database system does not support the ILIKE operator, 242 | // the ILike method will return "LOWER(field) LIKE LOWER(value)" 243 | // to simulate the behavior of the ILIKE operator. 244 | func (c *Cond) ILike(field string, value interface{}) string { 245 | if len(field) == 0 { 246 | return "" 247 | } 248 | 249 | return c.Var(condBuilder{ 250 | Builder: func(ctx *argsCompileContext) { 251 | switch ctx.Flavor { 252 | case PostgreSQL, SQLite: 253 | ctx.WriteString(field) 254 | ctx.WriteString(" ILIKE ") 255 | ctx.WriteValue(value) 256 | 257 | default: 258 | // Use LOWER to simulate ILIKE. 259 | ctx.WriteString("LOWER(") 260 | ctx.WriteString(field) 261 | ctx.WriteString(") LIKE LOWER(") 262 | ctx.WriteValue(value) 263 | ctx.WriteString(")") 264 | } 265 | }, 266 | }) 267 | } 268 | 269 | // NotLike is used to construct the expression "field NOT LIKE value". 270 | func (c *Cond) NotLike(field string, value interface{}) string { 271 | if len(field) == 0 { 272 | return "" 273 | } 274 | 275 | return c.Var(condBuilder{ 276 | Builder: func(ctx *argsCompileContext) { 277 | ctx.WriteString(field) 278 | ctx.WriteString(" NOT LIKE ") 279 | ctx.WriteValue(value) 280 | }, 281 | }) 282 | } 283 | 284 | // NotILike is used to construct the expression "field NOT ILIKE value". 285 | // 286 | // When the database system does not support the ILIKE operator, 287 | // the NotILike method will return "LOWER(field) NOT LIKE LOWER(value)" 288 | // to simulate the behavior of the ILIKE operator. 289 | func (c *Cond) NotILike(field string, value interface{}) string { 290 | if len(field) == 0 { 291 | return "" 292 | } 293 | 294 | return c.Var(condBuilder{ 295 | Builder: func(ctx *argsCompileContext) { 296 | switch ctx.Flavor { 297 | case PostgreSQL, SQLite: 298 | ctx.WriteString(field) 299 | ctx.WriteString(" NOT ILIKE ") 300 | ctx.WriteValue(value) 301 | 302 | default: 303 | // Use LOWER to simulate ILIKE. 304 | ctx.WriteString("LOWER(") 305 | ctx.WriteString(field) 306 | ctx.WriteString(") NOT LIKE LOWER(") 307 | ctx.WriteValue(value) 308 | ctx.WriteString(")") 309 | } 310 | }, 311 | }) 312 | } 313 | 314 | // IsNull is used to construct the expression "field IS NULL". 315 | func (c *Cond) IsNull(field string) string { 316 | if len(field) == 0 { 317 | return "" 318 | } 319 | 320 | return c.Var(condBuilder{ 321 | Builder: func(ctx *argsCompileContext) { 322 | ctx.WriteString(field) 323 | ctx.WriteString(" IS NULL") 324 | }, 325 | }) 326 | } 327 | 328 | // IsNotNull is used to construct the expression "field IS NOT NULL". 329 | func (c *Cond) IsNotNull(field string) string { 330 | if len(field) == 0 { 331 | return "" 332 | } 333 | return c.Var(condBuilder{ 334 | Builder: func(ctx *argsCompileContext) { 335 | ctx.WriteString(field) 336 | ctx.WriteString(" IS NOT NULL") 337 | }, 338 | }) 339 | } 340 | 341 | // Between is used to construct the expression "field BETWEEN lower AND upper". 342 | func (c *Cond) Between(field string, lower, upper interface{}) string { 343 | if len(field) == 0 { 344 | return "" 345 | } 346 | 347 | return c.Var(condBuilder{ 348 | Builder: func(ctx *argsCompileContext) { 349 | ctx.WriteString(field) 350 | ctx.WriteString(" BETWEEN ") 351 | ctx.WriteValue(lower) 352 | ctx.WriteString(" AND ") 353 | ctx.WriteValue(upper) 354 | }, 355 | }) 356 | } 357 | 358 | // NotBetween is used to construct the expression "field NOT BETWEEN lower AND upper". 359 | func (c *Cond) NotBetween(field string, lower, upper interface{}) string { 360 | if len(field) == 0 { 361 | return "" 362 | } 363 | 364 | return c.Var(condBuilder{ 365 | Builder: func(ctx *argsCompileContext) { 366 | ctx.WriteString(field) 367 | ctx.WriteString(" NOT BETWEEN ") 368 | ctx.WriteValue(lower) 369 | ctx.WriteString(" AND ") 370 | ctx.WriteValue(upper) 371 | }, 372 | }) 373 | } 374 | 375 | // Or is used to construct the expression OR logic like "expr1 OR expr2 OR expr3". 376 | func (c *Cond) Or(orExpr ...string) string { 377 | orExpr = filterEmptyStrings(orExpr) 378 | 379 | if len(orExpr) == 0 { 380 | return "" 381 | } 382 | 383 | exprByteLen := estimateStringsBytes(orExpr) 384 | if exprByteLen == 0 { 385 | return "" 386 | } 387 | 388 | buf := newStringBuilder() 389 | 390 | // Ensure that there is only 1 memory allocation. 391 | size := len(lparen) + len(rparen) + (len(orExpr)-1)*len(opOR) + exprByteLen 392 | buf.Grow(size) 393 | 394 | buf.WriteString(lparen) 395 | buf.WriteStrings(orExpr, opOR) 396 | buf.WriteString(rparen) 397 | return buf.String() 398 | } 399 | 400 | // And is used to construct the expression AND logic like "expr1 AND expr2 AND expr3". 401 | func (c *Cond) And(andExpr ...string) string { 402 | andExpr = filterEmptyStrings(andExpr) 403 | 404 | if len(andExpr) == 0 { 405 | return "" 406 | } 407 | 408 | exprByteLen := estimateStringsBytes(andExpr) 409 | if exprByteLen == 0 { 410 | return "" 411 | } 412 | 413 | buf := newStringBuilder() 414 | 415 | // Ensure that there is only 1 memory allocation. 416 | size := len(lparen) + len(rparen) + (len(andExpr)-1)*len(opAND) + exprByteLen 417 | buf.Grow(size) 418 | 419 | buf.WriteString(lparen) 420 | buf.WriteStrings(andExpr, opAND) 421 | buf.WriteString(rparen) 422 | return buf.String() 423 | } 424 | 425 | // Not is used to construct the expression "NOT expr". 426 | func (c *Cond) Not(notExpr string) string { 427 | if len(notExpr) == 0 { 428 | return "" 429 | } 430 | buf := newStringBuilder() 431 | 432 | // Ensure that there is only 1 memory allocation. 433 | size := len(opNOT) + len(notExpr) 434 | buf.Grow(size) 435 | 436 | buf.WriteString(opNOT) 437 | buf.WriteString(notExpr) 438 | return buf.String() 439 | } 440 | 441 | // Exists is used to construct the expression "EXISTS (subquery)". 442 | func (c *Cond) Exists(subquery interface{}) string { 443 | return c.Var(condBuilder{ 444 | Builder: func(ctx *argsCompileContext) { 445 | ctx.WriteString("EXISTS (") 446 | ctx.WriteValue(subquery) 447 | ctx.WriteString(")") 448 | }, 449 | }) 450 | } 451 | 452 | // NotExists is used to construct the expression "NOT EXISTS (subquery)". 453 | func (c *Cond) NotExists(subquery interface{}) string { 454 | return c.Var(condBuilder{ 455 | Builder: func(ctx *argsCompileContext) { 456 | ctx.WriteString("NOT EXISTS (") 457 | ctx.WriteValue(subquery) 458 | ctx.WriteString(")") 459 | }, 460 | }) 461 | } 462 | 463 | // Any is used to construct the expression "field op ANY (value...)". 464 | func (c *Cond) Any(field, op string, values ...interface{}) string { 465 | if len(field) == 0 || len(op) == 0 { 466 | return "" 467 | } 468 | 469 | // Empty values means "false". 470 | if len(values) == 0 { 471 | return "0 = 1" 472 | } 473 | 474 | return c.Var(condBuilder{ 475 | Builder: func(ctx *argsCompileContext) { 476 | ctx.WriteString(field) 477 | ctx.WriteString(" ") 478 | ctx.WriteString(op) 479 | ctx.WriteString(" ANY (") 480 | ctx.WriteValues(values, ", ") 481 | ctx.WriteString(")") 482 | }, 483 | }) 484 | } 485 | 486 | // All is used to construct the expression "field op ALL (value...)". 487 | func (c *Cond) All(field, op string, values ...interface{}) string { 488 | if len(field) == 0 || len(op) == 0 { 489 | return "" 490 | } 491 | 492 | // Empty values means "false". 493 | if len(values) == 0 { 494 | return "0 = 1" 495 | } 496 | 497 | return c.Var(condBuilder{ 498 | Builder: func(ctx *argsCompileContext) { 499 | ctx.WriteString(field) 500 | ctx.WriteString(" ") 501 | ctx.WriteString(op) 502 | ctx.WriteString(" ALL (") 503 | ctx.WriteValues(values, ", ") 504 | ctx.WriteString(")") 505 | }, 506 | }) 507 | } 508 | 509 | // Some is used to construct the expression "field op SOME (value...)". 510 | func (c *Cond) Some(field, op string, values ...interface{}) string { 511 | if len(field) == 0 || len(op) == 0 { 512 | return "" 513 | } 514 | 515 | // Empty values means "false". 516 | if len(values) == 0 { 517 | return "0 = 1" 518 | } 519 | 520 | return c.Var(condBuilder{ 521 | Builder: func(ctx *argsCompileContext) { 522 | ctx.WriteString(field) 523 | ctx.WriteString(" ") 524 | ctx.WriteString(op) 525 | ctx.WriteString(" SOME (") 526 | ctx.WriteValues(values, ", ") 527 | ctx.WriteString(")") 528 | }, 529 | }) 530 | } 531 | 532 | // IsDistinctFrom is used to construct the expression "field IS DISTINCT FROM value". 533 | // 534 | // When the database system does not support the IS DISTINCT FROM operator, 535 | // the NotILike method will return "NOT field <=> value" for MySQL or a 536 | // "CASE ... WHEN ... ELSE ... END" expression to simulate the behavior of 537 | // the IS DISTINCT FROM operator. 538 | func (c *Cond) IsDistinctFrom(field string, value interface{}) string { 539 | if len(field) == 0 { 540 | return "" 541 | } 542 | 543 | return c.Var(condBuilder{ 544 | Builder: func(ctx *argsCompileContext) { 545 | switch ctx.Flavor { 546 | case PostgreSQL, SQLite, SQLServer: 547 | ctx.WriteString(field) 548 | ctx.WriteString(" IS DISTINCT FROM ") 549 | ctx.WriteValue(value) 550 | 551 | case MySQL: 552 | ctx.WriteString("NOT ") 553 | ctx.WriteString(field) 554 | ctx.WriteString(" <=> ") 555 | ctx.WriteValue(value) 556 | 557 | default: 558 | // CASE 559 | // WHEN field IS NULL AND value IS NULL THEN 0 560 | // WHEN field IS NOT NULL AND value IS NOT NULL AND field = value THEN 0 561 | // ELSE 1 562 | // END = 1 563 | ctx.WriteString("CASE WHEN ") 564 | ctx.WriteString(field) 565 | ctx.WriteString(" IS NULL AND ") 566 | ctx.WriteValue(value) 567 | ctx.WriteString(" IS NULL THEN 0 WHEN ") 568 | ctx.WriteString(field) 569 | ctx.WriteString(" IS NOT NULL AND ") 570 | ctx.WriteValue(value) 571 | ctx.WriteString(" IS NOT NULL AND ") 572 | ctx.WriteString(field) 573 | ctx.WriteString(" = ") 574 | ctx.WriteValue(value) 575 | ctx.WriteString(" THEN 0 ELSE 1 END = 1") 576 | } 577 | }, 578 | }) 579 | } 580 | 581 | // IsNotDistinctFrom is used to construct the expression "field IS NOT DISTINCT FROM value". 582 | // 583 | // When the database system does not support the IS NOT DISTINCT FROM operator, 584 | // the NotILike method will return "field <=> value" for MySQL or a 585 | // "CASE ... WHEN ... ELSE ... END" expression to simulate the behavior of 586 | // the IS NOT DISTINCT FROM operator. 587 | func (c *Cond) IsNotDistinctFrom(field string, value interface{}) string { 588 | if len(field) == 0 { 589 | return "" 590 | } 591 | 592 | return c.Var(condBuilder{ 593 | Builder: func(ctx *argsCompileContext) { 594 | switch ctx.Flavor { 595 | case PostgreSQL, SQLite, SQLServer: 596 | ctx.WriteString(field) 597 | ctx.WriteString(" IS NOT DISTINCT FROM ") 598 | ctx.WriteValue(value) 599 | 600 | case MySQL: 601 | ctx.WriteString(field) 602 | ctx.WriteString(" <=> ") 603 | ctx.WriteValue(value) 604 | 605 | default: 606 | // CASE 607 | // WHEN field IS NULL AND value IS NULL THEN 1 608 | // WHEN field IS NOT NULL AND value IS NOT NULL AND field = value THEN 1 609 | // ELSE 0 610 | // END = 1 611 | ctx.WriteString("CASE WHEN ") 612 | ctx.WriteString(field) 613 | ctx.WriteString(" IS NULL AND ") 614 | ctx.WriteValue(value) 615 | ctx.WriteString(" IS NULL THEN 1 WHEN ") 616 | ctx.WriteString(field) 617 | ctx.WriteString(" IS NOT NULL AND ") 618 | ctx.WriteValue(value) 619 | ctx.WriteString(" IS NOT NULL AND ") 620 | ctx.WriteString(field) 621 | ctx.WriteString(" = ") 622 | ctx.WriteValue(value) 623 | ctx.WriteString(" THEN 1 ELSE 0 END = 1") 624 | } 625 | }, 626 | }) 627 | } 628 | 629 | // Var returns a placeholder for value. 630 | func (c *Cond) Var(value interface{}) string { 631 | return c.Args.Add(value) 632 | } 633 | 634 | type condBuilder struct { 635 | Builder func(ctx *argsCompileContext) 636 | } 637 | 638 | func estimateStringsBytes(strs []string) (n int) { 639 | for _, s := range strs { 640 | n += len(s) 641 | } 642 | 643 | return 644 | } 645 | -------------------------------------------------------------------------------- /cond_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "strings" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | type TestPair struct { 14 | Expected string 15 | Actual string 16 | } 17 | 18 | func newTestPair(expected string, fn func(c *Cond) string) *TestPair { 19 | cond := newCond() 20 | format := fn(cond) 21 | sql, _ := cond.Args.CompileWithFlavor(format, PostgreSQL) 22 | return &TestPair{ 23 | Expected: expected, 24 | Actual: sql, 25 | } 26 | } 27 | 28 | func TestCond(t *testing.T) { 29 | a := assert.New(t) 30 | cases := []*TestPair{ 31 | newTestPair("$a = $1", func(c *Cond) string { return c.Equal("$a", 123) }), 32 | newTestPair("$b = $1", func(c *Cond) string { return c.E("$b", 123) }), 33 | newTestPair("$c = $1", func(c *Cond) string { return c.EQ("$c", 123) }), 34 | newTestPair("$a <> $1", func(c *Cond) string { return c.NotEqual("$a", 123) }), 35 | newTestPair("$b <> $1", func(c *Cond) string { return c.NE("$b", 123) }), 36 | newTestPair("$c <> $1", func(c *Cond) string { return c.NEQ("$c", 123) }), 37 | newTestPair("$a > $1", func(c *Cond) string { return c.GreaterThan("$a", 123) }), 38 | newTestPair("$b > $1", func(c *Cond) string { return c.G("$b", 123) }), 39 | newTestPair("$c > $1", func(c *Cond) string { return c.GT("$c", 123) }), 40 | newTestPair("$a >= $1", func(c *Cond) string { return c.GreaterEqualThan("$a", 123) }), 41 | newTestPair("$b >= $1", func(c *Cond) string { return c.GE("$b", 123) }), 42 | newTestPair("$c >= $1", func(c *Cond) string { return c.GTE("$c", 123) }), 43 | newTestPair("$a < $1", func(c *Cond) string { return c.LessThan("$a", 123) }), 44 | newTestPair("$b < $1", func(c *Cond) string { return c.L("$b", 123) }), 45 | newTestPair("$c < $1", func(c *Cond) string { return c.LT("$c", 123) }), 46 | newTestPair("$a <= $1", func(c *Cond) string { return c.LessEqualThan("$a", 123) }), 47 | newTestPair("$b <= $1", func(c *Cond) string { return c.LE("$b", 123) }), 48 | newTestPair("$c <= $1", func(c *Cond) string { return c.LTE("$c", 123) }), 49 | newTestPair("$a IN ($1, $2, $3)", func(c *Cond) string { return c.In("$a", 1, 2, 3) }), 50 | newTestPair("0 = 1", func(c *Cond) string { return c.In("$a") }), 51 | newTestPair("$a NOT IN ($1, $2, $3)", func(c *Cond) string { return c.NotIn("$a", 1, 2, 3) }), 52 | newTestPair("$a LIKE $1", func(c *Cond) string { return c.Like("$a", "%Huan%") }), 53 | newTestPair("$a ILIKE $1", func(c *Cond) string { return c.ILike("$a", "%Huan%") }), 54 | newTestPair("$a NOT LIKE $1", func(c *Cond) string { return c.NotLike("$a", "%Huan%") }), 55 | newTestPair("$a NOT ILIKE $1", func(c *Cond) string { return c.NotILike("$a", "%Huan%") }), 56 | newTestPair("$a IS NULL", func(c *Cond) string { return c.IsNull("$a") }), 57 | newTestPair("$a IS NOT NULL", func(c *Cond) string { return c.IsNotNull("$a") }), 58 | newTestPair("$a BETWEEN $1 AND $2", func(c *Cond) string { return c.Between("$a", 123, 456) }), 59 | newTestPair("$a NOT BETWEEN $1 AND $2", func(c *Cond) string { return c.NotBetween("$a", 123, 456) }), 60 | newTestPair("NOT 1 = 1", func(c *Cond) string { return c.Not("1 = 1") }), 61 | newTestPair("EXISTS ($1)", func(c *Cond) string { return c.Exists(1) }), 62 | newTestPair("NOT EXISTS ($1)", func(c *Cond) string { return c.NotExists(1) }), 63 | newTestPair("$a > ANY ($1, $2)", func(c *Cond) string { return c.Any("$a", ">", 1, 2) }), 64 | newTestPair("0 = 1", func(c *Cond) string { return c.Any("$a", ">") }), 65 | newTestPair("$a < ALL ($1)", func(c *Cond) string { return c.All("$a", "<", 1) }), 66 | newTestPair("0 = 1", func(c *Cond) string { return c.All("$a", "<") }), 67 | newTestPair("$a > SOME ($1, $2, $3)", func(c *Cond) string { return c.Some("$a", ">", 1, 2, 3) }), 68 | newTestPair("0 = 1", func(c *Cond) string { return c.Some("$a", ">") }), 69 | newTestPair("$a IS DISTINCT FROM $1", func(c *Cond) string { return c.IsDistinctFrom("$a", 1) }), 70 | newTestPair("$a IS NOT DISTINCT FROM $1", func(c *Cond) string { return c.IsNotDistinctFrom("$a", 1) }), 71 | newTestPair("$1", func(c *Cond) string { return c.Var(123) }), 72 | } 73 | 74 | for _, f := range cases { 75 | a.Equal(f.Actual, f.Expected) 76 | } 77 | } 78 | 79 | func TestOrCond(t *testing.T) { 80 | a := assert.New(t) 81 | cases := []*TestPair{ 82 | newTestPair("(1 = 1 OR 2 = 2 OR 3 = 3)", func(c *Cond) string { return c.Or("1 = 1", "2 = 2", "3 = 3") }), 83 | 84 | newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("", "1 = 1", "2 = 2") }), 85 | newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("1 = 1", "2 = 2", "") }), 86 | newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("1 = 1", "", "2 = 2") }), 87 | 88 | newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("1 = 1", "", "") }), 89 | newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("", "1 = 1", "") }), 90 | newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("", "", "1 = 1") }), 91 | newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("1 = 1") }), 92 | 93 | {Expected: "", Actual: newCond().Or("")}, 94 | {Expected: "", Actual: newCond().Or()}, 95 | {Expected: "", Actual: newCond().Or("", "", "")}, 96 | } 97 | 98 | for _, f := range cases { 99 | a.Equal(f.Actual, f.Expected) 100 | } 101 | } 102 | 103 | func TestAndCond(t *testing.T) { 104 | a := assert.New(t) 105 | cases := []*TestPair{ 106 | newTestPair("(1 = 1 AND 2 = 2 AND 3 = 3)", func(c *Cond) string { return c.And("1 = 1", "2 = 2", "3 = 3") }), 107 | 108 | newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("", "1 = 1", "2 = 2") }), 109 | newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("1 = 1", "2 = 2", "") }), 110 | newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("1 = 1", "", "2 = 2") }), 111 | 112 | newTestPair("(1 = 1)", func(c *Cond) string { return c.And("1 = 1", "", "") }), 113 | newTestPair("(1 = 1)", func(c *Cond) string { return c.And("", "1 = 1", "") }), 114 | newTestPair("(1 = 1)", func(c *Cond) string { return c.And("", "", "1 = 1") }), 115 | newTestPair("(1 = 1)", func(c *Cond) string { return c.And("1 = 1") }), 116 | 117 | {Expected: "", Actual: newCond().And("")}, 118 | {Expected: "", Actual: newCond().And()}, 119 | {Expected: "", Actual: newCond().And("", "", "")}, 120 | } 121 | 122 | for _, f := range cases { 123 | a.Equal(f.Actual, f.Expected) 124 | } 125 | } 126 | 127 | func TestEmptyCond(t *testing.T) { 128 | a := assert.New(t) 129 | cases := []string{ 130 | newCond().Equal("", 123), 131 | newCond().NotEqual("", 123), 132 | newCond().GreaterThan("", 123), 133 | newCond().GreaterEqualThan("", 123), 134 | newCond().LessThan("", 123), 135 | newCond().LessEqualThan("", 123), 136 | newCond().In("", 1, 2, 3), 137 | newCond().NotIn("", 1, 2, 3), 138 | newCond().NotIn("a"), 139 | newCond().Like("", "%Huan%"), 140 | newCond().ILike("", "%Huan%"), 141 | newCond().NotLike("", "%Huan%"), 142 | newCond().NotILike("", "%Huan%"), 143 | newCond().IsNull(""), 144 | newCond().IsNotNull(""), 145 | newCond().Between("", 123, 456), 146 | newCond().NotBetween("", 123, 456), 147 | newCond().Not(""), 148 | 149 | newCond().Any("", "", 1, 2), 150 | newCond().Any("", ">", 1, 2), 151 | newCond().Any("$a", "", 1, 2), 152 | 153 | newCond().All("", "", 1), 154 | newCond().All("", ">", 1), 155 | newCond().All("$a", "", 1), 156 | 157 | newCond().Some("", "", 1, 2, 3), 158 | newCond().Some("", ">", 1, 2, 3), 159 | newCond().Some("$a", "", 1, 2, 3), 160 | 161 | newCond().IsDistinctFrom("", 1), 162 | newCond().IsNotDistinctFrom("", 1), 163 | } 164 | 165 | expected := "" 166 | for _, actual := range cases { 167 | a.Equal(actual, expected) 168 | } 169 | } 170 | 171 | func TestCondWithFlavor(t *testing.T) { 172 | a := assert.New(t) 173 | cond := &Cond{ 174 | Args: &Args{}, 175 | } 176 | format := strings.Join([]string{ 177 | cond.ILike("f1", 1), 178 | cond.NotILike("f2", 2), 179 | cond.IsDistinctFrom("f3", 3), 180 | cond.IsNotDistinctFrom("f4", 4), 181 | }, "\n") 182 | expectedResults := map[Flavor]string{ 183 | PostgreSQL: `f1 ILIKE $1 184 | f2 NOT ILIKE $2 185 | f3 IS DISTINCT FROM $3 186 | f4 IS NOT DISTINCT FROM $4`, 187 | MySQL: `LOWER(f1) LIKE LOWER(?) 188 | LOWER(f2) NOT LIKE LOWER(?) 189 | NOT f3 <=> ? 190 | f4 <=> ?`, 191 | SQLite: `f1 ILIKE ? 192 | f2 NOT ILIKE ? 193 | f3 IS DISTINCT FROM ? 194 | f4 IS NOT DISTINCT FROM ?`, 195 | Presto: `LOWER(f1) LIKE LOWER(?) 196 | LOWER(f2) NOT LIKE LOWER(?) 197 | CASE WHEN f3 IS NULL AND ? IS NULL THEN 0 WHEN f3 IS NOT NULL AND ? IS NOT NULL AND f3 = ? THEN 0 ELSE 1 END = 1 198 | CASE WHEN f4 IS NULL AND ? IS NULL THEN 1 WHEN f4 IS NOT NULL AND ? IS NOT NULL AND f4 = ? THEN 1 ELSE 0 END = 1`, 199 | } 200 | 201 | for flavor, expected := range expectedResults { 202 | actual, _ := cond.Args.CompileWithFlavor(format, flavor) 203 | a.Equal(actual, expected) 204 | } 205 | } 206 | 207 | func TestCondExpr(t *testing.T) { 208 | a := assert.New(t) 209 | cond := &Cond{ 210 | Args: &Args{}, 211 | } 212 | sb1 := Select("1 = 1") 213 | sb2 := Select("FALSE") 214 | formats := []string{ 215 | cond.And(), 216 | cond.Or(), 217 | cond.And(cond.Var(sb1), cond.Var(sb2)), 218 | cond.Or(cond.Var(sb1), cond.Var(sb2)), 219 | cond.Not(cond.Or(cond.Var(sb1), cond.And(cond.Var(sb1), cond.Var(sb2)))), 220 | } 221 | expectResults := []string{ 222 | "", 223 | "", 224 | "(SELECT 1 = 1 AND SELECT FALSE)", 225 | "(SELECT 1 = 1 OR SELECT FALSE)", 226 | "NOT (SELECT 1 = 1 OR (SELECT 1 = 1 AND SELECT FALSE))", 227 | } 228 | 229 | for i, expected := range expectResults { 230 | actual, values := cond.Args.Compile(formats[i]) 231 | a.Equal(len(values), 0) 232 | a.Equal(actual, expected) 233 | } 234 | } 235 | 236 | func TestCondMisuse(t *testing.T) { 237 | a := assert.New(t) 238 | 239 | cond := NewCond() 240 | sb := Select("*"). 241 | From("t1"). 242 | Where(cond.Equal("a", 123)) 243 | sql, args := sb.Build() 244 | 245 | a.Equal(sql, "SELECT * FROM t1 WHERE /* INVALID ARG $256 */") 246 | a.Equal(args, nil) 247 | } 248 | 249 | func newCond() *Cond { 250 | args := &Args{} 251 | return &Cond{ 252 | Args: args, 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /createtable.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "strings" 8 | ) 9 | 10 | const ( 11 | createTableMarkerInit injectionMarker = iota 12 | createTableMarkerAfterCreate 13 | createTableMarkerAfterDefine 14 | createTableMarkerAfterOption 15 | ) 16 | 17 | // NewCreateTableBuilder creates a new CREATE TABLE builder. 18 | func NewCreateTableBuilder() *CreateTableBuilder { 19 | return DefaultFlavor.NewCreateTableBuilder() 20 | } 21 | 22 | func newCreateTableBuilder() *CreateTableBuilder { 23 | args := &Args{} 24 | return &CreateTableBuilder{ 25 | verb: "CREATE TABLE", 26 | args: args, 27 | injection: newInjection(), 28 | marker: createTableMarkerInit, 29 | } 30 | } 31 | 32 | // CreateTableBuilder is a builder to build CREATE TABLE. 33 | type CreateTableBuilder struct { 34 | verb string 35 | ifNotExists bool 36 | table string 37 | defs [][]string 38 | options [][]string 39 | 40 | args *Args 41 | 42 | injection *injection 43 | marker injectionMarker 44 | } 45 | 46 | var _ Builder = new(CreateTableBuilder) 47 | 48 | // CreateTable sets the table name in CREATE TABLE. 49 | func CreateTable(table string) *CreateTableBuilder { 50 | return DefaultFlavor.NewCreateTableBuilder().CreateTable(table) 51 | } 52 | 53 | // CreateTable sets the table name in CREATE TABLE. 54 | func (ctb *CreateTableBuilder) CreateTable(table string) *CreateTableBuilder { 55 | ctb.table = Escape(table) 56 | ctb.marker = createTableMarkerAfterCreate 57 | return ctb 58 | } 59 | 60 | // CreateTempTable sets the table name and changes the verb of ctb to CREATE TEMPORARY TABLE. 61 | func (ctb *CreateTableBuilder) CreateTempTable(table string) *CreateTableBuilder { 62 | ctb.verb = "CREATE TEMPORARY TABLE" 63 | ctb.table = Escape(table) 64 | ctb.marker = createTableMarkerAfterCreate 65 | return ctb 66 | } 67 | 68 | // IfNotExists adds IF NOT EXISTS before table name in CREATE TABLE. 69 | func (ctb *CreateTableBuilder) IfNotExists() *CreateTableBuilder { 70 | ctb.ifNotExists = true 71 | return ctb 72 | } 73 | 74 | // Define adds definition of a column or index in CREATE TABLE. 75 | func (ctb *CreateTableBuilder) Define(def ...string) *CreateTableBuilder { 76 | ctb.defs = append(ctb.defs, def) 77 | ctb.marker = createTableMarkerAfterDefine 78 | return ctb 79 | } 80 | 81 | // Option adds a table option in CREATE TABLE. 82 | func (ctb *CreateTableBuilder) Option(opt ...string) *CreateTableBuilder { 83 | ctb.options = append(ctb.options, opt) 84 | ctb.marker = createTableMarkerAfterOption 85 | return ctb 86 | } 87 | 88 | // NumDefine returns the number of definitions in CREATE TABLE. 89 | func (ctb *CreateTableBuilder) NumDefine() int { 90 | return len(ctb.defs) 91 | } 92 | 93 | // String returns the compiled INSERT string. 94 | func (ctb *CreateTableBuilder) String() string { 95 | s, _ := ctb.Build() 96 | return s 97 | } 98 | 99 | // Build returns compiled CREATE TABLE string and args. 100 | // They can be used in `DB#Query` of package `database/sql` directly. 101 | func (ctb *CreateTableBuilder) Build() (sql string, args []interface{}) { 102 | return ctb.BuildWithFlavor(ctb.args.Flavor) 103 | } 104 | 105 | // BuildWithFlavor returns compiled CREATE TABLE string and args with flavor and initial args. 106 | // They can be used in `DB#Query` of package `database/sql` directly. 107 | func (ctb *CreateTableBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 108 | buf := newStringBuilder() 109 | ctb.injection.WriteTo(buf, createTableMarkerInit) 110 | 111 | if len(ctb.verb) > 0 { 112 | buf.WriteLeadingString(ctb.verb) 113 | } 114 | 115 | if ctb.ifNotExists { 116 | buf.WriteLeadingString("IF NOT EXISTS") 117 | } 118 | 119 | if len(ctb.table) > 0 { 120 | buf.WriteLeadingString(ctb.table) 121 | } 122 | 123 | ctb.injection.WriteTo(buf, createTableMarkerAfterCreate) 124 | 125 | if len(ctb.defs) > 0 { 126 | buf.WriteLeadingString("(") 127 | 128 | defs := make([]string, 0, len(ctb.defs)) 129 | 130 | for _, def := range ctb.defs { 131 | defs = append(defs, strings.Join(def, " ")) 132 | } 133 | 134 | buf.WriteStrings(defs, ", ") 135 | buf.WriteRune(')') 136 | 137 | ctb.injection.WriteTo(buf, createTableMarkerAfterDefine) 138 | } 139 | 140 | if len(ctb.options) > 0 { 141 | opts := make([]string, 0, len(ctb.options)) 142 | 143 | for _, opt := range ctb.options { 144 | opts = append(opts, strings.Join(opt, " ")) 145 | } 146 | 147 | buf.WriteLeadingString(strings.Join(opts, ", ")) 148 | ctb.injection.WriteTo(buf, createTableMarkerAfterOption) 149 | } 150 | 151 | return ctb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 152 | } 153 | 154 | // SetFlavor sets the flavor of compiled sql. 155 | func (ctb *CreateTableBuilder) SetFlavor(flavor Flavor) (old Flavor) { 156 | old = ctb.args.Flavor 157 | ctb.args.Flavor = flavor 158 | return 159 | } 160 | 161 | // Flavor returns flavor of builder 162 | func (ctb *CreateTableBuilder) Flavor() Flavor { 163 | return ctb.args.Flavor 164 | } 165 | 166 | // Var returns a placeholder for value. 167 | func (ctb *CreateTableBuilder) Var(arg interface{}) string { 168 | return ctb.args.Add(arg) 169 | } 170 | 171 | // SQL adds an arbitrary sql to current position. 172 | func (ctb *CreateTableBuilder) SQL(sql string) *CreateTableBuilder { 173 | ctb.injection.SQL(ctb.marker, sql) 174 | return ctb 175 | } 176 | -------------------------------------------------------------------------------- /createtable_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleCreateTable() { 14 | sql := CreateTable("demo.user").IfNotExists(). 15 | Define("id", "BIGINT(20)", "NOT NULL", "AUTO_INCREMENT", "PRIMARY KEY", `COMMENT "user id"`). 16 | String() 17 | 18 | fmt.Println(sql) 19 | 20 | // Output: 21 | // CREATE TABLE IF NOT EXISTS demo.user (id BIGINT(20) NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT "user id") 22 | } 23 | 24 | func ExampleCreateTableBuilder() { 25 | ctb := NewCreateTableBuilder() 26 | ctb.CreateTable("demo.user").IfNotExists() 27 | ctb.Define("id", "BIGINT(20)", "NOT NULL", "AUTO_INCREMENT", "PRIMARY KEY", `COMMENT "user id"`) 28 | ctb.Define("name", "VARCHAR(255)", "NOT NULL", `COMMENT "user name"`) 29 | ctb.Define("created_at", "DATETIME", "NOT NULL", `COMMENT "user create time"`) 30 | ctb.Define("modified_at", "DATETIME", "NOT NULL", `COMMENT "user modify time"`) 31 | ctb.Define("KEY", "idx_name_modified_at", "name, modified_at") 32 | ctb.Option("DEFAULT CHARACTER SET", "utf8mb4") 33 | 34 | fmt.Println(ctb) 35 | 36 | // Output: 37 | // CREATE TABLE IF NOT EXISTS demo.user (id BIGINT(20) NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT "user id", name VARCHAR(255) NOT NULL COMMENT "user name", created_at DATETIME NOT NULL COMMENT "user create time", modified_at DATETIME NOT NULL COMMENT "user modify time", KEY idx_name_modified_at name, modified_at) DEFAULT CHARACTER SET utf8mb4 38 | } 39 | 40 | func ExampleCreateTableBuilder_tempTable() { 41 | ctb := NewCreateTableBuilder() 42 | ctb.CreateTempTable("demo.user").IfNotExists() 43 | ctb.Define("id", "BIGINT(20)", "NOT NULL", "AUTO_INCREMENT", "PRIMARY KEY", `COMMENT "user id"`) 44 | ctb.Define("name", "VARCHAR(255)", "NOT NULL", `COMMENT "user name"`) 45 | ctb.Define("created_at", "DATETIME", "NOT NULL", `COMMENT "user create time"`) 46 | ctb.Define("modified_at", "DATETIME", "NOT NULL", `COMMENT "user modify time"`) 47 | ctb.Define("KEY", "idx_name_modified_at", "name, modified_at") 48 | ctb.Option("DEFAULT CHARACTER SET", "utf8mb4") 49 | 50 | fmt.Println(ctb) 51 | 52 | // Output: 53 | // CREATE TEMPORARY TABLE IF NOT EXISTS demo.user (id BIGINT(20) NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT "user id", name VARCHAR(255) NOT NULL COMMENT "user name", created_at DATETIME NOT NULL COMMENT "user create time", modified_at DATETIME NOT NULL COMMENT "user modify time", KEY idx_name_modified_at name, modified_at) DEFAULT CHARACTER SET utf8mb4 54 | } 55 | 56 | func ExampleCreateTableBuilder_SQL() { 57 | ctb := NewCreateTableBuilder() 58 | ctb.SQL(`/* before */`) 59 | ctb.CreateTempTable("demo.user").IfNotExists() 60 | ctb.SQL("/* after create */") 61 | ctb.Define("id", "BIGINT(20)", "NOT NULL", "AUTO_INCREMENT", "PRIMARY KEY", `COMMENT "user id"`) 62 | ctb.Define("name", "VARCHAR(255)", "NOT NULL", `COMMENT "user name"`) 63 | ctb.SQL("/* after define */") 64 | ctb.Option("DEFAULT CHARACTER SET", "utf8mb4") 65 | ctb.SQL(ctb.Var(Build("AS SELECT * FROM old.user WHERE name LIKE $?", "%Huan%"))) 66 | 67 | sql, args := ctb.Build() 68 | fmt.Println(sql) 69 | fmt.Println(args) 70 | 71 | // Output: 72 | // /* before */ CREATE TEMPORARY TABLE IF NOT EXISTS demo.user /* after create */ (id BIGINT(20) NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT "user id", name VARCHAR(255) NOT NULL COMMENT "user name") /* after define */ DEFAULT CHARACTER SET utf8mb4 AS SELECT * FROM old.user WHERE name LIKE ? 73 | // [%Huan%] 74 | } 75 | 76 | func ExampleCreateTableBuilder_NumDefine() { 77 | ctb := NewCreateTableBuilder() 78 | ctb.CreateTable("demo.user").IfNotExists() 79 | ctb.Define("id", "BIGINT(20)", "NOT NULL", "AUTO_INCREMENT", "PRIMARY KEY", `COMMENT "user id"`) 80 | ctb.Define("name", "VARCHAR(255)", "NOT NULL", `COMMENT "user name"`) 81 | ctb.Define("created_at", "DATETIME", "NOT NULL", `COMMENT "user create time"`) 82 | ctb.Define("modified_at", "DATETIME", "NOT NULL", `COMMENT "user modify time"`) 83 | ctb.Define("KEY", "idx_name_modified_at", "name, modified_at") 84 | ctb.Option("DEFAULT CHARACTER SET", "utf8mb4") 85 | 86 | // Count the number of definitions. 87 | fmt.Println(ctb.NumDefine()) 88 | 89 | // Output: 90 | // 5 91 | } 92 | 93 | func TestCreateTableGetFlavor(t *testing.T) { 94 | a := assert.New(t) 95 | ctb := newCreateTableBuilder() 96 | 97 | ctb.SetFlavor(PostgreSQL) 98 | flavor := ctb.Flavor() 99 | a.Equal(PostgreSQL, flavor) 100 | 101 | ctbClick := ClickHouse.NewCreateTableBuilder() 102 | flavor = ctbClick.Flavor() 103 | a.Equal(ClickHouse, flavor) 104 | } 105 | -------------------------------------------------------------------------------- /cte.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | const ( 7 | cteMarkerInit injectionMarker = iota 8 | cteMarkerAfterWith 9 | ) 10 | 11 | // With creates a new CTE builder with default flavor. 12 | func With(tables ...*CTEQueryBuilder) *CTEBuilder { 13 | return DefaultFlavor.NewCTEBuilder().With(tables...) 14 | } 15 | 16 | // WithRecursive creates a new recursive CTE builder with default flavor. 17 | func WithRecursive(tables ...*CTEQueryBuilder) *CTEBuilder { 18 | return DefaultFlavor.NewCTEBuilder().WithRecursive(tables...) 19 | } 20 | 21 | func newCTEBuilder() *CTEBuilder { 22 | return &CTEBuilder{ 23 | args: &Args{}, 24 | injection: newInjection(), 25 | } 26 | } 27 | 28 | // CTEBuilder is a CTE (Common Table Expression) builder. 29 | type CTEBuilder struct { 30 | recursive bool 31 | queries []*CTEQueryBuilder 32 | queryBuilderVars []string 33 | 34 | args *Args 35 | 36 | injection *injection 37 | marker injectionMarker 38 | } 39 | 40 | var _ Builder = new(CTEBuilder) 41 | 42 | // With sets the CTE name and columns. 43 | func (cteb *CTEBuilder) With(queries ...*CTEQueryBuilder) *CTEBuilder { 44 | queryBuilderVars := make([]string, 0, len(queries)) 45 | 46 | for _, query := range queries { 47 | queryBuilderVars = append(queryBuilderVars, cteb.args.Add(query)) 48 | } 49 | 50 | cteb.queries = queries 51 | cteb.queryBuilderVars = queryBuilderVars 52 | cteb.marker = cteMarkerAfterWith 53 | return cteb 54 | } 55 | 56 | // WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword. 57 | func (cteb *CTEBuilder) WithRecursive(queries ...*CTEQueryBuilder) *CTEBuilder { 58 | cteb.With(queries...).recursive = true 59 | return cteb 60 | } 61 | 62 | // Select creates a new SelectBuilder to build a SELECT statement using this CTE. 63 | func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder { 64 | sb := cteb.args.Flavor.NewSelectBuilder() 65 | return sb.With(cteb).Select(col...) 66 | } 67 | 68 | // DeleteFrom creates a new DeleteBuilder to build a DELETE statement using this CTE. 69 | func (cteb *CTEBuilder) DeleteFrom(table string) *DeleteBuilder { 70 | db := cteb.args.Flavor.NewDeleteBuilder() 71 | return db.With(cteb).DeleteFrom(table) 72 | } 73 | 74 | // Update creates a new UpdateBuilder to build an UPDATE statement using this CTE. 75 | func (cteb *CTEBuilder) Update(table string) *UpdateBuilder { 76 | ub := cteb.args.Flavor.NewUpdateBuilder() 77 | return ub.With(cteb).Update(table) 78 | } 79 | 80 | // String returns the compiled CTE string. 81 | func (cteb *CTEBuilder) String() string { 82 | sql, _ := cteb.Build() 83 | return sql 84 | } 85 | 86 | // Build returns compiled CTE string and args. 87 | func (cteb *CTEBuilder) Build() (sql string, args []interface{}) { 88 | return cteb.BuildWithFlavor(cteb.args.Flavor) 89 | } 90 | 91 | // BuildWithFlavor builds a CTE with the specified flavor and initial arguments. 92 | func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 93 | buf := newStringBuilder() 94 | cteb.injection.WriteTo(buf, cteMarkerInit) 95 | 96 | if len(cteb.queryBuilderVars) > 0 { 97 | buf.WriteLeadingString("WITH ") 98 | if cteb.recursive { 99 | buf.WriteString("RECURSIVE ") 100 | } 101 | buf.WriteStrings(cteb.queryBuilderVars, ", ") 102 | } 103 | 104 | cteb.injection.WriteTo(buf, cteMarkerAfterWith) 105 | return cteb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 106 | } 107 | 108 | // SetFlavor sets the flavor of compiled sql. 109 | func (cteb *CTEBuilder) SetFlavor(flavor Flavor) (old Flavor) { 110 | old = cteb.args.Flavor 111 | cteb.args.Flavor = flavor 112 | return 113 | } 114 | 115 | // Flavor returns flavor of builder 116 | func (cteb *CTEBuilder) Flavor() Flavor { 117 | return cteb.args.Flavor 118 | } 119 | 120 | // SQL adds an arbitrary sql to current position. 121 | func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder { 122 | cteb.injection.SQL(cteb.marker, sql) 123 | return cteb 124 | } 125 | 126 | // TableNames returns all table names in a CTE. 127 | func (cteb *CTEBuilder) TableNames() []string { 128 | if len(cteb.queryBuilderVars) == 0 { 129 | return nil 130 | } 131 | 132 | tableNames := make([]string, 0, len(cteb.queries)) 133 | 134 | for _, query := range cteb.queries { 135 | tableNames = append(tableNames, query.TableName()) 136 | } 137 | 138 | return tableNames 139 | } 140 | 141 | // tableNamesForFrom returns a list of table names which should be automatically added to FROM clause. 142 | // It's not public, as this feature is designed only for SelectBuilder/UpdateBuilder/DeleteBuilder right now. 143 | func (cteb *CTEBuilder) tableNamesForFrom() []string { 144 | cnt := 0 145 | 146 | // ShouldAddToTableList() unlikely returns true. 147 | // Count it before allocating any memory for better performance. 148 | for _, query := range cteb.queries { 149 | if query.ShouldAddToTableList() { 150 | cnt++ 151 | } 152 | } 153 | 154 | if cnt == 0 { 155 | return nil 156 | } 157 | 158 | tableNames := make([]string, 0, cnt) 159 | 160 | for _, query := range cteb.queries { 161 | if query.ShouldAddToTableList() { 162 | tableNames = append(tableNames, query.TableName()) 163 | } 164 | } 165 | 166 | return tableNames 167 | } 168 | -------------------------------------------------------------------------------- /cte_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleWith() { 14 | sb := With( 15 | CTETable("users", "id", "name").As( 16 | Select("id", "name").From("users").Where("name IS NOT NULL"), 17 | ), 18 | CTETable("devices").As( 19 | Select("device_id").From("devices"), 20 | ), 21 | ).Select("users.id", "orders.id", "devices.device_id").Join( 22 | "orders", 23 | "users.id = orders.user_id", 24 | "devices.device_id = orders.device_id", 25 | ) 26 | 27 | fmt.Println(sb) 28 | 29 | // Output: 30 | // WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL), devices AS (SELECT device_id FROM devices) SELECT users.id, orders.id, devices.device_id FROM users, devices JOIN orders ON users.id = orders.user_id AND devices.device_id = orders.device_id 31 | } 32 | 33 | func ExampleWithRecursive() { 34 | sb := WithRecursive( 35 | CTEQuery("source_accounts", "id", "parent_id").As( 36 | UnionAll( 37 | Select("p.id", "p.parent_id"). 38 | From("accounts AS p"). 39 | Where("p.id = 2"), // Show orders for account 2 and all its child accounts 40 | Select("c.id", "c.parent_id"). 41 | From("accounts AS c"). 42 | Join("source_accounts AS sa", "c.parent_id = sa.id"), 43 | ), 44 | ), 45 | ).Select("o.id", "o.date", "o.amount"). 46 | From("orders AS o"). 47 | Join("source_accounts", "o.account_id = source_accounts.id") 48 | 49 | fmt.Println(sb) 50 | 51 | // Output: 52 | // WITH RECURSIVE source_accounts (id, parent_id) AS ((SELECT p.id, p.parent_id FROM accounts AS p WHERE p.id = 2) UNION ALL (SELECT c.id, c.parent_id FROM accounts AS c JOIN source_accounts AS sa ON c.parent_id = sa.id)) SELECT o.id, o.date, o.amount FROM orders AS o JOIN source_accounts ON o.account_id = source_accounts.id 53 | } 54 | 55 | func ExampleCTEBuilder() { 56 | usersBuilder := Select("id", "name", "level").From("users") 57 | usersBuilder.Where( 58 | usersBuilder.GreaterEqualThan("level", 10), 59 | ) 60 | cteb := With( 61 | CTETable("valid_users").As(usersBuilder), 62 | ) 63 | fmt.Println(cteb) 64 | 65 | sb := Select("valid_users.id", "valid_users.name", "orders.id"). 66 | From("users").With(cteb). 67 | Join("orders", "users.id = orders.user_id") 68 | sb.Where( 69 | sb.LessEqualThan("orders.price", 200), 70 | "valid_users.level < orders.min_level", 71 | ).OrderBy("orders.price").Desc() 72 | 73 | sql, args := sb.Build() 74 | fmt.Println(sql) 75 | fmt.Println(args) 76 | fmt.Println(sb.TableNames()) 77 | 78 | // Output: 79 | // WITH valid_users AS (SELECT id, name, level FROM users WHERE level >= ?) 80 | // WITH valid_users AS (SELECT id, name, level FROM users WHERE level >= ?) SELECT valid_users.id, valid_users.name, orders.id FROM users, valid_users JOIN orders ON users.id = orders.user_id WHERE orders.price <= ? AND valid_users.level < orders.min_level ORDER BY orders.price DESC 81 | // [10 200] 82 | // [users valid_users] 83 | } 84 | 85 | func ExampleCTEBuilder_update() { 86 | builder := With( 87 | CTETable("users", "user_id").As( 88 | Select("user_id").From("vip_users"), 89 | ), 90 | ).Update("orders").Set( 91 | "orders.transport_fee = 0", 92 | ).Where( 93 | "users.user_id = orders.user_id", 94 | ) 95 | 96 | sqlForMySQL, _ := builder.BuildWithFlavor(MySQL) 97 | sqlForPostgreSQL, _ := builder.BuildWithFlavor(PostgreSQL) 98 | 99 | fmt.Println(sqlForMySQL) 100 | fmt.Println(sqlForPostgreSQL) 101 | 102 | // Output: 103 | // WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id 104 | // WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders SET orders.transport_fee = 0 FROM users WHERE users.user_id = orders.user_id 105 | } 106 | 107 | func ExampleCTEBuilder_delete() { 108 | sql := With( 109 | CTETable("users", "user_id").As( 110 | Select("user_id").From("cheaters"), 111 | ), 112 | ).DeleteFrom("awards").Where( 113 | "users.user_id = awards.user_id", 114 | ).String() 115 | 116 | fmt.Println(sql) 117 | 118 | // Output: 119 | // WITH users (user_id) AS (SELECT user_id FROM cheaters) DELETE FROM awards, users WHERE users.user_id = awards.user_id 120 | } 121 | 122 | func TestCTEBuilder(t *testing.T) { 123 | a := assert.New(t) 124 | cteb := newCTEBuilder() 125 | ctetb := newCTEQueryBuilder() 126 | cteb.SQL("/* init */") 127 | cteb.With(ctetb) 128 | cteb.SQL("/* after with */") 129 | 130 | ctetb.SQL("/* table init */") 131 | ctetb.Table("t", "a", "b") 132 | ctetb.SQL("/* after table */") 133 | 134 | ctetb.As(Select("a", "b").From("t")) 135 | ctetb.SQL("/* after table as */") 136 | 137 | a.Equal(cteb.TableNames(), []string{ctetb.TableName()}) 138 | 139 | sql, args := cteb.Build() 140 | a.Equal(sql, "/* init */ WITH /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */") 141 | a.Assert(args == nil) 142 | 143 | sql = ctetb.String() 144 | a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */") 145 | } 146 | 147 | func TestRecursiveCTEBuilder(t *testing.T) { 148 | a := assert.New(t) 149 | cteb := newCTEBuilder() 150 | cteb.recursive = true 151 | ctetb := newCTEQueryBuilder() 152 | cteb.SQL("/* init */") 153 | cteb.With(ctetb) 154 | cteb.SQL("/* after with */") 155 | 156 | ctetb.SQL("/* table init */") 157 | ctetb.Table("t", "a", "b") 158 | ctetb.SQL("/* after table */") 159 | 160 | ctetb.As(Select("a", "b").From("t")) 161 | ctetb.SQL("/* after table as */") 162 | 163 | sql, args := cteb.Build() 164 | a.Equal(sql, "/* init */ WITH RECURSIVE /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */") 165 | a.Assert(args == nil) 166 | 167 | sql = ctetb.String() 168 | a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */") 169 | } 170 | 171 | func TestCTEGetFlavor(t *testing.T) { 172 | a := assert.New(t) 173 | cteb := newCTEBuilder() 174 | 175 | cteb.SetFlavor(PostgreSQL) 176 | flavor := cteb.Flavor() 177 | a.Equal(PostgreSQL, flavor) 178 | 179 | ctebClick := ClickHouse.NewCTEBuilder() 180 | flavor = ctebClick.Flavor() 181 | a.Equal(ClickHouse, flavor) 182 | } 183 | 184 | func TestCTEQueryBuilderGetFlavor(t *testing.T) { 185 | a := assert.New(t) 186 | ctetb := newCTEQueryBuilder() 187 | 188 | ctetb.SetFlavor(PostgreSQL) 189 | flavor := ctetb.Flavor() 190 | a.Equal(PostgreSQL, flavor) 191 | 192 | ctetbClick := ClickHouse.NewCTEQueryBuilder() 193 | flavor = ctetbClick.Flavor() 194 | a.Equal(ClickHouse, flavor) 195 | } 196 | -------------------------------------------------------------------------------- /ctequery.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | const ( 7 | cteQueryMarkerInit injectionMarker = iota 8 | cteQueryMarkerAfterTable 9 | cteQueryMarkerAfterAs 10 | ) 11 | 12 | // CTETable creates a new CTE query builder with default flavor, marking it as a table. 13 | // 14 | // The resulting CTE query can be used in a `SelectBuilder“, where its table name will be 15 | // automatically included in the FROM clause. 16 | func CTETable(name string, cols ...string) *CTEQueryBuilder { 17 | return DefaultFlavor.NewCTEQueryBuilder().AddToTableList().Table(name, cols...) 18 | } 19 | 20 | // CTEQuery creates a new CTE query builder with default flavor. 21 | func CTEQuery(name string, cols ...string) *CTEQueryBuilder { 22 | return DefaultFlavor.NewCTEQueryBuilder().Table(name, cols...) 23 | } 24 | 25 | func newCTEQueryBuilder() *CTEQueryBuilder { 26 | return &CTEQueryBuilder{ 27 | args: &Args{}, 28 | injection: newInjection(), 29 | } 30 | } 31 | 32 | // CTEQueryBuilder is a builder to build one table in CTE (Common Table Expression). 33 | type CTEQueryBuilder struct { 34 | name string 35 | cols []string 36 | builderVar string 37 | 38 | // if true, this query's table name will be automatically added to the table list 39 | // in FROM clause of SELECT statement. 40 | autoAddToTableList bool 41 | 42 | args *Args 43 | 44 | injection *injection 45 | marker injectionMarker 46 | } 47 | 48 | var _ Builder = new(CTEQueryBuilder) 49 | 50 | // CTETableBuilder is an alias of CTEQueryBuilder for backward compatibility. 51 | // 52 | // Deprecated: use CTEQueryBuilder instead. 53 | type CTETableBuilder = CTEQueryBuilder 54 | 55 | // Table sets the table name and columns in a CTE table. 56 | func (ctetb *CTEQueryBuilder) Table(name string, cols ...string) *CTEQueryBuilder { 57 | ctetb.name = name 58 | ctetb.cols = cols 59 | ctetb.marker = cteQueryMarkerAfterTable 60 | return ctetb 61 | } 62 | 63 | // As sets the builder to select data. 64 | func (ctetb *CTEQueryBuilder) As(builder Builder) *CTEQueryBuilder { 65 | ctetb.builderVar = ctetb.args.Add(builder) 66 | ctetb.marker = cteQueryMarkerAfterAs 67 | return ctetb 68 | } 69 | 70 | // AddToTableList sets flag to add table name to table list in FROM clause of SELECT statement. 71 | func (ctetb *CTEQueryBuilder) AddToTableList() *CTEQueryBuilder { 72 | ctetb.autoAddToTableList = true 73 | return ctetb 74 | } 75 | 76 | // ShouldAddToTableList returns flag to add table name to table list in FROM clause of SELECT statement. 77 | func (ctetb *CTEQueryBuilder) ShouldAddToTableList() bool { 78 | return ctetb.autoAddToTableList 79 | } 80 | 81 | // String returns the compiled CTE string. 82 | func (ctetb *CTEQueryBuilder) String() string { 83 | sql, _ := ctetb.Build() 84 | return sql 85 | } 86 | 87 | // Build returns compiled CTE string and args. 88 | func (ctetb *CTEQueryBuilder) Build() (sql string, args []interface{}) { 89 | return ctetb.BuildWithFlavor(ctetb.args.Flavor) 90 | } 91 | 92 | // BuildWithFlavor builds a CTE with the specified flavor and initial arguments. 93 | func (ctetb *CTEQueryBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 94 | buf := newStringBuilder() 95 | ctetb.injection.WriteTo(buf, cteQueryMarkerInit) 96 | 97 | if ctetb.name != "" { 98 | buf.WriteLeadingString(ctetb.name) 99 | 100 | if len(ctetb.cols) > 0 { 101 | buf.WriteLeadingString("(") 102 | buf.WriteStrings(ctetb.cols, ", ") 103 | buf.WriteString(")") 104 | } 105 | 106 | ctetb.injection.WriteTo(buf, cteQueryMarkerAfterTable) 107 | } 108 | 109 | if ctetb.builderVar != "" { 110 | buf.WriteLeadingString("AS (") 111 | buf.WriteString(ctetb.builderVar) 112 | buf.WriteRune(')') 113 | 114 | ctetb.injection.WriteTo(buf, cteQueryMarkerAfterAs) 115 | } 116 | 117 | return ctetb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 118 | } 119 | 120 | // SetFlavor sets the flavor of compiled sql. 121 | func (ctetb *CTEQueryBuilder) SetFlavor(flavor Flavor) (old Flavor) { 122 | old = ctetb.args.Flavor 123 | ctetb.args.Flavor = flavor 124 | return 125 | } 126 | 127 | // Flavor returns flavor of builder 128 | func (ctetb *CTEQueryBuilder) Flavor() Flavor { 129 | return ctetb.args.Flavor 130 | } 131 | 132 | // SQL adds an arbitrary sql to current position. 133 | func (ctetb *CTEQueryBuilder) SQL(sql string) *CTEQueryBuilder { 134 | ctetb.injection.SQL(ctetb.marker, sql) 135 | return ctetb 136 | } 137 | 138 | // TableName returns the CTE table name. 139 | func (ctetb *CTEQueryBuilder) TableName() string { 140 | return ctetb.name 141 | } 142 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | const ( 7 | deleteMarkerInit injectionMarker = iota 8 | deleteMarkerAfterWith 9 | deleteMarkerAfterDeleteFrom 10 | deleteMarkerAfterWhere 11 | deleteMarkerAfterOrderBy 12 | deleteMarkerAfterLimit 13 | ) 14 | 15 | // NewDeleteBuilder creates a new DELETE builder. 16 | func NewDeleteBuilder() *DeleteBuilder { 17 | return DefaultFlavor.NewDeleteBuilder() 18 | } 19 | 20 | func newDeleteBuilder() *DeleteBuilder { 21 | args := &Args{} 22 | proxy := &whereClauseProxy{} 23 | return &DeleteBuilder{ 24 | whereClauseProxy: proxy, 25 | whereClauseExpr: args.Add(proxy), 26 | 27 | Cond: Cond{ 28 | Args: args, 29 | }, 30 | args: args, 31 | injection: newInjection(), 32 | } 33 | } 34 | 35 | // DeleteBuilder is a builder to build DELETE. 36 | type DeleteBuilder struct { 37 | *WhereClause 38 | Cond 39 | 40 | whereClauseProxy *whereClauseProxy 41 | whereClauseExpr string 42 | 43 | cteBuilderVar string 44 | cteBuilder *CTEBuilder 45 | 46 | tables []string 47 | orderByCols []string 48 | order string 49 | limitVar string 50 | 51 | args *Args 52 | 53 | injection *injection 54 | marker injectionMarker 55 | } 56 | 57 | var _ Builder = new(DeleteBuilder) 58 | 59 | // DeleteFrom sets table name in DELETE. 60 | func DeleteFrom(table ...string) *DeleteBuilder { 61 | return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table...) 62 | } 63 | 64 | // With sets WITH clause (the Common Table Expression) before DELETE. 65 | func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder { 66 | db.marker = deleteMarkerAfterWith 67 | db.cteBuilderVar = db.Var(builder) 68 | db.cteBuilder = builder 69 | return db 70 | } 71 | 72 | // DeleteFrom sets table name in DELETE. 73 | func (db *DeleteBuilder) DeleteFrom(table ...string) *DeleteBuilder { 74 | db.tables = table 75 | db.marker = deleteMarkerAfterDeleteFrom 76 | return db 77 | } 78 | 79 | // TableNames returns all table names in this DELETE statement. 80 | func (db *DeleteBuilder) TableNames() []string { 81 | var additionalTableNames []string 82 | 83 | if db.cteBuilder != nil { 84 | additionalTableNames = db.cteBuilder.tableNamesForFrom() 85 | } 86 | 87 | var tableNames []string 88 | 89 | if len(db.tables) > 0 && len(additionalTableNames) > 0 { 90 | tableNames = make([]string, len(db.tables)+len(additionalTableNames)) 91 | copy(tableNames, db.tables) 92 | copy(tableNames[len(db.tables):], additionalTableNames) 93 | } else if len(db.tables) > 0 { 94 | tableNames = db.tables 95 | } else if len(additionalTableNames) > 0 { 96 | tableNames = additionalTableNames 97 | } 98 | 99 | return tableNames 100 | } 101 | 102 | // Where sets expressions of WHERE in DELETE. 103 | func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder { 104 | if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { 105 | return db 106 | } 107 | 108 | if db.WhereClause == nil { 109 | db.WhereClause = NewWhereClause() 110 | } 111 | 112 | db.WhereClause.AddWhereExpr(db.args, andExpr...) 113 | db.marker = deleteMarkerAfterWhere 114 | return db 115 | } 116 | 117 | // AddWhereClause adds all clauses in the whereClause to SELECT. 118 | func (db *DeleteBuilder) AddWhereClause(whereClause *WhereClause) *DeleteBuilder { 119 | if db.WhereClause == nil { 120 | db.WhereClause = NewWhereClause() 121 | } 122 | 123 | db.WhereClause.AddWhereClause(whereClause) 124 | return db 125 | } 126 | 127 | // OrderBy sets columns of ORDER BY in DELETE. 128 | func (db *DeleteBuilder) OrderBy(col ...string) *DeleteBuilder { 129 | db.orderByCols = col 130 | db.marker = deleteMarkerAfterOrderBy 131 | return db 132 | } 133 | 134 | // Asc sets order of ORDER BY to ASC. 135 | func (db *DeleteBuilder) Asc() *DeleteBuilder { 136 | db.order = "ASC" 137 | db.marker = deleteMarkerAfterOrderBy 138 | return db 139 | } 140 | 141 | // Desc sets order of ORDER BY to DESC. 142 | func (db *DeleteBuilder) Desc() *DeleteBuilder { 143 | db.order = "DESC" 144 | db.marker = deleteMarkerAfterOrderBy 145 | return db 146 | } 147 | 148 | // Limit sets the LIMIT in DELETE. 149 | func (db *DeleteBuilder) Limit(limit int) *DeleteBuilder { 150 | if limit < 0 { 151 | db.limitVar = "" 152 | return db 153 | } 154 | 155 | db.limitVar = db.Var(limit) 156 | db.marker = deleteMarkerAfterLimit 157 | return db 158 | } 159 | 160 | // String returns the compiled DELETE string. 161 | func (db *DeleteBuilder) String() string { 162 | s, _ := db.Build() 163 | return s 164 | } 165 | 166 | // Build returns compiled DELETE string and args. 167 | // They can be used in `DB#Query` of package `database/sql` directly. 168 | func (db *DeleteBuilder) Build() (sql string, args []interface{}) { 169 | return db.BuildWithFlavor(db.args.Flavor) 170 | } 171 | 172 | // BuildWithFlavor returns compiled DELETE string and args with flavor and initial args. 173 | // They can be used in `DB#Query` of package `database/sql` directly. 174 | func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 175 | buf := newStringBuilder() 176 | db.injection.WriteTo(buf, deleteMarkerInit) 177 | 178 | if db.cteBuilder != nil { 179 | buf.WriteLeadingString(db.cteBuilderVar) 180 | db.injection.WriteTo(buf, deleteMarkerAfterWith) 181 | } 182 | 183 | tableNames := db.TableNames() 184 | 185 | if len(tableNames) > 0 { 186 | buf.WriteLeadingString("DELETE FROM ") 187 | buf.WriteStrings(tableNames, ", ") 188 | } 189 | 190 | db.injection.WriteTo(buf, deleteMarkerAfterDeleteFrom) 191 | 192 | if db.WhereClause != nil { 193 | db.whereClauseProxy.WhereClause = db.WhereClause 194 | defer func() { 195 | db.whereClauseProxy.WhereClause = nil 196 | }() 197 | 198 | buf.WriteLeadingString(db.whereClauseExpr) 199 | db.injection.WriteTo(buf, deleteMarkerAfterWhere) 200 | } 201 | 202 | if len(db.orderByCols) > 0 { 203 | buf.WriteLeadingString("ORDER BY ") 204 | buf.WriteStrings(db.orderByCols, ", ") 205 | 206 | if db.order != "" { 207 | buf.WriteRune(' ') 208 | buf.WriteString(db.order) 209 | } 210 | 211 | db.injection.WriteTo(buf, deleteMarkerAfterOrderBy) 212 | } 213 | 214 | if len(db.limitVar) > 0 { 215 | buf.WriteLeadingString("LIMIT ") 216 | buf.WriteString(db.limitVar) 217 | 218 | db.injection.WriteTo(buf, deleteMarkerAfterLimit) 219 | } 220 | 221 | return db.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 222 | } 223 | 224 | // SetFlavor sets the flavor of compiled sql. 225 | func (db *DeleteBuilder) SetFlavor(flavor Flavor) (old Flavor) { 226 | old = db.args.Flavor 227 | db.args.Flavor = flavor 228 | return 229 | } 230 | 231 | // Flavor returns flavor of builder 232 | func (db *DeleteBuilder) Flavor() Flavor { 233 | return db.args.Flavor 234 | } 235 | 236 | // SQL adds an arbitrary sql to current position. 237 | func (db *DeleteBuilder) SQL(sql string) *DeleteBuilder { 238 | db.injection.SQL(db.marker, sql) 239 | return db 240 | } 241 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleDeleteFrom() { 14 | sql := DeleteFrom("demo.user"). 15 | Where( 16 | "status = 1", 17 | ). 18 | Limit(10). 19 | String() 20 | 21 | fmt.Println(sql) 22 | 23 | // Output: 24 | // DELETE FROM demo.user WHERE status = 1 LIMIT ? 25 | } 26 | 27 | func ExampleDeleteBuilder() { 28 | db := NewDeleteBuilder() 29 | db.DeleteFrom("demo.user") 30 | db.Where( 31 | db.GreaterThan("id", 1234), 32 | db.Like("name", "%Du"), 33 | db.Or( 34 | db.IsNull("id_card"), 35 | db.In("status", 1, 2, 5), 36 | ), 37 | "modified_at > created_at + "+db.Var(86400), // It's allowed to write arbitrary SQL. 38 | ) 39 | 40 | sql, args := db.Build() 41 | fmt.Println(sql) 42 | fmt.Println(args) 43 | 44 | // Output: 45 | // DELETE FROM demo.user WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND modified_at > created_at + ? 46 | // [1234 %Du 1 2 5 86400] 47 | } 48 | 49 | func ExampleDeleteBuilder_SQL() { 50 | db := NewDeleteBuilder() 51 | db.SQL(`/* before */`) 52 | db.DeleteFrom("demo.user") 53 | db.SQL("PARTITION (p0)") 54 | db.Where( 55 | db.GreaterThan("id", 1234), 56 | ) 57 | db.SQL("/* after where */") 58 | db.OrderBy("id") 59 | db.SQL("/* after order by */") 60 | db.Limit(10) 61 | db.SQL("/* after limit */") 62 | 63 | sql, args := db.Build() 64 | fmt.Println(sql) 65 | fmt.Println(args) 66 | 67 | // Output: 68 | // /* before */ DELETE FROM demo.user PARTITION (p0) WHERE id > ? /* after where */ ORDER BY id /* after order by */ LIMIT ? /* after limit */ 69 | // [1234 10] 70 | } 71 | 72 | func ExampleDeleteBuilder_With() { 73 | sql := With( 74 | CTEQuery("users").As( 75 | Select("id", "name").From("users").Where("name IS NULL"), 76 | ), 77 | ).DeleteFrom("orders").Where( 78 | "users.id = orders.user_id", 79 | ).String() 80 | 81 | fmt.Println(sql) 82 | 83 | // Output: 84 | // WITH users AS (SELECT id, name FROM users WHERE name IS NULL) DELETE FROM orders WHERE users.id = orders.user_id 85 | } 86 | 87 | func TestDeleteBuilderGetFlavor(t *testing.T) { 88 | a := assert.New(t) 89 | db := newDeleteBuilder() 90 | 91 | db.SetFlavor(PostgreSQL) 92 | flavor := db.Flavor() 93 | a.Equal(PostgreSQL, flavor) 94 | 95 | dbClick := ClickHouse.NewDeleteBuilder() 96 | flavor = dbClick.Flavor() 97 | a.Equal(ClickHouse, flavor) 98 | } 99 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | // Package sqlbuilder is a flexible and powerful tool to build SQL string and associated args. 5 | package sqlbuilder 6 | -------------------------------------------------------------------------------- /fieldmapper.go: -------------------------------------------------------------------------------- 1 | package sqlbuilder 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/huandu/xstrings" 7 | ) 8 | 9 | var ( 10 | // DefaultFieldMapper is the default field name to table column name mapper func. 11 | // It's nil by default which means field name will be kept as it is. 12 | // 13 | // If a Struct has its own mapper func, the DefaultFieldMapper is ignored in this Struct. 14 | // Field tag has precedence over all kinds of field mapper functions. 15 | // 16 | // Field mapper is called only once on a Struct when the Struct is used to create builder for the first time. 17 | DefaultFieldMapper FieldMapperFunc 18 | 19 | // DefaultGetAlias is the default alias and dbtag get func 20 | DefaultGetAlias GetAliasFunc 21 | ) 22 | 23 | func init() { 24 | DefaultGetAlias = func(field *reflect.StructField) (alias string, dbtag string) { 25 | dbtag = field.Tag.Get(DBTag) 26 | alias = dbtag 27 | return 28 | } 29 | } 30 | 31 | // FieldMapperFunc is a func to map struct field names to column names, 32 | // which will be used in query as columns. 33 | type FieldMapperFunc func(name string) string 34 | 35 | // SnakeCaseMapper is a field mapper which can convert field name from CamelCase to snake_case. 36 | // 37 | // For instance, it will convert "MyField" to "my_field". 38 | // 39 | // SnakeCaseMapper uses package "xstrings" to do the conversion. 40 | // See https://pkg.go.dev/github.com/huandu/xstrings#ToSnakeCase for conversion rules. 41 | func SnakeCaseMapper(field string) string { 42 | return xstrings.ToSnakeCase(field) 43 | } 44 | 45 | // GetAliasFunc is a func to get alias and dbtag 46 | type GetAliasFunc func(field *reflect.StructField) (alias string, dbtag string) 47 | -------------------------------------------------------------------------------- /flavor.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | ) 10 | 11 | // Supported flavors. 12 | const ( 13 | invalidFlavor Flavor = iota 14 | 15 | MySQL 16 | PostgreSQL 17 | SQLite 18 | SQLServer 19 | CQL 20 | ClickHouse 21 | Presto 22 | Oracle 23 | Informix 24 | Doris 25 | ) 26 | 27 | var ( 28 | // DefaultFlavor is the default flavor for all builders. 29 | DefaultFlavor = MySQL 30 | ) 31 | 32 | var ( 33 | // ErrInterpolateNotImplemented means the method or feature is not implemented right now. 34 | ErrInterpolateNotImplemented = errors.New("go-sqlbuilder: interpolation for this flavor is not implemented") 35 | 36 | // ErrInterpolateMissingArgs means there are some args missing in query, so it's not possible to 37 | // prepare a query with such args. 38 | ErrInterpolateMissingArgs = errors.New("go-sqlbuilder: not enough args when interpolating") 39 | 40 | // ErrInterpolateUnsupportedArgs means that some types of the args are not supported. 41 | ErrInterpolateUnsupportedArgs = errors.New("go-sqlbuilder: unsupported args when interpolating") 42 | ) 43 | 44 | // Flavor is the flag to control the format of compiled sql. 45 | type Flavor int 46 | 47 | // String returns the name of f. 48 | func (f Flavor) String() string { 49 | switch f { 50 | case MySQL: 51 | return "MySQL" 52 | case PostgreSQL: 53 | return "PostgreSQL" 54 | case SQLite: 55 | return "SQLite" 56 | case SQLServer: 57 | return "SQLServer" 58 | case CQL: 59 | return "CQL" 60 | case ClickHouse: 61 | return "ClickHouse" 62 | case Presto: 63 | return "Presto" 64 | case Oracle: 65 | return "Oracle" 66 | case Informix: 67 | return "Informix" 68 | case Doris: 69 | return "Doris" 70 | } 71 | 72 | return "" 73 | } 74 | 75 | // Interpolate parses sql returned by `Args#Compile` or `Builder`, 76 | // and interpolate args to replace placeholders in the sql. 77 | // 78 | // If there are some args missing in sql, e.g. the number of placeholders are larger than len(args), 79 | // returns ErrMissingArgs error. 80 | func (f Flavor) Interpolate(sql string, args []interface{}) (string, error) { 81 | switch f { 82 | case MySQL: 83 | return mysqlInterpolate(sql, args...) 84 | case PostgreSQL: 85 | return postgresqlInterpolate(sql, args...) 86 | case SQLite: 87 | return sqliteInterpolate(sql, args...) 88 | case SQLServer: 89 | return sqlserverInterpolate(sql, args...) 90 | case CQL: 91 | return cqlInterpolate(sql, args...) 92 | case ClickHouse: 93 | return clickhouseInterpolate(sql, args...) 94 | case Presto: 95 | return prestoInterpolate(sql, args...) 96 | case Oracle: 97 | return oracleInterpolate(sql, args...) 98 | case Informix: 99 | return informixInterpolate(sql, args...) 100 | case Doris: 101 | return dorisInterpolate(sql, args...) 102 | } 103 | 104 | return "", ErrInterpolateNotImplemented 105 | } 106 | 107 | // NewCreateTableBuilder creates a new CREATE TABLE builder with flavor. 108 | func (f Flavor) NewCreateTableBuilder() *CreateTableBuilder { 109 | b := newCreateTableBuilder() 110 | b.SetFlavor(f) 111 | return b 112 | } 113 | 114 | // NewDeleteBuilder creates a new DELETE builder with flavor. 115 | func (f Flavor) NewDeleteBuilder() *DeleteBuilder { 116 | b := newDeleteBuilder() 117 | b.SetFlavor(f) 118 | return b 119 | } 120 | 121 | // NewInsertBuilder creates a new INSERT builder with flavor. 122 | func (f Flavor) NewInsertBuilder() *InsertBuilder { 123 | b := newInsertBuilder() 124 | b.SetFlavor(f) 125 | return b 126 | } 127 | 128 | // NewSelectBuilder creates a new SELECT builder with flavor. 129 | func (f Flavor) NewSelectBuilder() *SelectBuilder { 130 | b := newSelectBuilder() 131 | b.SetFlavor(f) 132 | return b 133 | } 134 | 135 | // NewUpdateBuilder creates a new UPDATE builder with flavor. 136 | func (f Flavor) NewUpdateBuilder() *UpdateBuilder { 137 | b := newUpdateBuilder() 138 | b.SetFlavor(f) 139 | return b 140 | } 141 | 142 | // NewUnionBuilder creates a new UNION builder with flavor. 143 | func (f Flavor) NewUnionBuilder() *UnionBuilder { 144 | b := newUnionBuilder() 145 | b.SetFlavor(f) 146 | return b 147 | } 148 | 149 | // NewCTEBuilder creates a new CTE builder with flavor. 150 | func (f Flavor) NewCTEBuilder() *CTEBuilder { 151 | b := newCTEBuilder() 152 | b.SetFlavor(f) 153 | return b 154 | } 155 | 156 | // NewCTETableBuilder creates a new CTE table builder with flavor. 157 | func (f Flavor) NewCTEQueryBuilder() *CTEQueryBuilder { 158 | b := newCTEQueryBuilder() 159 | b.SetFlavor(f) 160 | return b 161 | } 162 | 163 | // Quote adds quote for name to make sure the name can be used safely 164 | // as table name or field name. 165 | // 166 | // - For MySQL, use back quote (`) to quote name; 167 | // - For PostgreSQL, SQL Server and SQLite, use double quote (") to quote name. 168 | func (f Flavor) Quote(name string) string { 169 | switch f { 170 | case MySQL, ClickHouse, Doris: 171 | return fmt.Sprintf("`%s`", name) 172 | case PostgreSQL, SQLServer, SQLite, Presto, Oracle, Informix: 173 | return fmt.Sprintf(`"%s"`, name) 174 | case CQL: 175 | return fmt.Sprintf("'%s'", name) 176 | } 177 | 178 | return name 179 | } 180 | 181 | // PrepareInsertIgnore prepares the insert builder to build insert ignore SQL statement based on the sql flavor 182 | func (f Flavor) PrepareInsertIgnore(table string, ib *InsertBuilder) { 183 | switch ib.args.Flavor { 184 | case MySQL, Oracle: 185 | ib.verb = "INSERT IGNORE" 186 | 187 | case PostgreSQL: 188 | // see https://www.postgresql.org/docs/current/sql-insert.html 189 | ib.verb = "INSERT" 190 | // add sql statement at the end after values, i.e. INSERT INTO ... ON CONFLICT DO NOTHING 191 | ib.marker = insertMarkerAfterValues 192 | ib.SQL("ON CONFLICT DO NOTHING") 193 | 194 | case SQLite: 195 | // see https://www.sqlite.org/lang_insert.html 196 | ib.verb = "INSERT OR IGNORE" 197 | 198 | case ClickHouse, CQL, SQLServer, Presto, Informix, Doris: 199 | // All other databases do not support insert ignore 200 | ib.verb = "INSERT" 201 | 202 | default: 203 | // panic if the db flavor is not supported 204 | panic(fmt.Errorf("unsupported db flavor: %s", ib.args.Flavor.String())) 205 | } 206 | 207 | // Set the table and reset the marker right after insert into 208 | ib.table = Escape(table) 209 | ib.marker = insertMarkerAfterInsertInto 210 | } 211 | -------------------------------------------------------------------------------- /flavor_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func TestFlavor(t *testing.T) { 14 | a := assert.New(t) 15 | cases := map[Flavor]string{ 16 | 0: "", 17 | MySQL: "MySQL", 18 | PostgreSQL: "PostgreSQL", 19 | SQLite: "SQLite", 20 | SQLServer: "SQLServer", 21 | CQL: "CQL", 22 | ClickHouse: "ClickHouse", 23 | Oracle: "Oracle", 24 | Informix: "Informix", 25 | Doris: "Doris", 26 | } 27 | 28 | for f, expected := range cases { 29 | actual := f.String() 30 | a.Equal(actual, expected) 31 | } 32 | } 33 | 34 | func ExampleFlavor() { 35 | // Create a flavored builder. 36 | sb := PostgreSQL.NewSelectBuilder() 37 | sb.Select("name").From("user").Where( 38 | sb.E("id", 1234), 39 | sb.G("rank", 3), 40 | ) 41 | sql, args := sb.Build() 42 | 43 | fmt.Println(sql) 44 | fmt.Println(args) 45 | 46 | // Output: 47 | // SELECT name FROM user WHERE id = $1 AND rank > $2 48 | // [1234 3] 49 | } 50 | 51 | func ExampleFlavor_Interpolate_mySQL() { 52 | sb := MySQL.NewSelectBuilder() 53 | sb.Select("name").From("user").Where( 54 | sb.NE("id", 1234), 55 | sb.E("name", "Charmy Liu"), 56 | sb.Like("desc", "%mother's day%"), 57 | ) 58 | sql, args := sb.Build() 59 | query, err := MySQL.Interpolate(sql, args) 60 | 61 | fmt.Println(query) 62 | fmt.Println(err) 63 | 64 | // Output: 65 | // SELECT name FROM user WHERE id <> 1234 AND name = 'Charmy Liu' AND desc LIKE '%mother\'s day%' 66 | // 67 | } 68 | 69 | func ExampleFlavor_Interpolate_postgreSQL() { 70 | // Only the last `$1` is interpolated. 71 | // Others are not interpolated as they are inside dollar quote (the `$$`). 72 | query, err := PostgreSQL.Interpolate(` 73 | CREATE FUNCTION dup(in int, out f1 int, out f2 text) AS $$ 74 | SELECT $1, CAST($1 AS text) || ' is text' 75 | $$ 76 | LANGUAGE SQL; 77 | 78 | SELECT * FROM dup($1);`, []interface{}{42}) 79 | 80 | fmt.Println(query) 81 | fmt.Println(err) 82 | 83 | // Output: 84 | // 85 | // CREATE FUNCTION dup(in int, out f1 int, out f2 text) AS $$ 86 | // SELECT $1, CAST($1 AS text) || ' is text' 87 | // $$ 88 | // LANGUAGE SQL; 89 | // 90 | // SELECT * FROM dup(42); 91 | // 92 | } 93 | 94 | func ExampleFlavor_Interpolate_sqlite() { 95 | sb := SQLite.NewSelectBuilder() 96 | sb.Select("name").From("user").Where( 97 | sb.NE("id", 1234), 98 | sb.E("name", "Charmy Liu"), 99 | sb.Like("desc", "%mother's day%"), 100 | ) 101 | sql, args := sb.Build() 102 | query, err := SQLite.Interpolate(sql, args) 103 | 104 | fmt.Println(query) 105 | fmt.Println(err) 106 | 107 | // Output: 108 | // SELECT name FROM user WHERE id <> 1234 AND name = 'Charmy Liu' AND desc LIKE '%mother\'s day%' 109 | // 110 | } 111 | 112 | func ExampleFlavor_Interpolate_sqlServer() { 113 | sb := SQLServer.NewSelectBuilder() 114 | sb.Select("name").From("user").Where( 115 | sb.NE("id", 1234), 116 | sb.E("name", "Charmy Liu"), 117 | sb.Like("desc", "%mother's day%"), 118 | ) 119 | sql, args := sb.Build() 120 | query, err := SQLServer.Interpolate(sql, args) 121 | 122 | fmt.Println(query) 123 | fmt.Println(err) 124 | 125 | // Output: 126 | // SELECT name FROM user WHERE id <> 1234 AND name = N'Charmy Liu' AND desc LIKE N'%mother\'s day%' 127 | // 128 | } 129 | 130 | func ExampleFlavor_Interpolate_cql() { 131 | sb := CQL.NewSelectBuilder() 132 | sb.Select("name").From("user").Where( 133 | sb.E("id", 1234), 134 | sb.E("name", "Charmy Liu"), 135 | ) 136 | sql, args := sb.Build() 137 | query, err := CQL.Interpolate(sql, args) 138 | 139 | fmt.Println(query) 140 | fmt.Println(err) 141 | 142 | // Output: 143 | // SELECT name FROM user WHERE id = 1234 AND name = 'Charmy Liu' 144 | // 145 | } 146 | 147 | func ExampleFlavor_Interpolate_oracle() { 148 | sb := Oracle.NewSelectBuilder() 149 | sb.Select("name").From("user").Where( 150 | sb.E("id", 1234), 151 | sb.E("name", "Charmy Liu"), 152 | sb.E("enabled", true), 153 | ) 154 | sql, args := sb.Build() 155 | query, err := Oracle.Interpolate(sql, args) 156 | 157 | fmt.Println(query) 158 | fmt.Println(err) 159 | 160 | // Output: 161 | // SELECT name FROM user WHERE id = 1234 AND name = 'Charmy Liu' AND enabled = 1 162 | // 163 | } 164 | 165 | func ExampleFlavor_Interpolate_infomix() { 166 | sb := Informix.NewSelectBuilder() 167 | sb.Select("name").From("user").Where( 168 | sb.NE("id", 1234), 169 | sb.E("name", "Charmy Liu"), 170 | sb.E("enabled", true), 171 | ) 172 | sql, args := sb.Build() 173 | query, err := Informix.Interpolate(sql, args) 174 | 175 | fmt.Println(query) 176 | fmt.Println(err) 177 | 178 | // Output: 179 | // SELECT name FROM user WHERE id <> 1234 AND name = 'Charmy Liu' AND enabled = TRUE 180 | // 181 | } 182 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/huandu/go-sqlbuilder 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/huandu/go-assert v1.1.6 7 | github.com/huandu/xstrings v1.4.0 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/huandu/go-assert v1.1.6 h1:oaAfYxq9KNDi9qswn/6aE0EydfxSa+tWZC1KabNitYs= 4 | github.com/huandu/go-assert v1.1.6/go.mod h1:JuIfbmYG9ykwvuxoJ3V8TB5QP+3+ajIA54Y44TmkMxs= 5 | github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= 6 | github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= 7 | -------------------------------------------------------------------------------- /injection.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | // injection is a helper type to manage injected SQLs in all builders. 7 | type injection struct { 8 | markerSQLs map[injectionMarker][]string 9 | } 10 | 11 | type injectionMarker int 12 | 13 | // newInjection creates a new injection. 14 | func newInjection() *injection { 15 | return &injection{ 16 | markerSQLs: map[injectionMarker][]string{}, 17 | } 18 | } 19 | 20 | // SQL adds sql to injection's sql list. 21 | // All sqls inside injection is ordered by marker in ascending order. 22 | func (injection *injection) SQL(marker injectionMarker, sql string) { 23 | injection.markerSQLs[marker] = append(injection.markerSQLs[marker], sql) 24 | } 25 | 26 | // WriteTo joins all SQL strings at the same marker value with blank (" ") 27 | // and writes the joined value to buf. 28 | func (injection *injection) WriteTo(buf *stringBuilder, marker injectionMarker) { 29 | sqls := injection.markerSQLs[marker] 30 | 31 | if len(sqls) == 0 { 32 | return 33 | } 34 | 35 | buf.WriteLeadingString("") 36 | buf.WriteStrings(sqls, " ") 37 | } 38 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | const ( 12 | insertMarkerInit injectionMarker = iota 13 | insertMarkerAfterInsertInto 14 | insertMarkerAfterCols 15 | insertMarkerAfterValues 16 | insertMarkerAfterSelect 17 | insertMarkerAfterReturning 18 | ) 19 | 20 | // NewInsertBuilder creates a new INSERT builder. 21 | func NewInsertBuilder() *InsertBuilder { 22 | return DefaultFlavor.NewInsertBuilder() 23 | } 24 | 25 | func newInsertBuilder() *InsertBuilder { 26 | args := &Args{} 27 | return &InsertBuilder{ 28 | verb: "INSERT", 29 | args: args, 30 | injection: newInjection(), 31 | } 32 | } 33 | 34 | // InsertBuilder is a builder to build INSERT. 35 | type InsertBuilder struct { 36 | verb string 37 | table string 38 | cols []string 39 | values [][]string 40 | returning []string 41 | 42 | args *Args 43 | 44 | injection *injection 45 | marker injectionMarker 46 | 47 | sbHolder string 48 | } 49 | 50 | var _ Builder = new(InsertBuilder) 51 | 52 | // InsertInto sets table name in INSERT. 53 | func InsertInto(table string) *InsertBuilder { 54 | return DefaultFlavor.NewInsertBuilder().InsertInto(table) 55 | } 56 | 57 | // InsertInto sets table name in INSERT. 58 | func (ib *InsertBuilder) InsertInto(table string) *InsertBuilder { 59 | ib.table = Escape(table) 60 | ib.marker = insertMarkerAfterInsertInto 61 | return ib 62 | } 63 | 64 | // InsertIgnoreInto sets table name in INSERT IGNORE. 65 | func InsertIgnoreInto(table string) *InsertBuilder { 66 | return DefaultFlavor.NewInsertBuilder().InsertIgnoreInto(table) 67 | } 68 | 69 | // InsertIgnoreInto sets table name in INSERT IGNORE. 70 | func (ib *InsertBuilder) InsertIgnoreInto(table string) *InsertBuilder { 71 | ib.args.Flavor.PrepareInsertIgnore(table, ib) 72 | return ib 73 | } 74 | 75 | // ReplaceInto sets table name and changes the verb of ib to REPLACE. 76 | // REPLACE INTO is a MySQL extension to the SQL standard. 77 | func ReplaceInto(table string) *InsertBuilder { 78 | return DefaultFlavor.NewInsertBuilder().ReplaceInto(table) 79 | } 80 | 81 | // ReplaceInto sets table name and changes the verb of ib to REPLACE. 82 | // REPLACE INTO is a MySQL extension to the SQL standard. 83 | func (ib *InsertBuilder) ReplaceInto(table string) *InsertBuilder { 84 | ib.verb = "REPLACE" 85 | ib.table = Escape(table) 86 | ib.marker = insertMarkerAfterInsertInto 87 | return ib 88 | } 89 | 90 | // Cols sets columns in INSERT. 91 | func (ib *InsertBuilder) Cols(col ...string) *InsertBuilder { 92 | ib.cols = EscapeAll(col...) 93 | ib.marker = insertMarkerAfterCols 94 | return ib 95 | } 96 | 97 | // Select returns a new SelectBuilder to build a SELECT statement inside the INSERT INTO. 98 | func (isb *InsertBuilder) Select(col ...string) *SelectBuilder { 99 | sb := Select(col...) 100 | isb.sbHolder = isb.args.Add(sb) 101 | return sb 102 | } 103 | 104 | // Values adds a list of values for a row in INSERT. 105 | func (ib *InsertBuilder) Values(value ...interface{}) *InsertBuilder { 106 | placeholders := make([]string, 0, len(value)) 107 | 108 | for _, v := range value { 109 | placeholders = append(placeholders, ib.args.Add(v)) 110 | } 111 | 112 | ib.values = append(ib.values, placeholders) 113 | ib.marker = insertMarkerAfterValues 114 | return ib 115 | } 116 | 117 | // Returning sets returning columns. 118 | // For DBMS that doesn't support RETURNING, e.g. MySQL, it will be ignored. 119 | func (ib *InsertBuilder) Returning(col ...string) *InsertBuilder { 120 | ib.returning = col 121 | ib.marker = insertMarkerAfterReturning 122 | return ib 123 | } 124 | 125 | // NumValue returns the number of values to insert. 126 | func (ib *InsertBuilder) NumValue() int { 127 | return len(ib.values) 128 | } 129 | 130 | // String returns the compiled INSERT string. 131 | func (ib *InsertBuilder) String() string { 132 | s, _ := ib.Build() 133 | return s 134 | } 135 | 136 | // Build returns compiled INSERT string and args. 137 | // They can be used in `DB#Query` of package `database/sql` directly. 138 | func (ib *InsertBuilder) Build() (sql string, args []interface{}) { 139 | return ib.BuildWithFlavor(ib.args.Flavor) 140 | } 141 | 142 | // BuildWithFlavor returns compiled INSERT string and args with flavor and initial args. 143 | // They can be used in `DB#Query` of package `database/sql` directly. 144 | func (ib *InsertBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 145 | buf := newStringBuilder() 146 | ib.injection.WriteTo(buf, insertMarkerInit) 147 | 148 | if len(ib.values) > 1 && ib.args.Flavor == Oracle { 149 | buf.WriteLeadingString(ib.verb) 150 | buf.WriteString(" ALL") 151 | 152 | for _, v := range ib.values { 153 | if len(ib.table) > 0 { 154 | buf.WriteString(" INTO ") 155 | buf.WriteString(ib.table) 156 | } 157 | ib.injection.WriteTo(buf, insertMarkerAfterInsertInto) 158 | if len(ib.cols) > 0 { 159 | buf.WriteLeadingString("(") 160 | buf.WriteStrings(ib.cols, ", ") 161 | buf.WriteString(")") 162 | 163 | ib.injection.WriteTo(buf, insertMarkerAfterCols) 164 | } 165 | 166 | buf.WriteLeadingString("VALUES ") 167 | values := make([]string, 0, len(ib.values)) 168 | values = append(values, fmt.Sprintf("(%v)", strings.Join(v, ", "))) 169 | buf.WriteStrings(values, ", ") 170 | } 171 | 172 | buf.WriteString(" SELECT 1 from DUAL") 173 | 174 | ib.injection.WriteTo(buf, insertMarkerAfterValues) 175 | 176 | return ib.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 177 | } 178 | 179 | if len(ib.table) > 0 { 180 | buf.WriteLeadingString(ib.verb) 181 | buf.WriteString(" INTO ") 182 | buf.WriteString(ib.table) 183 | } 184 | 185 | ib.injection.WriteTo(buf, insertMarkerAfterInsertInto) 186 | 187 | if len(ib.cols) > 0 { 188 | buf.WriteLeadingString("(") 189 | buf.WriteStrings(ib.cols, ", ") 190 | buf.WriteString(")") 191 | 192 | ib.injection.WriteTo(buf, insertMarkerAfterCols) 193 | } 194 | 195 | if ib.sbHolder != "" { 196 | buf.WriteString(" ") 197 | buf.WriteString(ib.sbHolder) 198 | 199 | ib.injection.WriteTo(buf, insertMarkerAfterSelect) 200 | return ib.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 201 | } 202 | 203 | if len(ib.values) > 0 { 204 | buf.WriteLeadingString("VALUES ") 205 | values := make([]string, 0, len(ib.values)) 206 | 207 | for _, v := range ib.values { 208 | values = append(values, fmt.Sprintf("(%v)", strings.Join(v, ", "))) 209 | } 210 | 211 | buf.WriteStrings(values, ", ") 212 | } 213 | 214 | ib.injection.WriteTo(buf, insertMarkerAfterValues) 215 | 216 | if flavor == PostgreSQL || flavor == SQLite { 217 | if len(ib.returning) > 0 { 218 | buf.WriteLeadingString("RETURNING ") 219 | buf.WriteStrings(ib.returning, ", ") 220 | } 221 | 222 | ib.injection.WriteTo(buf, insertMarkerAfterReturning) 223 | } 224 | 225 | return ib.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 226 | } 227 | 228 | // SetFlavor sets the flavor of compiled sql. 229 | func (ib *InsertBuilder) SetFlavor(flavor Flavor) (old Flavor) { 230 | old = ib.args.Flavor 231 | ib.args.Flavor = flavor 232 | return 233 | } 234 | 235 | // Flavor returns flavor of builder 236 | func (ib *InsertBuilder) Flavor() Flavor { 237 | return ib.args.Flavor 238 | } 239 | 240 | // Var returns a placeholder for value. 241 | func (ib *InsertBuilder) Var(arg interface{}) string { 242 | return ib.args.Add(arg) 243 | } 244 | 245 | // SQL adds an arbitrary sql to current position. 246 | func (ib *InsertBuilder) SQL(sql string) *InsertBuilder { 247 | ib.injection.SQL(ib.marker, sql) 248 | return ib 249 | } 250 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleInsertInto() { 14 | sql, args := InsertInto("demo.user"). 15 | Cols("id", "name", "status"). 16 | Values(4, "Sample", 2). 17 | Build() 18 | 19 | fmt.Println(sql) 20 | fmt.Println(args) 21 | 22 | // Output: 23 | // INSERT INTO demo.user (id, name, status) VALUES (?, ?, ?) 24 | // [4 Sample 2] 25 | } 26 | 27 | func ExampleInsertIgnoreInto() { 28 | sql, args := InsertIgnoreInto("demo.user"). 29 | Cols("id", "name", "status"). 30 | Values(4, "Sample", 2). 31 | Build() 32 | 33 | fmt.Println(sql) 34 | fmt.Println(args) 35 | 36 | // Output: 37 | // INSERT IGNORE INTO demo.user (id, name, status) VALUES (?, ?, ?) 38 | // [4 Sample 2] 39 | } 40 | 41 | func ExampleReplaceInto() { 42 | sql, args := ReplaceInto("demo.user"). 43 | Cols("id", "name", "status"). 44 | Values(4, "Sample", 2). 45 | Build() 46 | 47 | fmt.Println(sql) 48 | fmt.Println(args) 49 | 50 | // Output: 51 | // REPLACE INTO demo.user (id, name, status) VALUES (?, ?, ?) 52 | // [4 Sample 2] 53 | } 54 | 55 | func ExampleInsertBuilder() { 56 | ib := NewInsertBuilder() 57 | ib.InsertInto("demo.user") 58 | ib.Cols("id", "name", "status", "created_at", "updated_at") 59 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 60 | ib.Values(2, "Charmy Liu", 1, 1234567890) 61 | 62 | sql, args := ib.Build() 63 | fmt.Println(sql) 64 | fmt.Println(args) 65 | 66 | // Output: 67 | // INSERT INTO demo.user (id, name, status, created_at, updated_at) VALUES (?, ?, ?, UNIX_TIMESTAMP(NOW())), (?, ?, ?, ?) 68 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 69 | } 70 | 71 | func ExampleInsertBuilder_flavorOracle() { 72 | ib := Oracle.NewInsertBuilder() 73 | ib.InsertInto("demo.user") 74 | ib.Cols("id", "name", "status") 75 | ib.Values(1, "Huan Du", 1) 76 | ib.Values(2, "Charmy Liu", 1) 77 | 78 | sql, args := ib.Build() 79 | fmt.Println(sql) 80 | fmt.Println(args) 81 | 82 | // Output: 83 | // INSERT ALL INTO demo.user (id, name, status) VALUES (:1, :2, :3) INTO demo.user (id, name, status) VALUES (:4, :5, :6) SELECT 1 from DUAL 84 | // [1 Huan Du 1 2 Charmy Liu 1] 85 | } 86 | 87 | func ExampleInsertBuilder_insertIgnore() { 88 | ib := NewInsertBuilder() 89 | ib.InsertIgnoreInto("demo.user") 90 | ib.Cols("id", "name", "status", "created_at", "updated_at") 91 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 92 | ib.Values(2, "Charmy Liu", 1, 1234567890) 93 | 94 | sql, args := ib.Build() 95 | fmt.Println(sql) 96 | fmt.Println(args) 97 | 98 | // Output: 99 | // INSERT IGNORE INTO demo.user (id, name, status, created_at, updated_at) VALUES (?, ?, ?, UNIX_TIMESTAMP(NOW())), (?, ?, ?, ?) 100 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 101 | } 102 | 103 | func ExampleInsertBuilder_insertIgnore_postgres() { 104 | ib := PostgreSQL.NewInsertBuilder() 105 | ib.InsertIgnoreInto("demo.user") 106 | ib.Cols("id", "name", "status", "created_at") 107 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 108 | ib.Values(2, "Charmy Liu", 1, 1234567890) 109 | 110 | sql, args := ib.Build() 111 | fmt.Println(sql) 112 | fmt.Println(args) 113 | 114 | // Output: 115 | // INSERT INTO demo.user (id, name, status, created_at) VALUES ($1, $2, $3, UNIX_TIMESTAMP(NOW())), ($4, $5, $6, $7) ON CONFLICT DO NOTHING 116 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 117 | } 118 | 119 | func ExampleInsertBuilder_insertIgnore_sqlite() { 120 | ib := SQLite.NewInsertBuilder() 121 | ib.InsertIgnoreInto("demo.user") 122 | ib.Cols("id", "name", "status", "created_at") 123 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 124 | ib.Values(2, "Charmy Liu", 1, 1234567890) 125 | 126 | sql, args := ib.Build() 127 | fmt.Println(sql) 128 | fmt.Println(args) 129 | 130 | // Output: 131 | // INSERT OR IGNORE INTO demo.user (id, name, status, created_at) VALUES (?, ?, ?, UNIX_TIMESTAMP(NOW())), (?, ?, ?, ?) 132 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 133 | } 134 | 135 | func ExampleInsertBuilder_insertIgnore_clickhouse() { 136 | ib := ClickHouse.NewInsertBuilder() 137 | ib.InsertIgnoreInto("demo.user") 138 | ib.Cols("id", "name", "status", "created_at") 139 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 140 | ib.Values(2, "Charmy Liu", 1, 1234567890) 141 | 142 | sql, args := ib.Build() 143 | fmt.Println(sql) 144 | fmt.Println(args) 145 | 146 | // Output: 147 | // INSERT INTO demo.user (id, name, status, created_at) VALUES (?, ?, ?, UNIX_TIMESTAMP(NOW())), (?, ?, ?, ?) 148 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 149 | } 150 | 151 | func ExampleInsertBuilder_replaceInto() { 152 | ib := NewInsertBuilder() 153 | ib.ReplaceInto("demo.user") 154 | ib.Cols("id", "name", "status", "created_at", "updated_at") 155 | ib.Values(1, "Huan Du", 1, Raw("UNIX_TIMESTAMP(NOW())")) 156 | ib.Values(2, "Charmy Liu", 1, 1234567890) 157 | 158 | sql, args := ib.Build() 159 | fmt.Println(sql) 160 | fmt.Println(args) 161 | 162 | // Output: 163 | // REPLACE INTO demo.user (id, name, status, created_at, updated_at) VALUES (?, ?, ?, UNIX_TIMESTAMP(NOW())), (?, ?, ?, ?) 164 | // [1 Huan Du 1 2 Charmy Liu 1 1234567890] 165 | } 166 | 167 | func ExampleInsertBuilder_SQL() { 168 | ib := NewInsertBuilder() 169 | ib.SQL("/* before */") 170 | ib.InsertInto("demo.user") 171 | ib.SQL("PARTITION (p0)") 172 | ib.Cols("id", "name", "status", "created_at") 173 | ib.SQL("/* after cols */") 174 | ib.Values(3, "Shawn Du", 1, 1234567890) 175 | ib.SQL(ib.Var(Build("ON DUPLICATE KEY UPDATE status = $?", 1))) 176 | 177 | sql, args := ib.Build() 178 | fmt.Println(sql) 179 | fmt.Println(args) 180 | 181 | // Output: 182 | // /* before */ INSERT INTO demo.user PARTITION (p0) (id, name, status, created_at) /* after cols */ VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE status = ? 183 | // [3 Shawn Du 1 1234567890 1] 184 | } 185 | 186 | func ExampleInsertBuilder_subSelect() { 187 | ib := NewInsertBuilder() 188 | ib.InsertInto("demo.user") 189 | ib.Cols("id", "name") 190 | sb := ib.Select("id", "name").From("demo.test") 191 | sb.Where(sb.EQ("id", 1)) 192 | 193 | sql, args := ib.Build() 194 | fmt.Println(sql) 195 | fmt.Println(args) 196 | 197 | // Output: 198 | // INSERT INTO demo.user (id, name) SELECT id, name FROM demo.test WHERE id = ? 199 | // [1] 200 | } 201 | 202 | func ExampleInsertBuilder_subSelect_oracle() { 203 | ib := Oracle.NewInsertBuilder() 204 | ib.InsertInto("demo.user") 205 | ib.Cols("id", "name") 206 | sb := ib.Select("id", "name").From("demo.test") 207 | sb.Where(sb.EQ("id", 1)) 208 | 209 | sql, args := ib.Build() 210 | fmt.Println(sql) 211 | fmt.Println(args) 212 | 213 | // Output: 214 | // INSERT INTO demo.user (id, name) SELECT id, name FROM demo.test WHERE id = :1 215 | // [1] 216 | } 217 | 218 | func ExampleInsertBuilder_subSelect_informix() { 219 | ib := Informix.NewInsertBuilder() 220 | ib.InsertInto("demo.user") 221 | ib.Cols("id", "name") 222 | sb := ib.Select("id", "name").From("demo.test") 223 | sb.Where(sb.EQ("id", 1)) 224 | 225 | sql, args := ib.Build() 226 | fmt.Println(sql) 227 | fmt.Println(args) 228 | 229 | // Output: 230 | // INSERT INTO demo.user (id, name) SELECT id, name FROM demo.test WHERE id = ? 231 | // [1] 232 | } 233 | 234 | func ExampleInsertBuilder_NumValue() { 235 | ib := NewInsertBuilder() 236 | ib.InsertInto("demo.user") 237 | ib.Cols("id", "name") 238 | ib.Values(1, "Huan Du") 239 | ib.Values(2, "Charmy Liu") 240 | 241 | // Count the number of values. 242 | fmt.Println(ib.NumValue()) 243 | 244 | // Output: 245 | // 2 246 | } 247 | 248 | func ExampleInsertBuilder_Returning() { 249 | sql, args := InsertInto("user"). 250 | Cols("name").Values("Huan Du"). 251 | Returning("id"). 252 | BuildWithFlavor(PostgreSQL) 253 | 254 | fmt.Println(sql) 255 | fmt.Println(args) 256 | 257 | // Output: 258 | // INSERT INTO user (name) VALUES ($1) RETURNING id 259 | // [Huan Du] 260 | } 261 | 262 | func TestInsertBuilderReturning(test *testing.T) { 263 | a := assert.New(test) 264 | ib := InsertInto("user"). 265 | Cols("name").Values("Huan Du"). 266 | Returning("id") 267 | 268 | sql, _ := ib.BuildWithFlavor(MySQL) 269 | a.Equal("INSERT INTO user (name) VALUES (?)", sql) 270 | 271 | sql, _ = ib.BuildWithFlavor(PostgreSQL) 272 | a.Equal("INSERT INTO user (name) VALUES ($1) RETURNING id", sql) 273 | 274 | sql, _ = ib.BuildWithFlavor(SQLite) 275 | a.Equal("INSERT INTO user (name) VALUES (?) RETURNING id", sql) 276 | 277 | sql, _ = ib.BuildWithFlavor(SQLServer) 278 | a.Equal("INSERT INTO user (name) VALUES (@p1)", sql) 279 | 280 | sql, _ = ib.BuildWithFlavor(CQL) 281 | a.Equal("INSERT INTO user (name) VALUES (?)", sql) 282 | 283 | sql, _ = ib.BuildWithFlavor(ClickHouse) 284 | a.Equal("INSERT INTO user (name) VALUES (?)", sql) 285 | 286 | sql, _ = ib.BuildWithFlavor(Presto) 287 | a.Equal("INSERT INTO user (name) VALUES (?)", sql) 288 | 289 | sql, _ = ib.BuildWithFlavor(Oracle) 290 | a.Equal("INSERT INTO user (name) VALUES (:1)", sql) 291 | 292 | sql, _ = ib.BuildWithFlavor(Informix) 293 | a.Equal("INSERT INTO user (name) VALUES (?)", sql) 294 | } 295 | 296 | func TestInsertBuilderGetFlavor(t *testing.T) { 297 | a := assert.New(t) 298 | ib := newInsertBuilder() 299 | 300 | ib.SetFlavor(PostgreSQL) 301 | flavor := ib.Flavor() 302 | a.Equal(PostgreSQL, flavor) 303 | 304 | ibClick := ClickHouse.NewInsertBuilder() 305 | flavor = ibClick.Flavor() 306 | a.Equal(ClickHouse, flavor) 307 | } 308 | -------------------------------------------------------------------------------- /interpolate_test.go: -------------------------------------------------------------------------------- 1 | package sqlbuilder 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | "testing" 9 | "time" 10 | 11 | "github.com/huandu/go-assert" 12 | ) 13 | 14 | type errorValuer int 15 | 16 | var ErrErrorValuer = errors.New("error valuer") 17 | 18 | func (v errorValuer) Value() (driver.Value, error) { 19 | return 0, ErrErrorValuer 20 | } 21 | 22 | func TestFlavorInterpolate(t *testing.T) { 23 | dt := time.Date(2019, 4, 24, 12, 23, 34, 123456789, time.FixedZone("CST", 8*60*60)) // 2019-04-24 12:23:34.987654321 CST 24 | _, errOutOfRange := strconv.ParseInt("12345678901234567890", 10, 32) 25 | byteArr := [...]byte{'f', 'o', 'o'} 26 | cases := []struct { 27 | Flavor Flavor 28 | SQL string 29 | Args []interface{} 30 | Query string 31 | Err error 32 | }{ 33 | { 34 | MySQL, 35 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 36 | "SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil, 37 | }, 38 | { 39 | MySQL, 40 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 41 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil, 42 | }, 43 | { 44 | MySQL, 45 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 46 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, _binary'I\\'m bytes', '2019-04-24 12:23:34.123457', '0000-00-00', NULL", nil, 47 | }, 48 | { 49 | MySQL, 50 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL}, 51 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil, 52 | }, 53 | { 54 | MySQL, 55 | "SELECT ?", []interface{}{byteArr}, 56 | "SELECT _binary'foo'", nil, 57 | }, 58 | { 59 | MySQL, 60 | "SELECT ?", nil, 61 | "", ErrInterpolateMissingArgs, 62 | }, 63 | { 64 | MySQL, 65 | "SELECT ?", []interface{}{complex(1, 2)}, 66 | "", ErrInterpolateUnsupportedArgs, 67 | }, 68 | { 69 | MySQL, 70 | "SELECT ?", []interface{}{[]complex128{complex(1, 2)}}, 71 | "", ErrInterpolateUnsupportedArgs, 72 | }, 73 | { 74 | MySQL, 75 | "SELECT ?", []interface{}{errorValuer(1)}, 76 | "", ErrErrorValuer, 77 | }, 78 | 79 | { 80 | PostgreSQL, 81 | "SELECT * FROM a WHERE name = $3 AND state IN ($2, $4, $1, $6, $5)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 82 | "SELECT * FROM a WHERE name = 8 AND state IN (42, -16, E'I\\'m fine', 64, 32)", nil, 83 | }, 84 | { 85 | PostgreSQL, 86 | "SELECT * FROM $abc$$1$abc$1$1 WHERE name = \"$1\" AND state IN ($2, '$1', $3, $6, $5, $4, $2) $3", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 87 | "SELECT * FROM $abc$$1$abc$1E'\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'' WHERE name = \"$1\" AND state IN (42, '$1', 8, 64, 32, 16, 42) 8", nil, 88 | }, 89 | { 90 | PostgreSQL, 91 | "SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $11, $a", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil, 10, 11, 12}, 92 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, E'\\\\x49276D206279746573'::bytea, '2019-04-24 12:23:34.123457 CST', '0000-00-00', NULL, 11, $a", nil, 93 | }, 94 | { 95 | PostgreSQL, 96 | "SELECT '\\'$1', \"\\\"$1\", `$1`, \\$1a, $$1$$, $a $b$ $a $ $1$b$1$1 $a$ $", []interface{}{MySQL}, 97 | "SELECT '\\'$1', \"\\\"$1\", `E'MySQL'`, \\E'MySQL'a, $$1$$, $a $b$ $a $ $1$b$1E'MySQL' $a$ $", nil, 98 | }, 99 | { 100 | PostgreSQL, 101 | "SELECT * FROM a WHERE name = 'Huan''Du''$1' AND desc = $1", []interface{}{"c'mon"}, 102 | "SELECT * FROM a WHERE name = 'Huan''Du''$1' AND desc = E'c\\'mon'", nil, 103 | }, 104 | { 105 | PostgreSQL, 106 | "SELECT $1", nil, 107 | "", ErrInterpolateMissingArgs, 108 | }, 109 | { 110 | PostgreSQL, 111 | "SELECT $1", []interface{}{complex(1, 2)}, 112 | "", ErrInterpolateUnsupportedArgs, 113 | }, 114 | { 115 | PostgreSQL, 116 | "SELECT $12345678901234567890", nil, 117 | "", errOutOfRange, 118 | }, 119 | 120 | { 121 | SQLite, 122 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 123 | "SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil, 124 | }, 125 | { 126 | SQLite, 127 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 128 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil, 129 | }, 130 | { 131 | SQLite, 132 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 133 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, X'49276D206279746573', '2019-04-24 12:23:34.123', '0000-00-00', NULL", nil, 134 | }, 135 | { 136 | SQLite, 137 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{SQLite}, 138 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'SQLite'", nil, 139 | }, 140 | 141 | { 142 | SQLServer, 143 | "SELECT * FROM a WHERE name = @p1 AND state IN (@p3, @P2, @p4, @P6, @p5)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 144 | "SELECT * FROM a WHERE name = N'I\\'m fine' AND state IN (8, 42, -16, 64, 32)", nil, 145 | }, 146 | { 147 | SQLServer, 148 | "SELECT * FROM \"a@p1\" WHERE name = '@p1' AND state IN (@p2, '@p1', @p1, @p3, @p4, @p5, @p6)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 149 | "SELECT * FROM \"a@p1\" WHERE name = '@p1' AND state IN (42, '@p1', N'\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', 8, 16, 32, 64)", nil, 150 | }, 151 | { 152 | SQLServer, 153 | "SELECT @p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 154 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, 0x49276D206279746573, '2019-04-24 12:23:34.123457 +08:00', '0000-00-00', NULL", nil, 155 | }, 156 | { 157 | SQLServer, 158 | "SELECT '\\'@p1', \"\\\"@p1\", \\@p1, @abc", []interface{}{SQLServer}, 159 | "SELECT '\\'@p1', \"\\\"@p1\", \\N'SQLServer', @abc", nil, 160 | }, 161 | { 162 | SQLServer, 163 | "SELECT @p1", nil, 164 | "", ErrInterpolateMissingArgs, 165 | }, 166 | { 167 | CQL, 168 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 169 | "SELECT * FROM a WHERE name = 'I''m fine' AND state IN (42, 8, -16, 32, 64)", nil, 170 | }, 171 | { 172 | CQL, 173 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 174 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"''', '?', 42, 8, 16, 32, 64)", nil, 175 | }, 176 | { 177 | CQL, 178 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 179 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, 0x49276D206279746573, '2019-04-24 12:23:34.123457+0800', '0000-00-00', NULL", nil, 180 | }, 181 | { 182 | CQL, 183 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{CQL}, 184 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'CQL'", nil, 185 | }, 186 | { 187 | CQL, 188 | "SELECT ?", nil, 189 | "", ErrInterpolateMissingArgs, 190 | }, 191 | { 192 | CQL, 193 | "SELECT ?", []interface{}{complex(1, 2)}, 194 | "", ErrInterpolateUnsupportedArgs, 195 | }, 196 | { 197 | ClickHouse, 198 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 199 | "SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil, 200 | }, 201 | { 202 | ClickHouse, 203 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 204 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil, 205 | }, 206 | { 207 | ClickHouse, 208 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), 9.87654321, []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 209 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, unhex('49276D206279746573'), '2019-04-24 12:23:34.123457', '0000-00-00', NULL", nil, 210 | }, 211 | { 212 | ClickHouse, 213 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL}, 214 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil, 215 | }, 216 | { 217 | ClickHouse, 218 | "SELECT ?", []interface{}{byteArr}, 219 | "SELECT unhex('666F6F')", nil, 220 | }, 221 | { 222 | ClickHouse, 223 | "SELECT ?", nil, 224 | "", ErrInterpolateMissingArgs, 225 | }, 226 | { 227 | ClickHouse, 228 | "SELECT ?", []interface{}{complex(1, 2)}, 229 | "", ErrInterpolateUnsupportedArgs, 230 | }, 231 | { 232 | ClickHouse, 233 | "SELECT ?", []interface{}{[]complex128{complex(1, 2)}}, 234 | "", ErrInterpolateUnsupportedArgs, 235 | }, 236 | { 237 | ClickHouse, 238 | "SELECT ?", []interface{}{errorValuer(1)}, 239 | "", ErrErrorValuer, 240 | }, 241 | { 242 | Presto, 243 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 244 | "SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil, 245 | }, 246 | { 247 | Presto, 248 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 249 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil, 250 | }, 251 | { 252 | Presto, 253 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), 9.87654321, []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 254 | "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, from_hex('49276D206279746573'), '2019-04-24 12:23:34.123', '0000-00-00', NULL", nil, 255 | }, 256 | { 257 | Presto, 258 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL}, 259 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil, 260 | }, 261 | { 262 | Presto, 263 | "SELECT ?", []interface{}{byteArr}, 264 | "SELECT from_hex('666F6F')", nil, 265 | }, 266 | { 267 | Presto, 268 | "SELECT ?", nil, 269 | "", ErrInterpolateMissingArgs, 270 | }, 271 | { 272 | Presto, 273 | "SELECT ?", []interface{}{complex(1, 2)}, 274 | "", ErrInterpolateUnsupportedArgs, 275 | }, 276 | { 277 | Presto, 278 | "SELECT ?", []interface{}{[]complex128{complex(1, 2)}}, 279 | "", ErrInterpolateUnsupportedArgs, 280 | }, 281 | { 282 | Presto, 283 | "SELECT ?", []interface{}{errorValuer(1)}, 284 | "", ErrErrorValuer, 285 | }, 286 | 287 | { 288 | Oracle, 289 | "SELECT * FROM a WHERE name = :3 AND state IN (:2, :4, :1, :6, :5)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 290 | "SELECT * FROM a WHERE name = 8 AND state IN (42, -16, 'I\\'m fine', 64, 32)", nil, 291 | }, 292 | { 293 | Oracle, 294 | "SELECT * FROM :abc::1:abc:1:1 WHERE name = \":1\" AND state IN (:2, ':1', :3, :6, :5, :4, :2) :3", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 295 | "SELECT * FROM :abc::1:abc:1'\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'' WHERE name = \":1\" AND state IN (42, ':1', 8, 64, 32, 16, 42) 8", nil, 296 | }, 297 | { 298 | Oracle, 299 | "SELECT :1, :2, :3, :4, :5, :6, :7, :8, :9, :11, :a", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil, 10, 11, 12}, 300 | "SELECT 1, 0, 1.234567, 9.87654321, NULL, hextoraw('49276D206279746573'), to_timestamp('2019-04-24 12:23:34.123457', 'YYYY-MM-DD HH24:MI:SS.FF'), '0000-00-00', NULL, 11, :a", nil, 301 | }, 302 | { 303 | Oracle, 304 | "SELECT '\\':1', \"\\\":1\", `:1`, \\:1a, ::1::, :a :b: :a : :1:b:1:1 :a: :", []interface{}{Oracle}, 305 | "SELECT '\\':1', \"\\\":1\", `'Oracle'`, \\'Oracle'a, ::1::, :a :b: :a : :1:b:1'Oracle' :a: :", nil, 306 | }, 307 | { 308 | Oracle, 309 | "SELECT * FROM a WHERE name = 'Huan''Du'':1' AND desc = :1", []interface{}{"c'mon"}, 310 | "SELECT * FROM a WHERE name = 'Huan''Du'':1' AND desc = 'c\\'mon'", nil, 311 | }, 312 | { 313 | Oracle, 314 | "SELECT :1", nil, 315 | "", ErrInterpolateMissingArgs, 316 | }, 317 | { 318 | Oracle, 319 | "SELECT :1", []interface{}{complex(1, 2)}, 320 | "", ErrInterpolateUnsupportedArgs, 321 | }, 322 | { 323 | Oracle, 324 | "SELECT :12345678901234567890", nil, 325 | "", errOutOfRange, 326 | }, 327 | { 328 | Informix, 329 | "SELECT * FROM a WHERE name = ? AND state IN (?, ?, ?, ?, ?)", []interface{}{"I'm fine", 42, int8(8), int16(-16), int32(32), int64(64)}, 330 | "SELECT * FROM a WHERE name = 'I\\'m fine' AND state IN (42, 8, -16, 32, 64)", nil, 331 | }, 332 | { 333 | Informix, 334 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN (?, '?', ?, ?, ?, ?, ?)", []interface{}{"\r\n\b\t\x1a\x00\\\"'", uint(42), uint8(8), uint16(16), uint32(32), uint64(64), "useless"}, 335 | "SELECT * FROM `a?` WHERE name = \"?\" AND state IN ('\\r\\n\\b\\t\\Z\\0\\\\\\\"\\'', '?', 42, 8, 16, 32, 64)", nil, 336 | }, 337 | { 338 | Informix, 339 | "SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?", []interface{}{true, false, float32(1.234567), float64(9.87654321), []byte(nil), []byte("I'm bytes"), dt, time.Time{}, nil}, 340 | // "SELECT TRUE, FALSE, 1.234567, 9.87654321, NULL, _binary'I\\'m bytes', '2019-04-24 12:23:34.123457', '0000-00-00', NULL", nil, 341 | "", ErrInterpolateUnsupportedArgs, 342 | }, 343 | { 344 | Informix, 345 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{Informix}, 346 | "SELECT '\\'?', \"\\\"?\", `\\`?`, \\'Informix'", nil, 347 | }, 348 | { 349 | Informix, 350 | "SELECT ?", []interface{}{byteArr}, 351 | // "SELECT _binary'foo'", nil, 352 | "", ErrInterpolateUnsupportedArgs, 353 | }, 354 | { 355 | Informix, 356 | "SELECT ?", nil, 357 | "", ErrInterpolateMissingArgs, 358 | }, 359 | { 360 | Informix, 361 | "SELECT ?", []interface{}{complex(1, 2)}, 362 | "", ErrInterpolateUnsupportedArgs, 363 | }, 364 | { 365 | Informix, 366 | "SELECT ?", []interface{}{[]complex128{complex(1, 2)}}, 367 | "", ErrInterpolateUnsupportedArgs, 368 | }, 369 | { 370 | Informix, 371 | "SELECT ?", []interface{}{errorValuer(1)}, 372 | "", ErrErrorValuer, 373 | }, 374 | { 375 | Doris, 376 | "SELECT ?", []interface{}{errorValuer(1)}, 377 | "", ErrErrorValuer, 378 | }, 379 | } 380 | 381 | for idx, c := range cases { 382 | t.Run(fmt.Sprintf("%s: %s", c.Flavor.String(), c.Query), func(t *testing.T) { 383 | a := assert.New(t) 384 | a.Use(&idx, &c) 385 | query, err := c.Flavor.Interpolate(c.SQL, c.Args) 386 | 387 | a.Equal(query, c.Query) 388 | a.Assert(err == c.Err || err.Error() == c.Err.Error()) 389 | }) 390 | } 391 | } 392 | -------------------------------------------------------------------------------- /modifiers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "reflect" 8 | "strings" 9 | ) 10 | 11 | // Escape replaces `$` with `$$` in ident. 12 | func Escape(ident string) string { 13 | return strings.Replace(ident, "$", "$$", -1) 14 | } 15 | 16 | // EscapeAll replaces `$` with `$$` in all strings of ident. 17 | func EscapeAll(ident ...string) []string { 18 | escaped := make([]string, 0, len(ident)) 19 | 20 | for _, i := range ident { 21 | escaped = append(escaped, Escape(i)) 22 | } 23 | 24 | return escaped 25 | } 26 | 27 | // Flatten recursively extracts values in slices and returns 28 | // a flattened []interface{} with all values. 29 | // If slices is not a slice, return `[]interface{}{slices}`. 30 | func Flatten(slices interface{}) (flattened []interface{}) { 31 | v := reflect.ValueOf(slices) 32 | slices, flattened = flatten(v) 33 | 34 | if slices != nil { 35 | return []interface{}{slices} 36 | } 37 | 38 | return flattened 39 | } 40 | 41 | func flatten(v reflect.Value) (elem interface{}, flattened []interface{}) { 42 | k := v.Kind() 43 | 44 | for k == reflect.Interface { 45 | v = v.Elem() 46 | k = v.Kind() 47 | } 48 | 49 | if k != reflect.Slice && k != reflect.Array { 50 | if !v.IsValid() || !v.CanInterface() { 51 | return 52 | } 53 | 54 | elem = v.Interface() 55 | return elem, nil 56 | } 57 | 58 | for i, l := 0, v.Len(); i < l; i++ { 59 | e, f := flatten(v.Index(i)) 60 | 61 | if e == nil { 62 | flattened = append(flattened, f...) 63 | } else { 64 | flattened = append(flattened, e) 65 | } 66 | } 67 | 68 | return 69 | } 70 | 71 | type rawArgs struct { 72 | expr string 73 | } 74 | 75 | // Raw marks the expr as a raw value which will not be added to args. 76 | func Raw(expr string) interface{} { 77 | return rawArgs{expr} 78 | } 79 | 80 | type listArgs struct { 81 | args []interface{} 82 | isTuple bool 83 | } 84 | 85 | // List marks arg as a list of data. 86 | // If arg is `[]int{1, 2, 3}`, it will be compiled to `?, ?, ?` with args `[1 2 3]`. 87 | func List(arg interface{}) interface{} { 88 | return listArgs{ 89 | args: Flatten(arg), 90 | } 91 | } 92 | 93 | // Tuple wraps values into a tuple and can be used as a single value. 94 | func Tuple(values ...interface{}) interface{} { 95 | return listArgs{ 96 | args: values, 97 | isTuple: true, 98 | } 99 | } 100 | 101 | // TupleNames joins names with tuple format. 102 | // The names is not escaped. Use `EscapeAll` to escape them if necessary. 103 | func TupleNames(names ...string) string { 104 | buf := newStringBuilder() 105 | buf.WriteRune('(') 106 | buf.WriteStrings(names, ", ") 107 | buf.WriteRune(')') 108 | 109 | return buf.String() 110 | } 111 | 112 | type namedArgs struct { 113 | name string 114 | arg interface{} 115 | } 116 | 117 | // Named creates a named argument. 118 | // Unlike `sql.Named`, this named argument works only with `Build` or `BuildNamed` for convenience 119 | // and will be replaced to a `?` after `Compile`. 120 | func Named(name string, arg interface{}) interface{} { 121 | return namedArgs{ 122 | name: name, 123 | arg: arg, 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /modifiers_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func TestEscape(t *testing.T) { 14 | a := assert.New(t) 15 | cases := map[string]string{ 16 | "foo": "foo", 17 | "$foo": "$$foo", 18 | "$$$": "$$$$$$", 19 | } 20 | var inputs, expects []string 21 | 22 | for s, expected := range cases { 23 | inputs = append(inputs, s) 24 | expects = append(expects, expected) 25 | actual := Escape(s) 26 | 27 | a.Equal(actual, expected) 28 | } 29 | 30 | actuals := EscapeAll(inputs...) 31 | a.Equal(actuals, expects) 32 | } 33 | 34 | func TestFlatten(t *testing.T) { 35 | a := assert.New(t) 36 | cases := [][2]interface{}{ 37 | { 38 | "foo", 39 | []interface{}{"foo"}, 40 | }, 41 | { 42 | []int{1, 2, 3}, 43 | []interface{}{1, 2, 3}, 44 | }, 45 | { 46 | []interface{}{"abc", []int{1, 2, 3}, [3]string{"def", "ghi"}}, 47 | []interface{}{"abc", 1, 2, 3, "def", "ghi", ""}, 48 | }, 49 | } 50 | 51 | for _, c := range cases { 52 | input, expected := c[0], c[1] 53 | actual := Flatten(input) 54 | 55 | a.Equal(actual, expected) 56 | } 57 | } 58 | 59 | func TestTuple(t *testing.T) { 60 | a := assert.New(t) 61 | cases := []struct { 62 | values []interface{} 63 | expected string 64 | }{ 65 | { 66 | nil, 67 | "()", 68 | }, 69 | { 70 | []interface{}{1, "bar", nil, Tuple("foo", Tuple(2, "baz"))}, 71 | "(1, 'bar', NULL, ('foo', (2, 'baz')))", 72 | }, 73 | } 74 | 75 | for _, c := range cases { 76 | sql, args := Build("$?", Tuple(c.values...)).Build() 77 | actual, err := DefaultFlavor.Interpolate(sql, args) 78 | a.NilError(err) 79 | a.Equal(actual, c.expected) 80 | } 81 | } 82 | 83 | func ExampleTuple() { 84 | sb := Select("id", "name").From("user") 85 | sb.Where( 86 | sb.In( 87 | TupleNames("type", "status"), 88 | Tuple("web", 1), 89 | Tuple("app", 1), 90 | Tuple("app", 2), 91 | ), 92 | ) 93 | sql, args := sb.Build() 94 | 95 | fmt.Println(sql) 96 | fmt.Println(args) 97 | 98 | // Output: 99 | // SELECT id, name FROM user WHERE (type, status) IN ((?, ?), (?, ?), (?, ?)) 100 | // [web 1 app 1 app 2] 101 | } 102 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "strings" 9 | ) 10 | 11 | const ( 12 | selectMarkerInit injectionMarker = iota 13 | selectMarkerAfterWith 14 | selectMarkerAfterSelect 15 | selectMarkerAfterFrom 16 | selectMarkerAfterJoin 17 | selectMarkerAfterWhere 18 | selectMarkerAfterGroupBy 19 | selectMarkerAfterOrderBy 20 | selectMarkerAfterLimit 21 | selectMarkerAfterFor 22 | ) 23 | 24 | // JoinOption is the option in JOIN. 25 | type JoinOption string 26 | 27 | // Join options. 28 | const ( 29 | FullJoin JoinOption = "FULL" 30 | FullOuterJoin JoinOption = "FULL OUTER" 31 | InnerJoin JoinOption = "INNER" 32 | LeftJoin JoinOption = "LEFT" 33 | LeftOuterJoin JoinOption = "LEFT OUTER" 34 | RightJoin JoinOption = "RIGHT" 35 | RightOuterJoin JoinOption = "RIGHT OUTER" 36 | ) 37 | 38 | // NewSelectBuilder creates a new SELECT builder. 39 | func NewSelectBuilder() *SelectBuilder { 40 | return DefaultFlavor.NewSelectBuilder() 41 | } 42 | 43 | func newSelectBuilder() *SelectBuilder { 44 | args := &Args{} 45 | proxy := &whereClauseProxy{} 46 | return &SelectBuilder{ 47 | whereClauseProxy: proxy, 48 | whereClauseExpr: args.Add(proxy), 49 | 50 | Cond: Cond{ 51 | Args: args, 52 | }, 53 | args: args, 54 | injection: newInjection(), 55 | } 56 | } 57 | 58 | // SelectBuilder is a builder to build SELECT. 59 | type SelectBuilder struct { 60 | *WhereClause 61 | Cond 62 | 63 | whereClauseProxy *whereClauseProxy 64 | whereClauseExpr string 65 | 66 | cteBuilderVar string 67 | cteBuilder *CTEBuilder 68 | 69 | distinct bool 70 | tables []string 71 | selectCols []string 72 | joinOptions []JoinOption 73 | joinTables []string 74 | joinExprs [][]string 75 | havingExprs []string 76 | groupByCols []string 77 | orderByCols []string 78 | order string 79 | limitVar string 80 | offsetVar string 81 | forWhat string 82 | 83 | args *Args 84 | 85 | injection *injection 86 | marker injectionMarker 87 | } 88 | 89 | var _ Builder = new(SelectBuilder) 90 | 91 | // Select sets columns in SELECT. 92 | func Select(col ...string) *SelectBuilder { 93 | return DefaultFlavor.NewSelectBuilder().Select(col...) 94 | } 95 | 96 | // TableNames returns all table names in this SELECT statement. 97 | func (sb *SelectBuilder) TableNames() []string { 98 | var additionalTableNames []string 99 | 100 | if sb.cteBuilder != nil { 101 | additionalTableNames = sb.cteBuilder.tableNamesForFrom() 102 | } 103 | 104 | var tableNames []string 105 | 106 | if len(sb.tables) > 0 && len(additionalTableNames) > 0 { 107 | tableNames = make([]string, len(sb.tables)+len(additionalTableNames)) 108 | copy(tableNames, sb.tables) 109 | copy(tableNames[len(sb.tables):], additionalTableNames) 110 | } else if len(sb.tables) > 0 { 111 | tableNames = sb.tables 112 | } else if len(additionalTableNames) > 0 { 113 | tableNames = additionalTableNames 114 | } 115 | 116 | return tableNames 117 | } 118 | 119 | // With sets WITH clause (the Common Table Expression) before SELECT. 120 | func (sb *SelectBuilder) With(builder *CTEBuilder) *SelectBuilder { 121 | sb.marker = selectMarkerAfterWith 122 | sb.cteBuilderVar = sb.Var(builder) 123 | sb.cteBuilder = builder 124 | return sb 125 | } 126 | 127 | // Select sets columns in SELECT. 128 | func (sb *SelectBuilder) Select(col ...string) *SelectBuilder { 129 | sb.selectCols = col 130 | sb.marker = selectMarkerAfterSelect 131 | return sb 132 | } 133 | 134 | // SelectMore adds more columns in SELECT. 135 | func (sb *SelectBuilder) SelectMore(col ...string) *SelectBuilder { 136 | sb.selectCols = append(sb.selectCols, col...) 137 | sb.marker = selectMarkerAfterSelect 138 | return sb 139 | } 140 | 141 | // Distinct marks this SELECT as DISTINCT. 142 | func (sb *SelectBuilder) Distinct() *SelectBuilder { 143 | sb.distinct = true 144 | sb.marker = selectMarkerAfterSelect 145 | return sb 146 | } 147 | 148 | // From sets table names in SELECT. 149 | func (sb *SelectBuilder) From(table ...string) *SelectBuilder { 150 | sb.tables = table 151 | sb.marker = selectMarkerAfterFrom 152 | return sb 153 | } 154 | 155 | // Join sets expressions of JOIN in SELECT. 156 | // 157 | // It builds a JOIN expression like 158 | // 159 | // JOIN table ON onExpr[0] AND onExpr[1] ... 160 | func (sb *SelectBuilder) Join(table string, onExpr ...string) *SelectBuilder { 161 | sb.marker = selectMarkerAfterJoin 162 | return sb.JoinWithOption("", table, onExpr...) 163 | } 164 | 165 | // JoinWithOption sets expressions of JOIN with an option. 166 | // 167 | // It builds a JOIN expression like 168 | // 169 | // option JOIN table ON onExpr[0] AND onExpr[1] ... 170 | // 171 | // Here is a list of supported options. 172 | // - FullJoin: FULL JOIN 173 | // - FullOuterJoin: FULL OUTER JOIN 174 | // - InnerJoin: INNER JOIN 175 | // - LeftJoin: LEFT JOIN 176 | // - LeftOuterJoin: LEFT OUTER JOIN 177 | // - RightJoin: RIGHT JOIN 178 | // - RightOuterJoin: RIGHT OUTER JOIN 179 | func (sb *SelectBuilder) JoinWithOption(option JoinOption, table string, onExpr ...string) *SelectBuilder { 180 | sb.joinOptions = append(sb.joinOptions, option) 181 | sb.joinTables = append(sb.joinTables, table) 182 | sb.joinExprs = append(sb.joinExprs, onExpr) 183 | sb.marker = selectMarkerAfterJoin 184 | return sb 185 | } 186 | 187 | // Where sets expressions of WHERE in SELECT. 188 | func (sb *SelectBuilder) Where(andExpr ...string) *SelectBuilder { 189 | if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { 190 | return sb 191 | } 192 | 193 | if sb.WhereClause == nil { 194 | sb.WhereClause = NewWhereClause() 195 | } 196 | 197 | sb.WhereClause.AddWhereExpr(sb.args, andExpr...) 198 | sb.marker = selectMarkerAfterWhere 199 | return sb 200 | } 201 | 202 | // AddWhereClause adds all clauses in the whereClause to SELECT. 203 | func (sb *SelectBuilder) AddWhereClause(whereClause *WhereClause) *SelectBuilder { 204 | if sb.WhereClause == nil { 205 | sb.WhereClause = NewWhereClause() 206 | } 207 | 208 | sb.WhereClause.AddWhereClause(whereClause) 209 | return sb 210 | } 211 | 212 | // Having sets expressions of HAVING in SELECT. 213 | func (sb *SelectBuilder) Having(andExpr ...string) *SelectBuilder { 214 | sb.havingExprs = append(sb.havingExprs, andExpr...) 215 | sb.marker = selectMarkerAfterGroupBy 216 | return sb 217 | } 218 | 219 | // GroupBy sets columns of GROUP BY in SELECT. 220 | func (sb *SelectBuilder) GroupBy(col ...string) *SelectBuilder { 221 | sb.groupByCols = append(sb.groupByCols, col...) 222 | sb.marker = selectMarkerAfterGroupBy 223 | return sb 224 | } 225 | 226 | // OrderBy sets columns of ORDER BY in SELECT. 227 | func (sb *SelectBuilder) OrderBy(col ...string) *SelectBuilder { 228 | sb.orderByCols = append(sb.orderByCols, col...) 229 | sb.marker = selectMarkerAfterOrderBy 230 | return sb 231 | } 232 | 233 | // Asc sets order of ORDER BY to ASC. 234 | func (sb *SelectBuilder) Asc() *SelectBuilder { 235 | sb.order = "ASC" 236 | sb.marker = selectMarkerAfterOrderBy 237 | return sb 238 | } 239 | 240 | // Desc sets order of ORDER BY to DESC. 241 | func (sb *SelectBuilder) Desc() *SelectBuilder { 242 | sb.order = "DESC" 243 | sb.marker = selectMarkerAfterOrderBy 244 | return sb 245 | } 246 | 247 | // Limit sets the LIMIT in SELECT. 248 | func (sb *SelectBuilder) Limit(limit int) *SelectBuilder { 249 | if limit < 0 { 250 | sb.limitVar = "" 251 | return sb 252 | } 253 | 254 | sb.limitVar = sb.Var(limit) 255 | sb.marker = selectMarkerAfterLimit 256 | return sb 257 | } 258 | 259 | // Offset sets the LIMIT offset in SELECT. 260 | func (sb *SelectBuilder) Offset(offset int) *SelectBuilder { 261 | if offset < 0 { 262 | sb.offsetVar = "" 263 | return sb 264 | } 265 | 266 | sb.offsetVar = sb.Var(offset) 267 | sb.marker = selectMarkerAfterLimit 268 | return sb 269 | } 270 | 271 | // ForUpdate adds FOR UPDATE at the end of SELECT statement. 272 | func (sb *SelectBuilder) ForUpdate() *SelectBuilder { 273 | sb.forWhat = "UPDATE" 274 | sb.marker = selectMarkerAfterFor 275 | return sb 276 | } 277 | 278 | // ForShare adds FOR SHARE at the end of SELECT statement. 279 | func (sb *SelectBuilder) ForShare() *SelectBuilder { 280 | sb.forWhat = "SHARE" 281 | sb.marker = selectMarkerAfterFor 282 | return sb 283 | } 284 | 285 | // As returns an AS expression. 286 | func (sb *SelectBuilder) As(name, alias string) string { 287 | return fmt.Sprintf("%s AS %s", name, alias) 288 | } 289 | 290 | // BuilderAs returns an AS expression wrapping a complex SQL. 291 | // According to SQL syntax, SQL built by builder is surrounded by parens. 292 | func (sb *SelectBuilder) BuilderAs(builder Builder, alias string) string { 293 | return fmt.Sprintf("(%s) AS %s", sb.Var(builder), alias) 294 | } 295 | 296 | // LateralAs returns a LATERAL derived table expression wrapping a complex SQL. 297 | func (sb *SelectBuilder) LateralAs(builder Builder, alias string) string { 298 | return fmt.Sprintf("LATERAL (%s) AS %s", sb.Var(builder), alias) 299 | } 300 | 301 | // NumCol returns the number of columns to select. 302 | func (sb *SelectBuilder) NumCol() int { 303 | return len(sb.selectCols) 304 | } 305 | 306 | // String returns the compiled SELECT string. 307 | func (sb *SelectBuilder) String() string { 308 | s, _ := sb.Build() 309 | return s 310 | } 311 | 312 | // Build returns compiled SELECT string and args. 313 | // They can be used in `DB#Query` of package `database/sql` directly. 314 | func (sb *SelectBuilder) Build() (sql string, args []interface{}) { 315 | return sb.BuildWithFlavor(sb.args.Flavor) 316 | } 317 | 318 | // BuildWithFlavor returns compiled SELECT string and args with flavor and initial args. 319 | // They can be used in `DB#Query` of package `database/sql` directly. 320 | func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 321 | buf := newStringBuilder() 322 | sb.injection.WriteTo(buf, selectMarkerInit) 323 | 324 | oraclePage := flavor == Oracle && (len(sb.limitVar) > 0 || len(sb.offsetVar) > 0) 325 | 326 | if sb.cteBuilderVar != "" { 327 | buf.WriteLeadingString(sb.cteBuilderVar) 328 | sb.injection.WriteTo(buf, selectMarkerAfterWith) 329 | } 330 | 331 | if len(sb.selectCols) > 0 { 332 | buf.WriteLeadingString("SELECT ") 333 | 334 | if sb.distinct { 335 | buf.WriteString("DISTINCT ") 336 | } 337 | 338 | if oraclePage { 339 | var selectCols = make([]string, 0, len(sb.selectCols)) 340 | for i := range sb.selectCols { 341 | cols := strings.SplitN(sb.selectCols[i], ".", 2) 342 | 343 | if len(cols) == 1 { 344 | selectCols = append(selectCols, cols[0]) 345 | } else { 346 | selectCols = append(selectCols, cols[1]) 347 | } 348 | } 349 | buf.WriteStrings(selectCols, ", ") 350 | } else { 351 | buf.WriteStrings(sb.selectCols, ", ") 352 | } 353 | } 354 | 355 | sb.injection.WriteTo(buf, selectMarkerAfterSelect) 356 | 357 | if oraclePage { 358 | if len(sb.selectCols) > 0 { 359 | buf.WriteLeadingString("FROM (SELECT ") 360 | 361 | if sb.distinct { 362 | buf.WriteString("DISTINCT ") 363 | } 364 | 365 | var selectCols = make([]string, 0, len(sb.selectCols)+1) 366 | selectCols = append(selectCols, "ROWNUM r") 367 | 368 | for i := range sb.selectCols { 369 | cols := strings.SplitN(sb.selectCols[i], ".", 2) 370 | if len(cols) == 1 { 371 | selectCols = append(selectCols, cols[0]) 372 | } else { 373 | selectCols = append(selectCols, cols[1]) 374 | } 375 | } 376 | 377 | buf.WriteStrings(selectCols, ", ") 378 | buf.WriteLeadingString("FROM (SELECT ") 379 | buf.WriteStrings(sb.selectCols, ", ") 380 | } 381 | } 382 | 383 | tableNames := sb.TableNames() 384 | 385 | if len(tableNames) > 0 { 386 | buf.WriteLeadingString("FROM ") 387 | buf.WriteStrings(tableNames, ", ") 388 | } 389 | 390 | sb.injection.WriteTo(buf, selectMarkerAfterFrom) 391 | 392 | for i := range sb.joinTables { 393 | if option := sb.joinOptions[i]; option != "" { 394 | buf.WriteLeadingString(string(option)) 395 | } 396 | 397 | buf.WriteLeadingString("JOIN ") 398 | buf.WriteString(sb.joinTables[i]) 399 | 400 | if exprs := filterEmptyStrings(sb.joinExprs[i]); len(exprs) > 0 { 401 | buf.WriteString(" ON ") 402 | buf.WriteStrings(exprs, " AND ") 403 | } 404 | } 405 | 406 | if len(sb.joinTables) > 0 { 407 | sb.injection.WriteTo(buf, selectMarkerAfterJoin) 408 | } 409 | 410 | if sb.WhereClause != nil { 411 | sb.whereClauseProxy.WhereClause = sb.WhereClause 412 | defer func() { 413 | sb.whereClauseProxy.WhereClause = nil 414 | }() 415 | 416 | buf.WriteLeadingString(sb.whereClauseExpr) 417 | sb.injection.WriteTo(buf, selectMarkerAfterWhere) 418 | } 419 | 420 | if len(sb.groupByCols) > 0 { 421 | buf.WriteLeadingString("GROUP BY ") 422 | buf.WriteStrings(sb.groupByCols, ", ") 423 | 424 | if havingExprs := filterEmptyStrings(sb.havingExprs); len(havingExprs) > 0 { 425 | buf.WriteString(" HAVING ") 426 | buf.WriteStrings(havingExprs, " AND ") 427 | } 428 | 429 | sb.injection.WriteTo(buf, selectMarkerAfterGroupBy) 430 | } 431 | 432 | if len(sb.orderByCols) > 0 { 433 | buf.WriteLeadingString("ORDER BY ") 434 | buf.WriteStrings(sb.orderByCols, ", ") 435 | 436 | if sb.order != "" { 437 | buf.WriteRune(' ') 438 | buf.WriteString(sb.order) 439 | } 440 | 441 | sb.injection.WriteTo(buf, selectMarkerAfterOrderBy) 442 | } 443 | 444 | switch flavor { 445 | case MySQL, SQLite, ClickHouse: 446 | if len(sb.limitVar) > 0 { 447 | buf.WriteLeadingString("LIMIT ") 448 | buf.WriteString(sb.limitVar) 449 | 450 | if len(sb.offsetVar) > 0 { 451 | buf.WriteLeadingString("OFFSET ") 452 | buf.WriteString(sb.offsetVar) 453 | } 454 | } 455 | 456 | case CQL: 457 | if len(sb.limitVar) > 0 { 458 | buf.WriteLeadingString("LIMIT ") 459 | buf.WriteString(sb.limitVar) 460 | } 461 | 462 | case PostgreSQL: 463 | if len(sb.limitVar) > 0 { 464 | buf.WriteLeadingString("LIMIT ") 465 | buf.WriteString(sb.limitVar) 466 | } 467 | 468 | if len(sb.offsetVar) > 0 { 469 | buf.WriteLeadingString("OFFSET ") 470 | buf.WriteString(sb.offsetVar) 471 | } 472 | 473 | case Presto: 474 | // There might be a hidden constraint in Presto requiring offset to be set before limit. 475 | // The select statement documentation (https://prestodb.io/docs/current/sql/select.html) 476 | // puts offset before limit, and Trino, which is based on Presto, seems 477 | // to require this specific order. 478 | if len(sb.offsetVar) > 0 { 479 | buf.WriteLeadingString("OFFSET ") 480 | buf.WriteString(sb.offsetVar) 481 | } 482 | 483 | if len(sb.limitVar) > 0 { 484 | buf.WriteLeadingString("LIMIT ") 485 | buf.WriteString(sb.limitVar) 486 | } 487 | 488 | case SQLServer: 489 | // If ORDER BY is not set, sort column #1 by default. 490 | // It's required to make OFFSET...FETCH work. 491 | if len(sb.orderByCols) == 0 && (len(sb.limitVar) > 0 || len(sb.offsetVar) > 0) { 492 | buf.WriteLeadingString("ORDER BY 1") 493 | } 494 | 495 | if len(sb.offsetVar) > 0 { 496 | buf.WriteLeadingString("OFFSET ") 497 | buf.WriteString(sb.offsetVar) 498 | buf.WriteString(" ROWS") 499 | } 500 | 501 | if len(sb.limitVar) > 0 { 502 | if len(sb.offsetVar) == 0 { 503 | buf.WriteLeadingString("OFFSET 0 ROWS") 504 | } 505 | 506 | buf.WriteLeadingString("FETCH NEXT ") 507 | buf.WriteString(sb.limitVar) 508 | buf.WriteString(" ROWS ONLY") 509 | } 510 | 511 | case Oracle: 512 | if oraclePage { 513 | buf.WriteString(") ") 514 | 515 | if len(sb.tables) > 0 { 516 | buf.WriteStrings(sb.tables, ", ") 517 | } 518 | 519 | buf.WriteString(") WHERE ") 520 | 521 | if len(sb.limitVar) > 0 { 522 | buf.WriteString("r BETWEEN ") 523 | 524 | if len(sb.offsetVar) > 0 { 525 | buf.WriteString(sb.offsetVar) 526 | buf.WriteString(" + 1 AND ") 527 | buf.WriteString(sb.limitVar) 528 | buf.WriteString(" + ") 529 | buf.WriteString(sb.offsetVar) 530 | } else { 531 | buf.WriteString("1 AND ") 532 | buf.WriteString(sb.limitVar) 533 | buf.WriteString(" + 1") 534 | } 535 | } else { 536 | // As oraclePage is true, sb.offsetVar must not be empty. 537 | buf.WriteString("r >= ") 538 | buf.WriteString(sb.offsetVar) 539 | buf.WriteString(" + 1") 540 | } 541 | } 542 | 543 | case Informix: 544 | // [SKIP N] FIRST M 545 | // M must be greater than 0 546 | if len(sb.limitVar) > 0 { 547 | if len(sb.offsetVar) > 0 { 548 | buf.WriteLeadingString("SKIP ") 549 | buf.WriteString(sb.offsetVar) 550 | } 551 | 552 | buf.WriteLeadingString("FIRST ") 553 | buf.WriteString(sb.limitVar) 554 | } 555 | 556 | case Doris: 557 | // #192: Doris doesn't support ? in OFFSET and LIMIT. 558 | if len(sb.limitVar) > 0 { 559 | buf.WriteLeadingString("LIMIT ") 560 | buf.WriteString(fmt.Sprint(sb.args.Value(sb.limitVar))) 561 | 562 | if len(sb.offsetVar) > 0 { 563 | buf.WriteLeadingString("OFFSET ") 564 | buf.WriteString(fmt.Sprint(sb.args.Value(sb.offsetVar))) 565 | } 566 | } 567 | } 568 | 569 | if len(sb.limitVar) > 0 { 570 | sb.injection.WriteTo(buf, selectMarkerAfterLimit) 571 | } 572 | 573 | if sb.forWhat != "" { 574 | buf.WriteLeadingString("FOR ") 575 | buf.WriteString(sb.forWhat) 576 | 577 | sb.injection.WriteTo(buf, selectMarkerAfterFor) 578 | } 579 | 580 | return sb.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 581 | } 582 | 583 | // SetFlavor sets the flavor of compiled sql. 584 | func (sb *SelectBuilder) SetFlavor(flavor Flavor) (old Flavor) { 585 | old = sb.args.Flavor 586 | sb.args.Flavor = flavor 587 | return 588 | } 589 | 590 | // Flavor returns flavor of builder 591 | func (sb *SelectBuilder) Flavor() Flavor { 592 | return sb.args.Flavor 593 | } 594 | 595 | // SQL adds an arbitrary sql to current position. 596 | func (sb *SelectBuilder) SQL(sql string) *SelectBuilder { 597 | sb.injection.SQL(sb.marker, sql) 598 | return sb 599 | } 600 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/huandu/go-assert" 12 | ) 13 | 14 | func ExampleSelect() { 15 | // Build a SQL to create a HIVE table using MySQL-like SQL syntax. 16 | sql, args := Select("columns[0] id", "columns[1] name", "columns[2] year"). 17 | From(MySQL.Quote("all-users.csv")). 18 | Limit(100). 19 | Build() 20 | 21 | fmt.Println(sql) 22 | fmt.Println(args) 23 | 24 | // Output: 25 | // SELECT columns[0] id, columns[1] name, columns[2] year FROM `all-users.csv` LIMIT ? 26 | // [100] 27 | } 28 | 29 | func ExampleSelectBuilder() { 30 | sb := NewSelectBuilder() 31 | sb.Distinct().Select("id", "name", sb.As("COUNT(*)", "t")) 32 | sb.From("demo.user") 33 | sb.Where( 34 | sb.GreaterThan("id", 1234), 35 | sb.Like("name", "%Du"), 36 | sb.Or( 37 | sb.IsNull("id_card"), 38 | sb.In("status", 1, 2, 5), 39 | ), 40 | sb.NotIn( 41 | "id", 42 | NewSelectBuilder().Select("id").From("banned"), 43 | ), // Nested SELECT. 44 | "modified_at > created_at + "+sb.Var(86400), // It's allowed to write arbitrary SQL. 45 | ) 46 | sb.GroupBy("status").Having(sb.NotIn("status", 4, 5)) 47 | sb.OrderBy("modified_at").Asc() 48 | sb.Limit(10).Offset(5) 49 | 50 | s, args := sb.Build() 51 | fmt.Println(s) 52 | fmt.Println(args) 53 | 54 | // Output: 55 | // SELECT DISTINCT id, name, COUNT(*) AS t FROM demo.user WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND id NOT IN (SELECT id FROM banned) AND modified_at > created_at + ? GROUP BY status HAVING status NOT IN (?, ?) ORDER BY modified_at ASC LIMIT ? OFFSET ? 56 | // [1234 %Du 1 2 5 86400 4 5 10 5] 57 | } 58 | 59 | func ExampleSelectBuilder_advancedUsage() { 60 | sb := NewSelectBuilder() 61 | innerSb := NewSelectBuilder() 62 | 63 | // Named arguments are supported. 64 | start := sql.Named("start", 1234567890) 65 | end := sql.Named("end", 1234599999) 66 | level := sql.Named("level", 20) 67 | 68 | sb.Select("id", "name") 69 | sb.From( 70 | sb.BuilderAs(innerSb, "user"), 71 | ) 72 | sb.Where( 73 | sb.In("status", Flatten([]int{1, 2, 3})...), 74 | sb.Between("created_at", start, end), 75 | ) 76 | sb.OrderBy("modified_at").Desc() 77 | 78 | innerSb.Select("*") 79 | innerSb.From("banned") 80 | innerSb.Where( 81 | innerSb.GreaterThan("level", level), 82 | innerSb.LessEqualThan("updated_at", end), 83 | innerSb.NotIn("name", Flatten([]string{"Huan Du", "Charmy Liu"})...), 84 | ) 85 | 86 | s, args := sb.Build() 87 | fmt.Println(s) 88 | fmt.Println(args) 89 | 90 | // Output: 91 | // SELECT id, name FROM (SELECT * FROM banned WHERE level > @level AND updated_at <= @end AND name NOT IN (?, ?)) AS user WHERE status IN (?, ?, ?) AND created_at BETWEEN @start AND @end ORDER BY modified_at DESC 92 | // [Huan Du Charmy Liu 1 2 3 {{} level 20} {{} end 1234599999} {{} start 1234567890}] 93 | } 94 | 95 | func ExampleSelectBuilder_join() { 96 | sb := NewSelectBuilder() 97 | sb.Select("u.id", "u.name", "c.type", "p.nickname") 98 | sb.From("user u") 99 | sb.Join("contract c", 100 | "u.id = c.user_id", 101 | sb.In("c.status", 1, 2, 5), 102 | ) 103 | sb.JoinWithOption(RightOuterJoin, "person p", 104 | "u.id = p.user_id", 105 | sb.Like("p.surname", "%Du"), 106 | ) 107 | sb.Where( 108 | "u.modified_at > u.created_at + " + sb.Var(86400), // It's allowed to write arbitrary SQL. 109 | ) 110 | 111 | sql, args := sb.Build() 112 | fmt.Println(sql) 113 | fmt.Println(args) 114 | 115 | // Output: 116 | // SELECT u.id, u.name, c.type, p.nickname FROM user u JOIN contract c ON u.id = c.user_id AND c.status IN (?, ?, ?) RIGHT OUTER JOIN person p ON u.id = p.user_id AND p.surname LIKE ? WHERE u.modified_at > u.created_at + ? 117 | // [1 2 5 %Du 86400] 118 | } 119 | 120 | func ExampleSelectBuilder_limit_offset() { 121 | flavors := []Flavor{MySQL, PostgreSQL, SQLite, SQLServer, CQL, ClickHouse, Presto, Oracle, Informix, Doris} 122 | results := make([][]string, len(flavors)) 123 | sb := NewSelectBuilder() 124 | saveResults := func() { 125 | for i, f := range flavors { 126 | s, _ := sb.BuildWithFlavor(f) 127 | results[i] = append(results[i], s) 128 | } 129 | } 130 | 131 | sb.Select("*") 132 | sb.From("user") 133 | 134 | // Case #1: limit < 0 and offset < 0 135 | // 136 | // All: No limit or offset in query. 137 | sb.Limit(-1) 138 | sb.Offset(-1) 139 | saveResults() 140 | 141 | // Case #2: limit < 0 and offset >= 0 142 | // 143 | // MySQL and SQLite: Ignore offset if the limit is not set. 144 | // PostgreSQL: Offset can be set without limit. 145 | // SQLServer: Offset can be set without limit. 146 | // CQL: Ignore offset. 147 | // Oracle: Offset can be set without limit. 148 | sb.Limit(-1) 149 | sb.Offset(0) 150 | saveResults() 151 | 152 | // Case #3: limit >= 0 and offset >= 0 153 | // 154 | // CQL: Ignore offset. 155 | // All others: Set both limit and offset. 156 | sb.Limit(1) 157 | sb.Offset(0) 158 | saveResults() 159 | 160 | // Case #4: limit >= 0 and offset < 0 161 | // 162 | // All: Set limit in query. 163 | sb.Limit(1) 164 | sb.Offset(-1) 165 | saveResults() 166 | 167 | // Case #5: limit >= 0 and offset >= 0 order by id 168 | // 169 | // CQL: Ignore offset. 170 | // All others: Set both limit and offset. 171 | sb.Limit(1) 172 | sb.Offset(1) 173 | sb.OrderBy("id") 174 | saveResults() 175 | 176 | for i, result := range results { 177 | fmt.Println() 178 | fmt.Println(flavors[i]) 179 | 180 | for n, s := range result { 181 | fmt.Printf("#%d: %s\n", n+1, s) 182 | } 183 | } 184 | 185 | // Output: 186 | // 187 | // MySQL 188 | // #1: SELECT * FROM user 189 | // #2: SELECT * FROM user 190 | // #3: SELECT * FROM user LIMIT ? OFFSET ? 191 | // #4: SELECT * FROM user LIMIT ? 192 | // #5: SELECT * FROM user ORDER BY id LIMIT ? OFFSET ? 193 | // 194 | // PostgreSQL 195 | // #1: SELECT * FROM user 196 | // #2: SELECT * FROM user OFFSET $1 197 | // #3: SELECT * FROM user LIMIT $1 OFFSET $2 198 | // #4: SELECT * FROM user LIMIT $1 199 | // #5: SELECT * FROM user ORDER BY id LIMIT $1 OFFSET $2 200 | // 201 | // SQLite 202 | // #1: SELECT * FROM user 203 | // #2: SELECT * FROM user 204 | // #3: SELECT * FROM user LIMIT ? OFFSET ? 205 | // #4: SELECT * FROM user LIMIT ? 206 | // #5: SELECT * FROM user ORDER BY id LIMIT ? OFFSET ? 207 | // 208 | // SQLServer 209 | // #1: SELECT * FROM user 210 | // #2: SELECT * FROM user ORDER BY 1 OFFSET @p1 ROWS 211 | // #3: SELECT * FROM user ORDER BY 1 OFFSET @p1 ROWS FETCH NEXT @p2 ROWS ONLY 212 | // #4: SELECT * FROM user ORDER BY 1 OFFSET 0 ROWS FETCH NEXT @p1 ROWS ONLY 213 | // #5: SELECT * FROM user ORDER BY id OFFSET @p1 ROWS FETCH NEXT @p2 ROWS ONLY 214 | // 215 | // CQL 216 | // #1: SELECT * FROM user 217 | // #2: SELECT * FROM user 218 | // #3: SELECT * FROM user LIMIT ? 219 | // #4: SELECT * FROM user LIMIT ? 220 | // #5: SELECT * FROM user ORDER BY id LIMIT ? 221 | // 222 | // ClickHouse 223 | // #1: SELECT * FROM user 224 | // #2: SELECT * FROM user 225 | // #3: SELECT * FROM user LIMIT ? OFFSET ? 226 | // #4: SELECT * FROM user LIMIT ? 227 | // #5: SELECT * FROM user ORDER BY id LIMIT ? OFFSET ? 228 | // 229 | // Presto 230 | // #1: SELECT * FROM user 231 | // #2: SELECT * FROM user OFFSET ? 232 | // #3: SELECT * FROM user OFFSET ? LIMIT ? 233 | // #4: SELECT * FROM user LIMIT ? 234 | // #5: SELECT * FROM user ORDER BY id OFFSET ? LIMIT ? 235 | // 236 | // Oracle 237 | // #1: SELECT * FROM user 238 | // #2: SELECT * FROM (SELECT ROWNUM r, * FROM (SELECT * FROM user) user) WHERE r >= :1 + 1 239 | // #3: SELECT * FROM (SELECT ROWNUM r, * FROM (SELECT * FROM user) user) WHERE r BETWEEN :1 + 1 AND :2 + :3 240 | // #4: SELECT * FROM (SELECT ROWNUM r, * FROM (SELECT * FROM user) user) WHERE r BETWEEN 1 AND :1 + 1 241 | // #5: SELECT * FROM (SELECT ROWNUM r, * FROM (SELECT * FROM user ORDER BY id) user) WHERE r BETWEEN :1 + 1 AND :2 + :3 242 | // 243 | // Informix 244 | // #1: SELECT * FROM user 245 | // #2: SELECT * FROM user 246 | // #3: SELECT * FROM user SKIP ? FIRST ? 247 | // #4: SELECT * FROM user FIRST ? 248 | // #5: SELECT * FROM user ORDER BY id SKIP ? FIRST ? 249 | // 250 | // Doris 251 | // #1: SELECT * FROM user 252 | // #2: SELECT * FROM user 253 | // #3: SELECT * FROM user LIMIT 1 OFFSET 0 254 | // #4: SELECT * FROM user LIMIT 1 255 | // #5: SELECT * FROM user ORDER BY id LIMIT 1 OFFSET 1 256 | } 257 | 258 | func ExampleSelectBuilder_ForUpdate() { 259 | sb := newSelectBuilder() 260 | sb.Select("*").From("user").Where( 261 | sb.Equal("id", 1234), 262 | ).ForUpdate() 263 | 264 | sql, args := sb.Build() 265 | fmt.Println(sql) 266 | fmt.Println(args) 267 | 268 | // Output: 269 | // SELECT * FROM user WHERE id = ? FOR UPDATE 270 | // [1234] 271 | } 272 | 273 | func ExampleSelectBuilder_varInCols() { 274 | // Column name may contain some characters, e.g. the $ sign, which have special meanings in builders. 275 | // It's recommended to call Escape() or EscapeAll() to escape the name. 276 | 277 | sb := NewSelectBuilder() 278 | v := sb.Var("foo") 279 | sb.Select(Escape("colHasA$Sign"), v) 280 | sb.From("table") 281 | 282 | s, args := sb.Build() 283 | fmt.Println(s) 284 | fmt.Println(args) 285 | 286 | // Output: 287 | // SELECT colHasA$Sign, ? FROM table 288 | // [foo] 289 | } 290 | 291 | func ExampleSelectBuilder_SQL() { 292 | sb := NewSelectBuilder() 293 | sb.SQL("/* before */") 294 | sb.Select("u.id", "u.name", "c.type", "p.nickname") 295 | sb.SQL("/* after select */") 296 | sb.From("user u") 297 | sb.SQL("/* after from */") 298 | sb.Join("contract c", 299 | "u.id = c.user_id", 300 | ) 301 | sb.JoinWithOption(RightOuterJoin, "person p", 302 | "u.id = p.user_id", 303 | ) 304 | sb.SQL("/* after join */") 305 | sb.Where( 306 | "u.modified_at > u.created_at", 307 | ) 308 | sb.SQL("/* after where */") 309 | sb.OrderBy("id") 310 | sb.SQL("/* after order by */") 311 | sb.Limit(10) 312 | sb.SQL("/* after limit */") 313 | sb.ForShare() 314 | sb.SQL("/* after for */") 315 | 316 | s := sb.String() 317 | fmt.Println(s) 318 | 319 | // Output: 320 | // /* before */ SELECT u.id, u.name, c.type, p.nickname /* after select */ FROM user u /* after from */ JOIN contract c ON u.id = c.user_id RIGHT OUTER JOIN person p ON u.id = p.user_id /* after join */ WHERE u.modified_at > u.created_at /* after where */ ORDER BY id /* after order by */ LIMIT ? /* after limit */ FOR SHARE /* after for */ 321 | } 322 | 323 | // Example for issue #115. 324 | func ExampleSelectBuilder_customSELECT() { 325 | sb := NewSelectBuilder() 326 | 327 | // Set a custom SELECT clause. 328 | sb.SQL("SELECT id, name FROM user").Where( 329 | sb.In("id", 1, 2, 3), 330 | ) 331 | 332 | s, args := sb.Build() 333 | fmt.Println(s) 334 | fmt.Println(args) 335 | 336 | // Output: 337 | // SELECT id, name FROM user WHERE id IN (?, ?, ?) 338 | // [1 2 3] 339 | } 340 | 341 | func ExampleSelectBuilder_NumCol() { 342 | sb := NewSelectBuilder() 343 | sb.Select("id", "name", "created_at") 344 | sb.From("demo.user") 345 | sb.Where( 346 | sb.GreaterThan("id", 1234), 347 | ) 348 | 349 | // Count the number of columns. 350 | fmt.Println(sb.NumCol()) 351 | 352 | // Output: 353 | // 3 354 | } 355 | 356 | func ExampleSelectBuilder_With() { 357 | sql := With( 358 | CTEQuery("users").As( 359 | Select("id", "name").From("users").Where("prime IS NOT NULL"), 360 | ), 361 | 362 | // The CTE table orders will be added to table list of FROM clause automatically. 363 | CTETable("orders").As( 364 | Select("id", "user_id").From("orders"), 365 | ), 366 | ).Select("orders.id").Join("users", "orders.user_id = users.id").Limit(10).String() 367 | 368 | fmt.Println(sql) 369 | 370 | // Output: 371 | // WITH users AS (SELECT id, name FROM users WHERE prime IS NOT NULL), orders AS (SELECT id, user_id FROM orders) SELECT orders.id FROM orders JOIN users ON orders.user_id = users.id LIMIT ? 372 | } 373 | 374 | func TestSelectBuilderSelectMore(t *testing.T) { 375 | a := assert.New(t) 376 | sb := Select("id").SQL("/* first */").Where( 377 | "name IS NOT NULL", 378 | ).SQL("/* second */").SelectMore("name").SQL("/* third */") 379 | a.Equal(sb.String(), "SELECT id, name /* first */ /* third */ WHERE name IS NOT NULL /* second */") 380 | } 381 | 382 | func TestSelectBuilderGetFlavor(t *testing.T) { 383 | a := assert.New(t) 384 | sb := newSelectBuilder() 385 | 386 | sb.SetFlavor(PostgreSQL) 387 | flavor := sb.Flavor() 388 | a.Equal(PostgreSQL, flavor) 389 | 390 | sbClick := ClickHouse.NewSelectBuilder() 391 | flavor = sbClick.Flavor() 392 | a.Equal(ClickHouse, flavor) 393 | } 394 | 395 | func ExampleSelectBuilder_LateralAs() { 396 | // Demo SQL comes from a sample on https://dev.mysql.com/doc/refman/8.4/en/lateral-derived-tables.html. 397 | sb := Select( 398 | "salesperson.name", 399 | "max_sale.amount", 400 | "max_sale.customer_name", 401 | ) 402 | sb.From( 403 | "salesperson", 404 | sb.LateralAs( 405 | Select("amount", "customer_name"). 406 | From("all_sales"). 407 | Where( 408 | "all_sales.salesperson_id = salesperson.id", 409 | ). 410 | OrderBy("amount").Desc().Limit(1), 411 | "max_sale", 412 | ), 413 | ) 414 | 415 | fmt.Println(sb) 416 | 417 | // Output: 418 | // SELECT salesperson.name, max_sale.amount, max_sale.customer_name FROM salesperson, LATERAL (SELECT amount, customer_name FROM all_sales WHERE all_sales.salesperson_id = salesperson.id ORDER BY amount DESC LIMIT ?) AS max_sale 419 | } 420 | -------------------------------------------------------------------------------- /stringbuilder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "io" 8 | "strings" 9 | ) 10 | 11 | type stringBuilder struct { 12 | builder *strings.Builder 13 | } 14 | 15 | var _ io.Writer = new(stringBuilder) 16 | 17 | func newStringBuilder() *stringBuilder { 18 | return &stringBuilder{ 19 | builder: &strings.Builder{}, 20 | } 21 | } 22 | 23 | // WriteLeadingString writes s to internal buffer. 24 | // If it's not the first time to write the string, a blank (" ") will be written before s. 25 | func (sb *stringBuilder) WriteLeadingString(s string) { 26 | if sb.builder.Len() > 0 { 27 | sb.builder.WriteString(" ") 28 | } 29 | 30 | sb.builder.WriteString(s) 31 | } 32 | 33 | func (sb *stringBuilder) WriteString(s string) { 34 | sb.builder.WriteString(s) 35 | } 36 | 37 | func (sb *stringBuilder) WriteStrings(ss []string, sep string) { 38 | if len(ss) == 0 { 39 | return 40 | } 41 | 42 | firstAdded := false 43 | if len(ss[0]) != 0 { 44 | sb.WriteString(ss[0]) 45 | firstAdded = true 46 | } 47 | 48 | for _, s := range ss[1:] { 49 | if len(s) != 0 { 50 | if firstAdded { 51 | sb.WriteString(sep) 52 | } 53 | sb.WriteString(s) 54 | firstAdded = true 55 | } 56 | } 57 | } 58 | 59 | func (sb *stringBuilder) WriteRune(r rune) { 60 | sb.builder.WriteRune(r) 61 | } 62 | 63 | func (sb *stringBuilder) Write(data []byte) (int, error) { 64 | return sb.builder.Write(data) 65 | } 66 | 67 | func (sb *stringBuilder) String() string { 68 | return sb.builder.String() 69 | } 70 | 71 | func (sb *stringBuilder) Reset() { 72 | sb.builder.Reset() 73 | } 74 | 75 | func (sb *stringBuilder) Grow(n int) { 76 | sb.builder.Grow(n) 77 | } 78 | 79 | // filterEmptyStrings removes empty strings from ss. 80 | // As ss rarely contains empty strings, filterEmptyStrings tries to avoid allocation if possible. 81 | func filterEmptyStrings(ss []string) []string { 82 | emptyStrings := 0 83 | 84 | for _, s := range ss { 85 | if len(s) == 0 { 86 | emptyStrings++ 87 | } 88 | } 89 | 90 | if emptyStrings == 0 { 91 | return ss 92 | } 93 | 94 | filtered := make([]string, 0, len(ss)-emptyStrings) 95 | 96 | for _, s := range ss { 97 | if len(s) != 0 { 98 | filtered = append(filtered, s) 99 | } 100 | } 101 | 102 | return filtered 103 | } 104 | -------------------------------------------------------------------------------- /structfields.go: -------------------------------------------------------------------------------- 1 | package sqlbuilder 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | "sync" 8 | ) 9 | 10 | type structFields struct { 11 | noTag *structTaggedFields 12 | tagged map[string]*structTaggedFields 13 | } 14 | 15 | type structTaggedFields struct { 16 | // All columns for SELECT. 17 | ForRead []*structField 18 | colsForRead map[string]*structField 19 | 20 | // All columns which can be used in INSERT and UPDATE. 21 | ForWrite []*structField 22 | colsForWrite map[string]struct{} 23 | } 24 | 25 | type structField struct { 26 | Name string 27 | Alias string 28 | As string 29 | Tags []string 30 | IsQuoted bool 31 | DBTag string 32 | Field reflect.StructField 33 | 34 | omitEmptyTags omitEmptyTagMap 35 | } 36 | 37 | type structFieldsParser func() *structFields 38 | 39 | func makeDefaultFieldsParser(t reflect.Type) structFieldsParser { 40 | return makeFieldsParser(t, nil, true) 41 | } 42 | 43 | func makeCustomFieldsParser(t reflect.Type, mapper FieldMapperFunc) structFieldsParser { 44 | return makeFieldsParser(t, mapper, false) 45 | } 46 | 47 | func makeFieldsParser(t reflect.Type, mapper FieldMapperFunc, useDefault bool) structFieldsParser { 48 | var once sync.Once 49 | sfs := &structFields{ 50 | noTag: makeStructTaggedFields(), 51 | tagged: map[string]*structTaggedFields{}, 52 | } 53 | 54 | return func() *structFields { 55 | once.Do(func() { 56 | if useDefault { 57 | mapper = DefaultFieldMapper 58 | } 59 | 60 | sfs.parse(t, mapper, "") 61 | }) 62 | 63 | return sfs 64 | } 65 | } 66 | 67 | func (sfs *structFields) parse(t reflect.Type, mapper FieldMapperFunc, prefix string) { 68 | l := t.NumField() 69 | var anonymous []reflect.StructField 70 | 71 | for i := 0; i < l; i++ { 72 | field := t.Field(i) 73 | 74 | // Skip unexported fields that are not embedded structs. 75 | if field.PkgPath != "" && !field.Anonymous { 76 | continue 77 | } 78 | 79 | if field.Anonymous { 80 | ft := field.Type 81 | 82 | // If field is an anonymous struct or pointer to struct, parse it later. 83 | if k := ft.Kind(); k == reflect.Struct || (k == reflect.Ptr && ft.Elem().Kind() == reflect.Struct) { 84 | anonymous = append(anonymous, field) 85 | continue 86 | } 87 | } 88 | 89 | // Parse DBTag. 90 | alias, dbtag := DefaultGetAlias(&field) 91 | 92 | if alias == "-" { 93 | continue 94 | } 95 | 96 | if alias == "" { 97 | alias = field.Name 98 | if mapper != nil { 99 | alias = mapper(alias) 100 | } 101 | } 102 | 103 | // Parse FieldOpt. 104 | fieldopt := field.Tag.Get(FieldOpt) 105 | opts := optRegex.FindAllString(fieldopt, -1) 106 | isQuoted := false 107 | omitEmptyTags := omitEmptyTagMap{} 108 | 109 | for _, opt := range opts { 110 | optMap := getOptMatchedMap(opt) 111 | 112 | switch optMap[optName] { 113 | case fieldOptOmitEmpty: 114 | tags := getTagsFromOptParams(optMap[optParams]) 115 | 116 | for _, tag := range tags { 117 | omitEmptyTags[tag] = struct{}{} 118 | } 119 | 120 | case fieldOptWithQuote: 121 | isQuoted = true 122 | } 123 | } 124 | 125 | // Parse FieldAs. 126 | fieldas := field.Tag.Get(FieldAs) 127 | 128 | // Parse FieldTag. 129 | fieldtag := field.Tag.Get(FieldTag) 130 | tags := splitTags(fieldtag) 131 | 132 | // Make struct field. 133 | structField := &structField{ 134 | Name: field.Name, 135 | Alias: alias, 136 | As: fieldas, 137 | Tags: tags, 138 | IsQuoted: isQuoted, 139 | DBTag: dbtag, 140 | Field: field, 141 | omitEmptyTags: omitEmptyTags, 142 | } 143 | 144 | // Make sure all fields can be added to noTag without conflict. 145 | sfs.noTag.Add(structField) 146 | 147 | for _, tag := range tags { 148 | sfs.taggedFields(tag).Add(structField) 149 | } 150 | } 151 | 152 | for _, field := range anonymous { 153 | ft := dereferencedType(field.Type) 154 | sfs.parse(ft, mapper, prefix+field.Name+".") 155 | } 156 | } 157 | 158 | func (sfs *structFields) FilterTags(with, without []string) *structTaggedFields { 159 | if len(with) == 0 && len(without) == 0 { 160 | return sfs.noTag 161 | } 162 | 163 | // Simply return the tagged fields. 164 | if len(with) == 1 && len(without) == 0 { 165 | return sfs.tagged[with[0]] 166 | } 167 | 168 | // Find out all with and without fields. 169 | taggedFields := makeStructTaggedFields() 170 | filteredReadFields := make(map[string]struct{}, len(sfs.noTag.colsForRead)) 171 | 172 | for _, tag := range without { 173 | if field, ok := sfs.tagged[tag]; ok { 174 | for k := range field.colsForRead { 175 | filteredReadFields[k] = struct{}{} 176 | } 177 | } 178 | } 179 | 180 | if len(with) == 0 { 181 | for _, field := range sfs.noTag.ForRead { 182 | k := field.Key() 183 | 184 | if _, ok := filteredReadFields[k]; !ok { 185 | taggedFields.Add(field) 186 | } 187 | } 188 | } else { 189 | for _, tag := range with { 190 | if fields, ok := sfs.tagged[tag]; ok { 191 | for _, field := range fields.ForRead { 192 | k := field.Key() 193 | 194 | if _, ok := filteredReadFields[k]; !ok { 195 | taggedFields.Add(field) 196 | } 197 | } 198 | } 199 | } 200 | } 201 | 202 | return taggedFields 203 | } 204 | 205 | func (sfs *structFields) taggedFields(tag string) *structTaggedFields { 206 | fields, ok := sfs.tagged[tag] 207 | 208 | if !ok { 209 | fields = makeStructTaggedFields() 210 | sfs.tagged[tag] = fields 211 | } 212 | 213 | return fields 214 | } 215 | 216 | func makeStructTaggedFields() *structTaggedFields { 217 | return &structTaggedFields{ 218 | colsForRead: map[string]*structField{}, 219 | colsForWrite: map[string]struct{}{}, 220 | } 221 | } 222 | 223 | // Add a new field to stfs. 224 | // If field's key exists in stfs.fields, the field is ignored. 225 | func (stfs *structTaggedFields) Add(field *structField) { 226 | key := field.Key() 227 | 228 | if _, ok := stfs.colsForRead[key]; !ok { 229 | stfs.colsForRead[key] = field 230 | stfs.ForRead = append(stfs.ForRead, field) 231 | } 232 | 233 | key = field.Alias 234 | 235 | if _, ok := stfs.colsForWrite[key]; !ok { 236 | stfs.colsForWrite[key] = struct{}{} 237 | stfs.ForWrite = append(stfs.ForWrite, field) 238 | } 239 | } 240 | 241 | // Cols returns the fields whose key is one of cols. 242 | // If any column in cols doesn't exist in sfs.fields, returns nil. 243 | func (stfs *structTaggedFields) Cols(cols []string) []*structField { 244 | fields := make([]*structField, 0, len(cols)) 245 | 246 | for _, col := range cols { 247 | field := stfs.colsForRead[col] 248 | 249 | if field == nil { 250 | return nil 251 | } 252 | 253 | fields = append(fields, field) 254 | } 255 | 256 | return fields 257 | } 258 | 259 | // Key returns the key name to identify a field. 260 | func (sf *structField) Key() string { 261 | if sf.As != "" { 262 | return sf.As 263 | } 264 | 265 | if sf.Alias != "" { 266 | return sf.Alias 267 | } 268 | 269 | return sf.Name 270 | } 271 | 272 | // NameForSelect returns the name for SELECT. 273 | func (sf *structField) NameForSelect(flavor Flavor) string { 274 | if sf.As == "" { 275 | return sf.Quote(flavor) 276 | } 277 | 278 | return fmt.Sprintf("%s AS %s", sf.Quote(flavor), sf.As) 279 | } 280 | 281 | // Quote the Alias in sf with flavor. 282 | func (sf *structField) Quote(flavor Flavor) string { 283 | if !sf.IsQuoted { 284 | return sf.Alias 285 | } 286 | 287 | return flavor.Quote(sf.Alias) 288 | } 289 | 290 | // ShouldOmitEmpty returns true only if any one of tags is in the omitted tags map. 291 | func (sf *structField) ShouldOmitEmpty(tags ...string) (ret bool) { 292 | omit := sf.omitEmptyTags 293 | 294 | if len(omit) == 0 { 295 | return 296 | } 297 | 298 | // Always check default tag. 299 | if _, ret = omit[""]; ret { 300 | return 301 | } 302 | 303 | for _, tag := range tags { 304 | if _, ret = omit[tag]; ret { 305 | return 306 | } 307 | } 308 | 309 | return 310 | } 311 | 312 | type omitEmptyTagMap map[string]struct{} 313 | 314 | func getOptMatchedMap(opt string) (res map[string]string) { 315 | res = map[string]string{} 316 | sm := optRegex.FindStringSubmatch(opt) 317 | 318 | for i, name := range optRegex.SubexpNames() { 319 | if name != "" { 320 | res[name] = sm[i] 321 | } 322 | } 323 | 324 | return 325 | } 326 | 327 | func getTagsFromOptParams(opts string) (tags []string) { 328 | tags = splitTags(opts) 329 | 330 | if len(tags) == 0 { 331 | tags = append(tags, "") 332 | } 333 | 334 | return 335 | } 336 | 337 | func splitTags(fieldtag string) (tags []string) { 338 | parts := strings.Split(fieldtag, ",") 339 | 340 | for _, v := range parts { 341 | tag := strings.TrimSpace(v) 342 | 343 | if tag == "" { 344 | continue 345 | } 346 | 347 | tags = append(tags, tag) 348 | } 349 | 350 | return 351 | } 352 | -------------------------------------------------------------------------------- /union.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | const ( 7 | unionDistinct = " UNION " // Default union type is DISTINCT. 8 | unionAll = " UNION ALL " 9 | ) 10 | 11 | const ( 12 | unionMarkerInit injectionMarker = iota 13 | unionMarkerAfterUnion 14 | unionMarkerAfterOrderBy 15 | unionMarkerAfterLimit 16 | ) 17 | 18 | // NewUnionBuilder creates a new UNION builder. 19 | func NewUnionBuilder() *UnionBuilder { 20 | return DefaultFlavor.NewUnionBuilder() 21 | } 22 | 23 | func newUnionBuilder() *UnionBuilder { 24 | return &UnionBuilder{ 25 | args: &Args{}, 26 | injection: newInjection(), 27 | } 28 | } 29 | 30 | // UnionBuilder is a builder to build UNION. 31 | type UnionBuilder struct { 32 | opt string 33 | builderVars []string 34 | orderByCols []string 35 | order string 36 | limitVar string 37 | offsetVar string 38 | 39 | args *Args 40 | 41 | injection *injection 42 | marker injectionMarker 43 | } 44 | 45 | var _ Builder = new(UnionBuilder) 46 | 47 | // Union unions all builders together using UNION operator. 48 | func Union(builders ...Builder) *UnionBuilder { 49 | return DefaultFlavor.NewUnionBuilder().Union(builders...) 50 | } 51 | 52 | // Union unions all builders together using UNION operator. 53 | func (ub *UnionBuilder) Union(builders ...Builder) *UnionBuilder { 54 | return ub.union(unionDistinct, builders...) 55 | } 56 | 57 | // UnionAll unions all builders together using UNION ALL operator. 58 | func UnionAll(builders ...Builder) *UnionBuilder { 59 | return DefaultFlavor.NewUnionBuilder().UnionAll(builders...) 60 | } 61 | 62 | // UnionAll unions all builders together using UNION ALL operator. 63 | func (ub *UnionBuilder) UnionAll(builders ...Builder) *UnionBuilder { 64 | return ub.union(unionAll, builders...) 65 | } 66 | 67 | func (ub *UnionBuilder) union(opt string, builders ...Builder) *UnionBuilder { 68 | builderVars := make([]string, 0, len(builders)) 69 | 70 | for _, b := range builders { 71 | builderVars = append(builderVars, ub.Var(b)) 72 | } 73 | 74 | ub.opt = opt 75 | ub.builderVars = builderVars 76 | ub.marker = unionMarkerAfterUnion 77 | return ub 78 | } 79 | 80 | // OrderBy sets columns of ORDER BY in SELECT. 81 | func (ub *UnionBuilder) OrderBy(col ...string) *UnionBuilder { 82 | ub.orderByCols = col 83 | ub.marker = unionMarkerAfterOrderBy 84 | return ub 85 | } 86 | 87 | // Asc sets order of ORDER BY to ASC. 88 | func (ub *UnionBuilder) Asc() *UnionBuilder { 89 | ub.order = "ASC" 90 | ub.marker = unionMarkerAfterOrderBy 91 | return ub 92 | } 93 | 94 | // Desc sets order of ORDER BY to DESC. 95 | func (ub *UnionBuilder) Desc() *UnionBuilder { 96 | ub.order = "DESC" 97 | ub.marker = unionMarkerAfterOrderBy 98 | return ub 99 | } 100 | 101 | // Limit sets the LIMIT in SELECT. 102 | func (ub *UnionBuilder) Limit(limit int) *UnionBuilder { 103 | if limit < 0 { 104 | ub.limitVar = "" 105 | return ub 106 | } 107 | 108 | ub.limitVar = ub.Var(limit) 109 | ub.marker = unionMarkerAfterLimit 110 | return ub 111 | } 112 | 113 | // Offset sets the LIMIT offset in SELECT. 114 | func (ub *UnionBuilder) Offset(offset int) *UnionBuilder { 115 | if offset < 0 { 116 | ub.offsetVar = "" 117 | return ub 118 | } 119 | 120 | ub.offsetVar = ub.Var(offset) 121 | ub.marker = unionMarkerAfterLimit 122 | return ub 123 | } 124 | 125 | // String returns the compiled SELECT string. 126 | func (ub *UnionBuilder) String() string { 127 | s, _ := ub.Build() 128 | return s 129 | } 130 | 131 | // Build returns compiled SELECT string and args. 132 | // They can be used in `DB#Query` of package `database/sql` directly. 133 | func (ub *UnionBuilder) Build() (sql string, args []interface{}) { 134 | return ub.BuildWithFlavor(ub.args.Flavor) 135 | } 136 | 137 | // BuildWithFlavor returns compiled SELECT string and args with flavor and initial args. 138 | // They can be used in `DB#Query` of package `database/sql` directly. 139 | func (ub *UnionBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 140 | buf := newStringBuilder() 141 | ub.injection.WriteTo(buf, unionMarkerInit) 142 | 143 | if len(ub.builderVars) > 0 { 144 | needParen := flavor != SQLite 145 | 146 | if needParen { 147 | buf.WriteLeadingString("(") 148 | buf.WriteString(ub.builderVars[0]) 149 | buf.WriteRune(')') 150 | } else { 151 | buf.WriteLeadingString(ub.builderVars[0]) 152 | } 153 | 154 | for _, b := range ub.builderVars[1:] { 155 | buf.WriteString(ub.opt) 156 | 157 | if needParen { 158 | buf.WriteRune('(') 159 | } 160 | 161 | buf.WriteString(b) 162 | 163 | if needParen { 164 | buf.WriteRune(')') 165 | } 166 | } 167 | } 168 | 169 | ub.injection.WriteTo(buf, unionMarkerAfterUnion) 170 | 171 | if len(ub.orderByCols) > 0 { 172 | buf.WriteLeadingString("ORDER BY ") 173 | buf.WriteStrings(ub.orderByCols, ", ") 174 | 175 | if ub.order != "" { 176 | buf.WriteRune(' ') 177 | buf.WriteString(ub.order) 178 | } 179 | 180 | ub.injection.WriteTo(buf, unionMarkerAfterOrderBy) 181 | } 182 | 183 | if len(ub.limitVar) > 0 { 184 | buf.WriteLeadingString("LIMIT ") 185 | buf.WriteString(ub.limitVar) 186 | 187 | } 188 | 189 | if ((MySQL == flavor || Informix == flavor) && len(ub.limitVar) > 0) || PostgreSQL == flavor { 190 | if len(ub.offsetVar) > 0 { 191 | buf.WriteLeadingString("OFFSET ") 192 | buf.WriteString(ub.offsetVar) 193 | } 194 | } 195 | 196 | if len(ub.limitVar) > 0 { 197 | ub.injection.WriteTo(buf, unionMarkerAfterLimit) 198 | } 199 | 200 | return ub.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 201 | } 202 | 203 | // SetFlavor sets the flavor of compiled sql. 204 | func (ub *UnionBuilder) SetFlavor(flavor Flavor) (old Flavor) { 205 | old = ub.args.Flavor 206 | ub.args.Flavor = flavor 207 | return 208 | } 209 | 210 | // Flavor returns flavor of builder 211 | func (ub *UnionBuilder) Flavor() Flavor { 212 | return ub.args.Flavor 213 | } 214 | 215 | // Var returns a placeholder for value. 216 | func (ub *UnionBuilder) Var(arg interface{}) string { 217 | return ub.args.Add(arg) 218 | } 219 | 220 | // SQL adds an arbitrary sql to current position. 221 | func (ub *UnionBuilder) SQL(sql string) *UnionBuilder { 222 | ub.injection.SQL(ub.marker, sql) 223 | return ub 224 | } 225 | -------------------------------------------------------------------------------- /union_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleUnion() { 14 | sb1 := NewSelectBuilder() 15 | sb1.Select("id", "name", "created_at") 16 | sb1.From("demo.user") 17 | sb1.Where( 18 | sb1.GreaterThan("id", 1234), 19 | ) 20 | 21 | sb2 := newSelectBuilder() 22 | sb2.Select("id", "avatar") 23 | sb2.From("demo.user_profile") 24 | sb2.Where( 25 | sb2.In("status", 1, 2, 5), 26 | ) 27 | 28 | ub := Union(sb1, sb2) 29 | ub.OrderBy("created_at").Desc() 30 | 31 | sql, args := ub.Build() 32 | fmt.Println(sql) 33 | fmt.Println(args) 34 | 35 | // Output: 36 | // (SELECT id, name, created_at FROM demo.user WHERE id > ?) UNION (SELECT id, avatar FROM demo.user_profile WHERE status IN (?, ?, ?)) ORDER BY created_at DESC 37 | // [1234 1 2 5] 38 | } 39 | 40 | func ExampleUnionAll() { 41 | sb := NewSelectBuilder() 42 | sb.Select("id", "name", "created_at") 43 | sb.From("demo.user") 44 | sb.Where( 45 | sb.GreaterThan("id", 1234), 46 | ) 47 | 48 | ub := UnionAll(sb, Build("TABLE demo.user_profile")) 49 | ub.OrderBy("created_at").Asc() 50 | ub.Limit(100).Offset(5) 51 | 52 | sql, args := ub.Build() 53 | fmt.Println(sql) 54 | fmt.Println(args) 55 | 56 | // Output: 57 | // (SELECT id, name, created_at FROM demo.user WHERE id > ?) UNION ALL (TABLE demo.user_profile) ORDER BY created_at ASC LIMIT ? OFFSET ? 58 | // [1234 100 5] 59 | } 60 | 61 | func ExampleUnionBuilder_SQL() { 62 | sb1 := NewSelectBuilder() 63 | sb1.Select("id", "name", "created_at") 64 | sb1.From("demo.user") 65 | 66 | sb2 := newSelectBuilder() 67 | sb2.Select("id", "avatar") 68 | sb2.From("demo.user_profile") 69 | 70 | ub := NewUnionBuilder() 71 | ub.SQL("/* before */") 72 | ub.Union(sb1, sb2) 73 | ub.SQL("/* after union */") 74 | ub.OrderBy("created_at").Desc() 75 | ub.SQL("/* after order by */") 76 | ub.Limit(100).Offset(5) 77 | ub.SQL("/* after limit */") 78 | 79 | sql := ub.String() 80 | fmt.Println(sql) 81 | 82 | // Output: 83 | // /* before */ (SELECT id, name, created_at FROM demo.user) UNION (SELECT id, avatar FROM demo.user_profile) /* after union */ ORDER BY created_at DESC /* after order by */ LIMIT ? OFFSET ? /* after limit */ 84 | } 85 | 86 | func TestUnionForSQLite(t *testing.T) { 87 | a := assert.New(t) 88 | sb1 := Select("id", "name").From("users").Where("created_at > DATE('now', '-15 days')") 89 | sb2 := Select("id", "nick_name").From("user_extras").Where("status IN (1, 2, 3)") 90 | sql, _ := UnionAll(sb1, sb2).OrderBy("id").BuildWithFlavor(SQLite) 91 | 92 | a.Equal(sql, "SELECT id, name FROM users WHERE created_at > DATE('now', '-15 days') UNION ALL SELECT id, nick_name FROM user_extras WHERE status IN (1, 2, 3) ORDER BY id") 93 | } 94 | 95 | func TestUnionBuilderGetFlavor(t *testing.T) { 96 | a := assert.New(t) 97 | ub := newUnionBuilder() 98 | 99 | ub.SetFlavor(PostgreSQL) 100 | flavor := ub.Flavor() 101 | a.Equal(PostgreSQL, flavor) 102 | 103 | ubClick := ClickHouse.NewUnionBuilder() 104 | flavor = ubClick.Flavor() 105 | a.Equal(ClickHouse, flavor) 106 | } 107 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | ) 9 | 10 | const ( 11 | updateMarkerInit injectionMarker = iota 12 | updateMarkerAfterWith 13 | updateMarkerAfterUpdate 14 | updateMarkerAfterSet 15 | updateMarkerAfterWhere 16 | updateMarkerAfterOrderBy 17 | updateMarkerAfterLimit 18 | ) 19 | 20 | // NewUpdateBuilder creates a new UPDATE builder. 21 | func NewUpdateBuilder() *UpdateBuilder { 22 | return DefaultFlavor.NewUpdateBuilder() 23 | } 24 | 25 | func newUpdateBuilder() *UpdateBuilder { 26 | args := &Args{} 27 | proxy := &whereClauseProxy{} 28 | return &UpdateBuilder{ 29 | whereClauseProxy: proxy, 30 | whereClauseExpr: args.Add(proxy), 31 | 32 | Cond: Cond{ 33 | Args: args, 34 | }, 35 | args: args, 36 | injection: newInjection(), 37 | } 38 | } 39 | 40 | // UpdateBuilder is a builder to build UPDATE. 41 | type UpdateBuilder struct { 42 | *WhereClause 43 | Cond 44 | 45 | whereClauseProxy *whereClauseProxy 46 | whereClauseExpr string 47 | 48 | cteBuilderVar string 49 | cteBuilder *CTEBuilder 50 | 51 | tables []string 52 | assignments []string 53 | orderByCols []string 54 | order string 55 | limitVar string 56 | 57 | args *Args 58 | 59 | injection *injection 60 | marker injectionMarker 61 | } 62 | 63 | var _ Builder = new(UpdateBuilder) 64 | 65 | // Update sets table name in UPDATE. 66 | func Update(table ...string) *UpdateBuilder { 67 | return DefaultFlavor.NewUpdateBuilder().Update(table...) 68 | } 69 | 70 | // With sets WITH clause (the Common Table Expression) before UPDATE. 71 | func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder { 72 | ub.marker = updateMarkerAfterWith 73 | ub.cteBuilderVar = ub.Var(builder) 74 | ub.cteBuilder = builder 75 | return ub 76 | } 77 | 78 | // Update sets table name in UPDATE. 79 | func (ub *UpdateBuilder) Update(table ...string) *UpdateBuilder { 80 | ub.tables = table 81 | ub.marker = updateMarkerAfterUpdate 82 | return ub 83 | } 84 | 85 | // TableNames returns all table names in this UPDATE statement. 86 | func (ub *UpdateBuilder) TableNames() (tableNames []string) { 87 | var additionalTableNames []string 88 | 89 | if ub.cteBuilder != nil { 90 | additionalTableNames = ub.cteBuilder.tableNamesForFrom() 91 | } 92 | 93 | if len(ub.tables) > 0 && len(additionalTableNames) > 0 { 94 | tableNames = make([]string, len(ub.tables)+len(additionalTableNames)) 95 | copy(tableNames, ub.tables) 96 | copy(tableNames[len(ub.tables):], additionalTableNames) 97 | } else if len(ub.tables) > 0 { 98 | tableNames = ub.tables 99 | } else if len(additionalTableNames) > 0 { 100 | tableNames = additionalTableNames 101 | } 102 | 103 | return tableNames 104 | } 105 | 106 | // Set sets the assignments in SET. 107 | func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder { 108 | ub.assignments = assignment 109 | ub.marker = updateMarkerAfterSet 110 | return ub 111 | } 112 | 113 | // SetMore appends the assignments in SET. 114 | func (ub *UpdateBuilder) SetMore(assignment ...string) *UpdateBuilder { 115 | ub.assignments = append(ub.assignments, assignment...) 116 | ub.marker = updateMarkerAfterSet 117 | return ub 118 | } 119 | 120 | // Where sets expressions of WHERE in UPDATE. 121 | func (ub *UpdateBuilder) Where(andExpr ...string) *UpdateBuilder { 122 | if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 { 123 | return ub 124 | } 125 | 126 | if ub.WhereClause == nil { 127 | ub.WhereClause = NewWhereClause() 128 | } 129 | 130 | ub.WhereClause.AddWhereExpr(ub.args, andExpr...) 131 | ub.marker = updateMarkerAfterWhere 132 | return ub 133 | } 134 | 135 | // AddWhereClause adds all clauses in the whereClause to SELECT. 136 | func (ub *UpdateBuilder) AddWhereClause(whereClause *WhereClause) *UpdateBuilder { 137 | if ub.WhereClause == nil { 138 | ub.WhereClause = NewWhereClause() 139 | } 140 | 141 | ub.WhereClause.AddWhereClause(whereClause) 142 | return ub 143 | } 144 | 145 | // Assign represents SET "field = value" in UPDATE. 146 | func (ub *UpdateBuilder) Assign(field string, value interface{}) string { 147 | return fmt.Sprintf("%s = %s", Escape(field), ub.args.Add(value)) 148 | } 149 | 150 | // Incr represents SET "field = field + 1" in UPDATE. 151 | func (ub *UpdateBuilder) Incr(field string) string { 152 | f := Escape(field) 153 | return fmt.Sprintf("%s = %s + 1", f, f) 154 | } 155 | 156 | // Decr represents SET "field = field - 1" in UPDATE. 157 | func (ub *UpdateBuilder) Decr(field string) string { 158 | f := Escape(field) 159 | return fmt.Sprintf("%s = %s - 1", f, f) 160 | } 161 | 162 | // Add represents SET "field = field + value" in UPDATE. 163 | func (ub *UpdateBuilder) Add(field string, value interface{}) string { 164 | f := Escape(field) 165 | return fmt.Sprintf("%s = %s + %s", f, f, ub.args.Add(value)) 166 | } 167 | 168 | // Sub represents SET "field = field - value" in UPDATE. 169 | func (ub *UpdateBuilder) Sub(field string, value interface{}) string { 170 | f := Escape(field) 171 | return fmt.Sprintf("%s = %s - %s", f, f, ub.args.Add(value)) 172 | } 173 | 174 | // Mul represents SET "field = field * value" in UPDATE. 175 | func (ub *UpdateBuilder) Mul(field string, value interface{}) string { 176 | f := Escape(field) 177 | return fmt.Sprintf("%s = %s * %s", f, f, ub.args.Add(value)) 178 | } 179 | 180 | // Div represents SET "field = field / value" in UPDATE. 181 | func (ub *UpdateBuilder) Div(field string, value interface{}) string { 182 | f := Escape(field) 183 | return fmt.Sprintf("%s = %s / %s", f, f, ub.args.Add(value)) 184 | } 185 | 186 | // OrderBy sets columns of ORDER BY in UPDATE. 187 | func (ub *UpdateBuilder) OrderBy(col ...string) *UpdateBuilder { 188 | ub.orderByCols = col 189 | ub.marker = updateMarkerAfterOrderBy 190 | return ub 191 | } 192 | 193 | // Asc sets order of ORDER BY to ASC. 194 | func (ub *UpdateBuilder) Asc() *UpdateBuilder { 195 | ub.order = "ASC" 196 | ub.marker = updateMarkerAfterOrderBy 197 | return ub 198 | } 199 | 200 | // Desc sets order of ORDER BY to DESC. 201 | func (ub *UpdateBuilder) Desc() *UpdateBuilder { 202 | ub.order = "DESC" 203 | ub.marker = updateMarkerAfterOrderBy 204 | return ub 205 | } 206 | 207 | // Limit sets the LIMIT in UPDATE. 208 | func (ub *UpdateBuilder) Limit(limit int) *UpdateBuilder { 209 | if limit < 0 { 210 | ub.limitVar = "" 211 | return ub 212 | } 213 | 214 | ub.limitVar = ub.Var(limit) 215 | ub.marker = updateMarkerAfterLimit 216 | return ub 217 | } 218 | 219 | // NumAssignment returns the number of assignments to update. 220 | func (ub *UpdateBuilder) NumAssignment() int { 221 | return len(ub.assignments) 222 | } 223 | 224 | // String returns the compiled UPDATE string. 225 | func (ub *UpdateBuilder) String() string { 226 | s, _ := ub.Build() 227 | return s 228 | } 229 | 230 | // Build returns compiled UPDATE string and args. 231 | // They can be used in `DB#Query` of package `database/sql` directly. 232 | func (ub *UpdateBuilder) Build() (sql string, args []interface{}) { 233 | return ub.BuildWithFlavor(ub.args.Flavor) 234 | } 235 | 236 | // BuildWithFlavor returns compiled UPDATE string and args with flavor and initial args. 237 | // They can be used in `DB#Query` of package `database/sql` directly. 238 | func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 239 | buf := newStringBuilder() 240 | ub.injection.WriteTo(buf, updateMarkerInit) 241 | 242 | if ub.cteBuilder != nil { 243 | buf.WriteLeadingString(ub.cteBuilderVar) 244 | ub.injection.WriteTo(buf, updateMarkerAfterWith) 245 | } 246 | 247 | switch flavor { 248 | case MySQL: 249 | // CTE table names should be written after UPDATE keyword in MySQL. 250 | tableNames := ub.TableNames() 251 | 252 | if len(tableNames) > 0 { 253 | buf.WriteLeadingString("UPDATE ") 254 | buf.WriteStrings(tableNames, ", ") 255 | } 256 | 257 | default: 258 | if len(ub.tables) > 0 { 259 | buf.WriteLeadingString("UPDATE ") 260 | buf.WriteStrings(ub.tables, ", ") 261 | } 262 | } 263 | 264 | ub.injection.WriteTo(buf, updateMarkerAfterUpdate) 265 | 266 | if assignments := filterEmptyStrings(ub.assignments); len(assignments) > 0 { 267 | buf.WriteLeadingString("SET ") 268 | buf.WriteStrings(assignments, ", ") 269 | } 270 | 271 | ub.injection.WriteTo(buf, updateMarkerAfterSet) 272 | 273 | if flavor != MySQL { 274 | // For ISO SQL, CTE table names should be written after FROM keyword. 275 | if ub.cteBuilder != nil { 276 | cteTableNames := ub.cteBuilder.tableNamesForFrom() 277 | 278 | if len(cteTableNames) > 0 { 279 | buf.WriteLeadingString("FROM ") 280 | buf.WriteStrings(cteTableNames, ", ") 281 | } 282 | } 283 | } 284 | 285 | if ub.WhereClause != nil { 286 | ub.whereClauseProxy.WhereClause = ub.WhereClause 287 | defer func() { 288 | ub.whereClauseProxy.WhereClause = nil 289 | }() 290 | 291 | buf.WriteLeadingString(ub.whereClauseExpr) 292 | ub.injection.WriteTo(buf, updateMarkerAfterWhere) 293 | } 294 | 295 | if len(ub.orderByCols) > 0 { 296 | buf.WriteLeadingString("ORDER BY ") 297 | buf.WriteStrings(ub.orderByCols, ", ") 298 | 299 | if ub.order != "" { 300 | buf.WriteLeadingString(ub.order) 301 | } 302 | 303 | ub.injection.WriteTo(buf, updateMarkerAfterOrderBy) 304 | } 305 | 306 | if len(ub.limitVar) > 0 { 307 | buf.WriteLeadingString("LIMIT ") 308 | buf.WriteString(ub.limitVar) 309 | 310 | ub.injection.WriteTo(buf, updateMarkerAfterLimit) 311 | } 312 | 313 | return ub.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 314 | } 315 | 316 | // SetFlavor sets the flavor of compiled sql. 317 | func (ub *UpdateBuilder) SetFlavor(flavor Flavor) (old Flavor) { 318 | old = ub.args.Flavor 319 | ub.args.Flavor = flavor 320 | return 321 | } 322 | 323 | // Flavor returns flavor of builder 324 | func (ub *UpdateBuilder) Flavor() Flavor { 325 | return ub.args.Flavor 326 | } 327 | 328 | // SQL adds an arbitrary sql to current position. 329 | func (ub *UpdateBuilder) SQL(sql string) *UpdateBuilder { 330 | ub.injection.SQL(ub.marker, sql) 331 | return ub 332 | } 333 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleUpdate() { 14 | sql := Update("demo.user"). 15 | Set( 16 | "visited = visited + 1", 17 | ). 18 | Where( 19 | "id = 1234", 20 | ). 21 | String() 22 | 23 | fmt.Println(sql) 24 | 25 | // Output: 26 | // UPDATE demo.user SET visited = visited + 1 WHERE id = 1234 27 | } 28 | 29 | func ExampleUpdateBuilder() { 30 | ub := NewUpdateBuilder() 31 | ub.Update("demo.user") 32 | ub.Set( 33 | ub.Assign("type", "sys"), 34 | ub.Incr("credit"), 35 | "modified_at = UNIX_TIMESTAMP(NOW())", // It's allowed to write arbitrary SQL. 36 | ) 37 | ub.Where( 38 | ub.GreaterThan("id", 1234), 39 | ub.Like("name", "%Du"), 40 | ub.Or( 41 | ub.IsNull("id_card"), 42 | ub.In("status", 1, 2, 5), 43 | ), 44 | "modified_at > created_at + "+ub.Var(86400), // It's allowed to write arbitrary SQL. 45 | ) 46 | ub.OrderBy("id").Asc() 47 | 48 | sql, args := ub.Build() 49 | fmt.Println(sql) 50 | fmt.Println(args) 51 | 52 | // Output: 53 | // UPDATE demo.user SET type = ?, credit = credit + 1, modified_at = UNIX_TIMESTAMP(NOW()) WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND modified_at > created_at + ? ORDER BY id ASC 54 | // [sys 1234 %Du 1 2 5 86400] 55 | } 56 | 57 | func TestUpdateAssignments(t *testing.T) { 58 | a := assert.New(t) 59 | cases := map[string]func(ub *UpdateBuilder) string{ 60 | "f = f + 1|[]": func(ub *UpdateBuilder) string { return ub.Incr("f") }, 61 | "f = f - 1|[]": func(ub *UpdateBuilder) string { return ub.Decr("f") }, 62 | "f = f + $1|[123]": func(ub *UpdateBuilder) string { return ub.Add("f", 123) }, 63 | "f = f - $1|[123]": func(ub *UpdateBuilder) string { return ub.Sub("f", 123) }, 64 | "f = f * $1|[123]": func(ub *UpdateBuilder) string { return ub.Mul("f", 123) }, 65 | "f = f / $1|[123]": func(ub *UpdateBuilder) string { return ub.Div("f", 123) }, 66 | } 67 | 68 | for expected, f := range cases { 69 | ub := NewUpdateBuilder() 70 | s := f(ub) 71 | ub.Set(s) 72 | _, args := ub.Build() 73 | actual := fmt.Sprintf("%v|%v", s, args) 74 | 75 | a.Equal(actual, expected) 76 | } 77 | } 78 | 79 | func ExampleUpdateBuilder_SetMore() { 80 | ub := NewUpdateBuilder() 81 | ub.Update("demo.user") 82 | ub.Set( 83 | ub.Assign("type", "sys"), 84 | ub.Incr("credit"), 85 | ) 86 | ub.SetMore( 87 | "modified_at = UNIX_TIMESTAMP(NOW())", // It's allowed to write arbitrary SQL. 88 | ) 89 | 90 | sql, args := ub.Build() 91 | fmt.Println(sql) 92 | fmt.Println(args) 93 | 94 | // Output: 95 | // UPDATE demo.user SET type = ?, credit = credit + 1, modified_at = UNIX_TIMESTAMP(NOW()) 96 | // [sys] 97 | } 98 | 99 | func ExampleUpdateBuilder_SQL() { 100 | ub := NewUpdateBuilder() 101 | ub.SQL("/* before */") 102 | ub.Update("demo.user") 103 | ub.SQL("/* after update */") 104 | ub.Set( 105 | ub.Assign("type", "sys"), 106 | ) 107 | ub.SQL("/* after set */") 108 | ub.OrderBy("id").Desc() 109 | ub.SQL("/* after order by */") 110 | ub.Limit(10) 111 | ub.SQL("/* after limit */") 112 | 113 | sql := ub.String() 114 | fmt.Println(sql) 115 | 116 | // Output: 117 | // /* before */ UPDATE demo.user /* after update */ SET type = ? /* after set */ ORDER BY id DESC /* after order by */ LIMIT ? /* after limit */ 118 | } 119 | 120 | func ExampleUpdateBuilder_NumAssignment() { 121 | ub := NewUpdateBuilder() 122 | ub.Update("demo.user") 123 | ub.Set( 124 | ub.Assign("type", "sys"), 125 | ub.Incr("credit"), 126 | "modified_at = UNIX_TIMESTAMP(NOW())", 127 | ) 128 | 129 | // Count the number of assignments. 130 | fmt.Println(ub.NumAssignment()) 131 | 132 | // Output: 133 | // 3 134 | } 135 | 136 | func ExampleUpdateBuilder_With() { 137 | sql := With( 138 | CTETable("users").As( 139 | Select("id", "name").From("users").Where("prime IS NOT NULL"), 140 | ), 141 | ).Update("orders").Set( 142 | "orders.transport_fee = 0", 143 | ).Where( 144 | "users.id = orders.user_id", 145 | ).String() 146 | 147 | fmt.Println(sql) 148 | 149 | // Output: 150 | // WITH users AS (SELECT id, name FROM users WHERE prime IS NOT NULL) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.id = orders.user_id 151 | } 152 | 153 | func TestUpdateBuilderGetFlavor(t *testing.T) { 154 | a := assert.New(t) 155 | ub := newUpdateBuilder() 156 | 157 | ub.SetFlavor(PostgreSQL) 158 | flavor := ub.Flavor() 159 | a.Equal(PostgreSQL, flavor) 160 | 161 | ubClick := ClickHouse.NewUpdateBuilder() 162 | flavor = ubClick.Flavor() 163 | a.Equal(ClickHouse, flavor) 164 | } 165 | -------------------------------------------------------------------------------- /whereclause.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | // WhereClause is a Builder for WHERE clause. 7 | // All builders which support `WHERE` clause have an anonymous `WhereClause` field, 8 | // in which the conditions are stored. 9 | // 10 | // WhereClause can be shared among multiple builders. 11 | // However, it is not thread-safe. 12 | type WhereClause struct { 13 | flavor Flavor 14 | clauses []clause 15 | } 16 | 17 | var _ Builder = new(WhereClause) 18 | 19 | // NewWhereClause creates a new WhereClause. 20 | func NewWhereClause() *WhereClause { 21 | return &WhereClause{} 22 | } 23 | 24 | // CopyWhereClause creates a copy of the whereClause. 25 | func CopyWhereClause(whereClause *WhereClause) *WhereClause { 26 | clauses := make([]clause, len(whereClause.clauses)) 27 | copy(clauses, whereClause.clauses) 28 | 29 | return &WhereClause{ 30 | flavor: whereClause.flavor, 31 | clauses: clauses, 32 | } 33 | } 34 | 35 | type clause struct { 36 | args *Args 37 | andExprs []string 38 | } 39 | 40 | func (c *clause) Build(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 41 | exprs := filterEmptyStrings(c.andExprs) 42 | 43 | if len(exprs) == 0 { 44 | return 45 | } 46 | 47 | buf := newStringBuilder() 48 | buf.WriteStrings(exprs, " AND ") 49 | sql, args = c.args.CompileWithFlavor(buf.String(), flavor, initialArg...) 50 | return 51 | } 52 | 53 | // whereClauseProxy is a proxy for WhereClause. 54 | // It's useful when the WhereClause in a build can be changed. 55 | type whereClauseProxy struct { 56 | *WhereClause 57 | } 58 | 59 | var _ Builder = new(whereClauseProxy) 60 | 61 | // BuildWithFlavor builds a WHERE clause with the specified flavor and initial arguments. 62 | func (wc *WhereClause) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { 63 | if len(wc.clauses) == 0 { 64 | return "", nil 65 | } 66 | 67 | buf := newStringBuilder() 68 | buf.WriteLeadingString("WHERE ") 69 | 70 | sql, args = wc.clauses[0].Build(flavor, initialArg...) 71 | buf.WriteString(sql) 72 | 73 | for _, clause := range wc.clauses[1:] { 74 | buf.WriteString(" AND ") 75 | sql, args = clause.Build(flavor, args...) 76 | buf.WriteString(sql) 77 | } 78 | 79 | return buf.String(), args 80 | } 81 | 82 | // Build returns compiled WHERE clause string and args. 83 | func (wc *WhereClause) Build() (sql string, args []interface{}) { 84 | return wc.BuildWithFlavor(wc.flavor) 85 | } 86 | 87 | // SetFlavor sets the flavor of compiled sql. 88 | // When the WhereClause belongs to a builder, the flavor of the builder will be used when building SQL. 89 | func (wc *WhereClause) SetFlavor(flavor Flavor) (old Flavor) { 90 | old = wc.flavor 91 | wc.flavor = flavor 92 | return 93 | } 94 | 95 | // Flavor returns flavor of clause 96 | func (wc *WhereClause) Flavor() Flavor { 97 | return wc.flavor 98 | } 99 | 100 | // AddWhereExpr adds an AND expression to WHERE clause with the specified arguments. 101 | func (wc *WhereClause) AddWhereExpr(args *Args, andExpr ...string) *WhereClause { 102 | if len(andExpr) == 0 { 103 | return wc 104 | } 105 | 106 | andExprsBytesLen := estimateStringsBytes(andExpr) 107 | 108 | if andExprsBytesLen == 0 { 109 | return wc 110 | } 111 | 112 | // Merge with last clause if possible. 113 | if len(wc.clauses) > 0 { 114 | lastClause := &wc.clauses[len(wc.clauses)-1] 115 | 116 | if lastClause.args == args { 117 | lastClause.andExprs = append(lastClause.andExprs, andExpr...) 118 | return wc 119 | } 120 | } 121 | 122 | wc.clauses = append(wc.clauses, clause{ 123 | args: args, 124 | andExprs: andExpr, 125 | }) 126 | return wc 127 | } 128 | 129 | // AddWhereClause adds all clauses in the whereClause to the wc. 130 | func (wc *WhereClause) AddWhereClause(whereClause *WhereClause) *WhereClause { 131 | if whereClause == nil { 132 | return wc 133 | } 134 | 135 | wc.clauses = append(wc.clauses, whereClause.clauses...) 136 | return wc 137 | } 138 | -------------------------------------------------------------------------------- /whereclause_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Huan Du. All rights reserved. 2 | // Licensed under the MIT license that can be found in the LICENSE file. 3 | 4 | package sqlbuilder 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/huandu/go-assert" 11 | ) 12 | 13 | func ExampleWhereClause() { 14 | // Build a SQL to select a user from database. 15 | sb := Select("name", "level").From("users") 16 | sb.Where( 17 | sb.Equal("id", 1234), 18 | ) 19 | sql, args := sb.Build() 20 | fmt.Println(sql) 21 | fmt.Println(args) 22 | 23 | // Query database with the sql and update this user's level... 24 | 25 | ub := Update("users") 26 | ub.Set( 27 | ub.Add("level", 10), 28 | ) 29 | 30 | // The WHERE clause of UPDATE should be the same as the WHERE clause of SELECT. 31 | ub.WhereClause = sb.WhereClause 32 | 33 | sql, args = ub.Build() 34 | fmt.Println(sql) 35 | fmt.Println(args) 36 | 37 | // Output: 38 | // SELECT name, level FROM users WHERE id = ? 39 | // [1234] 40 | // UPDATE users SET level = level + ? WHERE id = ? 41 | // [10 1234] 42 | } 43 | 44 | func ExampleWhereClause_sharedAmongBuilders() { 45 | // A WhereClause can be shared among builders. 46 | // However, as it's not thread-safe, don't use it in a concurrent environment. 47 | sb1 := Select("level").From("users") 48 | sb2 := Select("status").From("users") 49 | 50 | // Share the same WhereClause between sb1 and sb2. 51 | whereClause := NewWhereClause() 52 | sb1.WhereClause = whereClause 53 | sb2.WhereClause = whereClause 54 | 55 | // The Where method in sb1 and sb2 will update the same WhereClause. 56 | // When we call sb1.Where(), the WHERE clause in sb2 will also be updated. 57 | sb1.Where( 58 | sb1.Like("name", "Charmy%"), 59 | ) 60 | 61 | // We can get a copy of the WhereClause. 62 | // The copy is independent from the original. 63 | sb3 := Select("name").From("users") 64 | sb3.WhereClause = CopyWhereClause(whereClause) 65 | 66 | // Adding more expressions to sb1 and sb2 will not affect sb3. 67 | sb2.Where( 68 | sb2.In("status", 1, 2, 3), 69 | ) 70 | 71 | // Adding more expressions to sb3 will not affect sb1 and sb2. 72 | sb3.Where( 73 | sb3.GreaterEqualThan("level", 10), 74 | ) 75 | 76 | sql1, args1 := sb1.Build() 77 | sql2, args2 := sb2.Build() 78 | sql3, args3 := sb3.Build() 79 | 80 | fmt.Println(sql1) 81 | fmt.Println(args1) 82 | fmt.Println(sql2) 83 | fmt.Println(args2) 84 | fmt.Println(sql3) 85 | fmt.Println(args3) 86 | 87 | // Output: 88 | // SELECT level FROM users WHERE name LIKE ? AND status IN (?, ?, ?) 89 | // [Charmy% 1 2 3] 90 | // SELECT status FROM users WHERE name LIKE ? AND status IN (?, ?, ?) 91 | // [Charmy% 1 2 3] 92 | // SELECT name FROM users WHERE name LIKE ? AND level >= ? 93 | // [Charmy% 10] 94 | } 95 | 96 | func ExampleWhereClause_clearWhereClause() { 97 | db := DeleteFrom("users") 98 | db.Where( 99 | db.GreaterThan("level", 10), 100 | ) 101 | 102 | sql, args := db.Build() 103 | fmt.Println(sql) 104 | fmt.Println(args) 105 | 106 | // Clear WHERE clause. 107 | db.WhereClause = nil 108 | sql, args = db.Build() 109 | fmt.Println(sql) 110 | fmt.Println(args) 111 | 112 | db.Where( 113 | db.Equal("id", 1234), 114 | ) 115 | sql, args = db.Build() 116 | fmt.Println(sql) 117 | fmt.Println(args) 118 | 119 | // Output: 120 | // DELETE FROM users WHERE level > ? 121 | // [10] 122 | // DELETE FROM users 123 | // [] 124 | // DELETE FROM users WHERE id = ? 125 | // [1234] 126 | } 127 | 128 | func ExampleWhereClause_AddWhereExpr() { 129 | // WhereClause can be used as a standalone builder to build WHERE clause. 130 | // It's recommended to use it with Cond. 131 | whereClause := NewWhereClause() 132 | cond := NewCond() 133 | 134 | whereClause.AddWhereExpr( 135 | cond.Args, 136 | cond.In("name", "Charmy", "Huan"), 137 | cond.LessEqualThan("level", 10), 138 | ) 139 | 140 | // Set the flavor of the WhereClause to PostgreSQL. 141 | whereClause.SetFlavor(PostgreSQL) 142 | 143 | sql, args := whereClause.Build() 144 | fmt.Println(sql) 145 | fmt.Println(args) 146 | 147 | // Use this WhereClause in another builder. 148 | sb := MySQL.NewSelectBuilder() 149 | sb.Select("name", "level").From("users") 150 | sb.WhereClause = whereClause 151 | 152 | // The flavor of sb overrides the flavor of the WhereClause. 153 | sql, args = sb.Build() 154 | fmt.Println(sql) 155 | fmt.Println(args) 156 | 157 | // Output: 158 | // WHERE name IN ($1, $2) AND level <= $3 159 | // [Charmy Huan 10] 160 | // SELECT name, level FROM users WHERE name IN (?, ?) AND level <= ? 161 | // [Charmy Huan 10] 162 | } 163 | 164 | func ExampleWhereClause_AddWhereClause() { 165 | sb := Select("level").From("users") 166 | sb.Where( 167 | sb.Equal("id", 1234), 168 | ) 169 | 170 | sql, args := sb.Build() 171 | fmt.Println(sql) 172 | fmt.Println(args) 173 | 174 | ub := Update("users") 175 | ub.Set( 176 | ub.Add("level", 10), 177 | ) 178 | 179 | // Copy the WHERE clause of sb into ub and add more expressions. 180 | ub.AddWhereClause(sb.WhereClause).Where( 181 | ub.Equal("deleted", 0), 182 | ) 183 | 184 | sql, args = ub.Build() 185 | fmt.Println(sql) 186 | fmt.Println(args) 187 | 188 | // Output: 189 | // SELECT level FROM users WHERE id = ? 190 | // [1234] 191 | // UPDATE users SET level = level + ? WHERE id = ? AND deleted = ? 192 | // [10 1234 0] 193 | } 194 | 195 | func TestWhereClauseSharedInstances(t *testing.T) { 196 | a := assert.New(t) 197 | sb := Select("*").From("t") 198 | ub := Update("t").Set("foo = 1") 199 | db := DeleteFrom("t") 200 | 201 | whereClause := NewWhereClause() 202 | sb.WhereClause = whereClause 203 | ub.WhereClause = whereClause 204 | db.WhereClause = whereClause 205 | sb.Where(sb.Equal("id", 123)) 206 | a.Equal(sb.String(), "SELECT * FROM t WHERE id = ?") 207 | a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ?") 208 | a.Equal(db.String(), "DELETE FROM t WHERE id = ?") 209 | 210 | // Add more WhereClause. 211 | cond := NewCond() 212 | moreWhereClause := NewWhereClause().AddWhereExpr( 213 | cond.Args, 214 | cond.GreaterEqualThan("credit", 100), 215 | ) 216 | 217 | // The moreWhereClause is added to whereClause. 218 | // All builders sharing the same WhereClause will have the same new cluase. 219 | sb.AddWhereClause(moreWhereClause) 220 | a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ?") 221 | a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ?") 222 | a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ?") 223 | 224 | // Copied WhereClause is independent from the original. 225 | ub.WhereClause = CopyWhereClause(whereClause) 226 | ub.Where(ub.GreaterEqualThan("level", 10)) 227 | db.Where(db.In("status", 1, 2)) 228 | a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") 229 | a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ?") 230 | a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") 231 | 232 | // Clear the WhereClause and add new where clause and expressions. 233 | db.WhereClause = nil 234 | db.AddWhereClause(ub.WhereClause) 235 | db.AddWhereExpr(db.Args, db.Equal("deleted", 0)) 236 | a.Equal(sb.String(), "SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?)") 237 | a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ?") 238 | a.Equal(db.String(), "DELETE FROM t WHERE id = ? AND credit >= ? AND level >= ? AND deleted = ?") 239 | 240 | // Nested WhereClause. 241 | ub.Where(ub.NotIn("id", sb)) 242 | sb.Where(sb.NotEqual("flag", "normal")) 243 | a.Equal(ub.String(), "UPDATE t SET foo = 1 WHERE id = ? AND credit >= ? AND level >= ? AND id NOT IN (SELECT * FROM t WHERE id = ? AND credit >= ? AND status IN (?, ?) AND flag <> ?)") 244 | } 245 | 246 | func TestEmptyWhereExpr(t *testing.T) { 247 | a := assert.New(t) 248 | blankExprs := []string{"", ""} 249 | sb := Select("*").From("t").Where(blankExprs...) 250 | ub := Update("t").Set("foo = 1").Where(blankExprs...) 251 | db := DeleteFrom("t").Where(blankExprs...) 252 | 253 | a.Equal(sb.String(), "SELECT * FROM t") 254 | a.Equal(ub.String(), "UPDATE t SET foo = 1") 255 | a.Equal(db.String(), "DELETE FROM t") 256 | } 257 | 258 | func TestEmptyStringsWhere(t *testing.T) { 259 | a := assert.New(t) 260 | emptyExpr := []string{"", "", ""} 261 | 262 | sb := Select("*").From("t").Where(emptyExpr...) 263 | ub := Update("t").Set("foo = 1").Where(emptyExpr...) 264 | db := DeleteFrom("t").Where(emptyExpr...) 265 | 266 | a.Equal(sb.String(), "SELECT * FROM t") 267 | a.Equal(ub.String(), "UPDATE t SET foo = 1") 268 | a.Equal(db.String(), "DELETE FROM t") 269 | } 270 | 271 | func TestEmptyAddWhereExpr(t *testing.T) { 272 | a := assert.New(t) 273 | var emptyExpr []string 274 | sb := Select("*").From("t") 275 | ub := Update("t").Set("foo = 1") 276 | db := DeleteFrom("t") 277 | 278 | cond := NewCond() 279 | whereClause := NewWhereClause().AddWhereExpr( 280 | cond.Args, 281 | emptyExpr..., 282 | ) 283 | 284 | sb.AddWhereClause(whereClause) 285 | ub.AddWhereClause(whereClause) 286 | db.AddWhereClause(whereClause) 287 | 288 | a.Equal(sb.String(), "SELECT * FROM t ") 289 | a.Equal(ub.String(), "UPDATE t SET foo = 1 ") 290 | a.Equal(db.String(), "DELETE FROM t ") 291 | } 292 | 293 | func TestEmptyStringsWhereAddWhereExpr(t *testing.T) { 294 | a := assert.New(t) 295 | emptyExpr := []string{"", "", ""} 296 | sb := Select("*").From("t") 297 | ub := Update("t").Set("foo = 1") 298 | db := DeleteFrom("t") 299 | 300 | cond := NewCond() 301 | whereClause := NewWhereClause().AddWhereExpr( 302 | cond.Args, 303 | emptyExpr..., 304 | ) 305 | 306 | sb.AddWhereClause(whereClause) 307 | ub.AddWhereClause(whereClause) 308 | db.AddWhereClause(whereClause) 309 | 310 | a.Equal(sb.String(), "SELECT * FROM t ") 311 | a.Equal(ub.String(), "UPDATE t SET foo = 1 ") 312 | a.Equal(db.String(), "DELETE FROM t ") 313 | } 314 | 315 | func TestWhereClauseGetFlavor(t *testing.T) { 316 | a := assert.New(t) 317 | wc := NewWhereClause() 318 | wc.SetFlavor(PostgreSQL) 319 | flavor := wc.Flavor() 320 | a.Equal(PostgreSQL, flavor) 321 | } 322 | 323 | func TestWhereClauseCopyGetFlavor(t *testing.T) { 324 | a := assert.New(t) 325 | 326 | wc := NewWhereClause() 327 | wc.SetFlavor(PostgreSQL) 328 | 329 | wcCopy := CopyWhereClause(wc) 330 | flavor := wcCopy.Flavor() 331 | a.Equal(PostgreSQL, flavor) 332 | } 333 | --------------------------------------------------------------------------------