├── README.md
├── images
├── OD-bus.jpg
├── OD-test.jpg
├── PE-bus.jpg
├── PE-test.jpg
├── bus.jpg
└── test.jpg
├── pom.xml
└── src
└── main
├── java
└── cn
│ └── halashuo
│ ├── ObjectDetection.java
│ ├── PoseEstimation.java
│ ├── config
│ ├── ODConfig.java
│ └── PEConfig.java
│ ├── domain
│ ├── KeyPoint.java
│ ├── ODResult.java
│ └── PEResult.java
│ └── utils
│ ├── Letterbox.java
│ └── NMS.java
└── resources
└── lib
└── opencv-460.jar
/README.md:
--------------------------------------------------------------------------------
1 | ## ONNX files downloaded
2 |
3 | Here are a few ONNX files that have been trained [Click here to download](http://pan.halashuo.cn/?dir=data/ONNX), After downloading or converting your trained model to ONNX, please put it in the `resources\model` folder, or you can customize the ONNX storage location by modifying the `PEConfig.java` or `ODConfig.java`.
4 |
5 | ## Configure the environment
6 |
7 | This project mainly uses onnx runtime and openCV.
8 |
9 | Please copy the opencv_java460.dll generated by the openCV installation to the `C:\Windows\System32` folder before using this project, otherwise `no opencv_java460 in java.library.path` errors will be generated.
10 |
11 | ## Predict the outcome
12 |
13 | ### Object Detection
14 |
15 | 
16 |
17 | 
18 |
19 | ### Pose Estimation
20 |
21 | 
22 |
23 | 
24 |
25 |
--------------------------------------------------------------------------------
/images/OD-bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/OD-bus.jpg
--------------------------------------------------------------------------------
/images/OD-test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/OD-test.jpg
--------------------------------------------------------------------------------
/images/PE-bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/PE-bus.jpg
--------------------------------------------------------------------------------
/images/PE-test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/PE-test.jpg
--------------------------------------------------------------------------------
/images/bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/bus.jpg
--------------------------------------------------------------------------------
/images/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/images/test.jpg
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | cn.halashuo
8 | yolov7
9 | 1.0
10 |
11 |
12 | 8
13 | 8
14 |
15 |
16 |
17 |
18 | com.microsoft.onnxruntime
19 | onnxruntime
20 | 1.12.1
21 |
22 |
23 |
24 | org.opencv
25 | opencv
26 | 4.6.0
27 | system
28 | ${project.basedir}/src/main/resources/lib/opencv-460.jar
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/ObjectDetection.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo;
2 |
3 | import ai.onnxruntime.OnnxTensor;
4 | import ai.onnxruntime.OrtEnvironment;
5 | import ai.onnxruntime.OrtException;
6 | import ai.onnxruntime.OrtSession;
7 | import cn.halashuo.domain.ODResult;
8 | import cn.halashuo.config.ODConfig;
9 | import cn.halashuo.utils.Letterbox;
10 | import org.opencv.core.Core;
11 | import org.opencv.core.Mat;
12 | import org.opencv.core.Point;
13 | import org.opencv.core.Scalar;
14 | import org.opencv.highgui.HighGui;
15 | import org.opencv.imgcodecs.Imgcodecs;
16 | import org.opencv.imgproc.Imgproc;
17 |
18 | import java.nio.FloatBuffer;
19 | import java.util.Arrays;
20 | import java.util.HashMap;
21 |
22 | public class ObjectDetection {
23 |
24 | static
25 | {
26 | //在使用OpenCV前必须加载Core.NATIVE_LIBRARY_NAME类,否则会报错
27 | System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
28 | }
29 |
30 | public static void main(String[] args) throws OrtException {
31 |
32 | // 加载ONNX模型
33 | OrtEnvironment environment = OrtEnvironment.getEnvironment();
34 | OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
35 | OrtSession session = environment.createSession(ODConfig.modelPath, sessionOptions);
36 | // 输出基本信息
37 | session.getInputInfo().keySet().forEach(x-> {
38 | try {
39 | System.out.println("input name = " + x);
40 | System.out.println(session.getInputInfo().get(x).getInfo().toString());
41 | } catch (OrtException e) {
42 | throw new RuntimeException(e);
43 | }
44 | });
45 |
46 | // 加载标签及颜色
47 | ODConfig odConfig = new ODConfig();
48 |
49 | // 读取 image
50 | Mat img = Imgcodecs.imread(ODConfig.picPath);
51 | Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB);
52 | Mat image = img.clone();
53 |
54 | // 在这里先定义下框的粗细、字的大小、字的类型、字的颜色(按比例设置大小粗细比较好一些)
55 | int minDwDh = Math.min(img.width(), img.height());
56 | int thickness = minDwDh/ODConfig.lineThicknessRatio;
57 | double fontSize = minDwDh/ODConfig.fontSizeRatio;
58 | int fontFace = Imgproc.FONT_HERSHEY_SIMPLEX;
59 | Scalar fontColor = new Scalar(255, 255, 255);
60 |
61 | // 更改 image 尺寸
62 | Letterbox letterbox = new Letterbox();
63 | image = letterbox.letterbox(image);
64 | double ratio = letterbox.getRatio();
65 | double dw = letterbox.getDw();
66 | double dh = letterbox.getDh();
67 | int rows = letterbox.getHeight();
68 | int cols = letterbox.getWidth();
69 | int channels = image.channels();
70 |
71 | // 将Mat对象的像素值赋值给Float[]对象
72 | float[] pixels = new float[channels * rows * cols];
73 | for (int i = 0; i < rows; i++) {
74 | for (int j = 0; j < cols; j++) {
75 | double[] pixel = image.get(j,i);
76 | for (int k = 0; k < channels; k++) {
77 | // 这样设置相当于同时做了image.transpose((2, 0, 1))操作
78 | pixels[rows*cols*k+j*cols+i] = (float) pixel[k]/255.0f;
79 | }
80 | }
81 | }
82 |
83 | // 创建OnnxTensor对象
84 | long[] shape = { 1L, (long)channels, (long)rows, (long)cols };
85 | OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
86 | HashMap stringOnnxTensorHashMap = new HashMap<>();
87 | stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
88 |
89 | // 运行模型
90 | OrtSession.Result output = session.run(stringOnnxTensorHashMap);
91 |
92 | // 得到结果
93 | float[][] outputData = (float[][]) output.get(0).getValue();
94 | Arrays.stream(outputData).iterator().forEachRemaining(x->{
95 | ODResult odResult = new ODResult(x);
96 | System.out.println(odResult);
97 | // 画框
98 | Point topLeft = new Point((odResult.getX0()-dw)/ratio, (odResult.getY0()-dh)/ratio);
99 | Point bottomRight = new Point((odResult.getX1()-dw)/ratio, (odResult.getY1()-dh)/ratio);
100 | Scalar color = new Scalar(odConfig.getColor(odResult.getClsId()));
101 | Imgproc.rectangle(img, topLeft, bottomRight, color, thickness);
102 | // 框上写文字
103 | String boxName = odConfig.getName(odResult.getClsId()) + ": " + odResult.getScore();
104 | Point boxNameLoc = new Point((odResult.getX0()-dw)/ratio, (odResult.getY0()-dh)/ratio-3);
105 | Imgproc.putText(img, boxName, boxNameLoc, fontFace, fontSize, fontColor, thickness);
106 | });
107 |
108 | Imgproc.cvtColor(img, img, Imgproc.COLOR_RGB2BGR);
109 |
110 | // 保存图像
111 | Imgcodecs.imwrite(ODConfig.savePicPath, img);
112 | HighGui.imshow("Display Image", img);
113 | // 等待按下任意键继续执行程序
114 | HighGui.waitKey();
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/PoseEstimation.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo;
2 |
3 | import ai.onnxruntime.OnnxTensor;
4 | import ai.onnxruntime.OrtEnvironment;
5 | import ai.onnxruntime.OrtException;
6 | import ai.onnxruntime.OrtSession;
7 | import cn.halashuo.domain.KeyPoint;
8 | import cn.halashuo.domain.PEResult;
9 | import cn.halashuo.utils.Letterbox;
10 | import cn.halashuo.utils.NMS;
11 | import cn.halashuo.config.PEConfig;
12 | import org.opencv.core.*;
13 | import org.opencv.highgui.HighGui;
14 | import org.opencv.imgcodecs.Imgcodecs;
15 | import org.opencv.imgproc.Imgproc;
16 |
17 | import java.nio.FloatBuffer;
18 | import java.util.ArrayList;
19 | import java.util.HashMap;
20 | import java.util.List;
21 |
22 | public class PoseEstimation {
23 | static
24 | {
25 | //在使用OpenCV前必须加载Core.NATIVE_LIBRARY_NAME类,否则会报错
26 | System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
27 | }
28 |
29 | public static void main(String[] args) throws OrtException {
30 |
31 | // 加载ONNX模型
32 | OrtEnvironment environment = OrtEnvironment.getEnvironment();
33 | OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
34 | OrtSession session = environment.createSession(PEConfig.modelPath, sessionOptions);
35 | // 输出基本信息
36 | session.getInputInfo().keySet().forEach(x -> {
37 | try {
38 | System.out.println("input name = " + x);
39 | System.out.println(session.getInputInfo().get(x).getInfo().toString());
40 | } catch (OrtException e) {
41 | throw new RuntimeException(e);
42 | }
43 | });
44 |
45 | // 读取 image
46 | Mat img = Imgcodecs.imread(PEConfig.picPath);
47 | Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB);
48 | Mat image = img.clone();
49 |
50 | // 在这里先定义下线的粗细、关键的半径(按比例设置大小粗细比较好一些)
51 | int minDwDh = Math.min(img.width(), img.height());
52 | int thickness = minDwDh / PEConfig.lineThicknessRatio;
53 | int radius = minDwDh / PEConfig.dotRadiusRatio;
54 |
55 | // 更改 image 尺寸
56 | Letterbox letterbox = new Letterbox();
57 | letterbox.setNewShape(new Size(960, 960));
58 | letterbox.setStride(64);
59 | image = letterbox.letterbox(image);
60 | double ratio = letterbox.getRatio();
61 | double dw = letterbox.getDw();
62 | double dh = letterbox.getDh();
63 | int rows = letterbox.getHeight();
64 | int cols = letterbox.getWidth();
65 | int channels = image.channels();
66 |
67 | // 将Mat对象的像素值赋值给Float[]对象
68 | float[] pixels = new float[channels * rows * cols];
69 | for (int i = 0; i < rows; i++) {
70 | for (int j = 0; j < cols; j++) {
71 | double[] pixel = image.get(j, i);
72 | for (int k = 0; k < channels; k++) {
73 | // 这样设置相当于同时做了image.transpose((2, 0, 1))操作
74 | pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f;
75 | }
76 | }
77 | }
78 |
79 | // 创建OnnxTensor对象
80 | long[] shape = {1L, (long) channels, (long) rows, (long) cols};
81 | OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
82 | HashMap stringOnnxTensorHashMap = new HashMap<>();
83 | stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
84 |
85 | // 运行模型
86 | OrtSession.Result output = session.run(stringOnnxTensorHashMap);
87 |
88 | // 得到结果
89 | float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
90 |
91 | List peResults = new ArrayList<>();
92 | for (float[] outputDatum : outputData) {
93 | PEResult result = new PEResult(outputDatum);
94 | if (result.getScore() > PEConfig.personScoreThreshold) {
95 | peResults.add(result);
96 | }
97 | }
98 |
99 | // 对结果进行非极大值抑制
100 | peResults = NMS.nms(peResults, PEConfig.IoUThreshold);
101 |
102 | for (PEResult peResult: peResults) {
103 | System.out.println(peResult);
104 | // 画框
105 | Point topLeft = new Point((peResult.getX0()-dw)/ratio, (peResult.getY0()-dh)/ratio);
106 | Point bottomRight = new Point((peResult.getX1()-dw)/ratio, (peResult.getY1()-dh)/ratio);
107 | Imgproc.rectangle(img, topLeft, bottomRight, new Scalar(255,0,0), thickness);
108 | List keyPoints = peResult.getKeyPointList();
109 | // 画点
110 | keyPoints.forEach(keyPoint->{
111 | if (keyPoint.getScore()>PEConfig.keyPointScoreThreshold) {
112 | Point center = new Point((keyPoint.getX()-dw)/ratio, (keyPoint.getY()-dh)/ratio);
113 | Scalar color = PEConfig.poseKptColor.get(keyPoint.getId());
114 | Imgproc.circle(img, center, radius, color, -1); //-1表示实心
115 | }
116 | });
117 | // 画线
118 | for (int i = 0; i< PEConfig.skeleton.length; i++){
119 | int indexPoint1 = PEConfig.skeleton[i][0]-1;
120 | int indexPoint2 = PEConfig.skeleton[i][1]-1;
121 | if ( keyPoints.get(indexPoint1).getScore()>PEConfig.keyPointScoreThreshold &&
122 | keyPoints.get(indexPoint2).getScore()>PEConfig.keyPointScoreThreshold ) {
123 | Scalar coler = PEConfig.poseLimbColor.get(i);
124 | Point point1 = new Point(
125 | (keyPoints.get(indexPoint1).getX()-dw)/ratio,
126 | (keyPoints.get(indexPoint1).getY()-dh)/ratio
127 | );
128 | Point point2 = new Point(
129 | (keyPoints.get(indexPoint2).getX()-dw)/ratio,
130 | (keyPoints.get(indexPoint2).getY()-dh)/ratio
131 | );
132 | Imgproc.line(img, point1, point2, coler, thickness);
133 | }
134 | }
135 | }
136 | Imgproc.cvtColor(img, img, Imgproc.COLOR_RGB2BGR);
137 | // 保存图像
138 | Imgcodecs.imwrite(PEConfig.savePicPath, img);
139 | HighGui.imshow("Display Image", img);
140 | // 等待按下任意键继续执行程序
141 | HighGui.waitKey();
142 |
143 | }
144 | }
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/config/ODConfig.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.config;
2 |
3 | import java.util.Random;
4 | import java.util.ArrayList;
5 | import java.util.Arrays;
6 | import java.util.List;
7 | import java.util.Map;
8 | import java.util.HashMap;
9 |
10 | public final class ODConfig {
11 |
12 | public static final String modelPath = "src\\main\\resources\\model\\yolov7-d6.onnx";
13 | public static final String picPath = "images\\test.jpg";// 要预测图片位置
14 | public static final String savePicPath = "images\\OD-test.jpg";// 预测结果保存位置
15 | public static final Integer lineThicknessRatio = 333;
16 | public static final Double fontSizeRatio = 1145.14;
17 | private static final List names = new ArrayList<>(Arrays.asList(
18 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
19 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
20 | "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear",
21 | "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
22 | "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
23 | "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
24 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
25 | "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
26 | "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
27 | "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
28 | "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
29 | "teddy bear", "hair drier", "toothbrush"));
30 |
31 | private final Map colors;
32 |
33 | public ODConfig() {
34 | this.colors = new HashMap<>();
35 | names.forEach(name->{
36 | Random random = new Random();
37 | double[] color = {random.nextDouble()*256, random.nextDouble()*256, random.nextDouble()*256};
38 | colors.put(name, color);
39 | });
40 | }
41 |
42 | public String getName(int clsId) {
43 | return names.get(clsId);
44 | }
45 |
46 | public double[] getColor(int clsId) {
47 | return colors.get(getName(clsId));
48 | }
49 | }
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/config/PEConfig.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.config;
2 |
3 | import org.opencv.core.Scalar;
4 |
5 | import java.util.ArrayList;
6 | import java.util.Arrays;
7 | import java.util.List;
8 |
9 | public final class PEConfig {
10 |
11 | public static final String modelPath = "src\\main\\resources\\model\\yolov7-w6-pose.onnx";
12 | public static final String picPath = "images\\test.jpg";// 要预测图片位置
13 | public static final String savePicPath = "images\\PE-test.jpg";// 预测结果保存位置
14 | public static final float IoUThreshold = 0.65f;
15 | public static final float personScoreThreshold = 0.25f;
16 | public static final float keyPointScoreThreshold = 0.45f;
17 | // 根据图像大小按比例控制点大小及线粗
18 | public static final Integer dotRadiusRatio = 168;
19 | public static final Integer lineThicknessRatio = 333;
20 |
21 |
22 | public static final List palette= new ArrayList<>(Arrays.asList(
23 | new Scalar( 255, 128, 0 ),
24 | new Scalar( 255, 153, 51 ),
25 | new Scalar( 255, 178, 102 ),
26 | new Scalar( 230, 230, 0 ),
27 | new Scalar( 255, 153, 255 ),
28 | new Scalar( 153, 204, 255 ),
29 | new Scalar( 255, 102, 255 ),
30 | new Scalar( 255, 51, 255 ),
31 | new Scalar( 102, 178, 255 ),
32 | new Scalar( 51, 153, 255 ),
33 | new Scalar( 255, 153, 153 ),
34 | new Scalar( 255, 102, 102 ),
35 | new Scalar( 255, 51, 51 ),
36 | new Scalar( 153, 255, 153 ),
37 | new Scalar( 102, 255, 102 ),
38 | new Scalar( 51, 255, 51 ),
39 | new Scalar( 0, 255, 0 ),
40 | new Scalar( 0, 0, 255 ),
41 | new Scalar( 255, 0, 0 ),
42 | new Scalar( 255, 255, 255 )
43 | ));
44 |
45 | public static final int[][] skeleton = {
46 | {16, 14}, {14, 12}, {17, 15}, {15, 13}, {12, 13}, {6, 12},
47 | {7, 13}, {6, 7}, {6, 8}, {7, 9}, {8, 10}, {9, 11}, {2, 3},
48 | {1, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, {5, 7}
49 | };
50 |
51 | public static final List poseLimbColor = new ArrayList<>(Arrays.asList(
52 | palette.get(9), palette.get(9), palette.get(9), palette.get(9), palette.get(7),
53 | palette.get(7), palette.get(7), palette.get(0), palette.get(0), palette.get(0),
54 | palette.get(0), palette.get(0), palette.get(16), palette.get(16), palette.get(16),
55 | palette.get(16), palette.get(16), palette.get(16), palette.get(16)));
56 |
57 | public static final List poseKptColor = new ArrayList<>(Arrays.asList(
58 | palette.get(16), palette.get(16), palette.get(16), palette.get(16), palette.get(16),
59 | palette.get(0), palette.get(0), palette.get(0), palette.get(0), palette.get(0),
60 | palette.get(0), palette.get(9), palette.get(9), palette.get(9), palette.get(9),
61 | palette.get(9), palette.get(9)));
62 |
63 | }
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/domain/KeyPoint.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.domain;
2 |
3 | public class KeyPoint {
4 | private final Integer id;
5 | private final Float x;
6 | private final Float y;
7 | private final Float score;
8 |
9 | public KeyPoint(Integer id, Float x, Float y, Float score) {
10 | this.id = id;
11 | this.x = x;
12 | this.y = y;
13 | this.score = score;
14 | }
15 |
16 | public Integer getId() {
17 | return id;
18 | }
19 |
20 | public Float getX() {
21 | return x;
22 | }
23 |
24 | public Float getY() {
25 | return y;
26 | }
27 |
28 | public Float getScore() {
29 | return score;
30 | }
31 |
32 | @Override
33 | public String toString() {
34 | return " 第 " + (id+1) + " 个关键点: " +
35 | " x=" + x +
36 | " y=" + y +
37 | " c=" + score +
38 | "\n";
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/domain/ODResult.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.domain;
2 |
3 | import java.text.DecimalFormat;
4 |
5 | public class ODResult {
6 | private final Integer batchId;
7 | private final Float x0;
8 | private final Float y0;
9 | private final Float x1;
10 | private final Float y1;
11 | private final Integer clsId;
12 | private final Float score;
13 |
14 | public ODResult(float[] x) {
15 | this.batchId = (int) x[0];
16 | this.x0 = x[1];
17 | this.y0 = x[2];
18 | this.x1 = x[3];
19 | this.y1 = x[4];
20 | this.clsId = (int) x[5];
21 | this.score = x[6];
22 | }
23 |
24 | public Integer getBatchId() {
25 | return batchId;
26 | }
27 |
28 | public Float getX0() {
29 | return x0;
30 | }
31 |
32 | public Float getY0() {
33 | return y0;
34 | }
35 |
36 | public Float getX1() {
37 | return x1;
38 | }
39 |
40 | public Float getY1() {
41 | return y1;
42 | }
43 |
44 | public Integer getClsId() {
45 | return clsId;
46 | }
47 |
48 | public String getScore() {
49 | DecimalFormat df = new DecimalFormat("0.00%");
50 | return df.format(this.score);
51 | }
52 |
53 | @Override
54 | public String toString() {
55 | return "物体: " +
56 | " \t batchId=" + batchId +
57 | " \t x0=" + x0 +
58 | " \t y0=" + y0 +
59 | " \t x1=" + x1 +
60 | " \t y1=" + y1 +
61 | " \t clsId=" + clsId +
62 | " \t score=" + getScore() +
63 | " \t ;";
64 | }
65 | }
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/domain/PEResult.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.domain;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class PEResult {
7 |
8 | private final Float x0;
9 | private final Float y0;
10 | private final Float x1;
11 | private final Float y1;
12 | private final Float score;
13 | private final Integer clsId;
14 | private final List keyPointList;
15 |
16 | public PEResult(float[] peResult) {
17 | float x = peResult[0];
18 | float y = peResult[1];
19 | float w = peResult[2]/2.0f;
20 | float h = peResult[3]/2.0f;
21 | this.x0 = x-w;
22 | this.y0 = y-h;
23 | this.x1 = x+w;
24 | this.y1 = y+h;
25 | this.score = peResult[4];
26 | this.clsId = (int) peResult[5];
27 | this.keyPointList = new ArrayList<>();
28 | int keyPointNum = (peResult.length-6)/3;
29 | for (int i=0;i getKeyPointList() {
59 | return keyPointList;
60 | }
61 |
62 | @Override
63 | public String toString() {
64 | StringBuilder result = new StringBuilder("PEResult:" +
65 | " x0=" + x0 +
66 | ", y0=" + y0 +
67 | ", x1=" + x1 +
68 | ", y1=" + y1 +
69 | ", score=" + score +
70 | ", clsId=" + clsId +
71 | "\n");
72 | for (KeyPoint x : keyPointList) {
73 | result.append(x.toString());
74 | }
75 | return result.toString();
76 | }
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/utils/Letterbox.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.utils;
2 |
3 | import org.opencv.core.Core;
4 | import org.opencv.core.Mat;
5 | import org.opencv.core.Size;
6 | import org.opencv.imgproc.Imgproc;
7 |
8 | public class Letterbox {
9 |
10 | private Size newShape = new Size(1280, 1280);
11 | private final double[] color = new double[]{114,114,114};
12 | private final Boolean auto = false;
13 | private final Boolean scaleUp = true;
14 | private Integer stride = 32;
15 |
16 | private double ratio;
17 | private double dw;
18 | private double dh;
19 |
20 | public double getRatio() {
21 | return ratio;
22 | }
23 |
24 | public double getDw() {
25 | return dw;
26 | }
27 |
28 | public Integer getWidth() {
29 | return (int) this.newShape.width;
30 | }
31 |
32 | public Integer getHeight() {
33 | return (int) this.newShape.height;
34 | }
35 |
36 | public double getDh() {
37 | return dh;
38 | }
39 |
40 | public void setNewShape(Size newShape) {
41 | this.newShape = newShape;
42 | }
43 |
44 | public void setStride(Integer stride) {
45 | this.stride = stride;
46 | }
47 |
48 | public Mat letterbox(Mat im) { // 调整图像大小和填充图像,使满足步长约束,并记录参数
49 |
50 | int[] shape = {im.rows(), im.cols()}; // 当前形状 [height, width]
51 | // Scale ratio (new / old)
52 | double r = Math.min(this.newShape.height / shape[0], this.newShape.width / shape[1]);
53 | if (!this.scaleUp) { // 仅缩小,不扩大(一且为了mAP)
54 | r = Math.min(r, 1.0);
55 | }
56 | // Compute padding
57 | Size newUnpad = new Size(Math.round(shape[1] * r), Math.round(shape[0] * r));
58 | double dw = this.newShape.width - newUnpad.width, dh = this.newShape.height - newUnpad.height; // wh 填充
59 | if (this.auto) { // 最小矩形
60 | dw = dw % this.stride;
61 | dh = dh % this.stride;
62 | }
63 | dw /= 2; // 填充的时候两边都填充一半,使图像居于中心
64 | dh /= 2;
65 | if (shape[1] != newUnpad.width || shape[0] != newUnpad.height) { // resize
66 | Imgproc.resize(im, im, newUnpad, 0, 0, Imgproc.INTER_LINEAR);
67 | }
68 | int top = (int) Math.round(dh - 0.1), bottom = (int) Math.round(dh + 0.1);
69 | int left = (int) Math.round(dw - 0.1), right = (int) Math.round(dw + 0.1);
70 | // 将图像填充为正方形
71 | Core.copyMakeBorder(im, im, top, bottom, left, right, Core.BORDER_CONSTANT, new org.opencv.core.Scalar(this.color));
72 | this.ratio = r;
73 | this.dh = dh;
74 | this.dw = dw;
75 | return im;
76 | }
77 | }
--------------------------------------------------------------------------------
/src/main/java/cn/halashuo/utils/NMS.java:
--------------------------------------------------------------------------------
1 | package cn.halashuo.utils;
2 |
3 | import cn.halashuo.domain.PEResult;
4 |
5 | import java.util.ArrayList;
6 | import java.util.List;
7 |
8 | public class NMS {
9 |
10 | public static List nms(List boxes, float iouThreshold) {
11 | // 根据score从大到小对List进行排序
12 | boxes.sort((b1, b2) -> Float.compare(b2.getScore(), b1.getScore()));
13 | List resultList = new ArrayList<>();
14 | for (int i = 0; i < boxes.size(); i++) {
15 | PEResult box = boxes.get(i);
16 | boolean keep = true;
17 | // 从i+1开始,遍历之后的所有boxes,移除与box的IOU大于阈值的元素
18 | for (int j = i + 1; j < boxes.size(); j++) {
19 | PEResult otherBox = boxes.get(j);
20 | float iou = getIntersectionOverUnion(box, otherBox);
21 | if (iou > iouThreshold) {
22 | keep = false;
23 | break;
24 | }
25 | }
26 | if (keep) {
27 | resultList.add(box);
28 | }
29 | }
30 | return resultList;
31 | }
32 | private static float getIntersectionOverUnion(PEResult box1, PEResult box2) {
33 | float x1 = Math.max(box1.getX0(), box2.getX0());
34 | float y1 = Math.max(box1.getY0(), box2.getY0());
35 | float x2 = Math.min(box1.getX1(), box2.getX1());
36 | float y2 = Math.min(box1.getY1(), box2.getY1());
37 | float intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
38 | float box1Area = (box1.getX1() - box1.getX0()) * (box1.getY1() - box1.getY0());
39 | float box2Area = (box2.getX1() - box2.getX0()) * (box2.getY1() - box2.getY0());
40 | float unionArea = box1Area + box2Area - intersectionArea;
41 | return intersectionArea / unionArea;
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/resources/lib/opencv-460.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PanBohao/yolov7-java-onnx/1dea6bb08f8600d1ba41b7229095123f4e36a48e/src/main/resources/lib/opencv-460.jar
--------------------------------------------------------------------------------