├── go.mod ├── .gitignore ├── .github └── workflows │ └── go.yml ├── README.md ├── LICENSE ├── sql.go └── sql_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yangyin5127/sqlstring 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.20' 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlstring 2 | 3 | Simple SQL escape and format 4 | 5 | [![Go](https://github.com/feiin/sqlstring/actions/workflows/go.yml/badge.svg)](https://github.com/feiin/sqlstring/actions/workflows/go.yml) 6 | [![GoDoc](https://godoc.org/github.com/feiin/sqlstring?status.svg)](https://godoc.org/github.com/feiin/sqlstring) 7 | 8 | ## Escaping sql values 9 | 10 | ```golang 11 | //Format 12 | sql := sqlstring.Format("select * from users where name=? and age=? limit ?,?", "t'est", 10, 10, 10) 13 | 14 | fmt.Printf("sql: %s",sql) 15 | 16 | //Escape 17 | sql = "select * from users WHERE name = " + sqlstring.Escape(name); 18 | fmt.Printf("sql: %s",sql) 19 | 20 | ``` 21 | 22 | ## License 23 | 24 | MIT -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 solar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sql.go: -------------------------------------------------------------------------------- 1 | package sqlstring 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | "fmt" 7 | "reflect" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | var ( 15 | tmFmtZero = "0000-00-00 00:00:00" 16 | tmFmtWithMS = "2006-01-02 15:04:05.999" 17 | escaper = "'" 18 | nullStr = "NULL" 19 | singleQuoteEscaper = "\\" 20 | escapeRegexp = regexp.MustCompile(`[\0\t\x1a\n\r\"\'\\]`) 21 | 22 | //see href='https://dev.mysql.com/doc/refman/8.0/en/string-literals.html#character-escape-sequences' 23 | characterEscapeMap = map[string]string{ 24 | "\\0": `\\0`, //ASCII NULL 25 | "\b": `\\b`, //backspace 26 | "\t": `\\t`, //tab 27 | "\x1a": `\\Z`, //ASCII 26 (Control+Z); 28 | "\n": `\\n`, //newline character 29 | "\r": `\\r`, //return character 30 | "\"": `\\"`, //quote (") 31 | "'": `\'`, //quote (') 32 | "\\": `\\\\`, //backslash (\) 33 | // "\\%": `\\%`, //% character 34 | // "\\_": `\\_`, //_ character 35 | } 36 | ) 37 | 38 | //Escape escape the val for sql 39 | func Escape(val interface{}) string { 40 | return EscapeInLocation(val, time.Local) 41 | } 42 | 43 | //toSqlString escape the string val for sql 44 | func toSqlString(val string) string { 45 | return escapeRegexp.ReplaceAllStringFunc(val, func(s string) string { 46 | 47 | mVal, ok := characterEscapeMap[s] 48 | if ok { 49 | return mVal 50 | } 51 | return s 52 | }) 53 | } 54 | 55 | func timeToString(t time.Time, loc *time.Location) string { 56 | if t.IsZero() { 57 | return escaper + tmFmtZero + escaper 58 | } 59 | 60 | if loc != nil { 61 | return escaper + t.In(loc).Format(tmFmtWithMS) + escaper 62 | } 63 | return escaper + t.Format(tmFmtWithMS) + escaper 64 | } 65 | 66 | func arrayToString(refValue reflect.Value, loc *time.Location) string { 67 | var res []string 68 | for i := 0; i < refValue.Len(); i++ { 69 | res = append(res, EscapeInLocation(refValue.Index(i).Interface(), loc)) 70 | } 71 | return strings.Join(res, ",") 72 | } 73 | 74 | func bytesToString(b []byte) string { 75 | return "X" + escaper + hex.EncodeToString(b) + escaper 76 | } 77 | 78 | //EscapeInLocation escape the val with time.Location 79 | func EscapeInLocation(val interface{}, loc *time.Location) string { 80 | if val == nil { 81 | return nullStr 82 | } 83 | 84 | switch v := val.(type) { 85 | case bool: 86 | return strconv.FormatBool(v) 87 | case time.Time: 88 | return timeToString(v, loc) 89 | case *time.Time: 90 | if v == nil { 91 | return nullStr 92 | } 93 | return timeToString(*v, loc) 94 | case []byte: 95 | return bytesToString(v) 96 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 97 | return fmt.Sprintf("%d", v) 98 | case float32, float64: 99 | return fmt.Sprintf("%.6f", v) 100 | 101 | case string: 102 | return escaper + toSqlString(v) + escaper 103 | default: 104 | refValue := reflect.ValueOf(v) 105 | if v == nil || !refValue.IsValid() { 106 | return nullStr 107 | } 108 | 109 | if refValue.Kind() == reflect.Ptr && refValue.IsNil() { 110 | return nullStr 111 | } 112 | 113 | if refValue.Kind() == reflect.Ptr && !refValue.IsZero() { 114 | return EscapeInLocation(reflect.Indirect(refValue).Interface(), loc) 115 | } 116 | 117 | if refValue.Kind() == reflect.Array || refValue.Kind() == reflect.Slice { 118 | //slice or array 119 | return arrayToString(refValue, loc) 120 | } 121 | 122 | stringifyData, err := json.Marshal(v) 123 | if err != nil { 124 | return nullStr 125 | } 126 | return escaper + toSqlString(string(stringifyData)) + escaper 127 | 128 | } 129 | } 130 | 131 | //Format format the sql with args 132 | func Format(query string, args ...interface{}) string { 133 | 134 | if len(args) == 0 { 135 | return query 136 | } 137 | 138 | var sql strings.Builder 139 | replaceIndex := 0 140 | for _, v := range query { 141 | if v == '?' { 142 | if len(args) > replaceIndex { 143 | sql.WriteString(Escape(args[replaceIndex])) 144 | replaceIndex++ 145 | continue 146 | } 147 | } 148 | sql.WriteRune(v) 149 | } 150 | return sql.String() 151 | } 152 | 153 | //FormatInLocation format the sql with args 154 | func FormatInLocation(query string, loc *time.Location, args ...interface{}) string { 155 | 156 | if len(args) == 0 { 157 | return query 158 | } 159 | 160 | var sql strings.Builder 161 | replaceIndex := 0 162 | for _, v := range query { 163 | if v == '?' { 164 | if len(args) > replaceIndex { 165 | sql.WriteString(EscapeInLocation(args[replaceIndex], loc)) 166 | replaceIndex++ 167 | continue 168 | } 169 | } 170 | sql.WriteRune(v) 171 | } 172 | return sql.String() 173 | } 174 | 175 | //SetSingleQuoteEscaper set the singleQuoteEscaper 176 | //default:\' , e.g. '' 、 \' 177 | func SetSingleQuoteEscaper(escaper string) { 178 | 179 | characterEscapeMap["'"] = escaper 180 | // singleQuoteEscaper = escaper 181 | } 182 | -------------------------------------------------------------------------------- /sql_test.go: -------------------------------------------------------------------------------- 1 | package sqlstring 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestNULLEscape(t *testing.T) { 10 | result := Escape(nil) 11 | if result != "NULL" { 12 | t.Fatalf("escape error") 13 | } 14 | } 15 | 16 | func Test0Escape(t *testing.T) { 17 | result := Escape(`\0`) 18 | t.Logf("Test0Escape result: %s", result) 19 | if result != `'\\\\0'` { 20 | t.Fatalf("escape error") 21 | } 22 | } 23 | 24 | func TestEmptyStringEscape(t *testing.T) { 25 | result := Escape("") 26 | t.Logf("result :%s", result) 27 | if result != "''" { 28 | t.Fatalf("escape empty string error") 29 | } 30 | } 31 | func TestBoolEscape(t *testing.T) { 32 | 33 | result := Escape(true) 34 | if result != "true" { 35 | t.Fatalf("escape error") 36 | } 37 | 38 | result = Escape(false) 39 | if result != "false" { 40 | t.Fatalf("escape error") 41 | } 42 | } 43 | 44 | func TestTimeToString(t *testing.T) { 45 | bt, _ := time.ParseInLocation("2006-01-02 15:04:05", "2021-01-01 15:00:09", time.Local) 46 | 47 | result := Escape(bt) 48 | t.Logf("result time %s", result) 49 | if result != "'2021-01-01 15:00:09'" { 50 | t.Fatalf("escape time error") 51 | } 52 | 53 | result = Escape(&bt) 54 | t.Logf("result time2 %s", result) 55 | if result != "'2021-01-01 15:00:09'" { 56 | t.Fatalf("escape time error") 57 | } 58 | } 59 | 60 | func TestArrayToString(t *testing.T) { 61 | 62 | var a = []int{1, 2, 3, 4} 63 | val := reflect.ValueOf(a) 64 | result := arrayToString(val, time.Local) 65 | 66 | if result != "1,2,3,4" { 67 | t.Fatalf("escape slice error") 68 | 69 | } 70 | 71 | b := [3]string{"1", "2", "3"} 72 | val = reflect.ValueOf(b) 73 | result = arrayToString(val, time.Local) 74 | 75 | if result != "'1','2','3'" { 76 | t.Fatalf("escape arr error") 77 | 78 | } 79 | 80 | } 81 | 82 | func TestStringEscape(t *testing.T) { 83 | s := "hello world" 84 | result := Escape(s) 85 | if result != "'hello world'" { 86 | t.Fatalf("escape string error") 87 | 88 | } 89 | 90 | s = "hello ' world" 91 | result = Escape(s) 92 | if result != "'hello \\' world'" { 93 | t.Fatalf("escape string error") 94 | 95 | } 96 | } 97 | 98 | func TestStringEscape2(t *testing.T) { 99 | s := "hello world" 100 | result := Escape(s) 101 | if result != "'hello world'" { 102 | t.Fatalf("escape string error") 103 | 104 | } 105 | 106 | s = `hello \' world` 107 | t.Logf("TestStringEscape2 raw:%s", s) 108 | result = Escape(s) 109 | t.Logf("TestStringEscape2 result: %s", result) 110 | if result != `'hello \\\\\' world'` { 111 | t.Fatalf("escape string error") 112 | 113 | } 114 | } 115 | 116 | func TestStringCustomEscape(t *testing.T) { 117 | s := "hello world" 118 | SetSingleQuoteEscaper("''") 119 | result := Escape(s) 120 | if result != "'hello world'" { 121 | t.Fatalf("escape string error") 122 | 123 | } 124 | 125 | s = "hello ' world" 126 | result = Escape(s) 127 | t.Logf("TestStringCustomEscape result: %s", result) 128 | if result != "'hello '' world'" { 129 | t.Fatalf("escape string error") 130 | 131 | } 132 | SetSingleQuoteEscaper("\\'") 133 | 134 | } 135 | 136 | func TestBytesEscape(t *testing.T) { 137 | s := []byte{0, 1, 254, 255} 138 | result := Escape(s) 139 | t.Logf("TestBytesEscape result: %s", result) 140 | if result != "X'0001feff'" { 141 | t.Fatalf("escape bytes error") 142 | 143 | } 144 | 145 | } 146 | 147 | func TestIntEscape(t *testing.T) { 148 | var i int = 10 149 | result := Escape(i) 150 | t.Logf("TestBytesEscape result: %s", result) 151 | if result != "10" { 152 | t.Fatalf("escape int error") 153 | 154 | } 155 | 156 | var i2 int8 = 7 157 | result = Escape(i2) 158 | t.Logf("TestBytesEscape result: %s", result) 159 | if result != "7" { 160 | t.Fatalf("escape int8 error") 161 | 162 | } 163 | 164 | var i3 int16 = 12 165 | result = Escape(i3) 166 | t.Logf("TestBytesEscape result: %s", result) 167 | if result != "12" { 168 | t.Fatalf("escape int16 error") 169 | 170 | } 171 | 172 | var i4 int32 = 13 173 | result = Escape(i4) 174 | t.Logf("TestBytesEscape result: %s", result) 175 | if result != "13" { 176 | t.Fatalf("escape int32 error") 177 | 178 | } 179 | 180 | var i5 int64 = 14 181 | result = Escape(i5) 182 | t.Logf("TestBytesEscape result: %s", result) 183 | if result != "14" { 184 | t.Fatalf("escape int32 error") 185 | 186 | } 187 | 188 | } 189 | 190 | func TestUIntEscape(t *testing.T) { 191 | var i uint = 10 192 | result := Escape(i) 193 | t.Logf("TestBytesEscape result: %s", result) 194 | if result != "10" { 195 | t.Fatalf("escape int error") 196 | 197 | } 198 | 199 | var i2 uint8 = 7 200 | result = Escape(i2) 201 | t.Logf("TestBytesEscape result: %s", result) 202 | if result != "7" { 203 | t.Fatalf("escape int8 error") 204 | 205 | } 206 | 207 | var i3 uint16 = 12 208 | result = Escape(i3) 209 | t.Logf("TestBytesEscape result: %s", result) 210 | if result != "12" { 211 | t.Fatalf("escape int16 error") 212 | 213 | } 214 | 215 | var i4 uint32 = 13 216 | result = Escape(i4) 217 | t.Logf("TestBytesEscape result: %s", result) 218 | if result != "13" { 219 | t.Fatalf("escape int32 error") 220 | 221 | } 222 | 223 | var i5 uint64 = 14 224 | result = Escape(i5) 225 | t.Logf("TestBytesEscape result: %s", result) 226 | if result != "14" { 227 | t.Fatalf("escape int32 error") 228 | 229 | } 230 | 231 | } 232 | 233 | func TestOtherEscape(t *testing.T) { 234 | x := map[string]string{ 235 | "name": "asd'fsadf", 236 | "key": "test", 237 | } 238 | result := Escape(x) 239 | t.Logf("escape reuslt %s", result) 240 | 241 | if result != `'{\\"key\\":\\"test\\",\\"name\\":\\"asd\'fsadf\\"}'` { 242 | t.Fatalf("escape map error") 243 | 244 | } 245 | 246 | } 247 | 248 | func TestNewlineEscape(t *testing.T) { 249 | s := "hello\nworld" 250 | result := Escape(s) 251 | t.Logf("escape newline reuslt: %s", result) 252 | 253 | if result != "'hello\\\\nworld'" { 254 | t.Fatalf("escape string error") 255 | 256 | } 257 | 258 | } 259 | 260 | func TestReturnEscape(t *testing.T) { 261 | s := "hello\rworld" 262 | result := Escape(s) 263 | t.Logf("escape newline reuslt: %s", result) 264 | 265 | if result != "'hello\\\\rworld'" { 266 | t.Fatalf("escape string error") 267 | 268 | } 269 | 270 | } 271 | 272 | func TestTabEscape(t *testing.T) { 273 | s := "hello\tworld" 274 | result := Escape(s) 275 | t.Logf("escape tab reuslt: %s", result) 276 | 277 | if result != `'hello\\tworld'` { 278 | t.Fatalf("escape string error") 279 | 280 | } 281 | 282 | } 283 | 284 | func TestDoubleBackslashEscape(t *testing.T) { 285 | s := "hello\\world" 286 | result := Escape(s) 287 | t.Logf("escape tab reuslt: %s", result) 288 | 289 | if result != `'hello\\\\world'` { 290 | t.Fatalf("escape string error") 291 | 292 | } 293 | 294 | } 295 | 296 | func TestCtrlZEscape(t *testing.T) { 297 | s := "hello\x1aworld" 298 | result := Escape(s) 299 | t.Logf("escape tab reuslt: %s", result) 300 | 301 | if result != `'hello\\Zworld'` { 302 | t.Fatalf("escape string error") 303 | 304 | } 305 | 306 | } 307 | 308 | func TestDoubleQouteEscape(t *testing.T) { 309 | s := "hello \" world" 310 | result := Escape(s) 311 | t.Logf("escape tab reuslt: %s", result) 312 | 313 | if result != `'hello \\" world'` { 314 | t.Fatalf("escape string error") 315 | 316 | } 317 | 318 | } 319 | 320 | func TestFormatSql(t *testing.T) { 321 | 322 | sql := Format("select * from users where name=? and age=? limit ?,?", "t'est", 10, 10, 10) 323 | t.Logf("sql %s", sql) 324 | 325 | if sql != "select * from users where name='t\\'est' and age=10 limit 10,10" { 326 | t.Fatalf("escape format error") 327 | } 328 | 329 | sql = Format("? and ?", "a", "b") 330 | t.Logf("sql %s", sql) 331 | 332 | if sql != "'a' and 'b'" { 333 | t.Fatalf("escape format str error") 334 | } 335 | 336 | sql = Format("in (?)", []int{1, 2, 3}) 337 | t.Logf("sql %s", sql) 338 | 339 | if sql != "in (1,2,3)" { 340 | t.Fatalf("escape format arr error") 341 | 342 | } 343 | 344 | sql = Format("in (?)", []interface{}{1, 2, 3}) 345 | t.Logf("sql %s", sql) 346 | 347 | if sql != "in (1,2,3)" { 348 | t.Fatalf("escape format arr error") 349 | 350 | } 351 | 352 | sql = Format("in (?)", []string{"1", "2", "3"}) 353 | t.Logf("sql %s", sql) 354 | 355 | if sql != "in ('1','2','3')" { 356 | t.Fatalf("escape format arr2 error") 357 | 358 | } 359 | 360 | sql = Format("in (?)", []interface{}{"1", "2", "3"}) 361 | t.Logf("sql %s", sql) 362 | 363 | if sql != "in ('1','2','3')" { 364 | t.Fatalf("escape format arr2 error") 365 | 366 | } 367 | 368 | sql = Format("in (?)", []interface{}{1, 2, "3"}) 369 | t.Logf("sql %s", sql) 370 | 371 | if sql != "in (1,2,'3')" { 372 | t.Fatalf("escape format arr error") 373 | 374 | } 375 | 376 | bt, _ := time.ParseInLocation("2006-01-02 15:04:05", "2021-01-01 15:00:09", time.Local) 377 | 378 | sql = Format("a=?", bt) 379 | t.Logf("sql %s", sql) 380 | 381 | if sql != "a='2021-01-01 15:00:09'" { 382 | t.Fatalf("escape format time error") 383 | 384 | } 385 | 386 | sql = Format("select * from users where name=? and age=? limit ?,?", `t\'est`, 10, 10, 10) 387 | 388 | if sql != `'select * from users where name='t\\\\\'est' and age=10 limit 10,10'` { 389 | 390 | t.Logf("sql: %s\n", sql) 391 | } 392 | } 393 | --------------------------------------------------------------------------------