├── .dockerignore ├── .gitignore ├── Dockerfile.benchmark_client ├── Dockerfile.benchmark_server ├── Dockerfile.caffe_server ├── Dockerfile.inference_client ├── Dockerfile.tensorrt_server ├── LICENSE ├── README.md ├── benchmark ├── benchmark.cpp ├── benchmark.h ├── kernel.cu ├── kernel.h └── main.go ├── caffe ├── classification.cpp ├── classification.h ├── gpu_allocator.cpp ├── gpu_allocator.h └── main.go ├── common.h ├── images ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg └── 6.jpg └── tensorrt ├── classification.cpp ├── classification.h ├── gpu_allocator.cpp ├── gpu_allocator.h └── main.go /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile.* 2 | .git 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | TensorRT* 2 | -------------------------------------------------------------------------------- /Dockerfile.benchmark_client: -------------------------------------------------------------------------------- 1 | FROM golang:1.7 2 | 3 | MAINTAINER Felix Abecassis "fabecassis@nvidia.com" 4 | 5 | RUN go get github.com/rakyll/hey 6 | 7 | CMD hey -c ${CONCURRENCY} -n ${REQUESTS} http://localhost:8000/benchmark 8 | -------------------------------------------------------------------------------- /Dockerfile.benchmark_server: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:7.5-devel 2 | 3 | MAINTAINER Felix Abecassis "fabecassis@nvidia.com" 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \ 6 | ca-certificates \ 7 | pkg-config \ 8 | wget && \ 9 | rm -rf /var/lib/apt/lists/* 10 | 11 | # Install golang 12 | ENV GOLANG_VERSION 1.6 13 | RUN wget -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-amd64.tar.gz \ 14 | | tar -v -C /usr/local -xz 15 | ENV GOPATH /go 16 | ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH 17 | 18 | # Build benchmark server 19 | COPY benchmark /go/src/benchmark 20 | COPY common.h /go/src/common.h 21 | #RUN cd /go/src/benchmark && nvcc --shared -O3 kernel.cu -Xcompiler -fPIC -o libkernel.so 22 | RUN cd /go/src/benchmark && \ 23 | nvcc -O3 kernel.cu -c -o kernel.o && \ 24 | ar rcs libkernel.a kernel.o 25 | RUN go get -ldflags="-s" benchmark 26 | 27 | # FIME: entrypoint for contexts per device. 28 | CMD ["benchmark"] 29 | -------------------------------------------------------------------------------- /Dockerfile.caffe_server: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 as build 2 | 3 | ENV CUDA_ARCH_BIN "35 52 60 61 70" 4 | ENV CUDA_ARCH_PTX "70" 5 | 6 | # Install dependencies. 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | ca-certificates \ 9 | cmake \ 10 | git \ 11 | libatlas-base-dev \ 12 | libatlas-dev \ 13 | libboost-all-dev \ 14 | libgflags-dev \ 15 | libgoogle-glog-dev \ 16 | libhdf5-dev \ 17 | libprotobuf-dev \ 18 | pkg-config \ 19 | protobuf-compiler \ 20 | python-yaml \ 21 | python-six \ 22 | wget && \ 23 | rm -rf /var/lib/apt/lists/* 24 | 25 | # OpenCV 3.3.1 is needed to support custom allocators for GpuMat objects. 26 | RUN git clone --depth 1 -b 3.3.1 https://github.com/opencv/opencv.git /opencv && \ 27 | mkdir /opencv/build && cd /opencv/build && \ 28 | cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON \ 29 | -DWITH_CUDA=ON -DWITH_CUFFT=OFF -DCUDA_ARCH_BIN="${CUDA_ARCH_BIN}" -DCUDA_ARCH_PTX="${CUDA_ARCH_PTX}" \ 30 | -DWITH_JPEG=ON -DBUILD_JPEG=ON -DWITH_PNG=ON -DBUILD_PNG=ON \ 31 | -DBUILD_TESTS=OFF -DBUILD_EXAMPLES=OFF -DWITH_FFMPEG=OFF -DWITH_GTK=OFF \ 32 | -DWITH_OPENCL=OFF -DWITH_QT=OFF -DWITH_V4L=OFF -DWITH_JASPER=OFF \ 33 | -DWITH_1394=OFF -DWITH_TIFF=OFF -DWITH_OPENEXR=OFF -DWITH_IPP=OFF -DWITH_WEBP=OFF \ 34 | -DBUILD_opencv_superres=OFF -DBUILD_opencv_java=OFF -DBUILD_opencv_python2=OFF \ 35 | -DBUILD_opencv_videostab=OFF -DBUILD_opencv_apps=OFF -DBUILD_opencv_flann=OFF \ 36 | -DBUILD_opencv_ml=OFF -DBUILD_opencv_photo=OFF -DBUILD_opencv_shape=OFF \ 37 | -DBUILD_opencv_cudabgsegm=OFF -DBUILD_opencv_cudaoptflow=OFF -DBUILD_opencv_cudalegacy=OFF \ 38 | -DCUDA_NVCC_FLAGS="-O3" -DCUDA_FAST_MATH=ON .. && \ 39 | make -j"$(nproc)" install && \ 40 | ldconfig && \ 41 | rm -rf /opencv 42 | 43 | # A modified version of Caffe is used to properly handle multithreading and CUDA streams. 44 | RUN git clone --depth 1 -b bvlc_inference https://github.com/flx42/caffe.git /caffe && \ 45 | cd /caffe && \ 46 | cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON \ 47 | -DCUDA_ARCH_NAME=Manual -DCUDA_ARCH_BIN="${CUDA_ARCH_BIN}" -DCUDA_ARCH_PTX="${CUDA_ARCH_PTX}" \ 48 | -DUSE_CUDNN=ON -DUSE_OPENCV=ON -DUSE_LEVELDB=OFF -DUSE_LMDB=OFF \ 49 | -DBUILD_python=OFF -DBUILD_python_layer=OFF -DBUILD_matlab=OFF \ 50 | -DCMAKE_INSTALL_PREFIX=/usr/local \ 51 | -DCUDA_NVCC_FLAGS="-O3" && \ 52 | make -j"$(nproc)" install && \ 53 | ldconfig && \ 54 | make clean 55 | 56 | # Download Caffenet 57 | RUN /caffe/scripts/download_model_binary.py /caffe/models/bvlc_reference_caffenet && \ 58 | /caffe/data/ilsvrc12/get_ilsvrc_aux.sh 59 | 60 | # Install golang 61 | ENV GOLANG_VERSION 1.9.2 62 | RUN wget -nv -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-amd64.tar.gz \ 63 | | tar -C /usr/local -xz 64 | ENV GOPATH /go 65 | ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH 66 | 67 | # Build inference server 68 | COPY caffe /go/src/caffe-server 69 | COPY common.h /go/src/common.h 70 | RUN go get -ldflags="-s -w" caffe-server 71 | 72 | 73 | # We use a multi-stage build to get a smaller image for deployment. 74 | FROM nvidia/cuda:9.0-base-ubuntu16.04 75 | 76 | MAINTAINER Felix Abecassis "fabecassis@nvidia.com" 77 | 78 | RUN apt-get update && apt-get install -y --no-install-recommends \ 79 | libatlas3-base \ 80 | libboost-system1.58.0 \ 81 | libboost-thread1.58.0 \ 82 | libgflags2v5 \ 83 | libgoogle-glog0v5 \ 84 | libhdf5-10 \ 85 | libprotobuf9v5 \ 86 | libcudnn7=7.0.5.15-1+cuda9.0 \ 87 | cuda-cublas-9-0 \ 88 | cuda-curand-9-0 \ 89 | cuda-npp-9-0 && \ 90 | rm -rf /var/lib/apt/lists/ 91 | 92 | # Copy binary and dependencies 93 | COPY --from=build /go/bin/caffe-server /usr/local/bin/caffe-server 94 | COPY --from=build /usr/local/lib /usr/local/lib 95 | RUN ldconfig 96 | 97 | # Copy dataset. If you use your own dataset: delete these lines and mount a volume from the host. 98 | COPY --from=build /caffe/models/bvlc_reference_caffenet/deploy.prototxt /opt/caffenet/deploy.prototxt 99 | COPY --from=build /caffe/models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel /opt/caffenet/bvlc_reference_caffenet.caffemodel 100 | COPY --from=build /caffe/data/ilsvrc12/imagenet_mean.binaryproto /opt/caffenet/imagenet_mean.binaryproto 101 | COPY --from=build /caffe/data/ilsvrc12/synset_words.txt /opt/caffenet/synset_words.txt 102 | 103 | WORKDIR /opt/caffenet 104 | CMD ["caffe-server", "deploy.prototxt", "bvlc_reference_caffenet.caffemodel", "imagenet_mean.binaryproto", "synset_words.txt"] 105 | -------------------------------------------------------------------------------- /Dockerfile.inference_client: -------------------------------------------------------------------------------- 1 | FROM golang:1.7 2 | 3 | MAINTAINER Felix Abecassis "fabecassis@nvidia.com" 4 | 5 | RUN go get github.com/rakyll/hey 6 | 7 | COPY images /images 8 | 9 | CMD hey -c ${CONCURRENCY} -n ${REQUESTS} -m POST -D /images/2.jpg http://localhost:8000/api/classify 10 | -------------------------------------------------------------------------------- /Dockerfile.tensorrt_server: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 2 | ENV CUDA_ARCH "30 35 52" 3 | 4 | MAINTAINER Felix Abecassis "fabecassis@nvidia.com" 5 | 6 | # Install dependencies. 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | ca-certificates \ 9 | cmake \ 10 | git \ 11 | libboost-all-dev \ 12 | libgflags-dev \ 13 | libgoogle-glog-dev \ 14 | libprotobuf-dev \ 15 | pkg-config \ 16 | protobuf-compiler \ 17 | python-yaml \ 18 | wget && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install OpenCV 3.2.0 with CUDA support 22 | RUN git clone --depth 1 -b 3.2.0 https://github.com/Itseez/opencv.git /opencv && \ 23 | cd /opencv && \ 24 | cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON \ 25 | -DWITH_CUDA=ON -DCUDA_ARCH_BIN="${CUDA_ARCH}" -DCUDA_ARCH_PTX="${CUDA_ARCH}" \ 26 | -DWITH_JPEG=ON -DBUILD_JPEG=ON -DWITH_PNG=ON -DBUILD_PNG=ON \ 27 | -DBUILD_TESTS=OFF -DBUILD_EXAMPLES=OFF -DWITH_FFMPEG=OFF -DWITH_GTK=OFF \ 28 | -DWITH_OPENCL=OFF -DWITH_QT=OFF -DWITH_V4L=OFF -DWITH_JASPER=OFF \ 29 | -DWITH_1394=OFF -DWITH_TIFF=OFF -DWITH_OPENEXR=OFF -DWITH_IPP=OFF -DWITH_WEBP=OFF \ 30 | -DBUILD_opencv_superres=OFF -DBUILD_opencv_java=OFF -DBUILD_opencv_python2=OFF \ 31 | -DBUILD_opencv_videostab=OFF -DBUILD_opencv_apps=OFF -DBUILD_opencv_flann=OFF \ 32 | -DBUILD_opencv_ml=OFF -DBUILD_opencv_photo=OFF -DBUILD_opencv_shape=OFF \ 33 | -DBUILD_opencv_cudabgsegm=OFF -DBUILD_opencv_cudaoptflow=OFF -DBUILD_opencv_cudalegacy=OFF \ 34 | -DCUDA_NVCC_FLAGS="-O3" -DCUDA_FAST_MATH=ON && \ 35 | make -j"$(nproc)" install && ldconfig && \ 36 | rm -rf /opencv 37 | 38 | # Install golang 39 | ENV GOLANG_VERSION 1.8.1 40 | RUN wget -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-amd64.tar.gz \ 41 | | tar -v -C /usr/local -xz 42 | ENV GOPATH /go 43 | ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH 44 | 45 | # Require the TensorRT archive to be present in the build context. 46 | ADD TensorRT-2.1.2.x86_64.cuda-8.0-16-04.tar.bz2 /opt/ 47 | 48 | ENV CPLUS_INCLUDE_PATH /opt/TensorRT-2.1.2/include:$CPLUS_INCLUDE_PATH 49 | ENV LD_LIBRARY_PATH /opt/TensorRT-2.1.2/targets/x86_64-linux-gnu/lib:$LD_LIBRARY_PATH 50 | ENV LIBRARY_PATH /opt/TensorRT-2.1.2/targets/x86_64-linux-gnu/lib:$LIBRARY_PATH 51 | 52 | # Copy and build GPU Rest Engine with TensorRT 53 | COPY tensorrt /go/src/tensorrt-server 54 | COPY common.h /go/src/common.h 55 | RUN go get -ldflags="-s" tensorrt-server 56 | 57 | # Download model 58 | RUN git clone -b caffe-0.15 --depth 1 https://github.com/NVIDIA/caffe.git /caffe && \ 59 | /caffe/scripts/download_model_binary.py /caffe/models/bvlc_alexnet && \ 60 | /caffe/data/ilsvrc12/get_ilsvrc_aux.sh 61 | 62 | CMD ["tensorrt-server", "/caffe/models/bvlc_alexnet/deploy.prototxt", \ 63 | "/caffe/models/bvlc_alexnet/bvlc_alexnet.caffemodel", \ 64 | "/caffe/data/ilsvrc12/imagenet_mean.binaryproto", \ 65 | "/caffe/data/ilsvrc12/synset_words.txt"] 66 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of NVIDIA CORPORATION nor the names of its 12 | contributors may be used to endorse or promote products derived 13 | from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This repository shows how to implement a REST server for low-latency image classification (inference) using NVIDIA GPUs. This is an initial demonstration of the [GRE (GPU REST Engine)](https://developer.nvidia.com/gre) software that will allow you to build your own accelerated microservices. 4 | 5 | **This repository is a demo, it is not intended to be a generic solution that can accept any trained model. Code customization will be required for your use cases.** 6 | 7 | This demonstration makes use of several technologies with which you may be familiar: 8 | - [Docker](https://www.docker.com/): for bundling all the dependencies of our program and for easier deployment. 9 | - [Go](https://golang.org/): for its efficient builtin HTTP server. 10 | - [Caffe](https://github.com/BVLC/caffe): because it has good performance and a simple C++ API. 11 | - [TensorRT](https://developer.nvidia.com/tensorrt): NVIDIA's high-performance inference engine. 12 | - [cuDNN](https://developer.nvidia.com/cudnn): for accelerating common deep learning primitives on the GPU. 13 | - [OpenCV](http://opencv.org/): to have a simple C++ API for GPU image processing. 14 | 15 | # Building 16 | 17 | ## Prerequisites 18 | - A Kepler or Maxwell NVIDIA GPU with at least 2 GB of memory. 19 | - A Linux system with recent NVIDIA drivers (recommended: 352.79). 20 | - Install the latest version of [Docker](https://docs.docker.com/linux/step_one/). 21 | - Install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-(version-2.0)). 22 | 23 | ## Build command (Caffe) 24 | The command might take a while to execute: 25 | ``` 26 | $ docker build -t inference_server -f Dockerfile.caffe_server . 27 | ``` 28 | To speedup the build you can modify [this line](https://github.com/NVIDIA/gpu-rest-engine/blob/master/Dockerfile.caffe_server#L5) to only build for the GPU architecture that you need. 29 | 30 | ## Build command (TensorRT) 31 | This command requires the TensorRT archive to be present in the current folder. 32 | ``` 33 | $ docker build -t inference_server -f Dockerfile.tensorrt_server . 34 | ``` 35 | 36 | # Testing 37 | 38 | ## Starting the server 39 | Execute the following command and wait a few seconds for the initialization of the classifiers: 40 | ``` 41 | $ docker run --runtime=nvidia --name=server --net=host --rm inference_server 42 | ``` 43 | You can use the environment variable [`NVIDIA_VISIBLE_DEVICES`](https://github.com/NVIDIA/nvidia-docker/wiki/Usage#gpu-isolation) to isolate GPUs for this container. 44 | 45 | ## Single image 46 | Since we used [`--net=host`](https://docs.docker.com/engine/userguide/networking/), we can access our inference server from a terminal on the host using `curl`: 47 | ``` 48 | $ curl -XPOST --data-binary @images/1.jpg http://127.0.0.1:8000/api/classify 49 | [{"confidence":0.9998,"label":"n02328150 Angora, Angora rabbit"},{"confidence":0.0001,"label":"n02325366 wood rabbit, cottontail, cottontail rabbit"},{"confidence":0.0001,"label":"n02326432 hare"},{"confidence":0.0000,"label":"n02085936 Maltese dog, Maltese terrier, Maltese"},{"confidence":0.0000,"label":"n02342885 hamster"}] 50 | ``` 51 | 52 | ## Benchmarking performance 53 | We can benchmark the performance of our classification server using any tool that can generate HTTP load. We included a Dockerfile 54 | for a benchmarking client using [rakyll/hey](https://github.com/rakyll/hey): 55 | ``` 56 | $ docker build -t inference_client -f Dockerfile.inference_client . 57 | $ docker run -e CONCURRENCY=8 -e REQUESTS=20000 --net=host inference_client 58 | ``` 59 | 60 | If you have `Go` installed on your host, you can also benchmark the server with a client outside of a Docker container: 61 | ``` 62 | $ go get github.com/rakyll/hey 63 | $ hey -n 200000 -m POST -D images/2.jpg http://127.0.0.1:8000/api/classify 64 | ``` 65 | 66 | ## Performance on a NVIDIA DIGITS DevBox 67 | This machine has 4 GeForce GTX Titan X GPUs: 68 | ``` 69 | $ hey -c 8 -n 200000 -m POST -D images/2.jpg http://127.0.0.1:8000/api/classify 70 | Summary: 71 | Total: 100.7775 secs 72 | Slowest: 0.0167 secs 73 | Fastest: 0.0028 secs 74 | Average: 0.0040 secs 75 | Requests/sec: 1984.5690 76 | Total data: 68800000 bytes 77 | Size/request: 344 bytes 78 | [...] 79 | ``` 80 | 81 | As a comparison, Caffe in standalone mode achieves approximately 500 images / second on a single Titan X for inference (`batch=1`). This shows that our code achieves optimal GPU utilization and good multi-GPU scaling, even when adding a REST API on top. A discussion of GPU performance for inference at different batch sizes can be found in our [GPU-Based Deep Learning Inference whitepaper](https://www.nvidia.com/content/tegra/embedded-systems/pdf/jetson_tx1_whitepaper.pdf). 82 | 83 | This inference server is aimed for low-latency applications, to achieve higher throughput we would need to batch multiple incoming client requests, or have clients send multiple images to classify. Batching can be added easily when using the [C++ API](https://github.com/flx42/caffe/commit/be0bff1a84c9e16fb8e8514dc559f2de5ab1a416) of Caffe. An example of this strategy can be found in [this article](https://arxiv.org/pdf/1512.02595.pdf) from Baidu Research, they call it "Batch Dispatch". 84 | 85 | ## Benchmarking overhead of CUDA kernel calls 86 | Similarly to the inference server, a simple server code is provided for estimating the overhead of using CUDA kernels in your code. The server will simply call an empty CUDA kernel before responding `200` to the client. The server can be built using the same commands as above: 87 | ``` 88 | $ docker build -t benchmark_server -f Dockerfile.benchmark_server . 89 | $ docker run --runtime=nvidia --name=server --net=host --rm benchmark_server 90 | ``` 91 | And for the client: 92 | ``` 93 | $ docker build -t benchmark_client -f Dockerfile.benchmark_client . 94 | $ docker run -e CONCURRENCY=8 -e REQUESTS=200000 --net=host benchmark_client 95 | [...] 96 | Summary: 97 | Total: 5.8071 secs 98 | Slowest: 0.0127 secs 99 | Fastest: 0.0001 secs 100 | Average: 0.0002 secs 101 | Requests/sec: 34440.3083 102 | ``` 103 | 104 | 105 | ## Contributing 106 | 107 | Feel free to report issues during build or execution. We also welcome suggestions to improve the performance of this application. 108 | -------------------------------------------------------------------------------- /benchmark/benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "benchmark.h" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "common.h" 9 | #include "kernel.h" 10 | 11 | class BenchmarkContext 12 | { 13 | public: 14 | friend ScopedContext; 15 | 16 | static bool IsCompatible(int device) 17 | { 18 | cudaError_t st = cudaSetDevice(device); 19 | if (st != cudaSuccess) 20 | return false; 21 | 22 | return true; 23 | } 24 | 25 | BenchmarkContext(int device) 26 | : device_(device) 27 | { 28 | cudaError_t st = cudaSetDevice(device_); 29 | if (st != cudaSuccess) 30 | throw std::invalid_argument("could not set CUDA device"); 31 | 32 | st = cudaStreamCreate(&stream_); 33 | if (st != cudaSuccess) 34 | throw std::invalid_argument("could not create CUDA stream"); 35 | } 36 | 37 | ~BenchmarkContext() 38 | { 39 | cudaStreamDestroy(stream_); 40 | } 41 | 42 | cudaStream_t CUDAStream() 43 | { 44 | return stream_; 45 | } 46 | 47 | private: 48 | void Activate() 49 | { 50 | cudaError_t st = cudaSetDevice(device_); 51 | if (st != cudaSuccess) 52 | throw std::invalid_argument("could not set CUDA device"); 53 | } 54 | 55 | void Deactivate() 56 | { 57 | } 58 | 59 | private: 60 | int device_; 61 | cudaStream_t stream_; 62 | }; 63 | 64 | struct benchmark_ctx 65 | { 66 | ContextPool pool; 67 | }; 68 | 69 | constexpr static int kContextsPerDevice = 4; 70 | 71 | benchmark_ctx* benchmark_initialize() 72 | { 73 | try 74 | { 75 | int device_count; 76 | cudaError_t st = cudaGetDeviceCount(&device_count); 77 | if (st != cudaSuccess) 78 | throw std::invalid_argument("could not list CUDA devices"); 79 | 80 | ContextPool pool; 81 | for (int dev = 0; dev < device_count; ++dev) 82 | { 83 | if (!BenchmarkContext::IsCompatible(dev)) 84 | { 85 | std::cerr << "Skipping device: " << dev << std::endl; 86 | continue; 87 | } 88 | 89 | for (int i = 0; i < kContextsPerDevice; ++i) 90 | { 91 | std::unique_ptr context(new BenchmarkContext(dev)); 92 | pool.Push(std::move(context)); 93 | } 94 | } 95 | 96 | if (pool.Size() == 0) 97 | throw std::invalid_argument("no suitable CUDA device"); 98 | 99 | benchmark_ctx* ctx = new benchmark_ctx{std::move(pool)}; 100 | errno = 0; 101 | return ctx; 102 | } 103 | catch (const std::invalid_argument& ex) 104 | { 105 | errno = EINVAL; 106 | return nullptr; 107 | } 108 | } 109 | 110 | void benchmark_execute(benchmark_ctx* ctx) 111 | { 112 | try 113 | { 114 | ScopedContext context(ctx->pool); 115 | cudaStream_t stream = context->CUDAStream(); 116 | kernel_wrapper(stream); 117 | errno = 0; 118 | } 119 | catch (const std::invalid_argument&) 120 | { 121 | errno = EINVAL; 122 | } 123 | } 124 | 125 | void benchmark_destroy(benchmark_ctx* ctx) 126 | { 127 | delete ctx; 128 | } 129 | -------------------------------------------------------------------------------- /benchmark/benchmark.h: -------------------------------------------------------------------------------- 1 | #ifndef BENCHMARK_H 2 | #define BENCHMARK_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | 10 | typedef struct benchmark_ctx benchmark_ctx; 11 | 12 | benchmark_ctx* benchmark_initialize(); 13 | 14 | void benchmark_execute(benchmark_ctx* ctx); 15 | 16 | void benchmark_destroy(benchmark_ctx* ctx); 17 | 18 | #ifdef __cplusplus 19 | } 20 | #endif 21 | 22 | #endif // BENCHMARK_H 23 | -------------------------------------------------------------------------------- /benchmark/kernel.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | #include 4 | #include 5 | 6 | __global__ void empty_kernel() 7 | { 8 | } 9 | 10 | void kernel_wrapper(cudaStream_t stream) 11 | { 12 | empty_kernel<<<1, 1, 0, stream>>>(); 13 | cudaError_t st = cudaStreamSynchronize(stream); 14 | if (st != cudaSuccess) 15 | throw std::invalid_argument("could not launch CUDA kernel"); 16 | } 17 | -------------------------------------------------------------------------------- /benchmark/kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_H 2 | #define KERNEL_H 3 | 4 | void kernel_wrapper(cudaStream_t stream); 5 | 6 | #endif // KERNEL_H 7 | -------------------------------------------------------------------------------- /benchmark/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // #cgo pkg-config: cudart-7.5 4 | // #cgo LDFLAGS: -L${SRCDIR} -lkernel 5 | // #cgo CXXFLAGS: -std=c++11 -I.. -O2 -fomit-frame-pointer -Wall 6 | // #include 7 | // #include "benchmark.h" 8 | import "C" 9 | 10 | import ( 11 | "log" 12 | "net/http" 13 | ) 14 | 15 | var ctx *C.benchmark_ctx 16 | 17 | func handleRequest(w http.ResponseWriter, r *http.Request) { 18 | _, err := C.benchmark_execute(ctx) 19 | if err != nil { 20 | http.Error(w, err.Error(), http.StatusInternalServerError) 21 | return 22 | } 23 | } 24 | 25 | func main() { 26 | log.Println("Initializing benchmark context") 27 | var err error 28 | ctx, err = C.benchmark_initialize() 29 | if err != nil { 30 | log.Fatalln("could not initialize benchmark context:", err) 31 | return 32 | } 33 | defer C.benchmark_destroy(ctx) 34 | 35 | log.Println("Adding REST endpoint /benchmark") 36 | http.HandleFunc("/benchmark", handleRequest) 37 | log.Println("Starting server listening on :8000") 38 | log.Fatal(http.ListenAndServe(":8000", nil)) 39 | } 40 | -------------------------------------------------------------------------------- /caffe/classification.cpp: -------------------------------------------------------------------------------- 1 | #include "classification.h" 2 | 3 | #include 4 | #include 5 | 6 | #define USE_CUDNN 1 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "common.h" 14 | #include "gpu_allocator.h" 15 | 16 | using namespace caffe; 17 | using std::string; 18 | using GpuMat = cv::cuda::GpuMat; 19 | using namespace cv; 20 | 21 | /* Pair (label, confidence) representing a prediction. */ 22 | typedef std::pair Prediction; 23 | 24 | /* Based on the cpp_classification example of Caffe, but with GPU 25 | * image preprocessing and a simple memory pool. */ 26 | class Classifier 27 | { 28 | public: 29 | Classifier(const string& model_file, 30 | const string& trained_file, 31 | const string& mean_file, 32 | const string& label_file, 33 | GPUAllocator* allocator); 34 | 35 | std::vector Classify(const Mat& img, int N = 5); 36 | 37 | private: 38 | void SetMean(const string& mean_file); 39 | 40 | std::vector Predict(const Mat& img); 41 | 42 | void WrapInputLayer(std::vector* input_channels); 43 | 44 | void Preprocess(const Mat& img, 45 | std::vector* input_channels); 46 | 47 | private: 48 | GPUAllocator* allocator_; 49 | std::shared_ptr> net_; 50 | Size input_geometry_; 51 | int num_channels_; 52 | GpuMat mean_; 53 | std::vector labels_; 54 | }; 55 | 56 | Classifier::Classifier(const string& model_file, 57 | const string& trained_file, 58 | const string& mean_file, 59 | const string& label_file, 60 | GPUAllocator* allocator) 61 | : allocator_(allocator) 62 | { 63 | Caffe::set_mode(Caffe::GPU); 64 | 65 | /* Load the network. */ 66 | net_ = std::make_shared>(model_file, TEST); 67 | net_->CopyTrainedLayersFrom(trained_file); 68 | 69 | CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input."; 70 | CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output."; 71 | 72 | Blob* input_layer = net_->input_blobs()[0]; 73 | num_channels_ = input_layer->channels(); 74 | CHECK(num_channels_ == 3 || num_channels_ == 1) 75 | << "Input layer should have 1 or 3 channels."; 76 | input_geometry_ = Size(input_layer->width(), input_layer->height()); 77 | 78 | /* Load the binaryproto mean file. */ 79 | SetMean(mean_file); 80 | 81 | /* Load labels. */ 82 | std::ifstream labels(label_file.c_str()); 83 | CHECK(labels) << "Unable to open labels file " << label_file; 84 | string line; 85 | while (std::getline(labels, line)) 86 | labels_.push_back(string(line)); 87 | 88 | Blob* output_layer = net_->output_blobs()[0]; 89 | CHECK_EQ(labels_.size(), output_layer->channels()) 90 | << "Number of labels is different from the output layer dimension."; 91 | } 92 | 93 | static bool PairCompare(const std::pair& lhs, 94 | const std::pair& rhs) 95 | { 96 | return lhs.first > rhs.first; 97 | } 98 | 99 | /* Return the indices of the top N values of vector v. */ 100 | static std::vector Argmax(const std::vector& v, int N) 101 | { 102 | std::vector> pairs; 103 | for (size_t i = 0; i < v.size(); ++i) 104 | pairs.push_back(std::make_pair(v[i], i)); 105 | std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare); 106 | 107 | std::vector result; 108 | for (int i = 0; i < N; ++i) 109 | result.push_back(pairs[i].second); 110 | return result; 111 | } 112 | 113 | /* Return the top N predictions. */ 114 | std::vector Classifier::Classify(const Mat& img, int N) 115 | { 116 | std::vector output = Predict(img); 117 | 118 | N = std::min(labels_.size(), N); 119 | std::vector maxN = Argmax(output, N); 120 | std::vector predictions; 121 | for (int i = 0; i < N; ++i) 122 | { 123 | int idx = maxN[i]; 124 | predictions.push_back(std::make_pair(labels_[idx], output[idx])); 125 | } 126 | 127 | return predictions; 128 | } 129 | 130 | /* Load the mean file in binaryproto format. */ 131 | void Classifier::SetMean(const string& mean_file) 132 | { 133 | BlobProto blob_proto; 134 | ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); 135 | 136 | /* Convert from BlobProto to Blob */ 137 | Blob mean_blob; 138 | mean_blob.FromProto(blob_proto); 139 | CHECK_EQ(mean_blob.channels(), num_channels_) 140 | << "Number of channels of mean file doesn't match input layer."; 141 | 142 | /* The format of the mean file is planar 32-bit float BGR or grayscale. */ 143 | std::vector channels; 144 | float* data = mean_blob.mutable_cpu_data(); 145 | for (int i = 0; i < num_channels_; ++i) 146 | { 147 | /* Extract an individual channel. */ 148 | Mat channel(mean_blob.height(), mean_blob.width(), CV_32FC1, data); 149 | channels.push_back(channel); 150 | data += mean_blob.height() * mean_blob.width(); 151 | } 152 | 153 | /* Merge the separate channels into a single image. */ 154 | Mat packed_mean; 155 | merge(channels, packed_mean); 156 | 157 | /* Compute the global mean pixel value and create a mean image 158 | * filled with this value. */ 159 | Scalar channel_mean = mean(packed_mean); 160 | Mat host_mean = Mat(input_geometry_, packed_mean.type(), channel_mean); 161 | mean_.upload(host_mean); 162 | } 163 | 164 | std::vector Classifier::Predict(const Mat& img) 165 | { 166 | Blob* input_layer = net_->input_blobs()[0]; 167 | input_layer->Reshape(1, num_channels_, 168 | input_geometry_.height, input_geometry_.width); 169 | /* Forward dimension change to all layers. */ 170 | net_->Reshape(); 171 | 172 | std::vector input_channels; 173 | WrapInputLayer(&input_channels); 174 | 175 | Preprocess(img, &input_channels); 176 | 177 | net_->Forward(); 178 | 179 | /* Copy the output layer to a std::vector */ 180 | Blob* output_layer = net_->output_blobs()[0]; 181 | const float* begin = output_layer->cpu_data(); 182 | const float* end = begin + output_layer->channels(); 183 | return std::vector(begin, end); 184 | } 185 | 186 | /* Wrap the input layer of the network in separate GpuMat objects 187 | * (one per channel). This way we save one memcpy operation and we 188 | * don't need to rely on cudaMemcpy2D. The last preprocessing 189 | * operation will write the separate channels directly to the input 190 | * layer. */ 191 | void Classifier::WrapInputLayer(std::vector* input_channels) 192 | { 193 | Blob* input_layer = net_->input_blobs()[0]; 194 | 195 | int width = input_layer->width(); 196 | int height = input_layer->height(); 197 | float* input_data = input_layer->mutable_gpu_data(); 198 | for (int i = 0; i < input_layer->channels(); ++i) 199 | { 200 | GpuMat channel(height, width, CV_32FC1, input_data); 201 | input_channels->push_back(channel); 202 | input_data += width * height; 203 | } 204 | } 205 | 206 | void Classifier::Preprocess(const Mat& host_img, 207 | std::vector* input_channels) 208 | { 209 | GpuMat img(host_img, allocator_); 210 | /* Convert the input image to the input image format of the network. */ 211 | GpuMat sample(allocator_); 212 | if (img.channels() == 3 && num_channels_ == 1) 213 | cuda::cvtColor(img, sample, CV_BGR2GRAY); 214 | else if (img.channels() == 4 && num_channels_ == 1) 215 | cuda::cvtColor(img, sample, CV_BGRA2GRAY); 216 | else if (img.channels() == 4 && num_channels_ == 3) 217 | cuda::cvtColor(img, sample, CV_BGRA2BGR); 218 | else if (img.channels() == 1 && num_channels_ == 3) 219 | cuda::cvtColor(img, sample, CV_GRAY2BGR); 220 | else 221 | sample = img; 222 | 223 | GpuMat sample_resized(allocator_); 224 | if (sample.size() != input_geometry_) 225 | cuda::resize(sample, sample_resized, input_geometry_); 226 | else 227 | sample_resized = sample; 228 | 229 | GpuMat sample_float(allocator_); 230 | if (num_channels_ == 3) 231 | sample_resized.convertTo(sample_float, CV_32FC3); 232 | else 233 | sample_resized.convertTo(sample_float, CV_32FC1); 234 | 235 | GpuMat sample_normalized(allocator_); 236 | cuda::subtract(sample_float, mean_, sample_normalized); 237 | 238 | /* This operation will write the separate BGR planes directly to the 239 | * input layer of the network because it is wrapped by the GpuMat 240 | * objects in input_channels. */ 241 | cuda::split(sample_normalized, *input_channels); 242 | 243 | CHECK(reinterpret_cast(input_channels->at(0).data) 244 | == net_->input_blobs()[0]->gpu_data()) 245 | << "Input channels are not wrapping the input layer of the network."; 246 | } 247 | 248 | /* By using Go as the HTTP server, we have potentially more CPU threads than 249 | * available GPUs and more threads can be added on the fly by the Go 250 | * runtime. Therefore we cannot pin the CPU threads to specific GPUs. Instead, 251 | * when a CPU thread is ready for inference it will try to retrieve an 252 | * execution context from a queue of available GPU contexts and then do a 253 | * cudaSetDevice() to prepare for execution. Multiple contexts can be allocated 254 | * per GPU. */ 255 | class ExecContext 256 | { 257 | public: 258 | friend ScopedContext; 259 | 260 | static bool IsCompatible(int device) 261 | { 262 | cudaError_t st = cudaSetDevice(device); 263 | if (st != cudaSuccess) 264 | return false; 265 | 266 | cuda::DeviceInfo info; 267 | if (!info.isCompatible()) 268 | return false; 269 | 270 | return true; 271 | } 272 | 273 | ExecContext(const string& model_file, 274 | const string& trained_file, 275 | const string& mean_file, 276 | const string& label_file, 277 | int device) 278 | : device_(device) 279 | { 280 | cudaError_t st = cudaSetDevice(device_); 281 | if (st != cudaSuccess) 282 | throw std::invalid_argument("could not set CUDA device"); 283 | 284 | allocator_.reset(new GPUAllocator(1024 * 1024 * 128)); 285 | caffe_context_.reset(new Caffe); 286 | Caffe::Set(caffe_context_.get()); 287 | classifier_.reset(new Classifier(model_file, trained_file, 288 | mean_file, label_file, 289 | allocator_.get())); 290 | Caffe::Set(nullptr); 291 | } 292 | 293 | Classifier* CaffeClassifier() 294 | { 295 | return classifier_.get(); 296 | } 297 | 298 | private: 299 | void Activate() 300 | { 301 | cudaError_t st = cudaSetDevice(device_); 302 | if (st != cudaSuccess) 303 | throw std::invalid_argument("could not set CUDA device"); 304 | allocator_->reset(); 305 | Caffe::Set(caffe_context_.get()); 306 | } 307 | 308 | void Deactivate() 309 | { 310 | Caffe::Set(nullptr); 311 | } 312 | 313 | private: 314 | int device_; 315 | std::unique_ptr allocator_; 316 | std::unique_ptr caffe_context_; 317 | std::unique_ptr classifier_; 318 | }; 319 | 320 | struct classifier_ctx 321 | { 322 | ContextPool pool; 323 | }; 324 | 325 | /* Currently, 2 execution contexts are created per GPU. In other words, 2 326 | * inference tasks can execute in parallel on the same GPU. This helps improve 327 | * GPU utilization since some kernel operations of inference will not fully use 328 | * the GPU. */ 329 | constexpr static int kContextsPerDevice = 2; 330 | 331 | classifier_ctx* classifier_initialize(char* model_file, char* trained_file, 332 | char* mean_file, char* label_file) 333 | { 334 | try 335 | { 336 | ::google::InitGoogleLogging("inference_server"); 337 | 338 | int device_count; 339 | cudaError_t st = cudaGetDeviceCount(&device_count); 340 | if (st != cudaSuccess) 341 | throw std::invalid_argument("could not list CUDA devices"); 342 | 343 | ContextPool pool; 344 | for (int dev = 0; dev < device_count; ++dev) 345 | { 346 | if (!ExecContext::IsCompatible(dev)) 347 | { 348 | LOG(ERROR) << "Skipping device: " << dev; 349 | continue; 350 | } 351 | 352 | for (int i = 0; i < kContextsPerDevice; ++i) 353 | { 354 | std::unique_ptr context(new ExecContext(model_file, trained_file, 355 | mean_file, label_file, dev)); 356 | pool.Push(std::move(context)); 357 | } 358 | } 359 | 360 | if (pool.Size() == 0) 361 | throw std::invalid_argument("no suitable CUDA device"); 362 | 363 | classifier_ctx* ctx = new classifier_ctx{std::move(pool)}; 364 | /* Successful CUDA calls can set errno. */ 365 | errno = 0; 366 | return ctx; 367 | } 368 | catch (const std::invalid_argument& ex) 369 | { 370 | LOG(ERROR) << "exception: " << ex.what(); 371 | errno = EINVAL; 372 | return nullptr; 373 | } 374 | } 375 | 376 | const char* classifier_classify(classifier_ctx* ctx, 377 | char* buffer, size_t length) 378 | { 379 | try 380 | { 381 | _InputArray array(buffer, length); 382 | 383 | Mat img = imdecode(array, -1); 384 | if (img.empty()) 385 | throw std::invalid_argument("could not decode image"); 386 | 387 | std::vector predictions; 388 | { 389 | /* In this scope an execution context is acquired for inference and it 390 | * will be automatically released back to the context pool when 391 | * exiting this scope. */ 392 | ScopedContext context(ctx->pool); 393 | auto classifier = context->CaffeClassifier(); 394 | predictions = classifier->Classify(img); 395 | } 396 | 397 | /* Write the top N predictions in JSON format. */ 398 | std::ostringstream os; 399 | os << "["; 400 | for (size_t i = 0; i < predictions.size(); ++i) 401 | { 402 | Prediction p = predictions[i]; 403 | os << "{\"confidence\":" << std::fixed << std::setprecision(4) 404 | << p.second << ","; 405 | os << "\"label\":" << "\"" << p.first << "\"" << "}"; 406 | if (i != predictions.size() - 1) 407 | os << ","; 408 | } 409 | os << "]"; 410 | 411 | errno = 0; 412 | std::string str = os.str(); 413 | return strdup(str.c_str()); 414 | } 415 | catch (const std::invalid_argument&) 416 | { 417 | errno = EINVAL; 418 | return nullptr; 419 | } 420 | } 421 | 422 | void classifier_destroy(classifier_ctx* ctx) 423 | { 424 | delete ctx; 425 | } 426 | -------------------------------------------------------------------------------- /caffe/classification.h: -------------------------------------------------------------------------------- 1 | #ifndef CLASSIFICATION_H 2 | #define CLASSIFICATION_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | 10 | typedef struct classifier_ctx classifier_ctx; 11 | 12 | classifier_ctx* classifier_initialize(char* model_file, char* trained_file, 13 | char* mean_file, char* label_file); 14 | 15 | const char* classifier_classify(classifier_ctx* ctx, 16 | char* buffer, size_t length); 17 | 18 | void classifier_destroy(classifier_ctx* ctx); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif // CLASSIFICATION_H 25 | -------------------------------------------------------------------------------- /caffe/gpu_allocator.cpp: -------------------------------------------------------------------------------- 1 | #include "gpu_allocator.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // Page offset on Kepler/Maxwell 8 | #define ALIGNMENT (128*1024) 9 | 10 | GPUAllocator::GPUAllocator(size_t size) 11 | : total_size_(size), 12 | current_size_(0) 13 | { 14 | cudaError_t rc = cudaMalloc(&base_ptr_, total_size_); 15 | if (rc != cudaSuccess) 16 | throw std::runtime_error("Could not allocate GPU memory"); 17 | 18 | current_ptr_ = base_ptr_; 19 | } 20 | 21 | GPUAllocator::~GPUAllocator() 22 | { 23 | cudaFree(base_ptr_); 24 | } 25 | 26 | static int align_up(unsigned int v, unsigned int alignment) 27 | { 28 | return ((v + alignment - 1) / alignment) * alignment; 29 | } 30 | 31 | cudaError_t GPUAllocator::grow(void** dev_ptr, size_t size) 32 | { 33 | if (current_size_ + size >= total_size_) 34 | return cudaErrorMemoryAllocation; 35 | 36 | *dev_ptr = current_ptr_; 37 | size_t aligned_size = align_up(size, ALIGNMENT); 38 | current_ptr_ = (char*)current_ptr_ + aligned_size; 39 | current_size_ += aligned_size; 40 | 41 | return cudaSuccess; 42 | } 43 | 44 | void GPUAllocator::reset() 45 | { 46 | current_ptr_ = base_ptr_; 47 | current_size_ = 0; 48 | } 49 | 50 | bool GPUAllocator::allocate(cv::cuda::GpuMat* mat, int rows, int cols, size_t elemSize) 51 | { 52 | int padded_width = align_up(cols, 16); 53 | int padded_height = align_up(rows, 16); 54 | int total_size = elemSize * padded_width * padded_height; 55 | 56 | cudaError_t status = grow((void**)&mat->data, total_size); 57 | if (status != cudaSuccess) 58 | return false; 59 | 60 | mat->step = padded_width * elemSize; 61 | mat->refcount = new int; 62 | 63 | return true; 64 | } 65 | 66 | void GPUAllocator::free(cv::cuda::GpuMat* mat) 67 | { 68 | delete mat->refcount; 69 | } 70 | -------------------------------------------------------------------------------- /caffe/gpu_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef GPU_ALLOCATOR_H 2 | #define GPU_ALLOCATOR_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using GpuMat = cv::cuda::GpuMat; 9 | using namespace cv; 10 | 11 | /* A simple linear allocator class to allocate storage for cv::GpuMat objects. 12 | This feature was added in OpenCV 3.0. */ 13 | class GPUAllocator : public GpuMat::Allocator 14 | { 15 | public: 16 | GPUAllocator(size_t size); 17 | 18 | ~GPUAllocator(); 19 | 20 | void reset(); 21 | 22 | public: /* GpuMat::Allocator interface */ 23 | bool allocate(GpuMat* mat, int rows, int cols, size_t elemSize); 24 | 25 | void free(GpuMat* mat); 26 | 27 | private: 28 | cudaError_t grow(void** dev_ptr, size_t size); 29 | 30 | private: 31 | void* base_ptr_; 32 | void* current_ptr_; 33 | size_t total_size_; 34 | size_t current_size_; 35 | }; 36 | 37 | #endif // GPU_ALLOCATOR_H 38 | -------------------------------------------------------------------------------- /caffe/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // #cgo pkg-config: opencv cudart-9.0 4 | // #cgo LDFLAGS: -Lcaffe/lib -lcaffe -lglog -lboost_system -lboost_thread 5 | // #cgo CXXFLAGS: -std=c++11 -Icaffe/include -I.. -O2 -fomit-frame-pointer -Wall 6 | // #include 7 | // #include "classification.h" 8 | import "C" 9 | import "unsafe" 10 | 11 | import ( 12 | "io" 13 | "io/ioutil" 14 | "log" 15 | "net/http" 16 | "os" 17 | ) 18 | 19 | var ctx *C.classifier_ctx 20 | 21 | func classify(w http.ResponseWriter, r *http.Request) { 22 | if r.Method != "POST" { 23 | http.Error(w, "", http.StatusMethodNotAllowed) 24 | return 25 | } 26 | 27 | buffer, err := ioutil.ReadAll(r.Body) 28 | if err != nil { 29 | http.Error(w, err.Error(), http.StatusBadRequest) 30 | return 31 | } 32 | 33 | cstr, err := C.classifier_classify(ctx, (*C.char)(unsafe.Pointer(&buffer[0])), C.size_t(len(buffer))) 34 | if err != nil { 35 | http.Error(w, err.Error(), http.StatusBadRequest) 36 | return 37 | } 38 | defer C.free(unsafe.Pointer(cstr)) 39 | io.WriteString(w, C.GoString(cstr)) 40 | } 41 | 42 | func main() { 43 | cmodel := C.CString(os.Args[1]) 44 | ctrained := C.CString(os.Args[2]) 45 | cmean := C.CString(os.Args[3]) 46 | clabel := C.CString(os.Args[4]) 47 | 48 | log.Println("Initializing Caffe classifiers") 49 | var err error 50 | ctx, err = C.classifier_initialize(cmodel, ctrained, cmean, clabel) 51 | if err != nil { 52 | log.Fatalln("could not initialize classifier:", err) 53 | return 54 | } 55 | defer C.classifier_destroy(ctx) 56 | 57 | log.Println("Adding REST endpoint /api/classify") 58 | http.HandleFunc("/api/classify", classify) 59 | log.Println("Starting server listening on :8000") 60 | log.Fatal(http.ListenAndServe(":8000", nil)) 61 | } 62 | -------------------------------------------------------------------------------- /common.h: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_H 2 | #define COMMON_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /* A simple threadsafe queue using a mutex and a condition variable. */ 10 | template 11 | class Queue 12 | { 13 | public: 14 | Queue() = default; 15 | 16 | Queue(Queue&& other) 17 | { 18 | std::unique_lock lock(other.mutex_); 19 | queue_ = std::move(other.queue_); 20 | } 21 | 22 | void Push(T value) 23 | { 24 | std::unique_lock lock(mutex_); 25 | queue_.push(std::move(value)); 26 | lock.unlock(); 27 | cond_.notify_one(); 28 | } 29 | 30 | T Pop() 31 | { 32 | std::unique_lock lock(mutex_); 33 | cond_.wait(lock, [this]{return !queue_.empty();}); 34 | T value = std::move(queue_.front()); 35 | queue_.pop(); 36 | return value; 37 | } 38 | 39 | size_t Size() 40 | { 41 | std::unique_lock lock(mutex_); 42 | return queue_.size(); 43 | } 44 | 45 | private: 46 | mutable std::mutex mutex_; 47 | std::queue queue_; 48 | std::condition_variable cond_; 49 | }; 50 | 51 | /* A pool of available contexts is simply implemented as a queue for our example. */ 52 | template 53 | using ContextPool = Queue>; 54 | 55 | /* A RAII class for acquiring an execution context from a context pool. */ 56 | template 57 | class ScopedContext 58 | { 59 | public: 60 | explicit ScopedContext(ContextPool& pool) 61 | : pool_(pool), context_(pool_.Pop()) 62 | { 63 | context_->Activate(); 64 | } 65 | 66 | ~ScopedContext() 67 | { 68 | context_->Deactivate(); 69 | pool_.Push(std::move(context_)); 70 | } 71 | 72 | Context* operator->() const 73 | { 74 | return context_.get(); 75 | } 76 | 77 | private: 78 | ContextPool& pool_; 79 | std::unique_ptr context_; 80 | }; 81 | 82 | #endif // COMMON_H 83 | -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/2.jpg -------------------------------------------------------------------------------- /images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/3.jpg -------------------------------------------------------------------------------- /images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/4.jpg -------------------------------------------------------------------------------- /images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/5.jpg -------------------------------------------------------------------------------- /images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/gpu-rest-engine/30da8c32d31786b1ed0981560f57f695086f1c93/images/6.jpg -------------------------------------------------------------------------------- /tensorrt/classification.cpp: -------------------------------------------------------------------------------- 1 | #include "classification.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "NvInfer.h" 16 | #include "NvCaffeParser.h" 17 | 18 | #include "common.h" 19 | #include "gpu_allocator.h" 20 | 21 | using namespace nvinfer1; 22 | using namespace nvcaffeparser1; 23 | using std::string; 24 | using GpuMat = cuda::GpuMat; 25 | using namespace cv; 26 | 27 | /* Pair (label, confidence) representing a prediction. */ 28 | typedef std::pair Prediction; 29 | 30 | class Logger : public ILogger 31 | { 32 | void log(Severity severity, const char* msg) override 33 | { 34 | // suppress info-level messages 35 | if (severity != Severity::kINFO) 36 | std::cout << msg << std::endl; 37 | } 38 | }; 39 | static Logger gLogger; 40 | 41 | class InferenceEngine 42 | { 43 | public: 44 | InferenceEngine(const string& model_file, 45 | const string& trained_file); 46 | 47 | ~InferenceEngine(); 48 | 49 | ICudaEngine* Get() const 50 | { 51 | return engine_; 52 | } 53 | 54 | private: 55 | ICudaEngine* engine_; 56 | }; 57 | 58 | InferenceEngine::InferenceEngine(const string& model_file, 59 | const string& trained_file) 60 | { 61 | IBuilder* builder = createInferBuilder(gLogger); 62 | 63 | // parse the caffe model to populate the network, then set the outputs 64 | INetworkDefinition* network = builder->createNetwork(); 65 | 66 | ICaffeParser* parser = createCaffeParser(); 67 | auto blob_name_to_tensor = parser->parse(model_file.c_str(), 68 | trained_file.c_str(), 69 | *network, 70 | nvinfer1::DataType::kFLOAT); 71 | CHECK(blob_name_to_tensor) << "Could not parse the model"; 72 | 73 | // specify which tensors are outputs 74 | network->markOutput(*blob_name_to_tensor->find("prob")); 75 | 76 | // Build the engine 77 | builder->setMaxBatchSize(1); 78 | builder->setMaxWorkspaceSize(1 << 30); 79 | 80 | engine_ = builder->buildCudaEngine(*network); 81 | CHECK(engine_) << "Failed to create inference engine."; 82 | 83 | network->destroy(); 84 | builder->destroy(); 85 | } 86 | 87 | InferenceEngine::~InferenceEngine() 88 | { 89 | engine_->destroy(); 90 | } 91 | 92 | class Classifier 93 | { 94 | public: 95 | Classifier(std::shared_ptr engine, 96 | const string& mean_file, 97 | const string& label_file, 98 | GPUAllocator* allocator); 99 | 100 | ~Classifier(); 101 | 102 | std::vector Classify(const Mat& img, int N = 5); 103 | 104 | private: 105 | void SetModel(); 106 | 107 | void SetMean(const string& mean_file); 108 | 109 | void SetLabels(const string& label_file); 110 | 111 | std::vector Predict(const Mat& img); 112 | 113 | void WrapInputLayer(std::vector* input_channels); 114 | 115 | void Preprocess(const Mat& img, 116 | std::vector* input_channels); 117 | 118 | private: 119 | GPUAllocator* allocator_; 120 | std::shared_ptr engine_; 121 | IExecutionContext* context_; 122 | GpuMat mean_; 123 | std::vector labels_; 124 | DimsCHW input_dim_; 125 | Size input_cv_size_; 126 | float* input_layer_; 127 | DimsCHW output_dim_; 128 | float* output_layer_; 129 | }; 130 | 131 | Classifier::Classifier(std::shared_ptr engine, 132 | const string& mean_file, 133 | const string& label_file, 134 | GPUAllocator* allocator) 135 | : allocator_(allocator), 136 | engine_(engine) 137 | { 138 | SetModel(); 139 | SetMean(mean_file); 140 | SetLabels(label_file); 141 | } 142 | 143 | Classifier::~Classifier() 144 | { 145 | context_->destroy(); 146 | CHECK_EQ(cudaFree(input_layer_), cudaSuccess) << "Could not free input layer"; 147 | CHECK_EQ(cudaFree(output_layer_), cudaSuccess) << "Could not free output layer"; 148 | } 149 | 150 | static bool PairCompare(const std::pair& lhs, 151 | const std::pair& rhs) 152 | { 153 | return lhs.first > rhs.first; 154 | } 155 | 156 | /* Return the indices of the top N values of vector v. */ 157 | static std::vector Argmax(const std::vector& v, int N) 158 | { 159 | std::vector> pairs; 160 | for (size_t i = 0; i < v.size(); ++i) 161 | pairs.push_back(std::make_pair(v[i], i)); 162 | std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare); 163 | 164 | std::vector result; 165 | for (int i = 0; i < N; ++i) 166 | result.push_back(pairs[i].second); 167 | return result; 168 | } 169 | 170 | /* Return the top N predictions. */ 171 | std::vector Classifier::Classify(const Mat& img, int N) 172 | { 173 | std::vector output = Predict(img); 174 | 175 | std::vector maxN = Argmax(output, N); 176 | std::vector predictions; 177 | for (int i = 0; i < N; ++i) 178 | { 179 | int idx = maxN[i]; 180 | predictions.push_back(std::make_pair(labels_[idx], output[idx])); 181 | } 182 | 183 | return predictions; 184 | } 185 | 186 | void Classifier::SetModel() 187 | { 188 | ICudaEngine* engine = engine_->Get(); 189 | 190 | context_ = engine->createExecutionContext(); 191 | CHECK(context_) << "Failed to create execution context."; 192 | 193 | int input_index = engine->getBindingIndex("data"); 194 | input_dim_ = static_cast(engine->getBindingDimensions(input_index)); 195 | input_cv_size_ = Size(input_dim_.w(), input_dim_.h()); 196 | // FIXME: could be wrapped in a thrust or GpuMat object. 197 | size_t input_size = input_dim_.c() * input_dim_.h() * input_dim_.w() * sizeof(float); 198 | cudaError_t st = cudaMalloc(&input_layer_, input_size); 199 | CHECK_EQ(st, cudaSuccess) << "Could not allocate input layer."; 200 | 201 | int output_index = engine->getBindingIndex("prob"); 202 | output_dim_ = static_cast(engine->getBindingDimensions(output_index)); 203 | size_t output_size = output_dim_.c() * output_dim_.h() * output_dim_.w() * sizeof(float); 204 | st = cudaMalloc(&output_layer_, output_size); 205 | CHECK_EQ(st, cudaSuccess) << "Could not allocate output layer."; 206 | } 207 | 208 | void Classifier::SetMean(const string& mean_file) 209 | { 210 | ICaffeParser* parser = createCaffeParser(); 211 | IBinaryProtoBlob* mean_blob = parser->parseBinaryProto(mean_file.c_str()); 212 | parser->destroy(); 213 | CHECK(mean_blob) << "Could not load mean file."; 214 | 215 | DimsNCHW mean_dim = mean_blob->getDimensions(); 216 | int c = mean_dim.c(); 217 | int h = mean_dim.h(); 218 | int w = mean_dim.w(); 219 | CHECK_EQ(c, input_dim_.c()) 220 | << "Number of channels of mean file doesn't match input layer."; 221 | 222 | /* The format of the mean file is planar 32-bit float BGR or grayscale. */ 223 | std::vector channels; 224 | float* data = (float*)mean_blob->getData(); 225 | for (int i = 0; i < c; ++i) 226 | { 227 | /* Extract an individual channel. */ 228 | Mat channel(h, w, CV_32FC1, data); 229 | channels.push_back(channel); 230 | data += h * w; 231 | } 232 | 233 | /* Merge the separate channels into a single image. */ 234 | Mat packed_mean; 235 | merge(channels, packed_mean); 236 | 237 | /* Compute the global mean pixel value and create a mean image 238 | * filled with this value. */ 239 | Scalar channel_mean = mean(packed_mean); 240 | Mat host_mean = Mat(input_cv_size_, packed_mean.type(), channel_mean); 241 | mean_.upload(host_mean); 242 | } 243 | 244 | void Classifier::SetLabels(const string& label_file) 245 | { 246 | std::ifstream labels(label_file.c_str()); 247 | CHECK(labels) << "Unable to open labels file " << label_file; 248 | string line; 249 | while (std::getline(labels, line)) 250 | labels_.push_back(string(line)); 251 | } 252 | 253 | std::vector Classifier::Predict(const Mat& img) 254 | { 255 | std::vector input_channels; 256 | WrapInputLayer(&input_channels); 257 | 258 | Preprocess(img, &input_channels); 259 | 260 | void* buffers[2] = { input_layer_, output_layer_ }; 261 | context_->execute(1, buffers); 262 | 263 | size_t output_size = output_dim_.c() * output_dim_.h() * output_dim_.w(); 264 | std::vector output(output_size); 265 | cudaError_t st = cudaMemcpy(output.data(), output_layer_, output_size * sizeof(float), cudaMemcpyDeviceToHost); 266 | if (st != cudaSuccess) 267 | throw std::runtime_error("could not copy output layer back to host"); 268 | 269 | return output; 270 | } 271 | 272 | /* Wrap the input layer of the network in separate Mat objects 273 | * (one per channel). This way we save one memcpy operation and we 274 | * don't need to rely on cudaMemcpy2D. The last preprocessing 275 | * operation will write the separate channels directly to the input 276 | * layer. */ 277 | void Classifier::WrapInputLayer(std::vector* input_channels) 278 | { 279 | int width = input_dim_.w(); 280 | int height = input_dim_.h(); 281 | float* input_data = input_layer_; 282 | for (int i = 0; i < input_dim_.c(); ++i) 283 | { 284 | GpuMat channel(height, width, CV_32FC1, input_data); 285 | input_channels->push_back(channel); 286 | input_data += width * height; 287 | } 288 | } 289 | 290 | void Classifier::Preprocess(const Mat& host_img, 291 | std::vector* input_channels) 292 | { 293 | int num_channels = input_dim_.c(); 294 | GpuMat img(host_img, allocator_); 295 | /* Convert the input image to the input image format of the network. */ 296 | GpuMat sample(allocator_); 297 | if (img.channels() == 3 && num_channels == 1) 298 | cuda::cvtColor(img, sample, CV_BGR2GRAY); 299 | else if (img.channels() == 4 && num_channels == 1) 300 | cuda::cvtColor(img, sample, CV_BGRA2GRAY); 301 | else if (img.channels() == 4 && num_channels == 3) 302 | cuda::cvtColor(img, sample, CV_BGRA2BGR); 303 | else if (img.channels() == 1 && num_channels == 3) 304 | cuda::cvtColor(img, sample, CV_GRAY2BGR); 305 | else 306 | sample = img; 307 | 308 | GpuMat sample_resized(allocator_); 309 | if (sample.size() != input_cv_size_) 310 | cuda::resize(sample, sample_resized, input_cv_size_); 311 | else 312 | sample_resized = sample; 313 | 314 | GpuMat sample_float(allocator_); 315 | if (num_channels == 3) 316 | sample_resized.convertTo(sample_float, CV_32FC3); 317 | else 318 | sample_resized.convertTo(sample_float, CV_32FC1); 319 | 320 | GpuMat sample_normalized(allocator_); 321 | cuda::subtract(sample_float, mean_, sample_normalized); 322 | 323 | /* This operation will write the separate BGR planes directly to the 324 | * input layer of the network because it is wrapped by the Mat 325 | * objects in input_channels. */ 326 | cuda::split(sample_normalized, *input_channels); 327 | 328 | CHECK(reinterpret_cast(input_channels->at(0).data) == input_layer_) 329 | << "Input channels are not wrapping the input layer of the network."; 330 | } 331 | 332 | /* By using Go as the HTTP server, we have potentially more CPU threads than 333 | * available GPUs and more threads can be added on the fly by the Go 334 | * runtime. Therefore we cannot pin the CPU threads to specific GPUs. Instead, 335 | * when a CPU thread is ready for inference it will try to retrieve an 336 | * execution context from a queue of available GPU contexts and then do a 337 | * cudaSetDevice() to prepare for execution. Multiple contexts can be allocated 338 | * per GPU. */ 339 | class ExecContext 340 | { 341 | public: 342 | friend ScopedContext; 343 | 344 | static bool IsCompatible(int device) 345 | { 346 | cudaError_t st = cudaSetDevice(device); 347 | if (st != cudaSuccess) 348 | return false; 349 | 350 | cuda::DeviceInfo dev_info; 351 | if (dev_info.majorVersion() < 3) 352 | return false; 353 | 354 | return true; 355 | } 356 | 357 | ExecContext(std::shared_ptr engine, 358 | const string& mean_file, 359 | const string& label_file, 360 | int device) 361 | : device_(device) 362 | { 363 | cudaError_t st = cudaSetDevice(device_); 364 | 365 | if (st != cudaSuccess) 366 | throw std::invalid_argument("could not set CUDA device"); 367 | 368 | allocator_.reset(new GPUAllocator(1024 * 1024 * 128)); 369 | classifier_.reset(new Classifier(engine, mean_file, label_file, allocator_.get())); 370 | } 371 | 372 | Classifier* TensorRTClassifier() 373 | { 374 | return classifier_.get(); 375 | } 376 | 377 | private: 378 | void Activate() 379 | { 380 | cudaError_t st = cudaSetDevice(device_); 381 | if (st != cudaSuccess) 382 | throw std::invalid_argument("could not set CUDA device"); 383 | allocator_->reset(); 384 | } 385 | 386 | void Deactivate() 387 | { 388 | } 389 | 390 | private: 391 | int device_; 392 | std::unique_ptr allocator_; 393 | std::unique_ptr classifier_; 394 | }; 395 | 396 | struct classifier_ctx 397 | { 398 | ContextPool pool; 399 | }; 400 | 401 | constexpr static int kContextsPerDevice = 2; 402 | 403 | classifier_ctx* classifier_initialize(char* model_file, char* trained_file, 404 | char* mean_file, char* label_file) 405 | { 406 | try 407 | { 408 | ::google::InitGoogleLogging("inference_server"); 409 | 410 | int device_count; 411 | cudaError_t st = cudaGetDeviceCount(&device_count); 412 | if (st != cudaSuccess) 413 | throw std::invalid_argument("could not list CUDA devices"); 414 | 415 | ContextPool pool; 416 | for (int dev = 0; dev < device_count; ++dev) 417 | { 418 | if (!ExecContext::IsCompatible(dev)) 419 | { 420 | LOG(ERROR) << "Skipping device: " << dev; 421 | continue; 422 | } 423 | 424 | std::shared_ptr engine(new InferenceEngine(model_file, trained_file)); 425 | 426 | for (int i = 0; i < kContextsPerDevice; ++i) 427 | { 428 | std::unique_ptr context(new ExecContext(engine, mean_file, 429 | label_file, dev)); 430 | pool.Push(std::move(context)); 431 | } 432 | } 433 | 434 | if (pool.Size() == 0) 435 | throw std::invalid_argument("no suitable CUDA device"); 436 | 437 | classifier_ctx* ctx = new classifier_ctx{std::move(pool)}; 438 | /* Successful CUDA calls can set errno. */ 439 | errno = 0; 440 | return ctx; 441 | } 442 | catch (const std::invalid_argument& ex) 443 | { 444 | LOG(ERROR) << "exception: " << ex.what(); 445 | errno = EINVAL; 446 | return nullptr; 447 | } 448 | } 449 | 450 | const char* classifier_classify(classifier_ctx* ctx, 451 | char* buffer, size_t length) 452 | { 453 | try 454 | { 455 | _InputArray array(buffer, length); 456 | 457 | Mat img = imdecode(array, -1); 458 | if (img.empty()) 459 | throw std::invalid_argument("could not decode image"); 460 | 461 | std::vector predictions; 462 | { 463 | /* In this scope an execution context is acquired for inference and it 464 | * will be automatically released back to the context pool when 465 | * exiting this scope. */ 466 | ScopedContext context(ctx->pool); 467 | auto classifier = context->TensorRTClassifier(); 468 | predictions = classifier->Classify(img); 469 | } 470 | 471 | /* Write the top N predictions in JSON format. */ 472 | std::ostringstream os; 473 | os << "["; 474 | for (size_t i = 0; i < predictions.size(); ++i) 475 | { 476 | Prediction p = predictions[i]; 477 | os << "{\"confidence\":" << std::fixed << std::setprecision(4) 478 | << p.second << ","; 479 | os << "\"label\":" << "\"" << p.first << "\"" << "}"; 480 | if (i != predictions.size() - 1) 481 | os << ","; 482 | } 483 | os << "]"; 484 | 485 | errno = 0; 486 | std::string str = os.str(); 487 | return strdup(str.c_str()); 488 | } 489 | catch (const std::invalid_argument&) 490 | { 491 | errno = EINVAL; 492 | return nullptr; 493 | } 494 | } 495 | 496 | void classifier_destroy(classifier_ctx* ctx) 497 | { 498 | delete ctx; 499 | } 500 | -------------------------------------------------------------------------------- /tensorrt/classification.h: -------------------------------------------------------------------------------- 1 | #ifndef CLASSIFICATION_H 2 | #define CLASSIFICATION_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include 9 | 10 | typedef struct classifier_ctx classifier_ctx; 11 | 12 | classifier_ctx* classifier_initialize(char* model_file, char* trained_file, 13 | char* mean_file, char* label_file); 14 | 15 | const char* classifier_classify(classifier_ctx* ctx, 16 | char* buffer, size_t length); 17 | 18 | void classifier_destroy(classifier_ctx* ctx); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /tensorrt/gpu_allocator.cpp: -------------------------------------------------------------------------------- 1 | #include "gpu_allocator.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // Page offset on Kepler/Maxwell 8 | #define ALIGNMENT (128*1024) 9 | 10 | GPUAllocator::GPUAllocator(size_t size) 11 | : total_size_(size), 12 | current_size_(0) 13 | { 14 | cudaError_t rc = cudaMalloc(&base_ptr_, total_size_); 15 | if (rc != cudaSuccess) 16 | throw std::runtime_error("Could not allocate GPU memory"); 17 | 18 | current_ptr_ = base_ptr_; 19 | } 20 | 21 | GPUAllocator::~GPUAllocator() 22 | { 23 | cudaFree(base_ptr_); 24 | } 25 | 26 | static int align_up(unsigned int v, unsigned int alignment) 27 | { 28 | return ((v + alignment - 1) / alignment) * alignment; 29 | } 30 | 31 | cudaError_t GPUAllocator::grow(void** dev_ptr, size_t size) 32 | { 33 | if (current_size_ + size >= total_size_) 34 | return cudaErrorMemoryAllocation; 35 | 36 | *dev_ptr = current_ptr_; 37 | size_t aligned_size = align_up(size, ALIGNMENT); 38 | current_ptr_ = (char*)current_ptr_ + aligned_size; 39 | current_size_ += aligned_size; 40 | 41 | return cudaSuccess; 42 | } 43 | 44 | void GPUAllocator::reset() 45 | { 46 | current_ptr_ = base_ptr_; 47 | current_size_ = 0; 48 | } 49 | 50 | bool GPUAllocator::allocate(cv::cuda::GpuMat* mat, int rows, int cols, size_t elemSize) 51 | { 52 | int padded_width = align_up(cols, 16); 53 | int padded_height = align_up(rows, 16); 54 | int total_size = elemSize * padded_width * padded_height; 55 | 56 | cudaError_t status = grow((void**)&mat->data, total_size); 57 | if (status != cudaSuccess) 58 | return false; 59 | 60 | mat->step = padded_width * elemSize; 61 | mat->refcount = new int; 62 | 63 | return true; 64 | } 65 | 66 | void GPUAllocator::free(cv::cuda::GpuMat* mat) 67 | { 68 | delete mat->refcount; 69 | } 70 | -------------------------------------------------------------------------------- /tensorrt/gpu_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef GPU_ALLOCATOR_H 2 | #define GPU_ALLOCATOR_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using GpuMat = cv::cuda::GpuMat; 9 | using namespace cv; 10 | 11 | /* A simple linear allocator class to allocate storage for cv::GpuMat objects. 12 | This feature was added in OpenCV 3.0. */ 13 | class GPUAllocator : public GpuMat::Allocator 14 | { 15 | public: 16 | GPUAllocator(size_t size); 17 | 18 | ~GPUAllocator(); 19 | 20 | void reset(); 21 | 22 | public: /* GpuMat::Allocator interface */ 23 | bool allocate(GpuMat* mat, int rows, int cols, size_t elemSize); 24 | 25 | void free(GpuMat* mat); 26 | 27 | private: 28 | cudaError_t grow(void** dev_ptr, size_t size); 29 | 30 | private: 31 | void* base_ptr_; 32 | void* current_ptr_; 33 | size_t total_size_; 34 | size_t current_size_; 35 | }; 36 | 37 | #endif // GPU_ALLOCATOR_H 38 | -------------------------------------------------------------------------------- /tensorrt/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // #cgo pkg-config: opencv cudart-8.0 4 | // #cgo LDFLAGS: -lnvinfer -lnvcaffe_parser -lglog -lboost_system -lboost_thread 5 | // #cgo CXXFLAGS: -std=c++11 -I.. -O2 -fomit-frame-pointer -Wall 6 | // #include 7 | // #include "classification.h" 8 | import "C" 9 | import "unsafe" 10 | 11 | import ( 12 | "io" 13 | "io/ioutil" 14 | "log" 15 | "net/http" 16 | "os" 17 | ) 18 | 19 | var ctx *C.classifier_ctx 20 | 21 | func classify(w http.ResponseWriter, r *http.Request) { 22 | if r.Method != "POST" { 23 | http.Error(w, "", http.StatusMethodNotAllowed) 24 | return 25 | } 26 | 27 | buffer, err := ioutil.ReadAll(r.Body) 28 | if err != nil { 29 | http.Error(w, err.Error(), http.StatusBadRequest) 30 | return 31 | } 32 | 33 | cstr, err := C.classifier_classify(ctx, (*C.char)(unsafe.Pointer(&buffer[0])), C.size_t(len(buffer))) 34 | if err != nil { 35 | http.Error(w, err.Error(), http.StatusBadRequest) 36 | return 37 | } 38 | defer C.free(unsafe.Pointer(cstr)) 39 | io.WriteString(w, C.GoString(cstr)) 40 | } 41 | 42 | func main() { 43 | cmodel := C.CString(os.Args[1]) 44 | ctrained := C.CString(os.Args[2]) 45 | cmean := C.CString(os.Args[3]) 46 | clabel := C.CString(os.Args[4]) 47 | 48 | log.Println("Initializing TensorRT classifiers") 49 | var err error 50 | ctx, err = C.classifier_initialize(cmodel, ctrained, cmean, clabel) 51 | if err != nil { 52 | log.Fatalln("could not initialize classifier:", err) 53 | return 54 | } 55 | defer C.classifier_destroy(ctx) 56 | 57 | log.Println("Adding REST endpoint /api/classify") 58 | http.HandleFunc("/api/classify", classify) 59 | log.Println("Starting server listening on :8000") 60 | log.Fatal(http.ListenAndServe(":8000", nil)) 61 | } 62 | --------------------------------------------------------------------------------