├── .github └── workflows │ └── test.yml ├── History.md ├── LICENSE ├── Readme.md ├── ci.yml ├── context.go ├── gateway.go ├── gateway_test.go ├── go.mod ├── go.sum ├── request.go ├── request_test.go ├── response.go ├── response_test.go └── v2 ├── context.go ├── gateway.go ├── gateway_test.go ├── go.mod ├── go.sum ├── request.go ├── request_test.go ├── response.go └── response_test.go /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Tests 3 | jobs: 4 | test: 5 | strategy: 6 | matrix: 7 | go-version: [1.13.x] 8 | platform: [ubuntu-latest, macos-latest, windows-latest] 9 | runs-on: ${{ matrix.platform }} 10 | steps: 11 | - name: Install Go 12 | uses: actions/setup-go@v1 13 | with: 14 | go-version: ${{ matrix.go-version }} 15 | - name: Checkout code 16 | uses: actions/checkout@v1 17 | 18 | - name: Test v1 19 | run: go test -v -cover ./... 20 | 21 | - name: Test v2 22 | run: go test -v -cover ./... 23 | working-directory: v2 24 | -------------------------------------------------------------------------------- /History.md: -------------------------------------------------------------------------------- 1 | 2 | v1.1.1 / 2018-08-17 3 | =================== 4 | 5 | * fix passing of request context 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 TJ Holowaychuk tj@tjholowaychuk.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | 'Software'), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | Package gateway provides a drop-in replacement for net/http's `ListenAndServe` for use in [AWS Lambda](https://aws.amazon.com/lambda/) & [API Gateway](https://aws.amazon.com/api-gateway/), simply swap it out for `gateway.ListenAndServe`. Extracted from [Up](https://github.com/apex/up) which provides additional middleware features and operational functionality. 4 | 5 | There are two versions of this library, version 1.x supports AWS API Gateway 1.0 events used by the original [REST APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html), and 2.x which supports 2.0 events used by the [HTTP APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html). For more information on the options read [Choosing between HTTP APIs and REST APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-vs-rest.html) on the AWS documentation website. 6 | 7 | # Installation 8 | 9 | To install version 1.x for REST APIs. 10 | 11 | ``` 12 | go get github.com/apex/gateway 13 | ``` 14 | 15 | To install version 2.x for HTTP APIs. 16 | 17 | ``` 18 | go get github.com/apex/gateway/v2 19 | ``` 20 | 21 | # Example 22 | 23 | ```go 24 | package main 25 | 26 | import ( 27 | "fmt" 28 | "log" 29 | "net/http" 30 | "os" 31 | 32 | "github.com/apex/gateway" 33 | ) 34 | 35 | func main() { 36 | http.HandleFunc("/", hello) 37 | log.Fatal(gateway.ListenAndServe(":3000", nil)) 38 | } 39 | 40 | func hello(w http.ResponseWriter, r *http.Request) { 41 | // example retrieving values from the api gateway proxy request context. 42 | requestContext, ok := gateway.RequestContext(r.Context()) 43 | if !ok || requestContext.Authorizer["sub"] == nil { 44 | fmt.Fprint(w, "Hello World from Go") 45 | return 46 | } 47 | 48 | userID := requestContext.Authorizer["sub"].(string) 49 | fmt.Fprintf(w, "Hello %s from Go", userID) 50 | } 51 | ``` 52 | 53 | --- 54 | 55 | [![GoDoc](https://godoc.org/github.com/apex/up-go?status.svg)](https://godoc.org/github.com/apex/gateway) 56 | ![](https://img.shields.io/badge/license-MIT-blue.svg) 57 | ![](https://img.shields.io/badge/status-stable-green.svg) 58 | 59 | 60 | -------------------------------------------------------------------------------- /ci.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | 3 | phases: 4 | install: 5 | commands: 6 | - go get -t ./... 7 | build: 8 | commands: 9 | - go test -cover -v ./... 10 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-lambda-go/events" 7 | ) 8 | 9 | // key is the type used for any items added to the request context. 10 | type key int 11 | 12 | // requestContextKey is the key for the api gateway proxy `RequestContext`. 13 | const requestContextKey key = iota 14 | 15 | // newContext returns a new Context with specific api gateway proxy values. 16 | func newContext(ctx context.Context, e events.APIGatewayProxyRequest) context.Context { 17 | return context.WithValue(ctx, requestContextKey, e.RequestContext) 18 | } 19 | 20 | // RequestContext returns the APIGatewayProxyRequestContext value stored in ctx. 21 | func RequestContext(ctx context.Context) (events.APIGatewayProxyRequestContext, bool) { 22 | c, ok := ctx.Value(requestContextKey).(events.APIGatewayProxyRequestContext) 23 | return c, ok 24 | } 25 | -------------------------------------------------------------------------------- /gateway.go: -------------------------------------------------------------------------------- 1 | // Package gateway provides a drop-in replacement for net/http.ListenAndServe for use in AWS Lambda & API Gateway. 2 | package gateway 3 | 4 | import ( 5 | "context" 6 | "encoding/json" 7 | "net/http" 8 | 9 | "github.com/aws/aws-lambda-go/events" 10 | "github.com/aws/aws-lambda-go/lambda" 11 | ) 12 | 13 | // ListenAndServe is a drop-in replacement for 14 | // http.ListenAndServe for use within AWS Lambda. 15 | // 16 | // ListenAndServe always returns a non-nil error. 17 | func ListenAndServe(addr string, h http.Handler) error { 18 | if h == nil { 19 | h = http.DefaultServeMux 20 | } 21 | 22 | gw := NewGateway(h) 23 | 24 | lambda.StartHandler(gw) 25 | 26 | return nil 27 | } 28 | 29 | // NewGateway creates a gateway using the provided http.Handler enabling use in existing aws-lambda-go 30 | // projects 31 | func NewGateway(h http.Handler) *Gateway { 32 | return &Gateway{h: h} 33 | } 34 | 35 | // Gateway wrap a http handler to enable use as a lambda.Handler 36 | type Gateway struct { 37 | h http.Handler 38 | } 39 | 40 | // Invoke Handler implementation 41 | func (gw *Gateway) Invoke(ctx context.Context, payload []byte) ([]byte, error) { 42 | evt := events.APIGatewayProxyRequest{} 43 | 44 | if err := json.Unmarshal(payload, &evt); err != nil { 45 | return nil, err 46 | } 47 | 48 | r, err := NewRequest(ctx, evt) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | w := NewResponse() 54 | gw.h.ServeHTTP(w, r) 55 | 56 | resp := w.End() 57 | 58 | return json.Marshal(&resp) 59 | } 60 | -------------------------------------------------------------------------------- /gateway_test.go: -------------------------------------------------------------------------------- 1 | package gateway_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/apex/gateway" 11 | "github.com/tj/assert" 12 | ) 13 | 14 | func Example() { 15 | http.HandleFunc("/", hello) 16 | log.Fatal(gateway.ListenAndServe(":3000", nil)) 17 | } 18 | 19 | func hello(w http.ResponseWriter, r *http.Request) { 20 | fmt.Fprintln(w, "Hello World from Go") 21 | } 22 | 23 | func TestGateway_Invoke(t *testing.T) { 24 | 25 | e := []byte(`{"version": "1.0", "rawPath": "/pets/luna", "requestContext": {"http": {"method": "POST"}}}`) 26 | 27 | gw := gateway.NewGateway(http.HandlerFunc(hello)) 28 | 29 | payload, err := gw.Invoke(context.Background(), e) 30 | assert.NoError(t, err) 31 | assert.JSONEq(t, `{"body":"Hello World from Go\n", "headers":{"Content-Type":"text/plain; charset=utf8"}, "multiValueHeaders":{}, "statusCode":200}`, string(payload)) 32 | } 33 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/apex/gateway 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/aws/aws-lambda-go v1.17.0 7 | github.com/pkg/errors v0.9.1 8 | github.com/tj/assert v0.0.3 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/aws/aws-lambda-go v1.17.0 h1:Ogihmi8BnpmCNktKAGpNwSiILNNING1MiosnKUfU8m0= 3 | github.com/aws/aws-lambda-go v1.17.0/go.mod h1:FEwgPLE6+8wcGBTe5cJN3JWurd1Ztm9zN4jsXsjzKKw= 4 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 5 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 10 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 14 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 17 | github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho= 18 | github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 20 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/tj/assert v0.0.1 h1:T7ozLNagrCCKl3wc+a706ztUCn/D6WHCJtkyvqYG+kQ= 22 | github.com/tj/assert v0.0.1/go.mod h1:lsg+GHQ0XplTcWKGxFLf/XPcPxWO8x2ut5jminoR2rA= 23 | github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= 24 | github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= 25 | github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 29 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= 32 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | -------------------------------------------------------------------------------- /request.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/aws/aws-lambda-go/events" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | // NewRequest returns a new http.Request from the given Lambda event. 17 | func NewRequest(ctx context.Context, e events.APIGatewayProxyRequest) (*http.Request, error) { 18 | // path 19 | u, err := url.Parse(e.Path) 20 | if err != nil { 21 | return nil, errors.Wrap(err, "parsing path") 22 | } 23 | 24 | // querystring 25 | q := u.Query() 26 | for k, v := range e.QueryStringParameters { 27 | q.Set(k, v) 28 | } 29 | 30 | for k, values := range e.MultiValueQueryStringParameters { 31 | q[k] = values 32 | } 33 | u.RawQuery = q.Encode() 34 | 35 | // base64 encoded body 36 | body := e.Body 37 | if e.IsBase64Encoded { 38 | b, err := base64.StdEncoding.DecodeString(body) 39 | if err != nil { 40 | return nil, errors.Wrap(err, "decoding base64 body") 41 | } 42 | body = string(b) 43 | } 44 | 45 | // new request 46 | req, err := http.NewRequest(e.HTTPMethod, u.String(), strings.NewReader(body)) 47 | if err != nil { 48 | return nil, errors.Wrap(err, "creating request") 49 | } 50 | 51 | // manually set RequestURI because NewRequest is for clients and req.RequestURI is for servers 52 | req.RequestURI = u.RequestURI() 53 | 54 | // remote addr 55 | req.RemoteAddr = e.RequestContext.Identity.SourceIP 56 | 57 | // header fields 58 | for k, v := range e.Headers { 59 | req.Header.Set(k, v) 60 | } 61 | 62 | for k, values := range e.MultiValueHeaders { 63 | req.Header[k] = values 64 | } 65 | 66 | // content-length 67 | if req.Header.Get("Content-Length") == "" && body != "" { 68 | req.Header.Set("Content-Length", strconv.Itoa(len(body))) 69 | } 70 | 71 | // custom fields 72 | req.Header.Set("X-Request-Id", e.RequestContext.RequestID) 73 | req.Header.Set("X-Stage", e.RequestContext.Stage) 74 | 75 | // custom context values 76 | req = req.WithContext(newContext(ctx, e)) 77 | 78 | // xray support 79 | if traceID := ctx.Value("x-amzn-trace-id"); traceID != nil { 80 | req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("%v", traceID)) 81 | } 82 | 83 | // host 84 | req.URL.Host = req.Header.Get("Host") 85 | req.Host = req.URL.Host 86 | 87 | return req, nil 88 | } 89 | -------------------------------------------------------------------------------- /request_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "testing" 7 | 8 | "github.com/aws/aws-lambda-go/events" 9 | "github.com/tj/assert" 10 | ) 11 | 12 | func TestNewRequest_path(t *testing.T) { 13 | e := events.APIGatewayProxyRequest{ 14 | Path: "/pets/luna", 15 | } 16 | 17 | r, err := NewRequest(context.Background(), e) 18 | assert.NoError(t, err) 19 | 20 | assert.Equal(t, "GET", r.Method) 21 | assert.Equal(t, `/pets/luna`, r.URL.Path) 22 | assert.Equal(t, `/pets/luna`, r.URL.String()) 23 | assert.Equal(t, `/pets/luna`, r.RequestURI) 24 | } 25 | 26 | func TestNewRequest_method(t *testing.T) { 27 | e := events.APIGatewayProxyRequest{ 28 | HTTPMethod: "DELETE", 29 | Path: "/pets/luna", 30 | } 31 | 32 | r, err := NewRequest(context.Background(), e) 33 | assert.NoError(t, err) 34 | 35 | assert.Equal(t, "DELETE", r.Method) 36 | } 37 | 38 | func TestNewRequest_queryString(t *testing.T) { 39 | e := events.APIGatewayProxyRequest{ 40 | HTTPMethod: "GET", 41 | Path: "/pets", 42 | QueryStringParameters: map[string]string{ 43 | "order": "desc", 44 | "fields": "name,species", 45 | }, 46 | } 47 | 48 | r, err := NewRequest(context.Background(), e) 49 | assert.NoError(t, err) 50 | 51 | assert.Equal(t, `/pets?fields=name%2Cspecies&order=desc`, r.URL.String()) 52 | assert.Equal(t, `desc`, r.URL.Query().Get("order")) 53 | } 54 | 55 | func TestNewRequest_multiValueQueryString(t *testing.T) { 56 | e := events.APIGatewayProxyRequest{ 57 | HTTPMethod: "GET", 58 | Path: "/pets", 59 | MultiValueQueryStringParameters: map[string][]string{ 60 | "multi_fields": []string{"name", "species"}, 61 | "multi_arr[]": []string{"arr1", "arr2"}, 62 | }, 63 | QueryStringParameters: map[string]string{ 64 | "order": "desc", 65 | "fields": "name,species", 66 | }, 67 | } 68 | 69 | r, err := NewRequest(context.Background(), e) 70 | assert.NoError(t, err) 71 | 72 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.URL.String()) 73 | assert.Equal(t, []string{"name", "species"}, r.URL.Query()["multi_fields"]) 74 | assert.Equal(t, []string{"arr1", "arr2"}, r.URL.Query()["multi_arr[]"]) 75 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.RequestURI) 76 | } 77 | 78 | func TestNewRequest_remoteAddr(t *testing.T) { 79 | e := events.APIGatewayProxyRequest{ 80 | HTTPMethod: "GET", 81 | Path: "/pets", 82 | RequestContext: events.APIGatewayProxyRequestContext{ 83 | Identity: events.APIGatewayRequestIdentity{ 84 | SourceIP: "1.2.3.4", 85 | }, 86 | }, 87 | } 88 | 89 | r, err := NewRequest(context.Background(), e) 90 | assert.NoError(t, err) 91 | 92 | assert.Equal(t, `1.2.3.4`, r.RemoteAddr) 93 | } 94 | 95 | func TestNewRequest_header(t *testing.T) { 96 | e := events.APIGatewayProxyRequest{ 97 | HTTPMethod: "POST", 98 | Path: "/pets", 99 | Body: `{ "name": "Tobi" }`, 100 | Headers: map[string]string{ 101 | "Content-Type": "application/json", 102 | "X-Foo": "bar", 103 | "Host": "example.com", 104 | }, 105 | RequestContext: events.APIGatewayProxyRequestContext{ 106 | RequestID: "1234", 107 | Stage: "prod", 108 | }, 109 | } 110 | 111 | r, err := NewRequest(context.Background(), e) 112 | assert.NoError(t, err) 113 | 114 | assert.Equal(t, `example.com`, r.Host) 115 | assert.Equal(t, `prod`, r.Header.Get("X-Stage")) 116 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id")) 117 | assert.Equal(t, `18`, r.Header.Get("Content-Length")) 118 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type")) 119 | assert.Equal(t, `bar`, r.Header.Get("X-Foo")) 120 | } 121 | 122 | func TestNewRequest_multiHeader(t *testing.T) { 123 | e := events.APIGatewayProxyRequest{ 124 | HTTPMethod: "POST", 125 | Path: "/pets", 126 | Body: `{ "name": "Tobi" }`, 127 | MultiValueHeaders: map[string][]string{ 128 | "X-APEX": []string{"apex1", "apex2"}, 129 | "X-APEX-2": []string{"apex-1", "apex-2"}, 130 | }, 131 | Headers: map[string]string{ 132 | "Content-Type": "application/json", 133 | "X-Foo": "bar", 134 | "Host": "example.com", 135 | }, 136 | RequestContext: events.APIGatewayProxyRequestContext{ 137 | RequestID: "1234", 138 | Stage: "prod", 139 | }, 140 | } 141 | 142 | r, err := NewRequest(context.Background(), e) 143 | assert.NoError(t, err) 144 | 145 | assert.Equal(t, `example.com`, r.Host) 146 | assert.Equal(t, `prod`, r.Header.Get("X-Stage")) 147 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id")) 148 | assert.Equal(t, `18`, r.Header.Get("Content-Length")) 149 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type")) 150 | assert.Equal(t, `bar`, r.Header.Get("X-Foo")) 151 | assert.Equal(t, []string{"apex1", "apex2"}, r.Header["X-APEX"]) 152 | assert.Equal(t, []string{"apex-1", "apex-2"}, r.Header["X-APEX-2"]) 153 | } 154 | 155 | func TestNewRequest_body(t *testing.T) { 156 | e := events.APIGatewayProxyRequest{ 157 | HTTPMethod: "POST", 158 | Path: "/pets", 159 | Body: `{ "name": "Tobi" }`, 160 | } 161 | 162 | r, err := NewRequest(context.Background(), e) 163 | assert.NoError(t, err) 164 | 165 | b, err := ioutil.ReadAll(r.Body) 166 | assert.NoError(t, err) 167 | 168 | assert.Equal(t, `{ "name": "Tobi" }`, string(b)) 169 | } 170 | 171 | func TestNewRequest_bodyBinary(t *testing.T) { 172 | e := events.APIGatewayProxyRequest{ 173 | HTTPMethod: "POST", 174 | Path: "/pets", 175 | Body: `aGVsbG8gd29ybGQK`, 176 | IsBase64Encoded: true, 177 | } 178 | 179 | r, err := NewRequest(context.Background(), e) 180 | assert.NoError(t, err) 181 | 182 | b, err := ioutil.ReadAll(r.Body) 183 | assert.NoError(t, err) 184 | 185 | assert.Equal(t, "hello world\n", string(b)) 186 | } 187 | 188 | func TestNewRequest_context(t *testing.T) { 189 | e := events.APIGatewayProxyRequest{} 190 | ctx := context.WithValue(context.Background(), "key", "value") 191 | r, err := NewRequest(ctx, e) 192 | assert.NoError(t, err) 193 | v := r.Context().Value("key") 194 | assert.Equal(t, "value", v) 195 | } 196 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "mime" 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/aws/aws-lambda-go/events" 11 | ) 12 | 13 | // ResponseWriter implements the http.ResponseWriter interface 14 | // in order to support the API Gateway Lambda HTTP "protocol". 15 | type ResponseWriter struct { 16 | out events.APIGatewayProxyResponse 17 | buf bytes.Buffer 18 | header http.Header 19 | wroteHeader bool 20 | closeNotifyCh chan bool 21 | } 22 | 23 | // NewResponse returns a new response writer to capture http output. 24 | func NewResponse() *ResponseWriter { 25 | return &ResponseWriter{ 26 | closeNotifyCh: make(chan bool, 1), 27 | } 28 | } 29 | 30 | // Header implementation. 31 | func (w *ResponseWriter) Header() http.Header { 32 | if w.header == nil { 33 | w.header = make(http.Header) 34 | } 35 | 36 | return w.header 37 | } 38 | 39 | // Write implementation. 40 | func (w *ResponseWriter) Write(b []byte) (int, error) { 41 | if !w.wroteHeader { 42 | w.WriteHeader(http.StatusOK) 43 | } 44 | 45 | return w.buf.Write(b) 46 | } 47 | 48 | // WriteHeader implementation. 49 | func (w *ResponseWriter) WriteHeader(status int) { 50 | if w.wroteHeader { 51 | return 52 | } 53 | 54 | if w.Header().Get("Content-Type") == "" { 55 | w.Header().Set("Content-Type", "text/plain; charset=utf8") 56 | } 57 | 58 | w.out.StatusCode = status 59 | 60 | h := make(map[string]string) 61 | mvh := make(map[string][]string) 62 | 63 | for k, v := range w.Header() { 64 | if len(v) == 1 { 65 | h[k] = v[0] 66 | } else if len(v) > 1 { 67 | mvh[k] = v 68 | } 69 | } 70 | 71 | w.out.Headers = h 72 | w.out.MultiValueHeaders = mvh 73 | w.wroteHeader = true 74 | } 75 | 76 | // CloseNotify notify when the response is closed 77 | func (w *ResponseWriter) CloseNotify() <-chan bool { 78 | return w.closeNotifyCh 79 | } 80 | 81 | // End the request. 82 | func (w *ResponseWriter) End() events.APIGatewayProxyResponse { 83 | w.out.IsBase64Encoded = isBinary(w.header) 84 | 85 | if w.out.IsBase64Encoded { 86 | w.out.Body = base64.StdEncoding.EncodeToString(w.buf.Bytes()) 87 | } else { 88 | w.out.Body = w.buf.String() 89 | } 90 | 91 | // notify end 92 | w.closeNotifyCh <- true 93 | 94 | return w.out 95 | } 96 | 97 | // isBinary returns true if the response reprensents binary. 98 | func isBinary(h http.Header) bool { 99 | switch { 100 | case !isTextMime(h.Get("Content-Type")): 101 | return true 102 | case h.Get("Content-Encoding") == "gzip": 103 | return true 104 | default: 105 | return false 106 | } 107 | } 108 | 109 | // isTextMime returns true if the content type represents textual data. 110 | func isTextMime(kind string) bool { 111 | mt, _, err := mime.ParseMediaType(kind) 112 | if err != nil { 113 | return false 114 | } 115 | 116 | if strings.HasPrefix(mt, "text/") { 117 | return true 118 | } 119 | 120 | switch mt { 121 | case "image/svg+xml", "application/json", "application/xml","application/javascript": 122 | return true 123 | default: 124 | return false 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /response_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/tj/assert" 8 | ) 9 | 10 | func Test_JSON_isTextMime(t *testing.T) { 11 | assert.Equal(t, isTextMime("application/json"), true) 12 | assert.Equal(t, isTextMime("application/json; charset=utf-8"), true) 13 | assert.Equal(t, isTextMime("Application/JSON"), true) 14 | } 15 | 16 | func Test_XML_isTextMime(t *testing.T) { 17 | assert.Equal(t, isTextMime("application/xml"), true) 18 | assert.Equal(t, isTextMime("application/xml; charset=utf-8"), true) 19 | assert.Equal(t, isTextMime("ApPlicaTion/xMl"), true) 20 | } 21 | 22 | func TestResponseWriter_Header(t *testing.T) { 23 | w := NewResponse() 24 | w.Header().Set("Foo", "bar") 25 | w.Header().Set("Bar", "baz") 26 | 27 | var buf bytes.Buffer 28 | w.header.Write(&buf) 29 | 30 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\n", buf.String()) 31 | } 32 | 33 | func TestResponseWriter_multiHeader(t *testing.T) { 34 | w := NewResponse() 35 | w.Header().Set("Foo", "bar") 36 | w.Header().Set("Bar", "baz") 37 | w.Header().Add("X-APEX", "apex1") 38 | w.Header().Add("X-APEX", "apex2") 39 | 40 | var buf bytes.Buffer 41 | w.header.Write(&buf) 42 | 43 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\nX-Apex: apex1\r\nX-Apex: apex2\r\n", buf.String()) 44 | } 45 | 46 | func TestResponseWriter_Write_text(t *testing.T) { 47 | types := []string{ 48 | "text/x-custom", 49 | "text/plain", 50 | "text/plain; charset=utf-8", 51 | "application/json", 52 | "application/json; charset=utf-8", 53 | "application/xml", 54 | "image/svg+xml", 55 | } 56 | 57 | for _, kind := range types { 58 | t.Run(kind, func(t *testing.T) { 59 | w := NewResponse() 60 | w.Header().Set("Content-Type", kind) 61 | w.Write([]byte("hello world\n")) 62 | 63 | e := w.End() 64 | assert.Equal(t, 200, e.StatusCode) 65 | assert.Equal(t, "hello world\n", e.Body) 66 | assert.Equal(t, kind, e.Headers["Content-Type"]) 67 | assert.False(t, e.IsBase64Encoded) 68 | assert.True(t, <-w.CloseNotify()) 69 | }) 70 | } 71 | } 72 | 73 | func TestResponseWriter_Write_binary(t *testing.T) { 74 | w := NewResponse() 75 | w.Header().Set("Content-Type", "image/png") 76 | w.Write([]byte("data")) 77 | 78 | e := w.End() 79 | assert.Equal(t, 200, e.StatusCode) 80 | assert.Equal(t, "ZGF0YQ==", e.Body) 81 | assert.Equal(t, "image/png", e.Headers["Content-Type"]) 82 | assert.True(t, e.IsBase64Encoded) 83 | } 84 | 85 | func TestResponseWriter_Write_gzip(t *testing.T) { 86 | w := NewResponse() 87 | w.Header().Set("Content-Type", "text/plain") 88 | w.Header().Set("Content-Encoding", "gzip") 89 | w.Write([]byte("data")) 90 | 91 | e := w.End() 92 | assert.Equal(t, 200, e.StatusCode) 93 | assert.Equal(t, "ZGF0YQ==", e.Body) 94 | assert.Equal(t, "text/plain", e.Headers["Content-Type"]) 95 | assert.True(t, e.IsBase64Encoded) 96 | } 97 | 98 | func TestResponseWriter_WriteHeader(t *testing.T) { 99 | w := NewResponse() 100 | w.WriteHeader(404) 101 | w.Write([]byte("Not Found\n")) 102 | 103 | e := w.End() 104 | assert.Equal(t, 404, e.StatusCode) 105 | assert.Equal(t, "Not Found\n", e.Body) 106 | assert.Equal(t, "text/plain; charset=utf8", e.Headers["Content-Type"]) 107 | } 108 | -------------------------------------------------------------------------------- /v2/context.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-lambda-go/events" 7 | ) 8 | 9 | // key is the type used for any items added to the request context. 10 | type key int 11 | 12 | // requestContextKey is the key for the api gateway proxy `RequestContext`. 13 | const requestContextKey key = iota 14 | 15 | // RequestContext returns the APIGatewayV2HTTPRequestContext value stored in ctx. 16 | func RequestContext(ctx context.Context) (events.APIGatewayV2HTTPRequestContext, bool) { 17 | c, ok := ctx.Value(requestContextKey).(events.APIGatewayV2HTTPRequestContext) 18 | return c, ok 19 | } 20 | 21 | // newContext returns a new Context with specific api gateway v2 values. 22 | func newContext(ctx context.Context, e events.APIGatewayV2HTTPRequest) context.Context { 23 | return context.WithValue(ctx, requestContextKey, e.RequestContext) 24 | } 25 | -------------------------------------------------------------------------------- /v2/gateway.go: -------------------------------------------------------------------------------- 1 | // Package gateway provides a drop-in replacement for net/http.ListenAndServe for use in AWS Lambda & API Gateway. 2 | package gateway 3 | 4 | import ( 5 | "context" 6 | "encoding/json" 7 | "net/http" 8 | 9 | "github.com/aws/aws-lambda-go/events" 10 | "github.com/aws/aws-lambda-go/lambda" 11 | ) 12 | 13 | // ListenAndServe is a drop-in replacement for 14 | // http.ListenAndServe for use within AWS Lambda. 15 | // 16 | // ListenAndServe always returns a non-nil error. 17 | func ListenAndServe(addr string, h http.Handler) error { 18 | if h == nil { 19 | h = http.DefaultServeMux 20 | } 21 | 22 | gw := NewGateway(h) 23 | 24 | lambda.StartHandler(gw) 25 | 26 | return nil 27 | } 28 | 29 | // NewGateway creates a gateway using the provided http.Handler enabling use in existing aws-lambda-go 30 | // projects 31 | func NewGateway(h http.Handler) *Gateway { 32 | return &Gateway{h: h} 33 | } 34 | 35 | // Gateway wrap a http handler to enable use as a lambda.Handler 36 | type Gateway struct { 37 | h http.Handler 38 | } 39 | 40 | // Invoke Handler implementation 41 | func (gw *Gateway) Invoke(ctx context.Context, payload []byte) ([]byte, error) { 42 | var evt events.APIGatewayV2HTTPRequest 43 | 44 | if err := json.Unmarshal(payload, &evt); err != nil { 45 | return []byte{}, err 46 | } 47 | 48 | r, err := NewRequest(ctx, evt) 49 | if err != nil { 50 | return []byte{}, err 51 | } 52 | 53 | w := NewResponse() 54 | gw.h.ServeHTTP(w, r) 55 | 56 | resp := w.End() 57 | 58 | return json.Marshal(&resp) 59 | } 60 | -------------------------------------------------------------------------------- /v2/gateway_test.go: -------------------------------------------------------------------------------- 1 | package gateway_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/apex/gateway/v2" 11 | "github.com/tj/assert" 12 | ) 13 | 14 | func Example() { 15 | http.HandleFunc("/", hello) 16 | log.Fatal(gateway.ListenAndServe(":3000", nil)) 17 | } 18 | 19 | func hello(w http.ResponseWriter, r *http.Request) { 20 | fmt.Fprintln(w, "Hello World from Go") 21 | } 22 | 23 | func TestGateway_Invoke(t *testing.T) { 24 | 25 | e := []byte(`{"version": "2.0", "rawPath": "/pets/luna", "requestContext": {"http": {"method": "POST"}}}`) 26 | 27 | gw := gateway.NewGateway(http.HandlerFunc(hello)) 28 | 29 | payload, err := gw.Invoke(context.Background(), e) 30 | assert.NoError(t, err) 31 | assert.JSONEq(t, `{"body":"Hello World from Go\n", "cookies": null, "headers":{"Content-Type":"text/plain; charset=utf8"}, "multiValueHeaders":{}, "statusCode":200}`, string(payload)) 32 | } 33 | -------------------------------------------------------------------------------- /v2/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/apex/gateway/v2 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/aws/aws-lambda-go v1.17.0 7 | github.com/pkg/errors v0.9.1 8 | github.com/tj/assert v0.0.3 9 | ) 10 | -------------------------------------------------------------------------------- /v2/go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/aws/aws-lambda-go v1.17.0 h1:Ogihmi8BnpmCNktKAGpNwSiILNNING1MiosnKUfU8m0= 3 | github.com/aws/aws-lambda-go v1.17.0/go.mod h1:FEwgPLE6+8wcGBTe5cJN3JWurd1Ztm9zN4jsXsjzKKw= 4 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 5 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 10 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 14 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 17 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 18 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= 20 | github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= 21 | github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= 22 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 23 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 24 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 25 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 26 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 27 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= 28 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /v2/request.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/aws/aws-lambda-go/events" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | // NewRequest returns a new http.Request from the given Lambda event. 17 | func NewRequest(ctx context.Context, e events.APIGatewayV2HTTPRequest) (*http.Request, error) { 18 | // path 19 | u, err := url.Parse(e.RawPath) 20 | if err != nil { 21 | return nil, errors.Wrap(err, "parsing path") 22 | } 23 | 24 | u.RawQuery = e.RawQueryString 25 | 26 | // base64 encoded body 27 | body := e.Body 28 | if e.IsBase64Encoded { 29 | b, err := base64.StdEncoding.DecodeString(body) 30 | if err != nil { 31 | return nil, errors.Wrap(err, "decoding base64 body") 32 | } 33 | body = string(b) 34 | } 35 | 36 | // new request 37 | req, err := http.NewRequest(e.RequestContext.HTTP.Method, u.String(), strings.NewReader(body)) 38 | if err != nil { 39 | return nil, errors.Wrap(err, "creating request") 40 | } 41 | 42 | // manually set RequestURI because NewRequest is for clients and req.RequestURI is for servers 43 | req.RequestURI = u.RequestURI() 44 | 45 | // remote addr 46 | req.RemoteAddr = e.RequestContext.HTTP.SourceIP 47 | 48 | // header fields 49 | for k, values := range e.Headers { 50 | for _, v := range strings.Split(values, ",") { 51 | req.Header.Add(k, v) 52 | } 53 | } 54 | for _, c := range e.Cookies { 55 | req.Header.Add("Cookie", c) 56 | } 57 | 58 | // content-length 59 | if req.Header.Get("Content-Length") == "" && body != "" { 60 | req.Header.Set("Content-Length", strconv.Itoa(len(body))) 61 | } 62 | 63 | // custom fields 64 | req.Header.Set("X-Request-Id", e.RequestContext.RequestID) 65 | req.Header.Set("X-Stage", e.RequestContext.Stage) 66 | 67 | // custom context values 68 | req = req.WithContext(newContext(ctx, e)) 69 | 70 | // xray support 71 | if traceID := ctx.Value("x-amzn-trace-id"); traceID != nil { 72 | req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("%v", traceID)) 73 | } 74 | 75 | // host 76 | req.URL.Host = req.Header.Get("Host") 77 | req.Host = req.URL.Host 78 | 79 | return req, nil 80 | } 81 | -------------------------------------------------------------------------------- /v2/request_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/aws/aws-lambda-go/events" 10 | "github.com/tj/assert" 11 | ) 12 | 13 | func TestDecodeRequest_path(t *testing.T) { 14 | e := events.APIGatewayV2HTTPRequest{ 15 | RawPath: "/pets/luna", 16 | } 17 | 18 | r, err := NewRequest(context.Background(), e) 19 | assert.NoError(t, err) 20 | 21 | assert.Equal(t, "GET", r.Method) 22 | assert.Equal(t, `/pets/luna`, r.URL.Path) 23 | assert.Equal(t, `/pets/luna`, r.URL.String()) 24 | assert.Equal(t, `/pets/luna`, r.RequestURI) 25 | } 26 | 27 | func TestDecodeRequest_method(t *testing.T) { 28 | e := events.APIGatewayV2HTTPRequest{ 29 | RawPath: "/pets/luna", 30 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 31 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 32 | Method: "DELETE", 33 | Path: "/pets/luna", 34 | }, 35 | }, 36 | } 37 | 38 | r, err := NewRequest(context.Background(), e) 39 | assert.NoError(t, err) 40 | 41 | assert.Equal(t, "DELETE", r.Method) 42 | } 43 | 44 | func TestDecodeRequest_queryString(t *testing.T) { 45 | e := events.APIGatewayV2HTTPRequest{ 46 | RawPath: "/pets", 47 | RawQueryString: "fields=name%2Cspecies&order=desc", 48 | QueryStringParameters: map[string]string{ 49 | "order": "desc", 50 | "fields": "name,species", 51 | }, 52 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 53 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 54 | Method: "GET", 55 | Path: "/pets", 56 | }, 57 | }, 58 | } 59 | 60 | r, err := NewRequest(context.Background(), e) 61 | assert.NoError(t, err) 62 | 63 | assert.Equal(t, `/pets?fields=name%2Cspecies&order=desc`, r.URL.String()) 64 | assert.Equal(t, `desc`, r.URL.Query().Get("order")) 65 | } 66 | 67 | func TestDecodeRequest_multiValueQueryString(t *testing.T) { 68 | e := events.APIGatewayV2HTTPRequest{ 69 | RawPath: "/pets", 70 | RawQueryString: "fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc", 71 | QueryStringParameters: map[string]string{ 72 | "multi_fields": strings.Join([]string{"name", "species"}, ","), 73 | "multi_arr[]": strings.Join([]string{"arr1", "arr2"}, ","), 74 | "order": "desc", 75 | "fields": "name,species", 76 | }, 77 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 78 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 79 | Method: "GET", 80 | Path: "/pets", 81 | }, 82 | }, 83 | } 84 | 85 | r, err := NewRequest(context.Background(), e) 86 | assert.NoError(t, err) 87 | 88 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.URL.String()) 89 | assert.Equal(t, []string{"name", "species"}, r.URL.Query()["multi_fields"]) 90 | assert.Equal(t, []string{"arr1", "arr2"}, r.URL.Query()["multi_arr[]"]) 91 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.RequestURI) 92 | } 93 | 94 | func TestDecodeRequest_remoteAddr(t *testing.T) { 95 | e := events.APIGatewayV2HTTPRequest{ 96 | RawPath: "/pets", 97 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 98 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 99 | Method: "GET", 100 | Path: "/pets", 101 | SourceIP: "1.2.3.4", 102 | }, 103 | }, 104 | } 105 | 106 | r, err := NewRequest(context.Background(), e) 107 | assert.NoError(t, err) 108 | 109 | assert.Equal(t, `1.2.3.4`, r.RemoteAddr) 110 | } 111 | 112 | func TestDecodeRequest_header(t *testing.T) { 113 | e := events.APIGatewayV2HTTPRequest{ 114 | RawPath: "/pets", 115 | Body: `{ "name": "Tobi" }`, 116 | Headers: map[string]string{ 117 | "Content-Type": "application/json", 118 | "X-Foo": "bar", 119 | "Host": "example.com", 120 | }, 121 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 122 | RequestID: "1234", 123 | Stage: "prod", 124 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 125 | Path: "/pets", 126 | Method: "POST", 127 | }, 128 | }, 129 | } 130 | 131 | r, err := NewRequest(context.Background(), e) 132 | assert.NoError(t, err) 133 | 134 | assert.Equal(t, `example.com`, r.Host) 135 | assert.Equal(t, `prod`, r.Header.Get("X-Stage")) 136 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id")) 137 | assert.Equal(t, `18`, r.Header.Get("Content-Length")) 138 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type")) 139 | assert.Equal(t, `bar`, r.Header.Get("X-Foo")) 140 | } 141 | 142 | func TestDecodeRequest_multiHeader(t *testing.T) { 143 | e := events.APIGatewayV2HTTPRequest{ 144 | RawPath: "/pets", 145 | Body: `{ "name": "Tobi" }`, 146 | Headers: map[string]string{ 147 | "X-APEX": strings.Join([]string{"apex1", "apex2"}, ","), 148 | "X-APEX-2": strings.Join([]string{"apex-1", "apex-2"}, ","), 149 | "Content-Type": "application/json", 150 | "X-Foo": "bar", 151 | "Host": "example.com", 152 | }, 153 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 154 | RequestID: "1234", 155 | Stage: "prod", 156 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 157 | Path: "/pets", 158 | Method: "POST", 159 | }, 160 | }, 161 | } 162 | 163 | r, err := NewRequest(context.Background(), e) 164 | assert.NoError(t, err) 165 | 166 | assert.Equal(t, `example.com`, r.Host) 167 | assert.Equal(t, `prod`, r.Header.Get("X-Stage")) 168 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id")) 169 | assert.Equal(t, `18`, r.Header.Get("Content-Length")) 170 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type")) 171 | assert.Equal(t, `bar`, r.Header.Get("X-Foo")) 172 | assert.Equal(t, []string{"apex1", "apex2"}, r.Header["X-Apex"]) 173 | assert.Equal(t, []string{"apex-1", "apex-2"}, r.Header["X-Apex-2"]) 174 | } 175 | 176 | func TestDecodeRequest_cookie(t *testing.T) { 177 | e := events.APIGatewayV2HTTPRequest{ 178 | RawPath: "/pets", 179 | Body: `{ "name": "Tobi" }`, 180 | Headers: map[string]string{}, 181 | Cookies: []string{"TEST_COOKIE=TEST-VALUE"}, 182 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 183 | RequestID: "1234", 184 | Stage: "prod", 185 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 186 | Path: "/pets", 187 | Method: "POST", 188 | }, 189 | }, 190 | } 191 | 192 | r, err := NewRequest(context.Background(), e) 193 | assert.NoError(t, err) 194 | 195 | c, err := r.Cookie("TEST_COOKIE") 196 | assert.NoError(t, err) 197 | 198 | assert.Equal(t, "TEST-VALUE", c.Value) 199 | } 200 | 201 | func TestDecodeRequest_body(t *testing.T) { 202 | e := events.APIGatewayV2HTTPRequest{ 203 | RawPath: "/pets", 204 | Body: `{ "name": "Tobi" }`, 205 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 206 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 207 | Method: "POST", 208 | Path: "/pets", 209 | }, 210 | }, 211 | } 212 | 213 | r, err := NewRequest(context.Background(), e) 214 | assert.NoError(t, err) 215 | 216 | b, err := ioutil.ReadAll(r.Body) 217 | assert.NoError(t, err) 218 | 219 | assert.Equal(t, `{ "name": "Tobi" }`, string(b)) 220 | } 221 | 222 | func TestDecodeRequest_bodyBinary(t *testing.T) { 223 | e := events.APIGatewayV2HTTPRequest{ 224 | RawPath: "/pets", 225 | Body: `aGVsbG8gd29ybGQK`, 226 | IsBase64Encoded: true, 227 | RequestContext: events.APIGatewayV2HTTPRequestContext{ 228 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ 229 | Method: "POST", 230 | Path: "/pets", 231 | }, 232 | }, 233 | } 234 | 235 | r, err := NewRequest(context.Background(), e) 236 | assert.NoError(t, err) 237 | 238 | b, err := ioutil.ReadAll(r.Body) 239 | assert.NoError(t, err) 240 | 241 | assert.Equal(t, "hello world\n", string(b)) 242 | } 243 | 244 | func TestDecodeRequest_context(t *testing.T) { 245 | e := events.APIGatewayV2HTTPRequest{} 246 | ctx := context.WithValue(context.Background(), "key", "value") 247 | r, err := NewRequest(ctx, e) 248 | assert.NoError(t, err) 249 | v := r.Context().Value("key") 250 | assert.Equal(t, "value", v) 251 | } 252 | -------------------------------------------------------------------------------- /v2/response.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "mime" 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/aws/aws-lambda-go/events" 11 | ) 12 | 13 | // ResponseWriter implements the http.ResponseWriter interface 14 | // in order to support the API Gateway Lambda HTTP "protocol". 15 | type ResponseWriter struct { 16 | out events.APIGatewayV2HTTPResponse 17 | buf bytes.Buffer 18 | header http.Header 19 | wroteHeader bool 20 | closeNotifyCh chan bool 21 | } 22 | 23 | // NewResponse returns a new response writer to capture http output. 24 | func NewResponse() *ResponseWriter { 25 | return &ResponseWriter{ 26 | closeNotifyCh: make(chan bool, 1), 27 | } 28 | } 29 | 30 | // Header implementation. 31 | func (w *ResponseWriter) Header() http.Header { 32 | if w.header == nil { 33 | w.header = make(http.Header) 34 | } 35 | 36 | return w.header 37 | } 38 | 39 | // Write implementation. 40 | func (w *ResponseWriter) Write(b []byte) (int, error) { 41 | if !w.wroteHeader { 42 | w.WriteHeader(http.StatusOK) 43 | } 44 | 45 | return w.buf.Write(b) 46 | } 47 | 48 | // WriteHeader implementation. 49 | func (w *ResponseWriter) WriteHeader(status int) { 50 | if w.wroteHeader { 51 | return 52 | } 53 | 54 | if w.Header().Get("Content-Type") == "" { 55 | w.Header().Set("Content-Type", "text/plain; charset=utf8") 56 | } 57 | 58 | w.out.StatusCode = status 59 | 60 | h := make(map[string]string) 61 | mvh := make(map[string][]string) 62 | 63 | for k, v := range w.Header() { 64 | if len(v) == 1 { 65 | h[k] = v[0] 66 | } else if len(v) > 1 { 67 | mvh[k] = v 68 | } 69 | } 70 | 71 | w.out.Headers = h 72 | w.out.MultiValueHeaders = mvh 73 | w.wroteHeader = true 74 | } 75 | 76 | // CloseNotify notify when the response is closed 77 | func (w *ResponseWriter) CloseNotify() <-chan bool { 78 | return w.closeNotifyCh 79 | } 80 | 81 | // End the request. 82 | func (w *ResponseWriter) End() events.APIGatewayV2HTTPResponse { 83 | w.out.IsBase64Encoded = isBinary(w.header) 84 | 85 | if w.out.IsBase64Encoded { 86 | w.out.Body = base64.StdEncoding.EncodeToString(w.buf.Bytes()) 87 | } else { 88 | w.out.Body = w.buf.String() 89 | } 90 | 91 | // see https://aws.amazon.com/blogs/compute/simply-serverless-using-aws-lambda-to-expose-custom-cookies-with-api-gateway/ 92 | w.out.Cookies = w.header["Set-Cookie"] 93 | w.header.Del("Set-Cookie") 94 | 95 | // notify end 96 | w.closeNotifyCh <- true 97 | 98 | return w.out 99 | } 100 | 101 | // isBinary returns true if the response reprensents binary. 102 | func isBinary(h http.Header) bool { 103 | switch { 104 | case !isTextMime(h.Get("Content-Type")): 105 | return true 106 | case h.Get("Content-Encoding") == "gzip": 107 | return true 108 | default: 109 | return false 110 | } 111 | } 112 | 113 | // isTextMime returns true if the content type represents textual data. 114 | func isTextMime(kind string) bool { 115 | mt, _, err := mime.ParseMediaType(kind) 116 | if err != nil { 117 | return false 118 | } 119 | 120 | if strings.HasPrefix(mt, "text/") { 121 | return true 122 | } 123 | 124 | switch mt { 125 | case "image/svg+xml", "application/json", "application/xml", "application/javascript": 126 | return true 127 | default: 128 | return false 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /v2/response_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/tj/assert" 8 | ) 9 | 10 | func Test_JSON_isTextMime(t *testing.T) { 11 | assert.Equal(t, isTextMime("application/json"), true) 12 | assert.Equal(t, isTextMime("application/json; charset=utf-8"), true) 13 | assert.Equal(t, isTextMime("Application/JSON"), true) 14 | } 15 | 16 | func Test_XML_isTextMime(t *testing.T) { 17 | assert.Equal(t, isTextMime("application/xml"), true) 18 | assert.Equal(t, isTextMime("application/xml; charset=utf-8"), true) 19 | assert.Equal(t, isTextMime("ApPlicaTion/xMl"), true) 20 | } 21 | 22 | func TestResponseWriter_Header(t *testing.T) { 23 | w := NewResponse() 24 | w.Header().Set("Foo", "bar") 25 | w.Header().Set("Bar", "baz") 26 | 27 | var buf bytes.Buffer 28 | w.header.Write(&buf) 29 | 30 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\n", buf.String()) 31 | } 32 | 33 | func TestResponseWriter_multiHeader(t *testing.T) { 34 | w := NewResponse() 35 | w.Header().Set("Foo", "bar") 36 | w.Header().Set("Bar", "baz") 37 | w.Header().Add("X-APEX", "apex1") 38 | w.Header().Add("X-APEX", "apex2") 39 | 40 | var buf bytes.Buffer 41 | w.header.Write(&buf) 42 | 43 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\nX-Apex: apex1\r\nX-Apex: apex2\r\n", buf.String()) 44 | } 45 | 46 | func TestResponseWriter_Write_text(t *testing.T) { 47 | types := []string{ 48 | "text/x-custom", 49 | "text/plain", 50 | "text/plain; charset=utf-8", 51 | "application/json", 52 | "application/json; charset=utf-8", 53 | "application/xml", 54 | "image/svg+xml", 55 | } 56 | 57 | for _, kind := range types { 58 | t.Run(kind, func(t *testing.T) { 59 | w := NewResponse() 60 | w.Header().Set("Content-Type", kind) 61 | w.Write([]byte("hello world\n")) 62 | 63 | e := w.End() 64 | assert.Equal(t, 200, e.StatusCode) 65 | assert.Equal(t, "hello world\n", e.Body) 66 | assert.Equal(t, kind, e.Headers["Content-Type"]) 67 | assert.False(t, e.IsBase64Encoded) 68 | assert.True(t, <-w.CloseNotify()) 69 | }) 70 | } 71 | } 72 | 73 | func TestResponseWriter_Write_binary(t *testing.T) { 74 | w := NewResponse() 75 | w.Header().Set("Content-Type", "image/png") 76 | w.Write([]byte("data")) 77 | 78 | e := w.End() 79 | assert.Equal(t, 200, e.StatusCode) 80 | assert.Equal(t, "ZGF0YQ==", e.Body) 81 | assert.Equal(t, "image/png", e.Headers["Content-Type"]) 82 | assert.True(t, e.IsBase64Encoded) 83 | } 84 | 85 | func TestResponseWriter_Write_gzip(t *testing.T) { 86 | w := NewResponse() 87 | w.Header().Set("Content-Type", "text/plain") 88 | w.Header().Set("Content-Encoding", "gzip") 89 | w.Write([]byte("data")) 90 | 91 | e := w.End() 92 | assert.Equal(t, 200, e.StatusCode) 93 | assert.Equal(t, "ZGF0YQ==", e.Body) 94 | assert.Equal(t, "text/plain", e.Headers["Content-Type"]) 95 | assert.True(t, e.IsBase64Encoded) 96 | } 97 | 98 | func TestResponseWriter_WriteHeader(t *testing.T) { 99 | w := NewResponse() 100 | w.WriteHeader(404) 101 | w.Write([]byte("Not Found\n")) 102 | 103 | e := w.End() 104 | assert.Equal(t, 404, e.StatusCode) 105 | assert.Equal(t, "Not Found\n", e.Body) 106 | assert.Equal(t, "text/plain; charset=utf8", e.Headers["Content-Type"]) 107 | } 108 | --------------------------------------------------------------------------------