├── Matlab ├── saved_recons │ └── readme.txt ├── helper_functions │ ├── soft_thresh.m │ ├── Ltrans.m │ ├── false_color_calib.mat │ ├── power_iteration.m │ ├── false_color_function.m │ ├── tlv.m │ ├── tv3dApproxHaar.m │ ├── fista_spectral_3d.m │ └── imshow3D.m ├── SampleData │ └── readme.txt └── reconstruction_demo.m ├── Python ├── SampleData │ └── readme.txt ├── helper_functions │ ├── false_color_calib.mat │ ├── tv_approx_haar_cp.py │ ├── tv_approx_haar_np.py │ └── helper_functions.py └── fista_spectral_cupy.py ├── environment.yml ├── .gitignore ├── LICENSE └── README.md /Matlab/saved_recons/readme.txt: -------------------------------------------------------------------------------- 1 | Saved reconstructions will be saved here. -------------------------------------------------------------------------------- /Matlab/helper_functions/soft_thresh.m: -------------------------------------------------------------------------------- 1 | function out = soft_thresh(x,tau) 2 | 3 | out = max(abs(x)-tau,0); 4 | out = out.*sign(x); -------------------------------------------------------------------------------- /Matlab/SampleData/readme.txt: -------------------------------------------------------------------------------- 1 | The sample data can be found here: https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing -------------------------------------------------------------------------------- /Python/SampleData/readme.txt: -------------------------------------------------------------------------------- 1 | The sample data can be found here: https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing -------------------------------------------------------------------------------- /Matlab/helper_functions/Ltrans.m: -------------------------------------------------------------------------------- 1 | function P=Ltrans(X) 2 | 3 | [m,n]=size(X); 4 | 5 | P{1}=X(1:m-1,:)-X(2:m,:); 6 | P{2}=X(:,1:n-1)-X(:,2:n); 7 | -------------------------------------------------------------------------------- /Matlab/helper_functions/false_color_calib.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/SpectralDiffuserCam/HEAD/Matlab/helper_functions/false_color_calib.mat -------------------------------------------------------------------------------- /Python/helper_functions/false_color_calib.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/SpectralDiffuserCam/HEAD/Python/helper_functions/false_color_calib.mat -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SpectralDiffuserCam 2 | channels: 3 | - defaults 4 | dependencies: 5 | - anaconda 6 | - cupy 7 | prefix: /home/kristina/anaconda3/envs/SpectralDiffuserCam 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Jupyter Notebook 2 | .ipynb_checkpoints 3 | */.ipynb_checkpoints/* 4 | ipython_config.py 5 | 6 | # pyenv 7 | .python-version 8 | 9 | # Ignore image files and numpy files 10 | *.png 11 | *.jpg 12 | *.jpeg 13 | *.tiff 14 | *.npy 15 | calibration.mat 16 | __pycache__ -------------------------------------------------------------------------------- /Matlab/helper_functions/power_iteration.m: -------------------------------------------------------------------------------- 1 | 2 | function eig_b = power_iteration(A, sample_vect, num_iters) 3 | bk = rand(size(sample_vect)); 4 | for i=1:num_iters 5 | bk1 = A(bk); 6 | bk1_norm = norm(bk1); 7 | 8 | bk = bk1./bk1_norm; 9 | end 10 | 11 | Mx = A(bk); 12 | xx = transpose(bk(:))*bk(:); 13 | eig_b = (transpose(bk(:))*Mx(:))/xx; 14 | end -------------------------------------------------------------------------------- /Matlab/helper_functions/false_color_function.m: -------------------------------------------------------------------------------- 1 | function false_color = false_color_function(x) 2 | 3 | load('false_color_calib.mat'); 4 | 5 | scaling = [1,1,2.5]; 6 | false_color = zeros(size(x, 1), size(x, 2), 3); 7 | for i=1:64 8 | false_color(:,:,1) = false_color(:,:,1) + red(i)*scaling(1)*x(:,:,i); 9 | false_color(:,:,2) = false_color(:,:,2) + green(i)*scaling(2)*x(:,:,i); 10 | false_color(:,:,3) = false_color(:,:,3) + blue(i)*scaling(3)*x(:,:,i); 11 | end 12 | 13 | false_color = false_color/max(max(max(false_color))); 14 | return -------------------------------------------------------------------------------- /Matlab/helper_functions/tlv.m: -------------------------------------------------------------------------------- 1 | function out=tlv(X,type, gpu) 2 | %This function computes the total variation of an input image X 3 | % 4 | % INPUT 5 | % 6 | % X............................. An image 7 | % type .................... Type of total variation function. Either 'iso' 8 | % (isotropic) or 'l1' (nonisotropic) 9 | % 10 | % OUTPUT 11 | % out ....................... The total variation of X. 12 | [m,n]=size(X); 13 | P=Ltrans(X); 14 | 15 | switch type 16 | case 'iso' 17 | D=zeros(m,n); 18 | %if existsOnGPU(X) 19 | % D = gpuArray(D); 20 | %end 21 | if gpu 22 | D = gpuArray(D); 23 | end 24 | D(1:m-1,:)=P{1}.^2; 25 | D(:,1:n-1)=D(:,1:n-1)+P{2}.^2; 26 | out=sum(sum(sqrt(D))); 27 | case 'l1' 28 | out=sum(sum(abs(P{1})))+sum(sum(abs(P{2}))); 29 | otherwise 30 | error('Invalid total variation type. Should be either "iso" or "l1"'); 31 | end 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Waller Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Python/helper_functions/tv_approx_haar_cp.py: -------------------------------------------------------------------------------- 1 | import cupy as np 2 | 3 | def soft_py(x, tau): 4 | threshed = np.maximum(np.abs(x)-tau, 0) 5 | threshed = threshed*np.sign(x) 6 | return threshed 7 | 8 | def ht3(x, ax, shift, thresh): 9 | C = 1./np.sqrt(2.) 10 | 11 | if shift == True: 12 | x = np.roll(x, -1, axis = ax) 13 | if ax == 0: 14 | w1 = C*(x[1::2,:,:] + x[0::2, :, :]) 15 | w2 = soft_py(C*(x[1::2,:,:] - x[0::2, :, :]), thresh) 16 | elif ax == 1: 17 | w1 = C*(x[:, 1::2,:] + x[:, 0::2, :]) 18 | w2 = soft_py(C*(x[:,1::2,:] - x[:,0::2, :]), thresh) 19 | elif ax == 2: 20 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2]) 21 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh) 22 | return w1, w2 23 | 24 | def iht3(w1, w2, ax, shift, shape): 25 | 26 | C = 1./np.sqrt(2.) 27 | y = np.zeros(shape) 28 | 29 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 30 | if ax == 0: 31 | y[0::2, :, :] = x1 32 | y[1::2, :, :] = x2 33 | 34 | if ax == 1: 35 | y[:, 0::2, :] = x1 36 | y[:, 1::2, :] = x2 37 | if ax == 2: 38 | y[:, :, 0::2] = x1 39 | y[:, :, 1::2] = x2 40 | 41 | 42 | if shift == True: 43 | y = np.roll(y, 1, axis = ax) 44 | return y 45 | 46 | 47 | def iht3_py2(w1, w2, ax, shift, shape): 48 | 49 | C = 1./np.sqrt(2.) 50 | y = np.zeros(shape) 51 | 52 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 53 | 54 | ind = ax + 2; 55 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape) 56 | 57 | 58 | if shift == True: 59 | y = np.roll(y, 1, axis = ax+1) 60 | return y 61 | 62 | def tv3dApproxHaar(x, tau, alpha): 63 | D = 3 64 | fact = np.sqrt(2)*2 65 | 66 | thresh = D*tau*fact 67 | 68 | 69 | y = np.zeros_like(x) 70 | for ax in range(0,len(x.shape)): 71 | if ax ==2: 72 | t_scale = alpha 73 | else: 74 | t_scale = 1; 75 | 76 | w0, w1 = ht3(x, ax, False, thresh*t_scale) 77 | w2, w3 = ht3(x, ax, True, thresh*t_scale) 78 | 79 | t1 = iht3(w0, w1, ax, False, x.shape) 80 | t2 = iht3(w2, w3, ax, True, x.shape) 81 | y = y + t1 + t2 82 | 83 | y = y/(2*D) 84 | return y 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /Python/helper_functions/tv_approx_haar_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def soft_py(x, tau): 4 | threshed = np.maximum(np.abs(x)-tau, 0) 5 | threshed = threshed*np.sign(x) 6 | return threshed 7 | 8 | def ht3(x, ax, shift, thresh): 9 | C = 1./np.sqrt(2.) 10 | 11 | if shift == True: 12 | x = np.roll(x, -1, axis = ax) 13 | if ax == 0: 14 | w1 = C*(x[1::2,:,:] + x[0::2, :, :]) 15 | w2 = soft_py(C*(x[1::2,:,:] - x[0::2, :, :]), thresh) 16 | elif ax == 1: 17 | w1 = C*(x[:, 1::2,:] + x[:, 0::2, :]) 18 | w2 = soft_py(C*(x[:,1::2,:] - x[:,0::2, :]), thresh) 19 | elif ax == 2: 20 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2]) 21 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh) 22 | return w1, w2 23 | 24 | def iht3(w1, w2, ax, shift, shape): 25 | 26 | C = 1./np.sqrt(2.) 27 | y = np.zeros(shape) 28 | 29 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 30 | if ax == 0: 31 | y[0::2, :, :] = x1 32 | y[1::2, :, :] = x2 33 | 34 | if ax == 1: 35 | y[:, 0::2, :] = x1 36 | y[:, 1::2, :] = x2 37 | if ax == 2: 38 | y[:, :, 0::2] = x1 39 | y[:, :, 1::2] = x2 40 | 41 | 42 | if shift == True: 43 | y = np.roll(y, 1, axis = ax) 44 | return y 45 | 46 | 47 | def iht3_py2(w1, w2, ax, shift, shape): 48 | 49 | C = 1./np.sqrt(2.) 50 | y = np.zeros(shape) 51 | 52 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 53 | 54 | ind = ax + 2; 55 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape) 56 | 57 | 58 | if shift == True: 59 | y = np.roll(y, 1, axis = ax+1) 60 | return y 61 | 62 | def tv3dApproxHaar(x, tau, alpha): 63 | D = 3 64 | fact = np.sqrt(2)*2 65 | 66 | thresh = D*tau*fact 67 | 68 | 69 | y = np.zeros_like(x) 70 | for ax in range(0,len(x.shape)): 71 | if ax ==2: 72 | t_scale = alpha 73 | else: 74 | t_scale = 1; 75 | 76 | w0, w1 = ht3(x, ax, False, thresh*t_scale) 77 | w2, w3 = ht3(x, ax, True, thresh*t_scale) 78 | 79 | t1 = iht3(w0, w1, ax, False, x.shape) 80 | t2 = iht3(w2, w3, ax, True, x.shape) 81 | y = y + t1 + t2 82 | 83 | y = y/(2*D) 84 | return y 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /Matlab/helper_functions/tv3dApproxHaar.m: -------------------------------------------------------------------------------- 1 | function y = tv3dApproxHaar(x, tau, alphay, alphaz) 2 | D = 3; 3 | gamma = 1; %step size 4 | thresh = sqrt(2) * 2 * D * tau * gamma; 5 | y = zeros(size(x), 'like', x); 6 | for axis = 1 : 3 7 | if axis == 3 8 | t_scale = alphaz; 9 | elseif axis ==1 10 | t_scale = alphay; 11 | else 12 | t_scale = 1; 13 | end 14 | y = y + iht3(ht3(x, axis, false, thresh*t_scale), axis, false); 15 | y = y + iht3(ht3(x, axis, true, thresh*t_scale), axis, true); 16 | end 17 | y = y / (2 * D); 18 | return 19 | 20 | function w = ht3(x, ax, shift, thresh) 21 | s = size(x); 22 | w = zeros(s, 'like', x); 23 | C = 1 / sqrt(2); 24 | if shift 25 | x = circshift(x, -1, ax); 26 | end 27 | m = floor(s(ax) / 2); 28 | if ax == 1 29 | w(1:m, :, :) = C * (x(2:2:end, :, :) + x(1:2:end, :, :)); 30 | w((m + 1):end, :, :) = hs_soft(C * (x(2:2:end, :, :) - x(1:2:end, :, :)), thresh); 31 | %w((m + 1):end, :, :) = hs_soft(w((m + 1):end, :, :), thresh); 32 | elseif ax == 2 33 | w(:, 1:m, :) = C * (x(:, 2:2:end, :) + x(:, 1:2:end, :)); 34 | w(:, (m + 1):end, :) = C * (x(:, 2:2:end, :) - x(:, 1:2:end, :)); 35 | w(:, (m + 1):end, :) = hs_soft(w(:, (m + 1):end, :), thresh); 36 | else 37 | w(:, :, 1:m) = C * (x(:, :, 2:2:end) + x(:, :, 1:2:end)); 38 | w(:, :, (m + 1):end) = C * (x(:, :, 2:2:end) - x(:, :, 1:2:end)); 39 | w(:, :, (m + 1):end) = hs_soft(w(:, :, (m + 1):end), thresh); 40 | end 41 | return 42 | 43 | function y = iht3(w, ax, shift) 44 | s = size(w); 45 | y = zeros(s, 'like', w); 46 | C = 1 / sqrt(2); 47 | m = floor(s(ax) / 2); 48 | if ax == 1 49 | y(1:2:end, :, :) = C * (w(1:m, :, :) - w((m + 1):end, :, :)); 50 | y(2:2:end, :, :) = C * (w(1:m, :, :) + w((m + 1):end, :, :)); 51 | elseif ax == 2 52 | y(:, 1:2:end, :) = C * (w(:, 1:m, :) - w(:, (m + 1):end, :)); 53 | y(:, 2:2:end, :) = C * (w(:, 1:m, :) + w(:, (m + 1):end, :)); 54 | else 55 | y(:, :, 1:2:end) = C * (w(:, :, 1:m) - w(:, :, (m + 1):end)); 56 | y(:, :, 2:2:end) = C * (w(:, :, 1:m) + w(:, :, (m + 1):end)); 57 | end 58 | if shift 59 | y = circshift(y, 1, ax); 60 | end 61 | return 62 | 63 | function threshed = hs_soft(x,tau) 64 | 65 | threshed = max(abs(x)-tau,0); 66 | threshed = threshed.*sign(x); 67 | return -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | # [Spectral DiffuserCam](https://waller-lab.github.io/SpectralDiffuserCam/) 4 | 5 | 6 | 7 | ## Paper 8 | [Spectral DiffuserCam: lensless snapshot hyperspectral imaging with a spectral filter array](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-7-10-1298) 9 | 10 | Please cite the following paper when using this code or data: 11 | 12 | 13 | ``` 14 | @article{Monakhova:20, 15 | author = {Kristina Monakhova and Kyrollos Yanny and Neerja Aggarwal and Laura Waller}, 16 | journal = {Optica}, 17 | number = {10}, 18 | pages = {1298--1307}, 19 | publisher = {OSA}, 20 | title = {Spectral DiffuserCam: lensless snapshot hyperspectral imaging with a spectral filter array}, 21 | volume = {7}, 22 | month = {Oct}, 23 | year = {2020}, 24 | url = {http://www.osapublishing.org/optica/abstract.cfm?URI=optica-7-10-1298}, 25 | doi = {10.1364/OPTICA.397214} 26 | } 27 | 28 | ``` 29 | 30 | 31 | ## Contents 32 | 33 | 1. [Data](#Data) 34 | 2. [Setup](#Setup) 35 | 3. [Description](#Description) 36 | 37 | ## Data 38 | Sample data (needed to run the code) can be found [here](https://drive.google.com/drive/folders/1dmfzkTLFZZFUYW8GC6Vn6SOuZiZq47SS?usp=sharing) 39 | 40 | This includes the following files: 41 | * calibration.mat - includes the calibratated point spread function, filter function, and wavelength list 42 | * four sample raw measurements 43 | 44 | 45 | ## Setup 46 | Clone this project using: 47 | ``` 48 | git clone https://github.com/Waller-Lab/SpectralDiffuserCam.git 49 | ``` 50 | 51 | The dependencies can be installed by using: 52 | ``` 53 | conda env create -f environment.yml 54 | source activate SpectralDiffuserCam 55 | ``` 56 | 57 | Please place the downloaded data in SampleData folder in the Python and/or Matlab folders. 58 | 59 | [Reconstruction Demo.ipynb](https://github.com/Waller-Lab/SpectralDiffuserCam/blob/master/Python/Reconstruction%20Demo.ipynb) contains an example reconstruction in Python. 60 | 61 | [reconstruction_demo.m](https://github.com/Waller-Lab/SpectralDiffuserCam/blob/master/Matlab/reconstruction_demo.m) contains an example reconstruction in Matlab. 62 | 63 | We recommend running this code on a GPU, but it can also be run on a CPU (much slower!). 64 | 65 | ## Description 66 | This repository contains code in both Python and Matlab that is needed to process raw Spectral DiffuserCam images and reconstruct 3D hyperspectral volumes from the raw 2D measurements. Four example raw images are provided, along with the calibrated point spread function and spectral filter function. Both the Python and Matlab versions support GPU acceleration. In Python, this is accomplished using cupy. We use FISTA for our reconstructions with a 3D total variation prior. -------------------------------------------------------------------------------- /Matlab/reconstruction_demo.m: -------------------------------------------------------------------------------- 1 | %% Run Reconstruction for Spectral DiffuserCam 2 | % Last update: 9/29/2020 3 | 4 | addpath('helper_functions/') 5 | addpath('SampleData/') 6 | 7 | %% Load in Calibration data and PSF 8 | load('calibration.mat') 9 | wavelengths = wavs; 10 | 11 | im = double(imread('meas_dice.png')); 12 | 13 | %% Pre-process images and data 14 | % crop all data to the valid mask-pixels 15 | c1 = 100; c2 = 419; 16 | c3= 80; c4=539; 17 | 18 | mask = mask(c1:c2, c3:c4,:); 19 | psf = psf(c1:c2, c3:c4); 20 | im = double(im(c1:c2, c3:c4)); 21 | 22 | % normalize PSF 23 | psf = psf/norm(psf, 'fro'); 24 | 25 | % subtract pixel defect from mask and image 26 | mask_sum = sum(mask, 3); 27 | [maxval,idx]=max(mask_sum(:)); 28 | [row, col] = ind2sub(size(mask_sum), idx); 29 | mask(row-2:row+2, col-2:col+2, :)= 0; 30 | 31 | im = im/max(max(im)); 32 | im(row-2:row+2, col-2:col+2, :)= 0; 33 | 34 | 35 | %% Put everything on GPU (if using GPU) 36 | opts.use_gpu = 1; % Change to 0 to put on CPU 37 | if opts.use_gpu 38 | psf = gpuArray(single(psf(:,1:end))); 39 | if mod(size(mask,3),2) == 0 40 | mask = gpuArray(single(mask(:,1:end, 1:end))); 41 | else 42 | mask = gpuArray(single(mask(:,1:end, 1:end-1))); 43 | wavelengths = wavelengths(1:end-1); 44 | end 45 | im = gpuArray(single(im)); 46 | else 47 | if mod(size(mask,3),2) == 0 48 | mask = single(mask(:,1:end, 1:end)); 49 | else 50 | mask = single(mask(:,1:end, 1:end-1)); 51 | wavelengths = wavelengths(1:end-1); 52 | end 53 | end 54 | 55 | 56 | %% Define reconstruction options (leave these alone for defaults) 57 | name='thordog'; 58 | opts.fista_iters =500; % Number of FISTA iterations 59 | opts.denoise_method = 'tv'; % options: 'tv', 'non-neg', 'native', 'tv_lowrank' 60 | 61 | opts.tv_lambda = .003; % TV tuning parameter (higher is more TV) 62 | opts.tv_lambday = 1; % TV tuning parameter in y (compared to x) 63 | opts.tv_lambdaw = .01; % TV tuning parameter in lambda (compared to x) 64 | opts.lowrank_lambda = .00005; % Tuning parameter for the low-rank constraint 65 | 66 | opts.display_every = 1; % how often to display the reconstruction 67 | opts.save_data = 1; % save the intermediate recon images 68 | opts.save_data_freq = 100; % save data every 50 iterations 69 | 70 | % Filename to save the results 71 | filename_save = sprintf('saved_recons/%s_%s_lambda_%f_iterations_%f_', name, opts.denoise_method, opts.tv_lambda, opts.fista_iters); 72 | opts.save_data_path = filename_save; 73 | 74 | 75 | %% Run inverse solver 76 | [xout, loss_list] = fista_spectral_3d(im, psf, mask, opts); 77 | 78 | 79 | %% 80 | xout=fliplr(flipud(gather(xout))); 81 | false_color = false_color_function(xout); 82 | figure(), imshow(false_color); title('False-color reconstruction') 83 | 84 | figure(), imshow3D(xout); -------------------------------------------------------------------------------- /Python/helper_functions/helper_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import scipy.io 5 | from IPython.core.display import display, HTML 6 | from ipywidgets import interact, widgets, fixed 7 | 8 | import sys 9 | sys.path.append('helper_functions/') 10 | 11 | 12 | def plotf2(r, img, ttl, sz): 13 | #fig = plt.figure(figsize=(2, 2)); 14 | #plt.figure(figsize=(20, 20)); 15 | plt.title(ttl+' {}'.format(r)) 16 | plt.imshow(img[:,:,r], cmap="gray", vmin = 0, vmax = np.max(img)); 17 | plt.axis('off'); 18 | fig = plt.gcf() 19 | fig.set_size_inches(sz) 20 | plt.show(); 21 | #display(fig) 22 | #clear_output(wait=True) 23 | return 24 | 25 | def plt3D(img, title = '', size = (5,5)): 26 | #fig = plt.figure(figsize=sz); 27 | interact(plotf2, 28 | r=widgets.IntSlider(min=0,max=np.shape(img)[-1]-1,step=1,value=1), 29 | img = fixed(img), 30 | continuous_update= False, 31 | ttl = fixed(title), 32 | sz = fixed(size)); 33 | 34 | def crop(x): 35 | DIMS0 = x.shape[0]//2 # Image Dimensions 36 | DIMS1 = x.shape[1]//2 # Image Dimensions 37 | 38 | PAD_SIZE0 = int((DIMS0)//2) # Pad size 39 | PAD_SIZE1 = int((DIMS1)//2) # Pad size 40 | 41 | C01 = PAD_SIZE0; C02 = PAD_SIZE0 + DIMS0 # Crop indices 42 | C11 = PAD_SIZE1; C12 = PAD_SIZE1 + DIMS1 # Crop indices 43 | return x[C01:C02, C11:C12,:] 44 | 45 | def pre_plot(x): 46 | x = np.fliplr(np.flipud(x)) 47 | x = x/np.max(x) 48 | x = np.clip(x, 0,1) 49 | return x 50 | 51 | 52 | def stack_rgb_opt(reflArray, opt = 'helper_functions/false_color_calib.mat', scaling = [1,1,2.5]): 53 | 54 | color_dict = scipy.io.loadmat(opt) 55 | red = color_dict['red']; green = color_dict['green']; blue = color_dict['blue'] 56 | 57 | reflArray = reflArray/np.max(reflArray) 58 | 59 | red_channel = np.zeros((reflArray.shape[0], reflArray.shape[1])) 60 | green_channel = np.zeros((reflArray.shape[0], reflArray.shape[1])) 61 | blue_channel = np.zeros((reflArray.shape[0], reflArray.shape[1])) 62 | 63 | for i in range(0,64): 64 | red_channel = red_channel + reflArray[:,:,i]*red[0,i]*scaling[0] 65 | green_channel = green_channel + reflArray[:,:,i]*green[0,i]*scaling[1] 66 | blue_channel = blue_channel + reflArray[:,:,i]*blue[0,i]*scaling[2] 67 | 68 | red_channel = red_channel/64. 69 | green_channel = green_channel/64. 70 | blue_channel = blue_channel/64. 71 | 72 | stackedRGB = np.stack((red_channel,green_channel,blue_channel),axis=2) 73 | 74 | return stackedRGB 75 | 76 | def preprocess(mask, psf, im): 77 | 78 | # Crop indices 79 | c1 = 100; c2 = 420; c3 = 80; c4 = 540 80 | 81 | # Crop and normalize mask 82 | mask = mask[c1:c2, c3:c4, :] 83 | mask = mask/np.max(mask) 84 | 85 | # Crop and normalize PSF 86 | psf = psf[c1:c2, c3:c4] 87 | psf = psf/np.linalg.norm(psf) 88 | 89 | # Remove defective pixels in mask calibration 90 | mask_sum = np.sum(mask, 2) 91 | ind = np.unravel_index((np.argmax(mask_sum, axis = None)), mask_sum.shape) 92 | mask[ind[0]-2:ind[0]+2, ind[1]-2:ind[1]+2, :] = 0 93 | 94 | # Remove defective pixels in measurement 95 | im = im[c1:c2, c3:c4] 96 | im = im/np.max(im) 97 | im[ind[0]-2:ind[0]+2, ind[1]-2:ind[1]+2] = 0 98 | return mask, psf, im -------------------------------------------------------------------------------- /Matlab/helper_functions/fista_spectral_3d.m: -------------------------------------------------------------------------------- 1 | function [xout, loss_list]=fista_spectral_3d(input_image, psf, mask, opts) 2 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 3 | % FISTA implementation for Spectral DiffuserCam 4 | % Last update: 9/29/2020 5 | % 6 | % 7 | % Inputs 8 | % input_image ............. input, measured image 9 | % psf ..................... PSF, same size as input image 10 | % opts..................... options file 11 | % opts.fista_iters ....... Number of FISTA iterations 12 | % opts.denoise_method .... Either 'non-neg' or 'tv' 13 | % opts.tv_lambda ......... amount of tv 14 | % opts.tv_iters .......... number of inner-loop iterations 15 | % 16 | % Outputs 17 | % xout .................... deblurred image 18 | % loss_list ............... list of losses 19 | 20 | 21 | figure(2020) 22 | 23 | [Ny, Nx, n_filters] = size(mask); %Get problem size 24 | 25 | % Setup convolutional forward op 26 | p1 = floor(Ny/2); 27 | p2 = floor(Nx/2); 28 | pad2d = @(x)padarray(x,[p1,p2],'both'); %2D padding 29 | crop2d = @(x)x(p1+1:end-p1,p2+1:end-p2,:); %2D cropping 30 | 31 | vec = @(X)reshape(X,numel(X),1); 32 | Hs = fftn(ifftshift(pad2d(psf))); %Compute 3D spectrum 33 | Hs_conj = conj(Hs); 34 | 35 | Hfor_modified = @(x)pad2d(sum(mask.*crop2d(real((ifftn(Hs.*fftn((x)))))),3)); 36 | 37 | Hfor = @(x)sum(mask.*crop2d(real((ifftn(Hs.*fftn((x)))))),3); 38 | Hadj = @(x)real((ifftn(Hs_conj.*fftn((pad2d(repmat(x, [1,1,size(mask, 3)]).*mask)))))); 39 | 40 | maxeig = power_iteration(Hfor_modified, pad2d(psf), 10); 41 | L = maxeig*45; 42 | 43 | 44 | lambda = opts.tv_lambda; 45 | 46 | % TV denoising options: 47 | l = 0; u = Inf; 48 | clear parsin 49 | parsin.epsilon=1e-5; 50 | parsin.print=0; 51 | parsin.tv='iso'; 52 | parsin.use_gpu = opts.use_gpu; 53 | 54 | 55 | if strcmp(opts.denoise_method, 'tv') == 1 56 | prox = @(x)(1/2*(max(x,0) + (tv3dApproxHaar(x, opts.tv_lambda/L, opts.tv_lambday, opts.tv_lambdaw)))); 57 | loss = @(err, x) norm(err,'fro')^2 + 2*lambda/L*tlv(x, parsin.tv, parsin.use_gpu); 58 | elseif strcmp(opts.denoise_method, 'tv_lowrank') == 1 59 | prox = @(x)(1/3*(max(x,0) + tv3dApproxHaar(x, opts.tv_lambda/L, opts.tv_lambday, opts.tv_lambdaw)+ soft_thresh_lowrank(x, opts.lowrank_lambda))); 60 | loss = @(err, x) norm(err,'fro')^2 + 2*lambda/L*tlv(x, parsin.tv, parsin.use_gpu); 61 | elseif strcmp(opts.denoise_method, 'native') == 1 62 | prox = @(x) (1/2 * (max(x,0) + soft_thresh(x, opts.tv_lambda/L))); 63 | loss = @ (err, x) norm(err,'fro')^2; 64 | elseif strcmp(opts.denoise_method, 'non-neg') == 1 65 | prox = @(x)max(x,0); 66 | loss = @ (err, x) norm(err,'fro')^2; 67 | end 68 | 69 | 70 | if opts.save_data == 1 71 | if ~exist(opts.save_data_path, 'dir') 72 | mkdir(opts.save_data_path) 73 | end 74 | y_save = gather(input_image); 75 | mask_save = gather(mask); 76 | psf_save = gather(psf); 77 | 78 | filename = sprintf('%s/params.mat', opts.save_data_path); 79 | save(filename, 'y_save', 'psf_save', 'mask_save', 'opts', 'L'); 80 | end 81 | 82 | padded_input = pad2d(input_image); 83 | 84 | %% Start FISTA 85 | xk = zeros(Ny*2, Nx*2); 86 | vk = zeros(Ny*2, Nx*2); 87 | tk = 1.0; 88 | 89 | loss_list = []; 90 | 91 | for i=1:opts.fista_iters 92 | xold = xk; 93 | vold = vk; 94 | told = tk; 95 | 96 | error = Hfor(vold) - input_image; 97 | grads = Hadj(error); 98 | 99 | xk = prox(vold - 1/L*grads); 100 | 101 | tk = 1 + sqrt(1+4*told^2)/2; 102 | vk = xk + (told-1)/tk *(xk- xold); 103 | 104 | loss_i = loss(error, xk); 105 | loss_list = [loss_list, loss_i]; 106 | 107 | 108 | if mod(i,opts.display_every) == 0 109 | subplot(1,4,1) 110 | 111 | x_print = gather(xk); 112 | x_print = fliplr(flipud((x_print))); 113 | false_color = false_color_function(x_print); 114 | %imagesc(reshape(sum(xk, 3), Ny*2, Nx*2)); 115 | imagesc(false_color); 116 | title('False color image') 117 | 118 | subplot(1,4,2) 119 | imagesc(reshape(sum(x_print, 2), Ny*2, n_filters)); 120 | title('Y-Lambda sum projection') 121 | 122 | subplot(1,4,3) 123 | imagesc(reshape(sum(x_print, 1), Nx*2, n_filters)); 124 | title('X-Lambda sum projection') 125 | 126 | subplot(1,4,4), 127 | semilogy(loss_list); 128 | title('Loss') 129 | drawnow 130 | end 131 | 132 | if mod(i, opts.save_data_freq) ==0 && opts.save_data == 1 133 | opts.save_data_path 134 | 135 | filename = sprintf('%s/iter%i.mat', opts.save_data_path, i); 136 | x_save = gather(xk); 137 | save(filename, 'x_save'); 138 | end 139 | 140 | end 141 | 142 | xout = crop2d(xk); 143 | 144 | end 145 | -------------------------------------------------------------------------------- /Python/fista_spectral_cupy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | global device 3 | device= sys.argv[1] 4 | sys.path.append('helper_functions/') 5 | 6 | if device == 'GPU': 7 | import cupy as np 8 | import tv_approx_haar_cp as tv 9 | 10 | print('device = ', device, ', using GPU and cupy') 11 | else: 12 | import numpy as np 13 | import tv_approx_haar_np as tv 14 | print('device = ', device, ', using CPU and numpy') 15 | 16 | import helper_functions.helper_functions as fc 17 | import numpy as numpy 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | 22 | class fista_spectral_numpy(): 23 | def __init__(self, h, mask): 24 | 25 | ## Initialize constants 26 | self.DIMS0 = h.shape[0] # Image Dimensions 27 | self.DIMS1 = h.shape[1] # Image Dimensions 28 | 29 | self.spectral_channels = mask.shape[-1] # Number of spectral channels 30 | 31 | self.py = int((self.DIMS0)//2) # Pad size 32 | self.px = int((self.DIMS1)//2) # Pad size 33 | 34 | # FFT of point spread function 35 | self.H = np.expand_dims(np.fft.fft2((np.fft.ifftshift(self.pad(h), axes = (0,1))), axes = (0,1)), -1) 36 | self.Hconj = np.conj(self.H) 37 | 38 | self.mask = mask 39 | 40 | # Calculate the eigenvalue to set the step size 41 | maxeig = self.power_iteration(self.Hpower, (self.DIMS0*2, self.DIMS1*2), 10) 42 | self.L = maxeig*45 43 | 44 | 45 | self.prox_method = 'tv' # options: 'non-neg', 'tv', 'native' 46 | 47 | # Define soft-thresholding constants 48 | self.tau = .5 # Native sparsity tuning parameter 49 | self.tv_lambda = 0.00005 # TV tuning parameter 50 | self.tv_lambdaw = 0.00005 # TV tuning parameter for wavelength 51 | self.lowrank_lambda = 0.00005 # Low rank tuning parameter 52 | 53 | 54 | # Number of iterations of FISTA 55 | self.iters = 500 56 | 57 | self.show_recon_progress = True # Display the intermediate results 58 | self.print_every = 20 # Sets how often to print the image 59 | 60 | self.l_data = [] 61 | self.l_tv = [] 62 | 63 | # Power iteration to calculate eigenvalue 64 | def power_iteration(self, A, sample_vect_shape, num_iters): 65 | bk = np.random.randn(sample_vect_shape[0], sample_vect_shape[1]) 66 | for i in range(0, num_iters): 67 | bk1 = A(bk) 68 | bk1_norm = np.linalg.norm(bk1) 69 | 70 | bk = bk1/bk1_norm 71 | Mx = A(bk) 72 | xx = np.transpose(np.dot(bk.ravel(), bk.ravel())) 73 | eig_b = np.transpose(bk.ravel()).dot(Mx.ravel())/xx 74 | 75 | return eig_b 76 | 77 | # Helper functions for forward model 78 | def crop(self,x): 79 | return x[self.py:-self.py, self.px:-self.px] 80 | 81 | def pad(self,x): 82 | if len(x.shape) == 2: 83 | out = np.pad(x, ([self.py, self.py], [self.px,self.px]), mode = 'constant') 84 | elif len(x.shape) == 3: 85 | out = np.pad(x, ([self.py, self.py], [self.px,self.px], [0, 0]), mode = 'constant') 86 | return out 87 | 88 | def Hpower(self, x): 89 | x = np.fft.ifft2(self.H* np.fft.fft2(np.expand_dims(x,-1), axes = (0,1)), axes = (0,1)) 90 | x = np.sum(self.mask* self.crop(np.real(x)), 2) 91 | x = self.pad(x) 92 | return x 93 | 94 | def Hfor(self, x): 95 | x = np.fft.ifft2(self.H* np.fft.fft2(x, axes = (0,1)), axes = (0,1)) 96 | x = np.sum(self.mask* self.crop(np.real(x)), 2) 97 | return x 98 | 99 | def Hadj(self, x): 100 | x = np.expand_dims(x,-1) 101 | x = x*self.mask 102 | x = self.pad(x) 103 | 104 | x = np.fft.fft2(x, axes = (0,1)) 105 | x = np.fft.ifft2(self.Hconj*x, axes = (0,1)) 106 | x = np.real(x) 107 | return x 108 | 109 | def soft_thresh(self, x, tau): 110 | out = np.maximum(np.abs(x)- tau, 0) 111 | out = out*np.sign(x) 112 | return out 113 | 114 | def prox(self,x): 115 | if self.prox_method == 'tv': 116 | x = 0.5*(np.maximum(x,0) + tv.tv3dApproxHaar(x, self.tv_lambda/self.L, self.tv_lambdaw)) 117 | if self.prox_method == 'native': 118 | x = np.maximum(x,0) + self.soft_thresh(x, self.tau) 119 | if self.prox_method == 'non-neg': 120 | x = np.maximum(x,0) 121 | return x 122 | 123 | def tv(self, x): 124 | d = np.zeros_like(x) 125 | d[0:-1,:] = (x[0:-1,:] - x[1:, :])**2 126 | d[:,0:-1] = d[:,0:-1] + (x[:,0:-1] - x[:,1:])**2 127 | return np.sum(np.sqrt(d)) 128 | 129 | def loss(self,x,err): 130 | if self.prox_method == 'tv': 131 | self.l_data.append(np.linalg.norm(err)**2) 132 | self.l_tv.append(2*self.tv_lambda/self.L * self.tv(x)) 133 | 134 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * self.tv(x) 135 | if self.prox_method == 'native': 136 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * np.linalg.norm(x.ravel(), 1) 137 | if self.prox_method == 'non-neg': 138 | l = np.linalg.norm(err)**2 139 | return l 140 | 141 | # Main FISTA update 142 | def fista_update(self, vk, tk, xk, inputs): 143 | 144 | error = self.Hfor(vk) - inputs 145 | grads = self.Hadj(error) 146 | 147 | xup = self.prox(vk - 1/self.L * grads) 148 | tup = 1 + np.sqrt(1 + 4*tk**2)/2 149 | vup = xup + (tk-1)/tup * (xup-xk) 150 | 151 | return vup, tup, xup, self.loss(xup, error) 152 | 153 | 154 | # Run FISTA 155 | def run(self, inputs): 156 | 157 | # Initialize variables to zero 158 | xk = np.zeros((self.DIMS0*2, self.DIMS1*2, self.spectral_channels)) 159 | vk = np.zeros((self.DIMS0*2, self.DIMS1*2, self.spectral_channels)) 160 | tk = 1.0 161 | 162 | llist = [] 163 | 164 | # Start FISTA loop 165 | for i in range(0,self.iters): 166 | vk, tk, xk, l = self.fista_update(vk, tk, xk, inputs) 167 | 168 | llist.append(l) 169 | 170 | # Print out the intermediate results and the loss 171 | if self.show_recon_progress== True and i%self.print_every == 0: 172 | print('iteration: ', i, ' loss: ', l) 173 | if device == 'GPU': 174 | out_img = np.asnumpy(self.crop(xk)) 175 | else: 176 | out_img = self.crop(xk) 177 | fc_img = fc.pre_plot(fc.stack_rgb_opt(out_img)) 178 | plt.figure(figsize = (10,3)) 179 | plt.subplot(1,2,1), plt.imshow(fc_img/numpy.max(fc_img)); plt.title('Reconstruction') 180 | plt.subplot(1,2,2), plt.plot(llist); plt.title('Loss') 181 | plt.show() 182 | self.out_img = out_img 183 | xout = self.crop(xk) 184 | return xout, llist 185 | -------------------------------------------------------------------------------- /Matlab/helper_functions/imshow3D.m: -------------------------------------------------------------------------------- 1 | function imshow3D( Img, disprange ) 2 | %IMSHOW3D displays 3D grayscale or RGB images in slice by slice fashion 3 | %with mouse based slice browsing and window and level adjustment control. 4 | % 5 | % Usage: 6 | % imshow3D ( Image ) 7 | % imshow3D ( Image , [] ) 8 | % imshow3D ( Image , [LOW HIGH] ) 9 | % 10 | % Image: 3D image MxNxKxC (K slices of MxN images) C is either 1 11 | % (for grayscale images) or 3 (for RGB images) 12 | % [LOW HIGH]: display range that controls the display intensity range of 13 | % a grayscale image (default: the widest available range) 14 | % 15 | % Use the scroll bar or mouse scroll wheel to switch between slices. To 16 | % adjust window and level values keep the mouse right button pressed and 17 | % drag the mouse up and down (for level adjustment) or right and left (for 18 | % window adjustment). Window and level adjustment control works only for 19 | % grayscale images. 20 | % 21 | % "Auto W/L" button adjust the window and level automatically for grayscale 22 | % images 23 | % 24 | % While "Fine Tune" check box is checked the window/level adjustment gets 25 | % 16 times less sensitive to mouse movement, to make it easier to control 26 | % display intensity rang. 27 | % 28 | % Note: The sensitivity of mouse based window and level adjustment is set 29 | % based on the user defined display intensity range; the wider the range 30 | % the more sensitivity to mouse drag. 31 | % 32 | % Note: IMSHOW3DFULL is a newer version of IMSHOW3D (also available on 33 | % MathWorks) that displays 3D grayscale or RGB images from three 34 | % perpendicular views (i.e. axial, sagittal, and coronal). 35 | % 36 | % Example 37 | % -------- 38 | % % Display an image (MRI example) 39 | % load mri 40 | % Image = squeeze(D); 41 | % figure, 42 | % imshow3D(Image) 43 | % 44 | % % Display the image, adjust the display range 45 | % figure, 46 | % imshow3D(Image,[20 100]); 47 | % 48 | % See also IMSHOW. 49 | 50 | % 51 | % - Maysam Shahedi (mshahedi@gmail.com) 52 | % - Released: 1.0.0 Date: 2013/04/15 53 | % - Revision: 1.1.0 Date: 2013/04/19 54 | % - Revision: 1.5.0 Date: 2016/09/22 55 | % 56 | 57 | sno = size(Img,3); % number of slices 58 | S = round(sno/2); 59 | 60 | global InitialCoord; 61 | 62 | MinV = 0; 63 | MaxV = max(Img(:)); 64 | LevV = (double( MaxV) + double(MinV)) / 2; 65 | Win = double(MaxV) - double(MinV); 66 | WLAdjCoe = (Win + 1)/1024; 67 | FineTuneC = [1 1/16]; % Regular/Fine-tune mode coefficients 68 | 69 | if isa(Img,'uint8') 70 | MaxV = uint8(Inf); 71 | MinV = uint8(-Inf); 72 | LevV = (double( MaxV) + double(MinV)) / 2; 73 | Win = double(MaxV) - double(MinV); 74 | WLAdjCoe = (Win + 1)/1024; 75 | elseif isa(Img,'uint16') 76 | MaxV = uint16(Inf); 77 | MinV = uint16(-Inf); 78 | LevV = (double( MaxV) + double(MinV)) / 2; 79 | Win = double(MaxV) - double(MinV); 80 | WLAdjCoe = (Win + 1)/1024; 81 | elseif isa(Img,'uint32') 82 | MaxV = uint32(Inf); 83 | MinV = uint32(-Inf); 84 | LevV = (double( MaxV) + double(MinV)) / 2; 85 | Win = double(MaxV) - double(MinV); 86 | WLAdjCoe = (Win + 1)/1024; 87 | elseif isa(Img,'uint64') 88 | MaxV = uint64(Inf); 89 | MinV = uint64(-Inf); 90 | LevV = (double( MaxV) + double(MinV)) / 2; 91 | Win = double(MaxV) - double(MinV); 92 | WLAdjCoe = (Win + 1)/1024; 93 | elseif isa(Img,'int8') 94 | MaxV = int8(Inf); 95 | MinV = int8(-Inf); 96 | LevV = (double( MaxV) + double(MinV)) / 2; 97 | Win = double(MaxV) - double(MinV); 98 | WLAdjCoe = (Win + 1)/1024; 99 | elseif isa(Img,'int16') 100 | MaxV = int16(Inf); 101 | MinV = int16(-Inf); 102 | LevV = (double( MaxV) + double(MinV)) / 2; 103 | Win = double(MaxV) - double(MinV); 104 | WLAdjCoe = (Win + 1)/1024; 105 | elseif isa(Img,'int32') 106 | MaxV = int32(Inf); 107 | MinV = int32(-Inf); 108 | LevV = (double( MaxV) + double(MinV)) / 2; 109 | Win = double(MaxV) - double(MinV); 110 | WLAdjCoe = (Win + 1)/1024; 111 | elseif isa(Img,'int64') 112 | MaxV = int64(Inf); 113 | MinV = int64(-Inf); 114 | LevV = (double( MaxV) + double(MinV)) / 2; 115 | Win = double(MaxV) - double(MinV); 116 | WLAdjCoe = (Win + 1)/1024; 117 | elseif isa(Img,'logical') 118 | MaxV = 0; 119 | MinV = 1; 120 | LevV =0.5; 121 | Win = 1; 122 | WLAdjCoe = 0.1; 123 | end 124 | 125 | SFntSz = 9; 126 | LFntSz = 10; 127 | WFntSz = 10; 128 | LVFntSz = 9; 129 | WVFntSz = 9; 130 | BtnSz = 10; 131 | ChBxSz = 10; 132 | 133 | if (nargin < 2) 134 | [Rmin Rmax] = WL2R(Win, LevV); 135 | elseif numel(disprange) == 0 136 | [Rmin Rmax] = WL2R(Win, LevV); 137 | else 138 | LevV = (double(disprange(2)) + double(disprange(1))) / 2; 139 | Win = double(disprange(2)) - double(disprange(1)); 140 | WLAdjCoe = (Win + 1)/1024; 141 | [Rmin Rmax] = WL2R(Win, LevV); 142 | end 143 | 144 | axes('position',[0,0.2,1,0.8]), imshow(squeeze(Img(:,:,S,:)), [Rmin Rmax]) 145 | 146 | FigPos = get(gcf,'Position'); 147 | S_Pos = [50 45 uint16(FigPos(3)-100)+1 20]; 148 | Stxt_Pos = [50 65 uint16(FigPos(3)-100)+1 15]; 149 | Wtxt_Pos = [50 20 60 20]; 150 | Wval_Pos = [110 20 60 20]; 151 | Ltxt_Pos = [175 20 45 20]; 152 | Lval_Pos = [220 20 60 20]; 153 | BtnStPnt = uint16(FigPos(3)-250)+1; 154 | if BtnStPnt < 300 155 | BtnStPnt = 300; 156 | end 157 | Btn_Pos = [BtnStPnt 20 100 20]; 158 | ChBx_Pos = [BtnStPnt+110 20 100 20]; 159 | 160 | if sno > 1 161 | shand = uicontrol('Style', 'slider','Min',1,'Max',sno,'Value',S,'SliderStep',[1/(sno-1) 10/(sno-1)],'Position', S_Pos,'Callback', {@SliceSlider, Img}); 162 | stxthand = uicontrol('Style', 'text','Position', Stxt_Pos,'String',sprintf('Slice# %d / %d',S, sno), 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', SFntSz); 163 | else 164 | stxthand = uicontrol('Style', 'text','Position', Stxt_Pos,'String','2D image', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', SFntSz); 165 | end 166 | ltxthand = uicontrol('Style', 'text','Position', Ltxt_Pos,'String','Level: ', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', LFntSz); 167 | wtxthand = uicontrol('Style', 'text','Position', Wtxt_Pos,'String','Window: ', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', WFntSz); 168 | lvalhand = uicontrol('Style', 'edit','Position', Lval_Pos,'String',sprintf('%6.0f',LevV), 'BackgroundColor', [1 1 1], 'FontSize', LVFntSz,'Callback', @WinLevChanged); 169 | wvalhand = uicontrol('Style', 'edit','Position', Wval_Pos,'String',sprintf('%6.0f',Win), 'BackgroundColor', [1 1 1], 'FontSize', WVFntSz,'Callback', @WinLevChanged); 170 | Btnhand = uicontrol('Style', 'pushbutton','Position', Btn_Pos,'String','Auto W/L', 'FontSize', BtnSz, 'Callback' , @AutoAdjust); 171 | ChBxhand = uicontrol('Style', 'checkbox','Position', ChBx_Pos,'String','Fine Tune', 'BackgroundColor', [0.8 0.8 0.8], 'FontSize', ChBxSz); 172 | 173 | set (gcf, 'WindowScrollWheelFcn', @mouseScroll); 174 | set (gcf, 'ButtonDownFcn', @mouseClick); 175 | set(get(gca,'Children'),'ButtonDownFcn', @mouseClick); 176 | set(gcf,'WindowButtonUpFcn', @mouseRelease) 177 | set(gcf,'ResizeFcn', @figureResized) 178 | 179 | 180 | % -=< Figure resize callback function >=- 181 | function figureResized(object, eventdata) 182 | FigPos = get(gcf,'Position'); 183 | S_Pos = [50 45 uint16(FigPos(3)-100)+1 20]; 184 | Stxt_Pos = [50 65 uint16(FigPos(3)-100)+1 15]; 185 | BtnStPnt = uint16(FigPos(3)-250)+1; 186 | if BtnStPnt < 300 187 | BtnStPnt = 300; 188 | end 189 | Btn_Pos = [BtnStPnt 20 100 20]; 190 | ChBx_Pos = [BtnStPnt+110 20 100 20]; 191 | if sno > 1 192 | set(shand,'Position', S_Pos); 193 | end 194 | set(stxthand,'Position', Stxt_Pos); 195 | set(ltxthand,'Position', Ltxt_Pos); 196 | set(wtxthand,'Position', Wtxt_Pos); 197 | set(lvalhand,'Position', Lval_Pos); 198 | set(wvalhand,'Position', Wval_Pos); 199 | set(Btnhand,'Position', Btn_Pos); 200 | set(ChBxhand,'Position', ChBx_Pos); 201 | end 202 | 203 | % -=< Slice slider callback function >=- 204 | function SliceSlider (hObj,event, Img) 205 | S = round(get(hObj,'Value')); 206 | set(get(gca,'children'),'cdata',squeeze(Img(:,:,S,:))) 207 | caxis([Rmin Rmax]) 208 | if sno > 1 209 | set(stxthand, 'String', sprintf('Slice# %d / %d',S, sno)); 210 | else 211 | set(stxthand, 'String', '2D image'); 212 | end 213 | end 214 | 215 | % -=< Mouse scroll wheel callback function >=- 216 | function mouseScroll (object, eventdata) 217 | UPDN = eventdata.VerticalScrollCount; 218 | S = S - UPDN; 219 | if (S < 1) 220 | S = 1; 221 | elseif (S > sno) 222 | S = sno; 223 | end 224 | if sno > 1 225 | set(shand,'Value',S); 226 | set(stxthand, 'String', sprintf('Slice# %d / %d',S, sno)); 227 | else 228 | set(stxthand, 'String', '2D image'); 229 | end 230 | set(get(gca,'children'),'cdata',squeeze(Img(:,:,S,:))) 231 | end 232 | 233 | % -=< Mouse button released callback function >=- 234 | function mouseRelease (object,eventdata) 235 | set(gcf, 'WindowButtonMotionFcn', '') 236 | end 237 | 238 | % -=< Mouse click callback function >=- 239 | function mouseClick (object, eventdata) 240 | MouseStat = get(gcbf, 'SelectionType'); 241 | if (MouseStat(1) == 'a') % RIGHT CLICK 242 | InitialCoord = get(0,'PointerLocation'); 243 | set(gcf, 'WindowButtonMotionFcn', @WinLevAdj); 244 | end 245 | end 246 | 247 | % -=< Window and level mouse adjustment >=- 248 | function WinLevAdj(varargin) 249 | PosDiff = get(0,'PointerLocation') - InitialCoord; 250 | 251 | Win = Win + PosDiff(1) * WLAdjCoe * FineTuneC(get(ChBxhand,'Value')+1); 252 | LevV = LevV - PosDiff(2) * WLAdjCoe * FineTuneC(get(ChBxhand,'Value')+1); 253 | if (Win < 1) 254 | Win = 1; 255 | end 256 | 257 | [Rmin, Rmax] = WL2R(Win,LevV); 258 | caxis([Rmin, Rmax]) 259 | set(lvalhand, 'String', sprintf('%6.0f',LevV)); 260 | set(wvalhand, 'String', sprintf('%6.0f',Win)); 261 | InitialCoord = get(0,'PointerLocation'); 262 | end 263 | 264 | % -=< Window and level text adjustment >=- 265 | function WinLevChanged(varargin) 266 | 267 | LevV = str2double(get(lvalhand, 'string')); 268 | Win = str2double(get(wvalhand, 'string')); 269 | if (Win < 1) 270 | Win = 1; 271 | end 272 | 273 | [Rmin, Rmax] = WL2R(Win,LevV); 274 | caxis([Rmin, Rmax]) 275 | end 276 | 277 | % -=< Window and level to range conversion >=- 278 | function [Rmn Rmx] = WL2R(W,L) 279 | Rmn = L - (W/2); 280 | Rmx = L + (W/2); 281 | if (Rmn >= Rmx) 282 | Rmx = Rmn + 1; 283 | end 284 | end 285 | 286 | % -=< Window and level auto adjustment callback function >=- 287 | function AutoAdjust(object,eventdata) 288 | Win = double(max(Img(:))-min(Img(:))); 289 | Win (Win < 1) = 1; 290 | LevV = double(min(Img(:)) + (Win/2)); 291 | [Rmin, Rmax] = WL2R(Win,LevV); 292 | caxis([Rmin, Rmax]) 293 | set(lvalhand, 'String', sprintf('%6.0f',LevV)); 294 | set(wvalhand, 'String', sprintf('%6.0f',Win)); 295 | end 296 | 297 | end 298 | % -=< Maysam Shahedi (mshahedi@gmail.com), September 22, 2016>=- --------------------------------------------------------------------------------