├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── cli ├── cli.go ├── cli_test.go ├── cmd.go ├── cmd_test.go ├── utils.go └── utils_test.go ├── examples ├── client_example │ └── client_example.go └── package_example │ └── package_example.go ├── log └── pid.log ├── main.go ├── parser ├── analyzer.go ├── analyzer_test.go ├── ast.go ├── ast_test.go ├── dependency │ ├── bson │ │ ├── bson_test.go │ │ ├── common.go │ │ ├── custom_test.go │ │ ├── marshal.go │ │ ├── marshal_test.go │ │ ├── marshal_util.go │ │ ├── unmarshal.go │ │ ├── unmarshal_test.go │ │ └── unmarshal_util.go │ ├── bytes2 │ │ ├── chunked_writer.go │ │ └── cw_test.go │ ├── hack │ │ ├── hack.go │ │ └── hack_test.go │ └── sqltypes │ │ ├── sqltypes.go │ │ └── type_test.go ├── filter.go ├── parse_test.go ├── parsed_query.go ├── parsed_query_test.go ├── rewriter.go ├── rewriter_test.go ├── sql.go ├── sql.y ├── token.go └── tracked_buffer.go ├── pgproxy.conf ├── pgproxy.png ├── proxy ├── formate.go ├── proxy.go └── proxy_test.go └── version /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | install: 4 | - go get -d -t -v ./... && go build -v ./... 5 | 6 | go: 7 | - 1.6 8 | - 1.7 9 | - 1.8 10 | - tip 11 | 12 | script: 13 | - go vet ./... 14 | - go test -v -coverprofile=coverage.txt -covermode=atomic 15 | - go build 16 | 17 | after_success: 18 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![pgproxy](./pgproxy.png) 2 | 3 | # pgproxy 4 | [![Build Status](https://travis-ci.org/wgliang/pgproxy.svg?branch=master)](https://travis-ci.org/wgliang/pgproxy) 5 | [![codecov](https://codecov.io/gh/wgliang/pgproxy/branch/master/graph/badge.svg)](https://codecov.io/gh/wgliang/pgproxy) 6 | [![GoDoc](https://godoc.org/github.com/wgliang/pgproxy?status.svg)](https://godoc.org/github.com/wgliang/pgproxy) 7 | [![Code Health](https://landscape.io/github/wgliang/pgproxy/master/landscape.svg?style=flat)](https://landscape.io/github/wgliang/pgproxy/master) 8 | [![Code Issues](https://www.quantifiedcode.com/api/v1/project/98b2cb0efd774c5fa8f9299c4f96a8c5/badge.svg)](https://www.quantifiedcode.com/app/project/98b2cb0efd774c5fa8f9299c4f96a8c5) 9 | [![Go Report Card](https://goreportcard.com/badge/github.com/wgliang/pgproxy)](https://goreportcard.com/report/github.com/wgliang/pgproxy) 10 | [![License](https://img.shields.io/badge/LICENSE-Apache2.0-ff69b4.svg)](http://www.apache.org/licenses/LICENSE-2.0.html) 11 | 12 | pgproxy is a postgresql proxy server, through a pipe redirect connection, which allows you to filter the requested sql statement. In the future it will support multi-database backup, adapt to distributed databases and other schemes except the analyze sql statement. 13 | 14 | You can do: 15 | 16 | * database read and write separation 17 | * database services disaster recovery 18 | * proxy database 19 | * rewrite sql statement 20 | * filter dangerous sql 21 | * monitor database operations 22 | * sql requests current limit and merge 23 | 24 | ## Installation 25 | 26 | ``` 27 | $ go get -u github.com/wgliang/pgproxy 28 | ``` 29 | 30 | ## Using 31 | 32 | ### As a separate application 33 | 34 | Start or shut down the proxy server. 35 | ``` 36 | $ pgproxy start/stop 37 | ``` 38 | 39 | Use pgproxy on the command line 40 | ``` 41 | $ pgproxy cli 42 | ``` 43 | 44 | Ps: You can use it as you would with a native command line. 45 | 46 | ### Be called as a package 47 | 48 | [package_example](https://github.com/wgliang/pgproxy/blob/master/examples/package_example.go) 49 | 50 | ``` 51 | package main 52 | 53 | import ( 54 | "fmt" 55 | "os" 56 | "os/signal" 57 | "syscall" 58 | 59 | "github.com/wgliang/pgproxy/cli" 60 | ) 61 | 62 | func main() { 63 | // call proxy 64 | cli.Main("../pgproxy.conf", []string{"pgproxy", "start"}) 65 | 66 | // 捕获ctrl-c,平滑退出 67 | chExit := make(chan os.Signal, 1) 68 | signal.Notify(chExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL) 69 | select { 70 | case <-chExit: 71 | fmt.Println("Example EXITING...Bye.") 72 | } 73 | } 74 | 75 | ``` 76 | 77 | ## Support 78 | 79 | select/delete/update statement and support any case. 80 | 81 | On the support of the sql standard: 82 | 83 | The parser is forked from vitess's [sqlparser](https://github.com/youtube/vitess/tree/master/go/vt/sqlparser) of youtube. 84 | 85 | In pgproxy, database tables are like MySQL(5.6,5.7) relational tables, and you can use relational modeling schemes (normalization) to structure your schema. It supports almost all MySQL(5.6,5.7) scalar data types. It also provides full SQL support within a shard, including JOIN statements. Some postgresql operations are not supported,detail see [support type and keywords](https://github.com/wgliang/pgproxy/blob/master/parser/token.go#L37). 86 | 87 | 88 | ## Credits 89 | 90 | Package parser is based on [sqlparser](https://github.com/xwb1989/sqlparser) 91 | -------------------------------------------------------------------------------- /cli/cli.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package cli provides virtual command-line access 6 | // in pgproxy include start,cli and stop action. 7 | package cli 8 | 9 | import ( 10 | "flag" 11 | "fmt" 12 | "os" 13 | "strconv" 14 | "syscall" 15 | "time" 16 | 17 | "github.com/golang/glog" 18 | "github.com/wgliang/pgproxy/parser" 19 | "github.com/wgliang/pgproxy/proxy" 20 | ) 21 | 22 | var ( 23 | connStr string 24 | pc ProxyConfig 25 | ) 26 | 27 | // pgproxy Main 28 | func Main(config interface{}, pargs interface{}) { 29 | var proxyconf = flag.String("config", "pgproxy.conf", "configuration file for pgproxy") 30 | 31 | flag.Parse() 32 | defer glog.Flush() 33 | 34 | var args []string 35 | if nil != config { 36 | pc, connStr = readConfig(config.(string)) 37 | args = pargs.([]string) 38 | } else { 39 | pc, connStr = readConfig(*proxyconf) 40 | args = os.Args 41 | fmt.Println(args) 42 | } 43 | 44 | if len(args) < 2 { 45 | glog.Errorln("needed one parameters:", args) 46 | help() 47 | return 48 | } else { 49 | if args[1] == "start" { 50 | glog.Infoln("Starting pgproxy...") 51 | info(pc.ServerConfig.ProxyAddr) 52 | logDir() 53 | saveCurrentPid() 54 | proxy.Start(pc.ServerConfig.ProxyAddr, pc.DB["master"].Addr, parser.Filter, parser.Return) 55 | glog.Infoln("Started pgproxy successfully.") 56 | } else if args[1] == "cli" { 57 | Command() 58 | } else if args[1] == "stop" { 59 | stop() 60 | } else { 61 | help() 62 | } 63 | } 64 | } 65 | 66 | // print pgproxy help 67 | func help() { 68 | fmt.Println(" pgproxy is a proxy-server for database postgresql.") 69 | fmt.Println(" start :start pgproxy server.") 70 | fmt.Println(" stop :stop pgproxy server.") 71 | fmt.Println(" version :pgproxy version.") 72 | fmt.Println(" info :pgproxy info.") 73 | } 74 | 75 | // print pgproxy infomation 76 | func info(proxyhost string) { 77 | fmt.Println(Logo) 78 | hostname, err := os.Hostname() 79 | if err != nil { 80 | hostname = "" 81 | } 82 | pid := strconv.Itoa(os.Getpid()) 83 | starttime := time.Now().Format("2006-01-02 03:04:05 PM") 84 | fmt.Println(" ", VERSION) 85 | fmt.Println(" Host: " + hostname) 86 | fmt.Println(" Pid:", string(pid)) 87 | fmt.Println(" Proxy:", proxyhost) 88 | fmt.Println(" Starttime:", starttime) 89 | fmt.Println() 90 | } 91 | 92 | // set log dir 93 | func logDir() { 94 | _, err := os.Stat("./log") 95 | if err != nil && os.IsNotExist(err) { 96 | err := os.MkdirAll("./log", 0777) 97 | if err != nil { 98 | glog.Fatalln(err) 99 | } else { 100 | glog.Infoln("glog and process pid in ./log") 101 | } 102 | } 103 | } 104 | 105 | // save current pgproxy pid 106 | func saveCurrentPid() { 107 | // pid file 108 | filepath := "./log/pid.log" 109 | fout, err := os.OpenFile(filepath, os.O_CREATE|os.O_RDWR, 0777) 110 | if err != nil { 111 | glog.Errorln(err) 112 | return 113 | } 114 | defer fout.Close() 115 | // write current pid 116 | fout.WriteString(strconv.Itoa(os.Getpid())) 117 | } 118 | 119 | // get current pgproxy pid 120 | func getCurrentPid() int { 121 | // pid file 122 | filepath := "./log/pid.log" 123 | fin, err := os.OpenFile(filepath, os.O_RDONLY, 0777) 124 | if err != nil { 125 | glog.Errorln(err) 126 | return 0 127 | } 128 | defer fin.Close() 129 | // read current pid 130 | buf := make([]byte, 1024) 131 | 132 | n, _ := fin.Read(buf) 133 | if 0 >= n { 134 | return 0 135 | } else { 136 | pid, err := strconv.Atoi(string(buf[0:n])) 137 | if err != nil { 138 | glog.Errorln(err) 139 | return 0 140 | } else { 141 | return pid 142 | } 143 | } 144 | } 145 | 146 | // stop pgproxy 147 | func stop() { 148 | pid := getCurrentPid() 149 | if pid != 0 { 150 | err := syscall.Kill(pid, syscall.SIGTERM) 151 | if err != nil { 152 | glog.Errorln(err) 153 | } else { 154 | glog.Infoln("pgproxy exit successfully!") 155 | } 156 | } 157 | fmt.Printf("pgproxy(%d) Exit,thanks.\n", pid) 158 | } 159 | -------------------------------------------------------------------------------- /cli/cli_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_Main(t *testing.T) { 8 | Main(nil, nil) 9 | } 10 | -------------------------------------------------------------------------------- /cli/cmd.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package cli provides virtual command-line access 6 | // in pgproxy include start,cli and stop action. 7 | package cli 8 | 9 | import ( 10 | "bufio" 11 | "fmt" 12 | "os" 13 | "strings" 14 | "time" 15 | 16 | "github.com/golang/glog" 17 | "github.com/jmoiron/sqlx" 18 | _ "github.com/lib/pq" 19 | "github.com/wgliang/pgproxy/proxy" 20 | ) 21 | 22 | type Client struct { 23 | db *sqlx.DB 24 | timestamp int64 25 | } 26 | 27 | // Command line access to pgproxy and provide a friendly display 28 | // interface. 29 | func Command() { 30 | client := new(Client) 31 | var err error 32 | client.db, err = sqlx.Open("postgres", connStr) 33 | if err != nil { 34 | glog.Fatalln(err) 35 | } 36 | client.timestamp = time.Now().Unix() 37 | 38 | // Set connections num 39 | client.db.SetMaxIdleConns(1) 40 | client.db.SetMaxOpenConns(10) 41 | client.db.SetConnMaxLifetime(60 * time.Second) 42 | defer func() { 43 | client.db.Close() 44 | if err != nil { 45 | glog.Errorln(err) 46 | } 47 | }() 48 | fmt.Printf(" pgproxy (%s)\n", VERSION) 49 | fmt.Println(" Login in:", time.Unix(client.timestamp, 0).Format("2006-01-02 03:04:05 PM")) 50 | fmt.Println(` Type "help" for help.`) 51 | running := true 52 | reader := bufio.NewReader(os.Stdin) 53 | for running { 54 | // Sleep some Nanoseconds wait for event have been deal. 55 | time.Sleep(300000 * time.Nanosecond) 56 | fmt.Print("pgproxy#") 57 | data, _, _ := reader.ReadLine() 58 | command := string(data) 59 | if command == "quit" { 60 | fmt.Println("pgproxy Exit!") 61 | return 62 | } 63 | client.Request(command) 64 | } 65 | return 66 | } 67 | 68 | // Client request switcher,for different types of sql statement 69 | // calls different requests. 70 | func (c *Client) Request(sql string) { 71 | index := strings.Index(sql, " ") 72 | if index == -1 { 73 | index = len(sql) 74 | } 75 | // Choose right function for requests. 76 | switch strings.ToLower(sql[0:index]) { 77 | case "select": 78 | rows, err := c.db.Query(sql) 79 | if err != nil { 80 | glog.Errorln(err) 81 | } else { 82 | proxy.RowsFormater(rows) 83 | } 84 | case "insert", "delete", "update": 85 | res, err := c.db.Exec(sql) 86 | if err != nil { 87 | glog.Errorln(err) 88 | } else { 89 | proxy.ResultFormater(res) 90 | } 91 | case `\d`, `\l`, `\q`: 92 | // res := c.db.Exec(sql) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /cli/cmd_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_Command(t *testing.T) { 8 | // Command() 9 | } 10 | -------------------------------------------------------------------------------- /cli/utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package cli provides virtual command-line access 6 | // in pgproxy include start,cli and stop action. 7 | package cli 8 | 9 | import ( 10 | "fmt" 11 | "os" 12 | "strings" 13 | "syscall" 14 | 15 | "github.com/bbangert/toml" 16 | "github.com/golang/glog" 17 | ) 18 | 19 | const Logo = ` 20 | ____ ____ _____ _________ _ ____ __ 21 | / __ \/ __ '/ __ \/ ___/ __ \| |/_/ / / / 22 | / /_/ / /_/ / /_/ / / / /_/ /> " 580 | AST_LE = "<=" 581 | AST_GE = ">=" 582 | AST_NE = "!=" 583 | AST_NSE = "<=>" 584 | AST_IN = "in" 585 | AST_NOT_IN = "not in" 586 | AST_LIKE = "like" 587 | AST_NOT_LIKE = "not like" 588 | ) 589 | 590 | func (node *ComparisonExpr) Format(buf *TrackedBuffer) { 591 | buf.Myprintf("%v %s %v", node.Left, node.Operator, node.Right) 592 | } 593 | 594 | // RangeCond represents a BETWEEN or a NOT BETWEEN expression. 595 | type RangeCond struct { 596 | Operator string 597 | Left ValExpr 598 | From, To ValExpr 599 | } 600 | 601 | // RangeCond.Operator 602 | const ( 603 | AST_BETWEEN = "between" 604 | AST_NOT_BETWEEN = "not between" 605 | ) 606 | 607 | func (node *RangeCond) Format(buf *TrackedBuffer) { 608 | buf.Myprintf("%v %s %v and %v", node.Left, node.Operator, node.From, node.To) 609 | } 610 | 611 | // NullCheck represents an IS NULL or an IS NOT NULL expression. 612 | type NullCheck struct { 613 | Operator string 614 | Expr ValExpr 615 | } 616 | 617 | // NullCheck.Operator 618 | const ( 619 | AST_IS_NULL = "is null" 620 | AST_IS_NOT_NULL = "is not null" 621 | ) 622 | 623 | func (node *NullCheck) Format(buf *TrackedBuffer) { 624 | buf.Myprintf("%v %s", node.Expr, node.Operator) 625 | } 626 | 627 | // ExistsExpr represents an EXISTS expression. 628 | type ExistsExpr struct { 629 | Subquery *Subquery 630 | } 631 | 632 | func (node *ExistsExpr) Format(buf *TrackedBuffer) { 633 | buf.Myprintf("exists %v", node.Subquery) 634 | } 635 | 636 | // ValExpr represents a value expression. 637 | type ValExpr interface { 638 | IValExpr() 639 | Expr 640 | } 641 | 642 | func (StrVal) IValExpr() {} 643 | func (NumVal) IValExpr() {} 644 | func (ValArg) IValExpr() {} 645 | func (*NullVal) IValExpr() {} 646 | func (*ColName) IValExpr() {} 647 | func (ValTuple) IValExpr() {} 648 | func (*Subquery) IValExpr() {} 649 | func (ListArg) IValExpr() {} 650 | func (*BinaryExpr) IValExpr() {} 651 | func (*UnaryExpr) IValExpr() {} 652 | func (*FuncExpr) IValExpr() {} 653 | func (*CaseExpr) IValExpr() {} 654 | 655 | // StrVal represents a string value. 656 | type StrVal []byte 657 | 658 | func (node StrVal) Format(buf *TrackedBuffer) { 659 | s := sqltypes.MakeString([]byte(node)) 660 | s.EncodeSql(buf) 661 | } 662 | 663 | // NumVal represents a number. 664 | type NumVal []byte 665 | 666 | func (node NumVal) Format(buf *TrackedBuffer) { 667 | buf.Myprintf("%s", []byte(node)) 668 | } 669 | 670 | // ValArg represents a named bind var argument. 671 | type ValArg []byte 672 | 673 | func (node ValArg) Format(buf *TrackedBuffer) { 674 | buf.WriteArg(string(node)) 675 | } 676 | 677 | // NullVal represents a NULL value. 678 | type NullVal struct{} 679 | 680 | func (node *NullVal) Format(buf *TrackedBuffer) { 681 | buf.Myprintf("null") 682 | } 683 | 684 | // ColName represents a column name. 685 | type ColName struct { 686 | Name, Qualifier []byte 687 | } 688 | 689 | func (node *ColName) Format(buf *TrackedBuffer) { 690 | if node.Qualifier != nil { 691 | escape(buf, node.Qualifier) 692 | buf.Myprintf(".") 693 | } 694 | escape(buf, node.Name) 695 | } 696 | 697 | func escape(buf *TrackedBuffer, name []byte) { 698 | if _, ok := keywords[string(name)]; ok { 699 | buf.Myprintf("`%s`", name) 700 | } else { 701 | buf.Myprintf("%s", name) 702 | } 703 | } 704 | 705 | // ColTuple represents a list of column values. 706 | // It can be ValTuple, Subquery, ListArg. 707 | type ColTuple interface { 708 | IColTuple() 709 | ValExpr 710 | } 711 | 712 | func (ValTuple) IColTuple() {} 713 | func (*Subquery) IColTuple() {} 714 | func (ListArg) IColTuple() {} 715 | 716 | // ValTuple represents a tuple of actual values. 717 | type ValTuple ValExprs 718 | 719 | func (node ValTuple) Format(buf *TrackedBuffer) { 720 | buf.Myprintf("(%v)", ValExprs(node)) 721 | } 722 | 723 | // ValExprs represents a list of value expressions. 724 | // It's not a valid expression because it's not parenthesized. 725 | type ValExprs []ValExpr 726 | 727 | func (node ValExprs) Format(buf *TrackedBuffer) { 728 | var prefix string 729 | for _, n := range node { 730 | buf.Myprintf("%s%v", prefix, n) 731 | prefix = ", " 732 | } 733 | } 734 | 735 | // Subquery represents a subquery. 736 | type Subquery struct { 737 | Select SelectStatement 738 | } 739 | 740 | func (node *Subquery) Format(buf *TrackedBuffer) { 741 | buf.Myprintf("(%v)", node.Select) 742 | } 743 | 744 | // ListArg represents a named list argument. 745 | type ListArg []byte 746 | 747 | func (node ListArg) Format(buf *TrackedBuffer) { 748 | buf.WriteArg(string(node)) 749 | } 750 | 751 | // BinaryExpr represents a binary value expression. 752 | type BinaryExpr struct { 753 | Operator byte 754 | Left, Right Expr 755 | } 756 | 757 | // BinaryExpr.Operator 758 | const ( 759 | AST_BITAND = '&' 760 | AST_BITOR = '|' 761 | AST_BITXOR = '^' 762 | AST_PLUS = '+' 763 | AST_MINUS = '-' 764 | AST_MULT = '*' 765 | AST_DIV = '/' 766 | AST_MOD = '%' 767 | ) 768 | 769 | func (node *BinaryExpr) Format(buf *TrackedBuffer) { 770 | buf.Myprintf("%v%c%v", node.Left, node.Operator, node.Right) 771 | } 772 | 773 | // UnaryExpr represents a unary value expression. 774 | type UnaryExpr struct { 775 | Operator byte 776 | Expr Expr 777 | } 778 | 779 | // UnaryExpr.Operator 780 | const ( 781 | AST_UPLUS = '+' 782 | AST_UMINUS = '-' 783 | AST_TILDA = '~' 784 | ) 785 | 786 | func (node *UnaryExpr) Format(buf *TrackedBuffer) { 787 | buf.Myprintf("%c%v", node.Operator, node.Expr) 788 | } 789 | 790 | // FuncExpr represents a function call. 791 | type FuncExpr struct { 792 | Name []byte 793 | Distinct bool 794 | Exprs SelectExprs 795 | } 796 | 797 | func (node *FuncExpr) Format(buf *TrackedBuffer) { 798 | var distinct string 799 | if node.Distinct { 800 | distinct = "distinct " 801 | } 802 | buf.Myprintf("%s(%s%v)", node.Name, distinct, node.Exprs) 803 | } 804 | 805 | // Aggregates is a map of all aggregate functions. 806 | var Aggregates = map[string]bool{ 807 | "avg": true, 808 | "bit_and": true, 809 | "bit_or": true, 810 | "bit_xor": true, 811 | "count": true, 812 | "group_concat": true, 813 | "max": true, 814 | "min": true, 815 | "std": true, 816 | "stddev_pop": true, 817 | "stddev_samp": true, 818 | "stddev": true, 819 | "sum": true, 820 | "var_pop": true, 821 | "var_samp": true, 822 | "variance": true, 823 | } 824 | 825 | func (node *FuncExpr) IsAggregate() bool { 826 | return Aggregates[string(node.Name)] 827 | } 828 | 829 | // CaseExpr represents a CASE expression. 830 | type CaseExpr struct { 831 | Expr ValExpr 832 | Whens []*When 833 | Else ValExpr 834 | } 835 | 836 | func (node *CaseExpr) Format(buf *TrackedBuffer) { 837 | buf.Myprintf("case ") 838 | if node.Expr != nil { 839 | buf.Myprintf("%v ", node.Expr) 840 | } 841 | for _, when := range node.Whens { 842 | buf.Myprintf("%v ", when) 843 | } 844 | if node.Else != nil { 845 | buf.Myprintf("else %v ", node.Else) 846 | } 847 | buf.Myprintf("end") 848 | } 849 | 850 | // When represents a WHEN sub-expression. 851 | type When struct { 852 | Cond BoolExpr 853 | Val ValExpr 854 | } 855 | 856 | func (node *When) Format(buf *TrackedBuffer) { 857 | buf.Myprintf("when %v then %v", node.Cond, node.Val) 858 | } 859 | 860 | // GroupBy represents a GROUP BY clause. 861 | type GroupBy []ValExpr 862 | 863 | func (node GroupBy) Format(buf *TrackedBuffer) { 864 | prefix := " group by " 865 | for _, n := range node { 866 | buf.Myprintf("%s%v", prefix, n) 867 | prefix = ", " 868 | } 869 | } 870 | 871 | // OrderBy represents an ORDER By clause. 872 | type OrderBy []*Order 873 | 874 | func (node OrderBy) Format(buf *TrackedBuffer) { 875 | prefix := " order by " 876 | for _, n := range node { 877 | buf.Myprintf("%s%v", prefix, n) 878 | prefix = ", " 879 | } 880 | } 881 | 882 | // Order represents an ordering expression. 883 | type Order struct { 884 | Expr ValExpr 885 | Direction string 886 | } 887 | 888 | // Order.Direction 889 | const ( 890 | AST_ASC = "asc" 891 | AST_DESC = "desc" 892 | ) 893 | 894 | func (node *Order) Format(buf *TrackedBuffer) { 895 | buf.Myprintf("%v %s", node.Expr, node.Direction) 896 | } 897 | 898 | // Limit represents a LIMIT clause. 899 | type Limit struct { 900 | Offset, Rowcount ValExpr 901 | } 902 | 903 | func (node *Limit) Format(buf *TrackedBuffer) { 904 | if node == nil { 905 | return 906 | } 907 | buf.Myprintf(" limit ") 908 | if node.Offset != nil { 909 | buf.Myprintf("%v, ", node.Offset) 910 | } 911 | buf.Myprintf("%v", node.Rowcount) 912 | } 913 | 914 | // Limits returns the values of the LIMIT clause as interfaces. 915 | // The returned values can be nil for absent field, string for 916 | // bind variable names, or int64 for an actual number. 917 | // Otherwise, it's an error. 918 | func (node *Limit) Limits() (offset, rowcount interface{}, err error) { 919 | if node == nil { 920 | return nil, nil, nil 921 | } 922 | switch v := node.Offset.(type) { 923 | case NumVal: 924 | o, err := strconv.ParseInt(string(v), 0, 64) 925 | if err != nil { 926 | return nil, nil, err 927 | } 928 | if o < 0 { 929 | return nil, nil, fmt.Errorf("negative offset: %d", o) 930 | } 931 | offset = o 932 | case ValArg: 933 | offset = string(v) 934 | case nil: 935 | // pass 936 | default: 937 | return nil, nil, fmt.Errorf("unexpected node for offset: %+v", v) 938 | } 939 | switch v := node.Rowcount.(type) { 940 | case NumVal: 941 | rc, err := strconv.ParseInt(string(v), 0, 64) 942 | if err != nil { 943 | return nil, nil, err 944 | } 945 | if rc < 0 { 946 | return nil, nil, fmt.Errorf("negative limit: %d", rc) 947 | } 948 | rowcount = rc 949 | case ValArg: 950 | rowcount = string(v) 951 | default: 952 | return nil, nil, fmt.Errorf("unexpected node for rowcount: %+v", v) 953 | } 954 | return offset, rowcount, nil 955 | } 956 | 957 | // Values represents a VALUES clause. 958 | type Values []RowTuple 959 | 960 | func (node Values) Format(buf *TrackedBuffer) { 961 | prefix := "values " 962 | for _, n := range node { 963 | buf.Myprintf("%s%v", prefix, n) 964 | prefix = ", " 965 | } 966 | } 967 | 968 | // RowTuple represents a row of values. It can be ValTuple, Subquery. 969 | type RowTuple interface { 970 | IRowTuple() 971 | ValExpr 972 | } 973 | 974 | func (ValTuple) IRowTuple() {} 975 | func (*Subquery) IRowTuple() {} 976 | 977 | // UpdateExprs represents a list of update expressions. 978 | type UpdateExprs []*UpdateExpr 979 | 980 | func (node UpdateExprs) Format(buf *TrackedBuffer) { 981 | var prefix string 982 | for _, n := range node { 983 | buf.Myprintf("%s%v", prefix, n) 984 | prefix = ", " 985 | } 986 | } 987 | 988 | // UpdateExpr represents an update expression. 989 | type UpdateExpr struct { 990 | Name *ColName 991 | Expr ValExpr 992 | } 993 | 994 | func (node *UpdateExpr) Format(buf *TrackedBuffer) { 995 | buf.Myprintf("%v = %v", node.Name, node.Expr) 996 | } 997 | 998 | // OnDup represents an ON DUPLICATE KEY clause. 999 | type OnDup UpdateExprs 1000 | 1001 | func (node OnDup) Format(buf *TrackedBuffer) { 1002 | if node == nil { 1003 | return 1004 | } 1005 | buf.Myprintf(" on duplicate key update %v", UpdateExprs(node)) 1006 | } 1007 | 1008 | //ast keywords added for create table parsing 1009 | 1010 | const ( 1011 | //other keywords 1012 | AST_UNSIGNED = "unsigned" 1013 | AST_ZEROFILL = "zerofill" 1014 | 1015 | //datatypes 1016 | AST_BIT = "bit" 1017 | AST_TINYINT = "tinyint" 1018 | AST_SMALLINT = "smallint" 1019 | AST_MEDIUMINT = "mediumint" 1020 | AST_INT = "int" 1021 | AST_INTEGER = "integer" 1022 | AST_BIGINT = "bigint" 1023 | 1024 | AST_REAL = "real" 1025 | AST_DOUBLE = "double" 1026 | AST_FLOAT = "float" 1027 | AST_DECIMAL = "decimal" 1028 | AST_NUMERIC = "numeric" 1029 | 1030 | AST_CHAR = "char" 1031 | AST_VARCHAR = "varchar" 1032 | AST_TEXT = "text" 1033 | 1034 | AST_DATE = "date" 1035 | AST_TIME = "time" 1036 | AST_TIMESTAMP = "timestamp" 1037 | AST_DATETIME = "datetime" 1038 | AST_YEAR = "year" 1039 | 1040 | AST_PRIMARY_KEY = "primary key" 1041 | 1042 | AST_UNIQUE_KEY = "unique key" 1043 | AST_AUTO_INCREMENT = "auto_increment" 1044 | AST_NOT_NULL = "not null" 1045 | AST_DEFAULT = "default" 1046 | AST_KEY = "key" 1047 | ) 1048 | -------------------------------------------------------------------------------- /parser/ast_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import "testing" 8 | 9 | func TestLimits(t *testing.T) { 10 | var l *Limit 11 | o, r, err := l.Limits() 12 | if o != nil || r != nil || err != nil { 13 | t.Errorf("got %v, %v, %v, want nils", o, r, err) 14 | } 15 | 16 | l = &Limit{Offset: NumVal([]byte("aa"))} 17 | _, _, err = l.Limits() 18 | wantErr := "strconv.ParseInt: parsing \"aa\": invalid syntax" 19 | if err == nil || err.Error() != wantErr { 20 | t.Errorf("got %v, want %s", err, wantErr) 21 | } 22 | 23 | l = &Limit{Offset: NumVal([]byte("2"))} 24 | _, _, err = l.Limits() 25 | wantErr = "unexpected node for rowcount: " 26 | if err == nil || err.Error() != wantErr { 27 | t.Errorf("got %v, want %s", err, wantErr) 28 | } 29 | 30 | l = &Limit{Offset: StrVal([]byte("2"))} 31 | _, _, err = l.Limits() 32 | wantErr = "unexpected node for offset: [50]" 33 | if err == nil || err.Error() != wantErr { 34 | t.Errorf("got %v, want %s", err, wantErr) 35 | } 36 | 37 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("aa"))} 38 | _, _, err = l.Limits() 39 | wantErr = "strconv.ParseInt: parsing \"aa\": invalid syntax" 40 | if err == nil || err.Error() != wantErr { 41 | t.Errorf("got %v, want %s", err, wantErr) 42 | } 43 | 44 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("3"))} 45 | o, r, err = l.Limits() 46 | if o.(int64) != 2 || r.(int64) != 3 || err != nil { 47 | t.Errorf("got %v %v %v, want 2, 3, nil", o, r, err) 48 | } 49 | 50 | l = &Limit{Offset: ValArg([]byte(":a")), Rowcount: NumVal([]byte("3"))} 51 | o, r, err = l.Limits() 52 | if o.(string) != ":a" || r.(int64) != 3 || err != nil { 53 | t.Errorf("got %v %v %v, want :a, 3, nil", o, r, err) 54 | } 55 | 56 | l = &Limit{Offset: nil, Rowcount: NumVal([]byte("3"))} 57 | o, r, err = l.Limits() 58 | if o != nil || r.(int64) != 3 || err != nil { 59 | t.Errorf("got %v %v %v, want nil, 3, nil", o, r, err) 60 | } 61 | 62 | l = &Limit{Offset: nil, Rowcount: ValArg([]byte(":a"))} 63 | o, r, err = l.Limits() 64 | if o != nil || r.(string) != ":a" || err != nil { 65 | t.Errorf("got %v %v %v, want nil, :a, nil", o, r, err) 66 | } 67 | 68 | l = &Limit{Offset: NumVal([]byte("-2")), Rowcount: NumVal([]byte("0"))} 69 | _, _, err = l.Limits() 70 | wantErr = "negative offset: -2" 71 | if err == nil || err.Error() != wantErr { 72 | t.Errorf("got %v, want %s", err, wantErr) 73 | } 74 | 75 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("-2"))} 76 | _, _, err = l.Limits() 77 | wantErr = "negative limit: -2" 78 | if err == nil || err.Error() != wantErr { 79 | t.Errorf("got %v, want %s", err, wantErr) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /parser/dependency/bson/bson_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "bytes" 9 | "reflect" 10 | "testing" 11 | "time" 12 | 13 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 14 | ) 15 | 16 | type alltypes struct { 17 | Bytes []byte 18 | Float64 float64 19 | String string 20 | Bool bool 21 | Time time.Time 22 | Int64 int64 23 | Int32 int32 24 | Int int 25 | Uint64 uint64 26 | Uint32 uint32 27 | Uint uint 28 | Strings []string 29 | Nil interface{} 30 | } 31 | 32 | func (a *alltypes) UnmarshalBson(buf *bytes.Buffer, kind byte) { 33 | VerifyObject(kind) 34 | Next(buf, 4) 35 | 36 | kind = NextByte(buf) 37 | for kind != EOO { 38 | key := ReadCString(buf) 39 | switch key { 40 | case "Bytes": 41 | verifyKind("Bytes", Binary, kind) 42 | a.Bytes = DecodeBinary(buf, kind) 43 | case "Float64": 44 | verifyKind("Float64", Number, kind) 45 | a.Float64 = DecodeFloat64(buf, kind) 46 | case "String": 47 | verifyKind("String", Binary, kind) 48 | // Put an easter egg here to verify the function is called 49 | a.String = DecodeString(buf, kind) + "1" 50 | case "Bool": 51 | verifyKind("Bool", Boolean, kind) 52 | a.Bool = DecodeBool(buf, kind) 53 | case "Time": 54 | verifyKind("Time", Datetime, kind) 55 | a.Time = DecodeTime(buf, kind) 56 | case "Int32": 57 | verifyKind("Int32", Int, kind) 58 | a.Int32 = DecodeInt32(buf, kind) 59 | case "Int": 60 | verifyKind("Int", Long, kind) 61 | a.Int = DecodeInt(buf, kind) 62 | case "Int64": 63 | verifyKind("Int64", Long, kind) 64 | a.Int64 = DecodeInt64(buf, kind) 65 | case "Uint64": 66 | verifyKind("Uint64", Ulong, kind) 67 | a.Uint64 = DecodeUint64(buf, kind) 68 | case "Uint32": 69 | verifyKind("Uint32", Ulong, kind) 70 | a.Uint32 = DecodeUint32(buf, kind) 71 | case "Uint": 72 | verifyKind("Uint", Ulong, kind) 73 | a.Uint = DecodeUint(buf, kind) 74 | case "Strings": 75 | verifyKind("Strings", Array, kind) 76 | a.Strings = DecodeStringArray(buf, kind) 77 | case "Nil": 78 | verifyKind("Nil", Null, kind) 79 | default: 80 | Skip(buf, kind) 81 | } 82 | kind = NextByte(buf) 83 | } 84 | } 85 | 86 | func verifyKind(tag string, want, got byte) { 87 | if want != got { 88 | panic(NewBsonError("Decode %s, kind is %v, want %v", tag, got, want)) 89 | } 90 | } 91 | 92 | // TODO(sougou): Revisit usefulness of this test 93 | func TestUnmarshalUtil(t *testing.T) { 94 | a := alltypes{ 95 | Bytes: []byte("bytes"), 96 | Float64: float64(64), 97 | String: "string", 98 | Bool: true, 99 | Time: time.Unix(1136243045, 0).UTC(), 100 | Int64: int64(-0x8000000000000000), 101 | Int32: int32(-0x80000000), 102 | Int: int(-0x80000000), 103 | Uint64: uint64(0xFFFFFFFFFFFFFFFF), 104 | Uint32: uint32(0xFFFFFFFF), 105 | Uint: uint(0xFFFFFFFF), 106 | Strings: []string{"a", "b"}, 107 | Nil: nil, 108 | } 109 | got := verifyMarshal(t, a) 110 | var out alltypes 111 | verifyUnmarshal(t, got, &out) 112 | // Verify easter egg 113 | if out.String != "string1" { 114 | t.Errorf("got %s, want %s", out.String, "string1") 115 | } 116 | out.String = "string" 117 | if !reflect.DeepEqual(a, out) { 118 | t.Errorf("got\n%+v, want\n%+v", out, a) 119 | } 120 | 121 | b := alltypes{Bytes: []byte(""), Strings: []string{"a"}} 122 | got = verifyMarshal(t, b) 123 | var outb alltypes 124 | verifyUnmarshal(t, got, &outb) 125 | if outb.Bytes == nil || len(outb.Bytes) != 0 { 126 | t.Errorf("got %q, want nil", string(outb.Bytes)) 127 | } 128 | } 129 | 130 | func TestTypes(t *testing.T) { 131 | in := map[string]interface{}{ 132 | "bytes": []byte("bytes"), 133 | "float64": float64(64), 134 | "string": "string", 135 | "bool": true, 136 | "time": time.Unix(1136243045, 0).UTC(), 137 | "int64": int64(-0x8000000000000000), 138 | "int32": int32(-0x80000000), 139 | "int": int(-0x80000000), 140 | "uint64": uint64(0xFFFFFFFFFFFFFFFF), 141 | "uint32": uint32(0xFFFFFFFF), 142 | "uint": uint(0xFFFFFFFF), 143 | "slice": []interface{}{1, nil}, 144 | "nil": nil, 145 | } 146 | marshalled := verifyMarshal(t, in) 147 | got := make(map[string]interface{}) 148 | verifyUnmarshal(t, marshalled, &got) 149 | 150 | want := map[string]interface{}{ 151 | "bytes": []byte("bytes"), 152 | "float64": float64(64), 153 | "string": []byte("string"), 154 | "bool": true, 155 | "time": time.Unix(1136243045, 0).UTC(), 156 | "int64": int64(-0x8000000000000000), 157 | "int32": int32(-0x80000000), 158 | "int": int64(-0x80000000), 159 | "uint64": uint64(0xFFFFFFFFFFFFFFFF), 160 | "uint32": uint64(0xFFFFFFFF), 161 | "uint": uint64(0xFFFFFFFF), 162 | "slice": []interface{}{int64(1), nil}, 163 | "nil": nil, 164 | } 165 | // We do the range so the errors are more precise. 166 | for k, v := range got { 167 | if !reflect.DeepEqual(v, want[k]) { 168 | t.Errorf("got \n%+v, want \n%+v", v, want[k]) 169 | } 170 | } 171 | } 172 | 173 | // test that we are calling the right encoding method 174 | // if we use the reflection code, this will fail as reflection 175 | // cannot access the non-exported field 176 | type PrivateStruct struct { 177 | veryPrivate uint64 178 | } 179 | 180 | func (ps *PrivateStruct) MarshalBson(buf *bytes2.ChunkedWriter, key string) { 181 | EncodeOptionalPrefix(buf, Object, key) 182 | lenWriter := NewLenWriter(buf) 183 | 184 | EncodeUint64(buf, "Type", ps.veryPrivate) 185 | 186 | lenWriter.Close() 187 | } 188 | 189 | func (ps *PrivateStruct) UnmarshalBson(buf *bytes.Buffer, kind byte) { 190 | VerifyObject(kind) 191 | Next(buf, 4) 192 | 193 | for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) { 194 | key := ReadCString(buf) 195 | switch key { 196 | case "Type": 197 | verifyKind("Type", Ulong, kind) 198 | ps.veryPrivate = DecodeUint64(buf, kind) 199 | default: 200 | Skip(buf, kind) 201 | } 202 | } 203 | } 204 | 205 | // an array can use non-pointers for custom marshaler 206 | type PrivateStructList struct { 207 | List []PrivateStruct 208 | } 209 | 210 | // the map has to be using pointers, so the custom marshaler is used 211 | type PrivateStructMap struct { 212 | Map map[string]*PrivateStruct 213 | } 214 | 215 | type PrivateStructStruct struct { 216 | Inner *PrivateStruct 217 | } 218 | 219 | func TestCustomStruct(t *testing.T) { 220 | // This should use the custom marshaler & unmarshaler 221 | s := PrivateStruct{1} 222 | got := verifyMarshal(t, &s) 223 | want := "\x13\x00\x00\x00?Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00" 224 | if string(got) != want { 225 | t.Errorf("got %q, want %q", string(got), want) 226 | } 227 | var s2 PrivateStruct 228 | verifyUnmarshal(t, got, &s2) 229 | if s2 != s { 230 | t.Errorf("got \n%+v, want \n%+v", s2, s) 231 | } 232 | 233 | // This should use the custom marshaler & unmarshaler 234 | sl := PrivateStructList{make([]PrivateStruct, 1)} 235 | sl.List[0] = s 236 | got = verifyMarshal(t, &sl) 237 | want = "&\x00\x00\x00\x04List\x00\x1b\x00\x00\x00\x030\x00\x13\x00\x00\x00?Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" 238 | if string(got) != want { 239 | t.Errorf("got %q, want %q", string(got), want) 240 | } 241 | var sl2 PrivateStructList 242 | verifyUnmarshal(t, got, &sl2) 243 | if !reflect.DeepEqual(sl2, sl) { 244 | t.Errorf("got \n%+v, want \n%+v", sl2, sl) 245 | } 246 | 247 | // This should use the custom marshaler & unmarshaler 248 | smp := make(map[string]*PrivateStruct) 249 | smp["first"] = &s 250 | got = verifyMarshal(t, smp) 251 | want = "\x1f\x00\x00\x00\x03first\x00\x13\x00\x00\x00?Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00" 252 | if string(got) != want { 253 | t.Errorf("got %q, want %q", string(got), want) 254 | } 255 | smp2 := make(map[string]*PrivateStruct) 256 | verifyUnmarshal(t, got, &smp2) 257 | if !reflect.DeepEqual(smp2, smp) { 258 | t.Errorf("got \n%+v, want \n%+v", smp2, smp) 259 | } 260 | 261 | // This should not use the custom unmarshaler 262 | sm := make(map[string]PrivateStruct) 263 | sm["first"] = s 264 | sm2 := make(map[string]PrivateStruct) 265 | verifyUnmarshal(t, got, &sm2) 266 | if reflect.DeepEqual(sm2, sm) { 267 | t.Errorf("got \n%+v, want \n%+v", sm2, sm) 268 | } 269 | 270 | // This should not use the custom marshaler 271 | got = verifyMarshal(t, sm) 272 | want = "\x11\x00\x00\x00\x03first\x00\x05\x00\x00\x00\x00\x00" 273 | if string(got) != want { 274 | t.Errorf("got %q, want %q", string(got), want) 275 | } 276 | 277 | // This should not use the custom marshaler (or crash) 278 | nilinner := PrivateStructStruct{} 279 | got = verifyMarshal(t, &nilinner) 280 | want = "\f\x00\x00\x00\nInner\x00\x00" 281 | if string(got) != want { 282 | t.Errorf("got %q, want %q", string(got), want) 283 | } 284 | } 285 | 286 | type HasPrivate struct { 287 | private string 288 | Public string 289 | } 290 | 291 | func TestIgnorePrivateFields(t *testing.T) { 292 | v := HasPrivate{private: "private", Public: "public"} 293 | marshaled := verifyMarshal(t, v) 294 | unmarshaled := new(HasPrivate) 295 | Unmarshal(marshaled, unmarshaled) 296 | if unmarshaled.Public != "Public" && unmarshaled.private != "" { 297 | t.Errorf("private fields were not ignored: %+v", unmarshaled) 298 | } 299 | } 300 | 301 | type LotsMoreFields struct { 302 | CommonField1 string 303 | ExtraField1 float64 304 | ExtraField2 string 305 | ExtraField3 HasPrivate 306 | ExtraField4 []string 307 | CommonField2 string 308 | ExtraField5 []byte 309 | ExtraField6 bool 310 | ExtraField7 time.Time 311 | ExtraField8 *int 312 | ExtraField9 int32 313 | ExtraField10 int64 314 | ExtraField11 uint64 315 | } 316 | 317 | type LotsFewerFields struct { 318 | CommonField1 string 319 | CommonField2 string 320 | } 321 | 322 | func TestSkipUnknownFields(t *testing.T) { 323 | v := LotsMoreFields{ 324 | CommonField1: "value1", 325 | ExtraField1: 1.0, 326 | ExtraField2: "abcd", 327 | ExtraField3: HasPrivate{private: "private", Public: "public"}, 328 | ExtraField4: []string{"s1", "s2"}, 329 | CommonField2: "value3", 330 | ExtraField5: []byte("abcd"), 331 | ExtraField6: true, 332 | ExtraField7: time.Now(), 333 | } 334 | marshaled := verifyMarshal(t, v) 335 | unmarshaled := LotsFewerFields{} 336 | verifyUnmarshal(t, marshaled, &unmarshaled) 337 | want := LotsFewerFields{ 338 | CommonField1: "value1", 339 | CommonField2: "value3", 340 | } 341 | if unmarshaled != want { 342 | t.Errorf("got \n%+v, want \n%+v", unmarshaled, want) 343 | } 344 | } 345 | 346 | func TestEncodeFieldNil(t *testing.T) { 347 | buf := bytes2.NewChunkedWriter(DefaultBufferSize) 348 | EncodeField(buf, "Val", nil) 349 | got := string(buf.Bytes()) 350 | want := "\nVal\x00" 351 | if got != want { 352 | t.Errorf("nil encode: got %q, want %q", got, want) 353 | } 354 | } 355 | 356 | func TestStream(t *testing.T) { 357 | buf := bytes.NewBuffer(nil) 358 | err := MarshalToStream(buf, 1) 359 | if err != nil { 360 | t.Error(err) 361 | return 362 | } 363 | want := "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00" 364 | got := buf.String() 365 | if got != want { 366 | t.Errorf("got \n%q, want %q", got, want) 367 | } 368 | readbuf := bytes.NewBuffer(buf.Bytes()) 369 | var out int64 370 | err = UnmarshalFromStream(readbuf, &out) 371 | if err != nil { 372 | t.Error(err) 373 | return 374 | } 375 | if out != 1 { 376 | t.Errorf("got %d, want 1", out) 377 | } 378 | err = MarshalToStream(buf, make(chan int)) 379 | want = "unexpected type chan int" 380 | got = err.Error() 381 | if got != want { 382 | t.Errorf("got \n%q, want %q", got, want) 383 | } 384 | } 385 | 386 | var testMap map[string]interface{} 387 | var testBlob []byte 388 | 389 | func init() { 390 | testMap = map[string]interface{}{ 391 | "bytes": []byte("bytes"), 392 | "float64": float64(64), 393 | "string": "string", 394 | "bool": true, 395 | "time": time.Unix(1136243045, 0), 396 | "int64": int64(-0x8000000000000000), 397 | "int32": int32(-0x80000000), 398 | "int": int(-0x80000000), 399 | "uint64": uint64(0xFFFFFFFFFFFFFFFF), 400 | "uint32": uint32(0xFFFFFFFF), 401 | "uint": uint(0xFFFFFFFF), 402 | "slice": []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, nil}, 403 | "nil": nil, 404 | } 405 | testBlob, _ = Marshal(testMap) 406 | } 407 | 408 | func BenchmarkMarshal(b *testing.B) { 409 | for i := 0; i < b.N; i++ { 410 | _, err := Marshal(testMap) 411 | if err != nil { 412 | b.Fatal(err) 413 | } 414 | } 415 | } 416 | 417 | func BenchmarkUnmarshal(b *testing.B) { 418 | for i := 0; i < b.N; i++ { 419 | v := make(map[string]interface{}) 420 | err := Unmarshal(testBlob, &v) 421 | if err != nil { 422 | b.Fatal(err) 423 | } 424 | } 425 | } 426 | 427 | func BenchmarkEncodeField(b *testing.B) { 428 | values := []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 429 | for i := 0; i < b.N; i++ { 430 | buf := bytes2.NewChunkedWriter(2048) 431 | EncodeField(buf, "Val", values) 432 | buf.Reset() 433 | } 434 | } 435 | 436 | func BenchmarkEncodeInterface(b *testing.B) { 437 | values := []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 438 | for i := 0; i < b.N; i++ { 439 | buf := bytes2.NewChunkedWriter(2048) 440 | EncodeInterface(buf, "Val", values) 441 | buf.Reset() 442 | } 443 | } 444 | 445 | func verifyMarshal(t *testing.T, val interface{}) []byte { 446 | got, err := Marshal(val) 447 | if err != nil { 448 | t.Errorf("Marshal error for %+v: %v\n", val, err) 449 | } 450 | return got 451 | } 452 | 453 | func verifyUnmarshal(t *testing.T, buf []byte, val interface{}) { 454 | if err := Unmarshal(buf, val); err != nil { 455 | t.Errorf("Unmarshal error for %+v: %v\n", val, err) 456 | } 457 | } 458 | -------------------------------------------------------------------------------- /parser/dependency/bson/common.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package bson implements encoding and decoding of BSON objects. 6 | package bson 7 | 8 | import ( 9 | "encoding/binary" 10 | "fmt" 11 | "reflect" 12 | "time" 13 | ) 14 | 15 | // Pack is the BSON binary packing protocol. 16 | // It's little endian. 17 | var Pack = binary.LittleEndian 18 | 19 | var ( 20 | timeType = reflect.TypeOf(time.Time{}) 21 | bytesType = reflect.TypeOf([]byte(nil)) 22 | ) 23 | 24 | // Words size in bytes. 25 | const ( 26 | WORD32 = 4 27 | WORD64 = 8 28 | ) 29 | 30 | const ( 31 | EOO = 0x00 32 | Number = 0x01 33 | String = 0x02 34 | Object = 0x03 35 | Array = 0x04 36 | Binary = 0x05 37 | Undefined = 0x06 // deprecated 38 | OID = 0x07 // unsupported 39 | Boolean = 0x08 40 | Datetime = 0x09 41 | Null = 0x0A 42 | Regex = 0x0B // unsupported 43 | Ref = 0x0C // deprecated 44 | Code = 0x0D // unsupported 45 | Symbol = 0x0E // unsupported 46 | CodeWithScope = 0x0F // unsupported 47 | Int = 0x10 48 | Timestamp = 0x11 // unsupported 49 | Long = 0x12 50 | Ulong = 0x3F // nonstandard extension 51 | MinKey = 0xFF // unsupported 52 | MaxKey = 0x7F // unsupported 53 | ) 54 | 55 | const ( 56 | // MAGICTAG is the tag used to embed simple types inside 57 | // a bson document. 58 | MAGICTAG = "_Val_" 59 | ) 60 | 61 | type BsonError struct { 62 | Message string 63 | } 64 | 65 | func NewBsonError(format string, args ...interface{}) BsonError { 66 | return BsonError{fmt.Sprintf(format, args...)} 67 | } 68 | 69 | func (err BsonError) Error() string { 70 | return err.Message 71 | } 72 | 73 | func handleError(err *error) { 74 | if x := recover(); x != nil { 75 | *err = x.(BsonError) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /parser/dependency/bson/custom_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "reflect" 11 | "testing" 12 | "time" 13 | 14 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 15 | ) 16 | 17 | const ( 18 | bsonValNil = "\nVal\x00" 19 | bsonValBytes = "\x05Val\x00\x04\x00\x00\x00\x00test" 20 | bsonValInt64 = "\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00" 21 | bsonValInt32 = "\x10Val\x00\x01\x00\x00\x00" 22 | bsonValUint64 = "?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00" 23 | bsonValFloat64 = "\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?" 24 | bsonValBool = "\bVal\x00\x01" 25 | bsonValMap = "\x03Val\x00\x13\x00\x00\x00\x12Val1\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00" 26 | bsonValSlice = "\x04Val\x00\x10\x00\x00\x00\x120\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00" 27 | bsonValTime = "\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00" 28 | ) 29 | 30 | var interfaceMarshalCases = []struct { 31 | desc string 32 | in interface{} 33 | out string 34 | }{ 35 | {"nil", nil, bsonValNil}, 36 | {"string", "test", bsonValBytes}, 37 | {"[]byte", []byte("test"), bsonValBytes}, 38 | {"int64", int64(1), bsonValInt64}, 39 | {"int32", int32(1), bsonValInt32}, 40 | {"int", int(1), bsonValInt64}, 41 | {"uint64", uint64(1), bsonValUint64}, 42 | {"uint32", uint32(1), bsonValUint64}, 43 | {"uint", uint(1), bsonValUint64}, 44 | {"float64", float64(1.0), bsonValFloat64}, 45 | {"bool", true, bsonValBool}, 46 | {"nil map", map[string]interface{}(nil), bsonValNil}, 47 | {"map", map[string]interface{}{"Val1": 1}, bsonValMap}, 48 | {"nil slice", []interface{}(nil), bsonValNil}, 49 | {"slice", []interface{}{1}, bsonValSlice}, 50 | {"time", time.Unix(1136243045, 0).UTC(), bsonValTime}, 51 | } 52 | 53 | func TestInterfaceMarshal(t *testing.T) { 54 | for _, tcase := range interfaceMarshalCases { 55 | buf := bytes2.NewChunkedWriter(DefaultBufferSize) 56 | EncodeInterface(buf, "Val", tcase.in) 57 | got := string(buf.Bytes()) 58 | if got != tcase.out { 59 | t.Errorf("%s: got \n%q, want \n%q", tcase.desc, got, tcase.out) 60 | } 61 | } 62 | } 63 | 64 | func TestInterfaceMarshalFailure(t *testing.T) { 65 | want := "don't know how to marshal chan int" 66 | func() { 67 | defer func() { 68 | if x := recover(); x != nil { 69 | got := x.(BsonError).Error() 70 | if got != want { 71 | t.Errorf("got %s, want %s", got, want) 72 | } 73 | return 74 | } 75 | }() 76 | buf := bytes2.NewChunkedWriter(DefaultBufferSize) 77 | EncodeInterface(buf, "Val", make(chan int)) 78 | t.Errorf("got no error, want %s", want) 79 | }() 80 | } 81 | 82 | const ( 83 | bsonString = "\x05\x00\x00\x00test\x00" 84 | bsonBinary = "\x04\x00\x00\x00\x00test" 85 | bsonInt = "\x01\x00\x00\x00" 86 | bsonLong = "\x01\x00\x00\x00\x00\x00\x00\x00" 87 | bsonNumber = "\x00\x00\x00\x00\x00\x00\xf0?" 88 | bsonDatetime = "\x88\xf2\\\x8d\b\x01\x00\x00" 89 | bsonBoolean = "\x01" 90 | bsonObject = "\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00" 91 | bsonObjectNull = "\v\x00\x00\x00\nVal2\x00\x00" 92 | bsonArray = "\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00" 93 | bsonArrayNull = "\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00" 94 | bsonStringArray = "\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00" 95 | ) 96 | 97 | func stringDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeString(buf, kind) } 98 | func binaryDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeBinary(buf, kind) } 99 | func int64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt64(buf, kind) } 100 | func int32Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt32(buf, kind) } 101 | func intDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt(buf, kind) } 102 | func uint64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint64(buf, kind) } 103 | func uint32Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint32(buf, kind) } 104 | func uintDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint(buf, kind) } 105 | func float64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeFloat64(buf, kind) } 106 | func boolDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeBool(buf, kind) } 107 | func timeDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeTime(buf, kind) } 108 | func interfaceDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInterface(buf, kind) } 109 | func mapDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeMap(buf, kind) } 110 | func arrayDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeArray(buf, kind) } 111 | func skipDecoder(buf *bytes.Buffer, kind byte) interface{} { Skip(buf, kind); return nil } 112 | func stringArrayDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeStringArray(buf, kind) } 113 | 114 | var customUnmarshalCases = []struct { 115 | desc string 116 | in string 117 | kind byte 118 | decoder func(buf *bytes.Buffer, kind byte) interface{} 119 | out interface{} 120 | }{ 121 | {"String->string", bsonString, String, stringDecoder, "test"}, 122 | {"Binary->string", bsonBinary, Binary, stringDecoder, "test"}, 123 | {"Null->string", "", Null, stringDecoder, ""}, 124 | {"String->bytes", bsonString, String, binaryDecoder, []byte("test")}, 125 | {"Binary->bytes", bsonBinary, Binary, binaryDecoder, []byte("test")}, 126 | {"Null->bytes", "", Null, binaryDecoder, []byte(nil)}, 127 | {"Int->int64", bsonInt, Int, int64Decoder, int64(1)}, 128 | {"Long->int64", bsonLong, Long, int64Decoder, int64(1)}, 129 | {"Ulong->int64", bsonLong, Ulong, int64Decoder, int64(1)}, 130 | {"Null->int64", "", Null, int64Decoder, int64(0)}, 131 | {"Int->int32", bsonInt, Int, int32Decoder, int32(1)}, 132 | {"Null->int32", "", Null, int32Decoder, int32(0)}, 133 | {"Int->int", bsonInt, Int, intDecoder, int(1)}, 134 | {"Long->int", bsonLong, Long, intDecoder, int(1)}, 135 | {"Ulong->int", bsonLong, Ulong, intDecoder, int(1)}, 136 | {"Null->int", "", Null, intDecoder, int(0)}, 137 | {"Int->uint64", bsonInt, Int, uint64Decoder, uint64(1)}, 138 | {"Long->uint64", bsonLong, Long, uint64Decoder, uint64(1)}, 139 | {"Ulong->uint64", bsonLong, Ulong, uint64Decoder, uint64(1)}, 140 | {"Null->uint64", "", Null, uint64Decoder, uint64(0)}, 141 | {"Int->uint32", bsonInt, Int, uint32Decoder, uint32(1)}, 142 | {"Ulong->uint32", bsonLong, Ulong, uint32Decoder, uint32(1)}, 143 | {"Null->uint32", "", Null, uint32Decoder, uint32(0)}, 144 | {"Int->uint", bsonInt, Int, uintDecoder, uint(1)}, 145 | {"Long->uint", bsonLong, Long, uintDecoder, uint(1)}, 146 | {"Ulong->uint", bsonLong, Ulong, uintDecoder, uint(1)}, 147 | {"Null->uint", "", Null, uintDecoder, uint(0)}, 148 | {"Number->float64", bsonNumber, Number, float64Decoder, float64(1.0)}, 149 | {"Null->float64", "", Null, float64Decoder, float64(0.0)}, 150 | {"Boolean->bool", bsonBoolean, Boolean, boolDecoder, true}, 151 | {"Null->bool", "", Null, boolDecoder, false}, 152 | {"Datetime->time.Time", bsonDatetime, Datetime, timeDecoder, time.Unix(1136243045, 0).UTC()}, 153 | {"Null->time.Time", "", Null, timeDecoder, time.Time{}}, 154 | {"Number->interface{}", bsonNumber, Number, interfaceDecoder, float64(1.0)}, 155 | {"String->interface{}", bsonString, String, interfaceDecoder, "test"}, 156 | {"Object->interface{}", bsonObject, Object, interfaceDecoder, map[string]interface{}{"Val2": []byte("test")}}, 157 | {"Object->interface{} with null element", bsonObjectNull, Object, interfaceDecoder, map[string]interface{}{"Val2": nil}}, 158 | {"Array->interface{}", bsonArray, Array, interfaceDecoder, []interface{}{[]byte("test")}}, 159 | {"Array->interface{} with null element", bsonArrayNull, Array, interfaceDecoder, []interface{}{nil, int64(1)}}, 160 | {"Binary->interface{}", bsonBinary, Binary, interfaceDecoder, []byte("test")}, 161 | {"Boolean->interface{}", bsonBoolean, Boolean, interfaceDecoder, true}, 162 | {"Datetime->interface{}", bsonDatetime, Datetime, interfaceDecoder, time.Unix(1136243045, 0).UTC()}, 163 | {"Int->interface{}", bsonInt, Int, interfaceDecoder, int32(1)}, 164 | {"Long->interface{}", bsonLong, Long, interfaceDecoder, int64(1)}, 165 | {"Ulong->interface{}", bsonLong, Ulong, interfaceDecoder, uint64(1)}, 166 | {"Null->interface{}", "", Null, interfaceDecoder, nil}, 167 | {"Null->map[string]interface{}", "", Null, mapDecoder, map[string]interface{}(nil)}, 168 | {"Null->[]interface{}", "", Null, arrayDecoder, []interface{}(nil)}, 169 | {"Number->Skip", bsonNumber, Number, skipDecoder, nil}, 170 | {"String->Skip", bsonString, String, skipDecoder, nil}, 171 | {"Object->Skip", bsonObject, Object, skipDecoder, nil}, 172 | {"Object->Skip with null element", bsonObjectNull, Object, skipDecoder, nil}, 173 | {"Array->Skip", bsonArray, Array, skipDecoder, nil}, 174 | {"Array->Skip with null element", bsonArrayNull, Array, skipDecoder, nil}, 175 | {"Binary->Skip", bsonBinary, Binary, skipDecoder, nil}, 176 | {"Boolean->Skip", bsonBoolean, Boolean, skipDecoder, nil}, 177 | {"Datetime->Skip", bsonDatetime, Datetime, skipDecoder, nil}, 178 | {"Int->Skip", bsonInt, Int, skipDecoder, nil}, 179 | {"Long->Skip", bsonLong, Long, skipDecoder, nil}, 180 | {"Ulong->Skip", bsonLong, Ulong, skipDecoder, nil}, 181 | {"Null->Skip", "", Null, skipDecoder, nil}, 182 | {"Null->map[string]interface{}", "", Null, mapDecoder, map[string]interface{}(nil)}, 183 | {"Null->[]interface{}", "", Null, arrayDecoder, []interface{}(nil)}, 184 | {"Array->[]string", bsonStringArray, Array, stringArrayDecoder, []string{"test1", "test2"}}, 185 | {"Null->[]string", "", Null, stringArrayDecoder, []string(nil)}, 186 | } 187 | 188 | func TestCustomUnmarshal(t *testing.T) { 189 | for _, tcase := range customUnmarshalCases { 190 | buf := bytes.NewBuffer([]byte(tcase.in)) 191 | got := tcase.decoder(buf, tcase.kind) 192 | if !reflect.DeepEqual(got, tcase.out) { 193 | t.Errorf("%s: received: %v, want %v", tcase.desc, got, tcase.out) 194 | } 195 | if buf.Len() != 0 { 196 | t.Errorf("%s: %d unread bytes from %q, want 0", tcase.desc, buf.Len(), tcase.in) 197 | } 198 | } 199 | } 200 | 201 | var customUnmarshalFailureCases = []struct { 202 | typ string 203 | decoder func(buf *bytes.Buffer, kind byte) interface{} 204 | valid []byte 205 | }{ 206 | {"string", stringDecoder, []byte{String, Binary, Null}}, 207 | {"[]byte", binaryDecoder, []byte{String, Binary, Null}}, 208 | {"int64", int64Decoder, []byte{Int, Long, Ulong, Null}}, 209 | {"int32", int32Decoder, []byte{Int, Null}}, 210 | {"int", intDecoder, []byte{Int, Long, Ulong, Null}}, 211 | {"uint64", uint64Decoder, []byte{Int, Long, Ulong, Null}}, 212 | {"uint32", uint32Decoder, []byte{Int, Ulong, Null}}, 213 | {"uint", uintDecoder, []byte{Int, Long, Ulong, Null}}, 214 | {"float64", float64Decoder, []byte{Number, Null}}, 215 | {"bool", boolDecoder, []byte{Boolean, Int, Long, Ulong, Null}}, 216 | {"time.Time", timeDecoder, []byte{Datetime, Null}}, 217 | {"interface{}", interfaceDecoder, []byte{Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong}}, 218 | {"map", mapDecoder, []byte{Object, Null}}, 219 | {"slice", arrayDecoder, []byte{Array, Null}}, 220 | {"[]string", stringArrayDecoder, []byte{Array, Null}}, 221 | {"skip", skipDecoder, []byte{Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong}}, 222 | } 223 | 224 | func TestCustomUnmarshalFailures(t *testing.T) { 225 | allKinds := []byte{EOO, Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong} 226 | for _, tcase := range customUnmarshalFailureCases { 227 | for _, kind := range allKinds { 228 | want := fmt.Sprintf("unexpected kind %v for %s", kind, tcase.typ) 229 | func() { 230 | defer func() { 231 | if x := recover(); x != nil { 232 | got := x.(BsonError).Error() 233 | if got != want { 234 | t.Errorf("got %s, want %s", got, want) 235 | } 236 | return 237 | } 238 | }() 239 | for _, valid := range tcase.valid { 240 | if kind == valid { 241 | return 242 | } 243 | } 244 | tcase.decoder(nil, kind) 245 | t.Errorf("got no error, want %s", want) 246 | }() 247 | } 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /parser/dependency/bson/marshal.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "io" 9 | "math" 10 | "reflect" 11 | "strconv" 12 | "time" 13 | 14 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 15 | ) 16 | 17 | // LenWriter records the current write position on the buffer 18 | // and can later be used to record the number of bytes written 19 | // in conformance to BSON spec 20 | type LenWriter struct { 21 | buf *bytes2.ChunkedWriter 22 | off int 23 | b []byte 24 | } 25 | 26 | // NewLenWriter returns a LenWriter that reserves the 27 | // bytes buf so they can store the length later. 28 | func NewLenWriter(buf *bytes2.ChunkedWriter) LenWriter { 29 | off := buf.Len() 30 | b := buf.Reserve(WORD32) 31 | return LenWriter{buf, off, b} 32 | } 33 | 34 | // Close closes the current object being encoded by 35 | // writing bson's EOO byte and recording the length. 36 | func (lw LenWriter) Close() { 37 | lw.buf.WriteByte(EOO) 38 | Pack.PutUint32(lw.b, uint32(lw.buf.Len()-lw.off)) 39 | } 40 | 41 | // Marshaler is the interface that needs to be 42 | // satisfied by types that want to implement a custom 43 | // marshaler. 44 | // When being invoked as a top level object, key will 45 | // be "". In such cases, MarshalBson must not encode 46 | // any prefix. 47 | type Marshaler interface { 48 | MarshalBson(buf *bytes2.ChunkedWriter, key string) 49 | } 50 | 51 | func canMarshal(val reflect.Value) Marshaler { 52 | // Check the Marshaler interface on T. 53 | if marshaler, ok := val.Interface().(Marshaler); ok { 54 | // Don't call custom marshaler for nil values. 55 | switch val.Kind() { 56 | case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice: 57 | if val.IsNil() { 58 | return nil 59 | } 60 | } 61 | return marshaler 62 | } 63 | // Check the Marshaler interface on *T. 64 | if val.CanAddr() { 65 | if marshaler, ok := val.Addr().Interface().(Marshaler); ok { 66 | return marshaler 67 | } 68 | } 69 | return nil 70 | } 71 | 72 | // DefaultBufferSize is the default allocation size for ChunkedWriter. 73 | const DefaultBufferSize = 1024 74 | 75 | // MarshalToStream marshals val into writer. 76 | func MarshalToStream(writer io.Writer, val interface{}) (err error) { 77 | buf := bytes2.NewChunkedWriter(DefaultBufferSize) 78 | if err = MarshalToBuffer(buf, val); err != nil { 79 | return err 80 | } 81 | _, err = buf.WriteTo(writer) 82 | return err 83 | } 84 | 85 | // Marshal marshals val into encoded. 86 | func Marshal(val interface{}) (encoded []byte, err error) { 87 | buf := bytes2.NewChunkedWriter(DefaultBufferSize) 88 | err = MarshalToBuffer(buf, val) 89 | return buf.Bytes(), err 90 | } 91 | 92 | // MarshalToBuffer marshals val into buf. This is the most efficient 93 | // function to use, especially when marshaling large nested objects. 94 | func MarshalToBuffer(buf *bytes2.ChunkedWriter, val interface{}) (err error) { 95 | defer handleError(&err) 96 | if val == nil { 97 | return NewBsonError("cannot marshal nil") 98 | } 99 | 100 | v := reflect.Indirect(reflect.ValueOf(val)) 101 | if marshaler := canMarshal(v); marshaler != nil { 102 | marshaler.MarshalBson(buf, "") 103 | return 104 | } 105 | 106 | switch v.Kind() { 107 | case reflect.String, 108 | reflect.Int64, reflect.Int32, reflect.Int, 109 | reflect.Uint64, reflect.Uint32, reflect.Uint, 110 | reflect.Float64, reflect.Bool: 111 | EncodeSimple(buf, v.Interface()) 112 | case reflect.Struct: 113 | if v.Type() == timeType { 114 | EncodeSimple(buf, v.Interface()) 115 | } else { 116 | encodeStructContent(buf, v) 117 | } 118 | case reflect.Map: 119 | encodeMapContent(buf, v) 120 | case reflect.Slice, reflect.Array: 121 | if v.Type() == bytesType { 122 | EncodeSimple(buf, v.Interface()) 123 | } else { 124 | encodeSliceContent(buf, v) 125 | } 126 | default: 127 | return NewBsonError("unexpected type %v", v.Type()) 128 | } 129 | return nil 130 | } 131 | 132 | // EncodeSimple marshals simple objects that cannot be 133 | // encoded as a top level bson document. 134 | func EncodeSimple(buf *bytes2.ChunkedWriter, val interface{}) { 135 | lenWriter := NewLenWriter(buf) 136 | EncodeField(buf, MAGICTAG, val) 137 | lenWriter.Close() 138 | } 139 | 140 | // EncodeField encodes val using the supplied key as embedded tag. 141 | // Unlike EncodeInterface, EncodeField can handle complex objects 142 | // like structs, pointers, etc. But it is slower. 143 | func EncodeField(buf *bytes2.ChunkedWriter, key string, val interface{}) { 144 | encodeField(buf, key, reflect.ValueOf(val)) 145 | } 146 | 147 | func encodeField(buf *bytes2.ChunkedWriter, key string, val reflect.Value) { 148 | // nil interfaces show up as invalid 149 | if !val.IsValid() { 150 | EncodePrefix(buf, Null, key) 151 | return 152 | } 153 | if marshaler := canMarshal(val); marshaler != nil { 154 | marshaler.MarshalBson(buf, key) 155 | return 156 | } 157 | 158 | switch val.Kind() { 159 | case reflect.String: 160 | EncodeString(buf, key, val.String()) 161 | case reflect.Int64: 162 | EncodeInt64(buf, key, val.Int()) 163 | case reflect.Int32: 164 | EncodeInt32(buf, key, int32(val.Int())) 165 | case reflect.Int: 166 | EncodeInt(buf, key, int(val.Int())) 167 | case reflect.Uint64: 168 | EncodeUint64(buf, key, uint64(val.Uint())) 169 | case reflect.Uint32: 170 | EncodeUint32(buf, key, uint32(val.Uint())) 171 | case reflect.Uint: 172 | EncodeUint(buf, key, uint(val.Uint())) 173 | case reflect.Float64: 174 | EncodeFloat64(buf, key, val.Float()) 175 | case reflect.Bool: 176 | EncodeBool(buf, key, val.Bool()) 177 | case reflect.Struct: 178 | if val.Type() == timeType { 179 | EncodeTime(buf, key, val.Interface().(time.Time)) 180 | } else { 181 | encodeStruct(buf, key, val) 182 | } 183 | case reflect.Map: 184 | encodeMap(buf, key, val) 185 | case reflect.Slice: 186 | if val.Type() == bytesType { 187 | EncodeBinary(buf, key, val.Interface().([]byte)) 188 | } else { 189 | encodeSlice(buf, key, val) 190 | } 191 | case reflect.Ptr, reflect.Interface: 192 | if val.IsNil() { 193 | EncodePrefix(buf, Null, key) 194 | } else { 195 | encodeField(buf, key, val.Elem()) 196 | } 197 | default: 198 | panic(NewBsonError("don't know how to marshal %v", val.Type())) 199 | } 200 | } 201 | 202 | // EncodeOptionalPrefix encodes the key as prefix if it's not empty. 203 | // If it is empty, then it's a no-op, with the assumption that 204 | // it's a top level object. 205 | func EncodeOptionalPrefix(buf *bytes2.ChunkedWriter, etype byte, key string) { 206 | if key == "" { 207 | return 208 | } 209 | EncodePrefix(buf, etype, key) 210 | } 211 | 212 | // EncodePrefix encodes key as prefix for the next object or value. 213 | func EncodePrefix(buf *bytes2.ChunkedWriter, etype byte, key string) { 214 | b := buf.Reserve(len(key) + 2) 215 | b[0] = etype 216 | copy(b[1:], key) 217 | b[len(b)-1] = 0 218 | } 219 | 220 | // EncodeString encodes a string. 221 | func EncodeString(buf *bytes2.ChunkedWriter, key string, val string) { 222 | // Encode strings as binary; go strings are not necessarily unicode 223 | EncodePrefix(buf, Binary, key) 224 | putUint32(buf, uint32(len(val))) 225 | buf.WriteByte(0) 226 | buf.WriteString(val) 227 | } 228 | 229 | // EncodeBinary encodes a []byte as binary. 230 | func EncodeBinary(buf *bytes2.ChunkedWriter, key string, val []byte) { 231 | EncodePrefix(buf, Binary, key) 232 | putUint32(buf, uint32(len(val))) 233 | buf.WriteByte(0) 234 | buf.Write(val) 235 | } 236 | 237 | // EncodeInt64 encodes an int64. 238 | func EncodeInt64(buf *bytes2.ChunkedWriter, key string, val int64) { 239 | EncodePrefix(buf, Long, key) 240 | putUint64(buf, uint64(val)) 241 | } 242 | 243 | // EncodeInt32 encodes an int32. 244 | func EncodeInt32(buf *bytes2.ChunkedWriter, key string, val int32) { 245 | EncodePrefix(buf, Int, key) 246 | putUint32(buf, uint32(val)) 247 | } 248 | 249 | // EncodeInt encodes an int. 250 | func EncodeInt(buf *bytes2.ChunkedWriter, key string, val int) { 251 | EncodeInt64(buf, key, int64(val)) 252 | } 253 | 254 | // EncodeUint64 encodes an uint64. 255 | func EncodeUint64(buf *bytes2.ChunkedWriter, key string, val uint64) { 256 | EncodePrefix(buf, Ulong, key) 257 | putUint64(buf, val) 258 | } 259 | 260 | // EncodeUint32 encodes an uint32. 261 | func EncodeUint32(buf *bytes2.ChunkedWriter, key string, val uint32) { 262 | EncodeUint64(buf, key, uint64(val)) 263 | } 264 | 265 | // EncodeUint encodes an uint. 266 | func EncodeUint(buf *bytes2.ChunkedWriter, key string, val uint) { 267 | EncodeUint64(buf, key, uint64(val)) 268 | } 269 | 270 | // EncodeFloat64 encodes a float64. 271 | func EncodeFloat64(buf *bytes2.ChunkedWriter, key string, val float64) { 272 | EncodePrefix(buf, Number, key) 273 | bits := math.Float64bits(val) 274 | putUint64(buf, bits) 275 | } 276 | 277 | // EncodeBool encodes a bool. 278 | func EncodeBool(buf *bytes2.ChunkedWriter, key string, val bool) { 279 | EncodePrefix(buf, Boolean, key) 280 | if val { 281 | buf.WriteByte(1) 282 | } else { 283 | buf.WriteByte(0) 284 | } 285 | } 286 | 287 | // EncodeTime encodes a time.Time. 288 | func EncodeTime(buf *bytes2.ChunkedWriter, key string, val time.Time) { 289 | EncodePrefix(buf, Datetime, key) 290 | mtime := val.UnixNano() / 1e6 291 | putUint64(buf, uint64(mtime)) 292 | } 293 | 294 | func encodeStruct(buf *bytes2.ChunkedWriter, key string, val reflect.Value) { 295 | EncodePrefix(buf, Object, key) 296 | encodeStructContent(buf, val) 297 | } 298 | 299 | func encodeStructContent(buf *bytes2.ChunkedWriter, val reflect.Value) { 300 | lenWriter := NewLenWriter(buf) 301 | t := val.Type() 302 | for i := 0; i < t.NumField(); i++ { 303 | key := t.Field(i).Name 304 | 305 | // NOTE(szopa): Ignore private fields (copied from 306 | // encoding/json). Yes, it feels like a hack. 307 | if t.Field(i).PkgPath != "" { 308 | continue 309 | } 310 | encodeField(buf, key, val.Field(i)) 311 | } 312 | lenWriter.Close() 313 | } 314 | 315 | func encodeMap(buf *bytes2.ChunkedWriter, key string, val reflect.Value) { 316 | EncodePrefix(buf, Object, key) 317 | encodeMapContent(buf, val) 318 | } 319 | 320 | // a map seems to lose the 'CanAddr' property. So if we want 321 | // to use a custom marshaler with a struct pointer receiver, like: 322 | // func (ps *PrivateStruct) MarshalBson(buf *bytes2.ChunkedWriter, key string) { 323 | // the map has to be using pointers, i.e: 324 | // map[string]*PrivateStruct 325 | // and not: 326 | // map[string]PrivateStruct 327 | // (see unit test) 328 | func encodeMapContent(buf *bytes2.ChunkedWriter, val reflect.Value) { 329 | lenWriter := NewLenWriter(buf) 330 | mt := val.Type() 331 | if mt.Key().Kind() != reflect.String { 332 | panic(NewBsonError("can't marshall maps with non-string key types")) 333 | } 334 | keys := val.MapKeys() 335 | for _, k := range keys { 336 | key := k.String() 337 | encodeField(buf, key, val.MapIndex(k)) 338 | } 339 | lenWriter.Close() 340 | } 341 | 342 | func encodeSlice(buf *bytes2.ChunkedWriter, key string, val reflect.Value) { 343 | EncodePrefix(buf, Array, key) 344 | encodeSliceContent(buf, val) 345 | } 346 | 347 | func encodeSliceContent(buf *bytes2.ChunkedWriter, val reflect.Value) { 348 | lenWriter := NewLenWriter(buf) 349 | for i := 0; i < val.Len(); i++ { 350 | encodeField(buf, Itoa(i), val.Index(i)) 351 | } 352 | lenWriter.Close() 353 | } 354 | 355 | func putUint32(buf *bytes2.ChunkedWriter, val uint32) { 356 | Pack.PutUint32(buf.Reserve(WORD32), val) 357 | } 358 | 359 | func putUint64(buf *bytes2.ChunkedWriter, val uint64) { 360 | Pack.PutUint64(buf.Reserve(WORD64), val) 361 | } 362 | 363 | var intStrMap [intAliasSize + 1]string 364 | 365 | const ( 366 | intAliasSize = 1024 367 | ) 368 | 369 | func init() { 370 | for i := 0; i <= intAliasSize; i++ { 371 | intStrMap[i] = strconv.Itoa(i) 372 | } 373 | } 374 | 375 | func Itoa(i int) string { 376 | if i <= intAliasSize { 377 | return intStrMap[i] 378 | } 379 | return strconv.Itoa(i) 380 | } 381 | -------------------------------------------------------------------------------- /parser/dependency/bson/marshal_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | 11 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 12 | ) 13 | 14 | type String1 string 15 | 16 | func (cs String1) MarshalBson(buf *bytes2.ChunkedWriter, key string) { 17 | // Hardcode value to verify that function is called 18 | EncodeString(buf, key, "test") 19 | } 20 | 21 | type String2 string 22 | 23 | func (cs *String2) MarshalBson(buf *bytes2.ChunkedWriter, key string) { 24 | // Hardcode value to verify that function is called 25 | EncodeString(buf, key, "test") 26 | } 27 | 28 | var marshaltest = []struct { 29 | desc string 30 | in interface{} 31 | out string 32 | }{{ 33 | "struct encode", 34 | struct{ Val string }{"test"}, 35 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 36 | }, { 37 | "struct encode nil", 38 | struct{ Val *int }{}, 39 | "\n\x00\x00\x00\nVal\x00\x00", 40 | }, { 41 | "struct encode nil interface", 42 | struct{ Val interface{} }{}, 43 | "\n\x00\x00\x00\nVal\x00\x00", 44 | }, { 45 | "map encode", 46 | map[string]string{"Val": "test"}, 47 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 48 | }, { 49 | "embedded map encode", 50 | struct{ Inner map[string]string }{map[string]string{"Val": "test"}}, 51 | "\x1f\x00\x00\x00\x03Inner\x00\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00\x00", 52 | }, { 53 | "embedded map encode nil", 54 | struct{ Inner map[string]string }{}, 55 | "\x11\x00\x00\x00\x03Inner\x00\x05\x00\x00\x00\x00\x00", 56 | }, { 57 | "slice encode", 58 | []string{"test1", "test2"}, 59 | "\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00", 60 | }, { 61 | "embedded slice encode", 62 | struct{ Inner []string }{[]string{"test1", "test2"}}, 63 | "+\x00\x00\x00\x04Inner\x00\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00\x00", 64 | }, { 65 | "embedded slice encode nil", 66 | struct{ Inner []string }{}, 67 | "\x11\x00\x00\x00\x04Inner\x00\x05\x00\x00\x00\x00\x00", 68 | }, { 69 | "array encode", 70 | [2]string{"test1", "test2"}, 71 | "\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00", 72 | }, { 73 | "string encode", 74 | "test", 75 | "\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00", 76 | }, { 77 | "bytes encode", 78 | []byte("test"), 79 | "\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00", 80 | }, { 81 | "int64 encode", 82 | int64(1), 83 | "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 84 | }, { 85 | "int32 encode", 86 | int32(1), 87 | "\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00", 88 | }, { 89 | "int encode", 90 | int(1), 91 | "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 92 | }, { 93 | "unit64 encode", 94 | uint64(1), 95 | "\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 96 | }, { 97 | "uint32 encode", 98 | uint32(1), 99 | "\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 100 | }, { 101 | "uint encode", 102 | uint(1), 103 | "\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 104 | }, { 105 | "float encode", 106 | float64(1.0), 107 | "\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 108 | }, { 109 | "bool encode", 110 | true, 111 | "\r\x00\x00\x00\b_Val_\x00\x01\x00", 112 | }, { 113 | "time encode", 114 | time.Unix(1136243045, 0).UTC(), 115 | "\x14\x00\x00\x00\t_Val_\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00", 116 | }, { 117 | 118 | // Following encodes are for reference. They're used for 119 | // the decode tests. 120 | "embedded Object encode", 121 | struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}}, 122 | "\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00", 123 | }, { 124 | "embedded Object encode nil element", 125 | struct{ Val struct{ Val2 *int64 } }{struct{ Val2 *int64 }{nil}}, 126 | "\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00", 127 | }, { 128 | "embedded Array encode", 129 | struct{ Val []string }{Val: []string{"test"}}, 130 | "\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00", 131 | }, { 132 | "Array encode nil element", 133 | struct{ Val []*int64 }{Val: []*int64{nil, newint64(1)}}, 134 | "\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00", 135 | }, { 136 | "embedded Number encode", 137 | struct{ Val float64 }{1.0}, 138 | "\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 139 | }, { 140 | "embedded Binary encode", 141 | struct{ Val string }{"test"}, 142 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 143 | }, { 144 | "embedded Boolean encode", 145 | struct{ Val bool }{true}, 146 | "\v\x00\x00\x00\bVal\x00\x01\x00", 147 | }, { 148 | "embedded Datetime encode", 149 | struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()}, 150 | "\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00", 151 | }, { 152 | "embedded Null encode", 153 | struct{ Val *int }{}, 154 | "\n\x00\x00\x00\nVal\x00\x00", 155 | }, { 156 | "embedded Int encode", 157 | struct{ Val int32 }{1}, 158 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 159 | }, { 160 | "embedded Long encode", 161 | struct{ Val int64 }{1}, 162 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 163 | }, { 164 | "embedded Ulong encode", 165 | struct{ Val uint64 }{1}, 166 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 167 | }, { 168 | "embedded non-pointer encode with custom marshaler", 169 | struct{ Val String1 }{String1("foo")}, 170 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 171 | }, { 172 | "embedded pointer encode with custom marshaler", 173 | struct{ Val *String1 }{func(cs String1) *String1 { return &cs }(String1("foo"))}, 174 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 175 | }, { 176 | "embedded nil pointer encode with custom marshaler", 177 | struct{ Val *String1 }{}, 178 | "\n\x00\x00\x00\nVal\x00\x00", 179 | }, { 180 | "embedded pointer encode with custom pointer marshaler", 181 | struct{ Val *String2 }{func(cs String2) *String2 { return &cs }(String2("foo"))}, 182 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 183 | }, { 184 | "embedded addressable encode with custom pointer marshaler", 185 | &struct{ Val String2 }{String2("foo")}, 186 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 187 | }, { 188 | "embedded non-addressable encode with custom pointer marshaler", 189 | struct{ Val String2 }{String2("foo")}, 190 | "\x12\x00\x00\x00\x05Val\x00\x03\x00\x00\x00\x00foo\x00", 191 | }} 192 | 193 | func TestMarshal(t *testing.T) { 194 | for _, tcase := range marshaltest { 195 | got := verifyMarshal(t, tcase.in) 196 | if string(got) != tcase.out { 197 | t.Errorf("%s: encoded: \n%q, want\n%q", tcase.desc, got, tcase.out) 198 | } 199 | } 200 | } 201 | 202 | var marshalErrorCases = []struct { 203 | desc string 204 | in interface{} 205 | out string 206 | }{{ 207 | "nil input", 208 | nil, 209 | "cannot marshal nil", 210 | }, { 211 | "chan input", 212 | make(chan int), 213 | "unexpected type chan int", 214 | }, { 215 | "embedded chan input", 216 | struct{ Val chan int }{}, 217 | "don't know how to marshal chan int", 218 | }, { 219 | "map with int key", 220 | map[int]int{}, 221 | "can't marshall maps with non-string key types", 222 | }} 223 | 224 | func TestMarshalErrors(t *testing.T) { 225 | for _, tcase := range marshalErrorCases { 226 | _, err := Marshal(tcase.in) 227 | got := "" 228 | if err != nil { 229 | got = err.Error() 230 | } 231 | if got != tcase.out { 232 | t.Errorf("%s: received: %q, want %q", tcase.desc, got, tcase.out) 233 | } 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /parser/dependency/bson/marshal_util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Utility functions for custom encoders 6 | 7 | package bson 8 | 9 | import ( 10 | "time" 11 | 12 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 13 | ) 14 | 15 | // EncodeInterface bson encodes an interface{}. Elements 16 | // can be basic bson encodable types, or []interface{}, 17 | // or map[string]interface{}, whose elements have to in 18 | // turn be bson encodable. 19 | func EncodeInterface(buf *bytes2.ChunkedWriter, key string, val interface{}) { 20 | if val == nil { 21 | EncodePrefix(buf, Null, key) 22 | return 23 | } 24 | switch val := val.(type) { 25 | case string: 26 | EncodeString(buf, key, val) 27 | case []byte: 28 | EncodeBinary(buf, key, val) 29 | case int64: 30 | EncodeInt64(buf, key, val) 31 | case int32: 32 | EncodeInt32(buf, key, val) 33 | case int: 34 | EncodeInt(buf, key, val) 35 | case uint64: 36 | EncodeUint64(buf, key, val) 37 | case uint32: 38 | EncodeUint32(buf, key, val) 39 | case uint: 40 | EncodeUint(buf, key, val) 41 | case float64: 42 | EncodeFloat64(buf, key, val) 43 | case bool: 44 | EncodeBool(buf, key, val) 45 | case map[string]interface{}: 46 | if val == nil { 47 | EncodePrefix(buf, Null, key) 48 | return 49 | } 50 | EncodePrefix(buf, Object, key) 51 | lenWriter := NewLenWriter(buf) 52 | for k, v := range val { 53 | EncodeInterface(buf, k, v) 54 | } 55 | lenWriter.Close() 56 | case []interface{}: 57 | if val == nil { 58 | EncodePrefix(buf, Null, key) 59 | return 60 | } 61 | EncodePrefix(buf, Array, key) 62 | lenWriter := NewLenWriter(buf) 63 | for i, v := range val { 64 | EncodeInterface(buf, Itoa(i), v) 65 | } 66 | lenWriter.Close() 67 | case time.Time: 68 | EncodeTime(buf, key, val) 69 | default: 70 | panic(NewBsonError("don't know how to marshal %T", val)) 71 | } 72 | } 73 | 74 | // EncodeStringArray bson encodes a []string 75 | func EncodeStringArray(buf *bytes2.ChunkedWriter, name string, values []string) { 76 | if values == nil { 77 | EncodePrefix(buf, Null, name) 78 | return 79 | } 80 | EncodePrefix(buf, Array, name) 81 | lenWriter := NewLenWriter(buf) 82 | for i, val := range values { 83 | EncodeString(buf, Itoa(i), val) 84 | } 85 | lenWriter.Close() 86 | } 87 | -------------------------------------------------------------------------------- /parser/dependency/bson/unmarshal.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io" 11 | "reflect" 12 | "time" 13 | ) 14 | 15 | // Unmarshaler is the interface that needs to be satisfied 16 | // by types that want to perform custom unmarshaling. 17 | // If kind is EOO, then the type is being unmarshalled 18 | // as a top level object. Otherwise, it's an embedded 19 | // object, and kind will need to be type-checked 20 | // before unmarshaling. 21 | type Unmarshaler interface { 22 | UnmarshalBson(buf *bytes.Buffer, kind byte) 23 | } 24 | 25 | func (builder *valueBuilder) canUnMarshal() Unmarshaler { 26 | // Don't use custom unmarshalers for map values. 27 | // It loses symmetry. 28 | if builder.map_.IsValid() { 29 | return nil 30 | } 31 | if builder.val.CanAddr() { 32 | if unmarshaler, ok := builder.val.Addr().Interface().(Unmarshaler); ok { 33 | return unmarshaler 34 | } 35 | } 36 | return nil 37 | } 38 | 39 | // Unmarshal unmarshals b into val. 40 | func Unmarshal(b []byte, val interface{}) (err error) { 41 | return UnmarshalFromBuffer(bytes.NewBuffer(b), val) 42 | } 43 | 44 | // UnmarshalFromStream unmarshals from reader into val. 45 | func UnmarshalFromStream(reader io.Reader, val interface{}) (err error) { 46 | lenbuf := make([]byte, 4) 47 | var n int 48 | n, err = io.ReadFull(reader, lenbuf) 49 | if err != nil { 50 | return err 51 | } 52 | if n != 4 { 53 | return io.ErrUnexpectedEOF 54 | } 55 | length := Pack.Uint32(lenbuf) 56 | b := make([]byte, length) 57 | Pack.PutUint32(b, length) 58 | n, err = io.ReadFull(reader, b[4:]) 59 | if err != nil { 60 | if err == io.EOF { 61 | return io.ErrUnexpectedEOF 62 | } 63 | return err 64 | } 65 | if n != int(length-4) { 66 | return io.ErrUnexpectedEOF 67 | } 68 | return UnmarshalFromBuffer(bytes.NewBuffer(b), val) 69 | } 70 | 71 | // UnmarshalFromBuffer unmarshals from buf into val. 72 | func UnmarshalFromBuffer(buf *bytes.Buffer, val interface{}) (err error) { 73 | defer handleError(&err) 74 | if val == nil { 75 | Skip(buf, Object) 76 | return nil 77 | } 78 | 79 | if unmarshaler, ok := val.(Unmarshaler); ok { 80 | unmarshaler.UnmarshalBson(buf, EOO) 81 | return nil 82 | } 83 | sb, err := topLevelBuilder(val) 84 | if err != nil { 85 | return err 86 | } 87 | decodeDocument(buf, sb, EOO) 88 | sb.save() 89 | return nil 90 | } 91 | 92 | func decodeDocument(buf *bytes.Buffer, builder *valueBuilder, kind byte) { 93 | if kind != EOO && kind != Object && kind != Array { 94 | panic(NewBsonError("unexpected kind: %v", kind)) 95 | } 96 | Next(buf, 4) 97 | for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) { 98 | b2 := builder.initField(ReadCString(buf), kind) 99 | if b2 == nil { 100 | Skip(buf, kind) 101 | continue 102 | } 103 | if unmarshaler := b2.canUnMarshal(); unmarshaler != nil { 104 | unmarshaler.UnmarshalBson(buf, kind) 105 | continue 106 | } 107 | switch b2.val.Kind() { 108 | case reflect.String: 109 | b2.setString(DecodeString(buf, kind)) 110 | case reflect.Int64: 111 | b2.setInt(DecodeInt64(buf, kind)) 112 | case reflect.Int32: 113 | b2.setInt(int64(DecodeInt32(buf, kind))) 114 | case reflect.Int: 115 | b2.setInt(int64(DecodeInt(buf, kind))) 116 | case reflect.Uint64: 117 | b2.setUint(DecodeUint64(buf, kind)) 118 | case reflect.Uint32: 119 | b2.setUint(uint64(DecodeUint32(buf, kind))) 120 | case reflect.Uint: 121 | b2.setUint(uint64(DecodeUint(buf, kind))) 122 | case reflect.Float64: 123 | b2.setFloat(DecodeFloat64(buf, kind)) 124 | case reflect.Bool: 125 | b2.setBool(DecodeBool(buf, kind)) 126 | case reflect.Struct: 127 | if b2.val.Type() == timeType { 128 | b2.setTime(DecodeTime(buf, kind)) 129 | } else { 130 | decodeDocument(buf, b2, kind) 131 | } 132 | case reflect.Map, reflect.Array: 133 | decodeDocument(buf, b2, kind) 134 | case reflect.Slice: 135 | if b2.val.Type() == bytesType { 136 | b2.setBytes(DecodeBinary(buf, kind)) 137 | } else { 138 | decodeDocument(buf, b2, kind) 139 | } 140 | case reflect.Interface: 141 | b2.setInterface(DecodeInterface(buf, kind)) 142 | default: 143 | panic(NewBsonError("cannot unmarshal into %v", b2.val.Kind())) 144 | } 145 | b2.save() 146 | } 147 | } 148 | 149 | // Maps & interface values will not give you a reference to their underlying object. 150 | // You can only update them through their Set methods. 151 | type valueBuilder struct { 152 | val reflect.Value 153 | 154 | // if map_.IsValid(), write val to map_ using key. 155 | map_ reflect.Value 156 | key reflect.Value 157 | 158 | // index tracks current index if val is an array. 159 | index int 160 | } 161 | 162 | // topLevelBuilder returns a valid unmarshalable valueBuilder or an error 163 | func topLevelBuilder(val interface{}) (sb *valueBuilder, err error) { 164 | ival := reflect.ValueOf(val) 165 | if ival.Kind() != reflect.Ptr { 166 | return nil, fmt.Errorf("expecting pointer value, received %v", ival.Type()) 167 | } 168 | return newValueBuilder(ival.Elem()), nil 169 | } 170 | 171 | // newValuebuilder returns a valueBuilder for val. It perorms all 172 | // necessary memory allocations. 173 | func newValueBuilder(val reflect.Value) *valueBuilder { 174 | for val.Kind() == reflect.Ptr { 175 | if val.IsNil() { 176 | val.Set(reflect.New(val.Type().Elem())) 177 | } 178 | val = val.Elem() 179 | } 180 | switch val.Kind() { 181 | case reflect.Map: 182 | if val.IsNil() { 183 | val.Set(reflect.MakeMap(val.Type())) 184 | } 185 | case reflect.Slice: 186 | if val.IsNil() { 187 | val.Set(reflect.MakeSlice(val.Type(), 0, 8)) 188 | } 189 | } 190 | return &valueBuilder{val: val} 191 | } 192 | 193 | // mapValueBuilder returns a valueBuilder that represents a map value. 194 | // You need to call save after building the value to make sure it gets 195 | // saved to the map. 196 | func mapValueBuilder(typ reflect.Type, map_ reflect.Value, key reflect.Value) *valueBuilder { 197 | if typ.Kind() == reflect.Ptr { 198 | addr := reflect.New(typ.Elem()) 199 | map_.SetMapIndex(key, addr) 200 | return newValueBuilder(addr.Elem()) 201 | } 202 | builder := newValueBuilder(reflect.New(typ).Elem()) 203 | builder.map_ = map_ 204 | builder.key = key 205 | return builder 206 | } 207 | 208 | // save saves the built value into the map. 209 | func (builder *valueBuilder) save() { 210 | if builder.map_.IsValid() { 211 | builder.map_.SetMapIndex(builder.key, builder.val) 212 | } 213 | } 214 | 215 | // initField returns a valueBuilder based on the requested key. 216 | // If the key is a the magic tag _Val_, it returns itself. 217 | // If builder is a struct, it looks for a field of that name. 218 | // If builder is a map, it creates an entry for that key. 219 | // If buider is an array, it ignores the key and returns the next 220 | // element of the array. 221 | // If builder is a slice, it returns a newly appended element. 222 | // If the key cannot be resolved, it returns null. 223 | // If kind is Null, it initializes the field to the zero value. 224 | // Otherwise, it allocates memory as needed. 225 | func (builder *valueBuilder) initField(k string, kind byte) *valueBuilder { 226 | if k == MAGICTAG { 227 | if kind == Null { 228 | setZero(builder.val) 229 | return nil 230 | } 231 | return builder 232 | } 233 | switch builder.val.Kind() { 234 | case reflect.Struct: 235 | t := builder.val.Type() 236 | for i := 0; i < t.NumField(); i++ { 237 | if t.Field(i).Name == k { 238 | if kind == Null { 239 | setZero(builder.val.Field(i)) 240 | return nil 241 | } 242 | return newValueBuilder(builder.val.Field(i)) 243 | } 244 | } 245 | return nil 246 | case reflect.Map: 247 | t := builder.val.Type() 248 | if t.Key().Kind() != reflect.String { 249 | panic(NewBsonError("map index is not a string: %s", k)) 250 | } 251 | key := reflect.ValueOf(k) 252 | if kind == Null { 253 | zero := reflect.Zero(t.Elem()) 254 | builder.val.SetMapIndex(key, zero) 255 | return nil 256 | } 257 | return mapValueBuilder(t.Elem(), builder.val, key) 258 | case reflect.Array: 259 | if builder.index >= builder.val.Len() { 260 | panic(NewBsonError("array index %v out of bounds", builder.index)) 261 | } 262 | ind := builder.index 263 | builder.index++ 264 | if kind == Null { 265 | setZero(builder.val.Index(ind)) 266 | return nil 267 | } 268 | return newValueBuilder(builder.val.Index(ind)) 269 | case reflect.Slice: 270 | zero := reflect.Zero(builder.val.Type().Elem()) 271 | builder.val.Set(reflect.Append(builder.val, zero)) 272 | if kind == Null { 273 | return nil 274 | } 275 | return newValueBuilder(builder.val.Index(builder.val.Len() - 1)) 276 | } 277 | // Failsafe: this code is actually unreachable. 278 | panic(NewBsonError("internal error: unindexable type %v", builder.val.Type())) 279 | } 280 | 281 | func setZero(v reflect.Value) { 282 | v.Set(reflect.Zero(v.Type())) 283 | } 284 | 285 | func (builder *valueBuilder) setInt(i int64) { 286 | builder.val.SetInt(i) 287 | } 288 | 289 | func (builder *valueBuilder) setUint(u uint64) { 290 | builder.val.SetUint(u) 291 | } 292 | 293 | func (builder *valueBuilder) setFloat(f float64) { 294 | builder.val.SetFloat(f) 295 | } 296 | 297 | func (builder *valueBuilder) setString(s string) { 298 | builder.val.SetString(s) 299 | } 300 | 301 | func (builder *valueBuilder) setBool(tf bool) { 302 | builder.val.SetBool(tf) 303 | } 304 | 305 | func (builder *valueBuilder) setTime(t time.Time) { 306 | builder.val.Set(reflect.ValueOf(t)) 307 | } 308 | 309 | func (builder *valueBuilder) setBytes(b []byte) { 310 | builder.val.Set(reflect.ValueOf(b)) 311 | } 312 | 313 | func (builder *valueBuilder) setInterface(i interface{}) { 314 | builder.val.Set(reflect.ValueOf(i)) 315 | } 316 | -------------------------------------------------------------------------------- /parser/dependency/bson/unmarshal_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bson 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func newstring(v string) *string { return &v } 14 | func newint64(v int64) *int64 { return &v } 15 | func newint32(v int32) *int32 { return &v } 16 | func newint(v int) *int { return &v } 17 | func newuint64(v uint64) *uint64 { return &v } 18 | func newuint32(v uint32) *uint32 { return &v } 19 | func newuint(v uint) *uint { return &v } 20 | func newfloat64(v float64) *float64 { return &v } 21 | func newbool(v bool) *bool { return &v } 22 | func newtime(v time.Time) *time.Time { return &v } 23 | func newinterface(v interface{}) *interface{} { return &v } 24 | 25 | var unmarshaltest = []struct { 26 | desc string 27 | in string 28 | out interface{} 29 | want interface{} 30 | }{{ 31 | 32 | // top level decodes 33 | "top level nil decode", 34 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 35 | nil, 36 | nil, 37 | }, { 38 | "top level struct decode", 39 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 40 | &struct{ Val string }{}, 41 | &struct{ Val string }{"test"}, 42 | }, { 43 | "top level map decode", 44 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 45 | &map[string]string{}, 46 | &map[string]string{"Val": "test"}, 47 | }, { 48 | "top level slice decode", 49 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 50 | &[]string{}, 51 | &[]string{"test"}, 52 | }, { 53 | "top level array decode", 54 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 55 | &[2]string{}, 56 | &[2]string{"test", ""}, 57 | }, { 58 | "top level string decode", 59 | "\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00", 60 | newstring(""), 61 | newstring("test"), 62 | }, { 63 | "top level string decode from Null", 64 | "\x0c\x00\x00\x00\n_Val_\x00\x00", 65 | newstring("test"), 66 | newstring(""), 67 | }, { 68 | "top level bytes decode", 69 | "\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00", 70 | &[]byte{}, 71 | &[]byte{'t', 'e', 's', 't'}, 72 | }, { 73 | "top level int64 decode", 74 | "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 75 | newint64(0), 76 | newint64(1), 77 | }, { 78 | "top level int32 decode", 79 | "\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00", 80 | newint32(0), 81 | newint32(1), 82 | }, { 83 | "top level int decode", 84 | "\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00", 85 | newint(0), 86 | newint(1), 87 | }, { 88 | "top level uint64 decode", 89 | "\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 90 | newuint64(0), 91 | newuint64(1), 92 | }, { 93 | "top level uint32 decode", 94 | "\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 95 | newuint32(0), 96 | newuint32(1), 97 | }, { 98 | "top level uint decode", 99 | "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 100 | newuint(0), 101 | newuint(1), 102 | }, { 103 | "top level float64 decode", 104 | "\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 105 | newfloat64(0), 106 | newfloat64(1.0), 107 | }, { 108 | "top level bool decode", 109 | "\r\x00\x00\x00\b_Val_\x00\x01\x00", 110 | newbool(false), 111 | newbool(true), 112 | }, { 113 | "top level time decode", 114 | "\x14\x00\x00\x00\t_Val_\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00", 115 | newtime(time.Now()), 116 | newtime(time.Unix(1136243045, 0).UTC()), 117 | }, { 118 | "top level interface decode", 119 | "\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 120 | newinterface(nil), 121 | newinterface(float64(1.0)), 122 | }, { 123 | 124 | // embedded decodes 125 | "struct decode from Object", 126 | "\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00", 127 | &struct{ Val struct{ Val2 string } }{}, 128 | &struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}}, 129 | }, { 130 | "struct decode from Null", 131 | "\n\x00\x00\x00\nVal\x00\x00", 132 | &struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}}, 133 | &struct{ Val struct{ Val2 string } }{}, 134 | }, { 135 | "map decode from Object", 136 | "\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00", 137 | &struct{ Val map[string]string }{}, 138 | &struct{ Val map[string]string }{map[string]string{"Val2": "test"}}, 139 | }, { 140 | "map decode from Null", 141 | "\n\x00\x00\x00\nVal\x00\x00", 142 | &struct{ Val map[string]string }{map[string]string{"Val2": "test"}}, 143 | &struct{ Val map[string]string }{}, 144 | }, { 145 | "map decode from Null element", 146 | "\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00", 147 | &struct{ Val map[string]string }{}, 148 | &struct{ Val map[string]string }{map[string]string{"Val2": ""}}, 149 | }, { 150 | "slice decode from Array", 151 | "\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00", 152 | &struct{ Val []string }{}, 153 | &struct{ Val []string }{[]string{"test"}}, 154 | }, { 155 | "slice decode from Null", 156 | "\n\x00\x00\x00\nVal\x00\x00", 157 | &struct{ Val []string }{[]string{"test"}}, 158 | &struct{ Val []string }{}, 159 | }, { 160 | "slice decode from Null element", 161 | "\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00", 162 | &struct{ Val []*int64 }{}, 163 | &struct{ Val []*int64 }{[]*int64{nil, newint64(1)}}, 164 | }, { 165 | "array decode from Array", 166 | "\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00", 167 | &struct{ Val [2]string }{}, 168 | &struct{ Val [2]string }{[2]string{"test", ""}}, 169 | }, { 170 | "array decode from Null", 171 | "\n\x00\x00\x00\nVal\x00\x00", 172 | &struct{ Val [2]string }{[2]string{"test", ""}}, 173 | &struct{ Val [2]string }{}, 174 | }, { 175 | "array decode from Null element", 176 | "\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00", 177 | &struct{ Val [2]*int64 }{}, 178 | &struct{ Val [2]*int64 }{[2]*int64{nil, newint64(1)}}, 179 | }, { 180 | "string decode from String", 181 | "\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00", 182 | &struct{ Val string }{}, 183 | &struct{ Val string }{"test"}, 184 | }, { 185 | "string decode from Binary", 186 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 187 | &struct{ Val string }{}, 188 | &struct{ Val string }{"test"}, 189 | }, { 190 | "string decode from Null", 191 | "\n\x00\x00\x00\nVal\x00\x00", 192 | &struct{ Val string }{"test"}, 193 | &struct{ Val string }{}, 194 | }, { 195 | "bytes decode from String", 196 | "\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00", 197 | &struct{ Val []byte }{}, 198 | &struct{ Val []byte }{[]byte("test")}, 199 | }, { 200 | "bytes decode from Binary", 201 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 202 | &struct{ Val []byte }{}, 203 | &struct{ Val []byte }{[]byte("test")}, 204 | }, { 205 | "bytes decode from Null", 206 | "\n\x00\x00\x00\nVal\x00\x00", 207 | &struct{ Val []byte }{[]byte("test")}, 208 | &struct{ Val []byte }{}, 209 | }, { 210 | "int64 decode from Int", 211 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 212 | &struct{ Val int64 }{}, 213 | &struct{ Val int64 }{1}, 214 | }, { 215 | "int64 decode from Long", 216 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 217 | &struct{ Val int64 }{}, 218 | &struct{ Val int64 }{1}, 219 | }, { 220 | "int64 decode from Ulong", 221 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 222 | &struct{ Val int64 }{}, 223 | &struct{ Val int64 }{1}, 224 | }, { 225 | "int64 decode from Null", 226 | "\n\x00\x00\x00\nVal\x00\x00", 227 | &struct{ Val int64 }{1}, 228 | &struct{ Val int64 }{}, 229 | }, { 230 | "int32 decode from Int", 231 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 232 | &struct{ Val int32 }{}, 233 | &struct{ Val int32 }{1}, 234 | }, { 235 | "int32 decode from Null", 236 | "\n\x00\x00\x00\nVal\x00\x00", 237 | &struct{ Val int32 }{1}, 238 | &struct{ Val int32 }{}, 239 | }, { 240 | "int decode from Long", 241 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 242 | &struct{ Val int }{}, 243 | &struct{ Val int }{1}, 244 | }, { 245 | "int decode from Int", 246 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 247 | &struct{ Val int }{}, 248 | &struct{ Val int }{1}, 249 | }, { 250 | "int decode from Ulong", 251 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 252 | &struct{ Val int }{}, 253 | &struct{ Val int }{1}, 254 | }, { 255 | "int decode from Null", 256 | "\n\x00\x00\x00\nVal\x00\x00", 257 | &struct{ Val int }{1}, 258 | &struct{ Val int }{}, 259 | }, { 260 | "uint64 decode from Int", 261 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 262 | &struct{ Val uint64 }{}, 263 | &struct{ Val uint64 }{1}, 264 | }, { 265 | "uint64 decode from Long", 266 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 267 | &struct{ Val uint64 }{}, 268 | &struct{ Val uint64 }{1}, 269 | }, { 270 | "uint64 decode from Ulong", 271 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 272 | &struct{ Val uint64 }{}, 273 | &struct{ Val uint64 }{1}, 274 | }, { 275 | "uint64 decode from Null", 276 | "\n\x00\x00\x00\nVal\x00\x00", 277 | &struct{ Val uint64 }{1}, 278 | &struct{ Val uint64 }{}, 279 | }, { 280 | "uint32 decode from Int", 281 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 282 | &struct{ Val uint32 }{}, 283 | &struct{ Val uint32 }{1}, 284 | }, { 285 | "uint32 decode from Ulong", 286 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 287 | &struct{ Val uint32 }{}, 288 | &struct{ Val uint32 }{1}, 289 | }, { 290 | "uint32 decode from Null", 291 | "\n\x00\x00\x00\nVal\x00\x00", 292 | &struct{ Val uint32 }{1}, 293 | &struct{ Val uint32 }{}, 294 | }, { 295 | "uint decode from Int", 296 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 297 | &struct{ Val uint }{}, 298 | &struct{ Val uint }{1}, 299 | }, { 300 | "uint decode from Long", 301 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 302 | &struct{ Val uint }{}, 303 | &struct{ Val uint }{1}, 304 | }, { 305 | "uint decode from Ulong", 306 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 307 | &struct{ Val uint }{}, 308 | &struct{ Val uint }{1}, 309 | }, { 310 | "uint decode from Null", 311 | "\n\x00\x00\x00\nVal\x00\x00", 312 | &struct{ Val uint }{1}, 313 | &struct{ Val uint }{}, 314 | }, { 315 | "float64 decode from Number", 316 | "\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 317 | &struct{ Val float64 }{}, 318 | &struct{ Val float64 }{1.0}, 319 | }, { 320 | "float64 decode from Null", 321 | "\n\x00\x00\x00\nVal\x00\x00", 322 | &struct{ Val float64 }{1.0}, 323 | &struct{ Val float64 }{}, 324 | }, { 325 | "bool decode from Boolean", 326 | "\v\x00\x00\x00\bVal\x00\x01\x00", 327 | &struct{ Val bool }{}, 328 | &struct{ Val bool }{true}, 329 | }, { 330 | "bool decode from Int", 331 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 332 | &struct{ Val bool }{}, 333 | &struct{ Val bool }{true}, 334 | }, { 335 | "bool decode from Long", 336 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 337 | &struct{ Val bool }{}, 338 | &struct{ Val bool }{true}, 339 | }, { 340 | "bool decode from Ulong", 341 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 342 | &struct{ Val bool }{}, 343 | &struct{ Val bool }{true}, 344 | }, { 345 | "bool decode from Null", 346 | "\n\x00\x00\x00\nVal\x00\x00", 347 | &struct{ Val bool }{true}, 348 | &struct{ Val bool }{}, 349 | }, { 350 | "time decode from Datetime", 351 | "\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00", 352 | &struct{ Val time.Time }{}, 353 | &struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()}, 354 | }, { 355 | "time decode from Null", 356 | "\n\x00\x00\x00\nVal\x00\x00", 357 | &struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()}, 358 | &struct{ Val time.Time }{}, 359 | }, { 360 | "interface decode from Number", 361 | "\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00", 362 | &struct{ Val interface{} }{}, 363 | &struct{ Val interface{} }{float64(1.0)}, 364 | }, { 365 | "interface decode from String", 366 | "\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00", 367 | &struct{ Val interface{} }{}, 368 | &struct{ Val interface{} }{"test"}, 369 | }, { 370 | "interface decode from Binary", 371 | "\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00", 372 | &struct{ Val interface{} }{}, 373 | &struct{ Val interface{} }{[]byte("test")}, 374 | }, { 375 | "interface decode from Boolean", 376 | "\v\x00\x00\x00\bVal\x00\x01\x00", 377 | &struct{ Val interface{} }{}, 378 | &struct{ Val interface{} }{true}, 379 | }, { 380 | "interface decode from Datetime", 381 | "\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00", 382 | &struct{ Val interface{} }{}, 383 | &struct{ Val interface{} }{time.Unix(1136243045, 0).UTC()}, 384 | }, { 385 | "interface decode from Int", 386 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 387 | &struct{ Val interface{} }{}, 388 | &struct{ Val interface{} }{int32(1)}, 389 | }, { 390 | "interface decode from Long", 391 | "\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 392 | &struct{ Val interface{} }{}, 393 | &struct{ Val interface{} }{int64(1)}, 394 | }, { 395 | "interface decode from Ulong", 396 | "\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00", 397 | &struct{ Val interface{} }{}, 398 | &struct{ Val interface{} }{uint64(1)}, 399 | }, { 400 | "interface decode from Object", 401 | "\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00", 402 | &struct{ Val interface{} }{}, 403 | &struct{ Val interface{} }{map[string]interface{}{"Val2": []byte("test")}}, 404 | }, { 405 | "interface decode from Object with Null element", 406 | "\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00", 407 | &struct{ Val interface{} }{}, 408 | &struct{ Val interface{} }{map[string]interface{}{"Val2": nil}}, 409 | }, { 410 | "interface decode from Array", 411 | "\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00", 412 | &struct{ Val interface{} }{}, 413 | &struct{ Val interface{} }{[]interface{}{[]byte("test")}}, 414 | }, { 415 | "interface decode from Array null element", 416 | "\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00", 417 | &struct{ Val interface{} }{}, 418 | &struct{ Val interface{} }{[]interface{}{nil, int64(1)}}, 419 | }, { 420 | "interface decode from Null", 421 | "\n\x00\x00\x00\nVal\x00\x00", 422 | &struct{ Val interface{} }{uint64(1)}, 423 | &struct{ Val interface{} }{}, 424 | }, { 425 | "pointer decode from Int", 426 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 427 | &struct{ Val *int64 }{}, 428 | &struct{ Val *int64 }{newint64(1)}, 429 | }} 430 | 431 | func TestUnmarshal(t *testing.T) { 432 | for _, tcase := range unmarshaltest { 433 | verifyUnmarshal(t, []byte(tcase.in), tcase.out) 434 | if !reflect.DeepEqual(tcase.out, tcase.want) { 435 | out := reflect.ValueOf(tcase.out).Elem().Interface() 436 | want := reflect.ValueOf(tcase.want).Elem().Interface() 437 | t.Errorf("%s: decoded: \n%#v, want\n%#v", tcase.desc, out, want) 438 | } 439 | } 440 | } 441 | 442 | var unmarshalErrorCases = []struct { 443 | desc string 444 | in string 445 | out interface{} 446 | want string 447 | }{{ 448 | "non pointer input", 449 | "", 450 | 10, 451 | "expecting pointer value, received int", 452 | }, { 453 | "invalid bson kind", 454 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 455 | &struct{ Val struct{ Val2 int } }{}, 456 | "unexpected kind: 16", 457 | }, { 458 | "map with int key", 459 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 460 | &map[int]int{}, 461 | "map index is not a string: Val", 462 | }, { 463 | "small array", 464 | "\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00", 465 | &[1]string{}, 466 | "array index 1 out of bounds", 467 | }, { 468 | "chan in struct", 469 | "\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00", 470 | &struct{ Val chan int }{}, 471 | "cannot unmarshal into chan", 472 | }} 473 | 474 | func TestUnmarshalErrors(t *testing.T) { 475 | for _, tcase := range unmarshalErrorCases { 476 | err := Unmarshal([]byte(tcase.in), tcase.out) 477 | got := "" 478 | if err != nil { 479 | got = err.Error() 480 | } 481 | if got != tcase.want { 482 | t.Errorf("%s: received: %q, want %q", tcase.desc, got, tcase.want) 483 | } 484 | } 485 | } 486 | -------------------------------------------------------------------------------- /parser/dependency/bson/unmarshal_util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Utility functions for custom decoders 6 | 7 | package bson 8 | 9 | import ( 10 | "bytes" 11 | "math" 12 | "time" 13 | 14 | "github.com/wgliang/pgproxy/parser/dependency/hack" 15 | ) 16 | 17 | // VerifyObject verifies kind to make sure it's 18 | // either a top level document (EOO) or an Object. 19 | // TODO(sougou): deprecate this function. 20 | func VerifyObject(kind byte) { 21 | if kind != EOO && kind != Object { 22 | panic(NewBsonError("unexpected kind: %v", kind)) 23 | } 24 | } 25 | 26 | // DecodeString decodes a string from buf. 27 | // Allowed types: String, Binary, Null, 28 | func DecodeString(buf *bytes.Buffer, kind byte) string { 29 | switch kind { 30 | case String: 31 | l := int(Pack.Uint32(Next(buf, 4))) 32 | s := Next(buf, l-1) 33 | NextByte(buf) 34 | return string(s) 35 | case Binary: 36 | l := int(Pack.Uint32(Next(buf, 4))) 37 | NextByte(buf) 38 | return string(Next(buf, l)) 39 | case Null: 40 | return "" 41 | } 42 | panic(NewBsonError("unexpected kind %v for string", kind)) 43 | } 44 | 45 | // DecodeBinary decodes a []byte from buf. 46 | // Allowed types: String, Binary, Null. 47 | func DecodeBinary(buf *bytes.Buffer, kind byte) []byte { 48 | switch kind { 49 | case String: 50 | l := int(Pack.Uint32(Next(buf, 4))) 51 | b := Next(buf, l-1) 52 | NextByte(buf) 53 | return b 54 | case Binary: 55 | l := int(Pack.Uint32(Next(buf, 4))) 56 | NextByte(buf) 57 | return Next(buf, l) 58 | case Null: 59 | return nil 60 | } 61 | panic(NewBsonError("unexpected kind %v for []byte", kind)) 62 | } 63 | 64 | // DecodeInt64 decodes a int64 from buf. 65 | // Allowed types: Int, Long, Ulong, Null. 66 | func DecodeInt64(buf *bytes.Buffer, kind byte) int64 { 67 | switch kind { 68 | case Int: 69 | return int64(int32(Pack.Uint32(Next(buf, 4)))) 70 | case Long, Ulong: 71 | return int64(Pack.Uint64(Next(buf, 8))) 72 | case Null: 73 | return 0 74 | } 75 | panic(NewBsonError("unexpected kind %v for int64", kind)) 76 | } 77 | 78 | // DecodeInt32 decodes a int32 from buf. 79 | // Allowed types: Int, Null. 80 | func DecodeInt32(buf *bytes.Buffer, kind byte) int32 { 81 | switch kind { 82 | case Int: 83 | return int32(Pack.Uint32(Next(buf, 4))) 84 | case Null: 85 | return 0 86 | } 87 | panic(NewBsonError("unexpected kind %v for int32", kind)) 88 | } 89 | 90 | // DecodeInt decodes a int64 from buf. 91 | // Allowed types: Int, Long, Ulong, Null. 92 | func DecodeInt(buf *bytes.Buffer, kind byte) int { 93 | switch kind { 94 | case Int: 95 | return int(Pack.Uint32(Next(buf, 4))) 96 | case Long, Ulong: 97 | return int(Pack.Uint64(Next(buf, 8))) 98 | case Null: 99 | return 0 100 | } 101 | panic(NewBsonError("unexpected kind %v for int", kind)) 102 | } 103 | 104 | // DecodeUint64 decodes a uint64 from buf. 105 | // Allowed types: Int, Long, Ulong, Null. 106 | func DecodeUint64(buf *bytes.Buffer, kind byte) uint64 { 107 | switch kind { 108 | case Int: 109 | return uint64(Pack.Uint32(Next(buf, 4))) 110 | case Long, Ulong: 111 | return Pack.Uint64(Next(buf, 8)) 112 | case Null: 113 | return 0 114 | } 115 | panic(NewBsonError("unexpected kind %v for uint64", kind)) 116 | } 117 | 118 | // DecodeUint32 decodes a uint32 from buf. 119 | // Allowed types: Int, Long, Null. 120 | func DecodeUint32(buf *bytes.Buffer, kind byte) uint32 { 121 | switch kind { 122 | case Int: 123 | return Pack.Uint32(Next(buf, 4)) 124 | case Ulong: 125 | return uint32(Pack.Uint64(Next(buf, 8))) 126 | case Null: 127 | return 0 128 | } 129 | panic(NewBsonError("unexpected kind %v for uint32", kind)) 130 | } 131 | 132 | // DecodeUint decodes a uint64 from buf. 133 | // Allowed types: Int, Long, Ulong, Null. 134 | func DecodeUint(buf *bytes.Buffer, kind byte) uint { 135 | switch kind { 136 | case Int: 137 | return uint(Pack.Uint32(Next(buf, 4))) 138 | case Long, Ulong: 139 | return uint(Pack.Uint64(Next(buf, 8))) 140 | case Null: 141 | return 0 142 | } 143 | panic(NewBsonError("unexpected kind %v for uint", kind)) 144 | } 145 | 146 | // DecodeFloat64 decodes a float64 from buf. 147 | // Allowed types: Number, Null. 148 | func DecodeFloat64(buf *bytes.Buffer, kind byte) float64 { 149 | switch kind { 150 | case Number: 151 | return float64(math.Float64frombits(Pack.Uint64(Next(buf, 8)))) 152 | case Null: 153 | return 0 154 | } 155 | panic(NewBsonError("unexpected kind %v for float64", kind)) 156 | } 157 | 158 | // DecodeBool decodes a bool from buf. 159 | // Allowed types: Boolean, Int, Long, Ulong, Null. 160 | func DecodeBool(buf *bytes.Buffer, kind byte) bool { 161 | switch kind { 162 | case Boolean: 163 | b, _ := buf.ReadByte() 164 | return (b != 0) 165 | case Int: 166 | return (Pack.Uint32(Next(buf, 4)) != 0) 167 | case Long, Ulong: 168 | return (Pack.Uint64(Next(buf, 8)) != 0) 169 | case Null: 170 | return false 171 | default: 172 | panic(NewBsonError("unexpected kind %v for bool", kind)) 173 | } 174 | } 175 | 176 | // DecodeBinary decodes a time.Time from buf. 177 | // Allowed types: Datetime, Null. 178 | func DecodeTime(buf *bytes.Buffer, kind byte) time.Time { 179 | switch kind { 180 | case Datetime: 181 | ui64 := Pack.Uint64(Next(buf, 8)) 182 | return time.Unix(0, int64(ui64)*1e6).UTC() 183 | case Null: 184 | return time.Time{} 185 | } 186 | panic(NewBsonError("unexpected kind %v for time.Time", kind)) 187 | } 188 | 189 | // DecodeInterface decodes the next object into an interface. 190 | // Object is decoded as map[string]interface{}. 191 | // Array is decoded as []interface{} 192 | func DecodeInterface(buf *bytes.Buffer, kind byte) interface{} { 193 | switch kind { 194 | case Number: 195 | return DecodeFloat64(buf, kind) 196 | case String: 197 | return DecodeString(buf, kind) 198 | case Object: 199 | return DecodeMap(buf, kind) 200 | case Array: 201 | return DecodeArray(buf, kind) 202 | case Binary: 203 | return DecodeBinary(buf, kind) 204 | case Boolean: 205 | return DecodeBool(buf, kind) 206 | case Datetime: 207 | return DecodeTime(buf, kind) 208 | case Null: 209 | return nil 210 | case Int: 211 | return DecodeInt32(buf, kind) 212 | case Long: 213 | return DecodeInt64(buf, kind) 214 | case Ulong: 215 | return DecodeUint64(buf, kind) 216 | } 217 | panic(NewBsonError("unexpected kind %v for interface{}", kind)) 218 | } 219 | 220 | // DecodeMap decodes a map[string]interface{} from buf. 221 | // Allowed types: Object, Null. 222 | func DecodeMap(buf *bytes.Buffer, kind byte) map[string]interface{} { 223 | switch kind { 224 | case Object: 225 | // valid 226 | case Null: 227 | return nil 228 | default: 229 | panic(NewBsonError("unexpected kind %v for map", kind)) 230 | } 231 | 232 | result := make(map[string]interface{}) 233 | Next(buf, 4) 234 | for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) { 235 | key := ReadCString(buf) 236 | if kind == Null { 237 | result[key] = nil 238 | continue 239 | } 240 | result[key] = DecodeInterface(buf, kind) 241 | } 242 | return result 243 | } 244 | 245 | // DecodeMap decodes a []interface{} from buf. 246 | // Allowed types: Array, Null. 247 | func DecodeArray(buf *bytes.Buffer, kind byte) []interface{} { 248 | switch kind { 249 | case Array: 250 | // valid 251 | case Null: 252 | return nil 253 | default: 254 | panic(NewBsonError("unexpected kind %v for slice", kind)) 255 | } 256 | 257 | result := make([]interface{}, 0, 8) 258 | Next(buf, 4) 259 | for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) { 260 | ReadCString(buf) 261 | if kind == Null { 262 | result = append(result, nil) 263 | continue 264 | } 265 | result = append(result, DecodeInterface(buf, kind)) 266 | } 267 | return result 268 | } 269 | 270 | // DecodeMap decodes a []string from buf. 271 | // Allowed types: Array, Null. 272 | func DecodeStringArray(buf *bytes.Buffer, kind byte) []string { 273 | switch kind { 274 | case Array: 275 | // valid 276 | case Null: 277 | return nil 278 | default: 279 | panic(NewBsonError("unexpected kind %v for []string", kind)) 280 | } 281 | 282 | result := make([]string, 0, 8) 283 | Next(buf, 4) 284 | for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) { 285 | if kind != Binary { 286 | panic(NewBsonError("unexpected kind %v for string", kind)) 287 | } 288 | SkipIndex(buf) 289 | result = append(result, DecodeString(buf, kind)) 290 | } 291 | return result 292 | } 293 | 294 | // Skip will skip the next field we don't want to read. 295 | func Skip(buf *bytes.Buffer, kind byte) { 296 | switch kind { 297 | case Number, Datetime, Long, Ulong: 298 | Next(buf, 8) 299 | case String: 300 | // length of a string includes the 0 at the end, but not the size 301 | l := int(Pack.Uint32(Next(buf, 4))) 302 | Next(buf, l) 303 | case Object, Array: 304 | // the encoded length includes the 4 bytes for the size 305 | l := int(Pack.Uint32(Next(buf, 4))) 306 | if l < 4 { 307 | panic(NewBsonError("Object or Array should at least be 4 bytes long")) 308 | } 309 | Next(buf, l-4) 310 | case Binary: 311 | // length of a binary doesn't include the subtype 312 | l := int(Pack.Uint32(Next(buf, 4))) 313 | Next(buf, l+1) 314 | case Boolean: 315 | buf.ReadByte() 316 | case Int: 317 | Next(buf, 4) 318 | case Null: 319 | // no op 320 | default: 321 | panic(NewBsonError("unexpected kind %v for skip", kind)) 322 | } 323 | } 324 | 325 | // SkipIndex must be used to skip indexes in arrays. 326 | func SkipIndex(buf *bytes.Buffer) { 327 | ReadCString(buf) 328 | } 329 | 330 | // ReadCString reads the the bson document tag. 331 | func ReadCString(buf *bytes.Buffer) string { 332 | index := bytes.IndexByte(buf.Bytes(), 0) 333 | if index < 0 { 334 | panic(NewBsonError("unexpected EOF")) 335 | } 336 | // Read including null termination, but 337 | // return the string without the null. 338 | return hack.String(Next(buf, index+1)[:index]) 339 | } 340 | 341 | // Next returns the next n bytes from buf. 342 | func Next(buf *bytes.Buffer, n int) []byte { 343 | b := buf.Next(n) 344 | if len(b) != n { 345 | panic(NewBsonError("unexpected EOF")) 346 | } 347 | return b[:n:n] 348 | } 349 | 350 | // NextByte returns the next byte from buf. 351 | func NextByte(buf *bytes.Buffer) byte { 352 | return Next(buf, 1)[0] 353 | } 354 | -------------------------------------------------------------------------------- /parser/dependency/bytes2/chunked_writer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package bytes2 provides alternate implementations of functionality similar 6 | // to go's bytes package. 7 | package bytes2 8 | 9 | import ( 10 | "fmt" 11 | "io" 12 | "unicode/utf8" 13 | 14 | "github.com/wgliang/pgproxy/parser/dependency/hack" 15 | ) 16 | 17 | // ChunkedWriter has the same interface as bytes.Buffer's write functions. 18 | // It additionally provides a Reserve function that returns a []byte that 19 | // the caller can directly change. 20 | type ChunkedWriter struct { 21 | bufs [][]byte 22 | } 23 | 24 | func NewChunkedWriter(chunkSize int) *ChunkedWriter { 25 | cw := &ChunkedWriter{make([][]byte, 1)} 26 | cw.bufs[0] = make([]byte, 0, chunkSize) 27 | return cw 28 | } 29 | 30 | // Bytes This function can get expensive for large buffers. 31 | func (cw *ChunkedWriter) Bytes() (b []byte) { 32 | if len(cw.bufs) == 1 { 33 | return cw.bufs[0] 34 | } 35 | b = make([]byte, 0, cw.Len()) 36 | for _, buf := range cw.bufs { 37 | b = append(b, buf...) 38 | } 39 | return b 40 | } 41 | 42 | func (cw *ChunkedWriter) Len() int { 43 | l := 0 44 | for _, buf := range cw.bufs { 45 | l += len(buf) 46 | } 47 | return l 48 | } 49 | 50 | func (cw *ChunkedWriter) Reset() { 51 | b := cw.bufs[0][:0] 52 | cw.bufs = make([][]byte, 1) 53 | cw.bufs[0] = b 54 | } 55 | 56 | func (cw *ChunkedWriter) Truncate(n int) { 57 | for i, buf := range cw.bufs { 58 | if n > len(buf) { 59 | n -= len(buf) 60 | continue 61 | } 62 | cw.bufs[i] = buf[:n] 63 | cw.bufs = cw.bufs[:i+1] 64 | return 65 | } 66 | panic("bytes.ChunkedBuffer: truncation out of range") 67 | } 68 | 69 | func (cw *ChunkedWriter) Write(p []byte) (n int, err error) { 70 | return cw.WriteString(hack.String(p)) 71 | } 72 | 73 | func (cw *ChunkedWriter) WriteString(p string) (n int, err error) { 74 | n = len(p) 75 | lastbuf := cw.bufs[len(cw.bufs)-1] 76 | for { 77 | available := cap(lastbuf) - len(lastbuf) 78 | required := len(p) 79 | if available >= required { 80 | cw.bufs[len(cw.bufs)-1] = append(lastbuf, p...) 81 | return 82 | } 83 | cw.bufs[len(cw.bufs)-1] = append(lastbuf, p[:available]...) 84 | p = p[available:] 85 | lastbuf = make([]byte, 0, cap(cw.bufs[0])) 86 | cw.bufs = append(cw.bufs, lastbuf) 87 | } 88 | } 89 | 90 | func (cw *ChunkedWriter) Reserve(n int) (b []byte) { 91 | if n > cap(cw.bufs[0]) { 92 | panic(fmt.Sprintf("bytes.ChunkedBuffer: Reserve request too high: %d > %d", n, cap(cw.bufs[0]))) 93 | } 94 | lastbuf := cw.bufs[len(cw.bufs)-1] 95 | if n > cap(lastbuf)-len(lastbuf) { 96 | b = make([]byte, n, cap(cw.bufs[0])) 97 | cw.bufs = append(cw.bufs, b) 98 | return b 99 | } 100 | l := len(lastbuf) 101 | b = lastbuf[l : n+l] 102 | cw.bufs[len(cw.bufs)-1] = lastbuf[:n+l] 103 | return b 104 | } 105 | 106 | func (cw *ChunkedWriter) WriteByte(c byte) error { 107 | cw.Reserve(1)[0] = c 108 | return nil 109 | } 110 | 111 | func (cw *ChunkedWriter) WriteRune(r rune) (n int, err error) { 112 | n = utf8.EncodeRune(cw.Reserve(utf8.RuneLen(r)), r) 113 | return n, nil 114 | } 115 | 116 | func (cw *ChunkedWriter) WriteTo(w io.Writer) (n int64, err error) { 117 | for _, buf := range cw.bufs { 118 | m, err := w.Write(buf) 119 | n += int64(m) 120 | if err != nil { 121 | return n, err 122 | } 123 | if m != len(buf) { 124 | return n, io.ErrShortWrite 125 | } 126 | } 127 | cw.Reset() 128 | return n, nil 129 | } 130 | -------------------------------------------------------------------------------- /parser/dependency/bytes2/cw_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package bytes2 gives you alternate implementations of functionality 6 | // similar to go's bytes package 7 | 8 | package bytes2 9 | 10 | import ( 11 | "testing" 12 | ) 13 | 14 | func TestWrite(t *testing.T) { 15 | cw := NewChunkedWriter(5) 16 | cw.Write([]byte("1234")) 17 | if string(cw.Bytes()) != "1234" { 18 | t.Errorf("Expecting 1234, received %s", cw.Bytes()) 19 | } 20 | cw.WriteString("56") 21 | if string(cw.Bytes()) != "123456" { 22 | t.Errorf("Expecting 123456, received %s", cw.Bytes()) 23 | } 24 | if cw.Len() != 6 { 25 | t.Errorf("Expecting 6, received %d", cw.Len()) 26 | } 27 | } 28 | 29 | func TestTruncate(t *testing.T) { 30 | cw := NewChunkedWriter(3) 31 | cw.WriteString("123456789") 32 | cw.Truncate(8) 33 | if string(cw.Bytes()) != "12345678" { 34 | t.Errorf("Expecting 12345678, received %s", cw.Bytes()) 35 | } 36 | cw.Truncate(5) 37 | if string(cw.Bytes()) != "12345" { 38 | t.Errorf("Expecting 12345, received %s", cw.Bytes()) 39 | } 40 | cw.Truncate(2) 41 | if string(cw.Bytes()) != "12" { 42 | t.Errorf("Expecting 12345, received %s", cw.Bytes()) 43 | } 44 | cw.Reset() 45 | if cw.Len() != 0 { 46 | t.Errorf("Expecting 0, received %d", cw.Len()) 47 | } 48 | } 49 | 50 | func TestReserve(t *testing.T) { 51 | cw := NewChunkedWriter(4) 52 | b := cw.Reserve(2) 53 | b[0] = '1' 54 | b[1] = '2' 55 | cw.WriteByte('3') 56 | b = cw.Reserve(2) 57 | b[0] = '4' 58 | b[1] = '5' 59 | if string(cw.Bytes()) != "12345" { 60 | t.Errorf("Expecting 12345, received %s", cw.Bytes()) 61 | } 62 | } 63 | 64 | func TestWriteTo(t *testing.T) { 65 | cw1 := NewChunkedWriter(4) 66 | cw1.WriteString("123456789") 67 | cw2 := NewChunkedWriter(5) 68 | cw1.WriteTo(cw2) 69 | if string(cw2.Bytes()) != "123456789" { 70 | t.Errorf("Expecting 123456789, received %s", cw2.Bytes()) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /parser/dependency/hack/hack.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package hack gives you some efficient functionality at the cost of 6 | // breaking some Go rules. 7 | package hack 8 | 9 | import ( 10 | "reflect" 11 | "unsafe" 12 | ) 13 | 14 | // StringArena lets you consolidate allocations for a group of strings 15 | // that have similar life length 16 | type StringArena struct { 17 | buf []byte 18 | str string 19 | } 20 | 21 | // NewStringArena creates an arena of the specified size. 22 | func NewStringArena(size int) *StringArena { 23 | sa := &StringArena{buf: make([]byte, 0, size)} 24 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&sa.buf)) 25 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&sa.str)) 26 | pstring.Data = pbytes.Data 27 | pstring.Len = pbytes.Cap 28 | return sa 29 | } 30 | 31 | // NewString copies a byte slice into the arena and returns it as a string. 32 | // If the arena is full, it returns a traditional go string. 33 | func (sa *StringArena) NewString(b []byte) string { 34 | if len(sa.buf)+len(b) > cap(sa.buf) { 35 | return string(b) 36 | } 37 | start := len(sa.buf) 38 | sa.buf = append(sa.buf, b...) 39 | return sa.str[start : start+len(b)] 40 | } 41 | 42 | // SpaceLeft returns the amount of space left in the arena. 43 | func (sa *StringArena) SpaceLeft() int { 44 | return cap(sa.buf) - len(sa.buf) 45 | } 46 | 47 | // String force casts a []byte to a string. 48 | // USE AT YOUR OWN RISK 49 | func String(b []byte) (s string) { 50 | pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 51 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 52 | pstring.Data = pbytes.Data 53 | pstring.Len = pbytes.Len 54 | return 55 | } 56 | 57 | // StringPointer returns &s[0], which is not allowed in go 58 | func StringPointer(s string) unsafe.Pointer { 59 | pstring := (*reflect.StringHeader)(unsafe.Pointer(&s)) 60 | return unsafe.Pointer(pstring.Data) 61 | } 62 | -------------------------------------------------------------------------------- /parser/dependency/hack/hack_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package hack 6 | 7 | import ( 8 | "testing" 9 | ) 10 | 11 | func TestStringArena(t *testing.T) { 12 | sarena := NewStringArena(10) 13 | buf1 := []byte("01234") 14 | buf2 := []byte("5678") 15 | buf3 := []byte("ab") 16 | buf4 := []byte("9") 17 | 18 | s1 := sarena.NewString(buf1) 19 | checkint(t, len(sarena.buf), 5) 20 | checkint(t, sarena.SpaceLeft(), 5) 21 | checkstring(t, s1, "01234") 22 | 23 | s2 := sarena.NewString(buf2) 24 | checkint(t, len(sarena.buf), 9) 25 | checkint(t, sarena.SpaceLeft(), 1) 26 | checkstring(t, s2, "5678") 27 | 28 | // s3 will be allocated outside of sarena 29 | s3 := sarena.NewString(buf3) 30 | checkint(t, len(sarena.buf), 9) 31 | checkint(t, sarena.SpaceLeft(), 1) 32 | checkstring(t, s3, "ab") 33 | 34 | // s4 should still fit in sarena 35 | s4 := sarena.NewString(buf4) 36 | checkint(t, len(sarena.buf), 10) 37 | checkint(t, sarena.SpaceLeft(), 0) 38 | checkstring(t, s4, "9") 39 | 40 | sarena.buf[0] = 'A' 41 | checkstring(t, s1, "A1234") 42 | 43 | sarena.buf[5] = 'B' 44 | checkstring(t, s2, "B678") 45 | 46 | sarena.buf[9] = 'C' 47 | // s3 will not change 48 | checkstring(t, s3, "ab") 49 | checkstring(t, s4, "C") 50 | checkstring(t, sarena.str, "A1234B678C") 51 | } 52 | 53 | func checkstring(t *testing.T, actual, expected string) { 54 | if actual != expected { 55 | t.Errorf("received %s, expecting %s", actual, expected) 56 | } 57 | } 58 | 59 | func checkint(t *testing.T, actual, expected int) { 60 | if actual != expected { 61 | t.Errorf("received %d, expecting %d", actual, expected) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /parser/dependency/sqltypes/sqltypes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package sqltypes implements interfaces and types that represent SQL values. 6 | package sqltypes 7 | 8 | import ( 9 | "bytes" 10 | "encoding/base64" 11 | "encoding/gob" 12 | "encoding/json" 13 | "fmt" 14 | "strconv" 15 | "time" 16 | 17 | "github.com/wgliang/pgproxy/parser/dependency/bson" 18 | "github.com/wgliang/pgproxy/parser/dependency/bytes2" 19 | "github.com/wgliang/pgproxy/parser/dependency/hack" 20 | ) 21 | 22 | var ( 23 | NULL = Value{} 24 | DONTESCAPE = byte(255) 25 | nullstr = []byte("null") 26 | ) 27 | 28 | // BinWriter interface is used for encoding values. 29 | // Types like bytes.Buffer conform to this interface. 30 | // We expect the writer objects to be in-memory buffers. 31 | // So, we don't expect the write operations to fail. 32 | type BinWriter interface { 33 | Write([]byte) (int, error) 34 | WriteByte(byte) error 35 | } 36 | 37 | // Value can store any SQL value. NULL is stored as nil. 38 | type Value struct { 39 | Inner InnerValue 40 | } 41 | 42 | // Numeric represents non-fractional SQL number. 43 | type Numeric []byte 44 | 45 | // Fractional represents fractional types like float and decimal 46 | // It's functionally equivalent to Numeric other than how it's constructed 47 | type Fractional []byte 48 | 49 | // String represents any SQL type that needs to be represented using quotes. 50 | type String []byte 51 | 52 | // MakeNumeric makes a Numeric from a []byte without validation. 53 | func MakeNumeric(b []byte) Value { 54 | return Value{Numeric(b)} 55 | } 56 | 57 | // MakeFractional makes a Fractional value from a []byte without validation. 58 | func MakeFractional(b []byte) Value { 59 | return Value{Fractional(b)} 60 | } 61 | 62 | // MakeString makes a String value from a []byte. 63 | func MakeString(b []byte) Value { 64 | return Value{String(b)} 65 | } 66 | 67 | // Raw returns the raw bytes. All types are currently implemented as []byte. 68 | func (v Value) Raw() []byte { 69 | if v.Inner == nil { 70 | return nil 71 | } 72 | return v.Inner.raw() 73 | } 74 | 75 | // String returns the raw value as a string 76 | func (v Value) String() string { 77 | if v.Inner == nil { 78 | return "" 79 | } 80 | return hack.String(v.Inner.raw()) 81 | } 82 | 83 | // ParseInt64 will parse a Numeric value into an int64 84 | func (v Value) ParseInt64() (val int64, err error) { 85 | if v.Inner == nil { 86 | return 0, fmt.Errorf("value is null") 87 | } 88 | n, ok := v.Inner.(Numeric) 89 | if !ok { 90 | return 0, fmt.Errorf("value is not Numeric") 91 | } 92 | return strconv.ParseInt(string(n.raw()), 10, 64) 93 | } 94 | 95 | // ParseUint64 will parse a Numeric value into a uint64 96 | func (v Value) ParseUint64() (val uint64, err error) { 97 | if v.Inner == nil { 98 | return 0, fmt.Errorf("value is null") 99 | } 100 | n, ok := v.Inner.(Numeric) 101 | if !ok { 102 | return 0, fmt.Errorf("value is not Numeric") 103 | } 104 | return strconv.ParseUint(string(n.raw()), 10, 64) 105 | } 106 | 107 | // EncodeSql encodes the value into an SQL statement. Can be binary. 108 | func (v Value) EncodeSql(b BinWriter) { 109 | if v.Inner == nil { 110 | if _, err := b.Write(nullstr); err != nil { 111 | panic(err) 112 | } 113 | } else { 114 | v.Inner.encodeSql(b) 115 | } 116 | } 117 | 118 | // EncodeAscii encodes the value using 7-bit clean ascii bytes. 119 | func (v Value) EncodeAscii(b BinWriter) { 120 | if v.Inner == nil { 121 | if _, err := b.Write(nullstr); err != nil { 122 | panic(err) 123 | } 124 | } else { 125 | v.Inner.encodeAscii(b) 126 | } 127 | } 128 | 129 | func (v Value) MarshalBson(buf *bytes2.ChunkedWriter, key string) { 130 | if key == "" { 131 | lenWriter := bson.NewLenWriter(buf) 132 | defer lenWriter.Close() 133 | key = bson.MAGICTAG 134 | } 135 | if v.IsNull() { 136 | bson.EncodePrefix(buf, bson.Null, key) 137 | } else { 138 | bson.EncodeBinary(buf, key, v.Raw()) 139 | } 140 | } 141 | 142 | func (v *Value) UnmarshalBson(buf *bytes.Buffer, kind byte) { 143 | if kind == bson.EOO { 144 | bson.Next(buf, 4) 145 | kind = bson.NextByte(buf) 146 | bson.ReadCString(buf) 147 | } 148 | if kind != bson.Null { 149 | *v = MakeString(bson.DecodeBinary(buf, kind)) 150 | } 151 | } 152 | 153 | func (v Value) IsNull() bool { 154 | return v.Inner == nil 155 | } 156 | 157 | func (v Value) IsNumeric() (ok bool) { 158 | if v.Inner != nil { 159 | _, ok = v.Inner.(Numeric) 160 | } 161 | return ok 162 | } 163 | 164 | func (v Value) IsFractional() (ok bool) { 165 | if v.Inner != nil { 166 | _, ok = v.Inner.(Fractional) 167 | } 168 | return ok 169 | } 170 | 171 | func (v Value) IsString() (ok bool) { 172 | if v.Inner != nil { 173 | _, ok = v.Inner.(String) 174 | } 175 | return ok 176 | } 177 | 178 | // MarshalJSON should only be used for testing. 179 | // It's not a complete implementation. 180 | func (v Value) MarshalJSON() ([]byte, error) { 181 | return json.Marshal(v.Inner) 182 | } 183 | 184 | // UnmarshalJSON should only be used for testing. 185 | // It's not a complete implementation. 186 | func (v *Value) UnmarshalJSON(b []byte) error { 187 | if len(b) == 0 { 188 | return fmt.Errorf("error unmarshaling empty bytes") 189 | } 190 | var val interface{} 191 | var err error 192 | switch b[0] { 193 | case '-': 194 | var ival int64 195 | err = json.Unmarshal(b, &ival) 196 | val = ival 197 | case '"': 198 | var bval []byte 199 | err = json.Unmarshal(b, &bval) 200 | val = bval 201 | case 'n': // null 202 | err = json.Unmarshal(b, &val) 203 | default: 204 | var uval uint64 205 | err = json.Unmarshal(b, &uval) 206 | val = uval 207 | } 208 | if err != nil { 209 | return err 210 | } 211 | *v, err = BuildValue(val) 212 | return err 213 | } 214 | 215 | // InnerValue defines methods that need to be supported by all non-null value types. 216 | type InnerValue interface { 217 | raw() []byte 218 | encodeSql(BinWriter) 219 | encodeAscii(BinWriter) 220 | } 221 | 222 | func BuildValue(goval interface{}) (v Value, err error) { 223 | switch bindVal := goval.(type) { 224 | case nil: 225 | // no op 226 | case int: 227 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 228 | case int32: 229 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 230 | case int64: 231 | v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))} 232 | case uint: 233 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 234 | case uint32: 235 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 236 | case uint64: 237 | v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))} 238 | case float64: 239 | v = Value{Fractional(strconv.AppendFloat(nil, bindVal, 'f', -1, 64))} 240 | case string: 241 | v = Value{String([]byte(bindVal))} 242 | case []byte: 243 | v = Value{String(bindVal)} 244 | case time.Time: 245 | v = Value{String([]byte(bindVal.Format("2006-01-02 15:04:05")))} 246 | case Numeric, Fractional, String: 247 | v = Value{bindVal.(InnerValue)} 248 | case Value: 249 | v = bindVal 250 | default: 251 | return Value{}, fmt.Errorf("unsupported bind variable type %T: %v", goval, goval) 252 | } 253 | return v, nil 254 | } 255 | 256 | // BuildNumeric builds a Numeric type that represents any whole number. 257 | // It normalizes the representation to ensure 1:1 mapping between the 258 | // number and its representation. 259 | func BuildNumeric(val string) (n Value, err error) { 260 | if val[0] == '-' || val[0] == '+' { 261 | signed, err := strconv.ParseInt(val, 0, 64) 262 | if err != nil { 263 | return Value{}, err 264 | } 265 | n = Value{Numeric(strconv.AppendInt(nil, signed, 10))} 266 | } else { 267 | unsigned, err := strconv.ParseUint(val, 0, 64) 268 | if err != nil { 269 | return Value{}, err 270 | } 271 | n = Value{Numeric(strconv.AppendUint(nil, unsigned, 10))} 272 | } 273 | return n, nil 274 | } 275 | 276 | func (n Numeric) raw() []byte { 277 | return []byte(n) 278 | } 279 | 280 | func (n Numeric) encodeSql(b BinWriter) { 281 | if _, err := b.Write(n.raw()); err != nil { 282 | panic(err) 283 | } 284 | } 285 | 286 | func (n Numeric) encodeAscii(b BinWriter) { 287 | if _, err := b.Write(n.raw()); err != nil { 288 | panic(err) 289 | } 290 | } 291 | 292 | func (n Numeric) MarshalJSON() ([]byte, error) { 293 | return n.raw(), nil 294 | } 295 | 296 | func (f Fractional) raw() []byte { 297 | return []byte(f) 298 | } 299 | 300 | func (f Fractional) encodeSql(b BinWriter) { 301 | if _, err := b.Write(f.raw()); err != nil { 302 | panic(err) 303 | } 304 | } 305 | 306 | func (f Fractional) encodeAscii(b BinWriter) { 307 | if _, err := b.Write(f.raw()); err != nil { 308 | panic(err) 309 | } 310 | } 311 | 312 | func (s String) raw() []byte { 313 | return []byte(s) 314 | } 315 | 316 | func (s String) encodeSql(b BinWriter) { 317 | writebyte(b, '\'') 318 | for _, ch := range s.raw() { 319 | if encodedChar := SqlEncodeMap[ch]; encodedChar == DONTESCAPE { 320 | writebyte(b, ch) 321 | } else { 322 | writebyte(b, '\\') 323 | writebyte(b, encodedChar) 324 | } 325 | } 326 | writebyte(b, '\'') 327 | } 328 | 329 | func (s String) encodeAscii(b BinWriter) { 330 | writebyte(b, '\'') 331 | encoder := base64.NewEncoder(base64.StdEncoding, b) 332 | encoder.Write(s.raw()) 333 | encoder.Close() 334 | writebyte(b, '\'') 335 | } 336 | 337 | func writebyte(b BinWriter, c byte) { 338 | if err := b.WriteByte(c); err != nil { 339 | panic(err) 340 | } 341 | } 342 | 343 | // SqlEncodeMap specifies how to escape binary data with '\'. 344 | // Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html 345 | var SqlEncodeMap [256]byte 346 | 347 | // SqlDecodeMap is the reverse of SqlEncodeMap 348 | var SqlDecodeMap [256]byte 349 | 350 | var encodeRef = map[byte]byte{ 351 | '\x00': '0', 352 | '\'': '\'', 353 | '"': '"', 354 | '\b': 'b', 355 | '\n': 'n', 356 | '\r': 'r', 357 | '\t': 't', 358 | 26: 'Z', // ctl-Z 359 | '\\': '\\', 360 | } 361 | 362 | func init() { 363 | for i := range SqlEncodeMap { 364 | SqlEncodeMap[i] = DONTESCAPE 365 | SqlDecodeMap[i] = DONTESCAPE 366 | } 367 | for i := range SqlEncodeMap { 368 | if to, ok := encodeRef[byte(i)]; ok { 369 | SqlEncodeMap[byte(i)] = to 370 | SqlDecodeMap[to] = byte(i) 371 | } 372 | } 373 | gob.Register(Numeric(nil)) 374 | gob.Register(Fractional(nil)) 375 | gob.Register(String(nil)) 376 | } 377 | -------------------------------------------------------------------------------- /parser/dependency/sqltypes/type_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "bytes" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestNull(t *testing.T) { 14 | n := Value{} 15 | if !n.IsNull() { 16 | t.Errorf("value is not null") 17 | } 18 | if n.String() != "" { 19 | t.Errorf("Expecting '', got %s", n.String()) 20 | } 21 | b := bytes.NewBuffer(nil) 22 | n.EncodeSql(b) 23 | if b.String() != "null" { 24 | t.Errorf("Expecting null, got %s", b.String()) 25 | } 26 | n.EncodeAscii(b) 27 | if b.String() != "nullnull" { 28 | t.Errorf("Expecting nullnull, got %s", b.String()) 29 | } 30 | js, err := n.MarshalJSON() 31 | if err != nil { 32 | t.Errorf("Unexpected error: %s", err) 33 | } 34 | if string(js) != "null" { 35 | t.Errorf("Expecting null, received %s", js) 36 | } 37 | } 38 | 39 | func TestNumeric(t *testing.T) { 40 | n := Value{Numeric([]byte("1234"))} 41 | b := bytes.NewBuffer(nil) 42 | n.EncodeSql(b) 43 | if b.String() != "1234" { 44 | t.Errorf("Expecting 1234, got %s", b.String()) 45 | } 46 | n.EncodeAscii(b) 47 | if b.String() != "12341234" { 48 | t.Errorf("Expecting 12341234, got %s", b.String()) 49 | } 50 | js, err := n.MarshalJSON() 51 | if err != nil { 52 | t.Errorf("Unexpected error: %s", err) 53 | } 54 | if string(js) != "1234" { 55 | t.Errorf("Expecting 1234, received %s", js) 56 | } 57 | } 58 | 59 | func TestTime(t *testing.T) { 60 | date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) 61 | v, _ := BuildValue(date) 62 | if !v.IsString() || v.String() != "1999-01-02 03:04:05" { 63 | t.Errorf("Expecting 1999-01-02 03:04:05, got %s", v.String()) 64 | } 65 | 66 | b := &bytes.Buffer{} 67 | v.EncodeSql(b) 68 | if b.String() != "'1999-01-02 03:04:05'" { 69 | t.Errorf("Expecting '1999-01-02 03:04:05', got %s", b.String()) 70 | } 71 | } 72 | 73 | const ( 74 | INVALIDNEG = "-9223372036854775809" 75 | MINNEG = "-9223372036854775808" 76 | MAXPOS = "18446744073709551615" 77 | INVALIDPOS = "18446744073709551616" 78 | NEGFLOAT = "1.234" 79 | POSFLOAT = "-1.234" 80 | ) 81 | 82 | func TestBuildNumeric(t *testing.T) { 83 | var n Value 84 | var err error 85 | n, err = BuildNumeric(MINNEG) 86 | if err != nil { 87 | t.Errorf("Unexpected error: %s", err) 88 | } 89 | if n.String() != MINNEG { 90 | t.Errorf("Expecting %v, received %s", MINNEG, n.Raw()) 91 | } 92 | n, err = BuildNumeric(MAXPOS) 93 | if err != nil { 94 | t.Errorf("Unexpected error: %s", err) 95 | } 96 | if n.String() != MAXPOS { 97 | t.Errorf("Expecting %v, received %s", MAXPOS, n.Raw()) 98 | } 99 | n, err = BuildNumeric("0xA") 100 | if err != nil { 101 | t.Errorf("Unexpected error: %s", err) 102 | } 103 | if n.String() != "10" { 104 | t.Errorf("Expecting %v, received %s", 10, n.Raw()) 105 | } 106 | n, err = BuildNumeric("012") 107 | if err != nil { 108 | t.Errorf("Unexpected error: %s", err) 109 | } 110 | if string(n.Raw()) != "10" { 111 | t.Errorf("Expecting %v, received %s", 10, n.Raw()) 112 | } 113 | if n, err = BuildNumeric(INVALIDNEG); err == nil { 114 | t.Errorf("Expecting error") 115 | } 116 | if n, err = BuildNumeric(INVALIDPOS); err == nil { 117 | t.Errorf("Expecting error") 118 | } 119 | if n, err = BuildNumeric(NEGFLOAT); err == nil { 120 | t.Errorf("Expecting error") 121 | } 122 | if n, err = BuildNumeric(POSFLOAT); err == nil { 123 | t.Errorf("Expecting error") 124 | } 125 | } 126 | 127 | const ( 128 | HARDSQL = "\x00'\"\b\n\r\t\x1A\\" 129 | HARDESCAPED = "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'" 130 | HARDASCII = "'ACciCAoNCRpc'" 131 | ) 132 | 133 | func TestString(t *testing.T) { 134 | s := Value{String([]byte(HARDSQL))} 135 | b := bytes.NewBuffer(nil) 136 | s.EncodeSql(b) 137 | if b.String() != HARDESCAPED { 138 | t.Errorf("Expecting %s, received %s", HARDESCAPED, b.String()) 139 | } 140 | b = bytes.NewBuffer(nil) 141 | s.EncodeAscii(b) 142 | if b.String() != HARDASCII { 143 | t.Errorf("Expecting %s, received %#v", HARDASCII, b.String()) 144 | } 145 | s = Value{String([]byte("abcd"))} 146 | js, err := s.MarshalJSON() 147 | if err != nil { 148 | t.Errorf("Unexpected error: %s", err) 149 | } 150 | if string(js) != "\"YWJjZA==\"" { 151 | t.Errorf("Expecting \"YWJjZA==\", received %s", js) 152 | } 153 | } 154 | 155 | func TestBuildValue(t *testing.T) { 156 | v, err := BuildValue(nil) 157 | if err != nil { 158 | t.Errorf("%v", err) 159 | } 160 | if !v.IsNull() { 161 | t.Errorf("Expecting null") 162 | } 163 | n64, err := v.ParseUint64() 164 | if err == nil || err.Error() != "value is null" { 165 | t.Errorf("%v", err) 166 | } 167 | v, err = BuildValue(int(-1)) 168 | if err != nil { 169 | t.Errorf("%v", err) 170 | } 171 | if !v.IsNumeric() || v.String() != "-1" { 172 | t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) 173 | } 174 | v, err = BuildValue(int32(-1)) 175 | if err != nil { 176 | t.Errorf("%v", err) 177 | } 178 | if !v.IsNumeric() || v.String() != "-1" { 179 | t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) 180 | } 181 | v, err = BuildValue(int64(-1)) 182 | if err != nil { 183 | t.Errorf("%v", err) 184 | } 185 | if !v.IsNumeric() || v.String() != "-1" { 186 | t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) 187 | } 188 | n64, err = v.ParseUint64() 189 | if err == nil { 190 | t.Errorf("-1 shouldn't convert into uint64") 191 | } 192 | i64, err := v.ParseInt64() 193 | if i64 != -1 { 194 | t.Errorf("want -1, got %d", i64) 195 | } 196 | if err != nil { 197 | t.Errorf("%v", err) 198 | } 199 | v, err = BuildValue(uint(1)) 200 | if err != nil { 201 | t.Errorf("%v", err) 202 | } 203 | if !v.IsNumeric() || v.String() != "1" { 204 | t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) 205 | } 206 | v, err = BuildValue(uint32(1)) 207 | if err != nil { 208 | t.Errorf("%v", err) 209 | } 210 | if !v.IsNumeric() || v.String() != "1" { 211 | t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) 212 | } 213 | v, err = BuildValue(uint64(1)) 214 | if err != nil { 215 | t.Errorf("%v", err) 216 | } 217 | n64, err = v.ParseUint64() 218 | if err != nil { 219 | t.Errorf("%v", err) 220 | } 221 | if n64 != 1 { 222 | t.Errorf("Expecting 1, got %v", n64) 223 | } 224 | if !v.IsNumeric() || v.String() != "1" { 225 | t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) 226 | } 227 | v, err = BuildValue(1.23) 228 | if err != nil { 229 | t.Errorf("%v", err) 230 | } 231 | if !v.IsFractional() || v.String() != "1.23" { 232 | t.Errorf("Expecting 1.23, received %T: %s", v.Inner, v.String()) 233 | } 234 | n64, err = v.ParseUint64() 235 | if err == nil { 236 | t.Errorf("1.23 shouldn't convert into uint64") 237 | } 238 | v, err = BuildValue("abcd") 239 | if err != nil { 240 | t.Errorf("%v", err) 241 | } 242 | if !v.IsString() || v.String() != "abcd" { 243 | t.Errorf("Expecting abcd, received %T: %s", v.Inner, v.String()) 244 | } 245 | v, err = BuildValue([]byte("abcd")) 246 | if err != nil { 247 | t.Errorf("%v", err) 248 | } 249 | if !v.IsString() || v.String() != "abcd" { 250 | t.Errorf("Expecting abcd, received %T: %s", v.Inner, v.String()) 251 | } 252 | n64, err = v.ParseUint64() 253 | if err == nil || err.Error() != "value is not Numeric" { 254 | t.Errorf("%v", err) 255 | } 256 | v, err = BuildValue(time.Date(2012, time.February, 24, 23, 19, 43, 10, time.UTC)) 257 | if err != nil { 258 | t.Errorf("%v", err) 259 | } 260 | if !v.IsString() || v.String() != "2012-02-24 23:19:43" { 261 | t.Errorf("Expecting 2012-02-24 23:19:43, received %T: %s", v.Inner, v.String()) 262 | } 263 | v, err = BuildValue(Numeric([]byte("123"))) 264 | if err != nil { 265 | t.Errorf("%v", err) 266 | } 267 | if !v.IsNumeric() || v.String() != "123" { 268 | t.Errorf("Expecting 123, received %T: %s", v.Inner, v.String()) 269 | } 270 | v, err = BuildValue(Fractional([]byte("12.3"))) 271 | if err != nil { 272 | t.Errorf("%v", err) 273 | } 274 | if !v.IsFractional() || v.String() != "12.3" { 275 | t.Errorf("Expecting 12.3, received %T: %s", v.Inner, v.String()) 276 | } 277 | v, err = BuildValue(String([]byte("abc"))) 278 | if err != nil { 279 | t.Errorf("%v", err) 280 | } 281 | if !v.IsString() || v.String() != "abc" { 282 | t.Errorf("Expecting abc, received %T: %s", v.Inner, v.String()) 283 | } 284 | v, err = BuildValue(float32(1.23)) 285 | if err == nil { 286 | t.Errorf("Did not receive error") 287 | } 288 | v1 := Value{String([]byte("ab"))} 289 | v, err = BuildValue(v1) 290 | if err != nil { 291 | t.Errorf("%v", err) 292 | } 293 | if !v.IsString() || v.String() != "ab" { 294 | t.Errorf("Expecting ab, received %T: %s", v.Inner, v.String()) 295 | } 296 | v, err = BuildValue(float32(1.23)) 297 | if err == nil { 298 | t.Errorf("Did not receive error") 299 | } 300 | } 301 | 302 | // Ensure DONTESCAPE is not escaped 303 | func TestEncode(t *testing.T) { 304 | if SqlEncodeMap[DONTESCAPE] != DONTESCAPE { 305 | t.Errorf("Encode fail: %v", SqlEncodeMap[DONTESCAPE]) 306 | } 307 | if SqlDecodeMap[DONTESCAPE] != DONTESCAPE { 308 | t.Errorf("Decode fail: %v", SqlDecodeMap[DONTESCAPE]) 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /parser/filter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package parser provides filtering rules if you need. 6 | package parser 7 | 8 | import ( 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/golang/glog" 13 | ) 14 | 15 | // Callback function from proxy to postgresql for rewrite 16 | // request or sql. 17 | type Callback func(get []byte) bool 18 | 19 | // Extracte sql statement from string 20 | func Extracte(str []byte) string { 21 | return string(str)[5:] 22 | } 23 | 24 | // ReWrite SQL test 25 | func ReWriteSQL(str []byte) []byte { 26 | return append(str[0:5], []byte(strings.Replace(Extracte(str), "20", "10", -1))...) 27 | } 28 | 29 | // GetQueryModificada calllback 30 | func GetQueryModificada(queryOriginal string) string { 31 | if queryOriginal[:5] != "power" { 32 | 33 | return queryOriginal 34 | } 35 | return "select * from clientes limit 1;" 36 | } 37 | 38 | func Filter(str []byte) bool { 39 | sql := Extracte(str) 40 | tree, err := Parse(sql) 41 | if err != nil { 42 | glog.Errorln(err) 43 | return false 44 | } 45 | 46 | switch tree.(type) { 47 | case *Select: 48 | return ParseSelect(tree.(*Select)) 49 | case *Delete: 50 | return ParseDelete(tree.(*Delete)) 51 | case *Insert: 52 | return ParseInsert(tree.(*Insert)) 53 | case *Update: 54 | return ParseUpdate(tree.(*Update)) 55 | } 56 | return false 57 | } 58 | 59 | func Return(str []byte) bool { 60 | fmt.Println(string(str)) 61 | return true 62 | } 63 | 64 | func ParseSelect(sql *Select) bool { 65 | return !Is_SELECT_ALL(sql) && !Is_ORDER_BY_RAND(sql) 66 | } 67 | 68 | func Is_SELECT_ALL(sql *Select) bool { 69 | buf := NewTrackedBuffer(nil) 70 | sql.SelectExprs.Format(buf) 71 | if "*" == buf.String() { 72 | return true 73 | } 74 | return false 75 | } 76 | 77 | func Is_ORDER_BY_RAND(sql *Select) bool { 78 | buf := NewTrackedBuffer(nil) 79 | sql.OrderBy.Format(buf) 80 | if "rand()" == strings.ToLower(buf.String()) { 81 | return true 82 | } 83 | return false 84 | } 85 | 86 | func ParseDelete(sql *Delete) bool { 87 | return !Is_BIG_DELETE(sql) 88 | } 89 | 90 | func Is_BIG_DELETE(sql *Delete) bool { 91 | buf := NewTrackedBuffer(nil) 92 | sql.Limit.Format(buf) 93 | if "1000" < buf.String() { 94 | return true 95 | } 96 | return false 97 | } 98 | 99 | func ParseInsert(sql *Insert) bool { 100 | return !Is_BIG_INSERT(sql) 101 | } 102 | 103 | func Is_BIG_INSERT(sql *Insert) bool { 104 | buf := NewTrackedBuffer(nil) 105 | sql.Rows.Format(buf) 106 | if "1000" < buf.String() { 107 | return true 108 | } 109 | return false 110 | } 111 | 112 | func ParseUpdate(sql *Update) bool { 113 | return true 114 | } 115 | -------------------------------------------------------------------------------- /parser/parse_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import ( 8 | "fmt" 9 | "github.com/stretchr/testify/assert" 10 | "testing" 11 | ) 12 | 13 | func TestGen(t *testing.T) { 14 | _, err := Parse("select :a from a where a in (:b)") 15 | if err != nil { 16 | t.Error(err) 17 | } 18 | } 19 | 20 | func TestParse(t *testing.T) { 21 | sql := "select a from (select * from table1 where table1.a = 'tom') as t1, table2, table3 as t3, table4 left join table5 where t1.k = '1'" 22 | _, err := Parse(sql) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | } 27 | 28 | func TestParseInsert(t *testing.T) { 29 | sql := "INSERT INTO t3 VALUES (8, 10, 'baz')" 30 | _, err := Parse(sql) 31 | assert.Nil(t, err) 32 | } 33 | 34 | func TestCreatTable1(t *testing.T) { 35 | sql := `create table t1 ( 36 | ID int primary key, 37 | LastName varchar(255), 38 | FirstName varchar(255) 39 | )` 40 | tree, err := Parse(sql) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | s := String(tree) 45 | 46 | assert.Equal(t, sql, s) 47 | } 48 | 49 | func TestCreatTable2(t *testing.T) { 50 | sql := `create table t1 ( 51 | ID int primary key not null auto_increment, 52 | LastName varchar(255), 53 | FirstName varchar(255) 54 | )` 55 | tree, err := Parse(sql) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | s := String(tree) 60 | 61 | assert.Equal(t, sql, s) 62 | } 63 | 64 | func TestCreatTable3(t *testing.T) { 65 | sql := `create table t1 ( 66 | ID int unique key not null auto_increment, 67 | LastName varchar(255), 68 | FirstName varchar(255) 69 | )` 70 | tree, err := Parse(sql) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | s := String(tree) 75 | 76 | assert.Equal(t, sql, s) 77 | } 78 | 79 | func TestCreatTable4(t *testing.T) { 80 | for_precision := []string{"real", "double", "float", "decimal", "numeric"} 81 | for _, p := range for_precision { 82 | precision := "(32, 8)" 83 | for i := 0; i < 2; i++ { 84 | data_type := p 85 | if i == 0 { 86 | data_type += precision 87 | } 88 | sql := fmt.Sprintf(`create table t1 ( 89 | ID int unique key not null auto_increment, 90 | LastName varchar(255), 91 | FirstName varchar(255), 92 | Balance %s%s 93 | )`, p, precision) 94 | tree, err := Parse(sql) 95 | assert.Nil(t, err) 96 | s := String(tree) 97 | 98 | assert.Equal(t, sql, s) 99 | } 100 | } 101 | 102 | for_length := []string{"bit", "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "decimal", "numeric"} 103 | for _, p := range for_length { 104 | length := "(32)" 105 | for i := 0; i < 2; i++ { 106 | data_type := p 107 | if i == 0 { 108 | data_type += length 109 | } 110 | sql := fmt.Sprintf(`create table t1 ( 111 | ID int unique key not null auto_increment, 112 | LastName varchar(255), 113 | FirstName varchar(255), 114 | Balance %s%s 115 | )`, p, length) 116 | tree, err := Parse(sql) 117 | assert.Nil(t, err) 118 | s := String(tree) 119 | 120 | assert.Equal(t, sql, s) 121 | } 122 | } 123 | } 124 | 125 | func TestCreatTable5(t *testing.T) { 126 | for_time := []string{"date", "time", "timestamp", "datetime", "year"} 127 | for _, time := range for_time { 128 | sql := fmt.Sprintf(`create table t1 ( 129 | ID int unique key not null auto_increment, 130 | LastName varchar(255), 131 | FirstName varchar(255), 132 | LastUpdated %s 133 | )`, time) 134 | tree, err := Parse(sql) 135 | assert.Nil(t, err, "fail to parse:\n%s", sql) 136 | s := String(tree) 137 | assert.Equal(t, sql, s) 138 | } 139 | } 140 | 141 | func BenchmarkParse1(b *testing.B) { 142 | sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'" 143 | for i := 0; i < b.N; i++ { 144 | _, err := Parse(sql) 145 | if err != nil { 146 | b.Fatal(err) 147 | } 148 | } 149 | } 150 | 151 | func BenchmarkParse2(b *testing.B) { 152 | sql := "select aaaa, bbb, ccc, ddd, eeee, ffff, gggg, hhhh, iiii from tttt, ttt1, ttt3 where aaaa = bbbb and bbbb = cccc and dddd+1 = eeee group by fff, gggg having hhhh = iiii and iiii = jjjj order by kkkk, llll limit 3, 4" 153 | for i := 0; i < b.N; i++ { 154 | _, err := Parse(sql) 155 | if err != nil { 156 | b.Fatal(err) 157 | } 158 | } 159 | } 160 | 161 | type testCase struct { 162 | file string 163 | lineno int 164 | input string 165 | output string 166 | } 167 | -------------------------------------------------------------------------------- /parser/parsed_query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | 13 | "github.com/wgliang/pgproxy/parser/dependency/sqltypes" 14 | ) 15 | 16 | type bindLocation struct { 17 | offset, length int 18 | } 19 | 20 | type ParsedQuery struct { 21 | Query string 22 | bindLocations []bindLocation 23 | } 24 | 25 | type EncoderFunc func(value interface{}) ([]byte, error) 26 | 27 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}) ([]byte, error) { 28 | if len(pq.bindLocations) == 0 { 29 | return []byte(pq.Query), nil 30 | } 31 | buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query))) 32 | current := 0 33 | for _, loc := range pq.bindLocations { 34 | buf.WriteString(pq.Query[current:loc.offset]) 35 | name := pq.Query[loc.offset : loc.offset+loc.length] 36 | supplied, _, err := FetchBindVar(name, bindVariables) 37 | if err != nil { 38 | return nil, err 39 | } 40 | if err := EncodeValue(buf, supplied); err != nil { 41 | return nil, err 42 | } 43 | current = loc.offset + loc.length 44 | } 45 | buf.WriteString(pq.Query[current:]) 46 | return buf.Bytes(), nil 47 | } 48 | 49 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) { 50 | return json.Marshal(pq.Query) 51 | } 52 | 53 | func EncodeValue(buf *bytes.Buffer, value interface{}) error { 54 | switch bindVal := value.(type) { 55 | case nil: 56 | buf.WriteString("null") 57 | case []sqltypes.Value: 58 | for i := 0; i < len(bindVal); i++ { 59 | if i != 0 { 60 | buf.WriteString(", ") 61 | } 62 | if err := EncodeValue(buf, bindVal[i]); err != nil { 63 | return err 64 | } 65 | } 66 | case [][]sqltypes.Value: 67 | for i := 0; i < len(bindVal); i++ { 68 | if i != 0 { 69 | buf.WriteString(", ") 70 | } 71 | buf.WriteByte('(') 72 | if err := EncodeValue(buf, bindVal[i]); err != nil { 73 | return err 74 | } 75 | buf.WriteByte(')') 76 | } 77 | case []interface{}: 78 | buf.WriteByte('(') 79 | for i, v := range bindVal { 80 | if i != 0 { 81 | buf.WriteString(", ") 82 | } 83 | if err := EncodeValue(buf, v); err != nil { 84 | return err 85 | } 86 | } 87 | buf.WriteByte(')') 88 | case TupleEqualityList: 89 | if err := bindVal.Encode(buf); err != nil { 90 | return err 91 | } 92 | default: 93 | v, err := sqltypes.BuildValue(bindVal) 94 | if err != nil { 95 | return err 96 | } 97 | v.EncodeSql(buf) 98 | } 99 | return nil 100 | } 101 | 102 | type TupleEqualityList struct { 103 | Columns []string 104 | Rows [][]sqltypes.Value 105 | } 106 | 107 | func (tpl *TupleEqualityList) Encode(buf *bytes.Buffer) error { 108 | if len(tpl.Rows) == 0 { 109 | return errors.New("cannot encode with 0 rows") 110 | } 111 | if len(tpl.Columns) == 1 { 112 | return tpl.encodeAsIN(buf) 113 | } 114 | return tpl.encodeAsEquality(buf) 115 | } 116 | 117 | func (tpl *TupleEqualityList) encodeAsIN(buf *bytes.Buffer) error { 118 | buf.WriteString(tpl.Columns[0]) 119 | buf.WriteString(" in (") 120 | for i, r := range tpl.Rows { 121 | if len(r) != 1 { 122 | return errors.New("values don't match column count") 123 | } 124 | if i != 0 { 125 | buf.WriteString(", ") 126 | } 127 | if err := EncodeValue(buf, r); err != nil { 128 | return err 129 | } 130 | } 131 | buf.WriteByte(')') 132 | return nil 133 | } 134 | 135 | func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) error { 136 | for i, r := range tpl.Rows { 137 | if i != 0 { 138 | buf.WriteString(" or ") 139 | } 140 | buf.WriteString("(") 141 | for j, c := range tpl.Columns { 142 | if j != 0 { 143 | buf.WriteString(" and ") 144 | } 145 | buf.WriteString(c) 146 | buf.WriteString(" = ") 147 | if err := EncodeValue(buf, r[j]); err != nil { 148 | return err 149 | } 150 | } 151 | buf.WriteByte(')') 152 | } 153 | return nil 154 | } 155 | 156 | func FetchBindVar(name string, bindVariables map[string]interface{}) (val interface{}, isList bool, err error) { 157 | name = name[1:] 158 | if name[0] == ':' { 159 | name = name[1:] 160 | isList = true 161 | } 162 | supplied, ok := bindVariables[name] 163 | if !ok { 164 | return nil, false, fmt.Errorf("missing bind var %s", name) 165 | } 166 | list, gotList := supplied.([]interface{}) 167 | if isList { 168 | if !gotList { 169 | return nil, false, fmt.Errorf("unexpected list arg type %T for key %s", supplied, name) 170 | } 171 | if len(list) == 0 { 172 | return nil, false, fmt.Errorf("empty list supplied for %s", name) 173 | } 174 | return list, true, nil 175 | } 176 | if gotList { 177 | return nil, false, fmt.Errorf("unexpected arg type %T for key %s", supplied, name) 178 | } 179 | return supplied, false, nil 180 | } 181 | -------------------------------------------------------------------------------- /parser/parsed_query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/wgliang/pgproxy/parser/dependency/sqltypes" 11 | ) 12 | 13 | func TestParsedQuery(t *testing.T) { 14 | tcases := []struct { 15 | desc string 16 | query string 17 | bindVars map[string]interface{} 18 | output string 19 | }{ 20 | { 21 | "no subs", 22 | "select * from a where id = 2", 23 | map[string]interface{}{ 24 | "id": 1, 25 | }, 26 | "select * from a where id = 2", 27 | }, { 28 | "simple bindvar sub", 29 | "select * from a where id1 = :id1 and id2 = :id2", 30 | map[string]interface{}{ 31 | "id1": 1, 32 | "id2": nil, 33 | }, 34 | "select * from a where id1 = 1 and id2 = null", 35 | }, { 36 | "missing bind var", 37 | "select * from a where id1 = :id1 and id2 = :id2", 38 | map[string]interface{}{ 39 | "id1": 1, 40 | }, 41 | "missing bind var id2", 42 | }, { 43 | "unencodable bind var", 44 | "select * from a where id1 = :id", 45 | map[string]interface{}{ 46 | "id": make([]int, 1), 47 | }, 48 | "unsupported bind variable type []int: [0]", 49 | }, { 50 | "list inside bind vars", 51 | "select * from a where id in (:vals)", 52 | map[string]interface{}{ 53 | "vals": []sqltypes.Value{ 54 | sqltypes.MakeNumeric([]byte("1")), 55 | sqltypes.MakeString([]byte("aa")), 56 | }, 57 | }, 58 | "select * from a where id in (1, 'aa')", 59 | }, { 60 | "two lists inside bind vars", 61 | "select * from a where id in (:vals)", 62 | map[string]interface{}{ 63 | "vals": [][]sqltypes.Value{ 64 | []sqltypes.Value{ 65 | sqltypes.MakeNumeric([]byte("1")), 66 | sqltypes.MakeString([]byte("aa")), 67 | }, 68 | []sqltypes.Value{ 69 | sqltypes.Value{}, 70 | sqltypes.MakeString([]byte("bb")), 71 | }, 72 | }, 73 | }, 74 | "select * from a where id in ((1, 'aa'), (null, 'bb'))", 75 | }, { 76 | "list bind vars", 77 | "select * from a where id in ::vals", 78 | map[string]interface{}{ 79 | "vals": []interface{}{ 80 | 1, 81 | "aa", 82 | }, 83 | }, 84 | "select * from a where id in (1, 'aa')", 85 | }, { 86 | "list bind vars single argument", 87 | "select * from a where id in ::vals", 88 | map[string]interface{}{ 89 | "vals": []interface{}{ 90 | 1, 91 | }, 92 | }, 93 | "select * from a where id in (1)", 94 | }, { 95 | "list bind vars 0 arguments", 96 | "select * from a where id in ::vals", 97 | map[string]interface{}{ 98 | "vals": []interface{}{}, 99 | }, 100 | "empty list supplied for vals", 101 | }, { 102 | "non-list bind var supplied", 103 | "select * from a where id in ::vals", 104 | map[string]interface{}{ 105 | "vals": 1, 106 | }, 107 | "unexpected list arg type int for key vals", 108 | }, { 109 | "list bind var for non-list", 110 | "select * from a where id = :vals", 111 | map[string]interface{}{ 112 | "vals": []interface{}{1}, 113 | }, 114 | "unexpected arg type []interface {} for key vals", 115 | }, { 116 | "single column tuple equality", 117 | // We have to use an incorrect construct to get around the parser. 118 | "select * from a where b = :equality", 119 | map[string]interface{}{ 120 | "equality": TupleEqualityList{ 121 | Columns: []string{"pk"}, 122 | Rows: [][]sqltypes.Value{ 123 | []sqltypes.Value{sqltypes.MakeNumeric([]byte("1"))}, 124 | []sqltypes.Value{sqltypes.MakeString([]byte("aa"))}, 125 | }, 126 | }, 127 | }, 128 | "select * from a where b = pk in (1, 'aa')", 129 | }, { 130 | "multi column tuple equality", 131 | "select * from a where b = :equality", 132 | map[string]interface{}{ 133 | "equality": TupleEqualityList{ 134 | Columns: []string{"pk1", "pk2"}, 135 | Rows: [][]sqltypes.Value{ 136 | []sqltypes.Value{ 137 | sqltypes.MakeNumeric([]byte("1")), 138 | sqltypes.MakeString([]byte("aa")), 139 | }, 140 | []sqltypes.Value{ 141 | sqltypes.MakeNumeric([]byte("2")), 142 | sqltypes.MakeString([]byte("bb")), 143 | }, 144 | }, 145 | }, 146 | }, 147 | "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 148 | }, { 149 | "0 rows", 150 | "select * from a where b = :equality", 151 | map[string]interface{}{ 152 | "equality": TupleEqualityList{ 153 | Columns: []string{"pk"}, 154 | Rows: [][]sqltypes.Value{}, 155 | }, 156 | }, 157 | "cannot encode with 0 rows", 158 | }, { 159 | "values don't match column count", 160 | "select * from a where b = :equality", 161 | map[string]interface{}{ 162 | "equality": TupleEqualityList{ 163 | Columns: []string{"pk"}, 164 | Rows: [][]sqltypes.Value{ 165 | []sqltypes.Value{ 166 | sqltypes.MakeNumeric([]byte("1")), 167 | sqltypes.MakeString([]byte("aa")), 168 | }, 169 | }, 170 | }, 171 | }, 172 | "values don't match column count", 173 | }, 174 | } 175 | 176 | for _, tcase := range tcases { 177 | tree, err := Parse(tcase.query) 178 | if err != nil { 179 | t.Errorf("parse failed for %s: %v", tcase.desc, err) 180 | continue 181 | } 182 | buf := NewTrackedBuffer(nil) 183 | buf.Myprintf("%v", tree) 184 | pq := buf.ParsedQuery() 185 | bytes, err := pq.GenerateQuery(tcase.bindVars) 186 | var got string 187 | if err != nil { 188 | got = err.Error() 189 | } else { 190 | got = string(bytes) 191 | } 192 | if got != tcase.output { 193 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output) 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /parser/rewriter.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | _ "fmt" 5 | "reflect" 6 | ) 7 | 8 | var typeOfBytes = reflect.TypeOf([]byte(nil)) 9 | var typeOfSQLNode = reflect.TypeOf((*SQLNode)(nil)).Elem() 10 | 11 | type Rewriter func([]byte) []byte 12 | 13 | func Rewrite(node SQLNode, rewriter Rewriter) { 14 | rewrite(reflect.ValueOf(node), rewriter) 15 | } 16 | func rewrite(nodeVal reflect.Value, rewriter Rewriter) { 17 | if !nodeVal.IsValid() { 18 | return 19 | } 20 | nodeTyp := nodeVal.Type() 21 | switch nodeTyp.Kind() { 22 | case reflect.Slice: 23 | if nodeTyp == typeOfBytes && !nodeVal.IsNil() { 24 | val := rewriter(nodeVal.Bytes()) //use rewriter to rewrite the bytes 25 | nodeVal.SetBytes(val) 26 | } else if nodeTyp.Implements(typeOfSQLNode) { 27 | for i := 0; i < nodeVal.Len(); i++ { 28 | m := nodeVal.Index(i) 29 | rewrite(m, rewriter) 30 | } 31 | } 32 | case reflect.Struct: 33 | for i := 0; i < nodeVal.NumField(); i++ { 34 | f := nodeVal.Field(i) 35 | rewrite(f, rewriter) 36 | } 37 | case reflect.Ptr, reflect.Interface: 38 | rewrite(nodeVal.Elem(), rewriter) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /parser/rewriter_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestRewriteQuery(t *testing.T) { 10 | sql := "select distinct table1.* from table1 as t1" 11 | tree, _ := Parse(sql) 12 | 13 | rewriter := func(origin []byte) []byte { 14 | s := string(origin) 15 | if s == "table1" { 16 | s = fmt.Sprintf("%s%s%s", "_", s, "_") 17 | } 18 | return []byte(s) 19 | } 20 | 21 | Rewrite(tree, rewriter) 22 | 23 | expected := "select distinct _table1_.* from _table1_ as t1" 24 | actual := String(tree) 25 | 26 | assert.Equal(t, expected, actual) 27 | } 28 | -------------------------------------------------------------------------------- /parser/token.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/wgliang/pgproxy/parser/dependency/sqltypes" 13 | ) 14 | 15 | const EOFCHAR = 0x100 16 | 17 | // Tokenizer is the struct used to generate SQL 18 | // tokens for the parser. 19 | type Tokenizer struct { 20 | InStream *strings.Reader 21 | AllowComments bool 22 | ForceEOF bool 23 | lastChar uint16 24 | Position int 25 | errorToken []byte 26 | LastError string 27 | posVarIndex int 28 | ParseTree Statement 29 | } 30 | 31 | // NewStringTokenizer creates a new Tokenizer for the 32 | // sql string. 33 | func NewStringTokenizer(sql string) *Tokenizer { 34 | return &Tokenizer{InStream: strings.NewReader(sql)} 35 | } 36 | 37 | var keywords = map[string]int{ 38 | "all": ALL, 39 | "alter": ALTER, 40 | "analyze": ANALYZE, 41 | "and": AND, 42 | "as": AS, 43 | "asc": ASC, 44 | "between": BETWEEN, 45 | "by": BY, 46 | "case": CASE, 47 | "create": CREATE, 48 | "cross": CROSS, 49 | "default": DEFAULT, 50 | "delete": DELETE, 51 | "desc": DESC, 52 | "describe": DESCRIBE, 53 | "distinct": DISTINCT, 54 | "drop": DROP, 55 | "duplicate": DUPLICATE, 56 | "else": ELSE, 57 | "end": END, 58 | "except": EXCEPT, 59 | "exists": EXISTS, 60 | "explain": EXPLAIN, 61 | "for": FOR, 62 | "force": FORCE, 63 | "from": FROM, 64 | "group": GROUP, 65 | "having": HAVING, 66 | "if": IF, 67 | "ignore": IGNORE, 68 | "in": IN, 69 | "index": INDEX, 70 | "inner": INNER, 71 | "insert": INSERT, 72 | "intersect": INTERSECT, 73 | "into": INTO, 74 | "is": IS, 75 | "join": JOIN, 76 | "key": KEY, 77 | "left": LEFT, 78 | "like": LIKE, 79 | "limit": LIMIT, 80 | "lock": LOCK, 81 | "minus": MINUS, 82 | "natural": NATURAL, 83 | "not": NOT, 84 | "null": NULL, 85 | "on": ON, 86 | "or": OR, 87 | "order": ORDER, 88 | "outer": OUTER, 89 | "rename": RENAME, 90 | "right": RIGHT, 91 | "select": SELECT, 92 | "set": SET, 93 | "show": SHOW, 94 | "straight_join": STRAIGHT_JOIN, 95 | "table": TABLE, 96 | "then": THEN, 97 | "to": TO, 98 | "union": UNION, 99 | "unique": UNIQUE, 100 | "update": UPDATE, 101 | "use": USE, 102 | "using": USING, 103 | "values": VALUES, 104 | "view": VIEW, 105 | "when": WHEN, 106 | "where": WHERE, 107 | 108 | //keywords for creat table 109 | 110 | //datatypes 111 | "bit": BIT, 112 | "tinyint": TINYINT, 113 | "smallint": SMALLINT, 114 | "mediumint": MEDIUMINT, 115 | "int": INT, 116 | "integer": INTEGER, 117 | "bigint": BIGINT, 118 | "real": REAL, 119 | "double": DOUBLE, 120 | "float": FLOAT, 121 | "decimal": DECIMAL, 122 | "numeric": NUMERIC, 123 | 124 | "char": CHAR, 125 | "varchar": VARCHAR, 126 | "text": TEXT, 127 | 128 | "date": DATE, 129 | "time": TIME, 130 | "timestamp": TIMESTAMP, 131 | "datetime": DATETIME, 132 | "year": YEAR, 133 | 134 | //other keywords 135 | "unsigned": UNSIGNED, 136 | "zerofill": ZEROFILL, 137 | "primary": PRIMARY, 138 | "auto_increment": AUTO_INCREMENT, 139 | } 140 | 141 | // Lex returns the next token form the Tokenizer. 142 | // This function is used by go yacc. 143 | func (tkn *Tokenizer) Lex(lval *yySymType) int { 144 | typ, val := tkn.Scan() 145 | for typ == COMMENT { 146 | if tkn.AllowComments { 147 | break 148 | } 149 | typ, val = tkn.Scan() 150 | } 151 | switch typ { 152 | case ID, STRING, NUMBER, VALUE_ARG, LIST_ARG, COMMENT: 153 | lval.bytes = val 154 | } 155 | tkn.errorToken = val 156 | return typ 157 | } 158 | 159 | // Error is called by go yacc if there's a parsing error. 160 | func (tkn *Tokenizer) Error(err string) { 161 | buf := bytes.NewBuffer(make([]byte, 0, 32)) 162 | if tkn.errorToken != nil { 163 | fmt.Fprintf(buf, "%s at position %v near %s", err, tkn.Position, tkn.errorToken) 164 | } else { 165 | fmt.Fprintf(buf, "%s at position %v", err, tkn.Position) 166 | } 167 | tkn.LastError = buf.String() 168 | } 169 | 170 | // Scan scans the tokenizer for the next token and returns 171 | // the token type and an optional value. 172 | func (tkn *Tokenizer) Scan() (int, []byte) { 173 | if tkn.ForceEOF { 174 | return 0, nil 175 | } 176 | 177 | if tkn.lastChar == 0 { 178 | tkn.next() 179 | } 180 | tkn.skipBlank() 181 | switch ch := tkn.lastChar; { 182 | case isLetter(ch): 183 | return tkn.scanIdentifier() 184 | case isDigit(ch): 185 | return tkn.scanNumber(false) 186 | case ch == ':': 187 | return tkn.scanBindVar() 188 | default: 189 | tkn.next() 190 | switch ch { 191 | case EOFCHAR: 192 | return 0, nil 193 | case '=', ',', ';', '(', ')', '+', '*', '%', '&', '|', '^', '~': 194 | return int(ch), nil 195 | case '?': 196 | tkn.posVarIndex++ 197 | buf := new(bytes.Buffer) 198 | fmt.Fprintf(buf, ":v%d", tkn.posVarIndex) 199 | return VALUE_ARG, buf.Bytes() 200 | case '.': 201 | if isDigit(tkn.lastChar) { 202 | return tkn.scanNumber(true) 203 | } else { 204 | return int(ch), nil 205 | } 206 | case '/': 207 | switch tkn.lastChar { 208 | case '/': 209 | tkn.next() 210 | return tkn.scanCommentType1("//") 211 | case '*': 212 | tkn.next() 213 | return tkn.scanCommentType2() 214 | default: 215 | return int(ch), nil 216 | } 217 | case '-': 218 | if tkn.lastChar == '-' { 219 | tkn.next() 220 | return tkn.scanCommentType1("--") 221 | } else { 222 | return int(ch), nil 223 | } 224 | case '<': 225 | switch tkn.lastChar { 226 | case '>': 227 | tkn.next() 228 | return NE, nil 229 | case '=': 230 | tkn.next() 231 | switch tkn.lastChar { 232 | case '>': 233 | tkn.next() 234 | return NULL_SAFE_EQUAL, nil 235 | default: 236 | return LE, nil 237 | } 238 | default: 239 | return int(ch), nil 240 | } 241 | case '>': 242 | if tkn.lastChar == '=' { 243 | tkn.next() 244 | return GE, nil 245 | } else { 246 | return int(ch), nil 247 | } 248 | case '!': 249 | if tkn.lastChar == '=' { 250 | tkn.next() 251 | return NE, nil 252 | } else { 253 | return LEX_ERROR, []byte("!") 254 | } 255 | case '\'', '"': 256 | return tkn.scanString(ch, STRING) 257 | case '`': 258 | return tkn.scanLiteralIdentifier() 259 | default: 260 | return LEX_ERROR, []byte{byte(ch)} 261 | } 262 | } 263 | } 264 | 265 | func (tkn *Tokenizer) skipBlank() { 266 | ch := tkn.lastChar 267 | for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 268 | tkn.next() 269 | ch = tkn.lastChar 270 | } 271 | } 272 | 273 | func (tkn *Tokenizer) scanIdentifier() (int, []byte) { 274 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 275 | buffer.WriteByte(byte(tkn.lastChar)) 276 | for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() { 277 | buffer.WriteByte(byte(tkn.lastChar)) 278 | } 279 | lowered := bytes.ToLower(buffer.Bytes()) 280 | if keywordId, found := keywords[string(lowered)]; found { 281 | return keywordId, lowered 282 | } 283 | return ID, buffer.Bytes() 284 | } 285 | 286 | func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) { 287 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 288 | buffer.WriteByte(byte(tkn.lastChar)) 289 | if !isLetter(tkn.lastChar) { 290 | return LEX_ERROR, buffer.Bytes() 291 | } 292 | for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() { 293 | buffer.WriteByte(byte(tkn.lastChar)) 294 | } 295 | if tkn.lastChar != '`' { 296 | return LEX_ERROR, buffer.Bytes() 297 | } 298 | tkn.next() 299 | return ID, buffer.Bytes() 300 | } 301 | 302 | func (tkn *Tokenizer) scanBindVar() (int, []byte) { 303 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 304 | buffer.WriteByte(byte(tkn.lastChar)) 305 | token := VALUE_ARG 306 | tkn.next() 307 | if tkn.lastChar == ':' { 308 | token = LIST_ARG 309 | buffer.WriteByte(byte(tkn.lastChar)) 310 | tkn.next() 311 | } 312 | if !isLetter(tkn.lastChar) { 313 | return LEX_ERROR, buffer.Bytes() 314 | } 315 | for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' { 316 | buffer.WriteByte(byte(tkn.lastChar)) 317 | tkn.next() 318 | } 319 | return token, buffer.Bytes() 320 | } 321 | 322 | func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes.Buffer) { 323 | for digitVal(tkn.lastChar) < base { 324 | tkn.ConsumeNext(buffer) 325 | } 326 | } 327 | 328 | func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) { 329 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 330 | if seenDecimalPoint { 331 | buffer.WriteByte('.') 332 | tkn.scanMantissa(10, buffer) 333 | goto exponent 334 | } 335 | 336 | if tkn.lastChar == '0' { 337 | // int or float 338 | tkn.ConsumeNext(buffer) 339 | if tkn.lastChar == 'x' || tkn.lastChar == 'X' { 340 | // hexadecimal int 341 | tkn.ConsumeNext(buffer) 342 | tkn.scanMantissa(16, buffer) 343 | } else { 344 | // octal int or float 345 | seenDecimalDigit := false 346 | tkn.scanMantissa(8, buffer) 347 | if tkn.lastChar == '8' || tkn.lastChar == '9' { 348 | // illegal octal int or float 349 | seenDecimalDigit = true 350 | tkn.scanMantissa(10, buffer) 351 | } 352 | if tkn.lastChar == '.' || tkn.lastChar == 'e' || tkn.lastChar == 'E' { 353 | goto fraction 354 | } 355 | // octal int 356 | if seenDecimalDigit { 357 | return LEX_ERROR, buffer.Bytes() 358 | } 359 | } 360 | goto exit 361 | } 362 | 363 | // decimal int or float 364 | tkn.scanMantissa(10, buffer) 365 | 366 | fraction: 367 | if tkn.lastChar == '.' { 368 | tkn.ConsumeNext(buffer) 369 | tkn.scanMantissa(10, buffer) 370 | } 371 | 372 | exponent: 373 | if tkn.lastChar == 'e' || tkn.lastChar == 'E' { 374 | tkn.ConsumeNext(buffer) 375 | if tkn.lastChar == '+' || tkn.lastChar == '-' { 376 | tkn.ConsumeNext(buffer) 377 | } 378 | tkn.scanMantissa(10, buffer) 379 | } 380 | 381 | exit: 382 | return NUMBER, buffer.Bytes() 383 | } 384 | 385 | func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) { 386 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 387 | for { 388 | ch := tkn.lastChar 389 | tkn.next() 390 | if ch == delim { 391 | if tkn.lastChar == delim { 392 | tkn.next() 393 | } else { 394 | break 395 | } 396 | } else if ch == '\\' { 397 | if tkn.lastChar == EOFCHAR { 398 | return LEX_ERROR, buffer.Bytes() 399 | } 400 | if decodedChar := sqltypes.SqlDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DONTESCAPE { 401 | ch = tkn.lastChar 402 | } else { 403 | ch = uint16(decodedChar) 404 | } 405 | tkn.next() 406 | } 407 | if ch == EOFCHAR { 408 | return LEX_ERROR, buffer.Bytes() 409 | } 410 | buffer.WriteByte(byte(ch)) 411 | } 412 | return typ, buffer.Bytes() 413 | } 414 | 415 | func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) { 416 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 417 | buffer.WriteString(prefix) 418 | for tkn.lastChar != EOFCHAR { 419 | if tkn.lastChar == '\n' { 420 | tkn.ConsumeNext(buffer) 421 | break 422 | } 423 | tkn.ConsumeNext(buffer) 424 | } 425 | return COMMENT, buffer.Bytes() 426 | } 427 | 428 | func (tkn *Tokenizer) scanCommentType2() (int, []byte) { 429 | buffer := bytes.NewBuffer(make([]byte, 0, 8)) 430 | buffer.WriteString("/*") 431 | for { 432 | if tkn.lastChar == '*' { 433 | tkn.ConsumeNext(buffer) 434 | if tkn.lastChar == '/' { 435 | tkn.ConsumeNext(buffer) 436 | break 437 | } 438 | continue 439 | } 440 | if tkn.lastChar == EOFCHAR { 441 | return LEX_ERROR, buffer.Bytes() 442 | } 443 | tkn.ConsumeNext(buffer) 444 | } 445 | return COMMENT, buffer.Bytes() 446 | } 447 | 448 | func (tkn *Tokenizer) ConsumeNext(buffer *bytes.Buffer) { 449 | if tkn.lastChar == EOFCHAR { 450 | // This should never happen. 451 | panic("unexpected EOF") 452 | } 453 | buffer.WriteByte(byte(tkn.lastChar)) 454 | tkn.next() 455 | } 456 | 457 | func (tkn *Tokenizer) next() { 458 | if ch, err := tkn.InStream.ReadByte(); err != nil { 459 | // Only EOF is possible. 460 | tkn.lastChar = EOFCHAR 461 | } else { 462 | tkn.lastChar = uint16(ch) 463 | } 464 | tkn.Position++ 465 | } 466 | 467 | func isLetter(ch uint16) bool { 468 | return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@' 469 | } 470 | 471 | func digitVal(ch uint16) int { 472 | switch { 473 | case '0' <= ch && ch <= '9': 474 | return int(ch) - '0' 475 | case 'a' <= ch && ch <= 'f': 476 | return int(ch) - 'a' + 10 477 | case 'A' <= ch && ch <= 'F': 478 | return int(ch) - 'A' + 10 479 | } 480 | return 16 // larger than any legal digit val 481 | } 482 | 483 | func isDigit(ch uint16) bool { 484 | return '0' <= ch && ch <= '9' 485 | } 486 | -------------------------------------------------------------------------------- /parser/tracked_buffer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package parser 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | ) 11 | 12 | // TrackedBuffer is used to rebuild a query from the ast. 13 | // bindLocations keeps track of locations in the buffer that 14 | // use bind variables for efficient future substitutions. 15 | // nodeFormatter is the formatting function the buffer will 16 | // use to format a node. By default(nil), it's FormatNode. 17 | // But you can supply a different formatting function if you 18 | // want to generate a query that's different from the default. 19 | type TrackedBuffer struct { 20 | *bytes.Buffer 21 | bindLocations []bindLocation 22 | nodeFormatter func(buf *TrackedBuffer, node SQLNode) 23 | } 24 | 25 | func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node SQLNode)) *TrackedBuffer { 26 | buf := &TrackedBuffer{ 27 | Buffer: bytes.NewBuffer(make([]byte, 0, 128)), 28 | bindLocations: make([]bindLocation, 0, 4), 29 | nodeFormatter: nodeFormatter, 30 | } 31 | return buf 32 | } 33 | 34 | // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v), 35 | // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in 36 | // which case it adds tracking info for future substitutions. 37 | // 38 | // The name must be something other than the usual Printf() to avoid "go vet" 39 | // warnings due to our custom format specifiers. 40 | func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { 41 | end := len(format) 42 | fieldnum := 0 43 | for i := 0; i < end; { 44 | lasti := i 45 | for i < end && format[i] != '%' { 46 | i++ 47 | } 48 | if i > lasti { 49 | buf.WriteString(format[lasti:i]) 50 | } 51 | if i >= end { 52 | break 53 | } 54 | i++ // '%' 55 | switch format[i] { 56 | case 'c': 57 | switch v := values[fieldnum].(type) { 58 | case byte: 59 | buf.WriteByte(v) 60 | case rune: 61 | buf.WriteRune(v) 62 | default: 63 | panic(fmt.Sprintf("unexpected type %T", v)) 64 | } 65 | case 's': 66 | switch v := values[fieldnum].(type) { 67 | case []byte: 68 | buf.Write(v) 69 | case string: 70 | buf.WriteString(v) 71 | default: 72 | panic(fmt.Sprintf("unexpected type %T", v)) 73 | } 74 | case 'v': 75 | node := values[fieldnum].(SQLNode) 76 | if buf.nodeFormatter == nil { 77 | node.Format(buf) 78 | } else { 79 | buf.nodeFormatter(buf, node) 80 | } 81 | case 'a': 82 | buf.WriteArg(values[fieldnum].(string)) 83 | default: 84 | panic("unexpected") 85 | } 86 | fieldnum++ 87 | i++ 88 | } 89 | } 90 | 91 | // WriteArg writes a value argument into the buffer. arg should not contain 92 | // the ':' prefix. It also adds tracking info for future substitutions. 93 | func (buf *TrackedBuffer) WriteArg(arg string) { 94 | buf.bindLocations = append(buf.bindLocations, bindLocation{ 95 | offset: buf.Len(), 96 | length: len(arg), 97 | }) 98 | buf.WriteString(arg) 99 | } 100 | 101 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 102 | return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations} 103 | } 104 | 105 | func (buf *TrackedBuffer) HasBindVars() bool { 106 | return len(buf.bindLocations) != 0 107 | } 108 | -------------------------------------------------------------------------------- /pgproxy.conf: -------------------------------------------------------------------------------- 1 | # pgproxy配置文件 2 | 3 | # 默认服务配置 4 | [ServerConfig] 5 | ProxyAddr = "127.0.0.1:9090" # proxy服务地址,接受数据库连接 6 | 7 | # 默认数据库配置 8 | # master 为默认主数据库,slave[N]为从数据库 9 | [DB] 10 | [DB.master] # 需要代理的pg数据库名称 11 | Addr = "10.16.93.15:5432" # 数据库地址 12 | User = "postgres" # 数据库用户名 13 | Password = "Ilove360!" # 数据库密码 14 | DbName = "skylar" # 数据库名称 15 | 16 | [DB.slave1] 17 | Addr = "127.0.0.1:5433" 18 | User = "postgres" 19 | Password = "" 20 | DbName = "db" -------------------------------------------------------------------------------- /pgproxy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgliang/pgproxy/e19ae27c28d7d454fad8c80f2f72aca5633a42e3/pgproxy.png -------------------------------------------------------------------------------- /proxy/formate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package proxy provides proxy service and redirects requests 6 | // form proxy.Addr to remote.Addr. 7 | package proxy 8 | 9 | import ( 10 | "database/sql" 11 | "fmt" 12 | "os" 13 | "strconv" 14 | 15 | "github.com/golang/glog" 16 | "github.com/olekukonko/tablewriter" 17 | ) 18 | 19 | // Parse query's results and formate it,then will be print 20 | // in command line such as: 21 | // +---------+----------------+----------+ 22 | // | ID | IP | NAME | 23 | // +---------+----------------+----------+ 24 | // | 1 | 180.17.95.2 | Jack | 25 | // | 2 | 180.17.95.3 | Wong | 26 | // | 3 | 180.17.95.4 | Lin | 27 | // | 4 | 180.17.95.5 | Trump | 28 | // +---------+----------------+----------+ 29 | // else error 30 | func RowsFormater(rows *sql.Rows) { 31 | cols, err := rows.Columns() 32 | if err != nil { 33 | glog.Errorln(err) 34 | } 35 | table := tablewriter.NewWriter(os.Stdout) 36 | table.SetHeader(cols) 37 | data := make([][]string, 1) 38 | count := 0 39 | for rows.Next() { 40 | columns := make([]interface{}, len(cols)) 41 | columnPointers := make([]interface{}, len(cols)) 42 | for i, _ := range columns { 43 | columnPointers[i] = &columns[i] 44 | } 45 | 46 | // Scan the result into the column pointers... 47 | if err := rows.Scan(columnPointers...); err != nil { 48 | fmt.Println(err) 49 | } 50 | 51 | // Create our map, and retrieve the value for each column from the pointers slice, 52 | // storing it in the map with the name of the column as the key. 53 | row := make([]string, 0) 54 | for i, _ := range cols { 55 | val := columnPointers[i].(*interface{}) 56 | row = append(row, interface2String(*val)) 57 | } 58 | 59 | data = append(data, row) 60 | count = count + 1 61 | } 62 | for _, v := range data { 63 | table.Append(v) 64 | } 65 | table.Render() 66 | if count > 0 { 67 | fmt.Printf("(%d rows of records)\n", count) 68 | } 69 | } 70 | 71 | // Parse exec's results and formate it,then will be print 72 | // in command line such as: 73 | // OK, [n] rows affected 74 | // else error 75 | func ResultFormater(res sql.Result) { 76 | rowsAffected, err := res.RowsAffected() 77 | if err != nil { 78 | fmt.Println(err) 79 | } else { 80 | fmt.Printf("OK, %d rows affected\n", rowsAffected) 81 | } 82 | } 83 | 84 | // Convert type interface{} into string just for friendly display. 85 | func interface2String(input interface{}) string { 86 | switch input.(type) { 87 | case string: 88 | return input.(string) 89 | case int64: 90 | return strconv.FormatInt(input.(int64), 10) 91 | case []byte: 92 | return string(input.([]byte)) 93 | default: 94 | return "" 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /proxy/proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 wgliang. All rights reserved. 2 | // Use of this source code is governed by Apache 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package proxy provides proxy service and redirects requests 6 | // form proxy.Addr to remote.Addr. 7 | package proxy 8 | 9 | import ( 10 | // "bytes" 11 | // "encoding/binary" 12 | "errors" 13 | "fmt" 14 | "io" 15 | "net" 16 | 17 | "github.com/golang/glog" 18 | "github.com/wgliang/pgproxy/parser" 19 | ) 20 | 21 | var ( 22 | connid = uint64(0) // Self-increasing ConnectID. 23 | ) 24 | 25 | // Start proxy server needed receive and proxyHost, all 26 | // the request or database's sql of receive will redirect 27 | // to remoteHost. 28 | func Start(proxyHost, remoteHost string, filterCallback, returnCallBack parser.Callback) { 29 | defer glog.Flush() 30 | glog.Infof("Proxying from %v to %v\n", proxyHost, remoteHost) 31 | 32 | proxyAddr := getResolvedAddresses(proxyHost) 33 | remoteAddr := getResolvedAddresses(remoteHost) 34 | listener := getListener(proxyAddr) 35 | 36 | for { 37 | conn, err := listener.AcceptTCP() 38 | if err != nil { 39 | glog.Errorf("Failed to accept connection '%s'\n", err) 40 | continue 41 | } 42 | connid++ 43 | 44 | p := &Proxy{ 45 | lconn: conn, 46 | laddr: proxyAddr, 47 | raddr: remoteAddr, 48 | erred: false, 49 | errsig: make(chan bool), 50 | prefix: fmt.Sprintf("Connection #%03d ", connid), 51 | connId: connid, 52 | } 53 | go p.service(filterCallback, returnCallBack) 54 | } 55 | } 56 | 57 | // ResolvedAddresses of host. 58 | func getResolvedAddresses(host string) *net.TCPAddr { 59 | addr, err := net.ResolveTCPAddr("tcp", host) 60 | if err != nil { 61 | glog.Fatalln("ResolveTCPAddr of host:", err) 62 | } 63 | return addr 64 | } 65 | 66 | // Listener of a net.TCPAddr. 67 | func getListener(addr *net.TCPAddr) *net.TCPListener { 68 | listener, err := net.ListenTCP("tcp", addr) 69 | if err != nil { 70 | glog.Fatalf("ListenTCP of %s error:%v", addr, err) 71 | } 72 | return listener 73 | } 74 | 75 | // Proxy - Manages a Proxy connection, piping data between proxy and remote. 76 | type Proxy struct { 77 | sentBytes uint64 78 | receivedBytes uint64 79 | laddr, raddr *net.TCPAddr 80 | lconn, rconn *net.TCPConn 81 | erred bool 82 | errsig chan bool 83 | prefix string 84 | connId uint64 85 | } 86 | 87 | // New - Create a new Proxy instance. Takes over local connection passed in, 88 | // and closes it when finished. 89 | func New(conn *net.TCPConn, proxyAddr, remoteAddr *net.TCPAddr, connid uint64) *Proxy { 90 | return &Proxy{ 91 | lconn: conn, 92 | laddr: proxyAddr, 93 | raddr: remoteAddr, 94 | erred: false, 95 | errsig: make(chan bool), 96 | prefix: fmt.Sprintf("Connection #%03d ", connid), 97 | connId: connid, 98 | } 99 | } 100 | 101 | // proxy.err 102 | func (p *Proxy) err(s string, err error) { 103 | if p.erred { 104 | return 105 | } 106 | if err != io.EOF { 107 | glog.Errorf(p.prefix+s, err) 108 | } 109 | p.errsig <- true 110 | p.erred = true 111 | } 112 | 113 | // Proxy.service open connection to remote and service proxying data. 114 | func (p *Proxy) service(filterCallback, returnCallBack parser.Callback) { 115 | defer p.lconn.Close() 116 | // connect to remote server 117 | rconn, err := net.DialTCP("tcp", nil, p.raddr) 118 | if err != nil { 119 | p.err("Remote connection failed: %s", err) 120 | return 121 | } 122 | p.rconn = rconn 123 | defer p.rconn.Close() 124 | // proxying data 125 | go p.handleIncomingConnection(p.lconn, p.rconn, filterCallback) 126 | go p.handleResponseConnection(p.rconn, p.lconn, returnCallBack) 127 | // wait for close... 128 | <-p.errsig 129 | } 130 | 131 | // Proxy.handleIncomingConnection 132 | func (p *Proxy) handleIncomingConnection(src, dst *net.TCPConn, Callback parser.Callback) { 133 | // directional copy (64k buffer) 134 | buff := make([]byte, 0xffff) 135 | 136 | for { 137 | n, err := src.Read(buff) 138 | if err != nil { 139 | p.err("Read failed '%s'\n", err) 140 | return 141 | } 142 | b, err := getModifiedBuffer(buff[:n], Callback) 143 | if err != nil { 144 | p.err("%s\n", err) 145 | err = dst.Close() 146 | if err != nil { 147 | glog.Errorln(err) 148 | } 149 | return 150 | } 151 | 152 | n, err = dst.Write(b) 153 | if err != nil { 154 | p.err("Write failed '%s'\n", err) 155 | return 156 | } 157 | } 158 | } 159 | 160 | // Proxy.handleResponseConnection 161 | func (p *Proxy) handleResponseConnection(src, dst *net.TCPConn, Callback parser.Callback) { 162 | // directional copy (64k buffer) 163 | buff := make([]byte, 0xffff) 164 | 165 | for { 166 | n, err := src.Read(buff) 167 | if err != nil { 168 | p.err("Read failed '%s'\n", err) 169 | return 170 | } 171 | b := setResponseBuffer(p.erred, buff[:n], Callback) 172 | 173 | n, err = dst.Write(b) 174 | if err != nil { 175 | p.err("Write failed '%s'\n", err) 176 | return 177 | } 178 | } 179 | } 180 | 181 | // ModifiedBuffer when is local and will call filterCallback function 182 | func getModifiedBuffer(buffer []byte, filterCallback parser.Callback) (b []byte, err error) { 183 | if len(buffer) > 0 && string(buffer[0]) == "Q" { 184 | if !filterCallback(buffer) { 185 | return nil, errors.New(fmt.Sprintf("Do not meet the rules of the sql statement %s", string(buffer[1:]))) 186 | } 187 | } 188 | 189 | return buffer, nil 190 | } 191 | 192 | // ResponseBuffer when is local and will call returnCallback function 193 | func setResponseBuffer(iserr bool, buffer []byte, filterCallback parser.Callback) (b []byte) { 194 | if len(buffer) > 0 && string(buffer[0]) == "Q" { 195 | if !filterCallback(buffer) { 196 | return nil 197 | } 198 | } 199 | 200 | return buffer 201 | } 202 | -------------------------------------------------------------------------------- /proxy/proxy_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/jmoiron/sqlx" 11 | _ "github.com/lib/pq" 12 | "github.com/wgliang/pgproxy/parser" 13 | ) 14 | 15 | var ( 16 | testProxyHost = "127.0.0.1:9090" 17 | testRemoteHost = "127.0.0.1:5432" 18 | ) 19 | 20 | func Benchmark_Start(b *testing.B) { 21 | go Start(testProxyHost, testRemoteHost, parser.GetQueryModificada) 22 | time.Sleep(3 * time.Second) 23 | 24 | db, err := sqlx.Open("postgres", "host=127.0.0.1 user=postgres password=xxxxx dbname=db port=9090 sslmode=disable") 25 | if err != nil { 26 | b.Error(err) 27 | } 28 | db.SetMaxIdleConns(1) 29 | db.SetMaxOpenConns(100) 30 | 31 | for i := 0; i < b.N; i++ { 32 | sql := fmt.Sprintf("select id from client where id = %d", i) 33 | fmt.Println(sql) 34 | rows, err := db.Query(sql) 35 | if err != nil { 36 | b.Error(err) 37 | } else { 38 | for rows.Next() { 39 | var n int 40 | err = rows.Scan(&n) 41 | if err != nil { 42 | b.Error(err) 43 | } else { 44 | if n != i { 45 | b.Errorf("result is not match,n=%d but id=%d", n, i) 46 | } 47 | } 48 | } 49 | } 50 | } 51 | db.Close() 52 | os.Exit(0) 53 | } 54 | 55 | func Test_Start(t *testing.T) { 56 | go Start(testProxyHost, testRemoteHost, parser.GetQueryModificada) 57 | time.Sleep(3 * time.Second) 58 | 59 | db, err := sqlx.Open("postgres", "host=127.0.0.1 user=postgres password=xxxxx dbname=db port=9090 sslmode=disable") 60 | if err != nil { 61 | t.Error(err) 62 | } 63 | db.SetMaxIdleConns(1) 64 | db.SetMaxOpenConns(100) 65 | 66 | rows, err := db.Query("select id from client where id = 8 ") 67 | if err != nil { 68 | t.Error(err) 69 | } else { 70 | for rows.Next() { 71 | var n int32 72 | err = rows.Scan(&n) 73 | if err != nil { 74 | t.Error(err) 75 | } else { 76 | if n != 8 { 77 | t.Errorf("result is not match,n=%d but id=8", n) 78 | } 79 | } 80 | } 81 | } 82 | db.Close() 83 | os.Exit(0) 84 | } 85 | 86 | func Test_getResolvedAddresses(t *testing.T) { 87 | getResolvedAddresses("127.0.0.1:9090", "127.0.0.1:8080") 88 | } 89 | 90 | func Test_getListener(t *testing.T) { 91 | paddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9090") 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | getListener(paddr) 96 | } 97 | -------------------------------------------------------------------------------- /version: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.1" --------------------------------------------------------------------------------