├── .gitignore ├── LICENSE ├── README.md ├── certs.go ├── cmd └── main │ └── main.go ├── credits.go ├── messageserv.go ├── proxy.go ├── proxyhttp.go ├── proxyhttp_test.go ├── proxylistener.go ├── proxymessages.go ├── schema.go ├── search.go ├── search_test.go ├── signer.go ├── sqlitestorage.go ├── sqlitestorage_test.go ├── storage.go ├── testutil.go ├── util.go └── webui.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Rob Glew 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The Puppy Proxy 2 | =============== 3 | 4 | What is this? 5 | ------------- 6 | Puppy is a golang library that can be used to create proxies to intercept and modify HTTP and websocket messages that pass through it. Puppy itself does not provide any interactive interface, it provides an API to do proxy things in go. If you want a useful tool that uses Puppy, try [Pappy](https://github.com/roglew/pappy-proxy). 7 | 8 | Puppy was originally aimed to be a starting point to write a tool similar to [Burp Suite](https://portswigger.net/burp/) and to provide a base for writing other HTTP proxy software. 9 | 10 | Features 11 | -------- 12 | 13 | * Intercept and modify any HTTP messages passing through the proxy 14 | * Websocket support 15 | * Use custom CA certificate to strip TLS from HTTPS connections 16 | * Built in IPC API 17 | * Support for transparent request redirection 18 | * Built in support for writing messages to SQLite database 19 | * Flexible history search 20 | 21 | Example 22 | ------- 23 | 24 | The following example creates a simple proxy which listens on port 8080. In order to send HTTPS traffic through the proxy, you must add the generated server.pem certificate as a CA to your browser. 25 | 26 | ```go 27 | package main 28 | 29 | import ( 30 | "fmt" 31 | "net" 32 | "os" 33 | "path" 34 | "puppy" 35 | ) 36 | 37 | func checkerr(err error) { 38 | if err != nil { 39 | panic(err) 40 | } 41 | } 42 | 43 | func main() { 44 | // Create the proxy without a logger 45 | iproxy := puppy.NewInterceptingProxy(nil) 46 | 47 | // Load the CA certs 48 | ex, err := os.Executable() 49 | checkerr(err) 50 | certFile := path.Dir(ex) + "/server.pem" 51 | pkeyFile := path.Dir(ex) + "/server.key" 52 | err = iproxy.LoadCACertificates(certFile, pkeyFile) 53 | if err != nil { 54 | // Try generating the certs in case they're missing 55 | _, err := puppy.GenerateCACertsToDisk(certFile, pkeyFile) 56 | checkerr(err) 57 | err = iproxy.LoadCACertificates(certFile, pkeyFile) 58 | checkerr(err) 59 | } 60 | 61 | // Listen on port 8080 62 | listener, err := net.Listen("tcp", "127.0.0.1:8080") 63 | checkerr(err) 64 | iproxy.AddListener(listener) 65 | 66 | // Wait for exit 67 | fmt.Println("Proxy is running on localhost:8080") 68 | select {} 69 | } 70 | ``` 71 | 72 | Next, we will demonstrate editing messages by turning the proxy into a cloud2butt proxy which will replace every instance of the word "cloud" with the word "butt". This is done by writing a function that takes in a request and a response and returns a new response then adding it to the proxy: 73 | 74 | ```go 75 | package main 76 | 77 | import ( 78 | "bytes" 79 | "fmt" 80 | "net" 81 | "os" 82 | "path" 83 | "puppy" 84 | ) 85 | 86 | func checkerr(err error) { 87 | if err != nil { 88 | panic(err) 89 | } 90 | } 91 | 92 | func main() { 93 | // Create the proxy without a logger 94 | iproxy := puppy.NewInterceptingProxy(nil) 95 | 96 | // Load the CA certs 97 | ex, err := os.Executable() 98 | checkerr(err) 99 | certFile := path.Dir(ex) + "/server.pem" 100 | pkeyFile := path.Dir(ex) + "/server.key" 101 | err = iproxy.LoadCACertificates(certFile, pkeyFile) 102 | if err != nil { 103 | // Try generating the certs in case they're missing 104 | _, err := puppy.GenerateCACertsToDisk(certFile, pkeyFile) 105 | checkerr(err) 106 | err = iproxy.LoadCACertificates(certFile, pkeyFile) 107 | checkerr(err) 108 | } 109 | 110 | // Cloud2Butt interceptor 111 | var cloud2butt = func(req *puppy.ProxyRequest, rsp *puppy.ProxyResponse) (*puppy.ProxyResponse, error) { 112 | newBody := rsp.BodyBytes() 113 | newBody = bytes.Replace(newBody, []byte("cloud"), []byte("butt"), -1) 114 | newBody = bytes.Replace(newBody, []byte("Cloud"), []byte("Butt"), -1) 115 | rsp.SetBodyBytes(newBody) 116 | return rsp, nil 117 | } 118 | iproxy.AddRspInterceptor(cloud2butt) 119 | 120 | // Listen on port 8080 121 | listener, err := net.Listen("tcp", "127.0.0.1:8080") 122 | checkerr(err) 123 | iproxy.AddListener(listener) 124 | 125 | // Wait for exit 126 | fmt.Println("Proxy is running on localhost:8080") 127 | select {} 128 | } 129 | ``` 130 | 131 | For more information, check out the documentation. -------------------------------------------------------------------------------- /certs.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/sha1" 7 | "crypto/x509" 8 | "crypto/x509/pkix" 9 | "encoding/pem" 10 | "fmt" 11 | "math/big" 12 | "os" 13 | "time" 14 | ) 15 | 16 | // A certificate/private key pair 17 | type CAKeyPair struct { 18 | Certificate []byte 19 | PrivateKey *rsa.PrivateKey 20 | } 21 | 22 | func bigIntHash(n *big.Int) []byte { 23 | h := sha1.New() 24 | h.Write(n.Bytes()) 25 | return h.Sum(nil) 26 | } 27 | 28 | // GenerateCACerts generates a random CAKeyPair 29 | func GenerateCACerts() (*CAKeyPair, error) { 30 | key, err := rsa.GenerateKey(rand.Reader, 2048) 31 | if err != nil { 32 | return nil, fmt.Errorf("error generating private key: %s", err.Error()) 33 | } 34 | 35 | serial := new(big.Int) 36 | b := make([]byte, 20) 37 | _, err = rand.Read(b) 38 | if err != nil { 39 | return nil, fmt.Errorf("error generating serial: %s", err.Error()) 40 | } 41 | serial.SetBytes(b) 42 | 43 | end, err := time.Parse("2006-01-02", "2049-12-31") 44 | template := x509.Certificate{ 45 | SerialNumber: serial, 46 | Subject: pkix.Name{ 47 | CommonName: "Puppy Proxy", 48 | Organization: []string{"Puppy Proxy"}, 49 | }, 50 | NotBefore: time.Now().Add(-5 * time.Minute).UTC(), 51 | NotAfter: end, 52 | 53 | SubjectKeyId: bigIntHash(key.N), 54 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, 55 | BasicConstraintsValid: true, 56 | IsCA: true, 57 | MaxPathLenZero: true, 58 | } 59 | 60 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) 61 | if err != nil { 62 | return nil, fmt.Errorf("error generating certificate: %s", err.Error()) 63 | } 64 | 65 | return &CAKeyPair{ 66 | Certificate: derBytes, 67 | PrivateKey: key, 68 | }, nil 69 | } 70 | 71 | // Generate a pair of certificates and write them to the disk. Returns the generated keypair 72 | func GenerateCACertsToDisk(CertificateFile string, PrivateKeyFile string) (*CAKeyPair, error) { 73 | pair, err := GenerateCACerts() 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | pkeyFile, err := os.OpenFile(PrivateKeyFile, os.O_RDWR|os.O_CREATE, 0600) 79 | if err != nil { 80 | return nil, err 81 | } 82 | pkeyFile.Write(pair.PrivateKeyPEM()) 83 | if err := pkeyFile.Close(); err != nil { 84 | return nil, err 85 | } 86 | 87 | certFile, err := os.OpenFile(CertificateFile, os.O_RDWR|os.O_CREATE, 0600) 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | certFile.Write(pair.CACertPEM()) 93 | if err := certFile.Close(); err != nil { 94 | return nil, err 95 | } 96 | 97 | return pair, nil 98 | } 99 | 100 | 101 | // PrivateKeyPEM returns the private key of the CAKeyPair PEM encoded 102 | func (pair *CAKeyPair) PrivateKeyPEM() []byte { 103 | return pem.EncodeToMemory( 104 | &pem.Block{ 105 | Type: "BEGIN PRIVATE KEY", 106 | Bytes: x509.MarshalPKCS1PrivateKey(pair.PrivateKey), 107 | }, 108 | ) 109 | } 110 | 111 | // PrivateKeyPEM returns the CA cert of the CAKeyPair PEM encoded 112 | func (pair *CAKeyPair) CACertPEM() []byte { 113 | return pem.EncodeToMemory( 114 | &pem.Block{ 115 | Type: "CERTIFICATE", 116 | Bytes: pair.Certificate, 117 | }, 118 | ) 119 | } 120 | -------------------------------------------------------------------------------- /cmd/main/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "flag" 6 | "fmt" 7 | "io/ioutil" 8 | "log" 9 | "net" 10 | "os" 11 | "os/signal" 12 | "strings" 13 | "syscall" 14 | "time" 15 | 16 | "puppy" 17 | ) 18 | 19 | var logBanner string = ` 20 | ======================================== 21 | PUPPYSTARTEDPUPPYSTARTEDPUPPYSTARTEDPUPP 22 | .--. .---. 23 | /:. '. .' .. '._.---. 24 | /:::-. \.-"""-;' .-:::. .::\ 25 | /::'| '\/ _ _ \' '\:' ::::| 26 | __.' | / (o|o) \ ''. ':/ 27 | / .:. / | ___ | '---' 28 | | ::::' /: (._.) .:\ 29 | \ .=' |:' :::| 30 | '""' \ .-. ':/ 31 | '---'|I|'---' 32 | jgs '-' 33 | PUPPYSTARTEDPUPPYSTARTEDPUPPYSTARTEDPUPP 34 | ======================================== 35 | ` 36 | 37 | type listenArg struct { 38 | Type string 39 | Addr string 40 | } 41 | 42 | func quitErr(msg string) { 43 | os.Stderr.WriteString(msg) 44 | os.Stderr.WriteString("\n") 45 | os.Exit(1) 46 | } 47 | 48 | func checkErr(err error) { 49 | if err != nil { 50 | quitErr(err.Error()) 51 | } 52 | } 53 | 54 | func parseListenString(lstr string) (*listenArg, error) { 55 | args := strings.SplitN(lstr, ":", 2) 56 | if len(args) != 2 { 57 | return nil, errors.New("invalid listener. Must be in the form of \"tye:addr\"") 58 | } 59 | argStruct := &listenArg{ 60 | Type: strings.ToLower(args[0]), 61 | Addr: args[1], 62 | } 63 | if argStruct.Type != "tcp" && argStruct.Type != "unix" { 64 | return nil, fmt.Errorf("invalid listener type: %s", argStruct.Type) 65 | } 66 | return argStruct, nil 67 | } 68 | 69 | func unixAddr() string { 70 | return fmt.Sprintf("%s/proxy.%d.%d.sock", os.TempDir(), os.Getpid(), time.Now().UnixNano()) 71 | } 72 | 73 | var mln net.Listener 74 | var logger *log.Logger 75 | 76 | func cleanup() { 77 | if mln != nil { 78 | mln.Close() 79 | } 80 | } 81 | 82 | var MainLogger *log.Logger 83 | 84 | func main() { 85 | defer cleanup() 86 | // Handle signals 87 | sigc := make(chan os.Signal, 1) 88 | signal.Notify(sigc, os.Interrupt, os.Kill, syscall.SIGTERM) 89 | go func() { 90 | <-sigc 91 | if logger != nil { 92 | logger.Println("Caught signal. Cleaning up.") 93 | } 94 | cleanup() 95 | os.Exit(0) 96 | }() 97 | 98 | msgListenStr := flag.String("msglisten", "", "Listener for the message handler. Examples: \"tcp::8080\", \"tcp:127.0.0.1:8080\", \"unix:/tmp/foobar\"") 99 | autoListen := flag.Bool("msgauto", false, "Automatically pick and open a unix or tcp socket for the message listener") 100 | debugFlag := flag.Bool("dbg", false, "Enable debug logging") 101 | flag.Parse() 102 | 103 | if *debugFlag { 104 | logfile, err := os.OpenFile("log.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) 105 | checkErr(err) 106 | logger = log.New(logfile, "[*] ", log.Lshortfile) 107 | } else { 108 | logger = log.New(ioutil.Discard, "[*] ", log.Lshortfile) 109 | log.SetFlags(0) 110 | } 111 | MainLogger = logger 112 | 113 | // Parse arguments to structs 114 | if *msgListenStr == "" && *autoListen == false { 115 | quitErr("message listener address or `--msgauto` required") 116 | } 117 | if *msgListenStr != "" && *autoListen == true { 118 | quitErr("only one of listener address or `--msgauto` can be used") 119 | } 120 | 121 | // Create the message listener 122 | var listenStr string 123 | if *msgListenStr != "" { 124 | msgAddr, err := parseListenString(*msgListenStr) 125 | checkErr(err) 126 | if msgAddr.Type == "tcp" { 127 | var err error 128 | mln, err = net.Listen("tcp", msgAddr.Addr) 129 | checkErr(err) 130 | } else if msgAddr.Type == "unix" { 131 | var err error 132 | mln, err = net.Listen("unix", msgAddr.Addr) 133 | checkErr(err) 134 | } else { 135 | quitErr("unsupported listener type:" + msgAddr.Type) 136 | } 137 | listenStr = fmt.Sprintf("%s:%s", msgAddr.Type, msgAddr.Addr) 138 | } else { 139 | fpath := unixAddr() 140 | ulisten, err := net.Listen("unix", fpath) 141 | if err == nil { 142 | mln = ulisten 143 | listenStr = fmt.Sprintf("unix:%s", fpath) 144 | } else { 145 | tcplisten, err := net.Listen("tcp", "127.0.0.1:0") 146 | if err != nil { 147 | quitErr("unable to open any messaging ports") 148 | } 149 | mln = tcplisten 150 | listenStr = fmt.Sprintf("tcp:%s", tcplisten.Addr().String()) 151 | } 152 | } 153 | 154 | // Set up the intercepting proxy 155 | iproxy := puppy.NewInterceptingProxy(logger) 156 | iproxy.AddHTTPHandler("puppy", puppy.CreateWebUIHandler()) 157 | 158 | // Create a message server and have it serve for the iproxy 159 | mserv := puppy.NewProxyMessageListener(logger, iproxy) 160 | logger.Print(logBanner) 161 | fmt.Println(listenStr) 162 | mserv.Serve(mln) // serve until killed 163 | } 164 | -------------------------------------------------------------------------------- /credits.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | /* 4 | List of info that is used to display credits 5 | */ 6 | 7 | type creditItem struct { 8 | projectName string 9 | url string 10 | author string 11 | year string 12 | licenseType string 13 | longCopyright string 14 | } 15 | 16 | var lib_credits = []creditItem{ 17 | creditItem{ 18 | "goproxy", 19 | "https://github.com/elazarl/goproxy", 20 | "Elazar Leibovich", 21 | "2012", 22 | "3-Clause BSD", 23 | `Copyright (c) 2012 Elazar Leibovich. All rights reserved. 24 | 25 | Redistribution and use in source and binary forms, with or without 26 | modification, are permitted provided that the following conditions are 27 | met: 28 | 29 | * Redistributions of source code must retain the above copyright 30 | notice, this list of conditions and the following disclaimer. 31 | * Redistributions in binary form must reproduce the above 32 | copyright notice, this list of conditions and the following disclaimer 33 | in the documentation and/or other materials provided with the 34 | distribution. 35 | * Neither the name of Elazar Leibovich. nor the names of its 36 | contributors may be used to endorse or promote products derived from 37 | this software without specific prior written permission. 38 | 39 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 40 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 41 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 42 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 43 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 44 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 45 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 46 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 47 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 48 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 49 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`, 50 | }, 51 | 52 | creditItem{ 53 | "golang-set", 54 | "https://github.com/deckarep/golang-set", 55 | "Ralph Caraveo", 56 | "2013", 57 | "MIT", 58 | `Open Source Initiative OSI - The MIT License (MIT):Licensing 59 | 60 | The MIT License (MIT) 61 | Copyright (c) 2013 Ralph Caraveo (deckarep@gmail.com) 62 | 63 | Permission is hereby granted, free of charge, to any person obtaining a copy of 64 | this software and associated documentation files (the "Software"), to deal in 65 | the Software without restriction, including without limitation the rights to 66 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 67 | of the Software, and to permit persons to whom the Software is furnished to do 68 | so, subject to the following conditions: 69 | 70 | The above copyright notice and this permission notice shall be included in all 71 | copies or substantial portions of the Software. 72 | 73 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 74 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 75 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 76 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 77 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 78 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 79 | SOFTWARE.`, 80 | }, 81 | 82 | creditItem{ 83 | "Gorilla WebSocket", 84 | "https://github.com/gorilla/websocket", 85 | "Gorilla WebSocket Authors", 86 | "2013", 87 | "2-Clause BSD", 88 | `Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. 89 | 90 | Redistribution and use in source and binary forms, with or without 91 | modification, are permitted provided that the following conditions are met: 92 | 93 | Redistributions of source code must retain the above copyright notice, this 94 | list of conditions and the following disclaimer. 95 | 96 | Redistributions in binary form must reproduce the above copyright notice, 97 | this list of conditions and the following disclaimer in the documentation 98 | and/or other materials provided with the distribution. 99 | 100 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 101 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 102 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 103 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 104 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 105 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 106 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 107 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 108 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 109 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`, 110 | }, 111 | } 112 | -------------------------------------------------------------------------------- /messageserv.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | "strings" 11 | ) 12 | 13 | /* 14 | Message Server 15 | */ 16 | 17 | // A handler to handle a JSON message 18 | type MessageHandler func(message []byte, conn net.Conn, logger *log.Logger, iproxy *InterceptingProxy) 19 | 20 | // A listener that handles reading JSON messages and sending them to the correct handler 21 | type MessageListener struct { 22 | handlers map[string]MessageHandler 23 | iproxy *InterceptingProxy 24 | Logger *log.Logger 25 | } 26 | 27 | type commandData struct { 28 | Command string 29 | } 30 | 31 | type errorMessage struct { 32 | Success bool 33 | Reason string 34 | } 35 | 36 | // NewMessageListener creates a new message listener associated with the given intercepting proxy 37 | func NewMessageListener(l *log.Logger, iproxy *InterceptingProxy) *MessageListener { 38 | m := &MessageListener{ 39 | handlers: make(map[string]MessageHandler), 40 | iproxy: iproxy, 41 | Logger: l, 42 | } 43 | return m 44 | } 45 | 46 | // AddHandler will have the listener call the given handler when the "Command" parameter matches the given value 47 | func (l *MessageListener) AddHandler(command string, handler MessageHandler) { 48 | l.handlers[strings.ToLower(command)] = handler 49 | } 50 | 51 | func (l *MessageListener) Handle(message []byte, conn net.Conn) error { 52 | var c commandData 53 | if err := json.Unmarshal(message, &c); err != nil { 54 | return fmt.Errorf("error parsing message: %s", err.Error()) 55 | } 56 | 57 | handler, ok := l.handlers[strings.ToLower(c.Command)] 58 | if !ok { 59 | return fmt.Errorf("unknown command: %s", c.Command) 60 | } 61 | 62 | l.Logger.Printf("Calling handler for \"%s\"...", c.Command) 63 | handler(message, conn, l.Logger, l.iproxy) 64 | return nil 65 | } 66 | 67 | // Serve will have the listener serve messages on the given listener 68 | func (l *MessageListener) Serve(nl net.Listener) { 69 | for { 70 | conn, err := nl.Accept() 71 | if err != nil { 72 | // Listener closed 73 | break 74 | } 75 | 76 | reader := bufio.NewReader(conn) 77 | go func() { 78 | for { 79 | m, err := ReadMessage(reader) 80 | l.Logger.Printf("> %s\n", m) 81 | if err != nil { 82 | if err != io.EOF { 83 | ErrorResponse(conn, "error reading message") 84 | } 85 | return 86 | } 87 | err = l.Handle(m, conn) 88 | if err != nil { 89 | ErrorResponse(conn, err.Error()) 90 | } 91 | } 92 | }() 93 | } 94 | } 95 | 96 | // Error response writes an error message to the given writer 97 | func ErrorResponse(w io.Writer, reason string) { 98 | var m errorMessage 99 | m.Success = false 100 | m.Reason = reason 101 | MessageResponse(w, m) 102 | } 103 | 104 | // MessageResponse writes a response to a given writer 105 | func MessageResponse(w io.Writer, m interface{}) { 106 | b, err := json.Marshal(&m) 107 | if err != nil { 108 | panic(err) 109 | } 110 | w.Write(b) 111 | w.Write([]byte("\n")) 112 | } 113 | 114 | // ReadMessage reads a message from the given reader 115 | func ReadMessage(r *bufio.Reader) ([]byte, error) { 116 | m, err := r.ReadBytes('\n') 117 | if err != nil { 118 | return nil, err 119 | } 120 | return m, nil 121 | } 122 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // Puppy provices an interface to create a proxy to intercept and modify HTTP and websocket messages passing through the proxy 2 | package puppy 3 | 4 | import ( 5 | "crypto/tls" 6 | "encoding/base64" 7 | "fmt" 8 | "io/ioutil" 9 | "log" 10 | "net" 11 | "net/http" 12 | "sync" 13 | "time" 14 | 15 | "github.com/gorilla/websocket" 16 | ) 17 | 18 | var getNextSubId = IdCounter() 19 | var getNextStorageId = IdCounter() 20 | 21 | // ProxyWebUIHandler is a function that can be used for handling web requests intended to be handled by the proxy 22 | type ProxyWebUIHandler func(http.ResponseWriter, *http.Request, *InterceptingProxy) 23 | 24 | type savedStorage struct { 25 | storage MessageStorage 26 | description string 27 | } 28 | 29 | type GlobalStorageWatcher interface { 30 | // Callback for when a new request is saved 31 | NewRequestSaved(storageId int, ms MessageStorage, req *ProxyRequest) 32 | // Callback for when a request is updated 33 | RequestUpdated(storageId int, ms MessageStorage, req *ProxyRequest) 34 | // Callback for when a request is deleted 35 | RequestDeleted(storageId int, ms MessageStorage, DbId string) 36 | 37 | // Callback for when a new response is saved 38 | NewResponseSaved(storageId int, ms MessageStorage, rsp *ProxyResponse) 39 | // Callback for when a response is updated 40 | ResponseUpdated(storageId int, ms MessageStorage, rsp *ProxyResponse) 41 | // Callback for when a response is deleted 42 | ResponseDeleted(storageId int, ms MessageStorage, DbId string) 43 | 44 | // Callback for when a new wsmessage is saved 45 | NewWSMessageSaved(storageId int, ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) 46 | // Callback for when a wsmessage is updated 47 | WSMessageUpdated(storageId int, ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) 48 | // Callback for when a wsmessage is deleted 49 | WSMessageDeleted(storageId int, ms MessageStorage, DbId string) 50 | } 51 | 52 | type globalWatcher struct { 53 | watchers []GlobalStorageWatcher 54 | } 55 | 56 | type globalWatcherShim struct { 57 | storageId int 58 | globWatcher *globalWatcher 59 | logger *log.Logger 60 | } 61 | 62 | // InterceptingProxy is a struct which represents a proxy which can intercept and modify HTTP and websocket messages 63 | type InterceptingProxy struct { 64 | slistener *ProxyListener 65 | server *http.Server 66 | mtx sync.Mutex 67 | logger *log.Logger 68 | proxyStorage int 69 | netDial NetDialer 70 | 71 | usingProxy bool 72 | proxyHost string 73 | proxyPort int 74 | proxyIsSOCKS bool 75 | proxyCreds *ProxyCredentials 76 | 77 | requestInterceptor RequestInterceptor 78 | responseInterceptor ResponseInterceptor 79 | wSInterceptor WSInterceptor 80 | scopeChecker RequestChecker 81 | scopeQuery MessageQuery 82 | 83 | reqSubs []*ReqIntSub 84 | rspSubs []*RspIntSub 85 | wsSubs []*WSIntSub 86 | 87 | httpHandlers map[string]ProxyWebUIHandler 88 | 89 | messageStorage map[int]*savedStorage 90 | globWatcher *globalWatcher 91 | } 92 | 93 | // ProxyCredentials are a username/password combination used to represent an HTTP BasicAuth session 94 | type ProxyCredentials struct { 95 | Username string 96 | Password string 97 | } 98 | 99 | // RequestInterceptor is a function that takes in a ProxyRequest and returns a modified ProxyRequest or nil to represent dropping the request 100 | type RequestInterceptor func(req *ProxyRequest) (*ProxyRequest, error) 101 | 102 | // ResponseInterceptor is a function that takes in a ProxyResponse and the original request and returns a modified ProxyResponse or nil to represent dropping the response 103 | type ResponseInterceptor func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error) 104 | 105 | // WSInterceptor is a function that takes in a ProxyWSMessage and the ProxyRequest/ProxyResponse which made up its handshake and returns and returns a modified ProxyWSMessage or nil to represent dropping the message. A WSInterceptor should be able to modify messages originating from both the client and the remote server. 106 | type WSInterceptor func(req *ProxyRequest, rsp *ProxyResponse, msg *ProxyWSMessage) (*ProxyWSMessage, error) 107 | 108 | // ReqIntSub represents an active HTTP request interception session in an InterceptingProxy 109 | type ReqIntSub struct { 110 | id int 111 | Interceptor RequestInterceptor 112 | } 113 | 114 | // RspIntSub represents an active HTTP response interception session in an InterceptingProxy 115 | type RspIntSub struct { 116 | id int 117 | Interceptor ResponseInterceptor 118 | } 119 | 120 | // WSIntSub represents an active websocket interception session in an InterceptingProxy 121 | type WSIntSub struct { 122 | id int 123 | Interceptor WSInterceptor 124 | } 125 | 126 | // SerializeHeader serializes the ProxyCredentials into a value that can be included in an Authorization header 127 | func (creds *ProxyCredentials) SerializeHeader() string { 128 | toEncode := []byte(fmt.Sprintf("%s:%s", creds.Username, creds.Password)) 129 | encoded := base64.StdEncoding.EncodeToString(toEncode) 130 | return fmt.Sprintf("Basic %s", encoded) 131 | } 132 | 133 | // NewInterceptingProxy will create a new InterceptingProxy and have it log using the provided logger. If logger is nil, the proxy will log to ioutil.Discard 134 | func NewInterceptingProxy(logger *log.Logger) *InterceptingProxy { 135 | var iproxy InterceptingProxy 136 | var useLogger *log.Logger 137 | if logger != nil { 138 | useLogger = logger 139 | } else { 140 | useLogger = log.New(ioutil.Discard, "[*] ", log.Lshortfile) 141 | } 142 | 143 | iproxy.messageStorage = make(map[int]*savedStorage) 144 | iproxy.slistener = NewProxyListener(useLogger) 145 | iproxy.server = newProxyServer(useLogger, &iproxy) 146 | iproxy.logger = useLogger 147 | iproxy.httpHandlers = make(map[string]ProxyWebUIHandler) 148 | iproxy.globWatcher = &globalWatcher{ 149 | watchers: make([]GlobalStorageWatcher, 0), 150 | } 151 | 152 | go func() { 153 | iproxy.server.Serve(iproxy.slistener) 154 | }() 155 | return &iproxy 156 | } 157 | 158 | // Close closes all listeners being used by the proxy. Does not shut down internal HTTP server because there is no way to gracefully shut down an http server yet. 159 | func (iproxy *InterceptingProxy) Close() { 160 | // Will throw errors when the server finally shuts down and tries to call iproxy.slistener.Close a second time 161 | iproxy.mtx.Lock() 162 | defer iproxy.mtx.Unlock() 163 | iproxy.slistener.Close() 164 | //iproxy.server.Close() // Coming eventually... I hope 165 | } 166 | 167 | // LoadCACertificates loads a private/public key pair which should be used when generating self-signed certs for TLS connections 168 | func (iproxy *InterceptingProxy) LoadCACertificates(certFile, keyFile string) error { 169 | caCert, err := tls.LoadX509KeyPair(certFile, keyFile) 170 | if err != nil { 171 | return fmt.Errorf("could not load certificate pair: %s", err.Error()) 172 | } 173 | 174 | iproxy.SetCACertificate(&caCert) 175 | return nil 176 | } 177 | 178 | // SetCACertificate sets certificate which should be used when generating self-signed certs for TLS connections 179 | func (iproxy *InterceptingProxy) SetCACertificate(caCert *tls.Certificate) { 180 | if iproxy.slistener == nil { 181 | panic("intercepting proxy does not have a proxy listener") 182 | } 183 | iproxy.slistener.SetCACertificate(caCert) 184 | } 185 | 186 | // GetCACertificate returns certificate used to self-sign certificates for TLS connections 187 | func (iproxy *InterceptingProxy) GetCACertificate() *tls.Certificate { 188 | return iproxy.slistener.GetCACertificate() 189 | } 190 | 191 | // AddListener will have the proxy listen for HTTP connections on a listener. Proxy will attempt to strip TLS from the connection 192 | func (iproxy *InterceptingProxy) AddListener(l net.Listener) { 193 | iproxy.mtx.Lock() 194 | defer iproxy.mtx.Unlock() 195 | iproxy.slistener.AddListener(l) 196 | } 197 | 198 | // Have the proxy listen for HTTP connections on a listener and transparently redirect them to the destination. Listeners added this way can only redirect requests to a single destination. However, it does not rely on the client being aware that it is using an HTTP proxy. 199 | func (iproxy *InterceptingProxy) AddTransparentListener(l net.Listener, destHost string, destPort int, useTLS bool) { 200 | iproxy.mtx.Lock() 201 | defer iproxy.mtx.Unlock() 202 | iproxy.slistener.AddTransparentListener(l, destHost, destPort, useTLS) 203 | } 204 | 205 | // RemoveListner will have the proxy stop listening to a listener 206 | func (iproxy *InterceptingProxy) RemoveListener(l net.Listener) { 207 | iproxy.mtx.Lock() 208 | defer iproxy.mtx.Unlock() 209 | iproxy.slistener.RemoveListener(l) 210 | } 211 | 212 | // GetMessageStorage takes in a storage ID and returns the storage associated with the ID 213 | func (iproxy *InterceptingProxy) GetMessageStorage(id int) (MessageStorage, string) { 214 | iproxy.mtx.Lock() 215 | defer iproxy.mtx.Unlock() 216 | savedStorage, ok := iproxy.messageStorage[id] 217 | if !ok { 218 | return nil, "" 219 | } 220 | return savedStorage.storage, savedStorage.description 221 | } 222 | 223 | // AddMessageStorage associates a MessageStorage with the proxy and returns an ID to be used when referencing the storage in the future 224 | func (iproxy *InterceptingProxy) AddMessageStorage(storage MessageStorage, description string) int { 225 | iproxy.mtx.Lock() 226 | defer iproxy.mtx.Unlock() 227 | id := getNextStorageId() 228 | iproxy.messageStorage[id] = &savedStorage{storage, description} 229 | 230 | shim := &globalWatcherShim{ 231 | storageId: id, 232 | globWatcher: iproxy.globWatcher, 233 | logger: iproxy.logger, 234 | } 235 | storage.Watch(shim) 236 | return id 237 | } 238 | 239 | // CloseMessageStorage closes a message storage associated with the proxy 240 | func (iproxy *InterceptingProxy) CloseMessageStorage(id int) { 241 | iproxy.mtx.Lock() 242 | defer iproxy.mtx.Unlock() 243 | savedStorage, ok := iproxy.messageStorage[id] 244 | if !ok { 245 | return 246 | } 247 | delete(iproxy.messageStorage, id) 248 | savedStorage.storage.Close() 249 | } 250 | 251 | // SavedStorage represents a storage associated with the proxy 252 | type SavedStorage struct { 253 | Id int 254 | Storage MessageStorage 255 | Description string 256 | } 257 | 258 | // ListMessageStorage returns a list of storages associated with the proxy 259 | func (iproxy *InterceptingProxy) ListMessageStorage() []*SavedStorage { 260 | iproxy.mtx.Lock() 261 | defer iproxy.mtx.Unlock() 262 | 263 | r := make([]*SavedStorage, 0) 264 | for id, ss := range iproxy.messageStorage { 265 | r = append(r, &SavedStorage{id, ss.storage, ss.description}) 266 | } 267 | return r 268 | } 269 | 270 | func (iproxy *InterceptingProxy) getRequestSubs() []*ReqIntSub { 271 | iproxy.mtx.Lock() 272 | defer iproxy.mtx.Unlock() 273 | return iproxy.reqSubs 274 | } 275 | 276 | func (iproxy *InterceptingProxy) getResponseSubs() []*RspIntSub { 277 | iproxy.mtx.Lock() 278 | defer iproxy.mtx.Unlock() 279 | return iproxy.rspSubs 280 | } 281 | 282 | func (iproxy *InterceptingProxy) getWSSubs() []*WSIntSub { 283 | iproxy.mtx.Lock() 284 | defer iproxy.mtx.Unlock() 285 | return iproxy.wsSubs 286 | } 287 | 288 | // LoadScope loads the scope from the given storage and applies it to the proxy 289 | func (iproxy *InterceptingProxy) LoadScope(storageId int) error { 290 | // Try and set the scope 291 | savedStorage, ok := iproxy.messageStorage[storageId] 292 | if !ok { 293 | return fmt.Errorf("proxy has no associated storage") 294 | } 295 | iproxy.logger.Println("loading scope") 296 | if scope, err := savedStorage.storage.LoadQuery("__scope"); err == nil { 297 | if err := iproxy.setScopeQuery(scope); err != nil { 298 | iproxy.logger.Println("error setting scope:", err.Error()) 299 | } 300 | } else { 301 | iproxy.logger.Println("error loading scope:", err.Error()) 302 | } 303 | return nil 304 | } 305 | 306 | // GetScopeChecker creates a RequestChecker which checks if a request matches the proxy's current scope 307 | func (iproxy *InterceptingProxy) GetScopeChecker() RequestChecker { 308 | iproxy.mtx.Lock() 309 | defer iproxy.mtx.Unlock() 310 | return iproxy.scopeChecker 311 | } 312 | 313 | // SetScopeChecker has the proxy use a specific RequestChecker to check if a request is in scope. If the checker returns true for a request it is considered in scope. Otherwise it is considered out of scope. 314 | func (iproxy *InterceptingProxy) SetScopeChecker(checker RequestChecker) error { 315 | iproxy.mtx.Lock() 316 | defer iproxy.mtx.Unlock() 317 | savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] 318 | if !ok { 319 | savedStorage = nil 320 | } 321 | iproxy.scopeChecker = checker 322 | iproxy.scopeQuery = nil 323 | emptyQuery := make(MessageQuery, 0) 324 | if savedStorage != nil { 325 | savedStorage.storage.SaveQuery("__scope", emptyQuery) // Assume it clears it I guess 326 | } 327 | return nil 328 | } 329 | 330 | // GetScopeQuery returns the query associated with the proxy's scope. If the scope was set using SetScopeChecker, nil is returned 331 | func (iproxy *InterceptingProxy) GetScopeQuery() MessageQuery { 332 | iproxy.mtx.Lock() 333 | defer iproxy.mtx.Unlock() 334 | return iproxy.scopeQuery 335 | } 336 | 337 | // SetScopeQuery sets the scope of the proxy to include any request which matches the given MessageQuery 338 | func (iproxy *InterceptingProxy) SetScopeQuery(query MessageQuery) error { 339 | iproxy.mtx.Lock() 340 | defer iproxy.mtx.Unlock() 341 | return iproxy.setScopeQuery(query) 342 | } 343 | 344 | func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) error { 345 | checker, err := CheckerFromMessageQuery(query) 346 | if err != nil { 347 | return err 348 | } 349 | savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] 350 | if !ok { 351 | savedStorage = nil 352 | } 353 | iproxy.scopeChecker = checker 354 | iproxy.scopeQuery = query 355 | if savedStorage != nil { 356 | if err = savedStorage.storage.SaveQuery("__scope", query); err != nil { 357 | return fmt.Errorf("could not save scope to storage: %s", err.Error()) 358 | } 359 | } 360 | 361 | return nil 362 | } 363 | 364 | // ClearScope removes all scope checks from the proxy so that all requests passing through the proxy will be considered in-scope 365 | func (iproxy *InterceptingProxy) ClearScope() error { 366 | iproxy.mtx.Lock() 367 | defer iproxy.mtx.Unlock() 368 | iproxy.scopeChecker = nil 369 | iproxy.scopeChecker = nil 370 | emptyQuery := make(MessageQuery, 0) 371 | savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] 372 | if !ok { 373 | savedStorage = nil 374 | } 375 | if savedStorage != nil { 376 | if err := savedStorage.storage.SaveQuery("__scope", emptyQuery); err != nil { 377 | return fmt.Errorf("could not clear scope in storage: %s", err.Error()) 378 | } 379 | } 380 | return nil 381 | } 382 | 383 | // SetNetDial sets the NetDialer that should be used to create outgoing connections when submitting HTTP requests. Overwrites the request's NetDialer 384 | func (iproxy *InterceptingProxy) SetNetDial(dialer NetDialer) { 385 | iproxy.mtx.Lock() 386 | defer iproxy.mtx.Unlock() 387 | iproxy.netDial = dialer 388 | } 389 | 390 | // NetDial returns the dialer currently being used to create outgoing connections when submitting HTTP requests 391 | func (iproxy *InterceptingProxy) NetDial() NetDialer { 392 | iproxy.mtx.Lock() 393 | defer iproxy.mtx.Unlock() 394 | return iproxy.netDial 395 | } 396 | 397 | // ClearUpstreamProxy stops the proxy from using an upstream proxy for future connections 398 | func (iproxy *InterceptingProxy) ClearUpstreamProxy() { 399 | iproxy.mtx.Lock() 400 | defer iproxy.mtx.Unlock() 401 | iproxy.usingProxy = false 402 | iproxy.proxyHost = "" 403 | iproxy.proxyPort = 0 404 | iproxy.proxyIsSOCKS = false 405 | } 406 | 407 | // SetUpstreamProxy causes the proxy to begin using an upstream HTTP proxy for submitted HTTP requests 408 | func (iproxy *InterceptingProxy) SetUpstreamProxy(proxyHost string, proxyPort int, creds *ProxyCredentials) { 409 | iproxy.mtx.Lock() 410 | defer iproxy.mtx.Unlock() 411 | iproxy.usingProxy = true 412 | iproxy.proxyHost = proxyHost 413 | iproxy.proxyPort = proxyPort 414 | iproxy.proxyIsSOCKS = false 415 | iproxy.proxyCreds = creds 416 | } 417 | 418 | // SetUpstreamSOCKSProxy causes the proxy to begin using an upstream SOCKS proxy for submitted HTTP requests 419 | func (iproxy *InterceptingProxy) SetUpstreamSOCKSProxy(proxyHost string, proxyPort int, creds *ProxyCredentials) { 420 | iproxy.mtx.Lock() 421 | defer iproxy.mtx.Unlock() 422 | iproxy.usingProxy = true 423 | iproxy.proxyHost = proxyHost 424 | iproxy.proxyPort = proxyPort 425 | iproxy.proxyIsSOCKS = true 426 | iproxy.proxyCreds = creds 427 | } 428 | 429 | // SubmitRequest submits a ProxyRequest. Does not automatically save the request/results to proxy storage 430 | func (iproxy *InterceptingProxy) SubmitRequest(req *ProxyRequest) error { 431 | oldDial := req.NetDial 432 | defer func() { req.NetDial = oldDial }() 433 | req.NetDial = iproxy.NetDial() 434 | 435 | if iproxy.usingProxy { 436 | if iproxy.proxyIsSOCKS { 437 | return SubmitRequestSOCKSProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds) 438 | } else { 439 | return SubmitRequestProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds) 440 | } 441 | } 442 | return SubmitRequest(req) 443 | } 444 | 445 | // WSDial dials a remote server and submits the given request to initiate the handshake 446 | func (iproxy *InterceptingProxy) WSDial(req *ProxyRequest) (*WSSession, error) { 447 | oldDial := req.NetDial 448 | defer func() { req.NetDial = oldDial }() 449 | req.NetDial = iproxy.NetDial() 450 | 451 | if iproxy.usingProxy { 452 | if iproxy.proxyIsSOCKS { 453 | return WSDialSOCKSProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds) 454 | } else { 455 | return WSDialProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds) 456 | } 457 | } 458 | return WSDial(req) 459 | } 460 | 461 | // AddReqInterceptor adds a RequestInterceptor to the proxy which will be used to modify HTTP requests as they pass through the proxy. Returns a struct representing the active interceptor. 462 | func (iproxy *InterceptingProxy) AddReqInterceptor(f RequestInterceptor) *ReqIntSub { 463 | iproxy.mtx.Lock() 464 | defer iproxy.mtx.Unlock() 465 | 466 | sub := &ReqIntSub{ 467 | id: getNextSubId(), 468 | Interceptor: f, 469 | } 470 | iproxy.reqSubs = append(iproxy.reqSubs, sub) 471 | return sub 472 | } 473 | 474 | // RemoveReqInterceptor removes an active request interceptor from the proxy 475 | func (iproxy *InterceptingProxy) RemoveReqInterceptor(sub *ReqIntSub) { 476 | iproxy.mtx.Lock() 477 | defer iproxy.mtx.Unlock() 478 | 479 | for i, checkSub := range iproxy.reqSubs { 480 | if checkSub.id == sub.id { 481 | iproxy.reqSubs = append(iproxy.reqSubs[:i], iproxy.reqSubs[i+1:]...) 482 | return 483 | } 484 | } 485 | } 486 | 487 | // AddRspInterceptor adds a ResponseInterceptor to the proxy which will be used to modify HTTP responses as they pass through the proxy. Returns a struct representing the active interceptor. 488 | func (iproxy *InterceptingProxy) AddRspInterceptor(f ResponseInterceptor) *RspIntSub { 489 | iproxy.mtx.Lock() 490 | defer iproxy.mtx.Unlock() 491 | 492 | sub := &RspIntSub{ 493 | id: getNextSubId(), 494 | Interceptor: f, 495 | } 496 | iproxy.rspSubs = append(iproxy.rspSubs, sub) 497 | return sub 498 | } 499 | 500 | // RemoveRspInterceptor removes an active response interceptor from the proxy 501 | func (iproxy *InterceptingProxy) RemoveRspInterceptor(sub *RspIntSub) { 502 | iproxy.mtx.Lock() 503 | defer iproxy.mtx.Unlock() 504 | 505 | for i, checkSub := range iproxy.rspSubs { 506 | if checkSub.id == sub.id { 507 | iproxy.rspSubs = append(iproxy.rspSubs[:i], iproxy.rspSubs[i+1:]...) 508 | return 509 | } 510 | } 511 | } 512 | 513 | // AddWSInterceptor adds a WSInterceptor to the proxy which will be used to modify both incoming and outgoing websocket messages as they pass through the proxy. Returns a struct representing the active interceptor. 514 | func (iproxy *InterceptingProxy) AddWSInterceptor(f WSInterceptor) *WSIntSub { 515 | iproxy.mtx.Lock() 516 | defer iproxy.mtx.Unlock() 517 | 518 | sub := &WSIntSub{ 519 | id: getNextSubId(), 520 | Interceptor: f, 521 | } 522 | iproxy.wsSubs = append(iproxy.wsSubs, sub) 523 | return sub 524 | } 525 | 526 | // RemoveWSInterceptor removes an active websocket interceptor from the proxy 527 | func (iproxy *InterceptingProxy) RemoveWSInterceptor(sub *WSIntSub) { 528 | iproxy.mtx.Lock() 529 | defer iproxy.mtx.Unlock() 530 | 531 | for i, checkSub := range iproxy.wsSubs { 532 | if checkSub.id == sub.id { 533 | iproxy.wsSubs = append(iproxy.wsSubs[:i], iproxy.wsSubs[i+1:]...) 534 | return 535 | } 536 | } 537 | } 538 | 539 | // Add a global storage watcher 540 | func (iproxy *InterceptingProxy) GlobalStorageWatch(watcher GlobalStorageWatcher) error { 541 | iproxy.mtx.Lock() 542 | defer iproxy.mtx.Unlock() 543 | iproxy.globWatcher.watchers = append(iproxy.globWatcher.watchers, watcher) 544 | return nil 545 | } 546 | 547 | // Remove a global storage watcher 548 | func (iproxy *InterceptingProxy) GlobalStorageEndWatch(watcher GlobalStorageWatcher) error { 549 | iproxy.mtx.Lock() 550 | defer iproxy.mtx.Unlock() 551 | var newWatched = make([]GlobalStorageWatcher, 0) 552 | for _, testWatcher := range iproxy.globWatcher.watchers { 553 | if (testWatcher != watcher) { 554 | newWatched = append(newWatched, testWatcher) 555 | } 556 | } 557 | iproxy.globWatcher.watchers = newWatched 558 | return nil 559 | } 560 | 561 | // SetProxyStorage sets which storage should be used to store messages as they pass through the proxy 562 | func (iproxy *InterceptingProxy) SetProxyStorage(storageId int) error { 563 | iproxy.mtx.Lock() 564 | defer iproxy.mtx.Unlock() 565 | 566 | iproxy.proxyStorage = storageId 567 | 568 | _, ok := iproxy.messageStorage[iproxy.proxyStorage] 569 | if !ok { 570 | return fmt.Errorf("no storage with id %d", storageId) 571 | } 572 | 573 | iproxy.LoadScope(storageId) 574 | return nil 575 | } 576 | 577 | // GetProxyStorage returns the storage being used to save messages as they pass through the proxy 578 | func (iproxy *InterceptingProxy) GetProxyStorage() MessageStorage { 579 | iproxy.mtx.Lock() 580 | defer iproxy.mtx.Unlock() 581 | 582 | savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] 583 | if !ok { 584 | return nil 585 | } 586 | return savedStorage.storage 587 | } 588 | 589 | // AddHTTPHandler causes the proxy to redirect requests to a host to an HTTPHandler. This can be used, for example, to create an internal web inteface. Be careful with what actions are allowed through the interface because the interface could be vulnerable to cross-site request forgery attacks. 590 | func (iproxy *InterceptingProxy) AddHTTPHandler(host string, handler ProxyWebUIHandler) { 591 | iproxy.mtx.Lock() 592 | defer iproxy.mtx.Unlock() 593 | iproxy.httpHandlers[host] = handler 594 | } 595 | 596 | // GetHTTPHandler returns the HTTP handler for a given host 597 | func (iproxy *InterceptingProxy) GetHTTPHandler(host string) (ProxyWebUIHandler, error) { 598 | iproxy.mtx.Lock() 599 | defer iproxy.mtx.Unlock() 600 | handler, ok := iproxy.httpHandlers[host] 601 | if !ok { 602 | return nil, fmt.Errorf("no handler for host %s", host) 603 | } 604 | return handler, nil 605 | } 606 | 607 | // RemoveHTTPHandler removes the HTTP handler for a given host 608 | func (iproxy *InterceptingProxy) RemoveHTTPHandler(host string) { 609 | iproxy.mtx.Lock() 610 | defer iproxy.mtx.Unlock() 611 | delete(iproxy.httpHandlers, host) 612 | } 613 | 614 | // ParseProxyRequest converts an http.Request read from a connection from a ProxyListener into a ProxyRequest 615 | func ParseProxyRequest(r *http.Request) (*ProxyRequest, error) { 616 | host, port, useTLS, err := DecodeRemoteAddr(r.RemoteAddr) 617 | if err != nil { 618 | return nil, nil 619 | } 620 | pr := NewProxyRequest(r, host, port, useTLS) 621 | return pr, nil 622 | } 623 | 624 | // BlankResponse writes a blank response to a http.ResponseWriter. Used when a request/response is dropped. 625 | func BlankResponse(w http.ResponseWriter) { 626 | w.Header().Set("Connection", "close") 627 | w.Header().Set("Cache-control", "no-cache") 628 | w.Header().Set("Pragma", "no-cache") 629 | w.Header().Set("Cache-control", "no-store") 630 | w.Header().Set("X-Frame-Options", "DENY") 631 | w.WriteHeader(200) 632 | } 633 | 634 | // ErrResponse writes an error response to the given http.ResponseWriter. Used to give proxy error information to the browser 635 | func ErrResponse(w http.ResponseWriter, err error) { 636 | w.Header().Set("Connection", "close") 637 | w.Header().Set("Cache-control", "no-cache") 638 | w.Header().Set("Pragma", "no-cache") 639 | w.Header().Set("Cache-control", "no-store") 640 | w.Header().Set("X-Frame-Options", "DENY") 641 | http.Error(w, err.Error(), http.StatusInternalServerError) 642 | } 643 | 644 | // ServeHTTP is used to implement the interface required to have the proxy behave as an HTTP server 645 | func (iproxy *InterceptingProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 646 | handler, err := iproxy.GetHTTPHandler(r.Host) 647 | if err == nil { 648 | handler(w, r, iproxy) 649 | return 650 | } 651 | 652 | req, _ := ParseProxyRequest(r) 653 | iproxy.logger.Println("Received request to", req.FullURL().String()) 654 | req.StripProxyHeaders() 655 | 656 | ms := iproxy.GetProxyStorage() 657 | scopeChecker := iproxy.GetScopeChecker() 658 | 659 | // Helper functions 660 | checkScope := func(req *ProxyRequest) bool { 661 | if scopeChecker != nil { 662 | return scopeChecker(req) 663 | } 664 | return true 665 | } 666 | 667 | saveIfExists := func(req *ProxyRequest) error { 668 | if ms != nil && checkScope(req) { 669 | if err := UpdateRequest(ms, req); err != nil { 670 | return err 671 | } 672 | } 673 | return nil 674 | } 675 | 676 | /* 677 | functions to mangle messages using the iproxy's manglers 678 | each return the new message, whether it was modified, and an error 679 | */ 680 | 681 | mangleRequest := func(req *ProxyRequest) (*ProxyRequest, bool, error) { 682 | newReq := req.Clone() 683 | reqSubs := iproxy.getRequestSubs() 684 | for _, sub := range reqSubs { 685 | var err error = nil 686 | newReq, err = sub.Interceptor(newReq) 687 | if err != nil { 688 | e := fmt.Errorf("error with request interceptor: %s", err) 689 | return nil, false, e 690 | } 691 | if newReq == nil { 692 | break 693 | } 694 | } 695 | 696 | if newReq != nil { 697 | newReq.StartDatetime = time.Now() 698 | if !req.Eq(newReq) { 699 | iproxy.logger.Println("Request modified by interceptor") 700 | return newReq, true, nil 701 | } 702 | } else { 703 | return nil, true, nil 704 | } 705 | return req, false, nil 706 | } 707 | 708 | mangleResponse := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, bool, error) { 709 | reqCopy := req.Clone() 710 | newRsp := rsp.Clone() 711 | rspSubs := iproxy.getResponseSubs() 712 | iproxy.logger.Printf("%d interceptors", len(rspSubs)) 713 | for _, sub := range rspSubs { 714 | iproxy.logger.Println("mangling rsp...") 715 | var err error = nil 716 | newRsp, err = sub.Interceptor(reqCopy, newRsp) 717 | if err != nil { 718 | e := fmt.Errorf("error with response interceptor: %s", err) 719 | return nil, false, e 720 | } 721 | if newRsp == nil { 722 | break 723 | } 724 | } 725 | 726 | if newRsp != nil { 727 | if !rsp.Eq(newRsp) { 728 | iproxy.logger.Println("Response for", req.FullURL(), "modified by interceptor") 729 | // it was mangled 730 | return newRsp, true, nil 731 | } 732 | } else { 733 | // it was dropped 734 | return nil, true, nil 735 | } 736 | 737 | // it wasn't changed 738 | return rsp, false, nil 739 | } 740 | 741 | mangleWS := func(req *ProxyRequest, rsp *ProxyResponse, ws *ProxyWSMessage) (*ProxyWSMessage, bool, error) { 742 | newMsg := ws.Clone() 743 | reqCopy := req.Clone() 744 | rspCopy := rsp.Clone() 745 | wsSubs := iproxy.getWSSubs() 746 | for _, sub := range wsSubs { 747 | var err error = nil 748 | newMsg, err = sub.Interceptor(reqCopy, rspCopy, newMsg) 749 | if err != nil { 750 | e := fmt.Errorf("error with ws interceptor: %s", err) 751 | return nil, false, e 752 | } 753 | if newMsg == nil { 754 | break 755 | } 756 | } 757 | 758 | if newMsg != nil { 759 | if !ws.Eq(newMsg) { 760 | newMsg.Timestamp = time.Now() 761 | newMsg.Direction = ws.Direction 762 | iproxy.logger.Println("Message modified by interceptor") 763 | return newMsg, true, nil 764 | } 765 | } else { 766 | return nil, true, nil 767 | } 768 | return ws, false, nil 769 | } 770 | 771 | req.StartDatetime = time.Now() 772 | 773 | if checkScope(req) { 774 | if err := saveIfExists(req); err != nil { 775 | ErrResponse(w, err) 776 | return 777 | } 778 | newReq, mangled, err := mangleRequest(req) 779 | if err != nil { 780 | ErrResponse(w, err) 781 | return 782 | } 783 | if mangled { 784 | if newReq == nil { 785 | req.ServerResponse = nil 786 | if err := saveIfExists(req); err != nil { 787 | ErrResponse(w, err) 788 | return 789 | } 790 | BlankResponse(w) 791 | return 792 | } 793 | newReq.Unmangled = req 794 | req = newReq 795 | req.StartDatetime = time.Now() 796 | if err := saveIfExists(req); err != nil { 797 | ErrResponse(w, err) 798 | return 799 | } 800 | } 801 | } 802 | 803 | if req.IsWSUpgrade() { 804 | iproxy.logger.Println("Detected websocket request. Upgrading...") 805 | 806 | rc, err := iproxy.WSDial(req) 807 | if err != nil { 808 | iproxy.logger.Println("error dialing ws server:", err) 809 | http.Error(w, fmt.Sprintf("error dialing websocket server: %s", err.Error()), http.StatusInternalServerError) 810 | return 811 | } 812 | defer rc.Close() 813 | req.EndDatetime = time.Now() 814 | if err := saveIfExists(req); err != nil { 815 | ErrResponse(w, err) 816 | return 817 | } 818 | 819 | var upgrader = websocket.Upgrader{ 820 | CheckOrigin: func(r *http.Request) bool { 821 | return true 822 | }, 823 | } 824 | 825 | lc, err := upgrader.Upgrade(w, r, nil) 826 | if err != nil { 827 | iproxy.logger.Println("error upgrading connection:", err) 828 | http.Error(w, fmt.Sprintf("error upgrading connection: %s", err.Error()), http.StatusInternalServerError) 829 | return 830 | } 831 | defer lc.Close() 832 | 833 | var wg sync.WaitGroup 834 | var reqMtx sync.Mutex 835 | addWSMessage := func(req *ProxyRequest, wsm *ProxyWSMessage) { 836 | reqMtx.Lock() 837 | defer reqMtx.Unlock() 838 | req.WSMessages = append(req.WSMessages, wsm) 839 | } 840 | 841 | // Get messages from server 842 | wg.Add(1) 843 | go func() { 844 | for { 845 | mtype, msg, err := rc.ReadMessage() 846 | if err != nil { 847 | lc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 848 | iproxy.logger.Println("error with receiving server message:", err) 849 | wg.Done() 850 | return 851 | } 852 | pws, err := NewProxyWSMessage(mtype, msg, ToClient) 853 | if err != nil { 854 | iproxy.logger.Println("error creating ws object:", err.Error()) 855 | continue 856 | } 857 | pws.Timestamp = time.Now() 858 | 859 | if checkScope(req) { 860 | newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) 861 | if err != nil { 862 | iproxy.logger.Println("error mangling ws:", err) 863 | return 864 | } 865 | if mangled { 866 | if newMsg == nil { 867 | continue 868 | } else { 869 | newMsg.Unmangled = pws 870 | pws = newMsg 871 | pws.Request = nil 872 | } 873 | } 874 | } 875 | 876 | addWSMessage(req, pws) 877 | if err := saveIfExists(req); err != nil { 878 | iproxy.logger.Println("error saving request:", err) 879 | continue 880 | } 881 | lc.WriteMessage(pws.Type, pws.Message) 882 | } 883 | }() 884 | 885 | // Get messages from client 886 | wg.Add(1) 887 | go func() { 888 | for { 889 | mtype, msg, err := lc.ReadMessage() 890 | if err != nil { 891 | rc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 892 | iproxy.logger.Println("error with receiving client message:", err) 893 | wg.Done() 894 | return 895 | } 896 | pws, err := NewProxyWSMessage(mtype, msg, ToServer) 897 | if err != nil { 898 | iproxy.logger.Println("error creating ws object:", err.Error()) 899 | continue 900 | } 901 | pws.Timestamp = time.Now() 902 | 903 | if checkScope(req) { 904 | newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) 905 | if err != nil { 906 | iproxy.logger.Println("error mangling ws:", err) 907 | return 908 | } 909 | if mangled { 910 | if newMsg == nil { 911 | continue 912 | } else { 913 | newMsg.Unmangled = pws 914 | pws = newMsg 915 | pws.Request = nil 916 | } 917 | } 918 | } 919 | 920 | addWSMessage(req, pws) 921 | if err := saveIfExists(req); err != nil { 922 | iproxy.logger.Println("error saving request:", err) 923 | continue 924 | } 925 | rc.WriteMessage(pws.Type, pws.Message) 926 | } 927 | }() 928 | wg.Wait() 929 | iproxy.logger.Println("Websocket session complete!") 930 | } else { 931 | err := iproxy.SubmitRequest(req) 932 | if err != nil { 933 | http.Error(w, fmt.Sprintf("error submitting request: %s", err.Error()), http.StatusInternalServerError) 934 | return 935 | } 936 | req.EndDatetime = time.Now() 937 | if err := saveIfExists(req); err != nil { 938 | ErrResponse(w, err) 939 | return 940 | } 941 | 942 | if checkScope(req) { 943 | newRsp, mangled, err := mangleResponse(req, req.ServerResponse) 944 | if err != nil { 945 | http.Error(w, err.Error(), http.StatusInternalServerError) 946 | return 947 | } 948 | if mangled { 949 | if newRsp == nil { 950 | req.ServerResponse = nil 951 | if err := saveIfExists(req); err != nil { 952 | ErrResponse(w, err) 953 | return 954 | } 955 | BlankResponse(w) 956 | return 957 | } 958 | newRsp.Unmangled = req.ServerResponse 959 | req.ServerResponse = newRsp 960 | if err := saveIfExists(req); err != nil { 961 | ErrResponse(w, err) 962 | return 963 | } 964 | } 965 | } 966 | 967 | for k, v := range req.ServerResponse.Header { 968 | for _, vv := range v { 969 | w.Header().Add(k, vv) 970 | } 971 | } 972 | w.WriteHeader(req.ServerResponse.StatusCode) 973 | w.Write(req.ServerResponse.BodyBytes()) 974 | return 975 | } 976 | } 977 | 978 | func newProxyServer(logger *log.Logger, iproxy *InterceptingProxy) *http.Server { 979 | server := &http.Server{ 980 | Handler: iproxy, 981 | ErrorLog: logger, 982 | } 983 | return server 984 | } 985 | 986 | // StorageWatcher implementation 987 | func (watcher *globalWatcherShim) NewRequestSaved(ms MessageStorage, req *ProxyRequest) { 988 | for _, w := range watcher.globWatcher.watchers { 989 | w.NewRequestSaved(watcher.storageId, ms, req) 990 | } 991 | } 992 | 993 | func (watcher *globalWatcherShim) RequestUpdated(ms MessageStorage, req *ProxyRequest) { 994 | for _, w := range watcher.globWatcher.watchers { 995 | w.RequestUpdated(watcher.storageId, ms, req) 996 | } 997 | } 998 | 999 | func (watcher *globalWatcherShim) RequestDeleted(ms MessageStorage, DbId string) { 1000 | for _, w := range watcher.globWatcher.watchers { 1001 | w.RequestDeleted(watcher.storageId, ms, DbId) 1002 | } 1003 | } 1004 | 1005 | func (watcher *globalWatcherShim) NewResponseSaved(ms MessageStorage, rsp *ProxyResponse) { 1006 | for _, w := range watcher.globWatcher.watchers { 1007 | w.NewResponseSaved(watcher.storageId, ms, rsp) 1008 | } 1009 | } 1010 | 1011 | func (watcher *globalWatcherShim) ResponseUpdated(ms MessageStorage, rsp *ProxyResponse) { 1012 | for _, w := range watcher.globWatcher.watchers { 1013 | w.ResponseUpdated(watcher.storageId, ms, rsp) 1014 | } 1015 | } 1016 | 1017 | func (watcher *globalWatcherShim) ResponseDeleted(ms MessageStorage, DbId string) { 1018 | for _, w := range watcher.globWatcher.watchers { 1019 | w.ResponseDeleted(watcher.storageId, ms, DbId) 1020 | } 1021 | } 1022 | 1023 | func (watcher *globalWatcherShim) NewWSMessageSaved(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) { 1024 | for _, w := range watcher.globWatcher.watchers { 1025 | w.NewWSMessageSaved(watcher.storageId, ms, req, wsm) 1026 | } 1027 | } 1028 | 1029 | func (watcher *globalWatcherShim) WSMessageUpdated(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) { 1030 | for _, w := range watcher.globWatcher.watchers { 1031 | w.WSMessageUpdated(watcher.storageId, ms, req, wsm) 1032 | } 1033 | } 1034 | 1035 | func (watcher *globalWatcherShim) WSMessageDeleted(ms MessageStorage, DbId string) { 1036 | for _, w := range watcher.globWatcher.watchers { 1037 | w.WSMessageDeleted(watcher.storageId, ms, DbId) 1038 | } 1039 | } 1040 | 1041 | -------------------------------------------------------------------------------- /proxyhttp.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | /* 4 | Wrappers around http.Request and http.Response to add helper functions needed by the proxy 5 | */ 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "crypto/tls" 11 | "fmt" 12 | "io" 13 | "io/ioutil" 14 | "net" 15 | "net/http" 16 | "net/url" 17 | "reflect" 18 | "strconv" 19 | "strings" 20 | "time" 21 | 22 | "github.com/deckarep/golang-set" 23 | "github.com/gorilla/websocket" 24 | "golang.org/x/net/proxy" 25 | ) 26 | 27 | const ( 28 | ToServer = iota 29 | ToClient 30 | ) 31 | 32 | // A dialer used to create a net.Conn from a network and address 33 | type NetDialer func(network, addr string) (net.Conn, error) 34 | 35 | // ProxyResponse is an http.Response with additional fields for use within the proxy 36 | type ProxyResponse struct { 37 | http.Response 38 | bodyBytes []byte 39 | 40 | // Id used to reference this response in its associated MessageStorage. Blank string means it is not saved in any MessageStorage 41 | DbId string 42 | 43 | // If this response was modified by the proxy, Unmangled is the response before it was modified. If the response was not modified, Unmangled is nil. 44 | Unmangled *ProxyResponse 45 | } 46 | 47 | // ProxyRequest is an http.Request with additional fields for use within the proxy 48 | type ProxyRequest struct { 49 | http.Request 50 | 51 | // Host where this request is intended to be sent when submitted 52 | DestHost string 53 | // Port that should be used when this request is submitted 54 | DestPort int 55 | // Whether TLS should be used when this request is submitted 56 | DestUseTLS bool 57 | 58 | // Response received from the server when this request was submitted. If the request does not have any associated response, ServerResponse is nil. 59 | ServerResponse *ProxyResponse 60 | // If the request was the handshake for a websocket session, WSMessages will be a slice of all the messages that were sent over that session. 61 | WSMessages []*ProxyWSMessage 62 | // If the request was modified by the proxy, Unmangled will point to the unmodified version of the request. Otherwise it is nil. 63 | Unmangled *ProxyRequest 64 | 65 | // ID used to reference to this request in its associated MessageStorage 66 | DbId string 67 | // The time at which this request was submitted 68 | StartDatetime time.Time 69 | // The time at which the response to this request was received 70 | EndDatetime time.Time 71 | 72 | bodyBytes []byte 73 | tags mapset.Set 74 | 75 | // The dialer that should be used when this request is submitted 76 | NetDial NetDialer 77 | } 78 | 79 | // WSSession is an extension of websocket.Conn to contain a reference to the ProxyRequest used for the websocket handshake 80 | type WSSession struct { 81 | websocket.Conn 82 | 83 | // Request used for handshake 84 | Request *ProxyRequest 85 | } 86 | 87 | // ProxyWSMessage represents one message in a websocket session 88 | type ProxyWSMessage struct { 89 | // The type of websocket message 90 | Type int 91 | 92 | // The contents of the message 93 | Message []byte 94 | 95 | // The direction of the message. Either ToServer or ToClient 96 | Direction int 97 | // If the message was modified by the proxy, points to the original unmodified message 98 | Unmangled *ProxyWSMessage 99 | // The time at which the message was sent (if sent to the server) or received (if received from the server) 100 | Timestamp time.Time 101 | // The request used for the handhsake for the session that this session was used for 102 | Request *ProxyRequest 103 | 104 | // ID used to reference to this message in its associated MessageStorage 105 | DbId string 106 | } 107 | 108 | // PerformConnect submits a CONNECT request for the given host and port over the given connection 109 | func PerformConnect(conn net.Conn, destHost string, destPort int) error { 110 | connStr := []byte(fmt.Sprintf("CONNECT %s:%d HTTP/1.1\r\nHost: %s\r\nProxy-Connection: Keep-Alive\r\n\r\n", destHost, destPort, destHost)) 111 | conn.Write(connStr) 112 | rsp, err := http.ReadResponse(bufio.NewReader(conn), nil) 113 | if err != nil { 114 | return fmt.Errorf("error performing CONNECT handshake: %s", err.Error()) 115 | } 116 | if rsp.StatusCode != 200 { 117 | return fmt.Errorf("error performing CONNECT handshake") 118 | } 119 | return nil 120 | } 121 | 122 | // NewProxyRequest creates a new proxy request with the given destination 123 | func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS bool) *ProxyRequest { 124 | var retReq *ProxyRequest 125 | if r != nil { 126 | // Write/reread the request to make sure we get all the extra headers Go adds into req.Header 127 | buf := bytes.NewBuffer(make([]byte, 0)) 128 | r.Write(buf) 129 | httpReq2, err := http.ReadRequest(bufio.NewReader(buf)) 130 | if err != nil { 131 | panic(err) 132 | } 133 | 134 | retReq = &ProxyRequest{ 135 | *httpReq2, 136 | destHost, 137 | destPort, 138 | destUseTLS, 139 | nil, 140 | make([]*ProxyWSMessage, 0), 141 | nil, 142 | "", 143 | time.Unix(0, 0), 144 | time.Unix(0, 0), 145 | make([]byte, 0), 146 | mapset.NewSet(), 147 | nil, 148 | } 149 | } else { 150 | newReq, _ := http.NewRequest("GET", "/", nil) // Ignore error since this should be run the same every time and shouldn't error 151 | newReq.Header.Set("User-Agent", "Puppy-Proxy/1.0") 152 | newReq.Host = destHost 153 | retReq = &ProxyRequest{ 154 | *newReq, 155 | destHost, 156 | destPort, 157 | destUseTLS, 158 | nil, 159 | make([]*ProxyWSMessage, 0), 160 | nil, 161 | "", 162 | time.Unix(0, 0), 163 | time.Unix(0, 0), 164 | make([]byte, 0), 165 | mapset.NewSet(), 166 | nil, 167 | } 168 | } 169 | 170 | // Load the body 171 | bodyBuf, _ := ioutil.ReadAll(retReq.Body) 172 | retReq.SetBodyBytes(bodyBuf) 173 | return retReq 174 | } 175 | 176 | // ProxyRequestFromBytes parses a slice of bytes containing a well-formed HTTP request into a ProxyRequest. Does NOT correct incorrect Content-Length headers 177 | func ProxyRequestFromBytes(b []byte, destHost string, destPort int, destUseTLS bool) (*ProxyRequest, error) { 178 | buf := bytes.NewBuffer(b) 179 | httpReq, err := http.ReadRequest(bufio.NewReader(buf)) 180 | if err != nil { 181 | return nil, err 182 | } 183 | 184 | return NewProxyRequest(httpReq, destHost, destPort, destUseTLS), nil 185 | } 186 | 187 | // NewProxyResponse creates a new ProxyResponse given an http.Response 188 | func NewProxyResponse(r *http.Response) *ProxyResponse { 189 | // Write/reread the request to make sure we get all the extra headers Go adds into req.Header 190 | oldClose := r.Close 191 | r.Close = false 192 | buf := bytes.NewBuffer(make([]byte, 0)) 193 | r.Write(buf) 194 | r.Close = oldClose 195 | httpRsp2, err := http.ReadResponse(bufio.NewReader(buf), nil) 196 | if err != nil { 197 | panic(err) 198 | } 199 | httpRsp2.Close = false 200 | retRsp := &ProxyResponse{ 201 | *httpRsp2, 202 | make([]byte, 0), 203 | "", 204 | nil, 205 | } 206 | 207 | bodyBuf, _ := ioutil.ReadAll(retRsp.Body) 208 | retRsp.SetBodyBytes(bodyBuf) 209 | return retRsp 210 | } 211 | 212 | // NewProxyResponse parses a ProxyResponse from a slice of bytes containing a well-formed HTTP response. Does NOT correct incorrect Content-Length headers 213 | func ProxyResponseFromBytes(b []byte) (*ProxyResponse, error) { 214 | buf := bytes.NewBuffer(b) 215 | httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil) 216 | if err != nil { 217 | return nil, err 218 | } 219 | return NewProxyResponse(httpRsp), nil 220 | } 221 | 222 | // NewProxyWSMessage creates a new WSMessage given a type, message, and direction 223 | func NewProxyWSMessage(mtype int, message []byte, direction int) (*ProxyWSMessage, error) { 224 | return &ProxyWSMessage{ 225 | Type: mtype, 226 | Message: message, 227 | Direction: direction, 228 | Unmangled: nil, 229 | Timestamp: time.Unix(0, 0), 230 | DbId: "", 231 | }, nil 232 | } 233 | 234 | // DestScheme returns the scheme used by the request (ws, wss, http, or https) 235 | func (req *ProxyRequest) DestScheme() string { 236 | if req.IsWSUpgrade() { 237 | if req.DestUseTLS { 238 | return "wss" 239 | } else { 240 | return "ws" 241 | } 242 | } else { 243 | if req.DestUseTLS { 244 | return "https" 245 | } else { 246 | return "http" 247 | } 248 | } 249 | } 250 | 251 | // FullURL is the same as req.URL but guarantees it will include the scheme, host, and port if necessary 252 | func (req *ProxyRequest) FullURL() *url.URL { 253 | var u url.URL 254 | u = *(req.URL) // Copy the original req.URL 255 | u.Host = req.Host 256 | u.Scheme = req.DestScheme() 257 | return &u 258 | } 259 | 260 | // Same as req.FullURL() but uses DestHost and DestPort for the host and port of the URL 261 | func (req *ProxyRequest) DestURL() *url.URL { 262 | var u url.URL 263 | u = *(req.URL) // Copy the original req.URL 264 | u.Scheme = req.DestScheme() 265 | 266 | if req.DestUseTLS && req.DestPort == 443 || 267 | !req.DestUseTLS && req.DestPort == 80 { 268 | u.Host = req.DestHost 269 | } else { 270 | u.Host = fmt.Sprintf("%s:%d", req.DestHost, req.DestPort) 271 | } 272 | return &u 273 | } 274 | 275 | // Submit submits the request over the given connection. Does not take into account DestHost, DestPort, or DestUseTLS 276 | func (req *ProxyRequest) Submit(conn net.Conn) error { 277 | return req.submit(conn, false, nil) 278 | } 279 | 280 | // Submit submits the request in proxy form over the given connection for use with an upstream HTTP proxy. Does not take into account DestHost, DestPort, or DestUseTLS 281 | func (req *ProxyRequest) SubmitProxy(conn net.Conn, creds *ProxyCredentials) error { 282 | return req.submit(conn, true, creds) 283 | } 284 | 285 | func (req *ProxyRequest) submit(conn net.Conn, forProxy bool, proxyCreds *ProxyCredentials) error { 286 | // Write the request to the connection 287 | req.StartDatetime = time.Now() 288 | if forProxy { 289 | if req.DestUseTLS { 290 | req.URL.Scheme = "https" 291 | } else { 292 | req.URL.Scheme = "http" 293 | } 294 | req.URL.Opaque = "" 295 | 296 | if err := req.RepeatableProxyWrite(conn, proxyCreds); err != nil { 297 | return err 298 | } 299 | } else { 300 | if err := req.RepeatableWrite(conn); err != nil { 301 | return err 302 | } 303 | } 304 | 305 | // Read a response from the server 306 | httpRsp, err := http.ReadResponse(bufio.NewReader(conn), nil) 307 | if err != nil { 308 | return fmt.Errorf("error reading response: %s", err.Error()) 309 | } 310 | req.EndDatetime = time.Now() 311 | 312 | prsp := NewProxyResponse(httpRsp) 313 | req.ServerResponse = prsp 314 | return nil 315 | } 316 | 317 | // WSDial performs a websocket handshake over the given connection. Does not take into account DestHost, DestPort, or DestUseTLS 318 | func (req *ProxyRequest) WSDial(conn net.Conn) (*WSSession, error) { 319 | if !req.IsWSUpgrade() { 320 | return nil, fmt.Errorf("could not start websocket session: request is not a websocket handshake request") 321 | } 322 | 323 | upgradeHeaders := make(http.Header) 324 | for k, v := range req.Header { 325 | for _, vv := range v { 326 | if !(k == "Upgrade" || 327 | k == "Connection" || 328 | k == "Sec-Websocket-Key" || 329 | k == "Sec-Websocket-Version" || 330 | k == "Sec-Websocket-Extensions" || 331 | k == "Sec-Websocket-Protocol") { 332 | upgradeHeaders.Add(k, vv) 333 | } 334 | } 335 | } 336 | 337 | dialer := &websocket.Dialer{} 338 | dialer.NetDial = func(network, address string) (net.Conn, error) { 339 | return conn, nil 340 | } 341 | 342 | wsconn, rsp, err := dialer.Dial(req.DestURL().String(), upgradeHeaders) 343 | if err != nil { 344 | return nil, fmt.Errorf("could not dial WebSocket server: %s", err) 345 | } 346 | req.ServerResponse = NewProxyResponse(rsp) 347 | wsession := &WSSession{ 348 | *wsconn, 349 | req, 350 | } 351 | return wsession, nil 352 | } 353 | 354 | // WSDial dials the target server and performs a websocket handshake over the new connection. Uses destination information from the request. 355 | func WSDial(req *ProxyRequest) (*WSSession, error) { 356 | return wsDial(req, false, "", 0, nil, false) 357 | } 358 | 359 | // WSDialProxy dials the HTTP proxy server, performs a CONNECT handshake to connect to the remote server, then performs a websocket handshake over the new connection. Uses destination information from the request. 360 | func WSDialProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) (*WSSession, error) { 361 | return wsDial(req, true, proxyHost, proxyPort, creds, false) 362 | } 363 | 364 | // WSDialSOCKSProxy connects to the target host through the SOCKS proxy and performs a websocket handshake over the new connection. Uses destination information from the request. 365 | func WSDialSOCKSProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) (*WSSession, error) { 366 | return wsDial(req, true, proxyHost, proxyPort, creds, true) 367 | } 368 | 369 | func wsDial(req *ProxyRequest, useProxy bool, proxyHost string, proxyPort int, proxyCreds *ProxyCredentials, proxyIsSOCKS bool) (*WSSession, error) { 370 | var conn net.Conn 371 | var dialer NetDialer 372 | var err error 373 | 374 | if req.NetDial != nil { 375 | dialer = req.NetDial 376 | } else { 377 | dialer = net.Dial 378 | } 379 | 380 | if useProxy { 381 | if proxyIsSOCKS { 382 | var socksCreds *proxy.Auth 383 | if proxyCreds != nil { 384 | socksCreds = &proxy.Auth{ 385 | User: proxyCreds.Username, 386 | Password: proxyCreds.Password, 387 | } 388 | } 389 | socksDialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), socksCreds, proxy.Direct) 390 | if err != nil { 391 | return nil, fmt.Errorf("error creating SOCKS dialer: %s", err.Error()) 392 | } 393 | conn, err = socksDialer.Dial("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)) 394 | if err != nil { 395 | return nil, fmt.Errorf("error dialing host: %s", err.Error()) 396 | } 397 | defer conn.Close() 398 | } else { 399 | conn, err = dialer("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort)) 400 | if err != nil { 401 | return nil, fmt.Errorf("error dialing proxy: %s", err.Error()) 402 | } 403 | 404 | // always perform a CONNECT for websocket regardless of SSL 405 | if err := PerformConnect(conn, req.DestHost, req.DestPort); err != nil { 406 | return nil, err 407 | } 408 | } 409 | } else { 410 | conn, err = dialer("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)) 411 | if err != nil { 412 | return nil, fmt.Errorf("error dialing host: %s", err.Error()) 413 | } 414 | } 415 | 416 | if req.DestUseTLS { 417 | tls_conn := tls.Client(conn, &tls.Config{ 418 | InsecureSkipVerify: true, 419 | }) 420 | conn = tls_conn 421 | } 422 | 423 | return req.WSDial(conn) 424 | } 425 | 426 | // IsWSUpgrade returns whether the request is used to initiate a websocket handshake 427 | func (req *ProxyRequest) IsWSUpgrade() bool { 428 | for k, v := range req.Header { 429 | for _, vv := range v { 430 | if strings.ToLower(k) == "upgrade" && strings.Contains(vv, "websocket") { 431 | return true 432 | } 433 | } 434 | } 435 | return false 436 | } 437 | 438 | // StripProxyHeaders removes headers associated with requests made to a proxy from the request 439 | func (req *ProxyRequest) StripProxyHeaders() { 440 | if !req.IsWSUpgrade() { 441 | req.Header.Del("Connection") 442 | } 443 | req.Header.Del("Accept-Encoding") 444 | req.Header.Del("Proxy-Connection") 445 | req.Header.Del("Proxy-Authenticate") 446 | req.Header.Del("Proxy-Authorization") 447 | } 448 | 449 | // Eq checks whether the request is the same as another request and has the same destination information 450 | func (req *ProxyRequest) Eq(other *ProxyRequest) bool { 451 | if req.StatusLine() != other.StatusLine() || 452 | !reflect.DeepEqual(req.Header, other.Header) || 453 | bytes.Compare(req.BodyBytes(), other.BodyBytes()) != 0 || 454 | req.DestHost != other.DestHost || 455 | req.DestPort != other.DestPort || 456 | req.DestUseTLS != other.DestUseTLS { 457 | return false 458 | } 459 | 460 | return true 461 | } 462 | 463 | // Clone returns a request with the same contents and destination information as the original 464 | func (req *ProxyRequest) Clone() *ProxyRequest { 465 | buf := bytes.NewBuffer(make([]byte, 0)) 466 | req.RepeatableWrite(buf) 467 | newReq, err := ProxyRequestFromBytes(buf.Bytes(), req.DestHost, req.DestPort, req.DestUseTLS) 468 | if err != nil { 469 | panic(err) 470 | } 471 | newReq.DestHost = req.DestHost 472 | newReq.DestPort = req.DestPort 473 | newReq.DestUseTLS = req.DestUseTLS 474 | newReq.Header = copyHeader(req.Header) 475 | return newReq 476 | } 477 | 478 | // DeepClone returns a request with the same contents, destination, and storage information information as the original along with a deep clone of the associated response, the unmangled version of the request, and any websocket messages 479 | func (req *ProxyRequest) DeepClone() *ProxyRequest { 480 | // Returns a request with the same request, response, and associated websocket messages 481 | newReq := req.Clone() 482 | newReq.DbId = req.DbId 483 | 484 | if req.Unmangled != nil { 485 | newReq.Unmangled = req.Unmangled.DeepClone() 486 | } 487 | 488 | if req.ServerResponse != nil { 489 | newReq.ServerResponse = req.ServerResponse.DeepClone() 490 | } 491 | 492 | for _, wsm := range req.WSMessages { 493 | newReq.WSMessages = append(newReq.WSMessages, wsm.DeepClone()) 494 | } 495 | 496 | return newReq 497 | } 498 | 499 | func (req *ProxyRequest) resetBodyReader() { 500 | // yes I know this method isn't the most efficient, I'll fix it if it causes problems later 501 | req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyBytes())) 502 | } 503 | 504 | // RepeatableWrite is the same as http.Request.Write except that it can be safely called multiple times 505 | func (req *ProxyRequest) RepeatableWrite(w io.Writer) error { 506 | defer req.resetBodyReader() 507 | return req.Write(w) 508 | } 509 | 510 | // RepeatableWrite is the same as http.Request.ProxyWrite except that it can be safely called multiple times 511 | func (req *ProxyRequest) RepeatableProxyWrite(w io.Writer, proxyCreds *ProxyCredentials) error { 512 | defer req.resetBodyReader() 513 | if proxyCreds != nil { 514 | authHeader := proxyCreds.SerializeHeader() 515 | req.Header.Set("Proxy-Authorization", authHeader) 516 | defer func() { req.Header.Del("Proxy-Authorization") }() 517 | } 518 | return req.WriteProxy(w) 519 | } 520 | 521 | // BodyBytes returns the bytes of the request body 522 | func (req *ProxyRequest) BodyBytes() []byte { 523 | return DuplicateBytes(req.bodyBytes) 524 | } 525 | 526 | // SetBodyBytes sets the bytes of the request body and updates the Content-Length header 527 | func (req *ProxyRequest) SetBodyBytes(bs []byte) { 528 | req.bodyBytes = bs 529 | req.resetBodyReader() 530 | 531 | // Parse the form if we can, ignore errors 532 | req.ParseMultipartForm(1024 * 1024 * 1024) // 1GB for no good reason 533 | req.ParseForm() 534 | req.resetBodyReader() 535 | req.Header.Set("Content-Length", strconv.Itoa(len(bs))) 536 | } 537 | 538 | // FullMessage returns a slice of bytes containing the full HTTP message for the request 539 | func (req *ProxyRequest) FullMessage() []byte { 540 | buf := bytes.NewBuffer(make([]byte, 0)) 541 | req.RepeatableWrite(buf) 542 | return buf.Bytes() 543 | } 544 | 545 | // PostParameters attempts to parse POST parameters from the body of the request 546 | func (req *ProxyRequest) PostParameters() (url.Values, error) { 547 | vals, err := url.ParseQuery(string(req.BodyBytes())) 548 | if err != nil { 549 | return nil, err 550 | } 551 | return vals, nil 552 | } 553 | 554 | // SetPostParameter sets the value of a post parameter in the message body. If the body does not contain well-formed data, it is deleted replaced with a well-formed body containing only the new parameter 555 | func (req *ProxyRequest) SetPostParameter(key string, value string) { 556 | req.PostForm.Set(key, value) 557 | req.SetBodyBytes([]byte(req.PostForm.Encode())) 558 | } 559 | 560 | // AddPostParameter adds a post parameter to the body of the request even if a duplicate exists. If the body does not contain well-formed data, it is deleted replaced with a well-formed body containing only the new parameter 561 | func (req *ProxyRequest) AddPostParameter(key string, value string) { 562 | req.PostForm.Add(key, value) 563 | req.SetBodyBytes([]byte(req.PostForm.Encode())) 564 | } 565 | 566 | // DeletePostParameter removes a parameter from the body of the request. If the body does not contain well-formed data, it is deleted replaced with a well-formed body containing only the new parameter 567 | func (req *ProxyRequest) DeletePostParameter(key string) { 568 | req.PostForm.Del(key) 569 | req.SetBodyBytes([]byte(req.PostForm.Encode())) 570 | } 571 | 572 | // SetURLParameter sets the value of a URL parameter and updates ProxyRequest.URL 573 | func (req *ProxyRequest) SetURLParameter(key string, value string) { 574 | q := req.URL.Query() 575 | q.Set(key, value) 576 | req.URL.RawQuery = q.Encode() 577 | req.ParseForm() 578 | } 579 | 580 | // URLParameters returns the values of the request's URL parameters 581 | func (req *ProxyRequest) URLParameters() url.Values { 582 | vals := req.URL.Query() 583 | return vals 584 | } 585 | 586 | // AddURLParameter adds a URL parameter to the request ignoring duplicates 587 | func (req *ProxyRequest) AddURLParameter(key string, value string) { 588 | q := req.URL.Query() 589 | q.Add(key, value) 590 | req.URL.RawQuery = q.Encode() 591 | req.ParseForm() 592 | } 593 | 594 | // DeleteURLParameter removes a URL parameter from the request 595 | func (req *ProxyRequest) DeleteURLParameter(key string) { 596 | q := req.URL.Query() 597 | q.Del(key) 598 | req.URL.RawQuery = q.Encode() 599 | req.ParseForm() 600 | } 601 | 602 | // AddTag adds a tag to the request 603 | func (req *ProxyRequest) AddTag(tag string) { 604 | req.tags.Add(tag) 605 | } 606 | 607 | // CheckTag returns whether the request has a given tag 608 | func (req *ProxyRequest) CheckTag(tag string) bool { 609 | return req.tags.Contains(tag) 610 | } 611 | 612 | // RemoveTag removes a tag from the request 613 | func (req *ProxyRequest) RemoveTag(tag string) { 614 | req.tags.Remove(tag) 615 | } 616 | 617 | // ClearTag removes all of the tags associated with the request 618 | func (req *ProxyRequest) ClearTags() { 619 | req.tags.Clear() 620 | } 621 | 622 | // Tags returns a slice containing all of the tags associated with the request 623 | func (req *ProxyRequest) Tags() []string { 624 | items := req.tags.ToSlice() 625 | retslice := make([]string, 0) 626 | for _, item := range items { 627 | str, ok := item.(string) 628 | if ok { 629 | retslice = append(retslice, str) 630 | } 631 | } 632 | return retslice 633 | } 634 | 635 | // HTTPPath returns the path of the associated with the request 636 | func (req *ProxyRequest) HTTPPath() string { 637 | // The path used in the http request 638 | u := *req.URL 639 | u.Scheme = "" 640 | u.Host = "" 641 | u.Opaque = "" 642 | u.User = nil 643 | return u.String() 644 | } 645 | 646 | // StatusLine returns the status line associated with the request 647 | func (req *ProxyRequest) StatusLine() string { 648 | return fmt.Sprintf("%s %s %s", req.Method, req.HTTPPath(), req.Proto) 649 | } 650 | 651 | // HeaderSection returns the header section of the request without the additional \r\n at the end 652 | func (req *ProxyRequest) HeaderSection() string { 653 | retStr := req.StatusLine() 654 | retStr += "\r\n" 655 | for k, vs := range req.Header { 656 | for _, v := range vs { 657 | retStr += fmt.Sprintf("%s: %s\r\n", k, v) 658 | } 659 | } 660 | return retStr 661 | } 662 | 663 | func (rsp *ProxyResponse) resetBodyReader() { 664 | // yes I know this method isn't the most efficient, I'll fix it if it causes problems later 665 | rsp.Body = ioutil.NopCloser(bytes.NewBuffer(rsp.BodyBytes())) 666 | } 667 | 668 | // RepeatableWrite is the same as http.Response.Write except that it can safely be called multiple times 669 | func (rsp *ProxyResponse) RepeatableWrite(w io.Writer) error { 670 | defer rsp.resetBodyReader() 671 | return rsp.Write(w) 672 | } 673 | 674 | // BodyBytes returns the bytes contained in the body of the response 675 | func (rsp *ProxyResponse) BodyBytes() []byte { 676 | return DuplicateBytes(rsp.bodyBytes) 677 | } 678 | 679 | // SetBodyBytes sets the bytes in the body of the response and updates the Content-Length header 680 | func (rsp *ProxyResponse) SetBodyBytes(bs []byte) { 681 | rsp.bodyBytes = bs 682 | rsp.resetBodyReader() 683 | rsp.Header.Set("Content-Length", strconv.Itoa(len(bs))) 684 | } 685 | 686 | // Clone returns a response with the same status line, headers, and body as the response 687 | func (rsp *ProxyResponse) Clone() *ProxyResponse { 688 | buf := bytes.NewBuffer(make([]byte, 0)) 689 | rsp.RepeatableWrite(buf) 690 | newRsp, err := ProxyResponseFromBytes(buf.Bytes()) 691 | if err != nil { 692 | panic(err) 693 | } 694 | return newRsp 695 | } 696 | 697 | // DeepClone returns a response with the same status line, headers, and body as the original response along with a deep clone of its unmangled version if it exists 698 | func (rsp *ProxyResponse) DeepClone() *ProxyResponse { 699 | newRsp := rsp.Clone() 700 | newRsp.DbId = rsp.DbId 701 | if rsp.Unmangled != nil { 702 | newRsp.Unmangled = rsp.Unmangled.DeepClone() 703 | } 704 | return newRsp 705 | } 706 | 707 | // Eq returns whether the response has the same contents as another response 708 | func (rsp *ProxyResponse) Eq(other *ProxyResponse) bool { 709 | if rsp.StatusLine() != other.StatusLine() || 710 | !reflect.DeepEqual(rsp.Header, other.Header) || 711 | bytes.Compare(rsp.BodyBytes(), other.BodyBytes()) != 0 { 712 | return false 713 | } 714 | return true 715 | } 716 | 717 | // FullMessage returns the full HTTP message of the response 718 | func (rsp *ProxyResponse) FullMessage() []byte { 719 | buf := bytes.NewBuffer(make([]byte, 0)) 720 | rsp.RepeatableWrite(buf) 721 | return buf.Bytes() 722 | } 723 | 724 | // Returns the status text to be used in the http request 725 | func (rsp *ProxyResponse) HTTPStatus() string { 726 | // The status text to be used in the http request. Relies on being the same implementation as http.Response 727 | text := rsp.Status 728 | if text == "" { 729 | text = http.StatusText(rsp.StatusCode) 730 | if text == "" { 731 | text = "status code " + strconv.Itoa(rsp.StatusCode) 732 | } 733 | } else { 734 | // Just to reduce stutter, if user set rsp.Status to "200 OK" and StatusCode to 200. 735 | // Not important. 736 | text = strings.TrimPrefix(text, strconv.Itoa(rsp.StatusCode)+" ") 737 | } 738 | return text 739 | } 740 | 741 | // StatusLine returns the status line of the response 742 | func (rsp *ProxyResponse) StatusLine() string { 743 | // Status line, stolen from net/http/response.go 744 | return fmt.Sprintf("HTTP/%d.%d %03d %s", rsp.ProtoMajor, rsp.ProtoMinor, rsp.StatusCode, rsp.HTTPStatus()) 745 | } 746 | 747 | // HeaderSection returns the header section of the response (without the extra trailing \r\n) 748 | func (rsp *ProxyResponse) HeaderSection() string { 749 | retStr := rsp.StatusLine() 750 | retStr += "\r\n" 751 | for k, vs := range rsp.Header { 752 | for _, v := range vs { 753 | retStr += fmt.Sprintf("%s: %s\r\n", k, v) 754 | } 755 | } 756 | return retStr 757 | } 758 | 759 | func (msg *ProxyWSMessage) String() string { 760 | var dirStr string 761 | if msg.Direction == ToClient { 762 | dirStr = "ToClient" 763 | } else { 764 | dirStr = "ToServer" 765 | } 766 | return fmt.Sprintf("{WS Message msg=\"%s\", type=%d, dir=%s}", string(msg.Message), msg.Type, dirStr) 767 | } 768 | 769 | // Clone returns a copy of the original message. It will have the same type, message, direction, timestamp, and request 770 | func (msg *ProxyWSMessage) Clone() *ProxyWSMessage { 771 | var retMsg ProxyWSMessage 772 | retMsg.Type = msg.Type 773 | retMsg.Message = msg.Message 774 | retMsg.Direction = msg.Direction 775 | retMsg.Timestamp = msg.Timestamp 776 | retMsg.Request = msg.Request 777 | return &retMsg 778 | } 779 | 780 | // DeepClone returns a clone of the original message and a deep clone of the unmangled version if it exists 781 | func (msg *ProxyWSMessage) DeepClone() *ProxyWSMessage { 782 | retMsg := msg.Clone() 783 | retMsg.DbId = msg.DbId 784 | if msg.Unmangled != nil { 785 | retMsg.Unmangled = msg.Unmangled.DeepClone() 786 | } 787 | return retMsg 788 | } 789 | 790 | // Eq checks if the message has the same type, direction, and message as another message 791 | func (msg *ProxyWSMessage) Eq(other *ProxyWSMessage) bool { 792 | if msg.Type != other.Type || 793 | msg.Direction != other.Direction || 794 | bytes.Compare(msg.Message, other.Message) != 0 { 795 | return false 796 | } 797 | return true 798 | } 799 | 800 | func copyHeader(hd http.Header) http.Header { 801 | var ret http.Header = make(http.Header) 802 | for k, vs := range hd { 803 | for _, v := range vs { 804 | ret.Add(k, v) 805 | } 806 | } 807 | return ret 808 | } 809 | 810 | func submitRequest(req *ProxyRequest, useProxy bool, proxyHost string, 811 | proxyPort int, proxyCreds *ProxyCredentials, proxyIsSOCKS bool) error { 812 | var dialer NetDialer = req.NetDial 813 | if dialer == nil { 814 | dialer = net.Dial 815 | } 816 | 817 | var conn net.Conn 818 | var err error 819 | var proxyFormat bool = false 820 | if useProxy { 821 | if proxyIsSOCKS { 822 | var socksCreds *proxy.Auth 823 | if proxyCreds != nil { 824 | socksCreds = &proxy.Auth{ 825 | User: proxyCreds.Username, 826 | Password: proxyCreds.Password, 827 | } 828 | } 829 | socksDialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), socksCreds, proxy.Direct) 830 | if err != nil { 831 | return fmt.Errorf("error creating SOCKS dialer: %s", err.Error()) 832 | } 833 | conn, err = socksDialer.Dial("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)) 834 | if err != nil { 835 | return fmt.Errorf("error dialing host: %s", err.Error()) 836 | } 837 | defer conn.Close() 838 | } else { 839 | conn, err = dialer("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort)) 840 | if err != nil { 841 | return fmt.Errorf("error dialing proxy: %s", err.Error()) 842 | } 843 | defer conn.Close() 844 | if req.DestUseTLS { 845 | if err := PerformConnect(conn, req.DestHost, req.DestPort); err != nil { 846 | return err 847 | } 848 | proxyFormat = false 849 | } else { 850 | proxyFormat = true 851 | } 852 | } 853 | } else { 854 | conn, err = dialer("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)) 855 | if err != nil { 856 | return fmt.Errorf("error dialing host: %s", err.Error()) 857 | } 858 | defer conn.Close() 859 | } 860 | 861 | if req.DestUseTLS { 862 | tls_conn := tls.Client(conn, &tls.Config{ 863 | InsecureSkipVerify: true, 864 | }) 865 | conn = tls_conn 866 | } 867 | 868 | if proxyFormat { 869 | return req.SubmitProxy(conn, proxyCreds) 870 | } else { 871 | return req.Submit(conn) 872 | } 873 | } 874 | 875 | // SubmitRequest opens a connection to the request's DestHost:DestPort, using TLS if DestUseTLS is set, submits the request, and sets req.Response with the response when a response is received 876 | func SubmitRequest(req *ProxyRequest) error { 877 | return submitRequest(req, false, "", 0, nil, false) 878 | } 879 | 880 | // SubmitRequestProxy connects to the given HTTP proxy, performs neccessary handshakes, and submits the request to its destination. req.Response will be set once a response is received 881 | func SubmitRequestProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) error { 882 | return submitRequest(req, true, proxyHost, proxyPort, creds, false) 883 | } 884 | 885 | // SubmitRequestProxy connects to the given SOCKS proxy, performs neccessary handshakes, and submits the request to its destination. req.Response will be set once a response is received 886 | func SubmitRequestSOCKSProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) error { 887 | return submitRequest(req, true, proxyHost, proxyPort, creds, true) 888 | } 889 | -------------------------------------------------------------------------------- /proxyhttp_test.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "net/url" 5 | "runtime" 6 | "testing" 7 | // "bytes" 8 | // "net/http" 9 | // "bufio" 10 | // "os" 11 | ) 12 | 13 | type statusLiner interface { 14 | StatusLine() string 15 | } 16 | 17 | func checkStr(t *testing.T, result, expected string) { 18 | if result != expected { 19 | _, f, ln, _ := runtime.Caller(1) 20 | t.Errorf("Failed search test at %s:%d. Expected '%s', got '%s'", f, ln, expected, result) 21 | } 22 | } 23 | 24 | func checkStatusline(t *testing.T, msg statusLiner, expected string) { 25 | result := msg.StatusLine() 26 | checkStr(t, expected, result) 27 | } 28 | 29 | func TestStatusline(t *testing.T) { 30 | req := testReq() 31 | checkStr(t, req.StatusLine(), "POST /?foo=bar HTTP/1.1") 32 | 33 | req.Method = "GET" 34 | checkStr(t, req.StatusLine(), "GET /?foo=bar HTTP/1.1") 35 | 36 | req.URL.Fragment = "foofrag" 37 | checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") 38 | 39 | req.URL.User = url.UserPassword("foo", "bar") 40 | checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") 41 | 42 | req.URL.Scheme = "http" 43 | checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") 44 | 45 | req.URL.Opaque = "foobaropaque" 46 | checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") 47 | req.URL.Opaque = "" 48 | 49 | req.URL.Host = "foobarhost" 50 | checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") 51 | 52 | // rsp.Status is actually "200 OK" but the "200 " gets stripped from the front 53 | rsp := req.ServerResponse 54 | checkStr(t, rsp.StatusLine(), "HTTP/1.1 200 OK") 55 | 56 | rsp.StatusCode = 404 57 | checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 200 OK") 58 | 59 | rsp.Status = "is not there plz" 60 | checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz") 61 | 62 | // Same as with "200 OK" 63 | rsp.Status = "404 is not there plz" 64 | checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz") 65 | } 66 | 67 | func TestEq(t *testing.T) { 68 | req1 := testReq() 69 | req2 := testReq() 70 | 71 | // Requests 72 | 73 | if !req1.Eq(req2) { 74 | t.Error("failed eq") 75 | } 76 | 77 | if !req2.Eq(req1) { 78 | t.Error("failed eq") 79 | } 80 | 81 | req1.Header = map[string][]string{ 82 | "Foo": []string{"Bar", "Baz"}, 83 | "Foo2": []string{"Bar2", "Baz2"}, 84 | "Cookie": []string{"cookie=cocks"}, 85 | } 86 | req2.Header = map[string][]string{ 87 | "Foo": []string{"Bar", "Baz"}, 88 | "Foo2": []string{"Bar2", "Baz2"}, 89 | "Cookie": []string{"cookie=cocks"}, 90 | } 91 | 92 | if !req1.Eq(req2) { 93 | t.Error("failed eq") 94 | } 95 | 96 | req2.Header = map[string][]string{ 97 | "Foo": []string{"Baz", "Bar"}, 98 | "Foo2": []string{"Bar2", "Baz2"}, 99 | "Cookie": []string{"cookie=cocks"}, 100 | } 101 | if req1.Eq(req2) { 102 | t.Error("failed eq") 103 | } 104 | 105 | req2.Header = map[string][]string{ 106 | "Foo": []string{"Bar", "Baz"}, 107 | "Foo2": []string{"Bar2", "Baz2"}, 108 | "Cookie": []string{"cookiee=cocks"}, 109 | } 110 | if req1.Eq(req2) { 111 | t.Error("failed eq") 112 | } 113 | 114 | req2 = testReq() 115 | req2.URL.Host = "foobar" 116 | if req1.Eq(req2) { 117 | t.Error("failed eq") 118 | } 119 | req2 = testReq() 120 | 121 | // Responses 122 | 123 | if !req1.ServerResponse.Eq(req2.ServerResponse) { 124 | t.Error("failed eq") 125 | } 126 | 127 | if !req2.ServerResponse.Eq(req1.ServerResponse) { 128 | t.Error("failed eq") 129 | } 130 | 131 | req2.ServerResponse.StatusCode = 404 132 | if req1.ServerResponse.Eq(req2.ServerResponse) { 133 | t.Error("failed eq") 134 | } 135 | 136 | } 137 | 138 | func TestDeepClone(t *testing.T) { 139 | req1 := testReq() 140 | req2 := req1.DeepClone() 141 | 142 | if !req1.Eq(req2) { 143 | t.Errorf("cloned request does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", 144 | string(req1.FullMessage()), string(req2.FullMessage())) 145 | } 146 | 147 | if !req1.ServerResponse.Eq(req2.ServerResponse) { 148 | t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", 149 | string(req1.ServerResponse.FullMessage()), string(req2.ServerResponse.FullMessage())) 150 | } 151 | 152 | rsp1 := req1.ServerResponse.Clone() 153 | rsp1.Status = "foobarbaz" 154 | rsp2 := rsp1.Clone() 155 | if !rsp1.Eq(rsp2) { 156 | t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", 157 | string(rsp1.FullMessage()), string(rsp2.FullMessage())) 158 | } 159 | 160 | rsp1 = req1.ServerResponse.Clone() 161 | rsp1.ProtoMinor = 7 162 | rsp2 = rsp1.Clone() 163 | if !rsp1.Eq(rsp2) { 164 | t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", 165 | string(rsp1.FullMessage()), string(rsp2.FullMessage())) 166 | } 167 | 168 | rsp1 = req1.ServerResponse.Clone() 169 | rsp1.StatusCode = 234 170 | rsp2 = rsp1.Clone() 171 | if !rsp1.Eq(rsp2) { 172 | t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", 173 | string(rsp1.FullMessage()), string(rsp2.FullMessage())) 174 | } 175 | } 176 | 177 | // func TestFromBytes(t *testing.T) { 178 | // rsp, err := ProxyResponseFromBytes([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\nAAAA")) 179 | // if err != nil { 180 | // panic(err) 181 | // } 182 | // checkStr(t, string(rsp.BodyBytes()), "AAAA") 183 | // checkStr(t, string(rsp.Header.Get("Content-Length")[0]), "4") 184 | 185 | // //rspbytes := []byte("HTTP/1.0 200 OK\r\nServer: BaseHTTP/0.3 Python/2.7.11\r\nDate: Fri, 10 Mar 2017 18:21:27 GMT\r\n\r\nCLIENT VALUES:\nclient_address=('127.0.0.1', 62069) (1.0.0.127.in-addr.arpa)\ncommand=GET\npath=/?foo=foobar\nreal path=/\nquery=foo=foobar\nrequest_version=HTTP/1.1\n\nSERVER VALUES:\nserver_version=BaseHTTP/0.3\nsys_version=Python/2.7.11\nprotocol_version=HTTP/1.0") 186 | // rspbytes := []byte("HTTP/1.0 200 OK\r\n\r\nAAAA") 187 | // buf := bytes.NewBuffer(rspbytes) 188 | // httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil) 189 | // httpRsp.Close = false 190 | // //rsp2 := NewProxyResponse(httpRsp) 191 | // buf2 := bytes.NewBuffer(make([]byte, 0)) 192 | // httpRsp.Write(buf2) 193 | // httpRsp2, err := http.ReadResponse(bufio.NewReader(buf2), nil) 194 | // // fmt.Println(string(rsp2.FullMessage())) 195 | // // fmt.Println(rsp2.Header) 196 | // // if len(rsp2.Header["Connection"]) > 1 { 197 | // // t.Errorf("too many connection headers") 198 | // // } 199 | // } 200 | -------------------------------------------------------------------------------- /proxylistener.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/tls" 7 | "fmt" 8 | "io/ioutil" 9 | "log" 10 | "net" 11 | "net/http" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "time" 16 | 17 | "github.com/deckarep/golang-set" 18 | ) 19 | 20 | const ( 21 | ProxyStopped = iota 22 | ProxyStarting 23 | ProxyRunning 24 | ) 25 | 26 | var getNextConnId = IdCounter() 27 | var getNextListenerId = IdCounter() 28 | 29 | type internalAddr struct{} 30 | 31 | func (internalAddr) Network() string { 32 | return "" 33 | } 34 | 35 | func (internalAddr) String() string { 36 | return "" 37 | } 38 | 39 | /* 40 | ProxyConn which is the same as a net.Conn but implements Peek() and variales to store target host data 41 | */ 42 | type ProxyConn interface { 43 | net.Conn 44 | 45 | Id() int 46 | Logger() *log.Logger 47 | 48 | // Set the CA certificate to be used to sign TLS connections 49 | SetCACertificate(*tls.Certificate) 50 | 51 | // If the connection tries to start TLS, attempt to strip it so that further reads will get the decrypted text, otherwise it will just pass the plaintext 52 | StartMaybeTLS(hostname string) (bool, error) 53 | 54 | // Have all requests produced by this connection have the given destination information. Removes the need for requests generated by this connection to be aware they are being submitted through a proxy 55 | SetTransparentMode(destHost string, destPort int, useTLS bool) 56 | 57 | // End transparent mode 58 | EndTransparentMode() 59 | } 60 | 61 | type proxyAddr struct { 62 | Host string 63 | Port int // can probably do a uint16 or something but whatever 64 | UseTLS bool 65 | } 66 | 67 | type proxyConn struct { 68 | Addr *proxyAddr 69 | logger *log.Logger 70 | id int 71 | conn net.Conn // Wrapped connection 72 | readReq *http.Request // A replaced request 73 | caCert *tls.Certificate 74 | mtx sync.Mutex 75 | 76 | transparentMode bool 77 | } 78 | 79 | // Encode the destination information to be stored in the remote address 80 | func EncodeRemoteAddr(host string, port int, useTLS bool) string { 81 | var tlsInt int 82 | if useTLS { 83 | tlsInt = 1 84 | } else { 85 | tlsInt = 0 86 | } 87 | return fmt.Sprintf("%s/%d/%d", host, port, tlsInt) 88 | } 89 | 90 | // Decode destination information from a remote address 91 | func DecodeRemoteAddr(addrStr string) (host string, port int, useTLS bool, err error) { 92 | parts := strings.Split(addrStr, "/") 93 | if len(parts) != 3 { 94 | err = fmt.Errorf("Error parsing addrStr: %s", addrStr) 95 | return 96 | } 97 | 98 | host = parts[0] 99 | 100 | port, err = strconv.Atoi(parts[1]) 101 | if err != nil { 102 | return 103 | } 104 | 105 | useTLSInt, err := strconv.Atoi(parts[2]) 106 | if err != nil { 107 | return 108 | } 109 | 110 | if useTLSInt == 0 { 111 | useTLS = false 112 | } else { 113 | useTLS = true 114 | } 115 | 116 | return 117 | } 118 | 119 | func (a *proxyAddr) Network() string { 120 | return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS) 121 | } 122 | 123 | func (a *proxyAddr) String() string { 124 | return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS) 125 | } 126 | 127 | //// bufferedConn and wrappers 128 | type bufferedConn struct { 129 | reader *bufio.Reader 130 | net.Conn // Embed conn 131 | } 132 | 133 | func (c bufferedConn) Peek(n int) ([]byte, error) { 134 | return c.reader.Peek(n) 135 | } 136 | 137 | func (c bufferedConn) Read(p []byte) (int, error) { 138 | return c.reader.Read(p) 139 | } 140 | 141 | //// Implement net.Conn 142 | 143 | func (c *proxyConn) Read(b []byte) (n int, err error) { 144 | if c.readReq != nil { 145 | buf := new(bytes.Buffer) 146 | c.readReq.Write(buf) 147 | s := buf.String() 148 | n = 0 149 | for n = 0; n < len(b) && n < len(s); n++ { 150 | b[n] = s[n] 151 | } 152 | c.readReq = nil 153 | return n, nil 154 | } 155 | if c.conn == nil { 156 | return 0, fmt.Errorf("ProxyConn %d does not have an active connection", c.Id()) 157 | } 158 | return c.conn.Read(b) 159 | } 160 | 161 | func (c *proxyConn) Write(b []byte) (n int, err error) { 162 | return c.conn.Write(b) 163 | } 164 | 165 | func (c *proxyConn) Close() error { 166 | return c.conn.Close() 167 | } 168 | 169 | func (c *proxyConn) SetDeadline(t time.Time) error { 170 | return c.conn.SetDeadline(t) 171 | } 172 | 173 | func (c *proxyConn) SetReadDeadline(t time.Time) error { 174 | return c.conn.SetReadDeadline(t) 175 | } 176 | 177 | func (c *proxyConn) SetWriteDeadline(t time.Time) error { 178 | return c.conn.SetWriteDeadline(t) 179 | } 180 | 181 | func (c *proxyConn) LocalAddr() net.Addr { 182 | return c.conn.LocalAddr() 183 | } 184 | 185 | func (c *proxyConn) RemoteAddr() net.Addr { 186 | // RemoteAddr encodes the destination server for this connection 187 | return c.Addr 188 | } 189 | 190 | //// Implement ProxyConn 191 | 192 | func (pconn *proxyConn) Id() int { 193 | pconn.mtx.Lock() 194 | defer pconn.mtx.Unlock() 195 | 196 | return pconn.id 197 | } 198 | 199 | func (pconn *proxyConn) Logger() *log.Logger { 200 | pconn.mtx.Lock() 201 | defer pconn.mtx.Unlock() 202 | 203 | return pconn.logger 204 | } 205 | 206 | func (pconn *proxyConn) SetCACertificate(cert *tls.Certificate) { 207 | pconn.mtx.Lock() 208 | defer pconn.mtx.Unlock() 209 | 210 | pconn.caCert = cert 211 | } 212 | 213 | func (pconn *proxyConn) StartMaybeTLS(hostname string) (bool, error) { 214 | // Prepares to start doing TLS if the client starts. Returns whether TLS was started 215 | 216 | // Wrap the ProxyConn's net.Conn in a bufferedConn 217 | pconn.mtx.Lock() 218 | defer pconn.mtx.Unlock() 219 | 220 | bufConn := bufferedConn{bufio.NewReader(pconn.conn), pconn.conn} 221 | usingTLS := false 222 | 223 | // Guess if we're doing TLS 224 | byte, err := bufConn.Peek(1) 225 | if err != nil { 226 | return false, err 227 | } 228 | if byte[0] == '\x16' { 229 | usingTLS = true 230 | } 231 | 232 | if usingTLS { 233 | if err != nil { 234 | return false, err 235 | } 236 | 237 | cert, err := signHost(*pconn.caCert, []string{hostname}) 238 | if err != nil { 239 | return false, err 240 | } 241 | 242 | config := &tls.Config{ 243 | InsecureSkipVerify: true, 244 | Certificates: []tls.Certificate{cert}, 245 | } 246 | tlsConn := tls.Server(bufConn, config) 247 | pconn.conn = tlsConn 248 | return true, nil 249 | } else { 250 | pconn.conn = bufConn 251 | return false, nil 252 | } 253 | } 254 | 255 | func (pconn *proxyConn) SetTransparentMode(destHost string, destPort int, useTLS bool) { 256 | pconn.mtx.Lock() 257 | defer pconn.mtx.Unlock() 258 | 259 | pconn.Addr = &proxyAddr{Host: destHost, 260 | Port: destPort, 261 | UseTLS: useTLS, 262 | } 263 | pconn.transparentMode = true 264 | } 265 | 266 | func (pconn *proxyConn) EndTransparentMode() { 267 | pconn.mtx.Lock() 268 | defer pconn.mtx.Unlock() 269 | 270 | pconn.transparentMode = false 271 | } 272 | 273 | func newProxyConn(c net.Conn, l *log.Logger) *proxyConn { 274 | // converts a connection into a proxyConn 275 | a := proxyAddr{Host: "", Port: -1, UseTLS: false} 276 | p := proxyConn{Addr: &a, logger: l, conn: c, readReq: nil} 277 | p.id = getNextConnId() 278 | p.transparentMode = false 279 | return &p 280 | } 281 | 282 | func (pconn *proxyConn) returnRequest(req *http.Request) { 283 | pconn.mtx.Lock() 284 | defer pconn.mtx.Unlock() 285 | 286 | pconn.readReq = req 287 | } 288 | 289 | /* 290 | Implements net.Listener. Listeners can be added. Will accept 291 | connections on each listener and read HTTP messages from the 292 | connection. Will attempt to spoof TLS from incoming HTTP 293 | requests. Accept() returns a ProxyConn which transmists one 294 | unencrypted HTTP request and contains the intended destination for 295 | each request in the RemoteAddr. 296 | */ 297 | type ProxyListener struct { 298 | net.Listener 299 | 300 | // The current state of the listener 301 | State int 302 | 303 | inputListeners mapset.Set 304 | mtx sync.Mutex 305 | logger *log.Logger 306 | outputConns chan ProxyConn 307 | inputConns chan *inputConn 308 | outputConnDone chan struct{} 309 | inputConnDone chan struct{} 310 | listenWg sync.WaitGroup 311 | caCert *tls.Certificate 312 | } 313 | 314 | type inputConn struct { 315 | listener *ProxyListener 316 | conn net.Conn 317 | 318 | transparentMode bool 319 | transparentAddr *proxyAddr 320 | } 321 | 322 | type listenerData struct { 323 | Id int 324 | Listener net.Listener 325 | } 326 | 327 | func newListenerData(listener net.Listener) *listenerData { 328 | l := listenerData{} 329 | l.Id = getNextListenerId() 330 | l.Listener = listener 331 | return &l 332 | } 333 | 334 | // NewProxyListener creates and starts a new proxy listener that will log to the given logger 335 | func NewProxyListener(logger *log.Logger) *ProxyListener { 336 | var useLogger *log.Logger 337 | if logger != nil { 338 | useLogger = logger 339 | } else { 340 | useLogger = log.New(ioutil.Discard, "[*] ", log.Lshortfile) 341 | } 342 | l := ProxyListener{logger: useLogger, State: ProxyStarting} 343 | l.inputListeners = mapset.NewSet() 344 | 345 | l.outputConns = make(chan ProxyConn) 346 | l.inputConns = make(chan *inputConn) 347 | l.outputConnDone = make(chan struct{}) 348 | l.inputConnDone = make(chan struct{}) 349 | 350 | // Translate connections 351 | l.listenWg.Add(1) 352 | go func() { 353 | l.logger.Println("Starting connection translator...") 354 | defer l.listenWg.Done() 355 | for { 356 | select { 357 | case <-l.outputConnDone: 358 | l.logger.Println("Output channel closed. Shutting down translator.") 359 | return 360 | case inconn := <-l.inputConns: 361 | go func() { 362 | err := l.translateConn(inconn) 363 | if err != nil { 364 | l.logger.Println("Could not translate connection:", err) 365 | } 366 | }() 367 | } 368 | } 369 | }() 370 | 371 | l.State = ProxyRunning 372 | l.logger.Println("Proxy Started") 373 | 374 | return &l 375 | } 376 | 377 | // Accept accepts a new connection from any of its listeners 378 | func (listener *ProxyListener) Accept() (net.Conn, error) { 379 | if listener.outputConns == nil || 380 | listener.inputConns == nil || 381 | listener.outputConnDone == nil || 382 | listener.inputConnDone == nil { 383 | return nil, fmt.Errorf("Listener not initialized! Cannot accept connection.") 384 | 385 | } 386 | select { 387 | case <-listener.outputConnDone: 388 | listener.logger.Println("Cannot accept connection, ProxyListener is closed") 389 | return nil, fmt.Errorf("Connection is closed") 390 | case c := <-listener.outputConns: 391 | listener.logger.Println("Connection", c.Id(), "accepted from ProxyListener") 392 | return c, nil 393 | } 394 | } 395 | 396 | // Close closes all of the listeners associated with the ProxyListener 397 | func (listener *ProxyListener) Close() error { 398 | listener.mtx.Lock() 399 | defer listener.mtx.Unlock() 400 | 401 | listener.logger.Println("Closing ProxyListener...") 402 | listener.State = ProxyStopped 403 | close(listener.outputConnDone) 404 | close(listener.inputConnDone) 405 | close(listener.outputConns) 406 | close(listener.inputConns) 407 | 408 | it := listener.inputListeners.Iterator() 409 | for elem := range it.C { 410 | l := elem.(*listenerData) 411 | l.Listener.Close() 412 | listener.logger.Println("Closed listener", l.Id) 413 | } 414 | listener.logger.Println("ProxyListener closed") 415 | listener.listenWg.Wait() 416 | return nil 417 | } 418 | 419 | func (listener *ProxyListener) Addr() net.Addr { 420 | return internalAddr{} 421 | } 422 | 423 | // AddListener adds a listener for the ProxyListener to listen on 424 | func (listener *ProxyListener) AddListener(inlisten net.Listener) error { 425 | listener.mtx.Lock() 426 | defer listener.mtx.Unlock() 427 | return listener.addListener(inlisten, false, nil) 428 | } 429 | 430 | // AddTransparentListener is the same as AddListener, but all of the connections will be in transparent mode 431 | func (listener *ProxyListener) AddTransparentListener(inlisten net.Listener, destHost string, destPort int, useTLS bool) error { 432 | listener.mtx.Lock() 433 | defer listener.mtx.Unlock() 434 | addr := &proxyAddr{ 435 | Host: destHost, 436 | Port: destPort, 437 | UseTLS: useTLS, 438 | } 439 | return listener.addListener(inlisten, true, addr) 440 | } 441 | 442 | func (listener *ProxyListener) addListener(inlisten net.Listener, transparentMode bool, destAddr *proxyAddr) error { 443 | listener.logger.Println("Adding listener to ProxyListener:", inlisten) 444 | il := newListenerData(inlisten) 445 | l := listener 446 | listener.listenWg.Add(1) 447 | go func() { 448 | defer l.listenWg.Done() 449 | for { 450 | c, err := il.Listener.Accept() 451 | if err != nil { 452 | // TODO: verify that the connection is actually closed and not some other error 453 | l.logger.Println("Listener", il.Id, "closed") 454 | return 455 | } 456 | l.logger.Println("Received conn form listener", il.Id) 457 | newConn := &inputConn{ 458 | conn: c, 459 | listener: nil, 460 | transparentMode: transparentMode, 461 | transparentAddr: destAddr, 462 | } 463 | l.inputConns <- newConn 464 | } 465 | }() 466 | listener.inputListeners.Add(il) 467 | l.logger.Println("Listener", il.Id, "added to ProxyListener") 468 | return nil 469 | } 470 | 471 | // RemoveListener closes a listener and removes it from the ProxyListener. Does not kill active connections. 472 | func (listener *ProxyListener) RemoveListener(inlisten net.Listener) error { 473 | listener.mtx.Lock() 474 | defer listener.mtx.Unlock() 475 | 476 | listener.inputListeners.Remove(inlisten) 477 | inlisten.Close() 478 | listener.logger.Println("Listener removed:", inlisten) 479 | return nil 480 | } 481 | 482 | // TKTK working here 483 | // Take in a connection, strip TLS, get destination info, and push a ProxyConn to the listener.outputConnection channel 484 | func (listener *ProxyListener) translateConn(inconn *inputConn) error { 485 | pconn := newProxyConn(inconn.conn, listener.logger) 486 | pconn.SetCACertificate(listener.GetCACertificate()) 487 | if inconn.transparentMode { 488 | pconn.SetTransparentMode(inconn.transparentAddr.Host, 489 | inconn.transparentAddr.Port, 490 | inconn.transparentAddr.UseTLS) 491 | } 492 | 493 | var host string = "" 494 | var port int = -1 495 | var useTLS bool = false 496 | 497 | request, err := http.ReadRequest(bufio.NewReader(pconn)) 498 | if err != nil { 499 | listener.logger.Println(err) 500 | return err 501 | } 502 | 503 | // Get parsed host and port 504 | parsed_host, sport, err := net.SplitHostPort(request.URL.Host) 505 | if err != nil { 506 | // Assume that that URL.Host is the hostname and doesn't contain a port 507 | host = request.URL.Host 508 | port = -1 509 | } else { 510 | parsed_port, err := strconv.Atoi(sport) 511 | if err != nil { 512 | // Assume that that URL.Host is the hostname and doesn't contain a port 513 | return fmt.Errorf("Error parsing hostname: %s", err) 514 | } 515 | host = parsed_host 516 | port = parsed_port 517 | } 518 | 519 | // Handle CONNECT and TLS 520 | if request.Method == "CONNECT" { 521 | // Respond that we connected 522 | resp := http.Response{Status: "Connection established", Proto: "HTTP/1.1", ProtoMajor: 1, StatusCode: 200} 523 | err := resp.Write(inconn.conn) 524 | if err != nil { 525 | listener.logger.Println("Could not write CONNECT response:", err) 526 | return err 527 | } 528 | 529 | usedTLS, err := pconn.StartMaybeTLS(host) 530 | if err != nil { 531 | listener.logger.Println("Error starting maybeTLS:", err) 532 | return err 533 | } 534 | useTLS = usedTLS 535 | } else { 536 | // Put the request back 537 | pconn.returnRequest(request) 538 | useTLS = false 539 | } 540 | 541 | // Guess the port if we have to 542 | if port == -1 { 543 | if useTLS { 544 | port = 443 545 | } else { 546 | port = 80 547 | } 548 | } 549 | 550 | if !pconn.transparentMode { 551 | pconn.Addr.Host = host 552 | pconn.Addr.Port = port 553 | pconn.Addr.UseTLS = useTLS 554 | } 555 | 556 | var useTLSStr string 557 | if pconn.Addr.UseTLS { 558 | useTLSStr = "YES" 559 | } else { 560 | useTLSStr = "NO" 561 | } 562 | pconn.Logger().Printf("Received connection to: Host='%s', Port=%d, UseTls=%s", pconn.Addr.Host, pconn.Addr.Port, useTLSStr) 563 | 564 | // Put the conn in the output channel 565 | listener.outputConns <- pconn 566 | return nil 567 | } 568 | 569 | // SetCACertificate sets which certificate the listener should be used when spoofing TLS 570 | func (listener *ProxyListener) SetCACertificate(caCert *tls.Certificate) { 571 | listener.mtx.Lock() 572 | defer listener.mtx.Unlock() 573 | 574 | listener.caCert = caCert 575 | } 576 | 577 | // SetCACertificate gets which certificate the listener is using when spoofing TLS 578 | func (listener *ProxyListener) GetCACertificate() *tls.Certificate { 579 | listener.mtx.Lock() 580 | defer listener.mtx.Unlock() 581 | 582 | return listener.caCert 583 | } 584 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "runtime" 9 | "sort" 10 | "strings" 11 | ) 12 | 13 | type schemaUpdater func(tx *sql.Tx) error 14 | 15 | type tableNameRow struct { 16 | name string 17 | } 18 | 19 | var schemaUpdaters = []schemaUpdater{ 20 | schema8, 21 | schema9, 22 | schema10, 23 | } 24 | 25 | func UpdateSchema(db *sql.DB, logger *log.Logger) error { 26 | currSchemaVersion := 0 27 | var tableName string 28 | if err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_meta';").Scan(&tableName); err == sql.ErrNoRows { 29 | logger.Println("No datafile schema, initializing schema") 30 | currSchemaVersion = -1 31 | } else if err != nil { 32 | return err 33 | } else { 34 | svr := new(int) 35 | if err := db.QueryRow("SELECT version FROM schema_meta;").Scan(svr); err != nil { 36 | return err 37 | } 38 | currSchemaVersion = *svr 39 | if currSchemaVersion-7 < len(schemaUpdaters) { 40 | logger.Println("Schema out of date. Updating...") 41 | } 42 | } 43 | 44 | if currSchemaVersion >= 0 && currSchemaVersion < 8 { 45 | return fmt.Errorf("This is a PappyProxy datafile that is not the most recent schema version supported by PappyProxy. Load this datafile using the most recent version of Pappy to upgrade the schema and try importing it again.") 46 | } 47 | 48 | var updaterInd = 0 49 | if currSchemaVersion > 0 { 50 | updaterInd = currSchemaVersion - 7 51 | } 52 | 53 | if currSchemaVersion-7 < len(schemaUpdaters) { 54 | tx, err := db.Begin() 55 | if err != nil { 56 | return err 57 | } 58 | for i := updaterInd; i < len(schemaUpdaters); i++ { 59 | logger.Printf("Updating schema to version %d...", i+8) 60 | err := schemaUpdaters[i](tx) 61 | if err != nil { 62 | logger.Println("Error updating schema:", err) 63 | logger.Println("Rolling back") 64 | tx.Rollback() 65 | return err 66 | } 67 | } 68 | logger.Printf("Schema update successful") 69 | tx.Commit() 70 | } 71 | return nil 72 | } 73 | 74 | func execute(tx *sql.Tx, cmd string) error { 75 | err := executeNoDebug(tx, cmd) 76 | if err != nil { 77 | _, f, ln, _ := runtime.Caller(1) 78 | return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error()) 79 | } 80 | return nil 81 | } 82 | 83 | func executeNoDebug(tx *sql.Tx, cmd string) error { 84 | stmt, err := tx.Prepare(cmd) 85 | defer stmt.Close() 86 | if err != nil { 87 | return err 88 | } 89 | 90 | if _, err := tx.Stmt(stmt).Exec(); err != nil { 91 | return err 92 | } 93 | return nil 94 | } 95 | 96 | func executeMultiple(tx *sql.Tx, cmds []string) error { 97 | for _, cmd := range cmds { 98 | err := executeNoDebug(tx, cmd) 99 | if err != nil { 100 | _, f, ln, _ := runtime.Caller(1) 101 | return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error()) 102 | } 103 | } 104 | return nil 105 | } 106 | 107 | /* 108 | SCHEMA 8 / INITIAL 109 | */ 110 | 111 | func schema8(tx *sql.Tx) error { 112 | // Create a schema that is the same as pappy's last version 113 | 114 | cmds := []string{ 115 | 116 | ` 117 | CREATE TABLE schema_meta ( 118 | version INTEGER NOT NULL 119 | ); 120 | `, 121 | 122 | ` 123 | INSERT INTO "schema_meta" VALUES(8); 124 | `, 125 | 126 | ` 127 | CREATE TABLE responses ( 128 | id INTEGER PRIMARY KEY AUTOINCREMENT, 129 | full_response BLOB NOT NULL, 130 | unmangled_id INTEGER REFERENCES responses(id) 131 | ); 132 | `, 133 | 134 | ` 135 | CREATE TABLE scope ( 136 | filter_order INTEGER NOT NULL, 137 | filter_string TEXT NOT NULL 138 | ); 139 | `, 140 | 141 | ` 142 | CREATE TABLE tags ( 143 | id INTEGER PRIMARY KEY AUTOINCREMENT, 144 | tag TEXT NOT NULL 145 | ); 146 | `, 147 | 148 | ` 149 | CREATE TABLE tagged ( 150 | reqid INTEGER, 151 | tagid INTEGER 152 | ); 153 | `, 154 | 155 | ` 156 | CREATE TABLE "requests" ( 157 | id INTEGER PRIMARY KEY AUTOINCREMENT, 158 | full_request BLOB NOT NULL, 159 | submitted INTEGER NOT NULL, 160 | response_id INTEGER REFERENCES responses(id), 161 | unmangled_id INTEGER REFERENCES requests(id), 162 | port INTEGER, 163 | is_ssl INTEGER, 164 | host TEXT, 165 | plugin_data TEXT, 166 | start_datetime REAL, 167 | end_datetime REAL 168 | ); 169 | `, 170 | 171 | ` 172 | CREATE TABLE saved_contexts ( 173 | id INTEGER PRIMARY KEY AUTOINCREMENT, 174 | context_name TEXT UNIQUE, 175 | filter_strings TEXT 176 | ); 177 | `, 178 | 179 | ` 180 | CREATE TABLE websocket_messages ( 181 | id INTEGER PRIMARY KEY AUTOINCREMENT, 182 | parent_request INTEGER REFERENCES requests(id), 183 | unmangled_id INTEGER REFERENCES websocket_messages(id), 184 | is_binary INTEGER, 185 | direction INTEGER, 186 | time_sent REAL, 187 | contents BLOB 188 | ); 189 | `, 190 | 191 | ` 192 | CREATE INDEX ind_start_time ON requests(start_datetime); 193 | `, 194 | } 195 | 196 | err := executeMultiple(tx, cmds) 197 | if err != nil { 198 | return err 199 | } 200 | 201 | return nil 202 | } 203 | 204 | /* 205 | SCHEMA 9 206 | */ 207 | 208 | func pappyFilterToStrArgList(f string) ([]string, error) { 209 | parts := strings.Split(f, " ") 210 | 211 | // Validate the arguments 212 | goArgs, err := CheckArgsStrToGo(parts) 213 | if err != nil { 214 | return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err) 215 | } 216 | 217 | strArgs, err := CheckArgsGoToStr(goArgs) 218 | if err != nil { 219 | return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err) 220 | } 221 | 222 | return strArgs, nil 223 | } 224 | 225 | func pappyListToStrMessageQuery(f []string) (StrMessageQuery, error) { 226 | retFilter := make(StrMessageQuery, len(f)) 227 | 228 | for i, s := range f { 229 | strArgs, err := pappyFilterToStrArgList(s) 230 | if err != nil { 231 | return nil, err 232 | } 233 | 234 | newPhrase := make(StrQueryPhrase, 1) 235 | newPhrase[0] = strArgs 236 | 237 | retFilter[i] = newPhrase 238 | } 239 | 240 | return retFilter, nil 241 | } 242 | 243 | type s9ScopeStr struct { 244 | Order int64 245 | Filter string 246 | } 247 | 248 | type s9ScopeSort []*s9ScopeStr 249 | 250 | func (ls s9ScopeSort) Len() int { 251 | return len(ls) 252 | } 253 | 254 | func (ls s9ScopeSort) Swap(i int, j int) { 255 | ls[i], ls[j] = ls[j], ls[i] 256 | } 257 | 258 | func (ls s9ScopeSort) Less(i int, j int) bool { 259 | return ls[i].Order < ls[j].Order 260 | } 261 | 262 | func schema9(tx *sql.Tx) error { 263 | /* 264 | Converts the floating point timestamps into integers representing nanoseconds from jan 1 1970 265 | */ 266 | 267 | // Rename the old requests table 268 | if err := execute(tx, "ALTER TABLE requests RENAME TO requests_old"); err != nil { 269 | return err 270 | } 271 | 272 | if err := execute(tx, "ALTER TABLE websocket_messages RENAME TO websocket_messages_old"); err != nil { 273 | return err 274 | } 275 | 276 | // Create new requests table with integer datetime 277 | cmds := []string{` 278 | CREATE TABLE "requests" ( 279 | id INTEGER PRIMARY KEY AUTOINCREMENT, 280 | full_request BLOB NOT NULL, 281 | submitted INTEGER NOT NULL, 282 | response_id INTEGER REFERENCES responses(id), 283 | unmangled_id INTEGER REFERENCES requests(id), 284 | port INTEGER, 285 | is_ssl INTEGER, 286 | host TEXT, 287 | plugin_data TEXT, 288 | start_datetime INTEGER, 289 | end_datetime INTEGER 290 | ); 291 | `, 292 | 293 | ` 294 | INSERT INTO requests 295 | SELECT id, full_request, submitted, response_id, unmangled_id, port, is_ssl, host, plugin_data, 0, 0 296 | FROM requests_old 297 | `, 298 | 299 | ` 300 | CREATE TABLE websocket_messages ( 301 | id INTEGER PRIMARY KEY AUTOINCREMENT, 302 | parent_request INTEGER REFERENCES requests(id), 303 | unmangled_id INTEGER REFERENCES websocket_messages(id), 304 | is_binary INTEGER, 305 | direction INTEGER, 306 | time_sent INTEGER, 307 | contents BLOB 308 | ); 309 | `, 310 | 311 | ` 312 | INSERT INTO websocket_messages 313 | SELECT id, parent_request, unmangled_id, is_binary, direction, 0, contents 314 | FROM websocket_messages_old 315 | `, 316 | } 317 | if err := executeMultiple(tx, cmds); err != nil { 318 | return err 319 | } 320 | 321 | // Update time values to use unix time nanoseconds 322 | rows, err := tx.Query("SELECT id, start_datetime, end_datetime FROM requests_old;") 323 | if err != nil { 324 | return err 325 | } 326 | defer rows.Close() 327 | 328 | var reqid int64 329 | var startDT sql.NullFloat64 330 | var endDT sql.NullFloat64 331 | var newStartDT int64 332 | var newEndDT int64 333 | 334 | for rows.Next() { 335 | if err := rows.Scan(&reqid, &startDT, &endDT); err != nil { 336 | return err 337 | } 338 | 339 | if startDT.Valid { 340 | // Convert to nanoseconds 341 | newStartDT = int64(startDT.Float64 * 1000000000) 342 | } else { 343 | newStartDT = 0 344 | } 345 | 346 | if endDT.Valid { 347 | newEndDT = int64(endDT.Float64 * 1000000000) 348 | } else { 349 | newEndDT = 0 350 | } 351 | 352 | // Save the new value 353 | stmt, err := tx.Prepare("UPDATE requests SET start_datetime=?, end_datetime=? WHERE id=?") 354 | if err != nil { 355 | return err 356 | } 357 | defer stmt.Close() 358 | 359 | if _, err := tx.Stmt(stmt).Exec(newStartDT, newEndDT, reqid); err != nil { 360 | return err 361 | } 362 | } 363 | 364 | // Update websocket time values to use unix time nanoseconds 365 | rows, err = tx.Query("SELECT id, time_sent FROM websocket_messages_old;") 366 | if err != nil { 367 | return err 368 | } 369 | defer rows.Close() 370 | 371 | var wsid int64 372 | var sentDT sql.NullFloat64 373 | var newSentDT int64 374 | 375 | for rows.Next() { 376 | if err := rows.Scan(&wsid, &sentDT); err != nil { 377 | return err 378 | } 379 | 380 | if sentDT.Valid { 381 | // Convert to nanoseconds 382 | newSentDT = int64(startDT.Float64 * 1000000000) 383 | } else { 384 | newSentDT = 0 385 | } 386 | 387 | // Save the new value 388 | stmt, err := tx.Prepare("UPDATE websocket_messages SET time_sent=? WHERE id=?") 389 | if err != nil { 390 | return err 391 | } 392 | defer stmt.Close() 393 | 394 | if _, err := tx.Stmt(stmt).Exec(newSentDT, reqid); err != nil { 395 | return err 396 | } 397 | } 398 | err = rows.Err() 399 | if err != nil { 400 | return err 401 | } 402 | 403 | if err := execute(tx, "DROP TABLE requests_old"); err != nil { 404 | return err 405 | } 406 | 407 | if err := execute(tx, "DROP TABLE websocket_messages_old"); err != nil { 408 | return err 409 | } 410 | 411 | // Update saved contexts 412 | rows, err = tx.Query("SELECT id, context_name, filter_strings FROM saved_contexts") 413 | if err != nil { 414 | return err 415 | } 416 | defer rows.Close() 417 | 418 | var contextId int64 419 | var contextName sql.NullString 420 | var filterStrings sql.NullString 421 | 422 | for rows.Next() { 423 | if err := rows.Scan(&contextId, &contextName, &filterStrings); err != nil { 424 | return err 425 | } 426 | 427 | if !contextName.Valid { 428 | continue 429 | } 430 | 431 | if !filterStrings.Valid { 432 | continue 433 | } 434 | 435 | if contextName.String == "__scope" { 436 | // hopefully this doesn't break anything critical, but we want to store the scope 437 | // as a saved context now with the name __scope 438 | continue 439 | } 440 | 441 | var pappyFilters []string 442 | err = json.Unmarshal([]byte(filterStrings.String), &pappyFilters) 443 | if err != nil { 444 | return err 445 | } 446 | 447 | newFilter, err := pappyListToStrMessageQuery(pappyFilters) 448 | if err != nil { 449 | // We're just ignoring filters that we can't convert :| 450 | continue 451 | } 452 | 453 | newFilterStr, err := json.Marshal(newFilter) 454 | if err != nil { 455 | return err 456 | } 457 | 458 | stmt, err := tx.Prepare("UPDATE saved_contexts SET filter_strings=? WHERE id=?") 459 | if err != nil { 460 | return err 461 | } 462 | defer stmt.Close() 463 | 464 | if _, err := tx.Stmt(stmt).Exec(newFilterStr, contextId); err != nil { 465 | return err 466 | } 467 | } 468 | err = rows.Err() 469 | if err != nil { 470 | return err 471 | } 472 | 473 | // Move scope to a saved context 474 | rows, err = tx.Query("SELECT filter_order, filter_string FROM scope") 475 | if err != nil { 476 | return err 477 | } 478 | defer rows.Close() 479 | 480 | var filterOrder sql.NullInt64 481 | var filterString sql.NullString 482 | 483 | vals := make([]*s9ScopeStr, 0) 484 | for rows.Next() { 485 | if err := rows.Scan(&filterOrder, &filterString); err != nil { 486 | return err 487 | } 488 | 489 | if !filterOrder.Valid { 490 | continue 491 | } 492 | 493 | if !filterString.Valid { 494 | continue 495 | } 496 | 497 | vals = append(vals, &s9ScopeStr{filterOrder.Int64, filterString.String}) 498 | } 499 | err = rows.Err() 500 | if err != nil { 501 | return err 502 | } 503 | 504 | // Put the scope in the right order 505 | sort.Sort(s9ScopeSort(vals)) 506 | 507 | // Convert it into a list of filters 508 | filterList := make([]string, len(vals)) 509 | for i, ss := range vals { 510 | filterList[i] = ss.Filter 511 | } 512 | 513 | newScopeStrFilter, err := pappyListToStrMessageQuery(filterList) 514 | if err != nil { 515 | // We'll only convert the scope if we can, otherwise we'll drop it 516 | err := execute(tx, `INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", "[]")`) 517 | if err != nil { 518 | return err 519 | } 520 | } else { 521 | stmt, err := tx.Prepare(`INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", ?)`) 522 | if err != nil { 523 | return err 524 | } 525 | defer stmt.Close() 526 | 527 | newScopeFilterStr, err := json.Marshal(newScopeStrFilter) 528 | if err != nil { 529 | return err 530 | } 531 | 532 | if _, err := tx.Stmt(stmt).Exec(newScopeFilterStr); err != nil { 533 | return err 534 | } 535 | } 536 | 537 | if err := execute(tx, "DROP TABLE scope"); err != nil { 538 | return err 539 | } 540 | 541 | // Update schema number 542 | if err := execute(tx, `UPDATE schema_meta SET version=9`); err != nil { 543 | return err 544 | } 545 | return nil 546 | } 547 | 548 | func schema10(tx *sql.Tx) error { 549 | /* 550 | Create a "plugin data" table to let applications store app-specific data in the datafile 551 | */ 552 | cmds := []string{` 553 | CREATE TABLE plugin_data ( 554 | id INTEGER PRIMARY KEY AUTOINCREMENT, 555 | key TEXT UNIQUE, 556 | value STRING 557 | ); 558 | CREATE INDEX plugin_key_ind ON plugin_data(key); 559 | `, 560 | 561 | `UPDATE schema_meta SET version=10`, 562 | } 563 | 564 | err := executeMultiple(tx, cmds) 565 | if err != nil { 566 | return err 567 | } 568 | 569 | return nil 570 | } 571 | -------------------------------------------------------------------------------- /search.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | type SearchField int 15 | type StrComparer int 16 | 17 | type strFieldGetter func(req *ProxyRequest) ([]string, error) 18 | type kvFieldGetter func(req *ProxyRequest) ([]*PairValue, error) 19 | 20 | type RequestChecker func(req *ProxyRequest) bool 21 | 22 | // Searchable fields 23 | const ( 24 | FieldAll SearchField = iota 25 | 26 | FieldRequestBody 27 | FieldResponseBody 28 | FieldAllBody 29 | FieldWSMessage 30 | 31 | FieldRequestHeaders 32 | FieldResponseHeaders 33 | FieldBothHeaders 34 | 35 | FieldMethod 36 | FieldHost 37 | FieldPath 38 | FieldURL 39 | FieldStatusCode 40 | FieldTag 41 | 42 | FieldBothParam 43 | FieldURLParam 44 | FieldPostParam 45 | FieldResponseCookie 46 | FieldRequestCookie 47 | FieldBothCookie 48 | 49 | FieldAfter 50 | FieldBefore 51 | FieldTimeRange 52 | 53 | FieldInvert 54 | 55 | FieldId 56 | ) 57 | 58 | // Operators for string values 59 | const ( 60 | StrIs StrComparer = iota 61 | StrContains 62 | StrContainsRegexp 63 | 64 | StrLengthGreaterThan 65 | StrLengthLessThan 66 | StrLengthEqualTo 67 | ) 68 | 69 | // A struct representing the data to be searched for a pair such as a header or url param 70 | type PairValue struct { 71 | key string 72 | value string 73 | } 74 | 75 | // A list of queries. Will match if any queries match the request 76 | type QueryPhrase [][]interface{} 77 | // A list of phrases. Will match if all the phrases match the request 78 | type MessageQuery []QueryPhrase 79 | 80 | // A list of queries in string form. Will match if any queries match the request 81 | type StrQueryPhrase [][]string 82 | // A list of phrases in string form. Will match if all the phrases match the request 83 | type StrMessageQuery []StrQueryPhrase 84 | 85 | // Return a function that returns whether a request matches the given conditions 86 | func NewRequestChecker(args ...interface{}) (RequestChecker, error) { 87 | // Generates a request checker from the given search arguments 88 | if len(args) == 0 { 89 | return nil, errors.New("search requires a search field") 90 | } 91 | 92 | field, ok := args[0].(SearchField) 93 | if !ok { 94 | return nil, fmt.Errorf("first argument must hava a type of SearchField") 95 | } 96 | 97 | switch field { 98 | 99 | // Normal string fields 100 | case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: 101 | getter, err := createstrFieldGetter(field) 102 | if err != nil { 103 | return nil, fmt.Errorf("error performing search: %s", err.Error()) 104 | } 105 | 106 | if len(args) != 3 { 107 | return nil, errors.New("searches through strings must have one checker and one value") 108 | } 109 | 110 | comparer, ok := args[1].(StrComparer) 111 | if !ok { 112 | return nil, errors.New("comparer must be a StrComparer") 113 | } 114 | 115 | return genStrFieldChecker(getter, comparer, args[2]) 116 | 117 | // Normal key/value fields 118 | case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: 119 | getter, err := createKvPairGetter(field) 120 | if err != nil { 121 | return nil, fmt.Errorf("error performing search: %s", err.Error()) 122 | } 123 | 124 | if len(args) == 3 { 125 | // Get comparer and value out of function arguments 126 | comparer, ok := args[1].(StrComparer) 127 | if !ok { 128 | return nil, errors.New("comparer must be a StrComparer") 129 | } 130 | 131 | // Create a strFieldGetter out of our key/value getter 132 | strgetter := func(req *ProxyRequest) ([]string, error) { 133 | pairs, err := getter(req) 134 | if err != nil { 135 | return nil, err 136 | } 137 | return pairsToStrings(pairs), nil 138 | } 139 | 140 | // return a str field checker using our new str getter 141 | return genStrFieldChecker(strgetter, comparer, args[2]) 142 | } else if len(args) == 5 { 143 | // Get comparer and value out of function arguments 144 | comparer1, ok := args[1].(StrComparer) 145 | if !ok { 146 | return nil, errors.New("first comparer must be a StrComparer") 147 | } 148 | 149 | val1, ok := args[2].(string) 150 | if !ok { 151 | return nil, errors.New("first val must be a list of bytes") 152 | } 153 | 154 | comparer2, ok := args[3].(StrComparer) 155 | if !ok { 156 | return nil, errors.New("second comparer must be a StrComparer") 157 | } 158 | 159 | val2, ok := args[4].(string) 160 | if !ok { 161 | return nil, errors.New("second val must be a list of bytes") 162 | } 163 | 164 | // Create a checker out of our getter, comparers, and vals 165 | return genKvFieldChecker(getter, comparer1, val1, comparer2, val2) 166 | } else { 167 | return nil, errors.New("invalid number of arguments for a key/value search") 168 | } 169 | 170 | // Other fields 171 | case FieldAfter: 172 | if len(args) != 2 { 173 | return nil, errors.New("searching by 'after' takes exactly on parameter") 174 | } 175 | 176 | val, ok := args[1].(time.Time) 177 | if !ok { 178 | return nil, errors.New("search argument must be a time.Time") 179 | } 180 | 181 | return func(req *ProxyRequest) bool { 182 | return req.StartDatetime.After(val) 183 | }, nil 184 | 185 | case FieldBefore: 186 | if len(args) != 2 { 187 | return nil, errors.New("searching by 'before' takes exactly one parameter") 188 | } 189 | 190 | val, ok := args[1].(time.Time) 191 | if !ok { 192 | return nil, errors.New("search argument must be a time.Time") 193 | } 194 | 195 | return func(req *ProxyRequest) bool { 196 | return req.StartDatetime.Before(val) 197 | }, nil 198 | 199 | case FieldTimeRange: 200 | if len(args) != 3 { 201 | return nil, errors.New("searching by time range takes exactly two parameters") 202 | } 203 | 204 | begin, ok := args[1].(time.Time) 205 | if !ok { 206 | return nil, errors.New("search arguments must be a time.Time") 207 | } 208 | 209 | end, ok := args[2].(time.Time) 210 | if !ok { 211 | return nil, errors.New("search arguments must be a time.Time") 212 | } 213 | 214 | return func(req *ProxyRequest) bool { 215 | return req.StartDatetime.After(begin) && req.StartDatetime.Before(end) 216 | }, nil 217 | 218 | case FieldInvert: 219 | orig, err := NewRequestChecker(args[1:]...) 220 | if err != nil { 221 | return nil, fmt.Errorf("error with query to invert: %s", err.Error()) 222 | } 223 | return func(req *ProxyRequest) bool { 224 | return !orig(req) 225 | }, nil 226 | 227 | default: 228 | return nil, errors.New("invalid field") 229 | } 230 | } 231 | 232 | func createstrFieldGetter(field SearchField) (strFieldGetter, error) { 233 | switch field { 234 | case FieldAll: 235 | return func(req *ProxyRequest) ([]string, error) { 236 | strs := make([]string, 0) 237 | strs = append(strs, string(req.FullMessage())) 238 | 239 | if req.ServerResponse != nil { 240 | strs = append(strs, string(req.ServerResponse.FullMessage())) 241 | } 242 | 243 | for _, wsm := range req.WSMessages { 244 | strs = append(strs, string(wsm.Message)) 245 | } 246 | 247 | return strs, nil 248 | }, nil 249 | case FieldRequestBody: 250 | return func(req *ProxyRequest) ([]string, error) { 251 | strs := make([]string, 0) 252 | strs = append(strs, string(req.BodyBytes())) 253 | return strs, nil 254 | }, nil 255 | case FieldResponseBody: 256 | return func(req *ProxyRequest) ([]string, error) { 257 | strs := make([]string, 0) 258 | if req.ServerResponse != nil { 259 | strs = append(strs, string(req.ServerResponse.BodyBytes())) 260 | } 261 | return strs, nil 262 | }, nil 263 | case FieldAllBody: 264 | return func(req *ProxyRequest) ([]string, error) { 265 | strs := make([]string, 0) 266 | strs = append(strs, string(req.BodyBytes())) 267 | if req.ServerResponse != nil { 268 | strs = append(strs, string(req.ServerResponse.BodyBytes())) 269 | } 270 | return strs, nil 271 | }, nil 272 | case FieldWSMessage: 273 | return func(req *ProxyRequest) ([]string, error) { 274 | strs := make([]string, 0) 275 | 276 | for _, wsm := range req.WSMessages { 277 | strs = append(strs, string(wsm.Message)) 278 | } 279 | 280 | return strs, nil 281 | }, nil 282 | case FieldMethod: 283 | return func(req *ProxyRequest) ([]string, error) { 284 | strs := make([]string, 0) 285 | strs = append(strs, req.Method) 286 | return strs, nil 287 | }, nil 288 | case FieldHost: 289 | return func(req *ProxyRequest) ([]string, error) { 290 | strs := make([]string, 0) 291 | strs = append(strs, req.DestHost) 292 | strs = append(strs, req.Host) 293 | return strs, nil 294 | }, nil 295 | case FieldPath: 296 | return func(req *ProxyRequest) ([]string, error) { 297 | strs := make([]string, 0) 298 | strs = append(strs, req.URL.Path) 299 | return strs, nil 300 | }, nil 301 | case FieldURL: 302 | return func(req *ProxyRequest) ([]string, error) { 303 | strs := make([]string, 0) 304 | strs = append(strs, req.FullURL().String()) 305 | return strs, nil 306 | }, nil 307 | case FieldStatusCode: 308 | return func(req *ProxyRequest) ([]string, error) { 309 | strs := make([]string, 0) 310 | if req.ServerResponse != nil { 311 | strs = append(strs, strconv.Itoa(req.ServerResponse.StatusCode)) 312 | } 313 | return strs, nil 314 | }, nil 315 | case FieldTag: 316 | return func(req *ProxyRequest) ([]string, error) { 317 | return req.Tags(), nil 318 | }, nil 319 | case FieldId: 320 | return func(req *ProxyRequest) ([]string, error) { 321 | strs := make([]string, 1) 322 | strs[0] = req.DbId 323 | return strs, nil 324 | }, nil 325 | default: 326 | return nil, errors.New("field is not a string") 327 | } 328 | } 329 | 330 | func genStrChecker(cmp StrComparer, argval interface{}) (func(str string) bool, error) { 331 | switch cmp { 332 | case StrContains: 333 | val, ok := argval.(string) 334 | if !ok { 335 | return nil, errors.New("argument must be a string") 336 | } 337 | return func(str string) bool { 338 | if strings.Contains(str, val) { 339 | return true 340 | } 341 | return false 342 | }, nil 343 | case StrIs: 344 | val, ok := argval.(string) 345 | if !ok { 346 | return nil, errors.New("argument must be a string") 347 | } 348 | return func(str string) bool { 349 | if str == val { 350 | return true 351 | } 352 | return false 353 | }, nil 354 | case StrContainsRegexp: 355 | val, ok := argval.(string) 356 | if !ok { 357 | return nil, errors.New("argument must be a string") 358 | } 359 | regex, err := regexp.Compile(string(val)) 360 | if err != nil { 361 | return nil, fmt.Errorf("could not compile regular expression: %s", err.Error()) 362 | } 363 | return func(str string) bool { 364 | return regex.MatchString(string(str)) 365 | }, nil 366 | case StrLengthGreaterThan: 367 | val, ok := argval.(int) 368 | if !ok { 369 | return nil, errors.New("argument must be an integer") 370 | } 371 | return func(str string) bool { 372 | if len(str) > val { 373 | return true 374 | } 375 | return false 376 | }, nil 377 | case StrLengthLessThan: 378 | val, ok := argval.(int) 379 | if !ok { 380 | return nil, errors.New("argument must be an integer") 381 | } 382 | return func(str string) bool { 383 | if len(str) < val { 384 | return true 385 | } 386 | return false 387 | }, nil 388 | case StrLengthEqualTo: 389 | val, ok := argval.(int) 390 | if !ok { 391 | return nil, errors.New("argument must be an integer") 392 | } 393 | return func(str string) bool { 394 | if len(str) == val { 395 | return true 396 | } 397 | return false 398 | }, nil 399 | default: 400 | return nil, errors.New("invalid comparer") 401 | } 402 | } 403 | 404 | func genStrFieldChecker(strGetter strFieldGetter, cmp StrComparer, val interface{}) (RequestChecker, error) { 405 | getter := strGetter 406 | comparer, err := genStrChecker(cmp, val) 407 | if err != nil { 408 | return nil, err 409 | } 410 | 411 | return func(req *ProxyRequest) bool { 412 | strs, err := getter(req) 413 | if err != nil { 414 | panic(err) 415 | } 416 | for _, str := range strs { 417 | if comparer(str) { 418 | return true 419 | } 420 | } 421 | return false 422 | }, nil 423 | } 424 | 425 | func pairValuesFromHeader(header http.Header) []*PairValue { 426 | // Returns a list of pair values from a http.Header 427 | pairs := make([]*PairValue, 0) 428 | for k, vs := range header { 429 | for _, v := range vs { 430 | pair := &PairValue{string(k), string(v)} 431 | pairs = append(pairs, pair) 432 | } 433 | } 434 | return pairs 435 | } 436 | 437 | func pairValuesFromURLQuery(values url.Values) []*PairValue { 438 | // Returns a list of pair values from a http.Header 439 | pairs := make([]*PairValue, 0) 440 | for k, vs := range values { 441 | for _, v := range vs { 442 | pair := &PairValue{string(k), string(v)} 443 | pairs = append(pairs, pair) 444 | } 445 | } 446 | return pairs 447 | } 448 | 449 | func pairValuesFromCookies(cookies []*http.Cookie) []*PairValue { 450 | pairs := make([]*PairValue, 0) 451 | for _, c := range cookies { 452 | pair := &PairValue{string(c.Name), string(c.Value)} 453 | pairs = append(pairs, pair) 454 | } 455 | return pairs 456 | } 457 | 458 | func pairsToStrings(pairs []*PairValue) []string { 459 | // Converts a list of pairs into a list of strings containing all keys and values 460 | // k1: v1, k2: v2 -> ["k1", "v1", "k2", "v2"] 461 | strs := make([]string, 0) 462 | for _, p := range pairs { 463 | strs = append(strs, p.key) 464 | strs = append(strs, p.value) 465 | } 466 | return strs 467 | } 468 | 469 | func createKvPairGetter(field SearchField) (kvFieldGetter, error) { 470 | switch field { 471 | case FieldRequestHeaders: 472 | return func(req *ProxyRequest) ([]*PairValue, error) { 473 | return pairValuesFromHeader(req.Header), nil 474 | }, nil 475 | case FieldResponseHeaders: 476 | return func(req *ProxyRequest) ([]*PairValue, error) { 477 | var pairs []*PairValue 478 | if req.ServerResponse != nil { 479 | pairs = pairValuesFromHeader(req.ServerResponse.Header) 480 | } else { 481 | pairs = make([]*PairValue, 0) 482 | } 483 | return pairs, nil 484 | }, nil 485 | case FieldBothHeaders: 486 | return func(req *ProxyRequest) ([]*PairValue, error) { 487 | pairs := pairValuesFromHeader(req.Header) 488 | if req.ServerResponse != nil { 489 | pairs = append(pairs, pairValuesFromHeader(req.ServerResponse.Header)...) 490 | } 491 | return pairs, nil 492 | }, nil 493 | case FieldBothParam: 494 | return func(req *ProxyRequest) ([]*PairValue, error) { 495 | pairs := pairValuesFromURLQuery(req.URL.Query()) 496 | params, err := req.PostParameters() 497 | if err == nil { 498 | pairs = append(pairs, pairValuesFromURLQuery(params)...) 499 | } 500 | return pairs, nil 501 | }, nil 502 | case FieldURLParam: 503 | return func(req *ProxyRequest) ([]*PairValue, error) { 504 | return pairValuesFromURLQuery(req.URL.Query()), nil 505 | }, nil 506 | case FieldPostParam: 507 | return func(req *ProxyRequest) ([]*PairValue, error) { 508 | params, err := req.PostParameters() 509 | if err != nil { 510 | return nil, err 511 | } 512 | return pairValuesFromURLQuery(params), nil 513 | }, nil 514 | case FieldResponseCookie: 515 | return func(req *ProxyRequest) ([]*PairValue, error) { 516 | pairs := make([]*PairValue, 0) 517 | if req.ServerResponse != nil { 518 | cookies := req.ServerResponse.Cookies() 519 | pairs = append(pairs, pairValuesFromCookies(cookies)...) 520 | } 521 | return pairs, nil 522 | }, nil 523 | case FieldRequestCookie: 524 | return func(req *ProxyRequest) ([]*PairValue, error) { 525 | return pairValuesFromCookies(req.Cookies()), nil 526 | }, nil 527 | case FieldBothCookie: 528 | return func(req *ProxyRequest) ([]*PairValue, error) { 529 | pairs := pairValuesFromCookies(req.Cookies()) 530 | if req.ServerResponse != nil { 531 | cookies := req.ServerResponse.Cookies() 532 | pairs = append(pairs, pairValuesFromCookies(cookies)...) 533 | } 534 | return pairs, nil 535 | }, nil 536 | default: 537 | return nil, errors.New("not implemented") 538 | } 539 | } 540 | 541 | func genKvFieldChecker(kvGetter kvFieldGetter, cmp1 StrComparer, val1 string, 542 | cmp2 StrComparer, val2 string) (RequestChecker, error) { 543 | getter := kvGetter 544 | cmpfunc1, err := genStrChecker(cmp1, val1) 545 | if err != nil { 546 | return nil, err 547 | } 548 | 549 | cmpfunc2, err := genStrChecker(cmp2, val2) 550 | if err != nil { 551 | return nil, err 552 | } 553 | 554 | return func(req *ProxyRequest) bool { 555 | pairs, err := getter(req) 556 | if err != nil { 557 | return false 558 | } 559 | 560 | for _, p := range pairs { 561 | if cmpfunc1(p.key) && cmpfunc2(p.value) { 562 | return true 563 | } 564 | } 565 | return false 566 | }, nil 567 | } 568 | 569 | func checkerFromPhrase(phrase QueryPhrase) (RequestChecker, error) { 570 | checkers := make([]RequestChecker, len(phrase)) 571 | for i, args := range phrase { 572 | newChecker, err := NewRequestChecker(args...) 573 | if err != nil { 574 | return nil, fmt.Errorf("error with search %d: %s", i, err.Error()) 575 | } 576 | checkers[i] = newChecker 577 | } 578 | 579 | ret := func(req *ProxyRequest) bool { 580 | for _, checker := range checkers { 581 | if checker(req) { 582 | return true 583 | } 584 | } 585 | return false 586 | } 587 | 588 | return ret, nil 589 | } 590 | 591 | // Creates a RequestChecker from a MessageQuery 592 | func CheckerFromMessageQuery(query MessageQuery) (RequestChecker, error) { 593 | checkers := make([]RequestChecker, len(query)) 594 | for i, phrase := range query { 595 | newChecker, err := checkerFromPhrase(phrase) 596 | if err != nil { 597 | return nil, fmt.Errorf("error with phrase %d: %s", i, err.Error()) 598 | } 599 | checkers[i] = newChecker 600 | } 601 | 602 | ret := func(req *ProxyRequest) bool { 603 | for _, checker := range checkers { 604 | if !checker(req) { 605 | return false 606 | } 607 | } 608 | return true 609 | } 610 | 611 | return ret, nil 612 | } 613 | 614 | /* 615 | StringSearch conversions 616 | */ 617 | 618 | func fieldGoToString(field SearchField) (string, error) { 619 | switch field { 620 | case FieldAll: 621 | return "all", nil 622 | case FieldRequestBody: 623 | return "reqbody", nil 624 | case FieldResponseBody: 625 | return "rspbody", nil 626 | case FieldAllBody: 627 | return "body", nil 628 | case FieldWSMessage: 629 | return "wsmessage", nil 630 | case FieldRequestHeaders: 631 | return "reqheader", nil 632 | case FieldResponseHeaders: 633 | return "rspheader", nil 634 | case FieldBothHeaders: 635 | return "header", nil 636 | case FieldMethod: 637 | return "method", nil 638 | case FieldHost: 639 | return "host", nil 640 | case FieldPath: 641 | return "path", nil 642 | case FieldURL: 643 | return "url", nil 644 | case FieldStatusCode: 645 | return "statuscode", nil 646 | case FieldBothParam: 647 | return "param", nil 648 | case FieldURLParam: 649 | return "urlparam", nil 650 | case FieldPostParam: 651 | return "postparam", nil 652 | case FieldResponseCookie: 653 | return "rspcookie", nil 654 | case FieldRequestCookie: 655 | return "reqcookie", nil 656 | case FieldBothCookie: 657 | return "cookie", nil 658 | case FieldTag: 659 | return "tag", nil 660 | case FieldAfter: 661 | return "after", nil 662 | case FieldBefore: 663 | return "before", nil 664 | case FieldTimeRange: 665 | return "timerange", nil 666 | case FieldInvert: 667 | return "invert", nil 668 | case FieldId: 669 | return "dbid", nil 670 | default: 671 | return "", errors.New("invalid field") 672 | } 673 | } 674 | 675 | func fieldStrToGo(field string) (SearchField, error) { 676 | switch strings.ToLower(field) { 677 | case "all": 678 | return FieldAll, nil 679 | case "reqbody", "reqbd", "qbd", "qdata", "qdt": 680 | return FieldRequestBody, nil 681 | case "rspbody", "rspbd", "sbd", "sdata", "sdt": 682 | return FieldResponseBody, nil 683 | case "body", "bd", "data", "dt": 684 | return FieldAllBody, nil 685 | case "wsmessage", "wsm": 686 | return FieldWSMessage, nil 687 | case "reqheader", "reqhd", "qhd": 688 | return FieldRequestHeaders, nil 689 | case "rspheader", "rsphd", "shd": 690 | return FieldResponseHeaders, nil 691 | case "header", "hd": 692 | return FieldBothHeaders, nil 693 | case "method", "verb", "vb": 694 | return FieldMethod, nil 695 | case "host", "domain", "hs", "dm": 696 | return FieldHost, nil 697 | case "path", "pt": 698 | return FieldPath, nil 699 | case "url": 700 | return FieldURL, nil 701 | case "statuscode", "sc": 702 | return FieldStatusCode, nil 703 | case "param", "pm": 704 | return FieldBothParam, nil 705 | case "urlparam", "uparam": 706 | return FieldURLParam, nil 707 | case "postparam", "pparam": 708 | return FieldPostParam, nil 709 | case "rspcookie", "rspck", "sck": 710 | return FieldResponseCookie, nil 711 | case "reqcookie", "reqck", "qck": 712 | return FieldRequestCookie, nil 713 | case "cookie", "ck": 714 | return FieldBothCookie, nil 715 | case "tag": 716 | return FieldTag, nil 717 | case "after", "af": 718 | return FieldAfter, nil 719 | case "before", "b4": 720 | return FieldBefore, nil 721 | case "timerange": 722 | return FieldTimeRange, nil 723 | case "invert", "inv": 724 | return FieldInvert, nil 725 | case "dbid": 726 | return FieldId, nil 727 | default: 728 | return 0, fmt.Errorf("invalid field: %s", field) 729 | } 730 | } 731 | 732 | // Converts a StrComparer and a value into a comparer and value that can be used in string queries 733 | func cmpValGoToStr(comparer StrComparer, val interface{}) (string, string, error) { 734 | var cmpStr string 735 | switch comparer { 736 | case StrIs: 737 | cmpStr = "is" 738 | val, ok := val.(string) 739 | if !ok { 740 | return "", "", errors.New("val must be a string") 741 | } 742 | return cmpStr, val, nil 743 | case StrContains: 744 | cmpStr = "contains" 745 | val, ok := val.(string) 746 | if !ok { 747 | return "", "", errors.New("val must be a string") 748 | } 749 | return cmpStr, val, nil 750 | case StrContainsRegexp: 751 | cmpStr = "containsregexp" 752 | val, ok := val.(string) 753 | if !ok { 754 | return "", "", errors.New("val must be a string") 755 | } 756 | return cmpStr, val, nil 757 | case StrLengthGreaterThan: 758 | cmpStr = "lengt" 759 | val, ok := val.(int) 760 | if !ok { 761 | return "", "", errors.New("val must be an int") 762 | } 763 | return cmpStr, strconv.Itoa(val), nil 764 | case StrLengthLessThan: 765 | cmpStr = "lenlt" 766 | val, ok := val.(int) 767 | if !ok { 768 | return "", "", errors.New("val must be an int") 769 | } 770 | return cmpStr, strconv.Itoa(val), nil 771 | case StrLengthEqualTo: 772 | cmpStr = "leneq" 773 | val, ok := val.(int) 774 | if !ok { 775 | return "", "", errors.New("val must be an int") 776 | } 777 | return cmpStr, strconv.Itoa(val), nil 778 | default: 779 | return "", "", errors.New("invalid comparer") 780 | } 781 | } 782 | 783 | func cmpValStrToGo(strArgs []string) (StrComparer, interface{}, error) { 784 | if len(strArgs) != 2 { 785 | return 0, "", fmt.Errorf("parsing a comparer/val requires one comparer and one value. Got %d arguments.", len(strArgs)) 786 | } 787 | 788 | switch strArgs[0] { 789 | case "is": 790 | return StrIs, strArgs[1], nil 791 | case "contains", "ct": 792 | return StrContains, strArgs[1], nil 793 | case "containsregexp", "ctr": 794 | return StrContainsRegexp, strArgs[1], nil 795 | case "lengt": 796 | i, err := strconv.Atoi(strArgs[1]) 797 | if err != nil { 798 | return 0, nil, err 799 | } 800 | return StrLengthGreaterThan, i, nil 801 | case "lenlt": 802 | i, err := strconv.Atoi(strArgs[1]) 803 | if err != nil { 804 | return 0, nil, err 805 | } 806 | return StrLengthLessThan, i, nil 807 | case "leneq": 808 | i, err := strconv.Atoi(strArgs[1]) 809 | if err != nil { 810 | return 0, nil, err 811 | } 812 | return StrLengthEqualTo, i, nil 813 | default: 814 | return 0, "", fmt.Errorf("invalid comparer: %s", strArgs[0]) 815 | } 816 | } 817 | 818 | func CheckArgsStrToGo(strArgs []string) ([]interface{}, error) { 819 | args := make([]interface{}, 0) 820 | if len(strArgs) == 0 { 821 | return nil, errors.New("missing field") 822 | } 823 | 824 | // Parse the field 825 | field, err := fieldStrToGo(strArgs[0]) 826 | if err != nil { 827 | return nil, err 828 | } 829 | args = append(args, field) 830 | 831 | remaining := strArgs[1:] 832 | // Parse the query arguments 833 | switch args[0] { 834 | // Normal string fields 835 | case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: 836 | if len(remaining) != 2 { 837 | return nil, errors.New("string field searches require one comparer and one value") 838 | } 839 | 840 | cmp, val, err := cmpValStrToGo(remaining) 841 | if err != nil { 842 | return nil, err 843 | } 844 | args = append(args, cmp) 845 | args = append(args, val) 846 | // Normal key/value fields 847 | case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: 848 | if len(remaining) == 2 { 849 | cmp, val, err := cmpValStrToGo(remaining) 850 | if err != nil { 851 | return nil, err 852 | } 853 | args = append(args, cmp) 854 | args = append(args, val) 855 | } else if len(remaining) == 4 { 856 | cmp, val, err := cmpValStrToGo(remaining[0:2]) 857 | if err != nil { 858 | return nil, err 859 | } 860 | args = append(args, cmp) 861 | args = append(args, val) 862 | 863 | cmp, val, err = cmpValStrToGo(remaining[2:4]) 864 | if err != nil { 865 | return nil, err 866 | } 867 | args = append(args, cmp) 868 | args = append(args, val) 869 | } else { 870 | return nil, errors.New("key/value field searches require either one comparer and one value or two comparer/value pairs") 871 | } 872 | 873 | // Other fields 874 | case FieldAfter, FieldBefore: 875 | if len(remaining) != 1 { 876 | return nil, errors.New("before/after take exactly one argument") 877 | } 878 | nanoseconds, err := strconv.ParseInt(remaining[0], 10, 64) 879 | if err != nil { 880 | return nil, errors.New("error parsing time") 881 | } 882 | timeVal := time.Unix(0, nanoseconds) 883 | args = append(args, timeVal) 884 | case FieldTimeRange: 885 | if len(remaining) != 2 { 886 | return nil, errors.New("time range takes exactly two arguments") 887 | } 888 | startNanoseconds, err := strconv.ParseInt(remaining[0], 10, 64) 889 | if err != nil { 890 | return nil, errors.New("error parsing start time") 891 | } 892 | startTimeVal := time.Unix(0, startNanoseconds) 893 | args = append(args, startTimeVal) 894 | 895 | endNanoseconds, err := strconv.ParseInt(remaining[1], 10, 64) 896 | if err != nil { 897 | return nil, errors.New("error parsing end time") 898 | } 899 | endTimeVal := time.Unix(0, endNanoseconds) 900 | args = append(args, endTimeVal) 901 | case FieldInvert: 902 | remainingArgs, err := CheckArgsStrToGo(remaining) 903 | if err != nil { 904 | return nil, fmt.Errorf("error with query to invert: %s", err.Error()) 905 | } 906 | args = append(args, remainingArgs...) 907 | default: 908 | return nil, fmt.Errorf("field not yet implemented: %s", strArgs[0]) 909 | } 910 | 911 | return args, nil 912 | } 913 | 914 | func CheckArgsGoToStr(args []interface{}) ([]string, error) { 915 | if len(args) == 0 { 916 | return nil, errors.New("no arguments") 917 | } 918 | 919 | retargs := make([]string, 0) 920 | 921 | field, ok := args[0].(SearchField) 922 | if !ok { 923 | return nil, errors.New("first argument is not a field") 924 | } 925 | 926 | strField, err := fieldGoToString(field) 927 | if err != nil { 928 | return nil, err 929 | } 930 | retargs = append(retargs, strField) 931 | 932 | switch field { 933 | case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: 934 | if len(args) != 3 { 935 | return nil, errors.New("string fields require exactly two arguments") 936 | } 937 | 938 | comparer, ok := args[1].(StrComparer) 939 | if !ok { 940 | return nil, errors.New("comparer must be a StrComparer") 941 | } 942 | 943 | cmpStr, valStr, err := cmpValGoToStr(comparer, args[2]) 944 | if err != nil { 945 | return nil, err 946 | } 947 | retargs = append(retargs, cmpStr) 948 | retargs = append(retargs, valStr) 949 | return retargs, nil 950 | 951 | case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: 952 | if len(args) == 3 { 953 | comparer, ok := args[1].(StrComparer) 954 | if !ok { 955 | return nil, errors.New("comparer must be a StrComparer") 956 | } 957 | 958 | cmpStr, valStr, err := cmpValGoToStr(comparer, args[2]) 959 | if err != nil { 960 | return nil, err 961 | } 962 | retargs = append(retargs, cmpStr) 963 | retargs = append(retargs, valStr) 964 | 965 | return retargs, nil 966 | } else if len(args) == 5 { 967 | comparer1, ok := args[1].(StrComparer) 968 | if !ok { 969 | return nil, errors.New("comparer1 must be a StrComparer") 970 | } 971 | 972 | cmpStr1, valStr1, err := cmpValGoToStr(comparer1, args[2]) 973 | if err != nil { 974 | return nil, err 975 | } 976 | retargs = append(retargs, cmpStr1) 977 | retargs = append(retargs, valStr1) 978 | 979 | comparer2, ok := args[1].(StrComparer) 980 | if !ok { 981 | return nil, errors.New("comparer2 must be a StrComparer") 982 | } 983 | 984 | cmpStr2, valStr2, err := cmpValGoToStr(comparer2, args[2]) 985 | if err != nil { 986 | return nil, err 987 | } 988 | retargs = append(retargs, cmpStr2) 989 | retargs = append(retargs, valStr2) 990 | 991 | return retargs, nil 992 | } else { 993 | return nil, errors.New("key/value queries take exactly two or four arguments") 994 | } 995 | 996 | case FieldAfter, FieldBefore: 997 | if len(args) != 2 { 998 | return nil, errors.New("before/after fields require exactly one argument") 999 | } 1000 | 1001 | time, ok := args[1].(time.Time) 1002 | if !ok { 1003 | return nil, errors.New("argument must have a type of time.Time") 1004 | } 1005 | nanoseconds := time.UnixNano() 1006 | retargs = append(retargs, strconv.FormatInt(nanoseconds, 10)) 1007 | return retargs, nil 1008 | 1009 | case FieldTimeRange: 1010 | if len(args) != 3 { 1011 | return nil, errors.New("time range fields require exactly two arguments") 1012 | } 1013 | 1014 | time1, ok := args[1].(time.Time) 1015 | if !ok { 1016 | return nil, errors.New("arguments must have a type of time.Time") 1017 | } 1018 | nanoseconds1 := time1.UnixNano() 1019 | retargs = append(retargs, strconv.FormatInt(nanoseconds1, 10)) 1020 | 1021 | time2, ok := args[2].(time.Time) 1022 | if !ok { 1023 | return nil, errors.New("arguments must have a type of time.Time") 1024 | } 1025 | nanoseconds2 := time2.UnixNano() 1026 | retargs = append(retargs, strconv.FormatInt(nanoseconds2, 10)) 1027 | return retargs, nil 1028 | 1029 | case FieldInvert: 1030 | strs, err := CheckArgsGoToStr(args[1:]) 1031 | if err != nil { 1032 | return nil, err 1033 | } 1034 | retargs = append(retargs, strs...) 1035 | return retargs, nil 1036 | 1037 | default: 1038 | return nil, fmt.Errorf("invalid field") 1039 | } 1040 | } 1041 | 1042 | func strPhraseToGoPhrase(phrase StrQueryPhrase) (QueryPhrase, error) { 1043 | goPhrase := make(QueryPhrase, len(phrase)) 1044 | for i, strArgs := range phrase { 1045 | var err error 1046 | goPhrase[i], err = CheckArgsStrToGo(strArgs) 1047 | if err != nil { 1048 | return nil, fmt.Errorf("Error with argument set %d: %s", i, err.Error()) 1049 | } 1050 | } 1051 | return goPhrase, nil 1052 | } 1053 | 1054 | func goPhraseToStrPhrase(phrase QueryPhrase) (StrQueryPhrase, error) { 1055 | strPhrase := make(StrQueryPhrase, len(phrase)) 1056 | for i, goArgs := range phrase { 1057 | var err error 1058 | strPhrase[i], err = CheckArgsGoToStr(goArgs) 1059 | if err != nil { 1060 | return nil, fmt.Errorf("Error with argument set %d: %s", i, err.Error()) 1061 | } 1062 | } 1063 | return strPhrase, nil 1064 | } 1065 | 1066 | // Converts a StrMessageQuery into a MessageQuery 1067 | func StrQueryToMsgQuery(query StrMessageQuery) (MessageQuery, error) { 1068 | goQuery := make(MessageQuery, len(query)) 1069 | for i, phrase := range query { 1070 | var err error 1071 | goQuery[i], err = strPhraseToGoPhrase(phrase) 1072 | if err != nil { 1073 | return nil, fmt.Errorf("Error with phrase %d: %s", i, err.Error()) 1074 | } 1075 | } 1076 | return goQuery, nil 1077 | } 1078 | 1079 | // Converts a MessageQuery into a StrMessageQuery 1080 | func MsgQueryToStrQuery(query MessageQuery) (StrMessageQuery, error) { 1081 | strQuery := make(StrMessageQuery, len(query)) 1082 | for i, phrase := range query { 1083 | var err error 1084 | strQuery[i], err = goPhraseToStrPhrase(phrase) 1085 | if err != nil { 1086 | return nil, fmt.Errorf("Error with phrase %d: %s", i, err.Error()) 1087 | } 1088 | } 1089 | return strQuery, nil 1090 | } 1091 | -------------------------------------------------------------------------------- /search_test.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "runtime" 5 | "strconv" 6 | "testing" 7 | ) 8 | 9 | func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interface{}) { 10 | checker, err := NewRequestChecker(args...) 11 | if err != nil { 12 | t.Error(err.Error()) 13 | } 14 | result := checker(req) 15 | if result != expected { 16 | _, f, ln, _ := runtime.Caller(1) 17 | t.Errorf("Failed search test at %s:%d. Expected %s, got %s", f, ln, strconv.FormatBool(expected), strconv.FormatBool(result)) 18 | } 19 | } 20 | 21 | func TestAllSearch(t *testing.T) { 22 | checker, err := NewRequestChecker(FieldAll, StrContains, "foo") 23 | if err != nil { 24 | t.Error(err.Error()) 25 | } 26 | req := testReq() 27 | if !checker(req) { 28 | t.Error("Failed to match FieldAll, StrContains") 29 | } 30 | } 31 | 32 | func TestBodySearch(t *testing.T) { 33 | req := testReq() 34 | 35 | checkSearch(t, req, true, FieldAllBody, StrContains, "foo") 36 | checkSearch(t, req, true, FieldAllBody, StrContains, "oo=b") 37 | checkSearch(t, req, true, FieldAllBody, StrContains, "BBBB") 38 | checkSearch(t, req, false, FieldAllBody, StrContains, "FOO") 39 | 40 | checkSearch(t, req, true, FieldResponseBody, StrContains, "BBBB") 41 | checkSearch(t, req, false, FieldResponseBody, StrContains, "foo") 42 | 43 | checkSearch(t, req, false, FieldRequestBody, StrContains, "BBBB") 44 | checkSearch(t, req, true, FieldRequestBody, StrContains, "foo") 45 | } 46 | 47 | func TestHeaderSearch(t *testing.T) { 48 | req := testReq() 49 | 50 | checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo") 51 | checkSearch(t, req, true, FieldBothHeaders, StrContains, "Bar") 52 | checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo", StrContains, "Bar") 53 | checkSearch(t, req, false, FieldBothHeaders, StrContains, "Bar", StrContains, "Bar") 54 | checkSearch(t, req, false, FieldBothHeaders, StrContains, "Foo", StrContains, "Foo") 55 | } 56 | 57 | func TestRegexpSearch(t *testing.T) { 58 | req := testReq() 59 | 60 | checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "o.b") 61 | checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "baz$") 62 | checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "^f.+z") 63 | checkSearch(t, req, false, FieldRequestBody, StrContainsRegexp, "^baz") 64 | } 65 | -------------------------------------------------------------------------------- /signer.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | /* 4 | Copyright (c) 2012 Elazar Leibovich. All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are 8 | met: 9 | 10 | * Redistributions of source code must retain the above copyright 11 | notice, this list of conditions and the following disclaimer. 12 | * Redistributions in binary form must reproduce the above 13 | copyright notice, this list of conditions and the following disclaimer 14 | in the documentation and/or other materials provided with the 15 | distribution. 16 | * Neither the name of Elazar Leibovich. nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | /* 34 | Signer code used here was taken from: 35 | https://github.com/elazarl/goproxy 36 | */ 37 | 38 | import ( 39 | "crypto/aes" 40 | "crypto/cipher" 41 | "crypto/rsa" 42 | "crypto/sha1" 43 | "crypto/sha256" 44 | "crypto/tls" 45 | "crypto/x509" 46 | "crypto/x509/pkix" 47 | "errors" 48 | "math/big" 49 | "net" 50 | "runtime" 51 | "sort" 52 | "time" 53 | ) 54 | 55 | /* 56 | counterecryptor.go 57 | */ 58 | 59 | type counterEncryptorRand struct { 60 | cipher cipher.Block 61 | counter []byte 62 | rand []byte 63 | ix int 64 | } 65 | 66 | func newCounterEncryptorRandFromKey(key interface{}, seed []byte) (r counterEncryptorRand, err error) { 67 | var keyBytes []byte 68 | switch key := key.(type) { 69 | case *rsa.PrivateKey: 70 | keyBytes = x509.MarshalPKCS1PrivateKey(key) 71 | default: 72 | err = errors.New("only RSA keys supported") 73 | return 74 | } 75 | h := sha256.New() 76 | if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { 77 | return 78 | } 79 | r.counter = make([]byte, r.cipher.BlockSize()) 80 | if seed != nil { 81 | copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()]) 82 | } 83 | r.rand = make([]byte, r.cipher.BlockSize()) 84 | r.ix = len(r.rand) 85 | return 86 | } 87 | 88 | func (c *counterEncryptorRand) Seed(b []byte) { 89 | if len(b) != len(c.counter) { 90 | panic("SetCounter: wrong counter size") 91 | } 92 | copy(c.counter, b) 93 | } 94 | 95 | func (c *counterEncryptorRand) refill() { 96 | c.cipher.Encrypt(c.rand, c.counter) 97 | for i := 0; i < len(c.counter); i++ { 98 | if c.counter[i]++; c.counter[i] != 0 { 99 | break 100 | } 101 | } 102 | c.ix = 0 103 | } 104 | 105 | func (c *counterEncryptorRand) Read(b []byte) (n int, err error) { 106 | if c.ix == len(c.rand) { 107 | c.refill() 108 | } 109 | if n = len(c.rand) - c.ix; n > len(b) { 110 | n = len(b) 111 | } 112 | copy(b, c.rand[c.ix:c.ix+n]) 113 | c.ix += n 114 | return 115 | } 116 | 117 | /* 118 | signer.go 119 | */ 120 | 121 | func hashSorted(lst []string) []byte { 122 | c := make([]string, len(lst)) 123 | copy(c, lst) 124 | sort.Strings(c) 125 | h := sha1.New() 126 | for _, s := range c { 127 | h.Write([]byte(s + ",")) 128 | } 129 | return h.Sum(nil) 130 | } 131 | 132 | func hashSortedBigInt(lst []string) *big.Int { 133 | rv := new(big.Int) 134 | rv.SetBytes(hashSorted(lst)) 135 | return rv 136 | } 137 | 138 | var goproxySignerVersion = ":goroxy1" 139 | 140 | func signHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err error) { 141 | var x509ca *x509.Certificate 142 | 143 | // Use the provided ca and not the global GoproxyCa for certificate generation. 144 | if x509ca, err = x509.ParseCertificate(ca.Certificate[0]); err != nil { 145 | return 146 | } 147 | start := time.Unix(0, 0) 148 | end, err := time.Parse("2006-01-02", "2049-12-31") 149 | if err != nil { 150 | panic(err) 151 | } 152 | hash := hashSorted(append(hosts, goproxySignerVersion, ":"+runtime.Version())) 153 | serial := new(big.Int) 154 | serial.SetBytes(hash) 155 | template := x509.Certificate{ 156 | // TODO(elazar): instead of this ugly hack, just encode the certificate and hash the binary form. 157 | SerialNumber: serial, 158 | Issuer: x509ca.Subject, 159 | Subject: pkix.Name{ 160 | Organization: []string{"GoProxy untrusted MITM proxy Inc"}, 161 | }, 162 | NotBefore: start, 163 | NotAfter: end, 164 | 165 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 166 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 167 | BasicConstraintsValid: true, 168 | } 169 | for _, h := range hosts { 170 | if ip := net.ParseIP(h); ip != nil { 171 | template.IPAddresses = append(template.IPAddresses, ip) 172 | } else { 173 | template.DNSNames = append(template.DNSNames, h) 174 | } 175 | } 176 | var csprng counterEncryptorRand 177 | if csprng, err = newCounterEncryptorRandFromKey(ca.PrivateKey, hash); err != nil { 178 | return 179 | } 180 | var certpriv *rsa.PrivateKey 181 | if certpriv, err = rsa.GenerateKey(&csprng, 1024); err != nil { 182 | return 183 | } 184 | var derBytes []byte 185 | if derBytes, err = x509.CreateCertificate(&csprng, &template, x509ca, &certpriv.PublicKey, ca.PrivateKey); err != nil { 186 | return 187 | } 188 | return tls.Certificate{ 189 | Certificate: [][]byte{derBytes, ca.Certificate[0]}, 190 | PrivateKey: certpriv, 191 | }, nil 192 | } 193 | -------------------------------------------------------------------------------- /sqlitestorage_test.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "runtime" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func testStorage() *SQLiteStorage { 10 | s, _ := InMemoryStorage(NullLogger()) 11 | return s 12 | } 13 | 14 | func checkTags(t *testing.T, result, expected []string) { 15 | _, f, ln, _ := runtime.Caller(1) 16 | 17 | if len(result) != len(expected) { 18 | t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result) 19 | return 20 | } 21 | 22 | for i, a := range result { 23 | b := expected[i] 24 | if a != b { 25 | t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result) 26 | return 27 | } 28 | } 29 | } 30 | 31 | func TestTagging(t *testing.T) { 32 | req := testReq() 33 | storage := testStorage() 34 | defer storage.Close() 35 | 36 | err := SaveNewRequest(storage, req) 37 | testErr(t, err) 38 | req1, err := storage.LoadRequest(req.DbId) 39 | testErr(t, err) 40 | checkTags(t, req1.Tags(), []string{}) 41 | 42 | req.AddTag("foo") 43 | req.AddTag("bar") 44 | err = UpdateRequest(storage, req) 45 | testErr(t, err) 46 | req2, err := storage.LoadRequest(req.DbId) 47 | testErr(t, err) 48 | checkTags(t, req2.Tags(), []string{"foo", "bar"}) 49 | 50 | req.RemoveTag("foo") 51 | err = UpdateRequest(storage, req) 52 | testErr(t, err) 53 | req3, err := storage.LoadRequest(req.DbId) 54 | testErr(t, err) 55 | checkTags(t, req3.Tags(), []string{"bar"}) 56 | } 57 | 58 | func TestTime(t *testing.T) { 59 | req := testReq() 60 | req.StartDatetime = time.Unix(0, 1234567) 61 | req.EndDatetime = time.Unix(0, 2234567) 62 | storage := testStorage() 63 | defer storage.Close() 64 | 65 | err := SaveNewRequest(storage, req) 66 | testErr(t, err) 67 | 68 | req1, err := storage.LoadRequest(req.DbId) 69 | testErr(t, err) 70 | tstart := req1.StartDatetime.UnixNano() 71 | tend := req1.EndDatetime.UnixNano() 72 | 73 | if tstart != 1234567 { 74 | t.Errorf("Start time not saved properly. Expected 1234567, got %d", tstart) 75 | } 76 | 77 | if tend != 2234567 { 78 | t.Errorf("End time not saved properly. Expected 1234567, got %d", tend) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /storage.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // An interface that represents something that can be used to store data from the proxy 9 | type MessageStorage interface { 10 | 11 | // Close the storage 12 | Close() 13 | 14 | // Update an existing request in the storage. Requires that it has already been saved 15 | UpdateRequest(req *ProxyRequest) error 16 | // Save a new instance of the request in the storage regardless of if it has already been saved 17 | SaveNewRequest(req *ProxyRequest) error 18 | // Load a request given a unique id 19 | LoadRequest(reqid string) (*ProxyRequest, error) 20 | // Load the unmangled version of a request given a unique id 21 | LoadUnmangledRequest(reqid string) (*ProxyRequest, error) 22 | // Delete a request given a unique id 23 | DeleteRequest(reqid string) error 24 | 25 | // Update an existing response in the storage. Requires that it has already been saved 26 | UpdateResponse(rsp *ProxyResponse) error 27 | // Save a new instance of the response in the storage regardless of if it has already been saved 28 | SaveNewResponse(rsp *ProxyResponse) error 29 | // Load a response given a unique id 30 | LoadResponse(rspid string) (*ProxyResponse, error) 31 | // Load the unmangled version of a response given a unique id 32 | LoadUnmangledResponse(rspid string) (*ProxyResponse, error) 33 | // Delete a response given a unique id 34 | DeleteResponse(rspid string) error 35 | 36 | // Update an existing websocket message in the storage. Requires that it has already been saved 37 | UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error 38 | // Save a new instance of the websocket message in the storage regardless of if it has already been saved 39 | SaveNewWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error 40 | // Load a websocket message given a unique id 41 | LoadWSMessage(wsmid string) (*ProxyWSMessage, error) 42 | // Load the unmangled version of a websocket message given a unique id 43 | LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error) 44 | // Delete a websocket message given a unique id 45 | DeleteWSMessage(wsmid string) error 46 | 47 | // Get list of the keys for all of the stored requests 48 | RequestKeys() ([]string, error) 49 | 50 | // A function to perform a search of requests in the storage. Same arguments as NewRequestChecker 51 | Search(limit int64, args ...interface{}) ([]*ProxyRequest, error) 52 | 53 | // A function to naively check every function in storage with the given function and return the ones that match 54 | CheckRequests(limit int64, checker RequestChecker) ([]*ProxyRequest, error) 55 | 56 | // Return a list of all the queries stored in the MessageStorage 57 | AllSavedQueries() ([]*SavedQuery, error) 58 | // Save a query in the storage with a given name. If the name is already in storage, it should be overwritten 59 | SaveQuery(name string, query MessageQuery) error 60 | // Load a query by name from the storage 61 | LoadQuery(name string) (MessageQuery, error) 62 | // Delete a query by name from the storage 63 | DeleteQuery(name string) error 64 | 65 | // Add a storage watcher to make callbacks to on message saves 66 | Watch(watcher StorageWatcher) error 67 | // Remove a storage watcher from the storage 68 | EndWatch(watcher StorageWatcher) error 69 | 70 | // Set/get plugin values 71 | SetPluginValue(key string, value string) error 72 | GetPluginValue(key string) (string, error) 73 | } 74 | 75 | type StorageWatcher interface { 76 | // Callback for when a new request is saved 77 | NewRequestSaved(ms MessageStorage, req *ProxyRequest) 78 | // Callback for when a request is updated 79 | RequestUpdated(ms MessageStorage, req *ProxyRequest) 80 | // Callback for when a request is deleted 81 | RequestDeleted(ms MessageStorage, DbId string) 82 | 83 | // Callback for when a new response is saved 84 | NewResponseSaved(ms MessageStorage, rsp *ProxyResponse) 85 | // Callback for when a response is updated 86 | ResponseUpdated(ms MessageStorage, rsp *ProxyResponse) 87 | // Callback for when a response is deleted 88 | ResponseDeleted(ms MessageStorage, DbId string) 89 | 90 | // Callback for when a new wsmessage is saved 91 | NewWSMessageSaved(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) 92 | // Callback for when a wsmessage is updated 93 | WSMessageUpdated(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) 94 | // Callback for when a wsmessage is deleted 95 | WSMessageDeleted(ms MessageStorage, DbId string) 96 | } 97 | 98 | // An error to be returned if a query is not supported 99 | const QueryNotSupported = ConstErr("custom query not supported") 100 | 101 | // A type representing a search query that is stored in a MessageStorage 102 | type SavedQuery struct { 103 | Name string 104 | Query MessageQuery 105 | } 106 | 107 | /* 108 | General storage functions 109 | */ 110 | 111 | // Save a new request and new versions of all its dependant messages (response, websocket messages, and unmangled versions of everything). 112 | func SaveNewRequest(ms MessageStorage, req *ProxyRequest) error { 113 | if req.ServerResponse != nil { 114 | if err := SaveNewResponse(ms, req.ServerResponse); err != nil { 115 | return fmt.Errorf("error saving server response to request: %s", err.Error()) 116 | } 117 | } 118 | 119 | if req.Unmangled != nil { 120 | if req.DbId != "" && req.DbId == req.Unmangled.DbId { 121 | return errors.New("request has same DbId as unmangled version") 122 | } 123 | if err := SaveNewRequest(ms, req.Unmangled); err != nil { 124 | return fmt.Errorf("error saving unmangled version of request: %s", err.Error()) 125 | } 126 | } 127 | 128 | if err := ms.SaveNewRequest(req); err != nil { 129 | return fmt.Errorf("error saving new request: %s", err.Error()) 130 | } 131 | 132 | for _, wsm := range req.WSMessages { 133 | if err := SaveNewWSMessage(ms, req, wsm); err != nil { 134 | return fmt.Errorf("error saving request's ws message: %s", err.Error()) 135 | } 136 | } 137 | 138 | return nil 139 | } 140 | 141 | // Update a request and all its dependent messages. If the request has a DbId it will be updated, otherwise it will be inserted into the database and have its DbId updated. Same for all dependent messages 142 | func UpdateRequest(ms MessageStorage, req *ProxyRequest) error { 143 | if req.ServerResponse != nil { 144 | if err := UpdateResponse(ms, req.ServerResponse); err != nil { 145 | return fmt.Errorf("error saving server response to request: %s", err.Error()) 146 | } 147 | } 148 | 149 | if req.Unmangled != nil { 150 | if req.DbId != "" && req.DbId == req.Unmangled.DbId { 151 | return errors.New("request has same DbId as unmangled version") 152 | } 153 | if err := UpdateRequest(ms, req.Unmangled); err != nil { 154 | return fmt.Errorf("error saving unmangled version of request: %s", err.Error()) 155 | } 156 | } 157 | 158 | if req.DbId == "" { 159 | if err := ms.SaveNewRequest(req); err != nil { 160 | return fmt.Errorf("error saving new request: %s", err.Error()) 161 | } 162 | } else { 163 | if err := ms.UpdateRequest(req); err != nil { 164 | return fmt.Errorf("error updating request: %s", err.Error()) 165 | } 166 | } 167 | 168 | for _, wsm := range req.WSMessages { 169 | if err := UpdateWSMessage(ms, req, wsm); err != nil { 170 | return fmt.Errorf("error saving request's ws message: %s", err.Error()) 171 | } 172 | } 173 | 174 | return nil 175 | } 176 | 177 | // Save a new response/unmangled response to the message storage regardless of the existence of a DbId 178 | func SaveNewResponse(ms MessageStorage, rsp *ProxyResponse) error { 179 | if rsp.Unmangled != nil { 180 | if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId { 181 | return errors.New("response has same DbId as unmangled version") 182 | } 183 | if err := SaveNewResponse(ms, rsp.Unmangled); err != nil { 184 | return fmt.Errorf("error saving unmangled version of response: %s", err.Error()) 185 | } 186 | } 187 | 188 | return ms.SaveNewResponse(rsp) 189 | } 190 | 191 | // Update a response and its unmangled version in the database. If it has a DbId, it will be updated, otherwise a new version will be saved in the database 192 | func UpdateResponse(ms MessageStorage, rsp *ProxyResponse) error { 193 | if rsp.Unmangled != nil { 194 | if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId { 195 | return errors.New("response has same DbId as unmangled version") 196 | } 197 | if err := UpdateResponse(ms, rsp.Unmangled); err != nil { 198 | return fmt.Errorf("error saving unmangled version of response: %s", err.Error()) 199 | } 200 | } 201 | 202 | if rsp.DbId == "" { 203 | return ms.SaveNewResponse(rsp) 204 | } else { 205 | return ms.UpdateResponse(rsp) 206 | } 207 | } 208 | 209 | // Save a new websocket emssage/unmangled version to the message storage regardless of the existence of a DbId 210 | func SaveNewWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error { 211 | if wsm.Unmangled != nil { 212 | if wsm.DbId != "" && wsm.DbId == wsm.Unmangled.DbId { 213 | return errors.New("websocket message has same DbId as unmangled version") 214 | } 215 | if err := SaveNewWSMessage(ms, nil, wsm.Unmangled); err != nil { 216 | return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error()) 217 | } 218 | } 219 | 220 | return ms.SaveNewWSMessage(req, wsm) 221 | } 222 | 223 | // Update a websocket message and its unmangled version in the database. If it has a DbId, it will be updated, otherwise a new version will be saved in the database 224 | func UpdateWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error { 225 | if wsm.Unmangled != nil { 226 | if wsm.DbId != "" && wsm.Unmangled.DbId == wsm.DbId { 227 | return errors.New("websocket message has same DbId as unmangled version") 228 | } 229 | if err := UpdateWSMessage(ms, nil, wsm.Unmangled); err != nil { 230 | return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error()) 231 | } 232 | } 233 | 234 | if wsm.DbId == "" { 235 | return ms.SaveNewWSMessage(req, wsm) 236 | } else { 237 | return ms.UpdateWSMessage(req, wsm) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /testutil.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "runtime" 5 | "testing" 6 | ) 7 | 8 | func testReq() *ProxyRequest { 9 | testReq, _ := ProxyRequestFromBytes( 10 | []byte("POST /?foo=bar HTTP/1.1\r\nFoo: Bar\r\nCookie: cookie=choco\r\nContent-Length: 7\r\n\r\nfoo=baz"), 11 | "foobaz", 12 | 80, 13 | false, 14 | ) 15 | 16 | testRsp, _ := ProxyResponseFromBytes( 17 | []byte("HTTP/1.1 200 OK\r\nSet-Cookie: cockie=cocks\r\nContent-Length: 4\r\n\r\nBBBB"), 18 | ) 19 | 20 | testReq.ServerResponse = testRsp 21 | 22 | return testReq 23 | } 24 | 25 | func testErr(t *testing.T, err error) { 26 | if err != nil { 27 | _, f, ln, _ := runtime.Caller(1) 28 | t.Errorf("Failed test with error at %s:%d. Error: %s", f, ln, err) 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "sync" 7 | ) 8 | 9 | // A type that can be used to create specific error types. ie `const QueryNotSupported = ConstErr("custom query not supported")` 10 | type ConstErr string 11 | 12 | func (e ConstErr) Error() string { return string(e) } 13 | 14 | func DuplicateBytes(bs []byte) []byte { 15 | retBs := make([]byte, len(bs)) 16 | copy(retBs, bs) 17 | return retBs 18 | } 19 | 20 | func IdCounter() func() int { 21 | var nextId int = 1 22 | var nextIdMtx sync.Mutex 23 | return func() int { 24 | nextIdMtx.Lock() 25 | defer nextIdMtx.Unlock() 26 | ret := nextId 27 | nextId++ 28 | return ret 29 | } 30 | } 31 | 32 | func NullLogger() *log.Logger { 33 | return log.New(ioutil.Discard, "", log.Lshortfile) 34 | } 35 | 36 | // A helper type to sort requests by submission time: ie sort.Sort(ReqSort(reqs)) 37 | type ReqSort []*ProxyRequest 38 | 39 | func (reql ReqSort) Len() int { 40 | return len(reql) 41 | } 42 | 43 | func (reql ReqSort) Swap(i int, j int) { 44 | reql[i], reql[j] = reql[j], reql[i] 45 | } 46 | 47 | func (reql ReqSort) Less(i int, j int) bool { 48 | return reql[j].StartDatetime.After(reql[i].StartDatetime) 49 | } 50 | 51 | // A helper type to sort websocket messages by timestamp: ie sort.Sort(WSSort(req.WSMessages)) 52 | type WSSort []*ProxyWSMessage 53 | 54 | func (wsml WSSort) Len() int { 55 | return len(wsml) 56 | } 57 | 58 | func (wsml WSSort) Swap(i int, j int) { 59 | wsml[i], wsml[j] = wsml[j], wsml[i] 60 | } 61 | 62 | func (wsml WSSort) Less(i int, j int) bool { 63 | return wsml[j].Timestamp.After(wsml[i].Timestamp) 64 | } 65 | -------------------------------------------------------------------------------- /webui.go: -------------------------------------------------------------------------------- 1 | package puppy 2 | 3 | import ( 4 | "encoding/pem" 5 | "html/template" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | 11 | func responseHeaders(w http.ResponseWriter) { 12 | w.Header().Set("Connection", "close") 13 | w.Header().Set("Cache-control", "no-cache") 14 | w.Header().Set("Pragma", "no-cache") 15 | w.Header().Set("Cache-control", "no-store") 16 | w.Header().Set("X-Frame-Options", "DENY") 17 | } 18 | 19 | // Generate a proxy-compatible web handler that allows users to download certificates and view responses stored in the storage used by the proxyin the browser 20 | func CreateWebUIHandler() ProxyWebUIHandler { 21 | var masterSrc string = ` 22 | 23 | 24 | {{block "title" .}}Puppy Proxy{{end}} 25 | {{block "head" .}}{{end}} 26 | 27 | 28 | {{block "body" .}}{{end}} 29 | 30 | 31 | ` 32 | var masterTpl *template.Template 33 | 34 | var homeSrc string = ` 35 | {{define "title"}}Puppy Home{{end}} 36 | {{define "body"}} 37 |

Welcome to Puppy

38 |

41 | {{end}} 42 | ` 43 | var homeTpl *template.Template 44 | 45 | var certsSrc string = ` 46 | {{define "title"}}CA Certificate{{end}} 47 | {{define "body"}} 48 |

Downlad this CA cert and add it to your browser to intercept HTTPS messages

49 |

Download

50 | {{end}} 51 | ` 52 | var certsTpl *template.Template 53 | 54 | var rspviewSrc string = ` 55 | {{define "title"}}Response Viewer{{end}} 56 | {{define "head"}} 57 | 63 | {{end}} 64 | {{define "body"}} 65 |

Enter a response ID below to view it in the browser

66 | 67 | {{end}} 68 | ` 69 | var rspviewTpl *template.Template 70 | 71 | var err error 72 | masterTpl, err = template.New("master").Parse(masterSrc) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | homeTpl, err = template.Must(masterTpl.Clone()).Parse(homeSrc) 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | certsTpl, err = template.Must(masterTpl.Clone()).Parse(certsSrc) 83 | if err != nil { 84 | panic(err) 85 | } 86 | 87 | rspviewTpl, err = template.Must(masterTpl.Clone()).Parse(rspviewSrc) 88 | if err != nil { 89 | panic(err) 90 | } 91 | 92 | var WebUIRootHandler = func(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy) { 93 | err := homeTpl.Execute(w, nil) 94 | if err != nil { 95 | http.Error(w, err.Error(), http.StatusInternalServerError) 96 | return 97 | } 98 | } 99 | 100 | var WebUICertsHandler = func(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy, path []string) { 101 | if len(path) > 0 && path[0] == "download" { 102 | cert := iproxy.GetCACertificate() 103 | if cert == nil { 104 | w.Write([]byte("no active certs to download")) 105 | return 106 | } 107 | 108 | pemData := pem.EncodeToMemory( 109 | &pem.Block{ 110 | Type: "CERTIFICATE", 111 | Bytes: cert.Certificate[0], 112 | }, 113 | ) 114 | w.Header().Set("Content-Type", "application/octet-stream") 115 | w.Header().Set("Content-Disposition", "attachment; filename=\"cert.pem\"") 116 | w.Write(pemData) 117 | return 118 | } 119 | err := certsTpl.Execute(w, nil) 120 | if err != nil { 121 | http.Error(w, err.Error(), http.StatusInternalServerError) 122 | return 123 | } 124 | } 125 | 126 | var viewResponseHeaders = func(w http.ResponseWriter) { 127 | w.Header().Del("Cookie") 128 | } 129 | 130 | var WebUIRspHandler = func(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy, path []string) { 131 | if len(path) > 0 { 132 | reqid := path[0] 133 | ms := iproxy.GetProxyStorage() 134 | req, err := ms.LoadRequest(reqid) 135 | if err != nil { 136 | http.Error(w, err.Error(), http.StatusInternalServerError) 137 | return 138 | } 139 | rsp := req.ServerResponse 140 | for k, v := range rsp.Header { 141 | for _, vv := range v { 142 | w.Header().Add(k, vv) 143 | } 144 | } 145 | viewResponseHeaders(w) 146 | w.WriteHeader(rsp.StatusCode) 147 | w.Write(rsp.BodyBytes()) 148 | return 149 | } 150 | err := rspviewTpl.Execute(w, nil) 151 | if err != nil { 152 | http.Error(w, err.Error(), http.StatusInternalServerError) 153 | return 154 | } 155 | } 156 | 157 | return func(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy) { 158 | responseHeaders(w) 159 | parts := strings.Split(r.URL.Path, "/") 160 | switch parts[1] { 161 | case "": 162 | WebUIRootHandler(w, r, iproxy) 163 | case "certs": 164 | WebUICertsHandler(w, r, iproxy, parts[2:]) 165 | case "rsp": 166 | WebUIRspHandler(w, r, iproxy, parts[2:]) 167 | } 168 | } 169 | 170 | } 171 | --------------------------------------------------------------------------------