├── .dockerignore ├── .gitignore ├── Dockerfile ├── Dockerfile.amd64 ├── Dockerfile.arm32 ├── Dockerfile.arm64 ├── Dockerfile.base.cuda ├── Dockerfile.builder ├── Dockerfile.noavx ├── LICENSE.md ├── Makefile ├── README.md ├── builder.sh ├── cmd ├── api.go ├── client.go ├── root.go └── version.go ├── conf ├── defaults.go ├── logger.go ├── signal.go └── version.go ├── config.arm.yaml ├── config.yaml ├── detector ├── auth.go ├── dconfig │ └── dconfig.go ├── detector.go ├── regions.go ├── tensorflow │ └── tensorflow.go └── tflite │ ├── detector.go │ ├── go-tflite │ ├── callback.go │ ├── delegates │ │ ├── delegates.go │ │ └── edgetpu │ │ │ ├── edgetpu.go │ │ │ └── edgetpu.go.h │ ├── tflite.go │ ├── tflite.go.h │ ├── tflite_experimental.go │ ├── tflite_experimental.go.h │ ├── tflite_test.go │ ├── tflite_type.go │ └── type_string.go │ └── ppm.go ├── examples ├── grpcclient-single.go ├── grpcclient-stream.go └── rtspdetector.go ├── fetch_models.sh ├── go.mod ├── go.sum ├── main.go ├── odrpc ├── odrpc.go ├── raw.go ├── rpc.pb.go ├── rpc.pb.gw.go ├── rpc.proto └── rpc.swagger.json ├── server ├── error.go ├── jsonpb.go ├── routes.go ├── rpc │ ├── rpc.go │ ├── version.pb.go │ ├── version.pb.gw.go │ ├── version.proto │ └── version.swagger.json ├── server.go └── version.go └── tf_arm_toolchain_patch.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | doods 3 | *.txt 4 | *.png 5 | *.bmp 6 | *.ppm 7 | *.tflite 8 | *.pb 9 | *.weights 10 | *.cfg 11 | *.zip 12 | private 13 | models/** 14 | Dockerfile.*.* 15 | Dockerfile.* 16 | README.md 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | doods 3 | *.txt 4 | *.png 5 | *.bmp 6 | *.ppm 7 | *.jpg 8 | *.tflite 9 | *.zip 10 | private 11 | models 12 | example.yaml 13 | libedgetpu -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as base 2 | 3 | # Install reqs with cross compile support 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 6 | build-essential patch g++ python python-future python-numpy python-six python3 \ 7 | cmake ca-certificates \ 8 | libc6-dev libstdc++6 libusb-1.0-0 9 | 10 | # Install protoc 11 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 12 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 13 | rm /usr/local/readme.txt && \ 14 | rm protoc-3.12.3-linux-x86_64.zip 15 | 16 | # Version Configuration 17 | ARG BAZEL_VERSION="2.0.0" 18 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 19 | ARG OPENCV_VERSION="4.3.0" 20 | ARG GO_VERSION="1.14.3" 21 | 22 | # Install bazel 23 | ENV BAZEL_VERSION $BAZEL_VERSION 24 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 25 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 26 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 27 | 28 | # Download tensorflow sources 29 | ENV TF_VERSION $TF_VERSION 30 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 31 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 32 | 33 | # Configure tensorflow 34 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 35 | RUN cd /opt/tensorflow && yes '' | ./configure 36 | 37 | # Tensorflow build flags 38 | ENV BAZEL_COPT_FLAGS="-c opt --config monolithic --copt=-march=native --copt=-O3 --copt=-fomit-frame-pointer --incompatible_no_support_tools_in_action_inputs=false --config=noaws --config=nohdfs" 39 | ENV BAZEL_EXTRA_FLAGS="--host_linkopt=-lm" 40 | 41 | # Compile and build tensorflow lite 42 | RUN cd /opt/tensorflow && \ 43 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 44 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 45 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 46 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 47 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 48 | 49 | # Compile and install tensorflow shared library 50 | RUN cd /opt/tensorflow && \ 51 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 52 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 53 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 54 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 55 | 56 | # cleanup so the cache directory isn't huge 57 | RUN cd /opt/tensorflow && \ 58 | bazel clean && rm -Rf /root/.cache 59 | 60 | # Install GOCV 61 | ENV OPENCV_VERSION $OPENCV_VERSION 62 | RUN cd /tmp && \ 63 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 64 | unzip -q opencv.zip && \ 65 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 66 | unzip -q opencv_contrib.zip && \ 67 | rm opencv.zip opencv_contrib.zip && \ 68 | cd opencv-${OPENCV_VERSION} && \ 69 | mkdir build && cd build && \ 70 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 71 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 72 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 73 | -D WITH_JASPER=OFF \ 74 | -D WITH_QT=OFF \ 75 | -D WITH_GTK=OFF \ 76 | -D BUILD_DOCS=OFF \ 77 | -D BUILD_EXAMPLES=OFF \ 78 | -D BUILD_TESTS=OFF \ 79 | -D BUILD_PERF_TESTS=OFF \ 80 | -D BUILD_opencv_java=NO \ 81 | -D BUILD_opencv_python=NO \ 82 | -D BUILD_opencv_python2=NO \ 83 | -D BUILD_opencv_python3=NO \ 84 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 85 | make -j $(nproc --all) && \ 86 | make preinstall && make install && \ 87 | cd /tmp && rm -rf opencv* 88 | 89 | # Fetch the edgetpu library locally 90 | ADD libedgetpu/out/throttled/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 91 | RUN ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \ 92 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \ 93 | mkdir -p /usr/local/include/libedgetpu 94 | ADD libedgetpu/tflite/public/edgetpu.h /usr/local/include/libedgetpu/edgetpu.h 95 | ADD libedgetpu/tflite/public/edgetpu_c.h /usr/local/include/libedgetpu/edgetpu_c.h 96 | 97 | # Configure the Go version to be used 98 | ENV GO_ARCH "amd64" 99 | ENV GOARCH=amd64 100 | 101 | # Install Go 102 | ENV GO_VERSION $GO_VERSION 103 | RUN curl -kLo go${GO_VERSION}.linux-${GO_ARCH}.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 104 | tar -C /usr/local -xzf go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 105 | rm go${GO_VERSION}.linux-${GO_ARCH}.tar.gz 106 | 107 | FROM ubuntu:18.04 as builder 108 | 109 | RUN apt-get update && apt-get install -y --no-install-recommends \ 110 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 111 | build-essential patch g++ python python-future python3 ca-certificates \ 112 | libc6-dev libstdc++6 libusb-1.0-0 113 | 114 | # Copy all libraries, includes and go 115 | COPY --from=base /usr/local/. /usr/local/. 116 | COPY --from=base /opt/tensorflow /opt/tensorflow 117 | 118 | ENV GOOS=linux 119 | ENV CGO_ENABLED=1 120 | ENV CGO_CFLAGS=-I/opt/tensorflow 121 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 122 | ENV GOPATH /go 123 | 124 | # Create the build directory 125 | RUN mkdir /build 126 | WORKDIR /build 127 | ADD . . 128 | RUN make 129 | 130 | FROM ubuntu:18.04 131 | 132 | RUN apt-get update && \ 133 | apt-get install -y --no-install-recommends libusb-1.0 libc++-7-dev wget unzip ca-certificates libdc1394-22 libavcodec57 libavformat57 && \ 134 | apt-get clean && \ 135 | rm -rf /var/lib/apt/lists/* 136 | RUN mkdir -p /opt/doods 137 | WORKDIR /opt/doods 138 | COPY --from=builder /usr/local/lib/. /usr/local/lib/. 139 | COPY --from=builder /build/doods /opt/doods/doods 140 | RUN ldconfig 141 | 142 | # Download sample models 143 | RUN mkdir models 144 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 145 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 146 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 147 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 148 | ADD config.yaml config.yaml 149 | 150 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 151 | -------------------------------------------------------------------------------- /Dockerfile.amd64: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as base 2 | 3 | # Install reqs with cross compile support 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 6 | build-essential patch g++ python python-future python-numpy python-six python3 \ 7 | cmake ca-certificates \ 8 | libc6-dev libstdc++6 libusb-1.0-0 9 | 10 | # Install protoc 11 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 12 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 13 | rm /usr/local/readme.txt && \ 14 | rm protoc-3.12.3-linux-x86_64.zip 15 | 16 | # Version Configuration 17 | ARG BAZEL_VERSION="2.0.0" 18 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 19 | ARG OPENCV_VERSION="4.5.0" 20 | ARG GO_VERSION="1.14.3" 21 | 22 | # Install bazel 23 | ENV BAZEL_VERSION $BAZEL_VERSION 24 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 25 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 26 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 27 | 28 | # Download tensorflow sources 29 | ENV TF_VERSION $TF_VERSION 30 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 31 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 32 | 33 | # Configure tensorflow 34 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 35 | RUN cd /opt/tensorflow && yes '' | ./configure 36 | 37 | # Tensorflow build flags 38 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 --config monolithic --copt=-O3 --copt=-fomit-frame-pointer --copt=-mfpmath=both --copt=-mavx --copt=-msse4.2 --incompatible_no_support_tools_in_action_inputs=false --config=noaws --config=nohdfs" 39 | ENV BAZEL_EXTRA_FLAGS="--host_linkopt=-lm" 40 | 41 | # Compile and build tensorflow lite 42 | RUN cd /opt/tensorflow && \ 43 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 44 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 45 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 46 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 47 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 48 | 49 | # Compile and install tensorflow shared library 50 | RUN cd /opt/tensorflow && \ 51 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 52 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 53 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 54 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 55 | 56 | # cleanup so the cache directory isn't huge 57 | RUN cd /opt/tensorflow && \ 58 | bazel clean && rm -Rf /root/.cache 59 | 60 | # Install GOCV 61 | ENV OPENCV_VERSION $OPENCV_VERSION 62 | RUN cd /tmp && \ 63 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 64 | unzip -q opencv.zip && \ 65 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 66 | unzip -q opencv_contrib.zip && \ 67 | rm opencv.zip opencv_contrib.zip && \ 68 | cd opencv-${OPENCV_VERSION} && \ 69 | mkdir build && cd build && \ 70 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 71 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 72 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 73 | -D WITH_JASPER=OFF \ 74 | -D WITH_QT=OFF \ 75 | -D WITH_GTK=OFF \ 76 | -D BUILD_DOCS=OFF \ 77 | -D BUILD_EXAMPLES=OFF \ 78 | -D BUILD_TESTS=OFF \ 79 | -D BUILD_PERF_TESTS=OFF \ 80 | -D BUILD_opencv_java=NO \ 81 | -D BUILD_opencv_python=NO \ 82 | -D BUILD_opencv_python2=NO \ 83 | -D BUILD_opencv_python3=NO \ 84 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 85 | make -j $(nproc --all) && \ 86 | make preinstall && make install && \ 87 | cd /tmp && rm -rf opencv* 88 | 89 | # Fetch the edgetpu library locally 90 | ADD libedgetpu/out/throttled/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 91 | RUN ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \ 92 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \ 93 | mkdir -p /usr/local/include/libedgetpu 94 | ADD libedgetpu/tflite/public/edgetpu.h /usr/local/include/libedgetpu/edgetpu.h 95 | ADD libedgetpu/tflite/public/edgetpu_c.h /usr/local/include/libedgetpu/edgetpu_c.h 96 | 97 | # Configure the Go version to be used 98 | ENV GO_ARCH "amd64" 99 | ENV GOARCH=amd64 100 | 101 | # Install Go 102 | ENV GO_VERSION $GO_VERSION 103 | RUN curl -kLo go${GO_VERSION}.linux-${GO_ARCH}.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 104 | tar -C /usr/local -xzf go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 105 | rm go${GO_VERSION}.linux-${GO_ARCH}.tar.gz 106 | 107 | FROM ubuntu:18.04 as builder 108 | 109 | RUN apt-get update && apt-get install -y --no-install-recommends \ 110 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 111 | build-essential patch g++ python python-future python3 ca-certificates \ 112 | libc6-dev libstdc++6 libusb-1.0-0 113 | 114 | # Copy all libraries, includes and go 115 | COPY --from=base /usr/local/. /usr/local/. 116 | COPY --from=base /opt/tensorflow /opt/tensorflow 117 | 118 | ENV GOOS=linux 119 | ENV CGO_ENABLED=1 120 | ENV CGO_CFLAGS=-I/opt/tensorflow 121 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 122 | ENV GOPATH /go 123 | 124 | # Create the build directory 125 | RUN mkdir /build 126 | WORKDIR /build 127 | ADD . . 128 | RUN make 129 | 130 | FROM ubuntu:18.04 131 | 132 | RUN apt-get update && \ 133 | apt-get install -y --no-install-recommends libusb-1.0 libc++-7-dev wget unzip ca-certificates libdc1394-22 libavcodec57 libavformat57 && \ 134 | apt-get clean && \ 135 | rm -rf /var/lib/apt/lists/* 136 | RUN mkdir -p /opt/doods 137 | WORKDIR /opt/doods 138 | COPY --from=builder /usr/local/lib/. /usr/local/lib/. 139 | COPY --from=builder /build/doods /opt/doods/doods 140 | RUN ldconfig 141 | 142 | # Download sample models 143 | RUN mkdir models 144 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 145 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 146 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 --no-same-owner && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 147 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 148 | ADD config.yaml config.yaml 149 | 150 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 151 | -------------------------------------------------------------------------------- /Dockerfile.arm32: -------------------------------------------------------------------------------- 1 | FROM debian:buster as base 2 | 3 | # Install reqs with cross compile support 4 | RUN dpkg --add-architecture armhf && \ 5 | apt-get update && apt-get install -y --no-install-recommends \ 6 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 7 | build-essential patch g++ python python-future python3 \ 8 | python-numpy python-six \ 9 | cmake ca-certificates \ 10 | libc6-dev:armhf libstdc++6:armhf libusb-1.0-0:armhf 11 | 12 | # Install protoc 13 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 14 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 15 | rm /usr/local/readme.txt && \ 16 | rm protoc-3.12.3-linux-x86_64.zip 17 | 18 | # Version Configuration 19 | ARG BAZEL_VERSION="2.0.0" 20 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 21 | ARG OPENCV_VERSION="4.5.0" 22 | ARG GO_VERSION="1.14.3" 23 | 24 | # Install bazel 25 | ENV BAZEL_VERSION $BAZEL_VERSION 26 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 27 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 28 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 29 | 30 | # Download tensorflow sources 31 | ENV TF_VERSION $TF_VERSION 32 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 33 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 34 | 35 | # Download and configure the toolchain and patch tensorflow as needed 36 | ENV CROSSTOOL_COMPILER="yes" 37 | ENV CROSSTOOL_URL="https://releases.linaro.org/components/toolchain/binaries/5.5-2017.10/arm-linux-gnueabihf/gcc-linaro-5.5.0-2017.10-x86_64_arm-linux-gnueabihf.tar.xz" 38 | ENV CROSSTOOL_DIR="gcc-linaro-5.5.0-2017.10-x86_64_arm-linux-gnueabihf" 39 | ENV CROSSTOOL_NAME="arm-linux-gnueabihf" 40 | COPY tf_arm_toolchain_patch.sh /opt/tf_arm_toolchain_patch.sh 41 | RUN cd /opt/tensorflow && /opt/tf_arm_toolchain_patch.sh 42 | 43 | # Configure tensorflow 44 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 45 | RUN cd /opt/tensorflow && yes '' | ./configure 46 | 47 | # Tensorflow build flags 48 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 --copt=-march=armv7-a --copt=-mfpu=neon-vfpv4 --copt=-mfloat-abi=hard --copt=-O3 --copt=-fno-tree-pre --copt=-fpermissive --copt=-std=c++11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8 --config=monolithic --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-fomit-frame-pointer --copt=-DRASPBERRY_PI --noincompatible_strict_action_env --config=noaws --config=nohdfs --define tensorflow_mkldnn_contraction_kernel=0 --define=raspberry_pi_with_neon=true" 49 | ENV BAZEL_EXTRA_FLAGS="--cpu=armeabi --host_linkopt=-lm --crosstool_top=@local_config_arm_compiler//:toolchain" 50 | 51 | # Compile and build tensorflow lite 52 | RUN cd /opt/tensorflow && \ 53 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 54 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 55 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 56 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 57 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 58 | 59 | # Compile and install tensorflow shared library 60 | RUN cd /opt/tensorflow && \ 61 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 62 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 63 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 64 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 65 | 66 | # cleanup so the cache directory isn't huge 67 | RUN cd /opt/tensorflow && \ 68 | bazel clean && rm -Rf /root/.cache 69 | 70 | # Download and configure the build environment for gcc 6 which is needed to compile everything else 71 | RUN mkdir -p /tmp/sysroot/lib && mkdir -p /tmp/sysroot/usr/lib && \ 72 | cd /tmp && \ 73 | wget --no-check-certificate https://releases.linaro.org/components/toolchain/binaries/6.3-2017.05/arm-linux-gnueabihf/gcc-linaro-6.3.1-2017.05-x86_64_arm-linux-gnueabihf.tar.xz -O toolchain.tar.xz && \ 74 | tar xf toolchain.tar.xz -C /opt/toolchain/ && \ 75 | rm toolchain.tar.xz && \ 76 | cp -r /opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_arm-linux-gnueabihf/arm-linux-gnueabihf/libc/* /tmp/sysroot/ 77 | RUN mkdir -p /tmp/debs && cd /tmp/debs && apt-get update && apt-get download libc6:armhf libc6-dev:armhf && \ 78 | ar x libc6_*.deb && tar xvf data.tar.xz && \ 79 | ar x libc6-dev*.deb && tar xvf data.tar.xz && \ 80 | cp -R usr /tmp/sysroot && cp -R lib /tmp/sysroot && rm -Rf /tmp/debs && \ 81 | mkdir -p /tmp/debs && cd /tmp/debs && \ 82 | apt-get download libusb-1.0-0:armhf libudev1:armhf zlib1g-dev:armhf zlib1g:armhf && \ 83 | ar x libusb-1.0*.deb && tar xvf data.tar.xz && \ 84 | ar x libudev1*.deb && tar xvf data.tar.xz && \ 85 | ar x zlib1g_*.deb && tar xvf data.tar.xz && \ 86 | ar x zlib1g-dev*.deb && tar xvf data.tar.xz && rm usr/lib/arm-linux-gnueabihf/libz.so && \ 87 | cp -r lib/arm-linux-gnueabihf/* /tmp/sysroot/lib && \ 88 | cp -r usr/lib/arm-linux-gnueabihf/* /tmp/sysroot/usr/lib && \ 89 | cp -r usr/include/* /tmp/sysroot/usr/include && \ 90 | ln -rs /tmp/sysroot/lib/libusb-1.0.so.0.1.0 /tmp/sysroot/lib/libusb-1.0.so && \ 91 | ln -rs /tmp/sysroot/lib/libudev.so.1.6.13 /tmp/sysroot/lib/libudev.so && \ 92 | ln -rs /tmp/sysroot/lib/libz.so.1.2.11 /tmp/sysroot/lib/libz.so && \ 93 | ln -s /usr/local /tmp/sysroot/usr/local && \ 94 | cd /tmp && rm -Rf /tmp/debs 95 | ENV CC="/opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_arm-linux-gnueabihf/bin/arm-linux-gnueabihf-gcc" 96 | ENV CXX="/opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_arm-linux-gnueabihf/bin/arm-linux-gnueabihf-g++" 97 | ENV LDFLAGS="-v -L /lib -L /usr/lib --sysroot /tmp/sysroot" 98 | ENV CFLAGS="-L /lib -L /usr/lib --sysroot /tmp/sysroot" 99 | ENV CXXFLAGS="-L /lib -L /usr/lib --sysroot /tmp/sysroot" 100 | 101 | # Install GOCV 102 | ENV OPENCV_VERSION $OPENCV_VERSION 103 | RUN cd /tmp && \ 104 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 105 | unzip -q opencv.zip && \ 106 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 107 | unzip -q opencv_contrib.zip && \ 108 | rm opencv.zip opencv_contrib.zip && \ 109 | cd opencv-${OPENCV_VERSION} && \ 110 | mkdir build && cd build && \ 111 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 112 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 113 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 114 | -D WITH_JASPER=OFF \ 115 | -D WITH_QT=OFF \ 116 | -D WITH_GTK=OFF \ 117 | -D BUILD_DOCS=OFF \ 118 | -D BUILD_EXAMPLES=OFF \ 119 | -D BUILD_TESTS=OFF \ 120 | -D BUILD_PERF_TESTS=OFF \ 121 | -D BUILD_opencv_java=NO \ 122 | -D BUILD_opencv_python=NO \ 123 | -D BUILD_opencv_python2=NO \ 124 | -D BUILD_opencv_python3=NO \ 125 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 126 | make -j $(nproc --all) && \ 127 | make preinstall && make install && \ 128 | cd /tmp && rm -rf opencv* 129 | 130 | # Fetch the edgetpu library locally 131 | ADD libedgetpu/out/throttled/armv7a/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so.1.0 132 | RUN ln -rs /tmp/sysroot/usr/lib/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so.1 && \ 133 | ln -rs /tmp/sysroot/usr/lib/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so && \ 134 | mkdir -p /tmp/sysroot/usr/include/libedgetpu 135 | ADD libedgetpu/tflite/public/edgetpu.h /tmp/sysroot/usr/include/libedgetpu/edgetpu.h 136 | ADD libedgetpu/tflite/public/edgetpu_c.h /tmp/sysroot/usr/include/libedgetpu/edgetpu_c.h 137 | 138 | # Install Go 139 | ENV GO_VERSION $GO_VERSION 140 | ARG GO_ARCH="amd64" 141 | ENV GOOS=linux 142 | RUN curl -Lo go${GO_VERSION}.linux-$GO_ARCH.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-$GO_ARCH.tar.gz && \ 143 | tar -C /usr/local -xzf go${GO_VERSION}.linux-$GO_ARCH.tar.gz && \ 144 | rm go${GO_VERSION}.linux-$GO_ARCH.tar.gz 145 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 146 | ENV GOPATH /go 147 | 148 | # Start compile 149 | WORKDIR /build 150 | ADD . . 151 | 152 | # Install/Compile tools 153 | ENV CGO_ENABLED=0 154 | RUN make tools 155 | 156 | # Compile DOODS 157 | ENV GOARCH=arm 158 | ENV CGO_ENABLED=1 159 | ENV CGO_LDFLAGS="-v -L /lib -L /usr/lib -L /usr/local/lib --sysroot /tmp/sysroot -ledgetpu" 160 | ENV CGO_CFLAGS="-L /lib -L /usr/lib -L /usr/local/lib -I /opt/tensorflow --sysroot /tmp/sysroot" 161 | ENV CGO_CXXFLAGS="-L /lib -L /usr/lib -L /usr/local/lib -I /opt/tensorflow --sysroot /tmp/sysroot" 162 | RUN make 163 | 164 | # Start creating the new root directory 165 | WORKDIR /tmp/newroot 166 | RUN mkdir -p /tmp/newroot/lib && mkdir -p /tmp/newroot/usr/lib && \ 167 | cp -r /tmp/sysroot/lib/* /tmp/newroot/lib && \ 168 | cp -r /tmp/sysroot/usr/lib/* /tmp/newroot/usr/lib 169 | 170 | # Copy doods executable 171 | RUN mkdir -p /tmp/newroot/opt/doods && \ 172 | cp /build/doods /tmp/newroot/opt/doods/doods 173 | WORKDIR /tmp/newroot/opt/doods 174 | 175 | # Download sample models 176 | RUN mkdir models 177 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 178 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 179 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 --no-same-owner && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 180 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 181 | ADD config.arm.yaml config.yaml 182 | 183 | FROM arm32v7/debian:buster-slim 184 | # Copy the pre-built root filesystem 185 | COPY --from=base /tmp/newroot/. /. 186 | COPY --from=base /usr/local/. /usr/local/. 187 | # Needed because we can't run ldconfig 188 | ENV LD_LIBRARY_PATH=/usr/local/lib 189 | 190 | WORKDIR /opt/doods 191 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 192 | -------------------------------------------------------------------------------- /Dockerfile.arm64: -------------------------------------------------------------------------------- 1 | FROM debian:buster as base 2 | 3 | # Install reqs with cross compile support 4 | RUN dpkg --add-architecture arm64 && \ 5 | apt-get update && apt-get install -y --no-install-recommends \ 6 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 7 | build-essential patch g++ python python-future python3 \ 8 | python-numpy python-six \ 9 | cmake ca-certificates \ 10 | libc6-dev:arm64 libstdc++6:arm64 libusb-1.0-0:arm64 11 | 12 | # Install protoc 13 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 14 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 15 | rm /usr/local/readme.txt && \ 16 | rm protoc-3.12.3-linux-x86_64.zip 17 | 18 | # Version Configuration 19 | ARG BAZEL_VERSION="2.0.0" 20 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 21 | ARG OPENCV_VERSION="4.5.0" 22 | ARG GO_VERSION="1.14.3" 23 | 24 | # Install bazel 25 | ENV BAZEL_VERSION $BAZEL_VERSION 26 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 27 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 28 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 29 | 30 | # Download tensorflow sources 31 | ENV TF_VERSION $TF_VERSION 32 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 33 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 34 | 35 | # Download and configure the toolchain and patch tensorflow as needed 36 | ENV CROSSTOOL_COMPILER="yes" 37 | ENV CROSSTOOL_URL="https://releases.linaro.org/components/toolchain/binaries/5.5-2017.10/aarch64-linux-gnu/gcc-linaro-5.5.0-2017.10-x86_64_aarch64-linux-gnu.tar.xz" 38 | ENV CROSSTOOL_DIR="gcc-linaro-5.5.0-2017.10-x86_64_aarch64-linux-gnu" 39 | ENV CROSSTOOL_NAME="aarch64-linux-gnu" 40 | COPY tf_arm_toolchain_patch.sh /opt/tf_arm_toolchain_patch.sh 41 | RUN cd /opt/tensorflow && /opt/tf_arm_toolchain_patch.sh 42 | 43 | # Configure tensorflow 44 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 45 | RUN cd /opt/tensorflow && yes '' | ./configure 46 | 47 | # Tensorflow build flags for aarch64 (odroid c2 used for example) 48 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 --copt=-march=armv8-a+crc+simd --copt=-mtune=cortex-a53 --copt=-O3 --copt=-flax-vector-conversions --copt=-std=c++11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR --config=monolithic --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-fomit-frame-pointer --copt=-DRASPBERRY_PI --noincompatible_strict_action_env --config=v2 --config=noaws --config=nohdfs --define tensorflow_mkldnn_contraction_kernel=0 --define=raspberry_pi_with_neon=true" 49 | ENV BAZEL_EXTRA_FLAGS="--cpu=armeabi --host_linkopt=-lm --crosstool_top=@local_config_arm_compiler//:toolchain" 50 | 51 | # Compile and build tensorflow lite 52 | RUN cd /opt/tensorflow && \ 53 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 54 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 55 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 56 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 57 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 58 | 59 | # Compile and install tensorflow shared library 60 | RUN cd /opt/tensorflow && \ 61 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 62 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 63 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 64 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 65 | 66 | # cleanup so the cache directory isn't huge 67 | RUN cd /opt/tensorflow && \ 68 | bazel clean && rm -Rf /root/.cache 69 | 70 | # Download and configure the build environment for gcc 6 which is needed to compile everything else 71 | RUN mkdir -p /tmp/sysroot/lib && mkdir -p /tmp/sysroot/usr/lib && \ 72 | cd /tmp && \ 73 | wget --no-check-certificate https://releases.linaro.org/components/toolchain/binaries/6.3-2017.05/aarch64-linux-gnu/gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu.tar.xz -O toolchain.tar.xz && \ 74 | tar xf toolchain.tar.xz -C /opt/toolchain/ && \ 75 | rm toolchain.tar.xz && \ 76 | cp -r /opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu/aarch64-linux-gnu/libc/* /tmp/sysroot/ 77 | RUN mkdir -p /tmp/debs && cd /tmp/debs && apt-get update && apt-get download libc6:arm64 libc6-dev:arm64 && \ 78 | ar x libc6_*.deb && tar xvf data.tar.xz && \ 79 | ar x libc6-dev*.deb && tar xvf data.tar.xz && \ 80 | cp -R usr /tmp/sysroot && cp -R lib /tmp/sysroot && rm -Rf /tmp/debs && \ 81 | mkdir -p /tmp/debs && cd /tmp/debs && \ 82 | apt-get download libusb-1.0-0:arm64 libudev1:arm64 zlib1g-dev:arm64 zlib1g:arm64 && \ 83 | ar x libusb-1.0*.deb && tar xvf data.tar.xz && \ 84 | ar x libudev1*.deb && tar xvf data.tar.xz && \ 85 | ar x zlib1g_*.deb && tar xvf data.tar.xz && \ 86 | ar x zlib1g-dev*.deb && tar xvf data.tar.xz && rm usr/lib/aarch64-linux-gnu/libz.so && \ 87 | cp -r lib/aarch64-linux-gnu/* /tmp/sysroot/lib && \ 88 | cp -r usr/lib/aarch64-linux-gnu/* /tmp/sysroot/usr/lib && \ 89 | cp -r usr/include/* /tmp/sysroot/usr/include && \ 90 | ln -rs /tmp/sysroot/lib/libusb-1.0.so.0.1.0 /tmp/sysroot/lib/libusb-1.0.so && \ 91 | ln -rs /tmp/sysroot/lib/libudev.so.1.6.13 /tmp/sysroot/lib/libudev.so && \ 92 | ln -rs /tmp/sysroot/lib/libz.so.1.2.11 /tmp/sysroot/lib/libz.so && \ 93 | ln -s /usr/local /tmp/sysroot/usr/local && \ 94 | cd /tmp && rm -Rf /tmp/debs 95 | ENV CC="/opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu/bin/aarch64-linux-gnu-gcc" 96 | ENV CXX="/opt/toolchain/gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu/bin/aarch64-linux-gnu-g++" 97 | ENV LDFLAGS="-v -L /lib -L /usr/lib --sysroot /tmp/sysroot" 98 | ENV CFLAGS="-L /lib -L /usr/lib -D PNG_ARM_NEON_OPT=0 --sysroot /tmp/sysroot" 99 | ENV CXXFLAGS="-L /lib -L /usr/lib -D PNG_ARM_NEON_OPT=0 --sysroot /tmp/sysroot" 100 | 101 | # Install GOCV 102 | ENV OPENCV_VERSION $OPENCV_VERSION 103 | RUN cd /tmp && \ 104 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 105 | unzip -q opencv.zip && \ 106 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 107 | unzip -q opencv_contrib.zip && \ 108 | rm opencv.zip opencv_contrib.zip && \ 109 | cd opencv-${OPENCV_VERSION} && \ 110 | mkdir build && cd build && \ 111 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 112 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 113 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 114 | -D WITH_JASPER=OFF \ 115 | -D WITH_QT=OFF \ 116 | -D WITH_GTK=OFF \ 117 | -D WITH_IPP=OFF \ 118 | -D BUILD_DOCS=OFF \ 119 | -D BUILD_EXAMPLES=OFF \ 120 | -D BUILD_TESTS=OFF \ 121 | -D BUILD_PERF_TESTS=OFF \ 122 | -D BUILD_opencv_java=NO \ 123 | -D BUILD_opencv_python=NO \ 124 | -D BUILD_opencv_python2=NO \ 125 | -D BUILD_opencv_python3=NO \ 126 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 127 | make -j $(nproc --all) && \ 128 | make preinstall && make install && \ 129 | cd /tmp && rm -rf opencv* 130 | 131 | # Fetch the edgetpu library locally 132 | ADD libedgetpu/out/throttled/aarch64/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so.1.0 133 | RUN ln -rs /tmp/sysroot/usr/lib/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so.1 && \ 134 | ln -rs /tmp/sysroot/usr/lib/libedgetpu.so.1.0 /tmp/sysroot/usr/lib/libedgetpu.so && \ 135 | mkdir -p /tmp/sysroot/usr/include/libedgetpu 136 | ADD libedgetpu/tflite/public/edgetpu.h /tmp/sysroot/usr/include/libedgetpu/edgetpu.h 137 | ADD libedgetpu/tflite/public/edgetpu_c.h /tmp/sysroot/usr/include/libedgetpu/edgetpu_c.h 138 | 139 | # Install Go 140 | ENV GO_VERSION $GO_VERSION 141 | ARG GO_ARCH="amd64" 142 | ENV GOOS=linux 143 | RUN curl -Lo go${GO_VERSION}.linux-$GO_ARCH.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-$GO_ARCH.tar.gz && \ 144 | tar -C /usr/local -xzf go${GO_VERSION}.linux-$GO_ARCH.tar.gz && \ 145 | rm go${GO_VERSION}.linux-$GO_ARCH.tar.gz 146 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 147 | ENV GOPATH /go 148 | 149 | # Start compile 150 | WORKDIR /build 151 | ADD . . 152 | 153 | # Install/Compile tools 154 | ENV CGO_ENABLED=0 155 | RUN make tools 156 | 157 | # Compile DOODS 158 | ENV GOARCH=arm64 159 | ENV CGO_ENABLED=1 160 | ENV CGO_LDFLAGS="-v -L /lib -L /usr/lib -L /usr/local/lib --sysroot /tmp/sysroot -ledgetpu" 161 | ENV CGO_CFLAGS="-L /lib -L /usr/lib -L /usr/local/lib -I /opt/tensorflow --sysroot /tmp/sysroot" 162 | ENV CGO_CXXFLAGS="-L /lib -L /usr/lib -L /usr/local/lib -I /opt/tensorflow --sysroot /tmp/sysroot" 163 | RUN make 164 | 165 | # Start creating the new root directory 166 | WORKDIR /tmp/newroot 167 | RUN mkdir -p /tmp/newroot/lib && mkdir -p /tmp/newroot/usr/lib && \ 168 | cp -r /tmp/sysroot/lib/* /tmp/newroot/lib && \ 169 | cp -r /tmp/sysroot/usr/lib/* /tmp/newroot/usr/lib 170 | 171 | # Setup doods and sample config with models 172 | RUN mkdir -p /tmp/newroot/opt/doods && \ 173 | cp /build/doods /tmp/newroot/opt/doods/doods 174 | WORKDIR /tmp/newroot/opt/doods 175 | 176 | # Download sample models 177 | RUN mkdir models 178 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 179 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 180 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 --no-same-owner && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 181 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 182 | ADD config.arm.yaml config.yaml 183 | 184 | FROM arm64v8/debian:buster-slim 185 | # Copy the pre-built root filesystem 186 | COPY --from=base /tmp/newroot/. /. 187 | COPY --from=base /usr/local/. /usr/local/. 188 | # Needed because we can't run ldconfig 189 | ENV LD_LIBRARY_PATH=/usr/local/lib 190 | 191 | WORKDIR /opt/doods 192 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 193 | -------------------------------------------------------------------------------- /Dockerfile.base.cuda: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 as builder 2 | 3 | # Install reqs with cross compile support 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 6 | build-essential patch g++ python python-future python-numpy python-six python3 \ 7 | cmake ca-certificates \ 8 | libc6-dev libstdc++6 libusb-1.0-0 9 | 10 | # Install protoc 11 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.9.1/protoc-3.9.1-linux-x86_64.zip && \ 12 | unzip protoc-3.9.1-linux-x86_64.zip -d /usr/local && \ 13 | rm /usr/local/readme.txt && \ 14 | rm protoc-3.9.1-linux-x86_64.zip 15 | 16 | # Install bazel 17 | RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.1/bazel_0.27.1-linux-x86_64.deb && \ 18 | dpkg -i bazel_0.27.1-linux-x86_64.deb && \ 19 | rm bazel_0.27.1-linux-x86_64.deb 20 | 21 | # Download tensorflow sources 22 | ARG TF_VERSION="v2.1.0" 23 | ENV TF_VERSION $TF_VERSION 24 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 25 | 26 | # Configure tensorflow 27 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=1 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 28 | RUN cd /opt/tensorflow && yes '' | ./configure 29 | 30 | # Tensorflow build flags for rpi 31 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 --config monolithic --copt=-O3 --copt=-fomit-frame-pointer --copt=-mfpmath=both --copt=-mavx --copt=-msse4.2 --incompatible_no_support_tools_in_action_inputs=false --config=noaws --config=nohdfs" 32 | ENV BAZEL_EXTRA_FLAGS="" 33 | 34 | # Patch to make it work with cuda 10.2 35 | RUN cd /opt/tensorflow && sed -i '/"--bin2c-path=%s" % bin2c.dirname,/d' third_party/nccl/build_defs.bzl.tpl 36 | 37 | # Compile and build tensorflow lite 38 | RUN cd /opt/tensorflow && \ 39 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 40 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 41 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/experimental/c:libtensorflowlite_c.so && \ 42 | install bazel-bin/tensorflow/lite/experimental/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 43 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 44 | 45 | # Compile and install tensorflow shared library 46 | RUN cd /opt/tensorflow && \ 47 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 48 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 49 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 50 | 51 | # cleanup so the cache directory isn't huge 52 | RUN cd /opt/tensorflow && \ 53 | bazel clean && rm -Rf /root/.cache 54 | 55 | # Install GOCV 56 | ARG OPENCV_VERSION="4.5.0" 57 | ENV OPENCV_VERSION $OPENCV_VERSION 58 | RUN cd /tmp && \ 59 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 60 | unzip -q opencv.zip && \ 61 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 62 | unzip -q opencv_contrib.zip && \ 63 | rm opencv.zip opencv_contrib.zip && \ 64 | cd opencv-${OPENCV_VERSION} && \ 65 | mkdir build && cd build && \ 66 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 67 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 68 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 69 | -D WITH_JASPER=OFF \ 70 | -D WITH_QT=OFF \ 71 | -D WITH_GTK=OFF \ 72 | -D BUILD_DOCS=OFF \ 73 | -D BUILD_EXAMPLES=OFF \ 74 | -D BUILD_TESTS=OFF \ 75 | -D BUILD_PERF_TESTS=OFF \ 76 | -D BUILD_opencv_java=NO \ 77 | -D BUILD_opencv_python=NO \ 78 | -D BUILD_opencv_python2=NO \ 79 | -D BUILD_opencv_python3=NO \ 80 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 81 | make -j $(nproc --all) && \ 82 | make preinstall && make install && \ 83 | cd /tmp && rm -rf opencv* 84 | 85 | # Download the edgetpu library and install it 86 | RUN cd /tmp && git clone https://github.com/google-coral/edgetpu.git && \ 87 | install edgetpu/libedgetpu/throttled/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 && \ 88 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \ 89 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \ 90 | mkdir -p /usr/local/include/libedgetpu && \ 91 | install edgetpu/libedgetpu/edgetpu.h /usr/local/include/libedgetpu/edgetpu.h && \ 92 | install edgetpu/libedgetpu/edgetpu_c.h /usr/local/include/libedgetpu/edgetpu_c.h && \ 93 | rm -Rf edgetpu 94 | 95 | # Configure the Go version to be used 96 | ENV GO_ARCH "amd64" 97 | ENV GOARCH=amd64 98 | 99 | # Install Go 100 | ENV GO_VERSION "1.14.2" 101 | RUN curl -kLo go${GO_VERSION}.linux-${GO_ARCH}.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 102 | tar -C /usr/local -xzf go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 103 | rm go${GO_VERSION}.linux-${GO_ARCH}.tar.gz 104 | 105 | FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 as build 106 | 107 | RUN apt-get update && apt-get install -y --no-install-recommends \ 108 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 109 | build-essential patch g++ python python-future python3 ca-certificates \ 110 | libc6-dev libstdc++6 libusb-1.0-0 111 | 112 | # Copy all libraries, includes and go 113 | COPY --from=builder /usr/local/. /usr/local/. 114 | COPY --from=builder /opt/tensorflow /opt/tensorflow 115 | 116 | ENV GOOS=linux 117 | ENV CGO_ENABLED=1 118 | ENV CGO_CFLAGS=-I/opt/tensorflow 119 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 120 | ENV GOPATH /go 121 | 122 | # Create the build directory 123 | RUN mkdir /build 124 | WORKDIR /build 125 | ADD . . 126 | RUN make 127 | 128 | FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 129 | 130 | RUN apt-get update && \ 131 | apt-get install -y --no-install-recommends libusb-1.0 libc++-7-dev wget unzip ca-certificates libdc1394-22 libavcodec57 libavformat57 && \ 132 | apt-get clean && \ 133 | rm -rf /var/lib/apt/lists/* 134 | RUN mkdir -p /opt/doods 135 | WORKDIR /opt/doods 136 | COPY --from=build /usr/local/lib/. /usr/local/lib/. 137 | COPY --from=build /build/doods /opt/doods/doods 138 | ADD config.yaml /opt/doods/config.yaml 139 | ENV LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat 140 | RUN ldconfig 141 | 142 | # Download sample models 143 | RUN mkdir models 144 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 145 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 146 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 --no-same-owner && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 147 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 148 | 149 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 150 | -------------------------------------------------------------------------------- /Dockerfile.builder: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as base 2 | 3 | # Install reqs with cross compile support 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 6 | build-essential patch g++ python python-future python-numpy python-six python3 \ 7 | cmake ca-certificates \ 8 | libc6-dev libstdc++6 libusb-1.0-0 9 | 10 | # Install protoc 11 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 12 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 13 | rm /usr/local/readme.txt && \ 14 | rm protoc-3.12.3-linux-x86_64.zip 15 | 16 | # Version Configuration 17 | ARG BAZEL_VERSION="2.0.0" 18 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 19 | ARG OPENCV_VERSION="4.5.0" 20 | ARG GO_VERSION="1.14.3" 21 | 22 | # Install bazel 23 | ENV BAZEL_VERSION $BAZEL_VERSION 24 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 25 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 26 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 27 | 28 | # Download tensorflow sources 29 | ENV TF_VERSION $TF_VERSION 30 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 31 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 32 | 33 | # Configure tensorflow 34 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 35 | RUN cd /opt/tensorflow && yes '' | ./configure 36 | 37 | # Tensorflow build flags 38 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 -c opt --config monolithic --copt=-march=native --copt=-O3 --copt=-fomit-frame-pointer --incompatible_no_support_tools_in_action_inputs=false --config=noaws --config=nohdfs" 39 | ENV BAZEL_EXTRA_FLAGS="--host_linkopt=-lm" 40 | 41 | # Compile and build tensorflow lite 42 | RUN cd /opt/tensorflow && \ 43 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 44 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 45 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 46 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 47 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 48 | 49 | # Compile and install tensorflow shared library 50 | RUN cd /opt/tensorflow && \ 51 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 52 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 53 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 54 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 55 | 56 | # cleanup so the cache directory isn't huge 57 | RUN cd /opt/tensorflow && \ 58 | bazel clean && rm -Rf /root/.cache 59 | 60 | # Install GOCV 61 | ENV OPENCV_VERSION $OPENCV_VERSION 62 | RUN cd /tmp && \ 63 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 64 | unzip -q opencv.zip && \ 65 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 66 | unzip -q opencv_contrib.zip && \ 67 | rm opencv.zip opencv_contrib.zip && \ 68 | cd opencv-${OPENCV_VERSION} && \ 69 | mkdir build && cd build && \ 70 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 71 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 72 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 73 | -D WITH_JASPER=OFF \ 74 | -D WITH_QT=OFF \ 75 | -D WITH_GTK=OFF \ 76 | -D BUILD_DOCS=OFF \ 77 | -D BUILD_EXAMPLES=OFF \ 78 | -D BUILD_TESTS=OFF \ 79 | -D BUILD_PERF_TESTS=OFF \ 80 | -D BUILD_opencv_java=NO \ 81 | -D BUILD_opencv_python=NO \ 82 | -D BUILD_opencv_python2=NO \ 83 | -D BUILD_opencv_python3=NO \ 84 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 85 | make -j $(nproc --all) && \ 86 | make preinstall && make install && \ 87 | cd /tmp && rm -rf opencv* 88 | 89 | # Fetch the edgetpu library locally 90 | ADD libedgetpu/out/throttled/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 91 | RUN ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \ 92 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \ 93 | mkdir -p /usr/local/include/libedgetpu 94 | ADD libedgetpu/tflite/public/edgetpu.h /usr/local/include/libedgetpu/edgetpu.h 95 | ADD libedgetpu/tflite/public/edgetpu_c.h /usr/local/include/libedgetpu/edgetpu_c.h 96 | 97 | # Configure the Go version to be used 98 | ENV GO_ARCH "amd64" 99 | ENV GOARCH=amd64 100 | 101 | # Install Go 102 | ENV GO_VERSION $GO_VERSION 103 | RUN curl -kLo go${GO_VERSION}.linux-${GO_ARCH}.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 104 | tar -C /usr/local -xzf go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 105 | rm go${GO_VERSION}.linux-${GO_ARCH}.tar.gz 106 | 107 | FROM ubuntu:18.04 as builder 108 | 109 | RUN apt-get update && apt-get install -y --no-install-recommends \ 110 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 111 | build-essential patch g++ python python-future python3 ca-certificates \ 112 | libc6-dev libstdc++6 libusb-1.0-0 113 | 114 | # Copy all libraries, includes and go 115 | COPY --from=base /usr/local/. /usr/local/. 116 | COPY --from=base /opt/tensorflow /opt/tensorflow 117 | 118 | ENV GOOS=linux 119 | ENV CGO_ENABLED=1 120 | ENV CGO_CFLAGS=-I/opt/tensorflow 121 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 122 | ENV GOPATH /go 123 | 124 | # Switch to /build 125 | WORKDIR /build 126 | -------------------------------------------------------------------------------- /Dockerfile.noavx: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as base 2 | 3 | # Install reqs with cross compile support 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 6 | build-essential patch g++ python python-future python-numpy python-six python3 \ 7 | cmake ca-certificates \ 8 | libc6-dev libstdc++6 libusb-1.0-0 9 | 10 | # Install protoc 11 | RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip && \ 12 | unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local && \ 13 | rm /usr/local/readme.txt && \ 14 | rm protoc-3.12.3-linux-x86_64.zip 15 | 16 | # Version Configuration 17 | ARG BAZEL_VERSION="2.0.0" 18 | ARG TF_VERSION="f394a768719a55b5c351ed1ecab2ec6f16f99dd4" 19 | ARG OPENCV_VERSION="4.5.0" 20 | ARG GO_VERSION="1.14.3" 21 | 22 | # Install bazel 23 | ENV BAZEL_VERSION $BAZEL_VERSION 24 | RUN wget https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 25 | dpkg -i bazel_${BAZEL_VERSION}-linux-x86_64.deb && \ 26 | rm bazel_${BAZEL_VERSION}-linux-x86_64.deb 27 | 28 | # Download tensorflow sources 29 | ENV TF_VERSION $TF_VERSION 30 | #RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git --branch $TF_VERSION --single-branch 31 | RUN cd /opt && git clone https://github.com/tensorflow/tensorflow.git && cd /opt/tensorflow && git checkout ${TF_VERSION} 32 | 33 | # Configure tensorflow 34 | ENV TF_NEED_GDR=0 TF_NEED_AWS=0 TF_NEED_GCP=0 TF_NEED_CUDA=0 TF_NEED_HDFS=0 TF_NEED_OPENCL_SYCL=0 TF_NEED_VERBS=0 TF_NEED_MPI=0 TF_NEED_MKL=0 TF_NEED_JEMALLOC=1 TF_ENABLE_XLA=0 TF_NEED_S3=0 TF_NEED_KAFKA=0 TF_NEED_IGNITE=0 TF_NEED_ROCM=0 35 | RUN cd /opt/tensorflow && yes '' | ./configure 36 | 37 | # Tensorflow build flags 38 | ENV BAZEL_COPT_FLAGS="--local_resources 16000,16,1 --config monolithic --copt=-O3 --copt=-fomit-frame-pointer --config=noaws --config=nohdfs" 39 | ENV BAZEL_EXTRA_FLAGS="--host_linkopt=-lm" 40 | 41 | # Compile and build tensorflow lite 42 | RUN cd /opt/tensorflow && \ 43 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite:libtensorflowlite.so && \ 44 | install bazel-bin/tensorflow/lite/libtensorflowlite.so /usr/local/lib/libtensorflowlite.so && \ 45 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow/lite/c:libtensorflowlite_c.so && \ 46 | install bazel-bin/tensorflow/lite/c/libtensorflowlite_c.so /usr/local/lib/libtensorflowlite_c.so && \ 47 | mkdir -p /usr/local/include/flatbuffers && cp bazel-tensorflow/external/flatbuffers/include/flatbuffers/* /usr/local/include/flatbuffers 48 | 49 | # Compile and install tensorflow shared library 50 | RUN cd /opt/tensorflow && \ 51 | bazel build -c opt $BAZEL_COPT_FLAGS --verbose_failures $BAZEL_EXTRA_FLAGS //tensorflow:libtensorflow.so && \ 52 | install bazel-bin/tensorflow/libtensorflow.so /usr/local/lib/libtensorflow.so && \ 53 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.1 && \ 54 | ln -rs /usr/local/lib/libtensorflow.so /usr/local/lib/libtensorflow.so.2 55 | 56 | # cleanup so the cache directory isn't huge 57 | RUN cd /opt/tensorflow && \ 58 | bazel clean && rm -Rf /root/.cache 59 | 60 | # Install GOCV 61 | ENV OPENCV_VERSION $OPENCV_VERSION 62 | RUN cd /tmp && \ 63 | curl -Lo opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ 64 | unzip -q opencv.zip && \ 65 | curl -Lo opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ 66 | unzip -q opencv_contrib.zip && \ 67 | rm opencv.zip opencv_contrib.zip && \ 68 | cd opencv-${OPENCV_VERSION} && \ 69 | mkdir build && cd build && \ 70 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 71 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 72 | -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-${OPENCV_VERSION}/modules \ 73 | -D WITH_JASPER=OFF \ 74 | -D WITH_QT=OFF \ 75 | -D WITH_GTK=OFF \ 76 | -D BUILD_DOCS=OFF \ 77 | -D BUILD_EXAMPLES=OFF \ 78 | -D BUILD_TESTS=OFF \ 79 | -D BUILD_PERF_TESTS=OFF \ 80 | -D BUILD_opencv_java=NO \ 81 | -D BUILD_opencv_python=NO \ 82 | -D BUILD_opencv_python2=NO \ 83 | -D BUILD_opencv_python3=NO \ 84 | -D OPENCV_GENERATE_PKGCONFIG=ON .. && \ 85 | make -j $(nproc --all) && \ 86 | make preinstall && make install && \ 87 | cd /tmp && rm -rf opencv* 88 | 89 | # Fetch the edgetpu library locally 90 | ADD libedgetpu/out/throttled/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 91 | RUN ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \ 92 | ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \ 93 | mkdir -p /usr/local/include/libedgetpu 94 | ADD libedgetpu/tflite/public/edgetpu.h /usr/local/include/libedgetpu/edgetpu.h 95 | ADD libedgetpu/tflite/public/edgetpu_c.h /usr/local/include/libedgetpu/edgetpu_c.h 96 | 97 | # Configure the Go version to be used 98 | ENV GO_ARCH "amd64" 99 | ENV GOARCH=amd64 100 | 101 | # Install Go 102 | ENV GO_VERSION $GO_VERSION 103 | RUN curl -kLo go${GO_VERSION}.linux-${GO_ARCH}.tar.gz https://dl.google.com/go/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 104 | tar -C /usr/local -xzf go${GO_VERSION}.linux-${GO_ARCH}.tar.gz && \ 105 | rm go${GO_VERSION}.linux-${GO_ARCH}.tar.gz 106 | 107 | FROM ubuntu:18.04 as builder 108 | 109 | RUN apt-get update && apt-get install -y --no-install-recommends \ 110 | pkg-config zip zlib1g-dev unzip wget bash-completion git curl \ 111 | build-essential patch g++ python python-future python3 ca-certificates \ 112 | libc6-dev libstdc++6 libusb-1.0-0 113 | 114 | # Copy all libraries, includes and go 115 | COPY --from=base /usr/local/. /usr/local/. 116 | COPY --from=base /opt/tensorflow /opt/tensorflow 117 | 118 | ENV GOOS=linux 119 | ENV CGO_ENABLED=1 120 | ENV CGO_CFLAGS=-I/opt/tensorflow 121 | ENV PATH /usr/local/go/bin:/go/bin:${PATH} 122 | ENV GOPATH /go 123 | 124 | # Create the build directory 125 | RUN mkdir /build 126 | WORKDIR /build 127 | ADD . . 128 | RUN make 129 | 130 | FROM ubuntu:18.04 131 | 132 | RUN apt-get update && \ 133 | apt-get install -y --no-install-recommends libusb-1.0 libc++-7-dev wget unzip ca-certificates libdc1394-22 libavcodec57 libavformat57 && \ 134 | apt-get clean && \ 135 | rm -rf /var/lib/apt/lists/* 136 | RUN mkdir -p /opt/doods 137 | WORKDIR /opt/doods 138 | COPY --from=builder /usr/local/lib/. /usr/local/lib/. 139 | COPY --from=builder /build/doods /opt/doods/doods 140 | RUN ldconfig 141 | 142 | # Download sample models 143 | RUN mkdir models 144 | RUN wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 145 | RUN wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 146 | RUN wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 --no-same-owner && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 147 | RUN wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 148 | ADD config.yaml config.yaml 149 | 150 | CMD ["/opt/doods/doods", "-c", "/opt/doods/config.yaml", "api"] 151 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Zach Brown 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | EXECUTABLE := doods 2 | GITVERSION := $(shell git describe --dirty --always --tags --long) 3 | GOPATH ?= ${HOME}/go 4 | TAG ?= latest 5 | PACKAGENAME := $(shell go list -m -f '{{.Path}}') 6 | TOOLS := ${GOPATH}/src/github.com/gogo/protobuf/proto \ 7 | ${GOPATH}/bin/protoc-gen-gogoslick \ 8 | ${GOPATH}/bin/protoc-gen-grpc-gateway \ 9 | ${GOPATH}/bin/protoc-gen-swagger 10 | export PROTOBUF_INCLUDES = -I. -I/usr/include -I${GOPATH}/src -I$(shell go list -e -f '{{.Dir}}' .) -I$(shell go list -e -f '{{.Dir}}' github.com/grpc-ecosystem/grpc-gateway/runtime)/../third_party/googleapis 11 | PROTOS := ./server/rpc/version.pb.gw.go \ 12 | ./odrpc/rpc.pb.gw.go 13 | 14 | .PHONY: default 15 | default: ${EXECUTABLE} 16 | 17 | # This is all the tools required to compile, test and handle protobufs 18 | tools: ${TOOLS} 19 | 20 | ${GOPATH}/src/github.com/gogo/protobuf/proto: 21 | GO111MODULE=off go get github.com/gogo/protobuf/proto 22 | 23 | ${GOPATH}/bin/protoc-gen-gogoslick: 24 | go get github.com/gogo/protobuf/protoc-gen-gogoslick 25 | 26 | ${GOPATH}/bin/protoc-gen-grpc-gateway: 27 | go get github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway 28 | 29 | ${GOPATH}/bin/protoc-gen-swagger: 30 | go get github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger 31 | 32 | # Handle all grpc endpoint protobufs 33 | %.pb.gw.go: %.proto 34 | protoc ${PROTOBUF_INCLUDES} --gogoslick_out=paths=source_relative,plugins=grpc:. --grpc-gateway_out=paths=source_relative,logtostderr=true:. --swagger_out=logtostderr=true:. $*.proto 35 | 36 | # Handle any non-specific protobufs 37 | %.pb.go: %.proto 38 | protoc ${PROTOBUF_INCLUDES} --gogoslick_out=paths=source_relative,plugins=grpc:. $*.proto 39 | 40 | .PHONY: ${EXECUTABLE} 41 | ${EXECUTABLE}: tools ${PROTOS} 42 | # Compiling... 43 | go build -ldflags "-X ${PACKAGENAME}/conf.Executable=${EXECUTABLE} -X ${PACKAGENAME}/conf.GitVersion=${GITVERSION}" -o ${EXECUTABLE} 44 | 45 | .PHONY: test 46 | test: tools ${PROTOS} 47 | go test -cover ./... 48 | 49 | deps: 50 | # Fetching dependancies... 51 | go get -d -v # Adding -u here will break CI 52 | 53 | docker: 54 | docker build -t docker.io/snowzach/doods:local -f Dockerfile . 55 | 56 | docker-images: docker-noavx docker-amd64 docker-arm32 docker-arm64 57 | docker manifest push --purge snowzach/doods:latest 58 | docker manifest create snowzach/doods:latest snowzach/doods:noavx snowzach/doods:arm32 snowzach/doods:arm64 59 | docker manifest push snowzach/doods:latest 60 | 61 | .PHONY: docker-noavx 62 | docker-noavx: 63 | docker build -t docker.io/snowzach/doods:noavx -f Dockerfile.noavx . 64 | docker push docker.io/snowzach/doods:noavx 65 | 66 | .PHONY: docker-amd64 67 | docker-amd64: 68 | docker build -t docker.io/snowzach/doods:amd64 -f Dockerfile.amd64 . 69 | docker push docker.io/snowzach/doods:amd64 70 | 71 | .PHONY: docker-arm32 72 | docker-arm32: 73 | docker build -t docker.io/snowzach/doods:arm32 -f Dockerfile.arm32 . 74 | docker push docker.io/snowzach/doods:arm32 75 | 76 | .PHONY: docker-arm64 77 | docker-arm64: 78 | docker build -t docker.io/snowzach/doods:arm64 -f Dockerfile.arm64 . 79 | docker push docker.io/snowzach/doods:arm64 80 | 81 | .PHONY: docker-builder 82 | docker-builder: 83 | docker build -t docker.io/snowzach/doods:builder -f Dockerfile.builder . 84 | 85 | .PHONY: libedgetpu 86 | libedgetpu: 87 | git clone https://github.com/google-coral/libedgetpu || true 88 | bash -c 'cd libedgetpu; DOCKER_CPUS="k8 armv7a aarch64" DOCKER_TARGETS=libedgetpu make docker-build' 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEPRECATED!!! 2 | DOODS is now deprecated in favor of DOODS2... Now with more Python... 3 | 4 | https://github.com/snowzach/doods2 5 | 6 | # DOODS 7 | Dedicated Open Object Detection Service - Yes, it's a backronym... 8 | 9 | DOODS is a GRPC/REST service that detects objects in images. It's designed to be very easy to use, run as a container and available remotely. 10 | 11 | ## API 12 | The API uses gRPC to communicate but it has a REST gateway built in for ease of use. It supports both a single call RPC and a streaming interface. 13 | It supports very basic pre-shared key authentication if you wish to protect it. It also supports TLS encryption but is disabled by default. 14 | It uses the content-type header to automatically determine if you are connecting in REST mode or GRPC mode. It listens on port 8080 by default. 15 | 16 | ### GRPC Endpoints 17 | The protobuf API definitations are in the `odrpc/odrpc.proto` file. There are 3 endpoints. 18 | 19 | - GetDetector - Get the list of configured detectors. 20 | - Detect - Detect objects in an image - Data should be passed as raw bytes in GRPC. 21 | - DetectStream - Detect objects in a stream of images 22 | 23 | ### REST/JSON 24 | The services are available via rest API at these endpoints 25 | * `GET /version` - Get the version 26 | * `GET /detectors` - Get the list of configured detectors 27 | * `POST /detect` - Detect objects in an image 28 | 29 | For `POST /detect` it expects JSON in the following format. 30 | ``` 31 | { 32 | "detector_name": "default", 33 | "data": "", 34 | "file": " 35 | "detect": { 36 | "*": 50 37 | } 38 | } 39 | ``` 40 | 41 | The result is returned as: 42 | ``` 43 | { 44 | "id": "test", 45 | "detections": [ 46 | { 47 | "top": 0, 48 | "left": 0.05, 49 | "bottom": .8552, 50 | "right": 0.9441, 51 | "label": "person", 52 | "confidence": 87.890625 53 | } 54 | ] 55 | } 56 | ``` 57 | 58 | You can specify regions for specific detections: 59 | For `POST /detect` it expects JSON in the following format. If you specify covers than the detection region must completely cover 60 | the region you specify. If covers is false, if any detection is inside any part of the region it will trigger. 61 | ``` 62 | { 63 | "detector_name": "default", 64 | "data": "", 65 | "file": "", 66 | "regions": [ 67 | { 68 | "top": 0, 69 | "left": 0, 70 | "bottom": 1, 71 | "right": 1, 72 | "detect": { 73 | "person": 50, 74 | "*": 90 75 | }, 76 | "covers": true 77 | } 78 | ] 79 | } 80 | ``` 81 | 82 | This will perform a detection using the detector called default. (If omitted, it will use one called default if it exists) 83 | The `data`, when using the REST interface is base64 encoded image data. DOODS can decode png, bmp and jpg. 84 | You can also pass `file` in place of data to read the file from the machine DOODS is running on. `file` will override data. 85 | The `detect` object allows you to specify the list of objects to detect as defined in the labels file. You can give a min percentage match. 86 | You can also use "*" which will match anything with a minimum percentage. 87 | 88 | Example 1-Liner to call the API using curl with image data: 89 | ``` 90 | echo "{\"detector_name\":\"default\", \"detect\":{\"*\":60}, \"data\":\"`cat grace_hopper.png|base64 -w0`\"}" > /tmp/postdata.json && curl -d@/tmp/postdata.json -H "Content-Type: application/json" -X POST http://localhost:8080/detect 91 | ``` 92 | 93 | Another example 1-Liner specifying a region: 94 | ``` 95 | echo "{\"detector_name\":\"default\", \"regions\":[{\"top\":0,\"left\":0,\"bottom\":1,\"right\":1,\"detect\":{\"person\":40}}], \"data\":\"`cat grace_hopper.png|base64 -w0`\"}" > /tmp/postdata.json && curl -d@/tmp/postdata.json -H "Content-Type: application/json" -X POST http://localhost:8087/detect 96 | ``` 97 | 98 | ## Detectors 99 | You should optimally pass image data in the requested size for the detector. If not, it will be automatically resized. 100 | It can read BMP, PNG and JPG as well as PPM. For detectors that do not specify a size (inception) you do not need to resize 101 | 102 | ### TFLite 103 | If you pass PPM image data in the right dimensions, it can be fed directly into tensorflow lite. This skips a couple steps for speed. 104 | You can also specify `hwAccel: true` in the config and it will enable Coral EdgeTPU hardware acceleration. 105 | You must also provide it an appropriate EdgeTPU model file. There are none included with the base image. 106 | 107 | ## Compiling 108 | This is designed as a go module aware program and thus requires go 1.12 or better. It also relies heavily on CGO. The easiest way to compile it 109 | is to use the Dockerfile which will build a functioning docker image. It's a little large but it includes 2 models. 110 | 111 | ## Configuration 112 | The configuration can be specified in a number of ways. By default you can create a json file and call it with the -c option 113 | you can also specify environment variables that align with the config file values. 114 | 115 | Example: 116 | ```json 117 | { 118 | "logger": { 119 | "level": "debug" 120 | } 121 | } 122 | ``` 123 | Can be set via an environment variable: 124 | ``` 125 | LOGGER_LEVEL=debug 126 | ``` 127 | 128 | ### Options: 129 | | Setting | Description | Default | 130 | | ------------------------- | --------------------------------------------------- | ------------ | 131 | | logger.level | The default logging level | "info" | 132 | | logger.encoding | Logging format (console or json) | "console" | 133 | | logger.color | Enable color in console mode | true | 134 | | logger.disable_caller | Hide the caller source file and line number | false | 135 | | logger.disable_stacktrace | Hide a stacktrace on debug logs | true | 136 | | --- | --- | --- | 137 | | server.host | The host address to listen on (blank=all addresses) | "" | 138 | | server.port | The port number to listen on | 8080 | 139 | | server.tls | Enable https/tls | false | 140 | | server.devcert | Generate a development cert | false | 141 | | server.certfile | The HTTPS/TLS server certificate | "server.crt" | 142 | | server.keyfile | The HTTPS/TLS server key file | "server.key" | 143 | | server.log_requests | Log API requests | true | 144 | | server.profiler_enabled | Enable the profiler | false | 145 | | server.profiler_path | Where should the profiler be available | "/debug" | 146 | | --- | --- | --- | 147 | | pidfile | Write a pidfile (only if specified) | "" | 148 | | profiler.enabled | Enable the debug pprof interface | "false" | 149 | | profiler.host | The profiler host address to listen on | "" | 150 | | profiler.port | The profiler port to listen on | "6060" | 151 | | --- | --- | --- | 152 | | doods.auth_key | A pre-shared auth key. Disabled if blank | "" | 153 | | doods.detectors | The detector configurations | | 154 | 155 | ### TLS/HTTPS 156 | You can enable https by setting the config option server.tls = true and pointing it to your keyfile and certfile. 157 | To create a self-signed cert: `openssl req -new -newkey rsa:2048 -days 3650 -nodes -x509 -keyout server.key -out server.crt` 158 | You will need to mount these in the container and adjust the config to find them. 159 | 160 | ### Detector Config 161 | Detector config must be done with a configuration file. The default config includes one Tensorflow Lite mobilenet detector and the Tensorflow Inception model. 162 | This is the default config with the exception of the threads and concurrent are tuned a bit for the architecture they are running on. 163 | ``` 164 | doods: 165 | detectors: 166 | - name: default 167 | type: tflite 168 | modelFile: models/coco_ssd_mobilenet_v1_1.0_quant.tflite 169 | labelFile: models/coco_labels0.txt 170 | numThreads: 4 171 | numConcurrent: 4 172 | hwAccel: false 173 | timeout: 2m 174 | - name: tensorflow 175 | type: tensorflow 176 | modelFile: models/faster_rcnn_inception_v2_coco_2018_01_28.pb 177 | labelFile: models/coco_labels1.txt 178 | numThreads: 4 179 | numConcurrent: 4 180 | hwAccel: false 181 | timeout: 2m 182 | ``` 183 | The default models are downloaded from google: coco_ssd_mobilenet_v1_1.0_quant_2018_06_29 and faster_rcnn_inception_v2_coco_2018_01_28.pb 184 | 185 | [default/tflite model labels](https://dl.google.com/coral/canned_models/coco_labels.txt) 186 | 187 | [tensorflow model labels](https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt) 188 | 189 | The `numThreads` option is the number of threads that will be available for compatible operations in a model 190 | The `numConcurrent` option sets the number of models that will be able to run at the same time. This should be 1 unless you have a beefy machine. 191 | The `hwAccel` option is used to specify that a hardware device should be used. The only device supported is the edgetpu currently 192 | If `timeout` is set than a detector (namely an edgetpu) that hangs for longer than the timeout will cause doods to error and exit. Generally this error is not recoverable and Doods needs to be restarted. 193 | 194 | ### Detector Types Supported 195 | * tflite - Tensorflow lite models - Supports Coral EdgeTPU if hwAccel: true and appropriate model is used 196 | * tensorflow - Tensorflow 197 | 198 | EdgeTPU models can be downloaded from here: https://coral.ai/models/ (Use the Object Detection Models) 199 | 200 | ## Examples - Clients 201 | See the examples directory for sample clients 202 | 203 | ## Docker 204 | To run the container in docker you need to map port 8080. If you want to update the models, you need to map model files and a config to use them. 205 | `docker run -it -p 8080:8080 snowzach/doods:latest` 206 | 207 | There is a script called `fetch_models.sh` that you can download and run to create a models directory and download several models and outputs an `example.yaml` config file. 208 | You could then run: `docker run -it -v ./models:/opt/doods/models -v ./example.yaml:/opt/doods/config.yaml -p 8080:8080 snowzach/doods:latest` 209 | 210 | ### Coral EdgeTPU 211 | If you want to run it in docker using the Coral EdgeTPU, you need to pass the device to the container with: `--device /dev/bus/usb` 212 | Example: `docker run -it --device /dev/bus/usb -p 8080:8080 snowzach/doods:latest` 213 | 214 | ## Misc 215 | Special thanks to https://github.com/mattn/go-tflite as I would have never been able to figure out all the CGO stuff. I really wanted to write this in Go but I'm not good enough at C++/CGO to do it. Most of the tflite code is taken from that repo and customized for this tool. 216 | 217 | And special thanks to @lhelontra, @marianopeck and @PINTO0309 for help in building tensorflow and binaries for bazel on the arm. 218 | 219 | ## Docker Images 220 | There are several published Docker images that you can use 221 | 222 | * latest - This is a multi-arch image that points to the arm32 image, arm64 and noavx image 223 | * noavx - 64 bit x86 image that should be a highly compatible with any cpu. 224 | * arm64 - Arm 64 bit image 225 | * arm32 - Arm 32 bit/arm7 image optimized for the Raspberry Pi 226 | * amd64 - 64 bit x86 image with all the fancy cpu features like avx and sse4.2 227 | * cuda - Support for NVidia GPUs 228 | 229 | ## CUDA Support 230 | There is now NVidia GPU support with an docker image tagged cuda, to run: 231 | `docker run -it --gpus all -p 8080:8080 snowzach/doods:cuda` 232 | For whatever reason, it can take a good 60-80 seconds before the model finishes loading. 233 | 234 | ## Compiling 235 | You can compile it yourself using the plain `Dockerfile` which should pick the optimal CPU flags for your architecture. 236 | Make the `snowzach/doods:local` image with this command: 237 | ``` 238 | $ make libedgetpu 239 | $ make docker 240 | ``` 241 | You only need to make libedgetpu once, it will download and compile it for all architectures. I hope to streamline that process into the main dockerfile at some point 242 | 243 | [![paypal](https://www.paypalobjects.com/en_US/i/btn/btn_donateCC_LG.gif)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=QG353JUXA6BFW&source=url) 244 | -------------------------------------------------------------------------------- /builder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Check for edgetpu and specify devices if found 3 | DOCKER_EXTRA="" 4 | if `lsusb | egrep "(1a6e:089a|18d1:9302)" > /dev/null`; then 5 | DOCKER_EXTRA="--device /dev/bus/usb " 6 | echo "EdgeTPU detected..." 7 | fi 8 | 9 | docker run -it -v $PWD:/build -p 8090:8080 ${DOCKER_EXTRA} snowzach/doods:builder bash 10 | -------------------------------------------------------------------------------- /cmd/api.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | cli "github.com/spf13/cobra" 5 | "go.uber.org/zap" 6 | 7 | "github.com/snowzach/doods/conf" 8 | "github.com/snowzach/doods/detector" 9 | "github.com/snowzach/doods/odrpc" 10 | "github.com/snowzach/doods/server" 11 | ) 12 | 13 | func init() { 14 | rootCmd.AddCommand(apiCmd) 15 | } 16 | 17 | var ( 18 | apiCmd = &cli.Command{ 19 | Use: "api", 20 | Short: "Start API", 21 | Long: `Start API`, 22 | Run: func(cmd *cli.Command, args []string) { // Initialize the databse 23 | 24 | // Create the detector mux server 25 | d := detector.New() 26 | 27 | // Create the server 28 | s, err := server.New() 29 | if err != nil { 30 | logger.Fatalw("Could not create server", 31 | "error", err, 32 | ) 33 | } 34 | 35 | // Register the RPC server and it's GRPC Gateway for when it starts 36 | odrpc.RegisterOdrpcServer(s.GRPCServer(), d) 37 | s.GWReg(odrpc.RegisterOdrpcHandlerFromEndpoint) 38 | 39 | err = s.ListenAndServe() 40 | if err != nil { 41 | logger.Fatalw("Could not start server", 42 | "error", err, 43 | ) 44 | } 45 | 46 | <-conf.Stop.Chan() // Wait until StopChan 47 | conf.Stop.Wait() // Wait until everyone cleans up 48 | zap.L().Sync() // Flush the logger 49 | 50 | }, 51 | } 52 | ) 53 | -------------------------------------------------------------------------------- /cmd/client.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "fmt" 7 | "net" 8 | 9 | emptypb "github.com/golang/protobuf/ptypes/empty" 10 | cli "github.com/spf13/cobra" 11 | config "github.com/spf13/viper" 12 | "go.uber.org/zap" 13 | "google.golang.org/grpc" 14 | "google.golang.org/grpc/credentials" 15 | 16 | "github.com/snowzach/doods/server/rpc" 17 | ) 18 | 19 | var () 20 | 21 | func init() { 22 | 23 | rootCmd.AddCommand(&cli.Command{ 24 | Use: "client", 25 | Short: "CLI Client", 26 | Long: `CLI Client`, 27 | Run: func(cmd *cli.Command, args []string) { // Initialize the databse 28 | 29 | dialOptions := []grpc.DialOption{ 30 | grpc.WithBlock(), 31 | } 32 | if config.GetBool("server.tls") { 33 | dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}))) 34 | } else { 35 | dialOptions = append(dialOptions, grpc.WithInsecure()) 36 | } 37 | 38 | // Set up a connection to the gRPC server. 39 | conn, err := grpc.Dial(net.JoinHostPort(config.GetString("server.host"), config.GetString("server.port")), dialOptions...) 40 | if err != nil { 41 | logger.Fatalw("Could not connect", "error", err) 42 | } 43 | defer conn.Close() 44 | 45 | // gRPC version Client 46 | versionClient := rpc.NewVersionRPCClient(conn) 47 | 48 | // Make RPC call 49 | version, err := versionClient.Version(context.Background(), &emptypb.Empty{}) 50 | if err != nil { 51 | logger.Fatalw("Could not call Version", "error", err) 52 | } 53 | 54 | fmt.Printf("Version: %s\n", version.Version) 55 | 56 | // // gRPC thing Client 57 | // thingClient := rpc.NewThingRPCClient(conn) 58 | 59 | // // Make RPC call 60 | // things, err := thingClient.ThingFind(context.Background(), &emptypb.Empty{}) 61 | // if err != nil { 62 | // logger.Fatalw("Could not call ThingFind", "error", err) 63 | // } 64 | 65 | // // Pretty print it as JSON 66 | // b, err := json.MarshalIndent(things, "", " ") 67 | // if err != nil { 68 | // logger.Fatalw("Could not convert to JSON", "error", err) 69 | // } 70 | // fmt.Println(string(b)) 71 | 72 | zap.L().Sync() // Flush the logger 73 | 74 | }, 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | 8 | "net/http" 9 | _ "net/http/pprof" // Import for pprof 10 | 11 | cli "github.com/spf13/cobra" 12 | config "github.com/spf13/viper" 13 | "go.uber.org/zap" 14 | 15 | "github.com/snowzach/doods/conf" 16 | ) 17 | 18 | var ( 19 | 20 | // Config and global logger 21 | configFile string 22 | pidFile string 23 | logger *zap.SugaredLogger 24 | 25 | // The Root Cli Handler 26 | rootCmd = &cli.Command{ 27 | Version: conf.GitVersion, 28 | Use: conf.Executable, 29 | PersistentPreRunE: func(cmd *cli.Command, args []string) error { 30 | // Create Pid File 31 | pidFile = config.GetString("pidfile") 32 | if pidFile != "" { 33 | file, err := os.OpenFile(pidFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) 34 | if err != nil { 35 | return fmt.Errorf("Could not create pid file: %s Error:%v", pidFile, err) 36 | } 37 | defer file.Close() 38 | _, err = fmt.Fprintf(file, "%d\n", os.Getpid()) 39 | if err != nil { 40 | return fmt.Errorf("Could not create pid file: %s Error:%v", pidFile, err) 41 | } 42 | } 43 | return nil 44 | }, 45 | PersistentPostRun: func(cmd *cli.Command, args []string) { 46 | // Remove Pid file 47 | if pidFile != "" { 48 | os.Remove(pidFile) 49 | } 50 | }, 51 | } 52 | ) 53 | 54 | // Execute starts the program 55 | func Execute() { 56 | // Run the program 57 | if err := rootCmd.Execute(); err != nil { 58 | fmt.Fprintf(os.Stderr, "%s\n", err.Error()) 59 | } 60 | } 61 | 62 | // This is the main initializer handling cli, config and log 63 | func init() { 64 | // Initialize configuration 65 | cli.OnInitialize(initConfig, initLogger, initProfiler) 66 | rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "Config file") 67 | } 68 | 69 | // initConfig reads in config file and ENV variables if set. 70 | func initConfig() { 71 | 72 | // If a config file is found, read it in. 73 | if configFile != "" { 74 | config.SetConfigFile(configFile) 75 | err := config.ReadInConfig() 76 | if err != nil { 77 | fmt.Fprintf(os.Stderr, "Could not read config file: %s ERROR: %s\n", configFile, err.Error()) 78 | os.Exit(1) 79 | } 80 | 81 | } 82 | } 83 | 84 | func initLogger() { 85 | conf.InitLogger() 86 | logger = zap.S().With("package", "cmd") 87 | } 88 | 89 | // Profiler can explicitly listen on address/port 90 | func initProfiler() { 91 | if config.GetBool("profiler.enabled") { 92 | hostPort := net.JoinHostPort(config.GetString("profiler.host"), config.GetString("profiler.port")) 93 | go http.ListenAndServe(hostPort, nil) 94 | logger.Infof("Profiler enabled on http://%s", hostPort) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /cmd/version.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | cli "github.com/spf13/cobra" 7 | 8 | "github.com/snowzach/doods/conf" 9 | ) 10 | 11 | // Version command 12 | func init() { 13 | rootCmd.AddCommand(&cli.Command{ 14 | Use: "version", 15 | Short: "Show version", 16 | Long: `Show version`, 17 | Run: func(cmd *cli.Command, args []string) { 18 | fmt.Println(conf.Executable + " - " + conf.GitVersion) 19 | }, 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /conf/defaults.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | 7 | config "github.com/spf13/viper" 8 | 9 | "github.com/snowzach/doods/detector/dconfig" 10 | ) 11 | 12 | func init() { 13 | // Sets up the config file, environment etc 14 | config.SetTypeByDefaultValue(true) // If a default value is []string{"a"} an environment variable of "a b" will end up []string{"a","b"} 15 | config.AutomaticEnv() // Automatically use environment variables where available 16 | config.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // Environement variables use underscores instead of periods 17 | 18 | // Logger Defaults 19 | config.SetDefault("logger.level", "info") 20 | config.SetDefault("logger.encoding", "console") 21 | config.SetDefault("logger.color", true) 22 | config.SetDefault("logger.dev_mode", true) 23 | config.SetDefault("logger.disable_caller", false) 24 | config.SetDefault("logger.disable_stacktrace", true) 25 | 26 | // Pidfile 27 | config.SetDefault("pidfile", "") 28 | 29 | // Profiler config 30 | config.SetDefault("profiler.enabled", false) 31 | config.SetDefault("profiler.host", "") 32 | config.SetDefault("profiler.port", "6060") 33 | 34 | // Server Configuration 35 | config.SetDefault("server.host", "") 36 | config.SetDefault("server.port", "8080") 37 | config.SetDefault("server.tls", false) 38 | config.SetDefault("server.devcert", false) 39 | config.SetDefault("server.certfile", "server.crt") 40 | config.SetDefault("server.keyfile", "server.key") 41 | config.SetDefault("server.max_msg_size", 64000000) 42 | config.SetDefault("server.log_requests", true) 43 | config.SetDefault("server.profiler_enabled", false) 44 | config.SetDefault("server.profiler_path", "/debug") 45 | config.SetDefault("server.allowed_origins", []string{"*"}) 46 | config.SetDefault("server.allowed_methods", []string{http.MethodHead, http.MethodOptions, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch}) 47 | config.SetDefault("server.allowed_headers", []string{"*"}) 48 | config.SetDefault("server.allowed_credentials", false) 49 | config.SetDefault("server.max_age", 300) 50 | 51 | // Main settings 52 | config.SetDefault("doods.auth_key", "") 53 | config.SetDefault("doods.detectors", []*dconfig.DetectorConfig{}) 54 | 55 | } 56 | -------------------------------------------------------------------------------- /conf/logger.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "github.com/blendle/zapdriver" 5 | config "github.com/spf13/viper" 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | func InitLogger() { 11 | 12 | logConfig := zap.NewProductionConfig() 13 | 14 | // Log Level 15 | var logLevel zapcore.Level 16 | if err := logLevel.Set(config.GetString("logger.level")); err != nil { 17 | zap.S().Fatalw("Could not determine logger.level", "error", err) 18 | } 19 | logConfig.Level.SetLevel(logLevel) 20 | 21 | // Handle different logger encodings 22 | loggerEncoding := config.GetString("logger.encoding") 23 | switch loggerEncoding { 24 | case "stackdriver": 25 | logConfig.Encoding = "json" 26 | logConfig.EncoderConfig = zapdriver.NewDevelopmentEncoderConfig() 27 | default: 28 | logConfig.Encoding = loggerEncoding 29 | // Enable Color 30 | if config.GetBool("logger.color") { 31 | logConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder 32 | } 33 | logConfig.DisableStacktrace = config.GetBool("logger.disable_stacktrace") 34 | // Use sane timestamp when logging to console 35 | if logConfig.Encoding == "console" { 36 | logConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder 37 | } 38 | 39 | // JSON Fields 40 | logConfig.EncoderConfig.MessageKey = "msg" 41 | logConfig.EncoderConfig.LevelKey = "level" 42 | logConfig.EncoderConfig.CallerKey = "caller" 43 | } 44 | 45 | // Settings 46 | logConfig.Development = config.GetBool("logger.dev_mode") 47 | logConfig.DisableCaller = config.GetBool("logger.disable_caller") 48 | 49 | // Build the logger 50 | globalLogger, _ := logConfig.Build() 51 | zap.ReplaceGlobals(globalLogger) 52 | 53 | } 54 | -------------------------------------------------------------------------------- /conf/signal.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "os/signal" 7 | "sync" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | type stop struct { 13 | // c is a channel that is closed when we are stopping 14 | c chan struct{} 15 | // WaitGroup is a embedded WaitGroup that will wait before exiting cleanly to allow for cleanup 16 | sync.WaitGroup 17 | // Context will be canceled when stop is called 18 | Context context.Context 19 | cancel context.CancelFunc 20 | } 21 | 22 | var ( 23 | // Global Stop instance 24 | Stop = &stop{ 25 | c: make(chan struct{}), 26 | } 27 | // Handle signals 28 | signalChannel = make(chan os.Signal, 1) 29 | ) 30 | 31 | // Handles all incoming signals 32 | func init() { 33 | 34 | Stop.Context, Stop.cancel = context.WithCancel(context.Background()) 35 | 36 | // Stop flag will indicate if Ctrl-C/Interrupt has been sent to the process 37 | signal.Notify(signalChannel, os.Interrupt) 38 | 39 | // Handke signals 40 | go func() { 41 | for { 42 | for sig := range signalChannel { 43 | switch sig { 44 | case os.Interrupt: 45 | zap.S().Info("Received Interrupt...") 46 | close(Stop.c) 47 | Stop.cancel() 48 | return 49 | } 50 | } 51 | } 52 | }() 53 | 54 | } 55 | 56 | // Chan returns a read only channel that is closed when the program should exit 57 | func (s *stop) Chan() <-chan struct{} { 58 | return s.c 59 | } 60 | 61 | // Bool returns t/f if we should stop 62 | func (s *stop) Bool() bool { 63 | select { 64 | case <-s.c: 65 | return true 66 | default: 67 | return false 68 | } 69 | } 70 | 71 | // This will force a stop 72 | func (s *stop) Stop() { 73 | close(Stop.c) 74 | Stop.cancel() 75 | } 76 | -------------------------------------------------------------------------------- /conf/version.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | var ( 4 | // Executable is overridden by Makefile with executable name 5 | Executable = "NoExecutable" 6 | // GitVersion is overridden by Makefile with git information 7 | GitVersion = "NoGitVersion" 8 | ) 9 | -------------------------------------------------------------------------------- /config.arm.yaml: -------------------------------------------------------------------------------- 1 | doods: 2 | detectors: 3 | - name: default 4 | type: tflite 5 | modelFile: models/coco_ssd_mobilenet_v1_1.0_quant.tflite 6 | labelFile: models/coco_labels0.txt 7 | numThreads: 4 8 | numConcurrent: 4 9 | hwAccel: false 10 | timeout: 2m 11 | - name: tensorflow 12 | type: tensorflow 13 | modelFile: models/faster_rcnn_inception_v2_coco_2018_01_28.pb 14 | labelFile: models/coco_labels1.txt 15 | numThreads: 4 16 | numConcurrent: 1 17 | hwAccel: false 18 | timeout: 2m 19 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | doods: 2 | detectors: 3 | - name: default 4 | type: tflite 5 | modelFile: models/coco_ssd_mobilenet_v1_1.0_quant.tflite 6 | labelFile: models/coco_labels0.txt 7 | numThreads: 4 8 | numConcurrent: 4 9 | hwAccel: false 10 | timeout: 2m 11 | - name: tensorflow 12 | type: tensorflow 13 | modelFile: models/faster_rcnn_inception_v2_coco_2018_01_28.pb 14 | labelFile: models/coco_labels1.txt 15 | numThreads: 4 16 | numConcurrent: 4 17 | hwAccel: false 18 | timeout: 2m 19 | -------------------------------------------------------------------------------- /detector/auth.go: -------------------------------------------------------------------------------- 1 | package detector 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "google.golang.org/grpc/codes" 8 | "google.golang.org/grpc/metadata" 9 | "google.golang.org/grpc/status" 10 | 11 | "github.com/snowzach/doods/odrpc" 12 | ) 13 | 14 | // AuthFuncOverride will handle authentication 15 | func (m *Mux) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { 16 | 17 | // Auth disabled 18 | if m.authKey == "" { 19 | return ctx, nil 20 | } 21 | 22 | // Get request metadata 23 | md, ok := metadata.FromIncomingContext(ctx) 24 | if !ok { 25 | return ctx, status.Errorf(codes.PermissionDenied, "Permission Denied") 26 | } 27 | 28 | // Get the user pubKeyString 29 | if mdfirst(md, odrpc.DoodsAuthKeyHeader) != m.authKey { 30 | return ctx, status.Errorf(codes.PermissionDenied, "Invalid Login") 31 | } 32 | 33 | return ctx, nil 34 | 35 | } 36 | 37 | func mdfirst(md metadata.MD, key string) string { 38 | val := md.Get(strings.ToLower(key)) 39 | if len(val) > 0 { 40 | return val[0] 41 | } 42 | return "" 43 | } 44 | -------------------------------------------------------------------------------- /detector/dconfig/dconfig.go: -------------------------------------------------------------------------------- 1 | package dconfig 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // Detector config is used for parsing configuration data from the config file 8 | type DetectorConfig struct { 9 | Name string `json:"name"` 10 | Type string `json:"type"` 11 | ModelFile string `json:"model_file"` 12 | LabelFile string `json:"label_file"` 13 | NumThreads int `json:"num_threads"` 14 | NumConcurrent int `json:"num_concurrent"` 15 | HWAccel bool `json:"hw_accel"` 16 | Timeout time.Duration `json:"timeout"` 17 | } 18 | -------------------------------------------------------------------------------- /detector/detector.go: -------------------------------------------------------------------------------- 1 | package detector 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "sync" 7 | 8 | // We will support these formats 9 | _ "image/gif" 10 | _ "image/jpeg" 11 | _ "image/png" 12 | 13 | _ "github.com/lmittmann/ppm" 14 | _ "golang.org/x/image/bmp" 15 | 16 | emptypb "github.com/golang/protobuf/ptypes/empty" 17 | config "github.com/spf13/viper" 18 | "go.uber.org/zap" 19 | "google.golang.org/grpc/codes" 20 | "google.golang.org/grpc/status" 21 | 22 | "github.com/snowzach/doods/conf" 23 | "github.com/snowzach/doods/detector/dconfig" 24 | "github.com/snowzach/doods/detector/tensorflow" 25 | "github.com/snowzach/doods/detector/tflite" 26 | "github.com/snowzach/doods/odrpc" 27 | ) 28 | 29 | // Detector is the interface to object detectors 30 | type Detector interface { 31 | Config() *odrpc.Detector 32 | Detect(ctx context.Context, request *odrpc.DetectRequest) (*odrpc.DetectResponse, error) 33 | Shutdown() 34 | } 35 | 36 | // Mux handles and routes requests to the configured detectors 37 | type Mux struct { 38 | detectors map[string]Detector 39 | authKey string 40 | logger *zap.SugaredLogger 41 | } 42 | 43 | // Create a new mux 44 | func New() *Mux { 45 | 46 | m := &Mux{ 47 | detectors: make(map[string]Detector), 48 | authKey: config.GetString("doods.auth_key"), 49 | logger: zap.S().With("package", "detector"), 50 | } 51 | 52 | // Get the detectors config 53 | var detectorConfig []*dconfig.DetectorConfig 54 | config.UnmarshalKey("doods.detectors", &detectorConfig) 55 | 56 | // Create the detectors 57 | for _, c := range detectorConfig { 58 | var d Detector 59 | var err error 60 | 61 | m.logger.Debugw("Configuring detector", "config", c) 62 | 63 | switch c.Type { 64 | case "tflite": 65 | d, err = tflite.New(c) 66 | case "tensorflow": 67 | d, err = tensorflow.New(c) 68 | default: 69 | m.logger.Errorw("Could not initialize detector", "name", c.Name, "type", c.Type) 70 | continue 71 | } 72 | 73 | if err != nil { 74 | m.logger.Errorf("Could not initialize detector %s: %v", c.Name, err) 75 | continue 76 | } 77 | 78 | dc := d.Config() 79 | m.logger.Infow("Configured Detector", "name", dc.Name, "type", dc.Type, "model", dc.Model, "labels", len(dc.Labels), "width", dc.Width, "height", dc.Height) 80 | m.detectors[c.Name] = d 81 | } 82 | 83 | if len(m.detectors) == 0 { 84 | m.logger.Fatalf("No detectors configured") 85 | } 86 | 87 | return m 88 | 89 | } 90 | 91 | // GetDetectors returns the configured detectors 92 | func (m *Mux) GetDetectors(ctx context.Context, _ *emptypb.Empty) (*odrpc.GetDetectorsResponse, error) { 93 | detectors := make([]*odrpc.Detector, 0) 94 | for _, d := range m.detectors { 95 | detectors = append(detectors, d.Config()) 96 | } 97 | return &odrpc.GetDetectorsResponse{ 98 | Detectors: detectors, 99 | }, nil 100 | } 101 | 102 | // Shutdown deallocates/shuts down any detectors 103 | func (m *Mux) Shutdown() { 104 | for _, d := range m.detectors { 105 | d.Shutdown() 106 | } 107 | } 108 | 109 | // Run a detection 110 | func (m *Mux) Detect(ctx context.Context, request *odrpc.DetectRequest) (*odrpc.DetectResponse, error) { 111 | 112 | if request.DetectorName == "" { 113 | request.DetectorName = "default" 114 | } 115 | 116 | detector, ok := m.detectors[request.DetectorName] 117 | if !ok { 118 | return nil, status.Errorf(codes.NotFound, "not found") 119 | } 120 | 121 | // If file is specified, load the data from a file 122 | var err error 123 | if len(request.File) != 0 { 124 | request.Data, err = ioutil.ReadFile(request.File) 125 | if err != nil { 126 | return nil, status.Errorf(codes.NotFound, "could not open file %s", request.File) 127 | } 128 | } 129 | 130 | response, err := detector.Detect(ctx, request) 131 | if err != nil { 132 | return response, err 133 | } 134 | 135 | m.FilterResponse(request, response) 136 | 137 | return response, nil 138 | 139 | } 140 | 141 | // Handle a stream of detections 142 | func (m *Mux) DetectStream(stream odrpc.Odrpc_DetectStreamServer) error { 143 | 144 | // Handle cancel 145 | ctx, cancel := context.WithCancel(stream.Context()) 146 | go func() { 147 | select { 148 | case <-ctx.Done(): 149 | case <-conf.Stop.Chan(): 150 | cancel() 151 | } 152 | }() 153 | 154 | var send sync.Mutex 155 | var ret error 156 | for ctx.Err() == nil { 157 | 158 | request, err := stream.Recv() 159 | if err != nil { 160 | return nil 161 | } 162 | 163 | m.logger.Info("Stream Request") 164 | 165 | go func(request *odrpc.DetectRequest) { 166 | 167 | response, err := m.Detect(ctx, request) 168 | if err != nil { 169 | // A non-fatal error 170 | if status.Code(err) == codes.Internal { 171 | send.Lock() 172 | ret = err 173 | cancel() 174 | send.Unlock() 175 | return 176 | } else { 177 | response = &odrpc.DetectResponse{ 178 | Id: request.Id, 179 | Error: err.Error(), 180 | } 181 | } 182 | } 183 | 184 | send.Lock() 185 | stream.Send(response) 186 | send.Unlock() 187 | 188 | }(request) 189 | 190 | } 191 | 192 | return ret 193 | 194 | } 195 | -------------------------------------------------------------------------------- /detector/regions.go: -------------------------------------------------------------------------------- 1 | package detector 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/snowzach/doods/odrpc" 7 | ) 8 | 9 | func (m *Mux) FilterResponse(request *odrpc.DetectRequest, response *odrpc.DetectResponse) { 10 | 11 | // No filters, return everything 12 | if len(request.Detect) == 0 && len(request.Regions) == 0 { 13 | return 14 | } 15 | 16 | temp := response.Detections[:0] 17 | 18 | detectionsLoop: 19 | for _, detection := range response.Detections { 20 | // Cleanup the bounds 21 | if detection.Top < 0 { 22 | detection.Top = 0 23 | } 24 | if detection.Left < 0 { 25 | detection.Left = 0 26 | } 27 | if detection.Bottom > 1 { 28 | detection.Bottom = 1 29 | } 30 | if detection.Right > 1 { 31 | detection.Right = 1 32 | } 33 | 34 | // We have this class listed explicitly 35 | if score, ok := request.Detect[detection.Label]; ok { 36 | if detection.Confidence >= score { 37 | temp = append(temp, detection) 38 | continue 39 | } 40 | // Wildcard class 41 | } else if score, ok := request.Detect["*"]; ok { 42 | if detection.Confidence >= score { 43 | temp = append(temp, detection) 44 | continue 45 | } 46 | } 47 | 48 | for _, region := range request.Regions { 49 | var inRegion bool 50 | if region.Covers { 51 | if detection.Top >= region.Top && detection.Left >= region.Left && detection.Bottom <= region.Bottom && detection.Right <= region.Right { 52 | inRegion = true 53 | } 54 | } else { 55 | if detection.Top <= region.Bottom && detection.Left <= region.Right && detection.Bottom >= region.Top && detection.Right >= region.Left { 56 | inRegion = true 57 | } 58 | } 59 | if inRegion { 60 | // We have this class listed explicitly 61 | if score, ok := region.Detect[detection.Label]; ok { 62 | if detection.Confidence >= score { 63 | temp = append(temp, detection) 64 | continue detectionsLoop 65 | } 66 | // Wildcard class 67 | } else if score, ok := region.Detect["*"]; ok { 68 | if detection.Confidence >= score { 69 | temp = append(temp, detection) 70 | continue detectionsLoop 71 | } 72 | } 73 | } 74 | } 75 | } 76 | 77 | response.Detections = temp 78 | for _, detection := range response.Detections { 79 | m.logger.Debugw("Detection", "id", request.Id, "label", detection.Label, "confidence", detection.Confidence, "location", fmt.Sprintf("%f,%f,%f,%f", detection.Top, detection.Left, detection.Bottom, detection.Right)) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /detector/tensorflow/tensorflow.go: -------------------------------------------------------------------------------- 1 | package tensorflow 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "image" 9 | "io/ioutil" 10 | "os" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | tf "github.com/tensorflow/tensorflow/tensorflow/go" 16 | "github.com/tensorflow/tensorflow/tensorflow/go/op" 17 | "go.uber.org/zap" 18 | "golang.org/x/image/bmp" 19 | "google.golang.org/grpc/codes" 20 | "google.golang.org/grpc/status" 21 | 22 | "github.com/snowzach/doods/conf" 23 | "github.com/snowzach/doods/detector/dconfig" 24 | "github.com/snowzach/doods/odrpc" 25 | ) 26 | 27 | type detector struct { 28 | config odrpc.Detector 29 | logger *zap.SugaredLogger 30 | 31 | labels map[int]string 32 | graph *tf.Graph 33 | pool chan *tf.Session 34 | } 35 | 36 | func New(c *dconfig.DetectorConfig) (*detector, error) { 37 | 38 | d := &detector{ 39 | labels: make(map[int]string), 40 | logger: zap.S().With("package", "detector.tensorflow", "name", c.Name), 41 | pool: make(chan *tf.Session, c.NumConcurrent), 42 | } 43 | 44 | d.config.Name = c.Name 45 | d.config.Type = c.Type 46 | d.config.Model = c.ModelFile 47 | d.config.Labels = make([]string, 0) 48 | d.config.Width = -1 49 | d.config.Height = -1 50 | 51 | // Load labels 52 | f, err := os.Open(c.LabelFile) 53 | if err != nil { 54 | return nil, fmt.Errorf("could not load label", "error", err) 55 | } 56 | defer f.Close() 57 | scanner := bufio.NewScanner(f) 58 | for x := 1; scanner.Scan(); x++ { 59 | fields := strings.SplitAfterN(scanner.Text(), " ", 2) 60 | if len(fields) == 1 { 61 | d.labels[x] = fields[0] 62 | d.config.Labels = append(d.config.Labels, fields[0]) 63 | } else if len(fields) == 2 { 64 | if y, err := strconv.Atoi(strings.TrimSpace(fields[0])); err == nil { 65 | d.labels[y] = strings.TrimSpace(fields[1]) 66 | d.config.Labels = append(d.config.Labels, strings.TrimSpace(fields[1])) 67 | } 68 | } 69 | } 70 | 71 | // Raw model data 72 | modelData, err := ioutil.ReadFile(c.ModelFile) 73 | if err != nil { 74 | return nil, fmt.Errorf("Could not read model file %s: %v", c.ModelFile, err) 75 | } 76 | 77 | d.graph = tf.NewGraph() 78 | if err := d.graph.Import(modelData, ""); err != nil { 79 | return nil, fmt.Errorf("Could not import model: %v", err) 80 | } 81 | 82 | // Create sessions 83 | for x := 0; x < c.NumConcurrent; x++ { 84 | s, err := tf.NewSession(d.graph, nil) 85 | if err != nil { 86 | return nil, fmt.Errorf("Could not create session: %v", err) 87 | } 88 | d.pool <- s 89 | } 90 | 91 | return d, nil 92 | 93 | } 94 | 95 | func (d *detector) Config() *odrpc.Detector { 96 | return &d.config 97 | } 98 | 99 | func (d *detector) Shutdown() { 100 | close(d.pool) 101 | for { 102 | sess := <-d.pool 103 | if sess == nil { 104 | break 105 | } 106 | sess.Close() 107 | } 108 | } 109 | 110 | func (d *detector) Detect(ctx context.Context, request *odrpc.DetectRequest) (*odrpc.DetectResponse, error) { 111 | 112 | sess := <-d.pool 113 | conf.Stop.Add(1) // Wait until detection complete before stopping 114 | defer func() { 115 | d.pool <- sess 116 | conf.Stop.Done() 117 | }() 118 | 119 | // Determine the image type 120 | _, imgType, err := image.DecodeConfig(bytes.NewReader(request.Data)) 121 | if err != nil { 122 | return nil, status.Errorf(codes.InvalidArgument, "could not decode image: %v", err) 123 | } 124 | 125 | // If the image is not a supported type, convert it to bmp 126 | if imgType != "png" && imgType != "gif" && imgType != "jpeg" && imgType != "bmp" { 127 | 128 | img, _, err := image.Decode(bytes.NewReader(request.Data)) 129 | if err != nil { 130 | return nil, status.Errorf(codes.InvalidArgument, "could not decode image: %v", err) 131 | } 132 | 133 | // Encode as raw BMP 134 | err = bmp.Encode(bytes.NewBuffer(request.Data), img) 135 | if err != nil { 136 | return nil, status.Errorf(codes.Internal, "could not encode bmp: %v", err) 137 | } 138 | imgType = "bmp" 139 | 140 | } 141 | 142 | scope := op.NewScope() 143 | imgInput := op.Placeholder(scope, tf.String) 144 | 145 | var decodeOutput tf.Output 146 | switch imgType { 147 | case "gif": 148 | decodeOutput = op.DecodeGif(scope, imgInput) 149 | case "jpeg": 150 | decodeOutput = op.DecodeJpeg(scope, imgInput) 151 | case "png": 152 | decodeOutput = op.DecodePng(scope, imgInput) 153 | case "bmp": 154 | decodeOutput = op.DecodeBmp(scope, imgInput) 155 | } 156 | 157 | imgOutput := op.ExpandDims(scope, decodeOutput, op.Const(scope.SubScope("make_batch"), int32(0))) 158 | graph, err := scope.Finalize() 159 | 160 | imgTensor, err := tf.NewTensor(string(request.Data)) // FIX: Convert back to string 161 | if err != nil { 162 | return nil, status.Errorf(codes.Internal, "could not create input tensor: %v", err) 163 | } 164 | 165 | // Execute that graph to decode this one image 166 | imgSess, err := tf.NewSession(graph, nil) 167 | if err != nil { 168 | return nil, status.Errorf(codes.Internal, "could not create image session: %v", err) 169 | } 170 | 171 | // Run the detection 172 | decodedImgTensor, err := imgSess.Run(map[tf.Output]*tf.Tensor{imgInput: imgTensor}, []tf.Output{imgOutput}, nil) 173 | if err != nil { 174 | return nil, status.Errorf(codes.Internal, "error converting image: %v", err) 175 | } 176 | 177 | // Get all the input and output operations 178 | inputop := d.graph.Operation("image_tensor") 179 | // Output ops 180 | o1 := d.graph.Operation("detection_boxes") 181 | o2 := d.graph.Operation("detection_scores") 182 | o3 := d.graph.Operation("detection_classes") 183 | o4 := d.graph.Operation("num_detections") 184 | 185 | start := time.Now() 186 | 187 | output, err := sess.Run( 188 | map[tf.Output]*tf.Tensor{ 189 | inputop.Output(0): decodedImgTensor[0], 190 | }, 191 | []tf.Output{ 192 | o1.Output(0), 193 | o2.Output(0), 194 | o3.Output(0), 195 | o4.Output(0), 196 | }, 197 | nil) 198 | if err != nil { 199 | return nil, status.Errorf(codes.Internal, "could not run detection: %v", err) 200 | } 201 | 202 | scores := output[1].Value().([][]float32)[0] 203 | classes := output[2].Value().([][]float32)[0] 204 | locations := output[0].Value().([][][]float32)[0] 205 | count := int(output[3].Value().([]float32)[0]) 206 | 207 | d.logger.Debugw("Detection", "scores", scores, "classes", classes, "locations", locations, "count", count) 208 | 209 | detections := make([]*odrpc.Detection, 0) 210 | for i := 0; i < count; i++ { 211 | // Get the label 212 | label, ok := d.labels[int(classes[i])] 213 | if !ok { 214 | d.logger.Warnw("Missing label", "index", classes[i]) 215 | label = "unknown" 216 | } 217 | 218 | detections = append(detections, &odrpc.Detection{ 219 | Top: locations[i][0], 220 | Left: locations[i][1], 221 | Bottom: locations[i][2], 222 | Right: locations[i][3], 223 | Label: label, 224 | Confidence: scores[i] * 100.0, 225 | }) 226 | } 227 | 228 | d.logger.Infow("Detection Complete", "id", request.Id, "duration", time.Since(start), "detections", len(detections)) 229 | 230 | return &odrpc.DetectResponse{ 231 | Id: request.Id, 232 | Detections: detections, 233 | }, nil 234 | } 235 | -------------------------------------------------------------------------------- /detector/tflite/detector.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "image" 8 | "os" 9 | "strconv" 10 | "strings" 11 | "time" 12 | 13 | "go.uber.org/zap" 14 | "gocv.io/x/gocv" 15 | "google.golang.org/grpc/codes" 16 | "google.golang.org/grpc/status" 17 | 18 | "github.com/snowzach/doods/conf" 19 | "github.com/snowzach/doods/detector/dconfig" 20 | "github.com/snowzach/doods/odrpc" 21 | 22 | "github.com/snowzach/doods/detector/tflite/go-tflite" 23 | "github.com/snowzach/doods/detector/tflite/go-tflite/delegates/edgetpu" 24 | ) 25 | 26 | const ( 27 | OutputFormat_4_TFLite_Detection_PostProcess = iota 28 | OutputFormat_2_identity 29 | OutputFormat_1_scores 30 | ) 31 | 32 | type detector struct { 33 | config odrpc.Detector 34 | logger *zap.SugaredLogger 35 | 36 | labels map[int]string 37 | model *tflite.Model 38 | inputType tflite.TensorType 39 | outputFormat int 40 | pool chan *tflInterpreter 41 | 42 | devices []edgetpu.Device 43 | numThreads int 44 | hwAccel bool 45 | timeout time.Duration 46 | } 47 | 48 | type tflInterpreter struct { 49 | device *edgetpu.Device 50 | *tflite.Interpreter 51 | } 52 | 53 | func New(c *dconfig.DetectorConfig) (*detector, error) { 54 | 55 | d := &detector{ 56 | labels: make(map[int]string), 57 | logger: zap.S().With("package", "detector.tflite", "name", c.Name), 58 | pool: make(chan *tflInterpreter, c.NumConcurrent), 59 | numThreads: c.NumThreads, 60 | hwAccel: c.HWAccel, 61 | timeout: c.Timeout, 62 | } 63 | 64 | d.config.Name = c.Name 65 | d.config.Type = c.Type 66 | d.config.Model = c.ModelFile 67 | d.config.Labels = make([]string, 0) 68 | 69 | // Create the model 70 | d.model = tflite.NewModelFromFile(d.config.Model) 71 | if d.model == nil { 72 | return nil, fmt.Errorf("could not load model %s", d.config.Model) 73 | } 74 | 75 | // Load labels 76 | f, err := os.Open(c.LabelFile) 77 | if err != nil { 78 | return nil, fmt.Errorf("could not load label", "error", err) 79 | } 80 | defer f.Close() 81 | scanner := bufio.NewScanner(f) 82 | for x := 1; scanner.Scan(); x++ { 83 | fields := strings.SplitAfterN(scanner.Text(), " ", 2) 84 | if len(fields) == 1 { 85 | d.labels[x] = fields[0] 86 | d.config.Labels = append(d.config.Labels, fields[0]) 87 | } else if len(fields) == 2 { 88 | if y, err := strconv.Atoi(strings.TrimSpace(fields[0])); err == nil { 89 | d.labels[y] = strings.TrimSpace(fields[1]) 90 | d.config.Labels = append(d.config.Labels, strings.TrimSpace(fields[1])) 91 | } 92 | } 93 | } 94 | 95 | // If we are using edgetpu, make sure we have one 96 | if d.hwAccel { 97 | 98 | // Get the list of devices 99 | d.devices, err = edgetpu.DeviceList() 100 | if err != nil { 101 | return nil, fmt.Errorf("Could not fetch edgetpu device list: %v", err) 102 | } 103 | if len(d.devices) == 0 { 104 | return nil, fmt.Errorf("no edgetpu devices detected") 105 | } 106 | c.NumConcurrent = len(d.devices) 107 | d.config.Type = "tflite-edgetpu" 108 | 109 | // Enforce a timeout for edgetpu devices if not set 110 | if d.timeout == 0 { 111 | d.timeout = 30 * time.Second 112 | } 113 | 114 | } 115 | 116 | // Create the pool of interpreters 117 | var interpreter *tflInterpreter 118 | for x := 0; x < c.NumConcurrent; x++ { 119 | 120 | interpreter = new(tflInterpreter) 121 | 122 | // Get a device if there is one 123 | if d.hwAccel && len(d.devices) > x { 124 | interpreter.device = &d.devices[x] 125 | } 126 | 127 | interpreter.Interpreter, err = d.newInterpreter(interpreter.device) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | d.pool <- interpreter 133 | } 134 | 135 | // Get the settings from the input tensor 136 | if inputCount := interpreter.GetInputTensorCount(); inputCount != 1 { 137 | return nil, fmt.Errorf("unsupported input tensor count: %d", inputCount) 138 | } 139 | input := interpreter.GetInputTensor(0) 140 | if input.Name() != "normalized_input_image_tensor" && input.Name() != "image" && input.Name() != "input_1" { 141 | return nil, fmt.Errorf("unsupported input tensor name: %s", input.Name()) 142 | } 143 | d.config.Height = int32(input.Dim(1)) 144 | d.config.Width = int32(input.Dim(2)) 145 | d.config.Channels = int32(input.Dim(3)) 146 | d.inputType = input.Type() 147 | if d.inputType != tflite.UInt8 { 148 | return nil, fmt.Errorf("unsupported tensor input type: %s", d.inputType) 149 | } 150 | 151 | // Dump output tensor information 152 | count := interpreter.GetOutputTensorCount() 153 | for x := 0; x < count; x++ { 154 | tensor := interpreter.GetOutputTensor(x) 155 | numDims := tensor.NumDims() 156 | d.logger.Debugw("Tensor Output", "n", x, "name", tensor.Name(), "type", tensor.Type(), "num_dims", numDims, "byte_size", tensor.ByteSize(), "quant", tensor.QuantizationParams(), "shape", tensor.Shape()) 157 | if numDims > 1 { 158 | for y := 0; y < numDims; y++ { 159 | d.logger.Debugw("Tensor Dim", "n", x, "dim", y, "dim_size", tensor.Dim(y)) 160 | } 161 | } 162 | } 163 | 164 | if count == 4 && interpreter.GetOutputTensor(0).Name() == "TFLite_Detection_PostProcess" { 165 | d.outputFormat = OutputFormat_4_TFLite_Detection_PostProcess 166 | } else if count == 2 && interpreter.GetOutputTensor(0).Name() == "Identity" { 167 | d.outputFormat = OutputFormat_2_identity 168 | } else if count == 1 && interpreter.GetOutputTensor(0).Name() == "scores" { 169 | d.outputFormat = OutputFormat_1_scores 170 | // Check the output types 171 | tensor := interpreter.GetOutputTensor(0) 172 | if tensor.Type() != tflite.UInt8 { 173 | return nil, fmt.Errorf("unsupported tensor output type: %s", tensor.Type()) 174 | } 175 | // Ensure the length of the labels match the detection size 176 | for x := int(tensor.ByteSize()) - len(d.labels); x > 0; x-- { 177 | d.labels[x] = "unknown" 178 | } 179 | } else { 180 | return nil, fmt.Errorf("unsupported output tensor count: %d", count) 181 | } 182 | 183 | return d, nil 184 | } 185 | 186 | func (d *detector) newInterpreter(device *edgetpu.Device) (*tflite.Interpreter, error) { 187 | // Options 188 | options := tflite.NewInterpreterOptions() 189 | options.SetNumThread(d.numThreads) 190 | options.SetErrorReporter(func(msg string, user_data interface{}) { 191 | d.logger.Warnw("Error", "message", msg, "user_data", user_data) 192 | }, nil) 193 | 194 | // Use edgetpu 195 | if device != nil { 196 | etpuInstance := edgetpu.New(*device) 197 | if etpuInstance == nil { 198 | return nil, fmt.Errorf("could not initialize edgetpu %s", device.Path) 199 | } 200 | options.AddDelegate(etpuInstance) 201 | } 202 | 203 | interpreter := tflite.NewInterpreter(d.model, options) 204 | if interpreter == nil { 205 | return nil, fmt.Errorf("Could not create interpreter") 206 | } 207 | 208 | // Allocate 209 | status := interpreter.AllocateTensors() 210 | if status != tflite.OK { 211 | return nil, fmt.Errorf("interpreter allocate failed") 212 | } 213 | 214 | return interpreter, nil 215 | } 216 | 217 | func (d *detector) Config() *odrpc.Detector { 218 | return &d.config 219 | } 220 | 221 | func (d *detector) Shutdown() { 222 | close(d.pool) 223 | for { 224 | interpreter := <-d.pool 225 | if interpreter == nil { 226 | break 227 | } 228 | interpreter.Delete() 229 | } 230 | } 231 | 232 | func (d *detector) Detect(ctx context.Context, request *odrpc.DetectRequest) (*odrpc.DetectResponse, error) { 233 | 234 | var data []byte 235 | 236 | start := time.Now() 237 | 238 | // If this is ppm data, move it right to tensorflow 239 | if ppmInfo := FindPPMData(request.Data); ppmInfo != nil && int32(ppmInfo.Width) == d.config.Width && int32(ppmInfo.Height) == d.config.Height { 240 | // Dump data right to data input 241 | data = request.Data[ppmInfo.Offset:] 242 | } else { 243 | 244 | img, err := gocv.IMDecode(request.Data, gocv.IMReadColor) 245 | if err != nil { 246 | return nil, status.Errorf(codes.InvalidArgument, "could not decode image: %v", err) 247 | } else if img.Empty() { 248 | return nil, status.Errorf(codes.InvalidArgument, "could not read image") 249 | } 250 | defer img.Close() 251 | 252 | // Resize it if necessary 253 | dx := int32(img.Cols()) 254 | dy := int32(img.Rows()) 255 | 256 | d.logger.Debugw("Decoded Image", "id", request.Id, "width", dx, "height", dy, "duration", time.Now().Sub(start)) 257 | if dx != d.config.Width || dy != d.config.Height { 258 | gocv.Resize(img, &img, image.Point{X: int(d.config.Width), Y: int(d.config.Height)}, 0, 0, gocv.InterpolationNearestNeighbor) 259 | d.logger.Debugw("Resized Image", "id", request.Id, "width", d.config.Width, "height", d.config.Height, "duration", time.Now().Sub(start)) 260 | } 261 | 262 | // Convert to RGB 263 | gocv.CvtColor(img, &img, gocv.ColorBGRToRGB) 264 | 265 | // Convert to 8-bit unsigned 3 channel if it isn't 266 | if img.Type() != gocv.MatTypeCV8UC3 { 267 | d.logger.Debug("Converted Colorspace", "before", img.Type(), gocv.MatTypeCV8UC3) 268 | img.ConvertTo(&img, gocv.MatTypeCV8UC3) 269 | } 270 | 271 | data = img.ToBytes() 272 | } 273 | 274 | d.logger.Debugw("Image pre-processing complete", "duration", time.Now().Sub(start)) 275 | 276 | // Get an interpreter from the pool 277 | interpreter := <-d.pool 278 | conf.Stop.Add(1) // Wait until detection complete before stopping 279 | defer func() { 280 | d.pool <- interpreter 281 | conf.Stop.Done() 282 | }() 283 | 284 | // Build the tensor input 285 | input := interpreter.GetInputTensor(0) 286 | input.CopyFromBuffer(data) 287 | 288 | inferenceStart := time.Now() 289 | 290 | // Perform the detection 291 | var invokeStatus tflite.Status 292 | complete := make(chan struct{}) 293 | go func() { 294 | invokeStatus = interpreter.Invoke() 295 | close(complete) 296 | }() 297 | 298 | // Wait for complete or timeout if there is one set 299 | if d.timeout > 0 { 300 | select { 301 | case <-complete: 302 | // We're done 303 | case <-time.After(d.timeout): 304 | // The detector is hung, it needs to be reinitialized 305 | d.logger.Errorw("Detector timeout", zap.Any("device", interpreter.device)) 306 | conf.Stop.Stop() // Exit after all threads complete 307 | return nil, status.Errorf(codes.Internal, "detect failed") 308 | } 309 | } 310 | <-complete // Complete no timeout 311 | 312 | // Capture Errors 313 | if invokeStatus != tflite.OK { 314 | d.logger.Errorw("Detector error", "id", request.Id, "status", invokeStatus, zap.Any("device", interpreter.device)) 315 | return &odrpc.DetectResponse{ 316 | Id: request.Id, 317 | Error: "detector error", 318 | }, nil 319 | } 320 | 321 | d.logger.Debugw("Inference complete", "inference_time", time.Now().Sub(inferenceStart), "duration", time.Now().Sub(start)) 322 | 323 | detections := make([]*odrpc.Detection, 0) 324 | 325 | switch d.outputFormat { 326 | case OutputFormat_4_TFLite_Detection_PostProcess: 327 | // Parse results 328 | var countResult float32 329 | interpreter.GetOutputTensor(3).CopyToBuffer(&countResult) 330 | count := int(countResult) 331 | 332 | // Check for a sane count value 333 | if count < 0 || count > 100 { 334 | d.logger.Errorw("Detector invalid results", "id", request.Id, "count", count, zap.Any("device", interpreter.device)) 335 | return &odrpc.DetectResponse{ 336 | Id: request.Id, 337 | Error: "detector invalid result", 338 | }, nil 339 | } 340 | 341 | locations := make([]float32, count*4, count*4) 342 | classes := make([]float32, count, count) 343 | scores := make([]float32, count, count) 344 | 345 | if count > 0 { 346 | interpreter.GetOutputTensor(0).CopyToBuffer(&locations[0]) 347 | interpreter.GetOutputTensor(1).CopyToBuffer(&classes[0]) 348 | interpreter.GetOutputTensor(2).CopyToBuffer(&scores[0]) 349 | } 350 | 351 | for i := 0; i < count; i++ { 352 | // Get the label 353 | label, ok := d.labels[int(classes[i])] 354 | if !ok { 355 | d.logger.Warnw("Missing label", "index", classes[i]) 356 | label = "unknown" 357 | } 358 | 359 | detections = append(detections, &odrpc.Detection{ 360 | Top: locations[(i * 4)], 361 | Left: locations[(i*4)+1], 362 | Bottom: locations[(i*4)+2], 363 | Right: locations[(i*4)+3], 364 | Label: label, 365 | Confidence: scores[i] * 100.0, 366 | }) 367 | } 368 | 369 | case OutputFormat_2_identity: 370 | 371 | // https://github.com/guichristmann/edge-tpu-tiny-yolo 372 | test := make([]float32, 12000) 373 | interpreter.GetOutputTensor(0).CopyToBuffer(&test[0]) 374 | d.logger.Warnw("RESULTS", "test", test) 375 | 376 | case OutputFormat_1_scores: 377 | scores := make([]uint8, len(d.labels), len(d.labels)) 378 | interpreter.GetOutputTensor(0).CopyToBuffer(scores) 379 | 380 | for i := range scores { 381 | // Get the label 382 | label, ok := d.labels[i] 383 | if !ok { 384 | d.logger.Warnw("Missing label", "index", i) 385 | label = "unknown" 386 | } 387 | 388 | detections = append(detections, &odrpc.Detection{ 389 | Top: 0.0, 390 | Left: 0.0, 391 | Bottom: 1.0, 392 | Right: 1.0, 393 | Label: label, 394 | Confidence: 100.0 * (float32(scores[i]) / 255.0), 395 | }) 396 | } 397 | } 398 | 399 | d.logger.Infow("Detection Complete", "id", request.Id, "duration", time.Since(start), "detections", len(detections), zap.Any("device", interpreter.device)) 400 | 401 | return &odrpc.DetectResponse{ 402 | Id: request.Id, 403 | Detections: detections, 404 | }, nil 405 | } 406 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/callback.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | import "C" 4 | import ( 5 | "unsafe" 6 | 7 | "github.com/mattn/go-pointer" 8 | ) 9 | 10 | type callbackInfo struct { 11 | user_data interface{} 12 | f func(msg string, user_data interface{}) 13 | } 14 | 15 | //export _go_error_reporter 16 | func _go_error_reporter(user_data unsafe.Pointer, msg *C.char) { 17 | cb := pointer.Restore(user_data).(*callbackInfo) 18 | cb.f(C.GoString(msg), cb.user_data) 19 | } 20 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/delegates/delegates.go: -------------------------------------------------------------------------------- 1 | package delegates 2 | 3 | import ( 4 | "unsafe" 5 | ) 6 | 7 | type ModifyGraphWithDelegater interface { 8 | ModifyGraphWithDelegate(Delegater) 9 | } 10 | 11 | type Delegater interface { 12 | Delete() 13 | Ptr() unsafe.Pointer 14 | } 15 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/delegates/edgetpu/edgetpu.go: -------------------------------------------------------------------------------- 1 | package edgetpu 2 | 3 | /* 4 | #ifndef GO_EDGETPU_H 5 | #include "edgetpu.go.h" 6 | #include 7 | #endif 8 | #cgo LDFLAGS: -ledgetpu 9 | */ 10 | import "C" 11 | import ( 12 | "fmt" 13 | "unsafe" 14 | 15 | "github.com/snowzach/doods/detector/tflite/go-tflite/delegates" 16 | ) 17 | 18 | const ( 19 | // The Device Types 20 | TypeApexPCI DeviceType = C.EDGETPU_APEX_PCI 21 | TypeApexUSB DeviceType = C.EDGETPU_APEX_USB 22 | ) 23 | 24 | type DeviceType uint32 25 | 26 | type Device struct { 27 | Type DeviceType 28 | Path string 29 | } 30 | 31 | // There are no options 32 | type DelegateOptions struct { 33 | } 34 | 35 | // Delegate is the tflite delegate 36 | type Delegate struct { 37 | d *C.TfLiteDelegate 38 | } 39 | 40 | func New(device Device) delegates.Delegater { 41 | var d *C.TfLiteDelegate 42 | d = C.edgetpu_create_delegate(uint32(device.Type), C.CString(device.Path), nil, 0) 43 | if d == nil { 44 | return nil 45 | } 46 | return &Delegate{ 47 | d: d, 48 | } 49 | } 50 | 51 | // Delete the delegate 52 | func (etpu *Delegate) Delete() { 53 | C.edgetpu_free_delegate(etpu.d) 54 | } 55 | 56 | // Return a pointer 57 | func (etpu *Delegate) Ptr() unsafe.Pointer { 58 | return unsafe.Pointer(etpu.d) 59 | } 60 | 61 | // Version fetches the EdgeTPU runtime version information 62 | func Version() (string, error) { 63 | 64 | version := C.edgetpu_version() 65 | if version == nil { 66 | return "", fmt.Errorf("could not get version") 67 | } 68 | defer C.free(unsafe.Pointer(version)) 69 | return C.GoString(version), nil 70 | 71 | } 72 | 73 | // Verbosity sets the edgetpu verbosity 74 | func Verbosity(v int) { 75 | C.edgetpu_verbosity(C.int(v)) 76 | } 77 | 78 | // DeviceList fetches a list of devices 79 | func DeviceList() ([]Device, error) { 80 | 81 | // Fetch the list of devices 82 | var numDevices C.size_t 83 | cDevices := C.edgetpu_list_devices(&numDevices) 84 | 85 | if cDevices == nil { 86 | return []Device{}, nil 87 | } 88 | 89 | // Cast the result to a Go slice 90 | deviceSlice := (*[1024]C.struct_edgetpu_device)(unsafe.Pointer(cDevices))[:numDevices:numDevices] 91 | 92 | // Convert the list to go struct 93 | var devices []Device 94 | for i := C.size_t(0); i < numDevices; i++ { 95 | devices = append(devices, Device{ 96 | Type: DeviceType(deviceSlice[i]._type), 97 | Path: C.GoString(deviceSlice[i].path), 98 | }) 99 | } 100 | 101 | // Free the list 102 | C.edgetpu_free_devices(cDevices) 103 | 104 | return devices, nil 105 | } 106 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/delegates/edgetpu/edgetpu.go.h: -------------------------------------------------------------------------------- 1 | #ifndef GO_EDGETPU_H 2 | #define GO_EDGETPU_H 3 | 4 | #define _GNU_SOURCE 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #endif 12 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/tflite.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | /* 4 | #ifndef GO_TFLITE_H 5 | #include "tflite.go.h" 6 | #endif 7 | #cgo LDFLAGS: -ltensorflowlite_c 8 | #cgo linux LDFLAGS: -ldl -lrt 9 | */ 10 | import "C" 11 | import ( 12 | "reflect" 13 | "unsafe" 14 | 15 | "github.com/mattn/go-pointer" 16 | "github.com/snowzach/doods/detector/tflite/go-tflite/delegates" 17 | ) 18 | 19 | //go:generate stringer -type TensorType,Status -output type_string.go . 20 | 21 | // Model is TfLiteModel. 22 | type Model struct { 23 | m *C.TfLiteModel 24 | } 25 | 26 | // NewModel create new Model from buffer. 27 | func NewModel(model_data []byte) *Model { 28 | m := C.TfLiteModelCreate(C.CBytes(model_data), C.size_t(len(model_data))) 29 | if m == nil { 30 | return nil 31 | } 32 | return &Model{m: m} 33 | } 34 | 35 | // NewModelFromFile create new Model from file data. 36 | func NewModelFromFile(model_path string) *Model { 37 | ptr := C.CString(model_path) 38 | defer C.free(unsafe.Pointer(ptr)) 39 | 40 | m := C.TfLiteModelCreateFromFile(ptr) 41 | if m == nil { 42 | return nil 43 | } 44 | return &Model{m: m} 45 | } 46 | 47 | // Delete delete instance of model. 48 | func (m *Model) Delete() { 49 | if m != nil { 50 | C.TfLiteModelDelete(m.m) 51 | } 52 | } 53 | 54 | // InterpreterOptions implement TfLiteInterpreterOptions. 55 | type InterpreterOptions struct { 56 | o *C.TfLiteInterpreterOptions 57 | } 58 | 59 | // NewInterpreterOptions create new InterpreterOptions. 60 | func NewInterpreterOptions() *InterpreterOptions { 61 | o := C.TfLiteInterpreterOptionsCreate() 62 | if o == nil { 63 | return nil 64 | } 65 | return &InterpreterOptions{o: o} 66 | } 67 | 68 | // SetNumThread set number of threads. 69 | func (o *InterpreterOptions) SetNumThread(num_threads int) { 70 | C.TfLiteInterpreterOptionsSetNumThreads(o.o, C.int32_t(num_threads)) 71 | } 72 | 73 | // SetErrorRepoter set a function of reporter. 74 | func (o *InterpreterOptions) SetErrorReporter(f func(string, interface{}), user_data interface{}) { 75 | C._TfLiteInterpreterOptionsSetErrorReporter(o.o, pointer.Save(&callbackInfo{ 76 | user_data: user_data, 77 | f: f, 78 | })) 79 | } 80 | 81 | func (o *InterpreterOptions) AddDelegate(d delegates.Delegater) { 82 | C.TfLiteInterpreterOptionsAddDelegate(o.o, (*C.TfLiteDelegate)(d.Ptr())) 83 | } 84 | 85 | // Delete delete instance of InterpreterOptions. 86 | func (o *InterpreterOptions) Delete() { 87 | if o != nil { 88 | C.TfLiteInterpreterOptionsDelete(o.o) 89 | } 90 | } 91 | 92 | // Interpreter implement TfLiteInterpreter. 93 | type Interpreter struct { 94 | i *C.TfLiteInterpreter 95 | } 96 | 97 | // NewInterpreter create new Interpreter. 98 | func NewInterpreter(model *Model, options *InterpreterOptions) *Interpreter { 99 | var o *C.TfLiteInterpreterOptions 100 | if options != nil { 101 | o = options.o 102 | } 103 | i := C.TfLiteInterpreterCreate(model.m, o) 104 | if i == nil { 105 | return nil 106 | } 107 | return &Interpreter{i: i} 108 | } 109 | 110 | // Delete delete instance of Interpreter. 111 | func (i *Interpreter) Delete() { 112 | if i != nil { 113 | C.TfLiteInterpreterDelete(i.i) 114 | } 115 | } 116 | 117 | // Tensor implement TfLiteTensor. 118 | type Tensor struct { 119 | t *C.TfLiteTensor 120 | } 121 | 122 | // GetInputTensorCount return number of input tensors. 123 | func (i *Interpreter) GetInputTensorCount() int { 124 | return int(C.TfLiteInterpreterGetInputTensorCount(i.i)) 125 | } 126 | 127 | // GetInputTensor return input tensor specified by index. 128 | func (i *Interpreter) GetInputTensor(index int) *Tensor { 129 | t := C.TfLiteInterpreterGetInputTensor(i.i, C.int32_t(index)) 130 | if t == nil { 131 | return nil 132 | } 133 | return &Tensor{t: t} 134 | } 135 | 136 | // State implement TfLiteStatus. 137 | type Status int 138 | 139 | const ( 140 | OK Status = 0 141 | FAILED Status = 1 142 | Error 143 | ) 144 | 145 | // ResizeInputTensor resize the tensor specified by index with dims. 146 | func (i *Interpreter) ResizeInputTensor(index int, dims []int) Status { 147 | s := C.TfLiteInterpreterResizeInputTensor(i.i, C.int32_t(index), (*C.int)(unsafe.Pointer(&dims[0])), C.int32_t(len(dims))) 148 | return Status(s) 149 | } 150 | 151 | // AllocateTensor allocate tensors for the interpreter. 152 | func (i *Interpreter) AllocateTensors() Status { 153 | if i != nil { 154 | s := C.TfLiteInterpreterAllocateTensors(i.i) 155 | return Status(s) 156 | } 157 | return Error 158 | } 159 | 160 | // Invoke invoke the task. 161 | func (i *Interpreter) Invoke() Status { 162 | s := C.TfLiteInterpreterInvoke(i.i) 163 | return Status(s) 164 | } 165 | 166 | // GetOutputTensorCount return number of output tensors. 167 | func (i *Interpreter) GetOutputTensorCount() int { 168 | return int(C.TfLiteInterpreterGetOutputTensorCount(i.i)) 169 | } 170 | 171 | // GetOutputTensor return output tensor specified by index. 172 | func (i *Interpreter) GetOutputTensor(index int) *Tensor { 173 | t := C.TfLiteInterpreterGetOutputTensor(i.i, C.int32_t(index)) 174 | if t == nil { 175 | return nil 176 | } 177 | return &Tensor{t: t} 178 | } 179 | 180 | // TensorType is types of the tensor. 181 | type TensorType int 182 | 183 | const ( 184 | NoType TensorType = 0 185 | Float32 TensorType = 1 186 | Int32 TensorType = 2 187 | UInt8 TensorType = 3 188 | Int64 TensorType = 4 189 | String TensorType = 5 190 | Bool TensorType = 6 191 | Int16 TensorType = 7 192 | Complex64 TensorType = 8 193 | Int8 TensorType = 9 194 | ) 195 | 196 | // Type return TensorType. 197 | func (t *Tensor) Type() TensorType { 198 | return TensorType(C.TfLiteTensorType(t.t)) 199 | } 200 | 201 | // NumDims return number of dimensions. 202 | func (t *Tensor) NumDims() int { 203 | return int(C.TfLiteTensorNumDims(t.t)) 204 | } 205 | 206 | // Dim return dimension of the element specified by index. 207 | func (t *Tensor) Dim(index int) int { 208 | return int(C.TfLiteTensorDim(t.t, C.int32_t(index))) 209 | } 210 | 211 | // ByteSize return byte size of the tensor. 212 | func (t *Tensor) ByteSize() uint { 213 | return uint(C.TfLiteTensorByteSize(t.t)) 214 | } 215 | 216 | // Data return pointer of buffer. 217 | func (t *Tensor) Data() unsafe.Pointer { 218 | return C.TfLiteTensorData(t.t) 219 | } 220 | 221 | // Name return name of the tensor. 222 | func (t *Tensor) Name() string { 223 | return C.GoString(C.TfLiteTensorName(t.t)) 224 | } 225 | 226 | // Shape return shape of the tensor. 227 | func (t *Tensor) Shape() []int { 228 | shape := make([]int, t.NumDims()) 229 | for i := 0; i < t.NumDims(); i++ { 230 | shape[i] = t.Dim(i) 231 | } 232 | return shape 233 | } 234 | 235 | // QuantizationParams implement TfLiteQuantizationParams. 236 | type QuantizationParams struct { 237 | Scale float64 238 | ZeroPoint int 239 | } 240 | 241 | // QuantizationParams return quantization parameters of the tensor. 242 | func (t *Tensor) QuantizationParams() QuantizationParams { 243 | q := C.TfLiteTensorQuantizationParams(t.t) 244 | return QuantizationParams{ 245 | Scale: float64(q.scale), 246 | ZeroPoint: int(q.zero_point), 247 | } 248 | } 249 | 250 | // CopyFromBuffer write buffer to the tensor. 251 | func (t *Tensor) CopyFromBuffer(b interface{}) Status { 252 | return Status(C.TfLiteTensorCopyFromBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize()))) 253 | } 254 | 255 | // CopyToBuffer write buffer from the tensor. 256 | func (t *Tensor) CopyToBuffer(b interface{}) Status { 257 | return Status(C.TfLiteTensorCopyToBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize()))) 258 | } 259 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/tflite.go.h: -------------------------------------------------------------------------------- 1 | #ifndef GO_TFLITE_H 2 | #define GO_TFLITE_H 3 | 4 | #define _GNU_SOURCE 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | extern void _go_error_reporter(void*, char*); 11 | 12 | static void 13 | _error_reporter(void *user_data, const char* format, va_list args) { 14 | char *ptr; 15 | if (asprintf(&ptr, format, args)) {} 16 | _go_error_reporter(user_data, ptr); 17 | free(ptr); 18 | } 19 | 20 | static void 21 | _TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions* options, void* user_data) { 22 | TfLiteInterpreterOptionsSetErrorReporter(options, _error_reporter, user_data); 23 | } 24 | #endif 25 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/tflite_experimental.go.h: -------------------------------------------------------------------------------- 1 | #ifndef GO_TFLITE_H 2 | #define GO_TFLITE_H 3 | 4 | #define _GNU_SOURCE 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #endif 11 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/tflite_test.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestXOR(t *testing.T) { 8 | model := NewModelFromFile("testdata/xor_model.tflite") 9 | if model == nil { 10 | t.Fatal("cannot load model") 11 | } 12 | defer model.Delete() 13 | 14 | options := NewInterpreterOptions() 15 | defer options.Delete() 16 | 17 | interpreter := NewInterpreter(model, options) 18 | defer interpreter.Delete() 19 | 20 | interpreter.AllocateTensors() 21 | 22 | tests := []struct { 23 | input []float32 24 | want int 25 | }{ 26 | {input: []float32{0, 0}, want: 0}, 27 | {input: []float32{0, 1}, want: 1}, 28 | {input: []float32{1, 0}, want: 1}, 29 | {input: []float32{1, 1}, want: 0}, 30 | } 31 | 32 | for _, test := range tests { 33 | input := interpreter.GetInputTensor(0) 34 | float32s := input.Float32s() 35 | float32s[0], float32s[1] = test.input[0], test.input[1] 36 | interpreter.Invoke() 37 | 38 | output := interpreter.GetOutputTensor(0) 39 | float32s = output.Float32s() 40 | got := int(float32s[0] + 0.5) 41 | 42 | if got != test.want { 43 | t.Fatalf("want %v but got %v", test.want, got) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/tflite_type.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | /* 4 | #ifndef GO_TFLITE_H 5 | #include "tflite.go.h" 6 | #endif 7 | */ 8 | import "C" 9 | import "errors" 10 | 11 | var ( 12 | // ErrTypeMismatch is type mismatch. 13 | ErrTypeMismatch = errors.New("type mismatch") 14 | // ErrBadTensor is bad tensor. 15 | ErrBadTensor = errors.New("bad tensor") 16 | ) 17 | 18 | // SetInt32s sets int32s. 19 | func (t *Tensor) SetInt32s(v []int32) error { 20 | if t.Type() != Int32 { 21 | return ErrTypeMismatch 22 | } 23 | ptr := C.TfLiteTensorData(t.t) 24 | if ptr == nil { 25 | return ErrBadTensor 26 | } 27 | n := t.ByteSize() / 4 28 | to := (*((*[1<<29 - 1]int32)(ptr)))[:n] 29 | copy(to, v) 30 | return nil 31 | } 32 | 33 | // Int32s returns int32s. 34 | func (t *Tensor) Int32s() []int32 { 35 | if t.Type() != Int32 { 36 | return nil 37 | } 38 | ptr := C.TfLiteTensorData(t.t) 39 | if ptr == nil { 40 | return nil 41 | } 42 | n := t.ByteSize() / 4 43 | return (*((*[1<<29 - 1]int32)(ptr)))[:n] 44 | } 45 | 46 | // SetFloat32s sets float32s. 47 | func (t *Tensor) SetFloat32s(v []float32) error { 48 | if t.Type() != Float32 { 49 | return ErrTypeMismatch 50 | } 51 | ptr := C.TfLiteTensorData(t.t) 52 | if ptr == nil { 53 | return ErrBadTensor 54 | } 55 | n := t.ByteSize() / 4 56 | to := (*((*[1<<29 - 1]float32)(ptr)))[:n] 57 | copy(to, v) 58 | return nil 59 | } 60 | 61 | // Float32s returns float32s. 62 | func (t *Tensor) Float32s() []float32 { 63 | if t.Type() != Float32 { 64 | return nil 65 | } 66 | ptr := C.TfLiteTensorData(t.t) 67 | if ptr == nil { 68 | return nil 69 | } 70 | n := t.ByteSize() / 4 71 | return (*((*[1<<29 - 1]float32)(ptr)))[:n] 72 | } 73 | 74 | // Float32At returns float32 value located in the dimension. 75 | func (t *Tensor) Float32At(at ...int) float32 { 76 | pos := 0 77 | for i := 0; i < t.NumDims(); i++ { 78 | pos = pos*t.Dim(i) + at[i] 79 | } 80 | return t.Float32s()[pos] 81 | } 82 | 83 | // SetUint8s sets uint8s. 84 | func (t *Tensor) SetUint8s(v []uint8) error { 85 | if t.Type() != UInt8 { 86 | return ErrTypeMismatch 87 | } 88 | ptr := C.TfLiteTensorData(t.t) 89 | if ptr == nil { 90 | return ErrBadTensor 91 | } 92 | n := t.ByteSize() 93 | to := (*((*[1<<29 - 1]uint8)(ptr)))[:n] 94 | copy(to, v) 95 | return nil 96 | } 97 | 98 | // UInt8s returns uint8s. 99 | func (t *Tensor) UInt8s() []uint8 { 100 | if t.Type() != UInt8 { 101 | return nil 102 | } 103 | ptr := C.TfLiteTensorData(t.t) 104 | if ptr == nil { 105 | return nil 106 | } 107 | n := t.ByteSize() 108 | return (*((*[1<<29 - 1]uint8)(ptr)))[:n] 109 | } 110 | 111 | // SetInt64s sets int64s. 112 | func (t *Tensor) SetInt64s(v []int64) error { 113 | if t.Type() != Int64 { 114 | return ErrTypeMismatch 115 | } 116 | ptr := C.TfLiteTensorData(t.t) 117 | if ptr == nil { 118 | return ErrBadTensor 119 | } 120 | n := t.ByteSize() / 8 121 | to := (*((*[1<<28 - 1]int64)(ptr)))[:n] 122 | copy(to, v) 123 | return nil 124 | } 125 | 126 | // Int64s returns int64s. 127 | func (t *Tensor) Int64s() []int64 { 128 | if t.Type() != Int64 { 129 | return nil 130 | } 131 | ptr := C.TfLiteTensorData(t.t) 132 | if ptr == nil { 133 | return nil 134 | } 135 | n := t.ByteSize() / 8 136 | return (*((*[1<<28 - 1]int64)(ptr)))[:n] 137 | } 138 | 139 | // SetInt16s sets int16s. 140 | func (t *Tensor) SetInt16s(v []int16) error { 141 | if t.Type() != Int16 { 142 | return ErrTypeMismatch 143 | } 144 | ptr := C.TfLiteTensorData(t.t) 145 | if ptr == nil { 146 | return ErrBadTensor 147 | } 148 | n := t.ByteSize() / 2 149 | to := (*((*[1<<29 - 1]int16)(ptr)))[:n] 150 | copy(to, v) 151 | return nil 152 | } 153 | 154 | // Int16s returns int16s. 155 | func (t *Tensor) Int16s() []int16 { 156 | if t.Type() != Int16 { 157 | return nil 158 | } 159 | ptr := C.TfLiteTensorData(t.t) 160 | if ptr == nil { 161 | return nil 162 | } 163 | n := t.ByteSize() / 2 164 | return (*((*[1<<29 - 1]int16)(ptr)))[:n] 165 | } 166 | 167 | // SetInt8s sets int8s. 168 | func (t *Tensor) SetInt8s(v []int8) error { 169 | if t.Type() != Int8 { 170 | return ErrTypeMismatch 171 | } 172 | ptr := C.TfLiteTensorData(t.t) 173 | if ptr == nil { 174 | return ErrBadTensor 175 | } 176 | n := t.ByteSize() 177 | to := (*((*[1<<29 - 1]int8)(ptr)))[:n] 178 | copy(to, v) 179 | return nil 180 | } 181 | 182 | // Int8s returns int8s. 183 | func (t *Tensor) Int8s() []int8 { 184 | if t.Type() != Int8 { 185 | return nil 186 | } 187 | ptr := C.TfLiteTensorData(t.t) 188 | if ptr == nil { 189 | return nil 190 | } 191 | n := t.ByteSize() 192 | return (*((*[1<<29 - 1]int8)(ptr)))[:n] 193 | } 194 | 195 | // String returns name of tensor. 196 | func (t *Tensor) String() string { 197 | return t.Name() 198 | } 199 | -------------------------------------------------------------------------------- /detector/tflite/go-tflite/type_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type TensorType,Status -output type_string.go ."; DO NOT EDIT. 2 | 3 | package tflite 4 | 5 | import "strconv" 6 | 7 | const _TensorType_name = "NoTypeFloat32Int32UInt8Int64StringBoolInt16Complex64Int8" 8 | 9 | var _TensorType_index = [...]uint8{0, 6, 13, 18, 23, 28, 34, 38, 43, 52, 56} 10 | 11 | func (i TensorType) String() string { 12 | if i < 0 || i >= TensorType(len(_TensorType_index)-1) { 13 | return "TensorType(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _TensorType_name[_TensorType_index[i]:_TensorType_index[i+1]] 16 | } 17 | 18 | const _Status_name = "OKFAILED" 19 | 20 | var _Status_index = [...]uint8{0, 2, 8} 21 | 22 | func (i Status) String() string { 23 | if i < 0 || i >= Status(len(_Status_index)-1) { 24 | return "Status(" + strconv.FormatInt(int64(i), 10) + ")" 25 | } 26 | return _Status_name[_Status_index[i]:_Status_index[i+1]] 27 | } 28 | -------------------------------------------------------------------------------- /detector/tflite/ppm.go: -------------------------------------------------------------------------------- 1 | package tflite 2 | 3 | import ( 4 | "strconv" 5 | ) 6 | 7 | type PPMInfo struct { 8 | Width int 9 | Height int 10 | Offset int 11 | } 12 | 13 | func isSpace(b byte) bool { 14 | switch b { 15 | case ' ': 16 | return true 17 | case '\t': 18 | return true 19 | case '\n': 20 | return true 21 | case '\r': 22 | return true 23 | } 24 | return false 25 | } 26 | 27 | func FindPPMData(data []byte) *PPMInfo { 28 | 29 | i := new(PPMInfo) 30 | 31 | // Get the next header token 32 | getToken := func() string { 33 | // Get Token 34 | token := make([]byte, 0) 35 | for i.Offset < len(data) && !isSpace(data[i.Offset]) { 36 | token = append(token, data[i.Offset]) 37 | i.Offset++ 38 | } 39 | // Eat Spaces 40 | for i.Offset < len(data) && isSpace(data[i.Offset]) { 41 | i.Offset++ 42 | } 43 | return string(token) 44 | } 45 | 46 | // First Token 47 | if getToken() != "P6" { 48 | return nil 49 | } 50 | 51 | var err error 52 | 53 | i.Width, err = strconv.Atoi(getToken()) 54 | if err != nil { 55 | return nil 56 | } 57 | 58 | i.Height, err = strconv.Atoi(getToken()) 59 | if err != nil { 60 | return nil 61 | } 62 | 63 | if maxVal, err := strconv.Atoi(getToken()); err != nil || maxVal != 255 { 64 | return nil 65 | } 66 | 67 | return i 68 | } 69 | -------------------------------------------------------------------------------- /examples/grpcclient-single.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/metadata" 14 | 15 | "github.com/snowzach/doods/odrpc" 16 | ) 17 | 18 | func main() { 19 | 20 | if len(os.Args) < 4 { 21 | fmt.Println("How to run:\n\tgrpcclient-single [source file] [doods server] [detector]") 22 | return 23 | } 24 | 25 | // parse args 26 | sourceFile := os.Args[1] 27 | server := os.Args[2] 28 | detector := os.Args[3] 29 | 30 | dialOptions := []grpc.DialOption{ 31 | grpc.WithBlock(), 32 | grpc.WithInsecure(), 33 | } 34 | 35 | // Set up a connection to the gRPC server. 36 | conn, err := grpc.Dial(server, dialOptions...) 37 | if err != nil { 38 | log.Fatalf("Could not connect: %v", err) 39 | } 40 | defer conn.Close() 41 | 42 | // gRPC version Client 43 | client := odrpc.NewOdrpcClient(conn) 44 | 45 | img, err := ioutil.ReadFile(sourceFile) 46 | 47 | request := &odrpc.DetectRequest{ 48 | Data: img, 49 | DetectorName: detector, 50 | Detect: map[string]float32{ 51 | "*": 50, // 52 | }, 53 | } 54 | 55 | // Authentication information - ignored if not requried 56 | ctx := metadata.AppendToOutgoingContext(context.Background(), odrpc.DoodsAuthKeyHeader, "test123") 57 | 58 | start := time.Now() 59 | var wg sync.WaitGroup 60 | for x := 0; x < 200; x++ { 61 | wg.Add(1) 62 | go func() { 63 | response, err := client.Detect(ctx, request) 64 | if err != nil { 65 | log.Printf("Error: %v", err) 66 | } else { 67 | log.Printf("Processed: %v", response) 68 | } 69 | wg.Done() 70 | }() 71 | } 72 | wg.Wait() 73 | log.Printf("Done. Took: %v", time.Since(start).Seconds()) 74 | } 75 | -------------------------------------------------------------------------------- /examples/grpcclient-stream.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "log" 9 | "os" 10 | "sync" 11 | "time" 12 | 13 | "google.golang.org/grpc" 14 | "google.golang.org/grpc/metadata" 15 | 16 | "github.com/snowzach/doods/odrpc" 17 | ) 18 | 19 | func main() { 20 | 21 | if len(os.Args) < 4 { 22 | fmt.Println("How to run:\n\tgrpcclient-stream [source file] [doods server] [detector]") 23 | return 24 | } 25 | 26 | // parse args 27 | sourceFile := os.Args[1] 28 | server := os.Args[2] 29 | detector := os.Args[3] 30 | 31 | dialOptions := []grpc.DialOption{ 32 | grpc.WithBlock(), 33 | grpc.WithInsecure(), 34 | } 35 | 36 | // Set up a connection to the gRPC server. 37 | conn, err := grpc.Dial(server, dialOptions...) 38 | if err != nil { 39 | log.Fatalf("Could not connect: %v", err) 40 | } 41 | defer conn.Close() 42 | 43 | // gRPC version Client 44 | client := odrpc.NewOdrpcClient(conn) 45 | 46 | // Create the request 47 | img, err := ioutil.ReadFile(sourceFile) 48 | if err != nil { 49 | log.Fatalf("Could not load %s %v", sourceFile, err) 50 | } 51 | 52 | // Authentication information - ignored if not requried 53 | ctx := metadata.AppendToOutgoingContext(context.Background(), odrpc.DoodsAuthKeyHeader, "test123") 54 | // Open Stream 55 | stream, err := client.DetectStream(ctx) 56 | if err != nil { 57 | log.Fatalf("Could not stream: %v", err) 58 | } 59 | 60 | start := time.Now() 61 | var wg sync.WaitGroup 62 | doneSend := make(chan struct{}) 63 | 64 | // Send requests 65 | go func() { 66 | for x := 0; x < 200; x++ { 67 | wg.Add(1) 68 | request := &odrpc.DetectRequest{ 69 | Id: fmt.Sprintf("%d", x), 70 | DetectorName: detector, 71 | Data: img, 72 | Detect: map[string]float32{ 73 | "*": 50, // 74 | }, 75 | } 76 | if err := stream.Send(request); err != nil { 77 | log.Fatalf("could not stream send %v", err) 78 | } 79 | } 80 | close(doneSend) 81 | }() 82 | 83 | // Parse results 84 | go func() { 85 | for { 86 | response, err := stream.Recv() 87 | if err == io.EOF { 88 | break 89 | } 90 | if err != nil { 91 | log.Fatalf("can not receive %v", err) 92 | } 93 | wg.Done() 94 | log.Printf("Processed: %v", response) 95 | } 96 | }() 97 | 98 | // Wait until done sending and done receiving then close the stream 99 | <-doneSend 100 | wg.Wait() 101 | if err := stream.CloseSend(); err != nil { 102 | log.Fatal(err.Error()) 103 | } 104 | log.Printf("Done. Took: %v", time.Since(start).Seconds()) 105 | } 106 | -------------------------------------------------------------------------------- /examples/rtspdetector.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "image" 7 | "image/color" 8 | "io" 9 | "log" 10 | "net/http" 11 | _ "net/http/pprof" 12 | "os" 13 | "sync" 14 | "time" 15 | 16 | empty "github.com/golang/protobuf/ptypes/empty" 17 | "github.com/snowzach/mjpeg" 18 | "gocv.io/x/gocv" 19 | "google.golang.org/grpc" 20 | 21 | "github.com/snowzach/doods/odrpc" 22 | ) 23 | 24 | var ( 25 | deviceID int 26 | err error 27 | capture *gocv.VideoCapture 28 | stream *mjpeg.Stream 29 | ) 30 | 31 | func main() { 32 | if len(os.Args) < 5 { 33 | fmt.Println("How to run:\n\trtspdetector [source url] [host:port] [doods server] [detector]") 34 | return 35 | } 36 | 37 | // parse args 38 | source := os.Args[1] 39 | host := os.Args[2] 40 | server := os.Args[3] 41 | detector := os.Args[4] 42 | 43 | // open webcam 44 | capture, err = gocv.OpenVideoCapture(source) 45 | if err != nil { 46 | fmt.Printf("Error opening capture device: %v: %v\n", source, err) 47 | return 48 | } 49 | defer capture.Close() 50 | 51 | // create the mjpeg stream 52 | stream = mjpeg.NewStream(50 * time.Millisecond) 53 | 54 | // start capturing 55 | go mjpegCapture(server, detector) 56 | 57 | fmt.Println("Capturing. Point your browser to " + host) 58 | 59 | // start http server 60 | http.Handle("/", stream) 61 | log.Fatal(http.ListenAndServe(host, nil)) 62 | } 63 | 64 | func mjpegCapture(server string, detectorName string) { 65 | 66 | dialOptions := []grpc.DialOption{ 67 | grpc.WithBlock(), 68 | grpc.WithInsecure(), 69 | grpc.WithTimeout(5 * time.Second), 70 | grpc.WithMaxMsgSize(64000000), 71 | } 72 | 73 | // Set up a connection to the gRPC server. 74 | conn, err := grpc.Dial(server, dialOptions...) 75 | if err != nil { 76 | log.Fatalf("Could not connect to doods: %v", err) 77 | } 78 | defer conn.Close() 79 | 80 | // gRPC version Client 81 | client := odrpc.NewOdrpcClient(conn) 82 | 83 | // Fetch the detectors available 84 | detectorsResponse, err := client.GetDetectors(context.Background(), &empty.Empty{}) 85 | if err != nil { 86 | log.Fatalf("Could not get detectors: %v", err) 87 | } 88 | // Find our requested detector 89 | var detector *odrpc.Detector 90 | for _, d := range detectorsResponse.Detectors { 91 | if d.Name == detectorName { 92 | detector = d 93 | break 94 | } 95 | } 96 | if detector == nil { 97 | log.Fatalf("Could not find detector: %s\n", detectorName) 98 | } 99 | 100 | // Start the stream 101 | detectStream, err := client.DetectStream(context.Background()) 102 | if err != nil { 103 | log.Fatalf("Could not stream: %v", err) 104 | } 105 | 106 | img := gocv.NewMat() 107 | defer img.Close() 108 | detectImg := gocv.NewMat() 109 | defer detectImg.Close() 110 | 111 | // color for the rect for detectins 112 | green := color.RGBA{0, 255, 0, 0} 113 | var rs = make([]image.Rectangle, 0) 114 | var labels = make([]string, 0) 115 | var confidences = make([]float32, 0) 116 | var m sync.Mutex 117 | var detectorReady bool = true 118 | 119 | for { 120 | // Read an image 121 | if ok := capture.Read(&img); !ok { 122 | fmt.Printf("Device closed: %v\n", deviceID) 123 | return 124 | } 125 | if img.Empty() { 126 | continue 127 | } 128 | height := img.Rows() 129 | width := img.Cols() 130 | 131 | // Setup detection 132 | request := &odrpc.DetectRequest{ 133 | DetectorName: detector.Name, 134 | Detect: map[string]float32{ 135 | "*": 90, // 136 | }, 137 | } 138 | 139 | m.Lock() 140 | if detectorReady { 141 | 142 | // If the detector requires a specific size, resize before setting the data to the detector 143 | if detector.Width > 0 { 144 | gocv.Resize(img.Region(image.Rectangle{Min: image.Point{X: 0, Y: 0}, Max: image.Point{X: width, Y: height}}), &detectImg, image.Point{X: int(detector.Width), Y: int(detector.Height)}, 0.0, 0.0, gocv.InterpolationNearestNeighbor) 145 | request.Data, err = gocv.IMEncode(".ppm", detectImg) 146 | if err != nil { 147 | continue 148 | } 149 | } else { 150 | request.Data, err = gocv.IMEncode(".ppm", img) 151 | if err != nil { 152 | continue 153 | } 154 | } 155 | 156 | detectorReady = false 157 | if err := detectStream.Send(request); err != nil { 158 | log.Fatalf("could not stream send %v", err) 159 | } 160 | go func() { 161 | response, err := detectStream.Recv() 162 | if err == io.EOF { 163 | log.Fatalf("can not receive %v", err) 164 | } 165 | if err != nil { 166 | log.Fatalf("can not receive %v", err) 167 | } 168 | log.Printf("Processed: %v", response) 169 | 170 | m.Lock() 171 | detections := len(response.Detections) 172 | rs = make([]image.Rectangle, detections, detections) 173 | labels = make([]string, detections, detections) 174 | confidences = make([]float32, detections, detections) 175 | for x := 0; x < detections; x++ { 176 | rs[x] = image.Rectangle{ 177 | Min: image.Point{X: int(response.Detections[x].Left * float32(width)), Y: int(response.Detections[x].Top * float32(height))}, 178 | Max: image.Point{X: int(response.Detections[x].Right * float32(width)), Y: int(response.Detections[x].Bottom * float32(height))}, 179 | } 180 | labels[x] = response.Detections[x].Label 181 | confidences[x] = response.Detections[x].Confidence 182 | } 183 | detectorReady = true 184 | m.Unlock() 185 | }() 186 | } 187 | 188 | // Keep drawing the same rectangles until a new detection is ready 189 | for x := 0; x < len(rs); x++ { 190 | gocv.Rectangle(&img, rs[x], green, 1) 191 | pt := image.Pt(rs[x].Min.X, rs[x].Min.Y-2) 192 | gocv.PutText(&img, fmt.Sprintf("%s %0.0f", labels[x], confidences[x]), pt, gocv.FontHersheyPlain, 1.5, green, 1) 193 | } 194 | m.Unlock() 195 | 196 | // re-encode with boxes 197 | request.Data, err = gocv.IMEncode(".jpg", img) 198 | if err != nil { 199 | continue 200 | } 201 | 202 | // buf, _ := gocv.IMEncode(".jpg", img) 203 | stream.UpdateJPEG(request.Data) 204 | 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /fetch_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p models 3 | # coco_ssd_mobilenet_v1_1.0_quant_2018_06_29 4 | wget https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && unzip coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && rm coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip && mv detect.tflite models/coco_ssd_mobilenet_v1_1.0_quant.tflite && rm labelmap.txt 5 | wget https://dl.google.com/coral/canned_models/coco_labels.txt && mv coco_labels.txt models/coco_labels0.txt 6 | # mobilenet_ssd_v2_coco_quant_postprocess_edgetpu 7 | wget https://dl.google.com/coral/canned_models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite && mv mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite 8 | # faster_rcnn_inception_v2_coco_2018_01_28 9 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz && tar -zxvf faster_rcnn_inception_v2_coco_2018_01_28.tar.gz faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb --strip=1 && mv frozen_inference_graph.pb models/faster_rcnn_inception_v2_coco_2018_01_28.pb && rm faster_rcnn_inception_v2_coco_2018_01_28.tar.gz 10 | wget https://raw.githubusercontent.com/amikelive/coco-labels/master/coco-labels-2014_2017.txt && mv coco-labels-2014_2017.txt models/coco_labels1.txt 11 | 12 | cat << EOF > example.yaml 13 | doods: 14 | detectors: 15 | - name: default 16 | type: tflite 17 | modelFile: models/coco_ssd_mobilenet_v1_1.0_quant.tflite 18 | labelFile: models/coco_labels0.txt 19 | numThreads: 0 20 | numConcurrent: 4 21 | - name: edgetpu 22 | type: tflite 23 | modelFile: models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite 24 | labelFile: models/coco_labels0.txt 25 | numThreads: 0 26 | numConcurrent: 4 27 | hwAccel: true 28 | - name: tensorflow 29 | type: tensorflow 30 | modelFile: models/faster_rcnn_inception_v2_coco_2018_01_28.pb 31 | labelFile: models/coco_labels1.txt 32 | width: 224 33 | height: 224 34 | numThreads: 0 35 | numConcurrent: 4 36 | EOF -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/snowzach/doods 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/blendle/zapdriver v1.3.1 7 | github.com/go-chi/chi v1.5.0 8 | github.com/go-chi/cors v1.1.1 9 | github.com/go-chi/render v1.0.1 10 | github.com/gogo/protobuf v1.3.1 11 | github.com/golang/protobuf v1.4.3 12 | github.com/grpc-ecosystem/go-grpc-middleware v1.2.2 13 | github.com/grpc-ecosystem/grpc-gateway v1.16.0 14 | github.com/lmittmann/ppm v1.0.0 15 | github.com/mattn/go-pointer v0.0.1 16 | github.com/snowzach/certtools v1.0.2 17 | github.com/spf13/cobra v1.1.1 18 | github.com/spf13/viper v1.7.1 19 | github.com/tensorflow/tensorflow v2.0.3+incompatible 20 | go.uber.org/zap v1.16.0 21 | gocv.io/x/gocv v0.25.1-0.20201108120252-7f525fdbcb78 22 | golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 23 | golang.org/x/net v0.0.0-20201110031124-69a78807bb2b 24 | google.golang.org/genproto v0.0.0-20201119123407-9b1e624d6bc4 25 | google.golang.org/grpc v1.33.2 26 | ) 27 | 28 | //replace github.com/tensorflow/tensorflow v2.3.1+incompatible => github.com/tensorflow/tensorflow v2.0.3+incompatible 29 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/snowzach/doods/cmd" 5 | ) 6 | 7 | func main() { 8 | cmd.Execute() 9 | } 10 | -------------------------------------------------------------------------------- /odrpc/odrpc.go: -------------------------------------------------------------------------------- 1 | package odrpc 2 | 3 | const ( 4 | DoodsAuthKeyHeader = "doods-auth-key" 5 | ) 6 | -------------------------------------------------------------------------------- /odrpc/raw.go: -------------------------------------------------------------------------------- 1 | package odrpc 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | ) 7 | 8 | // Images are byte arrays 9 | type Raw []byte 10 | 11 | // MarshalJSON for Raw fields is represented as base64 12 | func (r *Raw) MarshalJSON() ([]byte, error) { 13 | if r == nil || *r == nil || len(*r) == 0 { 14 | return []byte(`""`), nil 15 | } 16 | return []byte(`"` + base64.StdEncoding.EncodeToString(*r) + `"`), nil 17 | } 18 | 19 | // UnmarshalJSON for Raw fields is parsed as base64 20 | func (r *Raw) UnmarshalJSON(in []byte) error { 21 | var ret *[]byte 22 | if in == nil || len(in) == 0 { 23 | *ret = []byte{} 24 | return nil 25 | } 26 | // Remove the beginning and ending " 27 | in = bytes.Trim(in, `"`) 28 | *r = make([]byte, base64.StdEncoding.DecodedLen(len(in))) 29 | _, err := base64.StdEncoding.Decode(*r, in) 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /odrpc/rpc.pb.gw.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. 2 | // source: odrpc/rpc.proto 3 | 4 | /* 5 | Package odrpc is a reverse proxy. 6 | 7 | It translates gRPC into RESTful JSON APIs. 8 | */ 9 | package odrpc 10 | 11 | import ( 12 | "context" 13 | "io" 14 | "net/http" 15 | 16 | "github.com/golang/protobuf/proto" 17 | "github.com/golang/protobuf/ptypes/empty" 18 | "github.com/grpc-ecosystem/grpc-gateway/runtime" 19 | "github.com/grpc-ecosystem/grpc-gateway/utilities" 20 | "google.golang.org/grpc" 21 | "google.golang.org/grpc/codes" 22 | "google.golang.org/grpc/grpclog" 23 | "google.golang.org/grpc/status" 24 | ) 25 | 26 | var _ codes.Code 27 | var _ io.Reader 28 | var _ status.Status 29 | var _ = runtime.String 30 | var _ = utilities.NewDoubleArray 31 | 32 | func request_Odrpc_GetDetectors_0(ctx context.Context, marshaler runtime.Marshaler, client OdrpcClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 33 | var protoReq empty.Empty 34 | var metadata runtime.ServerMetadata 35 | 36 | msg, err := client.GetDetectors(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) 37 | return msg, metadata, err 38 | 39 | } 40 | 41 | func local_request_Odrpc_GetDetectors_0(ctx context.Context, marshaler runtime.Marshaler, server OdrpcServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 42 | var protoReq empty.Empty 43 | var metadata runtime.ServerMetadata 44 | 45 | msg, err := server.GetDetectors(ctx, &protoReq) 46 | return msg, metadata, err 47 | 48 | } 49 | 50 | func request_Odrpc_Detect_0(ctx context.Context, marshaler runtime.Marshaler, client OdrpcClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 51 | var protoReq DetectRequest 52 | var metadata runtime.ServerMetadata 53 | 54 | newReader, berr := utilities.IOReaderFactory(req.Body) 55 | if berr != nil { 56 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) 57 | } 58 | if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { 59 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) 60 | } 61 | 62 | msg, err := client.Detect(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) 63 | return msg, metadata, err 64 | 65 | } 66 | 67 | func local_request_Odrpc_Detect_0(ctx context.Context, marshaler runtime.Marshaler, server OdrpcServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 68 | var protoReq DetectRequest 69 | var metadata runtime.ServerMetadata 70 | 71 | newReader, berr := utilities.IOReaderFactory(req.Body) 72 | if berr != nil { 73 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) 74 | } 75 | if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { 76 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) 77 | } 78 | 79 | msg, err := server.Detect(ctx, &protoReq) 80 | return msg, metadata, err 81 | 82 | } 83 | 84 | func request_Odrpc_Detect_1(ctx context.Context, marshaler runtime.Marshaler, client OdrpcClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 85 | var protoReq DetectRequest 86 | var metadata runtime.ServerMetadata 87 | 88 | newReader, berr := utilities.IOReaderFactory(req.Body) 89 | if berr != nil { 90 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) 91 | } 92 | if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { 93 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) 94 | } 95 | 96 | var ( 97 | val string 98 | ok bool 99 | err error 100 | _ = err 101 | ) 102 | 103 | val, ok = pathParams["detector_name"] 104 | if !ok { 105 | return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "detector_name") 106 | } 107 | 108 | protoReq.DetectorName, err = runtime.String(val) 109 | 110 | if err != nil { 111 | return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "detector_name", err) 112 | } 113 | 114 | msg, err := client.Detect(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) 115 | return msg, metadata, err 116 | 117 | } 118 | 119 | func local_request_Odrpc_Detect_1(ctx context.Context, marshaler runtime.Marshaler, server OdrpcServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 120 | var protoReq DetectRequest 121 | var metadata runtime.ServerMetadata 122 | 123 | newReader, berr := utilities.IOReaderFactory(req.Body) 124 | if berr != nil { 125 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) 126 | } 127 | if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { 128 | return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) 129 | } 130 | 131 | var ( 132 | val string 133 | ok bool 134 | err error 135 | _ = err 136 | ) 137 | 138 | val, ok = pathParams["detector_name"] 139 | if !ok { 140 | return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "detector_name") 141 | } 142 | 143 | protoReq.DetectorName, err = runtime.String(val) 144 | 145 | if err != nil { 146 | return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "detector_name", err) 147 | } 148 | 149 | msg, err := server.Detect(ctx, &protoReq) 150 | return msg, metadata, err 151 | 152 | } 153 | 154 | // RegisterOdrpcHandlerServer registers the http handlers for service Odrpc to "mux". 155 | // UnaryRPC :call OdrpcServer directly. 156 | // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. 157 | func RegisterOdrpcHandlerServer(ctx context.Context, mux *runtime.ServeMux, server OdrpcServer) error { 158 | 159 | mux.Handle("GET", pattern_Odrpc_GetDetectors_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 160 | ctx, cancel := context.WithCancel(req.Context()) 161 | defer cancel() 162 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 163 | rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) 164 | if err != nil { 165 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 166 | return 167 | } 168 | resp, md, err := local_request_Odrpc_GetDetectors_0(rctx, inboundMarshaler, server, req, pathParams) 169 | ctx = runtime.NewServerMetadataContext(ctx, md) 170 | if err != nil { 171 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 172 | return 173 | } 174 | 175 | forward_Odrpc_GetDetectors_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 176 | 177 | }) 178 | 179 | mux.Handle("POST", pattern_Odrpc_Detect_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 180 | ctx, cancel := context.WithCancel(req.Context()) 181 | defer cancel() 182 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 183 | rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) 184 | if err != nil { 185 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 186 | return 187 | } 188 | resp, md, err := local_request_Odrpc_Detect_0(rctx, inboundMarshaler, server, req, pathParams) 189 | ctx = runtime.NewServerMetadataContext(ctx, md) 190 | if err != nil { 191 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 192 | return 193 | } 194 | 195 | forward_Odrpc_Detect_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 196 | 197 | }) 198 | 199 | mux.Handle("POST", pattern_Odrpc_Detect_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 200 | ctx, cancel := context.WithCancel(req.Context()) 201 | defer cancel() 202 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 203 | rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) 204 | if err != nil { 205 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 206 | return 207 | } 208 | resp, md, err := local_request_Odrpc_Detect_1(rctx, inboundMarshaler, server, req, pathParams) 209 | ctx = runtime.NewServerMetadataContext(ctx, md) 210 | if err != nil { 211 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 212 | return 213 | } 214 | 215 | forward_Odrpc_Detect_1(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 216 | 217 | }) 218 | 219 | return nil 220 | } 221 | 222 | // RegisterOdrpcHandlerFromEndpoint is same as RegisterOdrpcHandler but 223 | // automatically dials to "endpoint" and closes the connection when "ctx" gets done. 224 | func RegisterOdrpcHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { 225 | conn, err := grpc.Dial(endpoint, opts...) 226 | if err != nil { 227 | return err 228 | } 229 | defer func() { 230 | if err != nil { 231 | if cerr := conn.Close(); cerr != nil { 232 | grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) 233 | } 234 | return 235 | } 236 | go func() { 237 | <-ctx.Done() 238 | if cerr := conn.Close(); cerr != nil { 239 | grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) 240 | } 241 | }() 242 | }() 243 | 244 | return RegisterOdrpcHandler(ctx, mux, conn) 245 | } 246 | 247 | // RegisterOdrpcHandler registers the http handlers for service Odrpc to "mux". 248 | // The handlers forward requests to the grpc endpoint over "conn". 249 | func RegisterOdrpcHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { 250 | return RegisterOdrpcHandlerClient(ctx, mux, NewOdrpcClient(conn)) 251 | } 252 | 253 | // RegisterOdrpcHandlerClient registers the http handlers for service Odrpc 254 | // to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "OdrpcClient". 255 | // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "OdrpcClient" 256 | // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in 257 | // "OdrpcClient" to call the correct interceptors. 258 | func RegisterOdrpcHandlerClient(ctx context.Context, mux *runtime.ServeMux, client OdrpcClient) error { 259 | 260 | mux.Handle("GET", pattern_Odrpc_GetDetectors_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 261 | ctx, cancel := context.WithCancel(req.Context()) 262 | defer cancel() 263 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 264 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 265 | if err != nil { 266 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 267 | return 268 | } 269 | resp, md, err := request_Odrpc_GetDetectors_0(rctx, inboundMarshaler, client, req, pathParams) 270 | ctx = runtime.NewServerMetadataContext(ctx, md) 271 | if err != nil { 272 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 273 | return 274 | } 275 | 276 | forward_Odrpc_GetDetectors_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 277 | 278 | }) 279 | 280 | mux.Handle("POST", pattern_Odrpc_Detect_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 281 | ctx, cancel := context.WithCancel(req.Context()) 282 | defer cancel() 283 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 284 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 285 | if err != nil { 286 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 287 | return 288 | } 289 | resp, md, err := request_Odrpc_Detect_0(rctx, inboundMarshaler, client, req, pathParams) 290 | ctx = runtime.NewServerMetadataContext(ctx, md) 291 | if err != nil { 292 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 293 | return 294 | } 295 | 296 | forward_Odrpc_Detect_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 297 | 298 | }) 299 | 300 | mux.Handle("POST", pattern_Odrpc_Detect_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 301 | ctx, cancel := context.WithCancel(req.Context()) 302 | defer cancel() 303 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 304 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 305 | if err != nil { 306 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 307 | return 308 | } 309 | resp, md, err := request_Odrpc_Detect_1(rctx, inboundMarshaler, client, req, pathParams) 310 | ctx = runtime.NewServerMetadataContext(ctx, md) 311 | if err != nil { 312 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 313 | return 314 | } 315 | 316 | forward_Odrpc_Detect_1(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 317 | 318 | }) 319 | 320 | return nil 321 | } 322 | 323 | var ( 324 | pattern_Odrpc_GetDetectors_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"detectors"}, "", runtime.AssumeColonVerbOpt(true))) 325 | 326 | pattern_Odrpc_Detect_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"detect"}, "", runtime.AssumeColonVerbOpt(true))) 327 | 328 | pattern_Odrpc_Detect_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 1, 0, 4, 1, 5, 1}, []string{"detect", "detector_name"}, "", runtime.AssumeColonVerbOpt(true))) 329 | ) 330 | 331 | var ( 332 | forward_Odrpc_GetDetectors_0 = runtime.ForwardResponseMessage 333 | 334 | forward_Odrpc_Detect_0 = runtime.ForwardResponseMessage 335 | 336 | forward_Odrpc_Detect_1 = runtime.ForwardResponseMessage 337 | ) 338 | -------------------------------------------------------------------------------- /odrpc/rpc.proto: -------------------------------------------------------------------------------- 1 | syntax="proto3"; 2 | package odrpc; 3 | 4 | import "google/api/annotations.proto"; 5 | import "google/protobuf/empty.proto"; 6 | import "github.com/gogo/protobuf/gogoproto/gogo.proto"; 7 | 8 | option go_package = "github.com/snowzach/doods/odrpc"; 9 | 10 | service odrpc { 11 | 12 | // Get Config 13 | rpc GetDetectors(google.protobuf.Empty) returns (GetDetectorsResponse) { 14 | option (google.api.http) = { 15 | get: "/detectors" 16 | }; 17 | } 18 | 19 | // Process an request 20 | rpc Detect(DetectRequest) returns (DetectResponse) { 21 | option (google.api.http) = { 22 | post: "/detect" 23 | body: "*" 24 | additional_bindings: { 25 | post: "/detect/{detector_name}" 26 | body: "*" 27 | } 28 | }; 29 | } 30 | 31 | // Process stream requests 32 | rpc DetectStream(stream DetectRequest) returns (stream DetectResponse){ 33 | } 34 | 35 | } 36 | 37 | message GetDetectorsResponse { 38 | repeated Detector detectors = 1; 39 | } 40 | 41 | message Detector { 42 | // The name for this config 43 | string name = 1; 44 | // The name for this config 45 | string type = 2; 46 | // Model Name 47 | string model = 3; 48 | // Labels 49 | repeated string labels = 4; 50 | // The detection width 51 | int32 width = 5; 52 | // The detection height 53 | int32 height = 6; 54 | // The detection channels 55 | int32 channels = 7; 56 | } 57 | 58 | // The Process Request 59 | message DetectRequest { 60 | // The ID for the request. 61 | string id = 1; 62 | // The ID for the request. 63 | string detector_name = 2; 64 | // The image data 65 | bytes data = 3 [(gogoproto.casttype) = "Raw",(gogoproto.jsontag) = "data"]; 66 | // A filename 67 | string file = 4; 68 | // What to detect 69 | map detect = 5; 70 | // Sub regions for detection 71 | repeated DetectRegion regions = 6; 72 | } 73 | 74 | message DetectRegion { 75 | // Coordinates 76 | float top = 1 [(gogoproto.jsontag) = "top"]; 77 | float left = 2 [(gogoproto.jsontag) = "left"]; 78 | float bottom = 3 [(gogoproto.jsontag) = "bottom"]; 79 | float right = 4 [(gogoproto.jsontag) = "right"]; 80 | // What to detect 81 | map detect = 5; 82 | bool covers = 6; 83 | } 84 | 85 | // Area for detection 86 | message Detection { 87 | // Coordinates 88 | float top = 1 [(gogoproto.jsontag) = "top"]; 89 | float left = 2 [(gogoproto.jsontag) = "left"]; 90 | float bottom = 3 [(gogoproto.jsontag) = "bottom"]; 91 | float right = 4 [(gogoproto.jsontag) = "right"]; 92 | string label = 5 [(gogoproto.jsontag) = "label"]; 93 | float confidence = 6 [(gogoproto.jsontag) = "confidence"]; 94 | } 95 | 96 | message DetectResponse { 97 | // The id for the response 98 | string id = 1; 99 | // The detected areas 100 | repeated Detection detections = 2; 101 | // If there was an error (streaming endpoint only) 102 | string error = 3; 103 | } 104 | -------------------------------------------------------------------------------- /odrpc/rpc.swagger.json: -------------------------------------------------------------------------------- 1 | { 2 | "swagger": "2.0", 3 | "info": { 4 | "title": "odrpc/rpc.proto", 5 | "version": "version not set" 6 | }, 7 | "schemes": [ 8 | "http", 9 | "https" 10 | ], 11 | "consumes": [ 12 | "application/json" 13 | ], 14 | "produces": [ 15 | "application/json" 16 | ], 17 | "paths": { 18 | "/detect": { 19 | "post": { 20 | "summary": "Process an request", 21 | "operationId": "Detect", 22 | "responses": { 23 | "200": { 24 | "description": "A successful response.", 25 | "schema": { 26 | "$ref": "#/definitions/odrpcDetectResponse" 27 | } 28 | } 29 | }, 30 | "parameters": [ 31 | { 32 | "name": "body", 33 | "in": "body", 34 | "required": true, 35 | "schema": { 36 | "$ref": "#/definitions/odrpcDetectRequest" 37 | } 38 | } 39 | ], 40 | "tags": [ 41 | "odrpc" 42 | ] 43 | } 44 | }, 45 | "/detect/{detector_name}": { 46 | "post": { 47 | "summary": "Process an request", 48 | "operationId": "Detect2", 49 | "responses": { 50 | "200": { 51 | "description": "A successful response.", 52 | "schema": { 53 | "$ref": "#/definitions/odrpcDetectResponse" 54 | } 55 | } 56 | }, 57 | "parameters": [ 58 | { 59 | "name": "detector_name", 60 | "description": "The ID for the request.", 61 | "in": "path", 62 | "required": true, 63 | "type": "string" 64 | }, 65 | { 66 | "name": "body", 67 | "in": "body", 68 | "required": true, 69 | "schema": { 70 | "$ref": "#/definitions/odrpcDetectRequest" 71 | } 72 | } 73 | ], 74 | "tags": [ 75 | "odrpc" 76 | ] 77 | } 78 | }, 79 | "/detectors": { 80 | "get": { 81 | "summary": "Get Config", 82 | "operationId": "GetDetectors", 83 | "responses": { 84 | "200": { 85 | "description": "A successful response.", 86 | "schema": { 87 | "$ref": "#/definitions/odrpcGetDetectorsResponse" 88 | } 89 | } 90 | }, 91 | "tags": [ 92 | "odrpc" 93 | ] 94 | } 95 | } 96 | }, 97 | "definitions": { 98 | "odrpcDetectRegion": { 99 | "type": "object", 100 | "properties": { 101 | "top": { 102 | "type": "number", 103 | "format": "float", 104 | "title": "Coordinates" 105 | }, 106 | "left": { 107 | "type": "number", 108 | "format": "float" 109 | }, 110 | "bottom": { 111 | "type": "number", 112 | "format": "float" 113 | }, 114 | "right": { 115 | "type": "number", 116 | "format": "float" 117 | }, 118 | "detect": { 119 | "type": "object", 120 | "additionalProperties": { 121 | "type": "number", 122 | "format": "float" 123 | }, 124 | "title": "What to detect" 125 | }, 126 | "covers": { 127 | "type": "boolean", 128 | "format": "boolean" 129 | } 130 | } 131 | }, 132 | "odrpcDetectRequest": { 133 | "type": "object", 134 | "properties": { 135 | "id": { 136 | "type": "string", 137 | "description": "The ID for the request." 138 | }, 139 | "detector_name": { 140 | "type": "string", 141 | "description": "The ID for the request." 142 | }, 143 | "data": { 144 | "type": "string", 145 | "format": "byte", 146 | "title": "The image data" 147 | }, 148 | "file": { 149 | "type": "string", 150 | "title": "A filename" 151 | }, 152 | "detect": { 153 | "type": "object", 154 | "additionalProperties": { 155 | "type": "number", 156 | "format": "float" 157 | }, 158 | "title": "What to detect" 159 | }, 160 | "regions": { 161 | "type": "array", 162 | "items": { 163 | "$ref": "#/definitions/odrpcDetectRegion" 164 | }, 165 | "title": "Sub regions for detection" 166 | } 167 | }, 168 | "title": "The Process Request" 169 | }, 170 | "odrpcDetectResponse": { 171 | "type": "object", 172 | "properties": { 173 | "id": { 174 | "type": "string", 175 | "title": "The id for the response" 176 | }, 177 | "detections": { 178 | "type": "array", 179 | "items": { 180 | "$ref": "#/definitions/odrpcDetection" 181 | }, 182 | "title": "The detected areas" 183 | }, 184 | "error": { 185 | "type": "string", 186 | "title": "If there was an error (streaming endpoint only)" 187 | } 188 | } 189 | }, 190 | "odrpcDetection": { 191 | "type": "object", 192 | "properties": { 193 | "top": { 194 | "type": "number", 195 | "format": "float", 196 | "title": "Coordinates" 197 | }, 198 | "left": { 199 | "type": "number", 200 | "format": "float" 201 | }, 202 | "bottom": { 203 | "type": "number", 204 | "format": "float" 205 | }, 206 | "right": { 207 | "type": "number", 208 | "format": "float" 209 | }, 210 | "label": { 211 | "type": "string" 212 | }, 213 | "confidence": { 214 | "type": "number", 215 | "format": "float" 216 | } 217 | }, 218 | "title": "Area for detection" 219 | }, 220 | "odrpcDetector": { 221 | "type": "object", 222 | "properties": { 223 | "name": { 224 | "type": "string", 225 | "title": "The name for this config" 226 | }, 227 | "type": { 228 | "type": "string", 229 | "title": "The name for this config" 230 | }, 231 | "model": { 232 | "type": "string", 233 | "title": "Model Name" 234 | }, 235 | "labels": { 236 | "type": "array", 237 | "items": { 238 | "type": "string" 239 | }, 240 | "title": "Labels" 241 | }, 242 | "width": { 243 | "type": "integer", 244 | "format": "int32", 245 | "title": "The detection width" 246 | }, 247 | "height": { 248 | "type": "integer", 249 | "format": "int32", 250 | "title": "The detection height" 251 | }, 252 | "channels": { 253 | "type": "integer", 254 | "format": "int32", 255 | "title": "The detection channels" 256 | } 257 | } 258 | }, 259 | "odrpcGetDetectorsResponse": { 260 | "type": "object", 261 | "properties": { 262 | "detectors": { 263 | "type": "array", 264 | "items": { 265 | "$ref": "#/definitions/odrpcDetector" 266 | } 267 | } 268 | } 269 | }, 270 | "protobufAny": { 271 | "type": "object", 272 | "properties": { 273 | "type_url": { 274 | "type": "string" 275 | }, 276 | "value": { 277 | "type": "string", 278 | "format": "byte" 279 | } 280 | } 281 | }, 282 | "runtimeStreamError": { 283 | "type": "object", 284 | "properties": { 285 | "grpc_code": { 286 | "type": "integer", 287 | "format": "int32" 288 | }, 289 | "http_code": { 290 | "type": "integer", 291 | "format": "int32" 292 | }, 293 | "message": { 294 | "type": "string" 295 | }, 296 | "http_status": { 297 | "type": "string" 298 | }, 299 | "details": { 300 | "type": "array", 301 | "items": { 302 | "$ref": "#/definitions/protobufAny" 303 | } 304 | } 305 | } 306 | } 307 | }, 308 | "x-stream-definitions": { 309 | "odrpcDetectResponse": { 310 | "type": "object", 311 | "properties": { 312 | "result": { 313 | "$ref": "#/definitions/odrpcDetectResponse" 314 | }, 315 | "error": { 316 | "$ref": "#/definitions/runtimeStreamError" 317 | } 318 | }, 319 | "title": "Stream result of odrpcDetectResponse" 320 | } 321 | } 322 | } 323 | -------------------------------------------------------------------------------- /server/error.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/render" 7 | ) 8 | 9 | // ErrResponse is a generic struct for returning a standard error document 10 | type ErrResponse struct { 11 | Err error `json:"-"` // low-level runtime error 12 | HTTPStatusCode int `json:"-"` // http response status code 13 | 14 | StatusText string `json:"status"` // user-level status message 15 | AppCode int64 `json:"code,omitempty"` // application-specific error code 16 | ErrorText string `json:"error,omitempty"` // application-level error message, for debugging 17 | } 18 | 19 | // ErrNotFound is a pre-built not-found error 20 | var ErrNotFound = &ErrResponse{HTTPStatusCode: 404, StatusText: "Resource not found."} 21 | 22 | // Render is the Renderer for ErrResponse struct 23 | func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { 24 | render.Status(r, e.HTTPStatusCode) 25 | return nil 26 | } 27 | 28 | // ErrInvalidRequest is used to indicate an error on user input (with wrapped error) 29 | func ErrInvalidRequest(err error) render.Renderer { 30 | var errorText string 31 | if err != nil { 32 | errorText = err.Error() 33 | } 34 | return &ErrResponse{ 35 | Err: err, 36 | HTTPStatusCode: http.StatusBadRequest, 37 | StatusText: "Invalid request.", 38 | ErrorText: errorText, 39 | } 40 | } 41 | 42 | // ErrInternalLog will log an error and return a generic server error to the user 43 | func (s *Server) ErrInternalLog(err error) render.Renderer { 44 | s.logger.Errorw("Server Error", "error", err) 45 | return ErrInternal(err) 46 | } 47 | 48 | // ErrInternal returns a generic server error to the user 49 | func ErrInternal(err error) render.Renderer { 50 | return &ErrResponse{ 51 | Err: err, 52 | HTTPStatusCode: http.StatusInternalServerError, 53 | StatusText: "Server Error.", 54 | ErrorText: "Server Error.", 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /server/jsonpb.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | 7 | gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" 8 | ) 9 | 10 | // This is a simple GRPC JSON Protobuf marshaller that just uses standard encoding/json for everything 11 | type JSONMarshaler struct{} 12 | 13 | func (jm *JSONMarshaler) Marshal(v interface{}) ([]byte, error) { 14 | return json.Marshal(v) 15 | } 16 | 17 | func (jm *JSONMarshaler) Unmarshal(data []byte, v interface{}) error { 18 | return json.Unmarshal(data, v) 19 | } 20 | 21 | func (jm *JSONMarshaler) NewDecoder(r io.Reader) gwruntime.Decoder { 22 | return json.NewDecoder(r) 23 | } 24 | 25 | func (jm *JSONMarshaler) NewEncoder(w io.Writer) gwruntime.Encoder { 26 | return json.NewEncoder(w) 27 | } 28 | 29 | func (jm *JSONMarshaler) ContentType() string { 30 | return "application/json" 31 | } 32 | -------------------------------------------------------------------------------- /server/routes.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/snowzach/doods/server/rpc" 7 | ) 8 | 9 | // SetupRoutes configures all the routes for this service 10 | func (s *Server) SetupRoutes() { 11 | 12 | // Register our routes - you need at aleast one route 13 | s.router.Get("/none", func(w http.ResponseWriter, r *http.Request) {}) 14 | 15 | // Register RPC Services 16 | rpc.RegisterVersionRPCServer(s.grpcServer, s) 17 | s.GWReg(rpc.RegisterVersionRPCHandlerFromEndpoint) 18 | 19 | } 20 | -------------------------------------------------------------------------------- /server/rpc/rpc.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | // Empty file to prevent issues 4 | -------------------------------------------------------------------------------- /server/rpc/version.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-gogo. DO NOT EDIT. 2 | // source: server/rpc/version.proto 3 | 4 | package rpc 5 | 6 | import ( 7 | context "context" 8 | fmt "fmt" 9 | proto "github.com/gogo/protobuf/proto" 10 | empty "github.com/golang/protobuf/ptypes/empty" 11 | _ "google.golang.org/genproto/googleapis/api/annotations" 12 | grpc "google.golang.org/grpc" 13 | io "io" 14 | math "math" 15 | reflect "reflect" 16 | strings "strings" 17 | ) 18 | 19 | // Reference imports to suppress errors if they are not otherwise used. 20 | var _ = proto.Marshal 21 | var _ = fmt.Errorf 22 | var _ = math.Inf 23 | 24 | // This is a compile-time assertion to ensure that this generated file 25 | // is compatible with the proto package it is being compiled against. 26 | // A compilation error at this line likely means your copy of the 27 | // proto package needs to be updated. 28 | const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package 29 | 30 | type VersionResponse struct { 31 | Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` 32 | } 33 | 34 | func (m *VersionResponse) Reset() { *m = VersionResponse{} } 35 | func (*VersionResponse) ProtoMessage() {} 36 | func (*VersionResponse) Descriptor() ([]byte, []int) { 37 | return fileDescriptor_281c0eaee4e6b80d, []int{0} 38 | } 39 | func (m *VersionResponse) XXX_Unmarshal(b []byte) error { 40 | return m.Unmarshal(b) 41 | } 42 | func (m *VersionResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { 43 | if deterministic { 44 | return xxx_messageInfo_VersionResponse.Marshal(b, m, deterministic) 45 | } else { 46 | b = b[:cap(b)] 47 | n, err := m.MarshalTo(b) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return b[:n], nil 52 | } 53 | } 54 | func (m *VersionResponse) XXX_Merge(src proto.Message) { 55 | xxx_messageInfo_VersionResponse.Merge(m, src) 56 | } 57 | func (m *VersionResponse) XXX_Size() int { 58 | return m.Size() 59 | } 60 | func (m *VersionResponse) XXX_DiscardUnknown() { 61 | xxx_messageInfo_VersionResponse.DiscardUnknown(m) 62 | } 63 | 64 | var xxx_messageInfo_VersionResponse proto.InternalMessageInfo 65 | 66 | func (m *VersionResponse) GetVersion() string { 67 | if m != nil { 68 | return m.Version 69 | } 70 | return "" 71 | } 72 | 73 | func init() { 74 | proto.RegisterType((*VersionResponse)(nil), "server.VersionResponse") 75 | } 76 | 77 | func init() { proto.RegisterFile("server/rpc/version.proto", fileDescriptor_281c0eaee4e6b80d) } 78 | 79 | var fileDescriptor_281c0eaee4e6b80d = []byte{ 80 | // 266 bytes of a gzipped FileDescriptorProto 81 | 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x28, 0x4e, 0x2d, 0x2a, 82 | 0x4b, 0x2d, 0xd2, 0x2f, 0x2a, 0x48, 0xd6, 0x2f, 0x4b, 0x2d, 0x2a, 0xce, 0xcc, 0xcf, 0xd3, 0x2b, 83 | 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xc8, 0x48, 0xc9, 0xa4, 0xe7, 0xe7, 0xa7, 0xe7, 0xa4, 84 | 0xea, 0x27, 0x16, 0x64, 0xea, 0x27, 0xe6, 0xe5, 0xe5, 0x97, 0x24, 0x96, 0x64, 0xe6, 0xe7, 0x15, 85 | 0x43, 0x54, 0x49, 0x49, 0x43, 0x65, 0xc1, 0xbc, 0xa4, 0xd2, 0x34, 0xfd, 0xd4, 0xdc, 0x82, 0x92, 86 | 0x4a, 0x88, 0xa4, 0x92, 0x36, 0x17, 0x7f, 0x18, 0xc4, 0xcc, 0xa0, 0xd4, 0xe2, 0x82, 0xfc, 0xbc, 87 | 0xe2, 0x54, 0x21, 0x09, 0x2e, 0x76, 0xa8, 0x35, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 88 | 0xae, 0x51, 0x14, 0x17, 0x17, 0x4c, 0x71, 0x80, 0xb3, 0x90, 0x0f, 0x17, 0x3b, 0x94, 0x27, 0x24, 89 | 0xa6, 0x07, 0xb1, 0x43, 0x0f, 0x66, 0x87, 0x9e, 0x2b, 0xc8, 0x0e, 0x29, 0x71, 0x3d, 0x88, 0x0b, 90 | 0xf5, 0xd0, 0xec, 0x50, 0x12, 0x68, 0xba, 0xfc, 0x64, 0x32, 0x13, 0x97, 0x10, 0x07, 0xcc, 0x47, 91 | 0x4e, 0x51, 0x17, 0x1e, 0xca, 0x31, 0xdc, 0x78, 0x28, 0xc7, 0xf0, 0xe1, 0xa1, 0x1c, 0x63, 0xc3, 92 | 0x23, 0x39, 0xc6, 0x15, 0x8f, 0xe4, 0x18, 0x4f, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 93 | 0xc1, 0x23, 0x39, 0xc6, 0x17, 0x8f, 0xe4, 0x18, 0x3e, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 94 | 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x54, 0xd2, 0x33, 0x4b, 0x32, 0x4a, 95 | 0x93, 0xf4, 0x92, 0xf3, 0x73, 0xf5, 0x8b, 0xf3, 0xf2, 0xcb, 0xab, 0x12, 0x93, 0x33, 0xf4, 0x53, 96 | 0xf2, 0xf3, 0x53, 0x8a, 0xf5, 0x11, 0xa1, 0x96, 0xc4, 0x06, 0x76, 0x96, 0x31, 0x20, 0x00, 0x00, 97 | 0xff, 0xff, 0x6f, 0x71, 0xae, 0xb3, 0x4a, 0x01, 0x00, 0x00, 98 | } 99 | 100 | func (this *VersionResponse) Equal(that interface{}) bool { 101 | if that == nil { 102 | return this == nil 103 | } 104 | 105 | that1, ok := that.(*VersionResponse) 106 | if !ok { 107 | that2, ok := that.(VersionResponse) 108 | if ok { 109 | that1 = &that2 110 | } else { 111 | return false 112 | } 113 | } 114 | if that1 == nil { 115 | return this == nil 116 | } else if this == nil { 117 | return false 118 | } 119 | if this.Version != that1.Version { 120 | return false 121 | } 122 | return true 123 | } 124 | func (this *VersionResponse) GoString() string { 125 | if this == nil { 126 | return "nil" 127 | } 128 | s := make([]string, 0, 5) 129 | s = append(s, "&rpc.VersionResponse{") 130 | s = append(s, "Version: "+fmt.Sprintf("%#v", this.Version)+",\n") 131 | s = append(s, "}") 132 | return strings.Join(s, "") 133 | } 134 | func valueToGoStringVersion(v interface{}, typ string) string { 135 | rv := reflect.ValueOf(v) 136 | if rv.IsNil() { 137 | return "nil" 138 | } 139 | pv := reflect.Indirect(rv).Interface() 140 | return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) 141 | } 142 | 143 | // Reference imports to suppress errors if they are not otherwise used. 144 | var _ context.Context 145 | var _ grpc.ClientConn 146 | 147 | // This is a compile-time assertion to ensure that this generated file 148 | // is compatible with the grpc package it is being compiled against. 149 | const _ = grpc.SupportPackageIsVersion4 150 | 151 | // VersionRPCClient is the client API for VersionRPC service. 152 | // 153 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. 154 | type VersionRPCClient interface { 155 | Version(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*VersionResponse, error) 156 | } 157 | 158 | type versionRPCClient struct { 159 | cc *grpc.ClientConn 160 | } 161 | 162 | func NewVersionRPCClient(cc *grpc.ClientConn) VersionRPCClient { 163 | return &versionRPCClient{cc} 164 | } 165 | 166 | func (c *versionRPCClient) Version(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*VersionResponse, error) { 167 | out := new(VersionResponse) 168 | err := c.cc.Invoke(ctx, "/server.VersionRPC/Version", in, out, opts...) 169 | if err != nil { 170 | return nil, err 171 | } 172 | return out, nil 173 | } 174 | 175 | // VersionRPCServer is the server API for VersionRPC service. 176 | type VersionRPCServer interface { 177 | Version(context.Context, *empty.Empty) (*VersionResponse, error) 178 | } 179 | 180 | func RegisterVersionRPCServer(s *grpc.Server, srv VersionRPCServer) { 181 | s.RegisterService(&_VersionRPC_serviceDesc, srv) 182 | } 183 | 184 | func _VersionRPC_Version_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 185 | in := new(empty.Empty) 186 | if err := dec(in); err != nil { 187 | return nil, err 188 | } 189 | if interceptor == nil { 190 | return srv.(VersionRPCServer).Version(ctx, in) 191 | } 192 | info := &grpc.UnaryServerInfo{ 193 | Server: srv, 194 | FullMethod: "/server.VersionRPC/Version", 195 | } 196 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 197 | return srv.(VersionRPCServer).Version(ctx, req.(*empty.Empty)) 198 | } 199 | return interceptor(ctx, in, info, handler) 200 | } 201 | 202 | var _VersionRPC_serviceDesc = grpc.ServiceDesc{ 203 | ServiceName: "server.VersionRPC", 204 | HandlerType: (*VersionRPCServer)(nil), 205 | Methods: []grpc.MethodDesc{ 206 | { 207 | MethodName: "Version", 208 | Handler: _VersionRPC_Version_Handler, 209 | }, 210 | }, 211 | Streams: []grpc.StreamDesc{}, 212 | Metadata: "server/rpc/version.proto", 213 | } 214 | 215 | func (m *VersionResponse) Marshal() (dAtA []byte, err error) { 216 | size := m.Size() 217 | dAtA = make([]byte, size) 218 | n, err := m.MarshalTo(dAtA) 219 | if err != nil { 220 | return nil, err 221 | } 222 | return dAtA[:n], nil 223 | } 224 | 225 | func (m *VersionResponse) MarshalTo(dAtA []byte) (int, error) { 226 | var i int 227 | _ = i 228 | var l int 229 | _ = l 230 | if len(m.Version) > 0 { 231 | dAtA[i] = 0xa 232 | i++ 233 | i = encodeVarintVersion(dAtA, i, uint64(len(m.Version))) 234 | i += copy(dAtA[i:], m.Version) 235 | } 236 | return i, nil 237 | } 238 | 239 | func encodeVarintVersion(dAtA []byte, offset int, v uint64) int { 240 | for v >= 1<<7 { 241 | dAtA[offset] = uint8(v&0x7f | 0x80) 242 | v >>= 7 243 | offset++ 244 | } 245 | dAtA[offset] = uint8(v) 246 | return offset + 1 247 | } 248 | func (m *VersionResponse) Size() (n int) { 249 | if m == nil { 250 | return 0 251 | } 252 | var l int 253 | _ = l 254 | l = len(m.Version) 255 | if l > 0 { 256 | n += 1 + l + sovVersion(uint64(l)) 257 | } 258 | return n 259 | } 260 | 261 | func sovVersion(x uint64) (n int) { 262 | for { 263 | n++ 264 | x >>= 7 265 | if x == 0 { 266 | break 267 | } 268 | } 269 | return n 270 | } 271 | func sozVersion(x uint64) (n int) { 272 | return sovVersion(uint64((x << 1) ^ uint64((int64(x) >> 63)))) 273 | } 274 | func (this *VersionResponse) String() string { 275 | if this == nil { 276 | return "nil" 277 | } 278 | s := strings.Join([]string{`&VersionResponse{`, 279 | `Version:` + fmt.Sprintf("%v", this.Version) + `,`, 280 | `}`, 281 | }, "") 282 | return s 283 | } 284 | func valueToStringVersion(v interface{}) string { 285 | rv := reflect.ValueOf(v) 286 | if rv.IsNil() { 287 | return "nil" 288 | } 289 | pv := reflect.Indirect(rv).Interface() 290 | return fmt.Sprintf("*%v", pv) 291 | } 292 | func (m *VersionResponse) Unmarshal(dAtA []byte) error { 293 | l := len(dAtA) 294 | iNdEx := 0 295 | for iNdEx < l { 296 | preIndex := iNdEx 297 | var wire uint64 298 | for shift := uint(0); ; shift += 7 { 299 | if shift >= 64 { 300 | return ErrIntOverflowVersion 301 | } 302 | if iNdEx >= l { 303 | return io.ErrUnexpectedEOF 304 | } 305 | b := dAtA[iNdEx] 306 | iNdEx++ 307 | wire |= uint64(b&0x7F) << shift 308 | if b < 0x80 { 309 | break 310 | } 311 | } 312 | fieldNum := int32(wire >> 3) 313 | wireType := int(wire & 0x7) 314 | if wireType == 4 { 315 | return fmt.Errorf("proto: VersionResponse: wiretype end group for non-group") 316 | } 317 | if fieldNum <= 0 { 318 | return fmt.Errorf("proto: VersionResponse: illegal tag %d (wire type %d)", fieldNum, wire) 319 | } 320 | switch fieldNum { 321 | case 1: 322 | if wireType != 2 { 323 | return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType) 324 | } 325 | var stringLen uint64 326 | for shift := uint(0); ; shift += 7 { 327 | if shift >= 64 { 328 | return ErrIntOverflowVersion 329 | } 330 | if iNdEx >= l { 331 | return io.ErrUnexpectedEOF 332 | } 333 | b := dAtA[iNdEx] 334 | iNdEx++ 335 | stringLen |= uint64(b&0x7F) << shift 336 | if b < 0x80 { 337 | break 338 | } 339 | } 340 | intStringLen := int(stringLen) 341 | if intStringLen < 0 { 342 | return ErrInvalidLengthVersion 343 | } 344 | postIndex := iNdEx + intStringLen 345 | if postIndex < 0 { 346 | return ErrInvalidLengthVersion 347 | } 348 | if postIndex > l { 349 | return io.ErrUnexpectedEOF 350 | } 351 | m.Version = string(dAtA[iNdEx:postIndex]) 352 | iNdEx = postIndex 353 | default: 354 | iNdEx = preIndex 355 | skippy, err := skipVersion(dAtA[iNdEx:]) 356 | if err != nil { 357 | return err 358 | } 359 | if skippy < 0 { 360 | return ErrInvalidLengthVersion 361 | } 362 | if (iNdEx + skippy) < 0 { 363 | return ErrInvalidLengthVersion 364 | } 365 | if (iNdEx + skippy) > l { 366 | return io.ErrUnexpectedEOF 367 | } 368 | iNdEx += skippy 369 | } 370 | } 371 | 372 | if iNdEx > l { 373 | return io.ErrUnexpectedEOF 374 | } 375 | return nil 376 | } 377 | func skipVersion(dAtA []byte) (n int, err error) { 378 | l := len(dAtA) 379 | iNdEx := 0 380 | for iNdEx < l { 381 | var wire uint64 382 | for shift := uint(0); ; shift += 7 { 383 | if shift >= 64 { 384 | return 0, ErrIntOverflowVersion 385 | } 386 | if iNdEx >= l { 387 | return 0, io.ErrUnexpectedEOF 388 | } 389 | b := dAtA[iNdEx] 390 | iNdEx++ 391 | wire |= (uint64(b) & 0x7F) << shift 392 | if b < 0x80 { 393 | break 394 | } 395 | } 396 | wireType := int(wire & 0x7) 397 | switch wireType { 398 | case 0: 399 | for shift := uint(0); ; shift += 7 { 400 | if shift >= 64 { 401 | return 0, ErrIntOverflowVersion 402 | } 403 | if iNdEx >= l { 404 | return 0, io.ErrUnexpectedEOF 405 | } 406 | iNdEx++ 407 | if dAtA[iNdEx-1] < 0x80 { 408 | break 409 | } 410 | } 411 | return iNdEx, nil 412 | case 1: 413 | iNdEx += 8 414 | return iNdEx, nil 415 | case 2: 416 | var length int 417 | for shift := uint(0); ; shift += 7 { 418 | if shift >= 64 { 419 | return 0, ErrIntOverflowVersion 420 | } 421 | if iNdEx >= l { 422 | return 0, io.ErrUnexpectedEOF 423 | } 424 | b := dAtA[iNdEx] 425 | iNdEx++ 426 | length |= (int(b) & 0x7F) << shift 427 | if b < 0x80 { 428 | break 429 | } 430 | } 431 | if length < 0 { 432 | return 0, ErrInvalidLengthVersion 433 | } 434 | iNdEx += length 435 | if iNdEx < 0 { 436 | return 0, ErrInvalidLengthVersion 437 | } 438 | return iNdEx, nil 439 | case 3: 440 | for { 441 | var innerWire uint64 442 | var start int = iNdEx 443 | for shift := uint(0); ; shift += 7 { 444 | if shift >= 64 { 445 | return 0, ErrIntOverflowVersion 446 | } 447 | if iNdEx >= l { 448 | return 0, io.ErrUnexpectedEOF 449 | } 450 | b := dAtA[iNdEx] 451 | iNdEx++ 452 | innerWire |= (uint64(b) & 0x7F) << shift 453 | if b < 0x80 { 454 | break 455 | } 456 | } 457 | innerWireType := int(innerWire & 0x7) 458 | if innerWireType == 4 { 459 | break 460 | } 461 | next, err := skipVersion(dAtA[start:]) 462 | if err != nil { 463 | return 0, err 464 | } 465 | iNdEx = start + next 466 | if iNdEx < 0 { 467 | return 0, ErrInvalidLengthVersion 468 | } 469 | } 470 | return iNdEx, nil 471 | case 4: 472 | return iNdEx, nil 473 | case 5: 474 | iNdEx += 4 475 | return iNdEx, nil 476 | default: 477 | return 0, fmt.Errorf("proto: illegal wireType %d", wireType) 478 | } 479 | } 480 | panic("unreachable") 481 | } 482 | 483 | var ( 484 | ErrInvalidLengthVersion = fmt.Errorf("proto: negative length found during unmarshaling") 485 | ErrIntOverflowVersion = fmt.Errorf("proto: integer overflow") 486 | ) 487 | -------------------------------------------------------------------------------- /server/rpc/version.pb.gw.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. 2 | // source: server/rpc/version.proto 3 | 4 | /* 5 | Package rpc is a reverse proxy. 6 | 7 | It translates gRPC into RESTful JSON APIs. 8 | */ 9 | package rpc 10 | 11 | import ( 12 | "io" 13 | "net/http" 14 | 15 | "github.com/golang/protobuf/proto" 16 | "github.com/golang/protobuf/ptypes/empty" 17 | "github.com/grpc-ecosystem/grpc-gateway/runtime" 18 | "github.com/grpc-ecosystem/grpc-gateway/utilities" 19 | "golang.org/x/net/context" 20 | "google.golang.org/grpc" 21 | "google.golang.org/grpc/codes" 22 | "google.golang.org/grpc/grpclog" 23 | "google.golang.org/grpc/status" 24 | ) 25 | 26 | var _ codes.Code 27 | var _ io.Reader 28 | var _ status.Status 29 | var _ = runtime.String 30 | var _ = utilities.NewDoubleArray 31 | 32 | func request_VersionRPC_Version_0(ctx context.Context, marshaler runtime.Marshaler, client VersionRPCClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { 33 | var protoReq empty.Empty 34 | var metadata runtime.ServerMetadata 35 | 36 | msg, err := client.Version(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) 37 | return msg, metadata, err 38 | 39 | } 40 | 41 | // RegisterVersionRPCHandlerFromEndpoint is same as RegisterVersionRPCHandler but 42 | // automatically dials to "endpoint" and closes the connection when "ctx" gets done. 43 | func RegisterVersionRPCHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { 44 | conn, err := grpc.Dial(endpoint, opts...) 45 | if err != nil { 46 | return err 47 | } 48 | defer func() { 49 | if err != nil { 50 | if cerr := conn.Close(); cerr != nil { 51 | grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) 52 | } 53 | return 54 | } 55 | go func() { 56 | <-ctx.Done() 57 | if cerr := conn.Close(); cerr != nil { 58 | grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) 59 | } 60 | }() 61 | }() 62 | 63 | return RegisterVersionRPCHandler(ctx, mux, conn) 64 | } 65 | 66 | // RegisterVersionRPCHandler registers the http handlers for service VersionRPC to "mux". 67 | // The handlers forward requests to the grpc endpoint over "conn". 68 | func RegisterVersionRPCHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { 69 | return RegisterVersionRPCHandlerClient(ctx, mux, NewVersionRPCClient(conn)) 70 | } 71 | 72 | // RegisterVersionRPCHandlerClient registers the http handlers for service VersionRPC 73 | // to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "VersionRPCClient". 74 | // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "VersionRPCClient" 75 | // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in 76 | // "VersionRPCClient" to call the correct interceptors. 77 | func RegisterVersionRPCHandlerClient(ctx context.Context, mux *runtime.ServeMux, client VersionRPCClient) error { 78 | 79 | mux.Handle("GET", pattern_VersionRPC_Version_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { 80 | ctx, cancel := context.WithCancel(req.Context()) 81 | defer cancel() 82 | if cn, ok := w.(http.CloseNotifier); ok { 83 | go func(done <-chan struct{}, closed <-chan bool) { 84 | select { 85 | case <-done: 86 | case <-closed: 87 | cancel() 88 | } 89 | }(ctx.Done(), cn.CloseNotify()) 90 | } 91 | inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) 92 | rctx, err := runtime.AnnotateContext(ctx, mux, req) 93 | if err != nil { 94 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 95 | return 96 | } 97 | resp, md, err := request_VersionRPC_Version_0(rctx, inboundMarshaler, client, req, pathParams) 98 | ctx = runtime.NewServerMetadataContext(ctx, md) 99 | if err != nil { 100 | runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) 101 | return 102 | } 103 | 104 | forward_VersionRPC_Version_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) 105 | 106 | }) 107 | 108 | return nil 109 | } 110 | 111 | var ( 112 | pattern_VersionRPC_Version_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0}, []string{"version"}, "")) 113 | ) 114 | 115 | var ( 116 | forward_VersionRPC_Version_0 = runtime.ForwardResponseMessage 117 | ) 118 | -------------------------------------------------------------------------------- /server/rpc/version.proto: -------------------------------------------------------------------------------- 1 | syntax="proto3"; 2 | package server; 3 | 4 | import "google/api/annotations.proto"; 5 | import "google/protobuf/empty.proto"; 6 | 7 | option go_package = "github.com/snowzach/doods/server/rpc"; 8 | 9 | service VersionRPC { 10 | 11 | rpc Version(google.protobuf.Empty) returns (VersionResponse) { 12 | option (google.api.http) = { 13 | get: "/version" 14 | }; 15 | } 16 | 17 | } 18 | 19 | message VersionResponse { 20 | string version = 1; 21 | } 22 | -------------------------------------------------------------------------------- /server/rpc/version.swagger.json: -------------------------------------------------------------------------------- 1 | { 2 | "swagger": "2.0", 3 | "info": { 4 | "title": "server/rpc/version.proto", 5 | "version": "version not set" 6 | }, 7 | "schemes": [ 8 | "http", 9 | "https" 10 | ], 11 | "consumes": [ 12 | "application/json" 13 | ], 14 | "produces": [ 15 | "application/json" 16 | ], 17 | "paths": { 18 | "/version": { 19 | "get": { 20 | "operationId": "Version", 21 | "responses": { 22 | "200": { 23 | "description": "", 24 | "schema": { 25 | "$ref": "#/definitions/serverVersionResponse" 26 | } 27 | } 28 | }, 29 | "tags": [ 30 | "VersionRPC" 31 | ] 32 | } 33 | } 34 | }, 35 | "definitions": { 36 | "serverVersionResponse": { 37 | "type": "object", 38 | "properties": { 39 | "version": { 40 | "type": "string" 41 | } 42 | } 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "crypto/x509" 7 | "fmt" 8 | "log" 9 | "net" 10 | "net/http" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "golang.org/x/net/http2" 16 | "golang.org/x/net/http2/h2c" 17 | 18 | "github.com/blendle/zapdriver" 19 | "github.com/go-chi/chi" 20 | "github.com/go-chi/chi/middleware" 21 | "github.com/go-chi/cors" 22 | "github.com/go-chi/render" 23 | grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" 24 | grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" 25 | gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" 26 | "github.com/snowzach/certtools" 27 | "github.com/snowzach/certtools/autocert" 28 | config "github.com/spf13/viper" 29 | "go.uber.org/zap" 30 | "go.uber.org/zap/zapcore" 31 | "google.golang.org/grpc" 32 | "google.golang.org/grpc/credentials" 33 | "google.golang.org/grpc/reflection" 34 | 35 | "github.com/snowzach/doods/conf" 36 | "github.com/snowzach/doods/odrpc" 37 | ) 38 | 39 | // Server is the GRPC server 40 | type Server struct { 41 | logger *zap.SugaredLogger 42 | router chi.Router 43 | server *http.Server 44 | grpcServer *grpc.Server 45 | gwRegFuncs []gwRegFunc 46 | } 47 | 48 | // When starting to listen, we will reigster gateway functions 49 | type gwRegFunc func(ctx context.Context, mux *gwruntime.ServeMux, endpoint string, opts []grpc.DialOption) error 50 | 51 | // This is the default authentication function, it requires no authentication 52 | func authenticate(ctx context.Context) (context.Context, error) { 53 | return ctx, nil 54 | } 55 | 56 | // New will setup the server 57 | func New() (*Server, error) { 58 | 59 | // This router is used for http requests only, setup all of our middleware 60 | r := chi.NewRouter() 61 | r.Use(middleware.RequestID) 62 | r.Use(middleware.Recoverer) 63 | r.Use(render.SetContentType(render.ContentTypeJSON)) 64 | 65 | // Log Requests - Use appropriate format depending on the encoding 66 | if config.GetBool("server.log_requests") { 67 | switch config.GetString("logger.encoding") { 68 | case "stackdriver": 69 | r.Use(func(next http.Handler) http.Handler { 70 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 71 | start := time.Now() 72 | var requestID string 73 | if reqID := r.Context().Value(middleware.RequestIDKey); reqID != nil { 74 | requestID = reqID.(string) 75 | } 76 | ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 77 | // Parse the request 78 | next.ServeHTTP(ww, r) 79 | // Don't log the version endpoint, it's too noisy 80 | if r.RequestURI == "/version" { 81 | return 82 | } 83 | // If the remote IP is being proxied, use the real IP 84 | remoteIP := r.Header.Get("X-Forwarded-For") 85 | if remoteIP == "" { 86 | remoteIP = r.RemoteAddr 87 | } 88 | zap.L().Info("HTTP Request", []zapcore.Field{ 89 | zapdriver.HTTP(&zapdriver.HTTPPayload{ 90 | RequestMethod: r.Method, 91 | RequestURL: r.RequestURI, 92 | RequestSize: strconv.FormatInt(r.ContentLength, 10), 93 | Status: ww.Status(), 94 | ResponseSize: strconv.Itoa(ww.BytesWritten()), 95 | UserAgent: r.UserAgent(), 96 | RemoteIP: remoteIP, 97 | Referer: r.Referer(), 98 | Latency: fmt.Sprintf("%fs", time.Since(start).Seconds()), 99 | Protocol: r.Proto, 100 | }), 101 | zap.String("request-id", requestID), 102 | }...) 103 | }) 104 | }) 105 | default: 106 | r.Use(func(next http.Handler) http.Handler { 107 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 108 | start := time.Now() 109 | var requestID string 110 | if reqID := r.Context().Value(middleware.RequestIDKey); reqID != nil { 111 | requestID = reqID.(string) 112 | } 113 | ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 114 | next.ServeHTTP(ww, r) 115 | 116 | // Don't log the version endpoint, it's too noisy 117 | if r.RequestURI == "/version" { 118 | return 119 | } 120 | 121 | latency := time.Since(start) 122 | 123 | fields := []zapcore.Field{ 124 | zap.Int("status", ww.Status()), 125 | zap.Duration("took", latency), 126 | zap.String("request", r.RequestURI), 127 | zap.String("method", r.Method), 128 | zap.String("package", "server.request"), 129 | } 130 | if requestID != "" { 131 | fields = append(fields, zap.String("request-id", requestID)) 132 | } 133 | // If we have an x-Forwarded-For header, use that for the remote 134 | if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { 135 | fields = append(fields, zap.String("remote", forwardedFor)) 136 | } else { 137 | fields = append(fields, zap.String("remote", r.RemoteAddr)) 138 | } 139 | zap.L().Info("HTTP Request", fields...) 140 | }) 141 | }) 142 | } 143 | } 144 | 145 | // CORS Config 146 | r.Use(cors.New(cors.Options{ 147 | AllowedOrigins: config.GetStringSlice("server.cors.allowed_origins"), 148 | AllowedMethods: config.GetStringSlice("server.cors.allowed_methods"), 149 | AllowedHeaders: config.GetStringSlice("server.cors.allowed_headers"), 150 | AllowCredentials: config.GetBool("server.cors.allowed_credentials"), 151 | MaxAge: config.GetInt("server.cors.max_age"), 152 | }).Handler) 153 | 154 | // GRPC Interceptors 155 | streamInterceptors := []grpc.StreamServerInterceptor{ 156 | grpc_auth.StreamServerInterceptor(authenticate), 157 | } 158 | unaryInterceptors := []grpc.UnaryServerInterceptor{ 159 | grpc_auth.UnaryServerInterceptor(authenticate), 160 | } 161 | 162 | // GRPC Server Options 163 | serverOptions := []grpc.ServerOption{ 164 | grpc_middleware.WithStreamServerChain(streamInterceptors...), 165 | grpc_middleware.WithUnaryServerChain(unaryInterceptors...), 166 | grpc.MaxRecvMsgSize(config.GetInt("server.max_msg_size")), 167 | } 168 | 169 | // Create gRPC Server 170 | g := grpc.NewServer(serverOptions...) 171 | // Register reflection service on gRPC server (so people know what we have) 172 | reflection.Register(g) 173 | 174 | s := &Server{ 175 | logger: zap.S().With("package", "server"), 176 | router: r, 177 | grpcServer: g, 178 | gwRegFuncs: make([]gwRegFunc, 0), 179 | } 180 | s.server = &http.Server{ 181 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 | if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { 183 | g.ServeHTTP(w, r) 184 | } else { 185 | s.router.ServeHTTP(w, r) 186 | } 187 | }), 188 | ErrorLog: log.New(&errorLogger{logger: s.logger}, "", 0), 189 | } 190 | 191 | s.SetupRoutes() 192 | 193 | return s, nil 194 | 195 | } 196 | 197 | // ListenAndServe will listen for requests 198 | func (s *Server) ListenAndServe() error { 199 | 200 | s.server.Addr = net.JoinHostPort(config.GetString("server.host"), config.GetString("server.port")) 201 | 202 | // Listen 203 | listener, err := net.Listen("tcp", s.server.Addr) 204 | if err != nil { 205 | return fmt.Errorf("Could not listen on %s: %v", s.server.Addr, err) 206 | } 207 | 208 | grpcGatewayDialOptions := []grpc.DialOption{} 209 | 210 | // Enable TLS? 211 | if config.GetBool("server.tls") { 212 | var cert tls.Certificate 213 | if config.GetBool("server.devcert") { 214 | s.logger.Warn("WARNING: This server is using an insecure development tls certificate. This is for development only!!!") 215 | cert, err = autocert.New(autocert.InsecureStringReader("localhost")) 216 | if err != nil { 217 | return fmt.Errorf("Could not autocert generate server certificate: %v", err) 218 | } 219 | } else { 220 | // Load keys from file 221 | cert, err = tls.LoadX509KeyPair(config.GetString("server.certfile"), config.GetString("server.keyfile")) 222 | if err != nil { 223 | return fmt.Errorf("Could not load server certificate: %v", err) 224 | } 225 | } 226 | 227 | // Enabed Certs - TODO Add/Get a cert 228 | s.server.TLSConfig = &tls.Config{ 229 | Certificates: []tls.Certificate{cert}, 230 | MinVersion: certtools.SecureTLSMinVersion(), 231 | CipherSuites: certtools.SecureTLSCipherSuites(), 232 | NextProtos: []string{"h2"}, 233 | } 234 | // Wrap the listener in a TLS Listener 235 | listener = tls.NewListener(listener, s.server.TLSConfig) 236 | 237 | // Fetch the CommonName from the certificate and generate a cert pool for the grpc gateway to use 238 | // This essentially figures out whatever certificate we happen to be using and makes it valid for the call between the GRPC gateway and the GRPC endpoint 239 | x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) 240 | if err != nil { 241 | return fmt.Errorf("Could not parse x509 public cert from tls certificate: %v", err) 242 | } 243 | clientCertPool := x509.NewCertPool() 244 | clientCertPool.AddCert(x509Cert) 245 | grpcCreds := credentials.NewClientTLSFromCert(clientCertPool, x509Cert.Subject.CommonName) 246 | grpcGatewayDialOptions = append(grpcGatewayDialOptions, grpc.WithTransportCredentials(grpcCreds)) 247 | 248 | } else { 249 | // This h2c helper allows using insecure requests to http2/grpc 250 | s.server.Handler = h2c.NewHandler(s.server.Handler, &http2.Server{}) 251 | grpcGatewayDialOptions = append(grpcGatewayDialOptions, grpc.WithInsecure()) 252 | } 253 | 254 | // Setup the GRPC gateway 255 | grpcGatewayMux := gwruntime.NewServeMux( 256 | gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &JSONMarshaler{}), 257 | gwruntime.WithIncomingHeaderMatcher(func(header string) (string, bool) { 258 | // Pass our headers 259 | switch strings.ToLower(header) { 260 | case odrpc.DoodsAuthKeyHeader: 261 | return header, true 262 | } 263 | return header, false 264 | }), 265 | ) 266 | // If the main router did not find and endpoint, pass it to the grpcGateway 267 | s.router.NotFound(func(w http.ResponseWriter, r *http.Request) { 268 | grpcGatewayMux.ServeHTTP(w, r) 269 | }) 270 | 271 | // Register all the GRPC gateway functions 272 | for _, gwrf := range s.gwRegFuncs { 273 | err = gwrf(context.Background(), grpcGatewayMux, listener.Addr().String(), grpcGatewayDialOptions) 274 | if err != nil { 275 | return fmt.Errorf("Could not register HTTP/gRPC gateway: %s", err) 276 | } 277 | } 278 | 279 | go func() { 280 | if err = s.server.Serve(listener); err != nil { 281 | s.logger.Fatalw("API Listen error", "error", err, "address", s.server.Addr) 282 | } 283 | }() 284 | s.logger.Infow("API Listening", "address", s.server.Addr, "tls", config.GetBool("server.tls"), "version", conf.GitVersion) 285 | 286 | // Enable profiler 287 | if config.GetBool("server.profiler_enabled") && config.GetString("server.profiler_path") != "" { 288 | zap.S().Debugw("Profiler enabled on API", "path", config.GetString("server.profiler_path")) 289 | s.router.Mount(config.GetString("server.profiler_path"), middleware.Profiler()) 290 | } 291 | 292 | return nil 293 | 294 | } 295 | 296 | // GWReg will save a gateway registration function for later when the server is started 297 | func (s *Server) GWReg(gwrf gwRegFunc) { 298 | s.gwRegFuncs = append(s.gwRegFuncs, gwrf) 299 | } 300 | 301 | // GRPCServer will return the grpc server to allow functions to register themselves 302 | func (s *Server) GRPCServer() *grpc.Server { 303 | return s.grpcServer 304 | } 305 | 306 | // errorLogger is used for logging errors from the server 307 | type errorLogger struct { 308 | logger *zap.SugaredLogger 309 | } 310 | 311 | // ErrorLogger implements an error logging function for the server 312 | func (el *errorLogger) Write(b []byte) (int, error) { 313 | el.logger.Error(string(b)) 314 | return len(b), nil 315 | } 316 | 317 | // RenderOrErrInternal will render whatever you pass it (assuming it has Renderer) or prints an internal error 318 | func RenderOrErrInternal(w http.ResponseWriter, r *http.Request, d render.Renderer) { 319 | if err := render.Render(w, r, d); err != nil { 320 | render.Render(w, r, ErrInternal(err)) 321 | return 322 | } 323 | } 324 | -------------------------------------------------------------------------------- /server/version.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | 6 | emptypb "github.com/golang/protobuf/ptypes/empty" 7 | 8 | "github.com/snowzach/doods/conf" 9 | "github.com/snowzach/doods/server/rpc" 10 | ) 11 | 12 | // Version returns the version 13 | func (s *Server) Version(ctx context.Context, _ *emptypb.Empty) (*rpc.VersionResponse, error) { 14 | 15 | return &rpc.VersionResponse{ 16 | Version: conf.GitVersion, 17 | }, nil 18 | 19 | } 20 | -------------------------------------------------------------------------------- /tf_arm_toolchain_patch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORKDIR="$(realpath $(dirname $0))" 4 | CROSSTOOL_DIR="${WORKDIR}/toolchain/${CROSSTOOL_DIR}/" 5 | mkdir -p ${WORKDIR}/toolchain/ 6 | wget --no-check-certificate $CROSSTOOL_URL -O toolchain.tar.xz 7 | tar xf toolchain.tar.xz -C ${WORKDIR}/toolchain/ 8 | rm toolchain.tar.xz &>/dev/null 9 | 10 | CROSSTOOL_EXTRA_INCLUDE="/usr/local/include/" 11 | CROSSTOOL_VERSION=$($CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-gcc -dumpversion) 12 | git apply << EOF 13 | diff --git a/third_party/toolchains/cpus/arm/BUILD b/third_party/toolchains/cpus/arm/BUILD 14 | index 5d388e918b..5e44b7e702 100644 15 | --- a/third_party/toolchains/cpus/arm/BUILD 16 | +++ b/third_party/toolchains/cpus/arm/BUILD 17 | @@ -64,5 +64,5 @@ cc_toolchain( 18 | strip_files = "arm_linux_all_files", 19 | supports_param_files = 1, 20 | toolchain_config = ":armeabi_config", 21 | - toolchain_identifier = "arm-linux-gnueabihf", 22 | + toolchain_identifier = "$CROSSTOOL_NAME", 23 | ) 24 | diff --git a/arm_compiler.BUILD b/arm_compiler.BUILD 25 | index cffe3fac70..dbde41825c 100644 26 | --- a/arm_compiler.BUILD 27 | +++ b/arm_compiler.BUILD 28 | @@ -3,65 +3,65 @@ package(default_visibility = ["//visibility:public"]) 29 | filegroup( 30 | name = "gcc", 31 | srcs = [ 32 | - "bin/arm-rpi-linux-gnueabihf-gcc", 33 | + "bin/$CROSSTOOL_NAME-gcc", 34 | ], 35 | ) 36 | 37 | filegroup( 38 | name = "ar", 39 | srcs = [ 40 | - "bin/arm-rpi-linux-gnueabihf-ar", 41 | + "bin/$CROSSTOOL_NAME-ar", 42 | ], 43 | ) 44 | 45 | filegroup( 46 | name = "ld", 47 | srcs = [ 48 | - "bin/arm-rpi-linux-gnueabihf-ld", 49 | + "bin/$CROSSTOOL_NAME-ld", 50 | ], 51 | ) 52 | 53 | filegroup( 54 | name = "nm", 55 | srcs = [ 56 | - "bin/arm-rpi-linux-gnueabihf-nm", 57 | + "bin/$CROSSTOOL_NAME-nm", 58 | ], 59 | ) 60 | 61 | filegroup( 62 | name = "objcopy", 63 | srcs = [ 64 | - "bin/arm-rpi-linux-gnueabihf-objcopy", 65 | + "bin/$CROSSTOOL_NAME-objcopy", 66 | ], 67 | ) 68 | 69 | filegroup( 70 | name = "objdump", 71 | srcs = [ 72 | - "bin/arm-rpi-linux-gnueabihf-objdump", 73 | + "bin/$CROSSTOOL_NAME-objdump", 74 | ], 75 | ) 76 | 77 | filegroup( 78 | name = "strip", 79 | srcs = [ 80 | - "bin/arm-rpi-linux-gnueabihf-strip", 81 | + "bin/$CROSSTOOL_NAME-strip", 82 | ], 83 | ) 84 | 85 | filegroup( 86 | name = "as", 87 | srcs = [ 88 | - "bin/arm-rpi-linux-gnueabihf-as", 89 | + "bin/$CROSSTOOL_NAME-as", 90 | ], 91 | ) 92 | 93 | filegroup( 94 | name = "compiler_pieces", 95 | srcs = glob([ 96 | - "arm-linux-gnueabihf/**", 97 | + "$CROSSTOOL_NAME/**", 98 | "libexec/**", 99 | - "lib/gcc/arm-linux-gnueabihf/**", 100 | + "lib/gcc/$CROSSTOOL_NAME/**", 101 | "include/**", 102 | ]), 103 | ) 104 | diff --git a/third_party/toolchains/cpus/arm/cc_config.bzl.tpl b/third_party/toolchains/cpus/arm/cc_config.bzl.tpl 105 | index bfe91e711b..da292cdcaf 100644 106 | --- a/third_party/toolchains/cpus/arm/cc_config.bzl.tpl 107 | +++ b/third_party/toolchains/cpus/arm/cc_config.bzl.tpl 108 | @@ -17,7 +17,7 @@ load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") 109 | 110 | def _impl(ctx): 111 | if (ctx.attr.cpu == "armeabi"): 112 | - toolchain_identifier = "arm-linux-gnueabihf" 113 | + toolchain_identifier = "$CROSSTOOL_NAME" 114 | elif (ctx.attr.cpu == "local"): 115 | toolchain_identifier = "local_linux" 116 | else: 117 | @@ -269,7 +269,6 @@ def _impl(ctx): 118 | "-U_FORTIFY_SOURCE", 119 | "-D_FORTIFY_SOURCE=1", 120 | "-fstack-protector", 121 | - "-DRASPBERRY_PI", 122 | ], 123 | ), 124 | ], 125 | @@ -331,17 +330,23 @@ def _impl(ctx): 126 | flags = [ 127 | "-std=c++11", 128 | "-isystem", 129 | - "%{ARM_COMPILER_PATH}%/lib/gcc/arm-rpi-linux-gnueabihf/6.5.0/include", 130 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/include/c++/$CROSSTOOL_VERSION/", 131 | "-isystem", 132 | - "%{ARM_COMPILER_PATH}%/lib/gcc/arm-rpi-linux-gnueabihf/6.5.0/include-fixed", 133 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/sysroot/usr/include/", 134 | "-isystem", 135 | - "%{ARM_COMPILER_PATH}%/arm-rpi-linux-gnueabihf/sysroot/usr/include/", 136 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/libc/usr/include/", 137 | "-isystem", 138 | - "%{ARM_COMPILER_PATH}%/arm-rpi-linux-gnueabihf/include/c++/6.5.0/", 139 | + "$CROSSTOOL_DIR/lib/gcc/$CROSSTOOL_NAME/$CROSSTOOL_VERSION/include", 140 | + "-isystem", 141 | + "$CROSSTOOL_DIR/lib/gcc/$CROSSTOOL_NAME/$CROSSTOOL_VERSION/include-fixed", 142 | + "-isystem", 143 | + "$CROSSTOOL_ROOT/usr/include", 144 | + "-isystem", 145 | + "$CROSSTOOL_ROOT/usr/include/$CROSSTOOL_NAME", 146 | + "-isystem", 147 | + "$CROSSTOOL_EXTRA_INCLUDE", 148 | "-isystem", 149 | "%{PYTHON_INCLUDE_PATH}%", 150 | - "-isystem", 151 | - "/usr/include/", 152 | ], 153 | ), 154 | ], 155 | @@ -559,12 +564,15 @@ def _impl(ctx): 156 | 157 | if (ctx.attr.cpu == "armeabi"): 158 | cxx_builtin_include_directories = [ 159 | - "%{ARM_COMPILER_PATH}%/lib/gcc/arm-rpi-linux-gnueabihf/6.5.0/include", 160 | - "%{ARM_COMPILER_PATH}%/lib/gcc/arm-rpi-linux-gnueabihf/6.5.0/include-fixed", 161 | - "%{ARM_COMPILER_PATH}%/arm-rpi-linux-gnueabihf/sysroot/usr/include/", 162 | - "%{ARM_COMPILER_PATH}%/arm-rpi-linux-gnueabihf/include/c++/6.5.0/", 163 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/include/c++/$CROSSTOOL_VERSION/", 164 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/sysroot/usr/include/", 165 | + "$CROSSTOOL_DIR/$CROSSTOOL_NAME/libc/usr/include/", 166 | + "$CROSSTOOL_DIR/lib/gcc/$CROSSTOOL_NAME/$CROSSTOOL_VERSION/include", 167 | + "$CROSSTOOL_DIR/lib/gcc/$CROSSTOOL_NAME/$CROSSTOOL_VERSION/include-fixed", 168 | - "/usr/include", 169 | + "$CROSSTOOL_ROOT/usr/include", 170 | - "/tmp/openblas_install/include/", 171 | + "/usr/include/$CROSSTOOL_NAME", 172 | + "$CROSSTOOL_EXTRA_INCLUDE", 173 | + "%{PYTHON_INCLUDE_PATH}%" 174 | ] 175 | elif (ctx.attr.cpu == "local"): 176 | cxx_builtin_include_directories = ["/usr/lib/gcc/", "/usr/local/include", "/usr/include"] 177 | @@ -579,44 +587,44 @@ def _impl(ctx): 178 | tool_paths = [ 179 | tool_path( 180 | name = "ar", 181 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-ar", 182 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-ar", 183 | ), 184 | tool_path(name = "compat-ld", path = "/bin/false"), 185 | tool_path( 186 | name = "cpp", 187 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-cpp", 188 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-cpp", 189 | ), 190 | tool_path( 191 | name = "dwp", 192 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-dwp", 193 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-dwp", 194 | ), 195 | tool_path( 196 | name = "gcc", 197 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-gcc", 198 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-gcc", 199 | ), 200 | tool_path( 201 | name = "gcov", 202 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-gcov", 203 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-gcov", 204 | ), 205 | tool_path( 206 | name = "ld", 207 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-ld", 208 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-ld", 209 | ), 210 | tool_path( 211 | name = "nm", 212 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-nm", 213 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-nm", 214 | ), 215 | tool_path( 216 | name = "objcopy", 217 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-objcopy", 218 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-objcopy", 219 | ), 220 | tool_path( 221 | name = "objdump", 222 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-objdump", 223 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-objdump", 224 | ), 225 | tool_path( 226 | name = "strip", 227 | - path = "%{ARM_COMPILER_PATH}%/bin/arm-rpi-linux-gnueabihf-strip", 228 | + path = "$CROSSTOOL_DIR/bin/$CROSSTOOL_NAME-strip", 229 | ), 230 | ] 231 | elif (ctx.attr.cpu == "local"): 232 | EOF 233 | --------------------------------------------------------------------------------