├── LICENSE ├── docs ├── README.md └── screenshot.png ├── example └── example.go ├── go.mod ├── go.sum ├── layer ├── extension │ └── extension.go └── layer.go └── postinvoke.go /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aidan Steele 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 19 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # `postinvoke` for Go on AWS Lambda 2 | 3 | AWS Lambda is a neat service. And it got neater with the launch of [external extensions][ext], 4 | which allow you to have background processes that can continue running (e.g. for 5 | cleanup) after a Lambda function has returned its response. But sometimes you 6 | want to do cleanup in the function process itself. Here's a library to achieve that. 7 | 8 | ## Usage 9 | 10 | First, create a Lambda function. You need to include the supporting `postinvoke` 11 | external extension in a layer. Here's an example that you can use (uncomment 12 | as appropriate for your architecture): 13 | 14 | ```yaml 15 | Transform: AWS::Serverless-2016-10-31 16 | 17 | Resources: 18 | Example: 19 | Type: AWS::Serverless::Function 20 | Properties: 21 | CodeUri: ./example/bootstrap 22 | Runtime: provided.al2 23 | Handler: unused 24 | Timeout: 30 25 | MemorySize: 512 26 | Architectures: [arm64] 27 | Layers: 28 | - !Sub arn:aws:lambda:${AWS::Region}:514202201242:layer:postinvoke-arm64:1 29 | # Architectures: [x86_64] 30 | # Layers: 31 | # - !Sub arn:aws:lambda:${AWS::Region}:514202201242:layer:postinvoke-x86_64:1 32 | ``` 33 | 34 | Second, here's the code for an example Lambda function: 35 | 36 | ```go 37 | package main 38 | 39 | import ( 40 | "context" 41 | "encoding/json" 42 | "fmt" 43 | "github.com/aidansteele/postinvoke" 44 | "github.com/aws/aws-lambda-go/lambda" 45 | "time" 46 | ) 47 | 48 | func main() { 49 | postinvoke.Shutdown(func() { 50 | // this will be executed when the lambda environment shuts down. 51 | // you get 300ms to clean up - be quick! 52 | fmt.Println("bye y'all!") 53 | }) 54 | 55 | h := lambda.NewHandler(handle) 56 | h = postinvoke.WrapHandler(h, nil) // don't forget this line 57 | lambda.StartHandler(h) 58 | } 59 | 60 | func handle(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { 61 | // this is where you do your lambda thing 62 | fmt.Println(string(input)) 63 | 64 | postinvoke.Run(ctx, func() { 65 | // and here is where you can do your post-invoke cleanup. this code runs 66 | // after the lambda function has returned its response to the caller 67 | fmt.Println("first on stack") 68 | }) 69 | 70 | postinvoke.Run(ctx, func() { 71 | // you can have multiple post-invoke methods 72 | // executed sequentially, in a defer-like stack 73 | fmt.Println("second on stack") 74 | time.Sleep(3 * time.Second) 75 | fmt.Println("second on complete") 76 | }) 77 | 78 | return input, nil 79 | } 80 | ``` 81 | 82 | This is what is logged by the above Lambda function. Note from the timestamps that 83 | the environment is shut down about six minutes later. 84 | 85 | ![screenshot](/docs/screenshot.png) 86 | 87 | [ext]: https://aws.amazon.com/blogs/compute/introducing-aws-lambda-extensions-in-preview/ 88 | -------------------------------------------------------------------------------- /docs/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aidansteele/postinvoke/dde44544a6f5741c4b5aec26c512d37e43117a34/docs/screenshot.png -------------------------------------------------------------------------------- /example/example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/aidansteele/postinvoke" 8 | "github.com/aws/aws-lambda-go/lambda" 9 | "time" 10 | ) 11 | 12 | func main() { 13 | postinvoke.Shutdown(func() { 14 | // this will be executed when the lambda environment shuts down. 15 | // you get 300ms to clean up - be quick! 16 | fmt.Println("bye y'all!") 17 | }) 18 | 19 | h := lambda.NewHandler(handle) 20 | h = postinvoke.WrapHandler(h, nil) // don't forget this line 21 | lambda.StartHandler(h) 22 | } 23 | 24 | func handle(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { 25 | // this is where you do your lambda thing 26 | fmt.Println(string(input)) 27 | 28 | postinvoke.Run(ctx, func() { 29 | // and here is where you can do your post-invoke cleanup. this code runs 30 | // after the lambda function has returned its response to the caller 31 | fmt.Println("first on stack") 32 | }) 33 | 34 | postinvoke.Run(ctx, func() { 35 | // you can have multiple post-invoke methods 36 | // executed sequentially, in a defer-like stack 37 | fmt.Println("second on stack") 38 | time.Sleep(3 * time.Second) 39 | fmt.Println("second on complete") 40 | }) 41 | 42 | return input, nil 43 | } 44 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/aidansteele/postinvoke 2 | 3 | go 1.17 4 | 5 | require github.com/aws/aws-lambda-go v1.28.0 // indirect 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/aws/aws-lambda-go v1.28.0 h1:fZiik1PZqW2IyAN4rj+Y0UBaO1IDFlsNo9Zz/XnArK4= 3 | github.com/aws/aws-lambda-go v1.28.0/go.mod h1:jJmlefzPfGnckuHdXX7/80O3BvUUi12XOkbv4w9SGLU= 4 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 5 | github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 10 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 11 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 12 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 13 | github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 15 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 16 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 17 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | -------------------------------------------------------------------------------- /layer/extension/extension.go: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: MIT-0 3 | 4 | package extension 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "encoding/json" 10 | "fmt" 11 | "io/ioutil" 12 | "net/http" 13 | ) 14 | 15 | // RegisterResponse is the body of the response for /register 16 | type RegisterResponse struct { 17 | FunctionName string `json:"functionName"` 18 | FunctionVersion string `json:"functionVersion"` 19 | Handler string `json:"handler"` 20 | } 21 | 22 | // NextEventResponse is the response for /event/next 23 | type NextEventResponse struct { 24 | EventType EventType `json:"eventType"` 25 | DeadlineMs int64 `json:"deadlineMs"` 26 | RequestID string `json:"requestId"` 27 | InvokedFunctionArn string `json:"invokedFunctionArn"` 28 | Tracing Tracing `json:"tracing"` 29 | } 30 | 31 | // Tracing is part of the response for /event/next 32 | type Tracing struct { 33 | Type string `json:"type"` 34 | Value string `json:"value"` 35 | } 36 | 37 | // EventType represents the type of events recieved from /event/next 38 | type EventType string 39 | 40 | const ( 41 | // Invoke is a lambda invoke 42 | Invoke EventType = "INVOKE" 43 | 44 | // Shutdown is a shutdown event for the environment 45 | Shutdown EventType = "SHUTDOWN" 46 | 47 | extensionNameHeader = "Lambda-Extension-Name" 48 | extensionIdentiferHeader = "Lambda-Extension-Identifier" 49 | ) 50 | 51 | // Client is a simple client for the Lambda Extensions API 52 | type Client struct { 53 | baseURL string 54 | httpClient *http.Client 55 | extensionID string 56 | } 57 | 58 | // NewClient returns a Lambda Extensions API client 59 | func NewClient(awsLambdaRuntimeAPI string) *Client { 60 | baseURL := fmt.Sprintf("http://%s/2020-01-01/extension", awsLambdaRuntimeAPI) 61 | return &Client{ 62 | baseURL: baseURL, 63 | httpClient: &http.Client{}, 64 | } 65 | } 66 | 67 | // Register will register the extension with the Extensions API 68 | func (e *Client) Register(ctx context.Context, filename string) (*RegisterResponse, error) { 69 | const action = "/register" 70 | url := e.baseURL + action 71 | 72 | reqBody, err := json.Marshal(map[string]interface{}{ 73 | "events": []EventType{Invoke, Shutdown}, 74 | }) 75 | if err != nil { 76 | return nil, err 77 | } 78 | httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody)) 79 | if err != nil { 80 | return nil, err 81 | } 82 | httpReq.Header.Set(extensionNameHeader, filename) 83 | httpRes, err := e.httpClient.Do(httpReq) 84 | if err != nil { 85 | return nil, err 86 | } 87 | if httpRes.StatusCode != 200 { 88 | return nil, fmt.Errorf("request failed with status %s", httpRes.Status) 89 | } 90 | defer httpRes.Body.Close() 91 | body, err := ioutil.ReadAll(httpRes.Body) 92 | if err != nil { 93 | return nil, err 94 | } 95 | res := RegisterResponse{} 96 | err = json.Unmarshal(body, &res) 97 | if err != nil { 98 | return nil, err 99 | } 100 | e.extensionID = httpRes.Header.Get(extensionIdentiferHeader) 101 | return &res, nil 102 | } 103 | 104 | // NextEvent blocks while long polling for the next lambda invoke or shutdown 105 | func (e *Client) NextEvent(ctx context.Context) (*NextEventResponse, error) { 106 | const action = "/event/next" 107 | url := e.baseURL + action 108 | 109 | httpReq, err := http.NewRequestWithContext(ctx, "GET", url, nil) 110 | if err != nil { 111 | return nil, err 112 | } 113 | httpReq.Header.Set(extensionIdentiferHeader, e.extensionID) 114 | httpRes, err := e.httpClient.Do(httpReq) 115 | if err != nil { 116 | return nil, err 117 | } 118 | if httpRes.StatusCode != 200 { 119 | return nil, fmt.Errorf("request failed with status %s", httpRes.Status) 120 | } 121 | defer httpRes.Body.Close() 122 | body, err := ioutil.ReadAll(httpRes.Body) 123 | if err != nil { 124 | return nil, err 125 | } 126 | res := NextEventResponse{} 127 | err = json.Unmarshal(body, &res) 128 | if err != nil { 129 | return nil, err 130 | } 131 | return &res, nil 132 | } 133 | -------------------------------------------------------------------------------- /layer/layer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/aidansteele/postinvoke" 6 | "github.com/aidansteele/postinvoke/layer/extension" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | ) 11 | 12 | func main() { 13 | ctx := context.Background() 14 | 15 | c := extension.NewClient(os.Getenv("AWS_LAMBDA_RUNTIME_API")) 16 | name := filepath.Base(os.Args[0]) // extension name has to match the filename 17 | 18 | _, err := c.Register(ctx, name) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | ch := make(chan struct{}) 24 | 25 | http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) { 26 | ch <- struct{}{} 27 | }) 28 | 29 | http.HandleFunc("/check", func(w http.ResponseWriter, r *http.Request) { 30 | // used by sdk to check if the extension is running 31 | w.WriteHeader(200) 32 | }) 33 | 34 | go http.ListenAndServe(postinvoke.Address, nil) 35 | loop(ctx, c, ch) 36 | } 37 | 38 | func loop(ctx context.Context, c *extension.Client, ch chan struct{}) { 39 | for { 40 | select { 41 | case <-ctx.Done(): 42 | return 43 | default: 44 | _, err := c.NextEvent(ctx) 45 | if err != nil { 46 | panic(err) 47 | } 48 | 49 | // wait for sdk to say its done 50 | <-ch 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /postinvoke.go: -------------------------------------------------------------------------------- 1 | package postinvoke 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/aws/aws-lambda-go/lambda" 7 | "net/http" 8 | "os/signal" 9 | "syscall" 10 | ) 11 | 12 | const Address = "127.0.0.1:1339" 13 | 14 | type handler struct { 15 | inner lambda.Handler 16 | client *http.Client 17 | stack []func() 18 | } 19 | 20 | type contextKey string 21 | 22 | const contextKeyWrapper = contextKey("contextKeyWrapper") 23 | 24 | type Options struct { 25 | Client *http.Client 26 | } 27 | 28 | func WrapHandler(inner lambda.Handler, opts *Options) lambda.Handler { 29 | if opts == nil { 30 | opts = &Options{} 31 | } 32 | 33 | c := opts.Client 34 | if c == nil { 35 | c = http.DefaultClient 36 | } 37 | 38 | h := &handler{inner: inner, client: c} 39 | 40 | get, err := h.client.Get(fmt.Sprintf("http://%s/check", Address)) 41 | if err != nil || get.StatusCode != 200 { 42 | panic("postinvoke: unable to connect to extension - did you forget to add the lambda layer?") 43 | } 44 | 45 | return h 46 | } 47 | 48 | func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) { 49 | ctx = context.WithValue(ctx, contextKeyWrapper, h) 50 | defer func() { 51 | go h.after() 52 | }() 53 | 54 | return h.inner.Invoke(ctx, payload) 55 | } 56 | 57 | func (h *handler) after() { 58 | for i := len(h.stack) - 1; i >= 0; i-- { 59 | fn := h.stack[i] 60 | fn() 61 | } 62 | 63 | // empty the stack, keep the memory 64 | h.stack = h.stack[:0] 65 | 66 | _, err := h.client.Post(fmt.Sprintf("http://%s/done", Address), "", nil) 67 | if err != nil { 68 | panic(err) 69 | } 70 | } 71 | 72 | func Run(ctx context.Context, fn func()) { 73 | wrapper, ok := ctx.Value(contextKeyWrapper).(*handler) 74 | if !ok { 75 | panic("postinvoke: context unavailable - did you forget to wrap your lambda handler?") 76 | } 77 | 78 | wrapper.stack = append(wrapper.stack, fn) 79 | } 80 | 81 | func Shutdown(fn func()) { 82 | go func() { 83 | ctx := context.Background() 84 | ctx, _ = signal.NotifyContext(ctx, syscall.SIGTERM) 85 | <-ctx.Done() 86 | fn() 87 | }() 88 | } 89 | --------------------------------------------------------------------------------