├── Matlab
├── saved_recons
│ └── readme.txt
├── helper_functions
│ ├── soft_thresh.m
│ ├── Ltrans.m
│ ├── false_color_calib.mat
│ ├── power_iteration.m
│ ├── false_color_function.m
│ ├── tlv.m
│ ├── tv3dApproxHaar.m
│ ├── fista_spectral_3d.m
│ └── imshow3D.m
├── SampleData
│ └── readme.txt
└── reconstruction_demo.m
├── Python
├── SampleData
│ └── readme.txt
├── helper_functions
│ ├── false_color_calib.mat
│ ├── tv_approx_haar_cp.py
│ ├── tv_approx_haar_np.py
│ └── helper_functions.py
└── fista_spectral_cupy.py
├── environment.yml
├── .gitignore
├── LICENSE
└── README.md
/Matlab/saved_recons/readme.txt:
--------------------------------------------------------------------------------
1 | Saved reconstructions will be saved here.
--------------------------------------------------------------------------------
/Matlab/helper_functions/soft_thresh.m:
--------------------------------------------------------------------------------
1 | function out = soft_thresh(x,tau)
2 |
3 | out = max(abs(x)-tau,0);
4 | out = out.*sign(x);
--------------------------------------------------------------------------------
/Matlab/SampleData/readme.txt:
--------------------------------------------------------------------------------
1 | The sample data can be found here: https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing
--------------------------------------------------------------------------------
/Python/SampleData/readme.txt:
--------------------------------------------------------------------------------
1 | The sample data can be found here: https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing
--------------------------------------------------------------------------------
/Matlab/helper_functions/Ltrans.m:
--------------------------------------------------------------------------------
1 | function P=Ltrans(X)
2 |
3 | [m,n]=size(X);
4 |
5 | P{1}=X(1:m-1,:)-X(2:m,:);
6 | P{2}=X(:,1:n-1)-X(:,2:n);
7 |
--------------------------------------------------------------------------------
/Matlab/helper_functions/false_color_calib.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/SpectralDiffuserCam/HEAD/Matlab/helper_functions/false_color_calib.mat
--------------------------------------------------------------------------------
/Python/helper_functions/false_color_calib.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/SpectralDiffuserCam/HEAD/Python/helper_functions/false_color_calib.mat
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: SpectralDiffuserCam
2 | channels:
3 | - defaults
4 | dependencies:
5 | - anaconda
6 | - cupy
7 | prefix: /home/kristina/anaconda3/envs/SpectralDiffuserCam
8 |
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Jupyter Notebook
2 | .ipynb_checkpoints
3 | */.ipynb_checkpoints/*
4 | ipython_config.py
5 |
6 | # pyenv
7 | .python-version
8 |
9 | # Ignore image files and numpy files
10 | *.png
11 | *.jpg
12 | *.jpeg
13 | *.tiff
14 | *.npy
15 | calibration.mat
16 | __pycache__
--------------------------------------------------------------------------------
/Matlab/helper_functions/power_iteration.m:
--------------------------------------------------------------------------------
1 |
2 | function eig_b = power_iteration(A, sample_vect, num_iters)
3 | bk = rand(size(sample_vect));
4 | for i=1:num_iters
5 | bk1 = A(bk);
6 | bk1_norm = norm(bk1);
7 |
8 | bk = bk1./bk1_norm;
9 | end
10 |
11 | Mx = A(bk);
12 | xx = transpose(bk(:))*bk(:);
13 | eig_b = (transpose(bk(:))*Mx(:))/xx;
14 | end
--------------------------------------------------------------------------------
/Matlab/helper_functions/false_color_function.m:
--------------------------------------------------------------------------------
1 | function false_color = false_color_function(x)
2 |
3 | load('false_color_calib.mat');
4 |
5 | scaling = [1,1,2.5];
6 | false_color = zeros(size(x, 1), size(x, 2), 3);
7 | for i=1:64
8 | false_color(:,:,1) = false_color(:,:,1) + red(i)*scaling(1)*x(:,:,i);
9 | false_color(:,:,2) = false_color(:,:,2) + green(i)*scaling(2)*x(:,:,i);
10 | false_color(:,:,3) = false_color(:,:,3) + blue(i)*scaling(3)*x(:,:,i);
11 | end
12 |
13 | false_color = false_color/max(max(max(false_color)));
14 | return
--------------------------------------------------------------------------------
/Matlab/helper_functions/tlv.m:
--------------------------------------------------------------------------------
1 | function out=tlv(X,type, gpu)
2 | %This function computes the total variation of an input image X
3 | %
4 | % INPUT
5 | %
6 | % X............................. An image
7 | % type .................... Type of total variation function. Either 'iso'
8 | % (isotropic) or 'l1' (nonisotropic)
9 | %
10 | % OUTPUT
11 | % out ....................... The total variation of X.
12 | [m,n]=size(X);
13 | P=Ltrans(X);
14 |
15 | switch type
16 | case 'iso'
17 | D=zeros(m,n);
18 | %if existsOnGPU(X)
19 | % D = gpuArray(D);
20 | %end
21 | if gpu
22 | D = gpuArray(D);
23 | end
24 | D(1:m-1,:)=P{1}.^2;
25 | D(:,1:n-1)=D(:,1:n-1)+P{2}.^2;
26 | out=sum(sum(sqrt(D)));
27 | case 'l1'
28 | out=sum(sum(abs(P{1})))+sum(sum(abs(P{2})));
29 | otherwise
30 | error('Invalid total variation type. Should be either "iso" or "l1"');
31 | end
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Waller Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/Python/helper_functions/tv_approx_haar_cp.py:
--------------------------------------------------------------------------------
1 | import cupy as np
2 |
3 | def soft_py(x, tau):
4 | threshed = np.maximum(np.abs(x)-tau, 0)
5 | threshed = threshed*np.sign(x)
6 | return threshed
7 |
8 | def ht3(x, ax, shift, thresh):
9 | C = 1./np.sqrt(2.)
10 |
11 | if shift == True:
12 | x = np.roll(x, -1, axis = ax)
13 | if ax == 0:
14 | w1 = C*(x[1::2,:,:] + x[0::2, :, :])
15 | w2 = soft_py(C*(x[1::2,:,:] - x[0::2, :, :]), thresh)
16 | elif ax == 1:
17 | w1 = C*(x[:, 1::2,:] + x[:, 0::2, :])
18 | w2 = soft_py(C*(x[:,1::2,:] - x[:,0::2, :]), thresh)
19 | elif ax == 2:
20 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2])
21 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh)
22 | return w1, w2
23 |
24 | def iht3(w1, w2, ax, shift, shape):
25 |
26 | C = 1./np.sqrt(2.)
27 | y = np.zeros(shape)
28 |
29 | x1 = C*(w1 - w2); x2 = C*(w1 + w2);
30 | if ax == 0:
31 | y[0::2, :, :] = x1
32 | y[1::2, :, :] = x2
33 |
34 | if ax == 1:
35 | y[:, 0::2, :] = x1
36 | y[:, 1::2, :] = x2
37 | if ax == 2:
38 | y[:, :, 0::2] = x1
39 | y[:, :, 1::2] = x2
40 |
41 |
42 | if shift == True:
43 | y = np.roll(y, 1, axis = ax)
44 | return y
45 |
46 |
47 | def iht3_py2(w1, w2, ax, shift, shape):
48 |
49 | C = 1./np.sqrt(2.)
50 | y = np.zeros(shape)
51 |
52 | x1 = C*(w1 - w2); x2 = C*(w1 + w2);
53 |
54 | ind = ax + 2;
55 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape)
56 |
57 |
58 | if shift == True:
59 | y = np.roll(y, 1, axis = ax+1)
60 | return y
61 |
62 | def tv3dApproxHaar(x, tau, alpha):
63 | D = 3
64 | fact = np.sqrt(2)*2
65 |
66 | thresh = D*tau*fact
67 |
68 |
69 | y = np.zeros_like(x)
70 | for ax in range(0,len(x.shape)):
71 | if ax ==2:
72 | t_scale = alpha
73 | else:
74 | t_scale = 1;
75 |
76 | w0, w1 = ht3(x, ax, False, thresh*t_scale)
77 | w2, w3 = ht3(x, ax, True, thresh*t_scale)
78 |
79 | t1 = iht3(w0, w1, ax, False, x.shape)
80 | t2 = iht3(w2, w3, ax, True, x.shape)
81 | y = y + t1 + t2
82 |
83 | y = y/(2*D)
84 | return y
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/Python/helper_functions/tv_approx_haar_np.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def soft_py(x, tau):
4 | threshed = np.maximum(np.abs(x)-tau, 0)
5 | threshed = threshed*np.sign(x)
6 | return threshed
7 |
8 | def ht3(x, ax, shift, thresh):
9 | C = 1./np.sqrt(2.)
10 |
11 | if shift == True:
12 | x = np.roll(x, -1, axis = ax)
13 | if ax == 0:
14 | w1 = C*(x[1::2,:,:] + x[0::2, :, :])
15 | w2 = soft_py(C*(x[1::2,:,:] - x[0::2, :, :]), thresh)
16 | elif ax == 1:
17 | w1 = C*(x[:, 1::2,:] + x[:, 0::2, :])
18 | w2 = soft_py(C*(x[:,1::2,:] - x[:,0::2, :]), thresh)
19 | elif ax == 2:
20 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2])
21 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh)
22 | return w1, w2
23 |
24 | def iht3(w1, w2, ax, shift, shape):
25 |
26 | C = 1./np.sqrt(2.)
27 | y = np.zeros(shape)
28 |
29 | x1 = C*(w1 - w2); x2 = C*(w1 + w2);
30 | if ax == 0:
31 | y[0::2, :, :] = x1
32 | y[1::2, :, :] = x2
33 |
34 | if ax == 1:
35 | y[:, 0::2, :] = x1
36 | y[:, 1::2, :] = x2
37 | if ax == 2:
38 | y[:, :, 0::2] = x1
39 | y[:, :, 1::2] = x2
40 |
41 |
42 | if shift == True:
43 | y = np.roll(y, 1, axis = ax)
44 | return y
45 |
46 |
47 | def iht3_py2(w1, w2, ax, shift, shape):
48 |
49 | C = 1./np.sqrt(2.)
50 | y = np.zeros(shape)
51 |
52 | x1 = C*(w1 - w2); x2 = C*(w1 + w2);
53 |
54 | ind = ax + 2;
55 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape)
56 |
57 |
58 | if shift == True:
59 | y = np.roll(y, 1, axis = ax+1)
60 | return y
61 |
62 | def tv3dApproxHaar(x, tau, alpha):
63 | D = 3
64 | fact = np.sqrt(2)*2
65 |
66 | thresh = D*tau*fact
67 |
68 |
69 | y = np.zeros_like(x)
70 | for ax in range(0,len(x.shape)):
71 | if ax ==2:
72 | t_scale = alpha
73 | else:
74 | t_scale = 1;
75 |
76 | w0, w1 = ht3(x, ax, False, thresh*t_scale)
77 | w2, w3 = ht3(x, ax, True, thresh*t_scale)
78 |
79 | t1 = iht3(w0, w1, ax, False, x.shape)
80 | t2 = iht3(w2, w3, ax, True, x.shape)
81 | y = y + t1 + t2
82 |
83 | y = y/(2*D)
84 | return y
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/Matlab/helper_functions/tv3dApproxHaar.m:
--------------------------------------------------------------------------------
1 | function y = tv3dApproxHaar(x, tau, alphay, alphaz)
2 | D = 3;
3 | gamma = 1; %step size
4 | thresh = sqrt(2) * 2 * D * tau * gamma;
5 | y = zeros(size(x), 'like', x);
6 | for axis = 1 : 3
7 | if axis == 3
8 | t_scale = alphaz;
9 | elseif axis ==1
10 | t_scale = alphay;
11 | else
12 | t_scale = 1;
13 | end
14 | y = y + iht3(ht3(x, axis, false, thresh*t_scale), axis, false);
15 | y = y + iht3(ht3(x, axis, true, thresh*t_scale), axis, true);
16 | end
17 | y = y / (2 * D);
18 | return
19 |
20 | function w = ht3(x, ax, shift, thresh)
21 | s = size(x);
22 | w = zeros(s, 'like', x);
23 | C = 1 / sqrt(2);
24 | if shift
25 | x = circshift(x, -1, ax);
26 | end
27 | m = floor(s(ax) / 2);
28 | if ax == 1
29 | w(1:m, :, :) = C * (x(2:2:end, :, :) + x(1:2:end, :, :));
30 | w((m + 1):end, :, :) = hs_soft(C * (x(2:2:end, :, :) - x(1:2:end, :, :)), thresh);
31 | %w((m + 1):end, :, :) = hs_soft(w((m + 1):end, :, :), thresh);
32 | elseif ax == 2
33 | w(:, 1:m, :) = C * (x(:, 2:2:end, :) + x(:, 1:2:end, :));
34 | w(:, (m + 1):end, :) = C * (x(:, 2:2:end, :) - x(:, 1:2:end, :));
35 | w(:, (m + 1):end, :) = hs_soft(w(:, (m + 1):end, :), thresh);
36 | else
37 | w(:, :, 1:m) = C * (x(:, :, 2:2:end) + x(:, :, 1:2:end));
38 | w(:, :, (m + 1):end) = C * (x(:, :, 2:2:end) - x(:, :, 1:2:end));
39 | w(:, :, (m + 1):end) = hs_soft(w(:, :, (m + 1):end), thresh);
40 | end
41 | return
42 |
43 | function y = iht3(w, ax, shift)
44 | s = size(w);
45 | y = zeros(s, 'like', w);
46 | C = 1 / sqrt(2);
47 | m = floor(s(ax) / 2);
48 | if ax == 1
49 | y(1:2:end, :, :) = C * (w(1:m, :, :) - w((m + 1):end, :, :));
50 | y(2:2:end, :, :) = C * (w(1:m, :, :) + w((m + 1):end, :, :));
51 | elseif ax == 2
52 | y(:, 1:2:end, :) = C * (w(:, 1:m, :) - w(:, (m + 1):end, :));
53 | y(:, 2:2:end, :) = C * (w(:, 1:m, :) + w(:, (m + 1):end, :));
54 | else
55 | y(:, :, 1:2:end) = C * (w(:, :, 1:m) - w(:, :, (m + 1):end));
56 | y(:, :, 2:2:end) = C * (w(:, :, 1:m) + w(:, :, (m + 1):end));
57 | end
58 | if shift
59 | y = circshift(y, 1, ax);
60 | end
61 | return
62 |
63 | function threshed = hs_soft(x,tau)
64 |
65 | threshed = max(abs(x)-tau,0);
66 | threshed = threshed.*sign(x);
67 | return
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | # [Spectral DiffuserCam](https://waller-lab.github.io/SpectralDiffuserCam/)
4 |
5 |
6 |
7 | ## Paper
8 | [Spectral DiffuserCam: lensless snapshot hyperspectral imaging with a spectral filter array](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-7-10-1298)
9 |
10 | Please cite the following paper when using this code or data:
11 |
12 |
13 | ```
14 | @article{Monakhova:20,
15 | author = {Kristina Monakhova and Kyrollos Yanny and Neerja Aggarwal and Laura Waller},
16 | journal = {Optica},
17 | number = {10},
18 | pages = {1298--1307},
19 | publisher = {OSA},
20 | title = {Spectral DiffuserCam: lensless snapshot hyperspectral imaging with a spectral filter array},
21 | volume = {7},
22 | month = {Oct},
23 | year = {2020},
24 | url = {http://www.osapublishing.org/optica/abstract.cfm?URI=optica-7-10-1298},
25 | doi = {10.1364/OPTICA.397214}
26 | }
27 |
28 | ```
29 |
30 |
31 | ## Contents
32 |
33 | 1. [Data](#Data)
34 | 2. [Setup](#Setup)
35 | 3. [Description](#Description)
36 |
37 | ## Data
38 | Sample data (needed to run the code) can be found [here](https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing)
39 |
40 | This includes the following files:
41 | * calibration.mat - includes the calibratated point spread function, filter function, and wavelength list
42 | * four sample raw measurements
43 |
44 |
45 | ## Setup
46 | Clone this project using:
47 | ```
48 | git clone https://github.com/Waller-Lab/SpectralDiffuserCam.git
49 | ```
50 |
51 | The dependencies can be installed by using:
52 | ```
53 | conda env create -f environment.yml
54 | source activate SpectralDiffuserCam
55 | ```
56 |
57 | Please place the downloaded data in SampleData folder in the Python and/or Matlab folders.
58 |
59 | [Reconstruction Demo.ipynb](https://github.com/Waller-Lab/SpectralDiffuserCam/blob/master/Python/Reconstruction%20Demo.ipynb) contains an example reconstruction in Python.
60 |
61 | [reconstruction_demo.m](https://github.com/Waller-Lab/SpectralDiffuserCam/blob/master/Matlab/reconstruction_demo.m) contains an example reconstruction in Matlab.
62 |
63 | We recommend running this code on a GPU, but it can also be run on a CPU (much slower!).
64 |
65 | ## Description
66 | This repository contains code in both Python and Matlab that is needed to process raw Spectral DiffuserCam images and reconstruct 3D hyperspectral volumes from the raw 2D measurements. Four example raw images are provided, along with the calibrated point spread function and spectral filter function. Both the Python and Matlab versions support GPU acceleration. In Python, this is accomplished using cupy. We use FISTA for our reconstructions with a 3D total variation prior.
--------------------------------------------------------------------------------
/Matlab/reconstruction_demo.m:
--------------------------------------------------------------------------------
1 | %% Run Reconstruction for Spectral DiffuserCam
2 | % Last update: 9/29/2020
3 |
4 | addpath('helper_functions/')
5 | addpath('SampleData/')
6 |
7 | %% Load in Calibration data and PSF
8 | load('calibration.mat')
9 | wavelengths = wavs;
10 |
11 | im = double(imread('meas_dice.png'));
12 |
13 | %% Pre-process images and data
14 | % crop all data to the valid mask-pixels
15 | c1 = 100; c2 = 419;
16 | c3= 80; c4=539;
17 |
18 | mask = mask(c1:c2, c3:c4,:);
19 | psf = psf(c1:c2, c3:c4);
20 | im = double(im(c1:c2, c3:c4));
21 |
22 | % normalize PSF
23 | psf = psf/norm(psf, 'fro');
24 |
25 | % subtract pixel defect from mask and image
26 | mask_sum = sum(mask, 3);
27 | [maxval,idx]=max(mask_sum(:));
28 | [row, col] = ind2sub(size(mask_sum), idx);
29 | mask(row-2:row+2, col-2:col+2, :)= 0;
30 |
31 | im = im/max(max(im));
32 | im(row-2:row+2, col-2:col+2, :)= 0;
33 |
34 |
35 | %% Put everything on GPU (if using GPU)
36 | opts.use_gpu = 1; % Change to 0 to put on CPU
37 | if opts.use_gpu
38 | psf = gpuArray(single(psf(:,1:end)));
39 | if mod(size(mask,3),2) == 0
40 | mask = gpuArray(single(mask(:,1:end, 1:end)));
41 | else
42 | mask = gpuArray(single(mask(:,1:end, 1:end-1)));
43 | wavelengths = wavelengths(1:end-1);
44 | end
45 | im = gpuArray(single(im));
46 | else
47 | if mod(size(mask,3),2) == 0
48 | mask = single(mask(:,1:end, 1:end));
49 | else
50 | mask = single(mask(:,1:end, 1:end-1));
51 | wavelengths = wavelengths(1:end-1);
52 | end
53 | end
54 |
55 |
56 | %% Define reconstruction options (leave these alone for defaults)
57 | name='thordog';
58 | opts.fista_iters =500; % Number of FISTA iterations
59 | opts.denoise_method = 'tv'; % options: 'tv', 'non-neg', 'native', 'tv_lowrank'
60 |
61 | opts.tv_lambda = .003; % TV tuning parameter (higher is more TV)
62 | opts.tv_lambday = 1; % TV tuning parameter in y (compared to x)
63 | opts.tv_lambdaw = .01; % TV tuning parameter in lambda (compared to x)
64 | opts.lowrank_lambda = .00005; % Tuning parameter for the low-rank constraint
65 |
66 | opts.display_every = 1; % how often to display the reconstruction
67 | opts.save_data = 1; % save the intermediate recon images
68 | opts.save_data_freq = 100; % save data every 50 iterations
69 |
70 | % Filename to save the results
71 | filename_save = sprintf('saved_recons/%s_%s_lambda_%f_iterations_%f_', name, opts.denoise_method, opts.tv_lambda, opts.fista_iters);
72 | opts.save_data_path = filename_save;
73 |
74 |
75 | %% Run inverse solver
76 | [xout, loss_list] = fista_spectral_3d(im, psf, mask, opts);
77 |
78 |
79 | %%
80 | xout=fliplr(flipud(gather(xout)));
81 | false_color = false_color_function(xout);
82 | figure(), imshow(false_color); title('False-color reconstruction')
83 |
84 | figure(), imshow3D(xout);
--------------------------------------------------------------------------------
/Python/helper_functions/helper_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import matplotlib
4 | import scipy.io
5 | from IPython.core.display import display, HTML
6 | from ipywidgets import interact, widgets, fixed
7 |
8 | import sys
9 | sys.path.append('helper_functions/')
10 |
11 |
12 | def plotf2(r, img, ttl, sz):
13 | #fig = plt.figure(figsize=(2, 2));
14 | #plt.figure(figsize=(20, 20));
15 | plt.title(ttl+' {}'.format(r))
16 | plt.imshow(img[:,:,r], cmap="gray", vmin = 0, vmax = np.max(img));
17 | plt.axis('off');
18 | fig = plt.gcf()
19 | fig.set_size_inches(sz)
20 | plt.show();
21 | #display(fig)
22 | #clear_output(wait=True)
23 | return
24 |
25 | def plt3D(img, title = '', size = (5,5)):
26 | #fig = plt.figure(figsize=sz);
27 | interact(plotf2,
28 | r=widgets.IntSlider(min=0,max=np.shape(img)[-1]-1,step=1,value=1),
29 | img = fixed(img),
30 | continuous_update= False,
31 | ttl = fixed(title),
32 | sz = fixed(size));
33 |
34 | def crop(x):
35 | DIMS0 = x.shape[0]//2 # Image Dimensions
36 | DIMS1 = x.shape[1]//2 # Image Dimensions
37 |
38 | PAD_SIZE0 = int((DIMS0)//2) # Pad size
39 | PAD_SIZE1 = int((DIMS1)//2) # Pad size
40 |
41 | C01 = PAD_SIZE0; C02 = PAD_SIZE0 + DIMS0 # Crop indices
42 | C11 = PAD_SIZE1; C12 = PAD_SIZE1 + DIMS1 # Crop indices
43 | return x[C01:C02, C11:C12,:]
44 |
45 | def pre_plot(x):
46 | x = np.fliplr(np.flipud(x))
47 | x = x/np.max(x)
48 | x = np.clip(x, 0,1)
49 | return x
50 |
51 |
52 | def stack_rgb_opt(reflArray, opt = 'helper_functions/false_color_calib.mat', scaling = [1,1,2.5]):
53 |
54 | color_dict = scipy.io.loadmat(opt)
55 | red = color_dict['red']; green = color_dict['green']; blue = color_dict['blue']
56 |
57 | reflArray = reflArray/np.max(reflArray)
58 |
59 | red_channel = np.zeros((reflArray.shape[0], reflArray.shape[1]))
60 | green_channel = np.zeros((reflArray.shape[0], reflArray.shape[1]))
61 | blue_channel = np.zeros((reflArray.shape[0], reflArray.shape[1]))
62 |
63 | for i in range(0,64):
64 | red_channel = red_channel + reflArray[:,:,i]*red[0,i]*scaling[0]
65 | green_channel = green_channel + reflArray[:,:,i]*green[0,i]*scaling[1]
66 | blue_channel = blue_channel + reflArray[:,:,i]*blue[0,i]*scaling[2]
67 |
68 | red_channel = red_channel/64.
69 | green_channel = green_channel/64.
70 | blue_channel = blue_channel/64.
71 |
72 | stackedRGB = np.stack((red_channel,green_channel,blue_channel),axis=2)
73 |
74 | return stackedRGB
75 |
76 | def preprocess(mask, psf, im):
77 |
78 | # Crop indices
79 | c1 = 100; c2 = 420; c3 = 80; c4 = 540
80 |
81 | # Crop and normalize mask
82 | mask = mask[c1:c2, c3:c4, :]
83 | mask = mask/np.max(mask)
84 |
85 | # Crop and normalize PSF
86 | psf = psf[c1:c2, c3:c4]
87 | psf = psf/np.linalg.norm(psf)
88 |
89 | # Remove defective pixels in mask calibration
90 | mask_sum = np.sum(mask, 2)
91 | ind = np.unravel_index((np.argmax(mask_sum, axis = None)), mask_sum.shape)
92 | mask[ind[0]-2:ind[0]+2, ind[1]-2:ind[1]+2, :] = 0
93 |
94 | # Remove defective pixels in measurement
95 | im = im[c1:c2, c3:c4]
96 | im = im/np.max(im)
97 | im[ind[0]-2:ind[0]+2, ind[1]-2:ind[1]+2] = 0
98 | return mask, psf, im
--------------------------------------------------------------------------------
/Matlab/helper_functions/fista_spectral_3d.m:
--------------------------------------------------------------------------------
1 | function [xout, loss_list]=fista_spectral_3d(input_image, psf, mask, opts)
2 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3 | % FISTA implementation for Spectral DiffuserCam
4 | % Last update: 9/29/2020
5 | %
6 | %
7 | % Inputs
8 | % input_image ............. input, measured image
9 | % psf ..................... PSF, same size as input image
10 | % opts..................... options file
11 | % opts.fista_iters ....... Number of FISTA iterations
12 | % opts.denoise_method .... Either 'non-neg' or 'tv'
13 | % opts.tv_lambda ......... amount of tv
14 | % opts.tv_iters .......... number of inner-loop iterations
15 | %
16 | % Outputs
17 | % xout .................... deblurred image
18 | % loss_list ............... list of losses
19 |
20 |
21 | figure(2020)
22 |
23 | [Ny, Nx, n_filters] = size(mask); %Get problem size
24 |
25 | % Setup convolutional forward op
26 | p1 = floor(Ny/2);
27 | p2 = floor(Nx/2);
28 | pad2d = @(x)padarray(x,[p1,p2],'both'); %2D padding
29 | crop2d = @(x)x(p1+1:end-p1,p2+1:end-p2,:); %2D cropping
30 |
31 | vec = @(X)reshape(X,numel(X),1);
32 | Hs = fftn(ifftshift(pad2d(psf))); %Compute 3D spectrum
33 | Hs_conj = conj(Hs);
34 |
35 | Hfor_modified = @(x)pad2d(sum(mask.*crop2d(real((ifftn(Hs.*fftn((x)))))),3));
36 |
37 | Hfor = @(x)sum(mask.*crop2d(real((ifftn(Hs.*fftn((x)))))),3);
38 | Hadj = @(x)real((ifftn(Hs_conj.*fftn((pad2d(repmat(x, [1,1,size(mask, 3)]).*mask))))));
39 |
40 | maxeig = power_iteration(Hfor_modified, pad2d(psf), 10);
41 | L = maxeig*45;
42 |
43 |
44 | lambda = opts.tv_lambda;
45 |
46 | % TV denoising options:
47 | l = 0; u = Inf;
48 | clear parsin
49 | parsin.epsilon=1e-5;
50 | parsin.print=0;
51 | parsin.tv='iso';
52 | parsin.use_gpu = opts.use_gpu;
53 |
54 |
55 | if strcmp(opts.denoise_method, 'tv') == 1
56 | prox = @(x)(1/2*(max(x,0) + (tv3dApproxHaar(x, opts.tv_lambda/L, opts.tv_lambday, opts.tv_lambdaw))));
57 | loss = @(err, x) norm(err,'fro')^2 + 2*lambda/L*tlv(x, parsin.tv, parsin.use_gpu);
58 | elseif strcmp(opts.denoise_method, 'tv_lowrank') == 1
59 | prox = @(x)(1/3*(max(x,0) + tv3dApproxHaar(x, opts.tv_lambda/L, opts.tv_lambday, opts.tv_lambdaw)+ soft_thresh_lowrank(x, opts.lowrank_lambda)));
60 | loss = @(err, x) norm(err,'fro')^2 + 2*lambda/L*tlv(x, parsin.tv, parsin.use_gpu);
61 | elseif strcmp(opts.denoise_method, 'native') == 1
62 | prox = @(x) (1/2 * (max(x,0) + soft_thresh(x, opts.tv_lambda/L)));
63 | loss = @ (err, x) norm(err,'fro')^2;
64 | elseif strcmp(opts.denoise_method, 'non-neg') == 1
65 | prox = @(x)max(x,0);
66 | loss = @ (err, x) norm(err,'fro')^2;
67 | end
68 |
69 |
70 | if opts.save_data == 1
71 | if ~exist(opts.save_data_path, 'dir')
72 | mkdir(opts.save_data_path)
73 | end
74 | y_save = gather(input_image);
75 | mask_save = gather(mask);
76 | psf_save = gather(psf);
77 |
78 | filename = sprintf('%s/params.mat', opts.save_data_path);
79 | save(filename, 'y_save', 'psf_save', 'mask_save', 'opts', 'L');
80 | end
81 |
82 | padded_input = pad2d(input_image);
83 |
84 | %% Start FISTA
85 | xk = zeros(Ny*2, Nx*2);
86 | vk = zeros(Ny*2, Nx*2);
87 | tk = 1.0;
88 |
89 | loss_list = [];
90 |
91 | for i=1:opts.fista_iters
92 | xold = xk;
93 | vold = vk;
94 | told = tk;
95 |
96 | error = Hfor(vold) - input_image;
97 | grads = Hadj(error);
98 |
99 | xk = prox(vold - 1/L*grads);
100 |
101 | tk = 1 + sqrt(1+4*told^2)/2;
102 | vk = xk + (told-1)/tk *(xk- xold);
103 |
104 | loss_i = loss(error, xk);
105 | loss_list = [loss_list, loss_i];
106 |
107 |
108 | if mod(i,opts.display_every) == 0
109 | subplot(1,4,1)
110 |
111 | x_print = gather(xk);
112 | x_print = fliplr(flipud((x_print)));
113 | false_color = false_color_function(x_print);
114 | %imagesc(reshape(sum(xk, 3), Ny*2, Nx*2));
115 | imagesc(false_color);
116 | title('False color image')
117 |
118 | subplot(1,4,2)
119 | imagesc(reshape(sum(x_print, 2), Ny*2, n_filters));
120 | title('Y-Lambda sum projection')
121 |
122 | subplot(1,4,3)
123 | imagesc(reshape(sum(x_print, 1), Nx*2, n_filters));
124 | title('X-Lambda sum projection')
125 |
126 | subplot(1,4,4),
127 | semilogy(loss_list);
128 | title('Loss')
129 | drawnow
130 | end
131 |
132 | if mod(i, opts.save_data_freq) ==0 && opts.save_data == 1
133 | opts.save_data_path
134 |
135 | filename = sprintf('%s/iter%i.mat', opts.save_data_path, i);
136 | x_save = gather(xk);
137 | save(filename, 'x_save');
138 | end
139 |
140 | end
141 |
142 | xout = crop2d(xk);
143 |
144 | end
145 |
--------------------------------------------------------------------------------
/Python/fista_spectral_cupy.py:
--------------------------------------------------------------------------------
1 | import sys
2 | global device
3 | device= sys.argv[1]
4 | sys.path.append('helper_functions/')
5 |
6 | if device == 'GPU':
7 | import cupy as np
8 | import tv_approx_haar_cp as tv
9 |
10 | print('device = ', device, ', using GPU and cupy')
11 | else:
12 | import numpy as np
13 | import tv_approx_haar_np as tv
14 | print('device = ', device, ', using CPU and numpy')
15 |
16 | import helper_functions.helper_functions as fc
17 | import numpy as numpy
18 | import matplotlib.pyplot as plt
19 |
20 |
21 |
22 | class fista_spectral_numpy():
23 | def __init__(self, h, mask):
24 |
25 | ## Initialize constants
26 | self.DIMS0 = h.shape[0] # Image Dimensions
27 | self.DIMS1 = h.shape[1] # Image Dimensions
28 |
29 | self.spectral_channels = mask.shape[-1] # Number of spectral channels
30 |
31 | self.py = int((self.DIMS0)//2) # Pad size
32 | self.px = int((self.DIMS1)//2) # Pad size
33 |
34 | # FFT of point spread function
35 | self.H = np.expand_dims(np.fft.fft2((np.fft.ifftshift(self.pad(h), axes = (0,1))), axes = (0,1)), -1)
36 | self.Hconj = np.conj(self.H)
37 |
38 | self.mask = mask
39 |
40 | # Calculate the eigenvalue to set the step size
41 | maxeig = self.power_iteration(self.Hpower, (self.DIMS0*2, self.DIMS1*2), 10)
42 | self.L = maxeig*45
43 |
44 |
45 | self.prox_method = 'tv' # options: 'non-neg', 'tv', 'native'
46 |
47 | # Define soft-thresholding constants
48 | self.tau = .5 # Native sparsity tuning parameter
49 | self.tv_lambda = 0.00005 # TV tuning parameter
50 | self.tv_lambdaw = 0.00005 # TV tuning parameter for wavelength
51 | self.lowrank_lambda = 0.00005 # Low rank tuning parameter
52 |
53 |
54 | # Number of iterations of FISTA
55 | self.iters = 500
56 |
57 | self.show_recon_progress = True # Display the intermediate results
58 | self.print_every = 20 # Sets how often to print the image
59 |
60 | self.l_data = []
61 | self.l_tv = []
62 |
63 | # Power iteration to calculate eigenvalue
64 | def power_iteration(self, A, sample_vect_shape, num_iters):
65 | bk = np.random.randn(sample_vect_shape[0], sample_vect_shape[1])
66 | for i in range(0, num_iters):
67 | bk1 = A(bk)
68 | bk1_norm = np.linalg.norm(bk1)
69 |
70 | bk = bk1/bk1_norm
71 | Mx = A(bk)
72 | xx = np.transpose(np.dot(bk.ravel(), bk.ravel()))
73 | eig_b = np.transpose(bk.ravel()).dot(Mx.ravel())/xx
74 |
75 | return eig_b
76 |
77 | # Helper functions for forward model
78 | def crop(self,x):
79 | return x[self.py:-self.py, self.px:-self.px]
80 |
81 | def pad(self,x):
82 | if len(x.shape) == 2:
83 | out = np.pad(x, ([self.py, self.py], [self.px,self.px]), mode = 'constant')
84 | elif len(x.shape) == 3:
85 | out = np.pad(x, ([self.py, self.py], [self.px,self.px], [0, 0]), mode = 'constant')
86 | return out
87 |
88 | def Hpower(self, x):
89 | x = np.fft.ifft2(self.H* np.fft.fft2(np.expand_dims(x,-1), axes = (0,1)), axes = (0,1))
90 | x = np.sum(self.mask* self.crop(np.real(x)), 2)
91 | x = self.pad(x)
92 | return x
93 |
94 | def Hfor(self, x):
95 | x = np.fft.ifft2(self.H* np.fft.fft2(x, axes = (0,1)), axes = (0,1))
96 | x = np.sum(self.mask* self.crop(np.real(x)), 2)
97 | return x
98 |
99 | def Hadj(self, x):
100 | x = np.expand_dims(x,-1)
101 | x = x*self.mask
102 | x = self.pad(x)
103 |
104 | x = np.fft.fft2(x, axes = (0,1))
105 | x = np.fft.ifft2(self.Hconj*x, axes = (0,1))
106 | x = np.real(x)
107 | return x
108 |
109 | def soft_thresh(self, x, tau):
110 | out = np.maximum(np.abs(x)- tau, 0)
111 | out = out*np.sign(x)
112 | return out
113 |
114 | def prox(self,x):
115 | if self.prox_method == 'tv':
116 | x = 0.5*(np.maximum(x,0) + tv.tv3dApproxHaar(x, self.tv_lambda/self.L, self.tv_lambdaw))
117 | if self.prox_method == 'native':
118 | x = np.maximum(x,0) + self.soft_thresh(x, self.tau)
119 | if self.prox_method == 'non-neg':
120 | x = np.maximum(x,0)
121 | return x
122 |
123 | def tv(self, x):
124 | d = np.zeros_like(x)
125 | d[0:-1,:] = (x[0:-1,:] - x[1:, :])**2
126 | d[:,0:-1] = d[:,0:-1] + (x[:,0:-1] - x[:,1:])**2
127 | return np.sum(np.sqrt(d))
128 |
129 | def loss(self,x,err):
130 | if self.prox_method == 'tv':
131 | self.l_data.append(np.linalg.norm(err)**2)
132 | self.l_tv.append(2*self.tv_lambda/self.L * self.tv(x))
133 |
134 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * self.tv(x)
135 | if self.prox_method == 'native':
136 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * np.linalg.norm(x.ravel(), 1)
137 | if self.prox_method == 'non-neg':
138 | l = np.linalg.norm(err)**2
139 | return l
140 |
141 | # Main FISTA update
142 | def fista_update(self, vk, tk, xk, inputs):
143 |
144 | error = self.Hfor(vk) - inputs
145 | grads = self.Hadj(error)
146 |
147 | xup = self.prox(vk - 1/self.L * grads)
148 | tup = 1 + np.sqrt(1 + 4*tk**2)/2
149 | vup = xup + (tk-1)/tup * (xup-xk)
150 |
151 | return vup, tup, xup, self.loss(xup, error)
152 |
153 |
154 | # Run FISTA
155 | def run(self, inputs):
156 |
157 | # Initialize variables to zero
158 | xk = np.zeros((self.DIMS0*2, self.DIMS1*2, self.spectral_channels))
159 | vk = np.zeros((self.DIMS0*2, self.DIMS1*2, self.spectral_channels))
160 | tk = 1.0
161 |
162 | llist = []
163 |
164 | # Start FISTA loop
165 | for i in range(0,self.iters):
166 | vk, tk, xk, l = self.fista_update(vk, tk, xk, inputs)
167 |
168 | llist.append(l)
169 |
170 | # Print out the intermediate results and the loss
171 | if self.show_recon_progress== True and i%self.print_every == 0:
172 | print('iteration: ', i, ' loss: ', l)
173 | if device == 'GPU':
174 | out_img = np.asnumpy(self.crop(xk))
175 | else:
176 | out_img = self.crop(xk)
177 | fc_img = fc.pre_plot(fc.stack_rgb_opt(out_img))
178 | plt.figure(figsize = (10,3))
179 | plt.subplot(1,2,1), plt.imshow(fc_img/numpy.max(fc_img)); plt.title('Reconstruction')
180 | plt.subplot(1,2,2), plt.plot(llist); plt.title('Loss')
181 | plt.show()
182 | self.out_img = out_img
183 | xout = self.crop(xk)
184 | return xout, llist
185 |
--------------------------------------------------------------------------------
/Matlab/helper_functions/imshow3D.m:
--------------------------------------------------------------------------------
1 | function imshow3D( Img, disprange )
2 | %IMSHOW3D displays 3D grayscale or RGB images in slice by slice fashion
3 | %with mouse based slice browsing and window and level adjustment control.
4 | %
5 | % Usage:
6 | % imshow3D ( Image )
7 | % imshow3D ( Image , [] )
8 | % imshow3D ( Image , [LOW HIGH] )
9 | %
10 | % Image: 3D image MxNxKxC (K slices of MxN images) C is either 1
11 | % (for grayscale images) or 3 (for RGB images)
12 | % [LOW HIGH]: display range that controls the display intensity range of
13 | % a grayscale image (default: the widest available range)
14 | %
15 | % Use the scroll bar or mouse scroll wheel to switch between slices. To
16 | % adjust window and level values keep the mouse right button pressed and
17 | % drag the mouse up and down (for level adjustment) or right and left (for
18 | % window adjustment). Window and level adjustment control works only for
19 | % grayscale images.
20 | %
21 | % "Auto W/L" button adjust the window and level automatically for grayscale
22 | % images
23 | %
24 | % While "Fine Tune" check box is checked the window/level adjustment gets
25 | % 16 times less sensitive to mouse movement, to make it easier to control
26 | % display intensity rang.
27 | %
28 | % Note: The sensitivity of mouse based window and level adjustment is set
29 | % based on the user defined display intensity range; the wider the range
30 | % the more sensitivity to mouse drag.
31 | %
32 | % Note: IMSHOW3DFULL is a newer version of IMSHOW3D (also available on
33 | % MathWorks) that displays 3D grayscale or RGB images from three
34 | % perpendicular views (i.e. axial, sagittal, and coronal).
35 | %
36 | % Example
37 | % --------
38 | % % Display an image (MRI example)
39 | % load mri
40 | % Image = squeeze(D);
41 | % figure,
42 | % imshow3D(Image)
43 | %
44 | % % Display the image, adjust the display range
45 | % figure,
46 | % imshow3D(Image,[20 100]);
47 | %
48 | % See also IMSHOW.
49 |
50 | %
51 | % - Maysam Shahedi (mshahedi@gmail.com)
52 | % - Released: 1.0.0 Date: 2013/04/15
53 | % - Revision: 1.1.0 Date: 2013/04/19
54 | % - Revision: 1.5.0 Date: 2016/09/22
55 | %
56 |
57 | sno = size(Img,3); % number of slices
58 | S = round(sno/2);
59 |
60 | global InitialCoord;
61 |
62 | MinV = 0;
63 | MaxV = max(Img(:));
64 | LevV = (double( MaxV) + double(MinV)) / 2;
65 | Win = double(MaxV) - double(MinV);
66 | WLAdjCoe = (Win + 1)/1024;
67 | FineTuneC = [1 1/16]; % Regular/Fine-tune mode coefficients
68 |
69 | if isa(Img,'uint8')
70 | MaxV = uint8(Inf);
71 | MinV = uint8(-Inf);
72 | LevV = (double( MaxV) + double(MinV)) / 2;
73 | Win = double(MaxV) - double(MinV);
74 | WLAdjCoe = (Win + 1)/1024;
75 | elseif isa(Img,'uint16')
76 | MaxV = uint16(Inf);
77 | MinV = uint16(-Inf);
78 | LevV = (double( MaxV) + double(MinV)) / 2;
79 | Win = double(MaxV) - double(MinV);
80 | WLAdjCoe = (Win + 1)/1024;
81 | elseif isa(Img,'uint32')
82 | MaxV = uint32(Inf);
83 | MinV = uint32(-Inf);
84 | LevV = (double( MaxV) + double(MinV)) / 2;
85 | Win = double(MaxV) - double(MinV);
86 | WLAdjCoe = (Win + 1)/1024;
87 | elseif isa(Img,'uint64')
88 | MaxV = uint64(Inf);
89 | MinV = uint64(-Inf);
90 | LevV = (double( MaxV) + double(MinV)) / 2;
91 | Win = double(MaxV) - double(MinV);
92 | WLAdjCoe = (Win + 1)/1024;
93 | elseif isa(Img,'int8')
94 | MaxV = int8(Inf);
95 | MinV = int8(-Inf);
96 | LevV = (double( MaxV) + double(MinV)) / 2;
97 | Win = double(MaxV) - double(MinV);
98 | WLAdjCoe = (Win + 1)/1024;
99 | elseif isa(Img,'int16')
100 | MaxV = int16(Inf);
101 | MinV = int16(-Inf);
102 | LevV = (double( MaxV) + double(MinV)) / 2;
103 | Win = double(MaxV) - double(MinV);
104 | WLAdjCoe = (Win + 1)/1024;
105 | elseif isa(Img,'int32')
106 | MaxV = int32(Inf);
107 | MinV = int32(-Inf);
108 | LevV = (double( MaxV) + double(MinV)) / 2;
109 | Win = double(MaxV) - double(MinV);
110 | WLAdjCoe = (Win + 1)/1024;
111 | elseif isa(Img,'int64')
112 | MaxV = int64(Inf);
113 | MinV = int64(-Inf);
114 | LevV = (double( MaxV) + double(MinV)) / 2;
115 | Win = double(MaxV) - double(MinV);
116 | WLAdjCoe = (Win + 1)/1024;
117 | elseif isa(Img,'logical')
118 | MaxV = 0;
119 | MinV = 1;
120 | LevV =0.5;
121 | Win = 1;
122 | WLAdjCoe = 0.1;
123 | end
124 |
125 | SFntSz = 9;
126 | LFntSz = 10;
127 | WFntSz = 10;
128 | LVFntSz = 9;
129 | WVFntSz = 9;
130 | BtnSz = 10;
131 | ChBxSz = 10;
132 |
133 | if (nargin < 2)
134 | [Rmin Rmax] = WL2R(Win, LevV);
135 | elseif numel(disprange) == 0
136 | [Rmin Rmax] = WL2R(Win, LevV);
137 | else
138 | LevV = (double(disprange(2)) + double(disprange(1))) / 2;
139 | Win = double(disprange(2)) - double(disprange(1));
140 | WLAdjCoe = (Win + 1)/1024;
141 | [Rmin Rmax] = WL2R(Win, LevV);
142 | end
143 |
144 | axes('position',[0,0.2,1,0.8]), imshow(squeeze(Img(:,:,S,:)), [Rmin Rmax])
145 |
146 | FigPos = get(gcf,'Position');
147 | S_Pos = [50 45 uint16(FigPos(3)-100)+1 20];
148 | Stxt_Pos = [50 65 uint16(FigPos(3)-100)+1 15];
149 | Wtxt_Pos = [50 20 60 20];
150 | Wval_Pos = [110 20 60 20];
151 | Ltxt_Pos = [175 20 45 20];
152 | Lval_Pos = [220 20 60 20];
153 | BtnStPnt = uint16(FigPos(3)-250)+1;
154 | if BtnStPnt < 300
155 | BtnStPnt = 300;
156 | end
157 | Btn_Pos = [BtnStPnt 20 100 20];
158 | ChBx_Pos = [BtnStPnt+110 20 100 20];
159 |
160 | if sno > 1
161 | shand = uicontrol('Style', 'slider','Min',1,'Max',sno,'Value',S,'SliderStep',[1/(sno-1) 10/(sno-1)],'Position', S_Pos,'Callback', {@SliceSlider, Img});
162 | stxthand = uicontrol('Style', 'text','Position', Stxt_Pos,'String',sprintf('Slice# %d / %d',S, sno), 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', SFntSz);
163 | else
164 | stxthand = uicontrol('Style', 'text','Position', Stxt_Pos,'String','2D image', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', SFntSz);
165 | end
166 | ltxthand = uicontrol('Style', 'text','Position', Ltxt_Pos,'String','Level: ', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', LFntSz);
167 | wtxthand = uicontrol('Style', 'text','Position', Wtxt_Pos,'String','Window: ', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', WFntSz);
168 | lvalhand = uicontrol('Style', 'edit','Position', Lval_Pos,'String',sprintf('%6.0f',LevV), 'BackgroundColor', [1 1 1], 'FontSize', LVFntSz,'Callback', @WinLevChanged);
169 | wvalhand = uicontrol('Style', 'edit','Position', Wval_Pos,'String',sprintf('%6.0f',Win), 'BackgroundColor', [1 1 1], 'FontSize', WVFntSz,'Callback', @WinLevChanged);
170 | Btnhand = uicontrol('Style', 'pushbutton','Position', Btn_Pos,'String','Auto W/L', 'FontSize', BtnSz, 'Callback' , @AutoAdjust);
171 | ChBxhand = uicontrol('Style', 'checkbox','Position', ChBx_Pos,'String','Fine Tune', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', ChBxSz);
172 |
173 | set (gcf, 'WindowScrollWheelFcn', @mouseScroll);
174 | set (gcf, 'ButtonDownFcn', @mouseClick);
175 | set(get(gca,'Children'),'ButtonDownFcn', @mouseClick);
176 | set(gcf,'WindowButtonUpFcn', @mouseRelease)
177 | set(gcf,'ResizeFcn', @figureResized)
178 |
179 |
180 | % -=< Figure resize callback function >=-
181 | function figureResized(object, eventdata)
182 | FigPos = get(gcf,'Position');
183 | S_Pos = [50 45 uint16(FigPos(3)-100)+1 20];
184 | Stxt_Pos = [50 65 uint16(FigPos(3)-100)+1 15];
185 | BtnStPnt = uint16(FigPos(3)-250)+1;
186 | if BtnStPnt < 300
187 | BtnStPnt = 300;
188 | end
189 | Btn_Pos = [BtnStPnt 20 100 20];
190 | ChBx_Pos = [BtnStPnt+110 20 100 20];
191 | if sno > 1
192 | set(shand,'Position', S_Pos);
193 | end
194 | set(stxthand,'Position', Stxt_Pos);
195 | set(ltxthand,'Position', Ltxt_Pos);
196 | set(wtxthand,'Position', Wtxt_Pos);
197 | set(lvalhand,'Position', Lval_Pos);
198 | set(wvalhand,'Position', Wval_Pos);
199 | set(Btnhand,'Position', Btn_Pos);
200 | set(ChBxhand,'Position', ChBx_Pos);
201 | end
202 |
203 | % -=< Slice slider callback function >=-
204 | function SliceSlider (hObj,event, Img)
205 | S = round(get(hObj,'Value'));
206 | set(get(gca,'children'),'cdata',squeeze(Img(:,:,S,:)))
207 | caxis([Rmin Rmax])
208 | if sno > 1
209 | set(stxthand, 'String', sprintf('Slice# %d / %d',S, sno));
210 | else
211 | set(stxthand, 'String', '2D image');
212 | end
213 | end
214 |
215 | % -=< Mouse scroll wheel callback function >=-
216 | function mouseScroll (object, eventdata)
217 | UPDN = eventdata.VerticalScrollCount;
218 | S = S - UPDN;
219 | if (S < 1)
220 | S = 1;
221 | elseif (S > sno)
222 | S = sno;
223 | end
224 | if sno > 1
225 | set(shand,'Value',S);
226 | set(stxthand, 'String', sprintf('Slice# %d / %d',S, sno));
227 | else
228 | set(stxthand, 'String', '2D image');
229 | end
230 | set(get(gca,'children'),'cdata',squeeze(Img(:,:,S,:)))
231 | end
232 |
233 | % -=< Mouse button released callback function >=-
234 | function mouseRelease (object,eventdata)
235 | set(gcf, 'WindowButtonMotionFcn', '')
236 | end
237 |
238 | % -=< Mouse click callback function >=-
239 | function mouseClick (object, eventdata)
240 | MouseStat = get(gcbf, 'SelectionType');
241 | if (MouseStat(1) == 'a') % RIGHT CLICK
242 | InitialCoord = get(0,'PointerLocation');
243 | set(gcf, 'WindowButtonMotionFcn', @WinLevAdj);
244 | end
245 | end
246 |
247 | % -=< Window and level mouse adjustment >=-
248 | function WinLevAdj(varargin)
249 | PosDiff = get(0,'PointerLocation') - InitialCoord;
250 |
251 | Win = Win + PosDiff(1) * WLAdjCoe * FineTuneC(get(ChBxhand,'Value')+1);
252 | LevV = LevV - PosDiff(2) * WLAdjCoe * FineTuneC(get(ChBxhand,'Value')+1);
253 | if (Win < 1)
254 | Win = 1;
255 | end
256 |
257 | [Rmin, Rmax] = WL2R(Win,LevV);
258 | caxis([Rmin, Rmax])
259 | set(lvalhand, 'String', sprintf('%6.0f',LevV));
260 | set(wvalhand, 'String', sprintf('%6.0f',Win));
261 | InitialCoord = get(0,'PointerLocation');
262 | end
263 |
264 | % -=< Window and level text adjustment >=-
265 | function WinLevChanged(varargin)
266 |
267 | LevV = str2double(get(lvalhand, 'string'));
268 | Win = str2double(get(wvalhand, 'string'));
269 | if (Win < 1)
270 | Win = 1;
271 | end
272 |
273 | [Rmin, Rmax] = WL2R(Win,LevV);
274 | caxis([Rmin, Rmax])
275 | end
276 |
277 | % -=< Window and level to range conversion >=-
278 | function [Rmn Rmx] = WL2R(W,L)
279 | Rmn = L - (W/2);
280 | Rmx = L + (W/2);
281 | if (Rmn >= Rmx)
282 | Rmx = Rmn + 1;
283 | end
284 | end
285 |
286 | % -=< Window and level auto adjustment callback function >=-
287 | function AutoAdjust(object,eventdata)
288 | Win = double(max(Img(:))-min(Img(:)));
289 | Win (Win < 1) = 1;
290 | LevV = double(min(Img(:)) + (Win/2));
291 | [Rmin, Rmax] = WL2R(Win,LevV);
292 | caxis([Rmin, Rmax])
293 | set(lvalhand, 'String', sprintf('%6.0f',LevV));
294 | set(wvalhand, 'String', sprintf('%6.0f',Win));
295 | end
296 |
297 | end
298 | % -=< Maysam Shahedi (mshahedi@gmail.com), September 22, 2016>=-
--------------------------------------------------------------------------------