├── .gitignore
├── README.md
├── assets
├── api.png
├── cli.png
├── gui.png
└── predictions.jpg
├── pom.xml
└── src
└── main
├── java
└── com
│ └── example
│ ├── app
│ ├── ConfigReader.java
│ ├── api
│ │ ├── DetectionController.java
│ │ ├── ErrorCode.java
│ │ ├── ErrorResponse.java
│ │ ├── GlobalExceptionHandler.java
│ │ ├── ServingApplication.java
│ │ └── UploadFileException.java
│ ├── cli
│ │ └── CLI_App.java
│ ├── swing
│ │ └── SwingApp.java
│ └── video
│ │ └── VideoApp.java
│ └── yolo
│ ├── Detection.java
│ ├── ImageUtil.java
│ ├── ModelFactory.java
│ ├── NotImplementedException.java
│ ├── Yolo.java
│ ├── YoloV5.java
│ └── YoloV8.java
└── resources
├── coco.names
├── dog.jpg
├── dog.png
├── kite.jpg
├── model.properties
├── yolov5s.onnx
├── yolov5s.sparseml.onnx
└── yolov8s.onnx
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | !.mvn/wrapper/maven-wrapper.jar
3 | !**/src/main/**/target/
4 | !**/src/test/**/target/
5 |
6 | ### IntelliJ IDEA ###
7 | .idea/modules.xml
8 | .idea/jarRepositories.xml
9 | .idea/compiler.xml
10 | .idea/libraries/
11 | *.iws
12 | *.iml
13 | *.ipr
14 |
15 | ### Eclipse ###
16 | .apt_generated
17 | .classpath
18 | .factorypath
19 | .project
20 | .settings
21 | .springBeans
22 | .sts4-cache
23 |
24 | ### NetBeans ###
25 | /nbproject/private/
26 | /nbbuild/
27 | /dist/
28 | /nbdist/
29 | /.nb-gradle/
30 | build/
31 | !**/src/main/**/build/
32 | !**/src/test/**/build/
33 |
34 | ### VS Code ###
35 | .vscode/
36 |
37 | ### Mac OS ###
38 | .DS_Store
39 | .idea
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Onnxruntime Java Example : yolov5
2 |
3 | - An object detection example using onnxruntime and YOLO (v5 and v8)
4 |
5 | 
6 |
7 | ```bash
8 | mvn clean compile
9 |
10 | # GUI App
11 | # mvn exec:java -Dexec.mainClass="com.example.app.swing.SwingApp" -Dexec.classpathScope=test
12 |
13 | # CLI APP
14 | # mvn exec:java -Dexec.mainClass="com.example.app.cli.CLI_App" -Dexec.classpathScope=test
15 |
16 | # API
17 | mvn spring-boot:run
18 | ```
19 |
20 | ## Configs
21 |
22 | - See `src/main/resources/model.properties`
23 |
24 | ```properties
25 | modelName=yolov5
26 | modelPath=yolov5s.onnx
27 | #modelName=yolov8
28 | #modelPath=yolov8s.onnx
29 | labelPath=coco.names
30 | confThreshold=0.25
31 | nmsThreshold=0.45
32 | gpuDeviceId=-1
33 | ```
34 |
35 | ## GUI Demo
36 |
37 |
38 | 
39 |
40 | ## CLI Demo
41 |
42 | 
43 |
44 | ## Rest API Demo
45 |
46 | - Endpoint: `/detection`
47 | - `multipart/form-data`
48 | - `name="uploadFile"`
49 |
50 | 
51 |
52 | ## Updates
53 |
54 | ### Feb 25, 2023
55 |
56 | - Updated Yolov8 support
57 | - The model configs have been changed to be controlled through the properties file
58 |
--------------------------------------------------------------------------------
/assets/api.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/api.png
--------------------------------------------------------------------------------
/assets/cli.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/cli.png
--------------------------------------------------------------------------------
/assets/gui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/gui.png
--------------------------------------------------------------------------------
/assets/predictions.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/predictions.jpg
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | org.example
8 | yolov5
9 | 1.0-SNAPSHOT
10 |
11 |
12 | 18
13 | 18
14 | UTF-8
15 |
16 |
17 |
18 |
19 | org.springframework.boot
20 | spring-boot-starter-web
21 | 2.7.5
22 |
23 |
24 | com.microsoft.onnxruntime
25 | onnxruntime
26 | 1.13.1
27 |
28 |
29 | org.openpnp
30 | opencv
31 | 4.5.5-1
32 |
33 |
34 | com.google.code.gson
35 | gson
36 | 2.10
37 |
38 |
39 | org.projectlombok
40 | lombok
41 | 1.18.24
42 | compile
43 |
44 |
45 |
46 |
47 |
48 |
49 | org.springframework.boot
50 | spring-boot-maven-plugin
51 |
52 |
53 | --spring.profiles.active=dev
54 |
55 |
56 | 2.7.5
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/ConfigReader.java:
--------------------------------------------------------------------------------
1 | package com.example.app;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.util.Properties;
6 |
7 | public class ConfigReader {
8 |
9 | public Properties readProperties(String propertiesFilePath) {
10 | Properties properties = new Properties();
11 | InputStream inputStream = getClass().getClassLoader().getResourceAsStream(propertiesFilePath);
12 | try {
13 | properties.load(inputStream);
14 | return properties;
15 | } catch (IOException e) {
16 | throw new RuntimeException(e);
17 | }
18 | }
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/DetectionController.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import com.example.yolo.Detection;
5 | import com.example.yolo.ModelFactory;
6 | import com.example.yolo.NotImplementedException;
7 | import com.example.yolo.Yolo;
8 | import org.opencv.core.Mat;
9 | import org.opencv.core.MatOfByte;
10 | import org.opencv.imgcodecs.Imgcodecs;
11 | import org.slf4j.Logger;
12 | import org.slf4j.LoggerFactory;
13 | import org.springframework.web.bind.annotation.PostMapping;
14 | import org.springframework.web.bind.annotation.RestController;
15 | import org.springframework.web.multipart.MultipartFile;
16 |
17 | import java.io.IOException;
18 | import java.util.Arrays;
19 | import java.util.List;
20 |
21 |
22 | @RestController
23 | public class DetectionController {
24 | static private final List mimeTypes = Arrays.asList("image/png", "image/jpeg");
25 | private final Yolo inferenceSession;
26 | private final Logger LOGGER = LoggerFactory.getLogger(DetectionController.class);
27 |
28 | private DetectionController() throws OrtException, IOException, NotImplementedException {
29 | ModelFactory modelFactory = new ModelFactory();
30 | this.inferenceSession = modelFactory.getModel("model.properties");
31 | }
32 |
33 | @PostMapping(value = "/detection", consumes = {"multipart/form-data"}, produces = {"application/json"})
34 | public List detection(MultipartFile uploadFile) throws OrtException, IOException {
35 | if (!mimeTypes.contains(uploadFile.getContentType())) throw new UploadFileException(ErrorCode.INVALID_MIME_TYPE);
36 | byte[] bytes = uploadFile.getBytes();
37 | Mat img = Imgcodecs.imdecode(new MatOfByte(bytes), Imgcodecs.IMREAD_COLOR);
38 | List result = inferenceSession.run(img);
39 | LOGGER.info("POST 200");
40 | return result;
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/ErrorCode.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import lombok.AllArgsConstructor;
4 | import lombok.Getter;
5 | import org.springframework.http.HttpStatus;
6 |
7 |
8 | @AllArgsConstructor
9 | @Getter
10 | public enum ErrorCode {
11 | INVALID_MIME_TYPE(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "invalid mime type"),
12 | INVALID_UPLOAD_FILE(HttpStatus.BAD_REQUEST, "invalid file");
13 | private final HttpStatus httpStatus;
14 | private final String message;
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/ErrorResponse.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import lombok.Builder;
4 | import lombok.Getter;
5 | import org.springframework.http.ResponseEntity;
6 |
7 | import java.time.LocalDateTime;
8 |
9 | @Getter
10 | @Builder
11 | public class ErrorResponse {
12 | private final LocalDateTime timestamp = LocalDateTime.now();
13 | private final int status;
14 | private final String error;
15 | private final String code;
16 | private final String message;
17 |
18 | public static ResponseEntity toResponseEntity(ErrorCode errorCode) {
19 | return ResponseEntity
20 | .status(errorCode.getHttpStatus())
21 | .body(ErrorResponse.builder()
22 | .status(errorCode.getHttpStatus().value())
23 | .error(errorCode.getHttpStatus().name())
24 | .code(errorCode.name())
25 | .message(errorCode.getMessage())
26 | .build()
27 | );
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/GlobalExceptionHandler.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 | import org.springframework.http.ResponseEntity;
5 | import org.springframework.web.bind.annotation.ExceptionHandler;
6 | import org.springframework.web.bind.annotation.RestControllerAdvice;
7 |
8 | @Slf4j
9 | @RestControllerAdvice
10 | public class GlobalExceptionHandler {
11 |
12 | @ExceptionHandler(value = {UploadFileException.class})
13 | protected ResponseEntity handleUploadFileException(UploadFileException e) {
14 | log.error("handleUploadFileException throw UploadFileException : {}", e.getErrorCode());
15 | return ErrorResponse.toResponseEntity(e.getErrorCode());
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/ServingApplication.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import org.springframework.boot.SpringApplication;
4 | import org.springframework.boot.autoconfigure.SpringBootApplication;
5 |
6 |
7 | @SpringBootApplication
8 | public class ServingApplication {
9 | public static void main(String[] args) {
10 | SpringApplication.run(ServingApplication.class, args);
11 | }
12 |
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/api/UploadFileException.java:
--------------------------------------------------------------------------------
1 | package com.example.app.api;
2 |
3 | import lombok.AllArgsConstructor;
4 | import lombok.Getter;
5 |
6 | @Getter
7 | @AllArgsConstructor
8 | public class UploadFileException extends RuntimeException {
9 | private final ErrorCode errorCode;
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/cli/CLI_App.java:
--------------------------------------------------------------------------------
1 | package com.example.app.cli;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import com.example.yolo.*;
5 | import com.google.gson.Gson;
6 | import org.opencv.core.Mat;
7 | import org.opencv.imgcodecs.Imgcodecs;
8 |
9 | import java.io.*;
10 | import java.util.List;
11 |
12 | public class CLI_App {
13 |
14 | private Yolo inferenceSession;
15 | private Gson gson;
16 |
17 | public CLI_App() {
18 |
19 | ModelFactory modelFactory = new ModelFactory();
20 |
21 | try {
22 | this.inferenceSession = modelFactory.getModel("./model.properties");
23 | this.gson = new Gson();
24 | } catch (OrtException | IOException exception) {
25 | exception.printStackTrace();
26 | System.exit(1);
27 | }
28 | }
29 |
30 | public static void main(String[] args) {
31 |
32 | CLI_App app = new CLI_App();
33 |
34 | while (true) {
35 |
36 | System.out.print("Enter image path (enter 'q' or 'Q' to exit): ");
37 |
38 | BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
39 | String input = null;
40 |
41 | try {
42 | input = br.readLine();
43 | } catch (IOException e) {
44 | e.printStackTrace();
45 | System.exit(1);
46 | }
47 |
48 | if ("q".equals(input) | "Q".equals(input)) {
49 | System.out.println("Exit");
50 | System.exit(0);
51 | }
52 |
53 | File f = new File(input);
54 | if (!f.exists()) {
55 | System.out.println("File does not exists: " + input);
56 | continue;
57 | }
58 | if (f.isDirectory()) {
59 | System.out.println(input + " is a directory");
60 | continue;
61 | }
62 |
63 | Mat img = Imgcodecs.imread(input, Imgcodecs.IMREAD_COLOR);
64 | if (img.dataAddr() == 0) {
65 | System.out.println("Could not open image: " + input);
66 | continue;
67 | }
68 | // run detection
69 | try {
70 | List detectionList = app.inferenceSession.run(img);
71 | ImageUtil.drawPredictions(img, detectionList);
72 | System.out.println(app.gson.toJson(detectionList));
73 | Imgcodecs.imwrite("predictions.jpg", img);
74 | } catch (OrtException ortException) {
75 | ortException.printStackTrace();
76 | }
77 |
78 | }
79 |
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/swing/SwingApp.java:
--------------------------------------------------------------------------------
1 | package com.example.app.swing;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import com.example.yolo.Detection;
5 | import com.example.yolo.ImageUtil;
6 | import com.example.yolo.ModelFactory;
7 | import com.example.yolo.Yolo;
8 | import org.opencv.core.Mat;
9 | import org.opencv.core.MatOfByte;
10 | import org.opencv.imgcodecs.Imgcodecs;
11 |
12 | import javax.swing.*;
13 | import javax.swing.filechooser.FileNameExtensionFilter;
14 | import java.awt.*;
15 | import java.awt.event.ActionEvent;
16 | import java.awt.event.ActionListener;
17 | import java.io.IOException;
18 | import java.util.List;
19 |
20 |
21 | public class SwingApp extends JFrame implements ActionListener {
22 |
23 | private final JMenuItem openItem;
24 | private final JLabel label;
25 | private Yolo inferenceSession;
26 |
27 | public SwingApp() {
28 |
29 | ModelFactory modelFactory = new ModelFactory();
30 |
31 | try {
32 | this.inferenceSession = modelFactory.getModel("model.properties");
33 | } catch (OrtException | IOException exception) {
34 | exception.printStackTrace();
35 | System.exit(1);
36 | }
37 |
38 | setTitle("Yolo Demo");
39 | setSize(640, 640);
40 | JMenuBar mbar = new JMenuBar();
41 | JMenu m = new JMenu("File");
42 | openItem = new JMenuItem("Open");
43 | openItem.addActionListener(this);
44 | m.add(openItem);
45 | mbar.add(m);
46 | setJMenuBar(mbar);
47 | label = new JLabel();
48 | Container contentPane = getContentPane();
49 | contentPane.add(label, "Center");
50 | setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
51 | setVisible(true);
52 | setResizable(true);
53 |
54 | }
55 |
56 |
57 | public static void main(String[] args) {
58 | new SwingApp();
59 | }
60 |
61 | @Override
62 | public void actionPerformed(ActionEvent e) {
63 | Object source = e.getSource();
64 |
65 | if (source == openItem) {
66 |
67 | JFileChooser chooser = new JFileChooser();
68 | chooser.setFileFilter(new FileNameExtensionFilter(
69 | "Image files",
70 | "png", "jpg", "jpeg")
71 | );
72 |
73 | int r = chooser.showOpenDialog(this);
74 |
75 | if (r == JFileChooser.APPROVE_OPTION) {
76 | String name = chooser.getSelectedFile().getPath();
77 |
78 | // open image file
79 | Mat img = Imgcodecs.imread(name);
80 | int MAX_SIZE = 720;
81 | if (Math.max(img.width(), img.height()) > MAX_SIZE) {
82 | ImageUtil.resizeWithPadding(img, img, MAX_SIZE, MAX_SIZE);
83 | }
84 |
85 | // run detection
86 | try {
87 | List detectionList = inferenceSession.run(img);
88 | ImageUtil.drawPredictions(img, detectionList);
89 | } catch (OrtException ortException) {
90 | ortException.printStackTrace();
91 | }
92 |
93 | // Displaying in the GUI
94 | MatOfByte matOfByte = new MatOfByte();
95 | Imgcodecs.imencode(".png", img, matOfByte);
96 |
97 | ImageIcon imageIcon = new ImageIcon(matOfByte.toArray());
98 | imageIcon.getImage().flush();
99 | label.setIcon(imageIcon);
100 | pack();
101 |
102 | }
103 | }
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/src/main/java/com/example/app/video/VideoApp.java:
--------------------------------------------------------------------------------
1 | package com.example.app.video;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import com.example.yolo.Detection;
5 | import com.example.yolo.ImageUtil;
6 | import com.example.yolo.ModelFactory;
7 | import com.example.yolo.Yolo;
8 | import org.opencv.core.Mat;
9 | import org.opencv.core.MatOfByte;
10 | import org.opencv.imgcodecs.Imgcodecs;
11 | import org.opencv.videoio.VideoCapture;
12 | import org.opencv.videoio.Videoio;
13 |
14 | import javax.swing.*;
15 | import java.io.IOException;
16 | import java.util.List;
17 |
18 |
19 | public class VideoApp {
20 |
21 | private Yolo inferenceSession;
22 |
23 | public VideoApp() {
24 | ModelFactory modelFactory = new ModelFactory();
25 | try {
26 | this.inferenceSession = modelFactory.getModel("model.properties");
27 | } catch (OrtException | IOException exception) {
28 | exception.printStackTrace();
29 | System.exit(1);
30 | }
31 | }
32 |
33 | public static void main(String[] args) {
34 | nu.pattern.OpenCV.loadLocally();
35 | VideoApp app = new VideoApp();
36 |
37 | VideoCapture cap = new VideoCapture(0);
38 | cap.set(Videoio.CAP_PROP_FRAME_WIDTH, 1280);
39 | cap.set(Videoio.CAP_PROP_FRAME_HEIGHT, 800);
40 |
41 | JFrame jframe = new JFrame("Title");
42 | jframe.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
43 | JLabel vidpanel = new JLabel();
44 | jframe.setContentPane(vidpanel);
45 | jframe.setVisible(true);
46 | jframe.setResizable(true);
47 | jframe.setSize(1280, 800);
48 |
49 | if (!cap.isOpened()) {
50 | System.exit(1);
51 | }
52 |
53 | Mat frame = new Mat();
54 | MatOfByte matofByte = new MatOfByte();
55 |
56 | while ( cap.read(frame) ) {
57 |
58 | ImageUtil.resizeWithPadding(frame, frame, 1280, 800);
59 |
60 | // run detection
61 | try {
62 | List detectionList = app.inferenceSession.run(frame);
63 | ImageUtil.drawPredictions(frame, detectionList);
64 | } catch (OrtException ortException) {
65 | ortException.printStackTrace();
66 | }
67 | Imgcodecs.imencode(".jpg", frame, matofByte);
68 | ImageIcon imageIcon = new ImageIcon(matofByte.toArray());
69 | vidpanel.setIcon(imageIcon);
70 | vidpanel.repaint();
71 | }
72 |
73 | cap.release();
74 | }
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/Detection.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | public record Detection(String label, float[] bbox, float confidence) {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/ImageUtil.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | import org.opencv.core.*;
4 | import org.opencv.imgproc.Imgproc;
5 |
6 | import java.util.List;
7 |
8 |
9 | public class ImageUtil {
10 |
11 | public static Mat resizeWithPadding(Mat src, int width, int height) {
12 |
13 | Mat dst = new Mat();
14 | int oldW = src.width();
15 | int oldH = src.height();
16 |
17 | double r = Math.min((double) width / oldW, (double) height / oldH);
18 |
19 | int newUnpadW = (int) Math.round(oldW * r);
20 | int newUnpadH = (int) Math.round(oldH * r);
21 |
22 | int dw = (width - newUnpadW) / 2;
23 | int dh = (height - newUnpadH) / 2;
24 |
25 | int top = (int) Math.round(dh - 0.1);
26 | int bottom = (int) Math.round(dh + 0.1);
27 | int left = (int) Math.round(dw - 0.1);
28 | int right = (int) Math.round(dw + 0.1);
29 |
30 | Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
31 | Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
32 |
33 | return dst;
34 |
35 | }
36 |
37 | public static void resizeWithPadding(Mat src, Mat dst, int width, int height) {
38 |
39 | int oldW = src.width();
40 | int oldH = src.height();
41 |
42 | double r = Math.min((double) width / oldW, (double) height / oldH);
43 |
44 | int newUnpadW = (int) Math.round(oldW * r);
45 | int newUnpadH = (int) Math.round(oldH * r);
46 |
47 | int dw = (width - newUnpadW) / 2;
48 | int dh = (height - newUnpadH) / 2;
49 |
50 | int top = (int) Math.round(dh - 0.1);
51 | int bottom = (int) Math.round(dh + 0.1);
52 | int left = (int) Math.round(dw - 0.1);
53 | int right = (int) Math.round(dw + 0.1);
54 |
55 | Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
56 | Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
57 |
58 | }
59 |
60 | public static void whc2cwh(float[] src, float[] dst, int start) {
61 | int j = start;
62 | for (int ch = 0; ch < 3; ++ch) {
63 | for (int i = ch; i < src.length; i += 3) {
64 | dst[j] = src[i];
65 | j++;
66 | }
67 | }
68 | }
69 |
70 | public static float[] whc2cwh(float[] src) {
71 | float[] chw = new float[src.length];
72 | int j = 0;
73 | for (int ch = 0; ch < 3; ++ch) {
74 | for (int i = ch; i < src.length; i += 3) {
75 | chw[j] = src[i];
76 | j++;
77 | }
78 | }
79 | return chw;
80 | }
81 |
82 | public static byte[] whc2cwh(byte[] src) {
83 | byte[] chw = new byte[src.length];
84 | int j = 0;
85 | for (int ch = 0; ch < 3; ++ch) {
86 | for (int i = ch; i < src.length; i += 3) {
87 | chw[j] = src[i];
88 | j++;
89 | }
90 | }
91 | return chw;
92 | }
93 |
94 | public static void drawPredictions(Mat img, List detectionList) {
95 | // debugging image
96 | for (Detection detection : detectionList) {
97 |
98 | float[] bbox = detection.bbox();
99 | Scalar color = new Scalar(249, 218, 60);
100 | Imgproc.rectangle(img, //Matrix obj of the image
101 | new Point(bbox[0], bbox[1]), //p1
102 | new Point(bbox[2], bbox[3]), //p2
103 | color, //Scalar object for color
104 | 2 //Thickness of the line
105 | );
106 | Imgproc.putText(
107 | img,
108 | detection.label(),
109 | new Point(bbox[0] - 1, bbox[1] - 5),
110 | Imgproc.FONT_HERSHEY_SIMPLEX,
111 | .5, color,
112 | 1);
113 | }
114 | }
115 |
116 | }
117 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/ModelFactory.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import com.example.app.ConfigReader;
5 | import org.springframework.util.ResourceUtils;
6 |
7 | import java.io.File;
8 | import java.io.IOException;
9 | import java.util.Properties;
10 |
11 | public class ModelFactory {
12 |
13 |
14 | public Yolo getModel(String propertiesFilePath) throws IOException, OrtException, NotImplementedException {
15 |
16 | ConfigReader configReader = new ConfigReader();
17 | Properties properties = configReader.readProperties(propertiesFilePath);
18 |
19 | String modelName = properties.getProperty("modelName");
20 | File file = ResourceUtils.getFile("classpath:" + properties.getProperty("modelPath"));
21 | File file2 = ResourceUtils.getFile("classpath:coco.names");
22 | String modelPath = String.valueOf((file));
23 | String labelPath = String.valueOf((file2));
24 | float confThreshold = Float.parseFloat(properties.getProperty("confThreshold"));
25 | float nmsThreshold = Float.parseFloat(properties.getProperty("nmsThreshold"));
26 | int gpuDeviceId = Integer.parseInt(properties.getProperty("gpuDeviceId"));
27 |
28 | if (modelName.equalsIgnoreCase("yolov5")) {
29 | return new YoloV5(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId);
30 | }
31 | else if (modelName.equalsIgnoreCase("yolov8")) {
32 | return new YoloV8(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId);
33 | }
34 | else {
35 | throw new NotImplementedException();
36 | }
37 |
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/NotImplementedException.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 |
4 | public class NotImplementedException extends RuntimeException {
5 | public NotImplementedException(){}
6 | }
7 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/Yolo.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | import ai.onnxruntime.*;
4 | import org.opencv.core.Mat;
5 |
6 | import java.io.BufferedReader;
7 | import java.io.FileReader;
8 | import java.io.IOException;
9 | import java.util.*;
10 | import java.util.stream.Collectors;
11 |
12 | public abstract class Yolo {
13 |
14 | public static final int INPUT_SIZE = 640;
15 | public static final int NUM_INPUT_ELEMENTS = 3 * 640 * 640;
16 | public static final long[] INPUT_SHAPE = {1, 3, 640, 640};
17 | public float confThreshold;
18 | public float nmsThreshold;
19 | public OnnxJavaType inputType;
20 | protected final OrtEnvironment env;
21 | protected final OrtSession session;
22 | protected final String inputName;
23 | public ArrayList labelNames;
24 |
25 | OnnxTensor inputTensor;
26 |
27 | public Yolo(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException {
28 | nu.pattern.OpenCV.loadLocally();
29 |
30 | this.env = OrtEnvironment.getEnvironment();
31 | var sessionOptions = new OrtSession.SessionOptions();
32 | sessionOptions.addCPU(false);
33 | sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
34 |
35 | if (gpuDeviceId >= 0) sessionOptions.addCUDA(gpuDeviceId);
36 | this.session = this.env.createSession(modelPath, sessionOptions);
37 |
38 | Map inputMetaMap = this.session.getInputInfo();
39 | this.inputName = this.session.getInputNames().iterator().next();
40 | NodeInfo inputMeta = inputMetaMap.get(this.inputName);
41 | this.inputType = ((TensorInfo) inputMeta.getInfo()).type;
42 |
43 | this.confThreshold = confThreshold;
44 | this.nmsThreshold = nmsThreshold;
45 |
46 | BufferedReader br = new BufferedReader(new FileReader(labelPath));
47 | String line;
48 | this.labelNames = new ArrayList<>();
49 | while ((line = br.readLine()) != null) {
50 | this.labelNames.add(line);
51 | }
52 |
53 | }
54 | public abstract List run(Mat img) throws OrtException;
55 |
56 | private float computeIOU(float[] box1, float[] box2) {
57 |
58 | float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
59 | float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
60 |
61 | float left = Math.max(box1[0], box2[0]);
62 | float top = Math.max(box1[1], box2[1]);
63 | float right = Math.min(box1[2], box2[2]);
64 | float bottom = Math.min(box1[3], box2[3]);
65 |
66 | float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0);
67 | float unionArea = area1 + area2 - interArea;
68 | return Math.max(interArea / unionArea, 1e-8f);
69 |
70 | }
71 |
72 | protected List nonMaxSuppression(List bboxes, float iouThreshold) {
73 |
74 | // output boxes
75 | List bestBboxes = new ArrayList<>();
76 |
77 | // confidence 순 정렬
78 | bboxes.sort(Comparator.comparing(a -> a[4]));
79 |
80 | // standard nms
81 | while (!bboxes.isEmpty()) {
82 | float[] bestBbox = bboxes.remove(bboxes.size() - 1); // 현재 가장 confidence 가 높은 박스 pop
83 | bestBboxes.add(bestBbox);
84 | bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList());
85 | }
86 |
87 | return bestBboxes;
88 | }
89 |
90 | protected void xywh2xyxy(float[] bbox) {
91 | float x = bbox[0];
92 | float y = bbox[1];
93 | float w = bbox[2];
94 | float h = bbox[3];
95 |
96 | bbox[0] = x - w * 0.5f;
97 | bbox[1] = y - h * 0.5f;
98 | bbox[2] = x + w * 0.5f;
99 | bbox[3] = y + h * 0.5f;
100 | }
101 |
102 | protected void scaleCoords(float[] bbox, float orgW, float orgH, float padW, float padH, float gain) {
103 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
104 | bbox[0] = Math.max(0, Math.min(orgW - 1, (bbox[0] - padW) / gain));
105 | bbox[1] = Math.max(0, Math.min(orgH - 1, (bbox[1] - padH) / gain));
106 | bbox[2] = Math.max(0, Math.min(orgW - 1, (bbox[2] - padW) / gain));
107 | bbox[3] = Math.max(0, Math.min(orgH - 1, (bbox[3] - padH) / gain));
108 | }
109 |
110 | static int argmax(float[] a) {
111 | float re = -Float.MAX_VALUE;
112 | int arg = -1;
113 | for (int i = 0; i < a.length; i++) {
114 | if (a[i] >= re) {
115 | re = a[i];
116 | arg = i;
117 | }
118 | }
119 | return arg;
120 | }
121 |
122 | }
123 |
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/YoloV5.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | import ai.onnxruntime.*;
4 | import org.opencv.core.CvType;
5 | import org.opencv.core.Mat;
6 | import org.opencv.imgproc.Imgproc;
7 |
8 | import java.io.IOException;
9 | import java.nio.ByteBuffer;
10 | import java.nio.FloatBuffer;
11 | import java.util.*;
12 |
13 | public class YoloV5 extends Yolo {
14 |
15 | public YoloV5(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException {
16 | super(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId);
17 | }
18 |
19 | public List run(Mat img) throws OrtException {
20 |
21 | float orgW = (float) img.size().width;
22 | float orgH = (float) img.size().height;
23 |
24 | float gain = Math.min((float) INPUT_SIZE / orgW, (float) INPUT_SIZE / orgH);
25 | float padW = (INPUT_SIZE - orgW * gain) * 0.5f;
26 | float padH = (INPUT_SIZE - orgH * gain) * 0.5f;
27 |
28 | // preprocessing
29 | Map inputContainer = this.preprocess(img);
30 |
31 | // Run inference
32 | float[][] predictions;
33 |
34 | OrtSession.Result results = this.session.run(inputContainer);
35 | predictions = ((float[][][]) results.get(0).getValue())[0];
36 |
37 | // postprocessing
38 | return postprocess(predictions, orgW, orgH, padW, padH, gain);
39 | }
40 |
41 |
42 | public Map preprocess(Mat img) throws OrtException {
43 |
44 | // Resizing
45 | Mat resizedImg = new Mat();
46 | ImageUtil.resizeWithPadding(img, resizedImg, INPUT_SIZE, INPUT_SIZE);
47 |
48 | // BGR -> RGB
49 | Imgproc.cvtColor(resizedImg, resizedImg, Imgproc.COLOR_BGR2RGB);
50 |
51 | Map container = new HashMap<>();
52 |
53 | // if model is quantized
54 | if (this.inputType.equals(OnnxJavaType.UINT8)) {
55 | byte[] whc = new byte[NUM_INPUT_ELEMENTS];
56 | resizedImg.get(0, 0, whc);
57 | byte[] chw = ImageUtil.whc2cwh(whc);
58 | ByteBuffer inputBuffer = ByteBuffer.wrap(chw);
59 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE, this.inputType);
60 | } else {
61 | // Normalization
62 | resizedImg.convertTo(resizedImg, CvType.CV_32FC1, 1. / 255);
63 | float[] whc = new float[NUM_INPUT_ELEMENTS];
64 | resizedImg.get(0, 0, whc);
65 | float[] chw = ImageUtil.whc2cwh(whc);
66 | FloatBuffer inputBuffer = FloatBuffer.wrap(chw);
67 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE);
68 | }
69 |
70 | // To OnnxTensor
71 | container.put(this.inputName, inputTensor);
72 |
73 | return container;
74 | }
75 |
76 | public List postprocess(float[][] outputs, float orgW, float orgH, float padW, float padH, float gain) {
77 |
78 | // predictions
79 | Map> class2Bbox = new HashMap<>();
80 |
81 | for (float[] bbox : outputs) {
82 |
83 | float conf = bbox[4];
84 | if (conf < this.confThreshold) continue;
85 |
86 | float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 5, 85);
87 | int label = Yolo.argmax(conditionalProbabilities);
88 |
89 | // xywh to (x1, y1, x2, y2)
90 | xywh2xyxy(bbox);
91 |
92 | // skip invalid predictions
93 | if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
94 |
95 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
96 | scaleCoords(bbox, orgW, orgH, padW, padH, gain);
97 | class2Bbox.putIfAbsent(label, new ArrayList<>());
98 | class2Bbox.get(label).add(bbox);
99 | }
100 |
101 | // Apply Non-max suppression for each class
102 | List detections = new ArrayList<>();
103 | for (Map.Entry> entry : class2Bbox.entrySet()) {
104 | int label = entry.getKey();
105 | List bboxes = entry.getValue();
106 | bboxes = nonMaxSuppression(bboxes, this.nmsThreshold);
107 | for (float[] bbox : bboxes) {
108 | String labelString = this.labelNames.get(label);
109 | detections.add(new Detection(labelString, Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
110 | }
111 | }
112 |
113 | return detections;
114 | }
115 | }
--------------------------------------------------------------------------------
/src/main/java/com/example/yolo/YoloV8.java:
--------------------------------------------------------------------------------
1 | package com.example.yolo;
2 |
3 | import ai.onnxruntime.*;
4 | import org.opencv.core.CvType;
5 | import org.opencv.core.Mat;
6 | import org.opencv.imgproc.Imgproc;
7 |
8 | import java.io.IOException;
9 | import java.nio.ByteBuffer;
10 | import java.nio.FloatBuffer;
11 | import java.util.*;
12 |
13 | public class YoloV8 extends Yolo{
14 |
15 | public YoloV8(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException {
16 | super(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId);
17 | }
18 |
19 | public List run(Mat img) throws OrtException {
20 |
21 | float orgW = (float) img.size().width;
22 | float orgH = (float) img.size().height;
23 |
24 | float gain = Math.min((float) INPUT_SIZE / orgW, (float) INPUT_SIZE / orgH);
25 | float padW = (INPUT_SIZE - orgW * gain) * 0.5f;
26 | float padH = (INPUT_SIZE - orgH * gain) * 0.5f;
27 |
28 | // preprocessing
29 | Map inputContainer = this.preprocess(img);
30 |
31 | // Run inference
32 | float[][] predictions;
33 |
34 | OrtSession.Result results = this.session.run(inputContainer);
35 | predictions = ((float[][][]) results.get(0).getValue())[0];
36 |
37 | // postprocessing
38 | return postprocess(predictions, orgW, orgH, padW, padH, gain);
39 | }
40 |
41 |
42 | public Map preprocess(Mat img) throws OrtException {
43 |
44 | // Resizing
45 | Mat resizedImg = new Mat();
46 | ImageUtil.resizeWithPadding(img, resizedImg, INPUT_SIZE, INPUT_SIZE);
47 |
48 | // BGR -> RGB
49 | Imgproc.cvtColor(resizedImg, resizedImg, Imgproc.COLOR_BGR2RGB);
50 |
51 | Map container = new HashMap<>();
52 |
53 | // if model is quantized
54 | if (this.inputType.equals(OnnxJavaType.UINT8)) {
55 | byte[] whc = new byte[NUM_INPUT_ELEMENTS];
56 | resizedImg.get(0, 0, whc);
57 | byte[] chw = ImageUtil.whc2cwh(whc);
58 | ByteBuffer inputBuffer = ByteBuffer.wrap(chw);
59 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE, this.inputType);
60 | } else {
61 | // Normalization
62 | resizedImg.convertTo(resizedImg, CvType.CV_32FC1, 1. / 255);
63 | float[] whc = new float[NUM_INPUT_ELEMENTS];
64 | resizedImg.get(0, 0, whc);
65 | float[] chw = ImageUtil.whc2cwh(whc);
66 | FloatBuffer inputBuffer = FloatBuffer.wrap(chw);
67 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE);
68 | }
69 |
70 | // To OnnxTensor
71 | container.put(this.inputName, inputTensor);
72 |
73 | return container;
74 | }
75 |
76 | public List postprocess(float[][] outputs, float orgW, float orgH, float padW, float padH, float gain) {
77 |
78 | // predictions
79 | outputs = transposeMatrix(outputs);
80 | Map> class2Bbox = new HashMap<>();
81 |
82 | for (float[] bbox : outputs) {
83 |
84 |
85 | float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, 84);
86 | int label = Yolo.argmax(conditionalProbabilities);
87 | float conf = conditionalProbabilities[label];
88 | if (conf < this.confThreshold) continue;
89 |
90 | bbox[4] = conf;
91 |
92 | // xywh to (x1, y1, x2, y2)
93 | xywh2xyxy(bbox);
94 |
95 | // skip invalid predictions
96 | if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
97 |
98 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
99 | scaleCoords(bbox, orgW, orgH, padW, padH, gain);
100 | class2Bbox.putIfAbsent(label, new ArrayList<>());
101 | class2Bbox.get(label).add(bbox);
102 | }
103 |
104 | // Apply Non-max suppression for each class
105 | List detections = new ArrayList<>();
106 | for (Map.Entry> entry : class2Bbox.entrySet()) {
107 | int label = entry.getKey();
108 | List bboxes = entry.getValue();
109 | bboxes = nonMaxSuppression(bboxes, this.nmsThreshold);
110 | for (float[] bbox : bboxes) {
111 | String labelString = this.labelNames.get(label);
112 | detections.add(new Detection(labelString, Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
113 | }
114 | }
115 |
116 | return detections;
117 | }
118 |
119 | public static float[][] transposeMatrix(float [][] m){
120 | float[][] temp = new float[m[0].length][m.length];
121 | for (int i = 0; i < m.length; i++)
122 | for (int j = 0; j < m[0].length; j++)
123 | temp[j][i] = m[i][j];
124 | return temp;
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/src/main/resources/coco.names:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
--------------------------------------------------------------------------------
/src/main/resources/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/dog.jpg
--------------------------------------------------------------------------------
/src/main/resources/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/dog.png
--------------------------------------------------------------------------------
/src/main/resources/kite.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/kite.jpg
--------------------------------------------------------------------------------
/src/main/resources/model.properties:
--------------------------------------------------------------------------------
1 | modelName=yolov5
2 | modelPath=yolov5s.onnx
3 | #modelName=yolov8
4 | #modelPath=yolov8s.onnx
5 | labelPath=coco.names
6 | confThreshold=0.25
7 | nmsThreshold=0.45
8 | gpuDeviceId=-1
9 |
--------------------------------------------------------------------------------
/src/main/resources/yolov5s.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov5s.onnx
--------------------------------------------------------------------------------
/src/main/resources/yolov5s.sparseml.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov5s.sparseml.onnx
--------------------------------------------------------------------------------
/src/main/resources/yolov8s.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov8s.onnx
--------------------------------------------------------------------------------