├── README.md ├── pom.xml └── src └── main ├── java └── sky │ └── tf │ ├── GarbageClassificationModel.java │ ├── IModel.java │ ├── ImageClassIndex.java │ ├── ImageClassificationMain.java │ ├── ImageDataPreprocessing.java │ ├── ModelParams.java │ ├── PredictionResult.java │ ├── TurboJpegLoader.java │ ├── multiplemodels │ ├── ImageFlatMapMultipleModels.java │ ├── ImageModelLoaderMultipleModels.java │ ├── ImagePredictSupportiveMultipleModels.java │ └── OpenVinoModelGeneratorMultipleModels.java │ └── threemodels │ ├── ImageFlatMap3Models.java │ ├── ImageModelLoader3Models.java │ ├── ImagePredictSupportive3Models.java │ ├── OpenVinoModelGenerator3Models.java │ └── ValidationDatasetAnalyzer.java └── resources └── log4j.properties /README.md: -------------------------------------------------------------------------------- 1 | # Apache Flink极客挑战赛——垃圾图片分类—复赛Java code 2 | 3 | # 1. How to build 4 | 5 | #### 1.1 安装 garbage_image_util 包到本地maven仓库 6 | ```bash 7 | sh install_jar.sh 8 | ``` 9 | 注:garbage_image_util 由天池提供,详见天池比赛的readme 10 | 11 | 12 | #### 1.2 build package (jar) 13 | ```bash 14 | #mvn clean package 15 | ``` 16 | 17 | 18 | # 2.许可声明 19 | 你可以使用此代码用于学习和研究,但务必不要将此代码用于任何商业用途和比赛项目。 20 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 19 | 21 | 4.0.0 22 | 23 | org.skypeace 24 | sky-flink-phase2 25 | 0.1 26 | jar 27 | 28 | Flink Quickstart Job 29 | https://tianchi.aliyun.com/ 30 | 31 | 32 | UTF-8 33 | 1.8.1 34 | 1.8 35 | 2.11 36 | 2.7.7 37 | ${java.version} 38 | ${java.version} 39 | 40 | 41 | 42 | 43 | apache.snapshots 44 | Apache Development Snapshot Repository 45 | https://repository.apache.org/content/repositories/snapshots/ 46 | 47 | false 48 | 49 | 50 | true 51 | 52 | 53 | 54 | ossrh 55 | sonatype repositroy 56 | https://oss.sonatype.org/content/groups/public/ 57 | 58 | true 59 | 60 | 61 | true 62 | 63 | 64 | 65 | central 66 | http://repo1.maven.org/maven2/ 67 | 68 | true 69 | 70 | 71 | true 72 | 73 | 74 | 75 | ome-releases 76 | https://artifacts.openmicroscopy.org/artifactory/ome.releases/ 77 | 78 | true 79 | 80 | 81 | true 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | com.alibaba.tianchi 90 | garbage_image_util 91 | 1.0-SNAPSHOT 92 | 93 | 94 | xml-apis 95 | xml-apis 96 | 97 | 98 | 99 | 100 | ome 101 | turbojpeg 102 | 6.2.1 103 | 104 | 105 | 106 | 107 | org.apache.flink 108 | flink-java 109 | ${flink.version} 110 | provided 111 | 112 | 113 | org.apache.flink 114 | flink-streaming-java_${scala.binary.version} 115 | ${flink.version} 116 | provided 117 | 118 | 119 | org.apache.flink 120 | flink-core 121 | ${flink.version} 122 | provided 123 | 124 | 125 | org.scala-lang 126 | scala-library 127 | 2.11.11 128 | provided 129 | 130 | 131 | 132 | com.intel.analytics.zoo 133 | analytics-zoo-bigdl_0.9.0-spark_2.4.3 134 | 0.6.0-SNAPSHOT 135 | jar 136 | 137 | 138 | com.intel.analytics.zoo 139 | zoo-core-openvino-java-linux 140 | 0.6.0-SNAPSHOT 141 | 142 | 143 | com.intel.analytics.zoo 144 | zoo-core-mkl-linux 145 | 0.6.0-SNAPSHOT 146 | 147 | 148 | com.intel.analytics.zoo 149 | zoo-core-pmem-java-linux 150 | 0.6.0-SNAPSHOT 151 | 152 | 153 | 154 | 155 | org.apache.spark 156 | spark-core_2.11 157 | 2.4.3 158 | compile 159 | 160 | 161 | org.apache.spark 162 | spark-network-common_2.11 163 | 2.4.3 164 | runtime 165 | 166 | 167 | org.apache.spark 168 | spark-network-shuffle_2.11 169 | 2.4.3 170 | runtime 171 | 172 | 173 | org.apache.spark 174 | spark-mllib_2.11 175 | 2.4.3 176 | runtime 177 | 178 | 179 | 180 | org.slf4j 181 | slf4j-log4j12 182 | 1.7.7 183 | runtime 184 | 185 | 186 | log4j 187 | log4j 188 | 1.2.17 189 | runtime 190 | 191 | 192 | 193 | org.json4s 194 | json4s-core_2.11 195 | 3.5.3 196 | jar 197 | compile 198 | 199 | 200 | 201 | org.json4s 202 | json4s-jackson_2.11 203 | 3.5.3 204 | jar 205 | compile 206 | 207 | 208 | 209 | com.google.protobuf 210 | protobuf-java 211 | 3.9.1 212 | jar 213 | compile 214 | 215 | 216 | 217 | org.apache.hadoop 218 | hadoop-common 219 | ${hadoop.version} 220 | 221 | 222 | org.apache.hadoop 223 | hadoop-hdfs 224 | ${hadoop.version} 225 | 226 | 227 | org.apache.hadoop 228 | hadoop-client 229 | ${hadoop.version} 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | org.apache.maven.plugins 240 | maven-compiler-plugin 241 | 3.1 242 | 243 | ${java.version} 244 | ${java.version} 245 | 246 | 247 | 248 | 249 | 250 | 251 | org.apache.maven.plugins 252 | maven-shade-plugin 253 | 3.0.0 254 | 255 | 256 | 257 | package 258 | 259 | shade 260 | 261 | 262 | 263 | 264 | org.apache.flink:* 265 | org.slf4j:* 266 | log4j:* 267 | com.google.code.findbugs:jsr305 268 | 269 | 270 | 271 | 272 | 274 | *:* 275 | 276 | META-INF/*.SF 277 | META-INF/*.DSA 278 | META-INF/*.RSA 279 | 280 | 281 | 282 | 283 | 284 | sky.tf.ImageClassificationMain 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | org.eclipse.m2e 299 | lifecycle-mapping 300 | 1.0.0 301 | 302 | 303 | 304 | 305 | 306 | org.apache.maven.plugins 307 | maven-shade-plugin 308 | [3.0.0,) 309 | 310 | shade 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | org.apache.maven.plugins 320 | maven-compiler-plugin 321 | [3.1,) 322 | 323 | testCompile 324 | compile 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | add-dependencies-for-IDEA 345 | 346 | 347 | 348 | idea.version 349 | 350 | 351 | 352 | 353 | 354 | org.apache.flink 355 | flink-core 356 | ${flink.version} 357 | compile 358 | 359 | 360 | org.apache.flink 361 | flink-java 362 | ${flink.version} 363 | compile 364 | 365 | 366 | org.apache.flink 367 | flink-streaming-java_${scala.binary.version} 368 | ${flink.version} 369 | compile 370 | 371 | 372 | com.intel.analytics.zoo 373 | zoo-core-tfnet-linux 374 | 0.6.0-SNAPSHOT 375 | jar 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/GarbageClassificationModel.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | 3 | import com.intel.analytics.zoo.pipeline.inference.*; 4 | 5 | import java.util.List; 6 | import java.util.concurrent.LinkedBlockingQueue; 7 | 8 | /** 9 | * @author SkyPeace 10 | * The models for predict. For this case, use Zoo OpenVINOModel directly instead of InferenceModel. 11 | */ 12 | public class GarbageClassificationModel implements IModel, java.io.Serializable 13 | { 14 | // Support Zoo InferenceModel and Zoo OpenVINOModel directly 15 | // For this case, the final decision is use Zoo OpenVINOModel directly instead of Zoo InferenceModel. 16 | private OpenVINOModel openVINOModel; 17 | private LinkedBlockingQueue referenceQueue = new LinkedBlockingQueue(12); 18 | 19 | public GarbageClassificationModel(byte[] modelXml, byte[] modelBin) 20 | { 21 | this.openVINOModel = this.loadOpenVINOModel(modelXml, modelBin); 22 | } 23 | 24 | private OpenVINOModel loadOpenVINOModel(byte[] modelXml, byte[] modelBin){ 25 | long beginTime = System.currentTimeMillis(); 26 | System.out.println(String.format("LOAD_OPENVION_MODEL_FROM_BYTES BEGIN %s", beginTime)); 27 | OpenVINOModel model = OpenVinoInferenceSupportive$.MODULE$.loadOpenVinoIR(modelXml, modelBin, 28 | com.intel.analytics.zoo.pipeline.inference.DeviceType.CPU(), 0); 29 | //OpenVinoInferenceSupportive.loadOpenVinoIRFromTempDir("modelName", "tempDir"); 30 | long endTime = System.currentTimeMillis(); 31 | System.out.println(String.format("LOAD_OPENVION_MODEL_FROM_BYTES END %s (Cost: %s)", 32 | endTime, (endTime - beginTime))); 33 | return model; 34 | } 35 | 36 | /** 37 | * Predict by inputs 38 | * @param inputs 39 | * @return 40 | * @throws Exception 41 | */ 42 | public List> predict(List> inputs) throws Exception { 43 | if(openVINOModel!=null) 44 | return openVINOModel.predict(inputs); 45 | else if(inferenceModel!=null) 46 | return inferenceModel.doPredict(inputs); 47 | else 48 | throw new RuntimeException("inferenceModel and openVINOModel are both null."); 49 | } 50 | 51 | public void addRefernce() throws InterruptedException 52 | { 53 | referenceQueue.put(1); 54 | } 55 | 56 | /** 57 | * Release model 58 | * @throws Exception 59 | */ 60 | public synchronized void release() throws Exception 61 | { 62 | if(referenceQueue.peek()==null) 63 | return; 64 | referenceQueue.poll(); 65 | if(referenceQueue.peek()==null) 66 | { 67 | if (openVINOModel != null) { 68 | System.out.println("Release openVINOModel ..."); 69 | openVINOModel.release(); 70 | System.out.println("openVINOModel released"); 71 | }else if (inferenceModel != null) { 72 | System.out.println("Release inferenceModel ..."); 73 | inferenceModel.doRelease(); 74 | System.out.println("inferenceModel released"); 75 | } 76 | else 77 | throw new RuntimeException("openVINOModel and inferenceModel are both null."); 78 | } 79 | } 80 | 81 | 82 | //Below code is reserved for test. Do not delete them. 83 | private InferenceModel inferenceModel; 84 | /** 85 | * Reserved for experimental test 86 | * @param savedModelBytes 87 | */ 88 | public GarbageClassificationModel(byte[] savedModelBytes, ModelParams modelParams) { 89 | this.inferenceModel = this.loadInferenceModel(savedModelBytes, modelParams); 90 | } 91 | 92 | /** 93 | * Reserved for experimental test 94 | * @param savedModelBytes 95 | * @return 96 | */ 97 | private InferenceModel loadInferenceModel(byte[] savedModelBytes, ModelParams modelParams) 98 | { 99 | long beginTime = System.currentTimeMillis(); 100 | System.out.println(String.format("loadInferenceModel BEGIN %s", beginTime)); 101 | InferenceModel model = new InferenceModel(1); 102 | model.doLoadTF(savedModelBytes, modelParams.getInputShape(), false, 103 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName()); 104 | long endTime = System.currentTimeMillis(); 105 | System.out.println(String.format("loadInferenceModel END %s (Cost: %s)", 106 | endTime, (endTime - beginTime))); 107 | return model; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/IModel.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | 3 | import com.intel.analytics.zoo.pipeline.inference.JTensor; 4 | 5 | import java.util.List; 6 | 7 | public interface IModel { 8 | List> predict(List> inputs) throws Exception; 9 | void release() throws Exception; 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/ImageClassIndex.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | 3 | import java.io.*; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | 7 | /** 8 | * @author SkyPeace 9 | * Image class index 10 | */ 11 | public class ImageClassIndex { 12 | 13 | private Map mapClassIndex; 14 | private static ImageClassIndex instance = new ImageClassIndex(); 15 | 16 | private ImageClassIndex() 17 | { 18 | } 19 | public static ImageClassIndex getInsatnce() 20 | { 21 | return instance; 22 | } 23 | 24 | public synchronized void loadClassIndexMap(String modelInferencePath) 25 | { 26 | String labelFilePath = modelInferencePath + File.separator + "labels.txt"; 27 | if(this.mapClassIndex == null) { 28 | long beginTime = System.currentTimeMillis(); 29 | System.out.println(String.format("READ_CLASS_INDEX BEGIN %s", beginTime)); 30 | this.mapClassIndex = this.read(labelFilePath); 31 | long endTime = System.currentTimeMillis(); 32 | System.out.println(String.format("READ_CLASS_INDEX END %s (Cost: %s)", endTime, (endTime - beginTime))); 33 | } 34 | } 35 | 36 | public String getImageHumanstring(String id) 37 | { 38 | return mapClassIndex.get(id); 39 | } 40 | 41 | /** 42 | * Get class index. 43 | * For test only. 44 | * @param filePath 45 | * @return 46 | */ 47 | private Map read(String filePath) 48 | { 49 | Map map = new HashMap(); 50 | try { 51 | FileReader fr = new FileReader(filePath); 52 | BufferedReader br = new BufferedReader(fr); 53 | String line; 54 | while ((line = br.readLine()) != null) { 55 | String columns[] = line.split(":"); 56 | map.put(columns[0], columns[1]); 57 | } 58 | fr.close(); 59 | }catch (IOException ex) 60 | { 61 | String errMsg = "Read class index file FAILED. " + ex.getMessage(); 62 | System.out.println("ERROR: " + errMsg); 63 | throw new RuntimeException(errMsg); 64 | } 65 | return map; 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/ImageClassificationMain.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | import com.alibaba.tianchi.garbage_image_util.ImageClassSink; 3 | import com.alibaba.tianchi.garbage_image_util.ImageDirSource; 4 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 5 | import sky.tf.threemodels.ImageFlatMap3Models; 6 | import sky.tf.threemodels.OpenVinoModelGenerator3Models; 7 | /** 8 | * The main class of Image inference for garbage classification 9 | * @author SkyPeace 10 | */ 11 | public class ImageClassificationMain { 12 | public static void main(String[] args) throws Exception { 13 | StreamExecutionEnvironment flinkEnv = StreamExecutionEnvironment.getExecutionEnvironment(); 14 | 15 | //Step 1. Generate optimized OpenVino models (Three type of models) for later prediction. 16 | OpenVinoModelGenerator3Models modelGenerator = new OpenVinoModelGenerator3Models(); 17 | modelGenerator.execute(); 18 | ImageFlatMap3Models imageFlatMap = new ImageFlatMap3Models(modelGenerator.getModel1Params(), 19 | modelGenerator.getModel2Params(), modelGenerator.getModel3Params()); 20 | 21 | ImageDirSource source = new ImageDirSource(); 22 | //IMPORTANT: Operator chaining maybe hit the score log issue (Tianchi ENV) when parallelism is set to N_N_x. 23 | //Use statistic tag PREDICT_PROCESS to get prediction's real elapsed time 24 | //flinkEnv.disableOperatorChaining(); 25 | flinkEnv.addSource(source).setParallelism(1) 26 | .flatMap(imageFlatMap).setParallelism(2) 27 | .addSink(new ImageClassSink()).setParallelism(2); 28 | flinkEnv.execute("Image inference for garbage classification-PhaseII-1.0 - SkyPeace"); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/ImageDataPreprocessing.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ImageData; 4 | import com.intel.analytics.bigdl.transform.vision.image.opencv.OpenCVMat; 5 | import com.intel.analytics.zoo.pipeline.inference.JTensor; 6 | import org.libjpegturbo.turbojpeg.TJ; 7 | import org.libjpegturbo.turbojpeg.TJDecompressor; 8 | import org.opencv.core.*; 9 | import org.opencv.imgcodecs.Imgcodecs; 10 | import org.opencv.imgproc.Imgproc; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | /** 16 | * @author SkyPeace 17 | * The class for image data preprocessing. 18 | */ 19 | public class ImageDataPreprocessing { 20 | public static final int IMAGE_DECODER_OPENCV = 0; 21 | public static final int IMAGE_DECODER_TURBOJPEG = 1; 22 | 23 | public static final int PREPROCESSING_VGG = 11; 24 | public static final int PREPROCESSING_INCEPTION = 12; 25 | public static final int PREPROCESSING_TIANCHI = 13; 26 | private ModelParams modelParams; 27 | 28 | public ImageDataPreprocessing(ModelParams modelParams) 29 | { 30 | this.modelParams = modelParams; 31 | } 32 | 33 | /** 34 | * Get RGB Mat data 35 | * @param imageData 36 | * @param decodeType 37 | * @return 38 | * @throws Exception 39 | */ 40 | public static Mat getRGBMat(ImageData imageData, int decodeType) throws Exception 41 | { 42 | //********************** Decode image **********************// 43 | long beginTime = System.currentTimeMillis(); 44 | System.out.println(String.format("%s IMAGE_BYTES = %s", "###", imageData.getImage().length)); 45 | System.out.println(String.format("%s IMAGE_BYTES_IMDECODE BEGIN %s", "###", beginTime)); 46 | Mat matRGB = null; 47 | if(decodeType == IMAGE_DECODER_OPENCV) 48 | matRGB = decodeByOpenCV(imageData); 49 | else if(decodeType == IMAGE_DECODER_TURBOJPEG) 50 | matRGB = decodeByTurboJpeg(imageData); 51 | else 52 | throw new Exception(String.format("Not support such decodeType: %s", decodeType)); 53 | System.out.println(String.format("%s IMAGE_BYTES_IMDECODE END %s (Cost: %s)", 54 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime))); 55 | return matRGB; 56 | } 57 | 58 | /** 59 | * Decode image data by turbojpeg. It is more fast than OpenCV decoder. 60 | * Please refer to https://libjpeg-turbo.org/About/Performance for performance comparison. 61 | * @param imageData 62 | * @return 63 | * @throws Exception 64 | */ 65 | private static Mat decodeByTurboJpeg(ImageData imageData) throws Exception 66 | { 67 | TJDecompressor tjd = new TJDecompressor(imageData.getImage()); 68 | //TJ.FLAG_FASTDCT can get more performance. TJ.FLAG_ACCURATEDCT can get more accuracy. 69 | byte bytes[] = tjd.decompress(tjd.getWidth(), 0, tjd.getHeight(), TJ.PF_RGB, TJ.FLAG_ACCURATEDCT); 70 | Mat matRGB = new Mat(tjd.getHeight(), tjd.getWidth(), CvType.CV_8UC3); 71 | matRGB.put(0, 0, bytes); 72 | return matRGB; 73 | } 74 | 75 | /** 76 | * Decode image data by OpenCV 77 | * @param imageData 78 | * @return 79 | */ 80 | private static Mat decodeByOpenCV(ImageData imageData) 81 | { 82 | Mat srcmat = new MatOfByte(imageData.getImage()); 83 | Mat mat = Imgcodecs.imdecode(srcmat, Imgcodecs.CV_LOAD_IMAGE_COLOR); 84 | Mat matRGB = new MatOfByte(); 85 | Imgproc.cvtColor(mat, matRGB, Imgproc.COLOR_BGR2RGB); 86 | System.out.println(String.format("%s image_width, image_height, image_depth, image_dims = %s, %s, %s, %s", 87 | "###", mat.width(), mat.height(), mat.depth(), mat.dims())); 88 | return matRGB; 89 | } 90 | 91 | /** 92 | * Preprocess image data and convert to float array (RGB) for JTensor list need. 93 | * @param matRGB 94 | * @param preprocessType 95 | * @param enableMultipleVoters 96 | * @return 97 | * @throws Exception 98 | */ 99 | public List> doPreProcessing(Mat matRGB, int preprocessType, boolean enableMultipleVoters) throws Exception 100 | { 101 | //********************** Resize, crop and other pre-processing **********************// 102 | long beginTime = System.currentTimeMillis(); 103 | System.out.println(String.format("%s IMAGE_RESIZE BEGIN %s", "###", beginTime)); 104 | List matList = new ArrayList(); 105 | int image_size = modelParams.getInputShape()[2]; 106 | 107 | //VGG_PREPROCESS, INCEPTION_PREPROCESS and PREPROCESSING_TIANCHI 108 | if(preprocessType==PREPROCESSING_VGG) { 109 | matList = this.resizeAndCenterCropImage(matRGB, 256, image_size, image_size, enableMultipleVoters); 110 | }else if(preprocessType==PREPROCESSING_INCEPTION) { 111 | matList = this.cropAndResizeImage(matRGB, image_size, image_size, enableMultipleVoters); 112 | }else if(preprocessType==PREPROCESSING_TIANCHI) { 113 | matList = this.resize(matRGB, image_size); 114 | } 115 | else 116 | throw new Exception(String.format("Not support such preprocessType: %s", preprocessType)); 117 | System.out.println(String.format("%s IMAGE_RESIZE END %s (Cost: %s)", 118 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime))); 119 | 120 | //********** Convert Mat to float array. R channel, G channel, B channel **********// 121 | beginTime = System.currentTimeMillis(); 122 | System.out.println(String.format("%s CONVERT_TO_RGB_FLOAT BEGIN %s", "###", beginTime)); 123 | List> inputs = new ArrayList>(); 124 | for(int i=0;i(); 126 | float data[] = new float[modelParams.getInputSize()]; 127 | List rgbMatList = new ArrayList(); 128 | //Split to R channel, G channel and B channel 129 | org.opencv.core.Core.split(matList.get(i), rgbMatList); 130 | MatOfByte matFloat = new MatOfByte(); 131 | //VConcat R,G,B and convert to float array. Native operation, quick enough (AVG less than 1ms). 132 | org.opencv.core.Core.vconcat(rgbMatList, matFloat); 133 | OpenCVMat.toFloatPixels(matFloat, data); 134 | //System.arraycopy(data, 0, datat, i*data.length, data.length); 135 | 136 | /* Below code is the example of worst performance, DO NOT coding as that. 137 | for (int row = 0; row < 224; row++) { 138 | for (int col = 0; col < 224; col++) { 139 | data[(col + row * 224) + 224 * 224 * 0] = (float) (dst.get(row, col)[0]); 140 | data[(col + row * 224) + 224 * 224 * 1] = (float) (dst.get(row, col)[1]); 141 | data[(col + row * 224) + 224 * 224 * 2] = (float) (dst.get(row, col)[2]); 142 | } 143 | } 144 | */ 145 | 146 | //Create a JTensor 147 | JTensor tensor = new JTensor(); 148 | tensor.setData(data); 149 | tensor.setShape(modelParams.getInputShape()); 150 | list.add(tensor); 151 | inputs.add(list); 152 | } 153 | System.out.println(String.format("%s CONVERT_TO_RGB_FLOAT END %s (Cost: %s)", 154 | "###", System.currentTimeMillis(), (System.currentTimeMillis() - beginTime))); 155 | return inputs; 156 | } 157 | 158 | /** 159 | * Just resize to specified size. Use for TianChi model 160 | * @param src 161 | * @param image_size 162 | * @return 163 | */ 164 | private List resize(Mat src, int image_size) 165 | { 166 | Mat resizedMat = new MatOfByte(); 167 | Imgproc.resize(src, resizedMat, new Size(image_size, image_size), 0, 0, Imgproc.INTER_CUBIC); 168 | List matList = new ArrayList(); 169 | matList.add(resizedMat); 170 | return matList; 171 | } 172 | 173 | /** 174 | * Resize and center crop. For VGG_PREPROCESS 175 | * @param src 176 | * @param smallestSide 177 | * @param outWidth 178 | * @param outHeight 179 | * @param enableMultipleVoters 180 | */ 181 | private List resizeAndCenterCropImage(Mat src, int smallestSide, int outWidth, int outHeight, boolean enableMultipleVoters) 182 | { 183 | //OpenCV resize. Use INTER_LINEAR, INTER_CUBIC or INTER_LANCZOS4 184 | //Imgproc.resize(src, dst, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_CUBIC); 185 | float scale = 0f; 186 | if(src.height()>src.width()) 187 | scale = (float)smallestSide / src.width(); 188 | else 189 | scale = (float)smallestSide / src.height(); 190 | float newHeight = src.height() * scale; 191 | float newWidth = src.width() * scale; 192 | Mat resizedMat = new MatOfByte(); 193 | 194 | Imgproc.resize(src, resizedMat, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); 195 | int offsetHeight = (int)((resizedMat.height() - outHeight) / 2); 196 | int offsetWidth = (int)((resizedMat.width() - outWidth) / 2); 197 | 198 | //Center 199 | Mat matCenter = resizedMat.submat(offsetHeight, offsetHeight + outHeight, offsetWidth, offsetWidth + outWidth); 200 | List matList = new ArrayList(); 201 | matList.add(matCenter); 202 | 203 | if(enableMultipleVoters) { 204 | //matFlip 205 | Mat matCenterFlip = new MatOfByte(); 206 | Core.flip(matCenter, matCenterFlip, 0); 207 | //Mat matTop = resizedMat.submat(0, outHeight, offsetWidth, offsetWidth + outWidth); 208 | Mat matBottom = resizedMat.submat((int) resizedMat.height() - outHeight, (int) resizedMat.height(), offsetWidth, offsetWidth + outWidth); 209 | Mat matLeft = resizedMat.submat(offsetHeight, offsetHeight + outHeight, 0, outWidth); 210 | Mat matRight = resizedMat.submat(offsetHeight, offsetHeight + outHeight, (int) resizedMat.width() - outWidth, (int) resizedMat.width()); 211 | matList.add(matCenterFlip); 212 | //matList.add(matTop); 213 | matList.add(matBottom); 214 | matList.add(matRight); 215 | matList.add(matLeft); 216 | } 217 | return matList; 218 | } 219 | 220 | /** 221 | * Crop and resize. For INCEPTION_PREPROCESS 222 | * @param src 223 | * @param outWidth 224 | * @param outHeight 225 | * @param enableMultipleVoters 226 | * @return 227 | */ 228 | private List cropAndResizeImage(Mat src, int outWidth, int outHeight, boolean enableMultipleVoters) 229 | { 230 | float scale = 0.875f; 231 | float newHeight = src.height() * scale; 232 | float newWidth = src.width() * scale; 233 | 234 | int offsetHeight = (int)((src.height() - newHeight) / 2); 235 | int offsetWidth = (int)((src.width() - newWidth) / 2); 236 | Mat centerCropMat = src.submat(offsetHeight, offsetHeight + (int)newHeight, offsetWidth, offsetWidth + (int)newWidth); 237 | 238 | Mat resizedMat = new MatOfByte(); 239 | Imgproc.resize(centerCropMat, resizedMat, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR); 240 | List matList = new ArrayList(); 241 | matList.add(resizedMat); 242 | 243 | if(enableMultipleVoters) { 244 | Mat matCenterFlip = new MatOfByte(); 245 | Core.flip(resizedMat, matCenterFlip, 0); 246 | matList.add(matCenterFlip); 247 | 248 | Mat matBottom = src.submat(src.height()-(int)newHeight, src.height(), offsetWidth, offsetWidth + (int)newWidth); 249 | Mat resizedMatBottom = new MatOfByte(); 250 | Imgproc.resize(matBottom, resizedMatBottom, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR); 251 | matList.add(resizedMatBottom); 252 | 253 | Mat matLeft = src.submat(offsetHeight, offsetHeight + (int) newHeight, 0, (int) newWidth); 254 | Mat resizedMatLeft = new MatOfByte(); 255 | Imgproc.resize(matLeft, resizedMatLeft, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR); 256 | matList.add(resizedMatLeft); 257 | 258 | Mat matRight = src.submat(offsetHeight, offsetHeight + (int)newHeight, src.width() - (int)newWidth, src.width()); 259 | Mat resizedMatRight = new MatOfByte(); 260 | Imgproc.resize(matRight, resizedMatRight, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR); 261 | matList.add(resizedMatRight); 262 | } 263 | return matList; 264 | } 265 | 266 | /** 267 | * Another implements of INCEPTION_PREPROCESS. For experimental only. 268 | * @param src 269 | * @param outWidth 270 | * @param outHeight 271 | * @param dst 272 | */ 273 | private void cropAndResizeImageReserved(Mat src, int outWidth, int outHeight, Mat dst) 274 | { 275 | Mat srcFloat = new MatOfFloat(); 276 | src.convertTo(srcFloat, CvType.CV_32F, 1.0f/255.0f); 277 | float scale = 0.875f; 278 | float newHeight = srcFloat.height() * scale; 279 | float newWidth = srcFloat.width() * scale; 280 | 281 | int offsetHeight = (int)((srcFloat.height() - newHeight) / 2); 282 | int offsetWidth = (int)((srcFloat.width() - newWidth) / 2); 283 | Mat cropMat = srcFloat.submat(offsetHeight, offsetHeight + (int)newHeight, offsetWidth, offsetWidth + (int)newWidth); 284 | 285 | Mat resizedMat = new MatOfFloat(); 286 | Imgproc.resize(cropMat, resizedMat, new Size(outWidth, outHeight), 0, 0, Imgproc.INTER_LINEAR); 287 | 288 | Mat resizedMat1 = new MatOfFloat(); 289 | Core.subtract(resizedMat, Scalar.all(0.5), resizedMat1); 290 | Core.multiply(resizedMat1, Scalar.all(2.0), dst); 291 | } 292 | 293 | } 294 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/ModelParams.java: -------------------------------------------------------------------------------- 1 | package sky.tf; 2 | 3 | /** 4 | * @author SkyPeace 5 | * The class for configure model's parameters. 6 | */ 7 | public class ModelParams implements java.io.Serializable 8 | { 9 | private String modelType; //resnet, inception 10 | private String modelName; 11 | private String inputName; 12 | private int[] inputShape; 13 | private int inputSize; 14 | private float[] meanValues; 15 | private float scale; 16 | 17 | private String optimizedModelDir; 18 | 19 | public String getModelName() { 20 | return modelName; 21 | } 22 | 23 | public void setModelName(String modelName) { 24 | this.modelName = modelName; 25 | } 26 | 27 | public String getOptimizedModelDir() { 28 | return optimizedModelDir; 29 | } 30 | 31 | public void setOptimizedModelDir(String optimizedModelDir) { 32 | this.optimizedModelDir = optimizedModelDir; 33 | } 34 | 35 | public String getModelType() { 36 | return modelType; 37 | } 38 | 39 | public void setModelType(String modelType) { 40 | this.modelType = modelType; 41 | } 42 | 43 | public String getInputName() { 44 | return inputName; 45 | } 46 | 47 | public void setInputName(String inputName) { 48 | this.inputName = inputName; 49 | } 50 | 51 | public int[] getInputShape() { 52 | return inputShape; 53 | } 54 | 55 | public void setInputShape(int[] inputShape) { 56 | this.inputShape = inputShape; 57 | } 58 | 59 | public int getInputSize() { 60 | int size = 1; 61 | for(int i=1;i { 21 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH); 22 | private String imageModelPathPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH); 23 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH"); 24 | private List modelParamsList; 25 | private transient List modelList; 26 | private transient ImagePredictSupportiveMultipleModels supportive; 27 | 28 | public ImageFlatMapMultipleModels(List modelParamsList) 29 | { 30 | this.modelParamsList = modelParamsList; 31 | } 32 | 33 | @Override 34 | public void open(Configuration parameters) throws Exception 35 | { 36 | //For troubleshooting use. 37 | System.out.println(String.format("ImageFlatMap.open(): imageModelPath is %s", this.imageModelPath)); 38 | System.out.println(String.format("ImageFlatMap.open(): modelInferencePath is %s", this.modelInferencePath)); 39 | System.out.println(String.format("ImageFlatMap.open(): imageModelPathPackagePath is %s", this.imageModelPathPackagePath)); 40 | 41 | //Step2: Load optimized OpenVino model from files (HDFS). 42 | // Cost about 1 seconds each model, quick enough. But it is not good solution to use too many models in client. 43 | modelList = new ArrayList(); 44 | for(ModelParams modelParams:modelParamsList) { 45 | GarbageClassificationModel model = 46 | ImageModelLoaderMultipleModels.getInstance().loadOpenVINOModelOnce(modelParams); 47 | modelList.add(model); 48 | } 49 | 50 | //First time warm-dummy check 51 | ImageClassIndex.getInsatnce().loadClassIndexMap(modelInferencePath); 52 | this.supportive = new ImagePredictSupportiveMultipleModels(modelList, modelParamsList); 53 | this.supportive.firstTimeDummyCheck(); 54 | } 55 | 56 | @Override 57 | public void flatMap(ImageData value, Collector out) 58 | throws Exception 59 | { 60 | IdLabel idLabel = new IdLabel(); 61 | idLabel.setId(value.getId()); 62 | 63 | long beginTime = System.currentTimeMillis(); 64 | System.out.println(String.format("PREDICT_PROCESS BEGIN %s", beginTime)); 65 | 66 | String imageLabelString = supportive.predictHumanString(value); 67 | 68 | long endTime = System.currentTimeMillis(); 69 | System.out.println(String.format("PREDICT_PROCESS END %s (Cost: %s)", endTime, (endTime - beginTime))); 70 | 71 | //Check whether elapsed time >= threshold. Logging it for review. 72 | if((endTime - beginTime)>498) 73 | System.out.println(String.format("PREDICT_PROCESS MAYBE EXCEED 500ms %s (Cost: %s)", 74 | endTime, (endTime - beginTime))); 75 | 76 | idLabel.setLabel(imageLabelString); 77 | out.collect(idLabel); 78 | } 79 | 80 | @Override 81 | public void close() throws Exception 82 | { 83 | System.out.println(String.format("getNumberOfParallelSubtasks: %s", 84 | this.getRuntimeContext().getNumberOfParallelSubtasks())); 85 | for(GarbageClassificationModel model:modelList) { 86 | model.release(); 87 | } 88 | } 89 | 90 | } 91 | 92 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/multiplemodels/ImageModelLoaderMultipleModels.java: -------------------------------------------------------------------------------- 1 | package sky.tf.multiplemodels; 2 | 3 | import com.google.common.io.Files; 4 | import com.intel.analytics.zoo.pipeline.inference.OpenVinoInferenceSupportive$; 5 | import org.apache.commons.io.FileUtils; 6 | import org.apache.flink.api.java.tuple.Tuple2; 7 | import org.apache.hadoop.fs.FileSystem; 8 | import org.apache.hadoop.fs.*; 9 | import org.apache.hadoop.hdfs.DistributedFileSystem; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | import sky.tf.GarbageClassificationModel; 13 | import sky.tf.ModelParams; 14 | 15 | import java.io.*; 16 | import java.net.URI; 17 | import java.nio.channels.Channels; 18 | import java.nio.channels.FileChannel; 19 | import java.nio.channels.ReadableByteChannel; 20 | import java.util.ArrayList; 21 | import java.util.HashMap; 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | 26 | /** 27 | * @author SkyPeace 28 | * The model loader. Experimental only. 29 | */ 30 | public class ImageModelLoaderMultipleModels { 31 | private Logger logger = LoggerFactory.getLogger(ImageModelLoaderMultipleModels.class); 32 | private static ImageModelLoaderMultipleModels instance = new ImageModelLoaderMultipleModels(); 33 | private volatile Map modelsMap = 34 | new HashMap(); 35 | 36 | private ImageModelLoaderMultipleModels() {} 37 | 38 | public static ImageModelLoaderMultipleModels getInstance() 39 | { 40 | return instance; 41 | } 42 | 43 | /** 44 | * Generate optimized OpenVino model's data (bytes). -- First step. 45 | * @param savedModelPath 46 | * @return 47 | */ 48 | public List> generateOpenVinoModelData(String savedModelPath, ModelParams modelParams) throws Exception 49 | { 50 | List> modelData = new ArrayList>(); 51 | File modelFile = new File(savedModelPath); 52 | String optimizeModelPath = null; 53 | if(modelFile.isDirectory()) 54 | optimizeModelPath = 55 | optimizeModelFromModelDir(savedModelPath, modelParams); 56 | else 57 | optimizeModelPath = 58 | optimizeModelFromModelPackage(savedModelPath, modelParams); 59 | byte[] xml = readDFSFile(optimizeModelPath + File.separator + "saved_model.xml"); 60 | byte[] bin = readDFSFile(optimizeModelPath + File.separator + "saved_model.bin"); 61 | logger.info("Size of optimized saved_model.xml: " + xml.length); 62 | logger.info("Size of optimized saved_model.bin: " + bin.length); 63 | Tuple2 xmlData = new Tuple2("xml", xml); 64 | Tuple2 binData = new Tuple2("bin", bin); 65 | modelData.add(xmlData); 66 | modelData.add(binData); 67 | return modelData; 68 | } 69 | 70 | /** 71 | * Optimize model from saved model dir. Return optimized model temp directory. 72 | * @param savedModelDir 73 | * @return 74 | * @throws Exception 75 | */ 76 | private String optimizeModelFromModelDir(String savedModelDir, ModelParams modelParams) throws Exception 77 | { 78 | File tmpDir = Files.createTempDir(); 79 | String optimizeModelTmpDir = tmpDir.getCanonicalPath(); 80 | //OpenVinoInferenceSupportive.optimizeTFImageClassificationModel(); 81 | OpenVinoInferenceSupportive$.MODULE$.optimizeTFImageClassificationModel( 82 | savedModelDir + File.separator + "SavedModel", modelParams.getInputShape(), false, 83 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName(), optimizeModelTmpDir); 84 | return optimizeModelTmpDir; 85 | } 86 | 87 | /** 88 | * Optimize model from saved model package (tar.gz file). Return optimized model temp dir. 89 | * @param savedModelPackagePath 90 | * @return 91 | * @throws Exception 92 | */ 93 | private String optimizeModelFromModelPackage(String savedModelPackagePath, ModelParams modelParams) throws Exception 94 | { 95 | byte[] savedModelBytes = readDFSFile(savedModelPackagePath); 96 | File tmpDir = Files.createTempDir(); 97 | String tempDirPath = tmpDir.getCanonicalPath(); 98 | String tarFileName = "saved-model.tar"; 99 | File tarFile = new File(tempDirPath + File.separator + tarFileName); 100 | ByteArrayInputStream tarFileInputStream = new ByteArrayInputStream(savedModelBytes); 101 | ReadableByteChannel tarFileSrc = Channels.newChannel(tarFileInputStream); 102 | FileChannel tarFileDest = (new FileOutputStream(tarFile)).getChannel(); 103 | tarFileDest.transferFrom(tarFileSrc, 0L, 9223372036854775807L); 104 | tarFileDest.close(); 105 | tarFileSrc.close(); 106 | String tarFileAbsolutePath = tarFile.getAbsolutePath(); 107 | String modelRootDir = tempDirPath + File.separator + "saved-model"; 108 | File modelRootDirFile = new File(modelRootDir); 109 | FileUtils.forceMkdir(modelRootDirFile); 110 | //tar -xvf -C 111 | Process proc = Runtime.getRuntime().exec(new String[]{"tar", "-xvf", tarFileAbsolutePath, "-C", modelRootDir}); 112 | //Runtime.getRuntime().exec(new String[]{"ls", "-l", modelRootDir}); 113 | BufferedReader insertReader = new BufferedReader(new InputStreamReader(proc.getInputStream())); 114 | BufferedReader errorReader = new BufferedReader(new InputStreamReader(proc.getErrorStream())); 115 | proc.waitFor(); 116 | if(insertReader!=null) 117 | insertReader.close(); 118 | if(errorReader!=null) 119 | errorReader.close(); 120 | File[] files = modelRootDirFile.listFiles(); 121 | String savedModelTmpDir = files[0].getAbsolutePath(); 122 | logger.info("Saved model temp dir will be used for optimization: " + savedModelTmpDir); 123 | String optimizeModelTmpDir = optimizeModelFromModelDir(savedModelTmpDir, modelParams); 124 | return optimizeModelTmpDir; 125 | } 126 | 127 | /** 128 | * Save optimized OpenVino model's data (bytes) into HDFS files 129 | * The specified dir is also the parent dir of saved model package. 130 | * @param openVinoModelData 131 | * @param modelParams 132 | * @return 133 | */ 134 | public void saveToOpenVinoModelFile(List> openVinoModelData, ModelParams modelParams) throws Exception 135 | { 136 | writeOptimizedModelToDFS(openVinoModelData, modelParams); 137 | } 138 | 139 | /** 140 | * Load OpenVino model with singleton pattern. -- Second step. 141 | * In this case, use this solution by default because it only cost about 2 seconds to load in Map.open(). 142 | * @param modelParams 143 | * @return 144 | */ 145 | public synchronized GarbageClassificationModel loadOpenVINOModelOnce(ModelParams modelParams) throws Exception 146 | { 147 | GarbageClassificationModel model = modelsMap.get(modelParams.getModelName()); 148 | if(model == null) 149 | { 150 | model = this.getModel(modelParams); 151 | modelsMap.put(modelParams.getModelName(), model); 152 | } 153 | model.addRefernce(); 154 | return model; 155 | } 156 | 157 | private GarbageClassificationModel getModel(ModelParams modelParams) throws Exception 158 | { 159 | List> openVinoModelData = 160 | getModelDataFromOptimizedModelDir(modelParams); 161 | byte[] modelXml = openVinoModelData.get(0).f1; 162 | byte[] modelBin = openVinoModelData.get(1).f1; 163 | return new GarbageClassificationModel(modelXml, modelBin); 164 | } 165 | 166 | /** 167 | * Get OpenVino model data from optimized model files 168 | * @param modelParams 169 | * @return 170 | */ 171 | private List> getModelDataFromOptimizedModelDir(ModelParams modelParams) throws Exception 172 | { 173 | String optimizedModelDir = modelParams.getOptimizedModelDir(); 174 | String modelName = modelParams.getModelName(); 175 | List> modelData = new ArrayList>(); 176 | byte[] xml = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.xml"); 177 | byte[] bin = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.bin"); 178 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.xml: " + xml.length); 179 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.bin: " + bin.length); 180 | Tuple2 xmlData = new Tuple2("xml", xml); 181 | Tuple2 binData = new Tuple2("bin", bin); 182 | modelData.add(xmlData); 183 | modelData.add(binData); 184 | return modelData; 185 | } 186 | 187 | /** 188 | * Read saved model files (HDFS) into byte array. 189 | * @param filePath 190 | * @return 191 | */ 192 | private byte[] readDFSFile(String filePath) throws Exception 193 | { 194 | try { 195 | long beginTime = System.currentTimeMillis(); 196 | System.out.println(String.format("READ_MODEL_FILE from %s BEGIN %s", filePath, beginTime)); 197 | Path imageRoot = new Path(filePath); 198 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration(); 199 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName()); 200 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName()); 201 | FileSystem fileSystem = FileSystem.get(new URI(filePath), hadoopConfig); 202 | FileStatus fileStatus = fileSystem.getFileStatus(imageRoot); 203 | //RemoteIterator it = fileSystem.listFiles(imageRoot, false); 204 | long fileLength = fileStatus.getLen(); 205 | FSDataInputStream in = fileSystem.open(imageRoot); 206 | byte[] buffer = new byte[(int) fileLength]; 207 | in.readFully(buffer); 208 | in.close(); 209 | long endTime = System.currentTimeMillis(); 210 | System.out.println(String.format("READ_MODEL_FILE from %s END %s (Cost: %s)", 211 | filePath, endTime, (endTime - beginTime))); 212 | fileSystem.close(); 213 | return buffer; 214 | }catch(Exception ex) 215 | { 216 | String msg = "Read DFS file FAILED. " + ex.getMessage(); 217 | logger.error(msg, ex); 218 | throw ex; 219 | } 220 | } 221 | 222 | /** 223 | * Write the optimized model's data to files (HDFS) 224 | * @param openVinoModelData 225 | * @param modelParams 226 | */ 227 | private void writeOptimizedModelToDFS(List> openVinoModelData, ModelParams modelParams) throws Exception 228 | { 229 | String optimizedModelDir = modelParams.getOptimizedModelDir(); 230 | String modelName = modelParams.getModelName(); 231 | try { 232 | long beginTime = System.currentTimeMillis(); 233 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s BEGIN %s", optimizedModelDir, beginTime)); 234 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration(); 235 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName()); 236 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName()); 237 | FileSystem fileSystem = FileSystem.get(new URI(optimizedModelDir), hadoopConfig); 238 | 239 | String xmlFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.xml"; 240 | System.out.println(String.format("Model xmlFileName: %s", xmlFileName)); 241 | Path xmlFilePath = new Path(xmlFileName); 242 | FSDataOutputStream xmlFileOut = fileSystem.create(xmlFilePath, true); 243 | xmlFileOut.write(openVinoModelData.get(0).f1); 244 | xmlFileOut.flush(); 245 | xmlFileOut.close(); 246 | 247 | String binFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.bin"; 248 | System.out.println(String.format("Model binFileName: %s", xmlFileName)); 249 | Path binFilePath = new Path(binFileName); 250 | FSDataOutputStream binFileOut = fileSystem.create(binFilePath, true); 251 | binFileOut.write(openVinoModelData.get(1).f1); 252 | binFileOut.flush(); 253 | binFileOut.close(); 254 | 255 | long endTime = System.currentTimeMillis(); 256 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s END %s (Cost: %s)", 257 | optimizedModelDir, endTime, (endTime - beginTime))); 258 | fileSystem.close(); 259 | }catch(Exception ex) 260 | { 261 | String msg = "Write DFS file FAILED. " + ex.getMessage(); 262 | logger.error(msg, ex); 263 | throw ex; 264 | } 265 | } 266 | 267 | } 268 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/multiplemodels/ImagePredictSupportiveMultipleModels.java: -------------------------------------------------------------------------------- 1 | package sky.tf.multiplemodels; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ImageData; 4 | import com.intel.analytics.bigdl.opencv.OpenCV; 5 | import com.intel.analytics.zoo.pipeline.inference.JTensor; 6 | import org.opencv.core.Mat; 7 | import sky.tf.*; 8 | 9 | import java.util.*; 10 | 11 | /** 12 | * @author SkyPeace 13 | * The supportive for image classification prediction. Experimental only. 14 | */ 15 | public class ImagePredictSupportiveMultipleModels 16 | { 17 | //private Logger logger = LoggerFactory.getLogger(ImagePredictSupportiveMultipleModels.class); 18 | private List modelParamsList; 19 | private List modelList; 20 | 21 | public ImagePredictSupportiveMultipleModels(List modelList, List modelParamsList) 22 | { 23 | this.modelList = modelList; 24 | this.modelParamsList = modelParamsList; 25 | } 26 | 27 | /** 28 | * Do predict and transfer the label ID to human string. 29 | * @param imageData 30 | * @return 31 | * @throws Exception 32 | */ 33 | public String predictHumanString(ImageData imageData) throws Exception 34 | { 35 | Integer predictId = predict(imageData); 36 | String humanString = this.getImageHumanString(predictId); 37 | if (humanString == null) { 38 | humanString = "class_index_not_available"; 39 | throw new Exception(humanString); 40 | } 41 | return humanString; 42 | } 43 | 44 | /** 45 | * Using multiple models do predict. 46 | * @param imageData 47 | * @return 48 | * @throws Exception 49 | */ 50 | private Integer predict(ImageData imageData) throws Exception { 51 | //Decode jpeg 52 | //Use turbojpeg as jpeg decoder. It is more fast than OpenCV decoder (OpenCV use libjpeg as decoder). 53 | //Refer to https://libjpeg-turbo.org/About/Performance for performance comparison. 54 | Mat matRGB = ImageDataPreprocessing.getRGBMat(imageData, ImageDataPreprocessing.IMAGE_DECODER_TURBOJPEG); 55 | 56 | long t1 = System.currentTimeMillis(); 57 | System.out.println(String.format("%s DO_PREDICT BEGIN %s", "###", t1)); 58 | List modelsPredictionResults = new ArrayList(); 59 | for(int i=0;i> inputs = dataPreprocessing.doPreProcessing(matRGB, preprocessingType, enableMultipleVoters); 87 | List> tensorResults = model.predict(inputs); 88 | List pResults = this.convertToPredictionResults(tensorResults); 89 | PredictionResult pResult = this.getSingleResult(pResults); 90 | return pResult; 91 | } 92 | 93 | /** 94 | * Get single result 95 | * @param pResults 96 | * @return 97 | */ 98 | private PredictionResult getSingleResult(List pResults) 99 | { 100 | if(pResults.size()==1) 101 | return pResults.get(0); 102 | System.out.println(String.format("There are %s multiple results", pResults.size())); 103 | Map distinctResults = new HashMap(); 104 | for(PredictionResult pResult:pResults) 105 | { 106 | Integer predictionId = pResult.getPredictionId(); 107 | PredictionResult existResult = distinctResults.get(predictionId); 108 | if(existResult == null){ 109 | distinctResults.put(predictionId, pResult); 110 | }else{ 111 | existResult.setCount(existResult.getCount() + 1); 112 | existResult.setProbability(Math.max(pResult.getProbability(), existResult.getProbability())); 113 | distinctResults.put(predictionId, existResult); 114 | } 115 | } 116 | PredictionResult maxCountResult = null; 117 | for(Integer key:distinctResults.keySet()) 118 | { 119 | PredictionResult pResult = distinctResults.get(key); 120 | float ratio = (float)pResult.getCount() / pResults.size(); 121 | if(ratio>0.5f) 122 | { 123 | System.out.println("Majority get agreement."); 124 | maxCountResult = pResult; 125 | return maxCountResult; 126 | }else if(pResult.getCount()>1) 127 | maxCountResult = pResult; 128 | } 129 | if(maxCountResult!=null) { 130 | System.out.println("Few models get agreement."); 131 | return maxCountResult; 132 | } 133 | 134 | System.out.println("The models can NOT get agreement."); 135 | float maxProbability = pResults.get(0).getProbability(); 136 | int maxIdx = 0; 137 | for (int i = 1; i < pResults.size(); i++) { 138 | if (pResults.get(i).getProbability() > maxProbability) { 139 | maxProbability = pResults.get(i).getProbability(); 140 | maxIdx = i; 141 | } 142 | } 143 | return pResults.get(maxIdx); 144 | } 145 | 146 | private List convertToPredictionResults(List> result) throws Exception 147 | { 148 | long beginTime = System.currentTimeMillis(); 149 | List pResults = new ArrayList(); 150 | System.out.println("result.size(): " + result.size()); 151 | System.out.println("result.get(0).size(): " + result.get(0).size()); 152 | for(int j=0; j 0) { 156 | float maxProbability = predictData[0]; 157 | int maxNo = 0; 158 | for (int i = 1; i < predictData.length; i++) { 159 | if (predictData[i] > maxProbability) { 160 | maxProbability = predictData[i]; 161 | maxNo = i; 162 | } 163 | } 164 | PredictionResult pResult = new PredictionResult(); 165 | pResult.setPredictionId(maxNo); 166 | pResult.setProbability(maxProbability); 167 | pResult.setCount(1); 168 | pResults.add(pResult); 169 | }else{ 170 | throw new Exception("ERROR: predictData.length=0"); 171 | } 172 | } 173 | long endTime = System.currentTimeMillis(); 174 | if((endTime-beginTime)>5) 175 | System.out.println(String.format("%s CONVERT_TO_PREDICTION_RESULT END %s (Cost: %s)", 176 | "###", endTime, (endTime - beginTime))); 177 | return pResults; 178 | } 179 | 180 | /** 181 | * Load turbojpeg library. 182 | * @throws Exception 183 | */ 184 | private void loadTurboJpeg() 185 | { 186 | long beginTime = System.currentTimeMillis(); 187 | if(!TurboJpegLoader.isTurbojpegLoaded()) { 188 | throw new RuntimeException("LOAD_TURBOJPEG library failed. Please check."); 189 | } 190 | long endTime = System.currentTimeMillis(); 191 | if((endTime - beginTime) > 1) 192 | System.out.println(String.format("LOAD_TURBOJPEG END %s (Cost: %s)", endTime, (endTime - beginTime))); 193 | } 194 | 195 | /** 196 | * Load OpenCV library. 197 | * @throws Exception 198 | */ 199 | private void loadOpenCV() 200 | { 201 | long beginTime = System.currentTimeMillis(); 202 | if(!OpenCV.isOpenCVLoaded()) { 203 | throw new RuntimeException("LOAD_OPENCV library failed. Please check."); 204 | } 205 | long endTime = System.currentTimeMillis(); 206 | if((endTime - beginTime) > 10) 207 | System.out.println(String.format("LOAD_OPENCV END %s (Cost: %s)", endTime, (endTime - beginTime))); 208 | 209 | //if(!OpenvinoNativeLoader.load()) 210 | // throw new RuntimeException("LOAD_Openvino library failed. Please check."); 211 | //if(!TFNetNative.isLoaded()) 212 | // throw new RuntimeException("LOAD_TFNetNative library failed. Please check."); 213 | } 214 | 215 | /** 216 | * Get image class human string. 217 | * @return 218 | */ 219 | private String getImageHumanString(Integer id) 220 | { 221 | return ImageClassIndex.getInsatnce().getImageHumanstring(String.valueOf(id)); 222 | } 223 | 224 | /** 225 | * For first time warm-dummy check only. 226 | * @throws Exception 227 | */ 228 | public void firstTimeDummyCheck() throws Exception 229 | { 230 | for(int i=0; i(); 250 | list.add(tensor); 251 | List> inputs = new ArrayList>(); 252 | inputs.add(list); 253 | pModel.predict(inputs); 254 | 255 | long endTime = System.currentTimeMillis(); 256 | System.out.println(String.format("FIRST_PREDICT END %s (Cost: %s)", endTime, (endTime - beginTime))); 257 | } 258 | 259 | } 260 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/multiplemodels/OpenVinoModelGeneratorMultipleModels.java: -------------------------------------------------------------------------------- 1 | package sky.tf.multiplemodels; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant; 4 | import org.apache.flink.api.java.tuple.Tuple2; 5 | import sky.tf.ModelParams; 6 | 7 | import java.io.File; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | /** 12 | * The class use for generate optimized OpenVino models (Multiple different models for prediction). Experimental only. 13 | * @author SkyPeace 14 | */ 15 | public class OpenVinoModelGeneratorMultipleModels { 16 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH); 17 | private String imageModelPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH); 18 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH"); 19 | private List modelParamsList; 20 | 21 | public void execute() 22 | { 23 | if (null == this.modelInferencePath) { 24 | throw new RuntimeException("ImageFlatMap(): Not set MODEL_INFERENCE_PATH environmental variable"); 25 | } 26 | if (null == this.imageModelPackagePath) { 27 | throw new RuntimeException("ImageFlatMap(): Not set imageModelPathPackagePath environmental variable"); 28 | } 29 | String imageModelPackageDir = 30 | imageModelPackagePath.substring(0, imageModelPackagePath.lastIndexOf(File.separator)); 31 | if(!imageModelPackageDir.equalsIgnoreCase(modelInferencePath)) 32 | { 33 | System.out.println("WARN: modelInferencePath NOT EQUAL imageModelPathPackageDir"); 34 | System.out.println(String.format("modelInferencePath: %s", modelInferencePath)); 35 | System.out.println(String.format("imageModelPackageDir: %s", imageModelPackageDir)); 36 | System.out.println(String.format("imageModelPath: %s", imageModelPath)); 37 | } 38 | 39 | try { 40 | //Generate optimized OpenVino model 41 | this.modelParamsList = new ArrayList(); 42 | //Here only generate 5 models for test. 43 | for(int i=0;i<5;i++) { 44 | String tfModelPath = modelInferencePath + File.separator + "SavedModel/model" + i; 45 | String optimizedOpenVinoModelDir = imageModelPackageDir; 46 | ModelParams modelParams = new ModelParams(); 47 | modelParams.setModelType("inception"); 48 | modelParams.setModelName("model" + i); 49 | modelParams.setInputName("input_1"); 50 | modelParams.setInputShape(new int[]{1, 299, 299, 3}); 51 | modelParams.setMeanValues(new float[]{127.5f, 127.5f, 127.5f}); 52 | modelParams.setScale(127.5f); 53 | modelParams.setOptimizedModelDir(optimizedOpenVinoModelDir); 54 | //Call Zoo API generate optimized OpenVino Model 55 | List> optimizedOpenVinoModelData = 56 | ImageModelLoaderMultipleModels.getInstance().generateOpenVinoModelData(tfModelPath, modelParams); 57 | ImageModelLoaderMultipleModels.getInstance(). 58 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, modelParams); 59 | modelParamsList.add(modelParams); 60 | } 61 | }catch(Exception ex) 62 | { 63 | ex.printStackTrace(); 64 | throw new RuntimeException("WARN: OpenVinoModelGenerator.execute() FAILED."); 65 | } 66 | } 67 | 68 | public List getModelParamsList() 69 | { 70 | return this.modelParamsList; 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/threemodels/ImageFlatMap3Models.java: -------------------------------------------------------------------------------- 1 | package sky.tf.threemodels; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant; 4 | import com.alibaba.tianchi.garbage_image_util.IdLabel; 5 | import com.alibaba.tianchi.garbage_image_util.ImageData; 6 | import org.apache.flink.api.common.functions.RichFlatMapFunction; 7 | import org.apache.flink.configuration.Configuration; 8 | import org.apache.flink.util.Collector; 9 | import sky.tf.GarbageClassificationModel; 10 | import sky.tf.ImageClassIndex; 11 | import sky.tf.ModelParams; 12 | 13 | /** 14 | * @author SkyPeace 15 | * The Map operator for predict the class of image. 16 | */ 17 | public class ImageFlatMap3Models extends RichFlatMapFunction { 18 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH); 19 | private String imageModelPathPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH); 20 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH"); 21 | private ModelParams model1Params; 22 | private ModelParams model2Params; 23 | private ModelParams model3Params; 24 | private transient GarbageClassificationModel model1; 25 | private transient GarbageClassificationModel model2; 26 | private transient GarbageClassificationModel model3; 27 | private transient ImagePredictSupportive3Models supportive; 28 | 29 | public ImageFlatMap3Models(ModelParams model1Params, ModelParams model2Params, ModelParams model3Params) 30 | { 31 | this.model1Params = model1Params; 32 | this.model2Params = model2Params; 33 | this.model3Params = model3Params; 34 | } 35 | 36 | @Override 37 | public void open(Configuration parameters) throws Exception 38 | { 39 | //For troubleshooting use. 40 | System.out.println(String.format("ImageFlatMap.open(): imageModelPath is %s", this.imageModelPath)); 41 | System.out.println(String.format("ImageFlatMap.open(): modelInferencePath is %s", this.modelInferencePath)); 42 | System.out.println(String.format("ImageFlatMap.open(): imageModelPathPackagePath is %s", this.imageModelPathPackagePath)); 43 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel1Dir is %s", 44 | this.model1Params.getOptimizedModelDir())); 45 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel2Dir is %s", 46 | this.model2Params.getOptimizedModelDir())); 47 | System.out.println(String.format("ImageFlatMap.open(): optimizedOpenVinoModel3Dir is %s", 48 | this.model3Params.getOptimizedModelDir())); 49 | 50 | //Step2: Load optimized OpenVino model from files (HDFS). (Cost about 3 seconds total models, quickly enough) 51 | this.model1 = ImageModelLoader3Models.getInstance().loadOpenVINOModel1Once(this.model1Params); 52 | this.model2 = ImageModelLoader3Models.getInstance().loadOpenVINOModel2Once(this.model2Params); 53 | this.model3 = ImageModelLoader3Models.getInstance().loadOpenVINOModel3Once(this.model3Params); 54 | 55 | ImageClassIndex.getInsatnce().loadClassIndexMap(modelInferencePath); 56 | ValidationDatasetAnalyzer validationTrainer = new ValidationDatasetAnalyzer(); 57 | validationTrainer.searchOptimizedThreshold(modelInferencePath); 58 | float preferableThreshold1 = validationTrainer.getPreferableThreshold1(); 59 | float preferableThreshold2 = validationTrainer.getPreferableThreshold2(); 60 | this.supportive = new ImagePredictSupportive3Models(this.model1, this.model1Params, 61 | this.model2, this.model2Params, this.model3, this.model3Params, 62 | preferableThreshold1, preferableThreshold2); 63 | //First time warm-dummy check 64 | this.supportive.firstTimeDummyCheck(); 65 | } 66 | 67 | @Override 68 | public void flatMap(ImageData value, Collector out) 69 | throws Exception 70 | { 71 | IdLabel idLabel = new IdLabel(); 72 | idLabel.setId(value.getId()); 73 | 74 | long beginTime = System.currentTimeMillis(); 75 | System.out.println(String.format("PREDICT_PROCESS BEGIN %s", beginTime)); 76 | 77 | String imageLabelString = supportive.predictHumanString(value); 78 | 79 | long endTime = System.currentTimeMillis(); 80 | System.out.println(String.format("PREDICT_PROCESS END %s (Cost: %s)", endTime, (endTime - beginTime))); 81 | 82 | //Check whether elapsed time >= threshold. Logging it for review. 83 | if((endTime - beginTime)>495) 84 | System.out.println(String.format("PREDICT_PROCESS MAYBE EXCEED THRESHOLD %s (Cost: %s)", 85 | endTime, (endTime - beginTime))); 86 | 87 | idLabel.setLabel(imageLabelString); 88 | out.collect(idLabel); 89 | } 90 | 91 | @Override 92 | public void close() throws Exception 93 | { 94 | System.out.println(String.format("getNumberOfParallelSubtasks: %s", 95 | this.getRuntimeContext().getNumberOfParallelSubtasks())); 96 | if(model1!=null) 97 | model1.release(); 98 | if(model2!=null) 99 | model2.release(); 100 | if(model3!=null) 101 | model3.release(); 102 | } 103 | 104 | } 105 | 106 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/threemodels/ImageModelLoader3Models.java: -------------------------------------------------------------------------------- 1 | package sky.tf.threemodels; 2 | 3 | import com.google.common.io.Files; 4 | import com.intel.analytics.zoo.pipeline.inference.OpenVinoInferenceSupportive$; 5 | import org.apache.commons.io.FileUtils; 6 | import org.apache.flink.api.java.tuple.Tuple2; 7 | import org.apache.hadoop.fs.*; 8 | import org.apache.hadoop.hdfs.DistributedFileSystem; 9 | import org.slf4j.Logger; 10 | import org.slf4j.LoggerFactory; 11 | import sky.tf.GarbageClassificationModel; 12 | import sky.tf.ModelParams; 13 | 14 | import java.io.*; 15 | import java.net.URI; 16 | import java.nio.channels.Channels; 17 | import java.nio.channels.FileChannel; 18 | import java.nio.channels.ReadableByteChannel; 19 | import java.util.ArrayList; 20 | import java.util.List; 21 | 22 | 23 | /** 24 | * @author SkyPeace 25 | * The model loader 26 | */ 27 | public class ImageModelLoader3Models { 28 | private Logger logger = LoggerFactory.getLogger(ImageModelLoader3Models.class); 29 | private static ImageModelLoader3Models instance = new ImageModelLoader3Models(); 30 | private volatile GarbageClassificationModel model1; 31 | private volatile GarbageClassificationModel model2; 32 | private volatile GarbageClassificationModel model3; 33 | 34 | private ImageModelLoader3Models() {} 35 | 36 | public static ImageModelLoader3Models getInstance() 37 | { 38 | return instance; 39 | } 40 | 41 | /** 42 | * Generate optimized OpenVino model's data (bytes). -- First step. 43 | * @param savedModelPath 44 | * @return 45 | */ 46 | public List> generateOpenVinoModelData(String savedModelPath, ModelParams modelParams) throws Exception 47 | { 48 | List> modelData = new ArrayList>(); 49 | File modelFile = new File(savedModelPath); 50 | String optimizeModelPath = null; 51 | if(modelFile.isDirectory()) 52 | optimizeModelPath = 53 | optimizeModelFromModelDir(savedModelPath, modelParams); 54 | else 55 | optimizeModelPath = 56 | optimizeModelFromModelPackage(savedModelPath, modelParams); 57 | byte[] xml = readDFSFile(optimizeModelPath + File.separator + "saved_model.xml"); 58 | byte[] bin = readDFSFile(optimizeModelPath + File.separator + "saved_model.bin"); 59 | logger.info("Size of optimized saved_model.xml: " + xml.length); 60 | logger.info("Size of optimized saved_model.bin: " + bin.length); 61 | Tuple2 xmlData = new Tuple2("xml", xml); 62 | Tuple2 binData = new Tuple2("bin", bin); 63 | modelData.add(xmlData); 64 | modelData.add(binData); 65 | return modelData; 66 | } 67 | 68 | /** 69 | * Optimize model from saved model dir. Return optimized model temp directory. 70 | * @param savedModelDir 71 | * @return 72 | * @throws Exception 73 | */ 74 | private String optimizeModelFromModelDir(String savedModelDir, ModelParams modelParams) throws Exception 75 | { 76 | File tmpDir = Files.createTempDir(); 77 | String optimizeModelTmpDir = tmpDir.getCanonicalPath(); 78 | //OpenVinoInferenceSupportive.optimizeTFImageClassificationModel(); 79 | OpenVinoInferenceSupportive$.MODULE$.optimizeTFImageClassificationModel( 80 | savedModelDir + File.separator + "SavedModel", modelParams.getInputShape(), false, 81 | modelParams.getMeanValues(), modelParams.getScale(), modelParams.getInputName(), optimizeModelTmpDir); 82 | return optimizeModelTmpDir; 83 | } 84 | 85 | /** 86 | * Optimize model from saved model package (tar.gz file). Return optimized model temp dir. 87 | * @param savedModelPackagePath 88 | * @return 89 | * @throws Exception 90 | */ 91 | private String optimizeModelFromModelPackage(String savedModelPackagePath, ModelParams modelParams) throws Exception 92 | { 93 | byte[] savedModelBytes = readDFSFile(savedModelPackagePath); 94 | File tmpDir = Files.createTempDir(); 95 | String tempDirPath = tmpDir.getCanonicalPath(); 96 | String tarFileName = "saved-model.tar"; 97 | File tarFile = new File(tempDirPath + File.separator + tarFileName); 98 | ByteArrayInputStream tarFileInputStream = new ByteArrayInputStream(savedModelBytes); 99 | ReadableByteChannel tarFileSrc = Channels.newChannel(tarFileInputStream); 100 | FileChannel tarFileDest = (new FileOutputStream(tarFile)).getChannel(); 101 | tarFileDest.transferFrom(tarFileSrc, 0L, 9223372036854775807L); 102 | tarFileDest.close(); 103 | tarFileSrc.close(); 104 | String tarFileAbsolutePath = tarFile.getAbsolutePath(); 105 | String modelRootDir = tempDirPath + File.separator + "saved-model"; 106 | File modelRootDirFile = new File(modelRootDir); 107 | FileUtils.forceMkdir(modelRootDirFile); 108 | //tar -xvf -C 109 | Process proc = Runtime.getRuntime().exec(new String[]{"tar", "-xvf", tarFileAbsolutePath, "-C", modelRootDir}); 110 | //Runtime.getRuntime().exec(new String[]{"ls", "-l", modelRootDir}); 111 | BufferedReader insertReader = new BufferedReader(new InputStreamReader(proc.getInputStream())); 112 | BufferedReader errorReader = new BufferedReader(new InputStreamReader(proc.getErrorStream())); 113 | proc.waitFor(); 114 | if(insertReader!=null) 115 | insertReader.close(); 116 | if(errorReader!=null) 117 | errorReader.close(); 118 | File[] files = modelRootDirFile.listFiles(); 119 | String savedModelTmpDir = files[0].getAbsolutePath(); 120 | logger.info("Saved model temp dir will be used for optimization: " + savedModelTmpDir); 121 | String optimizeModelTmpDir = optimizeModelFromModelDir(savedModelTmpDir, modelParams); 122 | return optimizeModelTmpDir; 123 | } 124 | 125 | /** 126 | * Save optimized OpenVino model's data (bytes) into HDFS files 127 | * The specified dir is also the parent dir of saved model package. 128 | * @param openVinoModelData 129 | * @param modelParams 130 | * @return 131 | */ 132 | public void saveToOpenVinoModelFile(List> openVinoModelData, ModelParams modelParams) throws Exception 133 | { 134 | writeOptimizedModelToDFS(openVinoModelData, modelParams); 135 | } 136 | 137 | /** 138 | * Load OpenVino model with singleton pattern. -- Second step. 139 | * In this case, use this solution by default because it only cost about 2 seconds to load in Map.open(). 140 | * @param modelParams 141 | * @return 142 | */ 143 | public synchronized GarbageClassificationModel loadOpenVINOModel1Once(ModelParams modelParams) throws Exception 144 | { 145 | if(this.model1 == null) 146 | { 147 | this.model1 = this.getModel(modelParams); 148 | } 149 | this.model1.addRefernce(); 150 | return this.model1; 151 | } 152 | 153 | public synchronized GarbageClassificationModel loadOpenVINOModel2Once(ModelParams modelParams) throws Exception 154 | { 155 | if(this.model2 == null) 156 | { 157 | this.model2 = this.getModel(modelParams); 158 | } 159 | this.model2.addRefernce(); 160 | return this.model2; 161 | } 162 | 163 | public synchronized GarbageClassificationModel loadOpenVINOModel3Once(ModelParams modelParams) throws Exception 164 | { 165 | if(this.model3 == null) 166 | { 167 | this.model3 = this.getModel(modelParams); 168 | } 169 | this.model3.addRefernce(); 170 | return this.model3; 171 | } 172 | 173 | private GarbageClassificationModel getModel(ModelParams modelParams) throws Exception 174 | { 175 | List> openVinoModelData = 176 | getModelDataFromOptimizedModelDir(modelParams); 177 | byte[] modelXml = openVinoModelData.get(0).f1; 178 | byte[] modelBin = openVinoModelData.get(1).f1; 179 | return new GarbageClassificationModel(modelXml, modelBin); 180 | } 181 | 182 | /** 183 | * Get OpenVino model data from optimized model files 184 | * @param modelParams 185 | * @return 186 | */ 187 | private List> getModelDataFromOptimizedModelDir(ModelParams modelParams) throws Exception 188 | { 189 | String optimizedModelDir = modelParams.getOptimizedModelDir(); 190 | String modelName = modelParams.getModelName(); 191 | List> modelData = new ArrayList>(); 192 | byte[] xml = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.xml"); 193 | byte[] bin = readDFSFile(optimizedModelDir + File.separator + "optimized_openvino_" + modelName + "_KEEPME.bin"); 194 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.xml: " + xml.length); 195 | logger.info("Size of optimized_openvino_" + modelName + "_KEEPME.bin: " + bin.length); 196 | Tuple2 xmlData = new Tuple2("xml", xml); 197 | Tuple2 binData = new Tuple2("bin", bin); 198 | modelData.add(xmlData); 199 | modelData.add(binData); 200 | return modelData; 201 | } 202 | 203 | /** 204 | * Read saved model files (HDFS) into byte array. 205 | * @param filePath 206 | * @return 207 | */ 208 | private byte[] readDFSFile(String filePath) throws Exception 209 | { 210 | try { 211 | long beginTime = System.currentTimeMillis(); 212 | System.out.println(String.format("READ_MODEL_FILE from %s BEGIN %s", filePath, beginTime)); 213 | Path imageRoot = new Path(filePath); 214 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration(); 215 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName()); 216 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName()); 217 | FileSystem fileSystem = FileSystem.get(new URI(filePath), hadoopConfig); 218 | FileStatus fileStatus = fileSystem.getFileStatus(imageRoot); 219 | //RemoteIterator it = fileSystem.listFiles(imageRoot, false); 220 | long fileLength = fileStatus.getLen(); 221 | FSDataInputStream in = fileSystem.open(imageRoot); 222 | byte[] buffer = new byte[(int) fileLength]; 223 | in.readFully(buffer); 224 | in.close(); 225 | long endTime = System.currentTimeMillis(); 226 | System.out.println(String.format("READ_MODEL_FILE from %s END %s (Cost: %s)", 227 | filePath, endTime, (endTime - beginTime))); 228 | fileSystem.close(); 229 | return buffer; 230 | }catch(Exception ex) 231 | { 232 | String msg = "Read DFS file FAILED. " + ex.getMessage(); 233 | logger.error(msg, ex); 234 | throw ex; 235 | } 236 | } 237 | 238 | /** 239 | * Write the optimized model's data to files (HDFS) 240 | * @param openVinoModelData 241 | * @param modelParams 242 | */ 243 | private void writeOptimizedModelToDFS(List> openVinoModelData, ModelParams modelParams) throws Exception 244 | { 245 | String optimizedModelDir = modelParams.getOptimizedModelDir(); 246 | String modelName = modelParams.getModelName(); 247 | try { 248 | long beginTime = System.currentTimeMillis(); 249 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s BEGIN %s", optimizedModelDir, beginTime)); 250 | org.apache.hadoop.conf.Configuration hadoopConfig = new org.apache.hadoop.conf.Configuration(); 251 | hadoopConfig.set("fs.hdfs.impl", DistributedFileSystem.class.getName()); 252 | hadoopConfig.set("fs.file.impl", LocalFileSystem.class.getName()); 253 | FileSystem fileSystem = FileSystem.get(new URI(optimizedModelDir), hadoopConfig); 254 | 255 | String xmlFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.xml"; 256 | System.out.println(String.format("Model xmlFileName: %s", xmlFileName)); 257 | Path xmlFilePath = new Path(xmlFileName); 258 | FSDataOutputStream xmlFileOut = fileSystem.create(xmlFilePath, true); 259 | xmlFileOut.write(openVinoModelData.get(0).f1); 260 | xmlFileOut.flush(); 261 | xmlFileOut.close(); 262 | 263 | String binFileName = optimizedModelDir + File.separator + "optimized_openvino_"+ modelName + "_KEEPME.bin"; 264 | System.out.println(String.format("Model binFileName: %s", xmlFileName)); 265 | Path binFilePath = new Path(binFileName); 266 | FSDataOutputStream binFileOut = fileSystem.create(binFilePath, true); 267 | binFileOut.write(openVinoModelData.get(1).f1); 268 | binFileOut.flush(); 269 | binFileOut.close(); 270 | 271 | long endTime = System.currentTimeMillis(); 272 | System.out.println(String.format("WRITE_OPTIMIZED_MODEL_FILE %s END %s (Cost: %s)", 273 | optimizedModelDir, endTime, (endTime - beginTime))); 274 | fileSystem.close(); 275 | }catch(Exception ex) 276 | { 277 | String msg = "Write DFS file FAILED. " + ex.getMessage(); 278 | logger.error(msg, ex); 279 | throw ex; 280 | } 281 | } 282 | 283 | } 284 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/threemodels/ImagePredictSupportive3Models.java: -------------------------------------------------------------------------------- 1 | package sky.tf.threemodels; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ImageData; 4 | import com.intel.analytics.bigdl.opencv.OpenCV; 5 | import com.intel.analytics.zoo.pipeline.inference.JTensor; 6 | import org.opencv.core.*; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | import sky.tf.*; 10 | 11 | import java.util.*; 12 | 13 | /** 14 | * @author SkyPeace 15 | * The supportive for image classification prediction. 16 | */ 17 | public class ImagePredictSupportive3Models 18 | { 19 | private Logger logger = LoggerFactory.getLogger(ImagePredictSupportive3Models.class); 20 | final long MAX_PROCESSING_TIME = 400; //ms. 400 21 | 22 | private GarbageClassificationModel model1; 23 | private ModelParams model1Params; 24 | private GarbageClassificationModel model2; 25 | private ModelParams model2Params; 26 | private GarbageClassificationModel model3; 27 | private ModelParams model3Params; 28 | 29 | private float preferableThreshold1 = 0.94f; 30 | private float preferableThreshold2 = 0.81f; 31 | 32 | public ImagePredictSupportive3Models(GarbageClassificationModel model1, ModelParams model1Params, 33 | GarbageClassificationModel model2, ModelParams model2Params, 34 | GarbageClassificationModel model3, ModelParams model3Params, 35 | float preferableThreshold1, float preferableThreshold2) 36 | { 37 | this.model1 = model1; 38 | this.model1Params = model1Params; 39 | this.model2 = model2; 40 | this.model2Params = model2Params; 41 | this.model3 = model3; 42 | this.model3Params = model3Params; 43 | this.preferableThreshold1 = preferableThreshold1; 44 | this.preferableThreshold2 = preferableThreshold2; 45 | } 46 | 47 | /** 48 | * Do predict and transfer the label ID to human string. 49 | * @param imageData 50 | * @return 51 | * @throws Exception 52 | */ 53 | public String predictHumanString(ImageData imageData) throws Exception 54 | { 55 | Integer predictId = predict(imageData); 56 | String humanString = this.getImageHumanString(predictId); 57 | if (humanString == null) { 58 | humanString = "class_index_not_available"; 59 | throw new Exception(humanString); 60 | } 61 | return humanString; 62 | } 63 | 64 | /** 65 | * Do predict using 3 different models 66 | * @param imageData 67 | * @return 68 | * @throws Exception 69 | */ 70 | public Integer predict(ImageData imageData) throws Exception { 71 | 72 | long processBeginTime = System.currentTimeMillis(); 73 | //Decode jpeg 74 | //Use turbojpeg as jpeg decoder. It is more fast than OpenCV decoder (OpenCV use libjpeg as decoder). 75 | //Refer to https://libjpeg-turbo.org/About/Performance for performance comparison. 76 | Mat matRGB = ImageDataPreprocessing.getRGBMat(imageData, ImageDataPreprocessing.IMAGE_DECODER_TURBOJPEG); 77 | 78 | //Model1 do prediction 79 | ImageDataPreprocessing dataPreprocessing = new ImageDataPreprocessing(model1Params); 80 | List> inputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_VGG, false); 81 | long t1 = System.currentTimeMillis(); 82 | System.out.println(String.format("%s MODEL1_DO_PREDICT BEGIN %s", "###", t1)); 83 | List> tensorResults = this.model1.predict(inputs); 84 | if (tensorResults == null && tensorResults.get(0) == null || tensorResults.get(0).size() == 0) { 85 | throw new Exception(String.format("ERROR: %s Model1 predict result is null.", imageData.getId())); 86 | } 87 | long t2 = System.currentTimeMillis(); 88 | System.out.println(String.format("%s MODEL1_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1))); 89 | 90 | List primaryResults = this.convertToPredictionResults(tensorResults); 91 | PredictionResult primaryResult = this.getSingleResult(primaryResults); 92 | Integer primaryPredictionId = primaryResult.getPredictionId(); 93 | 94 | if (primaryResult.getProbability() >= preferableThreshold1) { 95 | System.out.println("Model1's HP predictionId saved time."); 96 | return primaryPredictionId; 97 | } 98 | 99 | long timeUsed = System.currentTimeMillis() - processBeginTime; 100 | if(timeUsed>=MAX_PROCESSING_TIME) 101 | { 102 | System.out.println("There maybe no enough time to try multiple modles to do predict."); 103 | return primaryPredictionId; 104 | } 105 | 106 | System.out.println("There is enough time to try multiple modles to do predict."); 107 | 108 | //Use model2 do prediction 109 | dataPreprocessing = new ImageDataPreprocessing(model2Params); 110 | List> secondaryInputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_INCEPTION, false); 111 | t1 = System.currentTimeMillis(); 112 | System.out.println(String.format("%s MODEL2_DO_PREDICT BEGIN %s", "###", t1)); 113 | List> secondaryTensorResults = this.model2.predict(secondaryInputs); 114 | if (secondaryTensorResults == null && secondaryTensorResults.get(0) == null || secondaryTensorResults.get(0).size() == 0) { 115 | throw new Exception(String.format("ERROR: %s Model2 predict result is null.", imageData.getId())); 116 | } 117 | t2 = System.currentTimeMillis(); 118 | System.out.println(String.format("%s MODEL2_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1))); 119 | 120 | List secondaryResults = this.convertToPredictionResults(secondaryTensorResults); 121 | PredictionResult secondaryResult = this.getSingleResult(secondaryResults); 122 | Integer secondaryPredictionId = secondaryResult.getPredictionId(); 123 | 124 | if (secondaryResult.getProbability() >= preferableThreshold1) 125 | { 126 | if (!secondaryPredictionId.equals(primaryPredictionId)) { 127 | System.out.println("Force use model2's HP predictionId."); 128 | } else { 129 | System.out.println("Model2's HP predictionId saved time."); 130 | } 131 | return secondaryPredictionId; 132 | } 133 | 134 | //Use model3 do prediction 135 | dataPreprocessing = new ImageDataPreprocessing(model3Params); 136 | List> thirdInputs = dataPreprocessing.doPreProcessing(matRGB, ImageDataPreprocessing.PREPROCESSING_INCEPTION, false); 137 | t1 = System.currentTimeMillis(); 138 | System.out.println(String.format("%s MODEL3_DO_PREDICT BEGIN %s", "###", t1)); 139 | List> thirdTensorResults = this.model3.predict(thirdInputs); 140 | if (thirdTensorResults == null && thirdTensorResults.get(0) == null || thirdTensorResults.get(0).size() == 0) { 141 | throw new Exception(String.format("ERROR: %s Model3 predict result is null.", imageData.getId())); 142 | } 143 | t2 = System.currentTimeMillis(); 144 | System.out.println(String.format("%s MODEL3_DO_PREDICT END %s (Cost: %s)", "###", t2, (t2 - t1))); 145 | List thirdResults = this.convertToPredictionResults(thirdTensorResults); 146 | PredictionResult thirdResult = this.getSingleResult(thirdResults); 147 | Integer thirdPredictionId = thirdResult.getPredictionId(); 148 | 149 | /* 150 | if (thirdResult.getProbability() >= preferableThreshold1) 151 | { 152 | if (!thirdPredictionId.equals(primaryPredictionId)) { 153 | System.out.println("Force use model3's HP predictionId."); 154 | } else { 155 | System.out.println("Model3's HP predictionId saved time."); 156 | } 157 | return thirdPredictionId; 158 | } 159 | */ 160 | 161 | if(thirdPredictionId.equals(primaryPredictionId)&&thirdPredictionId.equals(secondaryPredictionId)) 162 | { 163 | System.out.println("Model1, model2 and model3 get agreement."); 164 | return primaryPredictionId; 165 | } 166 | if (primaryPredictionId.equals(secondaryPredictionId)) { 167 | System.out.println("Model1 and model2 get agreement."); 168 | return primaryPredictionId; 169 | } 170 | if (primaryPredictionId.equals(thirdPredictionId)) { 171 | System.out.println("Model1 and model3 get agreement."); 172 | return primaryPredictionId; 173 | } 174 | if (thirdPredictionId.equals(secondaryPredictionId)) { 175 | System.out.println("Model3 and model2 get agreement."); 176 | return secondaryPredictionId; 177 | } 178 | System.out.println("Model1, model2 and model3 does NOT get any agreement."); 179 | 180 | if (secondaryResult.getProbability() >= preferableThreshold2) 181 | { 182 | System.out.println("Model2(HP) is the last choice."); 183 | return secondaryPredictionId; 184 | } 185 | if(thirdResult.getProbability() >= preferableThreshold2){ 186 | System.out.println("Model3(LP) is the last choice."); 187 | return thirdPredictionId; 188 | } 189 | 190 | System.out.println("Model1(LP) is the last choice."); 191 | return primaryPredictionId; 192 | } 193 | 194 | /** 195 | * Get single result 196 | * @param pResults 197 | * @return 198 | */ 199 | private PredictionResult getSingleResult(List pResults) 200 | { 201 | if(pResults.size()==1) 202 | return pResults.get(0); 203 | System.out.println(String.format("There are %s multiple results", pResults.size())); 204 | Map distinctResults = new HashMap(); 205 | for(PredictionResult pResult:pResults) 206 | { 207 | Integer predictionId = pResult.getPredictionId(); 208 | PredictionResult existResult = distinctResults.get(predictionId); 209 | if(existResult == null){ 210 | distinctResults.put(predictionId, pResult); 211 | }else{ 212 | existResult.setCount(existResult.getCount() + 1); 213 | existResult.setProbability( Math.max(pResult.getProbability(), existResult.getProbability())); 214 | distinctResults.put(predictionId, existResult); 215 | } 216 | } 217 | PredictionResult maxResult = null; 218 | for(Integer key:distinctResults.keySet()) 219 | { 220 | PredictionResult pResult = distinctResults.get(key); 221 | float ratio = (float)pResult.getCount() / pResults.size(); 222 | if(ratio>0.5f) 223 | { 224 | System.out.println("Majority win."); 225 | maxResult = pResult; 226 | break; 227 | } 228 | } 229 | if(maxResult==null) 230 | maxResult = pResults.get(0); 231 | return maxResult; 232 | } 233 | 234 | /** 235 | * Convert JTensor list to prediciton results. 236 | * @param result 237 | * @return 238 | * @throws Exception 239 | */ 240 | private List convertToPredictionResults(List> result) throws Exception 241 | { 242 | long beginTime = System.currentTimeMillis(); 243 | List pResults = new ArrayList(); 244 | System.out.println("result.size(): " + result.size()); 245 | System.out.println("result.get(0).size(): " + result.get(0).size()); 246 | for(int j=0; j 0) { 250 | float maxProbability = predictData[0]; 251 | int maxNo = 0; 252 | for (int i = 1; i < predictData.length; i++) { 253 | if (predictData[i] > maxProbability) { 254 | maxProbability = predictData[i]; 255 | maxNo = i; 256 | } 257 | } 258 | PredictionResult pResult = new PredictionResult(); 259 | pResult.setPredictionId(maxNo); 260 | pResult.setProbability(maxProbability); 261 | pResult.setCount(1); 262 | pResults.add(pResult); 263 | }else{ 264 | throw new Exception("ERROR: predictData.length=0"); 265 | } 266 | } 267 | long endTime = System.currentTimeMillis(); 268 | if((endTime-beginTime)>5) 269 | System.out.println(String.format("%s CONVERT_TO_PREDICTION_RESULT END %s (Cost: %s)", 270 | "###", endTime, (endTime - beginTime))); 271 | return pResults; 272 | } 273 | 274 | /** 275 | * Load turbojpeg library. 276 | * @throws Exception 277 | */ 278 | private void loadTurboJpeg() 279 | { 280 | long beginTime = System.currentTimeMillis(); 281 | if(!TurboJpegLoader.isTurbojpegLoaded()) { 282 | throw new RuntimeException("LOAD_TURBOJPEG library failed. Please check."); 283 | } 284 | long endTime = System.currentTimeMillis(); 285 | if((endTime - beginTime) > 1) 286 | System.out.println(String.format("LOAD_TURBOJPEG END %s (Cost: %s)", endTime, (endTime - beginTime))); 287 | } 288 | 289 | /** 290 | * Load OpenCV library. 291 | * @throws Exception 292 | */ 293 | private void loadOpenCV() 294 | { 295 | long beginTime = System.currentTimeMillis(); 296 | if(!OpenCV.isOpenCVLoaded()) { 297 | throw new RuntimeException("LOAD_OPENCV library failed. Please check."); 298 | } 299 | long endTime = System.currentTimeMillis(); 300 | if((endTime - beginTime) > 10) 301 | System.out.println(String.format("LOAD_OPENCV END %s (Cost: %s)", endTime, (endTime - beginTime))); 302 | 303 | //if(!OpenvinoNativeLoader.load()) 304 | // throw new RuntimeException("LOAD_Openvino library failed. Please check."); 305 | //if(!TFNetNative.isLoaded()) 306 | // throw new RuntimeException("LOAD_TFNetNative library failed. Please check."); 307 | } 308 | 309 | /** 310 | * Get image class human string 311 | * @return 312 | */ 313 | private String getImageHumanString(Integer id) 314 | { 315 | return ImageClassIndex.getInsatnce().getImageHumanstring(String.valueOf(id)); 316 | } 317 | 318 | /** 319 | * For first time warm-dummy check only. 320 | * @throws Exception 321 | */ 322 | public void firstTimeDummyCheck() throws Exception 323 | { 324 | this.checkModel(model1, model1Params); 325 | if(this.model2!=null) 326 | this.checkModel(model2, model2Params); 327 | if(this.model3!=null) 328 | this.checkModel(model3, model3Params); 329 | this.getImageHumanString(-1); 330 | this.loadOpenCV(); 331 | this.loadTurboJpeg(); 332 | } 333 | 334 | /** 335 | * For first time warm-dummy check only. 336 | * @param pModel 337 | * @param pModelParams 338 | * @throws Exception 339 | */ 340 | private void checkModel(GarbageClassificationModel pModel, ModelParams pModelParams) throws Exception 341 | { 342 | if(pModel == null) { 343 | throw new Exception("ERROR: The model is null. Aborted the predict."); 344 | } 345 | long beginTime = System.currentTimeMillis(); 346 | System.out.println(String.format("FIRST_PREDICT BEGIN %s", beginTime)); 347 | 348 | JTensor tensor = new JTensor(); 349 | tensor.setData(new float[pModelParams.getInputSize()]); 350 | tensor.setShape(pModelParams.getInputShape()); 351 | List list = new ArrayList(); 352 | list.add(tensor); 353 | List> inputs = new ArrayList>(); 354 | inputs.add(list); 355 | pModel.predict(inputs); 356 | 357 | long endTime = System.currentTimeMillis(); 358 | System.out.println(String.format("FIRST_PREDICT END %s (Cost: %s)", endTime, (endTime - beginTime))); 359 | } 360 | 361 | } 362 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/threemodels/OpenVinoModelGenerator3Models.java: -------------------------------------------------------------------------------- 1 | package sky.tf.threemodels; 2 | 3 | import com.alibaba.tianchi.garbage_image_util.ConfigConstant; 4 | import org.apache.flink.api.java.tuple.Tuple2; 5 | import sky.tf.ModelParams; 6 | 7 | import java.io.File; 8 | import java.util.List; 9 | 10 | /** 11 | * The class use for generate optimized OpenVino models (Three type of models for prediction) 12 | * @author SkyPeace 13 | */ 14 | public class OpenVinoModelGenerator3Models { 15 | private String imageModelPath = System.getenv(ConfigConstant.IMAGE_MODEL_PATH); 16 | private String imageModelPackagePath = System.getenv(ConfigConstant.IMAGE_MODEL_PACKAGE_PATH); 17 | private String modelInferencePath = System.getenv("MODEL_INFERENCE_PATH"); 18 | private ModelParams model1Params; 19 | private ModelParams model2Params; 20 | private ModelParams model3Params; 21 | 22 | public void execute() 23 | { 24 | if (null == this.modelInferencePath) { 25 | throw new RuntimeException("ImageFlatMap(): Not set MODEL_INFERENCE_PATH environmental variable"); 26 | } 27 | if (null == this.imageModelPackagePath) { 28 | throw new RuntimeException("ImageFlatMap(): Not set imageModelPathPackagePath environmental variable"); 29 | } 30 | String imageModelPackageDir = 31 | imageModelPackagePath.substring(0, imageModelPackagePath.lastIndexOf(File.separator)); 32 | if(!imageModelPackageDir.equalsIgnoreCase(modelInferencePath)) 33 | { 34 | System.out.println("WARN: modelInferencePath NOT EQUAL imageModelPathPackageDir"); 35 | System.out.println(String.format("modelInferencePath: %s", modelInferencePath)); 36 | System.out.println(String.format("imageModelPackageDir: %s", imageModelPackageDir)); 37 | System.out.println(String.format("imageModelPath: %s", imageModelPath)); 38 | } 39 | 40 | try { 41 | //Generate optimized OpenVino model1. resnet_v1_101 42 | String tfModel1Path = modelInferencePath + File.separator + "SavedModel/model1"; 43 | String optimizedOpenVinoModel1Dir = imageModelPackageDir; 44 | model1Params = new ModelParams(); 45 | model1Params.setModelType("resnet"); 46 | model1Params.setModelName("model1"); 47 | model1Params.setInputName("input_1"); 48 | model1Params.setInputShape(new int[]{1, 224, 224, 3}); 49 | model1Params.setMeanValues(new float[]{123.68f,116.78f,103.94f}); 50 | model1Params.setScale(1.0f); 51 | model1Params.setOptimizedModelDir(optimizedOpenVinoModel1Dir); 52 | //Call Zoo API generate optimized OpenVino Model 53 | List> optimizedOpenVinoModelData = 54 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel1Path, model1Params); 55 | //Write optimized model's bytes into files (HDFS). The optimized model files parent dir is same as TF model package. 56 | ImageModelLoader3Models.getInstance(). 57 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model1Params); 58 | 59 | //Generate optimized OpenVino model2. inception_v4 60 | String tfModel2Path = modelInferencePath + File.separator + "SavedModel/model2"; 61 | String optimizedOpenVinoMode2Dir = imageModelPackageDir; 62 | model2Params = new ModelParams(); 63 | model2Params.setModelType("inception"); 64 | model2Params.setModelName("model2"); 65 | model2Params.setInputName("input_1"); 66 | model2Params.setInputShape(new int[]{1, 299, 299, 3}); 67 | model2Params.setMeanValues(new float[]{127.5f,127.5f,127.5f}); 68 | model2Params.setScale(127.5f); 69 | model2Params.setOptimizedModelDir(optimizedOpenVinoMode2Dir); 70 | //Call Zoo API generate optimized OpenVino Model 71 | optimizedOpenVinoModelData = 72 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel2Path, model2Params); 73 | ImageModelLoader3Models.getInstance(). 74 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model2Params); 75 | 76 | //Generate optimized OpenVino model3. inception_v3 77 | String tfModel3Path = this.modelInferencePath + File.separator + "SavedModel/model3"; 78 | String optimizedOpenVinoMode3Dir = imageModelPackageDir; 79 | model3Params = new ModelParams(); 80 | model3Params.setModelType("inception"); 81 | model3Params.setModelName("model3"); 82 | model3Params.setInputName("input_1"); 83 | model3Params.setInputShape(new int[]{1, 299, 299, 3}); 84 | model3Params.setMeanValues(new float[]{127.5f,127.5f,127.5f}); 85 | model3Params.setScale(127.5f); 86 | model3Params.setOptimizedModelDir(optimizedOpenVinoMode3Dir); 87 | //Call Zoo API generate optimized OpenVino Model 88 | optimizedOpenVinoModelData = 89 | ImageModelLoader3Models.getInstance().generateOpenVinoModelData(tfModel3Path, model3Params); 90 | ImageModelLoader3Models.getInstance(). 91 | saveToOpenVinoModelFile(optimizedOpenVinoModelData, model3Params); 92 | 93 | }catch(Exception ex) 94 | { 95 | ex.printStackTrace(); 96 | throw new RuntimeException("WARN: OpenVinoModelGenerator.execute() FAILED."); 97 | } 98 | } 99 | 100 | public ModelParams getModel1Params() 101 | { 102 | return model1Params; 103 | } 104 | 105 | public ModelParams getModel2Params() 106 | { 107 | return model2Params; 108 | } 109 | 110 | public ModelParams getModel3Params() 111 | { 112 | return model3Params; 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/sky/tf/threemodels/ValidationDatasetAnalyzer.java: -------------------------------------------------------------------------------- 1 | package sky.tf.threemodels; 2 | 3 | import sky.tf.PredictionResult; 4 | 5 | import java.io.BufferedReader; 6 | import java.io.File; 7 | import java.io.FileReader; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | /** 12 | * Experimental. Analysis validation data, search preferable probability thresholds for chose prediction result. 13 | */ 14 | public class ValidationDatasetAnalyzer { 15 | private float preferableThreshold1 = 0.94f; 16 | private float preferableThreshold2 = 0.81f; 17 | 18 | public void searchOptimizedThreshold(String validationPredictionDir) 19 | { 20 | String model1ValidationFilePath = validationPredictionDir + File.separator + "model1_validation_record.txt"; 21 | String model2ValidationFilePath = validationPredictionDir + File.separator + "model2_validation_record.txt"; 22 | String model3ValidationFilePath = validationPredictionDir + File.separator + "model3_validation_record.txt"; 23 | List labelList = getValidationLabels(model1ValidationFilePath); 24 | List model1Results = getPredictionResults(model1ValidationFilePath); 25 | List model2Results = getPredictionResults(model2ValidationFilePath); 26 | List model3Results = getPredictionResults(model3ValidationFilePath); 27 | if(labelList.size()==0) { 28 | System.out.println("WARN: Validation labels count is 0."); 29 | return; 30 | } 31 | int model1Success=0; 32 | int model2Success=0; 33 | int model3Success=0; 34 | for (int i = 0; i < labelList.size(); i++) { 35 | if (labelList.get(i).equals(model1Results.get(i).getPredictionId())) { 36 | model1Success++; 37 | } 38 | if (labelList.get(i).equals(model2Results.get(i).getPredictionId())) { 39 | model2Success++; 40 | } 41 | if (labelList.get(i).equals(model3Results.get(i).getPredictionId())) { 42 | model3Success++; 43 | } 44 | } 45 | float model1SucessRatio = (float) model1Success / (float)labelList.size(); 46 | float model2SucessRatio = (float) model2Success / (float)labelList.size(); 47 | float model3SucessRatio = (float) model3Success / (float)labelList.size(); 48 | System.out.println("The model1 success ratio: " + model1SucessRatio); 49 | System.out.println("The model2 success ratio: " + model2SucessRatio); 50 | System.out.println("The model3 success ratio: " + model3SucessRatio); 51 | 52 | float totalRatio = model1SucessRatio*model2SucessRatio*model3SucessRatio + 53 | model1SucessRatio*model2SucessRatio*(1f-model3SucessRatio) + 54 | model1SucessRatio*(1f-model2SucessRatio)*model3SucessRatio + 55 | (1f-model1SucessRatio)*model2SucessRatio*model3SucessRatio; 56 | System.out.println("The total success ratio: " + totalRatio); 57 | long t1 = System.currentTimeMillis(); 58 | float tmaxThreshold1 = 0f; 59 | float tmaxThreshold2 = 0f; 60 | float maxSuccessRatio = 0f; 61 | for (int j = 0; j < 20; j++) { 62 | float threshold1 = 0.01f * j + 0.80f; 63 | for (int k = 0; k < 80; k++) { 64 | float threshold2 = 0.01f * k + 0.2f; 65 | int success = 0; 66 | List finalResults = new ArrayList(); 67 | for (int i = 0; i < labelList.size(); i++) { 68 | PredictionResult model1Result = model1Results.get(i); 69 | PredictionResult model2Result = model2Results.get(i); 70 | PredictionResult model3Result = model3Results.get(i); 71 | 72 | if (model1Result.getProbability() >= threshold1) { 73 | finalResults.add(model1Result.getPredictionId()); 74 | continue; 75 | } 76 | 77 | if (model2Result.getProbability() >= threshold1) 78 | { 79 | finalResults.add(model2Result.getPredictionId()); 80 | continue; 81 | } 82 | 83 | /* 84 | if (model3Result.getProbability() >= threshold1) 85 | { 86 | finalResults.add(model3Result.getPredictionId()); 87 | continue; 88 | } 89 | */ 90 | 91 | if (model1Result.getPredictionId().equals(model2Result.getPredictionId()) 92 | && model1Result.getPredictionId().equals(model3Result.getPredictionId())) { 93 | finalResults.add(model1Result.getPredictionId()); 94 | continue; 95 | } 96 | if (model2Result.getPredictionId().equals(model3Result.getPredictionId())) { 97 | finalResults.add(model2Result.getPredictionId()); 98 | continue; 99 | } 100 | if (model1Result.getPredictionId().equals(model2Result.getPredictionId())) { 101 | finalResults.add(model1Result.getPredictionId()); 102 | continue; 103 | } 104 | if (model1Result.getPredictionId().equals(model3Result.getPredictionId())) { 105 | finalResults.add(model1Result.getPredictionId()); 106 | continue; 107 | } 108 | 109 | if (model2Result.getProbability() >= threshold2) 110 | { 111 | finalResults.add(model2Result.getPredictionId()); 112 | continue; 113 | } 114 | if (model3Result.getProbability() >= threshold2) { 115 | finalResults.add(model3Result.getPredictionId()); 116 | continue; 117 | } 118 | 119 | finalResults.add(model1Result.getPredictionId()); 120 | 121 | } 122 | for (int i = 0; i < labelList.size(); i++) { 123 | if (labelList.get(i).equals(finalResults.get(i))) { 124 | success++; 125 | } 126 | } 127 | float successRatio = (float) success / (float)labelList.size(); 128 | if(successRatio > maxSuccessRatio) { 129 | maxSuccessRatio = successRatio; 130 | tmaxThreshold2 = threshold2; 131 | tmaxThreshold1 = threshold1; 132 | } 133 | } 134 | } 135 | long t2 = System.currentTimeMillis(); 136 | System.out.println("Val elapsed time: " + (t2-t1)); 137 | System.out.println("Max threshold1: " + tmaxThreshold1); 138 | System.out.println("Max threshold2: " + tmaxThreshold2); 139 | System.out.println("Max successRation: " + maxSuccessRatio); 140 | this.preferableThreshold1 = tmaxThreshold1; 141 | this.preferableThreshold2 = tmaxThreshold2; 142 | } 143 | 144 | /** 145 | * Get validation data prediction results. 146 | */ 147 | private List getPredictionResults(String filePath) 148 | { 149 | List pResults = new ArrayList(); 150 | try { 151 | FileReader fr = new FileReader(filePath); 152 | BufferedReader br = new BufferedReader(fr); 153 | String line; 154 | int count1 = 0; 155 | while ((line = br.readLine()) != null) { 156 | String columns[] = line.split(","); 157 | //Format like as: 69 XXX => 69:YYY(P=0.74892) 158 | String aString = columns[0]; 159 | String aStrings[] = aString.split("=>"); 160 | String src = aStrings[0].replace(" ",""); 161 | String tgt = aStrings[1].trim(); 162 | //tgt = tgt.replace(":", ""); 163 | String pString = tgt.substring(tgt.indexOf("P=") +2 ); 164 | tgt = tgt.substring(0, tgt.indexOf(":")); 165 | pString = pString.replace(")",""); 166 | float P = Float.parseFloat(pString); 167 | PredictionResult result = new PredictionResult(); 168 | result.setPredictionId(Integer.valueOf(tgt)); 169 | result.setProbability(P); 170 | pResults.add(result); 171 | count1++; 172 | } 173 | //System.out.println("Count1: " + count1); 174 | fr.close(); 175 | }catch (Exception ex) 176 | { 177 | String errMsg = "Read model error index file FAILED. " + ex.getMessage(); 178 | System.out.println("WRAN: " + errMsg); 179 | //throw new RuntimeException(errMsg); 180 | } 181 | return pResults; 182 | } 183 | 184 | /** 185 | * Get validation dataset labels. 186 | * @param filePath 187 | * @return 188 | */ 189 | private List getValidationLabels(String filePath) 190 | { 191 | List pLabels = new ArrayList(); 192 | try { 193 | FileReader fr = new FileReader(filePath); 194 | BufferedReader br = new BufferedReader(fr); 195 | String line; 196 | int count1 = 0; 197 | while ((line = br.readLine()) != null) { 198 | String columns[] = line.split(","); 199 | //Format like as: 69 XXX => 69:YYY(P=0.74892) 200 | String aString = columns[0]; 201 | String sourceId = aString.substring(0, aString.indexOf(" ")); 202 | pLabels.add(Integer.valueOf(sourceId)); 203 | } 204 | System.out.println("Count1: " + count1); 205 | fr.close(); 206 | }catch (Exception ex) 207 | { 208 | String errMsg = "Read model error index file FAILED. " + ex.getMessage(); 209 | System.out.println("WRAN: " + errMsg); 210 | //throw new RuntimeException(errMsg); 211 | } 212 | return pLabels; 213 | } 214 | 215 | public float getPreferableThreshold1() 216 | { 217 | return this.preferableThreshold1; 218 | } 219 | public float getPreferableThreshold2() 220 | { 221 | return this.preferableThreshold2; 222 | } 223 | 224 | public static void main(String args[]) 225 | { 226 | ValidationDatasetAnalyzer validationDatasetTrainer = new ValidationDatasetAnalyzer(); 227 | validationDatasetTrainer.searchOptimizedThreshold("/Users/xyz/PycharmProjects/Inference"); 228 | System.out.println(validationDatasetTrainer.getPreferableThreshold1()); 229 | System.out.println(validationDatasetTrainer.getPreferableThreshold2()); 230 | } 231 | 232 | } 233 | -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | ################################################################################ 18 | 19 | log4j.rootLogger=INFO, console 20 | 21 | log4j.appender.console=org.apache.log4j.ConsoleAppender 22 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 23 | #log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n 24 | log4j.appender.console.layout.ConversionPattern=[%p] %d{yyyy-MM-dd HH:mm:ss,SSS} %-c %m %n 25 | --------------------------------------------------------------------------------