destCursor = dest.localizingCursor();
1057 | while (destCursor.hasNext()) {
1058 | destCursor.fwd();
1059 | sourceAccess.setPosition(destCursor);
1060 | destCursor.get().set(sourceAccess.get());
1061 | }
1062 | }
1063 | }
1064 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/demo/LabelImage.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.demo;
32 |
33 | import java.io.IOException;
34 | import java.util.Arrays;
35 | import java.util.Comparator;
36 | import java.util.List;
37 | import java.util.stream.IntStream;
38 |
39 | import net.imagej.Dataset;
40 | import net.imagej.ImageJ;
41 | import net.imagej.tensorflow.GraphBuilder;
42 | import net.imagej.tensorflow.TensorFlowService;
43 | import net.imagej.tensorflow.Tensors;
44 | import net.imglib2.RandomAccessibleInterval;
45 | import net.imglib2.converter.Converters;
46 | import net.imglib2.converter.RealFloatConverter;
47 | import net.imglib2.img.Img;
48 | import net.imglib2.type.numeric.RealType;
49 | import net.imglib2.type.numeric.real.FloatType;
50 |
51 | import org.scijava.ItemIO;
52 | import org.scijava.command.Command;
53 | import org.scijava.io.http.HTTPLocation;
54 | import org.scijava.log.LogService;
55 | import org.scijava.plugin.Parameter;
56 | import org.scijava.plugin.Plugin;
57 | import org.tensorflow.Graph;
58 | import org.tensorflow.Output;
59 | import org.tensorflow.Session;
60 | import org.tensorflow.Tensor;
61 |
62 | /**
63 | * Command to use an Inception-image recognition model to label an image.
64 | *
65 | * See the
66 | * TensorFlow
67 | * image recognition tutorial and its
69 | * Java implementation.
70 | */
71 | @Plugin(type = Command.class, menuPath = "Plugins>Sandbox>TensorFlow Demos>Label Image",
72 | headless = true)
73 | public class LabelImage implements Command {
74 |
75 | private static final String INCEPTION_MODEL_URL =
76 | "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
77 |
78 | private static final String MODEL_NAME = "inception5h";
79 |
80 | @Parameter
81 | private TensorFlowService tensorFlowService;
82 |
83 | @Parameter
84 | private LogService log;
85 |
86 | @Parameter
87 | private Dataset inputImage;
88 |
89 | @Parameter(label = "Min probability (%)", min = "0", max = "100")
90 | private double minPercent = 1;
91 |
92 | @Parameter(type = ItemIO.OUTPUT)
93 | private Img outputImage;
94 |
95 | @Parameter(type = ItemIO.OUTPUT)
96 | private String outputLabels;
97 |
98 | @Override
99 | public void run() {
100 |
101 | tensorFlowService.loadLibrary();
102 | if(!tensorFlowService.getStatus().isLoaded()) return;
103 | log.info("Version of TensorFlow: " + tensorFlowService.getTensorFlowVersion());
104 |
105 | try {
106 | final HTTPLocation source = new HTTPLocation(INCEPTION_MODEL_URL);
107 | final Graph graph = tensorFlowService.loadGraph(source, MODEL_NAME,
108 | "tensorflow_inception_graph.pb");
109 | final List labels = tensorFlowService.loadLabels(source,
110 | MODEL_NAME, "imagenet_comp_graph_label_strings.txt");
111 | log.info("Loaded graph and " + labels.size() + " labels");
112 |
113 | try (
114 | final Tensor inputTensor = loadFromImgLib(inputImage);
115 | final Tensor image = normalizeImage(inputTensor)
116 | )
117 | {
118 | outputImage = Tensors.imgFloat(image, new int[]{ 2, 1, 3, 0 });
119 | final float[] labelProbabilities = executeInceptionGraph(graph, image);
120 |
121 | // Sort labels by probability.
122 | final int labelCount = Math.min(labelProbabilities.length, labels
123 | .size());
124 | final Integer[] labelIndices = IntStream.range(0, labelCount).boxed()
125 | .toArray(Integer[]::new);
126 | Arrays.sort(labelIndices, 0, labelCount, new Comparator() {
127 |
128 | @Override
129 | public int compare(final Integer i1, final Integer i2) {
130 | final float p1 = labelProbabilities[i1];
131 | final float p2 = labelProbabilities[i2];
132 | return p1 == p2 ? 0 : p1 > p2 ? -1 : 1;
133 | }
134 | });
135 |
136 | // Output labels above the probability threshold.
137 | final double cutoff = minPercent / 100; // % -> probability
138 | final StringBuilder sb = new StringBuilder();
139 | for (int i = 0; i < labelCount; i++) {
140 | final int index = labelIndices[i];
141 | final double p = labelProbabilities[index];
142 | if (p < cutoff) break;
143 |
144 | sb.append(String.format("%s (%.2f%% likely)\n", labels.get(index),
145 | labelProbabilities[index] * 100f));
146 | }
147 | outputLabels = sb.toString();
148 | }
149 | }
150 | catch (final Exception exc) {
151 | log.error(exc);
152 | }
153 | }
154 |
155 | @SuppressWarnings({ "rawtypes", "unchecked" })
156 | private static Tensor loadFromImgLib(final Dataset d) {
157 | return loadFromImgLib((RandomAccessibleInterval) d.getImgPlus());
158 | }
159 |
160 | private static > Tensor loadFromImgLib(
161 | final RandomAccessibleInterval image)
162 | {
163 | // NB: Assumes XYC ordering. TensorFlow wants YXC.
164 | RealFloatConverter converter = new RealFloatConverter<>();
165 | return Tensors.tensorFloat(Converters.convert(image, converter, new FloatType()), new int[]{ 1, 0, 2 });
166 | }
167 |
168 | // -----------------------------------------------------------------------------------------------------------------
169 | // All the code below was essentially copied verbatim from:
170 | // https://github.com/tensorflow/tensorflow/blob/e8f2aad0c0502fde74fc629f5b13f04d5d206700/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
171 | // -----------------------------------------------------------------------------------------------------------------
172 | private static Tensor normalizeImage(final Tensor t) {
173 | try (Graph g = new Graph()) {
174 | final GraphBuilder b = new GraphBuilder(g);
175 | // Some constants specific to the pre-trained model at:
176 | // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
177 | //
178 | // - The model was trained with images scaled to 224x224 pixels.
179 | // - The colors, represented as R, G, B in 1-byte each were converted to
180 | // float using (value - Mean)/Scale.
181 | final int H = 224;
182 | final int W = 224;
183 | final float mean = 117f;
184 | final float scale = 1f;
185 |
186 | // Since the graph is being constructed once per execution here, we can
187 | // use a constant for the input image. If the graph were to be re-used for
188 | // multiple input images, a placeholder would have been more appropriate.
189 | final Output input = g.opBuilder("Const", "input")//
190 | .setAttr("dtype", t.dataType())//
191 | .setAttr("value", t).build().output(0);
192 | final Output output = b.div(b.sub(b.resizeBilinear(b.expandDims(//
193 | input, //
194 | b.constant("make_batch", 0)), //
195 | b.constant("size", new int[] { H, W })), //
196 | b.constant("mean", mean)), //
197 | b.constant("scale", scale));
198 | try (Session s = new Session(g)) {
199 | @SuppressWarnings("unchecked")
200 | Tensor result = (Tensor) s.runner().fetch(output.op().name()).run().get(0);
201 | return result;
202 | }
203 | }
204 | }
205 |
206 | private static float[] executeInceptionGraph(final Graph g,
207 | final Tensor image)
208 | {
209 | try (
210 | final Session s = new Session(g);
211 | @SuppressWarnings("unchecked")
212 | final Tensor result = (Tensor) s.runner().feed("input", image)//
213 | .fetch("output").run().get(0)
214 | )
215 | {
216 | final long[] rshape = result.shape();
217 | if (result.numDimensions() != 2 || rshape[0] != 1) {
218 | throw new RuntimeException(String.format(
219 | "Expected model to produce a [1 N] shaped tensor where N is " +
220 | "the number of labels, instead it produced one with shape %s",
221 | Arrays.toString(rshape)));
222 | }
223 | final int nlabels = (int) rshape[1];
224 | return result.copyTo(new float[1][nlabels])[0];
225 | }
226 | }
227 |
228 | public static void main(String[] args) throws IOException {
229 | final ImageJ ij = new ImageJ();
230 | ij.launch(args);
231 |
232 | // Open an image and display it.
233 | final String imagePath = "http://samples.fiji.sc/new-lenna.jpg";
234 | final Object dataset = ij.io().open(imagePath);
235 | ij.ui().show(dataset);
236 |
237 | // Launch the "Label Images" command with some sensible defaults.
238 | ij.command().run(LabelImage.class, true, //
239 | "inputImage", dataset);
240 | }
241 | }
242 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/ui/AvailableTensorFlowVersions.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.ui;
32 |
33 | import java.net.MalformedURLException;
34 | import java.net.URL;
35 | import java.util.ArrayList;
36 | import java.util.List;
37 |
38 | /**
39 | * Utility class providing a hardcoded list of available native TensorFlow library versions which can be downloaded from the google servers.
40 | *
41 | * @author Deborah Schmidt
42 | * @author Benjamin Wilhelm
43 | */
44 | public class AvailableTensorFlowVersions {
45 |
46 | /**
47 | * @return a list of downloadable native TensorFlow versions for all operating systems, including their URL and optionally their CUDA / CuDNN compatibility
48 | */
49 | public static List get() {
50 | List versions = new ArrayList<>();
51 | versions.clear();
52 | //linux64
53 | versions.add(version("linux64", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.2.0.tar.gz"));
54 | versions.add(version("linux64", "1.2.0", "GPU", "8.0", "5.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.2.0.tar.gz"));
55 | versions.add(version("linux64", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.3.0.tar.gz"));
56 | versions.add(version("linux64", "1.3.0", "GPU", "8.0", "6", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.3.0.tar.gz"));
57 | versions.add(version("linux64", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.4.1.tar.gz"));
58 | versions.add(version("linux64", "1.4.1", "GPU", "8.0", "6", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.4.1.tar.gz"));
59 | versions.add(version("linux64", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.5.0.tar.gz"));
60 | versions.add(version("linux64", "1.5.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.5.0.tar.gz"));
61 | versions.add(version("linux64", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.6.0.tar.gz"));
62 | versions.add(version("linux64", "1.6.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.6.0.tar.gz"));
63 | versions.add(version("linux64", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.7.0.tar.gz"));
64 | versions.add(version("linux64", "1.7.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.7.0.tar.gz"));
65 | versions.add(version("linux64", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.8.0.tar.gz"));
66 | versions.add(version("linux64", "1.8.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.8.0.tar.gz"));
67 | versions.add(version("linux64", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.9.0.tar.gz"));
68 | versions.add(version("linux64", "1.9.0", "GPU", "9.0", "7", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.9.0.tar.gz"));
69 | versions.add(version("linux64", "1.10.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.10.1.tar.gz"));
70 | versions.add(version("linux64", "1.10.1", "GPU", "9.0", "7.?", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.10.1.tar.gz"));
71 | versions.add(version("linux64", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.11.0.tar.gz"));
72 | versions.add(version("linux64", "1.11.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.11.0.tar.gz"));
73 | versions.add(version("linux64", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.12.0.tar.gz"));
74 | versions.add(version("linux64", "1.12.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.12.0.tar.gz"));
75 | versions.add(version("linux64", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.13.1.tar.gz"));
76 | versions.add(version("linux64", "1.13.1", "GPU", "10.0", "7.4", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.13.1.tar.gz"));
77 | versions.add(version("linux64", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.14.0.tar.gz"));
78 | versions.add(version("linux64", "1.14.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.14.0.tar.gz"));
79 | versions.add(version("linux64", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-linux-x86_64-1.15.0.tar.gz"));
80 | versions.add(version("linux64", "1.15.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.15.0.tar.gz"));
81 | //win64
82 | versions.add(version("win64", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.2.0.zip"));
83 | versions.add(version("win64", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.3.0.zip"));
84 | versions.add(version("win64", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.4.1.zip"));
85 | versions.add(version("win64", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.5.0.zip"));
86 | versions.add(version("win64", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0.zip"));
87 | versions.add(version("win64", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.7.0.zip"));
88 | versions.add(version("win64", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0.zip"));
89 | versions.add(version("win64", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0.zip"));
90 | versions.add(version("win64", "1.10.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.10.0.zip"));
91 | versions.add(version("win64", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.11.0.zip"));
92 | versions.add(version("win64", "1.12.0", "GPU", "9.0", ">= 7.2", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.12.0.zip"));
93 | versions.add(version("win64", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.12.0.zip"));
94 | versions.add(version("win64", "1.13.1", "GPU", "10.0", "7.4", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.13.1.zip"));
95 | versions.add(version("win64", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.13.1.zip"));
96 | versions.add(version("win64", "1.14.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.14.0.zip"));
97 | versions.add(version("win64", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.14.0.zip"));
98 | versions.add(version("win64", "1.15.0", "GPU", "10.0", ">= 7.4.1", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-windows-x86_64-1.15.0.zip"));
99 | versions.add(version("win64", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.15.0.zip"));
100 | //macosx
101 | versions.add(version("macosx", "1.2.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.2.0.tar.gz"));
102 | versions.add(version("macosx", "1.3.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.3.0.tar.gz"));
103 | versions.add(version("macosx", "1.4.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.4.1.tar.gz"));
104 | versions.add(version("macosx", "1.5.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.5.0.tar.gz"));
105 | versions.add(version("macosx", "1.6.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.6.0.tar.gz"));
106 | versions.add(version("macosx", "1.7.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.7.0.tar.gz"));
107 | versions.add(version("macosx", "1.8.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.8.0.tar.gz"));
108 | versions.add(version("macosx", "1.9.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.9.0.tar.gz"));
109 | versions.add(version("macosx", "1.10.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.10.1.tar.gz"));
110 | versions.add(version("macosx", "1.11.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.11.0.tar.gz"));
111 | versions.add(version("macosx", "1.12.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.12.0.tar.gz"));
112 | versions.add(version("macosx", "1.13.1", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.13.1.tar.gz"));
113 | versions.add(version("macosx", "1.14.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.14.0.tar.gz"));
114 | versions.add(version("macosx", "1.15.0", "CPU", "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-darwin-x86_64-1.15.0.tar.gz"));
115 | return versions;
116 | }
117 |
118 | private static DownloadableTensorFlowVersion version(String platform, String tensorFlowVersion, String mode, String url) {
119 | DownloadableTensorFlowVersion version = new DownloadableTensorFlowVersion(tensorFlowVersion, mode.equals("GPU"));
120 | try {
121 | version.setURL(new URL(url));
122 | } catch (MalformedURLException e) {
123 | e.printStackTrace();
124 | }
125 | version.setPlatform(platform);
126 | return version;
127 | }
128 |
129 | private static DownloadableTensorFlowVersion version(String platform, String tensorFlowVersion, String mode, String cuda, String cudnn, String url) {
130 | try {
131 | return new DownloadableTensorFlowVersion(new URL(url), tensorFlowVersion, platform, mode.equals("GPU"), cuda, cudnn);
132 | } catch (MalformedURLException e) {
133 | e.printStackTrace();
134 | return null;
135 | }
136 | }
137 | }
138 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/ui/DownloadableTensorFlowVersion.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.ui;
32 |
33 | import net.imagej.tensorflow.TensorFlowVersion;
34 |
35 | import java.net.URL;
36 | import java.util.Objects;
37 |
38 | /**
39 | * This class extends {@link TensorFlowVersion} with knowledge about the web source of this version, the matching platform and whether it has already been downloaded.
40 | *
41 | * @author Deborah Schmidt
42 | * @author Benjamin Wilhelm
43 | */
44 | class DownloadableTensorFlowVersion extends TensorFlowVersion {
45 | private URL url;
46 | private String platform;
47 | private String localPath;
48 |
49 | private boolean active = false;
50 | private boolean downloaded = false;
51 |
52 | /**
53 | * @param url the URL of the packed JNI file
54 | * @param version the TensorFlow version number
55 | * @param os the platform this version is associated with
56 | * @param supportsGPU whether this version runs on the GPU
57 | * @param compatibleCUDA the CUDA version compatible with this TensorFlow version
58 | * @param compatibleCuDNN the CuDNN version compatible with this TensorFlow version
59 | */
60 | DownloadableTensorFlowVersion(URL url, String version, String os, boolean supportsGPU, String compatibleCUDA, String compatibleCuDNN) {
61 | super(version, supportsGPU, compatibleCUDA, compatibleCuDNN);
62 | this.platform = os;
63 | this.url = url;
64 | }
65 |
66 | /**
67 | * @param version the TensorFlow version number
68 | * @param supportsGPU whether this version runs on the GPU
69 | */
70 | DownloadableTensorFlowVersion(String version, boolean supportsGPU) {
71 | super(version, supportsGPU, null, null);
72 | }
73 |
74 | DownloadableTensorFlowVersion(TensorFlowVersion other) {
75 | super(other);
76 | }
77 |
78 | /**
79 | * @return the URL where an archive of this version can be downloaded from
80 | */
81 | public URL getURL() {
82 | return url;
83 | }
84 |
85 | void setURL(URL url) {
86 | this.url = url;
87 | }
88 |
89 | /**
90 | * @return the platform this version is associated with (linux64, linux32, win64, win32, macosx)
91 | */
92 | public String getPlatform() {
93 | return platform;
94 | }
95 |
96 | void setPlatform(String platform) {
97 | this.platform = platform;
98 | }
99 |
100 | /**
101 | * @return the path this version is cached to
102 | */
103 | public String getLocalPath() {
104 | return localPath;
105 | }
106 |
107 | /**
108 | * @return whether this version is currently being used
109 | */
110 | public boolean isActive() {
111 | return active;
112 | }
113 |
114 | void setActive(boolean active) {
115 | this.active = active;
116 | }
117 |
118 | /**
119 | * @return whether this version is cashed to a local path
120 | */
121 | public boolean isCached() {
122 | return downloaded;
123 | }
124 |
125 | /**
126 | * @return String describing the origin of this version
127 | */
128 | public String getOriginDescription() {
129 | if(downloaded) return localPath;
130 | if(url == null) return "unknown source";
131 | return url.toString();
132 | }
133 |
134 | /**
135 | * @return the TensorFlow version number in an easily comparable format (e.g. 0.13.1 to 000.013.001)
136 | */
137 | String getComparableTFVersion() {
138 | String orderableVersion = "";
139 | String[] split = tfVersion.split("\\.");
140 | for(String part : split) {
141 | orderableVersion += String.format("%03d%n", Integer.valueOf(part));
142 | }
143 | String gpu = "";
144 | if(supportsGPU.isPresent()) gpu += "_" + (supportsGPU.get() ? "cpu" : "gpu");
145 | return orderableVersion + gpu;
146 | }
147 |
148 | void setCached(String localPath) {
149 | downloaded = true;
150 | this.localPath = localPath;
151 | }
152 |
153 | void discardCache() {
154 | downloaded = false;
155 | }
156 |
157 | void harvest(DownloadableTensorFlowVersion other) {
158 | if(!cuda.isPresent() && other.cuda.isPresent()) {
159 | cuda = other.cuda;
160 | }
161 | if(!cudnn.isPresent() && other.cudnn.isPresent()) {
162 | cudnn = other.cudnn;
163 | }
164 | if(url == null) url = other.url;
165 | if(localPath == null) localPath = other.localPath;
166 | downloaded = other.downloaded;
167 | }
168 |
169 | @Override
170 | public boolean equals(final Object obj) {
171 | if (!(obj.getClass().equals(this.getClass()))) return false;
172 | final DownloadableTensorFlowVersion o = (DownloadableTensorFlowVersion) obj;
173 | return super.equals(o) && Objects.equals(platform, o.platform);
174 | }
175 | }
176 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/ui/TensorFlowInstallationHandler.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.ui;
32 |
33 | import net.imagej.tensorflow.util.TensorFlowUtil;
34 | import net.imagej.tensorflow.util.UnpackUtil;
35 | import org.scijava.app.AppService;
36 | import org.scijava.app.StatusService;
37 | import org.scijava.log.LogService;
38 | import org.scijava.plugin.Parameter;
39 | import org.scijava.plugin.Plugin;
40 | import org.scijava.plugin.SciJavaPlugin;
41 |
42 | import java.io.File;
43 | import java.io.IOException;
44 | import java.io.InputStream;
45 | import java.net.URL;
46 | import java.nio.file.Files;
47 | import java.nio.file.Paths;
48 | import java.nio.file.StandardCopyOption;
49 |
50 | /**
51 | * This class handles instances of {@link DownloadableTensorFlowVersion}.
52 | * It is responsible for downloading, caching, unpacking, and installing them.
53 | *
54 | * @author Deborah Schmidt
55 | * @author Benjamin Wilhelm
56 | */
57 | class TensorFlowInstallationHandler {
58 |
59 | @Parameter
60 | private AppService appService;
61 |
62 | @Parameter
63 | private LogService logService;
64 |
65 | @Parameter
66 | private StatusService statusService;
67 |
68 | private static final String DOWNLOADDIR = "downloads/";
69 |
70 | /**
71 | * Checks for a specific version whether it is downloaded and installed.
72 | * @param version the version which will be checked
73 | */
74 | void updateCacheStatus(final DownloadableTensorFlowVersion version) {
75 | if(version.getURL() == null) return;
76 | if (version.getURL().getProtocol().equalsIgnoreCase("file")) {
77 | version.setCached(version.getURL().getFile());
78 | } else {
79 | String path = getDownloadDir() + getNameFromURL(version.getURL());
80 | if (new File(path).exists()) {
81 | version.setCached(path);
82 | }
83 | }
84 | }
85 |
86 | private String getNameFromURL(URL url) {
87 | String[] parts = url.getPath().split("/");
88 | if(parts.length > 0) return parts[parts.length - 1];
89 | else return null;
90 | }
91 |
92 | /**
93 | * Activates a specific TensorFlow library which will then be used after restarting the application.
94 | * @param version the version being activated
95 | * @throws IOException thrown in case downloading the version failed or the version could not be unpacked
96 | */
97 | void activateVersion(DownloadableTensorFlowVersion version) throws IOException {
98 | if (!version.isCached()) {
99 | downloadVersion(version.getURL());
100 | }
101 | updateCacheStatus(version);
102 | if (!version.isActive()) {
103 | installVersion(version);
104 | }
105 | }
106 |
107 | private void downloadVersion(URL url) throws IOException {
108 | createDownloadDir();
109 | String filename = url.getFile().substring(url.getFile().lastIndexOf("/") + 1);
110 | String localFile = getDownloadDir() + filename;
111 | logService.info("Downloading " + url + " to " + localFile);
112 | InputStream in = url.openStream();
113 | Files.copy(in, Paths.get(localFile), StandardCopyOption.REPLACE_EXISTING);
114 | }
115 |
116 | private void createDownloadDir() {
117 | File downloadDir = new File(getDownloadDir());
118 | if (!downloadDir.exists()) {
119 | downloadDir.mkdirs();
120 | }
121 | }
122 |
123 | private void installVersion(DownloadableTensorFlowVersion version) throws IOException {
124 |
125 | logService.info("Installing " + version);
126 |
127 | File outputDir = new File(TensorFlowUtil.getUpdateLibDir(getRoot()) + version.getPlatform() + File.separator);
128 |
129 | if (version.getLocalPath().contains(".zip")) {
130 | UnpackUtil.unZip(version.getLocalPath(), outputDir, logService, statusService);
131 | TensorFlowUtil.writeNativeVersionFile(getRoot(), version.getPlatform(), version);
132 | } else if (version.getLocalPath().endsWith(".tar.gz")) {
133 | String symLinkOutputDir = TensorFlowUtil.getLibDir(getRoot()) + version.getPlatform() + File.separator;
134 | UnpackUtil.unGZip(version.getLocalPath(), outputDir, symLinkOutputDir, logService, statusService);
135 | TensorFlowUtil.writeNativeVersionFile(getRoot(), version.getPlatform(), version);
136 | }
137 |
138 | if (version.getLocalPath().endsWith(".jar")) {
139 | // using default JAR version.
140 | TensorFlowUtil.removeNativeLibraries(getRoot(), logService);
141 | logService.info("Using default JAR TensorFlow version.");
142 | }
143 |
144 | statusService.clearStatus();
145 |
146 | TensorFlowUtil.getCrashFile(getRoot()).delete();
147 | }
148 |
149 | private String getRoot() {
150 | return appService.getApp().getBaseDirectory().getAbsolutePath();
151 | }
152 |
153 | private String getDownloadDir() {
154 | return appService.getApp().getBaseDirectory().getAbsolutePath() + File.separator + DOWNLOADDIR;
155 | }
156 |
157 | }
158 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/ui/TensorFlowLibraryManagementCommand.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.ui;
32 |
33 | import net.imagej.ImageJ;
34 | import net.imagej.tensorflow.TensorFlowService;
35 | import net.imagej.tensorflow.TensorFlowVersion;
36 | import net.imagej.tensorflow.util.TensorFlowUtil;
37 | import net.imagej.updater.util.Platforms;
38 | import org.scijava.Context;
39 | import org.scijava.app.AppService;
40 | import org.scijava.command.Command;
41 | import org.scijava.log.LogService;
42 | import org.scijava.plugin.Parameter;
43 | import org.scijava.plugin.Plugin;
44 |
45 | import java.awt.*;
46 | import java.util.ArrayList;
47 | import java.util.List;
48 |
49 | /**
50 | * This command provides an interface for activating a specific Java TensorFlow version.
51 | * The user can choose between the version included via maven and archived native library versions
52 | * which can be downloaded from the Google servers.
53 | *
54 | * @author Deborah Schmidt
55 | * @author Benjamin Wilhelm
56 | */
57 | @Plugin(type = Command.class, menuPath = "Edit>Options>TensorFlow...")
58 | public class TensorFlowLibraryManagementCommand implements Command {
59 |
60 | @Parameter
61 | private TensorFlowService tensorFlowService;
62 |
63 | @Parameter
64 | private AppService appService;
65 |
66 | @Parameter
67 | private LogService logService;
68 |
69 | @Parameter
70 | private Context context;
71 |
72 | DownloadableTensorFlowVersion currentVersion;
73 | List availableVersions = new ArrayList<>();
74 |
75 | String platform = Platforms.current();
76 |
77 | private TensorFlowInstallationHandler installationHandler;
78 |
79 | @Override
80 | public void run() {
81 | tensorFlowService.loadLibrary();
82 | installationHandler = new TensorFlowInstallationHandler();
83 | context.inject(installationHandler);
84 | final TensorFlowLibraryManagementFrame frame = new TensorFlowLibraryManagementFrame(tensorFlowService, installationHandler);
85 | frame.init();
86 | initAvailableVersions();
87 | frame.updateChoices(availableVersions);
88 | frame.pack();
89 | frame.setLocationRelativeTo(null);
90 | frame.setMinimumSize(new Dimension(0, 200));
91 | frame.setVisible(true);
92 |
93 | }
94 |
95 | private void initAvailableVersions() {
96 | initCurrentVersion();
97 | if(currentVersion != null) addAvailableVersion(currentVersion);
98 | addAvailableVersion(getTensorFlowJARVersion());
99 | AvailableTensorFlowVersions.get().forEach(version -> addAvailableVersion(version));
100 | }
101 |
102 | private void initCurrentVersion() {
103 | TensorFlowVersion current = tensorFlowService.getTensorFlowVersion();
104 | if(current != null) {
105 | currentVersion = new DownloadableTensorFlowVersion(current);
106 | currentVersion.setPlatform(platform);
107 | currentVersion.setActive(true);
108 | }
109 | }
110 |
111 | private DownloadableTensorFlowVersion getTensorFlowJARVersion() {
112 | DownloadableTensorFlowVersion version = new DownloadableTensorFlowVersion(TensorFlowUtil.versionFromClassPathJAR());
113 | version.setPlatform(platform);
114 | version.setURL(TensorFlowUtil.getTensorFlowJAR());
115 | return version;
116 | }
117 |
118 | private void addAvailableVersion(DownloadableTensorFlowVersion version) {
119 | if(!version.getPlatform().equals(platform)) return;
120 | installationHandler.updateCacheStatus(version);
121 | for (DownloadableTensorFlowVersion other : availableVersions) {
122 | if (other.equals(version)) {
123 | other.harvest(version);
124 | return;
125 | }
126 | }
127 | availableVersions.add(version);
128 | }
129 |
130 | public static void main(String... args) {
131 | ImageJ ij = new ImageJ();
132 | ij.ui().showUI();
133 | ij.command().run(TensorFlowLibraryManagementCommand.class, true);
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/ui/TensorFlowLibraryManagementFrame.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.ui;
32 |
33 | import net.imagej.tensorflow.TensorFlowService;
34 | import net.miginfocom.swing.MigLayout;
35 |
36 | import javax.swing.*;
37 | import java.awt.*;
38 | import java.io.IOException;
39 | import java.util.*;
40 | import java.util.List;
41 |
42 | /**
43 | * The UI part of the {@link TensorFlowLibraryManagementCommand}
44 | *
45 | * @author Deborah Schmidt
46 | * @author Benjamin Wilhelm
47 | */
48 | class TensorFlowLibraryManagementFrame extends JFrame {
49 |
50 | private static final long serialVersionUID = 1L;
51 |
52 | private static final Color LIST_BACKGROUND_COLOR = new Color(250,250,250);
53 | private static final String NOFILTER = "-";
54 |
55 | private TensorFlowService tensorFlowService;
56 | private TensorFlowInstallationHandler installationHandler;
57 |
58 | private JComboBox gpuChoiceBox;
59 | private JComboBox cudaChoiceBox;
60 | private JComboBox tfChoiceBox;
61 | private JPanel installPanel;
62 | private JTextArea status;
63 |
64 | private List availableVersions = new ArrayList<>();
65 | private List buttons = new ArrayList<>();
66 |
67 | TensorFlowLibraryManagementFrame(final TensorFlowService tensorFlowInstallationService, final TensorFlowInstallationHandler installationHandler) {
68 | super("TensorFlow library version management");
69 | this.tensorFlowService = tensorFlowInstallationService;
70 | this.installationHandler = installationHandler;
71 | }
72 |
73 | public void init() {
74 | JPanel panel = new JPanel();
75 | panel.setLayout(new MigLayout("height 400, wmax 600"));
76 | panel.add(new JLabel("Please select the TensorFlow version you would like to install."), "wrap");
77 | panel.add(createFilterPanel(), "wrap, span, align right");
78 | panel.add(createInstallPanel(), "wrap, span, grow");
79 | panel.add(createStatus(), "span, grow");
80 | setContentPane(panel);
81 | }
82 |
83 | private Component createFilterPanel() {
84 | gpuChoiceBox = new JComboBox<>();
85 | gpuChoiceBox.addItem(NOFILTER);
86 | gpuChoiceBox.addItem("GPU");
87 | gpuChoiceBox.addItem("CPU");
88 | gpuChoiceBox.addActionListener((e) -> updateFilter());
89 | cudaChoiceBox = new JComboBox<>();
90 | cudaChoiceBox.addActionListener((e) -> updateFilter());
91 | tfChoiceBox = new JComboBox<>();
92 | tfChoiceBox.addActionListener((e) -> updateFilter());
93 | JPanel panel = new JPanel();
94 | panel.setLayout(new MigLayout());
95 | panel.add(makeLabel("Filter by.."));
96 | panel.add(makeLabel("Mode: "));
97 | panel.add(gpuChoiceBox);
98 | panel.add(makeLabel("CUDA: "));
99 | panel.add(cudaChoiceBox);
100 | panel.add(makeLabel("TensorFlow: "));
101 | panel.add(tfChoiceBox);
102 | return panel;
103 | }
104 |
105 | private Component createStatus() {
106 | status = new JTextArea();
107 | status.setBorder(BorderFactory.createEmptyBorder());
108 | status.setEditable(false);
109 | status.setWrapStyleWord(true);
110 | status.setLineWrap(true);
111 | status.setOpaque(false);
112 | return status;
113 | }
114 |
115 | private void updateFilter() {
116 | installPanel.removeAll();
117 | buttons.forEach(btn -> {
118 | if(filter("", gpuChoiceBox, btn)) return;
119 | if(filter("CUDA ", cudaChoiceBox, btn)) return;
120 | if(filter("TF ", tfChoiceBox, btn)) return;
121 | installPanel.add(btn);
122 | });
123 | installPanel.revalidate();
124 | installPanel.repaint();
125 | }
126 |
127 | private boolean filter(String title, JComboBox choiceBox, JRadioButton btn) {
128 | if(choiceBox.getSelectedItem().toString().equals(NOFILTER)) return false;
129 | return !btn.getText().contains(title + choiceBox.getSelectedItem().toString());
130 | }
131 |
132 | private JLabel makeLabel(String s) {
133 | JLabel label = new JLabel(s);
134 | label.setHorizontalAlignment(SwingConstants.RIGHT);
135 | label.setHorizontalTextPosition(SwingConstants.RIGHT);
136 | return label;
137 | }
138 |
139 | private Component createInstallPanel() {
140 | installPanel = new JPanel(new MigLayout("flowy"));
141 | JScrollPane scroll = new JScrollPane(installPanel);
142 | scroll.setBorder(BorderFactory.createEmptyBorder());
143 | installPanel.setBackground(LIST_BACKGROUND_COLOR);
144 | return scroll;
145 | }
146 |
147 | void updateChoices(List availableVersions) {
148 | availableVersions.sort(Comparator.comparing(DownloadableTensorFlowVersion::getComparableTFVersion).reversed());
149 | this.availableVersions = availableVersions;
150 | updateCUDAChoices();
151 | updateTFChoices();
152 | ButtonGroup versionGroup = new ButtonGroup();
153 | installPanel.removeAll();
154 | for( DownloadableTensorFlowVersion version : availableVersions) {
155 | JRadioButton btn = new JRadioButton(version.toString());
156 | btn.setToolTipText(version.getOriginDescription());
157 | if(version.isActive()) {
158 | btn.setSelected(true);
159 | if(tensorFlowService.getStatus().isFailed()) {
160 | btn.setForeground(Color.red);
161 | }
162 | }
163 | btn.setOpaque(false);
164 | versionGroup.add(btn);
165 | buttons.add(btn);
166 | btn.addActionListener(e -> {
167 | if(btn.isSelected()) {
168 | new Thread(() -> activateVersion(version)).start();
169 | }
170 | });
171 | }
172 | updateFilter();
173 | updateStatus();
174 | }
175 |
176 | private void updateStatus() {
177 | status.setText(tensorFlowService.getStatus().getInfo());
178 | if(!tensorFlowService.getStatus().isLoaded()) {
179 | status.setForeground(Color.red);
180 | } else {
181 | status.setForeground(Color.black);
182 | }
183 | }
184 |
185 | private void activateVersion(DownloadableTensorFlowVersion version) {
186 | if(version.isActive()) {
187 | System.out.println("[WARNING] Cannot activate version, already active: " + version);
188 | return;
189 | }
190 | showWaitMessage();
191 | try {
192 | installationHandler.activateVersion(version);
193 | } catch (IOException e) {
194 | e.printStackTrace();
195 | Object[] options = {"Yes",
196 | "No",
197 | "Cancel"};
198 | int choice = JOptionPane.showOptionDialog(this,
199 | "Error while unpacking library file " + version.getLocalPath() + ".\nShould it be downloaded again?",
200 | "Unpacking library error",
201 | JOptionPane.YES_NO_CANCEL_OPTION,
202 | JOptionPane.QUESTION_MESSAGE,
203 | null,
204 | options,
205 | options[0]);
206 | if(choice == 0) {
207 | version.discardCache();
208 | try {
209 | installationHandler.activateVersion(version);
210 | } catch (IOException e1) {
211 | e1.printStackTrace();
212 | }
213 | } else {
214 | status.setText("Installation failed: " + e.getClass().getName() + ": " + e.getMessage());
215 | return;
216 | }
217 | }
218 | status.setText("");
219 | JOptionPane.showMessageDialog(null,
220 | "Installed selected TensorFlow version. Please restart Fiji to load it.",
221 | "Please restart",
222 | JOptionPane.PLAIN_MESSAGE);
223 | dispose();
224 | }
225 |
226 | private void showWaitMessage() {
227 | status.setForeground(Color.black);
228 | status.setText("Please wait..");
229 | }
230 |
231 | private void updateCUDAChoices() {
232 | Set choices = new LinkedHashSet<>();
233 | for(DownloadableTensorFlowVersion version :availableVersions) {
234 | if(version.getCompatibleCUDA().isPresent())
235 | choices.add(version.getCompatibleCUDA().get());
236 | }
237 | cudaChoiceBox.removeAllItems();
238 | cudaChoiceBox.addItem(NOFILTER);
239 | for(String choice : choices) {
240 | cudaChoiceBox.addItem(choice);
241 | }
242 | }
243 |
244 | private void updateTFChoices() {
245 | Set choices = new LinkedHashSet<>();
246 | for(DownloadableTensorFlowVersion version :availableVersions) {
247 | if(version.getVersionNumber() != null) choices.add(version.getVersionNumber());
248 | }
249 | tfChoiceBox.removeAllItems();
250 | tfChoiceBox.addItem(NOFILTER);
251 | for(String choice : choices) {
252 | tfChoiceBox.addItem(choice);
253 | }
254 | }
255 |
256 | }
257 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/util/TensorFlowUtil.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.util;
32 |
33 | import net.imagej.tensorflow.TensorFlowVersion;
34 | import net.imagej.updater.util.Platforms;
35 | import org.scijava.log.Logger;
36 | import org.tensorflow.TensorFlow;
37 |
38 | import java.io.BufferedWriter;
39 | import java.io.File;
40 | import java.io.FileWriter;
41 | import java.io.IOException;
42 | import java.net.JarURLConnection;
43 | import java.net.URL;
44 | import java.net.URLClassLoader;
45 | import java.nio.file.Files;
46 | import java.nio.file.Path;
47 | import java.util.regex.Matcher;
48 | import java.util.regex.Pattern;
49 |
50 | /**
51 | * Utility methods for dealing with the Java TensorFlow library
52 | *
53 | * @author Deborah Schmidt
54 | * @author Benjamin Wilhelm
55 | */
56 | public final class TensorFlowUtil {
57 |
58 | private static final String TFVERSIONFILE = ".tensorflowversion";
59 | private static final String CRASHFILE = ".crashed";
60 | private static final String LIBDIR = "lib";
61 | private static final String UPDATEDIR = "update";
62 |
63 | private static final String PLATFORM = Platforms.current();
64 |
65 | private static Pattern jarVersionPattern = Pattern.compile("(?<=(libtensorflow-)).*(?=(.jar))");
66 |
67 | private TensorFlowUtil(){}
68 |
69 | /**
70 | * @param jar the JAR file of the TensorFlow library
71 | * @return the version which will be loaded by the JAR file
72 | */
73 | public static TensorFlowVersion getTensorFlowJARVersion(URL jar) {
74 | Matcher matcher = jarVersionPattern.matcher(jar.getPath());
75 | if(matcher.find()) {
76 | // guess GPU support by looking for tensorflow_jni_gpu in the class path
77 | boolean supportsGPU = false;
78 | ClassLoader cl = ClassLoader.getSystemClassLoader();
79 | for(URL url: ((URLClassLoader)cl).getURLs()){
80 | if(url.getFile().contains("libtensorflow_jni_gpu")) {
81 | supportsGPU = true;
82 | break;
83 | }
84 | }
85 | return new TensorFlowVersion(matcher.group(), supportsGPU, null, null);
86 | }
87 | return null;
88 | }
89 |
90 | /**
91 | * @return The TensorFlow version included in the class path which will be loaded by default if no other native library is loaded beforehand.
92 | */
93 | public static TensorFlowVersion versionFromClassPathJAR() {
94 | return getTensorFlowJARVersion(getTensorFlowJAR());
95 | }
96 |
97 | /**
98 | * @return The JAR file URL shipping TensorFlow in Java
99 | */
100 | public static URL getTensorFlowJAR() {
101 | URL resource = TensorFlow.class.getResource("TensorFlow.class");
102 | JarURLConnection connection = null;
103 | try {
104 | connection = (JarURLConnection) resource.openConnection();
105 | } catch (IOException e) {
106 | e.printStackTrace();
107 | }
108 | return connection.getJarFileURL();
109 | }
110 |
111 | /**
112 | * Deletes all TensorFlow native library files in {@link #getLibDir(String)}.
113 | * @param root the root path of ImageJ
114 | * @param logger
115 | */
116 | public static void removeNativeLibraries(String root, Logger logger) {
117 | final File folder = new File(TensorFlowUtil.getLibDir(root), PLATFORM);
118 | if(!folder.exists()) {
119 | return;
120 | }
121 | final File[] listOfFiles = folder.listFiles();
122 | for (File file : listOfFiles) {
123 | if (file.getName().toLowerCase().contains("tensorflow")) {
124 | logger.info("Deleting " + file);
125 | file.delete();
126 | }
127 | }
128 | }
129 |
130 | /**
131 | * Reading the content of the {@link #TFVERSIONFILE} indicating which native version of TensorFlow is installed
132 | * @param root the root path of ImageJ
133 | * @return the current TensorFlow version installed in ImageJ/lib
134 | * @throws IOException in case the util file containing the native version number cannot be read (must not mean there is no usable native version)
135 | */
136 | public static TensorFlowVersion readNativeVersionFile(String root) throws IOException {
137 | if(getNativeVersionFile(root).exists()) {
138 | Path path = getNativeVersionFile(root).toPath();
139 | final String versionstr = new String(Files.readAllBytes(path));
140 | final String[] parts = versionstr.split(",");
141 | if(parts.length >= 3) {
142 | String version = parts[1];
143 | boolean gpuSupport = parts[2].toLowerCase().equals("gpu");
144 | if(parts.length == 3) {
145 | return new TensorFlowVersion(version, gpuSupport, null, null);
146 | }
147 | if(parts.length == 5) {
148 | String cuda = parts[3];
149 | String cudnn = parts[4];
150 | return new TensorFlowVersion(version, gpuSupport, cuda, cudnn);
151 | }
152 | } else {
153 | throw new IOException("Content of " + path + " does not match expected format");
154 | }
155 | }
156 | // unknown version origin
157 | return new TensorFlowVersion(TensorFlow.version(), null, null, null);
158 | }
159 |
160 | /**
161 | * Writing the content of the {@link #TFVERSIONFILE} indicating which native version of TensorFlow is installed
162 | * @param root the root path of ImageJ
163 | * @param platform the platform of the user (e.g. linux64, win64, macosx)
164 | * @param version the installed TensorFlow version
165 | */
166 | public static void writeNativeVersionFile(String root, String platform, TensorFlowVersion version) {
167 | // create content
168 | StringBuilder content = new StringBuilder();
169 | content.append(platform);
170 | content.append(",");
171 | content.append(version.getVersionNumber());
172 | content.append(",");
173 | if(version.usesGPU().isPresent()) {
174 | content.append(version.usesGPU().get() ? "GPU" : "CPU");
175 | } else {
176 | content.append("?");
177 | }
178 | if(version.getCompatibleCuDNN().isPresent() || version.getCompatibleCUDA().isPresent()) {
179 | if(version.getCompatibleCUDA().isPresent()) {
180 | content.append(",");
181 | content.append(version.getCompatibleCUDA().get());
182 | } else {
183 | content.append("?");
184 | }
185 | if(version.getCompatibleCuDNN().isPresent()) {
186 | content.append(",");
187 | content.append(version.getCompatibleCuDNN().get());
188 | } else {
189 | content.append("?");
190 | }
191 | }
192 | //write content to file
193 | try (BufferedWriter writer = new BufferedWriter(new FileWriter(getNativeVersionFile(root)))) {
194 | writer.write(content.toString());
195 | } catch (IOException e) {
196 | e.printStackTrace();
197 | }
198 | }
199 |
200 | /**
201 | * @param root the root path of ImageJ
202 | * @return the crash File location indication if loading TensorFlow resulted in a JVM crash
203 | */
204 | public static File getCrashFile(String root) {
205 | return new File(getLibDir(root) + PLATFORM + File.separator + CRASHFILE);
206 | }
207 |
208 | /**
209 | * @param root the root path of ImageJ
210 | * @return the file indication which native TensorFlow version is currently installed
211 | */
212 | public static File getNativeVersionFile(String root) {
213 | return new File(getLibDir(root) + PLATFORM + File.separator + TFVERSIONFILE);
214 | }
215 |
216 | /**
217 | * @param root the root path of ImageJ
218 | * @return the directory where native libraries can be placed so that they are loadable from Java
219 | */
220 | public static String getLibDir(String root) {
221 | return root + File.separator + LIBDIR + File.separator;
222 | }
223 |
224 | /**
225 | * @param root the root path of ImageJ
226 | * @return the directory where native libraries can be placed so that the updater will move them into {@link #getLibDir(String)}
227 | */
228 | public static String getUpdateLibDir(String root) {
229 | return root + File.separator + UPDATEDIR + File.separator + LIBDIR + File.separator;
230 | }
231 | }
232 |
--------------------------------------------------------------------------------
/src/main/java/net/imagej/tensorflow/util/UnpackUtil.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow.util;
32 |
33 | import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
34 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
35 | import org.scijava.app.StatusService;
36 | import org.scijava.log.LogService;
37 | import org.scijava.util.ByteArray;
38 |
39 | import java.io.*;
40 | import java.nio.file.Files;
41 | import java.nio.file.Path;
42 | import java.util.zip.GZIPInputStream;
43 | import java.util.zip.ZipEntry;
44 | import java.util.zip.ZipInputStream;
45 |
46 | /**
47 | * Utility class for unpacking archives
48 | * TODO: generalize and move it to a place where others can use it as well
49 | *
50 | * @author Deborah schmidt
51 | */
52 | public final class UnpackUtil {
53 |
54 | private UnpackUtil(){}
55 |
56 | public static void unGZip(String tarGzFile, File output, String symLinkOutputDir, LogService log, StatusService status) throws IOException {
57 | log("Unpacking " + tarGzFile + " to " + output, log , status);
58 | if (!output.exists()) {
59 | output.mkdirs();
60 | }
61 |
62 | String tarFileName = tarGzFile.replace(".gz", "");
63 | FileInputStream instream = new FileInputStream(tarGzFile);
64 | GZIPInputStream ginstream = new GZIPInputStream(instream);
65 | FileOutputStream outstream = new FileOutputStream(tarFileName);
66 | byte[] buf = new byte[1024];
67 | int len;
68 | while ((len = ginstream.read(buf)) > 0) {
69 | outstream.write(buf, 0, len);
70 |
71 | }
72 | ginstream.close();
73 | outstream.close();
74 | TarArchiveInputStream myTarFile = new TarArchiveInputStream(new FileInputStream(tarFileName));
75 | TarArchiveEntry entry;
76 | while ((entry = myTarFile.getNextTarEntry()) != null) {
77 | if (entry.isSymbolicLink() || entry.isLink()) {
78 | Path source = new File(output, entry.getName()).toPath();
79 | Path target = new File(symLinkOutputDir, entry.getLinkName()).toPath();
80 | deleteIfExists(source.toAbsolutePath().toString());
81 | if (entry.isSymbolicLink()) {
82 | log("Creating symbolic link: " + source + " -> " + target, log, status);
83 | Files.createSymbolicLink(source, target);
84 | } else {
85 | log("Creating link: " + source + " -> " + target, log, status);
86 | Files.createLink(source, target);
87 | }
88 | } else {
89 | File outEntry = new File(output + "/" + entry.getName());
90 | if (!outEntry.getParentFile().exists()) {
91 | outEntry.getParentFile().mkdirs();
92 | }
93 | if (outEntry.isDirectory()) {
94 | continue;
95 | }
96 | log("Writing " + outEntry, log, status);
97 | final byte[] buf1 = new byte[64 * 1024];
98 | final int size1 = (int) entry.getSize();
99 | int len1 = 0;
100 | try (final FileOutputStream outEntryStream = new FileOutputStream(outEntry)) {
101 | while (true) {
102 | status.showStatus(len1, size1, "Unpacking " + entry.getName());
103 | final int r = myTarFile.read(buf1, 0, buf1.length);
104 | if (r < 0) break; // end of entry
105 | len1 += r;
106 | outEntryStream.write(buf1, 0, r);
107 | }
108 | }
109 | }
110 | }
111 | status.clearStatus();
112 | myTarFile.close();
113 | File tarFile = new File(tarFileName);
114 | tarFile.delete();
115 | }
116 |
117 | public static void unZip(String zipFile, File output, LogService log, StatusService status) throws IOException {
118 | log("Unpacking " + zipFile + " to " + output, log, status);
119 | output.mkdirs();
120 | try (final ZipInputStream zis = new ZipInputStream(new FileInputStream(zipFile))) {
121 | unZip(zis, output, log, status);
122 | }
123 | }
124 |
125 | public static void unZip(File output, ByteArray byteArray, LogService log, StatusService status) throws IOException {
126 | // Extract the contents of the compressed data to the model cache.
127 | final ByteArrayInputStream bais = new ByteArrayInputStream(//
128 | byteArray.getArray(), 0, byteArray.size());
129 | output.mkdirs();
130 | try (final ZipInputStream zis = new ZipInputStream(bais)) {
131 | unZip(zis, output, log, status);
132 | }
133 | }
134 |
135 | private static void unZip(ZipInputStream zis, File output, LogService log, StatusService status) throws IOException {
136 | final byte[] buf = new byte[64 * 1024];
137 | while (true) {
138 | final ZipEntry entry = zis.getNextEntry();
139 | if (entry == null) break; // All done!
140 | final String name = entry.getName();
141 | log("Unpacking " + name, log, status);
142 | final File outFile = new File(output, name);
143 | if (!outFile.toPath().normalize().startsWith(output.toPath().normalize())) {
144 | throw new RuntimeException("Bad zip entry");
145 | }
146 | if (entry.isDirectory()) {
147 | outFile.mkdirs();
148 | }
149 | else {
150 | final int size = (int) entry.getSize();
151 | int len = 0;
152 | try (final FileOutputStream out = new FileOutputStream(outFile)) {
153 | while (true) {
154 | status.showStatus(len, size, "Unpacking " + name);
155 | final int r = zis.read(buf);
156 | if (r < 0) break; // end of entry
157 | len += r;
158 | out.write(buf, 0, r);
159 | }
160 | }
161 | }
162 | }
163 | status.clearStatus();
164 | }
165 |
166 | private static void deleteIfExists(String filePath) {
167 | File file = new File(filePath);
168 | if(file.exists()) {
169 | file.delete();
170 | }
171 | }
172 |
173 | private static void log(String msg, LogService log, StatusService status) {
174 | log.info(msg);
175 | status.showStatus(msg);
176 | }
177 |
178 | }
179 |
--------------------------------------------------------------------------------
/src/test/java/net/imagej/tensorflow/TensorsTest.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ/TensorFlow integration.
4 | * %%
5 | * Copyright (C) 2017 - 2025 Board of Regents of the University of
6 | * Wisconsin-Madison and Google, Inc.
7 | * %%
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | *
11 | * 1. Redistributions of source code must retain the above copyright notice,
12 | * this list of conditions and the following disclaimer.
13 | * 2. Redistributions in binary form must reproduce the above copyright notice,
14 | * this list of conditions and the following disclaimer in the documentation
15 | * and/or other materials provided with the distribution.
16 | *
17 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
21 | * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | * POSSIBILITY OF SUCH DAMAGE.
28 | * #L%
29 | */
30 |
31 | package net.imagej.tensorflow;
32 |
33 | import static org.junit.Assert.assertArrayEquals;
34 | import static org.junit.Assert.assertEquals;
35 | import static org.junit.Assert.fail;
36 |
37 | import java.nio.ByteBuffer;
38 | import java.nio.DoubleBuffer;
39 | import java.nio.FloatBuffer;
40 | import java.nio.IntBuffer;
41 | import java.nio.LongBuffer;
42 | import java.util.ArrayList;
43 | import java.util.List;
44 | import java.util.function.BiConsumer;
45 | import java.util.stream.IntStream;
46 |
47 | import org.junit.Test;
48 | import org.tensorflow.DataType;
49 | import org.tensorflow.Tensor;
50 | import org.tensorflow.types.UInt8;
51 |
52 | import com.google.common.base.Function;
53 |
54 | import net.imglib2.Point;
55 | import net.imglib2.RandomAccess;
56 | import net.imglib2.img.Img;
57 | import net.imglib2.img.array.ArrayImgFactory;
58 | import net.imglib2.type.numeric.RealType;
59 | import net.imglib2.type.numeric.integer.ByteType;
60 | import net.imglib2.type.numeric.integer.IntType;
61 | import net.imglib2.type.numeric.integer.LongType;
62 | import net.imglib2.type.numeric.real.DoubleType;
63 | import net.imglib2.type.numeric.real.FloatType;
64 | import net.imglib2.util.Intervals;
65 |
66 | public class TensorsTest {
67 |
68 | // --------- TENSOR to RAI ---------
69 |
70 | /** Tests the img(Tensor) functions */
71 | @Test
72 | public void testTensorToImgReverse() {
73 | final long[] shape = new long[] { 20, 10, 3 };
74 | final int[] mapping = new int[] { 2, 1, 0 };
75 | final long[] dims = new long[] { 3, 10, 20 };
76 | final int n = shape.length;
77 | final int size = 600;
78 |
79 | // Get some points to mark
80 | List points = createTestPoints(n);
81 |
82 | // Create Tensors of different type and convert them to images
83 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size);
84 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v));
85 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte);
86 | Img imgByte = Tensors.imgByte(tensorByte);
87 |
88 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size);
89 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v));
90 | Tensor tensorDouble = Tensor.create(shape, dataDouble);
91 | Img imgDouble = Tensors.imgDouble(tensorDouble);
92 |
93 | FloatBuffer dataFloat = FloatBuffer.allocate(size);
94 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v));
95 | Tensor tensorFloat = Tensor.create(shape, dataFloat);
96 | Img imgFloat = Tensors.imgFloat(tensorFloat);
97 |
98 | IntBuffer dataInt = IntBuffer.allocate(size);
99 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v));
100 | Tensor tensorInt = Tensor.create(shape, dataInt);
101 | Img imgInt = Tensors.imgInt(tensorInt);
102 |
103 | LongBuffer dataLong = LongBuffer.allocate(size);
104 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataLong.put(i, v));
105 | Tensor tensorLong = Tensor.create(shape, dataLong);
106 | Img imgLong = Tensors.imgLong(tensorLong);
107 |
108 | // Check all created images
109 | checkImage(imgByte, n, dims, points);
110 | checkImage(imgDouble, n, dims, points);
111 | checkImage(imgFloat, n, dims, points);
112 | checkImage(imgInt, n, dims, points);
113 | checkImage(imgLong, n, dims, points);
114 | }
115 |
116 | /** Tests the imgDirect(Tensor) functions */
117 | @Test
118 | public void testTensorToImgDirect() {
119 | final long[] shape = new long[] { 20, 10, 3 };
120 | final int[] mapping = new int[] { 0, 1, 2 };
121 | final int n = shape.length;
122 | final int size = 600;
123 |
124 | // Get some points to mark
125 | List points = createTestPoints(n);
126 |
127 | // Create Tensors of different type and convert them to images
128 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size);
129 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v));
130 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte);
131 | Img imgByte = Tensors.imgByteDirect(tensorByte);
132 |
133 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size);
134 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v));
135 | Tensor tensorDouble = Tensor.create(shape, dataDouble);
136 | Img imgDouble = Tensors.imgDoubleDirect(tensorDouble);
137 |
138 | FloatBuffer dataFloat = FloatBuffer.allocate(size);
139 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v));
140 | Tensor tensorFloat = Tensor.create(shape, dataFloat);
141 | Img imgFloat = Tensors.imgFloatDirect(tensorFloat);
142 |
143 | IntBuffer dataInt = IntBuffer.allocate(size);
144 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v));
145 | Tensor tensorInt = Tensor.create(shape, dataInt);
146 | Img imgInt = Tensors.imgIntDirect(tensorInt);
147 |
148 | LongBuffer dataLong = LongBuffer.allocate(size);
149 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataLong.put(i, v));
150 | Tensor tensorLong = Tensor.create(shape, dataLong);
151 | Img imgLong = Tensors.imgLongDirect(tensorLong);
152 |
153 | // Check all created images
154 | checkImage(imgByte, n, shape, points);
155 | checkImage(imgDouble, n, shape, points);
156 | checkImage(imgFloat, n, shape, points);
157 | checkImage(imgInt, n, shape, points);
158 | checkImage(imgLong, n, shape, points);
159 | }
160 |
161 | /** Tests the img(Tensor, int[]) functions */
162 | @Test
163 | public void testTensorToImgMapping() {
164 | final long[] shape = new long[] { 3, 5, 2, 4 };
165 | final int[] mapping = new int[] { 1, 3, 0, 2 }; // A strange mapping
166 | final long[] dims = new long[] { 5, 4, 3, 2 };
167 | final int n = shape.length;
168 | final int size = 120;
169 |
170 | // Get some points to mark
171 | List points = createTestPoints(n);
172 |
173 | // Create Tensors of different type and convert them to images
174 | ByteBuffer dataByte = ByteBuffer.allocateDirect(size);
175 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataByte.put(i, (byte) (int) v));
176 | Tensor tensorByte = Tensor.create(UInt8.class, shape, dataByte);
177 | Img imgByte = Tensors.imgByte(tensorByte, mapping);
178 |
179 | DoubleBuffer dataDouble = DoubleBuffer.allocate(size);
180 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataDouble.put(i, v));
181 | Tensor tensorDouble = Tensor.create(shape, dataDouble);
182 | Img imgDouble = Tensors.imgDouble(tensorDouble, mapping);
183 |
184 | FloatBuffer dataFloat = FloatBuffer.allocate(size);
185 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataFloat.put(i, v));
186 | Tensor tensorFloat = Tensor.create(shape, dataFloat);
187 | Img imgFloat = Tensors.imgFloat(tensorFloat, mapping);
188 |
189 | IntBuffer dataInt = IntBuffer.allocate(size);
190 | execForPointsWithBufferIndex(shape, mapping, points, (i, v) -> dataInt.put(i, v));
191 | Tensor tensorInt = Tensor.create(shape, dataInt);
192 | Img