├── src
└── main
│ ├── resources
│ ├── META-INF
│ │ └── services
│ │ │ └── qupath.lib.gui.extensions.QuPathExtension
│ └── qupath
│ │ └── ext
│ │ └── wsinfer
│ │ └── ui
│ │ ├── progress_dialog.fxml
│ │ ├── wsinferstyles.css
│ │ ├── strings.properties
│ │ └── wsinfer_control.fxml
│ └── java
│ └── qupath
│ └── ext
│ └── wsinfer
│ ├── models
│ ├── package-info.java
│ ├── WSInferModelCollection.java
│ ├── WSInferTransform.java
│ ├── WSInferModelConfiguration.java
│ ├── WSInferModelLocal.java
│ ├── WSInferUtils.java
│ └── WSInferModel.java
│ ├── ui
│ ├── package-info.java
│ ├── WSInferPrefs.java
│ ├── WSInferCommand.java
│ ├── WSInferProgressDialog.java
│ ├── PytorchManager.java
│ └── WSInferController.java
│ ├── ProgressListener.java
│ ├── ProgressLogger.java
│ ├── WSInferExtension.java
│ ├── MpsSupport.java
│ ├── TileLoader.java
│ └── WSInfer.java
├── gradle
└── wrapper
│ ├── gradle-wrapper.jar
│ └── gradle-wrapper.properties
├── .github
└── workflows
│ ├── github_release.yml
│ ├── build.yml
│ ├── scijava-maven.yml
│ └── update-gradle.yml
├── settings.gradle.kts
├── .gitignore
├── CHANGELOG.md
├── README.md
├── gradlew.bat
├── gradlew
└── LICENSE
/src/main/resources/META-INF/services/qupath.lib.gui.extensions.QuPathExtension:
--------------------------------------------------------------------------------
1 | qupath.ext.wsinfer.WSInferExtension
2 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qupath/qupath-extension-wsinfer/HEAD/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/models/package-info.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Classes to access and cache WSInfer models.
3 | */
4 | package qupath.ext.wsinfer.models;
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ui/package-info.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Main user interface classes for the WSInfer extension.
3 | */
4 | package qupath.ext.wsinfer.ui;
--------------------------------------------------------------------------------
/.github/workflows/github_release.yml:
--------------------------------------------------------------------------------
1 | name: Make draft release
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | build:
8 | permissions:
9 | contents: write
10 | uses: qupath/actions/.github/workflows/github-release.yml@main
11 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | distributionBase=GRADLE_USER_HOME
2 | distributionPath=wrapper/dists
3 | distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.2-bin.zip
4 | networkTimeout=10000
5 | validateDistributionUrl=true
6 | zipStoreBase=GRADLE_USER_HOME
7 | zipStorePath=wrapper/dists
8 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Run gradle build
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "main"
10 | workflow_dispatch:
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 | uses: qupath/actions/.github/workflows/gradle.yml@main
18 |
--------------------------------------------------------------------------------
/settings.gradle.kts:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | gradlePluginPortal()
4 | maven {
5 | url = uri("https://maven.scijava.org/content/repositories/releases")
6 | }
7 | }
8 | }
9 |
10 | qupath {
11 | version = "0.6.0"
12 | }
13 |
14 | plugins {
15 | id("io.github.qupath.qupath-extension-settings") version "0.2.1"
16 | }
17 |
--------------------------------------------------------------------------------
/.github/workflows/scijava-maven.yml:
--------------------------------------------------------------------------------
1 | name: Publish release to SciJava Maven
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | release:
7 | type: boolean
8 | description: Should this be a release? (if not, do a snapshot)
9 | required: true
10 |
11 | jobs:
12 | build:
13 | name: Publish release
14 | uses: qupath/actions/.github/workflows/scijava-maven.yml@main
15 | secrets: inherit
16 | with:
17 | release: ${{ inputs.release }}
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Javadocs
2 | docs/
3 |
4 | # Maven
5 | deploy/
6 | target/
7 | log/
8 |
9 | # IntelliJ
10 | .idea/
11 | *.iml
12 | out/
13 |
14 | # Gradle
15 | # Use local properties (e.g. to set a specific JDK)
16 | gradle.properties
17 | build/
18 | .gradle/
19 | gradle.properties
20 |
21 | # Eclipse
22 | .settings/
23 | .project
24 | .classpath
25 |
26 | # VSCode
27 | .vscode/
28 |
29 | # Mac
30 | .DS_Store
31 |
32 | # Java
33 | hs_err*.log
34 |
35 | # Other
36 | *.tmp
37 | *.bak
38 | *.swp
39 | *~.nib
40 | *thumbs.db
41 | bin/
42 |
--------------------------------------------------------------------------------
/src/main/resources/qupath/ext/wsinfer/ui/progress_dialog.fxml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/.github/workflows/update-gradle.yml:
--------------------------------------------------------------------------------
1 | name: Update gradle version
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | gradle-version:
7 | description: Gradle version
8 | default: latest
9 | type: string
10 | required: false
11 |
12 | jobs:
13 | update:
14 | env:
15 | GH_TOKEN: ${{ github.token }}
16 | runs-on: ubuntu-latest
17 | defaults:
18 | run:
19 | shell: bash
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Update gradlew
23 | run: |
24 | ./gradlew wrapper --gradle-version ${{ inputs.gradle-version }}
25 | - name: Commit and push
26 | run: |
27 | git config user.name "github-actions[bot]"
28 | git config user.email "github-actions[bot]@users.noreply.github.com"
29 | git checkout -b gradle-update
30 | git add .
31 | git commit --allow-empty -m "Update gradle via qupath/actions/.github/workflows/update-gradle.yml"
32 | git push -u origin gradle-update
33 | gh pr create --title "Update gradle via actions" --body "$(./gradlew --version)"
34 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ProgressListener.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer;
18 |
19 | /**
20 | * Minimal interface required for a progress listener when running inference.
21 | */
22 | @FunctionalInterface
23 | public interface ProgressListener {
24 |
25 | /**
26 | * Report a progress update.
27 | * @param message optional message to display
28 | * @param progress optional progress value between 0 and 1. If null, only the message should be used.
29 | */
30 | void updateProgress(String message, Double progress);
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## v0.3.0
4 |
5 | * Compatibility with QuPath v0.5 (now the minimum version required)
6 | * Enable auto-updating the extension (https://github.com/qupath/qupath-extension-wsinfer/issues/40)
7 | * Support for local models (https://github.com/qupath/qupath-extension-wsinfer/issues/32)
8 | * Show model information, where available
9 | * Code refactoring and UI improvements
10 | * Show tile predictions per second to help with performance tuning
11 | * Support for custom batch sizes
12 | * Can substantially improve performance with GPU acceleration
13 | * Requires PyTorch 2.0.1 or later for Apple Silicon MPS
14 |
15 | ## v0.2.1
16 |
17 | * Fix CPU/GPU support (https://github.com/qupath/qupath-extension-wsinfer/issues/41)
18 |
19 | ## v0.2.0
20 |
21 | * [Preprint on arXiv](https://arxiv.org/abs/2309.04631)
22 | * [Documentation on ReadTheDocs](https://qupath.readthedocs.io/en/0.4/docs/deep/wsinfer.html)
23 | * Faster tiling for complex regions (https://github.com/qupath/qupath-extension-wsinfer/pull/34)
24 | * Searchable combo box for model selection (https://github.com/qupath/qupath-extension-wsinfer/pull/31)
25 | * Improved support for offline use (once models are downloaded) (https://github.com/qupath/qupath-extension-wsinfer/pull/30)
26 |
27 | ## v0.1.0
28 |
29 | * First release!
30 |
--------------------------------------------------------------------------------
/src/main/resources/qupath/ext/wsinfer/ui/wsinferstyles.css:
--------------------------------------------------------------------------------
1 | /* Title */
2 | .wsinfer-title {
3 | -fx-font-size: 150%;
4 | -fx-font-weight: bold;
5 | -fx-text-fill: -fx-text-base-color;
6 | }
7 |
8 | /* Sub-title */
9 | .wsinfer-sub-title {
10 | -fx-text-fill: -fx-text-base-color;
11 | -fx-opacity:0.6;
12 | -fx-font-style: italic;
13 | /*-fx-padding: -7.5;*/
14 | }
15 |
16 | /* message-label-styles */
17 | .warning-message {
18 | -fx-text-fill: -qp-script-error-color;
19 | -fx-opacity:0.8;
20 | -fx-font-style: italic;
21 | }
22 | .standard-message {
23 | -fx-text-fill: -fx-text-base-color;
24 | -fx-opacity:0.6;
25 | -fx-font-style: italic;
26 | }
27 |
28 | /* separator */
29 | .wsinfer .separator {
30 | -fx-line-color: -fx-text-base-color;
31 | -fx-padding: 15;
32 | }
33 |
34 | /* spacing - current use in HBox*/
35 | .standard-spacing {
36 | -fx-spacing: 7.5;
37 | }
38 |
39 | /* spacing - current use in VBox*/
40 | .standard-vertical-spacing {
41 | -fx-spacing: 5;
42 | }
43 |
44 | .standard-padding {
45 | -fx-padding: 2;
46 | }
47 |
48 | /*download button*/
49 | /*.download-btn {
50 | -fx-background-image: url("icons/download.png");
51 | }*/
52 |
53 | .fa-icon {
54 | -fx-font-family: FontAwesome;
55 | -fx-font-size: 14px;
56 | -fx-fill: -fx-text-base-color;
57 | }
58 |
59 |
60 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/models/WSInferModelCollection.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.models;
18 |
19 | import java.util.Collections;
20 | import java.util.Map;
21 |
22 | /**
23 | * Wrapper for a collection of WSInfer models.
24 | *
25 | * See https://github.com/SBU-BMI/wsinfer-zoo/blob/main/wsinfer_zoo/schemas/wsinfer-zoo-registry.schema.json
26 | */
27 | public class WSInferModelCollection {
28 |
29 | private Map models;
30 |
31 | /**
32 | * Get a map of model names to models.
33 | * @return
34 | */
35 | public Map getModels() {
36 | return Collections.synchronizedMap(models);
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/models/WSInferTransform.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.models;
18 |
19 | import java.util.Map;
20 |
21 | /**
22 | * A single preprocessing transform to use with a WSInfer model.
23 | */
24 | public class WSInferTransform {
25 |
26 | private String name;
27 | private Map arguments;
28 |
29 | /**
30 | * Name of the transform.
31 | * @return
32 | */
33 | public String getName() {
34 | return name;
35 | }
36 |
37 | /**
38 | * Optional arguments associated with the transform.
39 | * @return
40 | */
41 | public Map getArguments() {
42 | return arguments;
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ProgressLogger.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer;
18 |
19 | import org.slf4j.Logger;
20 | import org.slf4j.LoggerFactory;
21 |
22 | /**
23 | * Implementation of {@link ProgressListener} that sends progress messages to a logger.
24 | */
25 | public class ProgressLogger implements ProgressListener {
26 |
27 | private Logger logger;
28 |
29 | public ProgressLogger(Logger logger) {
30 | this.logger = logger == null ? LoggerFactory.getLogger(ProgressLogger.class) : logger;
31 | }
32 |
33 | @Override
34 | public void updateProgress(String message, Double progress) {
35 | if (message == null && progress == null)
36 | return;
37 | if (message == null)
38 | logger.info("{}%", Math.round(progress * 100));
39 | else if (progress == null)
40 | logger.info(message);
41 | else
42 | logger.info("{}: {}%", message, Math.round(progress*100));
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://doi.org/10.1038/s41698-024-00499-9)
2 | [](https://qupath.readthedocs.io/en/stable/docs/deep/wsinfer.html)
3 | [](https://forum.image.sc/tag/wsinfer)
4 | [](https://github.com/qupath/qupath-extension-wsinfer/releases/latest)
5 | [](https://github.com/qupath/qupath-extension-wsinfer/releases)
6 |
7 |
8 | # QuPath WSInfer extension
9 |
10 | This repo contains the QuPath extension that adds support for the [WSInfer Model Zoo](https://wsinfer.readthedocs.io).
11 |
12 | Together, QuPath & WSInfer aim to make deep learning-based patch classification easy and interactive for pathology images.
13 |
14 | To find out more, check out the paper:
15 |
16 | > Kaczmarzyk, J.R., O’Callaghan, A., Inglis, F. et al.
17 | > Open and reusable deep learning for pathology with WSInfer and QuPath.
18 | > npj Precis. Onc. 8, 9 (2024). https://doi.org/10.1038/s41698-024-00499-9
19 |
20 | **If you use this extension, please cite both the [WSInfer paper](https://doi.org/10.1038/s41698-024-00499-9) and the [QuPath paper](https://qupath.readthedocs.io/en/0.4/docs/intro/citing.html)!**
21 |
22 | ## Installation
23 |
24 | Download the latest version of the extension from the [releases page](https://github.com/qupath/qupath-extension-wsinfer/releases).
25 |
26 | Then drag & drop the downloaded .jar file onto the main QuPath window to install it.
27 |
28 | See [the QuPath docs](https://qupath.readthedocs.io/en/0.5/docs/intro/extensions.html) for more info about installing & updating extensions.
29 |
30 | ## Usage
31 |
32 | The WSInfer extension adds a new menu item to QuPath's **Extensions** menu, which can be used to open a WSInfer dialog.
33 |
34 | The dialog tries to guide you through the main steps, from top to bottom.
35 | For a step-by-step guide,see [ReadTheDocs](https://qupath.readthedocs.io/en/stable/docs/deep/wsinfer.html).
36 |
37 | And if things don't go to plan or you want to discuss more, please post your questions on the [Scientific Community Image Forum (image.sc)](https://forum.image.sc/tag/wsinfer) - and tag both `wsinfer` and `qupath` so we notice.
38 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.models;
18 |
19 | import com.google.gson.annotations.SerializedName;
20 |
21 | import java.util.List;
22 |
23 | /**
24 | * Class representing the configuration of a WSInfer model.
25 | *
26 | * This is a Java representation of the JSON configuration file,
27 | * and stores the key information needed to run the model (preprocessing, resolution, output classes).
28 | *
29 | * See https://github.com/SBU-BMI/wsinfer-zoo/blob/main/wsinfer_zoo/schemas/model-config.schema.json
30 | */
31 | public class WSInferModelConfiguration {
32 |
33 | @SerializedName("spec_version")
34 | private String specVersion;
35 |
36 | private String architecture;
37 |
38 | @SerializedName("num_classes")
39 | private int numClasses;
40 | private List class_names;
41 |
42 | @SerializedName("patch_size_pixels")
43 | private int patchSizePixels;
44 |
45 | @SerializedName("spacing_um_px")
46 | private double spacingUmPx;
47 |
48 | private List transform;
49 |
50 | /**
51 | * Output classification names.
52 | * @return
53 | */
54 | public List getClassNames() {
55 | return class_names;
56 | }
57 |
58 | /**
59 | * Transform for preprocessing input patches.
60 | * @return
61 | */
62 | public List getTransform() {
63 | return transform;
64 | }
65 |
66 | /**
67 | * Classification patch size in pixels
68 | * @return
69 | */
70 | public double getPatchSizePixels() {
71 | return patchSizePixels;
72 | }
73 |
74 | /**
75 | * Requested pixel size in microns.
76 | * @return
77 | */
78 | public double getSpacingMicronPerPixel() {
79 | return spacingUmPx;
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.ui;
18 |
19 | import javafx.beans.property.Property;
20 | import javafx.beans.property.StringProperty;
21 | import qupath.lib.gui.UserDirectoryManager;
22 | import qupath.lib.gui.prefs.PathPrefs;
23 |
24 | import java.nio.file.Path;
25 | import java.nio.file.Paths;
26 |
27 | /**
28 | * Class to store preferences associated with WSInfer.
29 | */
30 | public class WSInferPrefs {
31 |
32 | private static final StringProperty modelDirectoryProperty = PathPrefs.createPersistentPreference(
33 | "wsinfer.model.dir",
34 | Paths.get(getUserDir().toString(),"wsinfer").toString()
35 | );
36 |
37 | private static final StringProperty deviceProperty = PathPrefs.createPersistentPreference(
38 | "wsinfer.device",
39 | "cpu"
40 | );
41 |
42 | private static final Property numWorkersProperty = PathPrefs.createPersistentPreference(
43 | "wsinfer.numWorkers",
44 | Math.min(4, Runtime.getRuntime().availableProcessors())
45 | ).asObject();
46 |
47 | private static final Property batchSizeProperty = PathPrefs.createPersistentPreference(
48 | "wsinfer.batchSize",
49 | 4
50 | ).asObject();
51 |
52 | /**
53 | * String storing the preferred directory to cache models.
54 | */
55 | public static StringProperty modelDirectoryProperty() {
56 | return modelDirectoryProperty;
57 | }
58 |
59 | /**
60 | * String storing the preferred device to use for model inference.
61 | */
62 | public static StringProperty deviceProperty() {
63 | return deviceProperty;
64 | }
65 |
66 | /**
67 | * Integer storing the preferred number of workers for tile requests.
68 | */
69 | public static Property numWorkersProperty() {
70 | return numWorkersProperty;
71 | }
72 |
73 | /**
74 | * Integer storing the batch size for inference.
75 | */
76 | public static Property batchSizeProperty() {
77 | return batchSizeProperty;
78 | }
79 |
80 | private static Path getUserDir() {
81 | Path userPath = UserDirectoryManager.getInstance().getUserPath();
82 | Path cachePath = Paths.get(System.getProperty("user.dir"), ".cache", "QuPath");
83 | return userPath == null || userPath.toString().isEmpty() ? cachePath : userPath;
84 | }
85 |
86 | }
87 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/models/WSInferModelLocal.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.models;
18 |
19 | import org.apache.commons.compress.utils.FileNameUtils;
20 |
21 | import java.io.File;
22 | import java.io.IOException;
23 | import java.util.Arrays;
24 | import java.util.List;
25 | import java.util.Objects;
26 | import java.util.ResourceBundle;
27 |
28 | public class WSInferModelLocal extends WSInferModel {
29 |
30 | private final File modelDirectory;
31 | private static final ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");
32 | private File torchscriptFile;
33 |
34 | /**
35 | * Try to create a WSInfer model from a user directory.
36 | * @param modelDirectory A user directory containing a
37 | * torchscript.pt file and a config.json file
38 | * @return A {@link WSInferModel} if the directory supplied is valid,
39 | * otherwise nothing.
40 | */
41 | public static WSInferModelLocal createInstance(File modelDirectory) throws IOException {
42 | return new WSInferModelLocal(modelDirectory);
43 | }
44 |
45 | private WSInferModelLocal(File modelDirectory) throws IOException {
46 | this.modelDirectory = modelDirectory;
47 | this.hfRepoId = modelDirectory.getName();
48 | List files = Arrays.asList(Objects.requireNonNull(modelDirectory.listFiles()));
49 | for (File f: files) {
50 | if (FileNameUtils.getExtension(f.toPath()).equals("pt")) {
51 | this.torchscriptFile = f;
52 | }
53 | }
54 | if (!files.contains(getConfigFile())) {
55 | throw new IOException(resources.getString("error.localModel") + ": " + getConfigFile().toString());
56 | }
57 | if (torchscriptFile == null) {
58 | throw new IOException(resources.getString("error.localModel") + ": " + "torchscript model (.pt)");
59 | }
60 | }
61 |
62 | @Override
63 | File getModelDirectory() {
64 | return this.modelDirectory;
65 | }
66 |
67 | @Override
68 | public boolean isValid() {
69 | return getTorchScriptFile().exists() && getConfiguration() != null;
70 | }
71 |
72 | @Override
73 | public File getTorchScriptFile() {
74 | return this.torchscriptFile;
75 | }
76 |
77 | @Override
78 | public synchronized void downloadModel() {}
79 |
80 | @Override
81 | public synchronized void removeCache() {}
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/WSInferExtension.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer;
18 |
19 | import javafx.beans.property.BooleanProperty;
20 | import javafx.scene.control.MenuItem;
21 | import org.slf4j.Logger;
22 | import org.slf4j.LoggerFactory;
23 | import qupath.ext.wsinfer.ui.WSInferCommand;
24 | import qupath.lib.common.Version;
25 | import qupath.lib.gui.QuPathGUI;
26 | import qupath.lib.gui.extensions.GitHubProject;
27 | import qupath.lib.gui.extensions.QuPathExtension;
28 | import qupath.lib.gui.prefs.PathPrefs;
29 |
30 | import java.util.ResourceBundle;
31 |
32 | /**
33 | * QuPath extension to run patch-based deep learning inference with WSInfer.
34 | * See https://wsinfer.readthedocs.io for more info.
35 | */
36 | public class WSInferExtension implements QuPathExtension, GitHubProject {
37 | private final static ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");
38 |
39 | private final static Logger logger = LoggerFactory.getLogger(WSInferExtension.class);
40 |
41 | private final static String EXTENSION_NAME = resources.getString("extension.title");
42 |
43 | private final static String EXTENSION_DESCRIPTION = resources.getString("extension.description");
44 |
45 | private final static Version EXTENSION_QUPATH_VERSION = Version.parse(resources.getString("extension.qupath.version"));
46 |
47 | private boolean isInstalled = false;
48 |
49 | private final BooleanProperty enableExtensionProperty = PathPrefs.createPersistentPreference(
50 | "enableExtension", true);
51 |
52 | @Override
53 | public void installExtension(QuPathGUI qupath) {
54 | if (isInstalled) {
55 | logger.debug("{} is already installed", getName());
56 | return;
57 | }
58 | isInstalled = true;
59 | addMenuItems(qupath);
60 | }
61 |
62 | private void addMenuItems(QuPathGUI qupath) {
63 | var menu = qupath.getMenu("Extensions>" + EXTENSION_NAME, true);
64 | MenuItem menuItem = new MenuItem(resources.getString("title"));
65 | WSInferCommand command = new WSInferCommand(qupath);
66 | menuItem.setOnAction(e -> command.run());
67 | menuItem.disableProperty().bind(enableExtensionProperty.not());
68 | menu.getItems().add(menuItem);
69 | }
70 |
71 | @Override
72 | public String getName() {
73 | return EXTENSION_NAME;
74 | }
75 |
76 | @Override
77 | public String getDescription() {
78 | return EXTENSION_DESCRIPTION;
79 | }
80 |
81 | @Override
82 | public Version getQuPathVersion() {
83 | return EXTENSION_QUPATH_VERSION;
84 | }
85 |
86 | @Override
87 | public GitHubRepo getRepository() {
88 | return GitHubRepo.create(getName(), "qupath", "qupath-extension-wsinfer");
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 | @rem SPDX-License-Identifier: Apache-2.0
17 | @rem
18 |
19 | @if "%DEBUG%"=="" @echo off
20 | @rem ##########################################################################
21 | @rem
22 | @rem Gradle startup script for Windows
23 | @rem
24 | @rem ##########################################################################
25 |
26 | @rem Set local scope for the variables with windows NT shell
27 | if "%OS%"=="Windows_NT" setlocal
28 |
29 | set DIRNAME=%~dp0
30 | if "%DIRNAME%"=="" set DIRNAME=.
31 | @rem This is normally unused
32 | set APP_BASE_NAME=%~n0
33 | set APP_HOME=%DIRNAME%
34 |
35 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
36 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
37 |
38 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
39 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
40 |
41 | @rem Find java.exe
42 | if defined JAVA_HOME goto findJavaFromJavaHome
43 |
44 | set JAVA_EXE=java.exe
45 | %JAVA_EXE% -version >NUL 2>&1
46 | if %ERRORLEVEL% equ 0 goto execute
47 |
48 | echo. 1>&2
49 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2
50 | echo. 1>&2
51 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2
52 | echo location of your Java installation. 1>&2
53 |
54 | goto fail
55 |
56 | :findJavaFromJavaHome
57 | set JAVA_HOME=%JAVA_HOME:"=%
58 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
59 |
60 | if exist "%JAVA_EXE%" goto execute
61 |
62 | echo. 1>&2
63 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2
64 | echo. 1>&2
65 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2
66 | echo location of your Java installation. 1>&2
67 |
68 | goto fail
69 |
70 | :execute
71 | @rem Setup the command line
72 |
73 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
74 |
75 |
76 | @rem Execute Gradle
77 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
78 |
79 | :end
80 | @rem End local scope for the variables with windows NT shell
81 | if %ERRORLEVEL% equ 0 goto mainEnd
82 |
83 | :fail
84 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
85 | rem the _cmd.exe /c_ return code!
86 | set EXIT_CODE=%ERRORLEVEL%
87 | if %EXIT_CODE% equ 0 set EXIT_CODE=1
88 | if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
89 | exit /b %EXIT_CODE%
90 |
91 | :mainEnd
92 | if "%OS%"=="Windows_NT" endlocal
93 |
94 | :omega
95 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ui/WSInferCommand.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.ui;
18 |
19 | import javafx.fxml.FXMLLoader;
20 | import javafx.scene.Scene;
21 | import javafx.scene.layout.BorderPane;
22 | import javafx.scene.layout.VBox;
23 | import javafx.stage.Stage;
24 | import org.slf4j.Logger;
25 | import org.slf4j.LoggerFactory;
26 | import qupath.fx.utils.FXUtils;
27 | import qupath.lib.gui.QuPathGUI;
28 | import qupath.fx.dialogs.Dialogs;
29 |
30 |
31 | import java.io.IOException;
32 | import java.net.URL;
33 | import java.util.ResourceBundle;
34 |
35 | /**
36 | * Command to open the QuPath-WSInfer user interface.
37 | */
38 | public class WSInferCommand implements Runnable {
39 |
40 | private static final Logger logger = LoggerFactory.getLogger(WSInferCommand.class);
41 |
42 | private final QuPathGUI qupath;
43 |
44 | private final ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");
45 |
46 | private Stage stage;
47 | private WSInferController controller;
48 |
49 | public WSInferCommand(QuPathGUI qupath) {
50 | this.qupath = qupath;
51 | }
52 |
53 | @Override
54 | public void run() {
55 | if (stage == null) {
56 | try {
57 | stage = createStage();
58 | stage.show();
59 | FXUtils.retainWindowPosition(stage);
60 | } catch (IOException e) {
61 | Dialogs.showErrorMessage(resources.getString("title"),
62 | resources.getString("error.window"));
63 | logger.error(e.getMessage(), e);
64 | return;
65 | }
66 | } else {
67 | controller.refreshAvailableModels();
68 | stage.show();
69 | }
70 | }
71 |
72 | private Stage createStage() throws IOException {
73 |
74 | URL url = getClass().getResource("wsinfer_control.fxml");
75 | if (url == null) {
76 | throw new IOException("Cannot find URL for WSInfer FXML");
77 | }
78 |
79 | // We need to use the ExtensionClassLoader to load the FXML, since it's in a different module
80 | var loader = new FXMLLoader(url, resources);
81 | loader.setClassLoader(this.getClass().getClassLoader());
82 | VBox root = loader.load();
83 | controller = loader.getController();
84 |
85 | // There's probably a better approach... but wrapping in a border pane
86 | // helped me get the resizing to behave
87 | BorderPane pane = new BorderPane(root);
88 | Scene scene = new Scene(pane);
89 |
90 | Stage stage = new Stage();
91 | stage.initOwner(qupath.getStage());
92 | stage.setTitle(resources.getString("title"));
93 | stage.setScene(scene);
94 | stage.setResizable(false);
95 |
96 | root.heightProperty().addListener((v, o, n) -> handleStageHeightChange());
97 |
98 | return stage;
99 | }
100 |
101 | private void handleStageHeightChange() {
102 | stage.sizeToScene();
103 | // This fixes a bug where the stage would migrate to the corner of a screen if it is
104 | // resized, hidden, then shown again
105 | if (stage.isShowing() && Double.isFinite(stage.getX()) && Double.isFinite(stage.getY()))
106 | FXUtils.retainWindowPosition(stage);
107 | }
108 |
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ui/WSInferProgressDialog.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.ui;
18 |
19 | import javafx.application.Platform;
20 | import javafx.event.ActionEvent;
21 | import javafx.event.EventHandler;
22 | import javafx.fxml.FXML;
23 | import javafx.fxml.FXMLLoader;
24 | import javafx.scene.Scene;
25 | import javafx.scene.control.Button;
26 | import javafx.scene.control.Label;
27 | import javafx.scene.control.ProgressBar;
28 | import javafx.scene.layout.AnchorPane;
29 | import javafx.stage.Modality;
30 | import javafx.stage.Stage;
31 | import javafx.stage.Window;
32 | import org.slf4j.Logger;
33 | import org.slf4j.LoggerFactory;
34 | import qupath.ext.wsinfer.ProgressListener;
35 |
36 | import java.io.IOException;
37 | import java.net.URL;
38 | import java.util.ResourceBundle;
39 |
40 |
41 | /**
42 | * Implementation of {@link ProgressListener} that shows a JavaFX progress bar.
43 | */
44 | class WSInferProgressDialog extends AnchorPane implements ProgressListener {
45 |
46 | private static final Logger logger = LoggerFactory.getLogger(WSInferProgressDialog.class);
47 | private final Stage stage;
48 | @FXML
49 | private ProgressBar progressBar;
50 | @FXML
51 | private Label progressLabel;
52 | @FXML
53 | private Button btnCancel;
54 |
55 | private boolean isCancelled = false;
56 |
57 | public WSInferProgressDialog(Window owner, EventHandler cancelHandler) {
58 | URL url = getClass().getResource("progress_dialog.fxml");
59 | ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");
60 | FXMLLoader loader = new FXMLLoader(url, resources);
61 | loader.setRoot(this);
62 | loader.setController(this);
63 | try {
64 | loader.load();
65 | } catch (IOException e) {
66 | logger.error("Cannot find FXML for class WSInferProgressDialog", e);
67 | }
68 | progressLabel.setLabelFor(progressBar);
69 | stage = new Stage();
70 | stage.setTitle(resources.getString("title"));
71 | stage.setResizable(false);
72 | stage.initOwner(owner);
73 | stage.initModality(Modality.APPLICATION_MODAL);
74 | stage.setScene(new Scene(this));
75 | btnCancel.setOnAction(cancelHandler);
76 | }
77 |
78 |
79 | @Override
80 | public void updateProgress(String message, Double progress) {
81 | if (isCancelled)
82 | return;
83 | if (Platform.isFxApplicationThread()) {
84 | if (message != null)
85 | progressLabel.setText(message);
86 | if (progress != null) {
87 | progressBar.setProgress(progress);
88 | if (progress.doubleValue() >= 1.0) {
89 | stage.hide();
90 | } else {
91 | stage.show();
92 | }
93 | }
94 | } else {
95 | Platform.runLater(() -> updateProgress(message, progress));
96 | }
97 | }
98 |
99 | /**
100 | * Immediately cancel and hide the progress dialog.
101 | * Subsequent calls to updateProgress will be ignored.
102 | */
103 | public void cancel() {
104 | isCancelled = true;
105 | stage.hide();
106 | progressLabel.setText("Cancelled");
107 | progressBar.setProgress(1.0);
108 | }
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/src/main/resources/qupath/ext/wsinfer/ui/strings.properties:
--------------------------------------------------------------------------------
1 | ## General
2 | title = WSInfer
3 | extension.description = Deep learning inference on tiled whole slide images using WSInfer.
4 | extension.title = WSInfer Extension
5 | # Required QuPath version
6 | extension.qupath.version = v0.5.0
7 |
8 | ## Main Window
9 | # Workflow
10 | workflow.title = Run WSInfer model
11 |
12 | # Processing tab
13 | ui.processing.pane = Run Inference
14 | ui.model = Select a model
15 | ui.model.tooltip = Select the model you would like to use
16 | ui.model.download.tooltip = Download the selected model
17 | ui.model-popup = The selected model is not available - do you want to download it?
18 | ui.model-not-downloaded = Model not downloaded - no inference performed
19 | ui.model.info.tooltip = Show model description and citation
20 |
21 | ui.selection.label = Create or select an annotation
22 | ui.selection.alt = or
23 | ui.selection.sub-label = Select all
24 | ui.selection.all-annotations = Annotations
25 | ui.selection.all-annotations.tooltip = Selects all annotations in current image - tiles will be generated within these annotations
26 | ui.selection.all-tiles = Tiles
27 | ui.selection.all-tiles.tooltip = Selects all tiles in current image - inferences will be made using these tiles
28 |
29 | ui.run = Run
30 | ui.run.tooltip = Run the selected model
31 | ui.error.no-selection = No annotation or detection selected
32 | ui.error.no-model = No model selected
33 | ui.error.no-image = No image selected
34 |
35 | ui.selection.annotations-single = 1 annotation selected
36 | ui.selection.annotations-multiple = %d annotations selected
37 | ui.selection.detections-single = 1 detection selected
38 | ui.selection.detections-multiple = %d detections selected
39 | ui.selection.empty = No valid objects selected
40 |
41 | # Results viewing tools tab
42 | ui.results.pane = View Results
43 | ui.results.open-measurement-maps = Measurement Maps
44 | ui.results.maps.tooltip = Open the measurement maps window
45 | ui.results.open-results = Results Table
46 | ui.results.results.tooltip = Open the results table window
47 | ui.results.slider = Overlay Opacity
48 | ui.results.slider.tooltip = Adjust the opacity of the QuPath viewer overlay
49 |
50 | # Hardware tab
51 | ui.options.pane = Additional Options
52 | ui.options.device = Preferred device:
53 | ui.options.device.tooltip = Select the preferred device for model running (choose CPU if other options are not available)
54 | ui.options.directory = Downloaded model directory:
55 | ui.options.directory.tooltip = Choose the directory where models should be stored
56 | ui.options.localModelDirectory = User model directory:
57 | ui.options.localModelDirectory.tooltip = Choose the directory where user-created models can be loaded from
58 | ui.options.pworkers = Number of parallel tile loaders:
59 | ui.options.pworkers.tooltip = Choose the desired number of threads used to request tiles for inference
60 | ui.options.batchSize = Batch size:
61 | ui.options.batchSize.tooltip = Choose the batch size for inference
62 |
63 | # Model directories
64 | ui.model-directory.choose-directory = Choose directory
65 | ui.model-directory.no-local-models = No local models found
66 | ui.model-directory.found-1-local-model = 1 local model found
67 | ui.model-directory.found-n-local-models = %d local models found
68 |
69 | ## Other Windows
70 | # Processing Window and progress pop-ups
71 | ui.processing = Processing tiles
72 | ui.processing-progress = Processing %d/%d tiles (%.1f per second)
73 | ui.processing-completed = Completed %d/%d tiles (%.1f per second)
74 | ui.cancel = Cancel
75 | ui.popup.fetching = Downloading model: %s
76 | ui.popup.available = Model available: %s
77 |
78 | # PyTorch Download Window
79 | ui.pytorch-downloading = Downloading PyTorch engine...
80 | ui.pytorch = PyTorch engine not found - would you like to download it?\n This may take some time (but is only required once).
81 | ui.pytorch-popup = No inference performed - PyTorch engine not found.
82 |
83 | # Stop Tasks Window
84 | ui.stop-tasks = Stop all running tasks?
85 |
86 | ## Errors
87 | error.window = Error initializing WSInfer window.\nAn internet connection is required when running for the first time.
88 | error.no-imagedata = Cannot run WSInfer without ImageData.
89 | error.downloading = Error downloading files
90 | error.localModel = Can't find file in user model directory
91 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/ui/PytorchManager.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer.ui;
18 |
19 | import ai.djl.engine.Engine;
20 | import ai.djl.engine.EngineException;
21 | import org.slf4j.Logger;
22 | import org.slf4j.LoggerFactory;
23 | import qupath.lib.common.GeneralTools;
24 |
25 | import java.io.IOException;
26 | import java.util.Collection;
27 | import java.util.LinkedHashSet;
28 | import java.util.List;
29 | import java.util.Set;
30 | import java.util.concurrent.Callable;
31 |
32 | /**
33 | * Helper class to manage access to PyTorch via Deep Java Library.
34 | */
35 | class PytorchManager {
36 |
37 | private static final Logger logger = LoggerFactory.getLogger(PytorchManager.class);
38 |
39 | /**
40 | * Get the available devices for PyTorch, including MPS if Apple Silicon.
41 | * @return
42 | */
43 | static Collection getAvailableDevices() {
44 | Set availableDevices = new LinkedHashSet<>();
45 | try {
46 |
47 | var engine = getEngineOffline();
48 | if (engine != null) {
49 | // This is expected to return GPUs if available, or CPU otherwise
50 | for (var device : engine.getDevices()) {
51 | String name = device.getDeviceType();
52 | availableDevices.add(name);
53 | }
54 | }
55 | // CPU should always be available
56 | availableDevices.add("cpu");
57 |
58 | // If we could use MPS, but don't have it already, add it
59 | if (GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) {
60 | availableDevices.add("mps");
61 | }
62 | return availableDevices;
63 | } catch (Exception e) {
64 | logger.info("Unable to load engine", e);
65 | return List.of();
66 | }
67 | }
68 |
69 | /**
70 | * Query if the PyTorch engine is already available, without a need to downlaod.
71 | * @return
72 | */
73 | static boolean hasPyTorchEngine() {
74 | return getEngineOffline() != null;
75 | }
76 |
77 | /**
78 | * Get the PyTorch engine, without automatically downloading it if it isn't available.
79 | * @return the engine if available, or null otherwise
80 | */
81 | static Engine getEngineOffline() {
82 | try {
83 | return callOffline(() -> Engine.getEngine("PyTorch"));
84 | } catch (EngineException | IOException e) {
85 | logger.info("Unable to load PyTorch", e);
86 | return null;
87 | } catch (Exception e) {
88 | logger.error(e.getMessage(), e);
89 | return null;
90 | }
91 | }
92 |
93 | /**
94 | * Get the PyTorch engine, downloading if necessary.
95 | * @return the engine if available, or null if this failed
96 | */
97 | static Engine getEngineOnline() {
98 | try {
99 | return callOnline(() -> Engine.getEngine("PyTorch"));
100 | } catch (Exception e) {
101 | logger.error(e.getMessage(), e);
102 | return null;
103 | }
104 | }
105 |
106 | /**
107 | * Call a function with the "offline" property set to true (to block automatic downloads).
108 | * @param callable Function that'll be called offline.
109 | * @return
110 | * @param
111 | * @throws EngineException If the engine can't be loaded (probably because it hasn't been downloaded)
112 | */
113 | private static T callOffline(Callable callable) throws Exception {
114 | return callWithTempProperty("ai.djl.offline", "true", callable);
115 | }
116 |
117 | /**
118 | * Call a function with the "offline" property set to false (to allow automatic downloads).
119 | * @param callable
120 | * @return
121 | * @param
122 | * @throws Exception
123 | */
124 | private static T callOnline(Callable callable) throws Exception {
125 | return callWithTempProperty("ai.djl.offline", "false", callable);
126 | }
127 |
128 |
129 | private static T callWithTempProperty(String property, String value, Callable callable) throws Exception {
130 | String oldValue = System.getProperty(property);
131 | System.setProperty(property, value);
132 | try {
133 | return callable.call();
134 | } finally {
135 | if (oldValue == null)
136 | System.clearProperty(property);
137 | else
138 | System.setProperty(property, oldValue);
139 | }
140 | }
141 |
142 |
143 | }
144 |
--------------------------------------------------------------------------------
/src/main/java/qupath/ext/wsinfer/MpsSupport.java:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2023 University of Edinburgh
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package qupath.ext.wsinfer;
18 |
19 | import ai.djl.modality.Classifications;
20 | import ai.djl.modality.cv.translator.BaseImageTranslator;
21 | import ai.djl.ndarray.NDArray;
22 | import ai.djl.ndarray.NDList;
23 | import ai.djl.ndarray.NDManager;
24 | import ai.djl.ndarray.types.DataType;
25 | import ai.djl.translate.Transform;
26 | import ai.djl.translate.TranslatorContext;
27 |
28 | import java.io.IOException;
29 | import java.util.ArrayList;
30 | import java.util.List;
31 |
32 | /**
33 | * This class contains two subclasses closely based upon classes in Deep Java Library (Apache v2.0,
34 | * see https://github.com/deepjavalibrary/djl)
35 | *
36 | * They have been adapted as described below to work around issues with MPS devices on Apple Silicon.
37 | *
38 | * The classes here should be removed if future updates to DJL make them unnecessary.
39 | *
40 | * It should be possible to use these classes for other devices, since the changes do not restrict use to MPS.
41 | */
42 | class MpsSupport {
43 |
44 | /**
45 | * This is based upon the {@link ai.djl.modality.cv.transform.ToTensor} class, but adapted to work with MPS devices
46 | * on Apple Silicon.
47 | *
48 | * The reason the original class can't be used is {@code result.div(255.0)} requires a float64, which fails on MPS.
49 | * The only required change is to use {@code result.div(255.0f)} instead.
50 | *