├── .github └── workflows │ └── mps.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_ZH.md ├── _examples ├── basic-auth │ └── main.go ├── cascade-proxy │ └── main.go ├── generateCert │ ├── openssl-gen.sh │ └── openssl.cnf ├── mitm-proxy │ ├── README.md │ ├── ca.crt │ ├── ca.key │ └── main.go ├── reverse-proxy │ └── main.go ├── simple-http-proxy │ └── main.go └── websocket-proxy │ └── main.go ├── cert ├── cert.go ├── container.go └── mem_provider.go ├── chunked.go ├── context.go ├── counter_encryptor.go ├── counter_encryptor_test.go ├── filter.go ├── filter_group.go ├── forward_handler.go ├── forward_handler_test.go ├── go.mod ├── go.sum ├── handle.go ├── http_proxy.go ├── http_proxy_test.go ├── middleware.go ├── middleware ├── basicAuth.go └── singleHostReverseProxy.go ├── mitm_handler.go ├── mitm_handler_test.go ├── mps.go ├── pool ├── buffer.go ├── conn_container.go ├── conn_options.go └── conn_provider.go ├── reverse_handler.go ├── reverse_handler_test.go ├── transport.go ├── tunnel_handler.go ├── tunnel_handler_test.go ├── websocket_handler.go └── websocket_handler_test.go /.github/workflows/mps.yml: -------------------------------------------------------------------------------- 1 | name: MPS 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.16 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | if [ -f Gopkg.toml ]; then 29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 30 | dep ensure 31 | fi 32 | 33 | - name: Test 34 | run: go test -v . 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | .DS_Store 17 | .idea 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Telanflow 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
MPS
3 |

4 | 5 | English | [🇨🇳中文](README_ZH.md) 6 | 7 | ## 📖 Introduction 8 | ![MPS](https://github.com/telanflow/mps/workflows/MPS/badge.svg) 9 | ![stars](https://img.shields.io/github/stars/telanflow/mps) 10 | ![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/telanflow/mps) 11 | ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/telanflow/mps) 12 | [![license](https://img.shields.io/github/license/telanflow/mps)](https://github.com/telanflow/mps/LICENSE) 13 | 14 | MPS (middle-proxy-server) is an high-performance middle proxy library. support HTTP, HTTPS, Websocket, ForwardProxy, ReverseProxy, TunnelProxy, MitmProxy. 15 | 16 | ## 🚀 Features 17 | - [X] Http Proxy 18 | - [X] Https Proxy 19 | - [X] Forward Proxy 20 | - [X] Reverse Proxy 21 | - [X] Tunnel Proxy 22 | - [X] Mitm Proxy (Man-in-the-middle) 23 | - [X] WekSocket Proxy 24 | 25 | ## 🧰 Install 26 | ``` 27 | go get -u github.com/telanflow/mps 28 | ``` 29 | 30 | ## 🛠 How to use 31 | A simple proxy service 32 | 33 | ```go 34 | package main 35 | 36 | import ( 37 | "github.com/telanflow/mps" 38 | "log" 39 | "net/http" 40 | ) 41 | 42 | func main() { 43 | proxy := mps.NewHttpProxy() 44 | log.Fatal(http.ListenAndServe(":8080", proxy)) 45 | } 46 | ``` 47 | 48 | More [examples](https://github.com/telanflow/mps/tree/master/_examples) 49 | 50 | ## 🧬 Middleware 51 | Middleware can intercept requests and responses. 52 | we have several middleware implementations built in, including [BasicAuth](https://github.com/telanflow/mps/tree/master/middleware) 53 | 54 | ```go 55 | func main() { 56 | proxy := mps.NewHttpProxy() 57 | 58 | proxy.Use(mps.MiddlewareFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 59 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 60 | return ctx.Next(req) 61 | })) 62 | 63 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 64 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 65 | resp, err := ctx.Next(req) 66 | if err != nil { 67 | return nil, err 68 | } 69 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 70 | return resp, err 71 | }) 72 | 73 | log.Fatal(http.ListenAndServe(":8080", proxy)) 74 | } 75 | ``` 76 | 77 | ## ♻️ Filters 78 | Filters can filter requests and responses for unified processing. 79 | It is based on middleware implementation. 80 | 81 | ```go 82 | func main() { 83 | proxy := mps.NewHttpProxy() 84 | 85 | // request Filter Group 86 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) 87 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { 88 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL) 89 | return req, nil 90 | }) 91 | 92 | // response Filter Group 93 | respGroup := proxy.OnResponse() 94 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { 95 | if err != nil { 96 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) 97 | return nil, err 98 | } 99 | 100 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 101 | return resp, err 102 | }) 103 | 104 | log.Fatal(http.ListenAndServe(":8080", proxy)) 105 | } 106 | ``` 107 | 108 | ## 📄 License 109 | Source code in `MPS` is available under the [BSD 3 License](/LICENSE). 110 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 |

2 |
MPS
3 |

4 | 5 | [English](README.md) | 🇨🇳中文 6 | 7 | ## 📖 介绍 8 | ![MPS](https://github.com/telanflow/mps/workflows/MPS/badge.svg) 9 | ![stars](https://img.shields.io/github/stars/telanflow/mps) 10 | ![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/telanflow/mps) 11 | ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/telanflow/mps) 12 | [![license](https://img.shields.io/github/license/telanflow/mps)](https://github.com/telanflow/mps/LICENSE) 13 | 14 | MPS 是一个高性能的中间代理扩展库,支持 HTTP、HTTPS、Websocket、正向代理、反向代理、隧道代理、中间人代理 等代理方式。 15 | 16 | ## 🚀 特性 17 | - [X] Http代理 18 | - [X] Https代理 19 | - [X] 正向代理 20 | - [X] 反向代理 21 | - [X] 隧道代理 22 | - [X] 中间人代理 (MITM) 23 | - [X] WekSocket代理 24 | 25 | ## 🧰 安装 26 | ``` 27 | go get -u github.com/telanflow/mps 28 | ``` 29 | 30 | ## 🛠 如何使用 31 | 一个简单的HTTP代理服务 32 | 33 | ```go 34 | package main 35 | 36 | import ( 37 | "github.com/telanflow/mps" 38 | "log" 39 | "net/http" 40 | ) 41 | 42 | func main() { 43 | proxy := mps.NewHttpProxy() 44 | log.Fatal(http.ListenAndServe(":8080", proxy)) 45 | } 46 | ``` 47 | 48 | 更多 [范例](https://github.com/telanflow/mps/tree/master/_examples) 49 | 50 | ## 🧬 中间件 51 | 中间件可以拦截请求和响应,我们内置实现了多个中间件,包括 [BasicAuth](https://github.com/telanflow/mps/tree/master/middleware) 52 | 53 | ```go 54 | func main() { 55 | proxy := mps.NewHttpProxy() 56 | 57 | proxy.Use(mps.MiddlewareFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 58 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 59 | return ctx.Next(req) 60 | })) 61 | 62 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 63 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 64 | resp, err := ctx.Next(req) 65 | if err != nil { 66 | return nil, err 67 | } 68 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 69 | return resp, err 70 | }) 71 | 72 | log.Fatal(http.ListenAndServe(":8080", proxy)) 73 | } 74 | ``` 75 | 76 | ## ♻️ 过滤器 77 | 过滤器可以对请求和响应进行筛选,统一进行处理。 78 | 它基于中间件实现。 79 | 80 | ```go 81 | func main() { 82 | proxy := mps.NewHttpProxy() 83 | 84 | // request Filter Group 85 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) 86 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { 87 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL) 88 | return req, nil 89 | }) 90 | 91 | // response Filter Group 92 | respGroup := proxy.OnResponse() 93 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { 94 | if err != nil { 95 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) 96 | return nil, err 97 | } 98 | 99 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 100 | return resp, err 101 | }) 102 | 103 | log.Fatal(http.ListenAndServe(":8080", proxy)) 104 | } 105 | ``` 106 | 107 | ## 📄 开源许可 108 | `MPS`中的源代码在[BSD 3 License](/LICENSE)下可用。 109 | -------------------------------------------------------------------------------- /_examples/basic-auth/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | "time" 9 | 10 | "github.com/telanflow/mps" 11 | "github.com/telanflow/mps/middleware" 12 | ) 13 | 14 | // A simple BasicAuth example 15 | func main() { 16 | // endPoint server 17 | go http.ListenAndServe("localhost:8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | // Basic Authentication 19 | usr, pwd, ok := r.BasicAuth() 20 | if !ok { 21 | w.WriteHeader(401) 22 | w.Write([]byte("401 Authentication Required")) 23 | return 24 | } 25 | if usr != "test" || pwd != "test" { 26 | w.WriteHeader(401) 27 | w.Write([]byte("401 Authentication Required")) 28 | return 29 | } 30 | w.Write([]byte("successful endPoint")) 31 | })) 32 | 33 | // proxy server 34 | proxy := mps.NewHttpProxy() 35 | // proxy BasicAuth 36 | proxy.Use(middleware.BasicAuth("mps_realm", func(username, password string) bool { 37 | return username == "mps" && password == "mps" 38 | })) 39 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 40 | // set endPoint BasicAuth 41 | // Or you can set the endPoint BasicAuth on the client 42 | req.SetBasicAuth("test", "test") 43 | return ctx.Next(req) 44 | }) 45 | go http.ListenAndServe("localhost:8081", proxy) 46 | 47 | // wait proxy started 48 | time.Sleep(2 * time.Second) 49 | 50 | // send request 51 | // request ==> proxy ==> http://localhost:8080 52 | // response <== proxy <== http://localhost:8080 53 | request, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) 54 | http.DefaultClient.Transport = &http.Transport{ 55 | Proxy: func(req *http.Request) (*url.URL, error) { 56 | // set proxy server BasicAuth 57 | middleware.SetBasicAuth(req, "mps", "mps") 58 | 59 | // set endPoint BasicAuth 60 | // Or you can set the endPoint to BasicAuth on the proxy server 61 | //req.SetBasicAuth("test", "test") 62 | 63 | return url.Parse("http://localhost:8081") 64 | }, 65 | } 66 | resp, err := http.DefaultClient.Do(request) 67 | if err != nil { 68 | log.Fatal(err) 69 | } 70 | defer resp.Body.Close() 71 | 72 | body, _ := io.ReadAll(resp.Body) 73 | log.Println(string(body)) 74 | } 75 | -------------------------------------------------------------------------------- /_examples/cascade-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | "time" 9 | 10 | "github.com/telanflow/mps" 11 | "github.com/telanflow/mps/middleware" 12 | ) 13 | 14 | // A simple example of cascading proxy. 15 | // It implements BasicAuth 16 | func main() { 17 | // endPoint server 18 | go http.ListenAndServe("localhost:9990", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 19 | _, _ = w.Write([]byte("successful endPoint server")) 20 | })) 21 | 22 | // proxy server 1 23 | proxy1 := mps.NewHttpProxy() 24 | proxy1.Ctx.KeepProxyHeaders = true 25 | proxy1.Use(middleware.BasicAuth("mps_realm_1", func(username, password string) bool { 26 | return username == "foo_1" && password == "bar_1" 27 | })) 28 | go http.ListenAndServe("localhost:9991", proxy1) 29 | 30 | // proxy server 2 31 | proxy2 := mps.NewHttpProxy() 32 | proxy2.Ctx.KeepProxyHeaders = true 33 | proxy2.Use(middleware.BasicAuth("mps_realm_2", func(username, password string) bool { 34 | return username == "foo_2" && password == "bar_2" 35 | })) 36 | proxy2.Transport().Proxy = func(req *http.Request) (*url.URL, error) { 37 | middleware.SetBasicAuth(req, "foo_1", "bar_1") 38 | return url.Parse("http://localhost:9991") 39 | } 40 | go http.ListenAndServe("localhost:9992", proxy2) 41 | 42 | // wait proxy server started 43 | time.Sleep(2 * time.Second) 44 | 45 | // send request 46 | // request ==> proxy2 ==> proxy1 ==> http://localhost:9990 47 | // response <== proxy2 <== proxy1 <== http://localhost:9990 48 | req, _ := http.NewRequest(http.MethodGet, "http://localhost:9990/", nil) 49 | http.DefaultClient.Transport = &http.Transport{ 50 | Proxy: func(r *http.Request) (*url.URL, error) { 51 | middleware.SetBasicAuth(r, "foo_2", "bar_2") 52 | return url.Parse("http://localhost:9992") 53 | }, 54 | } 55 | resp, err := http.DefaultClient.Do(req) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | 60 | body, _ := io.ReadAll(resp.Body) 61 | resp.Body.Close() 62 | 63 | log.Println(resp.Header) 64 | log.Println(string(body)) 65 | } 66 | -------------------------------------------------------------------------------- /_examples/generateCert/openssl-gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | # generate CA's key 4 | openssl genrsa -aes256 -passout pass:1 -out ca.key 4096 5 | openssl rsa -passin pass:1 -in ca.key -out ca.key.tmp 6 | mv ca.key.tmp ca.key 7 | 8 | openssl req -config openssl.cnf -key ca.key -new -x509 -days 7300 -sha256 -extensions v3_ca -out ca.crt 9 | -------------------------------------------------------------------------------- /_examples/generateCert/openssl.cnf: -------------------------------------------------------------------------------- 1 | [ ca ] 2 | default_ca = CA_default 3 | [ CA_default ] 4 | default_md = sha256 5 | [ v3_ca ] 6 | subjectKeyIdentifier=hash 7 | authorityKeyIdentifier=keyid:always,issuer 8 | basicConstraints = critical,CA:true 9 | [ req ] 10 | distinguished_name = req_distinguished_name 11 | [ req_distinguished_name ] 12 | countryName = Country Name (2 letter code) 13 | countryName_default = CN 14 | countryName_min = 2 15 | countryName_max = 2 16 | 17 | stateOrProvinceName = State or Province Name (full name) 18 | stateOrProvinceName_default = ZheJiang 19 | 20 | localityName = Locality Name (eg, city) 21 | localityName_default = HangZhou 22 | 23 | 0.organizationName = Organization Name (eg, company) 24 | 0.organizationName_default = mps 25 | 26 | # we can do this but it is not needed normally :-) 27 | #1.organizationName = Second Organization Name (eg, company) 28 | #1.organizationName_default = World Wide Web Pty Ltd 29 | 30 | organizationalUnitName = Organizational Unit Name (eg, section) 31 | organizationalUnitName_default = mps 32 | 33 | commonName = Common Name (e.g. server FQDN or YOUR name) 34 | commonName_default = mps.github.io 35 | commonName_max = 64 36 | 37 | emailAddress = Email Address 38 | emailAddress_default = telanflow@gmail.com 39 | emailAddress_max = 64 40 | -------------------------------------------------------------------------------- /_examples/mitm-proxy/README.md: -------------------------------------------------------------------------------- 1 | ## Mitm Proxy 2 | 3 | This example implements Https as a man-in-the-middle proxy. 4 | 5 | You can go to the `examples/generateCert` directory to regenerate the certificate files 6 | 7 | ## Steps 8 | 9 | 1. Go to `examples/generateCert` to generate the certificate file. 10 | 11 | 2. Import the `ca.crt` certificate file into your system. 12 | 13 | If you want to use the Go client to make HTTPS requests, you need to configure the certificate, 14 | for example: 15 | ```go 16 | func main() { 17 | // Load ca.crt file 18 | certPEMBlock, err := os.ReadFile("ca.crt") 19 | if err != nil { 20 | panic("failed to load ca.crt file") 21 | } 22 | 23 | // client cert pool 24 | clientCertPool := x509.NewCertPool() 25 | ok := clientCertPool.AppendCertsFromPEM(certPEMBlock) 26 | if !ok { 27 | panic("failed to parse root certificate") 28 | } 29 | 30 | // set Transport 31 | http.DefaultClient.Transport = &http.Transport{ 32 | Proxy: func(r *http.Request) (*url.URL, error) { 33 | // mitm proxy server address. eg. "http://localhost:8080" 34 | return url.Parse("http://localhost:8080") 35 | }, 36 | TLSClientConfig: &tls.Config{ 37 | Certificates: []tls.Certificate{cert.DefaultCertificate}, 38 | ClientAuth: tls.RequireAndVerifyClientCert, 39 | RootCAs: clientCertPool, 40 | }, 41 | } 42 | 43 | // To send request 44 | req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) 45 | resp, err := http.DefaultClient.Do(req) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | defer resp.Body.Close() 50 | } 51 | ``` 52 | 53 | 54 | -------------------------------------------------------------------------------- /_examples/mitm-proxy/ca.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIF7jCCA9agAwIBAgIJAJ+vhP71rZgIMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD 3 | VQQGEwJDTjERMA8GA1UECAwIWmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQww 4 | CgYDVQQKDANtcHMxDDAKBgNVBAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5p 5 | bzEiMCAGCSqGSIb3DQEJARYTdGVsYW5mbG93QGdtYWlsLmNvbTAeFw0yMDA4MTkw 6 | NjM0NDBaFw00MDA4MTQwNjM0NDBaMIGLMQswCQYDVQQGEwJDTjERMA8GA1UECAwI 7 | WmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQwwCgYDVQQKDANtcHMxDDAKBgNV 8 | BAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5pbzEiMCAGCSqGSIb3DQEJARYT 9 | dGVsYW5mbG93QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC 10 | ggIBALuMT+elSDg9C+1ZKFe2I+1DcDShxY7xzowwxYx59x9EEhdlCTurWf32za7s 11 | V7yzhhzsG9w65uJZHXGkmjzl5iDBepL4PMoifEY9gs3W7xsKPJqtQ0c1wvT1yv8H 12 | geSsnmMHmsz4RG+YakyluvuAIwU0QTbpGeQWmP0Hy6aMD9yP3b3p37I+3Ok8QIcC 13 | 0fExLx+9t/EL8PHKcjX0g813cdmyKel9QdGTmjjQ8QSkWRfBPhENhyvazu87X94f 14 | Ru9Gvd1scrJsY0ty6QRGiCoxtaSc7POjRg1PiC3PapiMjr1Sx3Q01sdb6sQwZ03/ 15 | OlWAD8T/kIoJo+C+hRdlYjlaumYiUPbNtVksYldVQZ1Bva/oTcLCYV30GP3jlw26 16 | UC2PXmT4yBJQDdEPrj0LqcTrlGslIdE8KQA/oBetpwdqGTV05oqKI1E3M8CY/QEV 17 | rW+yJXaVq/+ukRKaGWRwJNcZC4iPOUjHaAjaN1XtR0vxPG4Brw2MFy4p3SeEQ8mJ 18 | cvQLnaLzsEYc6iu/ntGdnYMpFMCLNGm+l/+B7dP+KiuseA67Q5BOanFp5wrEQ7ND 19 | 8TdCZUe9LzRnNvyZXi5N5wMxQ/tCVR/a27MLN66mxrG0KyzSunCiYqiLyoC4eQJ9 20 | RKjOkHWQ4V4y2NFTqKhgr/Ns1PBiH/4XKQHdRYoWVV4cL+6ZAgMBAAGjUzBRMB0G 21 | A1UdDgQWBBTNarhuZvbH55oZyKX7bEGgP4D9wDAfBgNVHSMEGDAWgBTNarhuZvbH 22 | 55oZyKX7bEGgP4D9wDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC 23 | AQAEusi8rHwPZwMObn/u/PJQ0K+7Ep/K0nqLGLLpMudj9AVYYV50ZG47lv7ibmQS 24 | LpgfxzJ0fCk5rM6Wg0+KxqLDUP2GotUlc31mIQACVXVbzipx1A06ThBthOUJXaZY 25 | +T6ggKkacLZOG35Af99qOzDIZU9FcunYyaQi8qUCY7VT+Uynu090k7ynvy6S7vut 26 | rGavfDD7oAWEpRyQen8MpUoqhY3PeQJ3VKUp7lcgsiP5f8tLwQAtx7Qvh9li95ml 27 | Kq/SFyA0qPZH6eUrQLvNr4InGwbaoKNAvN99rSHEk2wixuxG2CqsvKW7EQ4IRUcS 28 | H9EBfUEm6EH7xWsyxrlYNzh0k/dYfvy4je1YwkpNAl8BSV8PdFtpW0//FZa04KfG 29 | D9D+wb5MxOmnULl818cYHrg5QIB/uqet21ib85CB/oRvR9JF0D4hfC0gy3IZz+8L 30 | DBhKjf9KNP2631nqtuC7vrShLSFa+02D3bi2srgznIcORtkBHpHgZefAOI7dPJzM 31 | C/ENEorEayK0rdbc116grj7rfGY9yWg/7fD9B0lvrD2DwXTXTd94laDPbpIJIikN 32 | 6EFIJp3IjGKbH+HWUY4TxeRDH8bEq+wjTA0rrIRnI2z03VlX+ZK/UTVdnPxV9Pdd 33 | 3FuvmED63INzIdoY/DgWgUF2DShFbALKfBR3Yev/mkbCqg== 34 | -----END CERTIFICATE----- 35 | -------------------------------------------------------------------------------- /_examples/mitm-proxy/ca.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIJKgIBAAKCAgEAu4xP56VIOD0L7VkoV7Yj7UNwNKHFjvHOjDDFjHn3H0QSF2UJ 3 | O6tZ/fbNruxXvLOGHOwb3Drm4lkdcaSaPOXmIMF6kvg8yiJ8Rj2CzdbvGwo8mq1D 4 | RzXC9PXK/weB5KyeYweazPhEb5hqTKW6+4AjBTRBNukZ5BaY/QfLpowP3I/dvenf 5 | sj7c6TxAhwLR8TEvH7238Qvw8cpyNfSDzXdx2bIp6X1B0ZOaONDxBKRZF8E+EQ2H 6 | K9rO7ztf3h9G70a93WxysmxjS3LpBEaIKjG1pJzs86NGDU+ILc9qmIyOvVLHdDTW 7 | x1vqxDBnTf86VYAPxP+Qigmj4L6FF2ViOVq6ZiJQ9s21WSxiV1VBnUG9r+hNwsJh 8 | XfQY/eOXDbpQLY9eZPjIElAN0Q+uPQupxOuUayUh0TwpAD+gF62nB2oZNXTmiooj 9 | UTczwJj9ARWtb7IldpWr/66REpoZZHAk1xkLiI85SMdoCNo3Ve1HS/E8bgGvDYwX 10 | LindJ4RDyYly9AudovOwRhzqK7+e0Z2dgykUwIs0ab6X/4Ht0/4qK6x4DrtDkE5q 11 | cWnnCsRDs0PxN0JlR70vNGc2/JleLk3nAzFD+0JVH9rbsws3rqbGsbQrLNK6cKJi 12 | qIvKgLh5An1EqM6QdZDhXjLY0VOoqGCv82zU8GIf/hcpAd1FihZVXhwv7pkCAwEA 13 | AQKCAgBThCgQ/4kpggXNq+ZLKNDW1zEgPum6vfM8ent+EtH5Glb0FAoIiEWK0lzF 14 | iHmJjmgqePnvGEu4f/acpLAKblYMQBxVVjW7zZ+Jp9qXzx6q6+QQ/Rb4nvgyHUJI 15 | Tw+IxVXCw6ArpmLTTwwHFcYuOOFfb+WajjL5XxbBlrcZc0Wc8nPMHll/Bn9ZXXte 16 | o+LZhQ13FQTUUnz5Ly2s2TXYSVhpmO0RDLZCnXgP1Pt/FbCW43bAIUYQQV/lKIuI 17 | XmU4KEhkUebBjYKqFoGtZbs9DuXUaA0ccZjAVKpPvA274Nuvcy1ekikSndvtgaB/ 18 | Gyje6igbkbLLxX80laKuyHb1E3HtRZ22YI9t3GIGaKWYzU43nOnAiTp0o40EPnqt 19 | lF4ZEsN7eh0gGEkPnsw24Os5QyJJw4qeXfeVGU9Ig/kFOdf6n+8t2031imcFyScW 20 | BBb9gwj/hef3FFG/lucwoMXSDcIkKdITCr3aH1VxIxulAPQzjOK6rfnGR2kFnKoP 21 | RQcnQS3+mM7HqWppgy18St+MIiW95Cd+6WecUqAM25s9VOnIW2BbzHzKTfaNC4kR 22 | Ou5ruK5FYA36Da/gC0pEEuloV6bP0KVr1IutHNB656U6gl+g5ir0/rdCsrPqJRHP 23 | wHf5Vflb2lrIVFb/jIXqi0GsC7rbcvpkTD6zDcMOzXE69j+IAQKCAQEA7Mo3JO3c 24 | 8Q4S6sCIP8jYTkt//E4VADtNN+IgcSJEtPKQkcP8CvRlt4wkp0aaXQoV99aNzVHZ 25 | j22xOrJENZEjV7xB51R0RDWRe5sVWmVlXOVmePS6ta8WIwQG9RqNlTyEOyYpxhK7 26 | XJh0v29/CxF0UsCQYjYAxwDgE4fY11vIN/q7+YluTlA3Rwei4EvagoFftAGD4jRB 27 | 9ECnH2GiuCS6Mi5PQ90jj9XMCm5WcCXwFG8OKJWf/8uKClruL/yes3+G7lic9mof 28 | 8TbyBrD6bk8OHjp0X42ReCBNPLuDH3G/TWRDw8LtHc6LDLW0K16cm7+VMYF8x9yP 29 | 0BPuO3NixUtfEwKCAQEAysNqV1Ml8+Q+JaCCCOmSK65ad4XUepaarNbJOVuLywNI 30 | ok8VzSjJ3y/wsq9JlrhBMir3SZhjK/KqNb/o7zkDlp9/AAmppdR/iMdC7okmLMWn 31 | j3UM3UvqtH9nhBs+rSzlEMq6Qt6F77rOKB3nmHkcexZIPsC4Ct3ZWvuOCXkHQytT 32 | RLN4b9iqhvw9Zy4Wq8CSKEY9SqXvkaNa8YTyoyPqYozpEaJwAMv0mjNtEqyfic3S 33 | KBRMTTo0SQcLpn/Ee2OL+2EIjy9QgDMOGRDs0fs/xuwZFqeao3dr90jS2vL3atVN 34 | IxiGAI4Hplb2TaGyNTu/1yoXJXKgiel/o91qtmE1IwKCAQEA3m2No1j1JFLuHipB 35 | UnleBx4Q2XaXb6JFBOubQerI05jPiL2q8rdlHSe9/ovp0N/6hta6WVY7oemOg+6U 36 | +CSgKHglCCJjHPec85lYU5PPxZWPzqtFAAm6J6ZOysromHlCVTWiI/fQnEhx0qnv 37 | kvwQYvOULU1BKa5+zpnbbWFAEKWtEdixD0t2wXhA3aUjW1ggCD0sH76q/cAFvQrA 38 | CW4moaCywLLoBuL0ShAfjjV08hzoFeOHaodN4jBMcjNA+KggnaALwcUqwDG24+Y3 39 | OIt2XZrXWjLnpQniw9v4bf8xjodSyH9AsbElGQlOdzbmsb8jbF+QUUW0qecu8BWR 40 | gHculQKCAQEAmfhGekVTnp6FasE1vVrQeocNf5GKxgQzNGhtqTaRMvotX8M6RO5i 41 | TS70Ulu1P9Ru/Y+O9L3ZIPhGtEYktfPPe8NmBztPLfPtXIojk0tmR71X/iHeQPVz 42 | JtlQXArsT0i2MUggpMKhZmeuQNxkj234aKeE+NITb30DnolDVIIpN6Jgutyl6hjX 43 | dWV5oy5mXMoAssCTrmnPQAKR/rD8J1IQnAFwwslcz94Qwj+m5fVbuKMooPK49jPq 44 | nEHTYP3I0AHJvHv0qfY95PvgCrzFeLaXuZBzhLaFQPhgbglIxKaXpvKOfsYSi71O 45 | pcuHgW/2CWJzzQnTRcaDjfZXzLFIZXHvjQKCAQEAwnl07PTmCWFuZGGNQPM8DO5q 46 | s5TFy2c48vOlHSlqnatJoXc5rqga7GpI1fymRBDAUjRPYr+NdyXHQHaqrzKK54Dj 47 | Ca/nDJrn5830GKqpqEG2m2P1gbpmuFb9H63ugNCnC68ktu1S8672hC+/hgz3TNxB 48 | S1BNQ6PywAy8UfxJQa4S/l8QjKuOYl87/d7Ud4/UfRa4nVSA8/XF5fg/Cv21mATp 49 | iluyD/oMrP+uTGtWJlOhOgj8lH1alHdY8PogQ6ZlXJqL2JkgW0/25WjLXIZS88DQ 50 | CQp28IBolmUdAChYTq0V6NbCinqviE/mDMNpp09qCxMXrGabBB4j6Xm35vlivg== 51 | -----END RSA PRIVATE KEY----- 52 | -------------------------------------------------------------------------------- /_examples/mitm-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "github.com/telanflow/mps" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "regexp" 11 | "syscall" 12 | ) 13 | 14 | // A simple mitm proxy server 15 | func main() { 16 | quitSignChan := make(chan os.Signal) 17 | 18 | // create proxy server 19 | proxy := mps.NewHttpProxy() 20 | 21 | // Load cert file 22 | // The Connect request is processed using MitmHandler 23 | mitmHandler, err := mps.NewMitmHandlerWithCertFile(proxy.Ctx, "./examples/mitm-proxy/ca.crt", "./examples/mitm-proxy/ca.key") 24 | if err != nil { 25 | log.Panic(err) 26 | } 27 | proxy.HandleConnect = mitmHandler 28 | 29 | // Middleware 30 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 31 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 32 | return ctx.Next(req) 33 | }) 34 | 35 | // Filter 36 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) 37 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { 38 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL) 39 | return req, nil 40 | }) 41 | respGroup := proxy.OnResponse() 42 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { 43 | if err != nil { 44 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) 45 | return resp, err 46 | } 47 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 48 | return resp, err 49 | }) 50 | 51 | // Started proxy server 52 | srv := http.Server{ 53 | Addr: "localhost:8080", 54 | Handler: proxy, 55 | } 56 | go func() { 57 | log.Printf("MitmProxy started listen: http://%s", srv.Addr) 58 | err := srv.ListenAndServe() 59 | if errors.Is(err, http.ErrServerClosed) { 60 | return 61 | } 62 | if err != nil { 63 | quitSignChan <- syscall.SIGKILL 64 | log.Fatalf("MitmProxy start fail: %v", err) 65 | } 66 | }() 67 | 68 | // quit signal 69 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) 70 | 71 | <-quitSignChan 72 | _ = srv.Close() 73 | log.Fatal("MitmProxy server stop!") 74 | } 75 | -------------------------------------------------------------------------------- /_examples/reverse-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | 12 | "github.com/telanflow/mps" 13 | "github.com/telanflow/mps/middleware" 14 | ) 15 | 16 | // A simple reverse proxy server 17 | func main() { 18 | targetURL, _ := url.Parse("https://www.google.com") 19 | quitSignChan := make(chan os.Signal) 20 | 21 | // reverse proxy server 22 | proxy := mps.NewReverseHandler() 23 | proxy.UseFunc(middleware.SingleHostReverseProxy(targetURL)) 24 | 25 | reqGroup := proxy.OnRequest() 26 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { 27 | log.Printf("[INFO] req -- %s %s", req.Method, req.Host) 28 | return req, nil 29 | }) 30 | 31 | respGroup := proxy.OnResponse() 32 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { 33 | if err != nil { 34 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) 35 | return nil, err 36 | } 37 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 38 | return resp, err 39 | }) 40 | 41 | // started proxy server 42 | srv := http.Server{ 43 | Addr: "localhost:8080", 44 | Handler: proxy, 45 | } 46 | go func() { 47 | log.Printf("ReverseProxy started listen: http://%s", srv.Addr) 48 | err := srv.ListenAndServe() 49 | if errors.Is(err, http.ErrServerClosed) { 50 | return 51 | } 52 | if err != nil { 53 | quitSignChan <- syscall.SIGKILL 54 | log.Fatalf("ReverseProxy start fail: %v", err) 55 | } 56 | }() 57 | 58 | // quit signal 59 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) 60 | 61 | <-quitSignChan 62 | _ = srv.Close() 63 | log.Fatal("ReverseProxy server stop!") 64 | } 65 | -------------------------------------------------------------------------------- /_examples/simple-http-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "regexp" 10 | "syscall" 11 | 12 | "github.com/telanflow/mps" 13 | ) 14 | 15 | // A simple http proxy server 16 | func main() { 17 | quitSignChan := make(chan os.Signal) 18 | 19 | // create a http proxy server 20 | proxy := mps.NewHttpProxy() 21 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 22 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) 23 | return ctx.Next(req) 24 | }) 25 | 26 | // Filter Request 27 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) 28 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { 29 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL) 30 | return req, nil 31 | }) 32 | 33 | // Filter Response 34 | respGroup := proxy.OnResponse() 35 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { 36 | if err != nil { 37 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) 38 | return resp, err 39 | } 40 | 41 | log.Printf("[INFO] resp -- %d", resp.StatusCode) 42 | return resp, err 43 | }) 44 | 45 | // Start server 46 | srv := &http.Server{ 47 | Addr: "localhost:8080", 48 | Handler: proxy, 49 | } 50 | go func() { 51 | log.Printf("HttpProxy started listen: http://%s", srv.Addr) 52 | err := srv.ListenAndServe() 53 | if errors.Is(err, http.ErrServerClosed) { 54 | return 55 | } 56 | if err != nil { 57 | quitSignChan <- syscall.SIGKILL 58 | log.Fatalf("HttpProxy start fail: %v", err) 59 | } 60 | }() 61 | 62 | // quit signal 63 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) 64 | 65 | <-quitSignChan 66 | _ = srv.Close() 67 | log.Fatal("HttpProxy server stop!") 68 | } 69 | -------------------------------------------------------------------------------- /_examples/websocket-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | 12 | "github.com/gorilla/websocket" 13 | "github.com/telanflow/mps" 14 | ) 15 | 16 | var ( 17 | upgrader = websocket.Upgrader{} 18 | endPointAddr = "localhost:9990" 19 | ) 20 | 21 | // run a endPoint websocket server 22 | func runWebsocketServer() { 23 | http.ListenAndServe(endPointAddr, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 24 | c, err := upgrader.Upgrade(rw, req, nil) 25 | if err != nil { 26 | return 27 | } 28 | defer c.Close() 29 | for { 30 | mt, message, err := c.ReadMessage() 31 | if err != nil { 32 | break 33 | } 34 | err = c.WriteMessage(mt, message) 35 | if err != nil { 36 | break 37 | } 38 | } 39 | })) 40 | } 41 | 42 | // A simple proxy websocket server 43 | func main() { 44 | // quit signal 45 | quitSignChan := make(chan os.Signal) 46 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) 47 | 48 | // start endPoint websocket server 49 | go runWebsocketServer() 50 | 51 | // start proxy websocket server 52 | websocketHandler := mps.NewWebsocketHandler() 53 | websocketHandler.Transport().Proxy = func(request *http.Request) (*url.URL, error) { 54 | // endPoint websocket server 55 | return url.Parse("ws://" + endPointAddr) 56 | } 57 | srv := &http.Server{ 58 | Addr: "localhost:8080", 59 | Handler: websocketHandler, 60 | } 61 | go func() { 62 | log.Printf("WebsocketProxy started listen: ws://%s", srv.Addr) 63 | err := srv.ListenAndServe() 64 | if errors.Is(err, http.ErrServerClosed) { 65 | return 66 | } 67 | if err != nil { 68 | quitSignChan <- syscall.SIGKILL 69 | log.Fatalf("WebsocketProxy start fail: %v", err) 70 | } 71 | }() 72 | 73 | <-quitSignChan 74 | _ = srv.Close() 75 | log.Fatal("WebsocketProxy server stop!") 76 | } 77 | -------------------------------------------------------------------------------- /cert/cert.go: -------------------------------------------------------------------------------- 1 | package cert 2 | 3 | import "crypto/tls" 4 | 5 | const CertPEM = `-----BEGIN CERTIFICATE----- 6 | MIIF7jCCA9agAwIBAgIJAP/+a5pIA2lJMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD 7 | VQQGEwJDTjERMA8GA1UECAwIWmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQww 8 | CgYDVQQKDANtcHMxDDAKBgNVBAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5p 9 | bzEiMCAGCSqGSIb3DQEJARYTdGVsYW5mbG93QGdtYWlsLmNvbTAeFw0yMDA4MDYx 10 | MTE4MThaFw00MDA4MDExMTE4MThaMIGLMQswCQYDVQQGEwJDTjERMA8GA1UECAwI 11 | WmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQwwCgYDVQQKDANtcHMxDDAKBgNV 12 | BAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5pbzEiMCAGCSqGSIb3DQEJARYT 13 | dGVsYW5mbG93QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC 14 | ggIBANJZU0vyrS7aROi5+0e6AR4VBulFEjoivLrYaa1Pl1ENHHTgfjjmnLf2+22G 15 | ImMp95RUDYIT2tZ2GhksLJil+fJEvv7HMihsWYYTjGzr5u3kPke0+fB/7dbRYJ+h 16 | FvlsLEkItYPT9iBHryStu5CRV3P1VNtR9/7FF8YdX3kOqMQASnHQhBYNZ7av2OuR 17 | 3pDPLD0PKccqMeTXW+yMsB+z0L03RQQG3LOmi/7nWogvqVrnuwP7JbybOtHEvLO0 18 | rLEoAdXwdCCSAHdBCz2qat/I9CubGlKdUlgVw8eXVWZeYJeVOOQy8f7L9AEPoc5k 19 | uXpEyRPCzpo/T/6KSxi2oxaEI4BSZUtyxRS/Laezdgs+GnKkjO56Z3lMPCvwwLFO 20 | DNdtxA3OgLIvcZSA9zWPgoOSVQ0nCIQl3L3qEJ/TqyUWkcPINhiLNgnVSdu1dQ7q 21 | rFZegmi5RAKAyl0M1rSlmTAB3Q/Mf4BMzPaNUajW7bjx4MbU9LxknVlRUb1vv2Jv 22 | Pd6mUm0vLy6P/zl8/pZRpcnn91omFJ+PgZMoRzUPTBNDrgUEeXNsLzaLHg6t4fLb 23 | xd1QMsg99Upo643Q/Hb8Xfz2ogm82jRURXkiHhQgxPjUvk76N4obNW9noMlZEUpF 24 | /68/WwMc2CrWvWZ1HKWfpJDN6C2hjOqWvVWBng6LssVZdIBnAgMBAAGjUzBRMB0G 25 | A1UdDgQWBBSPklwhHPcnDnP8tNSQ2i+VAs/gvjAfBgNVHSMEGDAWgBSPklwhHPcn 26 | DnP8tNSQ2i+VAs/gvjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC 27 | AQA6KnQfPV1gS53ZwakZAzE3XEDx+ef1C0iFZx282PWcIwnBPYbkswTt8RJj5806 28 | MiKyBtHSDN3Agde5LP6C2BhCx2GeguUDcTDPY0PGj5/TaES2gRiu6rsKkQJhSUTs 29 | RSPekDT30yCJMJSz/1QnOqGXwToSpyr5rsWxAyYGAAz0fSAwpJ2XuE77vHCk+p3k 30 | zIOjWxrkkLTSxoqIhOjmq8hO5qvwojudUTn0PBcAakND6r5csWLR2i+Am7u36WJc 31 | m8dX8L+SXmWYw85Fs4tLDwUsORFuJclY32g7fKOvrgV2rhHXwwWw8pjpbjs0AAKA 32 | +Gk3QQMT0cF4FpnH8VBdK82/nOtcbNvWz994K18kEzGUzN5Dq7zi+7n99HH4rwjO 33 | 2eyWNl7hsatvtKfcpwHhIt63EHG3owhc+Wf0iG3i6BR1b9jO3XLJmIlgkQ+4pSpV 34 | +9BB2HNfklUOvJdsWwSxdhChavHJokl0rdGRf8weWhqhLkUCC3z2lIIylXgay71W 35 | ++48SxdaMbiqnuEZrt4cMlOt+KAxvtEl1krWLAIi6URHLSvdER7w/Dtkvg+PPb1t 36 | yLiugEvmBIpagsbw3zirMza8Rg1CchRB+0sRGVSE4ppRe5EiIAe4aZUKufIk0TtS 37 | 8yO67j6Lx35sfhqzg/Jl31HOk42M8MpZqAGy13Cyw1kGmg== 38 | -----END CERTIFICATE-----` 39 | 40 | const KeyPEM = `-----BEGIN RSA PRIVATE KEY----- 41 | MIIJKgIBAAKCAgEA0llTS/KtLtpE6Ln7R7oBHhUG6UUSOiK8uthprU+XUQ0cdOB+ 42 | OOact/b7bYYiYyn3lFQNghPa1nYaGSwsmKX58kS+/scyKGxZhhOMbOvm7eQ+R7T5 43 | 8H/t1tFgn6EW+WwsSQi1g9P2IEevJK27kJFXc/VU21H3/sUXxh1feQ6oxABKcdCE 44 | Fg1ntq/Y65HekM8sPQ8pxyox5Ndb7IywH7PQvTdFBAbcs6aL/udaiC+pWue7A/sl 45 | vJs60cS8s7SssSgB1fB0IJIAd0ELPapq38j0K5saUp1SWBXDx5dVZl5gl5U45DLx 46 | /sv0AQ+hzmS5ekTJE8LOmj9P/opLGLajFoQjgFJlS3LFFL8tp7N2Cz4acqSM7npn 47 | eUw8K/DAsU4M123EDc6Asi9xlID3NY+Cg5JVDScIhCXcveoQn9OrJRaRw8g2GIs2 48 | CdVJ27V1DuqsVl6CaLlEAoDKXQzWtKWZMAHdD8x/gEzM9o1RqNbtuPHgxtT0vGSd 49 | WVFRvW+/Ym893qZSbS8vLo//OXz+llGlyef3WiYUn4+BkyhHNQ9ME0OuBQR5c2wv 50 | NoseDq3h8tvF3VAyyD31SmjrjdD8dvxd/PaiCbzaNFRFeSIeFCDE+NS+Tvo3ihs1 51 | b2egyVkRSkX/rz9bAxzYKta9ZnUcpZ+kkM3oLaGM6pa9VYGeDouyxVl0gGcCAwEA 52 | AQKCAgArbEc2wXUg2+wnwuTtrKc4Z4zSsPCPUcZ2J+DA51JMaBF8yy8jXe/yRikn 53 | Ne55XBuA4k0bki+14BGJKsZWCMVtTuXCwKpJD/z3Iaf2gEheyaRVtzV1gWM+2mBA 54 | 88dDXCJUPVkDSslfZozwXHEA6hAMnxOSZvxz+onq2vtviSgrtgeoMSxjRQco/mog 55 | Ty+L40i1niC4vawpGpAeZ/ifwsYPmY5Ew4niCDqUN3xH6tbiLj48Fyd2JPFihmOS 56 | EXUo6SJf4NCIPLud4q6IX1rKsbg+HDm13kY2at/MnyABDvCPuj1RVncAa2gGpAx6 57 | B+8GH5cG3ks6KmHAIRpZkrJeHo8ZOZOVsfokBDjTEWUn3sQH8RrVXtNnm/W66dAl 58 | m4LnKyBWvyVaOHn65Jq06XTaUrT/9MY2RDmLPehzhcPczcZZn8RQikrarStQuHk2 59 | DOiiCvjSVnh+O13RdCKBMXfG4A482LucFnSSweiuFrXDU87GO++jKaZoYPeXsQul 60 | jlTNFUyr7zHO5gqVf+JzboRG+pwashiFZBCfVqu/h9Abx7BTOfIXi5k5f0r1elZw 61 | hQwJT0WgJKX4MjehojABNi+t4i6xqcsmCB9D68FBONSLHY4qpa1s9rhHEzd6BTJS 62 | Fg4GPFooxQ2lwhjAyz2ZhG6HbF3QNuV6HgoimkcYlLiVowx5qQKCAQEA7XiDMQoA 63 | 4J+ZVlseGPinIVpOuAY8ehcd25I9xLmaqvk0CGeGBvpC125KGnG42m4l133AicOa 64 | +Dz0yU0UudvfvnJEyPd6ojpAYMxX5/MH+85hJ8ARPwQQ+K+TlIPKH/8jXhS8D25D 65 | pcvY4MJftwiuBhFTYlveNnmbH7QCAge/lS0BSnlOMGUI/yBg4DSOt0sAZOblWNlv 66 | 1FKG3aKCdd3atY1VvTyGnqpUqFiuU6ENwSbam77hTuQHjE3rULESj/wHIxpj/2Gr 67 | VsjD/o29h2jjUApseUQqi55TllBj60K/DGuTiKSj8PXmXaBy20MDpef9HEYOVD/j 68 | lsB4aPHqqn6+BQKCAQEA4sMNrA7xAfFRLJphzXd/6k9ISbKTjcp57f4Akib3FqCo 69 | BJPz1F8cQJ5BHZLBJum/jyfbPEd36owr5bn/JlMqXsPgzb3eik/ZoIA/woucFosh 70 | 8MrebpARuSMmNtC0F2VfEDG4G/p7c+/aYWLPJJbte0XmIvsIVXwtlAt6k+HCAVAW 71 | PA+MLAelEC0gtHOk1ea9NN2VCfsfpsbw/4GSUlL6Efev5ufXyH2z+tPmZyusPnBS 72 | fAGZ78d000mH3RVsGN3o22Tzv8Hpx62MV8U08TvCESZsjggX2lCpVIsi3GRDR5wj 73 | NFU7LOEy7TljYMES8GIyNc8U9csIgyh2+WLSjcSkewKCAQEAvuFa2uVGlUfUkoSF 74 | ad8dQIL9uZBRtnW0a1Vezy299GaB+6tzIVKyvcYKTL1SsElPo6qSRGp1u8oLnW+X 75 | FFp3u/bP8ZZz/cjDDMvUcT56EV7v22rYsgWLusou33cb1qJYBHy4OdMRD0kO2IOF 76 | OnQApiHxG6Pqt3ECTvZ7krQ1vCxD2GAviFj+ZUzacf3tJcpk07aBbezBpjJ789V3 77 | 9lRRRBQKciUftJQHnpZB8jkH/FVF7WD+bFKA+rd7Sg47dH9KIV5KOPKCLi0M1iWK 78 | zjhyV1k5njQ72qR2XeHanzW0qcAjA/gLS1ntRR7+k96HJSmX2804IWKFhxzI7Npg 79 | HZHpHQKCAQEAiabMGuErDfnWQ9QngJmE7dBY2lvr1EvP/leNMysyHOtDcxv5DLb7 80 | qIIolvIqDBwi65zPKeVcduXGE/r3VuVvN/2B7oLOn3lfa13O1qL3CnxFCy2rHsSX 81 | 7aHXpbjFSdqAfY0g7OL9o+A62Zkok1aHLKi+zgdDBNmPtWnObAzEPxXFmYn6lhPB 82 | 8HLkgoYczrf1rSzBN0DY8t2bGA8oqo6yPMv1XJ7qT0t3QND28TQCqBh5CcvTDUov 83 | sb7WGa/SYbn7i4rZqFLnPg4svm7492NGKDEB/qoNCLqkP60CaXT3nnW6rR77//9o 84 | cba/i9FIVOHXBvEBET/BmBStPC/wDp0LFwKCAQEA1sZKLH3IhWMbdqOcWbO8H1VC 85 | 4NT74peijfTjJ1JxcLllyv1H0MXW9qXG0Sksmy41CdyPBZUQuYbzi78p8S1aNIlx 86 | sj1VGbGIHk+YNMJYHBlTBpn9hjDjXP3tZHtVHRzZN0rjpFV76ODpxatvKxryPl3X 87 | mAPMhTvwnxnQ2rNF7RBPC8H8qJVBbG98k9HCqHolbNiMhV2Iow2SKSuX57IjT0cY 88 | 7mxps5zU94dTyJARNZaP7nlGHv1qx2ihRqxksIWxFetQ+U1JrM/14aeFUtS23HM+ 89 | MJQxyIWYaidDHJzHy1MiZBZ5dpC2hwqNSjLq/OoDEV2cAYHEMnLnXCiz7ZulUA== 90 | -----END RSA PRIVATE KEY-----` 91 | 92 | // default certificate 93 | var DefaultCertificate, _ = tls.X509KeyPair([]byte(CertPEM), []byte(KeyPEM)) 94 | -------------------------------------------------------------------------------- /cert/container.go: -------------------------------------------------------------------------------- 1 | package cert 2 | 3 | import "crypto/tls" 4 | 5 | // certificate storage Container 6 | type Container interface { 7 | 8 | // Get the certificate for host 9 | Get(host string) (*tls.Certificate, error) 10 | 11 | // Set the certificate for host 12 | Set(host string, cert *tls.Certificate) error 13 | } 14 | -------------------------------------------------------------------------------- /cert/mem_provider.go: -------------------------------------------------------------------------------- 1 | package cert 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | ) 9 | 10 | var DefaultMemProvider = NewMemProvider() 11 | 12 | // MemProvider A simple in-memory certificate cache 13 | type MemProvider struct { 14 | cache map[string]*tls.Certificate 15 | rw sync.RWMutex 16 | } 17 | 18 | // Create a MemProvider 19 | func NewMemProvider() *MemProvider { 20 | return &MemProvider{ 21 | cache: make(map[string]*tls.Certificate), 22 | rw: sync.RWMutex{}, 23 | } 24 | } 25 | 26 | // Get the certificate for the Host from the cache 27 | func (m *MemProvider) Get(host string) (cert *tls.Certificate, err error) { 28 | var ok bool 29 | cert, ok = m.cache[strings.TrimSpace(host)] 30 | if !ok { 31 | err = fmt.Errorf("cert not exist") 32 | } 33 | return 34 | } 35 | 36 | // Set the Host certificate to the cache 37 | func (m *MemProvider) Set(host string, cert *tls.Certificate) error { 38 | host = strings.TrimSpace(host) 39 | m.rw.Lock() 40 | m.cache[host] = cert 41 | m.rw.Unlock() 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /chunked.go: -------------------------------------------------------------------------------- 1 | // Taken from $GOROOT/src/pkg/net/http/chunked 2 | // needed to write https responses to client. 3 | package mps 4 | 5 | import ( 6 | "io" 7 | "strconv" 8 | ) 9 | 10 | // newChunkedWriter returns a new chunkedWriter that translates writes into HTTP 11 | // "chunked" format before writing them to w. Closing the returned chunkedWriter 12 | // sends the final 0-length chunk that marks the end of the stream. 13 | // 14 | // newChunkedWriter is not needed by normal applications. The http 15 | // package adds chunking automatically if handlers don't set a 16 | // Content-Length header. Using newChunkedWriter inside a handler 17 | // would result in double chunking or chunking with a Content-Length 18 | // length, both of which are wrong. 19 | func newChunkedWriter(w io.Writer) io.WriteCloser { 20 | return &chunkedWriter{w} 21 | } 22 | 23 | // Writing to chunkedWriter translates to writing in HTTP chunked Transfer 24 | // Encoding wire format to the underlying Wire chunkedWriter. 25 | type chunkedWriter struct { 26 | Wire io.Writer 27 | } 28 | 29 | // Write the contents of data as one chunk to Wire. 30 | // NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has 31 | // a bug since it does not check for success of io.WriteString 32 | func (cw *chunkedWriter) Write(data []byte) (n int, err error) { 33 | // Don't send 0-length data. It looks like EOF for chunked encoding. 34 | if len(data) == 0 { 35 | return 0, nil 36 | } 37 | 38 | head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" 39 | 40 | if _, err = io.WriteString(cw.Wire, head); err != nil { 41 | return 0, err 42 | } 43 | if n, err = cw.Wire.Write(data); err != nil { 44 | return 45 | } 46 | if n != len(data) { 47 | err = io.ErrShortWrite 48 | return 49 | } 50 | _, err = io.WriteString(cw.Wire, "\r\n") 51 | return 52 | } 53 | 54 | func (cw *chunkedWriter) Close() error { 55 | _, err := io.WriteString(cw.Wire, "0\r\n") 56 | return err 57 | } 58 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "errors" 7 | "net" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | var ( 13 | // http request is nil 14 | RequestNilErr = errors.New("request is nil") 15 | // http request method not support 16 | MethodNotSupportErr = errors.New("request method not support") 17 | // http request is websocket 18 | RequestWebsocketUpgradeErr = errors.New("websocket upgrade") 19 | ) 20 | 21 | // Context for the request 22 | // which contains Middleware, Transport, and other values 23 | type Context struct { 24 | // context.Context 25 | Context context.Context 26 | 27 | // Request context-dependent requests 28 | Request *http.Request 29 | 30 | // Response is associated with Request 31 | Response *http.Response 32 | 33 | // Transport is used for global HTTP requests, and it will be reused. 34 | Transport *http.Transport 35 | 36 | // In some cases it is not always necessary to remove the proxy headers. 37 | // For example, cascade proxy 38 | KeepProxyHeaders bool 39 | 40 | // In some cases it is not always necessary to reset the headers. 41 | KeepClientHeaders bool 42 | 43 | // KeepDestinationHeaders indicates the proxy should retain any headers 44 | // present in the http.Response before proxying 45 | KeepDestinationHeaders bool 46 | 47 | // middlewares ACTS on Request and Response. 48 | // It's going to be reused by the Context 49 | // mi is the index subscript of the middlewares traversal 50 | // the default value for the index is -1 51 | mi int 52 | middlewares []Middleware 53 | } 54 | 55 | // NewContext create http request Context 56 | func NewContext() *Context { 57 | return &Context{ 58 | Context: context.Background(), 59 | // Cannot reuse one Transport because multiple proxy can collide with each other 60 | Transport: &http.Transport{ 61 | DialContext: (&net.Dialer{ 62 | Timeout: 15 * time.Second, 63 | KeepAlive: 30 * time.Second, 64 | DualStack: true, 65 | }).DialContext, 66 | MaxIdleConns: 100, 67 | IdleConnTimeout: 90 * time.Second, 68 | TLSHandshakeTimeout: 10 * time.Second, 69 | ExpectContinueTimeout: 1 * time.Second, 70 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 71 | Proxy: http.ProxyFromEnvironment, 72 | }, 73 | Request: nil, 74 | Response: nil, 75 | KeepProxyHeaders: false, 76 | KeepClientHeaders: false, 77 | KeepDestinationHeaders: false, 78 | mi: -1, 79 | middlewares: make([]Middleware, 0), 80 | } 81 | } 82 | 83 | // Use registers an Middleware to proxy 84 | func (ctx *Context) Use(middleware ...Middleware) { 85 | if ctx.middlewares == nil { 86 | ctx.middlewares = make([]Middleware, 0) 87 | } 88 | ctx.middlewares = append(ctx.middlewares, middleware...) 89 | } 90 | 91 | // UseFunc registers an MiddlewareFunc to proxy 92 | func (ctx *Context) UseFunc(fns ...MiddlewareFunc) { 93 | if ctx.middlewares == nil { 94 | ctx.middlewares = make([]Middleware, 0) 95 | } 96 | for _, fn := range fns { 97 | ctx.middlewares = append(ctx.middlewares, fn) 98 | } 99 | } 100 | 101 | // Next to exec middlewares 102 | // Execute the next middleware as a linked list. "ctx.Next(req)" 103 | // eg: 104 | // 105 | // func Handle(req *http.Request, ctx *Context) (*http.Response, error) { 106 | // // You can do anything to modify the http.Request ... 107 | // resp, err := ctx.Next(req) 108 | // // You can do anything to modify the http.Response ... 109 | // return resp, err 110 | // } 111 | // 112 | // Alternatively, you can simply return the response without executing `ctx.Next()`, 113 | // which will interrupt subsequent middleware execution. 114 | func (ctx *Context) Next(req *http.Request) (*http.Response, error) { 115 | var ( 116 | total = len(ctx.middlewares) 117 | err error 118 | ) 119 | ctx.mi++ 120 | if ctx.mi >= total { 121 | ctx.mi = -1 122 | // Final request coverage 123 | ctx.Request = req 124 | if req == nil { 125 | return nil, RequestNilErr 126 | } 127 | // To make the middleware available to the tunnel proxy, 128 | // no response is obtained when the request method is equal to Connect 129 | if req.Method == http.MethodConnect { 130 | return nil, MethodNotSupportErr 131 | } 132 | // Is it a Websocket requests 133 | if isWebSocketRequest(req) { 134 | return nil, RequestWebsocketUpgradeErr 135 | } 136 | 137 | return func() (*http.Response, error) { 138 | // explicitly discard request body to avoid data races in certain RoundTripper implementations 139 | // see https://github.com/golang/go/issues/61596#issuecomment-1652345131 140 | defer req.Body.Close() 141 | return ctx.RoundTrip(req) 142 | }() 143 | } 144 | 145 | middleware := ctx.middlewares[ctx.mi] 146 | ctx.Response, err = middleware.Handle(req, ctx) 147 | ctx.mi = -1 148 | return ctx.Response, err 149 | } 150 | 151 | // RoundTrip implements the RoundTripper interface. 152 | // 153 | // For higher-level HTTP client support (such as handling of cookies 154 | // and redirects), see Get, Post, and the Client type. 155 | // 156 | // Like the RoundTripper interface, the error types returned 157 | // by RoundTrip are unspecified. 158 | func (ctx *Context) RoundTrip(req *http.Request) (*http.Response, error) { 159 | // These Headers must be reset when a client Request is issued to reuse a Request 160 | if !ctx.KeepClientHeaders { 161 | ResetClientHeaders(req) 162 | } 163 | 164 | // In some cases it is not always necessary to remove the Proxy Header. 165 | // For example, cascade proxy 166 | if !ctx.KeepProxyHeaders { 167 | RemoveProxyHeaders(req) 168 | } 169 | 170 | if ctx.Transport != nil { 171 | return ctx.Transport.RoundTrip(req) 172 | } 173 | return DefaultTransport.RoundTrip(req) 174 | } 175 | 176 | // WithRequest get the Context of the request 177 | func (ctx *Context) WithRequest(req *http.Request) *Context { 178 | return &Context{ 179 | Context: ctx.Context, 180 | Request: req, 181 | Response: nil, 182 | KeepProxyHeaders: ctx.KeepProxyHeaders, 183 | KeepClientHeaders: ctx.KeepClientHeaders, 184 | KeepDestinationHeaders: ctx.KeepDestinationHeaders, 185 | Transport: ctx.Transport, 186 | mi: -1, 187 | middlewares: ctx.middlewares, 188 | } 189 | } 190 | 191 | // ResetClientHeaders These Headers must be reset when a client Request is issued to reuse a Request 192 | func ResetClientHeaders(r *http.Request) { 193 | // this must be reset when serving a request with the client 194 | r.RequestURI = "" 195 | // If no Accept-Encoding header exists, Transport will add the headers it can accept 196 | // and would wrap the response body with the relevant reader. 197 | r.Header.Del("Accept-Encoding") 198 | } 199 | 200 | // Hop-by-hop headers. These are removed when sent to the backend. 201 | // As of RFC 7230, hop-by-hop headers are required to appear in the 202 | // Connection header field. These are the headers defined by the 203 | // obsoleted RFC 2616 (section 13.5.1) and are used for backward 204 | // compatibility. 205 | func RemoveProxyHeaders(r *http.Request) { 206 | // RFC 2616 (section 13.5.1) 207 | // https://www.ietf.org/rfc/rfc2616.txt 208 | r.Header.Del("Proxy-Connection") 209 | r.Header.Del("Proxy-Authenticate") 210 | r.Header.Del("Proxy-Authorization") 211 | // Connection, Authenticate and Authorization are single hop Header: 212 | // http://www.w3.org/Protocols/rfc2616/rfc2616.txt 213 | // 14.10 Connection 214 | // The Connection general-header field allows the sender to specify 215 | // options that are desired for that particular connection and MUST NOT 216 | // be communicated by proxies over further connections. 217 | 218 | // When server reads http request it sets req.Close to true if 219 | // "Connection" header contains "close". 220 | // https://github.com/golang/go/blob/master/src/net/http/request.go#L1080 221 | // Later, transfer.go adds "Connection: close" back when req.Close is true 222 | // https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275 223 | // That's why tests that checks "Connection: close" removal fail 224 | if r.Header.Get("Connection") == "close" { 225 | r.Close = false 226 | } 227 | r.Header.Del("Connection") 228 | } 229 | -------------------------------------------------------------------------------- /counter_encryptor.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/ecdsa" 7 | "crypto/rsa" 8 | "crypto/sha256" 9 | "crypto/x509" 10 | "errors" 11 | ) 12 | 13 | type CounterEncryptorRand struct { 14 | cipher cipher.Block 15 | counter []byte 16 | rand []byte 17 | ix int 18 | } 19 | 20 | func NewCounterEncryptorRand(key interface{}, seed []byte) (r CounterEncryptorRand, err error) { 21 | var keyBytes []byte 22 | switch key := key.(type) { 23 | case *rsa.PrivateKey: 24 | keyBytes = x509.MarshalPKCS1PrivateKey(key) 25 | case *ecdsa.PrivateKey: 26 | if keyBytes, err = x509.MarshalECPrivateKey(key); err != nil { 27 | return 28 | } 29 | default: 30 | err = errors.New("only RSA and ECDSA keys supported") 31 | return 32 | } 33 | h := sha256.New() 34 | if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { 35 | return 36 | } 37 | r.counter = make([]byte, r.cipher.BlockSize()) 38 | if seed != nil { 39 | copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()]) 40 | } 41 | r.rand = make([]byte, r.cipher.BlockSize()) 42 | r.ix = len(r.rand) 43 | return 44 | } 45 | 46 | func (c *CounterEncryptorRand) Seed(b []byte) { 47 | if len(b) != len(c.counter) { 48 | panic("SetCounter: wrong counter size") 49 | } 50 | copy(c.counter, b) 51 | } 52 | 53 | func (c *CounterEncryptorRand) refill() { 54 | c.cipher.Encrypt(c.rand, c.counter) 55 | for i := 0; i < len(c.counter); i++ { 56 | if c.counter[i]++; c.counter[i] != 0 { 57 | break 58 | } 59 | } 60 | c.ix = 0 61 | } 62 | 63 | func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) { 64 | if c.ix == len(c.rand) { 65 | c.refill() 66 | } 67 | if n = len(c.rand) - c.ix; n > len(b) { 68 | n = len(b) 69 | } 70 | copy(b, c.rand[c.ix:c.ix+n]) 71 | c.ix += n 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /counter_encryptor_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rsa" 6 | "encoding/binary" 7 | "io" 8 | "math" 9 | "math/rand" 10 | "testing" 11 | ) 12 | 13 | type RandSeedReader struct { 14 | r rand.Rand 15 | } 16 | 17 | func (r *RandSeedReader) Read(b []byte) (n int, err error) { 18 | for i := range b { 19 | b[i] = byte(r.r.Int() & 0xFF) 20 | } 21 | return len(b), nil 22 | } 23 | 24 | func TestCounterEncDifferentConsecutive(t *testing.T) { 25 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) 26 | fatalOnErr(err, "rsa.GenerateKey", t) 27 | c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) 28 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) 29 | for i := 0; i < 100*1000; i++ { 30 | var a, b int64 31 | binary.Read(&c, binary.BigEndian, &a) 32 | binary.Read(&c, binary.BigEndian, &b) 33 | if a == b { 34 | t.Fatal("two consecutive equal int64", a, b) 35 | } 36 | } 37 | } 38 | 39 | func TestCounterEncIdenticalStreams(t *testing.T) { 40 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) 41 | fatalOnErr(err, "rsa.GenerateKey", t) 42 | c1, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) 43 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) 44 | c2, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) 45 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) 46 | nout := 1000 47 | out1, out2 := make([]byte, nout), make([]byte, nout) 48 | io.ReadFull(&c1, out1) 49 | tmp := out2[:] 50 | rand.Seed(0xFF43109) 51 | for len(tmp) > 0 { 52 | n := 1 + rand.Intn(256) 53 | if n > len(tmp) { 54 | n = len(tmp) 55 | } 56 | n, err := c2.Read(tmp[:n]) 57 | fatalOnErr(err, "CounterEncryptorRand.Read", t) 58 | tmp = tmp[n:] 59 | } 60 | if !bytes.Equal(out1, out2) { 61 | t.Error("identical CSPRNG does not produce the same output") 62 | } 63 | } 64 | 65 | func stddev(data []int) float64 { 66 | var sum, sum_sqr float64 = 0, 0 67 | for _, h := range data { 68 | sum += float64(h) 69 | sum_sqr += float64(h) * float64(h) 70 | } 71 | n := float64(len(data)) 72 | variance := (sum_sqr - ((sum * sum) / n)) / (n - 1) 73 | return math.Sqrt(variance) 74 | } 75 | 76 | func TestCounterEncStreamHistogram(t *testing.T) { 77 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) 78 | fatalOnErr(err, "rsa.GenerateKey", t) 79 | c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) 80 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) 81 | nout := 100 * 1000 82 | out := make([]byte, nout) 83 | io.ReadFull(&c, out) 84 | refhist := make([]int, 512) 85 | for i := 0; i < nout; i++ { 86 | refhist[rand.Intn(256)]++ 87 | } 88 | hist := make([]int, 512) 89 | for _, b := range out { 90 | hist[int(b)]++ 91 | } 92 | refstddev, stddev := stddev(refhist), stddev(hist) 93 | // due to lack of time, I guestimate 94 | t.Logf("ref:%v - act:%v = %v", refstddev, stddev, math.Abs(refstddev-stddev)) 95 | if math.Abs(refstddev-stddev) >= 1 { 96 | t.Errorf("stddev of ref histogram different than regular PRNG: %v %v", refstddev, stddev) 97 | } 98 | } 99 | 100 | func fatalOnErr(err error, msg string, t *testing.T) { 101 | if err != nil { 102 | t.Fatal(msg, err) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | // Filter is an request interceptor 10 | type Filter interface { 11 | Match(req *http.Request) bool 12 | } 13 | 14 | // A wrapper that would convert a function to a Filter interface type 15 | type FilterFunc func(req *http.Request) bool 16 | 17 | // Filter.Match(req) <=> FilterFunc(req) 18 | func (f FilterFunc) Match(req *http.Request) bool { 19 | return f(req) 20 | } 21 | 22 | // FilterHostMatches for request.Host 23 | func FilterHostMatches(regexps ...*regexp.Regexp) Filter { 24 | return FilterFunc(func(req *http.Request) bool { 25 | for _, re := range regexps { 26 | if re.MatchString(req.Host) { 27 | return true 28 | } 29 | } 30 | return false 31 | }) 32 | } 33 | 34 | // FilterHostIs returns a Filter, testing whether the host to which the request is directed to equal 35 | // to one of the given strings 36 | func FilterHostIs(hosts ...string) Filter { 37 | hostSet := make(map[string]bool) 38 | for _, h := range hosts { 39 | hostSet[h] = true 40 | } 41 | return FilterFunc(func(req *http.Request) bool { 42 | _, ok := hostSet[req.URL.Host] 43 | return ok 44 | }) 45 | } 46 | 47 | // FilterUrlMatches returns a Filter testing whether the destination URL 48 | // of the request matches the given regexp, with or without prefix 49 | func FilterUrlMatches(re *regexp.Regexp) Filter { 50 | return FilterFunc(func(req *http.Request) bool { 51 | return re.MatchString(req.URL.Path) || 52 | re.MatchString(req.URL.Host+req.URL.Path) 53 | }) 54 | } 55 | 56 | // FilterUrlHasPrefix returns a Filter checking wether the destination URL the proxy client has requested 57 | // has the given prefix, with or without the host. 58 | // For example FilterUrlHasPrefix("host/x") will match requests of the form 'GET host/x', and will match 59 | // requests to url 'http://host/x' 60 | func FilterUrlHasPrefix(prefix string) Filter { 61 | return FilterFunc(func(req *http.Request) bool { 62 | return strings.HasPrefix(req.URL.Path, prefix) || 63 | strings.HasPrefix(req.URL.Host+req.URL.Path, prefix) || 64 | strings.HasPrefix(req.URL.Scheme+req.URL.Host+req.URL.Path, prefix) 65 | }) 66 | } 67 | 68 | // FilterUrlIs returns a Filter, testing whether or not the request URL is one of the given strings 69 | // with or without the host prefix. 70 | // FilterUrlIs("google.com/","foo") will match requests 'GET /' to 'google.com', requests `'GET google.com/' to 71 | // any host, and requests of the form 'GET foo'. 72 | func FilterUrlIs(urls ...string) Filter { 73 | urlSet := make(map[string]bool) 74 | for _, u := range urls { 75 | urlSet[u] = true 76 | } 77 | return FilterFunc(func(req *http.Request) bool { 78 | _, pathOk := urlSet[req.URL.Path] 79 | _, hostAndOk := urlSet[req.URL.Host+req.URL.Path] 80 | return pathOk || hostAndOk 81 | }) 82 | } 83 | -------------------------------------------------------------------------------- /filter_group.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import "net/http" 4 | 5 | type FilterGroup interface { 6 | Handle() 7 | } 8 | 9 | // ReqFilterGroup ReqCondition is a request filter group 10 | type ReqFilterGroup struct { 11 | ctx *Context 12 | filters []Filter 13 | } 14 | 15 | func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) { 16 | cond.Do(RequestHandleFunc(fn)) 17 | } 18 | 19 | func (cond *ReqFilterGroup) Do(h RequestHandle) { 20 | cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { 21 | total := len(cond.filters) 22 | for i := 0; i < total; i++ { 23 | if !cond.filters[i].Match(req) { 24 | return ctx.Next(req) 25 | } 26 | } 27 | 28 | req, resp := h.HandleRequest(req, ctx) 29 | if resp != nil { 30 | return resp, nil 31 | } 32 | 33 | return ctx.Next(req) 34 | }) 35 | } 36 | 37 | // RespFilterGroup ReqCondition is a response filter group 38 | type RespFilterGroup struct { 39 | ctx *Context 40 | filters []Filter 41 | } 42 | 43 | func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, err error, ctx *Context) (*http.Response, error)) { 44 | cond.Do(ResponseHandleFunc(fn)) 45 | } 46 | 47 | func (cond *RespFilterGroup) Do(h ResponseHandle) { 48 | cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { 49 | total := len(cond.filters) 50 | for i := 0; i < total; i++ { 51 | if !cond.filters[i].Match(req) { 52 | return ctx.Next(req) 53 | } 54 | } 55 | resp, err := ctx.Next(req) 56 | return h.HandleResponse(resp, err, ctx) 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /forward_handler.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "net/http/httputil" 8 | "strconv" 9 | 10 | "github.com/telanflow/mps/pool" 11 | ) 12 | 13 | // ForwardHandler The forward proxy type. Implements http.Handler. 14 | type ForwardHandler struct { 15 | Ctx *Context 16 | BufferPool httputil.BufferPool 17 | } 18 | 19 | // NewForwardHandler Create a forward proxy 20 | func NewForwardHandler() *ForwardHandler { 21 | return &ForwardHandler{ 22 | Ctx: NewContext(), 23 | BufferPool: pool.DefaultBuffer, 24 | } 25 | } 26 | 27 | // NewForwardHandlerWithContext Create a ForwardHandler with Context 28 | func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler { 29 | return &ForwardHandler{ 30 | Ctx: ctx, 31 | BufferPool: pool.DefaultBuffer, 32 | } 33 | } 34 | 35 | // Standard net/http function. You can use it alone 36 | func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 37 | // Copying a Context preserves the Transport, Middleware 38 | ctx := forward.Ctx.WithRequest(req) 39 | resp, err := ctx.Next(req) 40 | if err != nil { 41 | http.Error(rw, err.Error(), 502) 42 | return 43 | } 44 | defer resp.Body.Close() 45 | 46 | var ( 47 | // Body buffer 48 | buffer = new(bytes.Buffer) 49 | // Body size 50 | bufferSize int64 51 | ) 52 | 53 | buf := forward.buffer().Get() 54 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf) 55 | forward.buffer().Put(buf) 56 | if err != nil { 57 | http.Error(rw, err.Error(), 502) 58 | return 59 | } 60 | 61 | resp.ContentLength = bufferSize 62 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize))) 63 | copyHeaders(rw.Header(), resp.Header, forward.Ctx.KeepDestinationHeaders) 64 | rw.WriteHeader(resp.StatusCode) 65 | _, err = buffer.WriteTo(rw) 66 | } 67 | 68 | // Use registers an Middleware to proxy 69 | func (forward *ForwardHandler) Use(middleware ...Middleware) { 70 | forward.Ctx.Use(middleware...) 71 | } 72 | 73 | // UseFunc registers an MiddlewareFunc to proxy 74 | func (forward *ForwardHandler) UseFunc(fus ...MiddlewareFunc) { 75 | forward.Ctx.UseFunc(fus...) 76 | } 77 | 78 | // OnRequest filter requests through Filters 79 | func (forward *ForwardHandler) OnRequest(filters ...Filter) *ReqFilterGroup { 80 | return &ReqFilterGroup{ctx: forward.Ctx, filters: filters} 81 | } 82 | 83 | // OnResponse filter response through Filters 84 | func (forward *ForwardHandler) OnResponse(filters ...Filter) *RespFilterGroup { 85 | return &RespFilterGroup{ctx: forward.Ctx, filters: filters} 86 | } 87 | 88 | // Transport 89 | func (forward *ForwardHandler) Transport() *http.Transport { 90 | return forward.Ctx.Transport 91 | } 92 | 93 | // Get buffer pool 94 | func (forward *ForwardHandler) buffer() httputil.BufferPool { 95 | if forward.BufferPool != nil { 96 | return forward.BufferPool 97 | } 98 | return pool.DefaultBuffer 99 | } 100 | -------------------------------------------------------------------------------- /forward_handler_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "strconv" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewForwardHandler_ContentLength(t *testing.T) { 15 | srv := newTestServer() 16 | defer srv.Close() 17 | 18 | forwardHandler := NewForwardHandler() 19 | proxySrv := httptest.NewServer(forwardHandler) 20 | defer proxySrv.Close() 21 | 22 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { 23 | return url.Parse(proxySrv.URL) 24 | }) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | defer resp.Body.Close() 29 | 30 | body, _ := io.ReadAll(resp.Body) 31 | bodySize := len(body) 32 | contentLength, _ := strconv.Atoi(resp.Header.Get("Content-Length")) 33 | 34 | asserts := assert.New(t) 35 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200") 36 | asserts.Equal(bodySize, contentLength, "Content-Length should be equal "+strconv.Itoa(bodySize)) 37 | asserts.Equal(int64(bodySize), resp.ContentLength) 38 | } 39 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/telanflow/mps 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.5.0 7 | github.com/stretchr/testify v1.8.4 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 5 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 10 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 11 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 12 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 13 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 14 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 17 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 19 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | -------------------------------------------------------------------------------- /handle.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import "net/http" 4 | 5 | type Handle interface { 6 | RequestHandle 7 | ResponseHandle 8 | } 9 | 10 | type RequestHandle interface { 11 | HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response) 12 | } 13 | 14 | // A wrapper that would convert a function to a RequestHandle interface type 15 | type RequestHandleFunc func(req *http.Request, ctx *Context) (*http.Request, *http.Response) 16 | 17 | // RequestHandle.Handle(req, ctx) <=> RequestHandleFunc(req, ctx) 18 | func (f RequestHandleFunc) HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response) { 19 | return f(req, ctx) 20 | } 21 | 22 | type ResponseHandle interface { 23 | HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error) 24 | } 25 | 26 | // A wrapper that would convert a function to a ResponseHandle interface type 27 | type ResponseHandleFunc func(resp *http.Response, err error, ctx *Context) (*http.Response, error) 28 | 29 | // ResponseHandle.Handle(resp, ctx) <=> ResponseHandleFunc(resp, ctx) 30 | func (f ResponseHandleFunc) HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error) { 31 | return f(resp, err, ctx) 32 | } 33 | -------------------------------------------------------------------------------- /http_proxy.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | // The basic proxy type. Implements http.Handler. 11 | type HttpProxy struct { 12 | // Handles Connect requests use the TunnelHandler by default 13 | HandleConnect http.Handler 14 | 15 | // HTTP requests use the ForwardHandler by default 16 | HttpHandler http.Handler 17 | 18 | // HTTP requests use the ReverseHandler by default 19 | ReverseHandler http.Handler 20 | 21 | // Client request Context 22 | Ctx *Context 23 | } 24 | 25 | func NewHttpProxy() *HttpProxy { 26 | // default Context with Proxy 27 | ctx := NewContext() 28 | return &HttpProxy{ 29 | Ctx: ctx, 30 | // default handles Connect method 31 | HandleConnect: &TunnelHandler{Ctx: ctx}, 32 | // default handles HTTP request 33 | HttpHandler: &ForwardHandler{Ctx: ctx}, 34 | // default Reverse proxy 35 | ReverseHandler: &ReverseHandler{Ctx: ctx}, 36 | } 37 | } 38 | 39 | // Standard net/http function. 40 | func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 41 | if req.Method == http.MethodConnect { 42 | proxy.HandleConnect.ServeHTTP(rw, req) 43 | return 44 | } 45 | 46 | // reverse proxy http request for example: 47 | // GET / HTTP/1.1 48 | // Host: www.example.com 49 | // Connection: keep-alive 50 | // 51 | // forward proxy http request for example : 52 | // GET http://www.example.com/ HTTP/1.1 53 | // Host: www.example.com 54 | // Proxy-Connection: keep-alive 55 | // 56 | // Determines whether the path is absolute 57 | if !req.URL.IsAbs() { 58 | proxy.ReverseHandler.ServeHTTP(rw, req) 59 | } else { 60 | proxy.HttpHandler.ServeHTTP(rw, req) 61 | } 62 | } 63 | 64 | // Use registers an Middleware to proxy 65 | func (proxy *HttpProxy) Use(middleware ...Middleware) { 66 | proxy.Ctx.Use(middleware...) 67 | } 68 | 69 | // UseFunc registers an MiddlewareFunc to proxy 70 | func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) { 71 | proxy.Ctx.UseFunc(fus...) 72 | } 73 | 74 | // OnRequest filter requests through Filters 75 | func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqFilterGroup { 76 | return &ReqFilterGroup{ctx: proxy.Ctx, filters: filters} 77 | } 78 | 79 | // OnResponse filter response through Filters 80 | func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespFilterGroup { 81 | return &RespFilterGroup{ctx: proxy.Ctx, filters: filters} 82 | } 83 | 84 | // Transport get http.Transport instance 85 | func (proxy *HttpProxy) Transport() *http.Transport { 86 | return proxy.Ctx.Transport 87 | } 88 | 89 | // hijacker an HTTP handler to take over the connection. 90 | func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) { 91 | hij, ok := rw.(http.Hijacker) 92 | if !ok { 93 | err = errors.New("not a hijacker") 94 | return 95 | } 96 | 97 | conn, _, err = hij.Hijack() 98 | if err != nil { 99 | err = fmt.Errorf("cannot hijack connection %v", err) 100 | } 101 | return 102 | } 103 | 104 | func copyHeaders(dst, src http.Header, keepDestHeaders bool) { 105 | if !keepDestHeaders { 106 | for k := range dst { 107 | dst.Del(k) 108 | } 109 | } 110 | for k, vs := range src { 111 | for _, v := range vs { 112 | dst.Add(k, v) 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /http_proxy_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func newTestServer() *httptest.Server { 15 | return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 16 | query := req.URL.Query() 17 | text := []byte("hello world") 18 | if query.Get("text") != "" { 19 | text = []byte(query.Get("text")) 20 | } 21 | 22 | rw.Header().Set("Server", "MPS proxy server") 23 | _, _ = rw.Write(text) 24 | })) 25 | } 26 | 27 | func HttpGet(rawurl string, proxy func(r *http.Request) (*url.URL, error)) (*http.Response, error) { 28 | req, _ := http.NewRequest(http.MethodGet, rawurl, nil) 29 | http.DefaultClient.Transport = &http.Transport{ 30 | Proxy: proxy, 31 | } 32 | return http.DefaultClient.Do(req) 33 | } 34 | 35 | func TestNewHttpProxy(t *testing.T) { 36 | srv := newTestServer() 37 | defer srv.Close() 38 | 39 | proxy := NewHttpProxy() 40 | proxySrv := httptest.NewServer(proxy) 41 | defer proxySrv.Close() 42 | 43 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { 44 | return url.Parse(proxySrv.URL) 45 | }) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | 50 | body, _ := io.ReadAll(resp.Body) 51 | resp.Body.Close() 52 | 53 | asserts := assert.New(t) 54 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200") 55 | asserts.Equal(int64(len(body)), resp.ContentLength) 56 | } 57 | 58 | func TestMiddlewareFunc(t *testing.T) { 59 | // target server 60 | srv := newTestServer() 61 | defer srv.Close() 62 | 63 | // proxy server 64 | proxy := NewHttpProxy() 65 | 66 | // use Middleware 67 | proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { 68 | resp, err := ctx.Next(req) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | var buf bytes.Buffer 74 | buf.WriteString("middleware") 75 | resp.Body = io.NopCloser(&buf) 76 | 77 | // 78 | // You have to reset Content-Length, if you change the Body. 79 | //resp.ContentLength = int64(buf.Len()) 80 | //resp.Header.Set("Content-Length", strconv.Itoa(buf.Len())) 81 | 82 | return resp, nil 83 | }) 84 | proxySrv := httptest.NewServer(proxy) 85 | defer proxySrv.Close() 86 | 87 | // send request 88 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { 89 | return url.Parse(proxySrv.URL) 90 | }) 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | body, _ := io.ReadAll(resp.Body) 96 | resp.Body.Close() 97 | 98 | asserts := assert.New(t) 99 | asserts.Equal(resp.StatusCode, 200) 100 | asserts.Equal(int64(len(body)), resp.ContentLength) 101 | asserts.Equal(string(body), "middleware") 102 | } 103 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import "net/http" 4 | 5 | // Middleware will "tamper" with the request coming to the proxy server 6 | type Middleware interface { 7 | // Handle execute the next middleware as a linked list. "ctx.Next(req)" 8 | // eg: 9 | // func Handle(req *http.Request, ctx *Context) (*http.Response, error) { 10 | // // You can do anything to modify the http.Request ... 11 | // resp, err := ctx.Next(req) 12 | // // You can do anything to modify the http.Response ... 13 | // return resp, err 14 | // } 15 | // 16 | // Alternatively, you can simply return the response without executing `ctx.Next()`, 17 | // which will interrupt subsequent middleware execution. 18 | Handle(req *http.Request, ctx *Context) (*http.Response, error) 19 | } 20 | 21 | // MiddlewareFunc A wrapper that would convert a function to a Middleware interface type 22 | type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error) 23 | 24 | // Handle Middleware.Handle(req, ctx) <=> MiddlewareFunc(req, ctx) 25 | func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) { 26 | return f(req, ctx) 27 | } 28 | -------------------------------------------------------------------------------- /middleware/basicAuth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "io" 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/telanflow/mps" 11 | ) 12 | 13 | // proxy Authorization header 14 | const proxyAuthorization = "Proxy-Authorization" 15 | 16 | // BasicAuth returns a HTTP Basic Authentication middleware for requests 17 | // You probably want to use mps.BasicAuth(proxy) to enable authentication for all proxy activities 18 | func BasicAuth(realm string, fn func(username, password string) bool) mps.MiddlewareFunc { 19 | return func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 20 | auth := req.Header.Get(proxyAuthorization) 21 | if auth == "" { 22 | return BasicUnauthorized(req, realm), nil 23 | } 24 | // parses an Basic Authentication string. 25 | usr, pwd, ok := parseBasicAuth(auth) 26 | if !ok { 27 | return BasicUnauthorized(req, realm), nil 28 | } 29 | if !fn(usr, pwd) { 30 | return BasicUnauthorized(req, realm), nil 31 | } 32 | // Authorization passed 33 | return ctx.Next(req) 34 | } 35 | } 36 | 37 | // SetBasicAuth sets the request's Authorization header to use HTTP 38 | // Basic Authentication with the provided username and password. 39 | // 40 | // With HTTP Basic Authentication the provided username and password 41 | // are not encrypted. 42 | // 43 | // Some protocols may impose additional requirements on pre-escaping the 44 | // username and password. For instance, when used with OAuth2, both arguments 45 | // must be URL encoded first with url.QueryEscape. 46 | func SetBasicAuth(req *http.Request, username, password string) { 47 | req.Header.Set(proxyAuthorization, "Basic "+basicAuth(username, password)) 48 | } 49 | 50 | // See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt 51 | // "To receive authorization, the client sends the userid and password, 52 | // separated by a single colon (":") character, within a base64 53 | // encoded string in the credentials." 54 | // It is not meant to be urlencoded. 55 | func basicAuth(username, password string) string { 56 | auth := username + ":" + password 57 | return base64.StdEncoding.EncodeToString([]byte(auth)) 58 | } 59 | 60 | // parseBasicAuth parses an HTTP Basic Authentication string. 61 | // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). 62 | func parseBasicAuth(auth string) (username, password string, ok bool) { 63 | const prefix = "Basic " 64 | // Case insensitive prefix match. See Issue 22736. 65 | if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { 66 | return 67 | } 68 | c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) 69 | if err != nil { 70 | return 71 | } 72 | cs := string(c) 73 | s := strings.IndexByte(cs, ':') 74 | if s < 0 { 75 | return 76 | } 77 | return cs[:s], cs[s+1:], true 78 | } 79 | 80 | func BasicUnauthorized(req *http.Request, realm string) *http.Response { 81 | const unauthorizedMsg = "407 Proxy Authentication Required" 82 | // verify realm is well formed 83 | return &http.Response{ 84 | StatusCode: 407, 85 | ProtoMajor: 1, 86 | ProtoMinor: 1, 87 | Request: req, 88 | Header: http.Header{ 89 | "Proxy-Authenticate": []string{"Basic realm=" + realm}, 90 | "Proxy-Connection": []string{"close"}, 91 | }, 92 | Body: io.NopCloser(bytes.NewBuffer([]byte(unauthorizedMsg))), 93 | ContentLength: int64(len(unauthorizedMsg)), 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /middleware/singleHostReverseProxy.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/telanflow/mps" 9 | ) 10 | 11 | // SingleHostReverseProxy returns a mps.Middleware 12 | // URLs to the scheme, host, and base path provided in target. If the 13 | // target's path is "/base" and the incoming request was for "/dir", 14 | // the target request will be for /base/dir. 15 | // SingleHostReverseProxy does not rewrite the Host header. 16 | // To rewrite Host headers, use ReverseProxy directly with a custom 17 | // Director policy. 18 | func SingleHostReverseProxy(target *url.URL) mps.MiddlewareFunc { 19 | targetQuery := target.RawQuery 20 | return func(req *http.Request, ctx *mps.Context) (*http.Response, error) { 21 | // changed request Host 22 | req.Host = target.Host 23 | // changed request URL 24 | req.URL.Scheme = target.Scheme 25 | req.URL.Host = target.Host 26 | req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) 27 | if targetQuery == "" || req.URL.RawQuery == "" { 28 | req.URL.RawQuery = targetQuery + req.URL.RawQuery 29 | } else { 30 | req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 31 | } 32 | if _, ok := req.Header["User-Agent"]; !ok { 33 | // explicitly disable User-Agent so it's not set to default value 34 | req.Header.Set("User-Agent", "") 35 | } 36 | return ctx.Next(req) 37 | } 38 | } 39 | 40 | func singleJoiningSlash(a, b string) string { 41 | aslash := strings.HasSuffix(a, "/") 42 | bslash := strings.HasPrefix(b, "/") 43 | switch { 44 | case aslash && bslash: 45 | return a + b[1:] 46 | case !aslash && !bslash: 47 | return a + "/" + b 48 | } 49 | return a + b 50 | } 51 | -------------------------------------------------------------------------------- /mitm_handler.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto" 7 | "crypto/ecdsa" 8 | "crypto/elliptic" 9 | "crypto/rsa" 10 | "crypto/sha1" 11 | "crypto/tls" 12 | "crypto/x509" 13 | "crypto/x509/pkix" 14 | "errors" 15 | "fmt" 16 | "io" 17 | "math/big" 18 | "net" 19 | "net/http" 20 | "net/http/httputil" 21 | "net/url" 22 | "regexp" 23 | "sort" 24 | "strconv" 25 | "strings" 26 | "time" 27 | 28 | "github.com/telanflow/mps/cert" 29 | "github.com/telanflow/mps/pool" 30 | ) 31 | 32 | var ( 33 | HttpMitmOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n") 34 | httpsRegexp = regexp.MustCompile("^https://") 35 | ) 36 | 37 | // MitmHandler The Man-in-the-middle proxy type. Implements http.Handler. 38 | type MitmHandler struct { 39 | Ctx *Context 40 | BufferPool httputil.BufferPool 41 | Certificate tls.Certificate 42 | // CertContainer is certificate storage container 43 | CertContainer cert.Container 44 | } 45 | 46 | // NewMitmHandler Create a mitmHandler, use default cert. 47 | func NewMitmHandler() *MitmHandler { 48 | return &MitmHandler{ 49 | Ctx: NewContext(), 50 | BufferPool: pool.DefaultBuffer, 51 | Certificate: cert.DefaultCertificate, 52 | CertContainer: cert.NewMemProvider(), 53 | } 54 | } 55 | 56 | // NewMitmHandlerWithContext Create a MitmHandler, use default cert. 57 | func NewMitmHandlerWithContext(ctx *Context) *MitmHandler { 58 | return &MitmHandler{ 59 | Ctx: ctx, 60 | BufferPool: pool.DefaultBuffer, 61 | Certificate: cert.DefaultCertificate, 62 | CertContainer: cert.NewMemProvider(), 63 | } 64 | } 65 | 66 | // NewMitmHandlerWithCert Create a MitmHandler with cert pem block 67 | func NewMitmHandlerWithCert(ctx *Context, certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) { 68 | certificate, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 69 | if err != nil { 70 | return nil, err 71 | } 72 | return &MitmHandler{ 73 | Ctx: ctx, 74 | BufferPool: pool.DefaultBuffer, 75 | Certificate: certificate, 76 | CertContainer: cert.NewMemProvider(), 77 | }, nil 78 | } 79 | 80 | // NewMitmHandlerWithCertFile Create a MitmHandler with cert file 81 | func NewMitmHandlerWithCertFile(ctx *Context, certFile, keyFile string) (*MitmHandler, error) { 82 | certificate, err := tls.LoadX509KeyPair(certFile, keyFile) 83 | if err != nil { 84 | return nil, err 85 | } 86 | return &MitmHandler{ 87 | Ctx: ctx, 88 | BufferPool: pool.DefaultBuffer, 89 | Certificate: certificate, 90 | CertContainer: cert.NewMemProvider(), 91 | }, nil 92 | } 93 | 94 | // Standard net/http function. You can use it alone 95 | func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 96 | // execution middleware 97 | ctx := mitm.Ctx.WithRequest(req) 98 | resp, err := ctx.Next(req) 99 | if err != nil && !errors.Is(err, MethodNotSupportErr) { 100 | if resp != nil { 101 | copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders) 102 | rw.WriteHeader(resp.StatusCode) 103 | buf := mitm.buffer().Get() 104 | _, err = io.CopyBuffer(rw, resp.Body, buf) 105 | mitm.buffer().Put(buf) 106 | } 107 | return 108 | } 109 | 110 | // get hijacker connection 111 | clientConn, err := hijacker(rw) 112 | if err != nil { 113 | http.Error(rw, err.Error(), 502) 114 | return 115 | } 116 | 117 | // this goes in a separate goroutine, so that the net/http server won't think we're 118 | // still handling the request even after hijacking the connection. Those HTTP CONNECT 119 | // request can take forever, and the server will be stuck when "closed". 120 | // TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible 121 | tlsConfig, err := mitm.TLSConfigFromCA(req.URL.Host) 122 | if err != nil { 123 | ConnError(clientConn) 124 | return 125 | } 126 | 127 | _, _ = clientConn.Write(HttpMitmOk) 128 | 129 | // data transmit 130 | go mitm.transmit(clientConn, req, tlsConfig) 131 | } 132 | 133 | func (mitm *MitmHandler) transmit(clientConn net.Conn, originalReq *http.Request, tlsConfig *tls.Config) { 134 | // TODO: cache connections to the remote website 135 | rawClientTls := tls.Server(clientConn, tlsConfig) 136 | if err := rawClientTls.Handshake(); err != nil { 137 | ConnError(clientConn) 138 | _ = rawClientTls.Close() 139 | return 140 | } 141 | defer rawClientTls.Close() 142 | 143 | clientTlsReader := bufio.NewReader(rawClientTls) 144 | for !isEof(clientTlsReader) { 145 | req, err := http.ReadRequest(clientTlsReader) 146 | if err != nil { 147 | break 148 | } 149 | 150 | // since we're converting the request, need to carry over the original connecting IP as well 151 | req.RemoteAddr = originalReq.RemoteAddr 152 | 153 | if !httpsRegexp.MatchString(req.URL.String()) { 154 | req.URL, err = url.Parse("https://" + originalReq.Host + req.URL.String()) 155 | } 156 | if err != nil { 157 | return 158 | } 159 | 160 | var resp *http.Response 161 | 162 | // Copying a Context preserves the Transport, Middleware 163 | ctx := mitm.Ctx.WithRequest(req) 164 | resp, err = ctx.Next(req) 165 | if err != nil { 166 | return 167 | } 168 | 169 | var ( 170 | // Body buffer 171 | buffer = new(bytes.Buffer) 172 | // Body size 173 | bufferSize int64 174 | ) 175 | 176 | buf := mitm.buffer().Get() 177 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf) 178 | mitm.buffer().Put(buf) 179 | if err != nil { 180 | _ = resp.Body.Close() 181 | return 182 | } 183 | _ = resp.Body.Close() 184 | 185 | // reset Content-Length 186 | resp.ContentLength = bufferSize 187 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize))) 188 | resp.Body = io.NopCloser(buffer) 189 | err = resp.Write(rawClientTls) 190 | if err != nil { 191 | return 192 | } 193 | } 194 | } 195 | 196 | // Use registers a Middleware to proxy 197 | func (mitm *MitmHandler) Use(middleware ...Middleware) { 198 | mitm.Ctx.Use(middleware...) 199 | } 200 | 201 | // UseFunc registers an MiddlewareFunc to proxy 202 | func (mitm *MitmHandler) UseFunc(fus ...MiddlewareFunc) { 203 | mitm.Ctx.UseFunc(fus...) 204 | } 205 | 206 | // OnRequest filter requests through Filters 207 | func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqFilterGroup { 208 | return &ReqFilterGroup{ctx: mitm.Ctx, filters: filters} 209 | } 210 | 211 | // OnResponse filter response through Filters 212 | func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespFilterGroup { 213 | return &RespFilterGroup{ctx: mitm.Ctx, filters: filters} 214 | } 215 | 216 | // Get buffer pool 217 | func (mitm *MitmHandler) buffer() httputil.BufferPool { 218 | if mitm.BufferPool != nil { 219 | return mitm.BufferPool 220 | } 221 | return pool.DefaultBuffer 222 | } 223 | 224 | // Get cert.Container instance 225 | func (mitm *MitmHandler) certContainer() cert.Container { 226 | if mitm.CertContainer != nil { 227 | return mitm.CertContainer 228 | } 229 | return cert.DefaultMemProvider 230 | } 231 | 232 | // Transport 233 | func (mitm *MitmHandler) Transport() *http.Transport { 234 | return mitm.Ctx.Transport 235 | } 236 | 237 | func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) { 238 | host = stripPort(host) 239 | 240 | // Returned existing certificate for the host 241 | crt, err := mitm.certContainer().Get(host) 242 | if err == nil && crt != nil { 243 | return &tls.Config{ 244 | InsecureSkipVerify: true, 245 | Certificates: []tls.Certificate{*crt}, 246 | }, nil 247 | } 248 | 249 | // Issue a certificate for host 250 | crt, err = signHost(mitm.Certificate, []string{host}) 251 | if err != nil { 252 | err = fmt.Errorf("cannot sign host certificate with provided CA: %v", err) 253 | return nil, err 254 | } 255 | 256 | // Set certificate to container 257 | _ = mitm.certContainer().Set(host, crt) 258 | 259 | return &tls.Config{ 260 | InsecureSkipVerify: true, 261 | Certificates: []tls.Certificate{*crt}, 262 | }, nil 263 | } 264 | 265 | // sign host 266 | func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) { 267 | // Use the provided ca for certificate generation. 268 | var x509ca *x509.Certificate 269 | x509ca, err = x509.ParseCertificate(ca.Certificate[0]) 270 | if err != nil { 271 | return 272 | } 273 | 274 | start := time.Unix(time.Now().Unix()-2592000, 0) // 2592000 = 30 day 275 | end := time.Unix(time.Now().Unix()+31536000, 0) // 31536000 = 365 day 276 | 277 | var random CounterEncryptorRand 278 | random, err = NewCounterEncryptorRand(ca.PrivateKey, hashHosts(hosts)) 279 | if err != nil { 280 | return 281 | } 282 | 283 | var pk crypto.Signer 284 | switch ca.PrivateKey.(type) { 285 | case *rsa.PrivateKey: 286 | pk, err = rsa.GenerateKey(&random, 2048) 287 | case *ecdsa.PrivateKey: 288 | pk, err = ecdsa.GenerateKey(elliptic.P256(), &random) 289 | default: 290 | err = fmt.Errorf("unsupported key type %T", ca.PrivateKey) 291 | } 292 | if err != nil { 293 | return 294 | } 295 | 296 | // certificate template 297 | serial := big.NewInt(mpsRand.Int63()) 298 | tpl := x509.Certificate{ 299 | SerialNumber: serial, 300 | Issuer: x509ca.Subject, 301 | Subject: pkix.Name{ 302 | Organization: []string{"MPS untrusted MITM proxy Inc"}, 303 | }, 304 | NotBefore: start, 305 | NotAfter: end, 306 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 307 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 308 | BasicConstraintsValid: true, 309 | EmailAddresses: x509ca.EmailAddresses, 310 | } 311 | 312 | total := len(hosts) 313 | for i := 0; i < total; i++ { 314 | if ip := net.ParseIP(hosts[i]); ip != nil { 315 | tpl.IPAddresses = append(tpl.IPAddresses, ip) 316 | } else { 317 | tpl.DNSNames = append(tpl.DNSNames, hosts[i]) 318 | tpl.Subject.CommonName = hosts[i] 319 | } 320 | } 321 | 322 | var der []byte 323 | der, err = x509.CreateCertificate(&random, &tpl, x509ca, pk.Public(), ca.PrivateKey) 324 | if err != nil { 325 | return 326 | } 327 | 328 | cert = &tls.Certificate{ 329 | Certificate: [][]byte{der, ca.Certificate[0]}, 330 | PrivateKey: pk, 331 | } 332 | return 333 | } 334 | 335 | func stripPort(s string) string { 336 | var ix int 337 | if strings.Contains(s, "[") && strings.Contains(s, "]") { 338 | // ipv6 : for example : [2606:4700:4700::1111]:443 339 | // strip '[' and ']' 340 | s = strings.ReplaceAll(s, "[", "") 341 | s = strings.ReplaceAll(s, "]", "") 342 | 343 | ix = strings.LastIndexAny(s, ":") 344 | if ix == -1 { 345 | return s 346 | } 347 | } else { 348 | //ipv4 349 | ix = strings.IndexRune(s, ':') 350 | if ix == -1 { 351 | return s 352 | } 353 | } 354 | return s[:ix] 355 | } 356 | 357 | func hashHosts(lst []string) []byte { 358 | c := make([]string, len(lst)) 359 | copy(c, lst) 360 | sort.Strings(c) 361 | h := sha1.New() 362 | h.Write([]byte(strings.Join(c, ","))) 363 | return h.Sum(nil) 364 | } 365 | 366 | // cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if 367 | // cfg is nil. This is safe to call even if cfg is in active use by a TLS 368 | // client or server. 369 | func cloneTLSConfig(cfg *tls.Config) *tls.Config { 370 | if cfg == nil { 371 | return &tls.Config{ 372 | InsecureSkipVerify: true, 373 | } 374 | } 375 | return cfg.Clone() 376 | } 377 | 378 | func isEof(r *bufio.Reader) bool { 379 | _, err := r.Peek(1) 380 | if err == io.EOF { 381 | return true 382 | } 383 | return false 384 | } 385 | -------------------------------------------------------------------------------- /mitm_handler_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/telanflow/mps/cert" 14 | ) 15 | 16 | func TestNewMitmHandler(t *testing.T) { 17 | mitmHandler := NewMitmHandler() 18 | mitmSrv := httptest.NewServer(mitmHandler) 19 | defer mitmSrv.Close() 20 | 21 | clientCertPool := x509.NewCertPool() 22 | ok := clientCertPool.AppendCertsFromPEM([]byte(cert.CertPEM)) 23 | if !ok { 24 | panic("failed to parse root certificate") 25 | } 26 | 27 | req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) 28 | http.DefaultClient.Transport = &http.Transport{ 29 | Proxy: func(r *http.Request) (*url.URL, error) { 30 | return url.Parse(mitmSrv.URL) 31 | }, 32 | TLSClientConfig: &tls.Config{ 33 | Certificates: []tls.Certificate{cert.DefaultCertificate}, 34 | ClientAuth: tls.RequireAndVerifyClientCert, 35 | RootCAs: clientCertPool, 36 | }, 37 | } 38 | 39 | resp, err := http.DefaultClient.Do(req) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | defer resp.Body.Close() 44 | 45 | body, _ := io.ReadAll(resp.Body) 46 | 47 | asserts := assert.New(t) 48 | asserts.Equal(resp.StatusCode, 200, "response status code not equal 200") 49 | asserts.Equal(int64(len(body)), resp.ContentLength) 50 | } 51 | -------------------------------------------------------------------------------- /mps.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | var ( 9 | // global random numbers for MPS. Go v1.20 10 | mpsRand = rand.New(rand.NewSource(time.Now().UnixNano())) 11 | ) 12 | -------------------------------------------------------------------------------- /pool/buffer.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import "sync" 4 | 5 | var DefaultBuffer = NewBuffer(2048) 6 | 7 | type Buffer struct { 8 | pl *sync.Pool 9 | size int 10 | } 11 | 12 | func NewBuffer(size int) *Buffer { 13 | bufPool := &Buffer{ 14 | pl: nil, 15 | size: size, 16 | } 17 | bufPool.pl = &sync.Pool{ 18 | New: bufPool.newPl, 19 | } 20 | return bufPool 21 | } 22 | 23 | func (b *Buffer) Get() []byte { 24 | return b.pl.Get().([]byte) 25 | } 26 | 27 | func (b *Buffer) Put(buf []byte) { 28 | b.pl.Put(buf) 29 | } 30 | 31 | func (b *Buffer) newPl() interface{} { 32 | return make([]byte, b.size) 33 | } 34 | -------------------------------------------------------------------------------- /pool/conn_container.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import "net" 4 | 5 | // ConnContainer connection pool interface 6 | type ConnContainer interface { 7 | // Get returned a idle net.Conn 8 | Get(addr string) (net.Conn, error) 9 | 10 | // Put place a idle net.Conn into the pool 11 | Put(conn net.Conn) error 12 | 13 | // Release connection pool 14 | Release() error 15 | } 16 | -------------------------------------------------------------------------------- /pool/conn_options.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import "time" 4 | 5 | var DefaultConnOptions = &ConnOptions{ 6 | IdleMaxCap: 30, 7 | Timeout: 90 * time.Second, 8 | } 9 | 10 | // ConnOptions is ConnProvider options 11 | type ConnOptions struct { 12 | // IdleMaxCap is max connection capacity for a single net.Addr 13 | IdleMaxCap int 14 | 15 | // Timeout specifies how long the connection will timeout 16 | Timeout time.Duration 17 | } 18 | -------------------------------------------------------------------------------- /pool/conn_provider.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | var DefaultConnProvider = NewConnProvider(DefaultConnOptions) 14 | 15 | // ConnProvider is a connection pool, it implements ConnContainer 16 | type ConnProvider struct { 17 | mu sync.RWMutex 18 | idleConnMap map[string]chan net.Conn 19 | options *ConnOptions 20 | closed int32 21 | } 22 | 23 | // Create a ConnProvider 24 | func NewConnProvider(opt *ConnOptions) *ConnProvider { 25 | return &ConnProvider{ 26 | options: opt, 27 | mu: sync.RWMutex{}, 28 | idleConnMap: make(map[string]chan net.Conn), 29 | } 30 | } 31 | 32 | // Get returned a idle net.Conn 33 | func (p *ConnProvider) Get(addr string) (net.Conn, error) { 34 | closed := atomic.LoadInt32(&p.closed) 35 | if closed == 1 { 36 | return nil, errors.New("pool is closed") 37 | } 38 | 39 | p.mu.Lock() 40 | if _, ok := p.idleConnMap[addr]; !ok { 41 | p.mu.Unlock() 42 | return nil, errors.New("no idle conn") 43 | } 44 | p.mu.Unlock() 45 | 46 | RETRY: 47 | select { 48 | case conn := <-p.idleConnMap[addr]: 49 | // Getting a net.Conn requires verifying that the net.Conn is valid 50 | _, err := conn.Read([]byte{}) 51 | if err != nil || err == io.EOF { 52 | // conn is close Or timeout 53 | _ = conn.Close() 54 | goto RETRY 55 | } 56 | return conn, nil 57 | default: 58 | return nil, errors.New("no idle conn") 59 | } 60 | } 61 | 62 | // Put place a idle net.Conn into the pool 63 | func (p *ConnProvider) Put(conn net.Conn) error { 64 | closed := atomic.LoadInt32(&p.closed) 65 | if closed == 1 { 66 | return errors.New("pool is closed") 67 | } 68 | 69 | addr := conn.RemoteAddr().String() 70 | 71 | p.mu.Lock() 72 | if _, ok := p.idleConnMap[addr]; !ok { 73 | p.idleConnMap[addr] = make(chan net.Conn, p.options.IdleMaxCap) 74 | } 75 | p.mu.Unlock() 76 | 77 | // set conn timeout 78 | // The timeout will be verified at the next `Get()` 79 | err := conn.SetDeadline(time.Now().Add(p.options.Timeout)) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | select { 85 | case p.idleConnMap[addr] <- conn: 86 | return nil 87 | default: 88 | return fmt.Errorf("beyond max capacity") 89 | } 90 | } 91 | 92 | // Release connection pool 93 | func (p *ConnProvider) Release() error { 94 | closed := atomic.LoadInt32(&p.closed) 95 | if closed == 1 { 96 | return errors.New("pool is closed") 97 | } 98 | 99 | atomic.StoreInt32(&p.closed, 1) 100 | for _, connChan := range p.idleConnMap { 101 | close(connChan) 102 | for conn, ok := <-connChan; ok; { 103 | _ = conn.Close() 104 | } 105 | } 106 | return nil 107 | } 108 | -------------------------------------------------------------------------------- /reverse_handler.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "net/http/httputil" 8 | "strconv" 9 | 10 | "github.com/telanflow/mps/pool" 11 | ) 12 | 13 | // ReverseHandler is a reverse proxy server implementation 14 | type ReverseHandler struct { 15 | Ctx *Context 16 | BufferPool httputil.BufferPool 17 | } 18 | 19 | // NewReverseHandler Create a reverse proxy 20 | func NewReverseHandler() *ReverseHandler { 21 | return &ReverseHandler{ 22 | Ctx: NewContext(), 23 | BufferPool: pool.DefaultBuffer, 24 | } 25 | } 26 | 27 | // Standard net/http function. You can use it alone 28 | func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 29 | // Copying a Context preserves the Transport, Middleware 30 | ctx := reverse.Ctx.WithRequest(req) 31 | resp, err := ctx.Next(req) 32 | if err != nil { 33 | http.Error(rw, err.Error(), 502) 34 | return 35 | } 36 | defer resp.Body.Close() 37 | 38 | var ( 39 | // Body buffer 40 | buffer = new(bytes.Buffer) 41 | // Body size 42 | bufferSize int64 43 | ) 44 | 45 | buf := reverse.buffer().Get() 46 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf) 47 | reverse.buffer().Put(buf) 48 | if err != nil { 49 | http.Error(rw, err.Error(), 502) 50 | return 51 | } 52 | 53 | resp.ContentLength = bufferSize 54 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize))) 55 | copyHeaders(rw.Header(), resp.Header, reverse.Ctx.KeepDestinationHeaders) 56 | rw.WriteHeader(resp.StatusCode) 57 | _, err = buffer.WriteTo(rw) 58 | } 59 | 60 | // Use registers an Middleware to proxy 61 | func (reverse *ReverseHandler) Use(middleware ...Middleware) { 62 | reverse.Ctx.Use(middleware...) 63 | } 64 | 65 | // UseFunc registers an MiddlewareFunc to proxy 66 | func (reverse *ReverseHandler) UseFunc(fus ...MiddlewareFunc) { 67 | reverse.Ctx.UseFunc(fus...) 68 | } 69 | 70 | // OnRequest filter requests through Filters 71 | func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqFilterGroup { 72 | return &ReqFilterGroup{ctx: reverse.Ctx, filters: filters} 73 | } 74 | 75 | // OnResponse filter response through Filters 76 | func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespFilterGroup { 77 | return &RespFilterGroup{ctx: reverse.Ctx, filters: filters} 78 | } 79 | 80 | // Get buffer pool 81 | func (reverse *ReverseHandler) buffer() httputil.BufferPool { 82 | if reverse.BufferPool != nil { 83 | return reverse.BufferPool 84 | } 85 | return pool.DefaultBuffer 86 | } 87 | 88 | // Transport 89 | func (reverse *ReverseHandler) Transport() *http.Transport { 90 | return reverse.Ctx.Transport 91 | } 92 | -------------------------------------------------------------------------------- /reverse_handler_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "strconv" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewReverseHandler(t *testing.T) { 15 | srv := newTestServer() 16 | defer srv.Close() 17 | 18 | reverseHandler := NewReverseHandler() 19 | proxySrv := httptest.NewServer(reverseHandler) 20 | defer proxySrv.Close() 21 | 22 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { 23 | return url.Parse(proxySrv.URL) 24 | }) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | defer resp.Body.Close() 29 | 30 | body, _ := io.ReadAll(resp.Body) 31 | bodySize := len(body) 32 | contentLength, _ := strconv.Atoi(resp.Header.Get("Content-Length")) 33 | 34 | asserts := assert.New(t) 35 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200") 36 | asserts.Equal(bodySize, contentLength, "Content-Length should be equal "+strconv.Itoa(bodySize)) 37 | asserts.Equal(int64(bodySize), resp.ContentLength) 38 | } 39 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | // Default http.Transport option 11 | var DefaultTransport = &http.Transport{ 12 | DialContext: (&net.Dialer{ 13 | Timeout: 15 * time.Second, 14 | KeepAlive: 30 * time.Second, 15 | DualStack: true, 16 | }).DialContext, 17 | ForceAttemptHTTP2: true, 18 | MaxIdleConns: 100, 19 | IdleConnTimeout: 90 * time.Second, 20 | TLSHandshakeTimeout: 10 * time.Second, 21 | ExpectContinueTimeout: 1 * time.Second, 22 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 23 | Proxy: http.ProxyFromEnvironment, 24 | } 25 | -------------------------------------------------------------------------------- /tunnel_handler.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "net/http" 8 | "net/http/httputil" 9 | "net/url" 10 | "regexp" 11 | "time" 12 | 13 | "github.com/telanflow/mps/pool" 14 | ) 15 | 16 | var ( 17 | HttpTunnelOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n") 18 | HttpTunnelFail = []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n") 19 | hasPort = regexp.MustCompile(`:\d+$`) 20 | ) 21 | 22 | // TunnelHandler The tunnel proxy type. Implements http.Handler. 23 | type TunnelHandler struct { 24 | Ctx *Context 25 | BufferPool httputil.BufferPool 26 | ConnContainer pool.ConnContainer 27 | } 28 | 29 | // NewTunnelHandler Create a tunnel handler 30 | func NewTunnelHandler() *TunnelHandler { 31 | return &TunnelHandler{ 32 | Ctx: NewContext(), 33 | BufferPool: pool.DefaultBuffer, 34 | } 35 | } 36 | 37 | // NewTunnelHandlerWithContext Create a tunnel handler with Context 38 | func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler { 39 | return &TunnelHandler{ 40 | Ctx: ctx, 41 | BufferPool: pool.DefaultBuffer, 42 | } 43 | } 44 | 45 | // Standard net/http function. You can use it alone 46 | func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 47 | // execution middleware 48 | ctx := tunnel.Ctx.WithRequest(req) 49 | resp, err := ctx.Next(req) 50 | if err != nil && err != MethodNotSupportErr { 51 | if resp != nil { 52 | copyHeaders(rw.Header(), resp.Header, tunnel.Ctx.KeepDestinationHeaders) 53 | rw.WriteHeader(resp.StatusCode) 54 | buf := tunnel.buffer().Get() 55 | _, err = io.CopyBuffer(rw, resp.Body, buf) 56 | tunnel.buffer().Put(buf) 57 | } 58 | return 59 | } 60 | 61 | // hijacker connection 62 | proxyClient, err := hijacker(rw) 63 | if err != nil { 64 | http.Error(rw, err.Error(), 502) 65 | return 66 | } 67 | 68 | var ( 69 | u *url.URL = nil 70 | targetConn net.Conn = nil 71 | targetAddr = hostAndPort(req.URL.Host) 72 | isCascadeProxy = false 73 | ) 74 | if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil { 75 | u, err = tunnel.Ctx.Transport.Proxy(req) 76 | if err != nil { 77 | ConnError(proxyClient) 78 | return 79 | } 80 | if u != nil { 81 | // connect addr eg. "localhost:80" 82 | targetAddr = hostAndPort(u.Host) 83 | isCascadeProxy = true 84 | } 85 | } 86 | 87 | // connect to targetAddr 88 | targetConn, err = tunnel.connContainer().Get(targetAddr) 89 | if err != nil { 90 | targetConn, err = tunnel.ConnectDial("tcp", targetAddr) 91 | if err != nil { 92 | ConnError(proxyClient) 93 | return 94 | } 95 | } 96 | 97 | // If the ConnContainer is exists, 98 | // When io.CopyBuffer is complete, 99 | // put the idle connection into the ConnContainer so can reuse it next time 100 | defer func() { 101 | err := tunnel.connContainer().Put(targetConn) 102 | if err != nil { 103 | // put conn fail, conn must be closed 104 | _ = targetConn.Close() 105 | } 106 | }() 107 | 108 | // The cascade proxy needs to forward the request 109 | if isCascadeProxy { 110 | // The cascade proxy needs to send it as-is 111 | _ = req.Write(targetConn) 112 | } else { 113 | // Tell client that the tunnel is ready 114 | _, _ = proxyClient.Write(HttpTunnelOk) 115 | } 116 | 117 | go func() { 118 | buf := tunnel.buffer().Get() 119 | _, _ = io.CopyBuffer(targetConn, proxyClient, buf) 120 | tunnel.buffer().Put(buf) 121 | _ = proxyClient.Close() 122 | }() 123 | buf := tunnel.buffer().Get() 124 | _, _ = io.CopyBuffer(proxyClient, targetConn, buf) 125 | tunnel.buffer().Put(buf) 126 | } 127 | 128 | // Use registers an Middleware to proxy 129 | func (tunnel *TunnelHandler) Use(middleware ...Middleware) { 130 | tunnel.Ctx.Use(middleware...) 131 | } 132 | 133 | // UseFunc registers an MiddlewareFunc to proxy 134 | func (tunnel *TunnelHandler) UseFunc(fus ...MiddlewareFunc) { 135 | tunnel.Ctx.UseFunc(fus...) 136 | } 137 | 138 | // OnRequest filter requests through Filters 139 | func (tunnel *TunnelHandler) OnRequest(filters ...Filter) *ReqFilterGroup { 140 | return &ReqFilterGroup{ctx: tunnel.Ctx, filters: filters} 141 | } 142 | 143 | // OnResponse filter response through Filters 144 | func (tunnel *TunnelHandler) OnResponse(filters ...Filter) *RespFilterGroup { 145 | return &RespFilterGroup{ctx: tunnel.Ctx, filters: filters} 146 | } 147 | 148 | func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) { 149 | if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil { 150 | return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr) 151 | } 152 | return net.DialTimeout(network, addr, 30*time.Second) 153 | } 154 | 155 | // Transport get http.Transport instance 156 | func (tunnel *TunnelHandler) Transport() *http.Transport { 157 | return tunnel.Ctx.Transport 158 | } 159 | 160 | // get a context.Context 161 | func (tunnel *TunnelHandler) context() context.Context { 162 | if tunnel.Ctx.Context != nil { 163 | return tunnel.Ctx.Context 164 | } 165 | return context.Background() 166 | } 167 | 168 | // Get buffer pool 169 | func (tunnel *TunnelHandler) buffer() httputil.BufferPool { 170 | if tunnel.BufferPool != nil { 171 | return tunnel.BufferPool 172 | } 173 | return pool.DefaultBuffer 174 | } 175 | 176 | // Get a conn pool 177 | func (tunnel *TunnelHandler) connContainer() pool.ConnContainer { 178 | if tunnel.ConnContainer != nil { 179 | return tunnel.ConnContainer 180 | } 181 | return pool.DefaultConnProvider 182 | } 183 | 184 | func hostAndPort(addr string) string { 185 | if !hasPort.MatchString(addr) { 186 | addr += ":80" 187 | } 188 | return addr 189 | } 190 | 191 | func ConnError(w net.Conn) { 192 | _, _ = w.Write(HttpTunnelFail) 193 | _ = w.Close() 194 | } 195 | -------------------------------------------------------------------------------- /tunnel_handler_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestNewTunnelHandler(t *testing.T) { 13 | srv := newTestServer() 14 | defer srv.Close() 15 | 16 | tunnel := NewTunnelHandler() 17 | //tunnel.Transport().Proxy = func(r *http.Request) (*url.URL, error) { 18 | // return url.Parse("http://127.0.0.1:7890") 19 | //} 20 | tunnelSrv := httptest.NewServer(tunnel) 21 | defer tunnelSrv.Close() 22 | 23 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) { 24 | return url.Parse(tunnelSrv.URL) 25 | }) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | resp.Body.Close() 30 | 31 | asserts := assert.New(t) 32 | asserts.Equal(resp.StatusCode, 200) 33 | } 34 | -------------------------------------------------------------------------------- /websocket_handler.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "net" 8 | "net/http" 9 | "net/http/httputil" 10 | "net/url" 11 | "strings" 12 | "time" 13 | 14 | "github.com/telanflow/mps/pool" 15 | ) 16 | 17 | // WebsocketHandler The websocket proxy type. Implements http.Handler. 18 | type WebsocketHandler struct { 19 | Ctx *Context 20 | BufferPool httputil.BufferPool 21 | } 22 | 23 | // NewWebsocketHandler Create a websocket handler 24 | func NewWebsocketHandler() *WebsocketHandler { 25 | return &WebsocketHandler{ 26 | Ctx: NewContext(), 27 | BufferPool: pool.DefaultBuffer, 28 | } 29 | } 30 | 31 | // NewWebsocketHandlerWithContext Create a tunnel handler with Context 32 | func NewWebsocketHandlerWithContext(ctx *Context) *WebsocketHandler { 33 | return &WebsocketHandler{ 34 | Ctx: ctx, 35 | BufferPool: pool.DefaultBuffer, 36 | } 37 | } 38 | 39 | // Standard net/http function. You can use it alone 40 | func (ws *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 41 | // Whether to upgrade to Websocket 42 | if !isWebSocketRequest(req) { 43 | return 44 | } 45 | 46 | // hijacker connection 47 | clientConn, err := hijacker(rw) 48 | if err != nil { 49 | http.Error(rw, err.Error(), 502) 50 | return 51 | } 52 | 53 | var ( 54 | u *url.URL 55 | targetAddr = hostAndPort(req.URL.Host) 56 | ) 57 | if ws.Ctx.Transport != nil && ws.Ctx.Transport.Proxy != nil { 58 | u, err = ws.Ctx.Transport.Proxy(req) 59 | if err != nil { 60 | ConnError(clientConn) 61 | return 62 | } 63 | if u != nil { 64 | // connect addr eg. "localhost:443" 65 | targetAddr = hostAndPort(u.Host) 66 | } 67 | } 68 | 69 | targetConn, err := ws.ConnectDial("tcp", targetAddr) 70 | if err != nil { 71 | return 72 | } 73 | defer targetConn.Close() 74 | 75 | // Perform handshake 76 | // write handshake request to target 77 | err = req.Write(targetConn) 78 | if err != nil { 79 | return 80 | } 81 | 82 | // Read handshake response from target 83 | targetReader := bufio.NewReader(targetConn) 84 | resp, err := http.ReadResponse(targetReader, req) 85 | if err != nil { 86 | return 87 | } 88 | 89 | // Proxy handshake back to client 90 | err = resp.Write(clientConn) 91 | if err != nil { 92 | return 93 | } 94 | 95 | // Proxy ws connection 96 | go func() { 97 | buf := ws.buffer().Get() 98 | _, _ = io.CopyBuffer(targetConn, clientConn, buf) 99 | ws.buffer().Put(buf) 100 | _ = clientConn.Close() 101 | }() 102 | buf := ws.buffer().Get() 103 | _, _ = io.CopyBuffer(clientConn, targetConn, buf) 104 | ws.buffer().Put(buf) 105 | } 106 | 107 | func (ws *WebsocketHandler) ConnectDial(network, addr string) (net.Conn, error) { 108 | if ws.Ctx.Transport != nil && ws.Ctx.Transport.DialContext != nil { 109 | return ws.Ctx.Transport.DialContext(ws.context(), network, addr) 110 | } 111 | return net.DialTimeout(network, addr, 30*time.Second) 112 | } 113 | 114 | // Transport get http.Transport instance 115 | func (ws *WebsocketHandler) Transport() *http.Transport { 116 | return ws.Ctx.Transport 117 | } 118 | 119 | // context returned a context.Context 120 | func (ws *WebsocketHandler) context() context.Context { 121 | if ws.Ctx.Context != nil { 122 | return ws.Ctx.Context 123 | } 124 | return context.Background() 125 | } 126 | 127 | // buffer returned a httputil.BufferPool 128 | func (ws *WebsocketHandler) buffer() httputil.BufferPool { 129 | if ws.BufferPool != nil { 130 | return ws.BufferPool 131 | } 132 | return pool.DefaultBuffer 133 | } 134 | 135 | // isWebSocketRequest to upgrade to a Websocket request 136 | func isWebSocketRequest(req *http.Request) bool { 137 | return headerContains(req.Header, "Connection", "upgrade") && 138 | headerContains(req.Header, "Upgrade", "websocket") 139 | } 140 | 141 | func headerContains(header http.Header, name string, value string) bool { 142 | for _, v := range header[name] { 143 | for _, s := range strings.Split(v, ",") { 144 | if strings.EqualFold(value, strings.TrimSpace(s)) { 145 | return true 146 | } 147 | } 148 | } 149 | return false 150 | } 151 | -------------------------------------------------------------------------------- /websocket_handler_test.go: -------------------------------------------------------------------------------- 1 | package mps 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/gorilla/websocket" 12 | ) 13 | 14 | var upgrader = websocket.Upgrader{} 15 | 16 | // create a test websocket server 17 | func newTestWebsocketServer() *httptest.Server { 18 | return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 19 | c, err := upgrader.Upgrade(rw, req, nil) 20 | if err != nil { 21 | return 22 | } 23 | defer c.Close() 24 | for { 25 | mt, message, err := c.ReadMessage() 26 | if err != nil { 27 | break 28 | } 29 | err = c.WriteMessage(mt, message) 30 | if err != nil { 31 | break 32 | } 33 | } 34 | })) 35 | } 36 | 37 | func TestNewWebsocketHandler(t *testing.T) { 38 | // create endPoint websocket server 39 | srv := newTestWebsocketServer() 40 | defer srv.Close() 41 | 42 | // Convert http://127.0.0.1 to ws://127.0.0.1 43 | endPoint := "ws" + strings.TrimPrefix(srv.URL, "http") 44 | log.Printf("endPoint: %s", endPoint) 45 | 46 | // create a proxy websocket server 47 | wsHandler := NewWebsocketHandler() 48 | wsHandler.Transport().Proxy = func(request *http.Request) (*url.URL, error) { 49 | return url.Parse(endPoint) 50 | } 51 | proxySrv := httptest.NewServer(wsHandler) 52 | defer proxySrv.Close() 53 | 54 | proxyWs := "ws" + strings.TrimPrefix(proxySrv.URL, "http") 55 | log.Printf("proxy: %s", proxyWs) 56 | 57 | // Connect to the proxy websocket server 58 | client, _, err := websocket.DefaultDialer.Dial(proxyWs, nil) 59 | if err != nil { 60 | t.Fatalf("%v", err) 61 | } 62 | defer client.Close() 63 | 64 | // Send message to server, read response and check to see if it's what we expect. 65 | for i := 0; i < 5; i++ { 66 | if err := client.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil { 67 | t.Fatalf("send fail: %v", err) 68 | } 69 | 70 | _, p, err := client.ReadMessage() 71 | if err != nil { 72 | t.Fatalf("read fail: %v", err) 73 | } 74 | 75 | log.Printf("recv: %s", string(p)) 76 | if string(p) != "hello" { 77 | t.Fatalf("bad message") 78 | } 79 | } 80 | } 81 | --------------------------------------------------------------------------------