├── .github
└── workflows
│ └── test.yml
├── History.md
├── LICENSE
├── Readme.md
├── ci.yml
├── context.go
├── gateway.go
├── gateway_test.go
├── go.mod
├── go.sum
├── request.go
├── request_test.go
├── response.go
├── response_test.go
└── v2
├── context.go
├── gateway.go
├── gateway_test.go
├── go.mod
├── go.sum
├── request.go
├── request_test.go
├── response.go
└── response_test.go
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | on: [push, pull_request]
2 | name: Tests
3 | jobs:
4 | test:
5 | strategy:
6 | matrix:
7 | go-version: [1.13.x]
8 | platform: [ubuntu-latest, macos-latest, windows-latest]
9 | runs-on: ${{ matrix.platform }}
10 | steps:
11 | - name: Install Go
12 | uses: actions/setup-go@v1
13 | with:
14 | go-version: ${{ matrix.go-version }}
15 | - name: Checkout code
16 | uses: actions/checkout@v1
17 |
18 | - name: Test v1
19 | run: go test -v -cover ./...
20 |
21 | - name: Test v2
22 | run: go test -v -cover ./...
23 | working-directory: v2
24 |
--------------------------------------------------------------------------------
/History.md:
--------------------------------------------------------------------------------
1 |
2 | v1.1.1 / 2018-08-17
3 | ===================
4 |
5 | * fix passing of request context
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2018 TJ Holowaychuk tj@tjholowaychuk.com
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining
6 | a copy of this software and associated documentation files (the
7 | 'Software'), to deal in the Software without restriction, including
8 | without limitation the rights to use, copy, modify, merge, publish,
9 | distribute, sublicense, and/or sell copies of the Software, and to
10 | permit persons to whom the Software is furnished to do so, subject to
11 | the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be
14 | included in all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23 |
24 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Package gateway provides a drop-in replacement for net/http's `ListenAndServe` for use in [AWS Lambda](https://aws.amazon.com/lambda/) & [API Gateway](https://aws.amazon.com/api-gateway/), simply swap it out for `gateway.ListenAndServe`. Extracted from [Up](https://github.com/apex/up) which provides additional middleware features and operational functionality.
4 |
5 | There are two versions of this library, version 1.x supports AWS API Gateway 1.0 events used by the original [REST APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html), and 2.x which supports 2.0 events used by the [HTTP APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html). For more information on the options read [Choosing between HTTP APIs and REST APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-vs-rest.html) on the AWS documentation website.
6 |
7 | # Installation
8 |
9 | To install version 1.x for REST APIs.
10 |
11 | ```
12 | go get github.com/apex/gateway
13 | ```
14 |
15 | To install version 2.x for HTTP APIs.
16 |
17 | ```
18 | go get github.com/apex/gateway/v2
19 | ```
20 |
21 | # Example
22 |
23 | ```go
24 | package main
25 |
26 | import (
27 | "fmt"
28 | "log"
29 | "net/http"
30 | "os"
31 |
32 | "github.com/apex/gateway"
33 | )
34 |
35 | func main() {
36 | http.HandleFunc("/", hello)
37 | log.Fatal(gateway.ListenAndServe(":3000", nil))
38 | }
39 |
40 | func hello(w http.ResponseWriter, r *http.Request) {
41 | // example retrieving values from the api gateway proxy request context.
42 | requestContext, ok := gateway.RequestContext(r.Context())
43 | if !ok || requestContext.Authorizer["sub"] == nil {
44 | fmt.Fprint(w, "Hello World from Go")
45 | return
46 | }
47 |
48 | userID := requestContext.Authorizer["sub"].(string)
49 | fmt.Fprintf(w, "Hello %s from Go", userID)
50 | }
51 | ```
52 |
53 | ---
54 |
55 | [](https://godoc.org/github.com/apex/gateway)
56 | 
57 | 
58 |
59 |
60 |
--------------------------------------------------------------------------------
/ci.yml:
--------------------------------------------------------------------------------
1 | version: 0.2
2 |
3 | phases:
4 | install:
5 | commands:
6 | - go get -t ./...
7 | build:
8 | commands:
9 | - go test -cover -v ./...
10 |
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/aws/aws-lambda-go/events"
7 | )
8 |
9 | // key is the type used for any items added to the request context.
10 | type key int
11 |
12 | // requestContextKey is the key for the api gateway proxy `RequestContext`.
13 | const requestContextKey key = iota
14 |
15 | // newContext returns a new Context with specific api gateway proxy values.
16 | func newContext(ctx context.Context, e events.APIGatewayProxyRequest) context.Context {
17 | return context.WithValue(ctx, requestContextKey, e.RequestContext)
18 | }
19 |
20 | // RequestContext returns the APIGatewayProxyRequestContext value stored in ctx.
21 | func RequestContext(ctx context.Context) (events.APIGatewayProxyRequestContext, bool) {
22 | c, ok := ctx.Value(requestContextKey).(events.APIGatewayProxyRequestContext)
23 | return c, ok
24 | }
25 |
--------------------------------------------------------------------------------
/gateway.go:
--------------------------------------------------------------------------------
1 | // Package gateway provides a drop-in replacement for net/http.ListenAndServe for use in AWS Lambda & API Gateway.
2 | package gateway
3 |
4 | import (
5 | "context"
6 | "encoding/json"
7 | "net/http"
8 |
9 | "github.com/aws/aws-lambda-go/events"
10 | "github.com/aws/aws-lambda-go/lambda"
11 | )
12 |
13 | // ListenAndServe is a drop-in replacement for
14 | // http.ListenAndServe for use within AWS Lambda.
15 | //
16 | // ListenAndServe always returns a non-nil error.
17 | func ListenAndServe(addr string, h http.Handler) error {
18 | if h == nil {
19 | h = http.DefaultServeMux
20 | }
21 |
22 | gw := NewGateway(h)
23 |
24 | lambda.StartHandler(gw)
25 |
26 | return nil
27 | }
28 |
29 | // NewGateway creates a gateway using the provided http.Handler enabling use in existing aws-lambda-go
30 | // projects
31 | func NewGateway(h http.Handler) *Gateway {
32 | return &Gateway{h: h}
33 | }
34 |
35 | // Gateway wrap a http handler to enable use as a lambda.Handler
36 | type Gateway struct {
37 | h http.Handler
38 | }
39 |
40 | // Invoke Handler implementation
41 | func (gw *Gateway) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
42 | evt := events.APIGatewayProxyRequest{}
43 |
44 | if err := json.Unmarshal(payload, &evt); err != nil {
45 | return nil, err
46 | }
47 |
48 | r, err := NewRequest(ctx, evt)
49 | if err != nil {
50 | return nil, err
51 | }
52 |
53 | w := NewResponse()
54 | gw.h.ServeHTTP(w, r)
55 |
56 | resp := w.End()
57 |
58 | return json.Marshal(&resp)
59 | }
60 |
--------------------------------------------------------------------------------
/gateway_test.go:
--------------------------------------------------------------------------------
1 | package gateway_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "net/http"
8 | "testing"
9 |
10 | "github.com/apex/gateway"
11 | "github.com/tj/assert"
12 | )
13 |
14 | func Example() {
15 | http.HandleFunc("/", hello)
16 | log.Fatal(gateway.ListenAndServe(":3000", nil))
17 | }
18 |
19 | func hello(w http.ResponseWriter, r *http.Request) {
20 | fmt.Fprintln(w, "Hello World from Go")
21 | }
22 |
23 | func TestGateway_Invoke(t *testing.T) {
24 |
25 | e := []byte(`{"version": "1.0", "rawPath": "/pets/luna", "requestContext": {"http": {"method": "POST"}}}`)
26 |
27 | gw := gateway.NewGateway(http.HandlerFunc(hello))
28 |
29 | payload, err := gw.Invoke(context.Background(), e)
30 | assert.NoError(t, err)
31 | assert.JSONEq(t, `{"body":"Hello World from Go\n", "headers":{"Content-Type":"text/plain; charset=utf8"}, "multiValueHeaders":{}, "statusCode":200}`, string(payload))
32 | }
33 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/apex/gateway
2 |
3 | go 1.12
4 |
5 | require (
6 | github.com/aws/aws-lambda-go v1.17.0
7 | github.com/pkg/errors v0.9.1
8 | github.com/tj/assert v0.0.3
9 | )
10 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
2 | github.com/aws/aws-lambda-go v1.17.0 h1:Ogihmi8BnpmCNktKAGpNwSiILNNING1MiosnKUfU8m0=
3 | github.com/aws/aws-lambda-go v1.17.0/go.mod h1:FEwgPLE6+8wcGBTe5cJN3JWurd1Ztm9zN4jsXsjzKKw=
4 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
5 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
10 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
13 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
14 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
17 | github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho=
18 | github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
19 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
20 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
21 | github.com/tj/assert v0.0.1 h1:T7ozLNagrCCKl3wc+a706ztUCn/D6WHCJtkyvqYG+kQ=
22 | github.com/tj/assert v0.0.1/go.mod h1:lsg+GHQ0XplTcWKGxFLf/XPcPxWO8x2ut5jminoR2rA=
23 | github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk=
24 | github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk=
25 | github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ=
26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
28 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
29 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
31 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo=
32 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
33 |
--------------------------------------------------------------------------------
/request.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "fmt"
7 | "net/http"
8 | "net/url"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/aws/aws-lambda-go/events"
13 | "github.com/pkg/errors"
14 | )
15 |
16 | // NewRequest returns a new http.Request from the given Lambda event.
17 | func NewRequest(ctx context.Context, e events.APIGatewayProxyRequest) (*http.Request, error) {
18 | // path
19 | u, err := url.Parse(e.Path)
20 | if err != nil {
21 | return nil, errors.Wrap(err, "parsing path")
22 | }
23 |
24 | // querystring
25 | q := u.Query()
26 | for k, v := range e.QueryStringParameters {
27 | q.Set(k, v)
28 | }
29 |
30 | for k, values := range e.MultiValueQueryStringParameters {
31 | q[k] = values
32 | }
33 | u.RawQuery = q.Encode()
34 |
35 | // base64 encoded body
36 | body := e.Body
37 | if e.IsBase64Encoded {
38 | b, err := base64.StdEncoding.DecodeString(body)
39 | if err != nil {
40 | return nil, errors.Wrap(err, "decoding base64 body")
41 | }
42 | body = string(b)
43 | }
44 |
45 | // new request
46 | req, err := http.NewRequest(e.HTTPMethod, u.String(), strings.NewReader(body))
47 | if err != nil {
48 | return nil, errors.Wrap(err, "creating request")
49 | }
50 |
51 | // manually set RequestURI because NewRequest is for clients and req.RequestURI is for servers
52 | req.RequestURI = u.RequestURI()
53 |
54 | // remote addr
55 | req.RemoteAddr = e.RequestContext.Identity.SourceIP
56 |
57 | // header fields
58 | for k, v := range e.Headers {
59 | req.Header.Set(k, v)
60 | }
61 |
62 | for k, values := range e.MultiValueHeaders {
63 | req.Header[k] = values
64 | }
65 |
66 | // content-length
67 | if req.Header.Get("Content-Length") == "" && body != "" {
68 | req.Header.Set("Content-Length", strconv.Itoa(len(body)))
69 | }
70 |
71 | // custom fields
72 | req.Header.Set("X-Request-Id", e.RequestContext.RequestID)
73 | req.Header.Set("X-Stage", e.RequestContext.Stage)
74 |
75 | // custom context values
76 | req = req.WithContext(newContext(ctx, e))
77 |
78 | // xray support
79 | if traceID := ctx.Value("x-amzn-trace-id"); traceID != nil {
80 | req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("%v", traceID))
81 | }
82 |
83 | // host
84 | req.URL.Host = req.Header.Get("Host")
85 | req.Host = req.URL.Host
86 |
87 | return req, nil
88 | }
89 |
--------------------------------------------------------------------------------
/request_test.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 | "io/ioutil"
6 | "testing"
7 |
8 | "github.com/aws/aws-lambda-go/events"
9 | "github.com/tj/assert"
10 | )
11 |
12 | func TestNewRequest_path(t *testing.T) {
13 | e := events.APIGatewayProxyRequest{
14 | Path: "/pets/luna",
15 | }
16 |
17 | r, err := NewRequest(context.Background(), e)
18 | assert.NoError(t, err)
19 |
20 | assert.Equal(t, "GET", r.Method)
21 | assert.Equal(t, `/pets/luna`, r.URL.Path)
22 | assert.Equal(t, `/pets/luna`, r.URL.String())
23 | assert.Equal(t, `/pets/luna`, r.RequestURI)
24 | }
25 |
26 | func TestNewRequest_method(t *testing.T) {
27 | e := events.APIGatewayProxyRequest{
28 | HTTPMethod: "DELETE",
29 | Path: "/pets/luna",
30 | }
31 |
32 | r, err := NewRequest(context.Background(), e)
33 | assert.NoError(t, err)
34 |
35 | assert.Equal(t, "DELETE", r.Method)
36 | }
37 |
38 | func TestNewRequest_queryString(t *testing.T) {
39 | e := events.APIGatewayProxyRequest{
40 | HTTPMethod: "GET",
41 | Path: "/pets",
42 | QueryStringParameters: map[string]string{
43 | "order": "desc",
44 | "fields": "name,species",
45 | },
46 | }
47 |
48 | r, err := NewRequest(context.Background(), e)
49 | assert.NoError(t, err)
50 |
51 | assert.Equal(t, `/pets?fields=name%2Cspecies&order=desc`, r.URL.String())
52 | assert.Equal(t, `desc`, r.URL.Query().Get("order"))
53 | }
54 |
55 | func TestNewRequest_multiValueQueryString(t *testing.T) {
56 | e := events.APIGatewayProxyRequest{
57 | HTTPMethod: "GET",
58 | Path: "/pets",
59 | MultiValueQueryStringParameters: map[string][]string{
60 | "multi_fields": []string{"name", "species"},
61 | "multi_arr[]": []string{"arr1", "arr2"},
62 | },
63 | QueryStringParameters: map[string]string{
64 | "order": "desc",
65 | "fields": "name,species",
66 | },
67 | }
68 |
69 | r, err := NewRequest(context.Background(), e)
70 | assert.NoError(t, err)
71 |
72 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.URL.String())
73 | assert.Equal(t, []string{"name", "species"}, r.URL.Query()["multi_fields"])
74 | assert.Equal(t, []string{"arr1", "arr2"}, r.URL.Query()["multi_arr[]"])
75 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.RequestURI)
76 | }
77 |
78 | func TestNewRequest_remoteAddr(t *testing.T) {
79 | e := events.APIGatewayProxyRequest{
80 | HTTPMethod: "GET",
81 | Path: "/pets",
82 | RequestContext: events.APIGatewayProxyRequestContext{
83 | Identity: events.APIGatewayRequestIdentity{
84 | SourceIP: "1.2.3.4",
85 | },
86 | },
87 | }
88 |
89 | r, err := NewRequest(context.Background(), e)
90 | assert.NoError(t, err)
91 |
92 | assert.Equal(t, `1.2.3.4`, r.RemoteAddr)
93 | }
94 |
95 | func TestNewRequest_header(t *testing.T) {
96 | e := events.APIGatewayProxyRequest{
97 | HTTPMethod: "POST",
98 | Path: "/pets",
99 | Body: `{ "name": "Tobi" }`,
100 | Headers: map[string]string{
101 | "Content-Type": "application/json",
102 | "X-Foo": "bar",
103 | "Host": "example.com",
104 | },
105 | RequestContext: events.APIGatewayProxyRequestContext{
106 | RequestID: "1234",
107 | Stage: "prod",
108 | },
109 | }
110 |
111 | r, err := NewRequest(context.Background(), e)
112 | assert.NoError(t, err)
113 |
114 | assert.Equal(t, `example.com`, r.Host)
115 | assert.Equal(t, `prod`, r.Header.Get("X-Stage"))
116 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id"))
117 | assert.Equal(t, `18`, r.Header.Get("Content-Length"))
118 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type"))
119 | assert.Equal(t, `bar`, r.Header.Get("X-Foo"))
120 | }
121 |
122 | func TestNewRequest_multiHeader(t *testing.T) {
123 | e := events.APIGatewayProxyRequest{
124 | HTTPMethod: "POST",
125 | Path: "/pets",
126 | Body: `{ "name": "Tobi" }`,
127 | MultiValueHeaders: map[string][]string{
128 | "X-APEX": []string{"apex1", "apex2"},
129 | "X-APEX-2": []string{"apex-1", "apex-2"},
130 | },
131 | Headers: map[string]string{
132 | "Content-Type": "application/json",
133 | "X-Foo": "bar",
134 | "Host": "example.com",
135 | },
136 | RequestContext: events.APIGatewayProxyRequestContext{
137 | RequestID: "1234",
138 | Stage: "prod",
139 | },
140 | }
141 |
142 | r, err := NewRequest(context.Background(), e)
143 | assert.NoError(t, err)
144 |
145 | assert.Equal(t, `example.com`, r.Host)
146 | assert.Equal(t, `prod`, r.Header.Get("X-Stage"))
147 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id"))
148 | assert.Equal(t, `18`, r.Header.Get("Content-Length"))
149 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type"))
150 | assert.Equal(t, `bar`, r.Header.Get("X-Foo"))
151 | assert.Equal(t, []string{"apex1", "apex2"}, r.Header["X-APEX"])
152 | assert.Equal(t, []string{"apex-1", "apex-2"}, r.Header["X-APEX-2"])
153 | }
154 |
155 | func TestNewRequest_body(t *testing.T) {
156 | e := events.APIGatewayProxyRequest{
157 | HTTPMethod: "POST",
158 | Path: "/pets",
159 | Body: `{ "name": "Tobi" }`,
160 | }
161 |
162 | r, err := NewRequest(context.Background(), e)
163 | assert.NoError(t, err)
164 |
165 | b, err := ioutil.ReadAll(r.Body)
166 | assert.NoError(t, err)
167 |
168 | assert.Equal(t, `{ "name": "Tobi" }`, string(b))
169 | }
170 |
171 | func TestNewRequest_bodyBinary(t *testing.T) {
172 | e := events.APIGatewayProxyRequest{
173 | HTTPMethod: "POST",
174 | Path: "/pets",
175 | Body: `aGVsbG8gd29ybGQK`,
176 | IsBase64Encoded: true,
177 | }
178 |
179 | r, err := NewRequest(context.Background(), e)
180 | assert.NoError(t, err)
181 |
182 | b, err := ioutil.ReadAll(r.Body)
183 | assert.NoError(t, err)
184 |
185 | assert.Equal(t, "hello world\n", string(b))
186 | }
187 |
188 | func TestNewRequest_context(t *testing.T) {
189 | e := events.APIGatewayProxyRequest{}
190 | ctx := context.WithValue(context.Background(), "key", "value")
191 | r, err := NewRequest(ctx, e)
192 | assert.NoError(t, err)
193 | v := r.Context().Value("key")
194 | assert.Equal(t, "value", v)
195 | }
196 |
--------------------------------------------------------------------------------
/response.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "mime"
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/aws/aws-lambda-go/events"
11 | )
12 |
13 | // ResponseWriter implements the http.ResponseWriter interface
14 | // in order to support the API Gateway Lambda HTTP "protocol".
15 | type ResponseWriter struct {
16 | out events.APIGatewayProxyResponse
17 | buf bytes.Buffer
18 | header http.Header
19 | wroteHeader bool
20 | closeNotifyCh chan bool
21 | }
22 |
23 | // NewResponse returns a new response writer to capture http output.
24 | func NewResponse() *ResponseWriter {
25 | return &ResponseWriter{
26 | closeNotifyCh: make(chan bool, 1),
27 | }
28 | }
29 |
30 | // Header implementation.
31 | func (w *ResponseWriter) Header() http.Header {
32 | if w.header == nil {
33 | w.header = make(http.Header)
34 | }
35 |
36 | return w.header
37 | }
38 |
39 | // Write implementation.
40 | func (w *ResponseWriter) Write(b []byte) (int, error) {
41 | if !w.wroteHeader {
42 | w.WriteHeader(http.StatusOK)
43 | }
44 |
45 | return w.buf.Write(b)
46 | }
47 |
48 | // WriteHeader implementation.
49 | func (w *ResponseWriter) WriteHeader(status int) {
50 | if w.wroteHeader {
51 | return
52 | }
53 |
54 | if w.Header().Get("Content-Type") == "" {
55 | w.Header().Set("Content-Type", "text/plain; charset=utf8")
56 | }
57 |
58 | w.out.StatusCode = status
59 |
60 | h := make(map[string]string)
61 | mvh := make(map[string][]string)
62 |
63 | for k, v := range w.Header() {
64 | if len(v) == 1 {
65 | h[k] = v[0]
66 | } else if len(v) > 1 {
67 | mvh[k] = v
68 | }
69 | }
70 |
71 | w.out.Headers = h
72 | w.out.MultiValueHeaders = mvh
73 | w.wroteHeader = true
74 | }
75 |
76 | // CloseNotify notify when the response is closed
77 | func (w *ResponseWriter) CloseNotify() <-chan bool {
78 | return w.closeNotifyCh
79 | }
80 |
81 | // End the request.
82 | func (w *ResponseWriter) End() events.APIGatewayProxyResponse {
83 | w.out.IsBase64Encoded = isBinary(w.header)
84 |
85 | if w.out.IsBase64Encoded {
86 | w.out.Body = base64.StdEncoding.EncodeToString(w.buf.Bytes())
87 | } else {
88 | w.out.Body = w.buf.String()
89 | }
90 |
91 | // notify end
92 | w.closeNotifyCh <- true
93 |
94 | return w.out
95 | }
96 |
97 | // isBinary returns true if the response reprensents binary.
98 | func isBinary(h http.Header) bool {
99 | switch {
100 | case !isTextMime(h.Get("Content-Type")):
101 | return true
102 | case h.Get("Content-Encoding") == "gzip":
103 | return true
104 | default:
105 | return false
106 | }
107 | }
108 |
109 | // isTextMime returns true if the content type represents textual data.
110 | func isTextMime(kind string) bool {
111 | mt, _, err := mime.ParseMediaType(kind)
112 | if err != nil {
113 | return false
114 | }
115 |
116 | if strings.HasPrefix(mt, "text/") {
117 | return true
118 | }
119 |
120 | switch mt {
121 | case "image/svg+xml", "application/json", "application/xml","application/javascript":
122 | return true
123 | default:
124 | return false
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/response_test.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/tj/assert"
8 | )
9 |
10 | func Test_JSON_isTextMime(t *testing.T) {
11 | assert.Equal(t, isTextMime("application/json"), true)
12 | assert.Equal(t, isTextMime("application/json; charset=utf-8"), true)
13 | assert.Equal(t, isTextMime("Application/JSON"), true)
14 | }
15 |
16 | func Test_XML_isTextMime(t *testing.T) {
17 | assert.Equal(t, isTextMime("application/xml"), true)
18 | assert.Equal(t, isTextMime("application/xml; charset=utf-8"), true)
19 | assert.Equal(t, isTextMime("ApPlicaTion/xMl"), true)
20 | }
21 |
22 | func TestResponseWriter_Header(t *testing.T) {
23 | w := NewResponse()
24 | w.Header().Set("Foo", "bar")
25 | w.Header().Set("Bar", "baz")
26 |
27 | var buf bytes.Buffer
28 | w.header.Write(&buf)
29 |
30 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\n", buf.String())
31 | }
32 |
33 | func TestResponseWriter_multiHeader(t *testing.T) {
34 | w := NewResponse()
35 | w.Header().Set("Foo", "bar")
36 | w.Header().Set("Bar", "baz")
37 | w.Header().Add("X-APEX", "apex1")
38 | w.Header().Add("X-APEX", "apex2")
39 |
40 | var buf bytes.Buffer
41 | w.header.Write(&buf)
42 |
43 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\nX-Apex: apex1\r\nX-Apex: apex2\r\n", buf.String())
44 | }
45 |
46 | func TestResponseWriter_Write_text(t *testing.T) {
47 | types := []string{
48 | "text/x-custom",
49 | "text/plain",
50 | "text/plain; charset=utf-8",
51 | "application/json",
52 | "application/json; charset=utf-8",
53 | "application/xml",
54 | "image/svg+xml",
55 | }
56 |
57 | for _, kind := range types {
58 | t.Run(kind, func(t *testing.T) {
59 | w := NewResponse()
60 | w.Header().Set("Content-Type", kind)
61 | w.Write([]byte("hello world\n"))
62 |
63 | e := w.End()
64 | assert.Equal(t, 200, e.StatusCode)
65 | assert.Equal(t, "hello world\n", e.Body)
66 | assert.Equal(t, kind, e.Headers["Content-Type"])
67 | assert.False(t, e.IsBase64Encoded)
68 | assert.True(t, <-w.CloseNotify())
69 | })
70 | }
71 | }
72 |
73 | func TestResponseWriter_Write_binary(t *testing.T) {
74 | w := NewResponse()
75 | w.Header().Set("Content-Type", "image/png")
76 | w.Write([]byte("data"))
77 |
78 | e := w.End()
79 | assert.Equal(t, 200, e.StatusCode)
80 | assert.Equal(t, "ZGF0YQ==", e.Body)
81 | assert.Equal(t, "image/png", e.Headers["Content-Type"])
82 | assert.True(t, e.IsBase64Encoded)
83 | }
84 |
85 | func TestResponseWriter_Write_gzip(t *testing.T) {
86 | w := NewResponse()
87 | w.Header().Set("Content-Type", "text/plain")
88 | w.Header().Set("Content-Encoding", "gzip")
89 | w.Write([]byte("data"))
90 |
91 | e := w.End()
92 | assert.Equal(t, 200, e.StatusCode)
93 | assert.Equal(t, "ZGF0YQ==", e.Body)
94 | assert.Equal(t, "text/plain", e.Headers["Content-Type"])
95 | assert.True(t, e.IsBase64Encoded)
96 | }
97 |
98 | func TestResponseWriter_WriteHeader(t *testing.T) {
99 | w := NewResponse()
100 | w.WriteHeader(404)
101 | w.Write([]byte("Not Found\n"))
102 |
103 | e := w.End()
104 | assert.Equal(t, 404, e.StatusCode)
105 | assert.Equal(t, "Not Found\n", e.Body)
106 | assert.Equal(t, "text/plain; charset=utf8", e.Headers["Content-Type"])
107 | }
108 |
--------------------------------------------------------------------------------
/v2/context.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/aws/aws-lambda-go/events"
7 | )
8 |
9 | // key is the type used for any items added to the request context.
10 | type key int
11 |
12 | // requestContextKey is the key for the api gateway proxy `RequestContext`.
13 | const requestContextKey key = iota
14 |
15 | // RequestContext returns the APIGatewayV2HTTPRequestContext value stored in ctx.
16 | func RequestContext(ctx context.Context) (events.APIGatewayV2HTTPRequestContext, bool) {
17 | c, ok := ctx.Value(requestContextKey).(events.APIGatewayV2HTTPRequestContext)
18 | return c, ok
19 | }
20 |
21 | // newContext returns a new Context with specific api gateway v2 values.
22 | func newContext(ctx context.Context, e events.APIGatewayV2HTTPRequest) context.Context {
23 | return context.WithValue(ctx, requestContextKey, e.RequestContext)
24 | }
25 |
--------------------------------------------------------------------------------
/v2/gateway.go:
--------------------------------------------------------------------------------
1 | // Package gateway provides a drop-in replacement for net/http.ListenAndServe for use in AWS Lambda & API Gateway.
2 | package gateway
3 |
4 | import (
5 | "context"
6 | "encoding/json"
7 | "net/http"
8 |
9 | "github.com/aws/aws-lambda-go/events"
10 | "github.com/aws/aws-lambda-go/lambda"
11 | )
12 |
13 | // ListenAndServe is a drop-in replacement for
14 | // http.ListenAndServe for use within AWS Lambda.
15 | //
16 | // ListenAndServe always returns a non-nil error.
17 | func ListenAndServe(addr string, h http.Handler) error {
18 | if h == nil {
19 | h = http.DefaultServeMux
20 | }
21 |
22 | gw := NewGateway(h)
23 |
24 | lambda.StartHandler(gw)
25 |
26 | return nil
27 | }
28 |
29 | // NewGateway creates a gateway using the provided http.Handler enabling use in existing aws-lambda-go
30 | // projects
31 | func NewGateway(h http.Handler) *Gateway {
32 | return &Gateway{h: h}
33 | }
34 |
35 | // Gateway wrap a http handler to enable use as a lambda.Handler
36 | type Gateway struct {
37 | h http.Handler
38 | }
39 |
40 | // Invoke Handler implementation
41 | func (gw *Gateway) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
42 | var evt events.APIGatewayV2HTTPRequest
43 |
44 | if err := json.Unmarshal(payload, &evt); err != nil {
45 | return []byte{}, err
46 | }
47 |
48 | r, err := NewRequest(ctx, evt)
49 | if err != nil {
50 | return []byte{}, err
51 | }
52 |
53 | w := NewResponse()
54 | gw.h.ServeHTTP(w, r)
55 |
56 | resp := w.End()
57 |
58 | return json.Marshal(&resp)
59 | }
60 |
--------------------------------------------------------------------------------
/v2/gateway_test.go:
--------------------------------------------------------------------------------
1 | package gateway_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "net/http"
8 | "testing"
9 |
10 | "github.com/apex/gateway/v2"
11 | "github.com/tj/assert"
12 | )
13 |
14 | func Example() {
15 | http.HandleFunc("/", hello)
16 | log.Fatal(gateway.ListenAndServe(":3000", nil))
17 | }
18 |
19 | func hello(w http.ResponseWriter, r *http.Request) {
20 | fmt.Fprintln(w, "Hello World from Go")
21 | }
22 |
23 | func TestGateway_Invoke(t *testing.T) {
24 |
25 | e := []byte(`{"version": "2.0", "rawPath": "/pets/luna", "requestContext": {"http": {"method": "POST"}}}`)
26 |
27 | gw := gateway.NewGateway(http.HandlerFunc(hello))
28 |
29 | payload, err := gw.Invoke(context.Background(), e)
30 | assert.NoError(t, err)
31 | assert.JSONEq(t, `{"body":"Hello World from Go\n", "cookies": null, "headers":{"Content-Type":"text/plain; charset=utf8"}, "multiValueHeaders":{}, "statusCode":200}`, string(payload))
32 | }
33 |
--------------------------------------------------------------------------------
/v2/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/apex/gateway/v2
2 |
3 | go 1.12
4 |
5 | require (
6 | github.com/aws/aws-lambda-go v1.17.0
7 | github.com/pkg/errors v0.9.1
8 | github.com/tj/assert v0.0.3
9 | )
10 |
--------------------------------------------------------------------------------
/v2/go.sum:
--------------------------------------------------------------------------------
1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
2 | github.com/aws/aws-lambda-go v1.17.0 h1:Ogihmi8BnpmCNktKAGpNwSiILNNING1MiosnKUfU8m0=
3 | github.com/aws/aws-lambda-go v1.17.0/go.mod h1:FEwgPLE6+8wcGBTe5cJN3JWurd1Ztm9zN4jsXsjzKKw=
4 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
5 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
10 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
13 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
14 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
17 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
18 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
19 | github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk=
20 | github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk=
21 | github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ=
22 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
23 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
24 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
25 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
26 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
27 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo=
28 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
29 |
--------------------------------------------------------------------------------
/v2/request.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "fmt"
7 | "net/http"
8 | "net/url"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/aws/aws-lambda-go/events"
13 | "github.com/pkg/errors"
14 | )
15 |
16 | // NewRequest returns a new http.Request from the given Lambda event.
17 | func NewRequest(ctx context.Context, e events.APIGatewayV2HTTPRequest) (*http.Request, error) {
18 | // path
19 | u, err := url.Parse(e.RawPath)
20 | if err != nil {
21 | return nil, errors.Wrap(err, "parsing path")
22 | }
23 |
24 | u.RawQuery = e.RawQueryString
25 |
26 | // base64 encoded body
27 | body := e.Body
28 | if e.IsBase64Encoded {
29 | b, err := base64.StdEncoding.DecodeString(body)
30 | if err != nil {
31 | return nil, errors.Wrap(err, "decoding base64 body")
32 | }
33 | body = string(b)
34 | }
35 |
36 | // new request
37 | req, err := http.NewRequest(e.RequestContext.HTTP.Method, u.String(), strings.NewReader(body))
38 | if err != nil {
39 | return nil, errors.Wrap(err, "creating request")
40 | }
41 |
42 | // manually set RequestURI because NewRequest is for clients and req.RequestURI is for servers
43 | req.RequestURI = u.RequestURI()
44 |
45 | // remote addr
46 | req.RemoteAddr = e.RequestContext.HTTP.SourceIP
47 |
48 | // header fields
49 | for k, values := range e.Headers {
50 | for _, v := range strings.Split(values, ",") {
51 | req.Header.Add(k, v)
52 | }
53 | }
54 | for _, c := range e.Cookies {
55 | req.Header.Add("Cookie", c)
56 | }
57 |
58 | // content-length
59 | if req.Header.Get("Content-Length") == "" && body != "" {
60 | req.Header.Set("Content-Length", strconv.Itoa(len(body)))
61 | }
62 |
63 | // custom fields
64 | req.Header.Set("X-Request-Id", e.RequestContext.RequestID)
65 | req.Header.Set("X-Stage", e.RequestContext.Stage)
66 |
67 | // custom context values
68 | req = req.WithContext(newContext(ctx, e))
69 |
70 | // xray support
71 | if traceID := ctx.Value("x-amzn-trace-id"); traceID != nil {
72 | req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("%v", traceID))
73 | }
74 |
75 | // host
76 | req.URL.Host = req.Header.Get("Host")
77 | req.Host = req.URL.Host
78 |
79 | return req, nil
80 | }
81 |
--------------------------------------------------------------------------------
/v2/request_test.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "context"
5 | "io/ioutil"
6 | "strings"
7 | "testing"
8 |
9 | "github.com/aws/aws-lambda-go/events"
10 | "github.com/tj/assert"
11 | )
12 |
13 | func TestDecodeRequest_path(t *testing.T) {
14 | e := events.APIGatewayV2HTTPRequest{
15 | RawPath: "/pets/luna",
16 | }
17 |
18 | r, err := NewRequest(context.Background(), e)
19 | assert.NoError(t, err)
20 |
21 | assert.Equal(t, "GET", r.Method)
22 | assert.Equal(t, `/pets/luna`, r.URL.Path)
23 | assert.Equal(t, `/pets/luna`, r.URL.String())
24 | assert.Equal(t, `/pets/luna`, r.RequestURI)
25 | }
26 |
27 | func TestDecodeRequest_method(t *testing.T) {
28 | e := events.APIGatewayV2HTTPRequest{
29 | RawPath: "/pets/luna",
30 | RequestContext: events.APIGatewayV2HTTPRequestContext{
31 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
32 | Method: "DELETE",
33 | Path: "/pets/luna",
34 | },
35 | },
36 | }
37 |
38 | r, err := NewRequest(context.Background(), e)
39 | assert.NoError(t, err)
40 |
41 | assert.Equal(t, "DELETE", r.Method)
42 | }
43 |
44 | func TestDecodeRequest_queryString(t *testing.T) {
45 | e := events.APIGatewayV2HTTPRequest{
46 | RawPath: "/pets",
47 | RawQueryString: "fields=name%2Cspecies&order=desc",
48 | QueryStringParameters: map[string]string{
49 | "order": "desc",
50 | "fields": "name,species",
51 | },
52 | RequestContext: events.APIGatewayV2HTTPRequestContext{
53 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
54 | Method: "GET",
55 | Path: "/pets",
56 | },
57 | },
58 | }
59 |
60 | r, err := NewRequest(context.Background(), e)
61 | assert.NoError(t, err)
62 |
63 | assert.Equal(t, `/pets?fields=name%2Cspecies&order=desc`, r.URL.String())
64 | assert.Equal(t, `desc`, r.URL.Query().Get("order"))
65 | }
66 |
67 | func TestDecodeRequest_multiValueQueryString(t *testing.T) {
68 | e := events.APIGatewayV2HTTPRequest{
69 | RawPath: "/pets",
70 | RawQueryString: "fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc",
71 | QueryStringParameters: map[string]string{
72 | "multi_fields": strings.Join([]string{"name", "species"}, ","),
73 | "multi_arr[]": strings.Join([]string{"arr1", "arr2"}, ","),
74 | "order": "desc",
75 | "fields": "name,species",
76 | },
77 | RequestContext: events.APIGatewayV2HTTPRequestContext{
78 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
79 | Method: "GET",
80 | Path: "/pets",
81 | },
82 | },
83 | }
84 |
85 | r, err := NewRequest(context.Background(), e)
86 | assert.NoError(t, err)
87 |
88 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.URL.String())
89 | assert.Equal(t, []string{"name", "species"}, r.URL.Query()["multi_fields"])
90 | assert.Equal(t, []string{"arr1", "arr2"}, r.URL.Query()["multi_arr[]"])
91 | assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.RequestURI)
92 | }
93 |
94 | func TestDecodeRequest_remoteAddr(t *testing.T) {
95 | e := events.APIGatewayV2HTTPRequest{
96 | RawPath: "/pets",
97 | RequestContext: events.APIGatewayV2HTTPRequestContext{
98 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
99 | Method: "GET",
100 | Path: "/pets",
101 | SourceIP: "1.2.3.4",
102 | },
103 | },
104 | }
105 |
106 | r, err := NewRequest(context.Background(), e)
107 | assert.NoError(t, err)
108 |
109 | assert.Equal(t, `1.2.3.4`, r.RemoteAddr)
110 | }
111 |
112 | func TestDecodeRequest_header(t *testing.T) {
113 | e := events.APIGatewayV2HTTPRequest{
114 | RawPath: "/pets",
115 | Body: `{ "name": "Tobi" }`,
116 | Headers: map[string]string{
117 | "Content-Type": "application/json",
118 | "X-Foo": "bar",
119 | "Host": "example.com",
120 | },
121 | RequestContext: events.APIGatewayV2HTTPRequestContext{
122 | RequestID: "1234",
123 | Stage: "prod",
124 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
125 | Path: "/pets",
126 | Method: "POST",
127 | },
128 | },
129 | }
130 |
131 | r, err := NewRequest(context.Background(), e)
132 | assert.NoError(t, err)
133 |
134 | assert.Equal(t, `example.com`, r.Host)
135 | assert.Equal(t, `prod`, r.Header.Get("X-Stage"))
136 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id"))
137 | assert.Equal(t, `18`, r.Header.Get("Content-Length"))
138 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type"))
139 | assert.Equal(t, `bar`, r.Header.Get("X-Foo"))
140 | }
141 |
142 | func TestDecodeRequest_multiHeader(t *testing.T) {
143 | e := events.APIGatewayV2HTTPRequest{
144 | RawPath: "/pets",
145 | Body: `{ "name": "Tobi" }`,
146 | Headers: map[string]string{
147 | "X-APEX": strings.Join([]string{"apex1", "apex2"}, ","),
148 | "X-APEX-2": strings.Join([]string{"apex-1", "apex-2"}, ","),
149 | "Content-Type": "application/json",
150 | "X-Foo": "bar",
151 | "Host": "example.com",
152 | },
153 | RequestContext: events.APIGatewayV2HTTPRequestContext{
154 | RequestID: "1234",
155 | Stage: "prod",
156 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
157 | Path: "/pets",
158 | Method: "POST",
159 | },
160 | },
161 | }
162 |
163 | r, err := NewRequest(context.Background(), e)
164 | assert.NoError(t, err)
165 |
166 | assert.Equal(t, `example.com`, r.Host)
167 | assert.Equal(t, `prod`, r.Header.Get("X-Stage"))
168 | assert.Equal(t, `1234`, r.Header.Get("X-Request-Id"))
169 | assert.Equal(t, `18`, r.Header.Get("Content-Length"))
170 | assert.Equal(t, `application/json`, r.Header.Get("Content-Type"))
171 | assert.Equal(t, `bar`, r.Header.Get("X-Foo"))
172 | assert.Equal(t, []string{"apex1", "apex2"}, r.Header["X-Apex"])
173 | assert.Equal(t, []string{"apex-1", "apex-2"}, r.Header["X-Apex-2"])
174 | }
175 |
176 | func TestDecodeRequest_cookie(t *testing.T) {
177 | e := events.APIGatewayV2HTTPRequest{
178 | RawPath: "/pets",
179 | Body: `{ "name": "Tobi" }`,
180 | Headers: map[string]string{},
181 | Cookies: []string{"TEST_COOKIE=TEST-VALUE"},
182 | RequestContext: events.APIGatewayV2HTTPRequestContext{
183 | RequestID: "1234",
184 | Stage: "prod",
185 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
186 | Path: "/pets",
187 | Method: "POST",
188 | },
189 | },
190 | }
191 |
192 | r, err := NewRequest(context.Background(), e)
193 | assert.NoError(t, err)
194 |
195 | c, err := r.Cookie("TEST_COOKIE")
196 | assert.NoError(t, err)
197 |
198 | assert.Equal(t, "TEST-VALUE", c.Value)
199 | }
200 |
201 | func TestDecodeRequest_body(t *testing.T) {
202 | e := events.APIGatewayV2HTTPRequest{
203 | RawPath: "/pets",
204 | Body: `{ "name": "Tobi" }`,
205 | RequestContext: events.APIGatewayV2HTTPRequestContext{
206 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
207 | Method: "POST",
208 | Path: "/pets",
209 | },
210 | },
211 | }
212 |
213 | r, err := NewRequest(context.Background(), e)
214 | assert.NoError(t, err)
215 |
216 | b, err := ioutil.ReadAll(r.Body)
217 | assert.NoError(t, err)
218 |
219 | assert.Equal(t, `{ "name": "Tobi" }`, string(b))
220 | }
221 |
222 | func TestDecodeRequest_bodyBinary(t *testing.T) {
223 | e := events.APIGatewayV2HTTPRequest{
224 | RawPath: "/pets",
225 | Body: `aGVsbG8gd29ybGQK`,
226 | IsBase64Encoded: true,
227 | RequestContext: events.APIGatewayV2HTTPRequestContext{
228 | HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{
229 | Method: "POST",
230 | Path: "/pets",
231 | },
232 | },
233 | }
234 |
235 | r, err := NewRequest(context.Background(), e)
236 | assert.NoError(t, err)
237 |
238 | b, err := ioutil.ReadAll(r.Body)
239 | assert.NoError(t, err)
240 |
241 | assert.Equal(t, "hello world\n", string(b))
242 | }
243 |
244 | func TestDecodeRequest_context(t *testing.T) {
245 | e := events.APIGatewayV2HTTPRequest{}
246 | ctx := context.WithValue(context.Background(), "key", "value")
247 | r, err := NewRequest(ctx, e)
248 | assert.NoError(t, err)
249 | v := r.Context().Value("key")
250 | assert.Equal(t, "value", v)
251 | }
252 |
--------------------------------------------------------------------------------
/v2/response.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "mime"
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/aws/aws-lambda-go/events"
11 | )
12 |
13 | // ResponseWriter implements the http.ResponseWriter interface
14 | // in order to support the API Gateway Lambda HTTP "protocol".
15 | type ResponseWriter struct {
16 | out events.APIGatewayV2HTTPResponse
17 | buf bytes.Buffer
18 | header http.Header
19 | wroteHeader bool
20 | closeNotifyCh chan bool
21 | }
22 |
23 | // NewResponse returns a new response writer to capture http output.
24 | func NewResponse() *ResponseWriter {
25 | return &ResponseWriter{
26 | closeNotifyCh: make(chan bool, 1),
27 | }
28 | }
29 |
30 | // Header implementation.
31 | func (w *ResponseWriter) Header() http.Header {
32 | if w.header == nil {
33 | w.header = make(http.Header)
34 | }
35 |
36 | return w.header
37 | }
38 |
39 | // Write implementation.
40 | func (w *ResponseWriter) Write(b []byte) (int, error) {
41 | if !w.wroteHeader {
42 | w.WriteHeader(http.StatusOK)
43 | }
44 |
45 | return w.buf.Write(b)
46 | }
47 |
48 | // WriteHeader implementation.
49 | func (w *ResponseWriter) WriteHeader(status int) {
50 | if w.wroteHeader {
51 | return
52 | }
53 |
54 | if w.Header().Get("Content-Type") == "" {
55 | w.Header().Set("Content-Type", "text/plain; charset=utf8")
56 | }
57 |
58 | w.out.StatusCode = status
59 |
60 | h := make(map[string]string)
61 | mvh := make(map[string][]string)
62 |
63 | for k, v := range w.Header() {
64 | if len(v) == 1 {
65 | h[k] = v[0]
66 | } else if len(v) > 1 {
67 | mvh[k] = v
68 | }
69 | }
70 |
71 | w.out.Headers = h
72 | w.out.MultiValueHeaders = mvh
73 | w.wroteHeader = true
74 | }
75 |
76 | // CloseNotify notify when the response is closed
77 | func (w *ResponseWriter) CloseNotify() <-chan bool {
78 | return w.closeNotifyCh
79 | }
80 |
81 | // End the request.
82 | func (w *ResponseWriter) End() events.APIGatewayV2HTTPResponse {
83 | w.out.IsBase64Encoded = isBinary(w.header)
84 |
85 | if w.out.IsBase64Encoded {
86 | w.out.Body = base64.StdEncoding.EncodeToString(w.buf.Bytes())
87 | } else {
88 | w.out.Body = w.buf.String()
89 | }
90 |
91 | // see https://aws.amazon.com/blogs/compute/simply-serverless-using-aws-lambda-to-expose-custom-cookies-with-api-gateway/
92 | w.out.Cookies = w.header["Set-Cookie"]
93 | w.header.Del("Set-Cookie")
94 |
95 | // notify end
96 | w.closeNotifyCh <- true
97 |
98 | return w.out
99 | }
100 |
101 | // isBinary returns true if the response reprensents binary.
102 | func isBinary(h http.Header) bool {
103 | switch {
104 | case !isTextMime(h.Get("Content-Type")):
105 | return true
106 | case h.Get("Content-Encoding") == "gzip":
107 | return true
108 | default:
109 | return false
110 | }
111 | }
112 |
113 | // isTextMime returns true if the content type represents textual data.
114 | func isTextMime(kind string) bool {
115 | mt, _, err := mime.ParseMediaType(kind)
116 | if err != nil {
117 | return false
118 | }
119 |
120 | if strings.HasPrefix(mt, "text/") {
121 | return true
122 | }
123 |
124 | switch mt {
125 | case "image/svg+xml", "application/json", "application/xml", "application/javascript":
126 | return true
127 | default:
128 | return false
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/v2/response_test.go:
--------------------------------------------------------------------------------
1 | package gateway
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/tj/assert"
8 | )
9 |
10 | func Test_JSON_isTextMime(t *testing.T) {
11 | assert.Equal(t, isTextMime("application/json"), true)
12 | assert.Equal(t, isTextMime("application/json; charset=utf-8"), true)
13 | assert.Equal(t, isTextMime("Application/JSON"), true)
14 | }
15 |
16 | func Test_XML_isTextMime(t *testing.T) {
17 | assert.Equal(t, isTextMime("application/xml"), true)
18 | assert.Equal(t, isTextMime("application/xml; charset=utf-8"), true)
19 | assert.Equal(t, isTextMime("ApPlicaTion/xMl"), true)
20 | }
21 |
22 | func TestResponseWriter_Header(t *testing.T) {
23 | w := NewResponse()
24 | w.Header().Set("Foo", "bar")
25 | w.Header().Set("Bar", "baz")
26 |
27 | var buf bytes.Buffer
28 | w.header.Write(&buf)
29 |
30 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\n", buf.String())
31 | }
32 |
33 | func TestResponseWriter_multiHeader(t *testing.T) {
34 | w := NewResponse()
35 | w.Header().Set("Foo", "bar")
36 | w.Header().Set("Bar", "baz")
37 | w.Header().Add("X-APEX", "apex1")
38 | w.Header().Add("X-APEX", "apex2")
39 |
40 | var buf bytes.Buffer
41 | w.header.Write(&buf)
42 |
43 | assert.Equal(t, "Bar: baz\r\nFoo: bar\r\nX-Apex: apex1\r\nX-Apex: apex2\r\n", buf.String())
44 | }
45 |
46 | func TestResponseWriter_Write_text(t *testing.T) {
47 | types := []string{
48 | "text/x-custom",
49 | "text/plain",
50 | "text/plain; charset=utf-8",
51 | "application/json",
52 | "application/json; charset=utf-8",
53 | "application/xml",
54 | "image/svg+xml",
55 | }
56 |
57 | for _, kind := range types {
58 | t.Run(kind, func(t *testing.T) {
59 | w := NewResponse()
60 | w.Header().Set("Content-Type", kind)
61 | w.Write([]byte("hello world\n"))
62 |
63 | e := w.End()
64 | assert.Equal(t, 200, e.StatusCode)
65 | assert.Equal(t, "hello world\n", e.Body)
66 | assert.Equal(t, kind, e.Headers["Content-Type"])
67 | assert.False(t, e.IsBase64Encoded)
68 | assert.True(t, <-w.CloseNotify())
69 | })
70 | }
71 | }
72 |
73 | func TestResponseWriter_Write_binary(t *testing.T) {
74 | w := NewResponse()
75 | w.Header().Set("Content-Type", "image/png")
76 | w.Write([]byte("data"))
77 |
78 | e := w.End()
79 | assert.Equal(t, 200, e.StatusCode)
80 | assert.Equal(t, "ZGF0YQ==", e.Body)
81 | assert.Equal(t, "image/png", e.Headers["Content-Type"])
82 | assert.True(t, e.IsBase64Encoded)
83 | }
84 |
85 | func TestResponseWriter_Write_gzip(t *testing.T) {
86 | w := NewResponse()
87 | w.Header().Set("Content-Type", "text/plain")
88 | w.Header().Set("Content-Encoding", "gzip")
89 | w.Write([]byte("data"))
90 |
91 | e := w.End()
92 | assert.Equal(t, 200, e.StatusCode)
93 | assert.Equal(t, "ZGF0YQ==", e.Body)
94 | assert.Equal(t, "text/plain", e.Headers["Content-Type"])
95 | assert.True(t, e.IsBase64Encoded)
96 | }
97 |
98 | func TestResponseWriter_WriteHeader(t *testing.T) {
99 | w := NewResponse()
100 | w.WriteHeader(404)
101 | w.Write([]byte("Not Found\n"))
102 |
103 | e := w.End()
104 | assert.Equal(t, 404, e.StatusCode)
105 | assert.Equal(t, "Not Found\n", e.Body)
106 | assert.Equal(t, "text/plain; charset=utf8", e.Headers["Content-Type"])
107 | }
108 |
--------------------------------------------------------------------------------