├── .github └── workflows │ ├── codeql-analysis.yml │ └── go.yml ├── .gitignore ├── LICENSE ├── Procfile ├── README.md ├── cmd └── demo │ └── main.go ├── go.mod ├── go.sum ├── pkg ├── fragment │ ├── definition.go │ └── definition_test.go ├── middleware │ └── logging │ │ ├── logging.go │ │ └── logging_test.go ├── multiplexer │ ├── headers.go │ ├── headers_test.go │ ├── multiplexer.go │ ├── multiplexer_test.go │ ├── requestable.go │ ├── result.go │ └── tripper.go ├── routeimporter │ ├── config.go │ ├── config_test.go │ ├── http.go │ ├── http_test.go │ ├── json.go │ └── json_test.go └── secretfilter │ ├── secretfilter.go │ └── secretfilter_test.go ├── response_builder.go ├── route.go ├── route_test.go ├── server.go ├── server_test.go ├── stitch_structure.go └── stitch_structure_test.go /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '17 7 * * 4' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 37 | # Learn more: 38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 39 | 40 | steps: 41 | - name: Checkout repository 42 | uses: actions/checkout@v2 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v1 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v1 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://git.io/JvXDl 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v1 72 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v4 17 | with: 18 | go-version: 1.21.1 19 | 20 | - name: Build 21 | run: go build -v ./... 22 | 23 | - name: Test 24 | run: go test -v -race ./... 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | view-proxy 2 | log 3 | ./demo 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Blake Williams 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: demo 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # viewproxy 2 | 3 | `viewproxy` is a Go service that makes multiple requests to an application in parallel, fetching HTML content and stitching it together to serve to a user. 4 | 5 | This is alpha software, and is currently a proof of concept used in conjunction with Rails and View Component as a performance optimization. 6 | 7 | ## Usage 8 | 9 | See `cmd/demo/main.go` for an example of how to use the package. 10 | 11 | To use `viewproxy`: 12 | 13 | ```go 14 | import "github.com/blakewilliams/viewproxy" 15 | import "github.com/blakewilliams/viewproxy/pkg/fragment" 16 | 17 | // Create and configure a new Server Instance 18 | server := viewproxy.NewServer(target) 19 | server.Port = 3005 20 | server.ProxyTimeout = time.Duration(5) * time.Second 21 | server.PassThrough = true 22 | 23 | // Define a route with a :name parameter that will be forwarded to the target host. 24 | // This will make a layout request and 3 fragment requests, one for the header, hello, and footer. 25 | 26 | // GET http://localhost:3000/_view_fragments/layouts/my_layout?name=world 27 | myPage := fragment.Define("my_layout", fragment.WithChildren(fragment.Children{ 28 | "header": fragment.Define("header"), // GET http://localhost:3000/_view_fragments/header?name=world 29 | "hello": fragment.Define("hello"), // GET http://localhost:3000/_view_fragments/hello?name=world 30 | "footer" fragment.Define("footer"), // GET http://localhost:3000/_view_fragments/footer?name=world 31 | })) 32 | server.Get("/hello/:name", myPage) 33 | 34 | server.ListenAndServe() 35 | ``` 36 | 37 | Each child fragment is replaced in the parent fragment via a special tag, 38 | ``. For example, the `header` fragment will be inserted into the 39 | `my_layout` fragment by looking for the following content: ``. 40 | 41 | ## Demo Usage 42 | 43 | - The port the server is bound to `3005` by default but can be set via the `PORT` environment variable. 44 | - The target server can be set via the `TARGET` environment variable. 45 | - The default is `localhost:3000/_view_fragments` 46 | - `viewproxy` will call that end-point with the fragment name being passed as a query parameter. e.g. `localhost:3000/_view_fragments?fragment=header` 47 | 48 | To run `viewproxy`, run `go build ./cmd/demo && ./demo` 49 | 50 | ## Tracing with Open Telemetry 51 | 52 | You can use tracing to learn which fragment(s) are slowest for a given page, so you know where to optimize. 53 | 54 | To set up distributed tracing via [Open Telemetry](https://opentelemetry.io), [configure a tracing provider](https://opentelemetry.io/docs/instrumentation/go/getting-started/) in your application that uses viewproxy, and viewproxy will use the default trace provider to create spans. 55 | 56 | ### Tracing attributes via fragment metadata 57 | 58 | Each fragment can be configured with a static map of key/values, which will be set as tracing attributes when each fragment is fetched. 59 | 60 | ```go 61 | layout := fragment.Define("my_layout") 62 | server.Get("/hello/:name", layout, fragment.Collection{ 63 | fragment.Define("header", fragment.WithMetadata(map[string]string{"page": "homepage"})), // spans will have a "page" attribute with value "homepage" 64 | }) 65 | ``` 66 | 67 | ## Philosophy 68 | 69 | `viewproxy` is a simple service designed to sit between a browser request and a web application. It is used to break pages down into fragments that can be rendered in parallel for faster response times. 70 | 71 | - `viewproxy` is not coupled to a specific application framework, but _is_ being driven by close integration with Rails applications. 72 | - `viewproxy` should rely on Rails' (or other target application framework) strengths when possible. 73 | - `viewproxy` itself and its client API's should focus on developer happiness and productivity. 74 | 75 | ## Development 76 | 77 | Run the tests: 78 | 79 | ```sh 80 | go test ./... 81 | ``` 82 | -------------------------------------------------------------------------------- /cmd/demo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net/http" 8 | "os" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/blakewilliams/viewproxy" 13 | "github.com/blakewilliams/viewproxy/pkg/fragment" 14 | "github.com/blakewilliams/viewproxy/pkg/middleware/logging" 15 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 16 | ) 17 | 18 | func main() { 19 | target := getTarget() 20 | server, err := viewproxy.NewServer(target, viewproxy.WithPassThrough(target)) 21 | 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | server.Addr = fmt.Sprintf("localhost:%d", getPort()) 27 | server.ProxyTimeout = time.Duration(5) * time.Second 28 | server.Logger = buildLogger() 29 | 30 | server.Get( 31 | "/hello/:name", 32 | fragment.Define("/layout/:name", fragment.WithChildren(fragment.Children{ 33 | "body": fragment.Define("/body/:name", fragment.WithChildren(fragment.Children{ 34 | "header": fragment.Define("/header/:name", fragment.WithMetadata(map[string]string{"title": "Hello"})), 35 | "message": fragment.Define("/message/:name"), 36 | })), 37 | })), 38 | ) 39 | 40 | server.AroundResponse = func(h http.Handler) http.Handler { 41 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 42 | // Strip etag header from response 43 | rw.Header().Del("etag") 44 | h.ServeHTTP(rw, r) 45 | }) 46 | } 47 | 48 | // setup middleware 49 | server.AroundRequest = func(handler http.Handler) http.Handler { 50 | handler = logging.Middleware(server, server.Logger)(handler) 51 | 52 | return handler 53 | } 54 | 55 | server.MultiplexerTripper = logging.NewLogTripper( 56 | server.Logger, 57 | server.SecretFilter, 58 | multiplexer.NewStandardTripper(&http.Client{}), 59 | ) 60 | 61 | server.ListenAndServe() 62 | } 63 | 64 | func buildLogger() *log.Logger { 65 | file, err := os.OpenFile("log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666) 66 | defer file.Close() 67 | 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | return log.New(io.MultiWriter(os.Stdout, file), "", log.Ldate|log.Ltime) 72 | } 73 | 74 | func getPort() int { 75 | if _, ok := os.LookupEnv("PORT"); ok { 76 | port, err := strconv.Atoi(os.Getenv("PORT")) 77 | 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | return port 83 | } 84 | 85 | return 3005 86 | } 87 | 88 | func getTarget() string { 89 | if value, ok := os.LookupEnv("TARGET"); ok { 90 | return value 91 | } 92 | 93 | return "http://localhost:3000/" 94 | } 95 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/blakewilliams/viewproxy 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/stretchr/testify v1.8.4 7 | go.opentelemetry.io/otel v1.19.0 8 | go.opentelemetry.io/otel/trace v1.19.0 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/go-logr/logr v1.2.4 // indirect 14 | github.com/go-logr/stdr v1.2.2 // indirect 15 | github.com/kr/pretty v0.3.1 // indirect 16 | github.com/pmezard/go-difflib v1.0.0 // indirect 17 | github.com/rogpeppe/go-internal v1.10.0 // indirect 18 | go.opentelemetry.io/otel/metric v1.19.0 // indirect 19 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 5 | github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= 6 | github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 7 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 8 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 9 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 10 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 11 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 12 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 13 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 14 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 15 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 16 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 17 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 21 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 22 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 23 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 24 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 25 | go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= 26 | go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= 27 | go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE= 28 | go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319EUrDVLrt7jqt8= 29 | go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= 30 | go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 32 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 33 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /pkg/fragment/definition.go: -------------------------------------------------------------------------------- 1 | package fragment 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 9 | ) 10 | 11 | type Children = map[string]*Definition 12 | type Collection = []*Definition 13 | type DefinitionOption = func(*Definition) 14 | 15 | type Definition struct { 16 | Path string 17 | routeParts []string 18 | dynamicParts []string 19 | Metadata map[string]string 20 | IgnoreValidation bool 21 | children map[string]*Definition 22 | } 23 | 24 | func Define(path string, options ...DefinitionOption) *Definition { 25 | safePath := strings.TrimPrefix(path, "/") 26 | definition := &Definition{ 27 | Path: path, 28 | routeParts: strings.Split(safePath, "/"), 29 | Metadata: make(map[string]string), 30 | children: make(map[string]*Definition), 31 | } 32 | 33 | dynamicParts := make([]string, 0) 34 | for _, part := range definition.routeParts { 35 | if strings.HasPrefix(part, ":") { 36 | dynamicParts = append(dynamicParts, part) 37 | } 38 | } 39 | definition.dynamicParts = dynamicParts 40 | 41 | for _, option := range options { 42 | option(definition) 43 | } 44 | 45 | return definition 46 | } 47 | 48 | func (d *Definition) Children() map[string]*Definition { 49 | return d.children 50 | } 51 | 52 | func (d *Definition) Child(name string) *Definition { 53 | return d.children[name] 54 | } 55 | 56 | func WithChildren(children Children) DefinitionOption { 57 | return func(definition *Definition) { 58 | for name, child := range children { 59 | definition.children[name] = child 60 | } 61 | } 62 | } 63 | 64 | func WithChild(name string, child *Definition) DefinitionOption { 65 | return func(definition *Definition) { 66 | // TODO error if overwriting? 67 | definition.children[name] = child 68 | } 69 | } 70 | 71 | func WithoutValidation() DefinitionOption { 72 | return func(definition *Definition) { 73 | definition.IgnoreValidation = true 74 | } 75 | } 76 | 77 | func WithMetadata(metadata map[string]string) DefinitionOption { 78 | return func(definition *Definition) { 79 | definition.Metadata = metadata 80 | } 81 | } 82 | 83 | func (d *Definition) DynamicParts() []string { 84 | return d.dynamicParts 85 | } 86 | 87 | func (d *Definition) Requestable(target *url.URL, pathParams map[string]string, query url.Values) (*Request, error) { 88 | var path strings.Builder 89 | 90 | for _, part := range d.routeParts { 91 | path.WriteByte('/') 92 | 93 | if strings.HasPrefix(part, ":") { 94 | if replacement, ok := pathParams[part]; ok { 95 | path.WriteString(replacement) 96 | } else { 97 | return nil, fmt.Errorf("no parameter was provided for %s in route %s", part, d.Path) 98 | } 99 | } else { 100 | path.WriteString(part) 101 | } 102 | } 103 | 104 | requestURL, err := buildURL(target, path.String(), query.Encode()) 105 | if err != nil { 106 | return nil, err 107 | } 108 | 109 | templateURL, err := buildURL(target, strings.Join(d.routeParts, "/"), "") 110 | if err != nil { 111 | return nil, err 112 | } 113 | 114 | return &Request{ 115 | RequestURL: requestURL, 116 | Definition: d, 117 | templateURL: templateURL, 118 | }, nil 119 | } 120 | 121 | func buildURL(base *url.URL, path string, query string) (*url.URL, error) { 122 | unescapedPath, err := url.PathUnescape(path) 123 | if err != nil { 124 | return nil, fmt.Errorf("could not encode url: %w", err) 125 | } 126 | 127 | u := *base // clone the url 128 | u.RawQuery = query 129 | u.Path = unescapedPath // Set unescaped path which treats %2f as a / 130 | u.RawPath = path // Set RawPath which lets go correlate %2f to / in the Path, and escape correctly when calling String() 131 | 132 | return &u, nil 133 | } 134 | 135 | type Request struct { 136 | RequestURL *url.URL 137 | Definition *Definition 138 | templateURL *url.URL 139 | } 140 | 141 | var _ multiplexer.Requestable = &Request{} 142 | 143 | func (fr *Request) URL() string { return fr.RequestURL.String() } 144 | func (fr *Request) TemplateURL() string { return fr.templateURL.String() } 145 | func (fr *Request) Metadata() map[string]string { return fr.Definition.Metadata } 146 | -------------------------------------------------------------------------------- /pkg/fragment/definition_test.go: -------------------------------------------------------------------------------- 1 | package fragment 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | var target, _ = url.Parse("http://fake.net") 11 | 12 | func TestFragment_IntoRequestable(t *testing.T) { 13 | definition := Define("/hello/:name") 14 | requestable, err := definition.Requestable( 15 | target, 16 | map[string]string{":name": "fox.mulder"}, 17 | url.Values{}, 18 | ) 19 | require.NoError(t, err) 20 | 21 | require.Equal(t, "http://fake.net/hello/fox.mulder", requestable.URL()) 22 | require.Equal(t, "http://fake.net/hello/:name", requestable.TemplateURL()) 23 | } 24 | 25 | func TestFragment_IntoRequestable_MissingDynamicPart(t *testing.T) { 26 | definition := Define("/hello/:name") 27 | _, err := definition.Requestable( 28 | target, 29 | map[string]string{}, 30 | url.Values{}, 31 | ) 32 | require.Error(t, err) 33 | require.EqualError(t, err, "no parameter was provided for :name in route /hello/:name") 34 | } 35 | 36 | func TestFragment_IntoRequestable_HandlesURLEncodings(t *testing.T) { 37 | definition := Define("/hello/:name") 38 | requestable, err := definition.Requestable( 39 | target, 40 | map[string]string{":name": "mulder%2fscully"}, 41 | url.Values{}, 42 | ) 43 | require.NoError(t, err) 44 | require.Equal(t, "http://fake.net/hello/mulder%2fscully", requestable.URL()) 45 | require.Equal(t, "http://fake.net/hello/:name", requestable.TemplateURL()) 46 | } 47 | -------------------------------------------------------------------------------- /pkg/middleware/logging/logging.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/blakewilliams/viewproxy" 8 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 9 | secretfilter "github.com/blakewilliams/viewproxy/pkg/secretfilter" 10 | ) 11 | 12 | type logger interface { 13 | Print(v ...interface{}) 14 | Printf(format string, v ...interface{}) 15 | } 16 | 17 | type ResponseWrapper struct { 18 | responseWriter http.ResponseWriter 19 | StatusCode int 20 | } 21 | 22 | func (rw *ResponseWrapper) Header() http.Header { 23 | return rw.responseWriter.Header() 24 | } 25 | 26 | func (rw *ResponseWrapper) Write(p []byte) (int, error) { 27 | return rw.responseWriter.Write(p) 28 | } 29 | 30 | func (rw *ResponseWrapper) WriteHeader(statusCode int) { 31 | rw.StatusCode = statusCode 32 | rw.responseWriter.WriteHeader(statusCode) 33 | } 34 | 35 | func Middleware(server *viewproxy.Server, l logger) func(http.Handler) http.Handler { 36 | return func(next http.Handler) http.Handler { 37 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | start := time.Now() 39 | route := viewproxy.RouteFromContext(r.Context()) 40 | 41 | if route != nil { 42 | l.Printf("Handling %s", r.URL.Path) 43 | } else if server.PassThroughEnabled() { 44 | l.Printf("Proxying %s", r.URL.Path) 45 | } else { 46 | l.Printf("Proxying is disabled and no route matches %s", r.URL.Path) 47 | } 48 | 49 | wrapper := &ResponseWrapper{responseWriter: w, StatusCode: 200} // use default 200 to initialize 50 | next.ServeHTTP(wrapper, r) 51 | 52 | duration := time.Since(start) 53 | 54 | if route != nil { 55 | l.Printf("Rendered %d in %dms for %s", wrapper.StatusCode, duration.Milliseconds(), r.URL.Path) 56 | } else if server.PassThroughEnabled() { 57 | l.Printf("Proxied %d in %dms for %s", wrapper.StatusCode, duration.Milliseconds(), r.URL.Path) 58 | } 59 | }) 60 | } 61 | } 62 | 63 | type logTripper struct { 64 | logger logger 65 | secretFilter secretfilter.Filter 66 | tripper multiplexer.Tripper 67 | } 68 | 69 | func NewLogTripper(l logger, sf secretfilter.Filter, tripper multiplexer.Tripper) multiplexer.Tripper { 70 | return &logTripper{logger: l, secretFilter: sf, tripper: tripper} 71 | } 72 | 73 | func (t *logTripper) Request(r *http.Request) (*http.Response, error) { 74 | start := time.Now() 75 | res, err := t.tripper.Request(r) 76 | duration := time.Since(start) 77 | requestable := multiplexer.RequestableFromContext(r.Context()) 78 | 79 | if err != nil { 80 | if requestable != nil { 81 | // TODO fragment.URL is full path 82 | safeUrl := t.secretFilter.FilterURLString(requestable.URL()) 83 | t.logger.Printf("Fragment exception in %dms for %s\nerror: %s", duration.Milliseconds(), safeUrl, err) 84 | } else { 85 | safeUrl := t.secretFilter.FilterURL(r.URL) 86 | t.logger.Printf("Proxy exception in %dms for %s\nerror: %s", duration.Milliseconds(), safeUrl, err) 87 | } 88 | return nil, err 89 | } 90 | 91 | // If fragment is nil, we are proxying 92 | if requestable != nil { 93 | // TODO fragment.URL is full path 94 | safeUrl := t.secretFilter.FilterURLString(requestable.URL()) 95 | t.logger.Printf("Fragment %d in %dms for %s", res.StatusCode, duration.Milliseconds(), safeUrl) 96 | } else { 97 | safeUrl := t.secretFilter.FilterURL(r.URL) 98 | t.logger.Printf("Proxy request %d in %dms for %s", res.StatusCode, duration.Milliseconds(), safeUrl) 99 | } 100 | 101 | return res, err 102 | } 103 | -------------------------------------------------------------------------------- /pkg/middleware/logging/logging_test.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "regexp" 8 | "strings" 9 | "sync" 10 | "testing" 11 | 12 | "github.com/blakewilliams/viewproxy" 13 | "github.com/blakewilliams/viewproxy/pkg/fragment" 14 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 15 | "github.com/blakewilliams/viewproxy/pkg/secretfilter" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | type SliceLogger struct { 20 | logs []string 21 | mu sync.Mutex 22 | } 23 | 24 | func (l *SliceLogger) Print(v ...interface{}) { 25 | l.mu.Lock() 26 | defer l.mu.Unlock() 27 | l.logs = append(l.logs, fmt.Sprint(v...)) 28 | } 29 | 30 | func (l *SliceLogger) Printf(line string, args ...interface{}) { 31 | l.mu.Lock() 32 | defer l.mu.Unlock() 33 | l.logs = append(l.logs, fmt.Sprintf(line, args...)) 34 | } 35 | 36 | func TestLoggingMiddleware(t *testing.T) { 37 | targetServer := startTargetServer() 38 | viewProxyServer, err := viewproxy.NewServer(targetServer.URL) 39 | require.NoError(t, err) 40 | 41 | viewProxyServer.Get( 42 | "/hello/:name", 43 | fragment.Define( 44 | "/layouts/test_layout/:name", 45 | fragment.WithChild("body", fragment.Define("/body/:name")), 46 | ), 47 | ) 48 | 49 | log := &SliceLogger{logs: make([]string, 0)} 50 | viewProxyServer.AroundRequest = func(handler http.Handler) http.Handler { 51 | handler = Middleware(viewProxyServer, log)(handler) 52 | 53 | return handler 54 | } 55 | 56 | // Regular request with fragments 57 | r := httptest.NewRequest("GET", "/hello/world", nil) 58 | w := httptest.NewRecorder() 59 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 60 | resp := w.Result() 61 | require.Equal(t, 200, resp.StatusCode) 62 | 63 | require.Equal(t, "Handling /hello/world", log.logs[0]) 64 | require.Regexp(t, regexp.MustCompile(`Rendered 200 in \d+ms for /hello/world`), log.logs[1]) 65 | 66 | // Proxying disabled 67 | r = httptest.NewRequest("GET", "/fake", nil) 68 | w = httptest.NewRecorder() 69 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 70 | resp = w.Result() 71 | require.Equal(t, 404, resp.StatusCode) 72 | 73 | require.Equal(t, "Proxying is disabled and no route matches /fake", log.logs[2]) 74 | } 75 | 76 | func TestLogTripperFragments(t *testing.T) { 77 | targetServer := startTargetServer() 78 | viewProxyServer, err := viewproxy.NewServer(targetServer.URL, viewproxy.WithPassThrough(targetServer.URL)) 79 | require.NoError(t, err) 80 | 81 | viewProxyServer.Get( 82 | "/hello/:name", 83 | fragment.Define("/layouts/test_layout/:name", fragment.WithChild("body", fragment.Define("/body/:name"))), 84 | ) 85 | 86 | log := &SliceLogger{logs: make([]string, 0)} 87 | viewProxyServer.MultiplexerTripper = NewLogTripper(log, secretfilter.New(), multiplexer.NewStandardTripper(&http.Client{})) 88 | 89 | r := httptest.NewRequest("GET", "/hello/world", nil) 90 | w := httptest.NewRecorder() 91 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 92 | resp := w.Result() 93 | require.Equal(t, 200, resp.StatusCode) 94 | 95 | require.Regexp(t, regexp.MustCompile(`Fragment 200 in \d+ms for http:\/\/.*`), log.logs[0]) 96 | require.Regexp(t, regexp.MustCompile(`Fragment 200 in \d+ms for http:\/\/.*`), log.logs[1]) 97 | } 98 | 99 | func startTargetServer() *httptest.Server { 100 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 101 | parts := strings.Split(r.URL.Path, "/") 102 | name := parts[len(parts)-1] 103 | 104 | if strings.HasPrefix(r.URL.Path, "/layouts/test_layout/") { 105 | w.WriteHeader(http.StatusOK) 106 | w.Write([]byte("")) 107 | } else if strings.HasPrefix(r.URL.Path, "/header") { 108 | w.WriteHeader(http.StatusOK) 109 | w.Write([]byte("")) 110 | } else if strings.HasPrefix(r.URL.Path, "/body/") { 111 | w.WriteHeader(http.StatusOK) 112 | w.Write([]byte(fmt.Sprintf("hello %s", name))) 113 | } else { 114 | w.WriteHeader(http.StatusNotFound) 115 | w.Write([]byte("404 not found")) 116 | } 117 | }) 118 | 119 | testServer := httptest.NewServer(instance) 120 | return testServer 121 | } 122 | -------------------------------------------------------------------------------- /pkg/multiplexer/headers.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | ) 8 | 9 | // Hop-by-hop headers defined here: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers 10 | var HopByHopHeaders []string = []string{ 11 | "Connection", 12 | "Keep-Alive", 13 | "Proxy-Authenticate", 14 | "Proxy-Authorization", 15 | "TE", 16 | "Trailers", 17 | "Transfer-Encoding", 18 | "Upgrade", 19 | } 20 | 21 | // TODO remove headers listed in the Connection header 22 | func HeadersFromRequest(req *http.Request) http.Header { 23 | newHeaders := make(http.Header) 24 | 25 | for name, values := range req.Header { 26 | newHeaders[name] = values 27 | } 28 | 29 | for _, hopByHopHeader := range HopByHopHeaders { 30 | newHeaders.Del(hopByHopHeader) 31 | } 32 | 33 | // Set Forwarded-For headers since we act as a proxy 34 | host := forwardedForFromRequest(req) 35 | if val := newHeaders.Get("X-Forwarded-For"); val != "" { 36 | newHeader := fmt.Sprintf("%s, %s", val, host) 37 | newHeaders.Set("X-Forwarded-For", newHeader) 38 | } else { 39 | newHeaders.Set("X-Forwarded-For", host) 40 | } 41 | 42 | // go strips the host header for some reason 43 | // https://github.com/golang/go/blob/master/src/net/http/server.go#L999 44 | newHeaders.Set("Host", req.Host) 45 | 46 | if val := newHeaders.Get("X-Forwarded-Host"); val == "" { 47 | newHeaders.Set("X-Forwarded-Host", req.Host) 48 | } 49 | if val := newHeaders.Get("X-Forwarded-Proto"); val == "" { 50 | newHeaders.Set("X-Forwarded-Proto", req.Proto) 51 | } 52 | 53 | return newHeaders 54 | } 55 | 56 | func forwardedForFromRequest(req *http.Request) string { 57 | host, _, err := net.SplitHostPort(req.RemoteAddr) 58 | 59 | if err != nil { 60 | return req.RemoteAddr 61 | } 62 | 63 | return host 64 | } 65 | 66 | func WithDefaultHeaders(next http.Handler) http.Handler { 67 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 68 | results := ResultsFromContext(r.Context()) 69 | 70 | if results != nil && len(results.Results()) > 0 { 71 | headers := results.Results()[0].HeadersWithoutProxyHeaders() 72 | for name, values := range headers { 73 | for _, value := range values { 74 | rw.Header().Add(name, value) 75 | } 76 | } 77 | 78 | rw.Header().Del("Content-Length") 79 | } 80 | 81 | next.ServeHTTP(rw, r) 82 | }) 83 | } 84 | -------------------------------------------------------------------------------- /pkg/multiplexer/headers_test.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestPreservesForwardedHeaders(t *testing.T) { 11 | headers := http.Header{} 12 | headers.Add("X-Forwarded-For", "1.2.3.4") 13 | headers.Add("X-Forwarded-Host", "example.com") 14 | headers.Add("X-Forwarded-Proto", "httpz") 15 | fakeHTTPRequest := &http.Request{Header: headers} 16 | fakeHTTPRequest.RemoteAddr = "1.3.5.7" 17 | 18 | newHeaders := HeadersFromRequest(fakeHTTPRequest) 19 | 20 | // append X-Forwarded-For 21 | require.Equal(t, "1.2.3.4, 1.3.5.7", newHeaders.Get("X-Forwarded-For")) 22 | 23 | // preserve X-Forwarded-Host and X-Forwarded-Proto 24 | require.Equal(t, "example.com", newHeaders.Get("X-Forwarded-Host")) 25 | require.Equal(t, "httpz", newHeaders.Get("X-Forwarded-Proto")) 26 | } 27 | 28 | func TestSetsDefaultForwardedHeaders(t *testing.T) { 29 | fakeHTTPRequest := &http.Request{} 30 | fakeHTTPRequest.Proto = "httpz" 31 | fakeHTTPRequest.Host = "example.com" 32 | fakeHTTPRequest.RemoteAddr = "1.3.5.7" 33 | 34 | newHeaders := HeadersFromRequest(fakeHTTPRequest) 35 | 36 | // append X-Forwarded-For 37 | require.Equal(t, "1.3.5.7", newHeaders.Get("X-Forwarded-For")) 38 | 39 | // set default X-Forwarded-Host and X-Forwarded-Proto 40 | require.Equal(t, "example.com", newHeaders.Get("X-Forwarded-Host")) 41 | require.Equal(t, "httpz", newHeaders.Get("X-Forwarded-Proto")) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/multiplexer/multiplexer.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import ( 4 | "compress/gzip" 5 | "context" 6 | "crypto/hmac" 7 | "crypto/sha256" 8 | "encoding/hex" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "io/ioutil" 13 | "net/http" 14 | "net/url" 15 | "sync" 16 | "time" 17 | 18 | "github.com/blakewilliams/viewproxy/pkg/secretfilter" 19 | "go.opentelemetry.io/otel" 20 | "go.opentelemetry.io/otel/attribute" 21 | "go.opentelemetry.io/otel/trace" 22 | ) 23 | 24 | type TimeoutError struct { 25 | inner error 26 | } 27 | 28 | func (et *TimeoutError) Error() string { 29 | return fmt.Sprintf("multiplexer timed out: %s", et.inner) 30 | } 31 | 32 | func (et *TimeoutError) Unwrap() error { 33 | return et.inner 34 | } 35 | 36 | var _ error = &TimeoutError{} 37 | 38 | func newTimeoutError(inner error) *TimeoutError { 39 | return &TimeoutError{inner: inner} 40 | } 41 | 42 | type ErrRequestCanceled struct { 43 | inner error 44 | } 45 | 46 | func (ec *ErrRequestCanceled) Error() string { 47 | return fmt.Sprintf("multiplexer request was canceled: %s", ec.inner) 48 | } 49 | 50 | func (ec *ErrRequestCanceled) Unwrap() error { 51 | return ec.inner 52 | } 53 | 54 | func newCancellationError(inner error) *ErrRequestCanceled { 55 | return &ErrRequestCanceled{inner: inner} 56 | } 57 | 58 | type Request struct { 59 | ctx context.Context 60 | Header http.Header 61 | requestables []Requestable 62 | Timeout time.Duration 63 | HmacSecret string 64 | Non2xxErrors bool 65 | Tripper Tripper 66 | SecretFilter secretfilter.Filter 67 | } 68 | 69 | func NewRequest(tripper Tripper) *Request { 70 | return &Request{ 71 | ctx: context.TODO(), 72 | requestables: []Requestable{}, 73 | Timeout: time.Duration(10) * time.Second, 74 | HmacSecret: "", 75 | Non2xxErrors: true, 76 | Header: http.Header{}, 77 | Tripper: tripper, 78 | } 79 | } 80 | 81 | func (r *Request) WithHeadersFromRequest(req *http.Request) { 82 | for key, values := range HeadersFromRequest(req) { 83 | for _, value := range values { 84 | r.Header.Add(key, value) 85 | } 86 | } 87 | } 88 | 89 | func (r *Request) WithRequestable(requestable Requestable) { 90 | r.requestables = append(r.requestables, requestable) 91 | } 92 | 93 | func (r *Request) Do(ctx context.Context) ([]*Result, error) { 94 | tracer := otel.Tracer("multiplexer") 95 | var span trace.Span 96 | ctx, span = tracer.Start(ctx, "fetch_urls") 97 | defer span.End() 98 | 99 | ctx, cancel := context.WithTimeout(ctx, r.Timeout) 100 | defer cancel() 101 | 102 | reqCount := len(r.requestables) 103 | wg := sync.WaitGroup{} 104 | wg.Add(reqCount) 105 | errCh := make(chan error, reqCount) 106 | results := make([]*Result, reqCount) 107 | 108 | for i, f := range r.requestables { 109 | reqCtx := context.WithValue(ctx, RequestableContextKey{}, f) 110 | 111 | go func(ctx context.Context, requestable Requestable, i int, wg *sync.WaitGroup) { 112 | defer wg.Done() 113 | var span trace.Span 114 | ctx, span = tracer.Start(ctx, "fetch_url") 115 | for key, value := range requestable.Metadata() { 116 | span.SetAttributes(attribute.String(key, value)) 117 | } 118 | defer span.End() 119 | 120 | headersForRequest := r.Header 121 | if r.HmacSecret != "" { 122 | headersForRequest = r.headersWithHmac(requestable.URL()) 123 | } 124 | 125 | result, err := r.fetchUrl(ctx, "GET", requestable, headersForRequest, nil) 126 | 127 | if err != nil { 128 | errCh <- r.filterError(requestable.TemplateURL(), err) 129 | } 130 | 131 | results[i] = result 132 | }(reqCtx, f, i, &wg) 133 | } 134 | 135 | // wait for all responses to complete 136 | done := make(chan struct{}) 137 | go (func(wg *sync.WaitGroup) { 138 | defer close(done) 139 | wg.Wait() 140 | })(&wg) 141 | 142 | select { 143 | case err := <-errCh: 144 | cancel() 145 | return make([]*Result, 0), err 146 | case <-done: 147 | return results, nil 148 | case <-ctx.Done(): 149 | switch { 150 | case errors.Is(ctx.Err(), context.Canceled): 151 | return make([]*Result, 0), newCancellationError(ctx.Err()) 152 | case errors.Is(ctx.Err(), context.DeadlineExceeded): 153 | return make([]*Result, 0), newTimeoutError(ctx.Err()) 154 | default: 155 | return make([]*Result, 0), ctx.Err() 156 | } 157 | } 158 | } 159 | 160 | func (r *Request) fetchUrl(ctx context.Context, method string, requestable Requestable, headers http.Header, body io.ReadCloser) (*Result, error) { 161 | start := time.Now() 162 | 163 | req, err := http.NewRequestWithContext(ctx, method, requestable.URL(), body) 164 | 165 | if err != nil { 166 | return nil, err 167 | } 168 | 169 | for name, values := range headers { 170 | for _, value := range values { 171 | req.Header.Add(name, value) 172 | } 173 | } 174 | 175 | resp, err := r.Tripper.Request(req) 176 | 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | defer resp.Body.Close() 182 | duration := time.Since(start) 183 | 184 | var responseBody []byte 185 | 186 | if resp.Header.Get("Content-Encoding") == "gzip" { 187 | gzipReader, err := gzip.NewReader(resp.Body) 188 | if err != nil { 189 | return nil, err 190 | } 191 | defer gzipReader.Close() 192 | 193 | responseBody, err = ioutil.ReadAll(gzipReader) 194 | 195 | if err != nil { 196 | return nil, err 197 | } 198 | } else { 199 | responseBody, err = ioutil.ReadAll(resp.Body) 200 | 201 | if err != nil { 202 | return nil, err 203 | } 204 | } 205 | 206 | result := &Result{ 207 | Url: requestable.URL(), 208 | Duration: duration, 209 | HttpResponse: resp, 210 | Body: responseBody, 211 | StatusCode: resp.StatusCode, 212 | } 213 | 214 | if r.Non2xxErrors && (resp.StatusCode < 200 || resp.StatusCode > 299) { 215 | return nil, newResultError(requestable.TemplateURL(), r, result) 216 | } 217 | 218 | return result, nil 219 | } 220 | 221 | func (r *Request) headersWithHmac(url string) http.Header { 222 | newHeaders := http.Header{} 223 | for name, value := range r.Header { 224 | newHeaders[name] = value 225 | } 226 | 227 | timestamp := fmt.Sprintf("%d", time.Now().Unix()) 228 | 229 | mac := hmac.New(sha256.New, []byte(r.HmacSecret)) 230 | mac.Write( 231 | []byte(fmt.Sprintf("%s,%s", pathFromFullUrl(url), timestamp)), 232 | ) 233 | 234 | newHeaders.Set("Authorization", hex.EncodeToString(mac.Sum(nil))) 235 | newHeaders.Set("X-Authorization-Time", timestamp) 236 | 237 | return newHeaders 238 | } 239 | 240 | func (r *Request) filterError(errURL string, err error) error { 241 | var urlErr *url.Error 242 | if errors.As(err, &urlErr) { 243 | return r.SecretFilter.FilterURLError(errURL, urlErr) 244 | } 245 | 246 | return err 247 | } 248 | 249 | func pathFromFullUrl(fullUrl string) string { 250 | targetUrl, _ := url.Parse(fullUrl) 251 | 252 | if targetUrl.RawQuery != "" { 253 | return fmt.Sprintf("%s?%s", targetUrl.Path, targetUrl.RawQuery) 254 | } else { 255 | return targetUrl.Path 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /pkg/multiplexer/multiplexer_test.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "testing" 10 | "time" 11 | 12 | "github.com/blakewilliams/viewproxy/pkg/secretfilter" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | var defaultTimeout = time.Duration(5) * time.Second 17 | 18 | type fakeRequestable struct { 19 | templateURL string 20 | url string 21 | } 22 | 23 | func (ff *fakeRequestable) URL() string { return ff.url } 24 | func (ff *fakeRequestable) TemplateURL() string { return ff.templateURL } 25 | func (ff *fakeRequestable) Metadata() map[string]string { return make(map[string]string) } 26 | func newFakeRequestable(url string) *fakeRequestable { 27 | return &fakeRequestable{url: url, templateURL: url} 28 | } 29 | 30 | var _ Requestable = &fakeRequestable{} 31 | 32 | func TestRequestDoReturnsMultipleResponsesInOrder(t *testing.T) { 33 | server := startServer(t) 34 | urls := []string{"http://localhost:9990?fragment=header", "http://localhost:9990?fragment=footer"} 35 | 36 | r := newRequest() 37 | r.WithRequestable(newFakeRequestable(urls[0])) 38 | r.WithRequestable(newFakeRequestable(urls[1])) 39 | r.Timeout = defaultTimeout 40 | results, err := r.Do(context.TODO()) 41 | 42 | require.Nil(t, err) 43 | 44 | require.Equal(t, 2, len(results), "Expected 2 results") 45 | 46 | require.Equal(t, 200, results[0].StatusCode) 47 | require.Equal(t, "", string(results[0].Body), "Expected first result body to be opening body tag") 48 | require.Equal(t, urls[0], results[0].Url) 49 | require.Greater(t, results[0].Duration, time.Duration(0)) 50 | 51 | require.Equal(t, 200, results[1].StatusCode) 52 | require.Equal(t, "", string(results[1].Body), "Expected last result body to be closing body tag") 53 | require.Equal(t, urls[1], results[1].Url) 54 | require.Greater(t, results[1].Duration, time.Duration(0)) 55 | 56 | server.Close() 57 | } 58 | 59 | func TestRequestDoForwardsHeaders(t *testing.T) { 60 | server := startServer(t) 61 | headers := http.Header{} 62 | headers.Add("X-Name", "viewproxy") 63 | 64 | fakeHTTPRequest := &http.Request{Header: headers} 65 | 66 | r := newRequest() 67 | r.WithRequestable(newFakeRequestable("http://localhost:9990?fragment=echo_headers")) 68 | r.WithHeadersFromRequest(fakeHTTPRequest) 69 | r.Timeout = defaultTimeout 70 | results, err := r.Do(context.TODO()) 71 | 72 | require.Nil(t, err) 73 | 74 | require.Contains(t, string(results[0].Body), "X-Name:viewproxy", "Expected X-Name header to be present") 75 | 76 | server.Close() 77 | } 78 | 79 | func TestFetch404ReturnsError(t *testing.T) { 80 | server := startServer(t) 81 | 82 | r := newRequest() 83 | r.WithRequestable(newFakeRequestable("http://localhost:9990/wowomg")) 84 | r.Timeout = defaultTimeout 85 | results, err := r.Do(context.TODO()) 86 | 87 | var resultErr *ResultError 88 | require.ErrorAs(t, err, &resultErr) 89 | require.Equal(t, 404, resultErr.Result.StatusCode) 90 | require.Equal(t, "http://localhost:9990/wowomg", resultErr.Result.Url) 91 | require.Equal(t, 0, len(results), "Expected 0 results") 92 | 93 | server.Close() 94 | } 95 | 96 | func TestResultErrorMessagesFilterUrls(t *testing.T) { 97 | server := startServer(t) 98 | 99 | r := newRequest() 100 | req := newFakeRequestable("http://localhost:9990/wowomg?foo=bar") 101 | req.templateURL = "http://localhost:9990/:name" 102 | r.WithRequestable(req) 103 | r.Timeout = defaultTimeout 104 | _, err := r.Do(context.TODO()) 105 | 106 | var resultErr *ResultError 107 | require.ErrorAs(t, err, &resultErr) 108 | require.Equal(t, "status: 404 url: http://localhost:9990/:name?foo=FILTERED", resultErr.Error()) 109 | 110 | server.Close() 111 | } 112 | 113 | func TestRequestErrorMessagesFilterUrls(t *testing.T) { 114 | server := startServer(t) 115 | 116 | r := newRequest() 117 | req := newFakeRequestable("http://localhost:9990/wowomg?fragment=bad_gateway&foo=bar") 118 | req.templateURL = "http://localhost:9990/:name?fragment=bad_gateway&foo=bar" 119 | r.WithRequestable(req) 120 | r.Timeout = defaultTimeout 121 | _, err := r.Do(context.TODO()) 122 | 123 | require.Error(t, err) 124 | require.Equal(t, "Get \"http://localhost:9990/:name?foo=FILTERED&fragment=FILTERED\": EOF", err.Error()) 125 | 126 | server.Close() 127 | } 128 | 129 | func TestFetch500ReturnsError(t *testing.T) { 130 | server := startServer(t) 131 | start := time.Now() 132 | 133 | urls := []string{"http://localhost:9990/?fragment=oops", "http://localhost:9990?fragment=slow"} 134 | r := newRequest() 135 | r.WithRequestable(newFakeRequestable(urls[0])) 136 | r.WithRequestable(newFakeRequestable(urls[1])) 137 | results, err := r.Do(context.TODO()) 138 | 139 | duration := time.Since(start) 140 | 141 | require.Less(t, duration, time.Duration(3)*time.Second) 142 | var resultErr *ResultError 143 | require.ErrorAs(t, err, &resultErr) 144 | require.Equal(t, 500, resultErr.Result.StatusCode) 145 | require.Equal(t, "http://localhost:9990/?fragment=oops", resultErr.Result.Url) 146 | require.Equal(t, 0, len(results), "Expected 0 results") 147 | 148 | server.Close() 149 | } 150 | 151 | func TestFetchTimeout(t *testing.T) { 152 | server := startServer(t) 153 | start := time.Now() 154 | 155 | r := newRequest() 156 | r.WithRequestable(newFakeRequestable("http://localhost:9990?fragment=slow")) 157 | r.Timeout = time.Duration(100) * time.Millisecond 158 | _, err := r.Do(context.Background()) 159 | duration := time.Since(start) 160 | 161 | require.EqualError(t, err, "multiplexer timed out: context deadline exceeded") 162 | require.Less(t, duration, time.Duration(120)*time.Millisecond) 163 | 164 | server.Close() 165 | } 166 | 167 | func TestFetchCancelled(t *testing.T) { 168 | server := startServer(t) 169 | defer server.Close() 170 | 171 | r := newRequest() 172 | r.WithRequestable(newFakeRequestable("http://localhost:9990?fragment=slow")) 173 | 174 | ctx, cancel := context.WithCancel(context.Background()) 175 | cancel() 176 | 177 | _, err := r.Do(ctx) 178 | 179 | require.EqualError(t, err, "multiplexer request was canceled: context canceled") 180 | } 181 | 182 | func TestCanIgnoreNon2xxErrors(t *testing.T) { 183 | server := startServer(t) 184 | 185 | r := newRequest() 186 | r.WithRequestable(newFakeRequestable("http://localhost:9990/?fragment=oops")) 187 | r.Non2xxErrors = false 188 | 189 | results, err := r.Do(context.Background()) 190 | 191 | require.Nil(t, err) 192 | require.Len(t, results, 1) 193 | require.Equal(t, 500, results[0].StatusCode) 194 | 195 | server.Close() 196 | } 197 | 198 | func startServer(t *testing.T) *http.Server { 199 | var testServer *http.Server 200 | 201 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 | params := r.URL.Query() 203 | fragment := params.Get("fragment") 204 | 205 | if fragment == "header" { 206 | w.Write([]byte("")) 207 | } else if fragment == "body" { 208 | w.Write([]byte(fmt.Sprintf("hello %s", params.Get("name")))) 209 | } else if fragment == "footer" { 210 | w.Write([]byte("")) 211 | } else if fragment == "slow" { 212 | time.Sleep(time.Duration(3) * time.Second) 213 | w.Write([]byte("")) 214 | } else if fragment == "oops" { 215 | w.WriteHeader(http.StatusInternalServerError) 216 | w.Write([]byte("500")) 217 | } else if fragment == "echo_headers" { 218 | for name, values := range r.Header { 219 | for _, value := range values { 220 | w.Write( 221 | []byte(fmt.Sprintf("%s:%s\n", name, value)), 222 | ) 223 | } 224 | } 225 | } else if fragment == "bad_gateway" { 226 | testServer.Close() 227 | } else { 228 | w.WriteHeader(http.StatusNotFound) 229 | w.Write([]byte("Not found")) 230 | } 231 | }) 232 | 233 | listener, err := net.Listen("tcp", "localhost:9990") 234 | require.NoError(t, err) 235 | 236 | testServer = &http.Server{Handler: instance} 237 | go func() { 238 | if err := testServer.Serve(listener); err != nil && err != http.ErrServerClosed { 239 | require.NoError(t, err) 240 | } 241 | }() 242 | 243 | return testServer 244 | } 245 | 246 | func TestTimeoutError(t *testing.T) { 247 | originalError := errors.New("omg") 248 | err := newTimeoutError(originalError) 249 | 250 | require.Equal(t, "multiplexer timed out: omg", err.Error()) 251 | require.Equal(t, originalError, err.Unwrap()) 252 | } 253 | 254 | func newRequest() *Request { 255 | r := NewRequest(NewStandardTripper(&http.Client{})) 256 | r.SecretFilter = secretfilter.New() 257 | return r 258 | } 259 | -------------------------------------------------------------------------------- /pkg/multiplexer/requestable.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import "context" 4 | 5 | type RequestableContextKey struct{} 6 | 7 | type Requestable interface { 8 | URL() string 9 | TemplateURL() string 10 | Metadata() map[string]string 11 | } 12 | 13 | func RequestableFromContext(ctx context.Context) Requestable { 14 | if ctx == nil { 15 | return nil 16 | } 17 | 18 | if requestable := ctx.Value(RequestableContextKey{}); requestable != nil { 19 | requestable := requestable.(Requestable) 20 | return requestable 21 | } 22 | return nil 23 | } 24 | -------------------------------------------------------------------------------- /pkg/multiplexer/result.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | type ResultError struct { 11 | Result *Result 12 | msg string 13 | } 14 | 15 | type Results interface { 16 | Error() error 17 | Results() []*Result 18 | } 19 | 20 | func newResultError(errURL string, req *Request, res *Result) *ResultError { 21 | safeUrl := req.SecretFilter.FilterURLStringThrough(res.Url, errURL) 22 | msg := fmt.Sprintf("status: %d url: %s", res.StatusCode, safeUrl) 23 | 24 | return &ResultError{Result: res, msg: msg} 25 | } 26 | 27 | func (re *ResultError) Error() string { 28 | return re.msg 29 | } 30 | 31 | type Result struct { 32 | Url string 33 | Duration time.Duration 34 | HttpResponse *http.Response 35 | Body []byte 36 | StatusCode int 37 | } 38 | 39 | func (r *Result) Header() http.Header { 40 | return r.HttpResponse.Header 41 | } 42 | 43 | func (r *Result) HeadersWithoutProxyHeaders() http.Header { 44 | headers := make(http.Header) 45 | 46 | for name, values := range r.Header() { 47 | headers[name] = values 48 | } 49 | 50 | for _, hopByHopHeader := range HopByHopHeaders { 51 | headers.Del(hopByHopHeader) 52 | } 53 | 54 | return headers 55 | } 56 | 57 | type resultsWrapper struct { 58 | err error 59 | results []*Result 60 | startTime time.Time 61 | } 62 | 63 | func (r *resultsWrapper) Results() []*Result { 64 | return r.results 65 | } 66 | 67 | func (r *resultsWrapper) Error() error { 68 | return r.err 69 | } 70 | 71 | func (r *resultsWrapper) StartTime() time.Time { 72 | return r.startTime 73 | } 74 | 75 | type resultsContextKey struct{} 76 | 77 | func ResultsFromContext(ctx context.Context) Results { 78 | if ctx == nil { 79 | return nil 80 | } 81 | 82 | if results := ctx.Value(resultsContextKey{}); results != nil { 83 | return results.(Results) 84 | } 85 | return nil 86 | } 87 | 88 | func ContextWithResults(ctx context.Context, results []*Result, err error) context.Context { 89 | return context.WithValue(ctx, resultsContextKey{}, &resultsWrapper{results: results, err: err}) 90 | } 91 | -------------------------------------------------------------------------------- /pkg/multiplexer/tripper.go: -------------------------------------------------------------------------------- 1 | package multiplexer 2 | 3 | import "net/http" 4 | 5 | type Tripper interface { 6 | Request(r *http.Request) (*http.Response, error) 7 | } 8 | 9 | type standardTripper struct { 10 | client *http.Client 11 | } 12 | 13 | // Creates a new instance of a Tripper. The passed in client is modified to 14 | // have no cookie jar and to not follow redirects. 15 | func NewStandardTripper(client *http.Client) Tripper { 16 | client.Jar = nil 17 | client.CheckRedirect = func(req *http.Request, via []*http.Request) error { 18 | return http.ErrUseLastResponse 19 | } 20 | 21 | return &standardTripper{client: client} 22 | } 23 | 24 | func (t *standardTripper) Request(r *http.Request) (*http.Response, error) { 25 | return t.client.Do(r) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/routeimporter/config.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "github.com/blakewilliams/viewproxy" 5 | "github.com/blakewilliams/viewproxy/pkg/fragment" 6 | ) 7 | 8 | type ConfigFragment struct { 9 | Path string 10 | Metadata map[string]string 11 | IgnoreValidation bool 12 | Children map[string]ConfigFragment 13 | } 14 | 15 | type ConfigRouteEntry struct { 16 | Path string 17 | Root ConfigFragment `json:"root"` 18 | Metadata map[string]string `json:"metadata"` 19 | IgnoreValidation bool 20 | } 21 | 22 | func LoadRoutes(server *viewproxy.Server, routeEntries []ConfigRouteEntry) error { 23 | for _, routeEntry := range routeEntries { 24 | root := createFragment(routeEntry.Root) 25 | 26 | err := server.Get( 27 | routeEntry.Path, 28 | root, 29 | viewproxy.WithRouteMetadata(routeEntry.Metadata), 30 | ) 31 | 32 | if err != nil { 33 | return err 34 | } 35 | } 36 | 37 | return nil 38 | } 39 | 40 | func createFragment(template ConfigFragment) *fragment.Definition { 41 | f := fragment.Define(template.Path, fragment.WithMetadata(template.Metadata)) 42 | f.IgnoreValidation = template.IgnoreValidation 43 | 44 | for name, child := range template.Children { 45 | fragment.WithChild(name, createFragment(child))(f) 46 | } 47 | 48 | return f 49 | } 50 | -------------------------------------------------------------------------------- /pkg/routeimporter/config_test.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/blakewilliams/viewproxy" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestLoadRoutesError(t *testing.T) { 11 | server, err := viewproxy.NewServer("localhost:9999") 12 | require.NoError(t, err) 13 | 14 | entry := ConfigRouteEntry{ 15 | Path: "/foo/bar", 16 | Root: ConfigFragment{Path: "/layout/:name"}, 17 | } 18 | 19 | err = LoadRoutes(server, []ConfigRouteEntry{entry}) 20 | require.Error(t, err) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/routeimporter/http.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "context" 5 | "crypto/hmac" 6 | "crypto/sha256" 7 | "encoding/hex" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "net/url" 13 | "time" 14 | 15 | "github.com/blakewilliams/viewproxy" 16 | ) 17 | 18 | func LoadHttp(ctx context.Context, server *viewproxy.Server, path string) error { 19 | var routeEntries []ConfigRouteEntry 20 | 21 | target, err := url.Parse(server.Target()) 22 | 23 | if err != nil { 24 | return fmt.Errorf("could not parse target: %w", err) 25 | } 26 | 27 | target.Path = path 28 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, target.String(), nil) 29 | 30 | if err != nil { 31 | return fmt.Errorf("Could not create a request when loading config: %w", err) 32 | } 33 | 34 | if server.HmacSecret != "" { 35 | setHmacHeaders(req, server.HmacSecret) 36 | } 37 | 38 | resp, err := http.DefaultClient.Do(req) 39 | 40 | if err != nil { 41 | return fmt.Errorf("could not fetch JSON configuration: %w", err) 42 | } 43 | 44 | routesJson, err := io.ReadAll(resp.Body) 45 | 46 | if err != nil { 47 | return fmt.Errorf("could not read route config response body: %w", err) 48 | } 49 | 50 | if err := json.Unmarshal(routesJson, &routeEntries); err != nil { 51 | return fmt.Errorf("could not unmarshal route config json: %w", err) 52 | } 53 | 54 | if err = LoadRoutes(server, routeEntries); err != nil { 55 | return fmt.Errorf("could not load routes into server: %w", err) 56 | } 57 | 58 | return ctx.Err() 59 | } 60 | 61 | func setHmacHeaders(r *http.Request, hmacSecret string) { 62 | timestamp := fmt.Sprintf("%d", time.Now().Unix()) 63 | 64 | mac := hmac.New(sha256.New, []byte(hmacSecret)) 65 | mac.Write( 66 | []byte(fmt.Sprintf("%s,%s", r.URL.Path, timestamp)), 67 | ) 68 | 69 | r.Header.Set("Authorization", hex.EncodeToString(mac.Sum(nil))) 70 | r.Header.Set("X-Authorization-Time", timestamp) 71 | } 72 | -------------------------------------------------------------------------------- /pkg/routeimporter/http_test.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "context" 5 | "crypto/hmac" 6 | "crypto/sha256" 7 | "encoding/hex" 8 | "fmt" 9 | "io/ioutil" 10 | "log" 11 | "net/http" 12 | "net/http/httptest" 13 | "testing" 14 | "time" 15 | 16 | "github.com/blakewilliams/viewproxy" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | var jsonConfig = []byte(`[ 21 | { 22 | "path": "/users/new", 23 | "metadata": { 24 | "controller": "sessions" 25 | }, 26 | "root": { 27 | "path": "/_viewproxy/users/new/layout", 28 | "children": { 29 | "content": { 30 | "path": "/_viewproxy/users/new/content" 31 | } 32 | } 33 | } 34 | } 35 | ]`) 36 | 37 | func TestLoadHttp(t *testing.T) { 38 | targetServer := startTargetServer() 39 | defer targetServer.CloseClientConnections() 40 | defer targetServer.Close() 41 | 42 | viewproxyServer, err := viewproxy.NewServer(targetServer.URL) 43 | require.NoError(t, err) 44 | viewproxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 45 | 46 | err = LoadHttp(context.TODO(), viewproxyServer, "/_viewproxy_routes") 47 | require.NoError(t, err) 48 | 49 | requireJsonConfigRoutesLoaded(t, viewproxyServer.Routes()) 50 | } 51 | 52 | func TestLoadHttp_ContextTimeout(t *testing.T) { 53 | targetServer := startTargetServer() 54 | defer targetServer.CloseClientConnections() 55 | defer targetServer.Close() 56 | 57 | viewproxyServer, err := viewproxy.NewServer(targetServer.URL) 58 | require.NoError(t, err) 59 | viewproxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 60 | 61 | start := time.Now() 62 | ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20) 63 | defer cancel() 64 | 65 | err = LoadHttp(ctx, viewproxyServer, "/_viewproxy_routes?sleepy=1") 66 | require.Error(t, err) 67 | 68 | <-ctx.Done() 69 | duration := time.Now().Sub(start) 70 | 71 | require.LessOrEqual(t, duration, time.Millisecond*40) 72 | } 73 | 74 | func TestLoadHttp_HMAC(t *testing.T) { 75 | hmacSecret := "abc123" 76 | 77 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | authorization := r.Header.Get("Authorization") 79 | timestamp := r.Header.Get("X-Authorization-Time") 80 | 81 | require.NotEmpty(t, authorization) 82 | require.NotEmpty(t, timestamp) 83 | 84 | mac := hmac.New(sha256.New, []byte(hmacSecret)) 85 | mac.Write( 86 | []byte(fmt.Sprintf("%s,%s", r.URL.Path, timestamp)), 87 | ) 88 | 89 | require.Equal(t, hex.EncodeToString(mac.Sum(nil)), authorization) 90 | 91 | w.Write(jsonConfig) 92 | }) 93 | 94 | testServer := httptest.NewServer(instance) 95 | defer testServer.CloseClientConnections() 96 | defer testServer.Close() 97 | 98 | viewproxyServer, err := viewproxy.NewServer(testServer.URL) 99 | require.NoError(t, err) 100 | viewproxyServer.HmacSecret = hmacSecret 101 | viewproxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 102 | 103 | err = LoadHttp(context.TODO(), viewproxyServer, "/_viewproxy_routes") 104 | require.NoError(t, err) 105 | 106 | requireJsonConfigRoutesLoaded(t, viewproxyServer.Routes()) 107 | } 108 | 109 | func startTargetServer() *httptest.Server { 110 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 111 | if r.URL.Query().Get("sleepy") == "1" { 112 | time.Sleep(time.Millisecond * 500) 113 | } 114 | 115 | if r.URL.Path == "/_viewproxy_routes" { 116 | w.Header().Set("Content-Type", "text/json") 117 | w.WriteHeader(http.StatusOK) 118 | 119 | w.Write(jsonConfig) 120 | } else { 121 | w.WriteHeader(http.StatusNotFound) 122 | w.Write([]byte("target: 404 not found")) 123 | } 124 | }) 125 | 126 | testServer := httptest.NewServer(instance) 127 | return testServer 128 | } 129 | 130 | func requireJsonConfigRoutesLoaded(t *testing.T, routes []viewproxy.Route) { 131 | require.Len(t, routes, 1) 132 | route := routes[0] 133 | 134 | require.Equal(t, "/users/new", route.Path) 135 | require.Equal(t, "sessions", route.Metadata["controller"]) 136 | require.Len(t, route.FragmentsToRequest(), 2) 137 | 138 | require.Contains(t, "/_viewproxy/users/new/layout", route.RootFragment.Path) 139 | require.Contains(t, "/_viewproxy/users/new/content", route.RootFragment.Child("content").Path) 140 | } 141 | -------------------------------------------------------------------------------- /pkg/routeimporter/json.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | 9 | "github.com/blakewilliams/viewproxy" 10 | ) 11 | 12 | func LoadJSONFile(server *viewproxy.Server, filepath string) error { 13 | file, err := os.Open(filepath) 14 | 15 | if err != nil { 16 | return fmt.Errorf("could not open config file: %w", err) 17 | } 18 | 19 | routesJSON, err := ioutil.ReadAll(file) 20 | 21 | if err != nil { 22 | return fmt.Errorf("could not read config file: %w", err) 23 | } 24 | 25 | err = LoadJSON(server, []byte(routesJSON)) 26 | 27 | if err != nil { 28 | return fmt.Errorf("could not load config: %w", err) 29 | } 30 | 31 | return nil 32 | } 33 | 34 | func LoadJSON(server *viewproxy.Server, routesJSON []byte) error { 35 | var routeEntries []ConfigRouteEntry 36 | 37 | if err := json.Unmarshal(routesJSON, &routeEntries); err != nil { 38 | return fmt.Errorf("could not unmarshal in loadJSON: %w", err) 39 | } 40 | 41 | err := LoadRoutes(server, routeEntries) 42 | 43 | if err != nil { 44 | return fmt.Errorf("could not unmarshal in loadJSON: %w", err) 45 | } 46 | 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /pkg/routeimporter/json_test.go: -------------------------------------------------------------------------------- 1 | package routeimporter 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | "github.com/blakewilliams/viewproxy" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestLoadJSONFile(t *testing.T) { 14 | viewproxyServer, err := viewproxy.NewServer("http://fake.net") 15 | require.NoError(t, err) 16 | viewproxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 17 | 18 | // Load routes from config 19 | file, err := ioutil.TempFile(os.TempDir(), "config.json") 20 | if err != nil { 21 | require.Error(t, err) 22 | } 23 | defer os.Remove(file.Name()) 24 | 25 | file.Write([]byte(jsonConfig)) 26 | file.Close() 27 | 28 | LoadJSONFile(viewproxyServer, file.Name()) 29 | 30 | requireJsonConfigRoutesLoaded(t, viewproxyServer.Routes()) 31 | } 32 | -------------------------------------------------------------------------------- /pkg/secretfilter/secretfilter.go: -------------------------------------------------------------------------------- 1 | package secretfilter 2 | 3 | import ( 4 | "net/url" 5 | "strings" 6 | ) 7 | 8 | type Filter interface { 9 | Allow(string) 10 | IsAllowed(string) bool 11 | FilterURL(url *url.URL) *url.URL 12 | FilterURLString(url string) string 13 | FilterURLStringThrough(source string, target string) string 14 | FilterQueryParams(params url.Values) url.Values 15 | FilterURLError(errURL string, err *url.Error) *url.Error 16 | } 17 | 18 | type mapKey struct{} 19 | 20 | type secretFilter struct { 21 | allowedMap map[string]mapKey 22 | } 23 | 24 | var _ Filter = &secretFilter{} 25 | 26 | func New() Filter { 27 | return &secretFilter{allowedMap: make(map[string]mapKey)} 28 | } 29 | 30 | func (l *secretFilter) Allow(key string) { 31 | l.allowedMap[strings.ToLower(key)] = mapKey{} 32 | } 33 | 34 | func (l *secretFilter) IsAllowed(key string) bool { 35 | if _, ok := l.allowedMap[strings.ToLower(key)]; ok { 36 | return true 37 | } 38 | 39 | return false 40 | } 41 | 42 | func (l *secretFilter) FilterURLString(urlString string) string { 43 | parsedUrl, err := url.Parse(urlString) 44 | 45 | if err != nil { 46 | return "FILTEREDINVALIDURL" 47 | } 48 | 49 | return l.FilterURL(parsedUrl).String() 50 | } 51 | 52 | func (l *secretFilter) FilterURL(originalUrl *url.URL) *url.URL { 53 | clonedUrl, _ := url.Parse(originalUrl.String()) 54 | 55 | if clonedUrl.User != nil { 56 | clonedUrl.User = url.UserPassword("FILTERED", "FILTERED") 57 | } 58 | 59 | filteredParams := l.FilterQueryParams(clonedUrl.Query()) 60 | clonedUrl.RawQuery = filteredParams.Encode() 61 | 62 | return clonedUrl 63 | } 64 | 65 | func (l *secretFilter) FilterQueryParams(query url.Values) url.Values { 66 | filteredQueryParams := make(url.Values, len(query)) 67 | 68 | for key, values := range query { 69 | for _, value := range values { 70 | if l.IsAllowed(key) { 71 | filteredQueryParams.Add(key, value) 72 | } else { 73 | filteredQueryParams.Add(key, "FILTERED") 74 | } 75 | } 76 | } 77 | 78 | return filteredQueryParams 79 | } 80 | 81 | func (l *secretFilter) FilterURLStringThrough(source string, target string) string { 82 | // Copy query params from source to target 83 | parsedSource, parseErr := url.Parse(source) 84 | if parseErr == nil { 85 | parsedTarget, parseErr := url.Parse(target) 86 | if parseErr == nil { 87 | parsedTarget.RawQuery = parsedSource.RawQuery 88 | target = parsedTarget.String() 89 | } 90 | } 91 | 92 | return l.FilterURLString(target) 93 | } 94 | 95 | func (l *secretFilter) FilterURLError(errURL string, err *url.Error) *url.Error { 96 | return &url.Error{ 97 | Op: err.Op, 98 | URL: l.FilterURLStringThrough(err.URL, errURL), 99 | Err: err.Err, 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /pkg/secretfilter/secretfilter_test.go: -------------------------------------------------------------------------------- 1 | package secretfilter 2 | 3 | import ( 4 | "io" 5 | "net/url" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestSecretFilter_FilterUrl(t *testing.T) { 12 | original, err := url.Parse("http://localhost/foo?a=1") 13 | require.NoError(t, err) 14 | 15 | filter := New() 16 | filtered := filter.FilterURL(original) 17 | 18 | require.Equal(t, "http://localhost/foo?a=1", original.String()) 19 | require.Equal(t, "http://localhost/foo?a=FILTERED", filtered.String()) 20 | } 21 | 22 | func TestSecretFilter_FilterUrlUserInfo(t *testing.T) { 23 | original, err := url.Parse("http://foo:password@localhost/foo?a=1") 24 | require.NoError(t, err) 25 | 26 | filter := New() 27 | filtered := filter.FilterURL(original) 28 | 29 | require.Equal(t, "http://FILTERED:FILTERED@localhost/foo?a=FILTERED", filtered.String()) 30 | } 31 | 32 | func TestSecretFilter_FilterUrlString(t *testing.T) { 33 | tests := map[string]struct { 34 | input string 35 | allow []string 36 | want string 37 | }{ 38 | "no allowed parameters": { 39 | input: "http://localhost/foo/bar?a=1&b=2", 40 | allow: []string{}, 41 | want: "http://localhost/foo/bar?a=FILTERED&b=FILTERED", 42 | }, 43 | "allowed param": { 44 | input: "http://localhost/foo/bar?a=1&b=2", 45 | allow: []string{"a"}, 46 | want: "http://localhost/foo/bar?a=1&b=FILTERED", 47 | }, 48 | "path only url": { 49 | input: "/foo/bar?a=1&b=2", 50 | allow: []string{"a"}, 51 | want: "/foo/bar?a=1&b=FILTERED", 52 | }, 53 | "mixed capitalization parameters": { 54 | input: "/foo/bar?A=1&b=2", 55 | allow: []string{"a"}, 56 | want: "/foo/bar?A=1&b=FILTERED", 57 | }, 58 | "invalid url": { 59 | input: "http://%41:8080/", 60 | allow: []string{}, 61 | want: "FILTEREDINVALIDURL", 62 | }, 63 | } 64 | 65 | for name, tc := range tests { 66 | t.Run(name, func(t *testing.T) { 67 | filter := New() 68 | for _, value := range tc.allow { 69 | filter.Allow(value) 70 | } 71 | 72 | require.Equal(t, tc.want, filter.FilterURLString(tc.input)) 73 | }) 74 | } 75 | } 76 | 77 | func TestSecretFilter_FilterQueryParams(t *testing.T) { 78 | tests := map[string]struct { 79 | input url.Values 80 | allow []string 81 | want url.Values 82 | }{ 83 | "no allowed params": { 84 | input: map[string][]string{"a": {"1"}, "b": {"2"}}, 85 | allow: []string{}, 86 | want: map[string][]string{"a": {"FILTERED"}, "b": {"FILTERED"}}, 87 | }, 88 | 89 | "allowed params": { 90 | input: map[string][]string{"a": {"1"}, "b": {"2"}}, 91 | allow: []string{"a", "b"}, 92 | want: map[string][]string{"a": {"1"}, "b": {"2"}}, 93 | }, 94 | } 95 | 96 | for name, tc := range tests { 97 | t.Run(name, func(t *testing.T) { 98 | filter := New() 99 | for _, value := range tc.allow { 100 | filter.Allow(value) 101 | } 102 | 103 | require.Equal(t, tc.want, filter.FilterQueryParams(tc.input)) 104 | }) 105 | } 106 | } 107 | 108 | func TestSecretFilter_FilterUrlError(t *testing.T) { 109 | original := &url.Error{ 110 | Op: "Get", 111 | Err: io.EOF, 112 | URL: "http://localhost/foo?a=1", 113 | } 114 | 115 | filter := New() 116 | filtered := filter.FilterURLError("http://localhost/:foo", original) 117 | 118 | require.Equal(t, "http://localhost/:foo?a=FILTERED", filtered.URL) 119 | require.Equal(t, "Get", filtered.Op) 120 | require.Equal(t, io.EOF, filtered.Err) 121 | } 122 | -------------------------------------------------------------------------------- /response_builder.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "fmt" 7 | "net/http" 8 | "strconv" 9 | "time" 10 | 11 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 12 | ) 13 | 14 | type responseBuilder struct { 15 | writer http.ResponseWriter 16 | server Server 17 | body []byte 18 | StatusCode int 19 | } 20 | 21 | func newResponseBuilder(server Server, w http.ResponseWriter) *responseBuilder { 22 | return &responseBuilder{server: server, writer: w, StatusCode: 200} 23 | } 24 | 25 | func (rb *responseBuilder) SetFragments(route *Route, results []*multiplexer.Result) { 26 | resultMap := mapResultsToFragmentKey(route, results) 27 | rb.body = stitch(route.structure, resultMap) 28 | } 29 | 30 | func (rb *responseBuilder) SetDuration(duration int64) { 31 | outputHtml := bytes.Replace(rb.body, []byte(""), []byte(strconv.FormatInt(duration, 10)), 1) 32 | rb.body = outputHtml 33 | } 34 | 35 | func (rb *responseBuilder) Write() { 36 | rb.writer.WriteHeader(rb.StatusCode) 37 | 38 | if rb.writer.Header().Get("Content-Encoding") == "gzip" { 39 | var b bytes.Buffer 40 | gzipWriter := gzip.NewWriter(&b) 41 | 42 | _, err := gzipWriter.Write(rb.body) 43 | if err != nil { 44 | rb.server.Logger.Printf("Could not write to gzip buffer: %s", err) 45 | } 46 | 47 | gzipWriter.Close() 48 | if err != nil { 49 | rb.server.Logger.Printf("Could not closeto gzip buffer: %s", err) 50 | } 51 | 52 | rb.writer.Write(b.Bytes()) 53 | } else { 54 | rb.writer.Write(rb.body) 55 | } 56 | } 57 | 58 | func withDefaultErrorHandler(next http.Handler) http.Handler { 59 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 60 | results := multiplexer.ResultsFromContext(r.Context()) 61 | 62 | if results != nil && results.Error() != nil { 63 | rw.WriteHeader(http.StatusInternalServerError) 64 | rw.Write([]byte("500 internal server error")) 65 | } else { 66 | next.ServeHTTP(rw, r) 67 | } 68 | }) 69 | } 70 | 71 | func withCombinedFragments(s *Server) http.Handler { 72 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 73 | route := RouteFromContext(r.Context()) 74 | results := multiplexer.ResultsFromContext(r.Context()) 75 | 76 | if results != nil && results.Error() == nil { 77 | resBuilder := newResponseBuilder(*s, rw) 78 | resBuilder.SetFragments(route, results.Results()) 79 | elapsed := time.Since(startTimeFromContext(r.Context())) 80 | resBuilder.SetDuration(elapsed.Milliseconds()) 81 | resBuilder.Write() 82 | } 83 | }) 84 | } 85 | 86 | func stitch(structure *stitchStructure, results map[string]*multiplexer.Result) []byte { 87 | childContent := make(map[string][]byte) 88 | 89 | for _, childBuild := range structure.DependentStructures() { 90 | childContent[childBuild.ReplacementID()] = stitch(childBuild, results) 91 | } 92 | 93 | self := results[structure.Key()].Body 94 | 95 | // handle edge fragments 96 | if len(childContent) == 0 { 97 | return self 98 | } 99 | 100 | for replacementKey, content := range childContent { 101 | directive := []byte(fmt.Sprintf("", replacementKey)) 102 | self = bytes.Replace(self, directive, content, 1) 103 | } 104 | 105 | return self 106 | } 107 | 108 | func mapResultsToFragmentKey(route *Route, results []*multiplexer.Result) map[string]*multiplexer.Result { 109 | resultMap := make(map[string]*multiplexer.Result, len(route.FragmentOrder())) 110 | 111 | for i, key := range route.FragmentOrder() { 112 | resultMap[key] = results[i] 113 | } 114 | 115 | return resultMap 116 | } 117 | -------------------------------------------------------------------------------- /route.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "strings" 8 | 9 | "github.com/blakewilliams/viewproxy/pkg/fragment" 10 | ) 11 | 12 | type RouteValidationError struct { 13 | Route *Route 14 | Fragment *fragment.Definition 15 | } 16 | 17 | func (rve *RouteValidationError) Error() string { 18 | if len(rve.Route.dynamicParts) > 0 { 19 | return fmt.Sprintf( 20 | "dynamic route %s has mismatched fragment route %s", 21 | rve.Route.Path, 22 | rve.Fragment.Path, 23 | ) 24 | } else { 25 | return fmt.Sprintf( 26 | "static route %s has mismatched fragment route %s", 27 | rve.Route.Path, 28 | rve.Fragment.Path, 29 | ) 30 | } 31 | } 32 | 33 | type Route struct { 34 | Path string 35 | Parts []string 36 | dynamicParts []string 37 | RootFragment *fragment.Definition 38 | Metadata map[string]string 39 | // memoized version of the mapping used to stitch fragments back together 40 | structure *stitchStructure 41 | // memoized version of fragments to request 42 | fragmentsToRequest []*fragment.Definition 43 | // memoized version mapping fragment names to multiplexer.Result order 44 | fragmentOrder []string 45 | } 46 | 47 | func newRoute(path string, metadata map[string]string, root *fragment.Definition) *Route { 48 | route := &Route{ 49 | Path: path, 50 | Parts: strings.Split(path, "/"), 51 | Metadata: metadata, 52 | RootFragment: root, 53 | } 54 | 55 | dynamicParts := make([]string, 0) 56 | for _, part := range route.Parts { 57 | if strings.HasPrefix(part, ":") { 58 | dynamicParts = append(dynamicParts, part) 59 | } 60 | } 61 | route.dynamicParts = dynamicParts 62 | route.structure = stitchStructureFor(root) 63 | 64 | route.memoizeFragments() 65 | 66 | return route 67 | } 68 | 69 | // Validates if the route and fragments have compatible dynamic route parts. 70 | func (r *Route) Validate() error { 71 | for _, fragment := range r.FragmentsToRequest() { 72 | if !fragment.IgnoreValidation && !compareStringSlice(r.dynamicParts, fragment.DynamicParts()) { 73 | return &RouteValidationError{Route: r, Fragment: fragment} 74 | } 75 | } 76 | 77 | return nil 78 | } 79 | 80 | func (r *Route) FragmentOrder() []string { 81 | return r.fragmentOrder 82 | } 83 | 84 | func (r *Route) FragmentsToRequest() []*fragment.Definition { 85 | return r.fragmentsToRequest 86 | } 87 | 88 | func compareStringSlice(first []string, other []string) bool { 89 | sort.Strings(first) 90 | sort.Strings(other) 91 | 92 | return reflect.DeepEqual(first, other) 93 | } 94 | 95 | func (r *Route) dynamicPartsFromRequest(path string) map[string]string { 96 | dynamicParts := make(map[string]string) 97 | routeParts := strings.Split(path, "/") 98 | 99 | for i, part := range r.Parts { 100 | if strings.HasPrefix(part, ":") { 101 | dynamicParts[part] = routeParts[i] 102 | } 103 | } 104 | 105 | return dynamicParts 106 | } 107 | 108 | func (r *Route) matchParts(pathParts []string) bool { 109 | if len(r.Parts) != len(pathParts) { 110 | return false 111 | } 112 | 113 | for i := 0; i < len(r.Parts); i++ { 114 | if r.Parts[i] != pathParts[i] && !strings.HasPrefix(r.Parts[i], ":") { 115 | return false 116 | } 117 | } 118 | 119 | return true 120 | } 121 | 122 | func (r *Route) parametersFor(pathParts []string) map[string]string { 123 | parameters := make(map[string]string) 124 | 125 | for i := 0; i < len(r.Parts); i++ { 126 | if strings.HasPrefix(r.Parts[i], ":") { 127 | paramName := r.Parts[i][1:] 128 | parameters[paramName] = pathParts[i] 129 | } 130 | } 131 | 132 | return parameters 133 | } 134 | 135 | func (r *Route) memoizeFragments() { 136 | mapping := fragmentMapping(r.RootFragment) 137 | 138 | keys := make([]string, 0, len(mapping)) 139 | 140 | for key := range mapping { 141 | keys = append(keys, key) 142 | } 143 | 144 | sort.Strings(keys) 145 | 146 | r.fragmentOrder = keys 147 | 148 | fragments := make([]*fragment.Definition, 0, len(keys)) 149 | for _, key := range keys { 150 | fragments = append(fragments, mapping[key]) 151 | } 152 | 153 | r.fragmentsToRequest = fragments 154 | } 155 | 156 | // fragmentMapping returns a map of fragment keys and their fragments. 157 | // 158 | // Fragment keys consist of each parent's name separated by a `.`. The top-level 159 | // fragment is always named root and child fragments are named after their key 160 | // in the parent's `Children` map. e.g. `root.layout.header` 161 | func fragmentMapping(f *fragment.Definition) map[string]*fragment.Definition { 162 | mapping := make(map[string]*fragment.Definition) 163 | mapping["root"] = f 164 | 165 | for name, child := range f.Children() { 166 | mapChildFragment("root", name, child, mapping) 167 | } 168 | 169 | return mapping 170 | } 171 | 172 | func mapChildFragment(prefix string, name string, f *fragment.Definition, mapping map[string]*fragment.Definition) { 173 | key := prefix + "." + name 174 | mapping[key] = f 175 | 176 | for name, child := range f.Children() { 177 | mapChildFragment(key, name, child, mapping) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /route_test.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | 8 | fragment "github.com/blakewilliams/viewproxy/pkg/fragment" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestRoute_MatchParts(t *testing.T) { 13 | tests := map[string]struct { 14 | routePath string 15 | providedUrl string 16 | want bool 17 | }{ 18 | "root": {routePath: "/", providedUrl: "/", want: true}, 19 | "mismatched root route": {routePath: "/", providedUrl: "/hello-world", want: false}, 20 | "matching static routes": {routePath: "/hello/world", providedUrl: "/hello/world", want: true}, 21 | "mismatched static routes": {routePath: "/hello/world", providedUrl: "/hello/false", want: false}, 22 | "valid dynamic route": {routePath: "/hello/:name", providedUrl: "/hello/world", want: true}, 23 | "invalid dynamic route": {routePath: "/hello/:name", providedUrl: "/hello/world/wow", want: false}, 24 | } 25 | 26 | for name, test := range tests { 27 | t.Run(name, func(t *testing.T) { 28 | route := newRoute(test.routePath, map[string]string{}, fragment.Define("")) 29 | providedUrlParts := strings.Split(test.providedUrl, "/") 30 | got := route.matchParts(providedUrlParts) 31 | 32 | if got != test.want { 33 | t.Fatalf("expected route %s to match URL %s", test.routePath, test.providedUrl) 34 | } 35 | }) 36 | } 37 | } 38 | 39 | func TestRoute_ParametersFor(t *testing.T) { 40 | tests := map[string]struct { 41 | routePath string 42 | providedUrl string 43 | want map[string]string 44 | }{ 45 | "simple": {routePath: "/", providedUrl: "/", want: map[string]string{}}, 46 | "multi false": {routePath: "/hello/:name", providedUrl: "/hello/world", want: map[string]string{"name": "world"}}, 47 | } 48 | 49 | for name, test := range tests { 50 | t.Run(name, func(t *testing.T) { 51 | route := newRoute(test.routePath, map[string]string{}, fragment.Define("")) 52 | providedUrlParts := strings.Split(test.providedUrl, "/") 53 | got := route.parametersFor(providedUrlParts) 54 | 55 | if !reflect.DeepEqual(got, test.want) { 56 | t.Fatalf("expected route %v with URL %s to have params: %v\n got: %v", test.routePath, test.providedUrl, test.want, got) 57 | } 58 | }) 59 | } 60 | } 61 | 62 | func TestRoute_Validate(t *testing.T) { 63 | testCases := map[string]struct { 64 | routePath string 65 | root *fragment.Definition 66 | errorString string 67 | valid bool 68 | }{ 69 | "static routes": { 70 | routePath: "/foo", 71 | root: fragment.Define("/foo/layout", fragment.WithChild( 72 | "body", fragment.Define("body"), 73 | )), 74 | }, 75 | "dynamic route matching": { 76 | routePath: "/hello/:name", 77 | root: fragment.Define("/_viewproxy/hello/:name/layout", fragment.WithChild( 78 | "body", fragment.Define("/_viewproxy/hello/:name/body"), 79 | )), 80 | }, 81 | "dynamic route matching with different order": { 82 | routePath: "/:greeting/:name", 83 | root: fragment.Define("/_viewproxy/:greeting/:name/layout", fragment.WithChild( 84 | "body", fragment.Define("/_viewproxy/hello/:name/:greeting/body"), 85 | )), 86 | }, 87 | "dynamic route layout not matching": { 88 | routePath: "/hello/:name", 89 | root: fragment.Define("/_viewproxy/hello/:login/layout", fragment.WithChild( 90 | "body", fragment.Define("/_viewproxy/hello/:name/body"), 91 | )), 92 | errorString: "dynamic route /hello/:name has mismatched fragment route /_viewproxy/hello/:login/layout", 93 | }, 94 | "dynamic route layout not matching without validation": { 95 | routePath: "/hello/:name", 96 | root: fragment.Define("/_viewproxy/hello/:login/layout", fragment.WithoutValidation(), fragment.WithChild( 97 | "body", fragment.Define("/_viewproxy/hello/:name/body"), 98 | )), 99 | }, 100 | "dynamic route body not matching": { 101 | routePath: "/hello/:name", 102 | root: fragment.Define("/_viewproxy/hello/:name/layout", fragment.WithChild( 103 | "body", fragment.Define("/_viewproxy/hello/:login/body"), 104 | )), 105 | errorString: "dynamic route /hello/:name has mismatched fragment route /_viewproxy/hello/:login/body", 106 | }, 107 | "dynamic route body not matching without validation": { 108 | routePath: "/hello/:name", 109 | root: fragment.Define("/_viewproxy/hello/:name/layout", fragment.WithChild( 110 | "body", fragment.Define("/_viewproxy/hello/:login/body", fragment.WithoutValidation()), 111 | )), 112 | }, 113 | "static route with dynamic layout": { 114 | routePath: "/foo", 115 | root: fragment.Define("/_viewproxy/hello/:name/layout", fragment.WithChild( 116 | "body", fragment.Define("body"), 117 | )), 118 | errorString: "static route /foo has mismatched fragment route /_viewproxy/hello/:name/layout", 119 | }, 120 | "static route with dynamic body": { 121 | routePath: "/foo", 122 | root: fragment.Define("/_viewproxy/foo/layout", fragment.WithChild( 123 | "body", fragment.Define("/_viewproxy/hello/:name/body"), 124 | )), 125 | errorString: "static route /foo has mismatched fragment route /_viewproxy/hello/:name/body", 126 | }, 127 | } 128 | for name, tc := range testCases { 129 | t.Run(name, func(t *testing.T) { 130 | route := newRoute(tc.routePath, map[string]string{}, tc.root) 131 | 132 | err := route.Validate() 133 | 134 | if tc.errorString == "" { 135 | require.NoError(t, err) 136 | } else { 137 | require.EqualError(t, err, tc.errorString) 138 | } 139 | }) 140 | } 141 | } 142 | 143 | func TestFragmentMapping(t *testing.T) { 144 | header := fragment.Define("header") 145 | footer := fragment.Define("footer") 146 | body := fragment.Define("body", fragment.WithChild("header", header), fragment.WithChild("footer", footer)) 147 | 148 | root := fragment.Define( 149 | "/hello/:name", 150 | fragment.WithChild("body", body), 151 | ) 152 | 153 | mapping := fragmentMapping(root) 154 | 155 | require.Equal(t, footer, mapping["root.body.footer"]) 156 | require.Equal(t, header, mapping["root.body.header"]) 157 | require.Equal(t, body, mapping["root.body"]) 158 | require.Equal(t, root, mapping["root"]) 159 | } 160 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/http/httputil" 10 | "net/url" 11 | "strings" 12 | "time" 13 | 14 | "github.com/blakewilliams/viewproxy/pkg/fragment" 15 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 16 | "github.com/blakewilliams/viewproxy/pkg/secretfilter" 17 | "go.opentelemetry.io/otel" 18 | "go.opentelemetry.io/otel/propagation" 19 | "go.opentelemetry.io/otel/trace" 20 | ) 21 | 22 | const ( 23 | HeaderViewProxyOriginalPath = "X-Viewproxy-Original-Path" 24 | ) 25 | 26 | // Re-export ResultError for convenience 27 | type ResultError = multiplexer.ResultError 28 | 29 | type logger interface { 30 | Fatal(v ...interface{}) 31 | Fatalf(format string, v ...interface{}) 32 | Fatalln(v ...interface{}) 33 | Panic(v ...interface{}) 34 | Panicf(format string, v ...interface{}) 35 | Panicln(v ...interface{}) 36 | Print(v ...interface{}) 37 | Printf(format string, v ...interface{}) 38 | Println(v ...interface{}) 39 | } 40 | 41 | type Server struct { 42 | Addr string 43 | // Sets the maximum duration for requests made to the target server 44 | ProxyTimeout time.Duration 45 | // Sets the maximum duration for reading the entire request, including the body 46 | ReadTimeout time.Duration 47 | // Sets the maximum duration before timing out writes of the response 48 | WriteTimeout time.Duration 49 | // Ignores incoming request's trailing slashes when trying to match a 50 | // request URL to a route. This only applies to routes that are not declared 51 | // with an explicit trailing slash. 52 | IgnoreTrailingSlash bool 53 | routes []Route 54 | target string 55 | targetURL *url.URL 56 | httpServer *http.Server 57 | reverseProxy *httputil.ReverseProxy 58 | Logger logger 59 | passThrough bool 60 | SecretFilter secretfilter.Filter 61 | // Sets the secret used to generate an HMAC that can be used by the target 62 | // server to validate that a request came from viewproxy. 63 | // 64 | // When set, two headers are sent to the target URL for fragment and layout 65 | // requests. The `X-Authorization-Timestamp` header, which is a timestamp 66 | // generated at the start of the request, and `X-Authorization`, which is a 67 | // hex encoded HMAC of "urlPathWithQueryParams,timestamp`. 68 | HmacSecret string 69 | // The transport passed to `http.Client` when fetching fragments or proxying 70 | // requests. 71 | // HttpTransport http.RoundTripper 72 | MultiplexerTripper multiplexer.Tripper 73 | // A function to wrap the entire request handling with other middleware 74 | AroundRequest func(http.Handler) http.Handler 75 | // A function to wrap around the generating of the response after the fragment 76 | // requests have completed or errored 77 | AroundResponse func(http.Handler) http.Handler 78 | } 79 | 80 | type ServerOption = func(*Server) error 81 | 82 | type routeContextKey struct{} 83 | type parametersContextKey struct{} 84 | type startTimeKey struct{} 85 | 86 | const defaultTimeout = 10 * time.Second 87 | 88 | func emptyMiddleware(h http.Handler) http.Handler { return h } 89 | 90 | // NewServer returns a new Server that will make requests to the given target argument. 91 | func NewServer(target string, opts ...ServerOption) (*Server, error) { 92 | targetURL, err := url.Parse(target) 93 | 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | server := &Server{ 99 | MultiplexerTripper: multiplexer.NewStandardTripper(&http.Client{}), 100 | Logger: log.Default(), 101 | SecretFilter: secretfilter.New(), 102 | Addr: "localhost:3005", 103 | ProxyTimeout: defaultTimeout, 104 | ReadTimeout: defaultTimeout, 105 | WriteTimeout: defaultTimeout, 106 | passThrough: false, 107 | AroundRequest: emptyMiddleware, 108 | AroundResponse: emptyMiddleware, 109 | IgnoreTrailingSlash: true, 110 | target: target, 111 | targetURL: targetURL, 112 | routes: make([]Route, 0), 113 | } 114 | 115 | for _, fn := range opts { 116 | err := fn(server) 117 | 118 | if err != nil { 119 | return nil, fmt.Errorf("viewproxy.ServerOption error: %w", err) 120 | } 121 | } 122 | 123 | return server, nil 124 | } 125 | 126 | func WithPassThrough(passthroughTarget string) ServerOption { 127 | return func(server *Server) error { 128 | targetURL, err := url.Parse(passthroughTarget) 129 | 130 | if err != nil { 131 | return fmt.Errorf("WithPassThrough error: %w", err) 132 | } 133 | 134 | server.passThrough = true 135 | server.reverseProxy = httputil.NewSingleHostReverseProxy(targetURL) 136 | 137 | return nil 138 | } 139 | } 140 | 141 | func (s *Server) PassThroughEnabled() bool { 142 | return s.passThrough 143 | } 144 | 145 | type GetOption = func(*Route) 146 | 147 | func WithRouteMetadata(metadata map[string]string) GetOption { 148 | return func(route *Route) { 149 | route.Metadata = metadata 150 | } 151 | } 152 | 153 | func (s *Server) Get(path string, root *fragment.Definition, opts ...GetOption) error { 154 | route := newRoute(path, map[string]string{}, root) 155 | 156 | for _, opt := range opts { 157 | opt(route) 158 | } 159 | 160 | err := route.Validate() 161 | if err != nil { 162 | return err 163 | } 164 | 165 | s.routes = append(s.routes, *route) 166 | 167 | return nil 168 | } 169 | 170 | // target returns the configured http target 171 | func (s *Server) Target() string { 172 | return s.target 173 | } 174 | 175 | // routes returns a slice containing routes defined on the server. 176 | func (s *Server) Routes() []Route { 177 | return s.routes 178 | } 179 | 180 | func (s *Server) Shutdown(ctx context.Context) error { 181 | return s.httpServer.Shutdown(ctx) 182 | } 183 | 184 | func (s *Server) Close() { 185 | s.httpServer.Close() 186 | } 187 | 188 | // TODO this should probably be a tree structure for faster lookups 189 | func (s *Server) MatchingRoute(path string) (*Route, map[string]string) { 190 | if s.IgnoreTrailingSlash && path != "/" { 191 | path = strings.TrimRight(path, "/") 192 | } 193 | parts := strings.Split(path, "/") 194 | 195 | for _, route := range s.routes { 196 | if route.matchParts(parts) { 197 | parameters := route.parametersFor(parts) 198 | return &route, parameters 199 | } 200 | } 201 | 202 | return nil, nil 203 | } 204 | 205 | func (s *Server) rootHandler(next http.Handler) http.Handler { 206 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 207 | ctx := r.Context() 208 | ctx = otel.GetTextMapPropagator().Extract(ctx, propagation.HeaderCarrier(r.Header)) 209 | 210 | tracer := otel.Tracer("server") 211 | var span trace.Span 212 | ctx, span = tracer.Start(ctx, "ServeHTTP") 213 | defer span.End() 214 | 215 | route, parameters := s.MatchingRoute(r.URL.EscapedPath()) 216 | 217 | if route != nil { 218 | ctx = context.WithValue(ctx, routeContextKey{}, route) 219 | ctx = context.WithValue(ctx, parametersContextKey{}, parameters) 220 | } 221 | 222 | next.ServeHTTP(w, r.WithContext(ctx)) 223 | }) 224 | } 225 | 226 | func (s *Server) requestHandler() http.Handler { 227 | responseHandler := s.createResponseHandler() 228 | 229 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 230 | ctx := r.Context() 231 | route := RouteFromContext(ctx) 232 | if route != nil { 233 | parameters := ParametersFromContext(ctx) 234 | s.handleRequest(w, r, route, parameters, ctx, responseHandler) 235 | } else { 236 | s.handlePassThrough(w, r) 237 | } 238 | }) 239 | } 240 | 241 | func (s *Server) CreateHandler() http.Handler { 242 | return s.rootHandler(s.AroundRequest(s.requestHandler())) 243 | } 244 | 245 | func (s *Server) createResponseHandler() http.Handler { 246 | handler := withCombinedFragments(s) 247 | handler = withDefaultErrorHandler(handler) 248 | handler = s.AroundResponse(handler) 249 | handler = multiplexer.WithDefaultHeaders(handler) 250 | 251 | return handler 252 | } 253 | 254 | func (s *Server) newRequest() *multiplexer.Request { 255 | req := multiplexer.NewRequest(s.MultiplexerTripper) 256 | req.SecretFilter = s.SecretFilter 257 | req.Timeout = s.ProxyTimeout 258 | return req 259 | } 260 | 261 | func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request, route *Route, parameters map[string]string, ctx context.Context, handler http.Handler) { 262 | startTime := time.Now() 263 | req := s.newRequest() 264 | req.HmacSecret = s.HmacSecret 265 | 266 | for _, f := range route.FragmentsToRequest() { 267 | query := url.Values{} 268 | 269 | for name, values := range r.URL.Query() { 270 | if query.Get(name) == "" { 271 | for _, value := range values { 272 | query.Add(name, value) 273 | } 274 | } 275 | } 276 | 277 | dynamicParts := route.dynamicPartsFromRequest(r.URL.EscapedPath()) 278 | requestable, err := f.Requestable(s.targetURL, dynamicParts, query) 279 | if len(r.URL.Query()) > 0 { 280 | requestable.RequestURL.RawQuery = query.Encode() 281 | } 282 | 283 | if err != nil { 284 | // This can be caused due to invalid encoding 285 | panic(err) 286 | } 287 | req.WithRequestable(requestable) 288 | } 289 | 290 | req.WithHeadersFromRequest(r) 291 | req.Header.Set(HeaderViewProxyOriginalPath, r.URL.RequestURI()) 292 | results, err := req.Do(ctx) 293 | 294 | handlerCtx := context.WithValue(r.Context(), startTimeKey{}, startTime) 295 | handlerCtx = multiplexer.ContextWithResults(handlerCtx, results, err) 296 | handler.ServeHTTP(w, r.WithContext(handlerCtx)) 297 | } 298 | 299 | func (s *Server) handlePassThrough(w http.ResponseWriter, r *http.Request) { 300 | if s.passThrough { 301 | s.reverseProxy.ServeHTTP(w, r) 302 | } else { 303 | w.WriteHeader(404) 304 | w.Write([]byte("404 not found")) 305 | } 306 | } 307 | 308 | func RouteFromContext(ctx context.Context) *Route { 309 | if ctx == nil { 310 | return nil 311 | } 312 | 313 | if route := ctx.Value(routeContextKey{}); route != nil { 314 | return route.(*Route) 315 | } 316 | return nil 317 | } 318 | 319 | func ParametersFromContext(ctx context.Context) map[string]string { 320 | if ctx == nil { 321 | return nil 322 | } 323 | 324 | if parameters := ctx.Value(parametersContextKey{}); parameters != nil { 325 | return parameters.(map[string]string) 326 | } 327 | return nil 328 | } 329 | 330 | func startTimeFromContext(ctx context.Context) time.Time { 331 | if ctx == nil { 332 | return time.Time{} 333 | } 334 | 335 | if startTime := ctx.Value(startTimeKey{}); startTime != nil { 336 | return startTime.(time.Time) 337 | } 338 | return time.Time{} 339 | } 340 | 341 | func FragmentRouteFromContext(ctx context.Context) *fragment.Definition { 342 | requestable := multiplexer.RequestableFromContext(ctx) 343 | 344 | if requestable == nil { 345 | return nil 346 | } 347 | 348 | if fragment, ok := requestable.(*fragment.Request); ok { 349 | return fragment.Definition 350 | } 351 | 352 | return nil 353 | } 354 | 355 | func (s *Server) ListenAndServe() error { 356 | return s.configureServer(func() error { 357 | s.Logger.Printf("Listening on %v", s.Addr) 358 | return s.httpServer.ListenAndServe() 359 | }) 360 | } 361 | 362 | func (s *Server) Serve(listener net.Listener) error { 363 | return s.configureServer(func() error { 364 | s.Logger.Printf("Listening on %v", listener.Addr()) 365 | return s.httpServer.Serve(listener) 366 | }) 367 | } 368 | 369 | func (s *Server) configureServer(serveFn func() error) error { 370 | s.httpServer = &http.Server{ 371 | Addr: s.Addr, 372 | Handler: s.CreateHandler(), 373 | ReadTimeout: s.ReadTimeout, 374 | WriteTimeout: s.WriteTimeout, 375 | MaxHeaderBytes: 1 << 20, 376 | } 377 | 378 | return serveFn() 379 | } 380 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "context" 7 | "crypto/hmac" 8 | "crypto/sha256" 9 | "encoding/hex" 10 | "fmt" 11 | "io" 12 | "io/ioutil" 13 | "log" 14 | "net/http" 15 | "net/http/httptest" 16 | "net/url" 17 | "os" 18 | "strings" 19 | "sync" 20 | "testing" 21 | "time" 22 | 23 | "github.com/blakewilliams/viewproxy/pkg/fragment" 24 | "github.com/blakewilliams/viewproxy/pkg/multiplexer" 25 | "github.com/stretchr/testify/require" 26 | ) 27 | 28 | var targetServer *httptest.Server 29 | 30 | func TestMain(m *testing.M) { 31 | targetServer = startTargetServer() 32 | defer targetServer.CloseClientConnections() 33 | defer targetServer.Close() 34 | 35 | os.Exit(m.Run()) 36 | } 37 | 38 | func TestServer(t *testing.T) { 39 | viewProxyServer := newServer(t, targetServer.URL) 40 | viewProxyServer.Addr = "localhost:9997" 41 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 42 | 43 | viewProxyServer.AroundResponse = func(next http.Handler) http.Handler { 44 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 45 | rw.Header().Del("etag") 46 | next.ServeHTTP(rw, r) 47 | }) 48 | } 49 | 50 | // layout is shared and has no :name fragment 51 | root := fragment.Define( 52 | "/layouts/test_layout", fragment.WithoutValidation(), 53 | fragment.WithChild("header", fragment.Define("/header/:name")), 54 | fragment.WithChild("body", fragment.Define("/body/:name")), 55 | fragment.WithChild("footer", fragment.Define("/footer/:name")), 56 | ) 57 | 58 | err := viewProxyServer.Get("/hello/:name", root) 59 | require.NoError(t, err) 60 | viewProxyServer.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) 61 | 62 | go func() { 63 | if err := viewProxyServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { 64 | panic(err) 65 | } 66 | }() 67 | 68 | resp, err := http.Get(fmt.Sprintf("http://localhost:9997%s", "/hello/world")) 69 | require.NoError(t, err) 70 | body, err := ioutil.ReadAll(resp.Body) 71 | require.NoError(t, err) 72 | 73 | require.Equal(t, "hello world", string(body)) 74 | } 75 | 76 | func TestServerRoot(t *testing.T) { 77 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | w.WriteHeader(500) 79 | w.Write([]byte("Internal server error")) 80 | }) 81 | testServer := httptest.NewServer(instance) 82 | 83 | viewProxyServer := newServer(t, testServer.URL) 84 | viewProxyServer.Addr = "localhost:9997" 85 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 86 | 87 | root := fragment.Define("/") 88 | 89 | err := viewProxyServer.Get("/", root) 90 | require.NoError(t, err) 91 | viewProxyServer.Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime) 92 | 93 | r := httptest.NewRequest("GET", "/", nil) 94 | w := httptest.NewRecorder() 95 | 96 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 97 | 98 | resp := w.Result() 99 | 100 | require.Equal(t, 500, resp.StatusCode) 101 | } 102 | 103 | func TestQueryParamForwardingServer(t *testing.T) { 104 | viewProxyServer := newServer(t, targetServer.URL) 105 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 106 | 107 | root := fragment.Define("/layouts/test_layout", 108 | fragment.WithoutValidation(), 109 | fragment.WithChild("header", fragment.Define("/header/:name")), 110 | fragment.WithChild("body", fragment.Define("/body/:name")), 111 | fragment.WithChild("footer", fragment.Define("/footer/:name")), 112 | ) 113 | viewProxyServer.Get("/hello/:name", root) 114 | 115 | r := httptest.NewRequest("GET", "/hello/world?important=true&name=override", nil) 116 | w := httptest.NewRecorder() 117 | 118 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 119 | 120 | resp := w.Result() 121 | 122 | body, err := ioutil.ReadAll(resp.Body) 123 | require.Nil(t, err) 124 | expected := "hello world!" 125 | 126 | require.Equal(t, expected, string(body)) 127 | require.Equal(t, "viewproxy", resp.Header.Get("x-name"), "Expected response to have an X-Name header") 128 | } 129 | 130 | func TestServer_EscapedNamedFragments(t *testing.T) { 131 | viewProxyServer := newServer(t, targetServer.URL) 132 | 133 | root := fragment.Define("/layouts/test_layout", 134 | fragment.WithoutValidation(), 135 | fragment.WithChild("header", fragment.Define("/header/:name")), 136 | fragment.WithChild("body", fragment.Define("/body/:name")), 137 | fragment.WithChild("footer", fragment.Define("/footer/:name")), 138 | ) 139 | err := viewProxyServer.Get("/hello/:name", root) 140 | require.NoError(t, err) 141 | 142 | r := httptest.NewRequest("GET", "/hello/world%2fvoltron", nil) 143 | w := httptest.NewRecorder() 144 | 145 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 146 | 147 | resp := w.Result() 148 | 149 | body, err := ioutil.ReadAll(resp.Body) 150 | require.Nil(t, err) 151 | expected := "hello world/voltron" 152 | 153 | require.Equal(t, expected, string(body)) 154 | } 155 | 156 | func TestPassThroughEnabled(t *testing.T) { 157 | viewProxyServer := newServer(t, targetServer.URL, WithPassThrough(targetServer.URL)) 158 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 159 | 160 | r := httptest.NewRequest("GET", "/oops", nil) 161 | w := httptest.NewRecorder() 162 | 163 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 164 | 165 | resp := w.Result() 166 | body, err := ioutil.ReadAll(resp.Body) 167 | require.Nil(t, err) 168 | 169 | require.Equal(t, 500, resp.StatusCode) 170 | require.Equal(t, "Something went wrong", string(body)) 171 | } 172 | 173 | func TestPassThroughDisabled(t *testing.T) { 174 | viewProxyServer := newServer(t, targetServer.URL) 175 | 176 | r := httptest.NewRequest("GET", "/hello/world", nil) 177 | w := httptest.NewRecorder() 178 | 179 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 180 | 181 | resp := w.Result() 182 | body, err := ioutil.ReadAll(resp.Body) 183 | require.Nil(t, err) 184 | 185 | require.Equal(t, 404, resp.StatusCode) 186 | require.Equal(t, "404 not found", string(body)) 187 | } 188 | 189 | func TestPassThroughPostRequest(t *testing.T) { 190 | ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 191 | defer cancel() 192 | done := make(chan struct{}) 193 | 194 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 195 | defer close(done) 196 | 197 | body, err := io.ReadAll(r.Body) 198 | 199 | require.Nil(t, err) 200 | require.Equal(t, http.MethodPost, r.Method) 201 | require.Equal(t, "hello", string(body)) 202 | })) 203 | 204 | viewProxyServer := newServer(t, server.URL, WithPassThrough(server.URL)) 205 | 206 | r := httptest.NewRequest("POST", "/hello/world", strings.NewReader("hello")) 207 | w := httptest.NewRecorder() 208 | 209 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 210 | 211 | select { 212 | case <-done: 213 | server.Close() 214 | case <-ctx.Done(): 215 | require.Fail(t, ctx.Err().Error()) 216 | } 217 | } 218 | 219 | func TestFragmentSendsVerifiableHmacWhenSet(t *testing.T) { 220 | done := make(chan struct{}) 221 | secret := "6ccd9547b7042e0f1101ce68931d6b2c" 222 | 223 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 224 | defer close(done) 225 | 226 | time := r.Header.Get("X-Authorization-Time") 227 | require.NotEqual(t, "", time, "Expected X-Authorization-Time header to be present") 228 | 229 | key := fmt.Sprintf("%s,%s", r.URL.Path, time) 230 | 231 | mac := hmac.New(sha256.New, []byte(secret)) 232 | mac.Write( 233 | []byte(key), 234 | ) 235 | 236 | authorization := r.Header.Get("Authorization") 237 | require.NotEqual(t, "", authorization, "Expected Authorization header to be present") 238 | 239 | expected := hex.EncodeToString(mac.Sum(nil)) 240 | 241 | require.Equal(t, expected, authorization) 242 | 243 | w.WriteHeader(http.StatusOK) 244 | })) 245 | 246 | viewProxyServer := newServer(t, server.URL) 247 | err := viewProxyServer.Get("/hello/:name", fragment.Define("/foo/:name")) 248 | require.NoError(t, err) 249 | viewProxyServer.HmacSecret = secret 250 | 251 | r := httptest.NewRequest("GET", "/hello/world", strings.NewReader("hello")) 252 | w := httptest.NewRecorder() 253 | 254 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 255 | 256 | <-done 257 | 258 | server.Close() 259 | } 260 | 261 | func TestSupportsGzip(t *testing.T) { 262 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 263 | var b bytes.Buffer 264 | 265 | gzWriter := gzip.NewWriter(&b) 266 | 267 | if strings.HasPrefix(r.URL.Path, "/layout") { 268 | gzWriter.Write([]byte(``)) 269 | } else if strings.HasPrefix(r.URL.Path, "/fragment") { 270 | gzWriter.Write([]byte("wow gzipped!")) 271 | } else { 272 | panic("Unexpected URL") 273 | } 274 | 275 | gzWriter.Close() 276 | 277 | w.Header().Set("Content-Encoding", "gzip") 278 | w.WriteHeader(http.StatusOK) 279 | w.Write(b.Bytes()) 280 | })) 281 | 282 | viewProxyServer := newServer(t, server.URL) 283 | viewProxyServer.Get( 284 | "/hello/:name", 285 | fragment.Define("/layout/:name", fragment.WithChild("fragment", fragment.Define("/fragment/:name"))), 286 | ) 287 | 288 | r := httptest.NewRequest("GET", "/hello/world", nil) 289 | r.Header.Set("Accept-Encoding", "gzip") 290 | w := httptest.NewRecorder() 291 | 292 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 293 | 294 | resp := w.Result() 295 | 296 | gzReader, err := gzip.NewReader(resp.Body) 297 | require.Nil(t, err) 298 | 299 | body, err := ioutil.ReadAll(gzReader) 300 | require.Nil(t, err) 301 | 302 | require.Equal(t, "wow gzipped!", string(body)) 303 | 304 | server.Close() 305 | } 306 | 307 | func TestAroundRequestCallback(t *testing.T) { 308 | done := make(chan struct{}) 309 | 310 | server := newServer(t, targetServer.URL) 311 | err := server.Get( 312 | "/hello/:name", 313 | fragment.Define("/layout/:name", fragment.WithChild("fragment", fragment.Define("/fragment/:name"))), 314 | ) 315 | require.NoError(t, err) 316 | server.AroundRequest = func(next http.Handler) http.Handler { 317 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 318 | defer close(done) 319 | w.Header().Set("x-viewproxy", "true") 320 | require.Equal(t, "/hello/:name", RouteFromContext(r.Context()).Path) 321 | require.Equal(t, "192.168.1.1", r.RemoteAddr) 322 | next.ServeHTTP(w, r) 323 | }) 324 | } 325 | 326 | w := httptest.NewRecorder() 327 | r := httptest.NewRequest("GET", "/hello/world", nil) 328 | r.RemoteAddr = "192.168.1.1" 329 | 330 | server.CreateHandler().ServeHTTP(w, r) 331 | 332 | resp := w.Result() 333 | 334 | require.Equal(t, "true", resp.Header.Get("x-viewproxy")) 335 | 336 | <-done 337 | } 338 | 339 | func TestErrorHandler(t *testing.T) { 340 | ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 341 | defer cancel() 342 | done := make(chan struct{}) 343 | 344 | server := newServer(t, targetServer.URL) 345 | err := server.Get( 346 | "/hello/:name", 347 | fragment.Define("/definitely_missing_and_not_defined/:name"), 348 | ) 349 | require.NoError(t, err) 350 | server.AroundRequest = func(next http.Handler) http.Handler { 351 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 352 | w.Header().Set("x-viewproxy", "true") 353 | require.Equal(t, "192.168.1.1", r.RemoteAddr) 354 | next.ServeHTTP(w, r) 355 | }) 356 | } 357 | server.AroundResponse = func(h http.Handler) http.Handler { 358 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 359 | rw.Header().Set("error-header", "true") 360 | 361 | defer close(done) 362 | 363 | results := multiplexer.ResultsFromContext(r.Context()) 364 | require.NotNil(t, results) 365 | 366 | var resultErr *ResultError 367 | require.ErrorAs(t, results.Error(), &resultErr) 368 | require.Equal( 369 | t, 370 | fmt.Sprintf("%s/definitely_missing_and_not_defined/world", targetServer.URL), 371 | resultErr.Result.Url, 372 | ) 373 | require.Equal(t, 404, resultErr.Result.StatusCode) 374 | }) 375 | } 376 | 377 | fakeWriter := httptest.NewRecorder() 378 | fakeRequest := httptest.NewRequest("GET", "/hello/world", nil) 379 | fakeRequest.RemoteAddr = "192.168.1.1" 380 | 381 | server.CreateHandler().ServeHTTP(fakeWriter, fakeRequest) 382 | 383 | require.Equal(t, "true", fakeWriter.Header().Get("x-viewproxy")) 384 | require.Equal(t, "true", fakeWriter.Header().Get("error-header")) 385 | 386 | select { 387 | case <-done: 388 | case <-ctx.Done(): 389 | require.Fail(t, ctx.Err().Error()) 390 | } 391 | } 392 | 393 | type contextTestTripper struct { 394 | route *Route 395 | requestables []multiplexer.Requestable 396 | mu sync.Mutex 397 | } 398 | 399 | func (t *contextTestTripper) Request(r *http.Request) (*http.Response, error) { 400 | t.mu.Lock() 401 | defer t.mu.Unlock() 402 | 403 | t.route = RouteFromContext(r.Context()) 404 | t.requestables = append(t.requestables, multiplexer.RequestableFromContext(r.Context())) 405 | return http.DefaultClient.Do(r) 406 | } 407 | 408 | func TestRoundTripperContext(t *testing.T) { 409 | viewProxyServer, err := NewServer(targetServer.URL) 410 | require.NoError(t, err) 411 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 412 | tripper := &contextTestTripper{} 413 | viewProxyServer.MultiplexerTripper = tripper 414 | 415 | root := fragment.Define( 416 | "/layouts/test_layout", fragment.WithoutValidation(), 417 | fragment.WithChild("header", fragment.Define("/header/:name")), 418 | fragment.WithChild("body", fragment.Define("/body/:name")), 419 | fragment.WithChild("footer", fragment.Define("/footer/:name")), 420 | ) 421 | 422 | err = viewProxyServer.Get("/hello/:name", root) 423 | require.NoError(t, err) 424 | 425 | r := httptest.NewRequest("GET", "/hello/world?important=true&name=override", nil) 426 | w := httptest.NewRecorder() 427 | 428 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 429 | 430 | resp := w.Result() 431 | 432 | require.Equal(t, 200, resp.StatusCode) 433 | require.Equal(t, 4, len(tripper.requestables)) 434 | require.NotNil(t, tripper.route) 435 | } 436 | 437 | func TestIgnoreTrailingSlash(t *testing.T) { 438 | viewProxyServer, err := NewServer(targetServer.URL) 439 | require.NoError(t, err) 440 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 441 | viewProxyServer.IgnoreTrailingSlash = true 442 | 443 | root := fragment.Define( 444 | "/layouts/test_layout", fragment.WithoutValidation(), 445 | fragment.WithChild("header", fragment.Define("/header/:name")), 446 | fragment.WithChild("body", fragment.Define("/body/:name")), 447 | fragment.WithChild("footer", fragment.Define("/footer/:name")), 448 | ) 449 | 450 | err = viewProxyServer.Get("/hello/:name", root) 451 | require.NoError(t, err) 452 | 453 | r := httptest.NewRequest("GET", "/hello/world/?important=true&name=override", nil) 454 | w := httptest.NewRecorder() 455 | 456 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 457 | resp := w.Result() 458 | require.Equal(t, 200, resp.StatusCode) 459 | 460 | r = httptest.NewRequest("GET", "/hello/world/?important=true&name=override", nil) 461 | w = httptest.NewRecorder() 462 | 463 | viewProxyServer.IgnoreTrailingSlash = false 464 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 465 | resp = w.Result() 466 | require.Equal(t, 404, resp.StatusCode) 467 | } 468 | 469 | func TestWithPassThrough_Error(t *testing.T) { 470 | _, err := NewServer(targetServer.URL, WithPassThrough("%invalid%")) 471 | 472 | require.Error(t, err) 473 | require.Contains(t, err.Error(), "viewproxy.ServerOption error") 474 | require.Contains(t, err.Error(), "WithPassThrough error") 475 | } 476 | 477 | func BenchmarkServer(b *testing.B) { 478 | viewProxyServer := newServer(b, targetServer.URL) 479 | viewProxyServer.Addr = "localhost:9997" 480 | viewProxyServer.Logger = log.New(ioutil.Discard, "", log.Ldate|log.Ltime) 481 | 482 | viewProxyServer.AroundResponse = func(next http.Handler) http.Handler { 483 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 484 | rw.Header().Del("etag") 485 | next.ServeHTTP(rw, r) 486 | }) 487 | } 488 | 489 | root := fragment.Define( 490 | "/layouts/test_layout", fragment.WithoutValidation(), fragment.WithChildren(fragment.Children{ 491 | "header": fragment.Define("/header/:name"), 492 | "body": fragment.Define("/body/:name"), 493 | "name": fragment.Define("/footer/:name"), 494 | }), 495 | ) 496 | viewProxyServer.Get("/hello/:name", root) 497 | 498 | b.ResetTimer() 499 | 500 | for i := 0; i < b.N; i++ { 501 | r := httptest.NewRequest("GET", "/hello/world", nil) 502 | w := httptest.NewRecorder() 503 | 504 | viewProxyServer.CreateHandler().ServeHTTP(w, r) 505 | } 506 | } 507 | 508 | func startTargetServer() *httptest.Server { 509 | instance := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 510 | parts := strings.Split(r.URL.EscapedPath(), "/") 511 | name, err := url.PathUnescape(parts[len(parts)-1]) 512 | 513 | w.Header().Set("EtAg", "1234") 514 | w.Header().Set("X-Name", "viewproxy") 515 | 516 | if err != nil { 517 | w.WriteHeader(http.StatusInternalServerError) 518 | w.Write([]byte(err.Error())) 519 | } 520 | 521 | if r.URL.Path == "/layouts/test_layout" { 522 | w.WriteHeader(http.StatusOK) 523 | w.Write([]byte(``)) 524 | } else if strings.HasPrefix(r.URL.Path, "/header/") { 525 | w.WriteHeader(http.StatusOK) 526 | w.Write([]byte("")) 527 | } else if strings.HasPrefix(r.URL.Path, "/body/") { 528 | w.WriteHeader(http.StatusOK) 529 | w.Write([]byte(fmt.Sprintf("hello %s", name))) 530 | if r.URL.Query().Get("important") != "" { 531 | w.Write([]byte("!")) 532 | } 533 | } else if strings.HasPrefix(r.URL.Path, "/footer/") { 534 | w.WriteHeader(http.StatusOK) 535 | w.Write([]byte("")) 536 | } else if r.URL.Path == "/oops" { 537 | w.WriteHeader(http.StatusInternalServerError) 538 | w.Write([]byte("Something went wrong")) 539 | } else { 540 | w.WriteHeader(http.StatusNotFound) 541 | w.Write([]byte("target: 404 not found")) 542 | } 543 | }) 544 | 545 | testServer := httptest.NewServer(instance) 546 | return testServer 547 | } 548 | 549 | func newServer(tb testing.TB, target string, opts ...ServerOption) *Server { 550 | server, err := NewServer(target, opts...) 551 | require.NoError(tb, err) 552 | 553 | return server 554 | } 555 | -------------------------------------------------------------------------------- /stitch_structure.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import "github.com/blakewilliams/viewproxy/pkg/fragment" 4 | 5 | type stitchStructure struct { 6 | key string 7 | replacementID string 8 | dependentStructures []*stitchStructure 9 | } 10 | 11 | func (s *stitchStructure) Key() string { 12 | return s.key 13 | } 14 | 15 | func (s *stitchStructure) ReplacementID() string { 16 | return s.replacementID 17 | } 18 | 19 | func (s *stitchStructure) DependentStructures() []*stitchStructure { 20 | return s.dependentStructures 21 | } 22 | 23 | func stitchStructureFor(d *fragment.Definition) *stitchStructure { 24 | structure := &stitchStructure{key: "root"} 25 | 26 | for name, child := range d.Children() { 27 | structure.dependentStructures = append(structure.dependentStructures, childStitchStructure("root", name, child)) 28 | } 29 | 30 | return structure 31 | } 32 | 33 | func childStitchStructure(prefix string, name string, d *fragment.Definition) *stitchStructure { 34 | key := prefix + "." + name 35 | buildInfo := &stitchStructure{key: key, replacementID: name} 36 | 37 | for name, child := range d.Children() { 38 | buildInfo.dependentStructures = append(buildInfo.dependentStructures, childStitchStructure(key, name, child)) 39 | } 40 | 41 | return buildInfo 42 | } 43 | -------------------------------------------------------------------------------- /stitch_structure_test.go: -------------------------------------------------------------------------------- 1 | package viewproxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/blakewilliams/viewproxy/pkg/fragment" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestStitchStructure(t *testing.T) { 11 | rootFragment := fragment.Define("layout", fragment.WithChildren(fragment.Children{ 12 | "header": fragment.Define("header"), 13 | "body": fragment.Define("body", fragment.WithChildren(fragment.Children{ 14 | "main": fragment.Define("main"), 15 | "sidebar": fragment.Define("sidebar"), 16 | })), 17 | })) 18 | 19 | structure := stitchStructureFor(rootFragment) 20 | 21 | var headerStructure *stitchStructure 22 | var bodyStructure *stitchStructure 23 | 24 | // Maps are not ordered, so we need to find the correct fragment order here 25 | if structure.DependentStructures()[0].Key() == "root.header" { 26 | headerStructure = structure.DependentStructures()[0] 27 | bodyStructure = structure.DependentStructures()[1] 28 | } else { 29 | headerStructure = structure.DependentStructures()[1] 30 | bodyStructure = structure.DependentStructures()[0] 31 | } 32 | 33 | require.Equal(t, "root", structure.Key()) 34 | 35 | require.Equal(t, "root.header", headerStructure.Key()) 36 | require.Equal(t, "header", headerStructure.ReplacementID()) 37 | 38 | require.Equal(t, "root.body", bodyStructure.Key()) 39 | require.Equal(t, "body", bodyStructure.ReplacementID()) 40 | 41 | var bodyKeys []string 42 | var bodyReplacementIDs []string 43 | for _, structure := range bodyStructure.DependentStructures() { 44 | bodyKeys = append(bodyKeys, structure.Key()) 45 | bodyReplacementIDs = append(bodyReplacementIDs, structure.ReplacementID()) 46 | } 47 | 48 | require.Contains(t, bodyKeys, "root.body.main") 49 | require.Contains(t, bodyKeys, "root.body.sidebar") 50 | 51 | require.Contains(t, bodyReplacementIDs, "main") 52 | require.Contains(t, bodyReplacementIDs, "sidebar") 53 | } 54 | --------------------------------------------------------------------------------