├── .gitignore ├── README.md ├── docker ├── cpu │ └── Dockerfile └── cuda10 │ └── Dockerfile ├── inference-cpp └── cnn-classification │ ├── CMakeLists.txt │ ├── build.sh │ ├── image.jpeg │ ├── infer.cc │ ├── infer.h │ ├── main.cc │ └── server │ ├── CMakeLists.txt │ ├── base64.cc │ ├── base64.h │ ├── build.sh │ ├── crow_all.h │ ├── main.cc │ └── test_api.py ├── models └── resnet │ ├── labels.txt │ └── resnet.py └── utils ├── opencvutils.cc ├── opencvutils.h ├── torchutils.cc └── torchutils.h /.gitignore: -------------------------------------------------------------------------------- 1 | # Pytorch models 2 | *.pth 3 | 4 | # Libtorch 5 | libtorch 6 | *.zip 7 | 8 | # Ignore binaries 9 | predict 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Serving PyTorch Models in C++ 2 | 3 | * This repository contains various examples to perform inference using PyTorch C++ API. 4 | * Run `git clone https://github.com/Wizaron/pytorch-cpp-inference` in order to clone this repository. 5 | 6 | ## Environment 7 | 8 | 1. Dockerfiles can be found at `docker` directory. There are two dockerfiles; one for cpu and the other for cuda10. In order to build docker image, you should go to `docker/cpu` or `docker/cuda10` directory and run `docker build -t .`. 9 | 2. After creation of the docker image, you should create a docker container via `docker run -v : -p 8181:8181 -it ` (We will use 8181 to serve our PyTorch C++ model). 10 | 3. Inside docker container, go to the directory that this repository resides. 11 | 4. Download `libtorch` from [PyTorch Website](https://pytorch.org/get-started/locally/) (CPU : `https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.3.1%2Bcpu.zip` - CUDA10 : `https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.3.1.zip`). 12 | 5. Unzip libtorch via `unzip`. This will create `libtorch` directory that contains torch shared libraries and headers. 13 | 14 | ## Code Structure 15 | 16 | * `models` directory stores PyTorch models. 17 | * `libtorch` directory stores C++ torch headers and shared libraries to link the model against PyTorch. 18 | * `utils` directory stores various utility function to perform inference in C++. 19 | * `inference-cpp` directory stores codes to perform inference. 20 | 21 | ## Exporting PyTorch ScriptModule 22 | 23 | * In order to export `torch.jit.ScriptModule` of ResNet18 to perform C++ inference, go to `models/resnet` directory and run `python3 resnet.py`. It will download pretrained ResNet18 model on ImageNet and create `models/resnet_model_cpu.pth` and (optionally) `models/resnet_model_gpu.pth` which we will use in C++ inference. 24 | 25 | ## Serving the C++ Model 26 | 27 | * We can either serve the model as a single executable or as a web server. 28 | 29 | ### Single Executable 30 | 31 | * In order to build a single executable for inference: 32 | 1. Go to `inference-cpp/cnn-classification` directory. 33 | 2. Run `./build.sh` in order to build executable, named as `predict`. 34 | 3. Run the executable via `./predict `. 35 | 4. Example: `./predict image.jpeg ../../models/resnet/resnet_model_cpu.pth ../../models/resnet/labels.txt false` 36 | 37 | ### Web Server 38 | 39 | * In order to build a web server for production: 40 | 1. Go to `inference-cpp/cnn-classification/server` directory. 41 | 2. Run `./build.sh` in order to build web server, named as `predict`. 42 | 3. Run the binary via `./predict ` (It will serve the model on `http://localhost:8181/predict`). 43 | 4. Example: `./predict ../../../models/resnet/resnet_model_cpu.pth ../../../models/resnet/labels.txt false` 44 | 4. In order to make a request, open a new tab and run `python test_api.py` (It will make a request to `localhost:8181/predict`). 45 | 46 | ## Acknowledgement 47 | 48 | 1. [pytorch](https://pytorch.org) 49 | 2. [crow](https://github.com/ipkn/crow) 50 | 3. [tensorflow_cpp_object_detection_web_server](https://github.com/CasiaFan/tensorflow_cpp_object_detection_web_server) 51 | -------------------------------------------------------------------------------- /docker/cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 3 | g++ \ 4 | make \ 5 | cmake \ 6 | wget \ 7 | unzip \ 8 | vim \ 9 | git \ 10 | libopencv-dev \ 11 | libboost-all-dev \ 12 | python3 \ 13 | python3-pip 14 | 15 | RUN pip3 install numpy "pillow<7" 16 | RUN pip3 install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html 17 | -------------------------------------------------------------------------------- /docker/cuda10/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 2 | 3 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 4 | g++ \ 5 | make \ 6 | cmake \ 7 | wget \ 8 | unzip \ 9 | vim \ 10 | git \ 11 | libopencv-dev \ 12 | libboost-all-dev \ 13 | python3 \ 14 | python3-pip 15 | 16 | RUN pip3 install numpy "pillow<7" 17 | RUN pip3 install torch torchvision 18 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(cnn-classification) 3 | 4 | find_package(Torch REQUIRED) 5 | find_package(OpenCV REQUIRED) 6 | 7 | add_executable(cnn-inference main.cc infer.cc ../../utils/opencvutils.cc ../../utils/torchutils.cc) 8 | 9 | target_link_libraries(cnn-inference "${TORCH_LIBRARIES}") 10 | target_link_libraries(cnn-inference "${OpenCV_LIBS}") 11 | 12 | set_property(TARGET cnn-inference PROPERTY CXX_STANDARD 11) 13 | set_property(TARGET cnn-inference PROPERTY OUTPUT_NAME predict) 14 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/build.sh: -------------------------------------------------------------------------------- 1 | rm -rf build 2 | 3 | mkdir -p build 4 | cd build 5 | cmake -DCMAKE_PREFIX_PATH=../../libtorch .. 6 | make -j4 7 | cd .. 8 | 9 | mv build/predict . 10 | 11 | rm -rf build 12 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/pytorch-cpp-inference/47859e695a29f75952f1290ce4e4318a6f58e6d1/inference-cpp/cnn-classification/image.jpeg -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/infer.cc: -------------------------------------------------------------------------------- 1 | #include "infer.h" 2 | 3 | std::tuple infer( 4 | cv::Mat image, 5 | int image_height, int image_width, 6 | std::vector mean, std::vector std, 7 | std::vector labels, 8 | torch::jit::script::Module model, 9 | bool usegpu) { 10 | 11 | if (image.empty()) { 12 | std::cout << "WARNING: Cannot read image!" << std::endl; 13 | } 14 | 15 | std::string pred = ""; 16 | std::string prob = "0.0"; 17 | 18 | // Predict if image is not empty 19 | if (!image.empty()) { 20 | 21 | // Preprocess image 22 | image = preprocess(image, image_height, image_width, 23 | mean, std); 24 | 25 | // Forward 26 | std::vector probs = forward({image, }, model, usegpu); 27 | 28 | // Postprocess 29 | tie(pred, prob) = postprocess(probs, labels); 30 | } 31 | 32 | return std::make_tuple(pred, prob); 33 | } 34 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/infer.h: -------------------------------------------------------------------------------- 1 | #ifndef INFER_H // To make sure you don't declare the function more than once by including the header multiple times. 2 | #define INFER_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "../../utils/torchutils.h" 22 | #include "../../utils/opencvutils.h" 23 | 24 | std::tuple infer( 25 | cv::Mat, 26 | int, int, 27 | std::vector, std::vector, 28 | std::vector, 29 | torch::jit::script::Module, 30 | bool); 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/main.cc: -------------------------------------------------------------------------------- 1 | #include "infer.h" 2 | 3 | int main(int argc, char **argv) { 4 | 5 | if (argc != 5) { 6 | std::cerr << "usage: predict \n"; 7 | return -1; 8 | } 9 | 10 | std::string image_path = argv[1]; 11 | std::string model_path = argv[2]; 12 | std::string labels_path = argv[3]; 13 | std::string usegpu_str = argv[4]; 14 | bool usegpu; 15 | 16 | if (usegpu_str == "true") { 17 | usegpu = true; 18 | } else { 19 | usegpu = false; 20 | } 21 | 22 | int image_height = 224; 23 | int image_width = 224; 24 | 25 | // Read labels 26 | std::vector labels; 27 | std::string label; 28 | std::ifstream labelsfile (labels_path); 29 | if (labelsfile.is_open()) 30 | { 31 | while (getline(labelsfile, label)) 32 | { 33 | labels.push_back(label); 34 | } 35 | labelsfile.close(); 36 | } 37 | 38 | std::vector mean = {0.485, 0.456, 0.406}; 39 | std::vector std = {0.229, 0.224, 0.225}; 40 | 41 | cv::Mat image = cv::imread(image_path); 42 | torch::jit::script::Module model = read_model(model_path, usegpu); 43 | 44 | std::string pred, prob; 45 | tie(pred, prob) = infer(image, image_height, image_width, mean, std, labels, model, usegpu); 46 | 47 | std::cout << "PREDICTION : " << pred << std::endl; 48 | std::cout << "CONFIDENCE : " << prob << std::endl; 49 | 50 | return 0; 51 | } 52 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(cnn-classification) 3 | 4 | find_package(Torch REQUIRED) 5 | find_package(OpenCV REQUIRED) 6 | find_package(Boost REQUIRED COMPONENTS system thread) 7 | 8 | include_directories(${Boost_INCLUDE_DIRS}) 9 | 10 | add_executable(cnn-inference main.cc base64.cc ../infer.cc ../../../utils/opencvutils.cc ../../../utils/torchutils.cc) 11 | 12 | target_link_libraries(cnn-inference "${TORCH_LIBRARIES}") 13 | target_link_libraries(cnn-inference "${OpenCV_LIBS}") 14 | target_link_libraries(cnn-inference ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY}) 15 | 16 | set_property(TARGET cnn-inference PROPERTY CXX_STANDARD 11) 17 | set_property(TARGET cnn-inference PROPERTY OUTPUT_NAME predict) 18 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/base64.cc: -------------------------------------------------------------------------------- 1 | // 2 | // Code source: https://renenyffenegger.ch/notes/development/Base64/Encoding-and-decoding-base-64-with-cpp 3 | // 4 | 5 | #include "base64.h" 6 | #include 7 | 8 | static const std::string base64_chars = 9 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10 | "abcdefghijklmnopqrstuvwxyz" 11 | "0123456789+/"; 12 | 13 | 14 | static inline bool is_base64(unsigned char c) { 15 | return (isalnum(c) || (c == '+') || (c == '/')); 16 | } 17 | 18 | std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { 19 | std::string ret; 20 | int i = 0; 21 | int j = 0; 22 | unsigned char char_array_3[3]; 23 | unsigned char char_array_4[4]; 24 | 25 | while (in_len--) { 26 | char_array_3[i++] = *(bytes_to_encode++); 27 | if (i == 3) { 28 | char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; 29 | char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); 30 | char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); 31 | char_array_4[3] = char_array_3[2] & 0x3f; 32 | 33 | for(i = 0; (i <4) ; i++) 34 | ret += base64_chars[char_array_4[i]]; 35 | i = 0; 36 | } 37 | } 38 | 39 | if (i) 40 | { 41 | for(j = i; j < 3; j++) 42 | char_array_3[j] = '\0'; 43 | 44 | char_array_4[0] = ( char_array_3[0] & 0xfc) >> 2; 45 | char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); 46 | char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); 47 | 48 | for (j = 0; (j < i + 1); j++) 49 | ret += base64_chars[char_array_4[j]]; 50 | 51 | while((i++ < 3)) 52 | ret += '='; 53 | 54 | } 55 | 56 | return ret; 57 | 58 | } 59 | 60 | std::string base64_decode(std::string const& encoded_string) { 61 | int in_len = encoded_string.size(); 62 | int i = 0; 63 | int j = 0; 64 | int in_ = 0; 65 | unsigned char char_array_4[4], char_array_3[3]; 66 | std::string ret; 67 | 68 | while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { 69 | char_array_4[i++] = encoded_string[in_]; in_++; 70 | if (i ==4) { 71 | for (i = 0; i <4; i++) 72 | char_array_4[i] = base64_chars.find(char_array_4[i]); 73 | 74 | char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4); 75 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); 76 | char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; 77 | 78 | for (i = 0; (i < 3); i++) 79 | ret += char_array_3[i]; 80 | i = 0; 81 | } 82 | } 83 | 84 | if (i) { 85 | for (j = 0; j < i; j++) 86 | char_array_4[j] = base64_chars.find(char_array_4[j]); 87 | 88 | char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); 89 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); 90 | 91 | for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; 92 | } 93 | 94 | return ret; 95 | } -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/base64.h: -------------------------------------------------------------------------------- 1 | // 2 | // Code source: https://renenyffenegger.ch/notes/development/Base64/Encoding-and-decoding-base-64-with-cpp 3 | 4 | // 5 | 6 | #ifndef PROJECT_BASE64_H 7 | #define PROJECT_BASE64_H 8 | #include 9 | 10 | std::string base64_encode(unsigned char const* , unsigned int len); 11 | std::string base64_decode(std::string const& s); 12 | 13 | #endif //PROJECT_BASE64_H 14 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/build.sh: -------------------------------------------------------------------------------- 1 | rm -rf build 2 | 3 | mkdir -p build 4 | cd build 5 | cmake -DCMAKE_PREFIX_PATH=../../../libtorch .. 6 | make -j4 7 | cd .. 8 | 9 | mv build/predict . 10 | 11 | rm -rf build 12 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/main.cc: -------------------------------------------------------------------------------- 1 | #include "../infer.h" 2 | #include "crow_all.h" 3 | #include "base64.h" 4 | 5 | int PORT = 8181; 6 | 7 | int main(int argc, char **argv) { 8 | 9 | if (argc != 4) { 10 | std::cerr << "usage: predict \n"; 11 | return -1; 12 | } 13 | 14 | std::string model_path = argv[1]; 15 | std::string labels_path = argv[2]; 16 | std::string usegpu_str = argv[3]; 17 | bool usegpu; 18 | 19 | if (usegpu_str == "true") { 20 | usegpu = true; 21 | } else { 22 | usegpu = false; 23 | } 24 | 25 | // Set image height and width 26 | int image_height = 224; 27 | int image_width = 224; 28 | 29 | // Read labels 30 | std::vector labels; 31 | std::string label; 32 | std::ifstream labelsfile (labels_path); 33 | if (labelsfile.is_open()) 34 | { 35 | while (getline(labelsfile, label)) 36 | { 37 | labels.push_back(label); 38 | } 39 | labelsfile.close(); 40 | } 41 | 42 | // Define mean and std 43 | std::vector mean = {0.485, 0.456, 0.406}; 44 | std::vector std = {0.229, 0.224, 0.225}; 45 | 46 | // Load Model 47 | torch::jit::script::Module model = read_model(model_path, usegpu); 48 | 49 | // App 50 | crow::SimpleApp app; 51 | CROW_ROUTE(app, "/predict").methods("POST"_method, "GET"_method) 52 | ([&image_height, &image_width, 53 | &mean, &std, &labels, &model, &usegpu](const crow::request& req){ 54 | crow::json::wvalue result; 55 | result["Prediction"] = ""; 56 | result["Confidence"] = ""; 57 | result["Status"] = "Failed"; 58 | std::ostringstream os; 59 | 60 | try { 61 | auto args = crow::json::load(req.body); 62 | 63 | // Get Image 64 | std::string base64_image = args["image"].s(); 65 | std::string decoded_image = base64_decode(base64_image); 66 | std::vector image_data(decoded_image.begin(), decoded_image.end()); 67 | cv::Mat image = cv::imdecode(image_data, cv::IMREAD_UNCHANGED); 68 | 69 | // Predict 70 | std::string pred, prob; 71 | tie(pred, prob) = infer(image, image_height, image_width, mean, std, labels, model, usegpu); 72 | 73 | result["Prediction"] = pred; 74 | result["Confidence"] = prob; 75 | result["Status"] = "Success"; 76 | 77 | os << crow::json::dump(result); 78 | return crow::response{os.str()}; 79 | 80 | } catch (std::exception& e){ 81 | os << crow::json::dump(result); 82 | return crow::response(os.str()); 83 | } 84 | 85 | }); 86 | 87 | app.port(PORT).run(); 88 | return 0; 89 | } 90 | -------------------------------------------------------------------------------- /inference-cpp/cnn-classification/server/test_api.py: -------------------------------------------------------------------------------- 1 | import requests, json, base64 2 | 3 | url = "http://localhost:8181/predict" 4 | 5 | image_path = "../image.jpeg" 6 | 7 | result = requests.post(url, json={"image": base64.b64encode(open(image_path, "rb").read())}).text 8 | 9 | print(json.loads(result)) 10 | -------------------------------------------------------------------------------- /models/resnet/labels.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great_white_shark 4 | tiger_shark 5 | hammerhead 6 | electric_ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house_finch 14 | junco 15 | indigo_bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water_ouzel 22 | kite 23 | bald_eagle 24 | vulture 25 | great_grey_owl 26 | European_fire_salamander 27 | common_newt 28 | eft 29 | spotted_salamander 30 | axolotl 31 | bullfrog 32 | tree_frog 33 | tailed_frog 34 | loggerhead 35 | leatherback_turtle 36 | mud_turtle 37 | terrapin 38 | box_turtle 39 | banded_gecko 40 | common_iguana 41 | American_chameleon 42 | whiptail 43 | agama 44 | frilled_lizard 45 | alligator_lizard 46 | Gila_monster 47 | green_lizard 48 | African_chameleon 49 | Komodo_dragon 50 | African_crocodile 51 | American_alligator 52 | triceratops 53 | thunder_snake 54 | ringneck_snake 55 | hognose_snake 56 | green_snake 57 | king_snake 58 | garter_snake 59 | water_snake 60 | vine_snake 61 | night_snake 62 | boa_constrictor 63 | rock_python 64 | Indian_cobra 65 | green_mamba 66 | sea_snake 67 | horned_viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black_and_gold_garden_spider 74 | barn_spider 75 | garden_spider 76 | black_widow 77 | tarantula 78 | wolf_spider 79 | tick 80 | centipede 81 | black_grouse 82 | ptarmigan 83 | ruffed_grouse 84 | prairie_chicken 85 | peacock 86 | quail 87 | partridge 88 | African_grey 89 | macaw 90 | sulphur-crested_cockatoo 91 | lorikeet 92 | coucal 93 | bee_eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted_merganser 100 | goose 101 | black_swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea_anemone 110 | brain_coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea_slug 117 | chiton 118 | chambered_nautilus 119 | Dungeness_crab 120 | rock_crab 121 | fiddler_crab 122 | king_crab 123 | American_lobster 124 | spiny_lobster 125 | crayfish 126 | hermit_crab 127 | isopod 128 | white_stork 129 | black_stork 130 | spoonbill 131 | flamingo 132 | little_blue_heron 133 | American_egret 134 | bittern 135 | crane 136 | limpkin 137 | European_gallinule 138 | American_coot 139 | bustard 140 | ruddy_turnstone 141 | red-backed_sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king_penguin 147 | albatross 148 | grey_whale 149 | killer_whale 150 | dugong 151 | sea_lion 152 | Chihuahua 153 | Japanese_spaniel 154 | Maltese_dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim_spaniel 158 | papillon 159 | toy_terrier 160 | Rhodesian_ridgeback 161 | Afghan_hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan_coonhound 167 | Walker_hound 168 | English_foxhound 169 | redbone 170 | borzoi 171 | Irish_wolfhound 172 | Italian_greyhound 173 | whippet 174 | Ibizan_hound 175 | Norwegian_elkhound 176 | otterhound 177 | Saluki 178 | Scottish_deerhound 179 | Weimaraner 180 | Staffordshire_bullterrier 181 | American_Staffordshire_terrier 182 | Bedlington_terrier 183 | Border_terrier 184 | Kerry_blue_terrier 185 | Irish_terrier 186 | Norfolk_terrier 187 | Norwich_terrier 188 | Yorkshire_terrier 189 | wire-haired_fox_terrier 190 | Lakeland_terrier 191 | Sealyham_terrier 192 | Airedale 193 | cairn 194 | Australian_terrier 195 | Dandie_Dinmont 196 | Boston_bull 197 | miniature_schnauzer 198 | giant_schnauzer 199 | standard_schnauzer 200 | Scotch_terrier 201 | Tibetan_terrier 202 | silky_terrier 203 | soft-coated_wheaten_terrier 204 | West_Highland_white_terrier 205 | Lhasa 206 | flat-coated_retriever 207 | curly-coated_retriever 208 | golden_retriever 209 | Labrador_retriever 210 | Chesapeake_Bay_retriever 211 | German_short-haired_pointer 212 | vizsla 213 | English_setter 214 | Irish_setter 215 | Gordon_setter 216 | Brittany_spaniel 217 | clumber 218 | English_springer 219 | Welsh_springer_spaniel 220 | cocker_spaniel 221 | Sussex_spaniel 222 | Irish_water_spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old_English_sheepdog 231 | Shetland_sheepdog 232 | collie 233 | Border_collie 234 | Bouvier_des_Flandres 235 | Rottweiler 236 | German_shepherd 237 | Doberman 238 | miniature_pinscher 239 | Greater_Swiss_Mountain_dog 240 | Bernese_mountain_dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull_mastiff 245 | Tibetan_mastiff 246 | French_bulldog 247 | Great_Dane 248 | Saint_Bernard 249 | Eskimo_dog 250 | malamute 251 | Siberian_husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great_Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon_griffon 264 | Pembroke 265 | Cardigan 266 | toy_poodle 267 | miniature_poodle 268 | standard_poodle 269 | Mexican_hairless 270 | timber_wolf 271 | white_wolf 272 | red_wolf 273 | coyote 274 | dingo 275 | dhole 276 | African_hunting_dog 277 | hyena 278 | red_fox 279 | kit_fox 280 | Arctic_fox 281 | grey_fox 282 | tabby 283 | tiger_cat 284 | Persian_cat 285 | Siamese_cat 286 | Egyptian_cat 287 | cougar 288 | lynx 289 | leopard 290 | snow_leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown_bear 296 | American_black_bear 297 | ice_bear 298 | sloth_bear 299 | mongoose 300 | meerkat 301 | tiger_beetle 302 | ladybug 303 | ground_beetle 304 | long-horned_beetle 305 | leaf_beetle 306 | dung_beetle 307 | rhinoceros_beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking_stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage_butterfly 326 | sulphur_butterfly 327 | lycaenid 328 | starfish 329 | sea_urchin 330 | sea_cucumber 331 | wood_rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox_squirrel 337 | marmot 338 | beaver 339 | guinea_pig 340 | sorrel 341 | zebra 342 | hog 343 | wild_boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water_buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian_camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed_ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed_sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis_monkey 378 | marmoset 379 | capuchin 380 | howler_monkey 381 | titi 382 | spider_monkey 383 | squirrel_monkey 384 | Madagascar_cat 385 | indri 386 | Indian_elephant 387 | African_elephant 388 | lesser_panda 389 | giant_panda 390 | barracouta 391 | eel 392 | coho 393 | rock_beauty 394 | anemone_fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic_gown 402 | accordion 403 | acoustic_guitar 404 | aircraft_carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog_clock 411 | apiary 412 | apron 413 | ashcan 414 | assault_rifle 415 | backpack 416 | bakery 417 | balance_beam 418 | balloon 419 | ballpoint 420 | Band_Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber_chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing_cap 435 | bath_towel 436 | bathtub 437 | beach_wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer_bottle 442 | beer_glass 443 | bell_cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo_tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow_tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof_vest 467 | bullet_train 468 | butcher_shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can_opener 475 | cardigan 476 | car_mirror 477 | carousel 478 | carpenter's_kit 479 | carton 480 | car_wheel 481 | cash_machine 482 | cassette 483 | cassette_player 484 | castle 485 | catamaran 486 | CD_player 487 | cello 488 | cellular_telephone 489 | chain 490 | chainlink_fence 491 | chain_mail 492 | chain_saw 493 | chest 494 | chiffonier 495 | chime 496 | china_cabinet 497 | Christmas_stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff_dwelling 502 | cloak 503 | clog 504 | cocktail_shaker 505 | coffee_mug 506 | coffeepot 507 | coil 508 | combination_lock 509 | computer_keyboard 510 | confectionery 511 | container_ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy_boot 516 | cowboy_hat 517 | cradle 518 | crane 519 | crash_helmet 520 | crate 521 | crib 522 | Crock_Pot 523 | croquet_ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop_computer 529 | dial_telephone 530 | diaper 531 | digital_clock 532 | digital_watch 533 | dining_table 534 | dishrag 535 | dishwasher 536 | disk_brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling_platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch_oven 546 | electric_fan 547 | electric_guitar 548 | electric_locomotive 549 | entertainment_center 550 | envelope 551 | espresso_maker 552 | face_powder 553 | feather_boa 554 | file 555 | fireboat 556 | fire_engine 557 | fire_screen 558 | flagpole 559 | flute 560 | folding_chair 561 | football_helmet 562 | forklift 563 | fountain 564 | fountain_pen 565 | four-poster 566 | freight_car 567 | French_horn 568 | frying_pan 569 | fur_coat 570 | garbage_truck 571 | gasmask 572 | gas_pump 573 | goblet 574 | go-kart 575 | golf_ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand_piano 581 | greenhouse 582 | grille 583 | grocery_store 584 | guillotine 585 | hair_slide 586 | hair_spray 587 | half_track 588 | hammer 589 | hamper 590 | hand_blower 591 | hand-held_computer 592 | handkerchief 593 | hard_disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home_theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal_bar 604 | horse_cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw_puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee_pad 617 | knot 618 | lab_coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn_mower 623 | lens_cap 624 | letter_opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic_compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole_cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring_cup 649 | medicine_chest 650 | megalith 651 | microphone 652 | microwave 653 | military_uniform 654 | milk_can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing_bowl 661 | mobile_home 662 | Model_T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito_net 671 | motor_scooter 672 | mountain_bike 673 | mountain_tent 674 | mouse 675 | mousetrap 676 | moving_van 677 | muzzle 678 | nail 679 | neck_brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil_filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen_mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper_towel 702 | parachute 703 | parallel_bars 704 | park_bench 705 | parking_meter 706 | passenger_car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil_box 711 | pencil_sharpener 712 | perfume 713 | Petri_dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket_fence 718 | pickup 719 | pier 720 | piggy_bank 721 | pill_bottle 722 | pillow 723 | ping-pong_ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic_bag 730 | plate_rack 731 | plow 732 | plunger 733 | Polaroid_camera 734 | pole 735 | police_van 736 | poncho 737 | pool_table 738 | pop_bottle 739 | pot 740 | potter's_wheel 741 | power_drill 742 | prayer_rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching_bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio_telescope 757 | rain_barrel 758 | recreational_vehicle 759 | reel 760 | reflex_camera 761 | refrigerator 762 | remote_control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking_chair 767 | rotisserie 768 | rubber_eraser 769 | rugby_ball 770 | rule 771 | running_shoe 772 | safe 773 | safety_pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school_bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat_belt 787 | sewing_machine 788 | shield 789 | shoe_shop 790 | shoji 791 | shopping_basket 792 | shopping_cart 793 | shovel 794 | shower_cap 795 | shower_curtain 796 | ski 797 | ski_mask 798 | sleeping_bag 799 | slide_rule 800 | sliding_door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap_dispenser 806 | soccer_ball 807 | sock 808 | solar_dish 809 | sombrero 810 | soup_bowl 811 | space_bar 812 | space_heater 813 | space_shuttle 814 | spatula 815 | speedboat 816 | spider_web 817 | spindle 818 | sports_car 819 | spotlight 820 | stage 821 | steam_locomotive 822 | steel_arch_bridge 823 | steel_drum 824 | stethoscope 825 | stole 826 | stone_wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio_couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension_bridge 841 | swab 842 | sweatshirt 843 | swimming_trunks 844 | swing 845 | switch 846 | syringe 847 | table_lamp 848 | tank 849 | tape_player 850 | teapot 851 | teddy 852 | television 853 | tennis_ball 854 | thatch 855 | theater_curtain 856 | thimble 857 | thresher 858 | throne 859 | tile_roof 860 | toaster 861 | tobacco_shop 862 | toilet_seat 863 | torch 864 | totem_pole 865 | tow_truck 866 | toyshop 867 | tractor 868 | trailer_truck 869 | tray 870 | trench_coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal_arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter_keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending_machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle_iron 893 | wall_clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water_bottle 900 | water_jug 901 | water_tower 902 | whiskey_jug 903 | whistle 904 | wig 905 | window_screen 906 | window_shade 907 | Windsor_tie 908 | wine_bottle 909 | wing 910 | wok 911 | wooden_spoon 912 | wool 913 | worm_fence 914 | wreck 915 | yawl 916 | yurt 917 | web_site 918 | comic_book 919 | crossword_puzzle 920 | street_sign 921 | traffic_light 922 | book_jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot_pot 928 | trifle 929 | ice_cream 930 | ice_lolly 931 | French_loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed_potato 937 | head_cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti_squash 942 | acorn_squash 943 | butternut_squash 944 | cucumber 945 | artichoke 946 | bell_pepper 947 | cardoon 948 | mushroom 949 | Granny_Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard_apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate_sauce 962 | dough 963 | meat_loaf 964 | pizza 965 | potpie 966 | burrito 967 | red_wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral_reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba_diver 985 | rapeseed 986 | daisy 987 | yellow_lady's_slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral_fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet_tissue 1001 | -------------------------------------------------------------------------------- /models/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | # An instance of your model. 5 | model = torchvision.models.resnet18(pretrained=True) 6 | 7 | # Evaluation mode 8 | model.eval() 9 | 10 | # An example input you would normally provide to your model's forward() method. 11 | example = torch.rand(1, 3, 224, 224) 12 | 13 | def export_cpu(model, example): 14 | model = model.to("cpu") 15 | example = example.to("cpu") 16 | 17 | # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. 18 | traced_script_module = torch.jit.trace(model, example) 19 | 20 | # Save traced model 21 | traced_script_module.save("resnet_model_cpu.pth") 22 | 23 | def export_gpu(model, example): 24 | model = model.to("cuda") 25 | example = example.to("cuda") 26 | 27 | # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. 28 | traced_script_module = torch.jit.trace(model, example) 29 | 30 | # Save traced model 31 | traced_script_module.save("resnet_model_gpu.pth") 32 | 33 | export_cpu(model, example) 34 | 35 | if torch.cuda.is_available(): 36 | export_gpu(model, example) 37 | -------------------------------------------------------------------------------- /utils/opencvutils.cc: -------------------------------------------------------------------------------- 1 | #include "opencvutils.h" 2 | 3 | // Resize an image to a given size to 4 | cv::Mat __resize_to_a_size(cv::Mat image, int new_height, int new_width) { 5 | 6 | // get original image size 7 | int org_image_height = image.rows; 8 | int org_image_width = image.cols; 9 | 10 | // get image area and resized image area 11 | float img_area = float(org_image_height * org_image_width); 12 | float new_area = float(new_height * new_width); 13 | 14 | // resize 15 | cv::Mat image_scaled; 16 | cv::Size scale(new_width, new_height); 17 | 18 | if (new_area >= img_area) { 19 | cv::resize(image, image_scaled, scale, 0, 0, cv::INTER_LANCZOS4); 20 | } else { 21 | cv::resize(image, image_scaled, scale, 0, 0, cv::INTER_AREA); 22 | } 23 | 24 | return image_scaled; 25 | } 26 | 27 | // Normalize an image by subtracting mean and dividing by standard deviation 28 | cv::Mat __normalize_mean_std(cv::Mat image, std::vector mean, std::vector std) { 29 | 30 | // clone 31 | cv::Mat image_normalized = image.clone(); 32 | 33 | // convert to float 34 | image_normalized.convertTo(image_normalized, CV_32FC3); 35 | 36 | // subtract mean 37 | cv::subtract(image_normalized, mean, image_normalized); 38 | 39 | // divide by standard deviation 40 | std::vector img_channels(3); 41 | cv::split(image_normalized, img_channels); 42 | 43 | img_channels[0] = img_channels[0] / std[0]; 44 | img_channels[1] = img_channels[1] / std[1]; 45 | img_channels[2] = img_channels[2] / std[2]; 46 | 47 | cv::merge(img_channels, image_normalized); 48 | 49 | return image_normalized; 50 | } 51 | 52 | // 1. Preprocess 53 | cv::Mat preprocess(cv::Mat image, int new_height, int new_width, 54 | std::vector mean, std::vector std) { 55 | 56 | // Clone 57 | cv::Mat image_proc = image.clone(); 58 | 59 | // Convert from BGR to RGB 60 | cv::cvtColor(image_proc, image_proc, cv::COLOR_BGR2RGB); 61 | 62 | // Resize image 63 | image_proc = __resize_to_a_size(image_proc, new_height, new_width); 64 | 65 | // Convert image to float 66 | image_proc.convertTo(image_proc, CV_32FC3); 67 | 68 | // 3. Normalize to [0, 1] 69 | image_proc = image_proc / 255.0; 70 | 71 | // 4. Subtract mean and divide by std 72 | image_proc = __normalize_mean_std(image_proc, mean, std); 73 | 74 | return image_proc; 75 | } 76 | -------------------------------------------------------------------------------- /utils/opencvutils.h: -------------------------------------------------------------------------------- 1 | #ifndef OPENCVUTILS_H // To make sure you don't declare the function more than once by including the header multiple times. 2 | #define OPENCVUTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | cv::Mat preprocess(cv::Mat, int, int, 11 | std::vector, 12 | std::vector); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /utils/torchutils.cc: -------------------------------------------------------------------------------- 1 | #include "torchutils.h" 2 | 3 | // Convert a vector of images to torch Tensor 4 | torch::Tensor __convert_images_to_tensor(std::vector images) { 5 | 6 | int n_images = images.size(); 7 | int n_channels = images[0].channels(); 8 | int height = images[0].rows; 9 | int width = images[0].cols; 10 | 11 | int image_type = images[0].type(); 12 | 13 | // Image Type must be one of CV_8U, CV_32F, CV_64F 14 | assert((image_type % 8 == 0) || ((image_type - 5) % 8 == 0) || ((image_type - 6) % 8 == 0)); 15 | 16 | std::vector dims = {1, height, width, n_channels}; 17 | std::vector permute_dims = {0, 3, 1, 2}; 18 | 19 | std::vector images_as_tensors; 20 | for (int i = 0; i != n_images; i++) { 21 | cv::Mat image = images[i].clone(); 22 | 23 | torch::Tensor image_as_tensor; 24 | if (image_type % 8 == 0) { 25 | torch::TensorOptions options(torch::kUInt8); 26 | image_as_tensor = torch::from_blob(image.data, torch::IntList(dims), options).clone(); 27 | } else if ((image_type - 5) % 8 == 0) { 28 | torch::TensorOptions options(torch::kFloat32); 29 | image_as_tensor = torch::from_blob(image.data, torch::IntList(dims), options).clone(); 30 | } else if ((image_type - 6) % 8 == 0) { 31 | torch::TensorOptions options(torch::kFloat64); 32 | image_as_tensor = torch::from_blob(image.data, torch::IntList(dims), options).clone(); 33 | } 34 | 35 | image_as_tensor = image_as_tensor.permute(torch::IntList(permute_dims)); 36 | image_as_tensor = image_as_tensor.toType(torch::kFloat32); 37 | images_as_tensors.push_back(image_as_tensor); 38 | } 39 | 40 | torch::Tensor output_tensor = torch::cat(images_as_tensors, 0); 41 | 42 | return output_tensor; 43 | } 44 | 45 | // Predict 46 | torch::Tensor __predict(torch::jit::script::Module model, torch::Tensor tensor) { 47 | 48 | std::vector inputs; 49 | inputs.push_back(tensor); 50 | 51 | // Execute the model and turn its output into a tensor. 52 | torch::NoGradGuard no_grad; 53 | torch::Tensor output = model.forward(inputs).toTensor(); 54 | 55 | torch::DeviceType cpu_device_type = torch::kCPU; 56 | torch::Device cpu_device(cpu_device_type); 57 | output = output.to(cpu_device); 58 | 59 | return output; 60 | } 61 | 62 | // Softmax 63 | std::vector __softmax(std::vector unnorm_probs) { 64 | 65 | int n_classes = unnorm_probs.size(); 66 | 67 | // 1. Partition function 68 | float log_sum_of_exp_unnorm_probs = 0; 69 | for (auto& n : unnorm_probs) { 70 | log_sum_of_exp_unnorm_probs += std::exp(n); 71 | } 72 | log_sum_of_exp_unnorm_probs = std::log(log_sum_of_exp_unnorm_probs); 73 | 74 | // 2. normalize 75 | std::vector probs(n_classes); 76 | for (int class_idx = 0; class_idx != n_classes; class_idx++) { 77 | probs[class_idx] = std::exp(unnorm_probs[class_idx] - log_sum_of_exp_unnorm_probs); 78 | } 79 | 80 | return probs; 81 | } 82 | 83 | // Convert output tensor to vector of floats 84 | std::vector __get_outputs(torch::Tensor output) { 85 | 86 | int ndim = output.ndimension(); 87 | assert(ndim == 2); 88 | 89 | torch::ArrayRef sizes = output.sizes(); 90 | int n_samples = sizes[0]; 91 | int n_classes = sizes[1]; 92 | 93 | assert(n_samples == 1); 94 | 95 | std::vector unnorm_probs(output.data_ptr(), 96 | output.data_ptr() + (n_samples * n_classes)); 97 | 98 | // Softmax 99 | std::vector probs = __softmax(unnorm_probs); 100 | 101 | return probs; 102 | } 103 | 104 | // 1. Read model 105 | torch::jit::script::Module read_model(std::string model_path, bool usegpu) { 106 | 107 | torch::jit::script::Module model = torch::jit::load(model_path); 108 | 109 | if (usegpu) { 110 | torch::DeviceType gpu_device_type = torch::kCUDA; 111 | torch::Device gpu_device(gpu_device_type); 112 | 113 | model.to(gpu_device); 114 | } else { 115 | torch::DeviceType cpu_device_type = torch::kCPU; 116 | torch::Device cpu_device(cpu_device_type); 117 | 118 | model.to(cpu_device); 119 | } 120 | 121 | return model; 122 | } 123 | 124 | // 2. Forward 125 | std::vector forward(std::vector images, 126 | torch::jit::script::Module model, bool usegpu) { 127 | 128 | // 1. Convert OpenCV matrices to torch Tensor 129 | torch::Tensor tensor = __convert_images_to_tensor(images); 130 | 131 | if (usegpu) { 132 | torch::DeviceType gpu_device_type = torch::kCUDA; 133 | torch::Device gpu_device(gpu_device_type); 134 | 135 | tensor = tensor.to(gpu_device); 136 | } else { 137 | torch::DeviceType cpu_device_type = torch::kCPU; 138 | torch::Device cpu_device(cpu_device_type); 139 | 140 | tensor = tensor.to(cpu_device); 141 | } 142 | 143 | // 2. Predict 144 | torch::Tensor output = __predict(model, tensor); 145 | 146 | // 3. Convert torch Tensor to vector of vector of floats 147 | std::vector probs = __get_outputs(output); 148 | 149 | return probs; 150 | } 151 | 152 | // 3. Postprocess 153 | std::tuple postprocess(std::vector probs, 154 | std::vector labels) { 155 | 156 | // 1. Get label and corresponding probability 157 | auto prob = std::max_element(probs.begin(), probs.end()); 158 | auto label_idx = std::distance(probs.begin(), prob); 159 | auto label = labels[label_idx]; 160 | float prob_float = *prob; 161 | 162 | return std::make_tuple(label, std::to_string(prob_float)); 163 | } 164 | -------------------------------------------------------------------------------- /utils/torchutils.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHUTILS_H // To make sure you don't declare the function more than once by including the header multiple times. 2 | #define TORCHUTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | torch::jit::script::Module read_model(std::string, bool); 22 | std::vector forward(std::vector, 23 | torch::jit::script::Module, bool); 24 | std::tuple postprocess(std::vector, 25 | std::vector); 26 | 27 | #endif 28 | --------------------------------------------------------------------------------