├── README.md
├── pom.xml
└── src
└── main
├── java
└── sky
│ └── tf
│ ├── GarbageClassificationModel.java
│ ├── IModel.java
│ ├── ImageClassIndex.java
│ ├── ImageClassificationMain.java
│ ├── ImageDataPreprocessing.java
│ ├── ModelParams.java
│ ├── PredictionResult.java
│ ├── TurboJpegLoader.java
│ ├── multiplemodels
│ ├── ImageFlatMapMultipleModels.java
│ ├── ImageModelLoaderMultipleModels.java
│ ├── ImagePredictSupportiveMultipleModels.java
│ └── OpenVinoModelGeneratorMultipleModels.java
│ └── threemodels
│ ├── ImageFlatMap3Models.java
│ ├── ImageModelLoader3Models.java
│ ├── ImagePredictSupportive3Models.java
│ ├── OpenVinoModelGenerator3Models.java
│ └── ValidationDatasetAnalyzer.java
└── resources
└── log4j.properties
/README.md:
--------------------------------------------------------------------------------
1 | # Apache Flink极客挑战赛——垃圾图片分类—复赛Java code
2 |
3 | # 1. How to build
4 |
5 | #### 1.1 安装 garbage_image_util 包到本地maven仓库
6 | ```bash
7 | sh install_jar.sh
8 | ```
9 | 注:garbage_image_util 由天池提供,详见天池比赛的readme
10 |
11 |
12 | #### 1.2 build package (jar)
13 | ```bash
14 | #mvn clean package
15 | ```
16 |
17 |
18 | # 2.许可声明
19 | 你可以使用此代码用于学习和研究,但务必不要将此代码用于任何商业用途和比赛项目。
20 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
19 |
21 | 4.0.0
22 |
23 | org.skypeace
24 | sky-flink-phase2
25 | 0.1
26 | jar
27 |
28 | Flink Quickstart Job
29 | https://tianchi.aliyun.com/
30 |
31 |
32 | UTF-8
33 | 1.8.1
34 | 1.8
35 | 2.11
36 | 2.7.7
37 | ${java.version}
38 | ${java.version}
39 |
40 |
41 |
42 |
43 | apache.snapshots
44 | Apache Development Snapshot Repository
45 | https://repository.apache.org/content/repositories/snapshots/
46 |
47 | false
48 |
49 |
50 | true
51 |
52 |
53 |
54 | ossrh
55 | sonatype repositroy
56 | https://oss.sonatype.org/content/groups/public/
57 |
58 | true
59 |
60 |
61 | true
62 |
63 |
64 |
65 | central
66 | http://repo1.maven.org/maven2/
67 |
68 | true
69 |
70 |
71 | true
72 |
73 |
74 |
75 | ome-releases
76 | https://artifacts.openmicroscopy.org/artifactory/ome.releases/
77 |
78 | true
79 |
80 |
81 | true
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 | com.alibaba.tianchi
90 | garbage_image_util
91 | 1.0-SNAPSHOT
92 |
93 |
94 | xml-apis
95 | xml-apis
96 |
97 |
98 |
99 |
100 | ome
101 | turbojpeg
102 | 6.2.1
103 |
104 |
105 |
106 |
107 | org.apache.flink
108 | flink-java
109 | ${flink.version}
110 | provided
111 |
112 |
113 | org.apache.flink
114 | flink-streaming-java_${scala.binary.version}
115 | ${flink.version}
116 | provided
117 |
118 |
119 | org.apache.flink
120 | flink-core
121 | ${flink.version}
122 | provided
123 |
124 |
125 | org.scala-lang
126 | scala-library
127 | 2.11.11
128 | provided
129 |
130 |
131 |
132 | com.intel.analytics.zoo
133 | analytics-zoo-bigdl_0.9.0-spark_2.4.3
134 | 0.6.0-SNAPSHOT
135 | jar
136 |
137 |
138 | com.intel.analytics.zoo
139 | zoo-core-openvino-java-linux
140 | 0.6.0-SNAPSHOT
141 |
142 |
143 | com.intel.analytics.zoo
144 | zoo-core-mkl-linux
145 | 0.6.0-SNAPSHOT
146 |
147 |
148 | com.intel.analytics.zoo
149 | zoo-core-pmem-java-linux
150 | 0.6.0-SNAPSHOT
151 |
152 |
153 |
154 |
155 | org.apache.spark
156 | spark-core_2.11
157 | 2.4.3
158 | compile
159 |
160 |
161 | org.apache.spark
162 | spark-network-common_2.11
163 | 2.4.3
164 | runtime
165 |
166 |
167 | org.apache.spark
168 | spark-network-shuffle_2.11
169 | 2.4.3
170 | runtime
171 |
172 |
173 | org.apache.spark
174 | spark-mllib_2.11
175 | 2.4.3
176 | runtime
177 |
178 |
179 |
180 | org.slf4j
181 | slf4j-log4j12
182 | 1.7.7
183 | runtime
184 |
185 |
186 | log4j
187 | log4j
188 | 1.2.17
189 | runtime
190 |
191 |
192 |
193 | org.json4s
194 | json4s-core_2.11
195 | 3.5.3
196 | jar
197 | compile
198 |
199 |
200 |
201 | org.json4s
202 | json4s-jackson_2.11
203 | 3.5.3
204 | jar
205 | compile
206 |
207 |
208 |
209 | com.google.protobuf
210 | protobuf-java
211 | 3.9.1
212 | jar
213 | compile
214 |
215 |
216 |
217 | org.apache.hadoop
218 | hadoop-common
219 | ${hadoop.version}
220 |
221 |
222 | org.apache.hadoop
223 | hadoop-hdfs
224 | ${hadoop.version}
225 |
226 |
227 | org.apache.hadoop
228 | hadoop-client
229 | ${hadoop.version}
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 | org.apache.maven.plugins
240 | maven-compiler-plugin
241 | 3.1
242 |
243 | ${java.version}
244 | ${java.version}
245 |
246 |
247 |
248 |
249 |
250 |
251 | org.apache.maven.plugins
252 | maven-shade-plugin
253 | 3.0.0
254 |
255 |
256 |
257 | package
258 |
259 | shade
260 |
261 |
262 |
263 |
264 | org.apache.flink:*
265 | org.slf4j:*
266 | log4j:*
267 | com.google.code.findbugs:jsr305
268 |
269 |
270 |
271 |
272 |
274 | *:*
275 |
276 | META-INF/*.SF
277 | META-INF/*.DSA
278 | META-INF/*.RSA
279 |
280 |
281 |
282 |
283 |
284 | sky.tf.ImageClassificationMain
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 | org.eclipse.m2e
299 | lifecycle-mapping
300 | 1.0.0
301 |
302 |
303 |
304 |
305 |
306 | org.apache.maven.plugins
307 | maven-shade-plugin
308 | [3.0.0,)
309 |
310 | shade
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 | org.apache.maven.plugins
320 | maven-compiler-plugin
321 | [3.1,)
322 |
323 | testCompile
324 | compile
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 | add-dependencies-for-IDEA
345 |
346 |
347 |
348 | idea.version
349 |
350 |
351 |
352 |
353 |
354 | org.apache.flink
355 | flink-core
356 | ${flink.version}
357 | compile
358 |
359 |
360 | org.apache.flink
361 | flink-java
362 | ${flink.version}
363 | compile
364 |
365 |
366 | org.apache.flink
367 | flink-streaming-java_${scala.binary.version}
368 | ${flink.version}
369 | compile
370 |
371 |
372 | com.intel.analytics.zoo
373 | zoo-core-tfnet-linux
374 | 0.6.0-SNAPSHOT
375 | jar
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/GarbageClassificationModel.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 |
3 | import com.intel.analytics.zoo.pipeline.inference.*;
4 |
5 | import java.util.List;
6 | import java.util.concurrent.LinkedBlockingQueue;
7 |
8 | /**
9 | * @author SkyPeace
10 | * The models for predict. For this case, use Zoo OpenVINOModel directly instead of InferenceModel.
11 | */
12 | public class GarbageClassificationModel implements IModel, java.io.Serializable
13 | {
14 | // Support Zoo InferenceModel and Zoo OpenVINOModel directly
15 | // For this case, the final decision is use Zoo OpenVINOModel directly instead of Zoo InferenceModel.
16 | private OpenVINOModel openVINOModel;
17 | private LinkedBlockingQueue referenceQueue = new LinkedBlockingQueue(12);
18 |
19 | public GarbageClassificationModel(byte[] modelXml, byte[] modelBin)
20 | {
21 | this.openVINOModel = this.loadOpenVINOModel(modelXml, modelBin);
22 | }
23 |
24 | private OpenVINOModel loadOpenVINOModel(byte[] modelXml, byte[] modelBin){
25 | long beginTime = System.currentTimeMillis();
26 | System.out.println(String.format("LOAD_OPENVION_MODEL_FROM_BYTES BEGIN %s", beginTime));
27 | OpenVINOModel model = OpenVinoInferenceSupportive$.MODULE$.loadOpenVinoIR(modelXml, modelBin,
28 | com.intel.analytics.zoo.pipeline.inference.DeviceType.CPU(), 0);
29 | //OpenVinoInferenceSupportive.loadOpenVinoIRFromTempDir("modelName", "tempDir");
30 | long endTime = System.currentTimeMillis();
31 | System.out.println(String.format("LOAD_OPENVION_MODEL_FROM_BYTES END %s (Cost: %s)",
32 | endTime, (endTime - beginTime)));
33 | return model;
34 | }
35 |
36 | /**
37 | * Predict by inputs
38 | * @param inputs
39 | * @return
40 | * @throws Exception
41 | */
42 | public List> predict(List> inputs) throws Exception {
43 | if(openVINOModel!=null)
44 | return openVINOModel.predict(inputs);
45 | else if(inferenceModel!=null)
46 | return inferenceModel.doPredict(inputs);
47 | else
48 | throw new RuntimeException("inferenceModel and openVINOModel are both null.");
49 | }
50 |
51 | public void addRefernce() throws InterruptedException
52 | {
53 | referenceQueue.put(1);
54 | }
55 |
56 | /**
57 | * Release model
58 | * @throws Exception
59 | */
60 | public synchronized void release() throws Exception
61 | {
62 | if(referenceQueue.peek()==null)
63 | return;
64 | referenceQueue.poll();
65 | if(referenceQueue.peek()==null)
66 | {
67 | if (openVINOModel != null) {
68 | System.out.println("Release openVINOModel ...");
69 | openVINOModel.release();
70 | System.out.println("openVINOModel released");
71 | }else if (inferenceModel != null) {
72 | System.out.println("Release inferenceModel ...");
73 | inferenceModel.doRelease();
74 | System.out.println("inferenceModel released");
75 | }
76 | else
77 | throw new RuntimeException("openVINOModel and inferenceModel are both null.");
78 | }
79 | }
80 |
81 |
82 | //Below code is reserved for test. Do not delete them.
83 | private InferenceModel inferenceModel;
84 | /**
85 | * Reserved for experimental test
86 | * @param savedModelBytes
87 | */
88 | public GarbageClassificationModel(byte[] savedModelBytes, ModelParams modelParams) {
89 | this.inferenceModel = this.loadInferenceModel(savedModelBytes, modelParams);
90 | }
91 |
92 | /**
93 | * Reserved for experimental test
94 | * @param savedModelBytes
95 | * @return
96 | */
97 | private InferenceModel loadInferenceModel(byte[] savedModelBytes, ModelParams modelParams)
98 | {
99 | long beginTime = System.currentTimeMillis();
100 | System.out.println(String.format("loadInferenceModel BEGIN %s", beginTime));
101 | InferenceModel model = new InferenceModel(1);
102 | model.doLoadTF(savedModelBytes, modelParams.getInputShape(), false,
103 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName());
104 | long endTime = System.currentTimeMillis();
105 | System.out.println(String.format("loadInferenceModel END %s (Cost: %s)",
106 | endTime, (endTime - beginTime)));
107 | return model;
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/IModel.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 |
3 | import com.intel.analytics.zoo.pipeline.inference.JTensor;
4 |
5 | import java.util.List;
6 |
7 | public interface IModel {
8 | List> predict(List> inputs) throws Exception;
9 | void release() throws Exception;
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/ImageClassIndex.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 |
3 | import java.io.*;
4 | import java.util.HashMap;
5 | import java.util.Map;
6 |
7 | /**
8 | * @author SkyPeace
9 | * Image class index
10 | */
11 | public class ImageClassIndex {
12 |
13 | private Map mapClassIndex;
14 | private static ImageClassIndex instance = new ImageClassIndex();
15 |
16 | private ImageClassIndex()
17 | {
18 | }
19 | public static ImageClassIndex getInsatnce()
20 | {
21 | return instance;
22 | }
23 |
24 | public synchronized void loadClassIndexMap(String modelInferencePath)
25 | {
26 | String labelFilePath = modelInferencePath + File.separator + "labels.txt";
27 | if(this.mapClassIndex == null) {
28 | long beginTime = System.currentTimeMillis();
29 | System.out.println(String.format("READ_CLASS_INDEX BEGIN %s", beginTime));
30 | this.mapClassIndex = this.read(labelFilePath);
31 | long endTime = System.currentTimeMillis();
32 | System.out.println(String.format("READ_CLASS_INDEX END %s (Cost: %s)", endTime, (endTime - beginTime)));
33 | }
34 | }
35 |
36 | public String getImageHumanstring(String id)
37 | {
38 | return mapClassIndex.get(id);
39 | }
40 |
41 | /**
42 | * Get class index.
43 | * For test only.
44 | * @param filePath
45 | * @return
46 | */
47 | private Map read(String filePath)
48 | {
49 | Map map = new HashMap();
50 | try {
51 | FileReader fr = new FileReader(filePath);
52 | BufferedReader br = new BufferedReader(fr);
53 | String line;
54 | while ((line = br.readLine()) != null) {
55 | String columns[] = line.split(":");
56 | map.put(columns[0], columns[1]);
57 | }
58 | fr.close();
59 | }catch (IOException ex)
60 | {
61 | String errMsg = "Read class index file FAILED. " + ex.getMessage();
62 | System.out.println("ERROR: " + errMsg);
63 | throw new RuntimeException(errMsg);
64 | }
65 | return map;
66 | }
67 |
68 | }
69 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/ImageClassificationMain.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 | import com.alibaba.tianchi.garbage_image_util.ImageClassSink;
3 | import com.alibaba.tianchi.garbage_image_util.ImageDirSource;
4 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
5 | import sky.tf.threemodels.ImageFlatMap3Models;
6 | import sky.tf.threemodels.OpenVinoModelGenerator3Models;
7 | /**
8 | * The main class of Image inference for garbage classification
9 | * @author SkyPeace
10 | */
11 | public class ImageClassificationMain {
12 | public static void main(String[] args) throws Exception {
13 | StreamExecutionEnvironment flinkEnv = StreamExecutionEnvironment.getExecutionEnvironment();
14 |
15 | //Step 1. Generate optimized OpenVino models (Three type of models) for later prediction.
16 | OpenVinoModelGenerator3Models modelGenerator = new OpenVinoModelGenerator3Models();
17 | modelGenerator.execute();
18 | ImageFlatMap3Models imageFlatMap = new ImageFlatMap3Models(modelGenerator.getModel1Params(),
19 | modelGenerator.getModel2Params(), modelGenerator.getModel3Params());
20 |
21 | ImageDirSource source = new ImageDirSource();
22 | //IMPORTANT: Operator chaining maybe hit the score log issue (Tianchi ENV) when parallelism is set to N_N_x.
23 | //Use statistic tag PREDICT_PROCESS to get prediction's real elapsed time
24 | //flinkEnv.disableOperatorChaining();
25 | flinkEnv.addSource(source).setParallelism(1)
26 | .flatMap(imageFlatMap).setParallelism(2)
27 | .addSink(new ImageClassSink()).setParallelism(2);
28 | flinkEnv.execute("Image inference for garbage classification-PhaseII-1.0 - SkyPeace");
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/ImageDataPreprocessing.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ImageData;
4 | import com.intel.analytics.bigdl.transform.vision.image.opencv.OpenCVMat;
5 | import com.intel.analytics.zoo.pipeline.inference.JTensor;
6 | import org.libjpegturbo.turbojpeg.TJ;
7 | import org.libjpegturbo.turbojpeg.TJDecompressor;
8 | import org.opencv.core.*;
9 | import org.opencv.imgcodecs.Imgcodecs;
10 | import org.opencv.imgproc.Imgproc;
11 |
12 | import java.util.ArrayList;
13 | import java.util.List;
14 |
15 | /**
16 | * @author SkyPeace
17 | * The class for image data preprocessing.
18 | */
19 | public class ImageDataPreprocessing {
20 | public static final int IMAGE_DECODER_OPENCV = 0;
21 | public static final int IMAGE_DECODER_TURBOJPEG = 1;
22 |
23 | public static final int PREPROCESSING_VGG = 11;
24 | public static final int PREPROCESSING_INCEPTION = 12;
25 | public static final int PREPROCESSING_TIANCHI = 13;
26 | private ModelParams modelParams;
27 |
28 | public ImageDataPreprocessing(ModelParams modelParams)
29 | {
30 | this.modelParams = modelParams;
31 | }
32 |
33 | /**
34 | * Get RGB Mat data
35 | * @param imageData
36 | * @param decodeType
37 | * @return
38 | * @throws Exception
39 | */
40 | public static Mat getRGBMat(ImageData imageData, int decodeType) throws Exception
41 | {
42 | //********************** Decode image **********************//
43 | long beginTime = System.currentTimeMillis();
44 | System.out.println(String.format("%s IMAGE_BYTES = %s", "###", imageData.getImage().length));
45 | System.out.println(String.format("%s IMAGE_BYTES_IMDECODE BEGIN %s", "###", beginTime));
46 | Mat matRGB = null;
47 | if(decodeType == IMAGE_DECODER_OPENCV)
48 | matRGB = decodeByOpenCV(imageData);
49 | else if(decodeType == IMAGE_DECODER_TURBOJPEG)
50 | matRGB = decodeByTurboJpeg(imageData);
51 | else
52 | throw new Exception(String.format("Not support such decodeType: %s", decodeType));
53 | System.out.println(String.format("%s IMAGE_BYTES_IMDECODE END %s (Cost: %s)",
54 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime)));
55 | return matRGB;
56 | }
57 |
58 | /**
59 | * Decode image data by turbojpeg. It is more fast than OpenCV decoder.
60 | * Please refer to https://libjpeg-turbo.org/About/Performance for performance comparison.
61 | * @param imageData
62 | * @return
63 | * @throws Exception
64 | */
65 | private static Mat decodeByTurboJpeg(ImageData imageData) throws Exception
66 | {
67 | TJDecompressor tjd = new TJDecompressor(imageData.getImage());
68 | //TJ.FLAG_FASTDCT can get more performance. TJ.FLAG_ACCURATEDCT can get more accuracy.
69 | byte bytes[] = tjd.decompress(tjd.getWidth(), 0, tjd.getHeight(), TJ.PF_RGB, TJ.FLAG_ACCURATEDCT);
70 | Mat matRGB = new Mat(tjd.getHeight(), tjd.getWidth(), CvType.CV_8UC3);
71 | matRGB.put(0, 0, bytes);
72 | return matRGB;
73 | }
74 |
75 | /**
76 | * Decode image data by OpenCV
77 | * @param imageData
78 | * @return
79 | */
80 | private static Mat decodeByOpenCV(ImageData imageData)
81 | {
82 | Mat srcmat = new MatOfByte(imageData.getImage());
83 | Mat mat = Imgcodecs.imdecode(srcmat, Imgcodecs.CV_LOAD_IMAGE_COLOR);
84 | Mat matRGB = new MatOfByte();
85 | Imgproc.cvtColor(mat, matRGB, Imgproc.COLOR_BGR2RGB);
86 | System.out.println(String.format("%s image_width, image_height, image_depth, image_dims = %s, %s, %s, %s",
87 | "###", mat.width(), mat.height(), mat.depth(), mat.dims()));
88 | return matRGB;
89 | }
90 |
91 | /**
92 | * Preprocess image data and convert to float array (RGB) for JTensor list need.
93 | * @param matRGB
94 | * @param preprocessType
95 | * @param enableMultipleVoters
96 | * @return
97 | * @throws Exception
98 | */
99 | public List> doPreProcessing(Mat matRGB, int preprocessType, boolean enableMultipleVoters) throws Exception
100 | {
101 | //********************** Resize, crop and other pre-processing **********************//
102 | long beginTime = System.currentTimeMillis();
103 | System.out.println(String.format("%s IMAGE_RESIZE BEGIN %s", "###", beginTime));
104 | List matList = new ArrayList();
105 | int image_size = modelParams.getInputShape()[2];
106 |
107 | //VGG_PREPROCESS, INCEPTION_PREPROCESS and PREPROCESSING_TIANCHI
108 | if(preprocessType==PREPROCESSING_VGG) {
109 | matList = this.resizeAndCenterCropImage(matRGB, 256, image_size, image_size, enableMultipleVoters);
110 | }else if(preprocessType==PREPROCESSING_INCEPTION) {
111 | matList = this.cropAndResizeImage(matRGB, image_size, image_size, enableMultipleVoters);
112 | }else if(preprocessType==PREPROCESSING_TIANCHI) {
113 | matList = this.resize(matRGB, image_size);
114 | }
115 | else
116 | throw new Exception(String.format("Not support such preprocessType: %s", preprocessType));
117 | System.out.println(String.format("%s IMAGE_RESIZE END %s (Cost: %s)",
118 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime)));
119 |
120 | //********** Convert Mat to float array. R channel, G channel, B channel **********//
121 | beginTime = System.currentTimeMillis();
122 | System.out.println(String.format("%s CONVERT_TO_RGB_FLOAT BEGIN %s", "###", beginTime));
123 | List> inputs = new ArrayList>();
124 | for(int i=0;i();
126 | float data[] = new float[modelParams.getInputSize()];
127 | List rgbMatList = new ArrayList();
128 | //Split to R channel, G channel and B channel
129 | org.opencv.core.Core.split(matList.get(i), rgbMatList);
130 | MatOfByte matFloat = new MatOfByte();
131 | //VConcat R,G,B and convert to float array. Native operation, quick enough (AVG less than 1ms).
132 | org.opencv.core.Core.vconcat(rgbMatList, matFloat);
133 | OpenCVMat.toFloatPixels(matFloat, data);
134 | //System.arraycopy(data, 0, datat, i*data.length, data.length);
135 |
136 | /* Below code is the example of worst performance, DO NOT coding as that.
137 | for (int row = 0; row < 224; row++) {
138 | for (int col = 0; col < 224; col++) {
139 | data[(col + row * 224) + 224 * 224 * 0] = (float) (dst.get(row, col)[0]);
140 | data[(col + row * 224) + 224 * 224 * 1] = (float) (dst.get(row, col)[1]);
141 | data[(col + row * 224) + 224 * 224 * 2] = (float) (dst.get(row, col)[2]);
142 | }
143 | }
144 | */
145 |
146 | //Create a JTensor
147 | JTensor tensor = new JTensor();
148 | tensor.setData(data);
149 | tensor.setShape(modelParams.getInputShape());
150 | list.add(tensor);
151 | inputs.add(list);
152 | }
153 | System.out.println(String.format("%s CONVERT_TO_RGB_FLOAT END %s (Cost: %s)",
154 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime)));
155 | return inputs;
156 | }
157 |
158 | /**
159 | * Just resize to specified size. Use for TianChi model
160 | * @param src
161 | * @param image_size
162 | * @return
163 | */
164 | private List resize(Mat src, int image_size)
165 | {
166 | Mat resizedMat = new MatOfByte();
167 | Imgproc.resize(src, resizedMat, new Size(image_size, image_size), 0, 0, Imgproc.INTER_CUBIC);
168 | List matList = new ArrayList();
169 | matList.add(resizedMat);
170 | return matList;
171 | }
172 |
173 | /**
174 | * Resize and center crop. For VGG_PREPROCESS
175 | * @param src
176 | * @param smallestSide
177 | * @param outWidth
178 | * @param outHeight
179 | * @param enableMultipleVoters
180 | */
181 | private List resizeAndCenterCropImage(Mat src, int smallestSide, int outWidth, int outHeight, boolean enableMultipleVoters)
182 | {
183 | //OpenCV resize. Use INTER_LINEAR, INTER_CUBIC or INTER_LANCZOS4
184 | //Imgproc.resize(src, dst, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_CUBIC);
185 | float scale = 0f;
186 | if(src.height()>src.width())
187 | scale = (float)smallestSide / src.width();
188 | else
189 | scale = (float)smallestSide / src.height();
190 | float newHeight = src.height() * scale;
191 | float newWidth = src.width() * scale;
192 | Mat resizedMat = new MatOfByte();
193 |
194 | Imgproc.resize(src, resizedMat, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
195 | int offsetHeight = (int)((resizedMat.height() - outHeight) / 2);
196 | int offsetWidth = (int)((resizedMat.width() - outWidth) / 2);
197 |
198 | //Center
199 | Mat matCenter = resizedMat.submat(offsetHeight, offsetHeight + outHeight, offsetWidth, offsetWidth + outWidth);
200 | List matList = new ArrayList();
201 | matList.add(matCenter);
202 |
203 | if(enableMultipleVoters) {
204 | //matFlip
205 | Mat matCenterFlip = new MatOfByte();
206 | Core.flip(matCenter, matCenterFlip, 0);
207 | //Mat matTop = resizedMat.submat(0, outHeight, offsetWidth, offsetWidth + outWidth);
208 | Mat matBottom = resizedMat.submat((int) resizedMat.height() - outHeight, (int) resizedMat.height(), offsetWidth, offsetWidth + outWidth);
209 | Mat matLeft = resizedMat.submat(offsetHeight, offsetHeight + outHeight, 0, outWidth);
210 | Mat matRight = resizedMat.submat(offsetHeight, offsetHeight + outHeight, (int) resizedMat.width() - outWidth, (int) resizedMat.width());
211 | matList.add(matCenterFlip);
212 | //matList.add(matTop);
213 | matList.add(matBottom);
214 | matList.add(matRight);
215 | matList.add(matLeft);
216 | }
217 | return matList;
218 | }
219 |
220 | /**
221 | * Crop and resize. For INCEPTION_PREPROCESS
222 | * @param src
223 | * @param outWidth
224 | * @param outHeight
225 | * @param enableMultipleVoters
226 | * @return
227 | */
228 | private List cropAndResizeImage(Mat src, int outWidth, int outHeight, boolean enableMultipleVoters)
229 | {
230 | float scale = 0.875f;
231 | float newHeight = src.height() * scale;
232 | float newWidth = src.width() * scale;
233 |
234 | int offsetHeight = (int)((src.height() - newHeight) / 2);
235 | int offsetWidth = (int)((src.width() - newWidth) / 2);
236 | Mat centerCropMat = src.submat(offsetHeight, offsetHeight + (int)newHeight, offsetWidth, offsetWidth + (int)newWidth);
237 |
238 | Mat resizedMat = new MatOfByte();
239 | Imgproc.resize(centerCropMat, resizedMat, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR);
240 | List matList = new ArrayList();
241 | matList.add(resizedMat);
242 |
243 | if(enableMultipleVoters) {
244 | Mat matCenterFlip = new MatOfByte();
245 | Core.flip(resizedMat, matCenterFlip, 0);
246 | matList.add(matCenterFlip);
247 |
248 | Mat matBottom = src.submat(src.height()-(int)newHeight, src.height(), offsetWidth, offsetWidth + (int)newWidth);
249 | Mat resizedMatBottom = new MatOfByte();
250 | Imgproc.resize(matBottom, resizedMatBottom, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR);
251 | matList.add(resizedMatBottom);
252 |
253 | Mat matLeft = src.submat(offsetHeight, offsetHeight + (int) newHeight, 0, (int) newWidth);
254 | Mat resizedMatLeft = new MatOfByte();
255 | Imgproc.resize(matLeft, resizedMatLeft, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR);
256 | matList.add(resizedMatLeft);
257 |
258 | Mat matRight = src.submat(offsetHeight, offsetHeight + (int)newHeight, src.width() - (int)newWidth, src.width());
259 | Mat resizedMatRight = new MatOfByte();
260 | Imgproc.resize(matRight, resizedMatRight, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR);
261 | matList.add(resizedMatRight);
262 | }
263 | return matList;
264 | }
265 |
266 | /**
267 | * Another implements of INCEPTION_PREPROCESS. For experimental only.
268 | * @param src
269 | * @param outWidth
270 | * @param outHeight
271 | * @param dst
272 | */
273 | private void cropAndResizeImageReserved(Mat src, int outWidth, int outHeight, Mat dst)
274 | {
275 | Mat srcFloat = new MatOfFloat();
276 | src.convertTo(srcFloat, CvType.CV_32F, 1.0f/255.0f);
277 | float scale = 0.875f;
278 | float newHeight = srcFloat.height() * scale;
279 | float newWidth = srcFloat.width() * scale;
280 |
281 | int offsetHeight = (int)((srcFloat.height() - newHeight) / 2);
282 | int offsetWidth = (int)((srcFloat.width() - newWidth) / 2);
283 | Mat cropMat = srcFloat.submat(offsetHeight, offsetHeight + (int)newHeight, offsetWidth, offsetWidth + (int)newWidth);
284 |
285 | Mat resizedMat = new MatOfFloat();
286 | Imgproc.resize(cropMat, resizedMat, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR);
287 |
288 | Mat resizedMat1 = new MatOfFloat();
289 | Core.subtract(resizedMat, Scalar.all(0.5), resizedMat1);
290 | Core.multiply(resizedMat1, Scalar.all(2.0), dst);
291 | }
292 |
293 | }
294 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/ModelParams.java:
--------------------------------------------------------------------------------
1 | package sky.tf;
2 |
3 | /**
4 | * @author SkyPeace
5 | * The class for configure model's parameters.
6 | */
7 | public class ModelParams implements java.io.Serializable
8 | {
9 | private String modelType; //resnet, inception
10 | private String modelName;
11 | private String inputName;
12 | private int[] inputShape;
13 | private int inputSize;
14 | private float[] meanValues;
15 | private float scale;
16 |
17 | private String optimizedModelDir;
18 |
19 | public String getModelName() {
20 | return modelName;
21 | }
22 |
23 | public void setModelName(String modelName) {
24 | this.modelName = modelName;
25 | }
26 |
27 | public String getOptimizedModelDir() {
28 | return optimizedModelDir;
29 | }
30 |
31 | public void setOptimizedModelDir(String optimizedModelDir) {
32 | this.optimizedModelDir = optimizedModelDir;
33 | }
34 |
35 | public String getModelType() {
36 | return modelType;
37 | }
38 |
39 | public void setModelType(String modelType) {
40 | this.modelType = modelType;
41 | }
42 |
43 | public String getInputName() {
44 | return inputName;
45 | }
46 |
47 | public void setInputName(String inputName) {
48 | this.inputName = inputName;
49 | }
50 |
51 | public int[] getInputShape() {
52 | return inputShape;
53 | }
54 |
55 | public void setInputShape(int[] inputShape) {
56 | this.inputShape = inputShape;
57 | }
58 |
59 | public int getInputSize() {
60 | int size = 1;
61 | for(int i=1;i {
21 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH);
22 | private String imageModelPathPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH);
23 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH");
24 | private List modelParamsList;
25 | private transient List modelList;
26 | private transient ImagePredictSupportiveMultipleModels supportive;
27 |
28 | public ImageFlatMapMultipleModels(List modelParamsList)
29 | {
30 | this.modelParamsList = modelParamsList;
31 | }
32 |
33 | @Override
34 | public void open(Configuration parameters) throws Exception
35 | {
36 | //For troubleshooting use.
37 | System.out.println(String.format("ImageFlatMap.open(): imageModelPath is %s", this.imageModelPath));
38 | System.out.println(String.format("ImageFlatMap.open(): modelInferencePath is %s", this.modelInferencePath));
39 | System.out.println(String.format("ImageFlatMap.open(): imageModelPathPackagePath is %s", this.imageModelPathPackagePath));
40 |
41 | //Step2: Load optimized OpenVino model from files (HDFS).
42 | // Cost about 1 seconds each model, quick enough. But it is not good solution to use too many models in client.
43 | modelList = new ArrayList();
44 | for(ModelParams modelParams:modelParamsList) {
45 | GarbageClassificationModel model =
46 | ImageModelLoaderMultipleModels.getInstance().loadOpenVINOModelOnce(modelParams);
47 | modelList.add(model);
48 | }
49 |
50 | //First time warm-dummy check
51 | ImageClassIndex.getInsatnce().loadClassIndexMap(modelInferencePath);
52 | this.supportive = new ImagePredictSupportiveMultipleModels(modelList, modelParamsList);
53 | this.supportive.firstTimeDummyCheck();
54 | }
55 |
56 | @Override
57 | public void flatMap(ImageData value, Collector out)
58 | throws Exception
59 | {
60 | IdLabel idLabel = new IdLabel();
61 | idLabel.setId(value.getId());
62 |
63 | long beginTime = System.currentTimeMillis();
64 | System.out.println(String.format("PREDICT_PROCESS BEGIN %s", beginTime));
65 |
66 | String imageLabelString = supportive.predictHumanString(value);
67 |
68 | long endTime = System.currentTimeMillis();
69 | System.out.println(String.format("PREDICT_PROCESS END %s (Cost: %s)", endTime, (endTime - beginTime)));
70 |
71 | //Check whether elapsed time >= threshold. Logging it for review.
72 | if((endTime - beginTime)>498)
73 | System.out.println(String.format("PREDICT_PROCESS MAYBE EXCEED 500ms %s (Cost: %s)",
74 | endTime, (endTime - beginTime)));
75 |
76 | idLabel.setLabel(imageLabelString);
77 | out.collect(idLabel);
78 | }
79 |
80 | @Override
81 | public void close() throws Exception
82 | {
83 | System.out.println(String.format("getNumberOfParallelSubtasks: %s",
84 | this.getRuntimeContext().getNumberOfParallelSubtasks()));
85 | for(GarbageClassificationModel model:modelList) {
86 | model.release();
87 | }
88 | }
89 |
90 | }
91 |
92 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/multiplemodels/ImageModelLoaderMultipleModels.java:
--------------------------------------------------------------------------------
1 | package sky.tf.multiplemodels;
2 |
3 | import com.google.common.io.Files;
4 | import com.intel.analytics.zoo.pipeline.inference.OpenVinoInferenceSupportive$;
5 | import org.apache.commons.io.FileUtils;
6 | import org.apache.flink.api.java.tuple.Tuple2;
7 | import org.apache.hadoop.fs.FileSystem;
8 | import org.apache.hadoop.fs.*;
9 | import org.apache.hadoop.hdfs.DistributedFileSystem;
10 | import org.slf4j.Logger;
11 | import org.slf4j.LoggerFactory;
12 | import sky.tf.GarbageClassificationModel;
13 | import sky.tf.ModelParams;
14 |
15 | import java.io.*;
16 | import java.net.URI;
17 | import java.nio.channels.Channels;
18 | import java.nio.channels.FileChannel;
19 | import java.nio.channels.ReadableByteChannel;
20 | import java.util.ArrayList;
21 | import java.util.HashMap;
22 | import java.util.List;
23 | import java.util.Map;
24 |
25 |
26 | /**
27 | * @author SkyPeace
28 | * The model loader. Experimental only.
29 | */
30 | public class ImageModelLoaderMultipleModels {
31 | private Logger logger = LoggerFactory.getLogger(ImageModelLoaderMultipleModels.class);
32 | private static ImageModelLoaderMultipleModels instance = new ImageModelLoaderMultipleModels();
33 | private volatile Map modelsMap =
34 | new HashMap();
35 |
36 | private ImageModelLoaderMultipleModels() {}
37 |
38 | public static ImageModelLoaderMultipleModels getInstance()
39 | {
40 | return instance;
41 | }
42 |
43 | /**
44 | * Generate optimized OpenVino model's data (bytes). -- First step.
45 | * @param savedModelPath
46 | * @return
47 | */
48 | public List> generateOpenVinoModelData(String savedModelPath, ModelParams modelParams) throws Exception
49 | {
50 | List> modelData = new ArrayList>();
51 | File modelFile = new File(savedModelPath);
52 | String optimizeModelPath = null;
53 | if(modelFile.isDirectory())
54 | optimizeModelPath =
55 | optimizeModelFromModelDir(savedModelPath, modelParams);
56 | else
57 | optimizeModelPath =
58 | optimizeModelFromModelPackage(savedModelPath, modelParams);
59 | byte[] xml = readDFSFile(optimizeModelPath + File.separator + "saved_model.xml");
60 | byte[] bin = readDFSFile(optimizeModelPath + File.separator + "saved_model.bin");
61 | logger.info("Size of optimized saved_model.xml: " + xml.length);
62 | logger.info("Size of optimized saved_model.bin: " + bin.length);
63 | Tuple2 xmlData = new Tuple2("xml", xml);
64 | Tuple2 binData = new Tuple2("bin", bin);
65 | modelData.add(xmlData);
66 | modelData.add(binData);
67 | return modelData;
68 | }
69 |
70 | /**
71 | * Optimize model from saved model dir. Return optimized model temp directory.
72 | * @param savedModelDir
73 | * @return
74 | * @throws Exception
75 | */
76 | private String optimizeModelFromModelDir(String savedModelDir, ModelParams modelParams) throws Exception
77 | {
78 | File tmpDir = Files.createTempDir();
79 | String optimizeModelTmpDir = tmpDir.getCanonicalPath();
80 | //OpenVinoInferenceSupportive.optimizeTFImageClassificationModel();
81 | OpenVinoInferenceSupportive$.MODULE$.optimizeTFImageClassificationModel(
82 | savedModelDir + File.separator + "SavedModel", modelParams.getInputShape(), false,
83 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName(), optimizeModelTmpDir);
84 | return optimizeModelTmpDir;
85 | }
86 |
87 | /**
88 | * Optimize model from saved model package (tar.gz file). Return optimized model temp dir.
89 | * @param savedModelPackagePath
90 | * @return
91 | * @throws Exception
92 | */
93 | private String optimizeModelFromModelPackage(String savedModelPackagePath, ModelParams modelParams) throws Exception
94 | {
95 | byte[] savedModelBytes = readDFSFile(savedModelPackagePath);
96 | File tmpDir = Files.createTempDir();
97 | String tempDirPath = tmpDir.getCanonicalPath();
98 | String tarFileName = "saved-model.tar";
99 | File tarFile = new File(tempDirPath + File.separator + tarFileName);
100 | ByteArrayInputStream tarFileInputStream = new ByteArrayInputStream(savedModelBytes);
101 | ReadableByteChannel tarFileSrc = Channels.newChannel(tarFileInputStream);
102 | FileChannel tarFileDest = (new FileOutputStream(tarFile)).getChannel();
103 | tarFileDest.transferFrom(tarFileSrc, 0L, 9223372036854775807L);
104 | tarFileDest.close();
105 | tarFileSrc.close();
106 | String tarFileAbsolutePath = tarFile.getAbsolutePath();
107 | String modelRootDir = tempDirPath + File.separator + "saved-model";
108 | File modelRootDirFile = new File(modelRootDir);
109 | FileUtils.forceMkdir(modelRootDirFile);
110 | //tar -xvf -C
111 | Process proc = Runtime.getRuntime().exec(new String[]{"tar", "-xvf", tarFileAbsolutePath, "-C", modelRootDir});
112 | //Runtime.getRuntime().exec(new String[]{"ls", "-l", modelRootDir});
113 | BufferedReader insertReader = new BufferedReader(new InputStreamReader(proc.getInputStream()));
114 | BufferedReader errorReader = new BufferedReader(new InputStreamReader(proc.getErrorStream()));
115 | proc.waitFor();
116 | if(insertReader!=null)
117 | insertReader.close();
118 | if(errorReader!=null)
119 | errorReader.close();
120 | File[] files = modelRootDirFile.listFiles();
121 | String savedModelTmpDir = files[0].getAbsolutePath();
122 | logger.info("Saved model temp dir will be used for optimization: " + savedModelTmpDir);
123 | String optimizeModelTmpDir = optimizeModelFromModelDir(savedModelTmpDir, modelParams);
124 | return optimizeModelTmpDir;
125 | }
126 |
127 | /**
128 | * Save optimized OpenVino model's data (bytes) into HDFS files
129 | * The specified dir is also the parent dir of saved model package.
130 | * @param openVinoModelData
131 | * @param modelParams
132 | * @return
133 | */
134 | public void saveToOpenVinoModelFile(List> openVinoModelData, ModelParams modelParams) throws Exception
135 | {
136 | writeOptimizedModelToDFS(openVinoModelData, modelParams);
137 | }
138 |
139 | /**
140 | * Load OpenVino model with singleton pattern. -- Second step.
141 | * In this case, use this solution by default because it only cost about 2 seconds to load in Map.open().
142 | * @param modelParams
143 | * @return
144 | */
145 | public synchronized GarbageClassificationModel loadOpenVINOModelOnce(ModelParams modelParams) throws Exception
146 | {
147 | GarbageClassificationModel model = modelsMap.get(modelParams.getModelName());
148 | if(model == null)
149 | {
150 | model = this.getModel(modelParams);
151 | modelsMap.put(modelParams.getModelName(), model);
152 | }
153 | model.addRefernce();
154 | return model;
155 | }
156 |
157 | private GarbageClassificationModel getModel(ModelParams modelParams) throws Exception
158 | {
159 | List> openVinoModelData =
160 | getModelDataFromOptimizedModelDir(modelParams);
161 | byte[] modelXml = openVinoModelData.get(0).f1;
162 | byte[] modelBin = openVinoModelData.get(1).f1;
163 | return new GarbageClassificationModel(modelXml, modelBin);
164 | }
165 |
166 | /**
167 | * Get OpenVino model data from optimized model files
168 | * @param modelParams
169 | * @return
170 | */
171 | private List> getModelDataFromOptimizedModelDir(ModelParams modelParams) throws Exception
172 | {
173 | String optimizedModelDir = modelParams.getOptimizedModelDir();
174 | String modelName = modelParams.getModelName();
175 | List> modelData = new ArrayList>();
176 | byte[] xml = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.xml");
177 | byte[] bin = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.bin");
178 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.xml: " + xml.length);
179 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.bin: " + bin.length);
180 | Tuple2 xmlData = new Tuple2("xml", xml);
181 | Tuple2 binData = new Tuple2("bin", bin);
182 | modelData.add(xmlData);
183 | modelData.add(binData);
184 | return modelData;
185 | }
186 |
187 | /**
188 | * Read saved model files (HDFS) into byte array.
189 | * @param filePath
190 | * @return
191 | */
192 | private byte[] readDFSFile(String filePath) throws Exception
193 | {
194 | try {
195 | long beginTime = System.currentTimeMillis();
196 | System.out.println(String.format("READ_MODEL_FILE from %s BEGIN %s", filePath, beginTime));
197 | Path imageRoot = new Path(filePath);
198 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration();
199 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName());
200 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName());
201 | FileSystem fileSystem = FileSystem.get(new URI(filePath), hadoopConfig);
202 | FileStatus fileStatus = fileSystem.getFileStatus(imageRoot);
203 | //RemoteIterator it = fileSystem.listFiles(imageRoot, false);
204 | long fileLength = fileStatus.getLen();
205 | FSDataInputStream in = fileSystem.open(imageRoot);
206 | byte[] buffer = new byte[(int) fileLength];
207 | in.readFully(buffer);
208 | in.close();
209 | long endTime = System.currentTimeMillis();
210 | System.out.println(String.format("READ_MODEL_FILE from %s END %s (Cost: %s)",
211 | filePath, endTime, (endTime - beginTime)));
212 | fileSystem.close();
213 | return buffer;
214 | }catch(Exception ex)
215 | {
216 | String msg = "Read DFS file FAILED. " + ex.getMessage();
217 | logger.error(msg, ex);
218 | throw ex;
219 | }
220 | }
221 |
222 | /**
223 | * Write the optimized model's data to files (HDFS)
224 | * @param openVinoModelData
225 | * @param modelParams
226 | */
227 | private void writeOptimizedModelToDFS(List> openVinoModelData, ModelParams modelParams) throws Exception
228 | {
229 | String optimizedModelDir = modelParams.getOptimizedModelDir();
230 | String modelName = modelParams.getModelName();
231 | try {
232 | long beginTime = System.currentTimeMillis();
233 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s BEGIN %s", optimizedModelDir, beginTime));
234 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration();
235 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName());
236 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName());
237 | FileSystem fileSystem = FileSystem.get(new URI(optimizedModelDir), hadoopConfig);
238 |
239 | String xmlFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.xml";
240 | System.out.println(String.format("Model xmlFileName: %s", xmlFileName));
241 | Path xmlFilePath = new Path(xmlFileName);
242 | FSDataOutputStream xmlFileOut = fileSystem.create(xmlFilePath, true);
243 | xmlFileOut.write(openVinoModelData.get(0).f1);
244 | xmlFileOut.flush();
245 | xmlFileOut.close();
246 |
247 | String binFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.bin";
248 | System.out.println(String.format("Model binFileName: %s", xmlFileName));
249 | Path binFilePath = new Path(binFileName);
250 | FSDataOutputStream binFileOut = fileSystem.create(binFilePath, true);
251 | binFileOut.write(openVinoModelData.get(1).f1);
252 | binFileOut.flush();
253 | binFileOut.close();
254 |
255 | long endTime = System.currentTimeMillis();
256 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s END %s (Cost: %s)",
257 | optimizedModelDir, endTime, (endTime - beginTime)));
258 | fileSystem.close();
259 | }catch(Exception ex)
260 | {
261 | String msg = "Write DFS file FAILED. " + ex.getMessage();
262 | logger.error(msg, ex);
263 | throw ex;
264 | }
265 | }
266 |
267 | }
268 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/multiplemodels/ImagePredictSupportiveMultipleModels.java:
--------------------------------------------------------------------------------
1 | package sky.tf.multiplemodels;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ImageData;
4 | import com.intel.analytics.bigdl.opencv.OpenCV;
5 | import com.intel.analytics.zoo.pipeline.inference.JTensor;
6 | import org.opencv.core.Mat;
7 | import sky.tf.*;
8 |
9 | import java.util.*;
10 |
11 | /**
12 | * @author SkyPeace
13 | * The supportive for image classification prediction. Experimental only.
14 | */
15 | public class ImagePredictSupportiveMultipleModels
16 | {
17 | //private Logger logger = LoggerFactory.getLogger(ImagePredictSupportiveMultipleModels.class);
18 | private List modelParamsList;
19 | private List modelList;
20 |
21 | public ImagePredictSupportiveMultipleModels(List modelList, List modelParamsList)
22 | {
23 | this.modelList = modelList;
24 | this.modelParamsList = modelParamsList;
25 | }
26 |
27 | /**
28 | * Do predict and transfer the label ID to human string.
29 | * @param imageData
30 | * @return
31 | * @throws Exception
32 | */
33 | public String predictHumanString(ImageData imageData) throws Exception
34 | {
35 | Integer predictId = predict(imageData);
36 | String humanString = this.getImageHumanString(predictId);
37 | if (humanString == null) {
38 | humanString = "class_index_not_available";
39 | throw new Exception(humanString);
40 | }
41 | return humanString;
42 | }
43 |
44 | /**
45 | * Using multiple models do predict.
46 | * @param imageData
47 | * @return
48 | * @throws Exception
49 | */
50 | private Integer predict(ImageData imageData) throws Exception {
51 | //Decode jpeg
52 | //Use turbojpeg as jpeg decoder. It is more fast than OpenCV decoder (OpenCV use libjpeg as decoder).
53 | //Refer to https://libjpeg-turbo.org/About/Performance for performance comparison.
54 | Mat matRGB = ImageDataPreprocessing.getRGBMat(imageData, ImageDataPreprocessing.IMAGE_DECODER_TURBOJPEG);
55 |
56 | long t1 = System.currentTimeMillis();
57 | System.out.println(String.format("%s DO_PREDICT BEGIN %s", "###", t1));
58 | List modelsPredictionResults = new ArrayList();
59 | for(int i=0;i> inputs = dataPreprocessing.doPreProcessing(matRGB, preprocessingType, enableMultipleVoters);
87 | List> tensorResults = model.predict(inputs);
88 | List pResults = this.convertToPredictionResults(tensorResults);
89 | PredictionResult pResult = this.getSingleResult(pResults);
90 | return pResult;
91 | }
92 |
93 | /**
94 | * Get single result
95 | * @param pResults
96 | * @return
97 | */
98 | private PredictionResult getSingleResult(List pResults)
99 | {
100 | if(pResults.size()==1)
101 | return pResults.get(0);
102 | System.out.println(String.format("There are %s multiple results", pResults.size()));
103 | Map distinctResults = new HashMap();
104 | for(PredictionResult pResult:pResults)
105 | {
106 | Integer predictionId = pResult.getPredictionId();
107 | PredictionResult existResult = distinctResults.get(predictionId);
108 | if(existResult == null){
109 | distinctResults.put(predictionId, pResult);
110 | }else{
111 | existResult.setCount(existResult.getCount() + 1);
112 | existResult.setProbability(Math.max(pResult.getProbability(), existResult.getProbability()));
113 | distinctResults.put(predictionId, existResult);
114 | }
115 | }
116 | PredictionResult maxCountResult = null;
117 | for(Integer key:distinctResults.keySet())
118 | {
119 | PredictionResult pResult = distinctResults.get(key);
120 | float ratio = (float)pResult.getCount() / pResults.size();
121 | if(ratio>0.5f)
122 | {
123 | System.out.println("Majority get agreement.");
124 | maxCountResult = pResult;
125 | return maxCountResult;
126 | }else if(pResult.getCount()>1)
127 | maxCountResult = pResult;
128 | }
129 | if(maxCountResult!=null) {
130 | System.out.println("Few models get agreement.");
131 | return maxCountResult;
132 | }
133 |
134 | System.out.println("The models can NOT get agreement.");
135 | float maxProbability = pResults.get(0).getProbability();
136 | int maxIdx = 0;
137 | for (int i = 1; i < pResults.size(); i++) {
138 | if (pResults.get(i).getProbability() > maxProbability) {
139 | maxProbability = pResults.get(i).getProbability();
140 | maxIdx = i;
141 | }
142 | }
143 | return pResults.get(maxIdx);
144 | }
145 |
146 | private List convertToPredictionResults(List> result) throws Exception
147 | {
148 | long beginTime = System.currentTimeMillis();
149 | List pResults = new ArrayList();
150 | System.out.println("result.size(): " + result.size());
151 | System.out.println("result.get(0).size(): " + result.get(0).size());
152 | for(int j=0; j 0) {
156 | float maxProbability = predictData[0];
157 | int maxNo = 0;
158 | for (int i = 1; i < predictData.length; i++) {
159 | if (predictData[i] > maxProbability) {
160 | maxProbability = predictData[i];
161 | maxNo = i;
162 | }
163 | }
164 | PredictionResult pResult = new PredictionResult();
165 | pResult.setPredictionId(maxNo);
166 | pResult.setProbability(maxProbability);
167 | pResult.setCount(1);
168 | pResults.add(pResult);
169 | }else{
170 | throw new Exception("ERROR: predictData.length=0");
171 | }
172 | }
173 | long endTime = System.currentTimeMillis();
174 | if((endTime-beginTime)>5)
175 | System.out.println(String.format("%s CONVERT_TO_PREDICTION_RESULT END %s (Cost: %s)",
176 | "###", endTime, (endTime - beginTime)));
177 | return pResults;
178 | }
179 |
180 | /**
181 | * Load turbojpeg library.
182 | * @throws Exception
183 | */
184 | private void loadTurboJpeg()
185 | {
186 | long beginTime = System.currentTimeMillis();
187 | if(!TurboJpegLoader.isTurbojpegLoaded()) {
188 | throw new RuntimeException("LOAD_TURBOJPEG library failed. Please check.");
189 | }
190 | long endTime = System.currentTimeMillis();
191 | if((endTime - beginTime) > 1)
192 | System.out.println(String.format("LOAD_TURBOJPEG END %s (Cost: %s)", endTime, (endTime - beginTime)));
193 | }
194 |
195 | /**
196 | * Load OpenCV library.
197 | * @throws Exception
198 | */
199 | private void loadOpenCV()
200 | {
201 | long beginTime = System.currentTimeMillis();
202 | if(!OpenCV.isOpenCVLoaded()) {
203 | throw new RuntimeException("LOAD_OPENCV library failed. Please check.");
204 | }
205 | long endTime = System.currentTimeMillis();
206 | if((endTime - beginTime) > 10)
207 | System.out.println(String.format("LOAD_OPENCV END %s (Cost: %s)", endTime, (endTime - beginTime)));
208 |
209 | //if(!OpenvinoNativeLoader.load())
210 | // throw new RuntimeException("LOAD_Openvino library failed. Please check.");
211 | //if(!TFNetNative.isLoaded())
212 | // throw new RuntimeException("LOAD_TFNetNative library failed. Please check.");
213 | }
214 |
215 | /**
216 | * Get image class human string.
217 | * @return
218 | */
219 | private String getImageHumanString(Integer id)
220 | {
221 | return ImageClassIndex.getInsatnce().getImageHumanstring(String.valueOf(id));
222 | }
223 |
224 | /**
225 | * For first time warm-dummy check only.
226 | * @throws Exception
227 | */
228 | public void firstTimeDummyCheck() throws Exception
229 | {
230 | for(int i=0; i();
250 | list.add(tensor);
251 | List> inputs = new ArrayList>();
252 | inputs.add(list);
253 | pModel.predict(inputs);
254 |
255 | long endTime = System.currentTimeMillis();
256 | System.out.println(String.format("FIRST_PREDICT END %s (Cost: %s)", endTime, (endTime - beginTime)));
257 | }
258 |
259 | }
260 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/multiplemodels/OpenVinoModelGeneratorMultipleModels.java:
--------------------------------------------------------------------------------
1 | package sky.tf.multiplemodels;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant;
4 | import org.apache.flink.api.java.tuple.Tuple2;
5 | import sky.tf.ModelParams;
6 |
7 | import java.io.File;
8 | import java.util.ArrayList;
9 | import java.util.List;
10 |
11 | /**
12 | * The class use for generate optimized OpenVino models (Multiple different models for prediction). Experimental only.
13 | * @author SkyPeace
14 | */
15 | public class OpenVinoModelGeneratorMultipleModels {
16 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH);
17 | private String imageModelPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH);
18 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH");
19 | private List modelParamsList;
20 |
21 | public void execute()
22 | {
23 | if (null == this.modelInferencePath) {
24 | throw new RuntimeException("ImageFlatMap(): Not set MODEL_INFERENCE_PATH environmental variable");
25 | }
26 | if (null == this.imageModelPackagePath) {
27 | throw new RuntimeException("ImageFlatMap(): Not set imageModelPathPackagePath environmental variable");
28 | }
29 | String imageModelPackageDir =
30 | imageModelPackagePath.substring(0, imageModelPackagePath.lastIndexOf(File.separator));
31 | if(!imageModelPackageDir.equalsIgnoreCase(modelInferencePath))
32 | {
33 | System.out.println("WARN: modelInferencePath NOT EQUAL imageModelPathPackageDir");
34 | System.out.println(String.format("modelInferencePath: %s", modelInferencePath));
35 | System.out.println(String.format("imageModelPackageDir: %s", imageModelPackageDir));
36 | System.out.println(String.format("imageModelPath: %s", imageModelPath));
37 | }
38 |
39 | try {
40 | //Generate optimized OpenVino model
41 | this.modelParamsList = new ArrayList();
42 | //Here only generate 5 models for test.
43 | for(int i=0;i<5;i++) {
44 | String tfModelPath = modelInferencePath + File.separator + "SavedModel/model" + i;
45 | String optimizedOpenVinoModelDir = imageModelPackageDir;
46 | ModelParams modelParams = new ModelParams();
47 | modelParams.setModelType("inception");
48 | modelParams.setModelName("model" + i);
49 | modelParams.setInputName("input_1");
50 | modelParams.setInputShape(new int[]{1, 299, 299, 3});
51 | modelParams.setMeanValues(new float[]{127.5f, 127.5f, 127.5f});
52 | modelParams.setScale(127.5f);
53 | modelParams.setOptimizedModelDir(optimizedOpenVinoModelDir);
54 | //Call Zoo API generate optimized OpenVino Model
55 | List> optimizedOpenVinoModelData =
56 | ImageModelLoaderMultipleModels.getInstance().generateOpenVinoModelData(tfModelPath, modelParams);
57 | ImageModelLoaderMultipleModels.getInstance().
58 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, modelParams);
59 | modelParamsList.add(modelParams);
60 | }
61 | }catch(Exception ex)
62 | {
63 | ex.printStackTrace();
64 | throw new RuntimeException("WARN: OpenVinoModelGenerator.execute() FAILED.");
65 | }
66 | }
67 |
68 | public List getModelParamsList()
69 | {
70 | return this.modelParamsList;
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/threemodels/ImageFlatMap3Models.java:
--------------------------------------------------------------------------------
1 | package sky.tf.threemodels;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant;
4 | import com.alibaba.tianchi.garbage_image_util.IdLabel;
5 | import com.alibaba.tianchi.garbage_image_util.ImageData;
6 | import org.apache.flink.api.common.functions.RichFlatMapFunction;
7 | import org.apache.flink.configuration.Configuration;
8 | import org.apache.flink.util.Collector;
9 | import sky.tf.GarbageClassificationModel;
10 | import sky.tf.ImageClassIndex;
11 | import sky.tf.ModelParams;
12 |
13 | /**
14 | * @author SkyPeace
15 | * The Map operator for predict the class of image.
16 | */
17 | public class ImageFlatMap3Models extends RichFlatMapFunction {
18 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH);
19 | private String imageModelPathPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH);
20 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH");
21 | private ModelParams model1Params;
22 | private ModelParams model2Params;
23 | private ModelParams model3Params;
24 | private transient GarbageClassificationModel model1;
25 | private transient GarbageClassificationModel model2;
26 | private transient GarbageClassificationModel model3;
27 | private transient ImagePredictSupportive3Models supportive;
28 |
29 | public ImageFlatMap3Models(ModelParams model1Params, ModelParams model2Params, ModelParams model3Params)
30 | {
31 | this.model1Params = model1Params;
32 | this.model2Params = model2Params;
33 | this.model3Params = model3Params;
34 | }
35 |
36 | @Override
37 | public void open(Configuration parameters) throws Exception
38 | {
39 | //For troubleshooting use.
40 | System.out.println(String.format("ImageFlatMap.open(): imageModelPath is %s", this.imageModelPath));
41 | System.out.println(String.format("ImageFlatMap.open(): modelInferencePath is %s", this.modelInferencePath));
42 | System.out.println(String.format("ImageFlatMap.open(): imageModelPathPackagePath is %s", this.imageModelPathPackagePath));
43 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel1Dir is %s",
44 | this.model1Params.getOptimizedModelDir()));
45 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel2Dir is %s",
46 | this.model2Params.getOptimizedModelDir()));
47 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel3Dir is %s",
48 | this.model3Params.getOptimizedModelDir()));
49 |
50 | //Step2: Load optimized OpenVino model from files (HDFS). (Cost about 3 seconds total models, quickly enough)
51 | this.model1 = ImageModelLoader3Models.getInstance().loadOpenVINOModel1Once(this.model1Params);
52 | this.model2 = ImageModelLoader3Models.getInstance().loadOpenVINOModel2Once(this.model2Params);
53 | this.model3 = ImageModelLoader3Models.getInstance().loadOpenVINOModel3Once(this.model3Params);
54 |
55 | ImageClassIndex.getInsatnce().loadClassIndexMap(modelInferencePath);
56 | ValidationDatasetAnalyzer validationTrainer = new ValidationDatasetAnalyzer();
57 | validationTrainer.searchOptimizedThreshold(modelInferencePath);
58 | float preferableThreshold1 = validationTrainer.getPreferableThreshold1();
59 | float preferableThreshold2 = validationTrainer.getPreferableThreshold2();
60 | this.supportive = new ImagePredictSupportive3Models(this.model1, this.model1Params,
61 | this.model2, this.model2Params, this.model3, this.model3Params,
62 | preferableThreshold1, preferableThreshold2);
63 | //First time warm-dummy check
64 | this.supportive.firstTimeDummyCheck();
65 | }
66 |
67 | @Override
68 | public void flatMap(ImageData value, Collector out)
69 | throws Exception
70 | {
71 | IdLabel idLabel = new IdLabel();
72 | idLabel.setId(value.getId());
73 |
74 | long beginTime = System.currentTimeMillis();
75 | System.out.println(String.format("PREDICT_PROCESS BEGIN %s", beginTime));
76 |
77 | String imageLabelString = supportive.predictHumanString(value);
78 |
79 | long endTime = System.currentTimeMillis();
80 | System.out.println(String.format("PREDICT_PROCESS END %s (Cost: %s)", endTime, (endTime - beginTime)));
81 |
82 | //Check whether elapsed time >= threshold. Logging it for review.
83 | if((endTime - beginTime)>495)
84 | System.out.println(String.format("PREDICT_PROCESS MAYBE EXCEED THRESHOLD %s (Cost: %s)",
85 | endTime, (endTime - beginTime)));
86 |
87 | idLabel.setLabel(imageLabelString);
88 | out.collect(idLabel);
89 | }
90 |
91 | @Override
92 | public void close() throws Exception
93 | {
94 | System.out.println(String.format("getNumberOfParallelSubtasks: %s",
95 | this.getRuntimeContext().getNumberOfParallelSubtasks()));
96 | if(model1!=null)
97 | model1.release();
98 | if(model2!=null)
99 | model2.release();
100 | if(model3!=null)
101 | model3.release();
102 | }
103 |
104 | }
105 |
106 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/threemodels/ImageModelLoader3Models.java:
--------------------------------------------------------------------------------
1 | package sky.tf.threemodels;
2 |
3 | import com.google.common.io.Files;
4 | import com.intel.analytics.zoo.pipeline.inference.OpenVinoInferenceSupportive$;
5 | import org.apache.commons.io.FileUtils;
6 | import org.apache.flink.api.java.tuple.Tuple2;
7 | import org.apache.hadoop.fs.*;
8 | import org.apache.hadoop.hdfs.DistributedFileSystem;
9 | import org.slf4j.Logger;
10 | import org.slf4j.LoggerFactory;
11 | import sky.tf.GarbageClassificationModel;
12 | import sky.tf.ModelParams;
13 |
14 | import java.io.*;
15 | import java.net.URI;
16 | import java.nio.channels.Channels;
17 | import java.nio.channels.FileChannel;
18 | import java.nio.channels.ReadableByteChannel;
19 | import java.util.ArrayList;
20 | import java.util.List;
21 |
22 |
23 | /**
24 | * @author SkyPeace
25 | * The model loader
26 | */
27 | public class ImageModelLoader3Models {
28 | private Logger logger = LoggerFactory.getLogger(ImageModelLoader3Models.class);
29 | private static ImageModelLoader3Models instance = new ImageModelLoader3Models();
30 | private volatile GarbageClassificationModel model1;
31 | private volatile GarbageClassificationModel model2;
32 | private volatile GarbageClassificationModel model3;
33 |
34 | private ImageModelLoader3Models() {}
35 |
36 | public static ImageModelLoader3Models getInstance()
37 | {
38 | return instance;
39 | }
40 |
41 | /**
42 | * Generate optimized OpenVino model's data (bytes). -- First step.
43 | * @param savedModelPath
44 | * @return
45 | */
46 | public List> generateOpenVinoModelData(String savedModelPath, ModelParams modelParams) throws Exception
47 | {
48 | List> modelData = new ArrayList>();
49 | File modelFile = new File(savedModelPath);
50 | String optimizeModelPath = null;
51 | if(modelFile.isDirectory())
52 | optimizeModelPath =
53 | optimizeModelFromModelDir(savedModelPath, modelParams);
54 | else
55 | optimizeModelPath =
56 | optimizeModelFromModelPackage(savedModelPath, modelParams);
57 | byte[] xml = readDFSFile(optimizeModelPath + File.separator + "saved_model.xml");
58 | byte[] bin = readDFSFile(optimizeModelPath + File.separator + "saved_model.bin");
59 | logger.info("Size of optimized saved_model.xml: " + xml.length);
60 | logger.info("Size of optimized saved_model.bin: " + bin.length);
61 | Tuple2 xmlData = new Tuple2("xml", xml);
62 | Tuple2 binData = new Tuple2("bin", bin);
63 | modelData.add(xmlData);
64 | modelData.add(binData);
65 | return modelData;
66 | }
67 |
68 | /**
69 | * Optimize model from saved model dir. Return optimized model temp directory.
70 | * @param savedModelDir
71 | * @return
72 | * @throws Exception
73 | */
74 | private String optimizeModelFromModelDir(String savedModelDir, ModelParams modelParams) throws Exception
75 | {
76 | File tmpDir = Files.createTempDir();
77 | String optimizeModelTmpDir = tmpDir.getCanonicalPath();
78 | //OpenVinoInferenceSupportive.optimizeTFImageClassificationModel();
79 | OpenVinoInferenceSupportive$.MODULE$.optimizeTFImageClassificationModel(
80 | savedModelDir + File.separator + "SavedModel", modelParams.getInputShape(), false,
81 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName(), optimizeModelTmpDir);
82 | return optimizeModelTmpDir;
83 | }
84 |
85 | /**
86 | * Optimize model from saved model package (tar.gz file). Return optimized model temp dir.
87 | * @param savedModelPackagePath
88 | * @return
89 | * @throws Exception
90 | */
91 | private String optimizeModelFromModelPackage(String savedModelPackagePath, ModelParams modelParams) throws Exception
92 | {
93 | byte[] savedModelBytes = readDFSFile(savedModelPackagePath);
94 | File tmpDir = Files.createTempDir();
95 | String tempDirPath = tmpDir.getCanonicalPath();
96 | String tarFileName = "saved-model.tar";
97 | File tarFile = new File(tempDirPath + File.separator + tarFileName);
98 | ByteArrayInputStream tarFileInputStream = new ByteArrayInputStream(savedModelBytes);
99 | ReadableByteChannel tarFileSrc = Channels.newChannel(tarFileInputStream);
100 | FileChannel tarFileDest = (new FileOutputStream(tarFile)).getChannel();
101 | tarFileDest.transferFrom(tarFileSrc, 0L, 9223372036854775807L);
102 | tarFileDest.close();
103 | tarFileSrc.close();
104 | String tarFileAbsolutePath = tarFile.getAbsolutePath();
105 | String modelRootDir = tempDirPath + File.separator + "saved-model";
106 | File modelRootDirFile = new File(modelRootDir);
107 | FileUtils.forceMkdir(modelRootDirFile);
108 | //tar -xvf -C
109 | Process proc = Runtime.getRuntime().exec(new String[]{"tar", "-xvf", tarFileAbsolutePath, "-C", modelRootDir});
110 | //Runtime.getRuntime().exec(new String[]{"ls", "-l", modelRootDir});
111 | BufferedReader insertReader = new BufferedReader(new InputStreamReader(proc.getInputStream()));
112 | BufferedReader errorReader = new BufferedReader(new InputStreamReader(proc.getErrorStream()));
113 | proc.waitFor();
114 | if(insertReader!=null)
115 | insertReader.close();
116 | if(errorReader!=null)
117 | errorReader.close();
118 | File[] files = modelRootDirFile.listFiles();
119 | String savedModelTmpDir = files[0].getAbsolutePath();
120 | logger.info("Saved model temp dir will be used for optimization: " + savedModelTmpDir);
121 | String optimizeModelTmpDir = optimizeModelFromModelDir(savedModelTmpDir, modelParams);
122 | return optimizeModelTmpDir;
123 | }
124 |
125 | /**
126 | * Save optimized OpenVino model's data (bytes) into HDFS files
127 | * The specified dir is also the parent dir of saved model package.
128 | * @param openVinoModelData
129 | * @param modelParams
130 | * @return
131 | */
132 | public void saveToOpenVinoModelFile(List> openVinoModelData, ModelParams modelParams) throws Exception
133 | {
134 | writeOptimizedModelToDFS(openVinoModelData, modelParams);
135 | }
136 |
137 | /**
138 | * Load OpenVino model with singleton pattern. -- Second step.
139 | * In this case, use this solution by default because it only cost about 2 seconds to load in Map.open().
140 | * @param modelParams
141 | * @return
142 | */
143 | public synchronized GarbageClassificationModel loadOpenVINOModel1Once(ModelParams modelParams) throws Exception
144 | {
145 | if(this.model1 == null)
146 | {
147 | this.model1 = this.getModel(modelParams);
148 | }
149 | this.model1.addRefernce();
150 | return this.model1;
151 | }
152 |
153 | public synchronized GarbageClassificationModel loadOpenVINOModel2Once(ModelParams modelParams) throws Exception
154 | {
155 | if(this.model2 == null)
156 | {
157 | this.model2 = this.getModel(modelParams);
158 | }
159 | this.model2.addRefernce();
160 | return this.model2;
161 | }
162 |
163 | public synchronized GarbageClassificationModel loadOpenVINOModel3Once(ModelParams modelParams) throws Exception
164 | {
165 | if(this.model3 == null)
166 | {
167 | this.model3 = this.getModel(modelParams);
168 | }
169 | this.model3.addRefernce();
170 | return this.model3;
171 | }
172 |
173 | private GarbageClassificationModel getModel(ModelParams modelParams) throws Exception
174 | {
175 | List> openVinoModelData =
176 | getModelDataFromOptimizedModelDir(modelParams);
177 | byte[] modelXml = openVinoModelData.get(0).f1;
178 | byte[] modelBin = openVinoModelData.get(1).f1;
179 | return new GarbageClassificationModel(modelXml, modelBin);
180 | }
181 |
182 | /**
183 | * Get OpenVino model data from optimized model files
184 | * @param modelParams
185 | * @return
186 | */
187 | private List> getModelDataFromOptimizedModelDir(ModelParams modelParams) throws Exception
188 | {
189 | String optimizedModelDir = modelParams.getOptimizedModelDir();
190 | String modelName = modelParams.getModelName();
191 | List> modelData = new ArrayList>();
192 | byte[] xml = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.xml");
193 | byte[] bin = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.bin");
194 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.xml: " + xml.length);
195 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.bin: " + bin.length);
196 | Tuple2 xmlData = new Tuple2("xml", xml);
197 | Tuple2 binData = new Tuple2("bin", bin);
198 | modelData.add(xmlData);
199 | modelData.add(binData);
200 | return modelData;
201 | }
202 |
203 | /**
204 | * Read saved model files (HDFS) into byte array.
205 | * @param filePath
206 | * @return
207 | */
208 | private byte[] readDFSFile(String filePath) throws Exception
209 | {
210 | try {
211 | long beginTime = System.currentTimeMillis();
212 | System.out.println(String.format("READ_MODEL_FILE from %s BEGIN %s", filePath, beginTime));
213 | Path imageRoot = new Path(filePath);
214 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration();
215 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName());
216 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName());
217 | FileSystem fileSystem = FileSystem.get(new URI(filePath), hadoopConfig);
218 | FileStatus fileStatus = fileSystem.getFileStatus(imageRoot);
219 | //RemoteIterator it = fileSystem.listFiles(imageRoot, false);
220 | long fileLength = fileStatus.getLen();
221 | FSDataInputStream in = fileSystem.open(imageRoot);
222 | byte[] buffer = new byte[(int) fileLength];
223 | in.readFully(buffer);
224 | in.close();
225 | long endTime = System.currentTimeMillis();
226 | System.out.println(String.format("READ_MODEL_FILE from %s END %s (Cost: %s)",
227 | filePath, endTime, (endTime - beginTime)));
228 | fileSystem.close();
229 | return buffer;
230 | }catch(Exception ex)
231 | {
232 | String msg = "Read DFS file FAILED. " + ex.getMessage();
233 | logger.error(msg, ex);
234 | throw ex;
235 | }
236 | }
237 |
238 | /**
239 | * Write the optimized model's data to files (HDFS)
240 | * @param openVinoModelData
241 | * @param modelParams
242 | */
243 | private void writeOptimizedModelToDFS(List> openVinoModelData, ModelParams modelParams) throws Exception
244 | {
245 | String optimizedModelDir = modelParams.getOptimizedModelDir();
246 | String modelName = modelParams.getModelName();
247 | try {
248 | long beginTime = System.currentTimeMillis();
249 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s BEGIN %s", optimizedModelDir, beginTime));
250 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration();
251 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName());
252 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName());
253 | FileSystem fileSystem = FileSystem.get(new URI(optimizedModelDir), hadoopConfig);
254 |
255 | String xmlFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.xml";
256 | System.out.println(String.format("Model xmlFileName: %s", xmlFileName));
257 | Path xmlFilePath = new Path(xmlFileName);
258 | FSDataOutputStream xmlFileOut = fileSystem.create(xmlFilePath, true);
259 | xmlFileOut.write(openVinoModelData.get(0).f1);
260 | xmlFileOut.flush();
261 | xmlFileOut.close();
262 |
263 | String binFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.bin";
264 | System.out.println(String.format("Model binFileName: %s", xmlFileName));
265 | Path binFilePath = new Path(binFileName);
266 | FSDataOutputStream binFileOut = fileSystem.create(binFilePath, true);
267 | binFileOut.write(openVinoModelData.get(1).f1);
268 | binFileOut.flush();
269 | binFileOut.close();
270 |
271 | long endTime = System.currentTimeMillis();
272 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s END %s (Cost: %s)",
273 | optimizedModelDir, endTime, (endTime - beginTime)));
274 | fileSystem.close();
275 | }catch(Exception ex)
276 | {
277 | String msg = "Write DFS file FAILED. " + ex.getMessage();
278 | logger.error(msg, ex);
279 | throw ex;
280 | }
281 | }
282 |
283 | }
284 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/threemodels/ImagePredictSupportive3Models.java:
--------------------------------------------------------------------------------
1 | package sky.tf.threemodels;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ImageData;
4 | import com.intel.analytics.bigdl.opencv.OpenCV;
5 | import com.intel.analytics.zoo.pipeline.inference.JTensor;
6 | import org.opencv.core.*;
7 | import org.slf4j.Logger;
8 | import org.slf4j.LoggerFactory;
9 | import sky.tf.*;
10 |
11 | import java.util.*;
12 |
13 | /**
14 | * @author SkyPeace
15 | * The supportive for image classification prediction.
16 | */
17 | public class ImagePredictSupportive3Models
18 | {
19 | private Logger logger = LoggerFactory.getLogger(ImagePredictSupportive3Models.class);
20 | final long MAX_PROCESSING_TIME = 400; //ms. 400
21 |
22 | private GarbageClassificationModel model1;
23 | private ModelParams model1Params;
24 | private GarbageClassificationModel model2;
25 | private ModelParams model2Params;
26 | private GarbageClassificationModel model3;
27 | private ModelParams model3Params;
28 |
29 | private float preferableThreshold1 = 0.94f;
30 | private float preferableThreshold2 = 0.81f;
31 |
32 | public ImagePredictSupportive3Models(GarbageClassificationModel model1, ModelParams model1Params,
33 | GarbageClassificationModel model2, ModelParams model2Params,
34 | GarbageClassificationModel model3, ModelParams model3Params,
35 | float preferableThreshold1, float preferableThreshold2)
36 | {
37 | this.model1 = model1;
38 | this.model1Params = model1Params;
39 | this.model2 = model2;
40 | this.model2Params = model2Params;
41 | this.model3 = model3;
42 | this.model3Params = model3Params;
43 | this.preferableThreshold1 = preferableThreshold1;
44 | this.preferableThreshold2 = preferableThreshold2;
45 | }
46 |
47 | /**
48 | * Do predict and transfer the label ID to human string.
49 | * @param imageData
50 | * @return
51 | * @throws Exception
52 | */
53 | public String predictHumanString(ImageData imageData) throws Exception
54 | {
55 | Integer predictId = predict(imageData);
56 | String humanString = this.getImageHumanString(predictId);
57 | if (humanString == null) {
58 | humanString = "class_index_not_available";
59 | throw new Exception(humanString);
60 | }
61 | return humanString;
62 | }
63 |
64 | /**
65 | * Do predict using 3 different models
66 | * @param imageData
67 | * @return
68 | * @throws Exception
69 | */
70 | public Integer predict(ImageData imageData) throws Exception {
71 |
72 | long processBeginTime = System.currentTimeMillis();
73 | //Decode jpeg
74 | //Use turbojpeg as jpeg decoder. It is more fast than OpenCV decoder (OpenCV use libjpeg as decoder).
75 | //Refer to https://libjpeg-turbo.org/About/Performance for performance comparison.
76 | Mat matRGB = ImageDataPreprocessing.getRGBMat(imageData, ImageDataPreprocessing.IMAGE_DECODER_TURBOJPEG);
77 |
78 | //Model1 do prediction
79 | ImageDataPreprocessing dataPreprocessing = new ImageDataPreprocessing(model1Params);
80 | List> inputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_VGG, false);
81 | long t1 = System.currentTimeMillis();
82 | System.out.println(String.format("%s MODEL1_DO_PREDICT BEGIN %s", "###", t1));
83 | List> tensorResults = this.model1.predict(inputs);
84 | if (tensorResults == null && tensorResults.get(0) == null || tensorResults.get(0).size() == 0) {
85 | throw new Exception(String.format("ERROR: %s Model1 predict result is null.", imageData.getId()));
86 | }
87 | long t2 = System.currentTimeMillis();
88 | System.out.println(String.format("%s MODEL1_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1)));
89 |
90 | List primaryResults = this.convertToPredictionResults(tensorResults);
91 | PredictionResult primaryResult = this.getSingleResult(primaryResults);
92 | Integer primaryPredictionId = primaryResult.getPredictionId();
93 |
94 | if (primaryResult.getProbability() >= preferableThreshold1) {
95 | System.out.println("Model1's HP predictionId saved time.");
96 | return primaryPredictionId;
97 | }
98 |
99 | long timeUsed = System.currentTimeMillis() - processBeginTime;
100 | if(timeUsed>=MAX_PROCESSING_TIME)
101 | {
102 | System.out.println("There maybe no enough time to try multiple modles to do predict.");
103 | return primaryPredictionId;
104 | }
105 |
106 | System.out.println("There is enough time to try multiple modles to do predict.");
107 |
108 | //Use model2 do prediction
109 | dataPreprocessing = new ImageDataPreprocessing(model2Params);
110 | List> secondaryInputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_INCEPTION, false);
111 | t1 = System.currentTimeMillis();
112 | System.out.println(String.format("%s MODEL2_DO_PREDICT BEGIN %s", "###", t1));
113 | List> secondaryTensorResults = this.model2.predict(secondaryInputs);
114 | if (secondaryTensorResults == null && secondaryTensorResults.get(0) == null || secondaryTensorResults.get(0).size() == 0) {
115 | throw new Exception(String.format("ERROR: %s Model2 predict result is null.", imageData.getId()));
116 | }
117 | t2 = System.currentTimeMillis();
118 | System.out.println(String.format("%s MODEL2_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1)));
119 |
120 | List secondaryResults = this.convertToPredictionResults(secondaryTensorResults);
121 | PredictionResult secondaryResult = this.getSingleResult(secondaryResults);
122 | Integer secondaryPredictionId = secondaryResult.getPredictionId();
123 |
124 | if (secondaryResult.getProbability() >= preferableThreshold1)
125 | {
126 | if (!secondaryPredictionId.equals(primaryPredictionId)) {
127 | System.out.println("Force use model2's HP predictionId.");
128 | } else {
129 | System.out.println("Model2's HP predictionId saved time.");
130 | }
131 | return secondaryPredictionId;
132 | }
133 |
134 | //Use model3 do prediction
135 | dataPreprocessing = new ImageDataPreprocessing(model3Params);
136 | List> thirdInputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_INCEPTION, false);
137 | t1 = System.currentTimeMillis();
138 | System.out.println(String.format("%s MODEL3_DO_PREDICT BEGIN %s", "###", t1));
139 | List> thirdTensorResults = this.model3.predict(thirdInputs);
140 | if (thirdTensorResults == null && thirdTensorResults.get(0) == null || thirdTensorResults.get(0).size() == 0) {
141 | throw new Exception(String.format("ERROR: %s Model3 predict result is null.", imageData.getId()));
142 | }
143 | t2 = System.currentTimeMillis();
144 | System.out.println(String.format("%s MODEL3_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1)));
145 | List thirdResults = this.convertToPredictionResults(thirdTensorResults);
146 | PredictionResult thirdResult = this.getSingleResult(thirdResults);
147 | Integer thirdPredictionId = thirdResult.getPredictionId();
148 |
149 | /*
150 | if (thirdResult.getProbability() >= preferableThreshold1)
151 | {
152 | if (!thirdPredictionId.equals(primaryPredictionId)) {
153 | System.out.println("Force use model3's HP predictionId.");
154 | } else {
155 | System.out.println("Model3's HP predictionId saved time.");
156 | }
157 | return thirdPredictionId;
158 | }
159 | */
160 |
161 | if(thirdPredictionId.equals(primaryPredictionId)&&thirdPredictionId.equals(secondaryPredictionId))
162 | {
163 | System.out.println("Model1, model2 and model3 get agreement.");
164 | return primaryPredictionId;
165 | }
166 | if (primaryPredictionId.equals(secondaryPredictionId)) {
167 | System.out.println("Model1 and model2 get agreement.");
168 | return primaryPredictionId;
169 | }
170 | if (primaryPredictionId.equals(thirdPredictionId)) {
171 | System.out.println("Model1 and model3 get agreement.");
172 | return primaryPredictionId;
173 | }
174 | if (thirdPredictionId.equals(secondaryPredictionId)) {
175 | System.out.println("Model3 and model2 get agreement.");
176 | return secondaryPredictionId;
177 | }
178 | System.out.println("Model1, model2 and model3 does NOT get any agreement.");
179 |
180 | if (secondaryResult.getProbability() >= preferableThreshold2)
181 | {
182 | System.out.println("Model2(HP) is the last choice.");
183 | return secondaryPredictionId;
184 | }
185 | if(thirdResult.getProbability() >= preferableThreshold2){
186 | System.out.println("Model3(LP) is the last choice.");
187 | return thirdPredictionId;
188 | }
189 |
190 | System.out.println("Model1(LP) is the last choice.");
191 | return primaryPredictionId;
192 | }
193 |
194 | /**
195 | * Get single result
196 | * @param pResults
197 | * @return
198 | */
199 | private PredictionResult getSingleResult(List pResults)
200 | {
201 | if(pResults.size()==1)
202 | return pResults.get(0);
203 | System.out.println(String.format("There are %s multiple results", pResults.size()));
204 | Map distinctResults = new HashMap();
205 | for(PredictionResult pResult:pResults)
206 | {
207 | Integer predictionId = pResult.getPredictionId();
208 | PredictionResult existResult = distinctResults.get(predictionId);
209 | if(existResult == null){
210 | distinctResults.put(predictionId, pResult);
211 | }else{
212 | existResult.setCount(existResult.getCount() + 1);
213 | existResult.setProbability( Math.max(pResult.getProbability(), existResult.getProbability()));
214 | distinctResults.put(predictionId, existResult);
215 | }
216 | }
217 | PredictionResult maxResult = null;
218 | for(Integer key:distinctResults.keySet())
219 | {
220 | PredictionResult pResult = distinctResults.get(key);
221 | float ratio = (float)pResult.getCount() / pResults.size();
222 | if(ratio>0.5f)
223 | {
224 | System.out.println("Majority win.");
225 | maxResult = pResult;
226 | break;
227 | }
228 | }
229 | if(maxResult==null)
230 | maxResult = pResults.get(0);
231 | return maxResult;
232 | }
233 |
234 | /**
235 | * Convert JTensor list to prediciton results.
236 | * @param result
237 | * @return
238 | * @throws Exception
239 | */
240 | private List convertToPredictionResults(List> result) throws Exception
241 | {
242 | long beginTime = System.currentTimeMillis();
243 | List pResults = new ArrayList();
244 | System.out.println("result.size(): " + result.size());
245 | System.out.println("result.get(0).size(): " + result.get(0).size());
246 | for(int j=0; j 0) {
250 | float maxProbability = predictData[0];
251 | int maxNo = 0;
252 | for (int i = 1; i < predictData.length; i++) {
253 | if (predictData[i] > maxProbability) {
254 | maxProbability = predictData[i];
255 | maxNo = i;
256 | }
257 | }
258 | PredictionResult pResult = new PredictionResult();
259 | pResult.setPredictionId(maxNo);
260 | pResult.setProbability(maxProbability);
261 | pResult.setCount(1);
262 | pResults.add(pResult);
263 | }else{
264 | throw new Exception("ERROR: predictData.length=0");
265 | }
266 | }
267 | long endTime = System.currentTimeMillis();
268 | if((endTime-beginTime)>5)
269 | System.out.println(String.format("%s CONVERT_TO_PREDICTION_RESULT END %s (Cost: %s)",
270 | "###", endTime, (endTime - beginTime)));
271 | return pResults;
272 | }
273 |
274 | /**
275 | * Load turbojpeg library.
276 | * @throws Exception
277 | */
278 | private void loadTurboJpeg()
279 | {
280 | long beginTime = System.currentTimeMillis();
281 | if(!TurboJpegLoader.isTurbojpegLoaded()) {
282 | throw new RuntimeException("LOAD_TURBOJPEG library failed. Please check.");
283 | }
284 | long endTime = System.currentTimeMillis();
285 | if((endTime - beginTime) > 1)
286 | System.out.println(String.format("LOAD_TURBOJPEG END %s (Cost: %s)", endTime, (endTime - beginTime)));
287 | }
288 |
289 | /**
290 | * Load OpenCV library.
291 | * @throws Exception
292 | */
293 | private void loadOpenCV()
294 | {
295 | long beginTime = System.currentTimeMillis();
296 | if(!OpenCV.isOpenCVLoaded()) {
297 | throw new RuntimeException("LOAD_OPENCV library failed. Please check.");
298 | }
299 | long endTime = System.currentTimeMillis();
300 | if((endTime - beginTime) > 10)
301 | System.out.println(String.format("LOAD_OPENCV END %s (Cost: %s)", endTime, (endTime - beginTime)));
302 |
303 | //if(!OpenvinoNativeLoader.load())
304 | // throw new RuntimeException("LOAD_Openvino library failed. Please check.");
305 | //if(!TFNetNative.isLoaded())
306 | // throw new RuntimeException("LOAD_TFNetNative library failed. Please check.");
307 | }
308 |
309 | /**
310 | * Get image class human string
311 | * @return
312 | */
313 | private String getImageHumanString(Integer id)
314 | {
315 | return ImageClassIndex.getInsatnce().getImageHumanstring(String.valueOf(id));
316 | }
317 |
318 | /**
319 | * For first time warm-dummy check only.
320 | * @throws Exception
321 | */
322 | public void firstTimeDummyCheck() throws Exception
323 | {
324 | this.checkModel(model1, model1Params);
325 | if(this.model2!=null)
326 | this.checkModel(model2, model2Params);
327 | if(this.model3!=null)
328 | this.checkModel(model3, model3Params);
329 | this.getImageHumanString(-1);
330 | this.loadOpenCV();
331 | this.loadTurboJpeg();
332 | }
333 |
334 | /**
335 | * For first time warm-dummy check only.
336 | * @param pModel
337 | * @param pModelParams
338 | * @throws Exception
339 | */
340 | private void checkModel(GarbageClassificationModel pModel, ModelParams pModelParams) throws Exception
341 | {
342 | if(pModel == null) {
343 | throw new Exception("ERROR: The model is null. Aborted the predict.");
344 | }
345 | long beginTime = System.currentTimeMillis();
346 | System.out.println(String.format("FIRST_PREDICT BEGIN %s", beginTime));
347 |
348 | JTensor tensor = new JTensor();
349 | tensor.setData(new float[pModelParams.getInputSize()]);
350 | tensor.setShape(pModelParams.getInputShape());
351 | List list = new ArrayList();
352 | list.add(tensor);
353 | List> inputs = new ArrayList>();
354 | inputs.add(list);
355 | pModel.predict(inputs);
356 |
357 | long endTime = System.currentTimeMillis();
358 | System.out.println(String.format("FIRST_PREDICT END %s (Cost: %s)", endTime, (endTime - beginTime)));
359 | }
360 |
361 | }
362 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/threemodels/OpenVinoModelGenerator3Models.java:
--------------------------------------------------------------------------------
1 | package sky.tf.threemodels;
2 |
3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant;
4 | import org.apache.flink.api.java.tuple.Tuple2;
5 | import sky.tf.ModelParams;
6 |
7 | import java.io.File;
8 | import java.util.List;
9 |
10 | /**
11 | * The class use for generate optimized OpenVino models (Three type of models for prediction)
12 | * @author SkyPeace
13 | */
14 | public class OpenVinoModelGenerator3Models {
15 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH);
16 | private String imageModelPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH);
17 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH");
18 | private ModelParams model1Params;
19 | private ModelParams model2Params;
20 | private ModelParams model3Params;
21 |
22 | public void execute()
23 | {
24 | if (null == this.modelInferencePath) {
25 | throw new RuntimeException("ImageFlatMap(): Not set MODEL_INFERENCE_PATH environmental variable");
26 | }
27 | if (null == this.imageModelPackagePath) {
28 | throw new RuntimeException("ImageFlatMap(): Not set imageModelPathPackagePath environmental variable");
29 | }
30 | String imageModelPackageDir =
31 | imageModelPackagePath.substring(0, imageModelPackagePath.lastIndexOf(File.separator));
32 | if(!imageModelPackageDir.equalsIgnoreCase(modelInferencePath))
33 | {
34 | System.out.println("WARN: modelInferencePath NOT EQUAL imageModelPathPackageDir");
35 | System.out.println(String.format("modelInferencePath: %s", modelInferencePath));
36 | System.out.println(String.format("imageModelPackageDir: %s", imageModelPackageDir));
37 | System.out.println(String.format("imageModelPath: %s", imageModelPath));
38 | }
39 |
40 | try {
41 | //Generate optimized OpenVino model1. resnet_v1_101
42 | String tfModel1Path = modelInferencePath + File.separator + "SavedModel/model1";
43 | String optimizedOpenVinoModel1Dir = imageModelPackageDir;
44 | model1Params = new ModelParams();
45 | model1Params.setModelType("resnet");
46 | model1Params.setModelName("model1");
47 | model1Params.setInputName("input_1");
48 | model1Params.setInputShape(new int[]{1, 224, 224, 3});
49 | model1Params.setMeanValues(new float[]{123.68f,116.78f,103.94f});
50 | model1Params.setScale(1.0f);
51 | model1Params.setOptimizedModelDir(optimizedOpenVinoModel1Dir);
52 | //Call Zoo API generate optimized OpenVino Model
53 | List> optimizedOpenVinoModelData =
54 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel1Path, model1Params);
55 | //Write optimized model's bytes into files (HDFS). The optimized model files parent dir is same as TF model package.
56 | ImageModelLoader3Models.getInstance().
57 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model1Params);
58 |
59 | //Generate optimized OpenVino model2. inception_v4
60 | String tfModel2Path = modelInferencePath + File.separator + "SavedModel/model2";
61 | String optimizedOpenVinoMode2Dir = imageModelPackageDir;
62 | model2Params = new ModelParams();
63 | model2Params.setModelType("inception");
64 | model2Params.setModelName("model2");
65 | model2Params.setInputName("input_1");
66 | model2Params.setInputShape(new int[]{1, 299, 299, 3});
67 | model2Params.setMeanValues(new float[]{127.5f,127.5f,127.5f});
68 | model2Params.setScale(127.5f);
69 | model2Params.setOptimizedModelDir(optimizedOpenVinoMode2Dir);
70 | //Call Zoo API generate optimized OpenVino Model
71 | optimizedOpenVinoModelData =
72 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel2Path, model2Params);
73 | ImageModelLoader3Models.getInstance().
74 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model2Params);
75 |
76 | //Generate optimized OpenVino model3. inception_v3
77 | String tfModel3Path = this.modelInferencePath + File.separator + "SavedModel/model3";
78 | String optimizedOpenVinoMode3Dir = imageModelPackageDir;
79 | model3Params = new ModelParams();
80 | model3Params.setModelType("inception");
81 | model3Params.setModelName("model3");
82 | model3Params.setInputName("input_1");
83 | model3Params.setInputShape(new int[]{1, 299, 299, 3});
84 | model3Params.setMeanValues(new float[]{127.5f,127.5f,127.5f});
85 | model3Params.setScale(127.5f);
86 | model3Params.setOptimizedModelDir(optimizedOpenVinoMode3Dir);
87 | //Call Zoo API generate optimized OpenVino Model
88 | optimizedOpenVinoModelData =
89 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel3Path, model3Params);
90 | ImageModelLoader3Models.getInstance().
91 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model3Params);
92 |
93 | }catch(Exception ex)
94 | {
95 | ex.printStackTrace();
96 | throw new RuntimeException("WARN: OpenVinoModelGenerator.execute() FAILED.");
97 | }
98 | }
99 |
100 | public ModelParams getModel1Params()
101 | {
102 | return model1Params;
103 | }
104 |
105 | public ModelParams getModel2Params()
106 | {
107 | return model2Params;
108 | }
109 |
110 | public ModelParams getModel3Params()
111 | {
112 | return model3Params;
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/src/main/java/sky/tf/threemodels/ValidationDatasetAnalyzer.java:
--------------------------------------------------------------------------------
1 | package sky.tf.threemodels;
2 |
3 | import sky.tf.PredictionResult;
4 |
5 | import java.io.BufferedReader;
6 | import java.io.File;
7 | import java.io.FileReader;
8 | import java.util.ArrayList;
9 | import java.util.List;
10 |
11 | /**
12 | * Experimental. Analysis validation data, search preferable probability thresholds for chose prediction result.
13 | */
14 | public class ValidationDatasetAnalyzer {
15 | private float preferableThreshold1 = 0.94f;
16 | private float preferableThreshold2 = 0.81f;
17 |
18 | public void searchOptimizedThreshold(String validationPredictionDir)
19 | {
20 | String model1ValidationFilePath = validationPredictionDir + File.separator + "model1_validation_record.txt";
21 | String model2ValidationFilePath = validationPredictionDir + File.separator + "model2_validation_record.txt";
22 | String model3ValidationFilePath = validationPredictionDir + File.separator + "model3_validation_record.txt";
23 | List labelList = getValidationLabels(model1ValidationFilePath);
24 | List model1Results = getPredictionResults(model1ValidationFilePath);
25 | List model2Results = getPredictionResults(model2ValidationFilePath);
26 | List model3Results = getPredictionResults(model3ValidationFilePath);
27 | if(labelList.size()==0) {
28 | System.out.println("WARN: Validation labels count is 0.");
29 | return;
30 | }
31 | int model1Success=0;
32 | int model2Success=0;
33 | int model3Success=0;
34 | for (int i = 0; i < labelList.size(); i++) {
35 | if (labelList.get(i).equals(model1Results.get(i).getPredictionId())) {
36 | model1Success++;
37 | }
38 | if (labelList.get(i).equals(model2Results.get(i).getPredictionId())) {
39 | model2Success++;
40 | }
41 | if (labelList.get(i).equals(model3Results.get(i).getPredictionId())) {
42 | model3Success++;
43 | }
44 | }
45 | float model1SucessRatio = (float) model1Success / (float)labelList.size();
46 | float model2SucessRatio = (float) model2Success / (float)labelList.size();
47 | float model3SucessRatio = (float) model3Success / (float)labelList.size();
48 | System.out.println("The model1 success ratio: " + model1SucessRatio);
49 | System.out.println("The model2 success ratio: " + model2SucessRatio);
50 | System.out.println("The model3 success ratio: " + model3SucessRatio);
51 |
52 | float totalRatio = model1SucessRatio*model2SucessRatio*model3SucessRatio +
53 | model1SucessRatio*model2SucessRatio*(1f-model3SucessRatio) +
54 | model1SucessRatio*(1f-model2SucessRatio)*model3SucessRatio +
55 | (1f-model1SucessRatio)*model2SucessRatio*model3SucessRatio;
56 | System.out.println("The total success ratio: " + totalRatio);
57 | long t1 = System.currentTimeMillis();
58 | float tmaxThreshold1 = 0f;
59 | float tmaxThreshold2 = 0f;
60 | float maxSuccessRatio = 0f;
61 | for (int j = 0; j < 20; j++) {
62 | float threshold1 = 0.01f * j + 0.80f;
63 | for (int k = 0; k < 80; k++) {
64 | float threshold2 = 0.01f * k + 0.2f;
65 | int success = 0;
66 | List finalResults = new ArrayList();
67 | for (int i = 0; i < labelList.size(); i++) {
68 | PredictionResult model1Result = model1Results.get(i);
69 | PredictionResult model2Result = model2Results.get(i);
70 | PredictionResult model3Result = model3Results.get(i);
71 |
72 | if (model1Result.getProbability() >= threshold1) {
73 | finalResults.add(model1Result.getPredictionId());
74 | continue;
75 | }
76 |
77 | if (model2Result.getProbability() >= threshold1)
78 | {
79 | finalResults.add(model2Result.getPredictionId());
80 | continue;
81 | }
82 |
83 | /*
84 | if (model3Result.getProbability() >= threshold1)
85 | {
86 | finalResults.add(model3Result.getPredictionId());
87 | continue;
88 | }
89 | */
90 |
91 | if (model1Result.getPredictionId().equals(model2Result.getPredictionId())
92 | && model1Result.getPredictionId().equals(model3Result.getPredictionId())) {
93 | finalResults.add(model1Result.getPredictionId());
94 | continue;
95 | }
96 | if (model2Result.getPredictionId().equals(model3Result.getPredictionId())) {
97 | finalResults.add(model2Result.getPredictionId());
98 | continue;
99 | }
100 | if (model1Result.getPredictionId().equals(model2Result.getPredictionId())) {
101 | finalResults.add(model1Result.getPredictionId());
102 | continue;
103 | }
104 | if (model1Result.getPredictionId().equals(model3Result.getPredictionId())) {
105 | finalResults.add(model1Result.getPredictionId());
106 | continue;
107 | }
108 |
109 | if (model2Result.getProbability() >= threshold2)
110 | {
111 | finalResults.add(model2Result.getPredictionId());
112 | continue;
113 | }
114 | if (model3Result.getProbability() >= threshold2) {
115 | finalResults.add(model3Result.getPredictionId());
116 | continue;
117 | }
118 |
119 | finalResults.add(model1Result.getPredictionId());
120 |
121 | }
122 | for (int i = 0; i < labelList.size(); i++) {
123 | if (labelList.get(i).equals(finalResults.get(i))) {
124 | success++;
125 | }
126 | }
127 | float successRatio = (float) success / (float)labelList.size();
128 | if(successRatio > maxSuccessRatio) {
129 | maxSuccessRatio = successRatio;
130 | tmaxThreshold2 = threshold2;
131 | tmaxThreshold1 = threshold1;
132 | }
133 | }
134 | }
135 | long t2 = System.currentTimeMillis();
136 | System.out.println("Val elapsed time: " + (t2-t1));
137 | System.out.println("Max threshold1: " + tmaxThreshold1);
138 | System.out.println("Max threshold2: " + tmaxThreshold2);
139 | System.out.println("Max successRation: " + maxSuccessRatio);
140 | this.preferableThreshold1 = tmaxThreshold1;
141 | this.preferableThreshold2 = tmaxThreshold2;
142 | }
143 |
144 | /**
145 | * Get validation data prediction results.
146 | */
147 | private List getPredictionResults(String filePath)
148 | {
149 | List pResults = new ArrayList();
150 | try {
151 | FileReader fr = new FileReader(filePath);
152 | BufferedReader br = new BufferedReader(fr);
153 | String line;
154 | int count1 = 0;
155 | while ((line = br.readLine()) != null) {
156 | String columns[] = line.split(",");
157 | //Format like as: 69 XXX => 69:YYY(P=0.74892)
158 | String aString = columns[0];
159 | String aStrings[] = aString.split("=>");
160 | String src = aStrings[0].replace(" ","");
161 | String tgt = aStrings[1].trim();
162 | //tgt = tgt.replace(":", "");
163 | String pString = tgt.substring(tgt.indexOf("P=") +2 );
164 | tgt = tgt.substring(0, tgt.indexOf(":"));
165 | pString = pString.replace(")","");
166 | float P = Float.parseFloat(pString);
167 | PredictionResult result = new PredictionResult();
168 | result.setPredictionId(Integer.valueOf(tgt));
169 | result.setProbability(P);
170 | pResults.add(result);
171 | count1++;
172 | }
173 | //System.out.println("Count1: " + count1);
174 | fr.close();
175 | }catch (Exception ex)
176 | {
177 | String errMsg = "Read model error index file FAILED. " + ex.getMessage();
178 | System.out.println("WRAN: " + errMsg);
179 | //throw new RuntimeException(errMsg);
180 | }
181 | return pResults;
182 | }
183 |
184 | /**
185 | * Get validation dataset labels.
186 | * @param filePath
187 | * @return
188 | */
189 | private List getValidationLabels(String filePath)
190 | {
191 | List pLabels = new ArrayList();
192 | try {
193 | FileReader fr = new FileReader(filePath);
194 | BufferedReader br = new BufferedReader(fr);
195 | String line;
196 | int count1 = 0;
197 | while ((line = br.readLine()) != null) {
198 | String columns[] = line.split(",");
199 | //Format like as: 69 XXX => 69:YYY(P=0.74892)
200 | String aString = columns[0];
201 | String sourceId = aString.substring(0, aString.indexOf(" "));
202 | pLabels.add(Integer.valueOf(sourceId));
203 | }
204 | System.out.println("Count1: " + count1);
205 | fr.close();
206 | }catch (Exception ex)
207 | {
208 | String errMsg = "Read model error index file FAILED. " + ex.getMessage();
209 | System.out.println("WRAN: " + errMsg);
210 | //throw new RuntimeException(errMsg);
211 | }
212 | return pLabels;
213 | }
214 |
215 | public float getPreferableThreshold1()
216 | {
217 | return this.preferableThreshold1;
218 | }
219 | public float getPreferableThreshold2()
220 | {
221 | return this.preferableThreshold2;
222 | }
223 |
224 | public static void main(String args[])
225 | {
226 | ValidationDatasetAnalyzer validationDatasetTrainer = new ValidationDatasetAnalyzer();
227 | validationDatasetTrainer.searchOptimizedThreshold("/Users/xyz/PycharmProjects/Inference");
228 | System.out.println(validationDatasetTrainer.getPreferableThreshold1());
229 | System.out.println(validationDatasetTrainer.getPreferableThreshold2());
230 | }
231 |
232 | }
233 |
--------------------------------------------------------------------------------
/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Licensed to the Apache Software Foundation (ASF) under one
3 | # or more contributor license agreements. See the NOTICE file
4 | # distributed with this work for additional information
5 | # regarding copyright ownership. The ASF licenses this file
6 | # to you under the Apache License, Version 2.0 (the
7 | # "License"); you may not use this file except in compliance
8 | # with the License. You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | ################################################################################
18 |
19 | log4j.rootLogger=INFO, console
20 |
21 | log4j.appender.console=org.apache.log4j.ConsoleAppender
22 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
23 | #log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n
24 | log4j.appender.console.layout.ConversionPattern=[%p] %d{yyyy-MM-dd HH:mm:ss,SSS} %-c %m %n
25 |
--------------------------------------------------------------------------------