├── src └── main │ ├── resources │ ├── META-INF │ │ └── services │ │ │ └── qupath.lib.gui.extensions.QuPathExtension │ └── qupath │ │ └── ext │ │ └── wsinfer │ │ └── ui │ │ ├── progress_dialog.fxml │ │ ├── wsinferstyles.css │ │ ├── strings.properties │ │ └── wsinfer_control.fxml │ └── java │ └── qupath │ └── ext │ └── wsinfer │ ├── models │ ├── package-info.java │ ├── WSInferModelCollection.java │ ├── WSInferTransform.java │ ├── WSInferModelConfiguration.java │ ├── WSInferModelLocal.java │ ├── WSInferUtils.java │ └── WSInferModel.java │ ├── ui │ ├── package-info.java │ ├── WSInferPrefs.java │ ├── WSInferCommand.java │ ├── WSInferProgressDialog.java │ ├── PytorchManager.java │ └── WSInferController.java │ ├── ProgressListener.java │ ├── ProgressLogger.java │ ├── WSInferExtension.java │ ├── MpsSupport.java │ ├── TileLoader.java │ └── WSInfer.java ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── .github └── workflows │ ├── github_release.yml │ ├── build.yml │ ├── scijava-maven.yml │ └── update-gradle.yml ├── settings.gradle.kts ├── .gitignore ├── CHANGELOG.md ├── README.md ├── gradlew.bat ├── gradlew └── LICENSE /src/main/resources/META-INF/services/qupath.lib.gui.extensions.QuPathExtension: -------------------------------------------------------------------------------- 1 | qupath.ext.wsinfer.WSInferExtension 2 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qupath/qupath-extension-wsinfer/HEAD/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /src/main/java/qupath/ext/wsinfer/models/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Classes to access and cache WSInfer models. 3 | */ 4 | package qupath.ext.wsinfer.models; -------------------------------------------------------------------------------- /src/main/java/qupath/ext/wsinfer/ui/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Main user interface classes for the WSInfer extension. 3 | */ 4 | package qupath.ext.wsinfer.ui; -------------------------------------------------------------------------------- /.github/workflows/github_release.yml: -------------------------------------------------------------------------------- 1 | name: Make draft release 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | build: 8 | permissions: 9 | contents: write 10 | uses: qupath/actions/.github/workflows/github-release.yml@main 11 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.2-bin.zip 4 | networkTimeout=10000 5 | validateDistributionUrl=true 6 | zipStoreBase=GRADLE_USER_HOME 7 | zipStorePath=wrapper/dists 8 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Run gradle build 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | uses: qupath/actions/.github/workflows/gradle.yml@main 18 | -------------------------------------------------------------------------------- /settings.gradle.kts: -------------------------------------------------------------------------------- 1 | pluginManagement { 2 | repositories { 3 | gradlePluginPortal() 4 | maven { 5 | url = uri("https://maven.scijava.org/content/repositories/releases") 6 | } 7 | } 8 | } 9 | 10 | qupath { 11 | version = "0.6.0" 12 | } 13 | 14 | plugins { 15 | id("io.github.qupath.qupath-extension-settings") version "0.2.1" 16 | } 17 | -------------------------------------------------------------------------------- /.github/workflows/scijava-maven.yml: -------------------------------------------------------------------------------- 1 | name: Publish release to SciJava Maven 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | release: 7 | type: boolean 8 | description: Should this be a release? (if not, do a snapshot) 9 | required: true 10 | 11 | jobs: 12 | build: 13 | name: Publish release 14 | uses: qupath/actions/.github/workflows/scijava-maven.yml@main 15 | secrets: inherit 16 | with: 17 | release: ${{ inputs.release }} 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Javadocs 2 | docs/ 3 | 4 | # Maven 5 | deploy/ 6 | target/ 7 | log/ 8 | 9 | # IntelliJ 10 | .idea/ 11 | *.iml 12 | out/ 13 | 14 | # Gradle 15 | # Use local properties (e.g. to set a specific JDK) 16 | gradle.properties 17 | build/ 18 | .gradle/ 19 | gradle.properties 20 | 21 | # Eclipse 22 | .settings/ 23 | .project 24 | .classpath 25 | 26 | # VSCode 27 | .vscode/ 28 | 29 | # Mac 30 | .DS_Store 31 | 32 | # Java 33 | hs_err*.log 34 | 35 | # Other 36 | *.tmp 37 | *.bak 38 | *.swp 39 | *~.nib 40 | *thumbs.db 41 | bin/ 42 | -------------------------------------------------------------------------------- /src/main/resources/qupath/ext/wsinfer/ui/progress_dialog.fxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 17 | 43 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 98 | 99 | 100 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 122 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 164 | 165 | 166 | 167 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 192 | 193 | 194 | 195 | 196 | 197 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /src/main/java/qupath/ext/wsinfer/WSInfer.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2023 University of Edinburgh 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package qupath.ext.wsinfer; 18 | 19 | import ai.djl.Application; 20 | import ai.djl.Device; 21 | import ai.djl.MalformedModelException; 22 | import ai.djl.inference.Predictor; 23 | import ai.djl.modality.Classifications; 24 | import ai.djl.modality.cv.Image; 25 | import ai.djl.modality.cv.transform.Normalize; 26 | import ai.djl.modality.cv.transform.ToTensor; 27 | import ai.djl.repository.zoo.Criteria; 28 | import ai.djl.repository.zoo.ModelNotFoundException; 29 | import ai.djl.repository.zoo.ZooModel; 30 | import ai.djl.translate.Pipeline; 31 | import ai.djl.translate.Transform; 32 | import ai.djl.translate.TranslateException; 33 | import ai.djl.translate.Translator; 34 | import org.slf4j.Logger; 35 | import org.slf4j.LoggerFactory; 36 | import qupath.ext.wsinfer.models.WSInferModel; 37 | import qupath.ext.wsinfer.models.WSInferModelConfiguration; 38 | import qupath.ext.wsinfer.models.WSInferTransform; 39 | import qupath.ext.wsinfer.models.WSInferUtils; 40 | import qupath.ext.wsinfer.ui.WSInferPrefs; 41 | import qupath.lib.gui.tools.GuiTools; 42 | import qupath.lib.images.ImageData; 43 | import qupath.lib.images.servers.ImageServer; 44 | import qupath.lib.images.servers.PixelCalibration; 45 | import qupath.lib.objects.PathObject; 46 | import qupath.lib.objects.classes.PathClass; 47 | import qupath.lib.objects.utils.Tiler; 48 | import qupath.lib.scripting.QP; 49 | 50 | import java.awt.image.BufferedImage; 51 | import java.io.IOException; 52 | import java.util.ArrayList; 53 | import java.util.Collection; 54 | import java.util.LinkedHashSet; 55 | import java.util.List; 56 | import java.util.Objects; 57 | import java.util.ResourceBundle; 58 | import java.util.stream.Collectors; 59 | 60 | /** 61 | * Main class to run inference with WSInfer. 62 | */ 63 | public class WSInfer { 64 | 65 | private static final Logger logger = LoggerFactory.getLogger(WSInfer.class); 66 | private static final ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings"); 67 | 68 | /** 69 | * Run inference on the current image data using the given model. 70 | * @param wsiModel 71 | * @throws ModelNotFoundException 72 | * @throws MalformedModelException 73 | * @throws IOException 74 | * @throws InterruptedException 75 | * @throws TranslateException 76 | */ 77 | public static void runInference(WSInferModel wsiModel) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException, TranslateException { 78 | runInference(QP.getCurrentImageData(), wsiModel); 79 | } 80 | 81 | /** 82 | * Run inference on the current image data using the specified model. 83 | * @param modelName name of the model to use for inference 84 | * @throws ModelNotFoundException 85 | * @throws MalformedModelException 86 | * @throws IOException 87 | * @throws InterruptedException 88 | * @throws TranslateException 89 | */ 90 | public static void runInference(String modelName) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException, TranslateException { 91 | var model = loadModel(modelName); 92 | runInference(QP.getCurrentImageData(), model); 93 | } 94 | 95 | /** 96 | * Run inference on the specified image data using the specified model. 97 | * @param imageData image data to run inference on 98 | * @param modelName name of the model to use for inference 99 | * @throws ModelNotFoundException 100 | * @throws MalformedModelException 101 | * @throws IOException 102 | * @throws InterruptedException 103 | * @throws TranslateException 104 | */ 105 | public static void runInference(ImageData imageData, String modelName) throws ModelNotFoundException, MalformedModelException, IOException, InterruptedException, TranslateException { 106 | var model = loadModel(modelName); 107 | runInference(imageData, model); 108 | } 109 | 110 | /** 111 | * Get a model from the model collection. 112 | * @param modelName the name of the model to fetch 113 | * @return the model; this should never be null, because an exception is thrown if the model is not found. 114 | * @throws IllegalArgumentException if no model is found with the given name. 115 | */ 116 | private static WSInferModel loadModel(String modelName) throws IllegalArgumentException { 117 | Objects.requireNonNull(modelName, "Model name cannot be null"); 118 | var modelCollection = WSInferUtils.getModelCollection(); 119 | var model = modelCollection.getModels().getOrDefault(modelName, null); 120 | if (model == null && modelName.contains("/")) { 121 | String shortModelName = modelName.substring(modelName.lastIndexOf("/") + 1); 122 | model = modelCollection.getModels().getOrDefault(shortModelName, null); 123 | if (model != null) 124 | logger.warn("Using short model name {} instead of {}", shortModelName, modelName); 125 | } 126 | if (model == null) { 127 | throw new IllegalArgumentException("No model found with name: " + modelName); 128 | } 129 | return model; 130 | } 131 | 132 | /** 133 | * Run inference on the specified image data using the given model. 134 | * @param imageData image data to run inference on 135 | * @param wsiModel model to use for inference 136 | * @throws InterruptedException 137 | * @throws ModelNotFoundException 138 | * @throws MalformedModelException 139 | * @throws IOException 140 | * @throws TranslateException 141 | */ 142 | public static void runInference(ImageData imageData, WSInferModel wsiModel) throws InterruptedException, ModelNotFoundException, MalformedModelException, IOException, TranslateException { 143 | runInference(imageData, wsiModel, new ProgressLogger(logger)); 144 | } 145 | 146 | /** 147 | * Run inference on the specified image data using the given model with a custom progress listener. 148 | * @param imageData image data to run inference on (required) 149 | * @param wsiModel model to use for inference (required) 150 | * @param progressListener the progress listener to report what is happening (required) 151 | * @throws InterruptedException 152 | * @throws ModelNotFoundException 153 | * @throws MalformedModelException 154 | * @throws IOException 155 | * @throws TranslateException 156 | */ 157 | public static void runInference(ImageData imageData, WSInferModel wsiModel, ProgressListener progressListener) throws InterruptedException, ModelNotFoundException, MalformedModelException, IOException, TranslateException { 158 | Objects.requireNonNull(wsiModel, "Model cannot be null"); 159 | if (imageData == null) { 160 | GuiTools.showNoImageError(resources.getString("title")); 161 | } 162 | 163 | // Try to get some tiles we can use 164 | var tiles = getTilesForInference(imageData, wsiModel.getConfiguration()); 165 | if (tiles.isEmpty()) { 166 | logger.warn("No tiles to process!"); 167 | return; 168 | } 169 | 170 | Device device = getDevice(); 171 | 172 | Pipeline pipeline = new Pipeline(); 173 | int resize = -1; 174 | for (WSInferTransform transform: wsiModel.getConfiguration().getTransform()) { 175 | switch(transform.getName()) { 176 | case "Resize": 177 | // Ideally we'd resize with the pipeline, but unfortunately that fails with MPS devides - 178 | // so instead we need to resize first 179 | // int size = ((Double) transform.getArguments().get("size")).intValue(); 180 | // builder.addTransform(new Resize(size, size, Image.Interpolation.BILINEAR)); 181 | resize = ((Number)transform.getArguments().get("size")).intValue(); 182 | logger.debug("Requesting resize to {}", resize); 183 | break; 184 | case "ToTensor": 185 | pipeline.add(createToTensorTransform(device)); 186 | break; 187 | case "Normalize": 188 | pipeline.add(createNormalizeTransform(transform)); 189 | break; 190 | default: 191 | logger.warn("Ignoring unknown transform: {}", transform.getName()); 192 | break; 193 | } 194 | } 195 | 196 | boolean applySoftmax = true; 197 | Translator translator = buildTranslator(wsiModel, pipeline, applySoftmax); 198 | Criteria criteria = buildCriteria(wsiModel, translator, device); 199 | List classNames = wsiModel.getConfiguration().getClassNames(); 200 | long startTime = System.currentTimeMillis(); 201 | 202 | ImageServer server = imageData.getServer(); 203 | double downsample = wsiModel.getConfiguration().getSpacingMicronPerPixel() / (double)server.getPixelCalibration().getAveragedPixelSize(); 204 | int width = (int) Math.round(wsiModel.getConfiguration().getPatchSizePixels() * downsample); 205 | int height = (int) Math.round(wsiModel.getConfiguration().getPatchSizePixels() * downsample); 206 | 207 | // Number of workers who will be busy fetching tiles for us while we're busy inferring 208 | int nWorkers = Math.max(1, WSInferPrefs.numWorkersProperty().getValue()); 209 | 210 | // Set batch size 211 | // Previously, this *had* to be 1 for MPS - but since DJL 0.24.0 that doesn't seem necessary any more 212 | int batchSize = Math.max(1, WSInferPrefs.batchSizeProperty().getValue()); 213 | 214 | // Number of tiles each worker should prefetch 215 | int numPrefetch = (int)Math.max(2, Math.ceil((double)batchSize * 2 / nWorkers)); 216 | 217 | int nTiles = tiles.size(); 218 | logger.info("Running {} for {} tiles", wsiModel.getName(), nTiles); 219 | 220 | var tileLoader = TileLoader.builder() 221 | .batchSize(batchSize) 222 | .numWorkers(nWorkers) 223 | .numPrefetch(numPrefetch) 224 | .server(server) 225 | .tileSize(width, height) 226 | .downsample(downsample) 227 | .tiles(tiles) 228 | .resizeTile(resize, resize) 229 | .build(); 230 | 231 | int completedTiles = 0; 232 | int totalTiles = tiles.size(); 233 | updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime); 234 | 235 | try (ZooModel model = criteria.loadModel()) { 236 | try (Predictor predictor = model.newPredictor()) { 237 | var batchQueue = tileLoader.getBatchQueue(); 238 | int pendingWorkers = nWorkers; 239 | while (pendingWorkers > 0 && !Thread.currentThread().isInterrupted()) { 240 | var batch = batchQueue.take(); 241 | if (batch.isEmpty()) { 242 | // We stop when all the tile workers have returned an empty batch 243 | pendingWorkers--; 244 | continue; 245 | } 246 | 247 | List inputs = batch.getInputs(); 248 | List pathObjectBatch = batch.getTiles(); 249 | List predictions = predictor.batchPredict(inputs); 250 | 251 | for (int i = 0; i < inputs.size(); i++) { 252 | PathObject pathObject = pathObjectBatch.get(i); 253 | Classifications classifications = predictions.get(i); 254 | for (String c : classNames) { 255 | double prob = classifications.get(c).getProbability(); 256 | pathObject.getMeasurements().put(c, prob); 257 | } 258 | // Set class based upon probability 259 | String name = classifications.topK(1).get(0).getClassName(); 260 | if (name == null) 261 | pathObject.resetPathClass(); 262 | else 263 | pathObject.setPathClass(PathClass.fromString(name)); 264 | } 265 | completedTiles += inputs.size(); 266 | updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime); 267 | } 268 | } 269 | long endTime = System.currentTimeMillis(); 270 | long duration = endTime - startTime; 271 | updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime); 272 | 273 | imageData.getHierarchy().fireObjectClassificationsChangedEvent(WSInfer.class, tiles); 274 | long durationSeconds = duration/1000; 275 | String seconds = durationSeconds == 1 ? "second" : "seconds"; 276 | logger.info("Finished {} tiles in {} {} ({} ms per tile)", nTiles, durationSeconds, seconds, duration/nTiles); 277 | } catch (InterruptedException e) { 278 | logger.error("Model inference interrupted {}", wsiModel.getName(), e); 279 | progressListener.updateProgress("Inference interrupted!", 1.0); 280 | throw e; 281 | } catch (IOException | ModelNotFoundException | MalformedModelException | TranslateException e) { 282 | logger.error("Error running model {}", wsiModel.getName(), e); 283 | progressListener.updateProgress("Inference failed!", 1.0); 284 | throw e; 285 | } 286 | } 287 | 288 | private static void updateProgressForTiles(ProgressListener progress, int completedTiles, int totalTiles, long startTime) { 289 | double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; 290 | if (completedTiles == totalTiles) 291 | progress.updateProgress( 292 | String.format(resources.getString("ui.processing-completed"), completedTiles, totalTiles, completedTiles/timeSeconds), 293 | (double)completedTiles/totalTiles); 294 | else { 295 | progress.updateProgress( 296 | String.format(resources.getString("ui.processing-progress"), completedTiles, totalTiles, completedTiles / timeSeconds), 297 | (double)completedTiles / totalTiles); 298 | } 299 | } 300 | 301 | /** 302 | * Check if a specified device corresponds to using the Metal Performance Shaders (MPS) backend (Apple Silicon) 303 | * @param device 304 | * @return 305 | */ 306 | private static boolean isMPS(Device device) { 307 | return device != null && device.getDeviceType().toLowerCase().startsWith("mps"); 308 | } 309 | 310 | 311 | private static Translator buildTranslator(WSInferModel wsiModel, Pipeline pipeline, boolean applySoftmax) { 312 | // We should use ImageClassificationTranslator.builder() in the future if this is updated to work with MPS 313 | // (See javadocs for MpsSupport.WSInferClassificationTranslator for details) 314 | // ImageClassificationTranslator.Builder builder = ImageClassificationTranslator.builder() 315 | return MpsSupport.WSInferClassificationTranslator.builder() 316 | .optSynset(wsiModel.getConfiguration().getClassNames()) 317 | .optApplySoftmax(applySoftmax) 318 | .setPipeline(pipeline) 319 | .build(); 320 | } 321 | 322 | private static Criteria buildCriteria(WSInferModel wsiModel, Translator translator, Device device) { 323 | return Criteria.builder() 324 | .optApplication(Application.CV.IMAGE_CLASSIFICATION) 325 | .optModelPath(wsiModel.getTorchScriptFile().toPath()) 326 | .optEngine("PyTorch") 327 | .setTypes(Image.class, Classifications.class) 328 | .optTranslator(translator) 329 | .optDevice(device) 330 | .build(); 331 | } 332 | 333 | 334 | private static Device getDevice() { 335 | String deviceName = WSInferPrefs.deviceProperty().get(); 336 | switch (deviceName) { 337 | case "gpu": 338 | return Device.gpu(); 339 | case "cpu": 340 | return Device.cpu(); 341 | default: 342 | logger.info("Attempting to set device to {}", deviceName); 343 | return Device.fromName(deviceName); 344 | } 345 | } 346 | 347 | 348 | private static Transform createToTensorTransform(Device device) { 349 | logger.debug("Creating ToTensor transform"); 350 | if (isMPS(device)) 351 | return new MpsSupport.ToTensor32(); 352 | else 353 | return new ToTensor(); 354 | } 355 | 356 | private static Transform createNormalizeTransform(WSInferTransform transform) { 357 | ArrayList mean = (ArrayList) transform.getArguments().get("mean"); 358 | ArrayList sd = (ArrayList) transform.getArguments().get("std"); 359 | logger.debug("Creating Normalize transform (mean={}, sd={})", mean, sd); 360 | float[] meanArr = new float[mean.size()]; 361 | float[] sdArr = new float[mean.size()]; 362 | for (int i = 0; i < mean.size(); i++) { 363 | meanArr[i] = mean.get(i).floatValue(); 364 | sdArr[i] = sd.get(i).floatValue(); 365 | } 366 | return new Normalize(meanArr, sdArr); 367 | } 368 | 369 | 370 | private static List getTilesForInference(ImageData imageData, WSInferModelConfiguration config) { 371 | // Here, we permit detections to be used instead of tiles 372 | var selectedObjects = imageData.getHierarchy().getSelectionModel().getSelectedObjects(); 373 | var selectedTiles = selectedObjects.stream() 374 | .filter(p -> p.isTile() || p.isDetection()) 375 | .collect(Collectors.toList()); 376 | if (!selectedTiles.isEmpty()) { 377 | return selectedTiles; 378 | } 379 | 380 | // If we have annotations selected, create tiles inside them 381 | Collection selectedAnnotations = selectedObjects.stream() 382 | .filter(PathObject::isAnnotation) 383 | .toList(); 384 | if (selectedAnnotations.isEmpty()) { 385 | throw new IllegalArgumentException(resources.getString("No tiles or annotations selected!")); 386 | } 387 | 388 | var annotationSet = new LinkedHashSet<>(selectedAnnotations); // We want this later 389 | double tileWidth, tileHeight; 390 | PixelCalibration cal = imageData.getServer().getPixelCalibration(); 391 | if (cal.hasPixelSizeMicrons()) { 392 | double tileSizeMicrons = config.getPatchSizePixels() * config.getSpacingMicronPerPixel(); 393 | tileWidth = (int)(tileSizeMicrons / cal.getPixelWidthMicrons() + .5); 394 | tileHeight = (int)(tileSizeMicrons / cal.getPixelHeightMicrons() + .5); 395 | } else { 396 | logger.warn("Pixel calibration not available, so using pixels instead of microns"); 397 | tileWidth = Math.round(config.getPatchSizePixels()); 398 | tileHeight = tileWidth; 399 | } 400 | 401 | var tiler = Tiler.builder((int)tileWidth, (int)tileHeight) 402 | .cropTiles(false) 403 | .filterByCentroid(false) 404 | .alignCenter() 405 | .build(); 406 | for (var annotation: selectedAnnotations) { 407 | var tiles = tiler.createTiles(annotation.getROI()); 408 | // add tiles to the hierarchy 409 | annotation.clearChildObjects(); 410 | for (int i = 0; i < tiles.size(); i++) { 411 | var tile = tiles.get(i); 412 | tile.setName("Tile " + i); 413 | annotation.addChildObject(tile); 414 | } 415 | annotation.setLocked(true); 416 | imageData.getHierarchy().fireHierarchyChangedEvent(annotation); 417 | } 418 | // We want our new tiles to be selected... but we also want to ensure that any tile object 419 | // has a selected annotation as a parent (in case there were other tiles already) 420 | return imageData.getHierarchy().getTileObjects() 421 | .stream() 422 | .filter(t -> annotationSet.contains(t.getParent())) 423 | .collect(Collectors.toList()); 424 | 425 | } 426 | 427 | } 428 | -------------------------------------------------------------------------------- /src/main/java/qupath/ext/wsinfer/ui/WSInferController.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2023 University of Edinburgh 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package qupath.ext.wsinfer.ui; 18 | 19 | import javafx.application.Platform; 20 | import javafx.beans.binding.Bindings; 21 | import javafx.beans.binding.BooleanBinding; 22 | import javafx.beans.binding.ObjectBinding; 23 | import javafx.beans.binding.StringBinding; 24 | import javafx.beans.property.BooleanProperty; 25 | import javafx.beans.property.IntegerProperty; 26 | import javafx.beans.property.ObjectProperty; 27 | import javafx.beans.property.SimpleBooleanProperty; 28 | import javafx.beans.property.SimpleIntegerProperty; 29 | import javafx.beans.property.SimpleObjectProperty; 30 | import javafx.beans.property.StringProperty; 31 | import javafx.beans.value.ObservableValue; 32 | import javafx.concurrent.Task; 33 | import javafx.concurrent.Worker; 34 | import javafx.event.ActionEvent; 35 | import javafx.fxml.FXML; 36 | import javafx.scene.control.Button; 37 | import javafx.scene.control.ChoiceBox; 38 | import javafx.scene.control.ContentDisplay; 39 | import javafx.scene.control.Label; 40 | import javafx.scene.control.Slider; 41 | import javafx.scene.control.Spinner; 42 | import javafx.scene.control.TextField; 43 | import javafx.scene.control.ToggleButton; 44 | import javafx.scene.input.MouseEvent; 45 | import javafx.scene.web.WebView; 46 | import javafx.stage.Stage; 47 | import javafx.util.StringConverter; 48 | import org.commonmark.ext.front.matter.YamlFrontMatterExtension; 49 | import org.commonmark.ext.front.matter.YamlFrontMatterVisitor; 50 | import org.commonmark.parser.Parser; 51 | import org.commonmark.renderer.html.HtmlRenderer; 52 | import org.controlsfx.control.PopOver; 53 | import org.controlsfx.control.SearchableComboBox; 54 | import org.controlsfx.control.action.Action; 55 | import org.controlsfx.control.action.ActionUtils; 56 | import org.slf4j.Logger; 57 | import org.slf4j.LoggerFactory; 58 | import qupath.ext.wsinfer.WSInfer; 59 | import qupath.ext.wsinfer.models.WSInferModel; 60 | import qupath.ext.wsinfer.models.WSInferModelCollection; 61 | import qupath.ext.wsinfer.models.WSInferUtils; 62 | import qupath.fx.dialogs.Dialogs; 63 | import qupath.fx.dialogs.FileChoosers; 64 | import qupath.fx.utils.FXUtils; 65 | import qupath.lib.common.ThreadTools; 66 | import qupath.lib.gui.QuPathGUI; 67 | import qupath.lib.gui.commands.Commands; 68 | import qupath.lib.gui.tools.GuiTools; 69 | import qupath.lib.gui.tools.WebViews; 70 | import qupath.lib.images.ImageData; 71 | import qupath.lib.objects.PathAnnotationObject; 72 | import qupath.lib.objects.PathObject; 73 | import qupath.lib.objects.PathTileObject; 74 | import qupath.lib.objects.hierarchy.PathObjectHierarchy; 75 | import qupath.lib.objects.hierarchy.events.PathObjectSelectionListener; 76 | import qupath.lib.plugins.workflow.DefaultScriptableWorkflowStep; 77 | 78 | import java.awt.image.BufferedImage; 79 | import java.io.File; 80 | import java.io.IOException; 81 | import java.nio.file.Files; 82 | import java.util.Arrays; 83 | import java.util.Collection; 84 | import java.util.Objects; 85 | import java.util.ResourceBundle; 86 | import java.util.Set; 87 | import java.util.concurrent.ExecutorService; 88 | import java.util.concurrent.Executors; 89 | 90 | 91 | /** 92 | * Controller for the WSInfer user interface, which is defined in wsinfer_control.fxml 93 | * and loaded by WSInferCommand. 94 | */ 95 | public class WSInferController { 96 | 97 | private static final Logger logger = LoggerFactory.getLogger(WSInferController.class); 98 | 99 | private QuPathGUI qupath; 100 | private final ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); 101 | private MessageTextHelper messageTextHelper; 102 | 103 | @FXML 104 | private Label labelMessage; 105 | @FXML 106 | private SearchableComboBox modelChoiceBox; 107 | @FXML 108 | private Button runButton; 109 | @FXML 110 | private Button downloadButton; 111 | @FXML 112 | private Button infoButton; 113 | @FXML 114 | private ChoiceBox deviceChoices; 115 | @FXML 116 | private ToggleButton toggleSelectAllAnnotations; 117 | @FXML 118 | private ToggleButton toggleSelectAllDetections; 119 | @FXML 120 | private ToggleButton toggleDetectionFill; 121 | @FXML 122 | private ToggleButton toggleDetections; 123 | @FXML 124 | private ToggleButton toggleAnnotations; 125 | @FXML 126 | private Slider sliderOpacity; 127 | @FXML 128 | private Spinner spinnerNumWorkers; 129 | @FXML 130 | private Spinner spinnerBatchSize; 131 | @FXML 132 | private TextField tfModelDirectory; 133 | 134 | private final WebView infoWebView = WebViews.create(true); 135 | private final PopOver infoPopover = new PopOver(infoWebView); 136 | 137 | private final static ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings"); 138 | 139 | private Stage measurementMapsStage; 140 | 141 | private final ExecutorService pool = Executors.newSingleThreadExecutor(ThreadTools.createThreadFactory("wsinfer", true)); 142 | 143 | private final ObjectProperty> pendingTask = new SimpleObjectProperty<>(); 144 | private final BooleanProperty downloadPending = new SimpleBooleanProperty(false); 145 | 146 | @FXML 147 | private void initialize() { 148 | logger.info("Initializing..."); 149 | 150 | this.qupath = QuPathGUI.getInstance(); 151 | this.imageDataProperty.bind(qupath.imageDataProperty()); 152 | 153 | refreshAvailableModels(); 154 | 155 | configureSelectionButtons(); 156 | configureDisplayToggleButtons(); 157 | configureOpacitySlider(); 158 | 159 | configureAvailableDevices(); 160 | configureModelDirectory(); 161 | configureNumWorkers(); 162 | configureBatchSize(); 163 | 164 | configureMessageLabel(); 165 | configureButtons(); 166 | 167 | configurePendingTaskProperty(); 168 | } 169 | 170 | private void configureMessageLabel() { 171 | messageTextHelper = new MessageTextHelper(); 172 | labelMessage.textProperty().bind(messageTextHelper.messageLabelText); 173 | if (messageTextHelper.hasWarning.get()) { 174 | labelMessage.getStyleClass().setAll("warning-message"); 175 | } else { 176 | labelMessage.getStyleClass().setAll("standard-message"); 177 | } 178 | messageTextHelper.hasWarning.addListener((observable, oldValue, newValue) -> { 179 | if (newValue) 180 | labelMessage.getStyleClass().setAll("warning-message"); 181 | else 182 | labelMessage.getStyleClass().setAll("standard-message"); 183 | }); 184 | } 185 | 186 | /** 187 | * Populate the available models & configure the UI elements to select and download models. 188 | */ 189 | void refreshAvailableModels() { 190 | WSInferModelCollection models = WSInferUtils.getModelCollection(); 191 | modelChoiceBox.getItems().setAll(models.getModels().values()); 192 | modelChoiceBox.setConverter(new ModelStringConverter(models)); 193 | modelChoiceBox.getSelectionModel().selectedItemProperty().addListener( 194 | (v, o, n) -> { 195 | infoPopover.hide(); 196 | }); 197 | } 198 | 199 | private static boolean checkFileExists(File file) { 200 | return file != null && file.isFile(); 201 | } 202 | 203 | private void configureAvailableDevices() { 204 | var available = PytorchManager.getAvailableDevices(); 205 | deviceChoices.getItems().setAll(available); 206 | var selected = WSInferPrefs.deviceProperty().get(); 207 | if (available.contains(selected)) { 208 | deviceChoices.getSelectionModel().select(selected); 209 | } else { 210 | deviceChoices.getSelectionModel().selectFirst(); 211 | } 212 | // Don't bind property for now, since this would cause trouble if the WSInferPrefs.deviceProperty() is 213 | // changed elsewhere 214 | deviceChoices.getSelectionModel().selectedItemProperty().addListener( 215 | (value, oldValue, newValue) -> WSInferPrefs.deviceProperty().set(newValue)); 216 | } 217 | 218 | private void configureButtons() { 219 | // Disable the run button while a task is pending, or we have no model selected, or download is required 220 | runButton.disableProperty().bind( 221 | imageDataProperty.isNull() 222 | .or(pendingTask.isNotNull()) 223 | .or(modelChoiceBox.getSelectionModel().selectedItemProperty().isNull()) 224 | .or(messageTextHelper.warningText.isNotEmpty()) 225 | ); 226 | downloadButton.disableProperty().bind( 227 | Bindings.createBooleanBinding( 228 | () -> modelChoiceBox.getSelectionModel().getSelectedItem() == null || modelChoiceBox.getSelectionModel().getSelectedItem().isValid(), 229 | modelChoiceBox.getSelectionModel().selectedItemProperty(), 230 | downloadPending) 231 | ); 232 | infoButton.disableProperty().bind( 233 | Bindings.createBooleanBinding( 234 | () -> { 235 | var n = modelChoiceBox.getSelectionModel().getSelectedItem(); 236 | return n == null || !n.isValid() || !checkFileExists(n.getReadMeFile()); 237 | }, 238 | modelChoiceBox.getSelectionModel().selectedItemProperty(), 239 | downloadPending) 240 | ); 241 | } 242 | 243 | private void configurePendingTaskProperty() { 244 | // Submit any new pending tasks to the thread pool 245 | pendingTask.addListener((observable, oldValue, newValue) -> { 246 | if (newValue != null) { 247 | pool.execute(newValue); 248 | } 249 | }); 250 | } 251 | 252 | private void configureDisplayToggleButtons() { 253 | var actions = qupath.getOverlayActions(); 254 | configureActionToggleButton(actions.FILL_DETECTIONS, toggleDetectionFill); 255 | configureActionToggleButton(actions.SHOW_DETECTIONS, toggleDetections); 256 | configureActionToggleButton(actions.SHOW_ANNOTATIONS, toggleAnnotations); 257 | } 258 | 259 | private void configureOpacitySlider() { 260 | var opacityProperty = qupath.getOverlayOptions().opacityProperty(); 261 | sliderOpacity.valueProperty().bindBidirectional(opacityProperty); 262 | } 263 | 264 | private void configureSelectionButtons() { 265 | toggleSelectAllAnnotations.disableProperty().bind(imageDataProperty.isNull()); 266 | overrideToggleSelected(toggleSelectAllAnnotations); 267 | toggleSelectAllDetections.disableProperty().bind(imageDataProperty.isNull()); 268 | overrideToggleSelected(toggleSelectAllDetections); 269 | } 270 | 271 | // Hack to prevent the toggle buttons from staying selected 272 | // This allows us to use a segmented button with the appearance of regular, non-toggle buttons 273 | private void overrideToggleSelected(ToggleButton button) { 274 | button.selectedProperty().addListener((value, oldValue, newValue) -> button.setSelected(false)); 275 | } 276 | 277 | /** 278 | * Configure a toggle button for showing/hiding or filling/unfilling objects. 279 | * @param action 280 | * @param button 281 | */ 282 | private void configureActionToggleButton(Action action, ToggleButton button) { 283 | ActionUtils.configureButton(action, button); 284 | button.setContentDisplay(ContentDisplay.GRAPHIC_ONLY); 285 | } 286 | 287 | private void configureModelDirectory() { 288 | tfModelDirectory.textProperty().bindBidirectional(WSInferPrefs.modelDirectoryProperty()); 289 | } 290 | 291 | private void configureNumWorkers() { 292 | spinnerNumWorkers.getValueFactory().valueProperty().bindBidirectional(WSInferPrefs.numWorkersProperty()); 293 | } 294 | 295 | private void configureBatchSize() { 296 | spinnerBatchSize.getValueFactory().valueProperty().bindBidirectional(WSInferPrefs.batchSizeProperty()); 297 | } 298 | 299 | /** 300 | * Try to run inference on the current image using the current model and parameters. 301 | */ 302 | public void runInference() { 303 | var imageData = this.imageDataProperty.get(); 304 | if (imageData == null) { 305 | Dialogs.showErrorMessage(resources.getString("title"), resources.getString("error.no-imagedata")); 306 | return; 307 | } 308 | // Check we have a model - and are willing to download it if needed 309 | var selectedModel = modelChoiceBox.getSelectionModel().getSelectedItem(); 310 | if (selectedModel == null) { 311 | // Shouldn't happen 312 | Dialogs.showErrorMessage(resources.getString("title"), resources.getString("error.no-model")); 313 | return; 314 | } 315 | if (!selectedModel.isValid()) { 316 | if (!Dialogs.showConfirmDialog(resources.getString("title"), resources.getString("ui.model-popup"))) { 317 | Dialogs.showWarningNotification(resources.getString("title"), resources.getString("ui.model-not-downloaded")); 318 | return; 319 | } 320 | } 321 | 322 | if (!PytorchManager.hasPyTorchEngine()) { 323 | if (!Dialogs.showConfirmDialog(resources.getString("title"), resources.getString("ui.pytorch"))) { 324 | Dialogs.showWarningNotification(resources.getString("title"), resources.getString("ui.pytorch-popup")); 325 | return; 326 | } 327 | } 328 | submitInferenceTask(imageData, selectedModel); 329 | } 330 | 331 | private void submitInferenceTask(ImageData imageData, WSInferModel model) { 332 | var task = new WSInferTask(imageData, model); 333 | pendingTask.set(task); 334 | // Reset the pending task when it completes (either successfully or not) 335 | task.stateProperty().addListener((observable, oldValue, newValue) -> { 336 | if (Set.of(Worker.State.CANCELLED, Worker.State.SUCCEEDED, Worker.State.FAILED).contains(newValue)) { 337 | if (pendingTask.get() == task) 338 | pendingTask.set(null); 339 | } 340 | }); 341 | } 342 | 343 | private static void addToHistoryWorkflow(ImageData imageData, String modelName) { 344 | imageData.getHistoryWorkflow() 345 | .addStep( 346 | new DefaultScriptableWorkflowStep( 347 | resources.getString("workflow.title"), 348 | WSInfer.class.getName() + ".runInference(\"" + modelName + "\")" 349 | )); 350 | } 351 | 352 | private static void showDownloadingModelNotification(String modelName) { 353 | Dialogs.showPlainNotification( 354 | resources.getString("title"), 355 | String.format(resources.getString("ui.popup.fetching"), modelName)); 356 | } 357 | 358 | private static void showModelAvailableNotification(String modelName) { 359 | Dialogs.showPlainNotification( 360 | resources.getString("title"), 361 | String.format(resources.getString("ui.popup.available"), modelName)); 362 | } 363 | 364 | public void downloadModel() { 365 | var model = modelChoiceBox.getSelectionModel().getSelectedItem(); 366 | if (model == null) { 367 | return; 368 | } 369 | if (model.isValid()) { 370 | showModelAvailableNotification(model.getName()); 371 | return; 372 | } 373 | 374 | Thread.ofVirtual().start(() -> { 375 | model.removeCache(); 376 | showDownloadingModelNotification(model.getName()); 377 | try { 378 | model.downloadModel(); 379 | } catch (IOException e) { 380 | Dialogs.showErrorMessage(resources.getString("title"), resources.getString("error.downloading")); 381 | return; 382 | } 383 | showModelAvailableNotification(model.getName()); 384 | downloadPending.set(false); 385 | }); 386 | downloadPending.set(true); 387 | } 388 | 389 | /** 390 | * Open the model directory in the system file browser when double-clicked. 391 | * @param event 392 | */ 393 | public void handleModelDirectoryLabelClick(MouseEvent event) { 394 | if (event.getClickCount() != 2) 395 | return; 396 | var path = WSInferPrefs.modelDirectoryProperty().get(); 397 | if (path == null || path.isEmpty()) 398 | return; 399 | var file = new File(path); 400 | if (file.exists()) 401 | GuiTools.browseDirectory(file); 402 | else 403 | logger.debug("Can't browse directory for {}", file); 404 | } 405 | 406 | public void promptForModelDirectory() { 407 | promptToUpdateDirectory(WSInferPrefs.modelDirectoryProperty()); 408 | } 409 | 410 | private void promptToUpdateDirectory(StringProperty dirPath) { 411 | var modelDirPath = dirPath.get(); 412 | var dir = modelDirPath == null || modelDirPath.isEmpty() ? null : new File(modelDirPath); 413 | if (dir != null) { 414 | if (dir.isFile()) 415 | dir = dir.getParentFile(); 416 | else if (!dir.exists()) 417 | dir = null; 418 | } 419 | var newDir = FileChoosers.promptForDirectory( 420 | FXUtils.getWindow(tfModelDirectory), // Get window from any node here 421 | resources.getString("ui.model-directory.choose-directory"), 422 | dir); 423 | if (newDir == null) 424 | return; 425 | dirPath.set(newDir.getAbsolutePath()); 426 | } 427 | 428 | public void showInfo() throws IOException { 429 | if (infoPopover.isShowing()) { 430 | infoPopover.hide(); 431 | return; 432 | } 433 | WSInferModel model = modelChoiceBox.getSelectionModel().getSelectedItem(); 434 | var file = model.getReadMeFile(); 435 | if (!checkFileExists(file)) { 436 | logger.warn("ReadMe file not available: {}", file); 437 | return; 438 | } 439 | try { 440 | var markdown = Files.readString(file.toPath()); 441 | 442 | // Parse the initial markdown only, to extract any YAML front matter 443 | var parser = Parser.builder() 444 | .extensions( 445 | Arrays.asList(YamlFrontMatterExtension.create()) 446 | ).build(); 447 | var doc = parser.parse(markdown); 448 | var visitor = new YamlFrontMatterVisitor(); 449 | doc.accept(visitor); 450 | var metadata = visitor.getData(); 451 | 452 | // If we have YAML metadata, remove from the start (since it renders weirdly) and append at the end 453 | if (!metadata.isEmpty()) { 454 | doc.getFirstChild().unlink(); 455 | var sb = new StringBuilder(); 456 | sb.append("----\n\n"); 457 | sb.append("### Metadata\n\n"); 458 | for (var entry : metadata.entrySet()) { 459 | sb.append("\n* **").append(entry.getKey()).append("**: ").append(entry.getValue()); 460 | } 461 | doc.appendChild(parser.parse(sb.toString())); 462 | } 463 | 464 | // If the markdown doesn't start with a title, pre-pending the model title & description (if available) 465 | if (!markdown.startsWith("#")) { 466 | var sb = new StringBuilder(); 467 | sb.append("## ").append(model.getName()).append("\n\n"); 468 | var description = model.getDescription(); 469 | if (description != null && !description.isEmpty()) { 470 | sb.append("_").append(description).append("_").append("\n\n"); 471 | } 472 | sb.append("----\n\n"); 473 | doc.prependChild(parser.parse(sb.toString())); 474 | } 475 | 476 | infoWebView.getEngine().loadContent( 477 | HtmlRenderer.builder().build().render(doc)); 478 | infoPopover.show(infoButton); 479 | } catch (IOException e) { 480 | logger.error("Error parsing readme file", e); 481 | } 482 | } 483 | 484 | 485 | 486 | 487 | @FXML 488 | private void selectAllAnnotations() { 489 | Commands.selectObjectsByClass(imageDataProperty.get(), PathAnnotationObject.class); 490 | } 491 | 492 | @FXML 493 | private void selectAllTiles() { 494 | Commands.selectObjectsByClass(imageDataProperty.get(), PathTileObject.class); 495 | } 496 | 497 | @FXML 498 | private void openMeasurementMaps(ActionEvent event) { 499 | // Try to use existing action, to avoid creating a new stage 500 | // TODO: Replace this if QuPath v0.5.0 provides direct access to the action 501 | // since that should be more robust (and also cope with language changes) 502 | var action = qupath.lookupActionByText("Show measurement maps"); 503 | if (action != null) { 504 | action.handle(event); 505 | return; 506 | } 507 | // Fallback in case we couldn't get the action 508 | if (measurementMapsStage == null) { 509 | logger.warn("Creating a new measurement map stage"); 510 | measurementMapsStage = Commands.createMeasurementMapDialog(QuPathGUI.getInstance()); 511 | } 512 | measurementMapsStage.show(); 513 | } 514 | 515 | @FXML 516 | private void openDetectionTable() { 517 | Commands.showDetectionMeasurementTable(qupath, imageDataProperty.get()); 518 | } 519 | 520 | /** 521 | * Wrapper for an inference task, which can be submitted to the thread pool. 522 | */ 523 | private static class WSInferTask extends Task { 524 | 525 | private final ImageData imageData; 526 | private final WSInferModel model; 527 | private final WSInferProgressDialog progressListener; 528 | 529 | private WSInferTask(ImageData imageData, WSInferModel model) { 530 | this.imageData = imageData; 531 | this.model = model; 532 | this.progressListener = new WSInferProgressDialog(QuPathGUI.getInstance().getStage(), e -> { 533 | if (Dialogs.showYesNoDialog(getDialogTitle(), resources.getString("ui.stop-tasks"))) { 534 | cancel(true); 535 | e.consume(); 536 | } 537 | }); 538 | this.stateProperty().addListener(this::handleStateChange); 539 | } 540 | 541 | private void handleStateChange(ObservableValue value, Worker.State oldValue, Worker.State newValue) { 542 | if (progressListener != null && newValue == Worker.State.CANCELLED) 543 | progressListener.cancel(); 544 | } 545 | 546 | private String getDialogTitle() { 547 | try { 548 | return ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings") 549 | .getString("title"); 550 | } catch (Exception e) { 551 | logger.debug("Exception attempting to request title resource"); 552 | return "WSInfer"; 553 | } 554 | } 555 | 556 | @Override 557 | protected Void call() { 558 | // Ensure PyTorch engine is available 559 | if (!PytorchManager.hasPyTorchEngine()) { 560 | Platform.runLater(() -> Dialogs.showInfoNotification(getDialogTitle(), resources.getString("ui.pytorch-downloading"))); 561 | PytorchManager.getEngineOnline(); 562 | } 563 | // Ensure model is available - any prompts allowing the user to cancel 564 | // should have been displayed already 565 | if (!model.isValid()) { 566 | showDownloadingModelNotification(model.getName()); 567 | try { 568 | model.downloadModel(); 569 | } catch (IOException e) { 570 | Platform.runLater(() -> Dialogs.showErrorMessage(resources.getString("title"), resources.getString("error.downloading"))); 571 | return null; 572 | } 573 | showModelAvailableNotification(model.getName()); 574 | } 575 | // Run inference 576 | try { 577 | WSInfer.runInference(imageData, model, progressListener); 578 | addToHistoryWorkflow(imageData, model.getName()); 579 | } catch (InterruptedException e) { 580 | Platform.runLater(() -> Dialogs.showErrorNotification(getDialogTitle(), e.getLocalizedMessage())); 581 | } catch (Exception e) { 582 | Platform.runLater(() -> Dialogs.showErrorMessage(getDialogTitle(), e.getLocalizedMessage())); 583 | } 584 | return null; 585 | } 586 | } 587 | 588 | 589 | /** 590 | * Helper class for determining which text to display in the message label. 591 | */ 592 | private class MessageTextHelper { 593 | 594 | private final SelectedObjectCounter selectedObjectCounter; 595 | 596 | /** 597 | * Text to display a warning (because inference can't be run) 598 | */ 599 | private StringBinding warningText; 600 | /** 601 | * Text to display the number of selected objects (usually when inference can be run) 602 | */ 603 | private StringBinding selectedObjectText; 604 | /** 605 | * Text to display in the message label (either the warning or the selected object text) 606 | */ 607 | private StringBinding messageLabelText; 608 | 609 | /** 610 | * Binding to check if the warning is empty. 611 | * Retained here because otherwise code that attaches a listener to {@code warningText.isEmpty()} would need to 612 | * retain a reference to the binding to prevent garbage collection. 613 | */ 614 | private BooleanBinding hasWarning; 615 | 616 | MessageTextHelper() { 617 | this.selectedObjectCounter = new SelectedObjectCounter(imageDataProperty); 618 | configureMessageTextBindings(); 619 | } 620 | 621 | private void configureMessageTextBindings() { 622 | this.warningText = createWarningTextBinding(); 623 | this.selectedObjectText = createSelectedObjectTextBinding(); 624 | this.messageLabelText = Bindings.createStringBinding(() -> { 625 | var warning = warningText.get(); 626 | if (warning == null || warning.isEmpty()) 627 | return selectedObjectText.get(); 628 | else 629 | return warning; 630 | }, warningText, selectedObjectText); 631 | this.hasWarning = warningText.isEmpty().not(); 632 | } 633 | 634 | private StringBinding createSelectedObjectTextBinding() { 635 | return Bindings.createStringBinding(this::getSelectedObjectText, 636 | selectedObjectCounter.numSelectedAnnotations, 637 | selectedObjectCounter.numSelectedDetections); 638 | } 639 | 640 | private String getSelectedObjectText() { 641 | int nAnnotations = selectedObjectCounter.numSelectedAnnotations.get(); 642 | int nDetections = selectedObjectCounter.numSelectedDetections.get(); 643 | if (nAnnotations == 1) 644 | return resources.getString("ui.selection.annotations-single"); 645 | else if (nAnnotations > 1) 646 | return String.format(resources.getString("ui.selection.annotations-multiple"), nAnnotations); 647 | else if (nDetections == 1) 648 | return resources.getString("ui.selection.detections-single"); 649 | else if (nDetections > 1) 650 | return String.format(resources.getString("ui.selection.detections-multiple"), nDetections); 651 | else 652 | return resources.getString("ui.selection.empty"); 653 | } 654 | 655 | private StringBinding createWarningTextBinding() { 656 | return Bindings.createStringBinding(this::getWarningText, 657 | imageDataProperty, 658 | modelChoiceBox.getSelectionModel().selectedItemProperty(), 659 | selectedObjectCounter.numSelectedAnnotations, 660 | selectedObjectCounter.numSelectedDetections); 661 | } 662 | 663 | private String getWarningText() { 664 | if (imageDataProperty.get() == null) 665 | return resources.getString("ui.error.no-image"); 666 | if (modelChoiceBox.getSelectionModel().isEmpty()) 667 | return resources.getString("ui.error.no-model"); 668 | if (selectedObjectCounter.numSelectedAnnotations.get() == 0 && selectedObjectCounter.numSelectedDetections.get() == 0) 669 | return resources.getString("ui.error.no-selection"); 670 | return null; 671 | } 672 | } 673 | 674 | 675 | /** 676 | * Helper class for maintaining a count of selected annotations and detections, 677 | * determined from an ImageData property (whose value may change). 678 | * This addresses the awkwardness of attaching/detaching listeners. 679 | */ 680 | private static class SelectedObjectCounter { 681 | 682 | private final ObjectProperty> imageDataProperty = new SimpleObjectProperty<>(); 683 | 684 | private final PathObjectSelectionListener selectionListener = this::selectedPathObjectChanged; 685 | 686 | private final ObservableValue hierarchyProperty; 687 | 688 | private final IntegerProperty numSelectedAnnotations = new SimpleIntegerProperty(); 689 | private final IntegerProperty numSelectedDetections = new SimpleIntegerProperty(); 690 | 691 | SelectedObjectCounter(ObservableValue> imageDataProperty) { 692 | this.imageDataProperty.bind(imageDataProperty); 693 | this.hierarchyProperty = createHierarchyBinding(); 694 | hierarchyProperty.addListener((observable, oldValue, newValue) -> updateHierarchy(oldValue, newValue)); 695 | updateHierarchy(null, hierarchyProperty.getValue()); 696 | } 697 | 698 | private ObjectBinding createHierarchyBinding() { 699 | return Bindings.createObjectBinding(() -> { 700 | var imageData = imageDataProperty.get(); 701 | return imageData == null ? null : imageData.getHierarchy(); 702 | }, 703 | imageDataProperty); 704 | } 705 | 706 | private void updateHierarchy(PathObjectHierarchy oldValue, PathObjectHierarchy newValue) { 707 | if (oldValue == newValue) 708 | return; 709 | if (oldValue != null) 710 | oldValue.getSelectionModel().removePathObjectSelectionListener(selectionListener); 711 | if (newValue != null) 712 | newValue.getSelectionModel().addPathObjectSelectionListener(selectionListener); 713 | updateSelectedObjectCounts(); 714 | } 715 | 716 | private void selectedPathObjectChanged(PathObject pathObjectSelected, PathObject previousObject, Collection allSelected) { 717 | updateSelectedObjectCounts(); 718 | } 719 | 720 | private void updateSelectedObjectCounts() { 721 | if (!Platform.isFxApplicationThread()) { 722 | Platform.runLater(this::updateSelectedObjectCounts); 723 | return; 724 | } 725 | var hierarchy = hierarchyProperty.getValue(); 726 | if (hierarchy == null) { 727 | numSelectedAnnotations.set(0); 728 | numSelectedDetections.set(0); 729 | } else { 730 | var selected = hierarchy.getSelectionModel().getSelectedObjects(); 731 | numSelectedAnnotations.set( 732 | (int)selected 733 | .stream().filter(PathObject::isAnnotation) 734 | .count() 735 | ); 736 | numSelectedDetections.set( 737 | (int)selected 738 | .stream().filter(PathObject::isDetection) 739 | .count() 740 | ); 741 | } 742 | } 743 | 744 | } 745 | 746 | private static class ModelStringConverter extends StringConverter { 747 | 748 | private final WSInferModelCollection models; 749 | 750 | private ModelStringConverter(WSInferModelCollection models) { 751 | Objects.requireNonNull(models, "Models cannot be null"); 752 | this.models = models; 753 | } 754 | 755 | @Override 756 | public String toString(WSInferModel object) { 757 | for (var entry : models.getModels().entrySet()) { 758 | if (entry.getValue() == object) 759 | return entry.getKey(); 760 | } 761 | return ""; 762 | } 763 | 764 | @Override 765 | public WSInferModel fromString(String string) { 766 | return models.getModels().getOrDefault(string, null); 767 | } 768 | } 769 | 770 | } 771 | --------------------------------------------------------------------------------