├── .github
└── workflows
│ └── mps.yml
├── .gitignore
├── LICENSE
├── README.md
├── README_ZH.md
├── _examples
├── basic-auth
│ └── main.go
├── cascade-proxy
│ └── main.go
├── generateCert
│ ├── openssl-gen.sh
│ └── openssl.cnf
├── mitm-proxy
│ ├── README.md
│ ├── ca.crt
│ ├── ca.key
│ └── main.go
├── reverse-proxy
│ └── main.go
├── simple-http-proxy
│ └── main.go
└── websocket-proxy
│ └── main.go
├── cert
├── cert.go
├── container.go
└── mem_provider.go
├── chunked.go
├── context.go
├── counter_encryptor.go
├── counter_encryptor_test.go
├── filter.go
├── filter_group.go
├── forward_handler.go
├── forward_handler_test.go
├── go.mod
├── go.sum
├── handle.go
├── http_proxy.go
├── http_proxy_test.go
├── middleware.go
├── middleware
├── basicAuth.go
└── singleHostReverseProxy.go
├── mitm_handler.go
├── mitm_handler_test.go
├── mps.go
├── pool
├── buffer.go
├── conn_container.go
├── conn_options.go
└── conn_provider.go
├── reverse_handler.go
├── reverse_handler_test.go
├── transport.go
├── tunnel_handler.go
├── tunnel_handler_test.go
├── websocket_handler.go
└── websocket_handler_test.go
/.github/workflows/mps.yml:
--------------------------------------------------------------------------------
1 | name: MPS
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 |
11 | build:
12 | name: Build
13 | runs-on: ubuntu-latest
14 | steps:
15 |
16 | - name: Set up Go 1.x
17 | uses: actions/setup-go@v2
18 | with:
19 | go-version: ^1.16
20 | id: go
21 |
22 | - name: Check out code into the Go module directory
23 | uses: actions/checkout@v2
24 |
25 | - name: Get dependencies
26 | run: |
27 | go get -v -t -d ./...
28 | if [ -f Gopkg.toml ]; then
29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
30 | dep ensure
31 | fi
32 |
33 | - name: Test
34 | run: go test -v .
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, built with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | # Dependency directories (remove the comment below to include it)
15 | # vendor/
16 | .DS_Store
17 | .idea
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Telanflow
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
MPS
3 |
4 |
5 | English | [🇨🇳中文](README_ZH.md)
6 |
7 | ## 📖 Introduction
8 | 
9 | 
10 | 
11 | 
12 | [](https://github.com/telanflow/mps/LICENSE)
13 |
14 | MPS (middle-proxy-server) is an high-performance middle proxy library. support HTTP, HTTPS, Websocket, ForwardProxy, ReverseProxy, TunnelProxy, MitmProxy.
15 |
16 | ## 🚀 Features
17 | - [X] Http Proxy
18 | - [X] Https Proxy
19 | - [X] Forward Proxy
20 | - [X] Reverse Proxy
21 | - [X] Tunnel Proxy
22 | - [X] Mitm Proxy (Man-in-the-middle)
23 | - [X] WekSocket Proxy
24 |
25 | ## 🧰 Install
26 | ```
27 | go get -u github.com/telanflow/mps
28 | ```
29 |
30 | ## 🛠 How to use
31 | A simple proxy service
32 |
33 | ```go
34 | package main
35 |
36 | import (
37 | "github.com/telanflow/mps"
38 | "log"
39 | "net/http"
40 | )
41 |
42 | func main() {
43 | proxy := mps.NewHttpProxy()
44 | log.Fatal(http.ListenAndServe(":8080", proxy))
45 | }
46 | ```
47 |
48 | More [examples](https://github.com/telanflow/mps/tree/master/_examples)
49 |
50 | ## 🧬 Middleware
51 | Middleware can intercept requests and responses.
52 | we have several middleware implementations built in, including [BasicAuth](https://github.com/telanflow/mps/tree/master/middleware)
53 |
54 | ```go
55 | func main() {
56 | proxy := mps.NewHttpProxy()
57 |
58 | proxy.Use(mps.MiddlewareFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
59 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
60 | return ctx.Next(req)
61 | }))
62 |
63 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
64 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
65 | resp, err := ctx.Next(req)
66 | if err != nil {
67 | return nil, err
68 | }
69 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
70 | return resp, err
71 | })
72 |
73 | log.Fatal(http.ListenAndServe(":8080", proxy))
74 | }
75 | ```
76 |
77 | ## ♻️ Filters
78 | Filters can filter requests and responses for unified processing.
79 | It is based on middleware implementation.
80 |
81 | ```go
82 | func main() {
83 | proxy := mps.NewHttpProxy()
84 |
85 | // request Filter Group
86 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$")))
87 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) {
88 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL)
89 | return req, nil
90 | })
91 |
92 | // response Filter Group
93 | respGroup := proxy.OnResponse()
94 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) {
95 | if err != nil {
96 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err)
97 | return nil, err
98 | }
99 |
100 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
101 | return resp, err
102 | })
103 |
104 | log.Fatal(http.ListenAndServe(":8080", proxy))
105 | }
106 | ```
107 |
108 | ## 📄 License
109 | Source code in `MPS` is available under the [BSD 3 License](/LICENSE).
110 |
--------------------------------------------------------------------------------
/README_ZH.md:
--------------------------------------------------------------------------------
1 |
2 |
MPS
3 |
4 |
5 | [English](README.md) | 🇨🇳中文
6 |
7 | ## 📖 介绍
8 | 
9 | 
10 | 
11 | 
12 | [](https://github.com/telanflow/mps/LICENSE)
13 |
14 | MPS 是一个高性能的中间代理扩展库,支持 HTTP、HTTPS、Websocket、正向代理、反向代理、隧道代理、中间人代理 等代理方式。
15 |
16 | ## 🚀 特性
17 | - [X] Http代理
18 | - [X] Https代理
19 | - [X] 正向代理
20 | - [X] 反向代理
21 | - [X] 隧道代理
22 | - [X] 中间人代理 (MITM)
23 | - [X] WekSocket代理
24 |
25 | ## 🧰 安装
26 | ```
27 | go get -u github.com/telanflow/mps
28 | ```
29 |
30 | ## 🛠 如何使用
31 | 一个简单的HTTP代理服务
32 |
33 | ```go
34 | package main
35 |
36 | import (
37 | "github.com/telanflow/mps"
38 | "log"
39 | "net/http"
40 | )
41 |
42 | func main() {
43 | proxy := mps.NewHttpProxy()
44 | log.Fatal(http.ListenAndServe(":8080", proxy))
45 | }
46 | ```
47 |
48 | 更多 [范例](https://github.com/telanflow/mps/tree/master/_examples)
49 |
50 | ## 🧬 中间件
51 | 中间件可以拦截请求和响应,我们内置实现了多个中间件,包括 [BasicAuth](https://github.com/telanflow/mps/tree/master/middleware)
52 |
53 | ```go
54 | func main() {
55 | proxy := mps.NewHttpProxy()
56 |
57 | proxy.Use(mps.MiddlewareFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
58 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
59 | return ctx.Next(req)
60 | }))
61 |
62 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
63 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
64 | resp, err := ctx.Next(req)
65 | if err != nil {
66 | return nil, err
67 | }
68 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
69 | return resp, err
70 | })
71 |
72 | log.Fatal(http.ListenAndServe(":8080", proxy))
73 | }
74 | ```
75 |
76 | ## ♻️ 过滤器
77 | 过滤器可以对请求和响应进行筛选,统一进行处理。
78 | 它基于中间件实现。
79 |
80 | ```go
81 | func main() {
82 | proxy := mps.NewHttpProxy()
83 |
84 | // request Filter Group
85 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$")))
86 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) {
87 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL)
88 | return req, nil
89 | })
90 |
91 | // response Filter Group
92 | respGroup := proxy.OnResponse()
93 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) {
94 | if err != nil {
95 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err)
96 | return nil, err
97 | }
98 |
99 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
100 | return resp, err
101 | })
102 |
103 | log.Fatal(http.ListenAndServe(":8080", proxy))
104 | }
105 | ```
106 |
107 | ## 📄 开源许可
108 | `MPS`中的源代码在[BSD 3 License](/LICENSE)下可用。
109 |
--------------------------------------------------------------------------------
/_examples/basic-auth/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "io"
5 | "log"
6 | "net/http"
7 | "net/url"
8 | "time"
9 |
10 | "github.com/telanflow/mps"
11 | "github.com/telanflow/mps/middleware"
12 | )
13 |
14 | // A simple BasicAuth example
15 | func main() {
16 | // endPoint server
17 | go http.ListenAndServe("localhost:8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18 | // Basic Authentication
19 | usr, pwd, ok := r.BasicAuth()
20 | if !ok {
21 | w.WriteHeader(401)
22 | w.Write([]byte("401 Authentication Required"))
23 | return
24 | }
25 | if usr != "test" || pwd != "test" {
26 | w.WriteHeader(401)
27 | w.Write([]byte("401 Authentication Required"))
28 | return
29 | }
30 | w.Write([]byte("successful endPoint"))
31 | }))
32 |
33 | // proxy server
34 | proxy := mps.NewHttpProxy()
35 | // proxy BasicAuth
36 | proxy.Use(middleware.BasicAuth("mps_realm", func(username, password string) bool {
37 | return username == "mps" && password == "mps"
38 | }))
39 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
40 | // set endPoint BasicAuth
41 | // Or you can set the endPoint BasicAuth on the client
42 | req.SetBasicAuth("test", "test")
43 | return ctx.Next(req)
44 | })
45 | go http.ListenAndServe("localhost:8081", proxy)
46 |
47 | // wait proxy started
48 | time.Sleep(2 * time.Second)
49 |
50 | // send request
51 | // request ==> proxy ==> http://localhost:8080
52 | // response <== proxy <== http://localhost:8080
53 | request, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
54 | http.DefaultClient.Transport = &http.Transport{
55 | Proxy: func(req *http.Request) (*url.URL, error) {
56 | // set proxy server BasicAuth
57 | middleware.SetBasicAuth(req, "mps", "mps")
58 |
59 | // set endPoint BasicAuth
60 | // Or you can set the endPoint to BasicAuth on the proxy server
61 | //req.SetBasicAuth("test", "test")
62 |
63 | return url.Parse("http://localhost:8081")
64 | },
65 | }
66 | resp, err := http.DefaultClient.Do(request)
67 | if err != nil {
68 | log.Fatal(err)
69 | }
70 | defer resp.Body.Close()
71 |
72 | body, _ := io.ReadAll(resp.Body)
73 | log.Println(string(body))
74 | }
75 |
--------------------------------------------------------------------------------
/_examples/cascade-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "io"
5 | "log"
6 | "net/http"
7 | "net/url"
8 | "time"
9 |
10 | "github.com/telanflow/mps"
11 | "github.com/telanflow/mps/middleware"
12 | )
13 |
14 | // A simple example of cascading proxy.
15 | // It implements BasicAuth
16 | func main() {
17 | // endPoint server
18 | go http.ListenAndServe("localhost:9990", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19 | _, _ = w.Write([]byte("successful endPoint server"))
20 | }))
21 |
22 | // proxy server 1
23 | proxy1 := mps.NewHttpProxy()
24 | proxy1.Ctx.KeepProxyHeaders = true
25 | proxy1.Use(middleware.BasicAuth("mps_realm_1", func(username, password string) bool {
26 | return username == "foo_1" && password == "bar_1"
27 | }))
28 | go http.ListenAndServe("localhost:9991", proxy1)
29 |
30 | // proxy server 2
31 | proxy2 := mps.NewHttpProxy()
32 | proxy2.Ctx.KeepProxyHeaders = true
33 | proxy2.Use(middleware.BasicAuth("mps_realm_2", func(username, password string) bool {
34 | return username == "foo_2" && password == "bar_2"
35 | }))
36 | proxy2.Transport().Proxy = func(req *http.Request) (*url.URL, error) {
37 | middleware.SetBasicAuth(req, "foo_1", "bar_1")
38 | return url.Parse("http://localhost:9991")
39 | }
40 | go http.ListenAndServe("localhost:9992", proxy2)
41 |
42 | // wait proxy server started
43 | time.Sleep(2 * time.Second)
44 |
45 | // send request
46 | // request ==> proxy2 ==> proxy1 ==> http://localhost:9990
47 | // response <== proxy2 <== proxy1 <== http://localhost:9990
48 | req, _ := http.NewRequest(http.MethodGet, "http://localhost:9990/", nil)
49 | http.DefaultClient.Transport = &http.Transport{
50 | Proxy: func(r *http.Request) (*url.URL, error) {
51 | middleware.SetBasicAuth(r, "foo_2", "bar_2")
52 | return url.Parse("http://localhost:9992")
53 | },
54 | }
55 | resp, err := http.DefaultClient.Do(req)
56 | if err != nil {
57 | log.Fatal(err)
58 | }
59 |
60 | body, _ := io.ReadAll(resp.Body)
61 | resp.Body.Close()
62 |
63 | log.Println(resp.Header)
64 | log.Println(string(body))
65 | }
66 |
--------------------------------------------------------------------------------
/_examples/generateCert/openssl-gen.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | # generate CA's key
4 | openssl genrsa -aes256 -passout pass:1 -out ca.key 4096
5 | openssl rsa -passin pass:1 -in ca.key -out ca.key.tmp
6 | mv ca.key.tmp ca.key
7 |
8 | openssl req -config openssl.cnf -key ca.key -new -x509 -days 7300 -sha256 -extensions v3_ca -out ca.crt
9 |
--------------------------------------------------------------------------------
/_examples/generateCert/openssl.cnf:
--------------------------------------------------------------------------------
1 | [ ca ]
2 | default_ca = CA_default
3 | [ CA_default ]
4 | default_md = sha256
5 | [ v3_ca ]
6 | subjectKeyIdentifier=hash
7 | authorityKeyIdentifier=keyid:always,issuer
8 | basicConstraints = critical,CA:true
9 | [ req ]
10 | distinguished_name = req_distinguished_name
11 | [ req_distinguished_name ]
12 | countryName = Country Name (2 letter code)
13 | countryName_default = CN
14 | countryName_min = 2
15 | countryName_max = 2
16 |
17 | stateOrProvinceName = State or Province Name (full name)
18 | stateOrProvinceName_default = ZheJiang
19 |
20 | localityName = Locality Name (eg, city)
21 | localityName_default = HangZhou
22 |
23 | 0.organizationName = Organization Name (eg, company)
24 | 0.organizationName_default = mps
25 |
26 | # we can do this but it is not needed normally :-)
27 | #1.organizationName = Second Organization Name (eg, company)
28 | #1.organizationName_default = World Wide Web Pty Ltd
29 |
30 | organizationalUnitName = Organizational Unit Name (eg, section)
31 | organizationalUnitName_default = mps
32 |
33 | commonName = Common Name (e.g. server FQDN or YOUR name)
34 | commonName_default = mps.github.io
35 | commonName_max = 64
36 |
37 | emailAddress = Email Address
38 | emailAddress_default = telanflow@gmail.com
39 | emailAddress_max = 64
40 |
--------------------------------------------------------------------------------
/_examples/mitm-proxy/README.md:
--------------------------------------------------------------------------------
1 | ## Mitm Proxy
2 |
3 | This example implements Https as a man-in-the-middle proxy.
4 |
5 | You can go to the `examples/generateCert` directory to regenerate the certificate files
6 |
7 | ## Steps
8 |
9 | 1. Go to `examples/generateCert` to generate the certificate file.
10 |
11 | 2. Import the `ca.crt` certificate file into your system.
12 |
13 | If you want to use the Go client to make HTTPS requests, you need to configure the certificate,
14 | for example:
15 | ```go
16 | func main() {
17 | // Load ca.crt file
18 | certPEMBlock, err := os.ReadFile("ca.crt")
19 | if err != nil {
20 | panic("failed to load ca.crt file")
21 | }
22 |
23 | // client cert pool
24 | clientCertPool := x509.NewCertPool()
25 | ok := clientCertPool.AppendCertsFromPEM(certPEMBlock)
26 | if !ok {
27 | panic("failed to parse root certificate")
28 | }
29 |
30 | // set Transport
31 | http.DefaultClient.Transport = &http.Transport{
32 | Proxy: func(r *http.Request) (*url.URL, error) {
33 | // mitm proxy server address. eg. "http://localhost:8080"
34 | return url.Parse("http://localhost:8080")
35 | },
36 | TLSClientConfig: &tls.Config{
37 | Certificates: []tls.Certificate{cert.DefaultCertificate},
38 | ClientAuth: tls.RequireAndVerifyClientCert,
39 | RootCAs: clientCertPool,
40 | },
41 | }
42 |
43 | // To send request
44 | req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
45 | resp, err := http.DefaultClient.Do(req)
46 | if err != nil {
47 | t.Fatal(err)
48 | }
49 | defer resp.Body.Close()
50 | }
51 | ```
52 |
53 |
54 |
--------------------------------------------------------------------------------
/_examples/mitm-proxy/ca.crt:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIF7jCCA9agAwIBAgIJAJ+vhP71rZgIMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD
3 | VQQGEwJDTjERMA8GA1UECAwIWmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQww
4 | CgYDVQQKDANtcHMxDDAKBgNVBAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5p
5 | bzEiMCAGCSqGSIb3DQEJARYTdGVsYW5mbG93QGdtYWlsLmNvbTAeFw0yMDA4MTkw
6 | NjM0NDBaFw00MDA4MTQwNjM0NDBaMIGLMQswCQYDVQQGEwJDTjERMA8GA1UECAwI
7 | WmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQwwCgYDVQQKDANtcHMxDDAKBgNV
8 | BAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5pbzEiMCAGCSqGSIb3DQEJARYT
9 | dGVsYW5mbG93QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
10 | ggIBALuMT+elSDg9C+1ZKFe2I+1DcDShxY7xzowwxYx59x9EEhdlCTurWf32za7s
11 | V7yzhhzsG9w65uJZHXGkmjzl5iDBepL4PMoifEY9gs3W7xsKPJqtQ0c1wvT1yv8H
12 | geSsnmMHmsz4RG+YakyluvuAIwU0QTbpGeQWmP0Hy6aMD9yP3b3p37I+3Ok8QIcC
13 | 0fExLx+9t/EL8PHKcjX0g813cdmyKel9QdGTmjjQ8QSkWRfBPhENhyvazu87X94f
14 | Ru9Gvd1scrJsY0ty6QRGiCoxtaSc7POjRg1PiC3PapiMjr1Sx3Q01sdb6sQwZ03/
15 | OlWAD8T/kIoJo+C+hRdlYjlaumYiUPbNtVksYldVQZ1Bva/oTcLCYV30GP3jlw26
16 | UC2PXmT4yBJQDdEPrj0LqcTrlGslIdE8KQA/oBetpwdqGTV05oqKI1E3M8CY/QEV
17 | rW+yJXaVq/+ukRKaGWRwJNcZC4iPOUjHaAjaN1XtR0vxPG4Brw2MFy4p3SeEQ8mJ
18 | cvQLnaLzsEYc6iu/ntGdnYMpFMCLNGm+l/+B7dP+KiuseA67Q5BOanFp5wrEQ7ND
19 | 8TdCZUe9LzRnNvyZXi5N5wMxQ/tCVR/a27MLN66mxrG0KyzSunCiYqiLyoC4eQJ9
20 | RKjOkHWQ4V4y2NFTqKhgr/Ns1PBiH/4XKQHdRYoWVV4cL+6ZAgMBAAGjUzBRMB0G
21 | A1UdDgQWBBTNarhuZvbH55oZyKX7bEGgP4D9wDAfBgNVHSMEGDAWgBTNarhuZvbH
22 | 55oZyKX7bEGgP4D9wDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
23 | AQAEusi8rHwPZwMObn/u/PJQ0K+7Ep/K0nqLGLLpMudj9AVYYV50ZG47lv7ibmQS
24 | LpgfxzJ0fCk5rM6Wg0+KxqLDUP2GotUlc31mIQACVXVbzipx1A06ThBthOUJXaZY
25 | +T6ggKkacLZOG35Af99qOzDIZU9FcunYyaQi8qUCY7VT+Uynu090k7ynvy6S7vut
26 | rGavfDD7oAWEpRyQen8MpUoqhY3PeQJ3VKUp7lcgsiP5f8tLwQAtx7Qvh9li95ml
27 | Kq/SFyA0qPZH6eUrQLvNr4InGwbaoKNAvN99rSHEk2wixuxG2CqsvKW7EQ4IRUcS
28 | H9EBfUEm6EH7xWsyxrlYNzh0k/dYfvy4je1YwkpNAl8BSV8PdFtpW0//FZa04KfG
29 | D9D+wb5MxOmnULl818cYHrg5QIB/uqet21ib85CB/oRvR9JF0D4hfC0gy3IZz+8L
30 | DBhKjf9KNP2631nqtuC7vrShLSFa+02D3bi2srgznIcORtkBHpHgZefAOI7dPJzM
31 | C/ENEorEayK0rdbc116grj7rfGY9yWg/7fD9B0lvrD2DwXTXTd94laDPbpIJIikN
32 | 6EFIJp3IjGKbH+HWUY4TxeRDH8bEq+wjTA0rrIRnI2z03VlX+ZK/UTVdnPxV9Pdd
33 | 3FuvmED63INzIdoY/DgWgUF2DShFbALKfBR3Yev/mkbCqg==
34 | -----END CERTIFICATE-----
35 |
--------------------------------------------------------------------------------
/_examples/mitm-proxy/ca.key:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIJKgIBAAKCAgEAu4xP56VIOD0L7VkoV7Yj7UNwNKHFjvHOjDDFjHn3H0QSF2UJ
3 | O6tZ/fbNruxXvLOGHOwb3Drm4lkdcaSaPOXmIMF6kvg8yiJ8Rj2CzdbvGwo8mq1D
4 | RzXC9PXK/weB5KyeYweazPhEb5hqTKW6+4AjBTRBNukZ5BaY/QfLpowP3I/dvenf
5 | sj7c6TxAhwLR8TEvH7238Qvw8cpyNfSDzXdx2bIp6X1B0ZOaONDxBKRZF8E+EQ2H
6 | K9rO7ztf3h9G70a93WxysmxjS3LpBEaIKjG1pJzs86NGDU+ILc9qmIyOvVLHdDTW
7 | x1vqxDBnTf86VYAPxP+Qigmj4L6FF2ViOVq6ZiJQ9s21WSxiV1VBnUG9r+hNwsJh
8 | XfQY/eOXDbpQLY9eZPjIElAN0Q+uPQupxOuUayUh0TwpAD+gF62nB2oZNXTmiooj
9 | UTczwJj9ARWtb7IldpWr/66REpoZZHAk1xkLiI85SMdoCNo3Ve1HS/E8bgGvDYwX
10 | LindJ4RDyYly9AudovOwRhzqK7+e0Z2dgykUwIs0ab6X/4Ht0/4qK6x4DrtDkE5q
11 | cWnnCsRDs0PxN0JlR70vNGc2/JleLk3nAzFD+0JVH9rbsws3rqbGsbQrLNK6cKJi
12 | qIvKgLh5An1EqM6QdZDhXjLY0VOoqGCv82zU8GIf/hcpAd1FihZVXhwv7pkCAwEA
13 | AQKCAgBThCgQ/4kpggXNq+ZLKNDW1zEgPum6vfM8ent+EtH5Glb0FAoIiEWK0lzF
14 | iHmJjmgqePnvGEu4f/acpLAKblYMQBxVVjW7zZ+Jp9qXzx6q6+QQ/Rb4nvgyHUJI
15 | Tw+IxVXCw6ArpmLTTwwHFcYuOOFfb+WajjL5XxbBlrcZc0Wc8nPMHll/Bn9ZXXte
16 | o+LZhQ13FQTUUnz5Ly2s2TXYSVhpmO0RDLZCnXgP1Pt/FbCW43bAIUYQQV/lKIuI
17 | XmU4KEhkUebBjYKqFoGtZbs9DuXUaA0ccZjAVKpPvA274Nuvcy1ekikSndvtgaB/
18 | Gyje6igbkbLLxX80laKuyHb1E3HtRZ22YI9t3GIGaKWYzU43nOnAiTp0o40EPnqt
19 | lF4ZEsN7eh0gGEkPnsw24Os5QyJJw4qeXfeVGU9Ig/kFOdf6n+8t2031imcFyScW
20 | BBb9gwj/hef3FFG/lucwoMXSDcIkKdITCr3aH1VxIxulAPQzjOK6rfnGR2kFnKoP
21 | RQcnQS3+mM7HqWppgy18St+MIiW95Cd+6WecUqAM25s9VOnIW2BbzHzKTfaNC4kR
22 | Ou5ruK5FYA36Da/gC0pEEuloV6bP0KVr1IutHNB656U6gl+g5ir0/rdCsrPqJRHP
23 | wHf5Vflb2lrIVFb/jIXqi0GsC7rbcvpkTD6zDcMOzXE69j+IAQKCAQEA7Mo3JO3c
24 | 8Q4S6sCIP8jYTkt//E4VADtNN+IgcSJEtPKQkcP8CvRlt4wkp0aaXQoV99aNzVHZ
25 | j22xOrJENZEjV7xB51R0RDWRe5sVWmVlXOVmePS6ta8WIwQG9RqNlTyEOyYpxhK7
26 | XJh0v29/CxF0UsCQYjYAxwDgE4fY11vIN/q7+YluTlA3Rwei4EvagoFftAGD4jRB
27 | 9ECnH2GiuCS6Mi5PQ90jj9XMCm5WcCXwFG8OKJWf/8uKClruL/yes3+G7lic9mof
28 | 8TbyBrD6bk8OHjp0X42ReCBNPLuDH3G/TWRDw8LtHc6LDLW0K16cm7+VMYF8x9yP
29 | 0BPuO3NixUtfEwKCAQEAysNqV1Ml8+Q+JaCCCOmSK65ad4XUepaarNbJOVuLywNI
30 | ok8VzSjJ3y/wsq9JlrhBMir3SZhjK/KqNb/o7zkDlp9/AAmppdR/iMdC7okmLMWn
31 | j3UM3UvqtH9nhBs+rSzlEMq6Qt6F77rOKB3nmHkcexZIPsC4Ct3ZWvuOCXkHQytT
32 | RLN4b9iqhvw9Zy4Wq8CSKEY9SqXvkaNa8YTyoyPqYozpEaJwAMv0mjNtEqyfic3S
33 | KBRMTTo0SQcLpn/Ee2OL+2EIjy9QgDMOGRDs0fs/xuwZFqeao3dr90jS2vL3atVN
34 | IxiGAI4Hplb2TaGyNTu/1yoXJXKgiel/o91qtmE1IwKCAQEA3m2No1j1JFLuHipB
35 | UnleBx4Q2XaXb6JFBOubQerI05jPiL2q8rdlHSe9/ovp0N/6hta6WVY7oemOg+6U
36 | +CSgKHglCCJjHPec85lYU5PPxZWPzqtFAAm6J6ZOysromHlCVTWiI/fQnEhx0qnv
37 | kvwQYvOULU1BKa5+zpnbbWFAEKWtEdixD0t2wXhA3aUjW1ggCD0sH76q/cAFvQrA
38 | CW4moaCywLLoBuL0ShAfjjV08hzoFeOHaodN4jBMcjNA+KggnaALwcUqwDG24+Y3
39 | OIt2XZrXWjLnpQniw9v4bf8xjodSyH9AsbElGQlOdzbmsb8jbF+QUUW0qecu8BWR
40 | gHculQKCAQEAmfhGekVTnp6FasE1vVrQeocNf5GKxgQzNGhtqTaRMvotX8M6RO5i
41 | TS70Ulu1P9Ru/Y+O9L3ZIPhGtEYktfPPe8NmBztPLfPtXIojk0tmR71X/iHeQPVz
42 | JtlQXArsT0i2MUggpMKhZmeuQNxkj234aKeE+NITb30DnolDVIIpN6Jgutyl6hjX
43 | dWV5oy5mXMoAssCTrmnPQAKR/rD8J1IQnAFwwslcz94Qwj+m5fVbuKMooPK49jPq
44 | nEHTYP3I0AHJvHv0qfY95PvgCrzFeLaXuZBzhLaFQPhgbglIxKaXpvKOfsYSi71O
45 | pcuHgW/2CWJzzQnTRcaDjfZXzLFIZXHvjQKCAQEAwnl07PTmCWFuZGGNQPM8DO5q
46 | s5TFy2c48vOlHSlqnatJoXc5rqga7GpI1fymRBDAUjRPYr+NdyXHQHaqrzKK54Dj
47 | Ca/nDJrn5830GKqpqEG2m2P1gbpmuFb9H63ugNCnC68ktu1S8672hC+/hgz3TNxB
48 | S1BNQ6PywAy8UfxJQa4S/l8QjKuOYl87/d7Ud4/UfRa4nVSA8/XF5fg/Cv21mATp
49 | iluyD/oMrP+uTGtWJlOhOgj8lH1alHdY8PogQ6ZlXJqL2JkgW0/25WjLXIZS88DQ
50 | CQp28IBolmUdAChYTq0V6NbCinqviE/mDMNpp09qCxMXrGabBB4j6Xm35vlivg==
51 | -----END RSA PRIVATE KEY-----
52 |
--------------------------------------------------------------------------------
/_examples/mitm-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "github.com/telanflow/mps"
6 | "log"
7 | "net/http"
8 | "os"
9 | "os/signal"
10 | "regexp"
11 | "syscall"
12 | )
13 |
14 | // A simple mitm proxy server
15 | func main() {
16 | quitSignChan := make(chan os.Signal)
17 |
18 | // create proxy server
19 | proxy := mps.NewHttpProxy()
20 |
21 | // Load cert file
22 | // The Connect request is processed using MitmHandler
23 | mitmHandler, err := mps.NewMitmHandlerWithCertFile(proxy.Ctx, "./examples/mitm-proxy/ca.crt", "./examples/mitm-proxy/ca.key")
24 | if err != nil {
25 | log.Panic(err)
26 | }
27 | proxy.HandleConnect = mitmHandler
28 |
29 | // Middleware
30 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
31 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
32 | return ctx.Next(req)
33 | })
34 |
35 | // Filter
36 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$")))
37 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) {
38 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL)
39 | return req, nil
40 | })
41 | respGroup := proxy.OnResponse()
42 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) {
43 | if err != nil {
44 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err)
45 | return resp, err
46 | }
47 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
48 | return resp, err
49 | })
50 |
51 | // Started proxy server
52 | srv := http.Server{
53 | Addr: "localhost:8080",
54 | Handler: proxy,
55 | }
56 | go func() {
57 | log.Printf("MitmProxy started listen: http://%s", srv.Addr)
58 | err := srv.ListenAndServe()
59 | if errors.Is(err, http.ErrServerClosed) {
60 | return
61 | }
62 | if err != nil {
63 | quitSignChan <- syscall.SIGKILL
64 | log.Fatalf("MitmProxy start fail: %v", err)
65 | }
66 | }()
67 |
68 | // quit signal
69 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT)
70 |
71 | <-quitSignChan
72 | _ = srv.Close()
73 | log.Fatal("MitmProxy server stop!")
74 | }
75 |
--------------------------------------------------------------------------------
/_examples/reverse-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "log"
6 | "net/http"
7 | "net/url"
8 | "os"
9 | "os/signal"
10 | "syscall"
11 |
12 | "github.com/telanflow/mps"
13 | "github.com/telanflow/mps/middleware"
14 | )
15 |
16 | // A simple reverse proxy server
17 | func main() {
18 | targetURL, _ := url.Parse("https://www.google.com")
19 | quitSignChan := make(chan os.Signal)
20 |
21 | // reverse proxy server
22 | proxy := mps.NewReverseHandler()
23 | proxy.UseFunc(middleware.SingleHostReverseProxy(targetURL))
24 |
25 | reqGroup := proxy.OnRequest()
26 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) {
27 | log.Printf("[INFO] req -- %s %s", req.Method, req.Host)
28 | return req, nil
29 | })
30 |
31 | respGroup := proxy.OnResponse()
32 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) {
33 | if err != nil {
34 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err)
35 | return nil, err
36 | }
37 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
38 | return resp, err
39 | })
40 |
41 | // started proxy server
42 | srv := http.Server{
43 | Addr: "localhost:8080",
44 | Handler: proxy,
45 | }
46 | go func() {
47 | log.Printf("ReverseProxy started listen: http://%s", srv.Addr)
48 | err := srv.ListenAndServe()
49 | if errors.Is(err, http.ErrServerClosed) {
50 | return
51 | }
52 | if err != nil {
53 | quitSignChan <- syscall.SIGKILL
54 | log.Fatalf("ReverseProxy start fail: %v", err)
55 | }
56 | }()
57 |
58 | // quit signal
59 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT)
60 |
61 | <-quitSignChan
62 | _ = srv.Close()
63 | log.Fatal("ReverseProxy server stop!")
64 | }
65 |
--------------------------------------------------------------------------------
/_examples/simple-http-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "log"
6 | "net/http"
7 | "os"
8 | "os/signal"
9 | "regexp"
10 | "syscall"
11 |
12 | "github.com/telanflow/mps"
13 | )
14 |
15 | // A simple http proxy server
16 | func main() {
17 | quitSignChan := make(chan os.Signal)
18 |
19 | // create a http proxy server
20 | proxy := mps.NewHttpProxy()
21 | proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
22 | log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL)
23 | return ctx.Next(req)
24 | })
25 |
26 | // Filter Request
27 | reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$")))
28 | reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) {
29 | log.Printf("[INFO] req -- %s %s", req.Method, req.URL)
30 | return req, nil
31 | })
32 |
33 | // Filter Response
34 | respGroup := proxy.OnResponse()
35 | respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) {
36 | if err != nil {
37 | log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err)
38 | return resp, err
39 | }
40 |
41 | log.Printf("[INFO] resp -- %d", resp.StatusCode)
42 | return resp, err
43 | })
44 |
45 | // Start server
46 | srv := &http.Server{
47 | Addr: "localhost:8080",
48 | Handler: proxy,
49 | }
50 | go func() {
51 | log.Printf("HttpProxy started listen: http://%s", srv.Addr)
52 | err := srv.ListenAndServe()
53 | if errors.Is(err, http.ErrServerClosed) {
54 | return
55 | }
56 | if err != nil {
57 | quitSignChan <- syscall.SIGKILL
58 | log.Fatalf("HttpProxy start fail: %v", err)
59 | }
60 | }()
61 |
62 | // quit signal
63 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT)
64 |
65 | <-quitSignChan
66 | _ = srv.Close()
67 | log.Fatal("HttpProxy server stop!")
68 | }
69 |
--------------------------------------------------------------------------------
/_examples/websocket-proxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "log"
6 | "net/http"
7 | "net/url"
8 | "os"
9 | "os/signal"
10 | "syscall"
11 |
12 | "github.com/gorilla/websocket"
13 | "github.com/telanflow/mps"
14 | )
15 |
16 | var (
17 | upgrader = websocket.Upgrader{}
18 | endPointAddr = "localhost:9990"
19 | )
20 |
21 | // run a endPoint websocket server
22 | func runWebsocketServer() {
23 | http.ListenAndServe(endPointAddr, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
24 | c, err := upgrader.Upgrade(rw, req, nil)
25 | if err != nil {
26 | return
27 | }
28 | defer c.Close()
29 | for {
30 | mt, message, err := c.ReadMessage()
31 | if err != nil {
32 | break
33 | }
34 | err = c.WriteMessage(mt, message)
35 | if err != nil {
36 | break
37 | }
38 | }
39 | }))
40 | }
41 |
42 | // A simple proxy websocket server
43 | func main() {
44 | // quit signal
45 | quitSignChan := make(chan os.Signal)
46 | signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT)
47 |
48 | // start endPoint websocket server
49 | go runWebsocketServer()
50 |
51 | // start proxy websocket server
52 | websocketHandler := mps.NewWebsocketHandler()
53 | websocketHandler.Transport().Proxy = func(request *http.Request) (*url.URL, error) {
54 | // endPoint websocket server
55 | return url.Parse("ws://" + endPointAddr)
56 | }
57 | srv := &http.Server{
58 | Addr: "localhost:8080",
59 | Handler: websocketHandler,
60 | }
61 | go func() {
62 | log.Printf("WebsocketProxy started listen: ws://%s", srv.Addr)
63 | err := srv.ListenAndServe()
64 | if errors.Is(err, http.ErrServerClosed) {
65 | return
66 | }
67 | if err != nil {
68 | quitSignChan <- syscall.SIGKILL
69 | log.Fatalf("WebsocketProxy start fail: %v", err)
70 | }
71 | }()
72 |
73 | <-quitSignChan
74 | _ = srv.Close()
75 | log.Fatal("WebsocketProxy server stop!")
76 | }
77 |
--------------------------------------------------------------------------------
/cert/cert.go:
--------------------------------------------------------------------------------
1 | package cert
2 |
3 | import "crypto/tls"
4 |
5 | const CertPEM = `-----BEGIN CERTIFICATE-----
6 | MIIF7jCCA9agAwIBAgIJAP/+a5pIA2lJMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD
7 | VQQGEwJDTjERMA8GA1UECAwIWmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQww
8 | CgYDVQQKDANtcHMxDDAKBgNVBAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5p
9 | bzEiMCAGCSqGSIb3DQEJARYTdGVsYW5mbG93QGdtYWlsLmNvbTAeFw0yMDA4MDYx
10 | MTE4MThaFw00MDA4MDExMTE4MThaMIGLMQswCQYDVQQGEwJDTjERMA8GA1UECAwI
11 | WmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQwwCgYDVQQKDANtcHMxDDAKBgNV
12 | BAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5pbzEiMCAGCSqGSIb3DQEJARYT
13 | dGVsYW5mbG93QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
14 | ggIBANJZU0vyrS7aROi5+0e6AR4VBulFEjoivLrYaa1Pl1ENHHTgfjjmnLf2+22G
15 | ImMp95RUDYIT2tZ2GhksLJil+fJEvv7HMihsWYYTjGzr5u3kPke0+fB/7dbRYJ+h
16 | FvlsLEkItYPT9iBHryStu5CRV3P1VNtR9/7FF8YdX3kOqMQASnHQhBYNZ7av2OuR
17 | 3pDPLD0PKccqMeTXW+yMsB+z0L03RQQG3LOmi/7nWogvqVrnuwP7JbybOtHEvLO0
18 | rLEoAdXwdCCSAHdBCz2qat/I9CubGlKdUlgVw8eXVWZeYJeVOOQy8f7L9AEPoc5k
19 | uXpEyRPCzpo/T/6KSxi2oxaEI4BSZUtyxRS/Laezdgs+GnKkjO56Z3lMPCvwwLFO
20 | DNdtxA3OgLIvcZSA9zWPgoOSVQ0nCIQl3L3qEJ/TqyUWkcPINhiLNgnVSdu1dQ7q
21 | rFZegmi5RAKAyl0M1rSlmTAB3Q/Mf4BMzPaNUajW7bjx4MbU9LxknVlRUb1vv2Jv
22 | Pd6mUm0vLy6P/zl8/pZRpcnn91omFJ+PgZMoRzUPTBNDrgUEeXNsLzaLHg6t4fLb
23 | xd1QMsg99Upo643Q/Hb8Xfz2ogm82jRURXkiHhQgxPjUvk76N4obNW9noMlZEUpF
24 | /68/WwMc2CrWvWZ1HKWfpJDN6C2hjOqWvVWBng6LssVZdIBnAgMBAAGjUzBRMB0G
25 | A1UdDgQWBBSPklwhHPcnDnP8tNSQ2i+VAs/gvjAfBgNVHSMEGDAWgBSPklwhHPcn
26 | DnP8tNSQ2i+VAs/gvjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
27 | AQA6KnQfPV1gS53ZwakZAzE3XEDx+ef1C0iFZx282PWcIwnBPYbkswTt8RJj5806
28 | MiKyBtHSDN3Agde5LP6C2BhCx2GeguUDcTDPY0PGj5/TaES2gRiu6rsKkQJhSUTs
29 | RSPekDT30yCJMJSz/1QnOqGXwToSpyr5rsWxAyYGAAz0fSAwpJ2XuE77vHCk+p3k
30 | zIOjWxrkkLTSxoqIhOjmq8hO5qvwojudUTn0PBcAakND6r5csWLR2i+Am7u36WJc
31 | m8dX8L+SXmWYw85Fs4tLDwUsORFuJclY32g7fKOvrgV2rhHXwwWw8pjpbjs0AAKA
32 | +Gk3QQMT0cF4FpnH8VBdK82/nOtcbNvWz994K18kEzGUzN5Dq7zi+7n99HH4rwjO
33 | 2eyWNl7hsatvtKfcpwHhIt63EHG3owhc+Wf0iG3i6BR1b9jO3XLJmIlgkQ+4pSpV
34 | +9BB2HNfklUOvJdsWwSxdhChavHJokl0rdGRf8weWhqhLkUCC3z2lIIylXgay71W
35 | ++48SxdaMbiqnuEZrt4cMlOt+KAxvtEl1krWLAIi6URHLSvdER7w/Dtkvg+PPb1t
36 | yLiugEvmBIpagsbw3zirMza8Rg1CchRB+0sRGVSE4ppRe5EiIAe4aZUKufIk0TtS
37 | 8yO67j6Lx35sfhqzg/Jl31HOk42M8MpZqAGy13Cyw1kGmg==
38 | -----END CERTIFICATE-----`
39 |
40 | const KeyPEM = `-----BEGIN RSA PRIVATE KEY-----
41 | MIIJKgIBAAKCAgEA0llTS/KtLtpE6Ln7R7oBHhUG6UUSOiK8uthprU+XUQ0cdOB+
42 | OOact/b7bYYiYyn3lFQNghPa1nYaGSwsmKX58kS+/scyKGxZhhOMbOvm7eQ+R7T5
43 | 8H/t1tFgn6EW+WwsSQi1g9P2IEevJK27kJFXc/VU21H3/sUXxh1feQ6oxABKcdCE
44 | Fg1ntq/Y65HekM8sPQ8pxyox5Ndb7IywH7PQvTdFBAbcs6aL/udaiC+pWue7A/sl
45 | vJs60cS8s7SssSgB1fB0IJIAd0ELPapq38j0K5saUp1SWBXDx5dVZl5gl5U45DLx
46 | /sv0AQ+hzmS5ekTJE8LOmj9P/opLGLajFoQjgFJlS3LFFL8tp7N2Cz4acqSM7npn
47 | eUw8K/DAsU4M123EDc6Asi9xlID3NY+Cg5JVDScIhCXcveoQn9OrJRaRw8g2GIs2
48 | CdVJ27V1DuqsVl6CaLlEAoDKXQzWtKWZMAHdD8x/gEzM9o1RqNbtuPHgxtT0vGSd
49 | WVFRvW+/Ym893qZSbS8vLo//OXz+llGlyef3WiYUn4+BkyhHNQ9ME0OuBQR5c2wv
50 | NoseDq3h8tvF3VAyyD31SmjrjdD8dvxd/PaiCbzaNFRFeSIeFCDE+NS+Tvo3ihs1
51 | b2egyVkRSkX/rz9bAxzYKta9ZnUcpZ+kkM3oLaGM6pa9VYGeDouyxVl0gGcCAwEA
52 | AQKCAgArbEc2wXUg2+wnwuTtrKc4Z4zSsPCPUcZ2J+DA51JMaBF8yy8jXe/yRikn
53 | Ne55XBuA4k0bki+14BGJKsZWCMVtTuXCwKpJD/z3Iaf2gEheyaRVtzV1gWM+2mBA
54 | 88dDXCJUPVkDSslfZozwXHEA6hAMnxOSZvxz+onq2vtviSgrtgeoMSxjRQco/mog
55 | Ty+L40i1niC4vawpGpAeZ/ifwsYPmY5Ew4niCDqUN3xH6tbiLj48Fyd2JPFihmOS
56 | EXUo6SJf4NCIPLud4q6IX1rKsbg+HDm13kY2at/MnyABDvCPuj1RVncAa2gGpAx6
57 | B+8GH5cG3ks6KmHAIRpZkrJeHo8ZOZOVsfokBDjTEWUn3sQH8RrVXtNnm/W66dAl
58 | m4LnKyBWvyVaOHn65Jq06XTaUrT/9MY2RDmLPehzhcPczcZZn8RQikrarStQuHk2
59 | DOiiCvjSVnh+O13RdCKBMXfG4A482LucFnSSweiuFrXDU87GO++jKaZoYPeXsQul
60 | jlTNFUyr7zHO5gqVf+JzboRG+pwashiFZBCfVqu/h9Abx7BTOfIXi5k5f0r1elZw
61 | hQwJT0WgJKX4MjehojABNi+t4i6xqcsmCB9D68FBONSLHY4qpa1s9rhHEzd6BTJS
62 | Fg4GPFooxQ2lwhjAyz2ZhG6HbF3QNuV6HgoimkcYlLiVowx5qQKCAQEA7XiDMQoA
63 | 4J+ZVlseGPinIVpOuAY8ehcd25I9xLmaqvk0CGeGBvpC125KGnG42m4l133AicOa
64 | +Dz0yU0UudvfvnJEyPd6ojpAYMxX5/MH+85hJ8ARPwQQ+K+TlIPKH/8jXhS8D25D
65 | pcvY4MJftwiuBhFTYlveNnmbH7QCAge/lS0BSnlOMGUI/yBg4DSOt0sAZOblWNlv
66 | 1FKG3aKCdd3atY1VvTyGnqpUqFiuU6ENwSbam77hTuQHjE3rULESj/wHIxpj/2Gr
67 | VsjD/o29h2jjUApseUQqi55TllBj60K/DGuTiKSj8PXmXaBy20MDpef9HEYOVD/j
68 | lsB4aPHqqn6+BQKCAQEA4sMNrA7xAfFRLJphzXd/6k9ISbKTjcp57f4Akib3FqCo
69 | BJPz1F8cQJ5BHZLBJum/jyfbPEd36owr5bn/JlMqXsPgzb3eik/ZoIA/woucFosh
70 | 8MrebpARuSMmNtC0F2VfEDG4G/p7c+/aYWLPJJbte0XmIvsIVXwtlAt6k+HCAVAW
71 | PA+MLAelEC0gtHOk1ea9NN2VCfsfpsbw/4GSUlL6Efev5ufXyH2z+tPmZyusPnBS
72 | fAGZ78d000mH3RVsGN3o22Tzv8Hpx62MV8U08TvCESZsjggX2lCpVIsi3GRDR5wj
73 | NFU7LOEy7TljYMES8GIyNc8U9csIgyh2+WLSjcSkewKCAQEAvuFa2uVGlUfUkoSF
74 | ad8dQIL9uZBRtnW0a1Vezy299GaB+6tzIVKyvcYKTL1SsElPo6qSRGp1u8oLnW+X
75 | FFp3u/bP8ZZz/cjDDMvUcT56EV7v22rYsgWLusou33cb1qJYBHy4OdMRD0kO2IOF
76 | OnQApiHxG6Pqt3ECTvZ7krQ1vCxD2GAviFj+ZUzacf3tJcpk07aBbezBpjJ789V3
77 | 9lRRRBQKciUftJQHnpZB8jkH/FVF7WD+bFKA+rd7Sg47dH9KIV5KOPKCLi0M1iWK
78 | zjhyV1k5njQ72qR2XeHanzW0qcAjA/gLS1ntRR7+k96HJSmX2804IWKFhxzI7Npg
79 | HZHpHQKCAQEAiabMGuErDfnWQ9QngJmE7dBY2lvr1EvP/leNMysyHOtDcxv5DLb7
80 | qIIolvIqDBwi65zPKeVcduXGE/r3VuVvN/2B7oLOn3lfa13O1qL3CnxFCy2rHsSX
81 | 7aHXpbjFSdqAfY0g7OL9o+A62Zkok1aHLKi+zgdDBNmPtWnObAzEPxXFmYn6lhPB
82 | 8HLkgoYczrf1rSzBN0DY8t2bGA8oqo6yPMv1XJ7qT0t3QND28TQCqBh5CcvTDUov
83 | sb7WGa/SYbn7i4rZqFLnPg4svm7492NGKDEB/qoNCLqkP60CaXT3nnW6rR77//9o
84 | cba/i9FIVOHXBvEBET/BmBStPC/wDp0LFwKCAQEA1sZKLH3IhWMbdqOcWbO8H1VC
85 | 4NT74peijfTjJ1JxcLllyv1H0MXW9qXG0Sksmy41CdyPBZUQuYbzi78p8S1aNIlx
86 | sj1VGbGIHk+YNMJYHBlTBpn9hjDjXP3tZHtVHRzZN0rjpFV76ODpxatvKxryPl3X
87 | mAPMhTvwnxnQ2rNF7RBPC8H8qJVBbG98k9HCqHolbNiMhV2Iow2SKSuX57IjT0cY
88 | 7mxps5zU94dTyJARNZaP7nlGHv1qx2ihRqxksIWxFetQ+U1JrM/14aeFUtS23HM+
89 | MJQxyIWYaidDHJzHy1MiZBZ5dpC2hwqNSjLq/OoDEV2cAYHEMnLnXCiz7ZulUA==
90 | -----END RSA PRIVATE KEY-----`
91 |
92 | // default certificate
93 | var DefaultCertificate, _ = tls.X509KeyPair([]byte(CertPEM), []byte(KeyPEM))
94 |
--------------------------------------------------------------------------------
/cert/container.go:
--------------------------------------------------------------------------------
1 | package cert
2 |
3 | import "crypto/tls"
4 |
5 | // certificate storage Container
6 | type Container interface {
7 |
8 | // Get the certificate for host
9 | Get(host string) (*tls.Certificate, error)
10 |
11 | // Set the certificate for host
12 | Set(host string, cert *tls.Certificate) error
13 | }
14 |
--------------------------------------------------------------------------------
/cert/mem_provider.go:
--------------------------------------------------------------------------------
1 | package cert
2 |
3 | import (
4 | "crypto/tls"
5 | "fmt"
6 | "strings"
7 | "sync"
8 | )
9 |
10 | var DefaultMemProvider = NewMemProvider()
11 |
12 | // MemProvider A simple in-memory certificate cache
13 | type MemProvider struct {
14 | cache map[string]*tls.Certificate
15 | rw sync.RWMutex
16 | }
17 |
18 | // Create a MemProvider
19 | func NewMemProvider() *MemProvider {
20 | return &MemProvider{
21 | cache: make(map[string]*tls.Certificate),
22 | rw: sync.RWMutex{},
23 | }
24 | }
25 |
26 | // Get the certificate for the Host from the cache
27 | func (m *MemProvider) Get(host string) (cert *tls.Certificate, err error) {
28 | var ok bool
29 | cert, ok = m.cache[strings.TrimSpace(host)]
30 | if !ok {
31 | err = fmt.Errorf("cert not exist")
32 | }
33 | return
34 | }
35 |
36 | // Set the Host certificate to the cache
37 | func (m *MemProvider) Set(host string, cert *tls.Certificate) error {
38 | host = strings.TrimSpace(host)
39 | m.rw.Lock()
40 | m.cache[host] = cert
41 | m.rw.Unlock()
42 | return nil
43 | }
44 |
--------------------------------------------------------------------------------
/chunked.go:
--------------------------------------------------------------------------------
1 | // Taken from $GOROOT/src/pkg/net/http/chunked
2 | // needed to write https responses to client.
3 | package mps
4 |
5 | import (
6 | "io"
7 | "strconv"
8 | )
9 |
10 | // newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
11 | // "chunked" format before writing them to w. Closing the returned chunkedWriter
12 | // sends the final 0-length chunk that marks the end of the stream.
13 | //
14 | // newChunkedWriter is not needed by normal applications. The http
15 | // package adds chunking automatically if handlers don't set a
16 | // Content-Length header. Using newChunkedWriter inside a handler
17 | // would result in double chunking or chunking with a Content-Length
18 | // length, both of which are wrong.
19 | func newChunkedWriter(w io.Writer) io.WriteCloser {
20 | return &chunkedWriter{w}
21 | }
22 |
23 | // Writing to chunkedWriter translates to writing in HTTP chunked Transfer
24 | // Encoding wire format to the underlying Wire chunkedWriter.
25 | type chunkedWriter struct {
26 | Wire io.Writer
27 | }
28 |
29 | // Write the contents of data as one chunk to Wire.
30 | // NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
31 | // a bug since it does not check for success of io.WriteString
32 | func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
33 | // Don't send 0-length data. It looks like EOF for chunked encoding.
34 | if len(data) == 0 {
35 | return 0, nil
36 | }
37 |
38 | head := strconv.FormatInt(int64(len(data)), 16) + "\r\n"
39 |
40 | if _, err = io.WriteString(cw.Wire, head); err != nil {
41 | return 0, err
42 | }
43 | if n, err = cw.Wire.Write(data); err != nil {
44 | return
45 | }
46 | if n != len(data) {
47 | err = io.ErrShortWrite
48 | return
49 | }
50 | _, err = io.WriteString(cw.Wire, "\r\n")
51 | return
52 | }
53 |
54 | func (cw *chunkedWriter) Close() error {
55 | _, err := io.WriteString(cw.Wire, "0\r\n")
56 | return err
57 | }
58 |
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "errors"
7 | "net"
8 | "net/http"
9 | "time"
10 | )
11 |
12 | var (
13 | // http request is nil
14 | RequestNilErr = errors.New("request is nil")
15 | // http request method not support
16 | MethodNotSupportErr = errors.New("request method not support")
17 | // http request is websocket
18 | RequestWebsocketUpgradeErr = errors.New("websocket upgrade")
19 | )
20 |
21 | // Context for the request
22 | // which contains Middleware, Transport, and other values
23 | type Context struct {
24 | // context.Context
25 | Context context.Context
26 |
27 | // Request context-dependent requests
28 | Request *http.Request
29 |
30 | // Response is associated with Request
31 | Response *http.Response
32 |
33 | // Transport is used for global HTTP requests, and it will be reused.
34 | Transport *http.Transport
35 |
36 | // In some cases it is not always necessary to remove the proxy headers.
37 | // For example, cascade proxy
38 | KeepProxyHeaders bool
39 |
40 | // In some cases it is not always necessary to reset the headers.
41 | KeepClientHeaders bool
42 |
43 | // KeepDestinationHeaders indicates the proxy should retain any headers
44 | // present in the http.Response before proxying
45 | KeepDestinationHeaders bool
46 |
47 | // middlewares ACTS on Request and Response.
48 | // It's going to be reused by the Context
49 | // mi is the index subscript of the middlewares traversal
50 | // the default value for the index is -1
51 | mi int
52 | middlewares []Middleware
53 | }
54 |
55 | // NewContext create http request Context
56 | func NewContext() *Context {
57 | return &Context{
58 | Context: context.Background(),
59 | // Cannot reuse one Transport because multiple proxy can collide with each other
60 | Transport: &http.Transport{
61 | DialContext: (&net.Dialer{
62 | Timeout: 15 * time.Second,
63 | KeepAlive: 30 * time.Second,
64 | DualStack: true,
65 | }).DialContext,
66 | MaxIdleConns: 100,
67 | IdleConnTimeout: 90 * time.Second,
68 | TLSHandshakeTimeout: 10 * time.Second,
69 | ExpectContinueTimeout: 1 * time.Second,
70 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
71 | Proxy: http.ProxyFromEnvironment,
72 | },
73 | Request: nil,
74 | Response: nil,
75 | KeepProxyHeaders: false,
76 | KeepClientHeaders: false,
77 | KeepDestinationHeaders: false,
78 | mi: -1,
79 | middlewares: make([]Middleware, 0),
80 | }
81 | }
82 |
83 | // Use registers an Middleware to proxy
84 | func (ctx *Context) Use(middleware ...Middleware) {
85 | if ctx.middlewares == nil {
86 | ctx.middlewares = make([]Middleware, 0)
87 | }
88 | ctx.middlewares = append(ctx.middlewares, middleware...)
89 | }
90 |
91 | // UseFunc registers an MiddlewareFunc to proxy
92 | func (ctx *Context) UseFunc(fns ...MiddlewareFunc) {
93 | if ctx.middlewares == nil {
94 | ctx.middlewares = make([]Middleware, 0)
95 | }
96 | for _, fn := range fns {
97 | ctx.middlewares = append(ctx.middlewares, fn)
98 | }
99 | }
100 |
101 | // Next to exec middlewares
102 | // Execute the next middleware as a linked list. "ctx.Next(req)"
103 | // eg:
104 | //
105 | // func Handle(req *http.Request, ctx *Context) (*http.Response, error) {
106 | // // You can do anything to modify the http.Request ...
107 | // resp, err := ctx.Next(req)
108 | // // You can do anything to modify the http.Response ...
109 | // return resp, err
110 | // }
111 | //
112 | // Alternatively, you can simply return the response without executing `ctx.Next()`,
113 | // which will interrupt subsequent middleware execution.
114 | func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
115 | var (
116 | total = len(ctx.middlewares)
117 | err error
118 | )
119 | ctx.mi++
120 | if ctx.mi >= total {
121 | ctx.mi = -1
122 | // Final request coverage
123 | ctx.Request = req
124 | if req == nil {
125 | return nil, RequestNilErr
126 | }
127 | // To make the middleware available to the tunnel proxy,
128 | // no response is obtained when the request method is equal to Connect
129 | if req.Method == http.MethodConnect {
130 | return nil, MethodNotSupportErr
131 | }
132 | // Is it a Websocket requests
133 | if isWebSocketRequest(req) {
134 | return nil, RequestWebsocketUpgradeErr
135 | }
136 |
137 | return func() (*http.Response, error) {
138 | // explicitly discard request body to avoid data races in certain RoundTripper implementations
139 | // see https://github.com/golang/go/issues/61596#issuecomment-1652345131
140 | defer req.Body.Close()
141 | return ctx.RoundTrip(req)
142 | }()
143 | }
144 |
145 | middleware := ctx.middlewares[ctx.mi]
146 | ctx.Response, err = middleware.Handle(req, ctx)
147 | ctx.mi = -1
148 | return ctx.Response, err
149 | }
150 |
151 | // RoundTrip implements the RoundTripper interface.
152 | //
153 | // For higher-level HTTP client support (such as handling of cookies
154 | // and redirects), see Get, Post, and the Client type.
155 | //
156 | // Like the RoundTripper interface, the error types returned
157 | // by RoundTrip are unspecified.
158 | func (ctx *Context) RoundTrip(req *http.Request) (*http.Response, error) {
159 | // These Headers must be reset when a client Request is issued to reuse a Request
160 | if !ctx.KeepClientHeaders {
161 | ResetClientHeaders(req)
162 | }
163 |
164 | // In some cases it is not always necessary to remove the Proxy Header.
165 | // For example, cascade proxy
166 | if !ctx.KeepProxyHeaders {
167 | RemoveProxyHeaders(req)
168 | }
169 |
170 | if ctx.Transport != nil {
171 | return ctx.Transport.RoundTrip(req)
172 | }
173 | return DefaultTransport.RoundTrip(req)
174 | }
175 |
176 | // WithRequest get the Context of the request
177 | func (ctx *Context) WithRequest(req *http.Request) *Context {
178 | return &Context{
179 | Context: ctx.Context,
180 | Request: req,
181 | Response: nil,
182 | KeepProxyHeaders: ctx.KeepProxyHeaders,
183 | KeepClientHeaders: ctx.KeepClientHeaders,
184 | KeepDestinationHeaders: ctx.KeepDestinationHeaders,
185 | Transport: ctx.Transport,
186 | mi: -1,
187 | middlewares: ctx.middlewares,
188 | }
189 | }
190 |
191 | // ResetClientHeaders These Headers must be reset when a client Request is issued to reuse a Request
192 | func ResetClientHeaders(r *http.Request) {
193 | // this must be reset when serving a request with the client
194 | r.RequestURI = ""
195 | // If no Accept-Encoding header exists, Transport will add the headers it can accept
196 | // and would wrap the response body with the relevant reader.
197 | r.Header.Del("Accept-Encoding")
198 | }
199 |
200 | // Hop-by-hop headers. These are removed when sent to the backend.
201 | // As of RFC 7230, hop-by-hop headers are required to appear in the
202 | // Connection header field. These are the headers defined by the
203 | // obsoleted RFC 2616 (section 13.5.1) and are used for backward
204 | // compatibility.
205 | func RemoveProxyHeaders(r *http.Request) {
206 | // RFC 2616 (section 13.5.1)
207 | // https://www.ietf.org/rfc/rfc2616.txt
208 | r.Header.Del("Proxy-Connection")
209 | r.Header.Del("Proxy-Authenticate")
210 | r.Header.Del("Proxy-Authorization")
211 | // Connection, Authenticate and Authorization are single hop Header:
212 | // http://www.w3.org/Protocols/rfc2616/rfc2616.txt
213 | // 14.10 Connection
214 | // The Connection general-header field allows the sender to specify
215 | // options that are desired for that particular connection and MUST NOT
216 | // be communicated by proxies over further connections.
217 |
218 | // When server reads http request it sets req.Close to true if
219 | // "Connection" header contains "close".
220 | // https://github.com/golang/go/blob/master/src/net/http/request.go#L1080
221 | // Later, transfer.go adds "Connection: close" back when req.Close is true
222 | // https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275
223 | // That's why tests that checks "Connection: close" removal fail
224 | if r.Header.Get("Connection") == "close" {
225 | r.Close = false
226 | }
227 | r.Header.Del("Connection")
228 | }
229 |
--------------------------------------------------------------------------------
/counter_encryptor.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "crypto/aes"
5 | "crypto/cipher"
6 | "crypto/ecdsa"
7 | "crypto/rsa"
8 | "crypto/sha256"
9 | "crypto/x509"
10 | "errors"
11 | )
12 |
13 | type CounterEncryptorRand struct {
14 | cipher cipher.Block
15 | counter []byte
16 | rand []byte
17 | ix int
18 | }
19 |
20 | func NewCounterEncryptorRand(key interface{}, seed []byte) (r CounterEncryptorRand, err error) {
21 | var keyBytes []byte
22 | switch key := key.(type) {
23 | case *rsa.PrivateKey:
24 | keyBytes = x509.MarshalPKCS1PrivateKey(key)
25 | case *ecdsa.PrivateKey:
26 | if keyBytes, err = x509.MarshalECPrivateKey(key); err != nil {
27 | return
28 | }
29 | default:
30 | err = errors.New("only RSA and ECDSA keys supported")
31 | return
32 | }
33 | h := sha256.New()
34 | if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil {
35 | return
36 | }
37 | r.counter = make([]byte, r.cipher.BlockSize())
38 | if seed != nil {
39 | copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()])
40 | }
41 | r.rand = make([]byte, r.cipher.BlockSize())
42 | r.ix = len(r.rand)
43 | return
44 | }
45 |
46 | func (c *CounterEncryptorRand) Seed(b []byte) {
47 | if len(b) != len(c.counter) {
48 | panic("SetCounter: wrong counter size")
49 | }
50 | copy(c.counter, b)
51 | }
52 |
53 | func (c *CounterEncryptorRand) refill() {
54 | c.cipher.Encrypt(c.rand, c.counter)
55 | for i := 0; i < len(c.counter); i++ {
56 | if c.counter[i]++; c.counter[i] != 0 {
57 | break
58 | }
59 | }
60 | c.ix = 0
61 | }
62 |
63 | func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) {
64 | if c.ix == len(c.rand) {
65 | c.refill()
66 | }
67 | if n = len(c.rand) - c.ix; n > len(b) {
68 | n = len(b)
69 | }
70 | copy(b, c.rand[c.ix:c.ix+n])
71 | c.ix += n
72 | return
73 | }
74 |
--------------------------------------------------------------------------------
/counter_encryptor_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bytes"
5 | "crypto/rsa"
6 | "encoding/binary"
7 | "io"
8 | "math"
9 | "math/rand"
10 | "testing"
11 | )
12 |
13 | type RandSeedReader struct {
14 | r rand.Rand
15 | }
16 |
17 | func (r *RandSeedReader) Read(b []byte) (n int, err error) {
18 | for i := range b {
19 | b[i] = byte(r.r.Int() & 0xFF)
20 | }
21 | return len(b), nil
22 | }
23 |
24 | func TestCounterEncDifferentConsecutive(t *testing.T) {
25 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128)
26 | fatalOnErr(err, "rsa.GenerateKey", t)
27 | c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog"))
28 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t)
29 | for i := 0; i < 100*1000; i++ {
30 | var a, b int64
31 | binary.Read(&c, binary.BigEndian, &a)
32 | binary.Read(&c, binary.BigEndian, &b)
33 | if a == b {
34 | t.Fatal("two consecutive equal int64", a, b)
35 | }
36 | }
37 | }
38 |
39 | func TestCounterEncIdenticalStreams(t *testing.T) {
40 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128)
41 | fatalOnErr(err, "rsa.GenerateKey", t)
42 | c1, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog"))
43 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t)
44 | c2, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog"))
45 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t)
46 | nout := 1000
47 | out1, out2 := make([]byte, nout), make([]byte, nout)
48 | io.ReadFull(&c1, out1)
49 | tmp := out2[:]
50 | rand.Seed(0xFF43109)
51 | for len(tmp) > 0 {
52 | n := 1 + rand.Intn(256)
53 | if n > len(tmp) {
54 | n = len(tmp)
55 | }
56 | n, err := c2.Read(tmp[:n])
57 | fatalOnErr(err, "CounterEncryptorRand.Read", t)
58 | tmp = tmp[n:]
59 | }
60 | if !bytes.Equal(out1, out2) {
61 | t.Error("identical CSPRNG does not produce the same output")
62 | }
63 | }
64 |
65 | func stddev(data []int) float64 {
66 | var sum, sum_sqr float64 = 0, 0
67 | for _, h := range data {
68 | sum += float64(h)
69 | sum_sqr += float64(h) * float64(h)
70 | }
71 | n := float64(len(data))
72 | variance := (sum_sqr - ((sum * sum) / n)) / (n - 1)
73 | return math.Sqrt(variance)
74 | }
75 |
76 | func TestCounterEncStreamHistogram(t *testing.T) {
77 | k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128)
78 | fatalOnErr(err, "rsa.GenerateKey", t)
79 | c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog"))
80 | fatalOnErr(err, "NewCounterEncryptorRandFromKey", t)
81 | nout := 100 * 1000
82 | out := make([]byte, nout)
83 | io.ReadFull(&c, out)
84 | refhist := make([]int, 512)
85 | for i := 0; i < nout; i++ {
86 | refhist[rand.Intn(256)]++
87 | }
88 | hist := make([]int, 512)
89 | for _, b := range out {
90 | hist[int(b)]++
91 | }
92 | refstddev, stddev := stddev(refhist), stddev(hist)
93 | // due to lack of time, I guestimate
94 | t.Logf("ref:%v - act:%v = %v", refstddev, stddev, math.Abs(refstddev-stddev))
95 | if math.Abs(refstddev-stddev) >= 1 {
96 | t.Errorf("stddev of ref histogram different than regular PRNG: %v %v", refstddev, stddev)
97 | }
98 | }
99 |
100 | func fatalOnErr(err error, msg string, t *testing.T) {
101 | if err != nil {
102 | t.Fatal(msg, err)
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/filter.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "net/http"
5 | "regexp"
6 | "strings"
7 | )
8 |
9 | // Filter is an request interceptor
10 | type Filter interface {
11 | Match(req *http.Request) bool
12 | }
13 |
14 | // A wrapper that would convert a function to a Filter interface type
15 | type FilterFunc func(req *http.Request) bool
16 |
17 | // Filter.Match(req) <=> FilterFunc(req)
18 | func (f FilterFunc) Match(req *http.Request) bool {
19 | return f(req)
20 | }
21 |
22 | // FilterHostMatches for request.Host
23 | func FilterHostMatches(regexps ...*regexp.Regexp) Filter {
24 | return FilterFunc(func(req *http.Request) bool {
25 | for _, re := range regexps {
26 | if re.MatchString(req.Host) {
27 | return true
28 | }
29 | }
30 | return false
31 | })
32 | }
33 |
34 | // FilterHostIs returns a Filter, testing whether the host to which the request is directed to equal
35 | // to one of the given strings
36 | func FilterHostIs(hosts ...string) Filter {
37 | hostSet := make(map[string]bool)
38 | for _, h := range hosts {
39 | hostSet[h] = true
40 | }
41 | return FilterFunc(func(req *http.Request) bool {
42 | _, ok := hostSet[req.URL.Host]
43 | return ok
44 | })
45 | }
46 |
47 | // FilterUrlMatches returns a Filter testing whether the destination URL
48 | // of the request matches the given regexp, with or without prefix
49 | func FilterUrlMatches(re *regexp.Regexp) Filter {
50 | return FilterFunc(func(req *http.Request) bool {
51 | return re.MatchString(req.URL.Path) ||
52 | re.MatchString(req.URL.Host+req.URL.Path)
53 | })
54 | }
55 |
56 | // FilterUrlHasPrefix returns a Filter checking wether the destination URL the proxy client has requested
57 | // has the given prefix, with or without the host.
58 | // For example FilterUrlHasPrefix("host/x") will match requests of the form 'GET host/x', and will match
59 | // requests to url 'http://host/x'
60 | func FilterUrlHasPrefix(prefix string) Filter {
61 | return FilterFunc(func(req *http.Request) bool {
62 | return strings.HasPrefix(req.URL.Path, prefix) ||
63 | strings.HasPrefix(req.URL.Host+req.URL.Path, prefix) ||
64 | strings.HasPrefix(req.URL.Scheme+req.URL.Host+req.URL.Path, prefix)
65 | })
66 | }
67 |
68 | // FilterUrlIs returns a Filter, testing whether or not the request URL is one of the given strings
69 | // with or without the host prefix.
70 | // FilterUrlIs("google.com/","foo") will match requests 'GET /' to 'google.com', requests `'GET google.com/' to
71 | // any host, and requests of the form 'GET foo'.
72 | func FilterUrlIs(urls ...string) Filter {
73 | urlSet := make(map[string]bool)
74 | for _, u := range urls {
75 | urlSet[u] = true
76 | }
77 | return FilterFunc(func(req *http.Request) bool {
78 | _, pathOk := urlSet[req.URL.Path]
79 | _, hostAndOk := urlSet[req.URL.Host+req.URL.Path]
80 | return pathOk || hostAndOk
81 | })
82 | }
83 |
--------------------------------------------------------------------------------
/filter_group.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import "net/http"
4 |
5 | type FilterGroup interface {
6 | Handle()
7 | }
8 |
9 | // ReqFilterGroup ReqCondition is a request filter group
10 | type ReqFilterGroup struct {
11 | ctx *Context
12 | filters []Filter
13 | }
14 |
15 | func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) {
16 | cond.Do(RequestHandleFunc(fn))
17 | }
18 |
19 | func (cond *ReqFilterGroup) Do(h RequestHandle) {
20 | cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
21 | total := len(cond.filters)
22 | for i := 0; i < total; i++ {
23 | if !cond.filters[i].Match(req) {
24 | return ctx.Next(req)
25 | }
26 | }
27 |
28 | req, resp := h.HandleRequest(req, ctx)
29 | if resp != nil {
30 | return resp, nil
31 | }
32 |
33 | return ctx.Next(req)
34 | })
35 | }
36 |
37 | // RespFilterGroup ReqCondition is a response filter group
38 | type RespFilterGroup struct {
39 | ctx *Context
40 | filters []Filter
41 | }
42 |
43 | func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, err error, ctx *Context) (*http.Response, error)) {
44 | cond.Do(ResponseHandleFunc(fn))
45 | }
46 |
47 | func (cond *RespFilterGroup) Do(h ResponseHandle) {
48 | cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
49 | total := len(cond.filters)
50 | for i := 0; i < total; i++ {
51 | if !cond.filters[i].Match(req) {
52 | return ctx.Next(req)
53 | }
54 | }
55 | resp, err := ctx.Next(req)
56 | return h.HandleResponse(resp, err, ctx)
57 | })
58 | }
59 |
--------------------------------------------------------------------------------
/forward_handler.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net/http"
7 | "net/http/httputil"
8 | "strconv"
9 |
10 | "github.com/telanflow/mps/pool"
11 | )
12 |
13 | // ForwardHandler The forward proxy type. Implements http.Handler.
14 | type ForwardHandler struct {
15 | Ctx *Context
16 | BufferPool httputil.BufferPool
17 | }
18 |
19 | // NewForwardHandler Create a forward proxy
20 | func NewForwardHandler() *ForwardHandler {
21 | return &ForwardHandler{
22 | Ctx: NewContext(),
23 | BufferPool: pool.DefaultBuffer,
24 | }
25 | }
26 |
27 | // NewForwardHandlerWithContext Create a ForwardHandler with Context
28 | func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
29 | return &ForwardHandler{
30 | Ctx: ctx,
31 | BufferPool: pool.DefaultBuffer,
32 | }
33 | }
34 |
35 | // Standard net/http function. You can use it alone
36 | func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
37 | // Copying a Context preserves the Transport, Middleware
38 | ctx := forward.Ctx.WithRequest(req)
39 | resp, err := ctx.Next(req)
40 | if err != nil {
41 | http.Error(rw, err.Error(), 502)
42 | return
43 | }
44 | defer resp.Body.Close()
45 |
46 | var (
47 | // Body buffer
48 | buffer = new(bytes.Buffer)
49 | // Body size
50 | bufferSize int64
51 | )
52 |
53 | buf := forward.buffer().Get()
54 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf)
55 | forward.buffer().Put(buf)
56 | if err != nil {
57 | http.Error(rw, err.Error(), 502)
58 | return
59 | }
60 |
61 | resp.ContentLength = bufferSize
62 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize)))
63 | copyHeaders(rw.Header(), resp.Header, forward.Ctx.KeepDestinationHeaders)
64 | rw.WriteHeader(resp.StatusCode)
65 | _, err = buffer.WriteTo(rw)
66 | }
67 |
68 | // Use registers an Middleware to proxy
69 | func (forward *ForwardHandler) Use(middleware ...Middleware) {
70 | forward.Ctx.Use(middleware...)
71 | }
72 |
73 | // UseFunc registers an MiddlewareFunc to proxy
74 | func (forward *ForwardHandler) UseFunc(fus ...MiddlewareFunc) {
75 | forward.Ctx.UseFunc(fus...)
76 | }
77 |
78 | // OnRequest filter requests through Filters
79 | func (forward *ForwardHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
80 | return &ReqFilterGroup{ctx: forward.Ctx, filters: filters}
81 | }
82 |
83 | // OnResponse filter response through Filters
84 | func (forward *ForwardHandler) OnResponse(filters ...Filter) *RespFilterGroup {
85 | return &RespFilterGroup{ctx: forward.Ctx, filters: filters}
86 | }
87 |
88 | // Transport
89 | func (forward *ForwardHandler) Transport() *http.Transport {
90 | return forward.Ctx.Transport
91 | }
92 |
93 | // Get buffer pool
94 | func (forward *ForwardHandler) buffer() httputil.BufferPool {
95 | if forward.BufferPool != nil {
96 | return forward.BufferPool
97 | }
98 | return pool.DefaultBuffer
99 | }
100 |
--------------------------------------------------------------------------------
/forward_handler_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/http/httptest"
7 | "net/url"
8 | "strconv"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestNewForwardHandler_ContentLength(t *testing.T) {
15 | srv := newTestServer()
16 | defer srv.Close()
17 |
18 | forwardHandler := NewForwardHandler()
19 | proxySrv := httptest.NewServer(forwardHandler)
20 | defer proxySrv.Close()
21 |
22 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
23 | return url.Parse(proxySrv.URL)
24 | })
25 | if err != nil {
26 | t.Fatal(err)
27 | }
28 | defer resp.Body.Close()
29 |
30 | body, _ := io.ReadAll(resp.Body)
31 | bodySize := len(body)
32 | contentLength, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
33 |
34 | asserts := assert.New(t)
35 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200")
36 | asserts.Equal(bodySize, contentLength, "Content-Length should be equal "+strconv.Itoa(bodySize))
37 | asserts.Equal(int64(bodySize), resp.ContentLength)
38 | }
39 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/telanflow/mps
2 |
3 | go 1.20
4 |
5 | require (
6 | github.com/gorilla/websocket v1.5.0
7 | github.com/stretchr/testify v1.8.4
8 | )
9 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
5 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
9 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
10 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
11 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
12 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
13 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
14 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
17 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
18 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
19 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
20 |
--------------------------------------------------------------------------------
/handle.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import "net/http"
4 |
5 | type Handle interface {
6 | RequestHandle
7 | ResponseHandle
8 | }
9 |
10 | type RequestHandle interface {
11 | HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response)
12 | }
13 |
14 | // A wrapper that would convert a function to a RequestHandle interface type
15 | type RequestHandleFunc func(req *http.Request, ctx *Context) (*http.Request, *http.Response)
16 |
17 | // RequestHandle.Handle(req, ctx) <=> RequestHandleFunc(req, ctx)
18 | func (f RequestHandleFunc) HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response) {
19 | return f(req, ctx)
20 | }
21 |
22 | type ResponseHandle interface {
23 | HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error)
24 | }
25 |
26 | // A wrapper that would convert a function to a ResponseHandle interface type
27 | type ResponseHandleFunc func(resp *http.Response, err error, ctx *Context) (*http.Response, error)
28 |
29 | // ResponseHandle.Handle(resp, ctx) <=> ResponseHandleFunc(resp, ctx)
30 | func (f ResponseHandleFunc) HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error) {
31 | return f(resp, err, ctx)
32 | }
33 |
--------------------------------------------------------------------------------
/http_proxy.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net"
7 | "net/http"
8 | )
9 |
10 | // The basic proxy type. Implements http.Handler.
11 | type HttpProxy struct {
12 | // Handles Connect requests use the TunnelHandler by default
13 | HandleConnect http.Handler
14 |
15 | // HTTP requests use the ForwardHandler by default
16 | HttpHandler http.Handler
17 |
18 | // HTTP requests use the ReverseHandler by default
19 | ReverseHandler http.Handler
20 |
21 | // Client request Context
22 | Ctx *Context
23 | }
24 |
25 | func NewHttpProxy() *HttpProxy {
26 | // default Context with Proxy
27 | ctx := NewContext()
28 | return &HttpProxy{
29 | Ctx: ctx,
30 | // default handles Connect method
31 | HandleConnect: &TunnelHandler{Ctx: ctx},
32 | // default handles HTTP request
33 | HttpHandler: &ForwardHandler{Ctx: ctx},
34 | // default Reverse proxy
35 | ReverseHandler: &ReverseHandler{Ctx: ctx},
36 | }
37 | }
38 |
39 | // Standard net/http function.
40 | func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
41 | if req.Method == http.MethodConnect {
42 | proxy.HandleConnect.ServeHTTP(rw, req)
43 | return
44 | }
45 |
46 | // reverse proxy http request for example:
47 | // GET / HTTP/1.1
48 | // Host: www.example.com
49 | // Connection: keep-alive
50 | //
51 | // forward proxy http request for example :
52 | // GET http://www.example.com/ HTTP/1.1
53 | // Host: www.example.com
54 | // Proxy-Connection: keep-alive
55 | //
56 | // Determines whether the path is absolute
57 | if !req.URL.IsAbs() {
58 | proxy.ReverseHandler.ServeHTTP(rw, req)
59 | } else {
60 | proxy.HttpHandler.ServeHTTP(rw, req)
61 | }
62 | }
63 |
64 | // Use registers an Middleware to proxy
65 | func (proxy *HttpProxy) Use(middleware ...Middleware) {
66 | proxy.Ctx.Use(middleware...)
67 | }
68 |
69 | // UseFunc registers an MiddlewareFunc to proxy
70 | func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) {
71 | proxy.Ctx.UseFunc(fus...)
72 | }
73 |
74 | // OnRequest filter requests through Filters
75 | func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqFilterGroup {
76 | return &ReqFilterGroup{ctx: proxy.Ctx, filters: filters}
77 | }
78 |
79 | // OnResponse filter response through Filters
80 | func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespFilterGroup {
81 | return &RespFilterGroup{ctx: proxy.Ctx, filters: filters}
82 | }
83 |
84 | // Transport get http.Transport instance
85 | func (proxy *HttpProxy) Transport() *http.Transport {
86 | return proxy.Ctx.Transport
87 | }
88 |
89 | // hijacker an HTTP handler to take over the connection.
90 | func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) {
91 | hij, ok := rw.(http.Hijacker)
92 | if !ok {
93 | err = errors.New("not a hijacker")
94 | return
95 | }
96 |
97 | conn, _, err = hij.Hijack()
98 | if err != nil {
99 | err = fmt.Errorf("cannot hijack connection %v", err)
100 | }
101 | return
102 | }
103 |
104 | func copyHeaders(dst, src http.Header, keepDestHeaders bool) {
105 | if !keepDestHeaders {
106 | for k := range dst {
107 | dst.Del(k)
108 | }
109 | }
110 | for k, vs := range src {
111 | for _, v := range vs {
112 | dst.Add(k, v)
113 | }
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/http_proxy_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "net/url"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func newTestServer() *httptest.Server {
15 | return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
16 | query := req.URL.Query()
17 | text := []byte("hello world")
18 | if query.Get("text") != "" {
19 | text = []byte(query.Get("text"))
20 | }
21 |
22 | rw.Header().Set("Server", "MPS proxy server")
23 | _, _ = rw.Write(text)
24 | }))
25 | }
26 |
27 | func HttpGet(rawurl string, proxy func(r *http.Request) (*url.URL, error)) (*http.Response, error) {
28 | req, _ := http.NewRequest(http.MethodGet, rawurl, nil)
29 | http.DefaultClient.Transport = &http.Transport{
30 | Proxy: proxy,
31 | }
32 | return http.DefaultClient.Do(req)
33 | }
34 |
35 | func TestNewHttpProxy(t *testing.T) {
36 | srv := newTestServer()
37 | defer srv.Close()
38 |
39 | proxy := NewHttpProxy()
40 | proxySrv := httptest.NewServer(proxy)
41 | defer proxySrv.Close()
42 |
43 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
44 | return url.Parse(proxySrv.URL)
45 | })
46 | if err != nil {
47 | t.Fatal(err)
48 | }
49 |
50 | body, _ := io.ReadAll(resp.Body)
51 | resp.Body.Close()
52 |
53 | asserts := assert.New(t)
54 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200")
55 | asserts.Equal(int64(len(body)), resp.ContentLength)
56 | }
57 |
58 | func TestMiddlewareFunc(t *testing.T) {
59 | // target server
60 | srv := newTestServer()
61 | defer srv.Close()
62 |
63 | // proxy server
64 | proxy := NewHttpProxy()
65 |
66 | // use Middleware
67 | proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
68 | resp, err := ctx.Next(req)
69 | if err != nil {
70 | return nil, err
71 | }
72 |
73 | var buf bytes.Buffer
74 | buf.WriteString("middleware")
75 | resp.Body = io.NopCloser(&buf)
76 |
77 | //
78 | // You have to reset Content-Length, if you change the Body.
79 | //resp.ContentLength = int64(buf.Len())
80 | //resp.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
81 |
82 | return resp, nil
83 | })
84 | proxySrv := httptest.NewServer(proxy)
85 | defer proxySrv.Close()
86 |
87 | // send request
88 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
89 | return url.Parse(proxySrv.URL)
90 | })
91 | if err != nil {
92 | t.Fatal(err)
93 | }
94 |
95 | body, _ := io.ReadAll(resp.Body)
96 | resp.Body.Close()
97 |
98 | asserts := assert.New(t)
99 | asserts.Equal(resp.StatusCode, 200)
100 | asserts.Equal(int64(len(body)), resp.ContentLength)
101 | asserts.Equal(string(body), "middleware")
102 | }
103 |
--------------------------------------------------------------------------------
/middleware.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import "net/http"
4 |
5 | // Middleware will "tamper" with the request coming to the proxy server
6 | type Middleware interface {
7 | // Handle execute the next middleware as a linked list. "ctx.Next(req)"
8 | // eg:
9 | // func Handle(req *http.Request, ctx *Context) (*http.Response, error) {
10 | // // You can do anything to modify the http.Request ...
11 | // resp, err := ctx.Next(req)
12 | // // You can do anything to modify the http.Response ...
13 | // return resp, err
14 | // }
15 | //
16 | // Alternatively, you can simply return the response without executing `ctx.Next()`,
17 | // which will interrupt subsequent middleware execution.
18 | Handle(req *http.Request, ctx *Context) (*http.Response, error)
19 | }
20 |
21 | // MiddlewareFunc A wrapper that would convert a function to a Middleware interface type
22 | type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error)
23 |
24 | // Handle Middleware.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
25 | func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) {
26 | return f(req, ctx)
27 | }
28 |
--------------------------------------------------------------------------------
/middleware/basicAuth.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "io"
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/telanflow/mps"
11 | )
12 |
13 | // proxy Authorization header
14 | const proxyAuthorization = "Proxy-Authorization"
15 |
16 | // BasicAuth returns a HTTP Basic Authentication middleware for requests
17 | // You probably want to use mps.BasicAuth(proxy) to enable authentication for all proxy activities
18 | func BasicAuth(realm string, fn func(username, password string) bool) mps.MiddlewareFunc {
19 | return func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
20 | auth := req.Header.Get(proxyAuthorization)
21 | if auth == "" {
22 | return BasicUnauthorized(req, realm), nil
23 | }
24 | // parses an Basic Authentication string.
25 | usr, pwd, ok := parseBasicAuth(auth)
26 | if !ok {
27 | return BasicUnauthorized(req, realm), nil
28 | }
29 | if !fn(usr, pwd) {
30 | return BasicUnauthorized(req, realm), nil
31 | }
32 | // Authorization passed
33 | return ctx.Next(req)
34 | }
35 | }
36 |
37 | // SetBasicAuth sets the request's Authorization header to use HTTP
38 | // Basic Authentication with the provided username and password.
39 | //
40 | // With HTTP Basic Authentication the provided username and password
41 | // are not encrypted.
42 | //
43 | // Some protocols may impose additional requirements on pre-escaping the
44 | // username and password. For instance, when used with OAuth2, both arguments
45 | // must be URL encoded first with url.QueryEscape.
46 | func SetBasicAuth(req *http.Request, username, password string) {
47 | req.Header.Set(proxyAuthorization, "Basic "+basicAuth(username, password))
48 | }
49 |
50 | // See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
51 | // "To receive authorization, the client sends the userid and password,
52 | // separated by a single colon (":") character, within a base64
53 | // encoded string in the credentials."
54 | // It is not meant to be urlencoded.
55 | func basicAuth(username, password string) string {
56 | auth := username + ":" + password
57 | return base64.StdEncoding.EncodeToString([]byte(auth))
58 | }
59 |
60 | // parseBasicAuth parses an HTTP Basic Authentication string.
61 | // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
62 | func parseBasicAuth(auth string) (username, password string, ok bool) {
63 | const prefix = "Basic "
64 | // Case insensitive prefix match. See Issue 22736.
65 | if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
66 | return
67 | }
68 | c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
69 | if err != nil {
70 | return
71 | }
72 | cs := string(c)
73 | s := strings.IndexByte(cs, ':')
74 | if s < 0 {
75 | return
76 | }
77 | return cs[:s], cs[s+1:], true
78 | }
79 |
80 | func BasicUnauthorized(req *http.Request, realm string) *http.Response {
81 | const unauthorizedMsg = "407 Proxy Authentication Required"
82 | // verify realm is well formed
83 | return &http.Response{
84 | StatusCode: 407,
85 | ProtoMajor: 1,
86 | ProtoMinor: 1,
87 | Request: req,
88 | Header: http.Header{
89 | "Proxy-Authenticate": []string{"Basic realm=" + realm},
90 | "Proxy-Connection": []string{"close"},
91 | },
92 | Body: io.NopCloser(bytes.NewBuffer([]byte(unauthorizedMsg))),
93 | ContentLength: int64(len(unauthorizedMsg)),
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/middleware/singleHostReverseProxy.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "strings"
7 |
8 | "github.com/telanflow/mps"
9 | )
10 |
11 | // SingleHostReverseProxy returns a mps.Middleware
12 | // URLs to the scheme, host, and base path provided in target. If the
13 | // target's path is "/base" and the incoming request was for "/dir",
14 | // the target request will be for /base/dir.
15 | // SingleHostReverseProxy does not rewrite the Host header.
16 | // To rewrite Host headers, use ReverseProxy directly with a custom
17 | // Director policy.
18 | func SingleHostReverseProxy(target *url.URL) mps.MiddlewareFunc {
19 | targetQuery := target.RawQuery
20 | return func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
21 | // changed request Host
22 | req.Host = target.Host
23 | // changed request URL
24 | req.URL.Scheme = target.Scheme
25 | req.URL.Host = target.Host
26 | req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
27 | if targetQuery == "" || req.URL.RawQuery == "" {
28 | req.URL.RawQuery = targetQuery + req.URL.RawQuery
29 | } else {
30 | req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
31 | }
32 | if _, ok := req.Header["User-Agent"]; !ok {
33 | // explicitly disable User-Agent so it's not set to default value
34 | req.Header.Set("User-Agent", "")
35 | }
36 | return ctx.Next(req)
37 | }
38 | }
39 |
40 | func singleJoiningSlash(a, b string) string {
41 | aslash := strings.HasSuffix(a, "/")
42 | bslash := strings.HasPrefix(b, "/")
43 | switch {
44 | case aslash && bslash:
45 | return a + b[1:]
46 | case !aslash && !bslash:
47 | return a + "/" + b
48 | }
49 | return a + b
50 | }
51 |
--------------------------------------------------------------------------------
/mitm_handler.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "crypto"
7 | "crypto/ecdsa"
8 | "crypto/elliptic"
9 | "crypto/rsa"
10 | "crypto/sha1"
11 | "crypto/tls"
12 | "crypto/x509"
13 | "crypto/x509/pkix"
14 | "errors"
15 | "fmt"
16 | "io"
17 | "math/big"
18 | "net"
19 | "net/http"
20 | "net/http/httputil"
21 | "net/url"
22 | "regexp"
23 | "sort"
24 | "strconv"
25 | "strings"
26 | "time"
27 |
28 | "github.com/telanflow/mps/cert"
29 | "github.com/telanflow/mps/pool"
30 | )
31 |
32 | var (
33 | HttpMitmOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n")
34 | httpsRegexp = regexp.MustCompile("^https://")
35 | )
36 |
37 | // MitmHandler The Man-in-the-middle proxy type. Implements http.Handler.
38 | type MitmHandler struct {
39 | Ctx *Context
40 | BufferPool httputil.BufferPool
41 | Certificate tls.Certificate
42 | // CertContainer is certificate storage container
43 | CertContainer cert.Container
44 | }
45 |
46 | // NewMitmHandler Create a mitmHandler, use default cert.
47 | func NewMitmHandler() *MitmHandler {
48 | return &MitmHandler{
49 | Ctx: NewContext(),
50 | BufferPool: pool.DefaultBuffer,
51 | Certificate: cert.DefaultCertificate,
52 | CertContainer: cert.NewMemProvider(),
53 | }
54 | }
55 |
56 | // NewMitmHandlerWithContext Create a MitmHandler, use default cert.
57 | func NewMitmHandlerWithContext(ctx *Context) *MitmHandler {
58 | return &MitmHandler{
59 | Ctx: ctx,
60 | BufferPool: pool.DefaultBuffer,
61 | Certificate: cert.DefaultCertificate,
62 | CertContainer: cert.NewMemProvider(),
63 | }
64 | }
65 |
66 | // NewMitmHandlerWithCert Create a MitmHandler with cert pem block
67 | func NewMitmHandlerWithCert(ctx *Context, certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) {
68 | certificate, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
69 | if err != nil {
70 | return nil, err
71 | }
72 | return &MitmHandler{
73 | Ctx: ctx,
74 | BufferPool: pool.DefaultBuffer,
75 | Certificate: certificate,
76 | CertContainer: cert.NewMemProvider(),
77 | }, nil
78 | }
79 |
80 | // NewMitmHandlerWithCertFile Create a MitmHandler with cert file
81 | func NewMitmHandlerWithCertFile(ctx *Context, certFile, keyFile string) (*MitmHandler, error) {
82 | certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
83 | if err != nil {
84 | return nil, err
85 | }
86 | return &MitmHandler{
87 | Ctx: ctx,
88 | BufferPool: pool.DefaultBuffer,
89 | Certificate: certificate,
90 | CertContainer: cert.NewMemProvider(),
91 | }, nil
92 | }
93 |
94 | // Standard net/http function. You can use it alone
95 | func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
96 | // execution middleware
97 | ctx := mitm.Ctx.WithRequest(req)
98 | resp, err := ctx.Next(req)
99 | if err != nil && !errors.Is(err, MethodNotSupportErr) {
100 | if resp != nil {
101 | copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders)
102 | rw.WriteHeader(resp.StatusCode)
103 | buf := mitm.buffer().Get()
104 | _, err = io.CopyBuffer(rw, resp.Body, buf)
105 | mitm.buffer().Put(buf)
106 | }
107 | return
108 | }
109 |
110 | // get hijacker connection
111 | clientConn, err := hijacker(rw)
112 | if err != nil {
113 | http.Error(rw, err.Error(), 502)
114 | return
115 | }
116 |
117 | // this goes in a separate goroutine, so that the net/http server won't think we're
118 | // still handling the request even after hijacking the connection. Those HTTP CONNECT
119 | // request can take forever, and the server will be stuck when "closed".
120 | // TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible
121 | tlsConfig, err := mitm.TLSConfigFromCA(req.URL.Host)
122 | if err != nil {
123 | ConnError(clientConn)
124 | return
125 | }
126 |
127 | _, _ = clientConn.Write(HttpMitmOk)
128 |
129 | // data transmit
130 | go mitm.transmit(clientConn, req, tlsConfig)
131 | }
132 |
133 | func (mitm *MitmHandler) transmit(clientConn net.Conn, originalReq *http.Request, tlsConfig *tls.Config) {
134 | // TODO: cache connections to the remote website
135 | rawClientTls := tls.Server(clientConn, tlsConfig)
136 | if err := rawClientTls.Handshake(); err != nil {
137 | ConnError(clientConn)
138 | _ = rawClientTls.Close()
139 | return
140 | }
141 | defer rawClientTls.Close()
142 |
143 | clientTlsReader := bufio.NewReader(rawClientTls)
144 | for !isEof(clientTlsReader) {
145 | req, err := http.ReadRequest(clientTlsReader)
146 | if err != nil {
147 | break
148 | }
149 |
150 | // since we're converting the request, need to carry over the original connecting IP as well
151 | req.RemoteAddr = originalReq.RemoteAddr
152 |
153 | if !httpsRegexp.MatchString(req.URL.String()) {
154 | req.URL, err = url.Parse("https://" + originalReq.Host + req.URL.String())
155 | }
156 | if err != nil {
157 | return
158 | }
159 |
160 | var resp *http.Response
161 |
162 | // Copying a Context preserves the Transport, Middleware
163 | ctx := mitm.Ctx.WithRequest(req)
164 | resp, err = ctx.Next(req)
165 | if err != nil {
166 | return
167 | }
168 |
169 | var (
170 | // Body buffer
171 | buffer = new(bytes.Buffer)
172 | // Body size
173 | bufferSize int64
174 | )
175 |
176 | buf := mitm.buffer().Get()
177 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf)
178 | mitm.buffer().Put(buf)
179 | if err != nil {
180 | _ = resp.Body.Close()
181 | return
182 | }
183 | _ = resp.Body.Close()
184 |
185 | // reset Content-Length
186 | resp.ContentLength = bufferSize
187 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize)))
188 | resp.Body = io.NopCloser(buffer)
189 | err = resp.Write(rawClientTls)
190 | if err != nil {
191 | return
192 | }
193 | }
194 | }
195 |
196 | // Use registers a Middleware to proxy
197 | func (mitm *MitmHandler) Use(middleware ...Middleware) {
198 | mitm.Ctx.Use(middleware...)
199 | }
200 |
201 | // UseFunc registers an MiddlewareFunc to proxy
202 | func (mitm *MitmHandler) UseFunc(fus ...MiddlewareFunc) {
203 | mitm.Ctx.UseFunc(fus...)
204 | }
205 |
206 | // OnRequest filter requests through Filters
207 | func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
208 | return &ReqFilterGroup{ctx: mitm.Ctx, filters: filters}
209 | }
210 |
211 | // OnResponse filter response through Filters
212 | func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespFilterGroup {
213 | return &RespFilterGroup{ctx: mitm.Ctx, filters: filters}
214 | }
215 |
216 | // Get buffer pool
217 | func (mitm *MitmHandler) buffer() httputil.BufferPool {
218 | if mitm.BufferPool != nil {
219 | return mitm.BufferPool
220 | }
221 | return pool.DefaultBuffer
222 | }
223 |
224 | // Get cert.Container instance
225 | func (mitm *MitmHandler) certContainer() cert.Container {
226 | if mitm.CertContainer != nil {
227 | return mitm.CertContainer
228 | }
229 | return cert.DefaultMemProvider
230 | }
231 |
232 | // Transport
233 | func (mitm *MitmHandler) Transport() *http.Transport {
234 | return mitm.Ctx.Transport
235 | }
236 |
237 | func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
238 | host = stripPort(host)
239 |
240 | // Returned existing certificate for the host
241 | crt, err := mitm.certContainer().Get(host)
242 | if err == nil && crt != nil {
243 | return &tls.Config{
244 | InsecureSkipVerify: true,
245 | Certificates: []tls.Certificate{*crt},
246 | }, nil
247 | }
248 |
249 | // Issue a certificate for host
250 | crt, err = signHost(mitm.Certificate, []string{host})
251 | if err != nil {
252 | err = fmt.Errorf("cannot sign host certificate with provided CA: %v", err)
253 | return nil, err
254 | }
255 |
256 | // Set certificate to container
257 | _ = mitm.certContainer().Set(host, crt)
258 |
259 | return &tls.Config{
260 | InsecureSkipVerify: true,
261 | Certificates: []tls.Certificate{*crt},
262 | }, nil
263 | }
264 |
265 | // sign host
266 | func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
267 | // Use the provided ca for certificate generation.
268 | var x509ca *x509.Certificate
269 | x509ca, err = x509.ParseCertificate(ca.Certificate[0])
270 | if err != nil {
271 | return
272 | }
273 |
274 | start := time.Unix(time.Now().Unix()-2592000, 0) // 2592000 = 30 day
275 | end := time.Unix(time.Now().Unix()+31536000, 0) // 31536000 = 365 day
276 |
277 | var random CounterEncryptorRand
278 | random, err = NewCounterEncryptorRand(ca.PrivateKey, hashHosts(hosts))
279 | if err != nil {
280 | return
281 | }
282 |
283 | var pk crypto.Signer
284 | switch ca.PrivateKey.(type) {
285 | case *rsa.PrivateKey:
286 | pk, err = rsa.GenerateKey(&random, 2048)
287 | case *ecdsa.PrivateKey:
288 | pk, err = ecdsa.GenerateKey(elliptic.P256(), &random)
289 | default:
290 | err = fmt.Errorf("unsupported key type %T", ca.PrivateKey)
291 | }
292 | if err != nil {
293 | return
294 | }
295 |
296 | // certificate template
297 | serial := big.NewInt(mpsRand.Int63())
298 | tpl := x509.Certificate{
299 | SerialNumber: serial,
300 | Issuer: x509ca.Subject,
301 | Subject: pkix.Name{
302 | Organization: []string{"MPS untrusted MITM proxy Inc"},
303 | },
304 | NotBefore: start,
305 | NotAfter: end,
306 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
307 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
308 | BasicConstraintsValid: true,
309 | EmailAddresses: x509ca.EmailAddresses,
310 | }
311 |
312 | total := len(hosts)
313 | for i := 0; i < total; i++ {
314 | if ip := net.ParseIP(hosts[i]); ip != nil {
315 | tpl.IPAddresses = append(tpl.IPAddresses, ip)
316 | } else {
317 | tpl.DNSNames = append(tpl.DNSNames, hosts[i])
318 | tpl.Subject.CommonName = hosts[i]
319 | }
320 | }
321 |
322 | var der []byte
323 | der, err = x509.CreateCertificate(&random, &tpl, x509ca, pk.Public(), ca.PrivateKey)
324 | if err != nil {
325 | return
326 | }
327 |
328 | cert = &tls.Certificate{
329 | Certificate: [][]byte{der, ca.Certificate[0]},
330 | PrivateKey: pk,
331 | }
332 | return
333 | }
334 |
335 | func stripPort(s string) string {
336 | var ix int
337 | if strings.Contains(s, "[") && strings.Contains(s, "]") {
338 | // ipv6 : for example : [2606:4700:4700::1111]:443
339 | // strip '[' and ']'
340 | s = strings.ReplaceAll(s, "[", "")
341 | s = strings.ReplaceAll(s, "]", "")
342 |
343 | ix = strings.LastIndexAny(s, ":")
344 | if ix == -1 {
345 | return s
346 | }
347 | } else {
348 | //ipv4
349 | ix = strings.IndexRune(s, ':')
350 | if ix == -1 {
351 | return s
352 | }
353 | }
354 | return s[:ix]
355 | }
356 |
357 | func hashHosts(lst []string) []byte {
358 | c := make([]string, len(lst))
359 | copy(c, lst)
360 | sort.Strings(c)
361 | h := sha1.New()
362 | h.Write([]byte(strings.Join(c, ",")))
363 | return h.Sum(nil)
364 | }
365 |
366 | // cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
367 | // cfg is nil. This is safe to call even if cfg is in active use by a TLS
368 | // client or server.
369 | func cloneTLSConfig(cfg *tls.Config) *tls.Config {
370 | if cfg == nil {
371 | return &tls.Config{
372 | InsecureSkipVerify: true,
373 | }
374 | }
375 | return cfg.Clone()
376 | }
377 |
378 | func isEof(r *bufio.Reader) bool {
379 | _, err := r.Peek(1)
380 | if err == io.EOF {
381 | return true
382 | }
383 | return false
384 | }
385 |
--------------------------------------------------------------------------------
/mitm_handler_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "crypto/tls"
5 | "crypto/x509"
6 | "io"
7 | "net/http"
8 | "net/http/httptest"
9 | "net/url"
10 | "testing"
11 |
12 | "github.com/stretchr/testify/assert"
13 | "github.com/telanflow/mps/cert"
14 | )
15 |
16 | func TestNewMitmHandler(t *testing.T) {
17 | mitmHandler := NewMitmHandler()
18 | mitmSrv := httptest.NewServer(mitmHandler)
19 | defer mitmSrv.Close()
20 |
21 | clientCertPool := x509.NewCertPool()
22 | ok := clientCertPool.AppendCertsFromPEM([]byte(cert.CertPEM))
23 | if !ok {
24 | panic("failed to parse root certificate")
25 | }
26 |
27 | req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil)
28 | http.DefaultClient.Transport = &http.Transport{
29 | Proxy: func(r *http.Request) (*url.URL, error) {
30 | return url.Parse(mitmSrv.URL)
31 | },
32 | TLSClientConfig: &tls.Config{
33 | Certificates: []tls.Certificate{cert.DefaultCertificate},
34 | ClientAuth: tls.RequireAndVerifyClientCert,
35 | RootCAs: clientCertPool,
36 | },
37 | }
38 |
39 | resp, err := http.DefaultClient.Do(req)
40 | if err != nil {
41 | t.Fatal(err)
42 | }
43 | defer resp.Body.Close()
44 |
45 | body, _ := io.ReadAll(resp.Body)
46 |
47 | asserts := assert.New(t)
48 | asserts.Equal(resp.StatusCode, 200, "response status code not equal 200")
49 | asserts.Equal(int64(len(body)), resp.ContentLength)
50 | }
51 |
--------------------------------------------------------------------------------
/mps.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "math/rand"
5 | "time"
6 | )
7 |
8 | var (
9 | // global random numbers for MPS. Go v1.20
10 | mpsRand = rand.New(rand.NewSource(time.Now().UnixNano()))
11 | )
12 |
--------------------------------------------------------------------------------
/pool/buffer.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import "sync"
4 |
5 | var DefaultBuffer = NewBuffer(2048)
6 |
7 | type Buffer struct {
8 | pl *sync.Pool
9 | size int
10 | }
11 |
12 | func NewBuffer(size int) *Buffer {
13 | bufPool := &Buffer{
14 | pl: nil,
15 | size: size,
16 | }
17 | bufPool.pl = &sync.Pool{
18 | New: bufPool.newPl,
19 | }
20 | return bufPool
21 | }
22 |
23 | func (b *Buffer) Get() []byte {
24 | return b.pl.Get().([]byte)
25 | }
26 |
27 | func (b *Buffer) Put(buf []byte) {
28 | b.pl.Put(buf)
29 | }
30 |
31 | func (b *Buffer) newPl() interface{} {
32 | return make([]byte, b.size)
33 | }
34 |
--------------------------------------------------------------------------------
/pool/conn_container.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import "net"
4 |
5 | // ConnContainer connection pool interface
6 | type ConnContainer interface {
7 | // Get returned a idle net.Conn
8 | Get(addr string) (net.Conn, error)
9 |
10 | // Put place a idle net.Conn into the pool
11 | Put(conn net.Conn) error
12 |
13 | // Release connection pool
14 | Release() error
15 | }
16 |
--------------------------------------------------------------------------------
/pool/conn_options.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import "time"
4 |
5 | var DefaultConnOptions = &ConnOptions{
6 | IdleMaxCap: 30,
7 | Timeout: 90 * time.Second,
8 | }
9 |
10 | // ConnOptions is ConnProvider options
11 | type ConnOptions struct {
12 | // IdleMaxCap is max connection capacity for a single net.Addr
13 | IdleMaxCap int
14 |
15 | // Timeout specifies how long the connection will timeout
16 | Timeout time.Duration
17 | }
18 |
--------------------------------------------------------------------------------
/pool/conn_provider.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "io"
7 | "net"
8 | "sync"
9 | "sync/atomic"
10 | "time"
11 | )
12 |
13 | var DefaultConnProvider = NewConnProvider(DefaultConnOptions)
14 |
15 | // ConnProvider is a connection pool, it implements ConnContainer
16 | type ConnProvider struct {
17 | mu sync.RWMutex
18 | idleConnMap map[string]chan net.Conn
19 | options *ConnOptions
20 | closed int32
21 | }
22 |
23 | // Create a ConnProvider
24 | func NewConnProvider(opt *ConnOptions) *ConnProvider {
25 | return &ConnProvider{
26 | options: opt,
27 | mu: sync.RWMutex{},
28 | idleConnMap: make(map[string]chan net.Conn),
29 | }
30 | }
31 |
32 | // Get returned a idle net.Conn
33 | func (p *ConnProvider) Get(addr string) (net.Conn, error) {
34 | closed := atomic.LoadInt32(&p.closed)
35 | if closed == 1 {
36 | return nil, errors.New("pool is closed")
37 | }
38 |
39 | p.mu.Lock()
40 | if _, ok := p.idleConnMap[addr]; !ok {
41 | p.mu.Unlock()
42 | return nil, errors.New("no idle conn")
43 | }
44 | p.mu.Unlock()
45 |
46 | RETRY:
47 | select {
48 | case conn := <-p.idleConnMap[addr]:
49 | // Getting a net.Conn requires verifying that the net.Conn is valid
50 | _, err := conn.Read([]byte{})
51 | if err != nil || err == io.EOF {
52 | // conn is close Or timeout
53 | _ = conn.Close()
54 | goto RETRY
55 | }
56 | return conn, nil
57 | default:
58 | return nil, errors.New("no idle conn")
59 | }
60 | }
61 |
62 | // Put place a idle net.Conn into the pool
63 | func (p *ConnProvider) Put(conn net.Conn) error {
64 | closed := atomic.LoadInt32(&p.closed)
65 | if closed == 1 {
66 | return errors.New("pool is closed")
67 | }
68 |
69 | addr := conn.RemoteAddr().String()
70 |
71 | p.mu.Lock()
72 | if _, ok := p.idleConnMap[addr]; !ok {
73 | p.idleConnMap[addr] = make(chan net.Conn, p.options.IdleMaxCap)
74 | }
75 | p.mu.Unlock()
76 |
77 | // set conn timeout
78 | // The timeout will be verified at the next `Get()`
79 | err := conn.SetDeadline(time.Now().Add(p.options.Timeout))
80 | if err != nil {
81 | return err
82 | }
83 |
84 | select {
85 | case p.idleConnMap[addr] <- conn:
86 | return nil
87 | default:
88 | return fmt.Errorf("beyond max capacity")
89 | }
90 | }
91 |
92 | // Release connection pool
93 | func (p *ConnProvider) Release() error {
94 | closed := atomic.LoadInt32(&p.closed)
95 | if closed == 1 {
96 | return errors.New("pool is closed")
97 | }
98 |
99 | atomic.StoreInt32(&p.closed, 1)
100 | for _, connChan := range p.idleConnMap {
101 | close(connChan)
102 | for conn, ok := <-connChan; ok; {
103 | _ = conn.Close()
104 | }
105 | }
106 | return nil
107 | }
108 |
--------------------------------------------------------------------------------
/reverse_handler.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net/http"
7 | "net/http/httputil"
8 | "strconv"
9 |
10 | "github.com/telanflow/mps/pool"
11 | )
12 |
13 | // ReverseHandler is a reverse proxy server implementation
14 | type ReverseHandler struct {
15 | Ctx *Context
16 | BufferPool httputil.BufferPool
17 | }
18 |
19 | // NewReverseHandler Create a reverse proxy
20 | func NewReverseHandler() *ReverseHandler {
21 | return &ReverseHandler{
22 | Ctx: NewContext(),
23 | BufferPool: pool.DefaultBuffer,
24 | }
25 | }
26 |
27 | // Standard net/http function. You can use it alone
28 | func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
29 | // Copying a Context preserves the Transport, Middleware
30 | ctx := reverse.Ctx.WithRequest(req)
31 | resp, err := ctx.Next(req)
32 | if err != nil {
33 | http.Error(rw, err.Error(), 502)
34 | return
35 | }
36 | defer resp.Body.Close()
37 |
38 | var (
39 | // Body buffer
40 | buffer = new(bytes.Buffer)
41 | // Body size
42 | bufferSize int64
43 | )
44 |
45 | buf := reverse.buffer().Get()
46 | bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf)
47 | reverse.buffer().Put(buf)
48 | if err != nil {
49 | http.Error(rw, err.Error(), 502)
50 | return
51 | }
52 |
53 | resp.ContentLength = bufferSize
54 | resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize)))
55 | copyHeaders(rw.Header(), resp.Header, reverse.Ctx.KeepDestinationHeaders)
56 | rw.WriteHeader(resp.StatusCode)
57 | _, err = buffer.WriteTo(rw)
58 | }
59 |
60 | // Use registers an Middleware to proxy
61 | func (reverse *ReverseHandler) Use(middleware ...Middleware) {
62 | reverse.Ctx.Use(middleware...)
63 | }
64 |
65 | // UseFunc registers an MiddlewareFunc to proxy
66 | func (reverse *ReverseHandler) UseFunc(fus ...MiddlewareFunc) {
67 | reverse.Ctx.UseFunc(fus...)
68 | }
69 |
70 | // OnRequest filter requests through Filters
71 | func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
72 | return &ReqFilterGroup{ctx: reverse.Ctx, filters: filters}
73 | }
74 |
75 | // OnResponse filter response through Filters
76 | func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespFilterGroup {
77 | return &RespFilterGroup{ctx: reverse.Ctx, filters: filters}
78 | }
79 |
80 | // Get buffer pool
81 | func (reverse *ReverseHandler) buffer() httputil.BufferPool {
82 | if reverse.BufferPool != nil {
83 | return reverse.BufferPool
84 | }
85 | return pool.DefaultBuffer
86 | }
87 |
88 | // Transport
89 | func (reverse *ReverseHandler) Transport() *http.Transport {
90 | return reverse.Ctx.Transport
91 | }
92 |
--------------------------------------------------------------------------------
/reverse_handler_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/http/httptest"
7 | "net/url"
8 | "strconv"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestNewReverseHandler(t *testing.T) {
15 | srv := newTestServer()
16 | defer srv.Close()
17 |
18 | reverseHandler := NewReverseHandler()
19 | proxySrv := httptest.NewServer(reverseHandler)
20 | defer proxySrv.Close()
21 |
22 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
23 | return url.Parse(proxySrv.URL)
24 | })
25 | if err != nil {
26 | t.Fatal(err)
27 | }
28 | defer resp.Body.Close()
29 |
30 | body, _ := io.ReadAll(resp.Body)
31 | bodySize := len(body)
32 | contentLength, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
33 |
34 | asserts := assert.New(t)
35 | asserts.Equal(resp.StatusCode, 200, "statusCode should be equal 200")
36 | asserts.Equal(bodySize, contentLength, "Content-Length should be equal "+strconv.Itoa(bodySize))
37 | asserts.Equal(int64(bodySize), resp.ContentLength)
38 | }
39 |
--------------------------------------------------------------------------------
/transport.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "crypto/tls"
5 | "net"
6 | "net/http"
7 | "time"
8 | )
9 |
10 | // Default http.Transport option
11 | var DefaultTransport = &http.Transport{
12 | DialContext: (&net.Dialer{
13 | Timeout: 15 * time.Second,
14 | KeepAlive: 30 * time.Second,
15 | DualStack: true,
16 | }).DialContext,
17 | ForceAttemptHTTP2: true,
18 | MaxIdleConns: 100,
19 | IdleConnTimeout: 90 * time.Second,
20 | TLSHandshakeTimeout: 10 * time.Second,
21 | ExpectContinueTimeout: 1 * time.Second,
22 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
23 | Proxy: http.ProxyFromEnvironment,
24 | }
25 |
--------------------------------------------------------------------------------
/tunnel_handler.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net"
7 | "net/http"
8 | "net/http/httputil"
9 | "net/url"
10 | "regexp"
11 | "time"
12 |
13 | "github.com/telanflow/mps/pool"
14 | )
15 |
16 | var (
17 | HttpTunnelOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n")
18 | HttpTunnelFail = []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")
19 | hasPort = regexp.MustCompile(`:\d+$`)
20 | )
21 |
22 | // TunnelHandler The tunnel proxy type. Implements http.Handler.
23 | type TunnelHandler struct {
24 | Ctx *Context
25 | BufferPool httputil.BufferPool
26 | ConnContainer pool.ConnContainer
27 | }
28 |
29 | // NewTunnelHandler Create a tunnel handler
30 | func NewTunnelHandler() *TunnelHandler {
31 | return &TunnelHandler{
32 | Ctx: NewContext(),
33 | BufferPool: pool.DefaultBuffer,
34 | }
35 | }
36 |
37 | // NewTunnelHandlerWithContext Create a tunnel handler with Context
38 | func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler {
39 | return &TunnelHandler{
40 | Ctx: ctx,
41 | BufferPool: pool.DefaultBuffer,
42 | }
43 | }
44 |
45 | // Standard net/http function. You can use it alone
46 | func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
47 | // execution middleware
48 | ctx := tunnel.Ctx.WithRequest(req)
49 | resp, err := ctx.Next(req)
50 | if err != nil && err != MethodNotSupportErr {
51 | if resp != nil {
52 | copyHeaders(rw.Header(), resp.Header, tunnel.Ctx.KeepDestinationHeaders)
53 | rw.WriteHeader(resp.StatusCode)
54 | buf := tunnel.buffer().Get()
55 | _, err = io.CopyBuffer(rw, resp.Body, buf)
56 | tunnel.buffer().Put(buf)
57 | }
58 | return
59 | }
60 |
61 | // hijacker connection
62 | proxyClient, err := hijacker(rw)
63 | if err != nil {
64 | http.Error(rw, err.Error(), 502)
65 | return
66 | }
67 |
68 | var (
69 | u *url.URL = nil
70 | targetConn net.Conn = nil
71 | targetAddr = hostAndPort(req.URL.Host)
72 | isCascadeProxy = false
73 | )
74 | if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil {
75 | u, err = tunnel.Ctx.Transport.Proxy(req)
76 | if err != nil {
77 | ConnError(proxyClient)
78 | return
79 | }
80 | if u != nil {
81 | // connect addr eg. "localhost:80"
82 | targetAddr = hostAndPort(u.Host)
83 | isCascadeProxy = true
84 | }
85 | }
86 |
87 | // connect to targetAddr
88 | targetConn, err = tunnel.connContainer().Get(targetAddr)
89 | if err != nil {
90 | targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
91 | if err != nil {
92 | ConnError(proxyClient)
93 | return
94 | }
95 | }
96 |
97 | // If the ConnContainer is exists,
98 | // When io.CopyBuffer is complete,
99 | // put the idle connection into the ConnContainer so can reuse it next time
100 | defer func() {
101 | err := tunnel.connContainer().Put(targetConn)
102 | if err != nil {
103 | // put conn fail, conn must be closed
104 | _ = targetConn.Close()
105 | }
106 | }()
107 |
108 | // The cascade proxy needs to forward the request
109 | if isCascadeProxy {
110 | // The cascade proxy needs to send it as-is
111 | _ = req.Write(targetConn)
112 | } else {
113 | // Tell client that the tunnel is ready
114 | _, _ = proxyClient.Write(HttpTunnelOk)
115 | }
116 |
117 | go func() {
118 | buf := tunnel.buffer().Get()
119 | _, _ = io.CopyBuffer(targetConn, proxyClient, buf)
120 | tunnel.buffer().Put(buf)
121 | _ = proxyClient.Close()
122 | }()
123 | buf := tunnel.buffer().Get()
124 | _, _ = io.CopyBuffer(proxyClient, targetConn, buf)
125 | tunnel.buffer().Put(buf)
126 | }
127 |
128 | // Use registers an Middleware to proxy
129 | func (tunnel *TunnelHandler) Use(middleware ...Middleware) {
130 | tunnel.Ctx.Use(middleware...)
131 | }
132 |
133 | // UseFunc registers an MiddlewareFunc to proxy
134 | func (tunnel *TunnelHandler) UseFunc(fus ...MiddlewareFunc) {
135 | tunnel.Ctx.UseFunc(fus...)
136 | }
137 |
138 | // OnRequest filter requests through Filters
139 | func (tunnel *TunnelHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
140 | return &ReqFilterGroup{ctx: tunnel.Ctx, filters: filters}
141 | }
142 |
143 | // OnResponse filter response through Filters
144 | func (tunnel *TunnelHandler) OnResponse(filters ...Filter) *RespFilterGroup {
145 | return &RespFilterGroup{ctx: tunnel.Ctx, filters: filters}
146 | }
147 |
148 | func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) {
149 | if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil {
150 | return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr)
151 | }
152 | return net.DialTimeout(network, addr, 30*time.Second)
153 | }
154 |
155 | // Transport get http.Transport instance
156 | func (tunnel *TunnelHandler) Transport() *http.Transport {
157 | return tunnel.Ctx.Transport
158 | }
159 |
160 | // get a context.Context
161 | func (tunnel *TunnelHandler) context() context.Context {
162 | if tunnel.Ctx.Context != nil {
163 | return tunnel.Ctx.Context
164 | }
165 | return context.Background()
166 | }
167 |
168 | // Get buffer pool
169 | func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
170 | if tunnel.BufferPool != nil {
171 | return tunnel.BufferPool
172 | }
173 | return pool.DefaultBuffer
174 | }
175 |
176 | // Get a conn pool
177 | func (tunnel *TunnelHandler) connContainer() pool.ConnContainer {
178 | if tunnel.ConnContainer != nil {
179 | return tunnel.ConnContainer
180 | }
181 | return pool.DefaultConnProvider
182 | }
183 |
184 | func hostAndPort(addr string) string {
185 | if !hasPort.MatchString(addr) {
186 | addr += ":80"
187 | }
188 | return addr
189 | }
190 |
191 | func ConnError(w net.Conn) {
192 | _, _ = w.Write(HttpTunnelFail)
193 | _ = w.Close()
194 | }
195 |
--------------------------------------------------------------------------------
/tunnel_handler_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "net/url"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestNewTunnelHandler(t *testing.T) {
13 | srv := newTestServer()
14 | defer srv.Close()
15 |
16 | tunnel := NewTunnelHandler()
17 | //tunnel.Transport().Proxy = func(r *http.Request) (*url.URL, error) {
18 | // return url.Parse("http://127.0.0.1:7890")
19 | //}
20 | tunnelSrv := httptest.NewServer(tunnel)
21 | defer tunnelSrv.Close()
22 |
23 | resp, err := HttpGet(srv.URL, func(r *http.Request) (*url.URL, error) {
24 | return url.Parse(tunnelSrv.URL)
25 | })
26 | if err != nil {
27 | t.Fatal(err)
28 | }
29 | resp.Body.Close()
30 |
31 | asserts := assert.New(t)
32 | asserts.Equal(resp.StatusCode, 200)
33 | }
34 |
--------------------------------------------------------------------------------
/websocket_handler.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "io"
7 | "net"
8 | "net/http"
9 | "net/http/httputil"
10 | "net/url"
11 | "strings"
12 | "time"
13 |
14 | "github.com/telanflow/mps/pool"
15 | )
16 |
17 | // WebsocketHandler The websocket proxy type. Implements http.Handler.
18 | type WebsocketHandler struct {
19 | Ctx *Context
20 | BufferPool httputil.BufferPool
21 | }
22 |
23 | // NewWebsocketHandler Create a websocket handler
24 | func NewWebsocketHandler() *WebsocketHandler {
25 | return &WebsocketHandler{
26 | Ctx: NewContext(),
27 | BufferPool: pool.DefaultBuffer,
28 | }
29 | }
30 |
31 | // NewWebsocketHandlerWithContext Create a tunnel handler with Context
32 | func NewWebsocketHandlerWithContext(ctx *Context) *WebsocketHandler {
33 | return &WebsocketHandler{
34 | Ctx: ctx,
35 | BufferPool: pool.DefaultBuffer,
36 | }
37 | }
38 |
39 | // Standard net/http function. You can use it alone
40 | func (ws *WebsocketHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
41 | // Whether to upgrade to Websocket
42 | if !isWebSocketRequest(req) {
43 | return
44 | }
45 |
46 | // hijacker connection
47 | clientConn, err := hijacker(rw)
48 | if err != nil {
49 | http.Error(rw, err.Error(), 502)
50 | return
51 | }
52 |
53 | var (
54 | u *url.URL
55 | targetAddr = hostAndPort(req.URL.Host)
56 | )
57 | if ws.Ctx.Transport != nil && ws.Ctx.Transport.Proxy != nil {
58 | u, err = ws.Ctx.Transport.Proxy(req)
59 | if err != nil {
60 | ConnError(clientConn)
61 | return
62 | }
63 | if u != nil {
64 | // connect addr eg. "localhost:443"
65 | targetAddr = hostAndPort(u.Host)
66 | }
67 | }
68 |
69 | targetConn, err := ws.ConnectDial("tcp", targetAddr)
70 | if err != nil {
71 | return
72 | }
73 | defer targetConn.Close()
74 |
75 | // Perform handshake
76 | // write handshake request to target
77 | err = req.Write(targetConn)
78 | if err != nil {
79 | return
80 | }
81 |
82 | // Read handshake response from target
83 | targetReader := bufio.NewReader(targetConn)
84 | resp, err := http.ReadResponse(targetReader, req)
85 | if err != nil {
86 | return
87 | }
88 |
89 | // Proxy handshake back to client
90 | err = resp.Write(clientConn)
91 | if err != nil {
92 | return
93 | }
94 |
95 | // Proxy ws connection
96 | go func() {
97 | buf := ws.buffer().Get()
98 | _, _ = io.CopyBuffer(targetConn, clientConn, buf)
99 | ws.buffer().Put(buf)
100 | _ = clientConn.Close()
101 | }()
102 | buf := ws.buffer().Get()
103 | _, _ = io.CopyBuffer(clientConn, targetConn, buf)
104 | ws.buffer().Put(buf)
105 | }
106 |
107 | func (ws *WebsocketHandler) ConnectDial(network, addr string) (net.Conn, error) {
108 | if ws.Ctx.Transport != nil && ws.Ctx.Transport.DialContext != nil {
109 | return ws.Ctx.Transport.DialContext(ws.context(), network, addr)
110 | }
111 | return net.DialTimeout(network, addr, 30*time.Second)
112 | }
113 |
114 | // Transport get http.Transport instance
115 | func (ws *WebsocketHandler) Transport() *http.Transport {
116 | return ws.Ctx.Transport
117 | }
118 |
119 | // context returned a context.Context
120 | func (ws *WebsocketHandler) context() context.Context {
121 | if ws.Ctx.Context != nil {
122 | return ws.Ctx.Context
123 | }
124 | return context.Background()
125 | }
126 |
127 | // buffer returned a httputil.BufferPool
128 | func (ws *WebsocketHandler) buffer() httputil.BufferPool {
129 | if ws.BufferPool != nil {
130 | return ws.BufferPool
131 | }
132 | return pool.DefaultBuffer
133 | }
134 |
135 | // isWebSocketRequest to upgrade to a Websocket request
136 | func isWebSocketRequest(req *http.Request) bool {
137 | return headerContains(req.Header, "Connection", "upgrade") &&
138 | headerContains(req.Header, "Upgrade", "websocket")
139 | }
140 |
141 | func headerContains(header http.Header, name string, value string) bool {
142 | for _, v := range header[name] {
143 | for _, s := range strings.Split(v, ",") {
144 | if strings.EqualFold(value, strings.TrimSpace(s)) {
145 | return true
146 | }
147 | }
148 | }
149 | return false
150 | }
151 |
--------------------------------------------------------------------------------
/websocket_handler_test.go:
--------------------------------------------------------------------------------
1 | package mps
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | "net/http/httptest"
7 | "net/url"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/gorilla/websocket"
12 | )
13 |
14 | var upgrader = websocket.Upgrader{}
15 |
16 | // create a test websocket server
17 | func newTestWebsocketServer() *httptest.Server {
18 | return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
19 | c, err := upgrader.Upgrade(rw, req, nil)
20 | if err != nil {
21 | return
22 | }
23 | defer c.Close()
24 | for {
25 | mt, message, err := c.ReadMessage()
26 | if err != nil {
27 | break
28 | }
29 | err = c.WriteMessage(mt, message)
30 | if err != nil {
31 | break
32 | }
33 | }
34 | }))
35 | }
36 |
37 | func TestNewWebsocketHandler(t *testing.T) {
38 | // create endPoint websocket server
39 | srv := newTestWebsocketServer()
40 | defer srv.Close()
41 |
42 | // Convert http://127.0.0.1 to ws://127.0.0.1
43 | endPoint := "ws" + strings.TrimPrefix(srv.URL, "http")
44 | log.Printf("endPoint: %s", endPoint)
45 |
46 | // create a proxy websocket server
47 | wsHandler := NewWebsocketHandler()
48 | wsHandler.Transport().Proxy = func(request *http.Request) (*url.URL, error) {
49 | return url.Parse(endPoint)
50 | }
51 | proxySrv := httptest.NewServer(wsHandler)
52 | defer proxySrv.Close()
53 |
54 | proxyWs := "ws" + strings.TrimPrefix(proxySrv.URL, "http")
55 | log.Printf("proxy: %s", proxyWs)
56 |
57 | // Connect to the proxy websocket server
58 | client, _, err := websocket.DefaultDialer.Dial(proxyWs, nil)
59 | if err != nil {
60 | t.Fatalf("%v", err)
61 | }
62 | defer client.Close()
63 |
64 | // Send message to server, read response and check to see if it's what we expect.
65 | for i := 0; i < 5; i++ {
66 | if err := client.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
67 | t.Fatalf("send fail: %v", err)
68 | }
69 |
70 | _, p, err := client.ReadMessage()
71 | if err != nil {
72 | t.Fatalf("read fail: %v", err)
73 | }
74 |
75 | log.Printf("recv: %s", string(p))
76 | if string(p) != "hello" {
77 | t.Fatalf("bad message")
78 | }
79 | }
80 | }
81 |
--------------------------------------------------------------------------------