├── .gitignore ├── README.md ├── assets ├── api.png ├── cli.png ├── gui.png └── predictions.jpg ├── pom.xml └── src └── main ├── java └── com │ └── example │ ├── app │ ├── ConfigReader.java │ ├── api │ │ ├── DetectionController.java │ │ ├── ErrorCode.java │ │ ├── ErrorResponse.java │ │ ├── GlobalExceptionHandler.java │ │ ├── ServingApplication.java │ │ └── UploadFileException.java │ ├── cli │ │ └── CLI_App.java │ ├── swing │ │ └── SwingApp.java │ └── video │ │ └── VideoApp.java │ └── yolo │ ├── Detection.java │ ├── ImageUtil.java │ ├── ModelFactory.java │ ├── NotImplementedException.java │ ├── Yolo.java │ ├── YoloV5.java │ └── YoloV8.java └── resources ├── coco.names ├── dog.jpg ├── dog.png ├── kite.jpg ├── model.properties ├── yolov5s.onnx ├── yolov5s.sparseml.onnx └── yolov8s.onnx /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | !.mvn/wrapper/maven-wrapper.jar 3 | !**/src/main/**/target/ 4 | !**/src/test/**/target/ 5 | 6 | ### IntelliJ IDEA ### 7 | .idea/modules.xml 8 | .idea/jarRepositories.xml 9 | .idea/compiler.xml 10 | .idea/libraries/ 11 | *.iws 12 | *.iml 13 | *.ipr 14 | 15 | ### Eclipse ### 16 | .apt_generated 17 | .classpath 18 | .factorypath 19 | .project 20 | .settings 21 | .springBeans 22 | .sts4-cache 23 | 24 | ### NetBeans ### 25 | /nbproject/private/ 26 | /nbbuild/ 27 | /dist/ 28 | /nbdist/ 29 | /.nb-gradle/ 30 | build/ 31 | !**/src/main/**/build/ 32 | !**/src/test/**/build/ 33 | 34 | ### VS Code ### 35 | .vscode/ 36 | 37 | ### Mac OS ### 38 | .DS_Store 39 | .idea 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Onnxruntime Java Example : yolov5 2 | 3 | - An object detection example using onnxruntime and YOLO (v5 and v8) 4 | 5 | ![](assets/predictions.jpg) 6 | 7 | ```bash 8 | mvn clean compile 9 | 10 | # GUI App 11 | # mvn exec:java -Dexec.mainClass="com.example.app.swing.SwingApp" -Dexec.classpathScope=test 12 | 13 | # CLI APP 14 | # mvn exec:java -Dexec.mainClass="com.example.app.cli.CLI_App" -Dexec.classpathScope=test 15 | 16 | # API 17 | mvn spring-boot:run 18 | ``` 19 | 20 | ## Configs 21 | 22 | - See `src/main/resources/model.properties` 23 | 24 | ```properties 25 | modelName=yolov5 26 | modelPath=yolov5s.onnx 27 | #modelName=yolov8 28 | #modelPath=yolov8s.onnx 29 | labelPath=coco.names 30 | confThreshold=0.25 31 | nmsThreshold=0.45 32 | gpuDeviceId=-1 33 | ``` 34 | 35 | ## GUI Demo 36 | 37 | 38 | ![](assets/gui.png) 39 | 40 | ## CLI Demo 41 | 42 | ![](assets/cli.png) 43 | 44 | ## Rest API Demo 45 | 46 | - Endpoint: `/detection` 47 | - `multipart/form-data` 48 | - `name="uploadFile"` 49 | 50 | ![](assets/api.png) 51 | 52 | ## Updates 53 | 54 | ### Feb 25, 2023 55 | 56 | - Updated Yolov8 support 57 | - The model configs have been changed to be controlled through the properties file 58 | -------------------------------------------------------------------------------- /assets/api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/api.png -------------------------------------------------------------------------------- /assets/cli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/cli.png -------------------------------------------------------------------------------- /assets/gui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/gui.png -------------------------------------------------------------------------------- /assets/predictions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/assets/predictions.jpg -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.example 8 | yolov5 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 18 13 | 18 14 | UTF-8 15 | 16 | 17 | 18 | 19 | org.springframework.boot 20 | spring-boot-starter-web 21 | 2.7.5 22 | 23 | 24 | com.microsoft.onnxruntime 25 | onnxruntime 26 | 1.13.1 27 | 28 | 29 | org.openpnp 30 | opencv 31 | 4.5.5-1 32 | 33 | 34 | com.google.code.gson 35 | gson 36 | 2.10 37 | 38 | 39 | org.projectlombok 40 | lombok 41 | 1.18.24 42 | compile 43 | 44 | 45 | 46 | 47 | 48 | 49 | org.springframework.boot 50 | spring-boot-maven-plugin 51 | 52 | 53 | --spring.profiles.active=dev 54 | 55 | 56 | 2.7.5 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/ConfigReader.java: -------------------------------------------------------------------------------- 1 | package com.example.app; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.util.Properties; 6 | 7 | public class ConfigReader { 8 | 9 | public Properties readProperties(String propertiesFilePath) { 10 | Properties properties = new Properties(); 11 | InputStream inputStream = getClass().getClassLoader().getResourceAsStream(propertiesFilePath); 12 | try { 13 | properties.load(inputStream); 14 | return properties; 15 | } catch (IOException e) { 16 | throw new RuntimeException(e); 17 | } 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/DetectionController.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import com.example.yolo.Detection; 5 | import com.example.yolo.ModelFactory; 6 | import com.example.yolo.NotImplementedException; 7 | import com.example.yolo.Yolo; 8 | import org.opencv.core.Mat; 9 | import org.opencv.core.MatOfByte; 10 | import org.opencv.imgcodecs.Imgcodecs; 11 | import org.slf4j.Logger; 12 | import org.slf4j.LoggerFactory; 13 | import org.springframework.web.bind.annotation.PostMapping; 14 | import org.springframework.web.bind.annotation.RestController; 15 | import org.springframework.web.multipart.MultipartFile; 16 | 17 | import java.io.IOException; 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | 22 | @RestController 23 | public class DetectionController { 24 | static private final List mimeTypes = Arrays.asList("image/png", "image/jpeg"); 25 | private final Yolo inferenceSession; 26 | private final Logger LOGGER = LoggerFactory.getLogger(DetectionController.class); 27 | 28 | private DetectionController() throws OrtException, IOException, NotImplementedException { 29 | ModelFactory modelFactory = new ModelFactory(); 30 | this.inferenceSession = modelFactory.getModel("model.properties"); 31 | } 32 | 33 | @PostMapping(value = "/detection", consumes = {"multipart/form-data"}, produces = {"application/json"}) 34 | public List detection(MultipartFile uploadFile) throws OrtException, IOException { 35 | if (!mimeTypes.contains(uploadFile.getContentType())) throw new UploadFileException(ErrorCode.INVALID_MIME_TYPE); 36 | byte[] bytes = uploadFile.getBytes(); 37 | Mat img = Imgcodecs.imdecode(new MatOfByte(bytes), Imgcodecs.IMREAD_COLOR); 38 | List result = inferenceSession.run(img); 39 | LOGGER.info("POST 200"); 40 | return result; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/ErrorCode.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Getter; 5 | import org.springframework.http.HttpStatus; 6 | 7 | 8 | @AllArgsConstructor 9 | @Getter 10 | public enum ErrorCode { 11 | INVALID_MIME_TYPE(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "invalid mime type"), 12 | INVALID_UPLOAD_FILE(HttpStatus.BAD_REQUEST, "invalid file"); 13 | private final HttpStatus httpStatus; 14 | private final String message; 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/ErrorResponse.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import lombok.Builder; 4 | import lombok.Getter; 5 | import org.springframework.http.ResponseEntity; 6 | 7 | import java.time.LocalDateTime; 8 | 9 | @Getter 10 | @Builder 11 | public class ErrorResponse { 12 | private final LocalDateTime timestamp = LocalDateTime.now(); 13 | private final int status; 14 | private final String error; 15 | private final String code; 16 | private final String message; 17 | 18 | public static ResponseEntity toResponseEntity(ErrorCode errorCode) { 19 | return ResponseEntity 20 | .status(errorCode.getHttpStatus()) 21 | .body(ErrorResponse.builder() 22 | .status(errorCode.getHttpStatus().value()) 23 | .error(errorCode.getHttpStatus().name()) 24 | .code(errorCode.name()) 25 | .message(errorCode.getMessage()) 26 | .build() 27 | ); 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/GlobalExceptionHandler.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | import org.springframework.http.ResponseEntity; 5 | import org.springframework.web.bind.annotation.ExceptionHandler; 6 | import org.springframework.web.bind.annotation.RestControllerAdvice; 7 | 8 | @Slf4j 9 | @RestControllerAdvice 10 | public class GlobalExceptionHandler { 11 | 12 | @ExceptionHandler(value = {UploadFileException.class}) 13 | protected ResponseEntity handleUploadFileException(UploadFileException e) { 14 | log.error("handleUploadFileException throw UploadFileException : {}", e.getErrorCode()); 15 | return ErrorResponse.toResponseEntity(e.getErrorCode()); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/ServingApplication.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | 6 | 7 | @SpringBootApplication 8 | public class ServingApplication { 9 | public static void main(String[] args) { 10 | SpringApplication.run(ServingApplication.class, args); 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/api/UploadFileException.java: -------------------------------------------------------------------------------- 1 | package com.example.app.api; 2 | 3 | import lombok.AllArgsConstructor; 4 | import lombok.Getter; 5 | 6 | @Getter 7 | @AllArgsConstructor 8 | public class UploadFileException extends RuntimeException { 9 | private final ErrorCode errorCode; 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/cli/CLI_App.java: -------------------------------------------------------------------------------- 1 | package com.example.app.cli; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import com.example.yolo.*; 5 | import com.google.gson.Gson; 6 | import org.opencv.core.Mat; 7 | import org.opencv.imgcodecs.Imgcodecs; 8 | 9 | import java.io.*; 10 | import java.util.List; 11 | 12 | public class CLI_App { 13 | 14 | private Yolo inferenceSession; 15 | private Gson gson; 16 | 17 | public CLI_App() { 18 | 19 | ModelFactory modelFactory = new ModelFactory(); 20 | 21 | try { 22 | this.inferenceSession = modelFactory.getModel("./model.properties"); 23 | this.gson = new Gson(); 24 | } catch (OrtException | IOException exception) { 25 | exception.printStackTrace(); 26 | System.exit(1); 27 | } 28 | } 29 | 30 | public static void main(String[] args) { 31 | 32 | CLI_App app = new CLI_App(); 33 | 34 | while (true) { 35 | 36 | System.out.print("Enter image path (enter 'q' or 'Q' to exit): "); 37 | 38 | BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); 39 | String input = null; 40 | 41 | try { 42 | input = br.readLine(); 43 | } catch (IOException e) { 44 | e.printStackTrace(); 45 | System.exit(1); 46 | } 47 | 48 | if ("q".equals(input) | "Q".equals(input)) { 49 | System.out.println("Exit"); 50 | System.exit(0); 51 | } 52 | 53 | File f = new File(input); 54 | if (!f.exists()) { 55 | System.out.println("File does not exists: " + input); 56 | continue; 57 | } 58 | if (f.isDirectory()) { 59 | System.out.println(input + " is a directory"); 60 | continue; 61 | } 62 | 63 | Mat img = Imgcodecs.imread(input, Imgcodecs.IMREAD_COLOR); 64 | if (img.dataAddr() == 0) { 65 | System.out.println("Could not open image: " + input); 66 | continue; 67 | } 68 | // run detection 69 | try { 70 | List detectionList = app.inferenceSession.run(img); 71 | ImageUtil.drawPredictions(img, detectionList); 72 | System.out.println(app.gson.toJson(detectionList)); 73 | Imgcodecs.imwrite("predictions.jpg", img); 74 | } catch (OrtException ortException) { 75 | ortException.printStackTrace(); 76 | } 77 | 78 | } 79 | 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/swing/SwingApp.java: -------------------------------------------------------------------------------- 1 | package com.example.app.swing; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import com.example.yolo.Detection; 5 | import com.example.yolo.ImageUtil; 6 | import com.example.yolo.ModelFactory; 7 | import com.example.yolo.Yolo; 8 | import org.opencv.core.Mat; 9 | import org.opencv.core.MatOfByte; 10 | import org.opencv.imgcodecs.Imgcodecs; 11 | 12 | import javax.swing.*; 13 | import javax.swing.filechooser.FileNameExtensionFilter; 14 | import java.awt.*; 15 | import java.awt.event.ActionEvent; 16 | import java.awt.event.ActionListener; 17 | import java.io.IOException; 18 | import java.util.List; 19 | 20 | 21 | public class SwingApp extends JFrame implements ActionListener { 22 | 23 | private final JMenuItem openItem; 24 | private final JLabel label; 25 | private Yolo inferenceSession; 26 | 27 | public SwingApp() { 28 | 29 | ModelFactory modelFactory = new ModelFactory(); 30 | 31 | try { 32 | this.inferenceSession = modelFactory.getModel("model.properties"); 33 | } catch (OrtException | IOException exception) { 34 | exception.printStackTrace(); 35 | System.exit(1); 36 | } 37 | 38 | setTitle("Yolo Demo"); 39 | setSize(640, 640); 40 | JMenuBar mbar = new JMenuBar(); 41 | JMenu m = new JMenu("File"); 42 | openItem = new JMenuItem("Open"); 43 | openItem.addActionListener(this); 44 | m.add(openItem); 45 | mbar.add(m); 46 | setJMenuBar(mbar); 47 | label = new JLabel(); 48 | Container contentPane = getContentPane(); 49 | contentPane.add(label, "Center"); 50 | setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 51 | setVisible(true); 52 | setResizable(true); 53 | 54 | } 55 | 56 | 57 | public static void main(String[] args) { 58 | new SwingApp(); 59 | } 60 | 61 | @Override 62 | public void actionPerformed(ActionEvent e) { 63 | Object source = e.getSource(); 64 | 65 | if (source == openItem) { 66 | 67 | JFileChooser chooser = new JFileChooser(); 68 | chooser.setFileFilter(new FileNameExtensionFilter( 69 | "Image files", 70 | "png", "jpg", "jpeg") 71 | ); 72 | 73 | int r = chooser.showOpenDialog(this); 74 | 75 | if (r == JFileChooser.APPROVE_OPTION) { 76 | String name = chooser.getSelectedFile().getPath(); 77 | 78 | // open image file 79 | Mat img = Imgcodecs.imread(name); 80 | int MAX_SIZE = 720; 81 | if (Math.max(img.width(), img.height()) > MAX_SIZE) { 82 | ImageUtil.resizeWithPadding(img, img, MAX_SIZE, MAX_SIZE); 83 | } 84 | 85 | // run detection 86 | try { 87 | List detectionList = inferenceSession.run(img); 88 | ImageUtil.drawPredictions(img, detectionList); 89 | } catch (OrtException ortException) { 90 | ortException.printStackTrace(); 91 | } 92 | 93 | // Displaying in the GUI 94 | MatOfByte matOfByte = new MatOfByte(); 95 | Imgcodecs.imencode(".png", img, matOfByte); 96 | 97 | ImageIcon imageIcon = new ImageIcon(matOfByte.toArray()); 98 | imageIcon.getImage().flush(); 99 | label.setIcon(imageIcon); 100 | pack(); 101 | 102 | } 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/java/com/example/app/video/VideoApp.java: -------------------------------------------------------------------------------- 1 | package com.example.app.video; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import com.example.yolo.Detection; 5 | import com.example.yolo.ImageUtil; 6 | import com.example.yolo.ModelFactory; 7 | import com.example.yolo.Yolo; 8 | import org.opencv.core.Mat; 9 | import org.opencv.core.MatOfByte; 10 | import org.opencv.imgcodecs.Imgcodecs; 11 | import org.opencv.videoio.VideoCapture; 12 | import org.opencv.videoio.Videoio; 13 | 14 | import javax.swing.*; 15 | import java.io.IOException; 16 | import java.util.List; 17 | 18 | 19 | public class VideoApp { 20 | 21 | private Yolo inferenceSession; 22 | 23 | public VideoApp() { 24 | ModelFactory modelFactory = new ModelFactory(); 25 | try { 26 | this.inferenceSession = modelFactory.getModel("model.properties"); 27 | } catch (OrtException | IOException exception) { 28 | exception.printStackTrace(); 29 | System.exit(1); 30 | } 31 | } 32 | 33 | public static void main(String[] args) { 34 | nu.pattern.OpenCV.loadLocally(); 35 | VideoApp app = new VideoApp(); 36 | 37 | VideoCapture cap = new VideoCapture(0); 38 | cap.set(Videoio.CAP_PROP_FRAME_WIDTH, 1280); 39 | cap.set(Videoio.CAP_PROP_FRAME_HEIGHT, 800); 40 | 41 | JFrame jframe = new JFrame("Title"); 42 | jframe.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 43 | JLabel vidpanel = new JLabel(); 44 | jframe.setContentPane(vidpanel); 45 | jframe.setVisible(true); 46 | jframe.setResizable(true); 47 | jframe.setSize(1280, 800); 48 | 49 | if (!cap.isOpened()) { 50 | System.exit(1); 51 | } 52 | 53 | Mat frame = new Mat(); 54 | MatOfByte matofByte = new MatOfByte(); 55 | 56 | while ( cap.read(frame) ) { 57 | 58 | ImageUtil.resizeWithPadding(frame, frame, 1280, 800); 59 | 60 | // run detection 61 | try { 62 | List detectionList = app.inferenceSession.run(frame); 63 | ImageUtil.drawPredictions(frame, detectionList); 64 | } catch (OrtException ortException) { 65 | ortException.printStackTrace(); 66 | } 67 | Imgcodecs.imencode(".jpg", frame, matofByte); 68 | ImageIcon imageIcon = new ImageIcon(matofByte.toArray()); 69 | vidpanel.setIcon(imageIcon); 70 | vidpanel.repaint(); 71 | } 72 | 73 | cap.release(); 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/Detection.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | public record Detection(String label, float[] bbox, float confidence) { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/ImageUtil.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | import org.opencv.core.*; 4 | import org.opencv.imgproc.Imgproc; 5 | 6 | import java.util.List; 7 | 8 | 9 | public class ImageUtil { 10 | 11 | public static Mat resizeWithPadding(Mat src, int width, int height) { 12 | 13 | Mat dst = new Mat(); 14 | int oldW = src.width(); 15 | int oldH = src.height(); 16 | 17 | double r = Math.min((double) width / oldW, (double) height / oldH); 18 | 19 | int newUnpadW = (int) Math.round(oldW * r); 20 | int newUnpadH = (int) Math.round(oldH * r); 21 | 22 | int dw = (width - newUnpadW) / 2; 23 | int dh = (height - newUnpadH) / 2; 24 | 25 | int top = (int) Math.round(dh - 0.1); 26 | int bottom = (int) Math.round(dh + 0.1); 27 | int left = (int) Math.round(dw - 0.1); 28 | int right = (int) Math.round(dw + 0.1); 29 | 30 | Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH)); 31 | Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT); 32 | 33 | return dst; 34 | 35 | } 36 | 37 | public static void resizeWithPadding(Mat src, Mat dst, int width, int height) { 38 | 39 | int oldW = src.width(); 40 | int oldH = src.height(); 41 | 42 | double r = Math.min((double) width / oldW, (double) height / oldH); 43 | 44 | int newUnpadW = (int) Math.round(oldW * r); 45 | int newUnpadH = (int) Math.round(oldH * r); 46 | 47 | int dw = (width - newUnpadW) / 2; 48 | int dh = (height - newUnpadH) / 2; 49 | 50 | int top = (int) Math.round(dh - 0.1); 51 | int bottom = (int) Math.round(dh + 0.1); 52 | int left = (int) Math.round(dw - 0.1); 53 | int right = (int) Math.round(dw + 0.1); 54 | 55 | Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH)); 56 | Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT); 57 | 58 | } 59 | 60 | public static void whc2cwh(float[] src, float[] dst, int start) { 61 | int j = start; 62 | for (int ch = 0; ch < 3; ++ch) { 63 | for (int i = ch; i < src.length; i += 3) { 64 | dst[j] = src[i]; 65 | j++; 66 | } 67 | } 68 | } 69 | 70 | public static float[] whc2cwh(float[] src) { 71 | float[] chw = new float[src.length]; 72 | int j = 0; 73 | for (int ch = 0; ch < 3; ++ch) { 74 | for (int i = ch; i < src.length; i += 3) { 75 | chw[j] = src[i]; 76 | j++; 77 | } 78 | } 79 | return chw; 80 | } 81 | 82 | public static byte[] whc2cwh(byte[] src) { 83 | byte[] chw = new byte[src.length]; 84 | int j = 0; 85 | for (int ch = 0; ch < 3; ++ch) { 86 | for (int i = ch; i < src.length; i += 3) { 87 | chw[j] = src[i]; 88 | j++; 89 | } 90 | } 91 | return chw; 92 | } 93 | 94 | public static void drawPredictions(Mat img, List detectionList) { 95 | // debugging image 96 | for (Detection detection : detectionList) { 97 | 98 | float[] bbox = detection.bbox(); 99 | Scalar color = new Scalar(249, 218, 60); 100 | Imgproc.rectangle(img, //Matrix obj of the image 101 | new Point(bbox[0], bbox[1]), //p1 102 | new Point(bbox[2], bbox[3]), //p2 103 | color, //Scalar object for color 104 | 2 //Thickness of the line 105 | ); 106 | Imgproc.putText( 107 | img, 108 | detection.label(), 109 | new Point(bbox[0] - 1, bbox[1] - 5), 110 | Imgproc.FONT_HERSHEY_SIMPLEX, 111 | .5, color, 112 | 1); 113 | } 114 | } 115 | 116 | } 117 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/ModelFactory.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import com.example.app.ConfigReader; 5 | import org.springframework.util.ResourceUtils; 6 | 7 | import java.io.File; 8 | import java.io.IOException; 9 | import java.util.Properties; 10 | 11 | public class ModelFactory { 12 | 13 | 14 | public Yolo getModel(String propertiesFilePath) throws IOException, OrtException, NotImplementedException { 15 | 16 | ConfigReader configReader = new ConfigReader(); 17 | Properties properties = configReader.readProperties(propertiesFilePath); 18 | 19 | String modelName = properties.getProperty("modelName"); 20 | File file = ResourceUtils.getFile("classpath:" + properties.getProperty("modelPath")); 21 | File file2 = ResourceUtils.getFile("classpath:coco.names"); 22 | String modelPath = String.valueOf((file)); 23 | String labelPath = String.valueOf((file2)); 24 | float confThreshold = Float.parseFloat(properties.getProperty("confThreshold")); 25 | float nmsThreshold = Float.parseFloat(properties.getProperty("nmsThreshold")); 26 | int gpuDeviceId = Integer.parseInt(properties.getProperty("gpuDeviceId")); 27 | 28 | if (modelName.equalsIgnoreCase("yolov5")) { 29 | return new YoloV5(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId); 30 | } 31 | else if (modelName.equalsIgnoreCase("yolov8")) { 32 | return new YoloV8(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId); 33 | } 34 | else { 35 | throw new NotImplementedException(); 36 | } 37 | 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/NotImplementedException.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | 4 | public class NotImplementedException extends RuntimeException { 5 | public NotImplementedException(){} 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/Yolo.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | import ai.onnxruntime.*; 4 | import org.opencv.core.Mat; 5 | 6 | import java.io.BufferedReader; 7 | import java.io.FileReader; 8 | import java.io.IOException; 9 | import java.util.*; 10 | import java.util.stream.Collectors; 11 | 12 | public abstract class Yolo { 13 | 14 | public static final int INPUT_SIZE = 640; 15 | public static final int NUM_INPUT_ELEMENTS = 3 * 640 * 640; 16 | public static final long[] INPUT_SHAPE = {1, 3, 640, 640}; 17 | public float confThreshold; 18 | public float nmsThreshold; 19 | public OnnxJavaType inputType; 20 | protected final OrtEnvironment env; 21 | protected final OrtSession session; 22 | protected final String inputName; 23 | public ArrayList labelNames; 24 | 25 | OnnxTensor inputTensor; 26 | 27 | public Yolo(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException { 28 | nu.pattern.OpenCV.loadLocally(); 29 | 30 | this.env = OrtEnvironment.getEnvironment(); 31 | var sessionOptions = new OrtSession.SessionOptions(); 32 | sessionOptions.addCPU(false); 33 | sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); 34 | 35 | if (gpuDeviceId >= 0) sessionOptions.addCUDA(gpuDeviceId); 36 | this.session = this.env.createSession(modelPath, sessionOptions); 37 | 38 | Map inputMetaMap = this.session.getInputInfo(); 39 | this.inputName = this.session.getInputNames().iterator().next(); 40 | NodeInfo inputMeta = inputMetaMap.get(this.inputName); 41 | this.inputType = ((TensorInfo) inputMeta.getInfo()).type; 42 | 43 | this.confThreshold = confThreshold; 44 | this.nmsThreshold = nmsThreshold; 45 | 46 | BufferedReader br = new BufferedReader(new FileReader(labelPath)); 47 | String line; 48 | this.labelNames = new ArrayList<>(); 49 | while ((line = br.readLine()) != null) { 50 | this.labelNames.add(line); 51 | } 52 | 53 | } 54 | public abstract List run(Mat img) throws OrtException; 55 | 56 | private float computeIOU(float[] box1, float[] box2) { 57 | 58 | float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]); 59 | float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]); 60 | 61 | float left = Math.max(box1[0], box2[0]); 62 | float top = Math.max(box1[1], box2[1]); 63 | float right = Math.min(box1[2], box2[2]); 64 | float bottom = Math.min(box1[3], box2[3]); 65 | 66 | float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0); 67 | float unionArea = area1 + area2 - interArea; 68 | return Math.max(interArea / unionArea, 1e-8f); 69 | 70 | } 71 | 72 | protected List nonMaxSuppression(List bboxes, float iouThreshold) { 73 | 74 | // output boxes 75 | List bestBboxes = new ArrayList<>(); 76 | 77 | // confidence 순 정렬 78 | bboxes.sort(Comparator.comparing(a -> a[4])); 79 | 80 | // standard nms 81 | while (!bboxes.isEmpty()) { 82 | float[] bestBbox = bboxes.remove(bboxes.size() - 1); // 현재 가장 confidence 가 높은 박스 pop 83 | bestBboxes.add(bestBbox); 84 | bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList()); 85 | } 86 | 87 | return bestBboxes; 88 | } 89 | 90 | protected void xywh2xyxy(float[] bbox) { 91 | float x = bbox[0]; 92 | float y = bbox[1]; 93 | float w = bbox[2]; 94 | float h = bbox[3]; 95 | 96 | bbox[0] = x - w * 0.5f; 97 | bbox[1] = y - h * 0.5f; 98 | bbox[2] = x + w * 0.5f; 99 | bbox[3] = y + h * 0.5f; 100 | } 101 | 102 | protected void scaleCoords(float[] bbox, float orgW, float orgH, float padW, float padH, float gain) { 103 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org) 104 | bbox[0] = Math.max(0, Math.min(orgW - 1, (bbox[0] - padW) / gain)); 105 | bbox[1] = Math.max(0, Math.min(orgH - 1, (bbox[1] - padH) / gain)); 106 | bbox[2] = Math.max(0, Math.min(orgW - 1, (bbox[2] - padW) / gain)); 107 | bbox[3] = Math.max(0, Math.min(orgH - 1, (bbox[3] - padH) / gain)); 108 | } 109 | 110 | static int argmax(float[] a) { 111 | float re = -Float.MAX_VALUE; 112 | int arg = -1; 113 | for (int i = 0; i < a.length; i++) { 114 | if (a[i] >= re) { 115 | re = a[i]; 116 | arg = i; 117 | } 118 | } 119 | return arg; 120 | } 121 | 122 | } 123 | -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/YoloV5.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | import ai.onnxruntime.*; 4 | import org.opencv.core.CvType; 5 | import org.opencv.core.Mat; 6 | import org.opencv.imgproc.Imgproc; 7 | 8 | import java.io.IOException; 9 | import java.nio.ByteBuffer; 10 | import java.nio.FloatBuffer; 11 | import java.util.*; 12 | 13 | public class YoloV5 extends Yolo { 14 | 15 | public YoloV5(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException { 16 | super(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId); 17 | } 18 | 19 | public List run(Mat img) throws OrtException { 20 | 21 | float orgW = (float) img.size().width; 22 | float orgH = (float) img.size().height; 23 | 24 | float gain = Math.min((float) INPUT_SIZE / orgW, (float) INPUT_SIZE / orgH); 25 | float padW = (INPUT_SIZE - orgW * gain) * 0.5f; 26 | float padH = (INPUT_SIZE - orgH * gain) * 0.5f; 27 | 28 | // preprocessing 29 | Map inputContainer = this.preprocess(img); 30 | 31 | // Run inference 32 | float[][] predictions; 33 | 34 | OrtSession.Result results = this.session.run(inputContainer); 35 | predictions = ((float[][][]) results.get(0).getValue())[0]; 36 | 37 | // postprocessing 38 | return postprocess(predictions, orgW, orgH, padW, padH, gain); 39 | } 40 | 41 | 42 | public Map preprocess(Mat img) throws OrtException { 43 | 44 | // Resizing 45 | Mat resizedImg = new Mat(); 46 | ImageUtil.resizeWithPadding(img, resizedImg, INPUT_SIZE, INPUT_SIZE); 47 | 48 | // BGR -> RGB 49 | Imgproc.cvtColor(resizedImg, resizedImg, Imgproc.COLOR_BGR2RGB); 50 | 51 | Map container = new HashMap<>(); 52 | 53 | // if model is quantized 54 | if (this.inputType.equals(OnnxJavaType.UINT8)) { 55 | byte[] whc = new byte[NUM_INPUT_ELEMENTS]; 56 | resizedImg.get(0, 0, whc); 57 | byte[] chw = ImageUtil.whc2cwh(whc); 58 | ByteBuffer inputBuffer = ByteBuffer.wrap(chw); 59 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE, this.inputType); 60 | } else { 61 | // Normalization 62 | resizedImg.convertTo(resizedImg, CvType.CV_32FC1, 1. / 255); 63 | float[] whc = new float[NUM_INPUT_ELEMENTS]; 64 | resizedImg.get(0, 0, whc); 65 | float[] chw = ImageUtil.whc2cwh(whc); 66 | FloatBuffer inputBuffer = FloatBuffer.wrap(chw); 67 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE); 68 | } 69 | 70 | // To OnnxTensor 71 | container.put(this.inputName, inputTensor); 72 | 73 | return container; 74 | } 75 | 76 | public List postprocess(float[][] outputs, float orgW, float orgH, float padW, float padH, float gain) { 77 | 78 | // predictions 79 | Map> class2Bbox = new HashMap<>(); 80 | 81 | for (float[] bbox : outputs) { 82 | 83 | float conf = bbox[4]; 84 | if (conf < this.confThreshold) continue; 85 | 86 | float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 5, 85); 87 | int label = Yolo.argmax(conditionalProbabilities); 88 | 89 | // xywh to (x1, y1, x2, y2) 90 | xywh2xyxy(bbox); 91 | 92 | // skip invalid predictions 93 | if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue; 94 | 95 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org) 96 | scaleCoords(bbox, orgW, orgH, padW, padH, gain); 97 | class2Bbox.putIfAbsent(label, new ArrayList<>()); 98 | class2Bbox.get(label).add(bbox); 99 | } 100 | 101 | // Apply Non-max suppression for each class 102 | List detections = new ArrayList<>(); 103 | for (Map.Entry> entry : class2Bbox.entrySet()) { 104 | int label = entry.getKey(); 105 | List bboxes = entry.getValue(); 106 | bboxes = nonMaxSuppression(bboxes, this.nmsThreshold); 107 | for (float[] bbox : bboxes) { 108 | String labelString = this.labelNames.get(label); 109 | detections.add(new Detection(labelString, Arrays.copyOfRange(bbox, 0, 4), bbox[4])); 110 | } 111 | } 112 | 113 | return detections; 114 | } 115 | } -------------------------------------------------------------------------------- /src/main/java/com/example/yolo/YoloV8.java: -------------------------------------------------------------------------------- 1 | package com.example.yolo; 2 | 3 | import ai.onnxruntime.*; 4 | import org.opencv.core.CvType; 5 | import org.opencv.core.Mat; 6 | import org.opencv.imgproc.Imgproc; 7 | 8 | import java.io.IOException; 9 | import java.nio.ByteBuffer; 10 | import java.nio.FloatBuffer; 11 | import java.util.*; 12 | 13 | public class YoloV8 extends Yolo{ 14 | 15 | public YoloV8(String modelPath, String labelPath, float confThreshold, float nmsThreshold, int gpuDeviceId) throws OrtException, IOException { 16 | super(modelPath, labelPath, confThreshold, nmsThreshold, gpuDeviceId); 17 | } 18 | 19 | public List run(Mat img) throws OrtException { 20 | 21 | float orgW = (float) img.size().width; 22 | float orgH = (float) img.size().height; 23 | 24 | float gain = Math.min((float) INPUT_SIZE / orgW, (float) INPUT_SIZE / orgH); 25 | float padW = (INPUT_SIZE - orgW * gain) * 0.5f; 26 | float padH = (INPUT_SIZE - orgH * gain) * 0.5f; 27 | 28 | // preprocessing 29 | Map inputContainer = this.preprocess(img); 30 | 31 | // Run inference 32 | float[][] predictions; 33 | 34 | OrtSession.Result results = this.session.run(inputContainer); 35 | predictions = ((float[][][]) results.get(0).getValue())[0]; 36 | 37 | // postprocessing 38 | return postprocess(predictions, orgW, orgH, padW, padH, gain); 39 | } 40 | 41 | 42 | public Map preprocess(Mat img) throws OrtException { 43 | 44 | // Resizing 45 | Mat resizedImg = new Mat(); 46 | ImageUtil.resizeWithPadding(img, resizedImg, INPUT_SIZE, INPUT_SIZE); 47 | 48 | // BGR -> RGB 49 | Imgproc.cvtColor(resizedImg, resizedImg, Imgproc.COLOR_BGR2RGB); 50 | 51 | Map container = new HashMap<>(); 52 | 53 | // if model is quantized 54 | if (this.inputType.equals(OnnxJavaType.UINT8)) { 55 | byte[] whc = new byte[NUM_INPUT_ELEMENTS]; 56 | resizedImg.get(0, 0, whc); 57 | byte[] chw = ImageUtil.whc2cwh(whc); 58 | ByteBuffer inputBuffer = ByteBuffer.wrap(chw); 59 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE, this.inputType); 60 | } else { 61 | // Normalization 62 | resizedImg.convertTo(resizedImg, CvType.CV_32FC1, 1. / 255); 63 | float[] whc = new float[NUM_INPUT_ELEMENTS]; 64 | resizedImg.get(0, 0, whc); 65 | float[] chw = ImageUtil.whc2cwh(whc); 66 | FloatBuffer inputBuffer = FloatBuffer.wrap(chw); 67 | inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, INPUT_SHAPE); 68 | } 69 | 70 | // To OnnxTensor 71 | container.put(this.inputName, inputTensor); 72 | 73 | return container; 74 | } 75 | 76 | public List postprocess(float[][] outputs, float orgW, float orgH, float padW, float padH, float gain) { 77 | 78 | // predictions 79 | outputs = transposeMatrix(outputs); 80 | Map> class2Bbox = new HashMap<>(); 81 | 82 | for (float[] bbox : outputs) { 83 | 84 | 85 | float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, 84); 86 | int label = Yolo.argmax(conditionalProbabilities); 87 | float conf = conditionalProbabilities[label]; 88 | if (conf < this.confThreshold) continue; 89 | 90 | bbox[4] = conf; 91 | 92 | // xywh to (x1, y1, x2, y2) 93 | xywh2xyxy(bbox); 94 | 95 | // skip invalid predictions 96 | if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue; 97 | 98 | // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org) 99 | scaleCoords(bbox, orgW, orgH, padW, padH, gain); 100 | class2Bbox.putIfAbsent(label, new ArrayList<>()); 101 | class2Bbox.get(label).add(bbox); 102 | } 103 | 104 | // Apply Non-max suppression for each class 105 | List detections = new ArrayList<>(); 106 | for (Map.Entry> entry : class2Bbox.entrySet()) { 107 | int label = entry.getKey(); 108 | List bboxes = entry.getValue(); 109 | bboxes = nonMaxSuppression(bboxes, this.nmsThreshold); 110 | for (float[] bbox : bboxes) { 111 | String labelString = this.labelNames.get(label); 112 | detections.add(new Detection(labelString, Arrays.copyOfRange(bbox, 0, 4), bbox[4])); 113 | } 114 | } 115 | 116 | return detections; 117 | } 118 | 119 | public static float[][] transposeMatrix(float [][] m){ 120 | float[][] temp = new float[m[0].length][m.length]; 121 | for (int i = 0; i < m.length; i++) 122 | for (int j = 0; j < m[0].length; j++) 123 | temp[j][i] = m[i][j]; 124 | return temp; 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/main/resources/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /src/main/resources/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/dog.jpg -------------------------------------------------------------------------------- /src/main/resources/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/dog.png -------------------------------------------------------------------------------- /src/main/resources/kite.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/kite.jpg -------------------------------------------------------------------------------- /src/main/resources/model.properties: -------------------------------------------------------------------------------- 1 | modelName=yolov5 2 | modelPath=yolov5s.onnx 3 | #modelName=yolov8 4 | #modelPath=yolov8s.onnx 5 | labelPath=coco.names 6 | confThreshold=0.25 7 | nmsThreshold=0.45 8 | gpuDeviceId=-1 9 | -------------------------------------------------------------------------------- /src/main/resources/yolov5s.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov5s.onnx -------------------------------------------------------------------------------- /src/main/resources/yolov5s.sparseml.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov5s.sparseml.onnx -------------------------------------------------------------------------------- /src/main/resources/yolov8s.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhgan00/java-ort-example-yolov5/5f6ea876f8663d439fd1e986d317368c3463d449/src/main/resources/yolov8s.onnx --------------------------------------------------------------------------------