concatenated = //
105 | Views.concatenate(1, flippedTiles);
106 | return Views.permute(concatenated, 1, 0);
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/java/sc/fiji/imageFocus/Main.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ plugin to analyze focus quality of microscope images.
4 | * %%
5 | * Copyright (C) 2017 - 2020 Google, Inc. and Board of Regents
6 | * of the University of Wisconsin-Madison.
7 | * %%
8 | * Licensed under the Apache License, Version 2.0 (the "License");
9 | * you may not use this file except in compliance with the License.
10 | * You may obtain a copy of the License at
11 | *
12 | * http://www.apache.org/licenses/LICENSE-2.0
13 | *
14 | * Unless required by applicable law or agreed to in writing, software
15 | * distributed under the License is distributed on an "AS IS" BASIS,
16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | * See the License for the specific language governing permissions and
18 | * limitations under the License.
19 | * #L%
20 | */
21 |
22 | package sc.fiji.imageFocus;
23 |
24 | import java.io.File;
25 | import java.io.IOException;
26 |
27 | import net.imagej.ImageJ;
28 |
29 | import org.scijava.widget.FileWidget;
30 |
31 | /**
32 | * Test drive for the {@link MicroscopeImageFocusQualityClassifier} command.
33 | *
34 | * @author Curtis Rueden
35 | */
36 | public final class Main {
37 |
38 | public static void main(final String[] args) throws IOException {
39 | // Launch ImageJ.
40 | final ImageJ ij = new ImageJ();
41 | ij.launch(args);
42 |
43 | // Load an image.
44 | final File file = ij.ui().chooseFile(null, FileWidget.OPEN_STYLE);
45 | if (file == null) return;
46 | final Object data = ij.io().open(file.getAbsolutePath());
47 | ij.ui().show(data);
48 |
49 | // Run the command.
50 | ij.command().run(MicroscopeImageFocusQualityClassifier.class, true);
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/sc/fiji/imageFocus/MicroscopeImageFocusQualityClassifier.java:
--------------------------------------------------------------------------------
1 | /*-
2 | * #%L
3 | * ImageJ plugin to analyze focus quality of microscope images.
4 | * %%
5 | * Copyright (C) 2017 - 2020 Google, Inc. and Board of Regents
6 | * of the University of Wisconsin-Madison.
7 | * %%
8 | * Licensed under the Apache License, Version 2.0 (the "License");
9 | * you may not use this file except in compliance with the License.
10 | * You may obtain a copy of the License at
11 | *
12 | * http://www.apache.org/licenses/LICENSE-2.0
13 | *
14 | * Unless required by applicable law or agreed to in writing, software
15 | * distributed under the License is distributed on an "AS IS" BASIS,
16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | * See the License for the specific language governing permissions and
18 | * limitations under the License.
19 | * #L%
20 | */
21 |
22 | package sc.fiji.imageFocus;
23 |
24 | import ij.ImagePlus;
25 | import ij.gui.Line;
26 | import ij.gui.Overlay;
27 | import ij.gui.Roi;
28 | import ij.gui.TextRoi;
29 |
30 | import java.awt.Color;
31 | import java.io.IOException;
32 | import java.util.Arrays;
33 | import java.util.List;
34 |
35 | import net.imagej.Dataset;
36 | import net.imagej.DatasetService;
37 | import net.imagej.ImgPlus;
38 | import net.imagej.axis.Axes;
39 | import net.imagej.axis.AxisType;
40 | import net.imagej.display.ColorTables;
41 | import net.imagej.tensorflow.TensorFlowService;
42 | import net.imagej.tensorflow.Tensors;
43 | import net.imglib2.RandomAccess;
44 | import net.imglib2.RandomAccessibleInterval;
45 | import net.imglib2.display.ColorTable8;
46 | import net.imglib2.img.Img;
47 | import net.imglib2.type.numeric.RealType;
48 | import net.imglib2.type.numeric.integer.UnsignedShortType;
49 | import net.imglib2.type.numeric.real.FloatType;
50 |
51 | import org.scijava.Initializable;
52 | import org.scijava.ItemIO;
53 | import org.scijava.command.Command;
54 | import org.scijava.command.Previewable;
55 | import org.scijava.io.http.HTTPLocation;
56 | import org.scijava.log.LogService;
57 | import org.scijava.plugin.Parameter;
58 | import org.scijava.plugin.Plugin;
59 | import org.scijava.widget.NumberWidget;
60 | import org.tensorflow.SavedModelBundle;
61 | import org.tensorflow.Tensor;
62 | import org.tensorflow.framework.MetaGraphDef;
63 | import org.tensorflow.framework.SignatureDef;
64 | import org.tensorflow.framework.TensorInfo;
65 |
66 | /**
67 | * Command to apply the Microscopy image focus quality classifier model on an
68 | * input (16-bit, greyscale image).
69 | *
70 | * The model has been trained on raw 16-bit microscope images with
71 | * integer-valued inputs in {@code [0, 65535]}, where the black level is usually
72 | * in {@code [0, ~1000]} and the typical brightness of cells is in
73 | * {@code [~1000, ~10,000]}.
74 | *
75 | *
76 | * This command optionally produces a multi-channel image of probability values
77 | * where each channel corresponds to one focus class, and optionally adds an
78 | * overlay annotation to the input image to visualize the most likely focus
79 | * class of each region.
80 | *
81 | */
82 | @Plugin(type = Command.class,
83 | menuPath = "Plugins>Classification>Microscope Image Focus Quality")
84 | public class MicroscopeImageFocusQualityClassifier>
85 | implements Command, Initializable, Previewable
86 | {
87 |
88 | private static final String MODEL_URL =
89 | "https://downloads.imagej.net/fiji/models/microscope-image-quality-model.zip";
90 |
91 | private static final String MODEL_NAME = "microscope-image-quality";
92 |
93 | // Same as the tag used in export_saved_model in the Python code.
94 | private static final String MODEL_TAG = "inference";
95 |
96 | // Same as
97 | // tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
98 | // in Python. Perhaps this should be an exported constant in TensorFlow's Java
99 | // API.
100 | private static final String DEFAULT_SERVING_SIGNATURE_DEF_KEY =
101 | "serving_default";
102 |
103 | private static final int TILE_SIZE = 84;
104 |
105 | @Parameter
106 | private TensorFlowService tensorFlowService;
107 |
108 | @Parameter
109 | private DatasetService datasetService;
110 |
111 | @Parameter
112 | private LogService log;
113 |
114 | @Parameter(label = "Microscope Image")
115 | private Img originalImage;
116 |
117 | /**
118 | * ImageJ 1.x version of the image to process.
119 | *
120 | * Only used for overlaying patches.
121 | *
122 | */
123 | @Parameter(required = false)
124 | private ImagePlus originalImagePlus;
125 |
126 | private Overlay originalOverlay;
127 |
128 | @Parameter(label = "Number of tiles in X", persist = false,
129 | callback = "refreshTilePreview", min = "1",
130 | description = "The number of tiles to process in the X direction. " +
131 | "The smaller this value, the less
of the image will be covered " +
132 | "horizontally, but the faster the processing will be.")
133 | private long tileCountX = 1;
134 |
135 | @Parameter(label = "Number of tiles in Y", persist = false,
136 | callback = "refreshTilePreview", min = "1",
137 | description = "The number of tiles to process in the Y direction. " +
138 | "The smaller this value, the less
of the image will be covered " +
139 | "vertically, but the faster the processing will be.")
140 | private long tileCountY = 1;
141 |
142 | @Parameter(label = "Generate probability image",
143 | description = "When checked, a multi-channel image will be created " +
144 | "with one channel per focus level,
and each value corresponding " +
145 | "to the probability of that sample being at that focus level.")
146 | private boolean createProbabilityImage = true;
147 |
148 | @Parameter(label = "Overlay probability patches",
149 | description = "When checked, each classified region of the image " +
150 | "will be overlaid with a color
whose hue denotes the most likely " +
151 | "focus level and whose brightness denotes
the confidence " +
152 | "(i.e., probability) of the region being at that level.")
153 | private boolean overlayPatches = true;
154 |
155 | @Parameter(label = "Show patches as solid rectangles",
156 | description = "When checked, overlaid probability patches will be " +
157 | "filled semi-transparent
and solid; when unchecked, they will be " +
158 | "drawn as hollow boundary boxes.")
159 | private boolean solidPatches;
160 |
161 | @Parameter(label = "Displayed patch border width", //
162 | min = "1", max = "10", style = NumberWidget.SCROLL_BAR_STYLE,
163 | description = "When drawing probability patches as boundary boxes, " +
164 | "this option controls the box thickness.")
165 | private int borderWidth = 4;
166 |
167 | @Parameter(type = ItemIO.OUTPUT)
168 | private Dataset probDataset;
169 |
170 | @Override
171 | public void run() {
172 | try {
173 | validateFormat(originalImage);
174 | final RandomAccessibleInterval tiledImage = //
175 | Images.tile(originalImage, tileCountX, tileCountY, TILE_SIZE, TILE_SIZE);
176 | final RandomAccessibleInterval normalizedImage = //
177 | Images.normalize(tiledImage);
178 |
179 | final long loadModelStart = System.nanoTime();
180 | final HTTPLocation source = new HTTPLocation(MODEL_URL);
181 | try (final SavedModelBundle model = //
182 | tensorFlowService.loadModel(source, MODEL_NAME, MODEL_TAG)) {
183 | final long loadModelEnd = System.nanoTime();
184 | log.info(String.format(
185 | "Loaded microscope focus image quality model in %dms", (loadModelEnd -
186 | loadModelStart) / 1000000));
187 |
188 | // Extract names from the model signature.
189 | // The strings "input", "probabilities" and "patches" are meant to be
190 | // in sync with the model exporter (export_saved_model()) in Python.
191 | final SignatureDef sig = MetaGraphDef.parseFrom(model.metaGraphDef())
192 | .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
193 | try (final Tensor> inputTensor = Tensors.tensor(normalizedImage)) {
194 | // Run the model.
195 | final long runModelStart = System.nanoTime();
196 | final List> fetches = model.session().runner() //
197 | .feed(opName(sig.getInputsOrThrow("input")), inputTensor) //
198 | .fetch(opName(sig.getOutputsOrThrow("probabilities"))) //
199 | .fetch(opName(sig.getOutputsOrThrow("patches"))) //
200 | .run();
201 | final long runModelEnd = System.nanoTime();
202 | log.info(String.format("Ran image through model in %dms", //
203 | (runModelEnd - runModelStart) / 1000000));
204 |
205 | // Process the results.
206 | try (final Tensor> probabilities = fetches.get(0);
207 | final Tensor> patches = fetches.get(1))
208 | {
209 | processPatches(probabilities, patches);
210 | }
211 | }
212 | }
213 | }
214 | catch (final Exception exc) {
215 | // Use the LogService to report the error.
216 | log.error(exc);
217 | }
218 | }
219 |
220 | @Override
221 | public void initialize() {
222 | if (originalImage == null) return;
223 | tileCountX = Math.max(1, originalImage.dimension(0) / TILE_SIZE);
224 | tileCountY = Math.max(1, originalImage.dimension(1) / TILE_SIZE);
225 | if (originalImagePlus != null) {
226 | originalOverlay = originalImagePlus.getOverlay();
227 | }
228 | refreshTilePreview();
229 | }
230 |
231 | @Override
232 | public void preview() {
233 | // NB: No action needed.
234 | }
235 |
236 | @Override
237 | public void cancel() {
238 | if (originalImagePlus != null && originalOverlay != null) {
239 | originalImagePlus.setOverlay(originalOverlay);
240 | }
241 | }
242 |
243 | /** Callback method for {@link #tileCountX} and {@link #tileCountY}. */
244 | private void refreshTilePreview() {
245 | if (originalImagePlus == null) return;
246 |
247 | final int tileOpacity = 64;
248 | final Color evenColor = new Color(100, 255, 255, tileOpacity);
249 | final Color oddColor = new Color(255, 255, 100, tileOpacity);
250 |
251 | final long w = originalImage.dimension(0);
252 | final long h = originalImage.dimension(1);
253 |
254 | final Overlay overlay = new Overlay();
255 | for (long y = 0; y < tileCountY; y++) {
256 | final long offsetY = Images.offset(y, tileCountY, TILE_SIZE, h);
257 | for (long x = 0; x < tileCountX; x++) {
258 | final long offsetX = Images.offset(x, tileCountX, TILE_SIZE, w);
259 | final Roi tile = new Roi(offsetX, offsetY, TILE_SIZE, TILE_SIZE);
260 | tile.setFillColor((x + y) % 2 == 0 ? evenColor : oddColor);
261 | overlay.add(tile);
262 | }
263 | }
264 | originalImagePlus.setOverlay(overlay);
265 | }
266 |
267 | private void validateFormat(final Img image) throws IOException {
268 | final int ndims = image.numDimensions();
269 | if (ndims != 2) {
270 | final long[] dims = new long[ndims];
271 | image.dimensions(dims);
272 | throw new IOException("Can only process 2D images, not an image with " +
273 | ndims + " dimensions (" + Arrays.toString(dims) + ")");
274 | }
275 | if (!(image.firstElement() instanceof UnsignedShortType)) {
276 | throw new IOException("Can only process uint16 images. " +
277 | "Please convert your image first via Image > Type > 16-bit.");
278 | }
279 | }
280 |
281 | private void processPatches(final Tensor> probabilities,
282 | final Tensor> patches)
283 | {
284 | // Extract probability values.
285 | final long[] probShape = probabilities.shape();
286 | log.debug("Probabilities shape: " + Arrays.toString(probShape));
287 | final int probPatchCount = (int) probShape[0];
288 | final int classCount = (int) probShape[1];
289 | final float[][] probValues = new float[probPatchCount][classCount];
290 | probabilities.copyTo(probValues);
291 |
292 | // Extract and validate patch layout.
293 | final long[] patchShape = patches.shape();
294 | log.debug("Patches shape: " + Arrays.toString(patchShape));
295 | assert patchShape.length == 4;
296 | final int patchCount = (int) patchShape[0];
297 | assert patchCount == probPatchCount;
298 | final int patchHeight = (int) patchShape[1];
299 | final int patchWidth = (int) patchShape[2];
300 | assert patchShape[3] == 1;
301 |
302 | // Dump probabilities to the log.
303 | for (int i = 0; i < probShape[0]; ++i) {
304 | log.info(String.format("Patch %02d probabilities: %s", i, //
305 | Arrays.toString(probValues[i])));
306 | }
307 |
308 | // Synthesize matched-size image with computed probabilities.
309 | if (createProbabilityImage) {
310 | createProbabilityImage(classCount, probValues, patchHeight, patchWidth);
311 | }
312 |
313 | // Add ImageJ 1.x overlay to the active image.
314 | if (overlayPatches && originalImagePlus != null) {
315 | addOverlay(probValues, patchWidth, patchHeight);
316 | }
317 | }
318 |
319 | private void createProbabilityImage(final int classCount,
320 | final float[][] probValues, final int patchHeight, final int patchWidth)
321 | {
322 | // Create probability image.
323 | final long width = originalImage.dimension(0);
324 | final long height = originalImage.dimension(1);
325 | final long[] dims = { width, height, classCount };
326 | final AxisType[] axes = { Axes.X, Axes.Y, Axes.CHANNEL };
327 | final FloatType type = new FloatType();
328 | probDataset = datasetService.create(type, dims, "Probabilities", axes,
329 | false);
330 |
331 | // Set the probability image to normalized grayscale.
332 | probDataset.initializeColorTables(classCount);
333 | for (int c = 0; c < classCount; c++) {
334 | probDataset.setColorTable(ColorTables.GRAYS, c);
335 | probDataset.setChannelMinimum(c, 0);
336 | probDataset.setChannelMaximum(c, 1);
337 | }
338 |
339 | final ImgPlus probImg = probDataset.typedImg(type);
340 |
341 | // Cover the probability image with NaNs.
342 | // Real values will be written only to tile-covered areas.
343 | for (final FloatType sample : probImg) {
344 | sample.set(Float.NaN);
345 | }
346 |
347 | // Populate the probability image's sample values.
348 | final RandomAccess access = probImg.randomAccess();
349 | for (int t = 0; t < probValues.length; t++) {
350 | for (int c = 0; c < probValues[t].length; c++) {
351 | // Compute tile coordinates from probability value index.
352 | final long tx = t % tileCountX;
353 | final long ty = t / tileCountX;
354 |
355 | // Compute offset of tile in original image.
356 | final long offsetX = Images.offset(tx, tileCountX, patchWidth, width);
357 | final long offsetY = Images.offset(ty, tileCountY, patchHeight, height);
358 |
359 | // Copy the current value to every sample within the tile.
360 | final float value = probValues[t][c];
361 | access.setPosition(c, 2);
362 | access.setPosition(offsetY, 1);
363 | for (int y = 0; y < TILE_SIZE; y++) {
364 | access.setPosition(offsetX, 0);
365 | for (int x = 0; x < TILE_SIZE; x++) {
366 | access.get().set(value);
367 | access.fwd(0);
368 | }
369 | access.fwd(1);
370 | }
371 | }
372 | }
373 | }
374 |
375 | private void addOverlay(final float[][] probValues, //
376 | final int patchWidth, final int patchHeight)
377 | {
378 | final int patchCount = probValues.length;
379 | final int classCount = probValues[0].length;
380 |
381 | final Overlay overlay = new Overlay();
382 |
383 | final int strokeWidth = solidPatches ? 0 : borderWidth;
384 | final ColorTable8 lut = ColorTables.SPECTRUM;
385 | final int lutMaxIndex = 172;
386 |
387 | final long width = originalImage.dimension(0);
388 | final long height = originalImage.dimension(1);
389 |
390 | for (int p = 0; p < patchCount; p++) {
391 | final long tx = p % tileCountX;
392 | final long ty = p / tileCountX;
393 | final long offsetX = Images.offset(tx, tileCountX, patchWidth, width) + strokeWidth / 2;
394 | final long offsetY = Images.offset(ty, tileCountY, patchHeight, height) + strokeWidth / 2;
395 |
396 | final Roi roi = new Roi(offsetX, offsetY, //
397 | patchWidth - strokeWidth, patchHeight - strokeWidth);
398 | final int classIndex = maxIndex(probValues[p]);
399 | final double confidence = probValues[p][classIndex];
400 |
401 | // NB: We scale to (0, 172) here instead of (0, 255) to avoid the high
402 | // indices looping from blue and purple back into red where we started.
403 | final int lutIndex = lutMaxIndex * classIndex / (classCount - 1);
404 |
405 | final int r = (int) (lut.get(0, lutIndex) * confidence);
406 | final int g = (int) (lut.get(1, lutIndex) * confidence);
407 | final int b = (int) (lut.get(2, lutIndex) * confidence);
408 | final int opacity = solidPatches ? 128 : 255;
409 | final Color color = new Color(r, g, b, opacity);
410 | if (solidPatches) {
411 | roi.setFillColor(color);
412 | }
413 | else {
414 | roi.setStrokeColor(color);
415 | roi.setStrokeWidth(strokeWidth);
416 | }
417 | overlay.add(roi);
418 | }
419 |
420 | // Add color bar with legend.
421 |
422 | final int barHeight = 24, barPad = 5;
423 | final int barY = originalImagePlus.getHeight() - barHeight - barPad;
424 |
425 | final TextRoi labelGood = new TextRoi(barPad, barY, "In focus");
426 | labelGood.setStrokeColor(Color.white);
427 | overlay.add(labelGood);
428 |
429 | final int barOffset = 2 * barPad + (int) labelGood.getBounds().getWidth();
430 |
431 | final TextRoi labelBad = new TextRoi(barOffset + lutMaxIndex + barPad,
432 | barY, "Out of focus");
433 | labelBad.setStrokeColor(Color.white);
434 | overlay.add(labelBad);
435 |
436 | for (int i = 0; i < lutMaxIndex; i++) {
437 | final int barX = barOffset + i;
438 | final Roi line = new Line(barX, barY, barX, barY + barHeight);
439 | final int r = lut.get(0, i);
440 | final int g = lut.get(1, i);
441 | final int b = lut.get(2, i);
442 | line.setStrokeColor(new Color(r, g, b));
443 | overlay.add(line);
444 | }
445 |
446 | originalImagePlus.setOverlay(overlay);
447 | }
448 |
449 | private static int maxIndex(final float[] values) {
450 | float max = values[0];
451 | int index = 0;
452 | for (int i = 1; i < values.length; i++) {
453 | if (values[i] > max) {
454 | max = values[i];
455 | index = i;
456 | }
457 | }
458 | return index;
459 | }
460 |
461 | /**
462 | * The SignatureDef inputs and outputs contain names of the form
463 | * {@code :}, where for this model,
464 | * {@code } is always 0. This function trims the {@code :0}
465 | * suffix to get the operation name.
466 | */
467 | private static String opName(final TensorInfo t) {
468 | final String n = t.getName();
469 | if (n.endsWith(":0")) {
470 | return n.substring(0, n.lastIndexOf(":0"));
471 | }
472 | return n;
473 | }
474 | }
475 |
--------------------------------------------------------------------------------