├── .dockerignore ├── .editorconfig ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── cover.jpg ├── go.mod ├── go.sum ├── image_tensor.go ├── main.go └── utilities.go /.dockerignore: -------------------------------------------------------------------------------- 1 | *.md 2 | *.jpg 3 | .directory 4 | .DS_Store 5 | desktop.ini 6 | LICENSE 7 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | indent_style = space 6 | indent_size = 2 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true 9 | 10 | [*.md] 11 | max_line_length = off 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | .DS_Store 3 | desktop.ini 4 | .directory -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.12.0 2 | 3 | # Install TensorFlow C library 4 | RUN curl -L \ 5 | "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0.tar.gz" | \ 6 | tar -C "/usr/local" -xz 7 | RUN ldconfig 8 | # Hide some warnings 9 | ENV TF_CPP_MIN_LOG_LEVEL 2 10 | 11 | # Install Go (https://github.com/docker-library/golang/blob/221ee92559f2963c1fe55646d3516f5b8f4c91a4/1.9/stretch/Dockerfile) 12 | RUN apt-get update && apt-get install -y --no-install-recommends \ 13 | g++ \ 14 | gcc \ 15 | libc6-dev \ 16 | make \ 17 | pkg-config \ 18 | wget \ 19 | curl \ 20 | git \ 21 | && rm -rf /var/lib/apt/lists/* 22 | 23 | ENV GOLANG_VERSION 1.11.2 24 | 25 | RUN set -eux; \ 26 | \ 27 | # this "case" statement is generated via "update.sh" 28 | dpkgArch="$(dpkg --print-architecture)"; \ 29 | case "${dpkgArch##*-}" in \ 30 | amd64) goRelArch='linux-amd64'; goRelSha256='1dfe664fa3d8ad714bbd15a36627992effd150ddabd7523931f077b3926d736d' ;; \ 31 | armhf) goRelArch='linux-armv6l'; goRelSha256='b9d16a8eb1f7b8fdadd27232f6300aa8b4427e5e4cb148c4be4089db8fb56429' ;; \ 32 | arm64) goRelArch='linux-arm64'; goRelSha256='98a42b9b8d3bacbcc6351a1e39af52eff582d0bc3ac804cd5a97ce497dd84026' ;; \ 33 | i386) goRelArch='linux-386'; goRelSha256='e74f2f37b43b9b1bcf18008a11e0efb8921b41dff399a4f48ac09a4f25729881' ;; \ 34 | ppc64el) goRelArch='linux-ppc64le'; goRelSha256='23291935a299fdfde4b6a988ce3faa0c7a498aab6d56bbafbf1e7476468529a3' ;; \ 35 | s390x) goRelArch='linux-s390x'; goRelSha256='a67ef820ef8cfecc8d68c69dd5bf513aaf647c09b6605570af425bf5fe8a32f0' ;; \ 36 | *) goRelArch='src'; goRelSha256='042fba357210816160341f1002440550e952eb12678f7c9e7e9d389437942550'; \ 37 | echo >&2; echo >&2 "warning: current architecture ($dpkgArch) does not have a corresponding Go binary release; will be building from source"; echo >&2 ;; \ 38 | esac; \ 39 | \ 40 | url="https://golang.org/dl/go${GOLANG_VERSION}.${goRelArch}.tar.gz"; \ 41 | wget -O go.tgz "$url"; \ 42 | echo "${goRelSha256} *go.tgz" | sha256sum -c -; \ 43 | tar -C /usr/local -xzf go.tgz; \ 44 | rm go.tgz; \ 45 | \ 46 | if [ "$goRelArch" = 'src' ]; then \ 47 | echo >&2; \ 48 | echo >&2 'error: UNIMPLEMENTED'; \ 49 | echo >&2 'TODO install golang-any from jessie-backports for GOROOT_BOOTSTRAP (and uninstall after build)'; \ 50 | echo >&2; \ 51 | exit 1; \ 52 | fi; \ 53 | \ 54 | export PATH="/usr/local/go/bin:$PATH"; \ 55 | go version 56 | 57 | ENV GOPATH /go 58 | ENV GO111MODULE=on 59 | ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH 60 | 61 | RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH" 62 | 63 | # Download InceptionV3 model 64 | RUN mkdir -p /model && \ 65 | wget "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" -O /model/inception.zip && \ 66 | unzip /model/inception.zip -d /model && \ 67 | chmod -R 777 /model 68 | 69 | # Set up project directory 70 | WORKDIR "/go/src/github.com/tinrab/go-tensorflow-image-recognition" 71 | COPY . . 72 | 73 | # Install the app 74 | RUN go build -o /usr/bin/app . 75 | 76 | # Run the app 77 | CMD [ "app" ] 78 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tin Rabzelj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Recognition API in Go using TensorFlow 2 | 3 |

4 | 5 |

6 | 7 | This is the underlying code for article [Build an Image Recognition API with Go and TensorFlow](https://outcrawl.com/image-recognition-api-go-tensorflow). 8 | 9 | ## Running the service 10 | 11 | Build the image. 12 | 13 | ``` 14 | $ docker build -t localhost/recognition . 15 | ``` 16 | 17 | Run service in a container. 18 | 19 | ``` 20 | $ docker run -p 8080:8080 --rm localhost/recognition 21 | ``` 22 | 23 | Call the service. 24 | 25 | ``` 26 | $ curl localhost:8080/recognize -F 'image=@./cat.jpg' 27 | { 28 | "filename": "cat.jpg", 29 | "labels": [ 30 | { "label": "tabby", "probability": 0.45087516 }, 31 | { "label": "Egyptian cat", "probability": 0.26096493 }, 32 | { "label": "tiger cat", "probability": 0.23208225 }, 33 | { "label": "lynx", "probability": 0.050698064 }, 34 | { "label": "grey fox", "probability": 0.0019019963 } 35 | ] 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinrab/go-tensorflow-image-recognition/fe7a3c7b914cd0cf54b0927760145601a76507fd/cover.jpg -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tinrab/go-tensorflow-image-recognition 2 | 3 | require ( 4 | github.com/julienschmidt/httprouter v1.2.0 5 | github.com/tensorflow/tensorflow v1.12.0 6 | ) 7 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/julienschmidt/httprouter v1.2.0 h1:TDTW5Yz1mjftljbcKqRcrYhd4XeOoI98t+9HbQbYf7g= 2 | github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= 3 | github.com/tensorflow/tensorflow v1.12.0 h1:fT4okrN4BkpgotWmDwS56wM6BdkRpTL0lLMzvkM+bLo= 4 | github.com/tensorflow/tensorflow v1.12.0/go.mod h1:itOSERT4trABok4UOoG+X4BoKds9F3rIsySdn+Lvu90= 5 | -------------------------------------------------------------------------------- /image_tensor.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | 6 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 7 | "github.com/tensorflow/tensorflow/tensorflow/go/op" 8 | ) 9 | 10 | func makeTensorFromImage(imageBuffer *bytes.Buffer, imageFormat string) (*tf.Tensor, error) { 11 | tensor, err := tf.NewTensor(imageBuffer.String()) 12 | if err != nil { 13 | return nil, err 14 | } 15 | graph, input, output, err := makeTransformImageGraph(imageFormat) 16 | if err != nil { 17 | return nil, err 18 | } 19 | session, err := tf.NewSession(graph, nil) 20 | if err != nil { 21 | return nil, err 22 | } 23 | defer session.Close() 24 | normalized, err := session.Run( 25 | map[tf.Output]*tf.Tensor{input: tensor}, 26 | []tf.Output{output}, 27 | nil) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return normalized[0], nil 32 | } 33 | 34 | // Creates a graph to decode, rezise and normalize an image 35 | func makeTransformImageGraph(imageFormat string) (graph *tf.Graph, input, output tf.Output, err error) { 36 | const ( 37 | H, W = 224, 224 38 | Mean = float32(117) 39 | Scale = float32(1) 40 | ) 41 | s := op.NewScope() 42 | input = op.Placeholder(s, tf.String) 43 | // Decode PNG or JPEG 44 | var decode tf.Output 45 | if imageFormat == "png" { 46 | decode = op.DecodePng(s, input, op.DecodePngChannels(3)) 47 | } else { 48 | decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)) 49 | } 50 | // Div and Sub perform (value-Mean)/Scale for each pixel 51 | output = op.Div(s, 52 | op.Sub(s, 53 | // Resize to 224x224 with bilinear interpolation 54 | op.ResizeBilinear(s, 55 | // Create a batch containing a single image 56 | op.ExpandDims(s, 57 | // Use decoded pixel values 58 | op.Cast(s, decode, tf.Float), 59 | op.Const(s.SubScope("make_batch"), int32(0))), 60 | op.Const(s.SubScope("size"), []int32{H, W})), 61 | op.Const(s.SubScope("mean"), Mean)), 62 | op.Const(s.SubScope("scale"), Scale)) 63 | graph, err = s.Finalize() 64 | return graph, input, output, err 65 | } 66 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "log" 10 | "net/http" 11 | "os" 12 | "sort" 13 | "strings" 14 | 15 | "github.com/julienschmidt/httprouter" 16 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 17 | ) 18 | 19 | type ClassifyResult struct { 20 | Filename string `json:"filename"` 21 | Labels []LabelResult `json:"labels"` 22 | } 23 | 24 | type LabelResult struct { 25 | Label string `json:"label"` 26 | Probability float32 `json:"probability"` 27 | } 28 | 29 | var ( 30 | graphModel *tf.Graph 31 | sessionModel *tf.Session 32 | labels []string 33 | ) 34 | 35 | func main() { 36 | if err := loadModel(); err != nil { 37 | log.Fatal(err) 38 | return 39 | } 40 | 41 | r := httprouter.New() 42 | 43 | r.POST("/recognize", recognizeHandler) 44 | 45 | fmt.Println("Listening on port 8080...") 46 | log.Fatal(http.ListenAndServe(":8080", r)) 47 | } 48 | 49 | func loadModel() error { 50 | // Load inception model 51 | model, err := ioutil.ReadFile("/model/tensorflow_inception_graph.pb") 52 | if err != nil { 53 | return err 54 | } 55 | graphModel = tf.NewGraph() 56 | if err := graphModel.Import(model, ""); err != nil { 57 | return err 58 | } 59 | 60 | sessionModel, err = tf.NewSession(graphModel, nil) 61 | if err != nil { 62 | log.Fatal(err) 63 | } 64 | 65 | // Load labels 66 | labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt") 67 | if err != nil { 68 | return err 69 | } 70 | defer labelsFile.Close() 71 | scanner := bufio.NewScanner(labelsFile) 72 | // Labels are separated by newlines 73 | for scanner.Scan() { 74 | labels = append(labels, scanner.Text()) 75 | } 76 | if err := scanner.Err(); err != nil { 77 | return err 78 | } 79 | return nil 80 | } 81 | 82 | func recognizeHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 83 | // Read image 84 | imageFile, header, err := r.FormFile("image") 85 | // Will contain filename and extension 86 | imageName := strings.Split(header.Filename, ".") 87 | if err != nil { 88 | responseError(w, "Could not read image", http.StatusBadRequest) 89 | return 90 | } 91 | defer imageFile.Close() 92 | var imageBuffer bytes.Buffer 93 | // Copy image data to a buffer 94 | io.Copy(&imageBuffer, imageFile) 95 | 96 | // ... 97 | // Make tensor 98 | tensor, err := makeTensorFromImage(&imageBuffer, imageName[:1][0]) 99 | if err != nil { 100 | responseError(w, "Invalid image", http.StatusBadRequest) 101 | return 102 | } 103 | 104 | // Run inference 105 | output, err := sessionModel.Run( 106 | map[tf.Output]*tf.Tensor{ 107 | graphModel.Operation("input").Output(0): tensor, 108 | }, 109 | []tf.Output{ 110 | graphModel.Operation("output").Output(0), 111 | }, 112 | nil) 113 | if err != nil { 114 | responseError(w, "Could not run inference", http.StatusInternalServerError) 115 | return 116 | } 117 | 118 | // Return best labels 119 | responseJSON(w, ClassifyResult{ 120 | Filename: header.Filename, 121 | Labels: findBestLabels(output[0].Value().([][]float32)[0]), 122 | }) 123 | } 124 | 125 | type ByProbability []LabelResult 126 | 127 | func (a ByProbability) Len() int { return len(a) } 128 | func (a ByProbability) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 129 | func (a ByProbability) Less(i, j int) bool { return a[i].Probability > a[j].Probability } 130 | 131 | func findBestLabels(probabilities []float32) []LabelResult { 132 | // Make a list of label/probability pairs 133 | var resultLabels []LabelResult 134 | for i, p := range probabilities { 135 | if i >= len(labels) { 136 | break 137 | } 138 | resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p}) 139 | } 140 | // Sort by probability 141 | sort.Sort(ByProbability(resultLabels)) 142 | // Return top 5 labels 143 | return resultLabels[:5] 144 | } 145 | -------------------------------------------------------------------------------- /utilities.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | func responseError(w http.ResponseWriter, message string, code int) { 9 | w.Header().Set("Content-Type", "application/json") 10 | w.WriteHeader(code) 11 | json.NewEncoder(w).Encode(map[string]string{"error": message}) 12 | } 13 | 14 | func responseJSON(w http.ResponseWriter, data interface{}) { 15 | w.Header().Set("Content-Type", "application/json") 16 | json.NewEncoder(w).Encode(data) 17 | } 18 | --------------------------------------------------------------------------------