├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── bus.jpg ├── jit_extract.py ├── mnist.cpp ├── net.py ├── six.jpg ├── train.py └── yolov5.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | alchemistic_directory 2 | *.pyc 3 | .vscode 4 | *.pth 5 | ./data 6 | build 7 | data 8 | *.zip 9 | libtorch 10 | opencv 11 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | 3 | project(${TARGET}) 4 | 5 | find_package(Torch REQUIRED PATHS ./libtorch) 6 | find_package(OpenCV REQUIRED PATHS ./opencv/build REQUIRED) 7 | 8 | add_executable(${TARGET} ${TARGET}.cpp) 9 | target_link_libraries(${TARGET} "${TORCH_LIBRARIES}" "${OpenCV_LIBS}") 10 | 11 | set_property(TARGET ${TARGET} PROPERTY CXX_STANDARD 14) 12 | set(CMAKE_CXX_STANDARD_REQUIRED TRUE) 13 | 14 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chenglu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchscript-demos 2 | 3 | A brief of TorchScript by MNIST and YOLOv5. 4 | 5 | ## Requirements 6 | 7 | Any x86 arch CPU and UNIX like system should work. 8 | 9 | For training in Python 10 | 11 | * Python==3.7+ 12 | * PyTorch==1.8.1 13 | * MineTorch==0.6.12 14 | 15 | For inference in C++ 16 | 17 | * cmake 18 | * LibTorch 19 | * OpenCV 20 | 21 | ## Installation 22 | 23 | This guide will cover the part of the LibTorch and OpenCV installation and assume other things are already installed. Everything will installed within directory of the repo so uninstallation will be the same as removing the whole directory. 24 | 25 | 1. Clone this repo. 26 | 27 | ```bash 28 | git clone https://github.com/louis-she/torchscript-demos.git 29 | cd torchscript-demos 30 | ``` 31 | 32 | 2. Install OpenCV 33 | 34 | ```bash 35 | # In repo base dir 36 | git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git 37 | mkdir opencv/build && cd opencv/build 38 | cmake .. 39 | make -j 4 40 | ``` 41 | 42 | 3. Download LibTorch 43 | 44 | ```bash 45 | # In repo base dir 46 | wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip 47 | unzip libtorch-shared-with-deps-latest.zip 48 | ``` 49 | 50 | 4. Clone YOLOv5(Optional) 51 | 52 | > Required if you want to try the YOLOv5 example. 53 | 54 | ```bash 55 | git clone https://github.com/ultralytics/yolov5 56 | ``` 57 | 58 | ## MNIST 59 | 60 | **1. Training** 61 | 62 | ```bash 63 | # In repo base dir 64 | python3 train.py 65 | ``` 66 | 67 | Use `ctrl + C` to stop training process, the training log and graphs can be found at `./alchemistic_directory/baseline`. 68 | 69 | **2. Build binary** 70 | 71 | ```bash 72 | # In repo base dir 73 | 74 | # Dump TorchScript Module file `./jit_module.pth` 75 | python3 jit_extract.py 76 | 77 | # Build C++ binary `./build/mnist` 78 | mkdir build && cd build 79 | cmake .. -DTARGET=mnist 80 | make 81 | ``` 82 | 83 | **3. Make inference with the binary** 84 | 85 | ```bash 86 | # In repo base dir 87 | ./build/mnist jit_module.pth six.jpg 88 | 89 | # Output: The number is: 6 90 | ``` 91 | 92 | ## YOLOv5 93 | 94 | **1. Build binary** 95 | 96 | ```bash 97 | # In repo base dir 98 | 99 | # Build C++ binary `./build/mnist` 100 | mkdir build && cd build 101 | cmake .. -DTARGET=yolov5 102 | make 103 | ``` 104 | 105 | **2. Export TorchScript module** 106 | 107 | ``` 108 | # In YOLOv5 base dir 109 | python export.py --weights yolov5s.pt --img 640 --batch 1 110 | 111 | # Then copy the yolov5s.torchscript to base dir of this repo 112 | ``` 113 | 114 | **3. Make inference with the binary** 115 | 116 | ```bash 117 | # In base dir 118 | 119 | ./build/yolov5 yolov5s.torchscript bus.jpg 120 | ``` 121 | -------------------------------------------------------------------------------- /bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/torchscript-demos/5400801f4e212018232dc896436dbf264b265c98/bus.jpg -------------------------------------------------------------------------------- /jit_extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from net import Net 3 | 4 | model_to_load = "./alchemistic_directory/baseline/models/best.pth.tar" 5 | 6 | net = Net() 7 | net.load_state_dict(torch.load(model_to_load)["state_dict"]) 8 | net.eval() 9 | 10 | jit_module = torch.jit.trace(net, torch.rand(1, 1, 28, 28)) 11 | torch.jit.save(jit_module, "jit_module.pth") 12 | -------------------------------------------------------------------------------- /mnist.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | // #include 8 | #include 9 | #include 10 | 11 | torch::Tensor toTensor(cv::Mat img) { 12 | torch::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols }, torch::kUInt8); 13 | auto tensor_image_normed = (tensor_image / 255.0).sub_(0.5).div_(0.5); 14 | return tensor_image_normed; 15 | }; 16 | 17 | int main(int argc, char** argv) { 18 | // load jit module 19 | auto module = torch::jit::load(argv[1]); 20 | 21 | // load input image 22 | auto image = cv::imread(argv[2], cv::COLOR_BGR2GRAY); 23 | 24 | // preprocessing 25 | cv::Mat resized_image; 26 | cv::resize(image, resized_image, cv::Size(28, 28)); 27 | auto input_tensor = toTensor(resized_image); 28 | input_tensor.unsqueeze_(0).unsqueeze_(0); 29 | 30 | // forward 31 | std::vector inputs; 32 | inputs.push_back(input_tensor); 33 | torch::Tensor output = module.forward(inputs).toTensor(); 34 | 35 | // get result 36 | int result = output.argmax().item(); 37 | std::cout << "The number is: " << result << std::endl; 38 | return 0; 39 | } 40 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class Net(nn.Module): 6 | def __init__(self): 7 | super(Net, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = nn.Dropout2d() 11 | self.fc1 = nn.Linear(320, 50) 12 | self.fc2 = nn.Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return F.log_softmax(x, dim=1) 22 | -------------------------------------------------------------------------------- /six.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/torchscript-demos/5400801f4e212018232dc896436dbf264b265c98/six.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | 4 | from minetorch.miner import Miner 5 | from minetorch.metrics import MultiClassesClassificationMetricWithLogic 6 | from torchvision import datasets, transforms 7 | 8 | from net import Net 9 | 10 | 11 | train_loader = torch.utils.data.DataLoader( 12 | datasets.MNIST( 13 | "./data", 14 | train=True, 15 | download=True, 16 | transform=transforms.Compose( 17 | [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)] 18 | ), 19 | ), 20 | batch_size=128, 21 | shuffle=True, 22 | ) 23 | 24 | val_loader = torch.utils.data.DataLoader( 25 | datasets.MNIST( 26 | "./data", 27 | train=False, 28 | transform=transforms.Compose( 29 | [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)] 30 | ), 31 | ), 32 | batch_size=128, 33 | shuffle=True, 34 | ) 35 | 36 | model = Net() 37 | 38 | trainer = Miner( 39 | alchemistic_directory="./alchemistic_directory", 40 | code="baseline", 41 | model=model, 42 | optimizer=optim.SGD(model.parameters(), lr=0.01), 43 | train_dataloader=train_loader, 44 | val_dataloader=val_loader, 45 | loss_func=torch.nn.CrossEntropyLoss(), 46 | plugins=[MultiClassesClassificationMetricWithLogic()], 47 | ) 48 | 49 | trainer.train() 50 | -------------------------------------------------------------------------------- /yolov5.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace std; 13 | using namespace torch::indexing; 14 | 15 | class Yolo 16 | { 17 | public: 18 | Yolo(std::string model_path, float threshold) : m_model_path(model_path), m_threshold(threshold) 19 | { 20 | m_module = torch::jit::load(model_path); 21 | } 22 | 23 | torch::Tensor nms(torch::Tensor dets, float thres) 24 | { 25 | auto x1 = dets.index({Ellipsis, 0}); 26 | auto y1 = dets.index({Ellipsis, 1}); 27 | auto x2 = dets.index({Ellipsis, 2}); 28 | auto y2 = dets.index({Ellipsis, 3}); 29 | auto scores = dets.index({Ellipsis, 4}); 30 | 31 | auto areas = (x2 - x1 + 1) * (y2 - y1 + 1); 32 | auto order = scores.argsort(-1, true); 33 | 34 | torch::Tensor i; 35 | vector keep; 36 | 37 | while (order.sizes()[0] > 0) 38 | { 39 | i = order[0]; 40 | keep.push_back(i); 41 | auto xx1 = torch::maximum(x1.index({i}), x1.index({order.index({Slice(1, None, None)})})); 42 | auto yy1 = torch::maximum(y1.index({i}), y1.index({order.index({Slice(1, None, None)})})); 43 | auto xx2 = torch::minimum(x2.index({i}), x2.index({order.index({Slice(1, None, None)})})); 44 | auto yy2 = torch::minimum(y2.index({i}), y2.index({order.index({Slice(1, None, None)})})); 45 | 46 | auto w = torch::maximum(torch::zeros_like(xx2), xx2 - xx1 + 1); 47 | auto h = torch::maximum(torch::zeros_like(yy2), yy2 - yy1 + 1); 48 | 49 | auto inter = w * h; 50 | auto ovr = inter / (areas.index({i}) + areas.index({order.index({Slice(1, None, None)})}) - inter); 51 | auto inds = torch::where(ovr <= thres)[0]; 52 | order = order.index({inds + 1}); 53 | } 54 | 55 | return torch::stack(torch::TensorList(keep)); 56 | } 57 | 58 | vector> predict(char *png_buffer, size_t png_buffer_length) 59 | { 60 | vector> ret; 61 | auto image = cv::imdecode(cv::Mat(1, png_buffer_length, CV_8UC1, png_buffer), CV_LOAD_IMAGE_UNCHANGED); 62 | if (image.data == NULL) { 63 | return ret; 64 | } 65 | cv::cvtColor(image, image, cv::COLOR_BGR2RGB); 66 | 67 | // preprocessing 68 | auto h_gain = 640.0 / image.size[0]; 69 | auto w_gain = 640.0 / image.size[1]; 70 | 71 | cv::resize(image, image, cv::Size(640, 640), cv::INTER_LINEAR); 72 | auto input_tensor = to_tensor(image); 73 | input_tensor.unsqueeze_(0); 74 | 75 | // inference 76 | std::vector inputs; 77 | inputs.push_back(input_tensor); 78 | 79 | auto output = m_module.forward(inputs).toTuple()->elements()[0].toTensor().squeeze(); 80 | // processing output to N x 6 where 6 is (cx, cy, w, h, confidence, class_num) 81 | vector processed_output_vec; 82 | 83 | processed_output_vec.push_back(output.index({Ellipsis, 0})); 84 | processed_output_vec.push_back(output.index({Ellipsis, 1})); 85 | processed_output_vec.push_back(output.index({Ellipsis, 2})); 86 | processed_output_vec.push_back(output.index({Ellipsis, 3})); 87 | 88 | auto max_indices_output = torch::max(output.index({Ellipsis, Slice(5, None, None)}), 1); 89 | auto class_scores = std::get<0>(max_indices_output) * output.index({Ellipsis, 4}); 90 | auto class_nums = std::get<1>(max_indices_output); 91 | 92 | processed_output_vec.push_back(class_scores); 93 | processed_output_vec.push_back(class_nums); 94 | 95 | auto processed_output = torch::stack(torch::TensorList(processed_output_vec), 1); 96 | auto filtered_output = processed_output.index({processed_output.index({Ellipsis, 4}) > m_threshold, Ellipsis}); 97 | if (filtered_output.sizes()[0] == 0) { 98 | return ret; 99 | } 100 | 101 | filtered_output.index_put_({Ellipsis, 0}, filtered_output.index({Ellipsis, 0}) / w_gain); 102 | filtered_output.index_put_({Ellipsis, 1}, filtered_output.index({Ellipsis, 1}) / h_gain); 103 | filtered_output.index_put_({Ellipsis, 2}, filtered_output.index({Ellipsis, 2}) / w_gain); 104 | filtered_output.index_put_({Ellipsis, 3}, filtered_output.index({Ellipsis, 3}) / h_gain); 105 | 106 | // change coords from center x, center y, width, height to xyxy 107 | auto filtered_output_shadow = filtered_output.clone(); 108 | filtered_output.index_put_({Ellipsis, 0}, (filtered_output_shadow.index({Ellipsis, 0}) - filtered_output_shadow.index({Ellipsis, 2}) / 2)); 109 | filtered_output.index_put_({Ellipsis, 1}, (filtered_output_shadow.index({Ellipsis, 1}) - filtered_output_shadow.index({Ellipsis, 3}) / 2)); 110 | filtered_output.index_put_({Ellipsis, 2}, (filtered_output_shadow.index({Ellipsis, 0}) + filtered_output_shadow.index({Ellipsis, 2}) / 2)); 111 | filtered_output.index_put_({Ellipsis, 3}, (filtered_output_shadow.index({Ellipsis, 1}) + filtered_output_shadow.index({Ellipsis, 3}) / 2)); 112 | 113 | auto inds = nms(filtered_output, 0.3); 114 | filtered_output = filtered_output.index({inds}); 115 | 116 | for (auto i = 0; i < filtered_output.sizes()[0]; i++) 117 | { 118 | vector box; 119 | box.push_back(filtered_output.index({i, 0}).item()); 120 | box.push_back(filtered_output.index({i, 1}).item()); 121 | box.push_back(filtered_output.index({i, 2}).item()); 122 | box.push_back(filtered_output.index({i, 3}).item()); 123 | box.push_back(filtered_output.index({i, 5}).item()); 124 | ret.push_back(box); 125 | } 126 | return ret; 127 | } 128 | 129 | torch::Tensor to_tensor(cv::Mat img) 130 | { 131 | auto tensor_image = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kUInt8); 132 | tensor_image = tensor_image.permute({2, 0, 1}); 133 | auto tensor_image_normed = tensor_image / 255.0; 134 | return tensor_image_normed; 135 | }; 136 | 137 | private: 138 | float m_threshold; 139 | std::string m_model_path; 140 | torch::jit::Module m_module; 141 | }; 142 | 143 | int main(int argc, char **argv) 144 | { 145 | // example of read image from file then inference 146 | ifstream file_img(argv[2], ios::binary); 147 | file_img.seekg(0, std::ios::end); 148 | int buffer_length = file_img.tellg(); 149 | file_img.seekg(0, std::ios::beg); 150 | 151 | // Read image data into memory 152 | char *buffer = new char[buffer_length]; 153 | file_img.read(buffer, buffer_length); 154 | 155 | Yolo yolo(argv[1], 0.6f); 156 | auto preds = yolo.predict(buffer, buffer_length); 157 | 158 | auto output_image = cv::imread(argv[2], cv::COLOR_BGR2RGB); 159 | for (auto r : preds) 160 | { 161 | auto rec = cv::Rect(r[0], r[1], r[2] - r[0], r[3] - r[1]); 162 | cv::rectangle(output_image, rec, cv::Scalar(0, 255, 0), 2, 8, 0); 163 | } 164 | 165 | cv::imwrite("output.png", output_image); 166 | cout<<"See output.png"<