├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── app └── main.go ├── scripts └── build.sh └── terraform ├── main.tf ├── output.tf └── variables.tf /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Test binary, build with `go test -c` 8 | *.test 9 | 10 | # Output of the go coverage tool, specifically when used with LiteIDE 11 | *.out 12 | 13 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 14 | .glide/ 15 | 16 | *.h5 17 | *.pb 18 | 19 | .DS_Store 20 | 21 | .terraform 22 | *.tfstate* 23 | *.zip 24 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow 2 | 3 | RUN mkdir /build 4 | 5 | # Install Git 6 | RUN apt-get update 7 | RUN apt-get install -y git zip 8 | 9 | # Install Go 10 | RUN curl -O https://dl.google.com/go/go1.10.linux-amd64.tar.gz 11 | RUN tar -xzvf go1.10.linux-amd64.tar.gz 12 | RUN mv go /usr/local 13 | ENV GOPATH /go 14 | ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH 15 | 16 | # Install TensorFlow 17 | RUN curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.6.0.tar.gz" | tar -C /build -xz 18 | ENV LIBRARY_PATH=$LIBRARY_PATH:/build/lib 19 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/build/lib 20 | RUN go get github.com/tensorflow/tensorflow/tensorflow/go 21 | 22 | WORKDIR $GOPATH/src/app 23 | 24 | # Download pre-trained Inception model. 25 | RUN curl -O "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" 26 | RUN unzip inception5h.zip 27 | 28 | COPY app/ ./ 29 | 30 | # Build application. 31 | RUN mv tensorflow_inception_graph.pb /build/graph.pb 32 | RUN mv imagenet_comp_graph_label_strings.txt /build/labels.txt 33 | RUN chmod -R 644 /build/graph.pb 34 | RUN chmod -R 644 /build/labels.txt 35 | RUN go get 36 | RUN go build -o /build/main 37 | 38 | CMD ["zip", "-r", "/build.zip", "/build/"] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Steve Kaliski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | ./scripts/build.sh 3 | 4 | .PHONY: build 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # infer-lambda-api 2 | 3 | An example image recognition API using [infer](https://github.com/sjkaliski/infer) and AWS [Lambda](#) + [API Gateway](#). 4 | 5 | ## Overview 6 | 7 | This project serves as an example for how to quickly deploy a Go Image Recognition API using Terraform and AWS. It uses [infer](https://github.com/sjkaliski/infer), a Go package that provides abstractions for interacting with TensorFlow models. 8 | 9 | ## Setup 10 | 11 | ### Build 12 | 13 | Lambda functions require a zipped application containing the executable and any supporting assets. A `Dockerfile` has been provided to make the build step easy. 14 | 15 | ``` 16 | $ make build 17 | ``` 18 | 19 | Once executed, a `deployment.zip` file will be included in the root of the project. The zip includes: 20 | 21 | - Lambda Function as a Go binary 22 | - TensorFlow bindings 23 | - Model file 24 | - Labels file 25 | 26 | ### Deploy 27 | 28 | Once the application has been built, it's ready to be deployed. The application is deployed as a Lambda function and uses API Gateway to provide HTTP access. [Terraform](https://www.terraform.io/) config has been included to make setup easy. 29 | 30 | For an in-depth tutorial for using Terraform and Lambda/API Gateway, see [here](https://www.terraform.io/docs/providers/aws/guides/serverless-with-aws-lambda-and-api-gateway.html). 31 | 32 | First, make sure you have a valid AWS account and have installed Terraform. 33 | 34 | - [AWS Account](https://aws.amazon.com/) 35 | - [Terraform](https://www.terraform.io/) 36 | 37 | From there, initialize the Terraform build and apply. *Note: This will incur fees*. 38 | 39 | ``` 40 | $ cd terraform 41 | $ terraform init 42 | $ terraform apply 43 | ``` 44 | 45 | This creates the following 46 | 47 | - S3 Bucket and Object containing deployment zip 48 | - Lambda Function 49 | - IAM Role 50 | - API Gateway Rest API 51 | - API Gateway Methods + Integrations into Lambda 52 | 53 | Together, this results in a Lambda function that is accessible via HTTP over API Gateway. 54 | 55 | ### Teardown 56 | 57 | ``` 58 | $ terraform destroy 59 | ``` 60 | 61 | A version number will be required. Any semver value works. 62 | 63 | ## Usage 64 | 65 | Locate the API Gateway Endpoint. 66 | 67 | ``` 68 | $ cd terraform 69 | $ terraform output endpoint 70 | ``` 71 | 72 | Execute a request, example using `curl` below. 73 | 74 | ``` 75 | $ curl --upload-file "/path/to/img.png" "ENDPOINT" -H "Content-Type: image/png" 76 | [ 77 | { 78 | "Class": "thing", 79 | "Score": 0.875 80 | }, 81 | ... 82 | ] 83 | ``` 84 | 85 | ## Notes 86 | 87 | - Lambda Function memory size set to 512mb. Any lower and the TF model could not execute. 88 | - Startup time on first request can often be slow. 89 | -------------------------------------------------------------------------------- /app/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "errors" 7 | "io/ioutil" 8 | "net/http" 9 | "os" 10 | "strings" 11 | 12 | "github.com/aws/aws-lambda-go/events" 13 | "github.com/aws/aws-lambda-go/lambda" 14 | "github.com/sjkaliski/infer" 15 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 16 | ) 17 | 18 | var ( 19 | m *infer.Model 20 | ) 21 | 22 | var ( 23 | errInvalidImage = errors.New("invalid image supplied") 24 | ) 25 | 26 | func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { 27 | if len(request.Body) < 1 || !request.IsBase64Encoded { 28 | return events.APIGatewayProxyResponse{}, errInvalidImage 29 | } 30 | 31 | reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(request.Body)) 32 | opts := &infer.ImageOptions{ 33 | IsGray: false, 34 | } 35 | 36 | predictions, err := m.FromImage(reader, opts) 37 | if err != nil { 38 | return events.APIGatewayProxyResponse{}, errInvalidImage 39 | } 40 | 41 | data, err := json.Marshal(predictions[:10]) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | return events.APIGatewayProxyResponse{ 47 | Body: string(data), 48 | StatusCode: http.StatusOK, 49 | }, nil 50 | } 51 | 52 | func init() { 53 | model, err := ioutil.ReadFile(os.Getenv("MODEL")) 54 | if err != nil { 55 | panic(err) 56 | } 57 | 58 | labelFile, err := ioutil.ReadFile(os.Getenv("LABELS")) 59 | if err != nil { 60 | panic(err) 61 | } 62 | labels := strings.Split(string(labelFile), "\n") 63 | 64 | graph := tf.NewGraph() 65 | err = graph.Import(model, "") 66 | if err != nil { 67 | panic(err) 68 | } 69 | 70 | m, _ = infer.New(&infer.Model{ 71 | Graph: graph, 72 | Classes: labels, 73 | Input: &infer.Input{ 74 | Key: "input", 75 | Dimensions: []int32{224, 224}, 76 | }, 77 | Output: &infer.Output{ 78 | Key: "output", 79 | Dimensions: [][]float32{}, 80 | }, 81 | }) 82 | } 83 | 84 | func main() { 85 | lambda.Start(handler) 86 | } 87 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | docker build -t sjkaliski/infer-lambda . 4 | docker run sjkaliski/infer-lambda 5 | ID=$(docker ps --latest --quiet) 6 | docker cp $ID:/build.zip deployment.zip 7 | -------------------------------------------------------------------------------- /terraform/main.tf: -------------------------------------------------------------------------------- 1 | provider "aws" { 2 | region = "${var.region}" 3 | } 4 | 5 | # S3 Bucket and Object. This is where the application is located. 6 | resource "aws_s3_bucket" "lambda" { 7 | bucket = "infer-lambda" 8 | acl = "private" 9 | 10 | force_destroy = true 11 | } 12 | 13 | resource "aws_s3_bucket_object" "lambda" { 14 | bucket = "${aws_s3_bucket.lambda.id}" 15 | key = "${var.version}/deployment.zip" 16 | source = "../deployment.zip" 17 | etag = "${md5(file("../deployment.zip"))}" 18 | } 19 | 20 | # New function, uses s3 obj as source of function. 21 | resource "aws_lambda_function" "main" { 22 | function_name = "InferLambdaExample" 23 | 24 | s3_bucket = "${aws_s3_bucket_object.lambda.bucket}" 25 | s3_key = "${aws_s3_bucket_object.lambda.key}" 26 | 27 | handler = "build/main" 28 | runtime = "go1.x" 29 | 30 | role = "${aws_iam_role.lambda.arn}" 31 | 32 | memory_size = 512 33 | 34 | environment { 35 | variables = { 36 | LIBRARY_PATH = "$LIBRARY_PATH:/var/task/build/lib" 37 | LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:/var/task/build/lib" 38 | MODEL = "/var/task/build/graph.pb" 39 | LABELS = "/var/task/build/labels.txt" 40 | } 41 | } 42 | } 43 | 44 | # IAM role to be used across resources. 45 | resource "aws_iam_role" "lambda" { 46 | name = "infer_lambda_example" 47 | assume_role_policy = "${data.aws_iam_policy_document.lambda.json}" 48 | } 49 | 50 | data "aws_iam_policy_document" "lambda" { 51 | statement = { 52 | actions = [ 53 | "sts:AssumeRole", 54 | ] 55 | 56 | principals { 57 | type = "Service" 58 | 59 | identifiers = [ 60 | "lambda.amazonaws.com", 61 | ] 62 | } 63 | } 64 | } 65 | 66 | # Create a REST API. 67 | resource "aws_api_gateway_rest_api" "main" { 68 | name = "InferLambdaExample" 69 | 70 | binary_media_types = [ 71 | "image/png", 72 | "image/jpeg", 73 | ] 74 | } 75 | 76 | resource "aws_api_gateway_resource" "proxy" { 77 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 78 | parent_id = "${aws_api_gateway_rest_api.main.root_resource_id}" 79 | path_part = "{proxy+}" 80 | } 81 | 82 | resource "aws_api_gateway_method" "proxy" { 83 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 84 | resource_id = "${aws_api_gateway_resource.proxy.id}" 85 | http_method = "ANY" 86 | authorization = "NONE" 87 | } 88 | 89 | resource "aws_api_gateway_integration" "lambda" { 90 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 91 | resource_id = "${aws_api_gateway_method.proxy.resource_id}" 92 | http_method = "${aws_api_gateway_method.proxy.http_method}" 93 | 94 | integration_http_method = "POST" 95 | type = "AWS_PROXY" 96 | uri = "${aws_lambda_function.main.invoke_arn}" 97 | } 98 | 99 | resource "aws_api_gateway_method" "proxy_root" { 100 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 101 | resource_id = "${aws_api_gateway_rest_api.main.root_resource_id}" 102 | http_method = "ANY" 103 | authorization = "NONE" 104 | } 105 | 106 | resource "aws_api_gateway_integration" "lambda_root" { 107 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 108 | resource_id = "${aws_api_gateway_method.proxy_root.resource_id}" 109 | http_method = "${aws_api_gateway_method.proxy_root.http_method}" 110 | 111 | integration_http_method = "POST" 112 | type = "AWS_PROXY" 113 | uri = "${aws_lambda_function.main.invoke_arn}" 114 | } 115 | 116 | resource "aws_api_gateway_deployment" "example" { 117 | depends_on = [ 118 | "aws_api_gateway_integration.lambda", 119 | "aws_api_gateway_integration.lambda_root", 120 | ] 121 | 122 | rest_api_id = "${aws_api_gateway_rest_api.main.id}" 123 | stage_name = "test" 124 | } 125 | 126 | resource "aws_lambda_permission" "apigw" { 127 | statement_id = "AllowAPIGatewayInvoke" 128 | action = "lambda:InvokeFunction" 129 | function_name = "${aws_lambda_function.main.arn}" 130 | principal = "apigateway.amazonaws.com" 131 | source_arn = "${aws_api_gateway_deployment.example.execution_arn}/*/*" 132 | } 133 | -------------------------------------------------------------------------------- /terraform/output.tf: -------------------------------------------------------------------------------- 1 | output "endpoint" { 2 | value = "${aws_api_gateway_deployment.example.invoke_url}" 3 | } 4 | 5 | output "version" { 6 | value = "${var.version}" 7 | } 8 | -------------------------------------------------------------------------------- /terraform/variables.tf: -------------------------------------------------------------------------------- 1 | variable "region" { 2 | description = "The AWS region to deploy into." 3 | default = "us-east-1" 4 | } 5 | 6 | variable "version" { 7 | description = "The version (semver) of the function." 8 | } 9 | --------------------------------------------------------------------------------