├── README.md ├── images ├── OD-bus.jpg ├── OD-test.jpg ├── PE-bus.jpg ├── PE-test.jpg ├── bus.jpg └── test.jpg ├── pom.xml └── src └── main ├── java └── cn │ └── halashuo │ ├── ObjectDetection.java │ ├── PoseEstimation.java │ ├── config │ ├── ODConfig.java │ └── PEConfig.java │ ├── domain │ ├── KeyPoint.java │ ├── ODResult.java │ └── PEResult.java │ └── utils │ ├── Letterbox.java │ └── NMS.java └── resources └── lib └── opencv-460.jar /README.md: -------------------------------------------------------------------------------- 1 | ## ONNX files downloaded 2 | 3 | Here are a few ONNX files that have been trained [Click here to download](http://pan.halashuo.cn/?dir=data/ONNX), After downloading or converting your trained model to ONNX, please put it in the `resources\model` folder, or you can customize the ONNX storage location by modifying the `PEConfig.java` or `ODConfig.java`. 4 | 5 | ## Configure the environment 6 | 7 | This project mainly uses onnx runtime and openCV. 8 | 9 | Please copy the opencv_java460.dll generated by the openCV installation to the `C:\Windows\System32` folder before using this project, otherwise `no opencv_java460 in java.library.path` errors will be generated. 10 | 11 | ## Predict the outcome 12 | 13 | ### Object Detection 14 | 15 | ![](images/OD-bus.jpg) 16 | 17 | ![](images/OD-test.jpg) 18 | 19 | ### Pose Estimation 20 | 21 | ![](images/PE-bus.jpg) 22 | 23 | ![](images/PE-test.jpg) 24 | 25 | -------------------------------------------------------------------------------- /images/OD-bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/OD-bus.jpg -------------------------------------------------------------------------------- /images/OD-test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/OD-test.jpg -------------------------------------------------------------------------------- /images/PE-bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/PE-bus.jpg -------------------------------------------------------------------------------- /images/PE-test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/PE-test.jpg -------------------------------------------------------------------------------- /images/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/bus.jpg -------------------------------------------------------------------------------- /images/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/test.jpg -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | cn.halashuo 8 | yolov7 9 | 1.0 10 | 11 | 12 | 8 13 | 8 14 | 15 | 16 | 17 | 18 | com.microsoft.onnxruntime 19 | onnxruntime 20 | 1.12.1 21 | 22 | 23 | 24 | org.opencv 25 | opencv 26 | 4.6.0 27 | system 28 | ${project.basedir}/src/main/resources/lib/opencv-460.jar 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/ObjectDetection.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo; 2 | 3 | import ai.onnxruntime.OnnxTensor; 4 | import ai.onnxruntime.OrtEnvironment; 5 | import ai.onnxruntime.OrtException; 6 | import ai.onnxruntime.OrtSession; 7 | import cn.halashuo.domain.ODResult; 8 | import cn.halashuo.config.ODConfig; 9 | import cn.halashuo.utils.Letterbox; 10 | import org.opencv.core.Core; 11 | import org.opencv.core.Mat; 12 | import org.opencv.core.Point; 13 | import org.opencv.core.Scalar; 14 | import org.opencv.highgui.HighGui; 15 | import org.opencv.imgcodecs.Imgcodecs; 16 | import org.opencv.imgproc.Imgproc; 17 | 18 | import java.nio.FloatBuffer; 19 | import java.util.Arrays; 20 | import java.util.HashMap; 21 | 22 | public class ObjectDetection { 23 | 24 | static 25 | { 26 | //在使用OpenCV前必须加载Core.NATIVE_LIBRARY_NAME类,否则会报错 27 | System.loadLibrary(Core.NATIVE_LIBRARY_NAME); 28 | } 29 | 30 | public static void main(String[] args) throws OrtException { 31 | 32 | // 加载ONNX模型 33 | OrtEnvironment environment = OrtEnvironment.getEnvironment(); 34 | OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); 35 | OrtSession session = environment.createSession(ODConfig.modelPath, sessionOptions); 36 | // 输出基本信息 37 | session.getInputInfo().keySet().forEach(x-> { 38 | try { 39 | System.out.println("input name = " + x); 40 | System.out.println(session.getInputInfo().get(x).getInfo().toString()); 41 | } catch (OrtException e) { 42 | throw new RuntimeException(e); 43 | } 44 | }); 45 | 46 | // 加载标签及颜色 47 | ODConfig odConfig = new ODConfig(); 48 | 49 | // 读取 image 50 | Mat img = Imgcodecs.imread(ODConfig.picPath); 51 | Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB); 52 | Mat image = img.clone(); 53 | 54 | // 在这里先定义下框的粗细、字的大小、字的类型、字的颜色(按比例设置大小粗细比较好一些) 55 | int minDwDh = Math.min(img.width(), img.height()); 56 | int thickness = minDwDh/ODConfig.lineThicknessRatio; 57 | double fontSize = minDwDh/ODConfig.fontSizeRatio; 58 | int fontFace = Imgproc.FONT_HERSHEY_SIMPLEX; 59 | Scalar fontColor = new Scalar(255, 255, 255); 60 | 61 | // 更改 image 尺寸 62 | Letterbox letterbox = new Letterbox(); 63 | image = letterbox.letterbox(image); 64 | double ratio = letterbox.getRatio(); 65 | double dw = letterbox.getDw(); 66 | double dh = letterbox.getDh(); 67 | int rows = letterbox.getHeight(); 68 | int cols = letterbox.getWidth(); 69 | int channels = image.channels(); 70 | 71 | // 将Mat对象的像素值赋值给Float[]对象 72 | float[] pixels = new float[channels * rows * cols]; 73 | for (int i = 0; i < rows; i++) { 74 | for (int j = 0; j < cols; j++) { 75 | double[] pixel = image.get(j,i); 76 | for (int k = 0; k < channels; k++) { 77 | // 这样设置相当于同时做了image.transpose((2, 0, 1))操作 78 | pixels[rows*cols*k+j*cols+i] = (float) pixel[k]/255.0f; 79 | } 80 | } 81 | } 82 | 83 | // 创建OnnxTensor对象 84 | long[] shape = { 1L, (long)channels, (long)rows, (long)cols }; 85 | OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape); 86 | HashMap stringOnnxTensorHashMap = new HashMap<>(); 87 | stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor); 88 | 89 | // 运行模型 90 | OrtSession.Result output = session.run(stringOnnxTensorHashMap); 91 | 92 | // 得到结果 93 | float[][] outputData = (float[][]) output.get(0).getValue(); 94 | Arrays.stream(outputData).iterator().forEachRemaining(x->{ 95 | ODResult odResult = new ODResult(x); 96 | System.out.println(odResult); 97 | // 画框 98 | Point topLeft = new Point((odResult.getX0()-dw)/ratio, (odResult.getY0()-dh)/ratio); 99 | Point bottomRight = new Point((odResult.getX1()-dw)/ratio, (odResult.getY1()-dh)/ratio); 100 | Scalar color = new Scalar(odConfig.getColor(odResult.getClsId())); 101 | Imgproc.rectangle(img, topLeft, bottomRight, color, thickness); 102 | // 框上写文字 103 | String boxName = odConfig.getName(odResult.getClsId()) + ": " + odResult.getScore(); 104 | Point boxNameLoc = new Point((odResult.getX0()-dw)/ratio, (odResult.getY0()-dh)/ratio-3); 105 | Imgproc.putText(img, boxName, boxNameLoc, fontFace, fontSize, fontColor, thickness); 106 | }); 107 | 108 | Imgproc.cvtColor(img, img, Imgproc.COLOR_RGB2BGR); 109 | 110 | // 保存图像 111 | Imgcodecs.imwrite(ODConfig.savePicPath, img); 112 | HighGui.imshow("Display Image", img); 113 | // 等待按下任意键继续执行程序 114 | HighGui.waitKey(); 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/PoseEstimation.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo; 2 | 3 | import ai.onnxruntime.OnnxTensor; 4 | import ai.onnxruntime.OrtEnvironment; 5 | import ai.onnxruntime.OrtException; 6 | import ai.onnxruntime.OrtSession; 7 | import cn.halashuo.domain.KeyPoint; 8 | import cn.halashuo.domain.PEResult; 9 | import cn.halashuo.utils.Letterbox; 10 | import cn.halashuo.utils.NMS; 11 | import cn.halashuo.config.PEConfig; 12 | import org.opencv.core.*; 13 | import org.opencv.highgui.HighGui; 14 | import org.opencv.imgcodecs.Imgcodecs; 15 | import org.opencv.imgproc.Imgproc; 16 | 17 | import java.nio.FloatBuffer; 18 | import java.util.ArrayList; 19 | import java.util.HashMap; 20 | import java.util.List; 21 | 22 | public class PoseEstimation { 23 | static 24 | { 25 | //在使用OpenCV前必须加载Core.NATIVE_LIBRARY_NAME类,否则会报错 26 | System.loadLibrary(Core.NATIVE_LIBRARY_NAME); 27 | } 28 | 29 | public static void main(String[] args) throws OrtException { 30 | 31 | // 加载ONNX模型 32 | OrtEnvironment environment = OrtEnvironment.getEnvironment(); 33 | OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); 34 | OrtSession session = environment.createSession(PEConfig.modelPath, sessionOptions); 35 | // 输出基本信息 36 | session.getInputInfo().keySet().forEach(x -> { 37 | try { 38 | System.out.println("input name = " + x); 39 | System.out.println(session.getInputInfo().get(x).getInfo().toString()); 40 | } catch (OrtException e) { 41 | throw new RuntimeException(e); 42 | } 43 | }); 44 | 45 | // 读取 image 46 | Mat img = Imgcodecs.imread(PEConfig.picPath); 47 | Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB); 48 | Mat image = img.clone(); 49 | 50 | // 在这里先定义下线的粗细、关键的半径(按比例设置大小粗细比较好一些) 51 | int minDwDh = Math.min(img.width(), img.height()); 52 | int thickness = minDwDh / PEConfig.lineThicknessRatio; 53 | int radius = minDwDh / PEConfig.dotRadiusRatio; 54 | 55 | // 更改 image 尺寸 56 | Letterbox letterbox = new Letterbox(); 57 | letterbox.setNewShape(new Size(960, 960)); 58 | letterbox.setStride(64); 59 | image = letterbox.letterbox(image); 60 | double ratio = letterbox.getRatio(); 61 | double dw = letterbox.getDw(); 62 | double dh = letterbox.getDh(); 63 | int rows = letterbox.getHeight(); 64 | int cols = letterbox.getWidth(); 65 | int channels = image.channels(); 66 | 67 | // 将Mat对象的像素值赋值给Float[]对象 68 | float[] pixels = new float[channels * rows * cols]; 69 | for (int i = 0; i < rows; i++) { 70 | for (int j = 0; j < cols; j++) { 71 | double[] pixel = image.get(j, i); 72 | for (int k = 0; k < channels; k++) { 73 | // 这样设置相当于同时做了image.transpose((2, 0, 1))操作 74 | pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f; 75 | } 76 | } 77 | } 78 | 79 | // 创建OnnxTensor对象 80 | long[] shape = {1L, (long) channels, (long) rows, (long) cols}; 81 | OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape); 82 | HashMap stringOnnxTensorHashMap = new HashMap<>(); 83 | stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor); 84 | 85 | // 运行模型 86 | OrtSession.Result output = session.run(stringOnnxTensorHashMap); 87 | 88 | // 得到结果 89 | float[][] outputData = ((float[][][]) output.get(0).getValue())[0]; 90 | 91 | List peResults = new ArrayList<>(); 92 | for (float[] outputDatum : outputData) { 93 | PEResult result = new PEResult(outputDatum); 94 | if (result.getScore() > PEConfig.personScoreThreshold) { 95 | peResults.add(result); 96 | } 97 | } 98 | 99 | // 对结果进行非极大值抑制 100 | peResults = NMS.nms(peResults, PEConfig.IoUThreshold); 101 | 102 | for (PEResult peResult: peResults) { 103 | System.out.println(peResult); 104 | // 画框 105 | Point topLeft = new Point((peResult.getX0()-dw)/ratio, (peResult.getY0()-dh)/ratio); 106 | Point bottomRight = new Point((peResult.getX1()-dw)/ratio, (peResult.getY1()-dh)/ratio); 107 | Imgproc.rectangle(img, topLeft, bottomRight, new Scalar(255,0,0), thickness); 108 | List keyPoints = peResult.getKeyPointList(); 109 | // 画点 110 | keyPoints.forEach(keyPoint->{ 111 | if (keyPoint.getScore()>PEConfig.keyPointScoreThreshold) { 112 | Point center = new Point((keyPoint.getX()-dw)/ratio, (keyPoint.getY()-dh)/ratio); 113 | Scalar color = PEConfig.poseKptColor.get(keyPoint.getId()); 114 | Imgproc.circle(img, center, radius, color, -1); //-1表示实心 115 | } 116 | }); 117 | // 画线 118 | for (int i = 0; i< PEConfig.skeleton.length; i++){ 119 | int indexPoint1 = PEConfig.skeleton[i][0]-1; 120 | int indexPoint2 = PEConfig.skeleton[i][1]-1; 121 | if ( keyPoints.get(indexPoint1).getScore()>PEConfig.keyPointScoreThreshold && 122 | keyPoints.get(indexPoint2).getScore()>PEConfig.keyPointScoreThreshold ) { 123 | Scalar coler = PEConfig.poseLimbColor.get(i); 124 | Point point1 = new Point( 125 | (keyPoints.get(indexPoint1).getX()-dw)/ratio, 126 | (keyPoints.get(indexPoint1).getY()-dh)/ratio 127 | ); 128 | Point point2 = new Point( 129 | (keyPoints.get(indexPoint2).getX()-dw)/ratio, 130 | (keyPoints.get(indexPoint2).getY()-dh)/ratio 131 | ); 132 | Imgproc.line(img, point1, point2, coler, thickness); 133 | } 134 | } 135 | } 136 | Imgproc.cvtColor(img, img, Imgproc.COLOR_RGB2BGR); 137 | // 保存图像 138 | Imgcodecs.imwrite(PEConfig.savePicPath, img); 139 | HighGui.imshow("Display Image", img); 140 | // 等待按下任意键继续执行程序 141 | HighGui.waitKey(); 142 | 143 | } 144 | } -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/config/ODConfig.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.config; 2 | 3 | import java.util.Random; 4 | import java.util.ArrayList; 5 | import java.util.Arrays; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.HashMap; 9 | 10 | public final class ODConfig { 11 | 12 | public static final String modelPath = "src\\main\\resources\\model\\yolov7-d6.onnx"; 13 | public static final String picPath = "images\\test.jpg";// 要预测图片位置 14 | public static final String savePicPath = "images\\OD-test.jpg";// 预测结果保存位置 15 | public static final Integer lineThicknessRatio = 333; 16 | public static final Double fontSizeRatio = 1145.14; 17 | private static final List names = new ArrayList<>(Arrays.asList( 18 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", 19 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", 20 | "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", 21 | "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", 22 | "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", 23 | "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 24 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", 25 | "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", 26 | "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", 27 | "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", 28 | "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 29 | "teddy bear", "hair drier", "toothbrush")); 30 | 31 | private final Map colors; 32 | 33 | public ODConfig() { 34 | this.colors = new HashMap<>(); 35 | names.forEach(name->{ 36 | Random random = new Random(); 37 | double[] color = {random.nextDouble()*256, random.nextDouble()*256, random.nextDouble()*256}; 38 | colors.put(name, color); 39 | }); 40 | } 41 | 42 | public String getName(int clsId) { 43 | return names.get(clsId); 44 | } 45 | 46 | public double[] getColor(int clsId) { 47 | return colors.get(getName(clsId)); 48 | } 49 | } -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/config/PEConfig.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.config; 2 | 3 | import org.opencv.core.Scalar; 4 | 5 | import java.util.ArrayList; 6 | import java.util.Arrays; 7 | import java.util.List; 8 | 9 | public final class PEConfig { 10 | 11 | public static final String modelPath = "src\\main\\resources\\model\\yolov7-w6-pose.onnx"; 12 | public static final String picPath = "images\\test.jpg";// 要预测图片位置 13 | public static final String savePicPath = "images\\PE-test.jpg";// 预测结果保存位置 14 | public static final float IoUThreshold = 0.65f; 15 | public static final float personScoreThreshold = 0.25f; 16 | public static final float keyPointScoreThreshold = 0.45f; 17 | // 根据图像大小按比例控制点大小及线粗 18 | public static final Integer dotRadiusRatio = 168; 19 | public static final Integer lineThicknessRatio = 333; 20 | 21 | 22 | public static final List palette= new ArrayList<>(Arrays.asList( 23 | new Scalar( 255, 128, 0 ), 24 | new Scalar( 255, 153, 51 ), 25 | new Scalar( 255, 178, 102 ), 26 | new Scalar( 230, 230, 0 ), 27 | new Scalar( 255, 153, 255 ), 28 | new Scalar( 153, 204, 255 ), 29 | new Scalar( 255, 102, 255 ), 30 | new Scalar( 255, 51, 255 ), 31 | new Scalar( 102, 178, 255 ), 32 | new Scalar( 51, 153, 255 ), 33 | new Scalar( 255, 153, 153 ), 34 | new Scalar( 255, 102, 102 ), 35 | new Scalar( 255, 51, 51 ), 36 | new Scalar( 153, 255, 153 ), 37 | new Scalar( 102, 255, 102 ), 38 | new Scalar( 51, 255, 51 ), 39 | new Scalar( 0, 255, 0 ), 40 | new Scalar( 0, 0, 255 ), 41 | new Scalar( 255, 0, 0 ), 42 | new Scalar( 255, 255, 255 ) 43 | )); 44 | 45 | public static final int[][] skeleton = { 46 | {16, 14}, {14, 12}, {17, 15}, {15, 13}, {12, 13}, {6, 12}, 47 | {7, 13}, {6, 7}, {6, 8}, {7, 9}, {8, 10}, {9, 11}, {2, 3}, 48 | {1, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, {5, 7} 49 | }; 50 | 51 | public static final List poseLimbColor = new ArrayList<>(Arrays.asList( 52 | palette.get(9), palette.get(9), palette.get(9), palette.get(9), palette.get(7), 53 | palette.get(7), palette.get(7), palette.get(0), palette.get(0), palette.get(0), 54 | palette.get(0), palette.get(0), palette.get(16), palette.get(16), palette.get(16), 55 | palette.get(16), palette.get(16), palette.get(16), palette.get(16))); 56 | 57 | public static final List poseKptColor = new ArrayList<>(Arrays.asList( 58 | palette.get(16), palette.get(16), palette.get(16), palette.get(16), palette.get(16), 59 | palette.get(0), palette.get(0), palette.get(0), palette.get(0), palette.get(0), 60 | palette.get(0), palette.get(9), palette.get(9), palette.get(9), palette.get(9), 61 | palette.get(9), palette.get(9))); 62 | 63 | } -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/domain/KeyPoint.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.domain; 2 | 3 | public class KeyPoint { 4 | private final Integer id; 5 | private final Float x; 6 | private final Float y; 7 | private final Float score; 8 | 9 | public KeyPoint(Integer id, Float x, Float y, Float score) { 10 | this.id = id; 11 | this.x = x; 12 | this.y = y; 13 | this.score = score; 14 | } 15 | 16 | public Integer getId() { 17 | return id; 18 | } 19 | 20 | public Float getX() { 21 | return x; 22 | } 23 | 24 | public Float getY() { 25 | return y; 26 | } 27 | 28 | public Float getScore() { 29 | return score; 30 | } 31 | 32 | @Override 33 | public String toString() { 34 | return " 第 " + (id+1) + " 个关键点: " + 35 | " x=" + x + 36 | " y=" + y + 37 | " c=" + score + 38 | "\n"; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/domain/ODResult.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.domain; 2 | 3 | import java.text.DecimalFormat; 4 | 5 | public class ODResult { 6 | private final Integer batchId; 7 | private final Float x0; 8 | private final Float y0; 9 | private final Float x1; 10 | private final Float y1; 11 | private final Integer clsId; 12 | private final Float score; 13 | 14 | public ODResult(float[] x) { 15 | this.batchId = (int) x[0]; 16 | this.x0 = x[1]; 17 | this.y0 = x[2]; 18 | this.x1 = x[3]; 19 | this.y1 = x[4]; 20 | this.clsId = (int) x[5]; 21 | this.score = x[6]; 22 | } 23 | 24 | public Integer getBatchId() { 25 | return batchId; 26 | } 27 | 28 | public Float getX0() { 29 | return x0; 30 | } 31 | 32 | public Float getY0() { 33 | return y0; 34 | } 35 | 36 | public Float getX1() { 37 | return x1; 38 | } 39 | 40 | public Float getY1() { 41 | return y1; 42 | } 43 | 44 | public Integer getClsId() { 45 | return clsId; 46 | } 47 | 48 | public String getScore() { 49 | DecimalFormat df = new DecimalFormat("0.00%"); 50 | return df.format(this.score); 51 | } 52 | 53 | @Override 54 | public String toString() { 55 | return "物体: " + 56 | " \t batchId=" + batchId + 57 | " \t x0=" + x0 + 58 | " \t y0=" + y0 + 59 | " \t x1=" + x1 + 60 | " \t y1=" + y1 + 61 | " \t clsId=" + clsId + 62 | " \t score=" + getScore() + 63 | " \t ;"; 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/domain/PEResult.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.domain; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class PEResult { 7 | 8 | private final Float x0; 9 | private final Float y0; 10 | private final Float x1; 11 | private final Float y1; 12 | private final Float score; 13 | private final Integer clsId; 14 | private final List keyPointList; 15 | 16 | public PEResult(float[] peResult) { 17 | float x = peResult[0]; 18 | float y = peResult[1]; 19 | float w = peResult[2]/2.0f; 20 | float h = peResult[3]/2.0f; 21 | this.x0 = x-w; 22 | this.y0 = y-h; 23 | this.x1 = x+w; 24 | this.y1 = y+h; 25 | this.score = peResult[4]; 26 | this.clsId = (int) peResult[5]; 27 | this.keyPointList = new ArrayList<>(); 28 | int keyPointNum = (peResult.length-6)/3; 29 | for (int i=0;i getKeyPointList() { 59 | return keyPointList; 60 | } 61 | 62 | @Override 63 | public String toString() { 64 | StringBuilder result = new StringBuilder("PEResult:" + 65 | " x0=" + x0 + 66 | ", y0=" + y0 + 67 | ", x1=" + x1 + 68 | ", y1=" + y1 + 69 | ", score=" + score + 70 | ", clsId=" + clsId + 71 | "\n"); 72 | for (KeyPoint x : keyPointList) { 73 | result.append(x.toString()); 74 | } 75 | return result.toString(); 76 | } 77 | } 78 | 79 | -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/utils/Letterbox.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.utils; 2 | 3 | import org.opencv.core.Core; 4 | import org.opencv.core.Mat; 5 | import org.opencv.core.Size; 6 | import org.opencv.imgproc.Imgproc; 7 | 8 | public class Letterbox { 9 | 10 | private Size newShape = new Size(1280, 1280); 11 | private final double[] color = new double[]{114,114,114}; 12 | private final Boolean auto = false; 13 | private final Boolean scaleUp = true; 14 | private Integer stride = 32; 15 | 16 | private double ratio; 17 | private double dw; 18 | private double dh; 19 | 20 | public double getRatio() { 21 | return ratio; 22 | } 23 | 24 | public double getDw() { 25 | return dw; 26 | } 27 | 28 | public Integer getWidth() { 29 | return (int) this.newShape.width; 30 | } 31 | 32 | public Integer getHeight() { 33 | return (int) this.newShape.height; 34 | } 35 | 36 | public double getDh() { 37 | return dh; 38 | } 39 | 40 | public void setNewShape(Size newShape) { 41 | this.newShape = newShape; 42 | } 43 | 44 | public void setStride(Integer stride) { 45 | this.stride = stride; 46 | } 47 | 48 | public Mat letterbox(Mat im) { // 调整图像大小和填充图像,使满足步长约束,并记录参数 49 | 50 | int[] shape = {im.rows(), im.cols()}; // 当前形状 [height, width] 51 | // Scale ratio (new / old) 52 | double r = Math.min(this.newShape.height / shape[0], this.newShape.width / shape[1]); 53 | if (!this.scaleUp) { // 仅缩小,不扩大(一且为了mAP) 54 | r = Math.min(r, 1.0); 55 | } 56 | // Compute padding 57 | Size newUnpad = new Size(Math.round(shape[1] * r), Math.round(shape[0] * r)); 58 | double dw = this.newShape.width - newUnpad.width, dh = this.newShape.height - newUnpad.height; // wh 填充 59 | if (this.auto) { // 最小矩形 60 | dw = dw % this.stride; 61 | dh = dh % this.stride; 62 | } 63 | dw /= 2; // 填充的时候两边都填充一半,使图像居于中心 64 | dh /= 2; 65 | if (shape[1] != newUnpad.width || shape[0] != newUnpad.height) { // resize 66 | Imgproc.resize(im, im, newUnpad, 0, 0, Imgproc.INTER_LINEAR); 67 | } 68 | int top = (int) Math.round(dh - 0.1), bottom = (int) Math.round(dh + 0.1); 69 | int left = (int) Math.round(dw - 0.1), right = (int) Math.round(dw + 0.1); 70 | // 将图像填充为正方形 71 | Core.copyMakeBorder(im, im, top, bottom, left, right, Core.BORDER_CONSTANT, new org.opencv.core.Scalar(this.color)); 72 | this.ratio = r; 73 | this.dh = dh; 74 | this.dw = dw; 75 | return im; 76 | } 77 | } -------------------------------------------------------------------------------- /src/main/java/cn/halashuo/utils/NMS.java: -------------------------------------------------------------------------------- 1 | package cn.halashuo.utils; 2 | 3 | import cn.halashuo.domain.PEResult; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | public class NMS { 9 | 10 | public static List nms(List boxes, float iouThreshold) { 11 | // 根据score从大到小对List进行排序 12 | boxes.sort((b1, b2) -> Float.compare(b2.getScore(), b1.getScore())); 13 | List resultList = new ArrayList<>(); 14 | for (int i = 0; i < boxes.size(); i++) { 15 | PEResult box = boxes.get(i); 16 | boolean keep = true; 17 | // 从i+1开始,遍历之后的所有boxes,移除与box的IOU大于阈值的元素 18 | for (int j = i + 1; j < boxes.size(); j++) { 19 | PEResult otherBox = boxes.get(j); 20 | float iou = getIntersectionOverUnion(box, otherBox); 21 | if (iou > iouThreshold) { 22 | keep = false; 23 | break; 24 | } 25 | } 26 | if (keep) { 27 | resultList.add(box); 28 | } 29 | } 30 | return resultList; 31 | } 32 | private static float getIntersectionOverUnion(PEResult box1, PEResult box2) { 33 | float x1 = Math.max(box1.getX0(), box2.getX0()); 34 | float y1 = Math.max(box1.getY0(), box2.getY0()); 35 | float x2 = Math.min(box1.getX1(), box2.getX1()); 36 | float y2 = Math.min(box1.getY1(), box2.getY1()); 37 | float intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); 38 | float box1Area = (box1.getX1() - box1.getX0()) * (box1.getY1() - box1.getY0()); 39 | float box2Area = (box2.getX1() - box2.getX0()) * (box2.getY1() - box2.getY0()); 40 | float unionArea = box1Area + box2Area - intersectionArea; 41 | return intersectionArea / unionArea; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/resources/lib/opencv-460.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/src/main/resources/lib/opencv-460.jar --------------------------------------------------------------------------------