├── .github ├── build.sh ├── setup.sh └── workflows │ ├── build-main.yml │ └── build-pr.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── pom.xml └── src ├── main └── java │ └── net │ └── imagej │ └── tensorflow │ ├── CachedModelBundle.java │ ├── DefaultTensorFlowService.java │ ├── GraphBuilder.java │ ├── TensorFlowLibraryStatus.java │ ├── TensorFlowService.java │ ├── TensorFlowVersion.java │ ├── Tensors.java │ ├── demo │ └── LabelImage.java │ ├── ui │ ├── AvailableTensorFlowVersions.java │ ├── DownloadableTensorFlowVersion.java │ ├── TensorFlowInstallationHandler.java │ ├── TensorFlowLibraryManagementCommand.java │ └── TensorFlowLibraryManagementFrame.java │ └── util │ ├── TensorFlowUtil.java │ └── UnpackUtil.java └── test └── java └── net └── imagej └── tensorflow ├── TensorsTest.java ├── VersionBuilderTest.java ├── demo └── LabelImageTest.java └── ui ├── AvailableTensorFlowVersionsTest.java └── TensorFlowVersionTest.java /.github/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | curl -fsLO https://raw.githubusercontent.com/scijava/scijava-scripts/master/ci-build.sh 3 | sh ci-build.sh 4 | -------------------------------------------------------------------------------- /.github/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | curl -fsLO https://raw.githubusercontent.com/scijava/scijava-scripts/master/ci-setup-github-actions.sh 3 | sh ci-setup-github-actions.sh 4 | -------------------------------------------------------------------------------- /.github/workflows/build-main.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | tags: 8 | - "*-[0-9]+.*" 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Java 17 | uses: actions/setup-java@v3 18 | with: 19 | java-version: '8' 20 | distribution: 'zulu' 21 | cache: 'maven' 22 | - name: Set up CI environment 23 | run: .github/setup.sh 24 | - name: Execute the build 25 | run: .github/build.sh 26 | env: 27 | GPG_KEY_NAME: ${{ secrets.GPG_KEY_NAME }} 28 | GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} 29 | MAVEN_USER: ${{ secrets.MAVEN_USER }} 30 | MAVEN_PASS: ${{ secrets.MAVEN_PASS }} 31 | OSSRH_PASS: ${{ secrets.OSSRH_PASS }} 32 | SIGNING_ASC: ${{ secrets.SIGNING_ASC }} 33 | -------------------------------------------------------------------------------- /.github/workflows/build-pr.yml: -------------------------------------------------------------------------------- 1 | name: build PR 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Java 15 | uses: actions/setup-java@v3 16 | with: 17 | java-version: '8' 18 | distribution: 'zulu' 19 | cache: 'maven' 20 | - name: Set up CI environment 21 | run: .github/setup.sh 22 | - name: Execute the build 23 | run: .github/build.sh 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /models/ 2 | # Files generated by the maven build. 3 | /target/ 4 | # Eclipse metadata files. 5 | /.classpath 6 | /.project 7 | /.settings/ 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2025, Board of Regents of the University of 2 | Wisconsin-Madison and Google, Inc. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 19 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 20 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 21 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 22 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 23 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 24 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://github.com/imagej/imagej-tensorflow/actions/workflows/build-main.yml/badge.svg)](https://github.com/imagej/imagej-tensorflow/actions/workflows/build-main.yml) 2 | 3 | # [ImageJ](https://imagej.net/) + [TensorFlow](https://www.tensorflow.org) integration layer 4 | 5 | This component is a library which can translate between ImageJ images and 6 | TensorFlow tensors. 7 | 8 | It also contains a demo ImageJ command for classifying images using a 9 | TensorFlow image model, adapted from the [TensorFlow image recognition 10 | tutorial](https://www.tensorflow.org/tutorials/image_recognition). 11 | 12 | ## Quickstart 13 | 14 | ```sh 15 | git clone https://github.com/imagej/imagej-tensorflow 16 | cd imagej-tensorflow 17 | mvn -Pexec 18 | ``` 19 | 20 | This requires [Maven](https://maven.apache.org/install.html). Typically `brew 21 | install maven` on OS X, `apt-get install maven` on Ubuntu, or [detailed 22 | instructions](https://maven.apache.org/install.html) otherwise. 23 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.scijava 7 | pom-scijava 8 | 40.0.0 9 | 10 | 11 | 12 | net.imagej 13 | imagej-tensorflow 14 | 1.1.9-SNAPSHOT 15 | 16 | ImageJ-TensorFlow 17 | ImageJ/TensorFlow integration. 18 | https://github.com/imagej/imagej-tensorflow 19 | 2017 20 | 21 | ImageJ 22 | https://imagej.net/ 23 | 24 | 25 | 26 | Simplified BSD License 27 | repo 28 | 29 | 30 | 31 | 32 | 33 | ctrueden 34 | Curtis Rueden 35 | https://imagej.net/people/ctrueden 36 | 37 | founder 38 | lead 39 | developer 40 | debugger 41 | reviewer 42 | support 43 | maintainer 44 | 45 | 46 | 47 | HedgehogCode 48 | Benjamin Wilhelm 49 | https://imagej.net/people/HedgehogCode 50 | 51 | developer 52 | debugger 53 | reviewer 54 | support 55 | maintainer 56 | 57 | 58 | 59 | frauzufall 60 | Deborah Schmidt 61 | https://imagej.net/people/frauzufall 62 | 63 | developer 64 | debugger 65 | reviewer 66 | support 67 | maintainer 68 | 69 | 70 | 71 | 72 | 73 | Asim Shankar 74 | asimshankar 75 | founder 76 | 77 | 78 | Samuel Yang 79 | samueljyang 80 | 81 | 82 | Christian Dietz 83 | https://imagej.net/people/dietzc 84 | dietzc 85 | 86 | 87 | 88 | 89 | 90 | Image.sc Forum 91 | https://forum.image.sc/tag/imagej 92 | 93 | 94 | 95 | 96 | scm:git:https://github.com/imagej/imagej-tensorflow 97 | scm:git:git@github.com:imagej/imagej-tensorflow 98 | HEAD 99 | https://github.com/imagej/imagej-tensorflow 100 | 101 | 102 | GitHub Issues 103 | https://github.com/imagej/imagej-tensorflow/issues 104 | 105 | 106 | GitHub Actions 107 | https://github.com/imagej/imagej-tensorflow/actions 108 | 109 | 110 | 111 | net.imagej.tensorflow.demo.LabelImage 112 | net.imagej.tensorflow 113 | 114 | bsd_2 115 | Board of Regents of the University of 116 | Wisconsin-Madison and Google, Inc. 117 | 118 | 119 | sign,deploy-to-scijava 120 | 121 | 2.0.0 122 | 123 | 124 | 125 | 126 | scijava.public 127 | https://maven.scijava.org/content/groups/public 128 | 129 | 130 | 131 | 132 | 133 | 134 | net.imagej 135 | imagej 136 | 137 | 138 | net.imagej 139 | imagej-common 140 | 141 | 142 | net.imagej 143 | imagej-updater 144 | 145 | 146 | 147 | 148 | net.imglib2 149 | imglib2 150 | 151 | 152 | 153 | 154 | org.scijava 155 | scijava-common 156 | 157 | 158 | 163 | org.scijava 164 | scijava-io-http 165 | 166 | 167 | 168 | 169 | org.tensorflow 170 | libtensorflow 171 | 172 | 173 | org.tensorflow 174 | tensorflow 175 | 176 | 177 | 178 | 179 | com.miglayout 180 | miglayout-swing 181 | 182 | 183 | org.apache.commons 184 | commons-compress 185 | 186 | 187 | 188 | 189 | junit 190 | junit 191 | test 192 | 193 | 194 | com.google.guava 195 | guava 196 | test 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/CachedModelBundle.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import org.tensorflow.SavedModelBundle; 34 | 35 | /** 36 | * A wrapper for {@link SavedModelBundle} remembering if it got closed. 37 | * @author Deborah Schmidt 38 | */ 39 | public class CachedModelBundle implements AutoCloseable { 40 | private SavedModelBundle model; 41 | private boolean closed = false; 42 | 43 | public CachedModelBundle(String path, String[] tags) { 44 | this.model = SavedModelBundle.load(path, tags); 45 | } 46 | 47 | public SavedModelBundle model() { 48 | return model; 49 | } 50 | 51 | @Override 52 | public void close() { 53 | closed = true; 54 | model.close(); 55 | } 56 | 57 | public boolean isClosed() { 58 | return closed; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/DefaultTensorFlowService.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import java.io.BufferedReader; 34 | import java.io.File; 35 | import java.io.FileInputStream; 36 | import java.io.IOException; 37 | import java.io.InputStreamReader; 38 | import java.nio.charset.StandardCharsets; 39 | import java.text.DecimalFormat; 40 | import java.util.Arrays; 41 | import java.util.HashMap; 42 | import java.util.List; 43 | import java.util.Map; 44 | import java.util.concurrent.ExecutionException; 45 | import java.util.stream.Collectors; 46 | 47 | import net.imagej.tensorflow.util.TensorFlowUtil; 48 | import net.imagej.tensorflow.util.UnpackUtil; 49 | import org.scijava.app.AppService; 50 | import org.scijava.app.StatusService; 51 | import org.scijava.download.DiskLocationCache; 52 | import org.scijava.download.DownloadService; 53 | import org.scijava.event.EventHandler; 54 | import org.scijava.io.location.BytesLocation; 55 | import org.scijava.io.location.Location; 56 | import org.scijava.log.LogService; 57 | import org.scijava.plugin.Parameter; 58 | import org.scijava.plugin.Plugin; 59 | import org.scijava.service.AbstractService; 60 | import org.scijava.service.Service; 61 | import org.scijava.task.Task; 62 | import org.scijava.task.event.TaskEvent; 63 | import org.scijava.util.ByteArray; 64 | import org.scijava.util.FileUtils; 65 | import org.tensorflow.Graph; 66 | import org.tensorflow.SavedModelBundle; 67 | import org.tensorflow.TensorFlow; 68 | 69 | /** 70 | * Default implementation of {@link TensorFlowService}. 71 | * 72 | * @author Curtis Rueden 73 | */ 74 | @Plugin(type = Service.class) 75 | public class DefaultTensorFlowService extends AbstractService implements TensorFlowService 76 | { 77 | 78 | @Parameter 79 | private DownloadService downloadService; 80 | 81 | @Parameter 82 | private AppService appService; 83 | 84 | @Parameter 85 | private LogService logService; 86 | 87 | /** Models which are already cached in memory. */ 88 | private final Map models = new HashMap<>(); 89 | 90 | /** Graphs which are already cached in memory. */ 91 | private final Map graphs = new HashMap<>(); 92 | 93 | /** Labels which are already cached in memory. */ 94 | private final Map> labelses = new HashMap<>(); 95 | 96 | /** Disk cache defining where compressed models are stored locally. */ 97 | private DiskLocationCache modelCache; 98 | 99 | private TensorFlowLibraryStatus tfStatus = TensorFlowLibraryStatus.notLoaded(); 100 | 101 | /** The loaded TensorFlow version. Will be initialized once in loadLibrary */ 102 | private TensorFlowVersion tfVersion; 103 | 104 | private static String CACHE_DIR_PROPERTY_KEY = "imagej.tensorflow.models.dir"; 105 | 106 | // -- TensorFlowService methods -- 107 | 108 | @Deprecated 109 | @Override 110 | public SavedModelBundle loadModel(final Location source, 111 | final String modelName, final String... tags) throws IOException 112 | { 113 | // Get a local directory with unpacked model data. 114 | final File modelDir = modelDir(source, modelName); 115 | return SavedModelBundle.load(modelDir.getAbsolutePath(), tags); 116 | } 117 | 118 | @Override 119 | public CachedModelBundle loadCachedModel(final Location source, 120 | final String modelName, final String... tags) throws IOException 121 | { 122 | final String key = modelName + "/" + Arrays.toString(tags); 123 | 124 | // If the model is already cached in memory, return it. 125 | if (models.containsKey(key)) { 126 | CachedModelBundle model = models.get(key); 127 | if(!model.isClosed()) return model; 128 | } 129 | 130 | // Get a local directory with unpacked model data. 131 | final File modelDir = modelDir(source, modelName); 132 | 133 | // Load the saved model. 134 | final CachedModelBundle model = // 135 | new CachedModelBundle(modelDir.getAbsolutePath(), tags); 136 | 137 | // Cache the result for performance next time. 138 | models.put(key, model); 139 | 140 | return model; 141 | } 142 | 143 | @Override 144 | public Graph loadGraph(final Location source, final String modelName, 145 | final String graphPath) throws IOException 146 | { 147 | final String key = modelName + "/" + graphPath; 148 | 149 | // If the graph is already cached in memory, return it. 150 | if (graphs.containsKey(key)) return graphs.get(key); 151 | 152 | // Get a local directory with unpacked model data. 153 | final File modelDir = modelDir(source, modelName); 154 | 155 | // Read the serialized graph. 156 | final byte[] graphDef = FileUtils.readFile(new File(modelDir, graphPath)); 157 | 158 | // Convert to a TensorFlow Graph object. 159 | final Graph graph = new Graph(); 160 | graph.importGraphDef(graphDef); 161 | 162 | // Cache the result for performance next time. 163 | graphs.put(key, graph); 164 | 165 | return graph; 166 | } 167 | 168 | @Override 169 | public List loadLabels(final Location source, final String modelName, 170 | final String labelsPath) throws IOException 171 | { 172 | final String key = modelName + "/" + labelsPath; 173 | 174 | // If the labels are already cached in memory, return them. 175 | if (labelses.containsKey(key)) return labelses.get(key); 176 | 177 | // Get a local directory with unpacked model data. 178 | final File modelDir = modelDir(source, modelName); 179 | 180 | // Read the labels. 181 | final File labelsFile = new File(modelDir, labelsPath); 182 | final List labels; 183 | try (final BufferedReader labelsReader = new BufferedReader( 184 | new InputStreamReader(new FileInputStream(labelsFile), 185 | StandardCharsets.UTF_8))) 186 | { 187 | labels = labelsReader.lines().collect(Collectors.toList()); 188 | } 189 | 190 | // Cache the result for performance next time. 191 | labelses.put(key, labels); 192 | 193 | return labels; 194 | } 195 | 196 | /** 197 | * Loads the TensorFlow library. 198 | */ 199 | @Override 200 | public synchronized void loadLibrary() { 201 | // Do not try to load the library twice 202 | if (tfStatus.triedLoading()) 203 | return; 204 | 205 | createCrashFile(); 206 | 207 | boolean error = true; 208 | try { 209 | tfStatus = loadFromLib(); 210 | if (!tfStatus.isFailed() && tfStatus.isLoaded()) { 211 | error = false; 212 | } 213 | } catch (Exception e) {} 214 | 215 | if (error) { 216 | try { 217 | tfStatus = loadFromJAR(); 218 | if (!tfStatus.isFailed() && tfStatus.isLoaded()) { 219 | error = false; 220 | } 221 | } catch (Exception e) {} 222 | } 223 | 224 | if (!error) { 225 | deleteCrashFile(); 226 | } 227 | 228 | if(!tfStatus.isLoaded() && !tfStatus.isFailed()) { 229 | tfStatus = TensorFlowLibraryStatus.failed("Could not find any TensorFlow library."); 230 | } 231 | 232 | // Handle a previous crash 233 | boolean jniCrashed = false; 234 | boolean jarCrashed = false; 235 | if (getCrashFile().exists()) { 236 | logService.warn("Crash file exists: " + getCrashFile().getAbsolutePath()); 237 | if (getNativeVersionFile().exists()) { 238 | // The jni library crashed the JVM: We should test if the jar library works 239 | jniCrashed = true; 240 | tfStatus = TensorFlowLibraryStatus.crashed("TensorFlow native library crashed: "); 241 | logService.warn("The JVM seems to have crashed when loading the TensorFlow native library."); 242 | logService.warn("Trying to delete TensorFlow libraries from " + TensorFlowUtil.getLibDir(getRoot())); 243 | TensorFlowUtil.removeNativeLibraries(getRoot(), logService); 244 | logService.warn("Trying to load the library provided by this JAR instead:" + TensorFlowUtil.getTensorFlowJAR()); 245 | } else { 246 | // The jar library crashed the JVM: We should not try to load it again 247 | jarCrashed = true; 248 | tfStatus = TensorFlowLibraryStatus.crashed("TensorFlow JAR crashed: " + TensorFlowUtil.getTensorFlowJAR()); 249 | logService.warn("The JVM seems to have crashed when loading the TensorFlow JAR library: " + TensorFlowUtil.getTensorFlowJAR()); 250 | logService.warn("Please run Edit > Options > TensorFlow and choose another version or delete the crash file to load the JAR library again."); 251 | } 252 | } 253 | } 254 | 255 | /** 256 | * Tries to load the TensorFlow library from Fiji.app/lib folder. 257 | * In case of success, {@link #tfVersion} is set to the loaded TensorFlow version 258 | * @return the {@link TensorFlowLibraryStatus} indicating success or failure of loading the library 259 | */ 260 | private TensorFlowLibraryStatus loadFromLib() { 261 | try { 262 | System.loadLibrary("tensorflow_jni"); 263 | try { 264 | tfVersion = TensorFlowUtil.readNativeVersionFile(getRoot()); 265 | } catch (IOException e) { 266 | // could not read native version file, unknown version origin 267 | logService.warn(e.getMessage()); 268 | tfVersion = new TensorFlowVersion(TensorFlow.version(), null, null, null); 269 | } 270 | return TensorFlowLibraryStatus.loaded("Using native TensorFlow version: " + tfVersion); 271 | } catch (final UnsatisfiedLinkError e) { 272 | if(e.getMessage().contains("no tensorflow_jni")) { 273 | logService.info("No TF library found in " + TensorFlowUtil.getLibDir(getRoot()) + "."); 274 | return TensorFlowLibraryStatus.notLoaded(); 275 | } else { 276 | try { 277 | tfVersion = TensorFlowUtil.readNativeVersionFile(getRoot()); 278 | logService.warn("Could not load native TF library " + tfVersion + " " + e.getMessage()); 279 | } catch (final UnsatisfiedLinkError | IOException e1) { 280 | logService.warn("Could not load native TF library (unknown version) " + e.getMessage()); 281 | } 282 | // TODO maybe ask if the native lib should be deleted from /lib 283 | return TensorFlowLibraryStatus.failed( e.getMessage()); 284 | } 285 | } 286 | } 287 | 288 | /** 289 | * Tries to load the TensorFlow library from the TensorFlow JAR in the class path. 290 | * In case of success, {@link #tfVersion} is set to the loaded TensorFlow version 291 | * @return the {@link TensorFlowLibraryStatus} indicating success or failure of loading the library 292 | */ 293 | private TensorFlowLibraryStatus loadFromJAR() { 294 | // Calling a native method will load the library from the jar 295 | try { 296 | tfVersion = getJarVersion(); 297 | return TensorFlowLibraryStatus.loaded("Using default TensorFlow version from JAR: " + tfVersion); 298 | } catch (UnsatisfiedLinkError | NoClassDefFoundError e) { 299 | logService.warn("Could not load JAR TF library: " + e.getMessage()); 300 | return TensorFlowLibraryStatus.failed(e.getMessage()); 301 | } 302 | } 303 | 304 | private File getNativeVersionFile() { 305 | return TensorFlowUtil.getNativeVersionFile(getRoot()); 306 | } 307 | 308 | private File getCrashFile() { 309 | return TensorFlowUtil.getCrashFile(getRoot()); 310 | } 311 | 312 | @Override 313 | public synchronized TensorFlowVersion getTensorFlowVersion() { 314 | return tfVersion; 315 | } 316 | 317 | @Override 318 | public synchronized TensorFlowLibraryStatus getStatus() { 319 | return tfStatus; 320 | } 321 | 322 | @Override 323 | public File loadFile(final Location source, final String modelName, final String filePath) throws IOException { 324 | // Get a local directory with unpacked model data. 325 | final File modelDir = modelDir(source, modelName); 326 | 327 | return new File(modelDir, filePath); 328 | } 329 | 330 | // -- Disposable methods -- 331 | 332 | @Override 333 | public void dispose() { 334 | // Dispose models. 335 | for (final CachedModelBundle model : models.values()) { 336 | model.close(); 337 | } 338 | models.clear(); 339 | 340 | // Dispose graphs. 341 | for (final Graph graph : graphs.values()) { 342 | graph.close(); 343 | } 344 | graphs.clear(); 345 | 346 | // Dispose labels. 347 | labelses.clear(); 348 | } 349 | 350 | // -- Helper methods -- 351 | 352 | private DiskLocationCache modelCache() { 353 | if (modelCache == null) initModelCache(); 354 | return modelCache; 355 | } 356 | 357 | private synchronized void initModelCache() { 358 | final DiskLocationCache cache = new DiskLocationCache(); 359 | 360 | // Cache the models into $IMAGEJ_DIR/models. 361 | final File cacheBase; 362 | 363 | String modelCacheDir = System.getProperty(CACHE_DIR_PROPERTY_KEY); 364 | if (modelCacheDir != null) { 365 | cacheBase = new File(modelCacheDir); 366 | } else { 367 | final File baseDir = appService.getApp().getBaseDirectory(); 368 | cacheBase = new File(baseDir, "models"); 369 | } 370 | logService.info("Caching TensorFlow models to " + cacheBase.getAbsolutePath()); 371 | 372 | if (!cacheBase.exists()) 373 | cacheBase.mkdirs(); 374 | 375 | cache.setBaseDirectory(cacheBase); 376 | 377 | modelCache = cache; 378 | } 379 | 380 | // TODO - Migrate unpacking logic into the DownloadService proper. 381 | // And consider whether/how to avoid using so much temporary space. 382 | 383 | private File modelDir(final Location source, final String modelName) 384 | throws IOException 385 | { 386 | final File modelDir = new File(modelCache().getBaseDirectory(), modelName); 387 | if (!modelDir.exists()) try { 388 | downloadAndUnpackResource(source, modelDir); 389 | } 390 | catch (final InterruptedException | ExecutionException exc) { 391 | throw new IOException(exc); 392 | } 393 | return modelDir; 394 | } 395 | 396 | /** Downloads and unpacks a zipped resource. */ 397 | private void downloadAndUnpackResource(final Location source, 398 | final File destDir) throws InterruptedException, ExecutionException, 399 | IOException 400 | { 401 | // Allocate a dynamic byte array. 402 | final ByteArray byteArray = new ByteArray(1024 * 1024); 403 | 404 | // Download the compressed model into the byte array. 405 | final BytesLocation bytes = new BytesLocation(byteArray); 406 | final Task task = // 407 | downloadService.download(source, bytes, modelCache()).task(); 408 | final StatusUpdater statusUpdater = new StatusUpdater(task); 409 | context().inject(statusUpdater); 410 | task.waitFor(); 411 | UnpackUtil.unZip(destDir, byteArray, logService, statusUpdater.statusService); 412 | 413 | statusUpdater.clear(); 414 | } 415 | 416 | private void createCrashFile() { 417 | try { 418 | getCrashFile().getParentFile().mkdirs(); 419 | getCrashFile().createNewFile(); 420 | } catch (final IOException e) { 421 | logService.warn(e); 422 | } 423 | } 424 | 425 | private void deleteCrashFile() { 426 | final File crashFile = getCrashFile(); 427 | if(crashFile.exists()) crashFile.delete(); 428 | } 429 | 430 | private TensorFlowVersion getJarVersion() { 431 | TensorFlowVersion version = TensorFlowUtil.versionFromClassPathJAR(); 432 | String loadedVersion = TensorFlow.version(); 433 | if(!version.getVersionNumber().equals(loadedVersion)) { 434 | logService.warn("Loaded TensorFlow version is " + loadedVersion 435 | + " whereas the TensorFlow class in the classpath suggests version " 436 | + version.getVersionNumber() + "."); 437 | return new TensorFlowVersion(loadedVersion, null, null, null); 438 | } 439 | return version; 440 | } 441 | 442 | private String getRoot() { 443 | return appService.getApp().getBaseDirectory().getAbsolutePath(); 444 | } 445 | /** 446 | * A dumb class which passes task events on to the {@link StatusService}. 447 | * Eventually, this sort of logic will be built in to SciJava Common. But for 448 | * the moment, we do it ourselves. 449 | */ 450 | private class StatusUpdater { 451 | private final DecimalFormat formatter = new DecimalFormat("##.##"); 452 | private final Task task; 453 | 454 | private long lastUpdate; 455 | 456 | @Parameter 457 | private StatusService statusService; 458 | 459 | private StatusUpdater(final Task task) { 460 | this.task = task; 461 | } 462 | 463 | public void update(final String message) { 464 | statusService.showStatus(message); 465 | } 466 | 467 | public void update(final int value, final int max, final String message) { 468 | final long timestamp = System.currentTimeMillis(); 469 | if (timestamp < lastUpdate + 100) return; // Avoid excessive updates. 470 | lastUpdate = timestamp; 471 | 472 | final double percent = 100.0 * value / max; 473 | statusService.showStatus(value, max, message + ": " + // 474 | formatter.format(percent) + "%"); 475 | } 476 | 477 | public void clear() { 478 | statusService.clearStatus(); 479 | } 480 | 481 | @EventHandler 482 | private void onEvent(final TaskEvent evt) { 483 | if (task == evt.getTask()) { 484 | final int value = (int) task.getProgressValue(); 485 | final int max = (int) task.getProgressMaximum(); 486 | final String message = task.getStatusMessage(); 487 | update(value, max, message); 488 | } 489 | } 490 | } 491 | } 492 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/GraphBuilder.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import java.nio.FloatBuffer; 34 | 35 | import org.tensorflow.Graph; 36 | import org.tensorflow.Output; 37 | import org.tensorflow.Tensor; 38 | 39 | // In the fullness of time, equivalents of the methods of this class should be 40 | // auto-generated from the OpDefs linked into libtensorflow_jni.so. That would 41 | // match what is done in other languages like Python, C++ and Go. 42 | 43 | public class GraphBuilder { 44 | 45 | private final Graph g; 46 | 47 | public GraphBuilder(final Graph g) { 48 | this.g = g; 49 | } 50 | 51 | public Output div(final Output x, final Output y) { 52 | return binaryOp3("Div", x, y, "Div"); 53 | } 54 | 55 | public Output div(final Output x, final Output y, 56 | final String name) 57 | { 58 | return binaryOp3("Div", x, y, name); 59 | } 60 | 61 | public Output sub(final Output x, final Output y) { 62 | return binaryOp("Sub", x, y, "Sub"); 63 | } 64 | 65 | public Output sub(final Output x, final Output y, 66 | final String name) 67 | { 68 | return binaryOp("Sub", x, y, name); 69 | } 70 | 71 | public Output resizeBilinear(final Output images, final Output size) { 72 | return binaryOp3("ResizeBilinear", images, size, "ResizeBilinear"); 73 | } 74 | 75 | public Output resizeBilinear(final Output images, final Output size, 76 | final String name) 77 | { 78 | return binaryOp3("ResizeBilinear", images, size, name); 79 | } 80 | 81 | public Output expandDims(final Output input, final Output dim) { 82 | return binaryOp3("ExpandDims", input, dim, "ExpandDims"); 83 | } 84 | 85 | public Output expandDims(final Output input, final Output dim, 86 | final String name) 87 | { 88 | return binaryOp3("ExpandDims", input, dim, name); 89 | } 90 | 91 | public Output constant(final String name, final Object value) { 92 | try (Tensor t = Tensor.create(value)) { 93 | return constant(name, t); 94 | } 95 | } 96 | 97 | public Output constant(final String name, final float[] value, 98 | final long... shape) 99 | { 100 | try (Tensor t = Tensor.create(shape, FloatBuffer.wrap(value))) { 101 | return constant(name, t); 102 | } 103 | } 104 | 105 | public Output constant(final String name, final Tensor value) { 106 | return g.opBuilder("Const", name).setAttr("dtype", value.dataType()) 107 | .setAttr("value", value).build().output(0); 108 | } 109 | 110 | private Output binaryOp(final String type, final Output in1, 111 | final Output in2, final String name) 112 | { 113 | return g.opBuilder(type, name).addInput(in1).addInput(in2).build().output( 114 | 0); 115 | } 116 | 117 | private Output binaryOp3(final String type, final Output in1, 118 | final Output in2, final String name) 119 | { 120 | return g.opBuilder(type, name).addInput(in1).addInput(in2).build().output( 121 | 0); 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/TensorFlowLibraryStatus.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | /** 34 | * The loading status of the TensorFlow native library. 35 | * @author Deborah Schmidt 36 | * @author Benjamin Wilhelm 37 | */ 38 | public final class TensorFlowLibraryStatus { 39 | 40 | private final boolean loaded; 41 | private final boolean crashed; 42 | private final boolean failed; 43 | private final String info; 44 | 45 | static TensorFlowLibraryStatus notLoaded() { 46 | return new TensorFlowLibraryStatus(false, false, false, ""); 47 | } 48 | 49 | static TensorFlowLibraryStatus loaded(final String info) { 50 | return new TensorFlowLibraryStatus(true, false, false, info); 51 | } 52 | 53 | static TensorFlowLibraryStatus crashed(final String info) { 54 | return new TensorFlowLibraryStatus(false, true, false, info); 55 | } 56 | 57 | static TensorFlowLibraryStatus failed(final String info) { 58 | return new TensorFlowLibraryStatus(false, false, true, info); 59 | } 60 | 61 | private TensorFlowLibraryStatus(final boolean loaded, final boolean crashed, final boolean failed, 62 | final String info) { 63 | this.loaded = loaded; 64 | this.crashed = crashed; 65 | this.failed = failed; 66 | this.info = info; 67 | } 68 | 69 | /** 70 | * @return whether there was an attempt to load the TensorFlow library 71 | */ 72 | public boolean triedLoading() { 73 | return loaded || crashed || failed; 74 | } 75 | 76 | /** 77 | * @return whether the TensorFlow library was successfully loaded 78 | */ 79 | public boolean isLoaded() { 80 | return loaded; 81 | } 82 | 83 | /** 84 | * @return whether the TensorFlow library failed to load 85 | */ 86 | public boolean isFailed() { 87 | return failed; 88 | } 89 | 90 | /** 91 | * @return whether a crash log file was found indicating that the TensorFlow library crashed the JVM during {@link TensorFlowService#loadLibrary()} 92 | */ 93 | public boolean isCrashed() { 94 | return crashed; 95 | } 96 | 97 | /** 98 | * @return a message describing the failed or successful loading attempt of the TensorFlow Library 99 | */ 100 | public String getInfo() { 101 | return info; 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/TensorFlowService.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import net.imagej.ImageJService; 34 | import org.scijava.io.location.Location; 35 | import org.tensorflow.Graph; 36 | import org.tensorflow.SavedModelBundle; 37 | 38 | import java.io.File; 39 | import java.io.IOException; 40 | import java.util.List; 41 | 42 | /** 43 | * Service for working with TensorFlow. 44 | * 45 | * @author Curtis Rueden 46 | */ 47 | public interface TensorFlowService extends ImageJService { 48 | 49 | /** 50 | * Extracts a persisted model from the given location. 51 | * 52 | * @param source The location of the model, which must be structured as a ZIP 53 | * archive. 54 | * @param modelName The name of the model by which the source should be 55 | * unpacked and cached as needed. 56 | * @param tags Optional list of tags passed to 57 | * {@link SavedModelBundle#load(String, String...)}. 58 | * @return The extracted TensorFlow {@link SavedModelBundle} object. 59 | * @throws IOException If something goes wrong reading or unpacking the 60 | * archive. 61 | * Deprecated - use {@link #loadCachedModel(Location, String, String...)} instead. 62 | */ 63 | @Deprecated 64 | SavedModelBundle loadModel(Location source, String modelName, String... tags) 65 | throws IOException; 66 | 67 | /** 68 | * Extracts a persisted model from the given location. 69 | * Returns it from cache in case it was already loaded. 70 | * 71 | * @param source The location of the model, which must be structured as a ZIP 72 | * archive. 73 | * @param modelName The name of the model by which the source should be 74 | * unpacked and cached as needed. 75 | * @param tags Optional list of tags passed to 76 | * {@link SavedModelBundle#load(String, String...)}. 77 | * @return The extracted TensorFlow {@link SavedModelBundle} object 78 | * wrapped by a {@link CachedModelBundle}. 79 | * @throws IOException If something goes wrong reading or unpacking the 80 | * archive. 81 | */ 82 | CachedModelBundle loadCachedModel(Location source, String modelName, String... tags) 83 | throws IOException; 84 | 85 | /** 86 | * Extracts a graph from the given location. 87 | * 88 | * @param source The location of the graph, which must be structured as a ZIP 89 | * archive. 90 | * @param modelName The name of the model by which the source should be 91 | * unpacked and cached as needed. 92 | * @param graphPath The name of the .pb file inside the ZIP archive containing 93 | * the graph. 94 | * @return The extracted TensorFlow {@link Graph} object. 95 | * @throws IOException If something goes wrong reading or unpacking the 96 | * archive. 97 | */ 98 | Graph loadGraph(Location source, String modelName, String graphPath) 99 | throws IOException; 100 | 101 | /** 102 | * Extracts labels from the given location. 103 | * 104 | * @param source The location of the labels, which must be structured as a ZIP 105 | * archive. 106 | * @param modelName The name of the model by which the source should be 107 | * unpacked and cached as needed. 108 | * @param labelsPath The name of the .txt file inside the ZIP archive 109 | * containing the labels. 110 | * @return The extracted TensorFlow {@link Graph} object. 111 | * @throws IOException If something goes wrong reading or unpacking the 112 | * archive. 113 | */ 114 | List loadLabels(Location source, String modelName, String labelsPath) 115 | throws IOException; 116 | 117 | /** 118 | * Loads the TensorFlow Library. 119 | */ 120 | void loadLibrary(); 121 | 122 | /** 123 | * @return the TensorFlow version which is currently loaded. 124 | * null if no version is loaded. 125 | */ 126 | TensorFlowVersion getTensorFlowVersion(); 127 | 128 | /** 129 | * @return Status information about the TensorFlow library (e.g. whether it crashed the application) 130 | */ 131 | TensorFlowLibraryStatus getStatus(); 132 | 133 | /** 134 | * Extracts a file from the given location. 135 | * 136 | * @param source The location of the ZIP archive. 137 | * @param modelName The name of the model by which the source should be 138 | * unpacked and cached as needed. 139 | * @param filePath The name of the file inside the ZIP archive. 140 | * @return A {@link File} object. 141 | * @throws IOException If something goes wrong reading or unpacking the 142 | * archive. 143 | */ 144 | File loadFile(final Location source, final String modelName, final String filePath) 145 | throws IOException; 146 | } 147 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/TensorFlowVersion.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import java.util.Objects; 34 | import java.util.Optional; 35 | 36 | /** 37 | * A class describing the details of a specific TensorFlow version. 38 | * 39 | * @author Deborah Schmidt 40 | * @author Benjamin Wilhelm 41 | */ 42 | public class TensorFlowVersion { 43 | 44 | protected final String tfVersion; 45 | protected final Optional supportsGPU; 46 | protected Optional cuda; 47 | protected Optional cudnn; 48 | 49 | /** 50 | * @param version the TensorFlow version number 51 | * @param supportsGPU whether this version runs on the GPU 52 | * @param compatibleCUDA the CUDA version compatible with this TensorFlow 53 | * version 54 | * @param compatibleCuDNN the CuDNN version compatible with this TensorFlow 55 | * version 56 | */ 57 | public TensorFlowVersion(final String version, final Boolean supportsGPU, final String compatibleCUDA, 58 | final String compatibleCuDNN) { 59 | this.tfVersion = version; 60 | this.supportsGPU = Optional.ofNullable(supportsGPU); 61 | this.cuda = Optional.ofNullable(compatibleCUDA); 62 | this.cudnn = Optional.ofNullable(compatibleCuDNN); 63 | } 64 | 65 | public TensorFlowVersion(TensorFlowVersion other) { 66 | this.tfVersion = other.tfVersion; 67 | this.supportsGPU = other.supportsGPU; 68 | this.cuda = other.cuda; 69 | this.cudnn = other.cudnn; 70 | } 71 | 72 | /** 73 | * @return the version number of TensorFlow, e.g. 0.13.1 74 | */ 75 | public String getVersionNumber() { 76 | return tfVersion; 77 | } 78 | 79 | /** 80 | * @return if the version supports the usage of the GPU 81 | */ 82 | public Optional usesGPU() { 83 | return supportsGPU; 84 | } 85 | 86 | /** 87 | * @return the CUDA version this TensorFlow version is compatible with 88 | */ 89 | public Optional getCompatibleCUDA() { 90 | return cuda; 91 | } 92 | 93 | /** 94 | * @return the CuDNN version this TensorFlow version is compatible with 95 | */ 96 | public Optional getCompatibleCuDNN() { 97 | return cudnn; 98 | } 99 | 100 | @Override 101 | public boolean equals(final Object obj) { 102 | if (!(obj.getClass().equals(this.getClass()))) return false; 103 | final TensorFlowVersion o = (TensorFlowVersion) obj; 104 | return tfVersion.equals(o.tfVersion) // 105 | && supportsGPU.equals(o.supportsGPU); 106 | } 107 | 108 | @Override 109 | public int hashCode() { 110 | return Objects.hash(tfVersion, supportsGPU, cuda, cudnn); 111 | } 112 | 113 | @Override 114 | public String toString() { 115 | String text = "TF " + tfVersion; 116 | if(supportsGPU.isPresent()) text += " " + (supportsGPU.get() ? "GPU" : "CPU"); 117 | if (cuda.isPresent() || cudnn.isPresent()) { 118 | text += " ("; 119 | if (cuda.isPresent()) 120 | text += "CUDA " + cuda.get(); 121 | if (cuda.isPresent() && cudnn.isPresent()) 122 | text += ", "; 123 | if (cudnn.isPresent()) 124 | text += "CuDNN " + cudnn.get(); 125 | text += ")"; 126 | } 127 | return text; 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/Tensors.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import java.nio.ByteBuffer; 34 | import java.nio.DoubleBuffer; 35 | import java.nio.FloatBuffer; 36 | import java.nio.IntBuffer; 37 | import java.nio.LongBuffer; 38 | import java.util.Arrays; 39 | import java.util.stream.IntStream; 40 | 41 | import org.tensorflow.DataType; 42 | import org.tensorflow.Tensor; 43 | import org.tensorflow.types.UInt8; 44 | 45 | import net.imglib2.Cursor; 46 | import net.imglib2.Dimensions; 47 | import net.imglib2.FinalDimensions; 48 | import net.imglib2.IterableInterval; 49 | import net.imglib2.RandomAccess; 50 | import net.imglib2.RandomAccessibleInterval; 51 | import net.imglib2.img.Img; 52 | import net.imglib2.img.ImgView; 53 | import net.imglib2.img.array.ArrayImg; 54 | import net.imglib2.img.array.ArrayImgs; 55 | import net.imglib2.img.basictypeaccess.array.ByteArray; 56 | import net.imglib2.img.basictypeaccess.array.DoubleArray; 57 | import net.imglib2.img.basictypeaccess.array.FloatArray; 58 | import net.imglib2.img.basictypeaccess.array.IntArray; 59 | import net.imglib2.img.basictypeaccess.array.LongArray; 60 | import net.imglib2.type.numeric.RealType; 61 | import net.imglib2.type.numeric.integer.ByteType; 62 | import net.imglib2.type.numeric.integer.IntType; 63 | import net.imglib2.type.numeric.integer.LongType; 64 | import net.imglib2.type.numeric.real.DoubleType; 65 | import net.imglib2.type.numeric.real.FloatType; 66 | import net.imglib2.util.Intervals; 67 | import net.imglib2.util.Util; 68 | import net.imglib2.view.Views; 69 | 70 | /** 71 | * Utility class for working with TensorFlow {@link Tensor} objects. In 72 | * particular, this class provides methods for converting between ImgLib2 data 73 | * structures and TensorFlow {@link Tensor}s. 74 | * 75 | * @author Curtis Rueden 76 | * @author Christian Dietz 77 | * @author Benjamin Wilhelm 78 | */ 79 | public final class Tensors { 80 | 81 | private Tensors() { 82 | // NB: Prevent instantiation of utility class. 83 | } 84 | 85 | // --------- TENSOR to RAI --------- 86 | 87 | // NB: The following "agnostic" API is somehow bad due to recursive generics. 88 | // The Img sucks because we don't know it's RealType. 89 | // But "? extends RealType" is also a problem because ?s don't match. 90 | // And putting a T param here breaks things downstream: calling code 91 | // has not T context and so you can assign things to improper types. 92 | // public static Img img(final Tensor image) { 93 | // switch (image.dataType()) { 94 | // case BOOL: 95 | // return imgBool(image); 96 | // case DOUBLE: 97 | // return imgDouble(image); 98 | // case FLOAT: 99 | // return imgFloat(image); 100 | // case INT32: 101 | // return imgInt(image); 102 | // case INT64: 103 | // return imgLong(image); 104 | // case STRING: 105 | // return imgString(image); 106 | // case UINT8: 107 | // return imgByte(image); 108 | // default: 109 | // throw new UnsupportedOperationException(); 110 | // } 111 | // } 112 | 113 | /** 114 | * Creates an image of type {@link ByteType} containing the data of a 115 | * TensorFlow Tensor with the data type {@link DataType#UINT8}. 116 | *

117 | * Note that this does _not_ adjust any dimensions. This means that 118 | * the resulting image will have dimensions corresponding to the reversed 119 | * shape of the Tensor. See {@link #imgByteDirect(Tensor)} and 120 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 121 | * differently. 122 | *

123 | * Note also that no exception is thrown if the data type is not 124 | * {@link DataType#UINT8} but it will give unexpected results. 125 | *

126 | * @param image The TensorFlow Tensor. 127 | * @return An image containing the data of the Tensor. 128 | */ 129 | public static Img imgByte(final Tensor image) { 130 | final byte[] out = new byte[image.numElements()]; 131 | image.writeTo(ByteBuffer.wrap(out)); 132 | return ArrayImgs.bytes(out, shape(image)); 133 | } 134 | 135 | /** 136 | * Creates an image of type {@link DoubleType} containing the data of a 137 | * TensorFlow Tensor with the data type {@link DataType#DOUBLE}. 138 | *

139 | * Note that this does _not_ adjust any dimensions. This means that 140 | * the resulting image will have dimensions corresponding to the reversed 141 | * shape of the Tensor. See {@link #imgByteDirect(Tensor)} and 142 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 143 | * differently. 144 | *

145 | * @param image The TensorFlow Tensor. 146 | * @return An image containing the data of the Tensor. 147 | * @throws IllegalArgumentException if Tensor data type is not double. 148 | */ 149 | public static Img imgDouble(final Tensor image) { 150 | final double[] out = new double[image.numElements()]; 151 | image.writeTo(DoubleBuffer.wrap(out)); 152 | return ArrayImgs.doubles(out, shape(image)); 153 | } 154 | 155 | /** 156 | * Creates an image of type {@link FloatType} containing the data of a 157 | * TensorFlow Tensor with the data type {@link DataType#FLOAT}. 158 | *

159 | * Note that this does _not_ adjust any dimensions. This means that 160 | * the resulting image will have dimensions corresponding to the reversed 161 | * shape of the Tensor. See {@link #imgByteDirect(Tensor)} and 162 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 163 | * differently. 164 | *

165 | * @param image The TensorFlow Tensor. 166 | * @return An image containing the data of the Tensor. 167 | * @throws IllegalArgumentException if Tensor data type is not float. 168 | */ 169 | public static Img imgFloat(final Tensor image) { 170 | final float[] out = new float[image.numElements()]; 171 | image.writeTo(FloatBuffer.wrap(out)); 172 | return ArrayImgs.floats(out, shape(image)); 173 | } 174 | 175 | /** 176 | * Creates an image of type {@link IntType} containing the data of a 177 | * TensorFlow Tensor with the data type {@link DataType#INT32}. 178 | *

179 | * Note that this does _not_ adjust any dimensions. This means that 180 | * the resulting image will have dimensions corresponding to the reversed 181 | * shape of the Tensor. See {@link #imgByteDirect(Tensor)} and 182 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 183 | * differently. 184 | *

185 | * @param image The TensorFlow Tensor. 186 | * @return An image containing the data of the Tensor. 187 | * @throws IllegalArgumentException if Tensor data type is not int. 188 | */ 189 | public static Img imgInt(final Tensor image) { 190 | final int[] out = new int[image.numElements()]; 191 | image.writeTo(IntBuffer.wrap(out)); 192 | return ArrayImgs.ints(out, shape(image)); 193 | } 194 | 195 | /** 196 | * Creates an image of type {@link LongType} containing the data of a 197 | * TensorFlow Tensor with the data type {@link DataType#INT64}. 198 | *

199 | * Note that this does _not_ adjust any dimensions. This means that 200 | * the resulting image will have dimensions corresponding to the reversed 201 | * shape of the Tensor. See {@link #imgByteDirect(Tensor)} and 202 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 203 | * differently. 204 | *

205 | * @param image The TensorFlow Tensor. 206 | * @return An image containing the data of the Tensor. 207 | * @throws IllegalArgumentException if Tensor data type is not long. 208 | */ 209 | public static Img imgLong(final Tensor image) { 210 | final long[] out = new long[image.numElements()]; 211 | image.writeTo(LongBuffer.wrap(out)); 212 | return ArrayImgs.longs(out, shape(image)); 213 | } 214 | 215 | /** 216 | * Creates an image of type {@link ByteType} containing the data of a 217 | * TensorFlow Tensor with the data type {@link DataType#UINT8}. 218 | *

219 | * Note that no exception is thrown if the data type is not {@link DataType#UINT8} 220 | * but it will give unexpected results. 221 | *

222 | * @param image The TensorFlow Tensor. 223 | * @param dimOrder Defines the mapping of the dimensions between the image 224 | * and the Tensor where the index corresponds to the dimension 225 | * in the image and the value corresponds to the dimension in the 226 | * Tensor. TODO Example? 227 | * @return An image containing the data of the Tensor. 228 | */ 229 | public static Img imgByte(final Tensor image, int[] dimOrder) { 230 | return reverseReorder(reverse(imgByte(image)), dimOrder); 231 | } 232 | 233 | /** 234 | * Creates an image of type {@link DoubleType} containing the data of a 235 | * TensorFlow Tensor with the data type {@link DataType#DOUBLE}. 236 | * 237 | * @param image The TensorFlow Tensor. 238 | * @param dimOrder Defines the mapping of the dimensions between the image 239 | * and the Tensor where the index corresponds to the dimension 240 | * in the image and the value corresponds to the dimension in the 241 | * Tensor. 242 | * @return An image containing the data of the Tensor. 243 | * @throws IllegalArgumentException if Tensor data type is not double. 244 | */ 245 | public static Img imgDouble(final Tensor image, int[] dimOrder) { 246 | return reverseReorder(reverse(imgDouble(image)), dimOrder); 247 | } 248 | 249 | /** 250 | * Creates an image of type {@link FloatType} containing the data of a 251 | * TensorFlow Tensor with the data type {@link DataType#FLOAT}. 252 | * 253 | * @param image The TensorFlow Tensor. 254 | * @param dimOrder Defines the mapping of the dimensions between the image 255 | * and the Tensor where the index corresponds to the dimension 256 | * in the image and the value corresponds to the dimension in the 257 | * Tensor. 258 | * @return An image containing the data of the Tensor. 259 | * @throws IllegalArgumentException if Tensor data type is not float. 260 | */ 261 | public static Img imgFloat(final Tensor image, int[] dimOrder) { 262 | return reverseReorder(reverse(imgFloat(image)), dimOrder); 263 | } 264 | 265 | /** 266 | * Creates an image of type {@link IntType} containing the data of a 267 | * TensorFlow Tensor with the data type {@link DataType#INT32}. 268 | * 269 | * @param image The TensorFlow Tensor. 270 | * @param dimOrder Defines the mapping of the dimensions between the image 271 | * and the Tensor where the index corresponds to the dimension 272 | * in the image and the value corresponds to the dimension in the 273 | * Tensor. 274 | * @return An image containing the data of the Tensor. 275 | * @throws IllegalArgumentException if Tensor data type is not int. 276 | */ 277 | public static Img imgInt(final Tensor image, int[] dimOrder) { 278 | return reverseReorder(reverse(imgInt(image)), dimOrder); 279 | } 280 | 281 | /** 282 | * Creates an image of type {@link LongType} containing the data of a 283 | * TensorFlow Tensor with the data type {@link DataType#INT64}. 284 | * 285 | * @param image The TensorFlow Tensor. 286 | * @param dimOrder Defines the mapping of the dimensions between the image 287 | * and the Tensor where the index corresponds to the dimension 288 | * in the image and the value corresponds to the dimension in the 289 | * Tensor. 290 | * @return An image containing the data of the Tensor. 291 | * @throws IllegalArgumentException if Tensor data type is not long. 292 | */ 293 | public static Img imgLong(final Tensor image, int[] dimOrder) { 294 | return reverseReorder(reverse(imgLong(image)), dimOrder); 295 | } 296 | 297 | /** 298 | * Creates an image of type {@link ByteType} containing the data of a 299 | * TensorFlow Tensor with the data type {@link DataType#UINT8}. 300 | *

301 | * Note that this _does_ adjust the dimensions. This means that 302 | * the resulting image will have dimensions directly corresponding to the 303 | * shape of the Tensor. See {@link #imgByte(Tensor)} and 304 | * {@link #imgByte(Tensor, int[])} if you want to handle dimensions 305 | * differently. 306 | *

307 | * Note also that no exception is thrown if the data type is not 308 | * {@link DataType#UINT8} but it will give unexpected results. 309 | *

310 | * @param image The TensorFlow Tensor. 311 | * @return An image containing the data of the Tensor. 312 | */ 313 | public static Img imgByteDirect(final Tensor image) { 314 | return reverse(imgByte(image)); 315 | } 316 | 317 | /** 318 | * Creates an image of type {@link DoubleType} containing the data of a 319 | * TensorFlow Tensor with the data type {@link DataType#DOUBLE}. 320 | *

321 | * Note that this _does_ adjust the dimensions. This means that 322 | * the resulting image will have dimensions directly corresponding to the 323 | * shape of the Tensor. See {@link #imgDouble(Tensor)} and 324 | * {@link #imgDouble(Tensor, int[])} if you want to handle dimensions 325 | * differently. 326 | *

327 | * @param image The TensorFlow Tensor. 328 | * @return An image containing the data of the Tensor. 329 | * @throws IllegalArgumentException if Tensor data type is not double. 330 | */ 331 | public static Img imgDoubleDirect(final Tensor image) { 332 | return reverse(imgDouble(image)); 333 | } 334 | 335 | /** 336 | * Creates an image of type {@link FloatType} containing the data of a 337 | * TensorFlow Tensor with the data type {@link DataType#FLOAT}. 338 | *

339 | * Note that this _does_ adjust the dimensions. This means that 340 | * the resulting image will have dimensions directly corresponding to the 341 | * shape of the Tensor. See {@link #imgFloat(Tensor)} and 342 | * {@link #imgFloat(Tensor, int[])} if you want to handle dimensions 343 | * differently. 344 | *

345 | * @param image The TensorFlow Tensor. 346 | * @return An image containing the data of the Tensor. 347 | * @throws IllegalArgumentException if Tensor data type is not float. 348 | */ 349 | public static Img imgFloatDirect(final Tensor image) { 350 | return reverse(imgFloat(image)); 351 | } 352 | 353 | /** 354 | * Creates an image of type {@link IntType} containing the data of a 355 | * TensorFlow Tensor with the data type {@link DataType#INT32}. 356 | *

357 | * Note that this _does_ adjust the dimensions. This means that 358 | * the resulting image will have dimensions directly corresponding to the 359 | * shape of the Tensor. See {@link #imgInt(Tensor)} and 360 | * {@link #imgInt(Tensor, int[])} if you want to handle dimensions 361 | * differently. 362 | *

363 | * @param image The TensorFlow Tensor. 364 | * @return An image containing the data of the Tensor. 365 | * @throws IllegalArgumentException if Tensor data type is not int. 366 | */ 367 | public static Img imgIntDirect(final Tensor image) { 368 | return reverse(imgInt(image)); 369 | } 370 | 371 | /** 372 | * Creates an image of type {@link LongType} containing the data of a 373 | * TensorFlow Tensor with the data type {@link DataType#INT64}. 374 | *

375 | * Note that this _does_ adjust the dimensions. This means that 376 | * the resulting image will have dimensions directly corresponding to the 377 | * shape of the Tensor. See {@link #imgLong(Tensor)} and 378 | * {@link #imgLong(Tensor, int[])} if you want to handle dimensions 379 | * differently. 380 | *

381 | * @param image The TensorFlow Tensor. 382 | * @return An image containing the data of the Tensor. 383 | * @throws IllegalArgumentException if Tensor data type is not long. 384 | */ 385 | public static Img imgLongDirect(final Tensor image) { 386 | return reverse(imgLong(image)); 387 | } 388 | 389 | // --------- RAI to TENSOR --------- 390 | 391 | /** 392 | * Creates a TensorFlow Tensor containing data from the given image. 393 | *

394 | * Note that this does _not_ adjust any dimensions. This means that 395 | * the resulting Tensor will have a shape corresponding to the reversed 396 | * dimensions of the image. This is probably what you want because 397 | * TensorFlow uses the dimension order CYX while ImageJ uses XYC. See 398 | * {@link #tensorDirect(RandomAccessibleInterval)} and 399 | * {@link #tensor(RandomAccessibleInterval, int[])} if you want to handle 400 | * dimensions differently. 401 | *

402 | * Also note that this will use the backing RAI's primitive array when one is 403 | * available. Otherwise a copy will be made. 404 | *

405 | * @param image The image which should be put into the Tensor. 406 | * @return A Tensor containing the data of the image. 407 | * @throws IllegalArgumentException if the type of the image is not supported. 408 | * Supported types are {@link ByteType}, {@link DoubleType}, 409 | * {@link FloatType}, {@link IntType} and {@link LongType}. 410 | */ 411 | public static > Tensor tensor( 412 | final RandomAccessibleInterval image) 413 | { 414 | // NB: In the functions called we use reversed dimensions because 415 | // tensorflow iterates the array differently than imglib2. 416 | // Example: array: [ 1, 2, 3, 4, 5, 6 ], shape: [ 3, 2 ] 417 | // imglib2 tensorflow 418 | // | 1 2 3 | | 1 3 5 | 419 | // | 4 5 6 | | 2 4 6 | 420 | 421 | final T type = Util.getTypeFromInterval(image); 422 | if (type instanceof ByteType) { 423 | @SuppressWarnings("unchecked") 424 | final RandomAccessibleInterval typedImage = 425 | (RandomAccessibleInterval) image; 426 | return tensorByte(typedImage); 427 | } 428 | if (type instanceof DoubleType) { 429 | @SuppressWarnings("unchecked") 430 | final RandomAccessibleInterval typedImage = 431 | (RandomAccessibleInterval) image; 432 | return tensorDouble(typedImage); 433 | } 434 | if (type instanceof FloatType) { 435 | @SuppressWarnings("unchecked") 436 | final RandomAccessibleInterval typedImage = 437 | (RandomAccessibleInterval) image; 438 | return tensorFloat(typedImage); 439 | } 440 | if (type instanceof IntType) { 441 | @SuppressWarnings("unchecked") 442 | final RandomAccessibleInterval typedImage = 443 | (RandomAccessibleInterval) image; 444 | return tensorInt(typedImage); 445 | } 446 | if (type instanceof LongType) { 447 | @SuppressWarnings("unchecked") 448 | final RandomAccessibleInterval typedImage = 449 | (RandomAccessibleInterval) image; 450 | return tensorLong(typedImage); 451 | } 452 | throw new IllegalArgumentException("Unsupported image type: " + // 453 | type.getClass().getName()); 454 | } 455 | 456 | // "higher level" routines that give ways of adjusting dimension order 457 | 458 | /** 459 | * Creates a TensorFlow Tensor containing data from the given image. 460 | *

461 | * Note that this will use the backing RAI's primitive array when one is 462 | * available and no dimensions where swapped. Otherwise a copy will be made. 463 | *

464 | * @param image The image which should be put into the Tensor. 465 | * @param dimOrder Defines the mapping of the dimensions between the image 466 | * and the Tensor where the index corresponds to the dimension 467 | * in the image and the value corresponds to the dimension in the 468 | * Tensor. TODO Example? 469 | * @return A Tensor containing the data of the image. 470 | * @throws IllegalArgumentException if the type of the image is not supported. 471 | * Supported types are {@link ByteType}, {@link DoubleType}, 472 | * {@link FloatType}, {@link IntType} and {@link LongType}. 473 | */ 474 | public static > Tensor tensor( 475 | final RandomAccessibleInterval image, int[] dimOrder) 476 | { 477 | // TODO Are 2 calls bad? More views are created but they should be smart 478 | // about this 479 | return tensor(reverse(reorder(image, dimOrder))); 480 | } 481 | 482 | /** 483 | * Creates a TensorFlow Tensor containing data from the given image. 484 | *

485 | * Note that this _does_ adjust the dimensions. This means that 486 | * the resulting Tensor will have a shape directly corresponding to the 487 | * dimensions of the image. Make sure the dimensions are as you want 488 | * them in TensorFlow. See {@link #tensor(RandomAccessibleInterval)} and 489 | * {@link #tensor(RandomAccessibleInterval, int[])} if you want to handle 490 | * dimensions differently. 491 | *

492 | * Also note that this will use the backing RAI's primitive array when one is 493 | * available. Otherwise a copy will be made. 494 | *

495 | * @param image The image which should be put into the Tensor. 496 | * @return A Tensor containing the data of the image. 497 | * @throws IllegalArgumentException if the type of the image is not supported. 498 | * Supported types are {@link ByteType}, {@link DoubleType}, 499 | * {@link FloatType}, {@link IntType} and {@link LongType}. 500 | */ 501 | public static > Tensor tensorDirect( 502 | final RandomAccessibleInterval image) 503 | { 504 | return tensor(reverse(image)); 505 | } 506 | 507 | // "low level" methods that do NOT adjust dimensions 508 | 509 | /** 510 | * Creates a TensorFlow Tensor containing data from the given byte image. 511 | *

512 | * Note that this does _not_ adjust any dimensions. This means that 513 | * the resulting Tensor will have a shape corresponding to the reversed 514 | * dimensions of the image. 515 | *

516 | * Also note that this will use the backing RAI's primitive array when one is 517 | * available. Otherwise a copy will be made. 518 | *

519 | * @param image The image which should be put into the Tensor. 520 | * @return A Tensor containing the data of the image. 521 | */ 522 | public static Tensor tensorByte( 523 | final RandomAccessibleInterval image) 524 | { 525 | final byte[] value = byteArray(image); 526 | ByteBuffer buffer = ByteBuffer.wrap(value); 527 | return Tensor.create(UInt8.class, shape(image), buffer); 528 | } 529 | 530 | /** 531 | * Creates a TensorFlow Tensor containing data from the given double image. 532 | *

533 | * Note that this does _not_ adjust any dimensions. This means that 534 | * the resulting Tensor will have a shape corresponding to the reversed 535 | * dimensions of the image. 536 | *

537 | * Also note that this will use the backing RAI's primitive array when one is 538 | * available. Otherwise a copy will be made. 539 | *

540 | * @param image The image which should be put into the Tensor. 541 | * @return A Tensor containing the data of the image. 542 | */ 543 | public static Tensor tensorDouble( 544 | final RandomAccessibleInterval image) 545 | { 546 | final double[] value = doubleArray(image); 547 | DoubleBuffer buffer = DoubleBuffer.wrap(value); 548 | return Tensor.create(shape(image), buffer); 549 | } 550 | 551 | /** 552 | * Creates a TensorFlow Tensor containing data from the given float image. 553 | *

554 | * Note that this does _not_ adjust any dimensions. This means that 555 | * the resulting Tensor will have a shape corresponding to the reversed 556 | * dimensions of the image. 557 | *

558 | * Also note that this will use the backing RAI's primitive array when one is 559 | * available. Otherwise a copy will be made. 560 | *

561 | * @param image The image which should be put into the Tensor. 562 | * @return A Tensor containing the data of the image. 563 | */ 564 | public static Tensor tensorFloat( 565 | final RandomAccessibleInterval image) 566 | { 567 | final float[] value = floatArray(image); 568 | FloatBuffer buffer = FloatBuffer.wrap(value); 569 | return Tensor.create(shape(image), buffer); 570 | } 571 | 572 | /** 573 | * Creates a TensorFlow Tensor containing data from the given int image. 574 | *

575 | * Note that this does _not_ adjust any dimensions. This means that 576 | * the resulting Tensor will have a shape corresponding to the reversed 577 | * dimensions of the image. 578 | *

579 | * Also note that this will use the backing RAI's primitive array when one is 580 | * available. Otherwise a copy will be made. 581 | *

582 | * @param image The image which should be put into the Tensor. 583 | * @return A Tensor containing the data of the image. 584 | */ 585 | public static Tensor tensorInt( 586 | final RandomAccessibleInterval image) 587 | { 588 | final int[] value = intArray(image); 589 | IntBuffer buffer = IntBuffer.wrap(value); 590 | return Tensor.create(shape(image), buffer); 591 | } 592 | 593 | /** 594 | * Creates a TensorFlow Tensor containing data from the given long image. 595 | *

596 | * Note that this does _not_ adjust any dimensions. This means that 597 | * the resulting Tensor will have a shape corresponding to the reversed 598 | * dimensions of the image. 599 | *

600 | * Also note that this will use the backing RAI's primitive array when one is 601 | * available. Otherwise a copy will be made. 602 | *

603 | * @param image The image which should be put into the Tensor. 604 | * @return A Tensor containing the data of the image. 605 | */ 606 | public static Tensor tensorLong( 607 | final RandomAccessibleInterval image) 608 | { 609 | final long[] value = longArray(image); 610 | LongBuffer buffer = LongBuffer.wrap(value); 611 | return Tensor.create(shape(image), buffer); 612 | } 613 | 614 | /** 615 | * Creates a TensorFlow Tensor containing data from the given byte image. 616 | *

617 | * Note that this will use the backing RAI's primitive array when one is 618 | * available and no dimensions where swapped. Otherwise a copy will be made. 619 | *

620 | * @param image The image which should be put into the Tensor. 621 | * @param dimOrder Defines the mapping of the dimensions between the image 622 | * and the Tensor where the index corresponds to the dimension 623 | * in the image and the value corresponds to the dimension in the 624 | * Tensor. TODO Example? 625 | * @return A Tensor containing the data of the image. 626 | */ 627 | public static Tensor tensorByte( 628 | final RandomAccessibleInterval image, final int[] dimOrder) 629 | { 630 | return tensorByte(reverse(reorder(image, dimOrder))); 631 | } 632 | 633 | /** 634 | * Creates a TensorFlow Tensor containing data from the given double image. 635 | *

636 | * Note that this will use the backing RAI's primitive array when one is 637 | * available and no dimensions where swapped. Otherwise a copy will be made. 638 | *

639 | * @param image The image which should be put into the Tensor. 640 | * @param dimOrder Defines the mapping of the dimensions between the image 641 | * and the Tensor where the index corresponds to the dimension 642 | * in the image and the value corresponds to the dimension in the 643 | * Tensor. TODO Example? 644 | * @return A Tensor containing the data of the image. 645 | */ 646 | public static Tensor tensorDouble( 647 | final RandomAccessibleInterval image, final int[] dimOrder) 648 | { 649 | return tensorDouble(reverse(reorder(image, dimOrder))); 650 | } 651 | 652 | /** 653 | * Creates a TensorFlow Tensor containing data from the given float image. 654 | *

655 | * Note that this will use the backing RAI's primitive array when one is 656 | * available and no dimensions where swapped. Otherwise a copy will be made. 657 | *

658 | * @param image The image which should be put into the Tensor. 659 | * @param dimOrder Defines the mapping of the dimensions between the image 660 | * and the Tensor where the index corresponds to the dimension 661 | * in the image and the value corresponds to the dimension in the 662 | * Tensor. TODO Example? 663 | * @return A Tensor containing the data of the image. 664 | */ 665 | public static Tensor tensorFloat( 666 | final RandomAccessibleInterval image, final int[] dimOrder) 667 | { 668 | return tensorFloat(reverse(reorder(image, dimOrder))); 669 | } 670 | 671 | /** 672 | * Creates a TensorFlow Tensor containing data from the given int image. 673 | *

674 | * Note that this will use the backing RAI's primitive array when one is 675 | * available and no dimensions where swapped. Otherwise a copy will be made. 676 | *

677 | * @param image The image which should be put into the Tensor. 678 | * @param dimOrder Defines the mapping of the dimensions between the image 679 | * and the Tensor where the index corresponds to the dimension 680 | * in the image and the value corresponds to the dimension in the 681 | * Tensor. TODO Example? 682 | * @return A Tensor containing the data of the image. 683 | */ 684 | public static Tensor tensorInt( 685 | final RandomAccessibleInterval image, final int[] dimOrder) 686 | { 687 | return tensorInt(reverse(reorder(image, dimOrder))); 688 | } 689 | 690 | /** 691 | * Creates a TensorFlow Tensor containing data from the given long image. 692 | *

693 | * Note that this will use the backing RAI's primitive array when one is 694 | * available and no dimensions where swapped. Otherwise a copy will be made. 695 | *

696 | * @param image The image which should be put into the Tensor. 697 | * @param dimOrder Defines the mapping of the dimensions between the image 698 | * and the Tensor where the index corresponds to the dimension 699 | * in the image and the value corresponds to the dimension in the 700 | * Tensor. TODO Example? 701 | * @return A Tensor containing the data of the image. 702 | */ 703 | public static Tensor tensorLong( 704 | final RandomAccessibleInterval image, final int[] dimOrder) 705 | { 706 | return tensorLong(reverse(reorder(image, dimOrder))); 707 | } 708 | 709 | /** 710 | * Creates a TensorFlow Tensor containing data from the given byte image. 711 | *

712 | * Note that this _does_ adjust the dimensions. This means that 713 | * the resulting Tensor will have a shape directly corresponding to the 714 | * dimensions of the image. Make sure the dimensions are as you want 715 | * them in TensorFlow. See {@link #tensorByte(RandomAccessibleInterval)} and 716 | * {@link #tensorByte(RandomAccessibleInterval, int[])} if you want to handle 717 | * dimensions differently. 718 | *

719 | * Also note that this will use the backing RAI's primitive array when one is 720 | * available. Otherwise a copy will be made. 721 | *

722 | * @param image The image which should be put into the Tensor. 723 | * @return A Tensor containing the data of the image. 724 | */ 725 | public static Tensor tensorByteDirect( 726 | final RandomAccessibleInterval image) 727 | { 728 | return tensorByte(reverse(image)); 729 | } 730 | 731 | /** 732 | * Creates a TensorFlow Tensor containing data from the given double image. 733 | *

734 | * Note that this _does_ adjust the dimensions. This means that 735 | * the resulting Tensor will have a shape directly corresponding to the 736 | * dimensions of the image. Make sure the dimensions are as you want 737 | * them in TensorFlow. See {@link #tensorDouble(RandomAccessibleInterval)} and 738 | * {@link #tensorDouble(RandomAccessibleInterval, int[])} if you want to handle 739 | * dimensions differently. 740 | *

741 | * Also note that this will use the backing RAI's primitive array when one is 742 | * available. Otherwise a copy will be made. 743 | *

744 | * @param image The image which should be put into the Tensor. 745 | * @return A Tensor containing the data of the image. 746 | */ 747 | public static Tensor tensorDoubleDirect( 748 | final RandomAccessibleInterval image) 749 | { 750 | return tensorDouble(reverse(image)); 751 | } 752 | 753 | /** 754 | * Creates a TensorFlow Tensor containing data from the given float image. 755 | *

756 | * Note that this _does_ adjust the dimensions. This means that 757 | * the resulting Tensor will have a shape directly corresponding to the 758 | * dimensions of the image. Make sure the dimensions are as you want 759 | * them in TensorFlow. See {@link #tensorFloat(RandomAccessibleInterval)} and 760 | * {@link #tensorFloat(RandomAccessibleInterval, int[])} if you want to handle 761 | * dimensions differently. 762 | *

763 | * Also note that this will use the backing RAI's primitive array when one is 764 | * available. Otherwise a copy will be made. 765 | *

766 | * @param image The image which should be put into the Tensor. 767 | * @return A Tensor containing the data of the image. 768 | */ 769 | public static Tensor tensorFloatDirect( 770 | final RandomAccessibleInterval image) 771 | { 772 | return tensorFloat(reverse(image)); 773 | } 774 | 775 | /** 776 | * Creates a TensorFlow Tensor containing data from the given int image. 777 | *

778 | * Note that this _does_ adjust the dimensions. This means that 779 | * the resulting Tensor will have a shape directly corresponding to the 780 | * dimensions of the image. Make sure the dimensions are as you want 781 | * them in TensorFlow. See {@link #tensorInt(RandomAccessibleInterval)} and 782 | * {@link #tensorInt(RandomAccessibleInterval, int[])} if you want to handle 783 | * dimensions differently. 784 | *

785 | * Also note that this will use the backing RAI's primitive array when one is 786 | * available. Otherwise a copy will be made. 787 | *

788 | * @param image The image which should be put into the Tensor. 789 | * @return A Tensor containing the data of the image. 790 | */ 791 | public static Tensor tensorIntDirect( 792 | final RandomAccessibleInterval image) 793 | { 794 | return tensorInt(reverse(image)); 795 | } 796 | 797 | /** 798 | * Creates a TensorFlow Tensor containing data from the given long image. 799 | *

800 | * Note that this _does_ adjust the dimensions. This means that 801 | * the resulting Tensor will have a shape directly corresponding to the 802 | * dimensions of the image. Make sure the dimensions are as you want 803 | * them in TensorFlow. See {@link #tensorLong(RandomAccessibleInterval)} and 804 | * {@link #tensorLong(RandomAccessibleInterval, int[])} if you want to handle 805 | * dimensions differently. 806 | *

807 | * Also note that this will use the backing RAI's primitive array when one is 808 | * available. Otherwise a copy will be made. 809 | *

810 | * @param image The image which should be put into the Tensor. 811 | * @return A Tensor containing the data of the image. 812 | */ 813 | public static Tensor tensorLongDirect( 814 | final RandomAccessibleInterval image) 815 | { 816 | return tensorLong(reverse(image)); 817 | } 818 | 819 | // --------- DIMENSIONAL HELPER METHODS --------- 820 | 821 | /** 822 | * Gets the TensorFlow shape of an image. This is the same as the image's 823 | * dimension lengths, but in reversed order. 824 | * 825 | * @param image The image whose shape is desired. 826 | * @return The TensorFlow shape. 827 | */ 828 | private static long[] shape(final Dimensions image) { 829 | long[] shape = new long[image.numDimensions()]; 830 | for (int d = 0; d < shape.length; d++) { 831 | shape[d] = image.dimension(shape.length - d - 1); 832 | } 833 | return shape; 834 | } 835 | 836 | /** 837 | * Gets the imglib dimension length of a tensor. This is the same as the 838 | * tensor's shape but in reversed order. 839 | * 840 | * @param tensor The tensor whose dimension length are desired. 841 | * @return The imglib dimension length. 842 | */ 843 | private static long[] shape(final Tensor tensor) { 844 | return shape(new FinalDimensions(tensor.shape())); 845 | } 846 | 847 | /** Flips all dimensions {@code d0,d1,...,dn -> dn,...,d1,d0}. */ 848 | public static > Img reverse(Img image) 849 | { 850 | RandomAccessibleInterval reversed = reverse((RandomAccessibleInterval) image); 851 | return ImgView.wrap(reversed, image.factory()); 852 | } 853 | 854 | /** Flips all dimensions {@code d0,d1,...,dn -> dn,...,d1,d0}. */ 855 | public static > RandomAccessibleInterval reverse( 856 | RandomAccessibleInterval image) 857 | { 858 | RandomAccessibleInterval reversed = image; 859 | for (int d = 0; d < image.numDimensions() / 2; d++) { 860 | reversed = Views.permute(reversed, d, image.numDimensions() - d - 1); 861 | } 862 | return reversed; 863 | } 864 | 865 | private static > Img reorder( 866 | Img image, int[] dimOrder) 867 | { 868 | RandomAccessibleInterval result = reorder((RandomAccessibleInterval) image, dimOrder); 869 | return ImgView.wrap(result, image.factory()); 870 | } 871 | 872 | private static > RandomAccessibleInterval reorder( 873 | RandomAccessibleInterval image, int[] dimOrder) 874 | { 875 | RandomAccessibleInterval output = image; 876 | 877 | // Array which contains for each dimension information on which dimension it is right now 878 | int[] moved = IntStream.range(0, image.numDimensions()).toArray(); 879 | 880 | // Loop over all dimensions and move it to the right spot 881 | for (int i = 0; i < image.numDimensions(); i++) { 882 | int from = moved[i]; 883 | int to = dimOrder[i]; 884 | 885 | // Move the dimension to the right dimension 886 | output = Views.permute(output, from, to); 887 | 888 | // Now we have to update which dimension was moved where 889 | moved[i] = to; 890 | moved = Arrays.stream(moved).map(v -> v == to ? from : v).toArray(); 891 | } 892 | return output; 893 | } 894 | 895 | private static > Img reverseReorder(Img image, int[] dimOrder) { 896 | int[] reverseDimOrder = new int[dimOrder.length]; 897 | for (int i = 0; i < dimOrder.length; i++) { 898 | reverseDimOrder[dimOrder[i]] = i; 899 | } 900 | return reorder(image, reverseDimOrder); 901 | } 902 | 903 | // --------- HELPER STUFF --------- 904 | 905 | // TODO: consider also putting this outside Tensors, and instead in ImgLib2 core. 906 | // _Maybe_ mix with extractFloatArray? Probably better to keep separate layer though. 907 | // This is really "extract or create float array" 908 | // So maybe we want a "createFloatArray" which _always_ does the copy. 909 | // And the a nice "extractOrCreate" guy that calls extract followed by create 910 | // as is done here. 911 | 912 | private static byte[] byteArray( 913 | final RandomAccessibleInterval image) 914 | { 915 | final byte[] array = extractByteArray(image); 916 | return array == null ? createByteArray(image) : array; 917 | } 918 | 919 | private static double[] doubleArray( 920 | final RandomAccessibleInterval image) 921 | { 922 | final double[] array = extractDoubleArray(image); 923 | return array == null ? createDoubleArray(image) : array; 924 | } 925 | 926 | private static float[] floatArray( 927 | final RandomAccessibleInterval image) 928 | { 929 | final float[] array = extractFloatArray(image); 930 | return array == null ? createFloatArray(image) : array; 931 | } 932 | 933 | private static int[] intArray( 934 | final RandomAccessibleInterval image) 935 | { 936 | final int[] array = extractIntArray(image); 937 | return array == null ? createIntArray(image) : array; 938 | } 939 | 940 | private static long[] longArray( 941 | final RandomAccessibleInterval image) 942 | { 943 | final long[] array = extractLongArray(image); 944 | return array == null ? createLongArray(image) : array; 945 | } 946 | 947 | private static byte[] createByteArray( 948 | final RandomAccessibleInterval image) 949 | { 950 | final long[] dims = Intervals.dimensionsAsLongArray(image); 951 | final ArrayImg dest = ArrayImgs.bytes(dims); 952 | copy(image, dest); 953 | return dest.update(null).getCurrentStorageArray(); 954 | } 955 | 956 | private static double[] createDoubleArray( 957 | final RandomAccessibleInterval image) 958 | { 959 | final long[] dims = Intervals.dimensionsAsLongArray(image); 960 | final ArrayImg dest = ArrayImgs.doubles(dims); 961 | copy(image, dest); 962 | return dest.update(null).getCurrentStorageArray(); 963 | } 964 | 965 | private static float[] createFloatArray( 966 | final RandomAccessibleInterval image) 967 | { 968 | final long[] dims = Intervals.dimensionsAsLongArray(image); 969 | final ArrayImg dest = ArrayImgs.floats(dims); 970 | copy(image, dest); 971 | return dest.update(null).getCurrentStorageArray(); 972 | } 973 | 974 | private static int[] createIntArray( 975 | final RandomAccessibleInterval image) 976 | { 977 | final long[] dims = Intervals.dimensionsAsLongArray(image); 978 | final ArrayImg dest = ArrayImgs.ints(dims); 979 | copy(image, dest); 980 | return dest.update(null).getCurrentStorageArray(); 981 | } 982 | 983 | private static long[] createLongArray( 984 | final RandomAccessibleInterval image) 985 | { 986 | final long[] dims = Intervals.dimensionsAsLongArray(image); 987 | final ArrayImg dest = ArrayImgs.longs(dims); 988 | copy(image, dest); 989 | return dest.update(null).getCurrentStorageArray(); 990 | } 991 | 992 | private static byte[] extractByteArray( 993 | final RandomAccessibleInterval image) 994 | { 995 | if (!(image instanceof ArrayImg)) return null; 996 | @SuppressWarnings("unchecked") 997 | final ArrayImg arrayImg = (ArrayImg) image; 998 | final Object dataAccess = arrayImg.update(null); 999 | return dataAccess instanceof ByteArray ? // 1000 | ((ByteArray) dataAccess).getCurrentStorageArray() : null; 1001 | } 1002 | 1003 | private static double[] extractDoubleArray( 1004 | final RandomAccessibleInterval image) 1005 | { 1006 | if (!(image instanceof ArrayImg)) return null; 1007 | @SuppressWarnings("unchecked") 1008 | final ArrayImg arrayImg = (ArrayImg) image; 1009 | final Object dataAccess = arrayImg.update(null); 1010 | return dataAccess instanceof DoubleArray ? // 1011 | ((DoubleArray) dataAccess).getCurrentStorageArray() : null; 1012 | } 1013 | 1014 | private static float[] extractFloatArray( 1015 | final RandomAccessibleInterval image) 1016 | { 1017 | if (!(image instanceof ArrayImg)) return null; 1018 | @SuppressWarnings("unchecked") 1019 | final ArrayImg arrayImg = (ArrayImg) image; 1020 | // GOOD NEWS: float[] rasterization order is dimension-wise! 1021 | // BAD NEWS: it always goes d0,d1,d2,.... is that the order we need? 1022 | // MORE BAD NEWS: As soon as you use Views.permute, image is not ArrayImg anymore. 1023 | // SO: This only works if you give a RAI that happens to be laid out directly as TensorFlow desires. 1024 | final Object dataAccess = arrayImg.update(null); 1025 | return dataAccess instanceof FloatArray ? // 1026 | ((FloatArray) dataAccess).getCurrentStorageArray() : null; 1027 | } 1028 | 1029 | private static int[] extractIntArray( 1030 | final RandomAccessibleInterval image) 1031 | { 1032 | if (!(image instanceof ArrayImg)) return null; 1033 | @SuppressWarnings("unchecked") 1034 | final ArrayImg arrayImg = (ArrayImg) image; 1035 | final Object dataAccess = arrayImg.update(null); 1036 | return dataAccess instanceof IntArray ? // 1037 | ((IntArray) dataAccess).getCurrentStorageArray() : null; 1038 | } 1039 | 1040 | private static long[] extractLongArray( 1041 | final RandomAccessibleInterval image) 1042 | { 1043 | if (!(image instanceof ArrayImg)) return null; 1044 | @SuppressWarnings("unchecked") 1045 | final ArrayImg arrayImg = (ArrayImg) image; 1046 | final Object dataAccess = arrayImg.update(null); 1047 | return dataAccess instanceof LongArray ? // 1048 | ((LongArray) dataAccess).getCurrentStorageArray() : null; 1049 | } 1050 | 1051 | private static > void copy( 1052 | final RandomAccessibleInterval source, 1053 | final IterableInterval dest) 1054 | { 1055 | final RandomAccess sourceAccess = source.randomAccess(); 1056 | final Cursor destCursor = dest.localizingCursor(); 1057 | while (destCursor.hasNext()) { 1058 | destCursor.fwd(); 1059 | sourceAccess.setPosition(destCursor); 1060 | destCursor.get().set(sourceAccess.get()); 1061 | } 1062 | } 1063 | } 1064 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/demo/LabelImage.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.demo; 32 | 33 | import java.io.IOException; 34 | import java.util.Arrays; 35 | import java.util.Comparator; 36 | import java.util.List; 37 | import java.util.stream.IntStream; 38 | 39 | import net.imagej.Dataset; 40 | import net.imagej.ImageJ; 41 | import net.imagej.tensorflow.GraphBuilder; 42 | import net.imagej.tensorflow.TensorFlowService; 43 | import net.imagej.tensorflow.Tensors; 44 | import net.imglib2.RandomAccessibleInterval; 45 | import net.imglib2.converter.Converters; 46 | import net.imglib2.converter.RealFloatConverter; 47 | import net.imglib2.img.Img; 48 | import net.imglib2.type.numeric.RealType; 49 | import net.imglib2.type.numeric.real.FloatType; 50 | 51 | import org.scijava.ItemIO; 52 | import org.scijava.command.Command; 53 | import org.scijava.io.http.HTTPLocation; 54 | import org.scijava.log.LogService; 55 | import org.scijava.plugin.Parameter; 56 | import org.scijava.plugin.Plugin; 57 | import org.tensorflow.Graph; 58 | import org.tensorflow.Output; 59 | import org.tensorflow.Session; 60 | import org.tensorflow.Tensor; 61 | 62 | /** 63 | * Command to use an Inception-image recognition model to label an image. 64 | *

65 | * See the 66 | * TensorFlow 67 | * image recognition tutorial and its 69 | * Java implementation. 70 | */ 71 | @Plugin(type = Command.class, menuPath = "Plugins>Sandbox>TensorFlow Demos>Label Image", 72 | headless = true) 73 | public class LabelImage implements Command { 74 | 75 | private static final String INCEPTION_MODEL_URL = 76 | "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; 77 | 78 | private static final String MODEL_NAME = "inception5h"; 79 | 80 | @Parameter 81 | private TensorFlowService tensorFlowService; 82 | 83 | @Parameter 84 | private LogService log; 85 | 86 | @Parameter 87 | private Dataset inputImage; 88 | 89 | @Parameter(label = "Min probability (%)", min = "0", max = "100") 90 | private double minPercent = 1; 91 | 92 | @Parameter(type = ItemIO.OUTPUT) 93 | private Img outputImage; 94 | 95 | @Parameter(type = ItemIO.OUTPUT) 96 | private String outputLabels; 97 | 98 | @Override 99 | public void run() { 100 | 101 | tensorFlowService.loadLibrary(); 102 | if(!tensorFlowService.getStatus().isLoaded()) return; 103 | log.info("Version of TensorFlow: " + tensorFlowService.getTensorFlowVersion()); 104 | 105 | try { 106 | final HTTPLocation source = new HTTPLocation(INCEPTION_MODEL_URL); 107 | final Graph graph = tensorFlowService.loadGraph(source, MODEL_NAME, 108 | "tensorflow_inception_graph.pb"); 109 | final List labels = tensorFlowService.loadLabels(source, 110 | MODEL_NAME, "imagenet_comp_graph_label_strings.txt"); 111 | log.info("Loaded graph and " + labels.size() + " labels"); 112 | 113 | try ( 114 | final Tensor inputTensor = loadFromImgLib(inputImage); 115 | final Tensor image = normalizeImage(inputTensor) 116 | ) 117 | { 118 | outputImage = Tensors.imgFloat(image, new int[]{ 2, 1, 3, 0 }); 119 | final float[] labelProbabilities = executeInceptionGraph(graph, image); 120 | 121 | // Sort labels by probability. 122 | final int labelCount = Math.min(labelProbabilities.length, labels 123 | .size()); 124 | final Integer[] labelIndices = IntStream.range(0, labelCount).boxed() 125 | .toArray(Integer[]::new); 126 | Arrays.sort(labelIndices, 0, labelCount, new Comparator() { 127 | 128 | @Override 129 | public int compare(final Integer i1, final Integer i2) { 130 | final float p1 = labelProbabilities[i1]; 131 | final float p2 = labelProbabilities[i2]; 132 | return p1 == p2 ? 0 : p1 > p2 ? -1 : 1; 133 | } 134 | }); 135 | 136 | // Output labels above the probability threshold. 137 | final double cutoff = minPercent / 100; // % -> probability 138 | final StringBuilder sb = new StringBuilder(); 139 | for (int i = 0; i < labelCount; i++) { 140 | final int index = labelIndices[i]; 141 | final double p = labelProbabilities[index]; 142 | if (p < cutoff) break; 143 | 144 | sb.append(String.format("%s (%.2f%% likely)\n", labels.get(index), 145 | labelProbabilities[index] * 100f)); 146 | } 147 | outputLabels = sb.toString(); 148 | } 149 | } 150 | catch (final Exception exc) { 151 | log.error(exc); 152 | } 153 | } 154 | 155 | @SuppressWarnings({ "rawtypes", "unchecked" }) 156 | private static Tensor loadFromImgLib(final Dataset d) { 157 | return loadFromImgLib((RandomAccessibleInterval) d.getImgPlus()); 158 | } 159 | 160 | private static > Tensor loadFromImgLib( 161 | final RandomAccessibleInterval image) 162 | { 163 | // NB: Assumes XYC ordering. TensorFlow wants YXC. 164 | RealFloatConverter converter = new RealFloatConverter<>(); 165 | return Tensors.tensorFloat(Converters.convert(image, converter, new FloatType()), new int[]{ 1, 0, 2 }); 166 | } 167 | 168 | // ----------------------------------------------------------------------------------------------------------------- 169 | // All the code below was essentially copied verbatim from: 170 | // https://github.com/tensorflow/tensorflow/blob/e8f2aad0c0502fde74fc629f5b13f04d5d206700/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java 171 | // ----------------------------------------------------------------------------------------------------------------- 172 | private static Tensor normalizeImage(final Tensor t) { 173 | try (Graph g = new Graph()) { 174 | final GraphBuilder b = new GraphBuilder(g); 175 | // Some constants specific to the pre-trained model at: 176 | // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip 177 | // 178 | // - The model was trained with images scaled to 224x224 pixels. 179 | // - The colors, represented as R, G, B in 1-byte each were converted to 180 | // float using (value - Mean)/Scale. 181 | final int H = 224; 182 | final int W = 224; 183 | final float mean = 117f; 184 | final float scale = 1f; 185 | 186 | // Since the graph is being constructed once per execution here, we can 187 | // use a constant for the input image. If the graph were to be re-used for 188 | // multiple input images, a placeholder would have been more appropriate. 189 | final Output input = g.opBuilder("Const", "input")// 190 | .setAttr("dtype", t.dataType())// 191 | .setAttr("value", t).build().output(0); 192 | final Output output = b.div(b.sub(b.resizeBilinear(b.expandDims(// 193 | input, // 194 | b.constant("make_batch", 0)), // 195 | b.constant("size", new int[] { H, W })), // 196 | b.constant("mean", mean)), // 197 | b.constant("scale", scale)); 198 | try (Session s = new Session(g)) { 199 | @SuppressWarnings("unchecked") 200 | Tensor result = (Tensor) s.runner().fetch(output.op().name()).run().get(0); 201 | return result; 202 | } 203 | } 204 | } 205 | 206 | private static float[] executeInceptionGraph(final Graph g, 207 | final Tensor image) 208 | { 209 | try ( 210 | final Session s = new Session(g); 211 | @SuppressWarnings("unchecked") 212 | final Tensor result = (Tensor) s.runner().feed("input", image)// 213 | .fetch("output").run().get(0) 214 | ) 215 | { 216 | final long[] rshape = result.shape(); 217 | if (result.numDimensions() != 2 || rshape[0] != 1) { 218 | throw new RuntimeException(String.format( 219 | "Expected model to produce a [1 N] shaped tensor where N is " + 220 | "the number of labels, instead it produced one with shape %s", 221 | Arrays.toString(rshape))); 222 | } 223 | final int nlabels = (int) rshape[1]; 224 | return result.copyTo(new float[1][nlabels])[0]; 225 | } 226 | } 227 | 228 | public static void main(String[] args) throws IOException { 229 | final ImageJ ij = new ImageJ(); 230 | ij.launch(args); 231 | 232 | // Open an image and display it. 233 | final String imagePath = "http://samples.fiji.sc/new-lenna.jpg"; 234 | final Object dataset = ij.io().open(imagePath); 235 | ij.ui().show(dataset); 236 | 237 | // Launch the "Label Images" command with some sensible defaults. 238 | ij.command().run(LabelImage.class, true, // 239 | "inputImage", dataset); 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/ui/AvailableTensorFlowVersions.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.ui; 32 | 33 | import java.net.MalformedURLException; 34 | import java.net.URL; 35 | import java.util.ArrayList; 36 | import java.util.List; 37 | 38 | /** 39 | * Utility class providing a hardcoded list of available native TensorFlow library versions which can be downloaded from the google servers. 40 | * 41 | * @author Deborah Schmidt 42 | * @author Benjamin Wilhelm 43 | */ 44 | public class AvailableTensorFlowVersions { 45 | 46 | /** 47 | * @return a list of downloadable native TensorFlow versions for all operating systems, including their URL and optionally their CUDA / CuDNN compatibility 48 | */ 49 | public static List get() { 50 | List versions = new ArrayList<>(); 51 | versions.clear(); 52 | //linux64 53 | versions.add(version("linux64", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.2.0.tar.gz")); 54 | versions.add(version("linux64", "1.2.0", "GPU", "8.0", "5.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.2.0.tar.gz")); 55 | versions.add(version("linux64", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.3.0.tar.gz")); 56 | versions.add(version("linux64", "1.3.0", "GPU", "8.0", "6", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.3.0.tar.gz")); 57 | versions.add(version("linux64", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.4.1.tar.gz")); 58 | versions.add(version("linux64", "1.4.1", "GPU", "8.0", "6", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.4.1.tar.gz")); 59 | versions.add(version("linux64", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.5.0.tar.gz")); 60 | versions.add(version("linux64", "1.5.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.5.0.tar.gz")); 61 | versions.add(version("linux64", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.6.0.tar.gz")); 62 | versions.add(version("linux64", "1.6.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.6.0.tar.gz")); 63 | versions.add(version("linux64", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.7.0.tar.gz")); 64 | versions.add(version("linux64", "1.7.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.7.0.tar.gz")); 65 | versions.add(version("linux64", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.8.0.tar.gz")); 66 | versions.add(version("linux64", "1.8.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.8.0.tar.gz")); 67 | versions.add(version("linux64", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.9.0.tar.gz")); 68 | versions.add(version("linux64", "1.9.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.9.0.tar.gz")); 69 | versions.add(version("linux64", "1.10.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.10.1.tar.gz")); 70 | versions.add(version("linux64", "1.10.1", "GPU", "9.0", "7.?", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.10.1.tar.gz")); 71 | versions.add(version("linux64", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.11.0.tar.gz")); 72 | versions.add(version("linux64", "1.11.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.11.0.tar.gz")); 73 | versions.add(version("linux64", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.12.0.tar.gz")); 74 | versions.add(version("linux64", "1.12.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.12.0.tar.gz")); 75 | versions.add(version("linux64", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.13.1.tar.gz")); 76 | versions.add(version("linux64", "1.13.1", "GPU", "10.0", "7.4", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.13.1.tar.gz")); 77 | versions.add(version("linux64", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.14.0.tar.gz")); 78 | versions.add(version("linux64", "1.14.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.14.0.tar.gz")); 79 | versions.add(version("linux64", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.15.0.tar.gz")); 80 | versions.add(version("linux64", "1.15.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.15.0.tar.gz")); 81 | //win64 82 | versions.add(version("win64", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.2.0.zip")); 83 | versions.add(version("win64", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.3.0.zip")); 84 | versions.add(version("win64", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.4.1.zip")); 85 | versions.add(version("win64", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.5.0.zip")); 86 | versions.add(version("win64", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0.zip")); 87 | versions.add(version("win64", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.7.0.zip")); 88 | versions.add(version("win64", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0.zip")); 89 | versions.add(version("win64", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0.zip")); 90 | versions.add(version("win64", "1.10.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.10.0.zip")); 91 | versions.add(version("win64", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.11.0.zip")); 92 | versions.add(version("win64", "1.12.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.12.0.zip")); 93 | versions.add(version("win64", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.12.0.zip")); 94 | versions.add(version("win64", "1.13.1", "GPU", "10.0", "7.4", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.13.1.zip")); 95 | versions.add(version("win64", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.13.1.zip")); 96 | versions.add(version("win64", "1.14.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.14.0.zip")); 97 | versions.add(version("win64", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.14.0.zip")); 98 | versions.add(version("win64", "1.15.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.15.0.zip")); 99 | versions.add(version("win64", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.15.0.zip")); 100 | //macosx 101 | versions.add(version("macosx", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.2.0.tar.gz")); 102 | versions.add(version("macosx", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.3.0.tar.gz")); 103 | versions.add(version("macosx", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.4.1.tar.gz")); 104 | versions.add(version("macosx", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.5.0.tar.gz")); 105 | versions.add(version("macosx", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.6.0.tar.gz")); 106 | versions.add(version("macosx", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.7.0.tar.gz")); 107 | versions.add(version("macosx", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.8.0.tar.gz")); 108 | versions.add(version("macosx", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.9.0.tar.gz")); 109 | versions.add(version("macosx", "1.10.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.10.1.tar.gz")); 110 | versions.add(version("macosx", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.11.0.tar.gz")); 111 | versions.add(version("macosx", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.12.0.tar.gz")); 112 | versions.add(version("macosx", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.13.1.tar.gz")); 113 | versions.add(version("macosx", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.14.0.tar.gz")); 114 | versions.add(version("macosx", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.15.0.tar.gz")); 115 | return versions; 116 | } 117 | 118 | private static DownloadableTensorFlowVersion version(String platform, String tensorFlowVersion, String mode, String url) { 119 | DownloadableTensorFlowVersion version = new DownloadableTensorFlowVersion(tensorFlowVersion, mode.equals("GPU")); 120 | try { 121 | version.setURL(new URL(url)); 122 | } catch (MalformedURLException e) { 123 | e.printStackTrace(); 124 | } 125 | version.setPlatform(platform); 126 | return version; 127 | } 128 | 129 | private static DownloadableTensorFlowVersion version(String platform, String tensorFlowVersion, String mode, String cuda, String cudnn, String url) { 130 | try { 131 | return new DownloadableTensorFlowVersion(new URL(url), tensorFlowVersion, platform, mode.equals("GPU"), cuda, cudnn); 132 | } catch (MalformedURLException e) { 133 | e.printStackTrace(); 134 | return null; 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/ui/DownloadableTensorFlowVersion.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.ui; 32 | 33 | import net.imagej.tensorflow.TensorFlowVersion; 34 | 35 | import java.net.URL; 36 | import java.util.Objects; 37 | 38 | /** 39 | * This class extends {@link TensorFlowVersion} with knowledge about the web source of this version, the matching platform and whether it has already been downloaded. 40 | * 41 | * @author Deborah Schmidt 42 | * @author Benjamin Wilhelm 43 | */ 44 | class DownloadableTensorFlowVersion extends TensorFlowVersion { 45 | private URL url; 46 | private String platform; 47 | private String localPath; 48 | 49 | private boolean active = false; 50 | private boolean downloaded = false; 51 | 52 | /** 53 | * @param url the URL of the packed JNI file 54 | * @param version the TensorFlow version number 55 | * @param os the platform this version is associated with 56 | * @param supportsGPU whether this version runs on the GPU 57 | * @param compatibleCUDA the CUDA version compatible with this TensorFlow version 58 | * @param compatibleCuDNN the CuDNN version compatible with this TensorFlow version 59 | */ 60 | DownloadableTensorFlowVersion(URL url, String version, String os, boolean supportsGPU, String compatibleCUDA, String compatibleCuDNN) { 61 | super(version, supportsGPU, compatibleCUDA, compatibleCuDNN); 62 | this.platform = os; 63 | this.url = url; 64 | } 65 | 66 | /** 67 | * @param version the TensorFlow version number 68 | * @param supportsGPU whether this version runs on the GPU 69 | */ 70 | DownloadableTensorFlowVersion(String version, boolean supportsGPU) { 71 | super(version, supportsGPU, null, null); 72 | } 73 | 74 | DownloadableTensorFlowVersion(TensorFlowVersion other) { 75 | super(other); 76 | } 77 | 78 | /** 79 | * @return the URL where an archive of this version can be downloaded from 80 | */ 81 | public URL getURL() { 82 | return url; 83 | } 84 | 85 | void setURL(URL url) { 86 | this.url = url; 87 | } 88 | 89 | /** 90 | * @return the platform this version is associated with (linux64, linux32, win64, win32, macosx) 91 | */ 92 | public String getPlatform() { 93 | return platform; 94 | } 95 | 96 | void setPlatform(String platform) { 97 | this.platform = platform; 98 | } 99 | 100 | /** 101 | * @return the path this version is cached to 102 | */ 103 | public String getLocalPath() { 104 | return localPath; 105 | } 106 | 107 | /** 108 | * @return whether this version is currently being used 109 | */ 110 | public boolean isActive() { 111 | return active; 112 | } 113 | 114 | void setActive(boolean active) { 115 | this.active = active; 116 | } 117 | 118 | /** 119 | * @return whether this version is cashed to a local path 120 | */ 121 | public boolean isCached() { 122 | return downloaded; 123 | } 124 | 125 | /** 126 | * @return String describing the origin of this version 127 | */ 128 | public String getOriginDescription() { 129 | if(downloaded) return localPath; 130 | if(url == null) return "unknown source"; 131 | return url.toString(); 132 | } 133 | 134 | /** 135 | * @return the TensorFlow version number in an easily comparable format (e.g. 0.13.1 to 000.013.001) 136 | */ 137 | String getComparableTFVersion() { 138 | String orderableVersion = ""; 139 | String[] split = tfVersion.split("\\."); 140 | for(String part : split) { 141 | orderableVersion += String.format("%03d%n", Integer.valueOf(part)); 142 | } 143 | String gpu = ""; 144 | if(supportsGPU.isPresent()) gpu += "_" + (supportsGPU.get() ? "cpu" : "gpu"); 145 | return orderableVersion + gpu; 146 | } 147 | 148 | void setCached(String localPath) { 149 | downloaded = true; 150 | this.localPath = localPath; 151 | } 152 | 153 | void discardCache() { 154 | downloaded = false; 155 | } 156 | 157 | void harvest(DownloadableTensorFlowVersion other) { 158 | if(!cuda.isPresent() && other.cuda.isPresent()) { 159 | cuda = other.cuda; 160 | } 161 | if(!cudnn.isPresent() && other.cudnn.isPresent()) { 162 | cudnn = other.cudnn; 163 | } 164 | if(url == null) url = other.url; 165 | if(localPath == null) localPath = other.localPath; 166 | downloaded = other.downloaded; 167 | } 168 | 169 | @Override 170 | public boolean equals(final Object obj) { 171 | if (!(obj.getClass().equals(this.getClass()))) return false; 172 | final DownloadableTensorFlowVersion o = (DownloadableTensorFlowVersion) obj; 173 | return super.equals(o) && Objects.equals(platform, o.platform); 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/ui/TensorFlowInstallationHandler.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.ui; 32 | 33 | import net.imagej.tensorflow.util.TensorFlowUtil; 34 | import net.imagej.tensorflow.util.UnpackUtil; 35 | import org.scijava.app.AppService; 36 | import org.scijava.app.StatusService; 37 | import org.scijava.log.LogService; 38 | import org.scijava.plugin.Parameter; 39 | import org.scijava.plugin.Plugin; 40 | import org.scijava.plugin.SciJavaPlugin; 41 | 42 | import java.io.File; 43 | import java.io.IOException; 44 | import java.io.InputStream; 45 | import java.net.URL; 46 | import java.nio.file.Files; 47 | import java.nio.file.Paths; 48 | import java.nio.file.StandardCopyOption; 49 | 50 | /** 51 | * This class handles instances of {@link DownloadableTensorFlowVersion}. 52 | * It is responsible for downloading, caching, unpacking, and installing them. 53 | * 54 | * @author Deborah Schmidt 55 | * @author Benjamin Wilhelm 56 | */ 57 | class TensorFlowInstallationHandler { 58 | 59 | @Parameter 60 | private AppService appService; 61 | 62 | @Parameter 63 | private LogService logService; 64 | 65 | @Parameter 66 | private StatusService statusService; 67 | 68 | private static final String DOWNLOADDIR = "downloads/"; 69 | 70 | /** 71 | * Checks for a specific version whether it is downloaded and installed. 72 | * @param version the version which will be checked 73 | */ 74 | void updateCacheStatus(final DownloadableTensorFlowVersion version) { 75 | if(version.getURL() == null) return; 76 | if (version.getURL().getProtocol().equalsIgnoreCase("file")) { 77 | version.setCached(version.getURL().getFile()); 78 | } else { 79 | String path = getDownloadDir() + getNameFromURL(version.getURL()); 80 | if (new File(path).exists()) { 81 | version.setCached(path); 82 | } 83 | } 84 | } 85 | 86 | private String getNameFromURL(URL url) { 87 | String[] parts = url.getPath().split("/"); 88 | if(parts.length > 0) return parts[parts.length - 1]; 89 | else return null; 90 | } 91 | 92 | /** 93 | * Activates a specific TensorFlow library which will then be used after restarting the application. 94 | * @param version the version being activated 95 | * @throws IOException thrown in case downloading the version failed or the version could not be unpacked 96 | */ 97 | void activateVersion(DownloadableTensorFlowVersion version) throws IOException { 98 | if (!version.isCached()) { 99 | downloadVersion(version.getURL()); 100 | } 101 | updateCacheStatus(version); 102 | if (!version.isActive()) { 103 | installVersion(version); 104 | } 105 | } 106 | 107 | private void downloadVersion(URL url) throws IOException { 108 | createDownloadDir(); 109 | String filename = url.getFile().substring(url.getFile().lastIndexOf("/") + 1); 110 | String localFile = getDownloadDir() + filename; 111 | logService.info("Downloading " + url + " to " + localFile); 112 | InputStream in = url.openStream(); 113 | Files.copy(in, Paths.get(localFile), StandardCopyOption.REPLACE_EXISTING); 114 | } 115 | 116 | private void createDownloadDir() { 117 | File downloadDir = new File(getDownloadDir()); 118 | if (!downloadDir.exists()) { 119 | downloadDir.mkdirs(); 120 | } 121 | } 122 | 123 | private void installVersion(DownloadableTensorFlowVersion version) throws IOException { 124 | 125 | logService.info("Installing " + version); 126 | 127 | File outputDir = new File(TensorFlowUtil.getUpdateLibDir(getRoot()) + version.getPlatform() + File.separator); 128 | 129 | if (version.getLocalPath().contains(".zip")) { 130 | UnpackUtil.unZip(version.getLocalPath(), outputDir, logService, statusService); 131 | TensorFlowUtil.writeNativeVersionFile(getRoot(), version.getPlatform(), version); 132 | } else if (version.getLocalPath().endsWith(".tar.gz")) { 133 | String symLinkOutputDir = TensorFlowUtil.getLibDir(getRoot()) + version.getPlatform() + File.separator; 134 | UnpackUtil.unGZip(version.getLocalPath(), outputDir, symLinkOutputDir, logService, statusService); 135 | TensorFlowUtil.writeNativeVersionFile(getRoot(), version.getPlatform(), version); 136 | } 137 | 138 | if (version.getLocalPath().endsWith(".jar")) { 139 | // using default JAR version. 140 | TensorFlowUtil.removeNativeLibraries(getRoot(), logService); 141 | logService.info("Using default JAR TensorFlow version."); 142 | } 143 | 144 | statusService.clearStatus(); 145 | 146 | TensorFlowUtil.getCrashFile(getRoot()).delete(); 147 | } 148 | 149 | private String getRoot() { 150 | return appService.getApp().getBaseDirectory().getAbsolutePath(); 151 | } 152 | 153 | private String getDownloadDir() { 154 | return appService.getApp().getBaseDirectory().getAbsolutePath() + File.separator + DOWNLOADDIR; 155 | } 156 | 157 | } 158 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/ui/TensorFlowLibraryManagementCommand.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.ui; 32 | 33 | import net.imagej.ImageJ; 34 | import net.imagej.tensorflow.TensorFlowService; 35 | import net.imagej.tensorflow.TensorFlowVersion; 36 | import net.imagej.tensorflow.util.TensorFlowUtil; 37 | import net.imagej.updater.util.Platforms; 38 | import org.scijava.Context; 39 | import org.scijava.app.AppService; 40 | import org.scijava.command.Command; 41 | import org.scijava.log.LogService; 42 | import org.scijava.plugin.Parameter; 43 | import org.scijava.plugin.Plugin; 44 | 45 | import java.awt.*; 46 | import java.util.ArrayList; 47 | import java.util.List; 48 | 49 | /** 50 | * This command provides an interface for activating a specific Java TensorFlow version. 51 | * The user can choose between the version included via maven and archived native library versions 52 | * which can be downloaded from the Google servers. 53 | * 54 | * @author Deborah Schmidt 55 | * @author Benjamin Wilhelm 56 | */ 57 | @Plugin(type = Command.class, menuPath = "Edit>Options>TensorFlow...") 58 | public class TensorFlowLibraryManagementCommand implements Command { 59 | 60 | @Parameter 61 | private TensorFlowService tensorFlowService; 62 | 63 | @Parameter 64 | private AppService appService; 65 | 66 | @Parameter 67 | private LogService logService; 68 | 69 | @Parameter 70 | private Context context; 71 | 72 | DownloadableTensorFlowVersion currentVersion; 73 | List availableVersions = new ArrayList<>(); 74 | 75 | String platform = Platforms.current(); 76 | 77 | private TensorFlowInstallationHandler installationHandler; 78 | 79 | @Override 80 | public void run() { 81 | tensorFlowService.loadLibrary(); 82 | installationHandler = new TensorFlowInstallationHandler(); 83 | context.inject(installationHandler); 84 | final TensorFlowLibraryManagementFrame frame = new TensorFlowLibraryManagementFrame(tensorFlowService, installationHandler); 85 | frame.init(); 86 | initAvailableVersions(); 87 | frame.updateChoices(availableVersions); 88 | frame.pack(); 89 | frame.setLocationRelativeTo(null); 90 | frame.setMinimumSize(new Dimension(0, 200)); 91 | frame.setVisible(true); 92 | 93 | } 94 | 95 | private void initAvailableVersions() { 96 | initCurrentVersion(); 97 | if(currentVersion != null) addAvailableVersion(currentVersion); 98 | addAvailableVersion(getTensorFlowJARVersion()); 99 | AvailableTensorFlowVersions.get().forEach(version -> addAvailableVersion(version)); 100 | } 101 | 102 | private void initCurrentVersion() { 103 | TensorFlowVersion current = tensorFlowService.getTensorFlowVersion(); 104 | if(current != null) { 105 | currentVersion = new DownloadableTensorFlowVersion(current); 106 | currentVersion.setPlatform(platform); 107 | currentVersion.setActive(true); 108 | } 109 | } 110 | 111 | private DownloadableTensorFlowVersion getTensorFlowJARVersion() { 112 | DownloadableTensorFlowVersion version = new DownloadableTensorFlowVersion(TensorFlowUtil.versionFromClassPathJAR()); 113 | version.setPlatform(platform); 114 | version.setURL(TensorFlowUtil.getTensorFlowJAR()); 115 | return version; 116 | } 117 | 118 | private void addAvailableVersion(DownloadableTensorFlowVersion version) { 119 | if(!version.getPlatform().equals(platform)) return; 120 | installationHandler.updateCacheStatus(version); 121 | for (DownloadableTensorFlowVersion other : availableVersions) { 122 | if (other.equals(version)) { 123 | other.harvest(version); 124 | return; 125 | } 126 | } 127 | availableVersions.add(version); 128 | } 129 | 130 | public static void main(String... args) { 131 | ImageJ ij = new ImageJ(); 132 | ij.ui().showUI(); 133 | ij.command().run(TensorFlowLibraryManagementCommand.class, true); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/ui/TensorFlowLibraryManagementFrame.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.ui; 32 | 33 | import net.imagej.tensorflow.TensorFlowService; 34 | import net.miginfocom.swing.MigLayout; 35 | 36 | import javax.swing.*; 37 | import java.awt.*; 38 | import java.io.IOException; 39 | import java.util.*; 40 | import java.util.List; 41 | 42 | /** 43 | * The UI part of the {@link TensorFlowLibraryManagementCommand} 44 | * 45 | * @author Deborah Schmidt 46 | * @author Benjamin Wilhelm 47 | */ 48 | class TensorFlowLibraryManagementFrame extends JFrame { 49 | 50 | private static final long serialVersionUID = 1L; 51 | 52 | private static final Color LIST_BACKGROUND_COLOR = new Color(250,250,250); 53 | private static final String NOFILTER = "-"; 54 | 55 | private TensorFlowService tensorFlowService; 56 | private TensorFlowInstallationHandler installationHandler; 57 | 58 | private JComboBox gpuChoiceBox; 59 | private JComboBox cudaChoiceBox; 60 | private JComboBox tfChoiceBox; 61 | private JPanel installPanel; 62 | private JTextArea status; 63 | 64 | private List availableVersions = new ArrayList<>(); 65 | private List buttons = new ArrayList<>(); 66 | 67 | TensorFlowLibraryManagementFrame(final TensorFlowService tensorFlowInstallationService, final TensorFlowInstallationHandler installationHandler) { 68 | super("TensorFlow library version management"); 69 | this.tensorFlowService = tensorFlowInstallationService; 70 | this.installationHandler = installationHandler; 71 | } 72 | 73 | public void init() { 74 | JPanel panel = new JPanel(); 75 | panel.setLayout(new MigLayout("height 400, wmax 600")); 76 | panel.add(new JLabel("Please select the TensorFlow version you would like to install."), "wrap"); 77 | panel.add(createFilterPanel(), "wrap, span, align right"); 78 | panel.add(createInstallPanel(), "wrap, span, grow"); 79 | panel.add(createStatus(), "span, grow"); 80 | setContentPane(panel); 81 | } 82 | 83 | private Component createFilterPanel() { 84 | gpuChoiceBox = new JComboBox<>(); 85 | gpuChoiceBox.addItem(NOFILTER); 86 | gpuChoiceBox.addItem("GPU"); 87 | gpuChoiceBox.addItem("CPU"); 88 | gpuChoiceBox.addActionListener((e) -> updateFilter()); 89 | cudaChoiceBox = new JComboBox<>(); 90 | cudaChoiceBox.addActionListener((e) -> updateFilter()); 91 | tfChoiceBox = new JComboBox<>(); 92 | tfChoiceBox.addActionListener((e) -> updateFilter()); 93 | JPanel panel = new JPanel(); 94 | panel.setLayout(new MigLayout()); 95 | panel.add(makeLabel("Filter by..")); 96 | panel.add(makeLabel("Mode: ")); 97 | panel.add(gpuChoiceBox); 98 | panel.add(makeLabel("CUDA: ")); 99 | panel.add(cudaChoiceBox); 100 | panel.add(makeLabel("TensorFlow: ")); 101 | panel.add(tfChoiceBox); 102 | return panel; 103 | } 104 | 105 | private Component createStatus() { 106 | status = new JTextArea(); 107 | status.setBorder(BorderFactory.createEmptyBorder()); 108 | status.setEditable(false); 109 | status.setWrapStyleWord(true); 110 | status.setLineWrap(true); 111 | status.setOpaque(false); 112 | return status; 113 | } 114 | 115 | private void updateFilter() { 116 | installPanel.removeAll(); 117 | buttons.forEach(btn -> { 118 | if(filter("", gpuChoiceBox, btn)) return; 119 | if(filter("CUDA ", cudaChoiceBox, btn)) return; 120 | if(filter("TF ", tfChoiceBox, btn)) return; 121 | installPanel.add(btn); 122 | }); 123 | installPanel.revalidate(); 124 | installPanel.repaint(); 125 | } 126 | 127 | private boolean filter(String title, JComboBox choiceBox, JRadioButton btn) { 128 | if(choiceBox.getSelectedItem().toString().equals(NOFILTER)) return false; 129 | return !btn.getText().contains(title + choiceBox.getSelectedItem().toString()); 130 | } 131 | 132 | private JLabel makeLabel(String s) { 133 | JLabel label = new JLabel(s); 134 | label.setHorizontalAlignment(SwingConstants.RIGHT); 135 | label.setHorizontalTextPosition(SwingConstants.RIGHT); 136 | return label; 137 | } 138 | 139 | private Component createInstallPanel() { 140 | installPanel = new JPanel(new MigLayout("flowy")); 141 | JScrollPane scroll = new JScrollPane(installPanel); 142 | scroll.setBorder(BorderFactory.createEmptyBorder()); 143 | installPanel.setBackground(LIST_BACKGROUND_COLOR); 144 | return scroll; 145 | } 146 | 147 | void updateChoices(List availableVersions) { 148 | availableVersions.sort(Comparator.comparing(DownloadableTensorFlowVersion::getComparableTFVersion).reversed()); 149 | this.availableVersions = availableVersions; 150 | updateCUDAChoices(); 151 | updateTFChoices(); 152 | ButtonGroup versionGroup = new ButtonGroup(); 153 | installPanel.removeAll(); 154 | for( DownloadableTensorFlowVersion version : availableVersions) { 155 | JRadioButton btn = new JRadioButton(version.toString()); 156 | btn.setToolTipText(version.getOriginDescription()); 157 | if(version.isActive()) { 158 | btn.setSelected(true); 159 | if(tensorFlowService.getStatus().isFailed()) { 160 | btn.setForeground(Color.red); 161 | } 162 | } 163 | btn.setOpaque(false); 164 | versionGroup.add(btn); 165 | buttons.add(btn); 166 | btn.addActionListener(e -> { 167 | if(btn.isSelected()) { 168 | new Thread(() -> activateVersion(version)).start(); 169 | } 170 | }); 171 | } 172 | updateFilter(); 173 | updateStatus(); 174 | } 175 | 176 | private void updateStatus() { 177 | status.setText(tensorFlowService.getStatus().getInfo()); 178 | if(!tensorFlowService.getStatus().isLoaded()) { 179 | status.setForeground(Color.red); 180 | } else { 181 | status.setForeground(Color.black); 182 | } 183 | } 184 | 185 | private void activateVersion(DownloadableTensorFlowVersion version) { 186 | if(version.isActive()) { 187 | System.out.println("[WARNING] Cannot activate version, already active: " + version); 188 | return; 189 | } 190 | showWaitMessage(); 191 | try { 192 | installationHandler.activateVersion(version); 193 | } catch (IOException e) { 194 | e.printStackTrace(); 195 | Object[] options = {"Yes", 196 | "No", 197 | "Cancel"}; 198 | int choice = JOptionPane.showOptionDialog(this, 199 | "Error while unpacking library file " + version.getLocalPath() + ".\nShould it be downloaded again?", 200 | "Unpacking library error", 201 | JOptionPane.YES_NO_CANCEL_OPTION, 202 | JOptionPane.QUESTION_MESSAGE, 203 | null, 204 | options, 205 | options[0]); 206 | if(choice == 0) { 207 | version.discardCache(); 208 | try { 209 | installationHandler.activateVersion(version); 210 | } catch (IOException e1) { 211 | e1.printStackTrace(); 212 | } 213 | } else { 214 | status.setText("Installation failed: " + e.getClass().getName() + ": " + e.getMessage()); 215 | return; 216 | } 217 | } 218 | status.setText(""); 219 | JOptionPane.showMessageDialog(null, 220 | "Installed selected TensorFlow version. Please restart Fiji to load it.", 221 | "Please restart", 222 | JOptionPane.PLAIN_MESSAGE); 223 | dispose(); 224 | } 225 | 226 | private void showWaitMessage() { 227 | status.setForeground(Color.black); 228 | status.setText("Please wait.."); 229 | } 230 | 231 | private void updateCUDAChoices() { 232 | Set choices = new LinkedHashSet<>(); 233 | for(DownloadableTensorFlowVersion version :availableVersions) { 234 | if(version.getCompatibleCUDA().isPresent()) 235 | choices.add(version.getCompatibleCUDA().get()); 236 | } 237 | cudaChoiceBox.removeAllItems(); 238 | cudaChoiceBox.addItem(NOFILTER); 239 | for(String choice : choices) { 240 | cudaChoiceBox.addItem(choice); 241 | } 242 | } 243 | 244 | private void updateTFChoices() { 245 | Set choices = new LinkedHashSet<>(); 246 | for(DownloadableTensorFlowVersion version :availableVersions) { 247 | if(version.getVersionNumber() != null) choices.add(version.getVersionNumber()); 248 | } 249 | tfChoiceBox.removeAllItems(); 250 | tfChoiceBox.addItem(NOFILTER); 251 | for(String choice : choices) { 252 | tfChoiceBox.addItem(choice); 253 | } 254 | } 255 | 256 | } 257 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/util/TensorFlowUtil.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.util; 32 | 33 | import net.imagej.tensorflow.TensorFlowVersion; 34 | import net.imagej.updater.util.Platforms; 35 | import org.scijava.log.Logger; 36 | import org.tensorflow.TensorFlow; 37 | 38 | import java.io.BufferedWriter; 39 | import java.io.File; 40 | import java.io.FileWriter; 41 | import java.io.IOException; 42 | import java.net.JarURLConnection; 43 | import java.net.URL; 44 | import java.net.URLClassLoader; 45 | import java.nio.file.Files; 46 | import java.nio.file.Path; 47 | import java.util.regex.Matcher; 48 | import java.util.regex.Pattern; 49 | 50 | /** 51 | * Utility methods for dealing with the Java TensorFlow library 52 | * 53 | * @author Deborah Schmidt 54 | * @author Benjamin Wilhelm 55 | */ 56 | public final class TensorFlowUtil { 57 | 58 | private static final String TFVERSIONFILE = ".tensorflowversion"; 59 | private static final String CRASHFILE = ".crashed"; 60 | private static final String LIBDIR = "lib"; 61 | private static final String UPDATEDIR = "update"; 62 | 63 | private static final String PLATFORM = Platforms.current(); 64 | 65 | private static Pattern jarVersionPattern = Pattern.compile("(?<=(libtensorflow-)).*(?=(.jar))"); 66 | 67 | private TensorFlowUtil(){} 68 | 69 | /** 70 | * @param jar the JAR file of the TensorFlow library 71 | * @return the version which will be loaded by the JAR file 72 | */ 73 | public static TensorFlowVersion getTensorFlowJARVersion(URL jar) { 74 | Matcher matcher = jarVersionPattern.matcher(jar.getPath()); 75 | if(matcher.find()) { 76 | // guess GPU support by looking for tensorflow_jni_gpu in the class path 77 | boolean supportsGPU = false; 78 | ClassLoader cl = ClassLoader.getSystemClassLoader(); 79 | for(URL url: ((URLClassLoader)cl).getURLs()){ 80 | if(url.getFile().contains("libtensorflow_jni_gpu")) { 81 | supportsGPU = true; 82 | break; 83 | } 84 | } 85 | return new TensorFlowVersion(matcher.group(), supportsGPU, null, null); 86 | } 87 | return null; 88 | } 89 | 90 | /** 91 | * @return The TensorFlow version included in the class path which will be loaded by default if no other native library is loaded beforehand. 92 | */ 93 | public static TensorFlowVersion versionFromClassPathJAR() { 94 | return getTensorFlowJARVersion(getTensorFlowJAR()); 95 | } 96 | 97 | /** 98 | * @return The JAR file URL shipping TensorFlow in Java 99 | */ 100 | public static URL getTensorFlowJAR() { 101 | URL resource = TensorFlow.class.getResource("TensorFlow.class"); 102 | JarURLConnection connection = null; 103 | try { 104 | connection = (JarURLConnection) resource.openConnection(); 105 | } catch (IOException e) { 106 | e.printStackTrace(); 107 | } 108 | return connection.getJarFileURL(); 109 | } 110 | 111 | /** 112 | * Deletes all TensorFlow native library files in {@link #getLibDir(String)}. 113 | * @param root the root path of ImageJ 114 | * @param logger 115 | */ 116 | public static void removeNativeLibraries(String root, Logger logger) { 117 | final File folder = new File(TensorFlowUtil.getLibDir(root), PLATFORM); 118 | if(!folder.exists()) { 119 | return; 120 | } 121 | final File[] listOfFiles = folder.listFiles(); 122 | for (File file : listOfFiles) { 123 | if (file.getName().toLowerCase().contains("tensorflow")) { 124 | logger.info("Deleting " + file); 125 | file.delete(); 126 | } 127 | } 128 | } 129 | 130 | /** 131 | * Reading the content of the {@link #TFVERSIONFILE} indicating which native version of TensorFlow is installed 132 | * @param root the root path of ImageJ 133 | * @return the current TensorFlow version installed in ImageJ/lib 134 | * @throws IOException in case the util file containing the native version number cannot be read (must not mean there is no usable native version) 135 | */ 136 | public static TensorFlowVersion readNativeVersionFile(String root) throws IOException { 137 | if(getNativeVersionFile(root).exists()) { 138 | Path path = getNativeVersionFile(root).toPath(); 139 | final String versionstr = new String(Files.readAllBytes(path)); 140 | final String[] parts = versionstr.split(","); 141 | if(parts.length >= 3) { 142 | String version = parts[1]; 143 | boolean gpuSupport = parts[2].toLowerCase().equals("gpu"); 144 | if(parts.length == 3) { 145 | return new TensorFlowVersion(version, gpuSupport, null, null); 146 | } 147 | if(parts.length == 5) { 148 | String cuda = parts[3]; 149 | String cudnn = parts[4]; 150 | return new TensorFlowVersion(version, gpuSupport, cuda, cudnn); 151 | } 152 | } else { 153 | throw new IOException("Content of " + path + " does not match expected format"); 154 | } 155 | } 156 | // unknown version origin 157 | return new TensorFlowVersion(TensorFlow.version(), null, null, null); 158 | } 159 | 160 | /** 161 | * Writing the content of the {@link #TFVERSIONFILE} indicating which native version of TensorFlow is installed 162 | * @param root the root path of ImageJ 163 | * @param platform the platform of the user (e.g. linux64, win64, macosx) 164 | * @param version the installed TensorFlow version 165 | */ 166 | public static void writeNativeVersionFile(String root, String platform, TensorFlowVersion version) { 167 | // create content 168 | StringBuilder content = new StringBuilder(); 169 | content.append(platform); 170 | content.append(","); 171 | content.append(version.getVersionNumber()); 172 | content.append(","); 173 | if(version.usesGPU().isPresent()) { 174 | content.append(version.usesGPU().get() ? "GPU" : "CPU"); 175 | } else { 176 | content.append("?"); 177 | } 178 | if(version.getCompatibleCuDNN().isPresent() || version.getCompatibleCUDA().isPresent()) { 179 | if(version.getCompatibleCUDA().isPresent()) { 180 | content.append(","); 181 | content.append(version.getCompatibleCUDA().get()); 182 | } else { 183 | content.append("?"); 184 | } 185 | if(version.getCompatibleCuDNN().isPresent()) { 186 | content.append(","); 187 | content.append(version.getCompatibleCuDNN().get()); 188 | } else { 189 | content.append("?"); 190 | } 191 | } 192 | //write content to file 193 | try (BufferedWriter writer = new BufferedWriter(new FileWriter(getNativeVersionFile(root)))) { 194 | writer.write(content.toString()); 195 | } catch (IOException e) { 196 | e.printStackTrace(); 197 | } 198 | } 199 | 200 | /** 201 | * @param root the root path of ImageJ 202 | * @return the crash File location indication if loading TensorFlow resulted in a JVM crash 203 | */ 204 | public static File getCrashFile(String root) { 205 | return new File(getLibDir(root) + PLATFORM + File.separator + CRASHFILE); 206 | } 207 | 208 | /** 209 | * @param root the root path of ImageJ 210 | * @return the file indication which native TensorFlow version is currently installed 211 | */ 212 | public static File getNativeVersionFile(String root) { 213 | return new File(getLibDir(root) + PLATFORM + File.separator + TFVERSIONFILE); 214 | } 215 | 216 | /** 217 | * @param root the root path of ImageJ 218 | * @return the directory where native libraries can be placed so that they are loadable from Java 219 | */ 220 | public static String getLibDir(String root) { 221 | return root + File.separator + LIBDIR + File.separator; 222 | } 223 | 224 | /** 225 | * @param root the root path of ImageJ 226 | * @return the directory where native libraries can be placed so that the updater will move them into {@link #getLibDir(String)} 227 | */ 228 | public static String getUpdateLibDir(String root) { 229 | return root + File.separator + UPDATEDIR + File.separator + LIBDIR + File.separator; 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /src/main/java/net/imagej/tensorflow/util/UnpackUtil.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow.util; 32 | 33 | import org.apache.commons.compress.archivers.tar.TarArchiveEntry; 34 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; 35 | import org.scijava.app.StatusService; 36 | import org.scijava.log.LogService; 37 | import org.scijava.util.ByteArray; 38 | 39 | import java.io.*; 40 | import java.nio.file.Files; 41 | import java.nio.file.Path; 42 | import java.util.zip.GZIPInputStream; 43 | import java.util.zip.ZipEntry; 44 | import java.util.zip.ZipInputStream; 45 | 46 | /** 47 | * Utility class for unpacking archives 48 | * TODO: generalize and move it to a place where others can use it as well 49 | * 50 | * @author Deborah schmidt 51 | */ 52 | public final class UnpackUtil { 53 | 54 | private UnpackUtil(){} 55 | 56 | public static void unGZip(String tarGzFile, File output, String symLinkOutputDir, LogService log, StatusService status) throws IOException { 57 | log("Unpacking " + tarGzFile + " to " + output, log , status); 58 | if (!output.exists()) { 59 | output.mkdirs(); 60 | } 61 | 62 | String tarFileName = tarGzFile.replace(".gz", ""); 63 | FileInputStream instream = new FileInputStream(tarGzFile); 64 | GZIPInputStream ginstream = new GZIPInputStream(instream); 65 | FileOutputStream outstream = new FileOutputStream(tarFileName); 66 | byte[] buf = new byte[1024]; 67 | int len; 68 | while ((len = ginstream.read(buf)) > 0) { 69 | outstream.write(buf, 0, len); 70 | 71 | } 72 | ginstream.close(); 73 | outstream.close(); 74 | TarArchiveInputStream myTarFile = new TarArchiveInputStream(new FileInputStream(tarFileName)); 75 | TarArchiveEntry entry; 76 | while ((entry = myTarFile.getNextTarEntry()) != null) { 77 | if (entry.isSymbolicLink() || entry.isLink()) { 78 | Path source = new File(output, entry.getName()).toPath(); 79 | Path target = new File(symLinkOutputDir, entry.getLinkName()).toPath(); 80 | deleteIfExists(source.toAbsolutePath().toString()); 81 | if (entry.isSymbolicLink()) { 82 | log("Creating symbolic link: " + source + " -> " + target, log, status); 83 | Files.createSymbolicLink(source, target); 84 | } else { 85 | log("Creating link: " + source + " -> " + target, log, status); 86 | Files.createLink(source, target); 87 | } 88 | } else { 89 | File outEntry = new File(output + "/" + entry.getName()); 90 | if (!outEntry.getParentFile().exists()) { 91 | outEntry.getParentFile().mkdirs(); 92 | } 93 | if (outEntry.isDirectory()) { 94 | continue; 95 | } 96 | log("Writing " + outEntry, log, status); 97 | final byte[] buf1 = new byte[64 * 1024]; 98 | final int size1 = (int) entry.getSize(); 99 | int len1 = 0; 100 | try (final FileOutputStream outEntryStream = new FileOutputStream(outEntry)) { 101 | while (true) { 102 | status.showStatus(len1, size1, "Unpacking " + entry.getName()); 103 | final int r = myTarFile.read(buf1, 0, buf1.length); 104 | if (r < 0) break; // end of entry 105 | len1 += r; 106 | outEntryStream.write(buf1, 0, r); 107 | } 108 | } 109 | } 110 | } 111 | status.clearStatus(); 112 | myTarFile.close(); 113 | File tarFile = new File(tarFileName); 114 | tarFile.delete(); 115 | } 116 | 117 | public static void unZip(String zipFile, File output, LogService log, StatusService status) throws IOException { 118 | log("Unpacking " + zipFile + " to " + output, log, status); 119 | output.mkdirs(); 120 | try (final ZipInputStream zis = new ZipInputStream(new FileInputStream(zipFile))) { 121 | unZip(zis, output, log, status); 122 | } 123 | } 124 | 125 | public static void unZip(File output, ByteArray byteArray, LogService log, StatusService status) throws IOException { 126 | // Extract the contents of the compressed data to the model cache. 127 | final ByteArrayInputStream bais = new ByteArrayInputStream(// 128 | byteArray.getArray(), 0, byteArray.size()); 129 | output.mkdirs(); 130 | try (final ZipInputStream zis = new ZipInputStream(bais)) { 131 | unZip(zis, output, log, status); 132 | } 133 | } 134 | 135 | private static void unZip(ZipInputStream zis, File output, LogService log, StatusService status) throws IOException { 136 | final byte[] buf = new byte[64 * 1024]; 137 | while (true) { 138 | final ZipEntry entry = zis.getNextEntry(); 139 | if (entry == null) break; // All done! 140 | final String name = entry.getName(); 141 | log("Unpacking " + name, log, status); 142 | final File outFile = new File(output, name); 143 | if (!outFile.toPath().normalize().startsWith(output.toPath().normalize())) { 144 | throw new RuntimeException("Bad zip entry"); 145 | } 146 | if (entry.isDirectory()) { 147 | outFile.mkdirs(); 148 | } 149 | else { 150 | final int size = (int) entry.getSize(); 151 | int len = 0; 152 | try (final FileOutputStream out = new FileOutputStream(outFile)) { 153 | while (true) { 154 | status.showStatus(len, size, "Unpacking " + name); 155 | final int r = zis.read(buf); 156 | if (r < 0) break; // end of entry 157 | len += r; 158 | out.write(buf, 0, r); 159 | } 160 | } 161 | } 162 | } 163 | status.clearStatus(); 164 | } 165 | 166 | private static void deleteIfExists(String filePath) { 167 | File file = new File(filePath); 168 | if(file.exists()) { 169 | file.delete(); 170 | } 171 | } 172 | 173 | private static void log(String msg, LogService log, StatusService status) { 174 | log.info(msg); 175 | status.showStatus(msg); 176 | } 177 | 178 | } 179 | -------------------------------------------------------------------------------- /src/test/java/net/imagej/tensorflow/TensorsTest.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | 31 | package net.imagej.tensorflow; 32 | 33 | import static org.junit.Assert.assertArrayEquals; 34 | import static org.junit.Assert.assertEquals; 35 | import static org.junit.Assert.fail; 36 | 37 | import java.nio.ByteBuffer; 38 | import java.nio.DoubleBuffer; 39 | import java.nio.FloatBuffer; 40 | import java.nio.IntBuffer; 41 | import java.nio.LongBuffer; 42 | import java.util.ArrayList; 43 | import java.util.List; 44 | import java.util.function.BiConsumer; 45 | import java.util.stream.IntStream; 46 | 47 | import org.junit.Test; 48 | import org.tensorflow.DataType; 49 | import org.tensorflow.Tensor; 50 | import org.tensorflow.types.UInt8; 51 | 52 | import com.google.common.base.Function; 53 | 54 | import net.imglib2.Point; 55 | import net.imglib2.RandomAccess; 56 | import net.imglib2.img.Img; 57 | import net.imglib2.img.array.ArrayImgFactory; 58 | import net.imglib2.type.numeric.RealType; 59 | import net.imglib2.type.numeric.integer.ByteType; 60 | import net.imglib2.type.numeric.integer.IntType; 61 | import net.imglib2.type.numeric.integer.LongType; 62 | import net.imglib2.type.numeric.real.DoubleType; 63 | import net.imglib2.type.numeric.real.FloatType; 64 | import net.imglib2.util.Intervals; 65 | 66 | public class TensorsTest { 67 | 68 | // --------- TENSOR to RAI --------- 69 | 70 | /** Tests the img(Tensor) functions */ 71 | @Test 72 | public void testTensorToImgReverse() { 73 | final long[] shape = new long[] { 20, 10, 3 }; 74 | final int[] mapping = new int[] { 2, 1, 0 }; 75 | final long[] dims = new long[] { 3, 10, 20 }; 76 | final int n = shape.length; 77 | final int size = 600; 78 | 79 | // Get some points to mark 80 | List points = createTestPoints(n); 81 | 82 | // Create Tensors of different type and convert them to images 83 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size); 84 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v)); 85 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte); 86 | Img imgByte = Tensors.imgByte(tensorByte); 87 | 88 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size); 89 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v)); 90 | Tensor tensorDouble = Tensor.create(shape, dataDouble); 91 | Img imgDouble = Tensors.imgDouble(tensorDouble); 92 | 93 | FloatBuffer dataFloat = FloatBuffer.allocate(size); 94 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v)); 95 | Tensor tensorFloat = Tensor.create(shape, dataFloat); 96 | Img imgFloat = Tensors.imgFloat(tensorFloat); 97 | 98 | IntBuffer dataInt = IntBuffer.allocate(size); 99 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v)); 100 | Tensor tensorInt = Tensor.create(shape, dataInt); 101 | Img imgInt = Tensors.imgInt(tensorInt); 102 | 103 | LongBuffer dataLong = LongBuffer.allocate(size); 104 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataLong.put(i, v)); 105 | Tensor tensorLong = Tensor.create(shape, dataLong); 106 | Img imgLong = Tensors.imgLong(tensorLong); 107 | 108 | // Check all created images 109 | checkImage(imgByte, n, dims, points); 110 | checkImage(imgDouble, n, dims, points); 111 | checkImage(imgFloat, n, dims, points); 112 | checkImage(imgInt, n, dims, points); 113 | checkImage(imgLong, n, dims, points); 114 | } 115 | 116 | /** Tests the imgDirect(Tensor) functions */ 117 | @Test 118 | public void testTensorToImgDirect() { 119 | final long[] shape = new long[] { 20, 10, 3 }; 120 | final int[] mapping = new int[] { 0, 1, 2 }; 121 | final int n = shape.length; 122 | final int size = 600; 123 | 124 | // Get some points to mark 125 | List points = createTestPoints(n); 126 | 127 | // Create Tensors of different type and convert them to images 128 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size); 129 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v)); 130 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte); 131 | Img imgByte = Tensors.imgByteDirect(tensorByte); 132 | 133 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size); 134 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v)); 135 | Tensor tensorDouble = Tensor.create(shape, dataDouble); 136 | Img imgDouble = Tensors.imgDoubleDirect(tensorDouble); 137 | 138 | FloatBuffer dataFloat = FloatBuffer.allocate(size); 139 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v)); 140 | Tensor tensorFloat = Tensor.create(shape, dataFloat); 141 | Img imgFloat = Tensors.imgFloatDirect(tensorFloat); 142 | 143 | IntBuffer dataInt = IntBuffer.allocate(size); 144 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v)); 145 | Tensor tensorInt = Tensor.create(shape, dataInt); 146 | Img imgInt = Tensors.imgIntDirect(tensorInt); 147 | 148 | LongBuffer dataLong = LongBuffer.allocate(size); 149 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataLong.put(i, v)); 150 | Tensor tensorLong = Tensor.create(shape, dataLong); 151 | Img imgLong = Tensors.imgLongDirect(tensorLong); 152 | 153 | // Check all created images 154 | checkImage(imgByte, n, shape, points); 155 | checkImage(imgDouble, n, shape, points); 156 | checkImage(imgFloat, n, shape, points); 157 | checkImage(imgInt, n, shape, points); 158 | checkImage(imgLong, n, shape, points); 159 | } 160 | 161 | /** Tests the img(Tensor, int[]) functions */ 162 | @Test 163 | public void testTensorToImgMapping() { 164 | final long[] shape = new long[] { 3, 5, 2, 4 }; 165 | final int[] mapping = new int[] { 1, 3, 0, 2 }; // A strange mapping 166 | final long[] dims = new long[] { 5, 4, 3, 2 }; 167 | final int n = shape.length; 168 | final int size = 120; 169 | 170 | // Get some points to mark 171 | List points = createTestPoints(n); 172 | 173 | // Create Tensors of different type and convert them to images 174 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size); 175 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v)); 176 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte); 177 | Img imgByte = Tensors.imgByte(tensorByte, mapping); 178 | 179 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size); 180 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v)); 181 | Tensor tensorDouble = Tensor.create(shape, dataDouble); 182 | Img imgDouble = Tensors.imgDouble(tensorDouble, mapping); 183 | 184 | FloatBuffer dataFloat = FloatBuffer.allocate(size); 185 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v)); 186 | Tensor tensorFloat = Tensor.create(shape, dataFloat); 187 | Img imgFloat = Tensors.imgFloat(tensorFloat, mapping); 188 | 189 | IntBuffer dataInt = IntBuffer.allocate(size); 190 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v)); 191 | Tensor tensorInt = Tensor.create(shape, dataInt); 192 | Img imgInt = Tensors.imgInt(tensorInt, mapping); 193 | 194 | LongBuffer dataLong = LongBuffer.allocate(size); 195 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataLong.put(i, v)); 196 | Tensor tensorLong = Tensor.create(shape, dataLong); 197 | Img imgLong = Tensors.imgLong(tensorLong, mapping); 198 | 199 | // Check all created images 200 | checkImage(imgByte, n, dims, points); 201 | checkImage(imgDouble, n, dims, points); 202 | checkImage(imgFloat, n, dims, points); 203 | checkImage(imgInt, n, dims, points); 204 | checkImage(imgLong, n, dims, points); 205 | } 206 | 207 | /** Checks one image for the dimensions and marked points */ 208 | private > void checkImage(final Img img, final int n, final long[] dims, 209 | final List points) { 210 | assertArrayEquals(dims, Intervals.dimensionsAsLongArray(img)); 211 | assertEquals(n, img.numDimensions()); 212 | checkPoints(img, points); 213 | } 214 | 215 | // --------- RAI to TENSOR --------- 216 | 217 | /** Tests the tensor(RAI) function */ 218 | @Test 219 | public void testImgToTensorReverse() { 220 | assertEquals(1, 1); 221 | 222 | final long[] dims = new long[] { 20, 10, 3 }; 223 | final long[] shape = new long[] { 3, 10, 20 }; 224 | final int n = dims.length; 225 | 226 | // ByteType 227 | testImg2TensorReverse(new ArrayImgFactory().create(dims, new ByteType()), n, shape, 228 | DataType.UINT8); 229 | 230 | // DoubleType 231 | testImg2TensorReverse(new ArrayImgFactory().create(dims, new DoubleType()), n, shape, 232 | DataType.DOUBLE); 233 | 234 | // FloatType 235 | testImg2TensorReverse(new ArrayImgFactory().create(dims, new FloatType()), n, shape, 236 | DataType.FLOAT); 237 | 238 | // IntType 239 | testImg2TensorReverse(new ArrayImgFactory().create(dims, new IntType()), n, shape, 240 | DataType.INT32); 241 | 242 | // LongType 243 | testImg2TensorReverse(new ArrayImgFactory().create(dims, new LongType()), n, shape, 244 | DataType.INT64); 245 | } 246 | 247 | /** Tests the tensorDirect(RAI) function */ 248 | @Test 249 | public void testImgToTensorDirect() { 250 | assertEquals(1, 1); 251 | 252 | final long[] dims = new long[] { 20, 10, 3 }; 253 | final int n = dims.length; 254 | 255 | // ByteType 256 | testImg2TensorDirect(new ArrayImgFactory().create(dims, new ByteType()), n, dims, 257 | DataType.UINT8); 258 | 259 | // DoubleType 260 | testImg2TensorDirect(new ArrayImgFactory().create(dims, new DoubleType()), n, dims, 261 | DataType.DOUBLE); 262 | 263 | // FloatType 264 | testImg2TensorDirect(new ArrayImgFactory().create(dims, new FloatType()), n, dims, 265 | DataType.FLOAT); 266 | 267 | // IntType 268 | testImg2TensorDirect(new ArrayImgFactory().create(dims, new IntType()), n, dims, 269 | DataType.INT32); 270 | 271 | // LongType 272 | testImg2TensorDirect(new ArrayImgFactory().create(dims, new LongType()), n, dims, 273 | DataType.INT64); 274 | } 275 | 276 | /** Tests the tensor(RAI, int[]) function */ 277 | @Test 278 | public void testImgToTensorMapping() { 279 | assertEquals(1, 1); 280 | 281 | final long[] dims = new long[] { 5, 4, 3, 2 }; 282 | final int[] mapping = new int[] { 1, 3, 0, 2 }; // A strange mapping 283 | final long[] shape = new long[] { 3, 5, 2, 4 }; 284 | final int n = dims.length; 285 | 286 | // ByteType 287 | testImg2TensorMapping(new ArrayImgFactory().create(dims, new ByteType()), mapping, n, shape, 288 | DataType.UINT8); 289 | 290 | // DoubleType 291 | testImg2TensorMapping(new ArrayImgFactory().create(dims, new DoubleType()), mapping, n, 292 | shape, DataType.DOUBLE); 293 | 294 | // FloatType 295 | testImg2TensorMapping(new ArrayImgFactory().create(dims, new FloatType()), mapping, n, shape, 296 | DataType.FLOAT); 297 | 298 | // IntType 299 | testImg2TensorMapping(new ArrayImgFactory().create(dims, new IntType()), mapping, n, shape, 300 | DataType.INT32); 301 | 302 | // LongType 303 | testImg2TensorMapping(new ArrayImgFactory().create(dims, new LongType()), mapping, n, shape, 304 | DataType.INT64); 305 | } 306 | 307 | /** Tests the tensor(RAI) function */ 308 | @Test 309 | public void testImgToTensorReverseTyped() { 310 | assertEquals(1, 1); 311 | 312 | final long[] dims = new long[] { 20, 10, 3 }; 313 | final long[] shape = new long[] { 3, 10, 20 }; 314 | final int n = dims.length; 315 | 316 | // ByteType 317 | testImg2TensorReverseTyped(new ArrayImgFactory().create(dims, new ByteType()), n, shape, 318 | UInt8.class, (i) -> Tensors.tensorByte(i)); 319 | 320 | // DoubleType 321 | testImg2TensorReverseTyped(new ArrayImgFactory().create(dims, new DoubleType()), n, shape, 322 | Double.class, (i) -> Tensors.tensorDouble(i)); 323 | 324 | // FloatType 325 | testImg2TensorReverseTyped(new ArrayImgFactory().create(dims, new FloatType()), n, shape, 326 | Float.class, (i) -> Tensors.tensorFloat(i)); 327 | 328 | // IntType 329 | testImg2TensorReverseTyped(new ArrayImgFactory().create(dims, new IntType()), n, shape, 330 | Integer.class, (i) -> Tensors.tensorInt(i)); 331 | 332 | // LongType 333 | testImg2TensorReverseTyped(new ArrayImgFactory().create(dims, new LongType()), n, shape, 334 | Long.class, (i) -> Tensors.tensorLong(i)); 335 | } 336 | 337 | /** Tests the tensorDirect(RAI) function */ 338 | @Test 339 | public void testImgToTensorDirectTyped() { 340 | assertEquals(1, 1); 341 | 342 | final long[] dims = new long[] { 20, 10, 3 }; 343 | final int n = dims.length; 344 | 345 | // ByteType 346 | testImg2TensorDirectTyped(new ArrayImgFactory().create(dims, new ByteType()), n, dims, 347 | UInt8.class, (i) -> Tensors.tensorByteDirect(i)); 348 | 349 | // DoubleType 350 | testImg2TensorDirectTyped(new ArrayImgFactory().create(dims, new DoubleType()), n, dims, 351 | Double.class, (i) -> Tensors.tensorDoubleDirect(i)); 352 | 353 | // FloatType 354 | testImg2TensorDirectTyped(new ArrayImgFactory().create(dims, new FloatType()), n, dims, 355 | Float.class, (i) -> Tensors.tensorFloatDirect(i)); 356 | 357 | // IntType 358 | testImg2TensorDirectTyped(new ArrayImgFactory().create(dims, new IntType()), n, dims, 359 | Integer.class, (i) -> Tensors.tensorIntDirect(i)); 360 | 361 | // LongType 362 | testImg2TensorDirectTyped(new ArrayImgFactory().create(dims, new LongType()), n, dims, 363 | Long.class, (i) -> Tensors.tensorLongDirect(i)); 364 | } 365 | 366 | /** Tests the tensor(RAI, int[]) function */ 367 | @Test 368 | public void testImgToTensorMappingTyped() { 369 | assertEquals(1, 1); 370 | 371 | final long[] dims = new long[] { 5, 4, 3, 2 }; 372 | final int[] mapping = new int[] { 1, 3, 0, 2 }; // A strange mapping 373 | final long[] shape = new long[] { 3, 5, 2, 4 }; 374 | final int n = dims.length; 375 | 376 | // ByteType 377 | testImg2TensorMappingTyped(new ArrayImgFactory().create(dims, new ByteType()), mapping, n, shape, 378 | UInt8.class, (i) -> Tensors.tensorByte(i, mapping)); 379 | 380 | // DoubleType 381 | testImg2TensorMappingTyped(new ArrayImgFactory().create(dims, new DoubleType()), mapping, n, shape, 382 | Double.class, (i) -> Tensors.tensorDouble(i, mapping)); 383 | 384 | // FloatType 385 | testImg2TensorMappingTyped(new ArrayImgFactory().create(dims, new FloatType()), mapping, n, shape, 386 | Float.class, (i) -> Tensors.tensorFloat(i, mapping)); 387 | 388 | // IntType 389 | testImg2TensorMappingTyped(new ArrayImgFactory().create(dims, new IntType()), mapping, n, shape, 390 | Integer.class, (i) -> Tensors.tensorInt(i, mapping)); 391 | 392 | // LongType 393 | testImg2TensorMappingTyped(new ArrayImgFactory().create(dims, new LongType()), mapping, n, shape, 394 | Long.class, (i) -> Tensors.tensorLong(i, mapping)); 395 | } 396 | 397 | /** Tests the tensor(RAI) function for one image */ 398 | private > void testImg2TensorReverse(final Img img, final int n, final long[] shape, 399 | final DataType t) { 400 | // Put some values to check into the image 401 | List points = createTestPoints(n); 402 | markPoints(img, points); 403 | 404 | Tensor tensor = Tensors.tensor(img); 405 | 406 | assertArrayEquals(shape, tensor.shape()); 407 | assertEquals(n, tensor.numDimensions()); 408 | assertEquals(t, tensor.dataType()); 409 | checkPointsTensor(tensor, IntStream.range(0, n).map(i -> n - 1 - i).toArray(), points); 410 | } 411 | 412 | /** Tests the tensorDirect(RAI) function for one image */ 413 | private > void testImg2TensorDirect(final Img img, final int n, final long[] shape, 414 | final DataType t) { 415 | // Put some values to check into the image 416 | List points = createTestPoints(n); 417 | markPoints(img, points); 418 | 419 | Tensor tensor = Tensors.tensorDirect(img); 420 | 421 | assertArrayEquals(shape, tensor.shape()); 422 | assertEquals(n, tensor.numDimensions()); 423 | assertEquals(t, tensor.dataType()); 424 | checkPointsTensor(tensor, IntStream.range(0, n).toArray(), points); 425 | } 426 | 427 | /** Tests the tensor(RAI, int[]) function for one image */ 428 | private > void testImg2TensorMapping(final Img img, final int[] mapping, 429 | final int n, final long[] shape, final DataType t) { 430 | // Put some values to check into the image 431 | List points = createTestPoints(n); 432 | markPoints(img, points); 433 | Tensor tensor = Tensors.tensor(img, mapping); 434 | 435 | assertArrayEquals(shape, tensor.shape()); 436 | assertEquals(n, tensor.numDimensions()); 437 | assertEquals(t, tensor.dataType()); 438 | checkPointsTensor(tensor, mapping, points); 439 | } 440 | 441 | /** Tests the tensor(RAI) typed function for one image */ 442 | private , U> void testImg2TensorReverseTyped(final Img img, final int n, final long[] shape, 443 | Class t, final Function, Tensor> converter) { 444 | // Put some values to check into the image 445 | List points = createTestPoints(n); 446 | markPoints(img, points); 447 | 448 | Tensor tensor = converter.apply(img); 449 | 450 | assertArrayEquals(shape, tensor.shape()); 451 | assertEquals(n, tensor.numDimensions()); 452 | assertEquals(DataType.fromClass(t), tensor.dataType()); 453 | checkPointsTensor(tensor, IntStream.range(0, n).map(i -> n - 1 - i).toArray(), points); 454 | } 455 | 456 | /** Tests the tensorDirect(RAI) typed function for one image */ 457 | private , U> void testImg2TensorDirectTyped(final Img img, final int n, final long[] shape, 458 | Class t, final Function, Tensor> converter) { 459 | // Put some values to check into the image 460 | List points = createTestPoints(n); 461 | markPoints(img, points); 462 | 463 | Tensor tensor = converter.apply(img); 464 | 465 | assertArrayEquals(shape, tensor.shape()); 466 | assertEquals(n, tensor.numDimensions()); 467 | assertEquals(DataType.fromClass(t), tensor.dataType()); 468 | checkPointsTensor(tensor, IntStream.range(0, n).toArray(), points); 469 | } 470 | 471 | /** Tests the tensor(RAI, int[]) typed function for one image */ 472 | private , U> void testImg2TensorMappingTyped(final Img img, final int[] mapping, final int n, 473 | final long[] shape, Class t, final Function, Tensor> converter) { 474 | // Put some values to check into the image 475 | List points = createTestPoints(n); 476 | markPoints(img, points); 477 | 478 | Tensor tensor = converter.apply(img); 479 | 480 | assertArrayEquals(shape, tensor.shape()); 481 | assertEquals(n, tensor.numDimensions()); 482 | assertEquals(DataType.fromClass(t), tensor.dataType()); 483 | checkPointsTensor(tensor, mapping, points); 484 | } 485 | 486 | /** Creates some interesting points to mark */ 487 | private List createTestPoints(int n) { 488 | List points = new ArrayList<>(); 489 | points.add(new Point(n)); 490 | for (int d = 0; d < n; d++) { 491 | Point p = new Point(n); 492 | p.fwd(d); 493 | points.add(p); 494 | } 495 | return points; 496 | } 497 | 498 | /** Marks a list of points in an image */ 499 | private > void markPoints(final Img img, final List points) { 500 | for (int i = 0; i < points.size(); i++) { 501 | RandomAccess randomAccess = img.randomAccess(); 502 | randomAccess.setPosition(points.get(i)); 503 | randomAccess.get().setReal(i + 1); 504 | } 505 | } 506 | 507 | /** Checks if points in an image are set to the right value */ 508 | private > void checkPoints(final Img img, List points) { 509 | for (int i = 0; i < points.size(); i++) { 510 | RandomAccess randomAccess = img.randomAccess(); 511 | randomAccess.setPosition(points.get(i)); 512 | assertEquals(i + 1, randomAccess.get().getRealFloat(), 0.001); 513 | } 514 | } 515 | 516 | /** 517 | * Calculates the index of a point in a buffer of a Tensor of the given 518 | * shape 519 | */ 520 | private int calcIndex(final long[] shape, final int[] mapping, final Point p) { 521 | assert p.numDimensions() == shape.length; 522 | int n = shape.length; 523 | 524 | // switch index and value of the mapping 525 | int[] switchedMapping = new int[mapping.length]; 526 | IntStream.range(0, mapping.length).forEach(i -> switchedMapping[mapping[i]] = i); 527 | 528 | int index = 0; 529 | for (int dimPreSize = 1, d = n - 1; d >= 0; d--) { 530 | index += dimPreSize * p.getIntPosition(switchedMapping[d]); 531 | dimPreSize *= shape[d]; 532 | } 533 | return index; 534 | } 535 | 536 | /** Checks the given points for the given tensor */ 537 | private void checkPointsTensor(final Tensor tensor, final int[] mapping, final List points) { 538 | switch (tensor.dataType()) { 539 | case UINT8: { 540 | ByteBuffer buffer = ByteBuffer.allocate(tensor.numElements()); 541 | tensor.writeTo(buffer); 542 | execForPointsWithBufferIndex(tensor.shape(), mapping, points, 543 | (i, v) -> assertEquals(v, buffer.get(i), 0.001)); 544 | } 545 | break; 546 | case DOUBLE: { 547 | DoubleBuffer buffer = DoubleBuffer.allocate(tensor.numElements()); 548 | tensor.writeTo(buffer); 549 | execForPointsWithBufferIndex(tensor.shape(), mapping, points, 550 | (i, v) -> assertEquals(v, buffer.get(i), 0.001)); 551 | } 552 | break; 553 | case FLOAT: { 554 | FloatBuffer buffer = FloatBuffer.allocate(tensor.numElements()); 555 | tensor.writeTo(buffer); 556 | execForPointsWithBufferIndex(tensor.shape(), mapping, points, 557 | (i, v) -> assertEquals(v, buffer.get(i), 0.001)); 558 | } 559 | break; 560 | case INT32: { 561 | IntBuffer buffer = IntBuffer.allocate(tensor.numElements()); 562 | tensor.writeTo(buffer); 563 | execForPointsWithBufferIndex(tensor.shape(), mapping, points, 564 | (i, v) -> assertEquals(v, buffer.get(i), 0.001)); 565 | } 566 | break; 567 | case INT64: { 568 | LongBuffer buffer = LongBuffer.allocate(tensor.numElements()); 569 | tensor.writeTo(buffer); 570 | execForPointsWithBufferIndex(tensor.shape(), mapping, points, 571 | (i, v) -> assertEquals(v, buffer.get(i), 0.001)); 572 | } 573 | break; 574 | default: 575 | // This should not happen because the type is checked before 576 | fail("Tensor has unsupported type."); 577 | } 578 | } 579 | 580 | /** 581 | * Calculates the index of a point in a Tensor buffer and executes the 582 | * BiConsumer with the index and the value of the point (which is the index 583 | * in the list + 1). 584 | */ 585 | private void execForPointsWithBufferIndex(final long[] shape, final int[] mapping, final List points, 586 | final BiConsumer exec) { 587 | for (int i = 0; i < points.size(); i++) { 588 | exec.accept(calcIndex(shape, mapping, points.get(i)), i + 1); 589 | } 590 | } 591 | } 592 | -------------------------------------------------------------------------------- /src/test/java/net/imagej/tensorflow/VersionBuilderTest.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | package net.imagej.tensorflow; 31 | 32 | import net.imagej.tensorflow.util.TensorFlowUtil; 33 | import org.junit.Test; 34 | import org.tensorflow.TensorFlow; 35 | 36 | import java.net.URL; 37 | 38 | import static org.junit.Assert.assertEquals; 39 | 40 | public class VersionBuilderTest { 41 | @Test 42 | public void testJARVersion() { 43 | URL source = TensorFlow.class.getResource("TensorFlow.class"); 44 | TensorFlowVersion version = TensorFlowUtil.getTensorFlowJARVersion(source); 45 | assertEquals(TensorFlow.version(), version.getVersionNumber()); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/test/java/net/imagej/tensorflow/demo/LabelImageTest.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | package net.imagej.tensorflow.demo; 31 | 32 | import net.imagej.Dataset; 33 | import net.imagej.ImageJ; 34 | import net.imagej.axis.Axes; 35 | import net.imagej.axis.AxisType; 36 | import net.imglib2.type.numeric.real.FloatType; 37 | import org.junit.Test; 38 | import org.scijava.command.CommandModule; 39 | 40 | import java.util.concurrent.ExecutionException; 41 | 42 | import static org.junit.Assert.assertNotNull; 43 | 44 | public class LabelImageTest { 45 | 46 | @Test 47 | public void runLabelImageCommand() throws ExecutionException, InterruptedException { 48 | final ImageJ ij = new ImageJ(); 49 | Dataset img = ij.dataset().create(new FloatType(), new long[]{10, 10, 3}, "", new AxisType[]{Axes.X, Axes.Y}); 50 | CommandModule module = ij.command().run(LabelImage.class, false, "inputImage", img).get(); 51 | assertNotNull(module); 52 | String labels = (String) module.getOutput("outputLabels"); 53 | assertNotNull(labels); 54 | System.out.println(labels); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/net/imagej/tensorflow/ui/AvailableTensorFlowVersionsTest.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | package net.imagej.tensorflow.ui; 31 | 32 | import org.junit.Test; 33 | 34 | import java.io.IOException; 35 | import java.net.HttpURLConnection; 36 | import java.util.List; 37 | 38 | import static org.junit.Assert.assertEquals; 39 | 40 | public class AvailableTensorFlowVersionsTest { 41 | 42 | @Test 43 | public void testURLs() throws IOException { 44 | List versions = AvailableTensorFlowVersions.get(); 45 | for (DownloadableTensorFlowVersion version : versions) { 46 | System.out.println("Testing " + version.getURL()); 47 | HttpURLConnection huc = (HttpURLConnection) version.getURL().openConnection(); 48 | huc.setRequestMethod("HEAD"); 49 | huc.connect(); 50 | assertEquals(HttpURLConnection.HTTP_OK, huc.getResponseCode()); 51 | } 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/test/java/net/imagej/tensorflow/ui/TensorFlowVersionTest.java: -------------------------------------------------------------------------------- 1 | /*- 2 | * #%L 3 | * ImageJ/TensorFlow integration. 4 | * %% 5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of 6 | * Wisconsin-Madison and Google, Inc. 7 | * %% 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * 11 | * 1. Redistributions of source code must retain the above copyright notice, 12 | * this list of conditions and the following disclaimer. 13 | * 2. Redistributions in binary form must reproduce the above copyright notice, 14 | * this list of conditions and the following disclaimer in the documentation 15 | * and/or other materials provided with the distribution. 16 | * 17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | * POSSIBILITY OF SUCH DAMAGE. 28 | * #L% 29 | */ 30 | package net.imagej.tensorflow.ui; 31 | 32 | import net.imagej.tensorflow.TensorFlowVersion; 33 | import org.junit.Test; 34 | 35 | import static org.junit.Assert.assertEquals; 36 | import static org.junit.Assert.assertNotEquals; 37 | 38 | public class TensorFlowVersionTest { 39 | 40 | @Test 41 | public void compareVersions() { 42 | TensorFlowVersion v1 = new TensorFlowVersion("1.4.0", true, null, null); 43 | TensorFlowVersion v2 = new TensorFlowVersion("1.4.0", true, null, null); 44 | assertEquals(v1, v2); 45 | 46 | v1 = new TensorFlowVersion("1.4.0", true, null, null); 47 | v2 = new TensorFlowVersion("1.4.0", false, null, null); 48 | assertNotEquals(v1, v2); 49 | 50 | v1 = new TensorFlowVersion("1.4.0", false, null, null); 51 | v2 = new TensorFlowVersion("1.4.0", false, "x.x", "y.y"); 52 | assertEquals(v1, v2); 53 | 54 | DownloadableTensorFlowVersion v3 = new DownloadableTensorFlowVersion("1.4.0", false); 55 | assertNotEquals(v3, v1); 56 | assertNotEquals(v1, v3); 57 | 58 | } 59 | 60 | } 61 | --------------------------------------------------------------------------------