├── .gitignore ├── EEG └── eeg_data.mat ├── big_images ├── 1.jpg ├── 10.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg ├── large_images ├── 1.jpg ├── 2.jpg ├── 3.jpg └── 4.JPG ├── medium_images ├── 1.jpg ├── 10.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg ├── Code_image ├── coding.png ├── filters_chau.png ├── coding_wrt_beta.png ├── objective_func.png ├── reconstruction.png ├── reconstr_err_wrt_beta.png ├── compare_beta.py ├── visualize_results_image.py ├── main_image.py └── CSC_image.py ├── small_images ├── image1.jpg ├── image2.jpg └── image3.jpg ├── single_image └── houses_big.jpg ├── Code_Signal ├── Result │ ├── codes_0.png │ ├── filters.png │ ├── codes_10.png │ ├── codes_20.png │ ├── codes_30.png │ ├── codes_40.png │ ├── codes_0_fit.png │ ├── objective_func.png │ ├── reconstruction.png │ └── original_signal.png ├── .ipynb_checkpoints │ └── Untitled-checkpoint.ipynb ├── main_signal_unknown_filter.py ├── main_signal_known_filter.py ├── visualize_code fitting.py ├── visualize_results_signal.py └── CSC_signal.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npz 3 | -------------------------------------------------------------------------------- /EEG/eeg_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/EEG/eeg_data.mat -------------------------------------------------------------------------------- /big_images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/1.jpg -------------------------------------------------------------------------------- /big_images/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/10.jpg -------------------------------------------------------------------------------- /big_images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/2.jpg -------------------------------------------------------------------------------- /big_images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/3.jpg -------------------------------------------------------------------------------- /big_images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/4.jpg -------------------------------------------------------------------------------- /big_images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/5.jpg -------------------------------------------------------------------------------- /big_images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/6.jpg -------------------------------------------------------------------------------- /big_images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/7.jpg -------------------------------------------------------------------------------- /big_images/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/8.jpg -------------------------------------------------------------------------------- /big_images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/big_images/9.jpg -------------------------------------------------------------------------------- /large_images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/large_images/1.jpg -------------------------------------------------------------------------------- /large_images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/large_images/2.jpg -------------------------------------------------------------------------------- /large_images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/large_images/3.jpg -------------------------------------------------------------------------------- /large_images/4.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/large_images/4.JPG -------------------------------------------------------------------------------- /medium_images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/1.jpg -------------------------------------------------------------------------------- /medium_images/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/10.jpg -------------------------------------------------------------------------------- /medium_images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/2.jpg -------------------------------------------------------------------------------- /medium_images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/3.jpg -------------------------------------------------------------------------------- /medium_images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/4.jpg -------------------------------------------------------------------------------- /medium_images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/5.jpg -------------------------------------------------------------------------------- /medium_images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/6.jpg -------------------------------------------------------------------------------- /medium_images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/7.jpg -------------------------------------------------------------------------------- /medium_images/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/8.jpg -------------------------------------------------------------------------------- /medium_images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/medium_images/9.jpg -------------------------------------------------------------------------------- /Code_image/coding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/coding.png -------------------------------------------------------------------------------- /small_images/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/small_images/image1.jpg -------------------------------------------------------------------------------- /small_images/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/small_images/image2.jpg -------------------------------------------------------------------------------- /small_images/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/small_images/image3.jpg -------------------------------------------------------------------------------- /Code_image/filters_chau.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/filters_chau.png -------------------------------------------------------------------------------- /single_image/houses_big.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/single_image/houses_big.jpg -------------------------------------------------------------------------------- /Code_Signal/Result/codes_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_0.png -------------------------------------------------------------------------------- /Code_Signal/Result/filters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/filters.png -------------------------------------------------------------------------------- /Code_image/coding_wrt_beta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/coding_wrt_beta.png -------------------------------------------------------------------------------- /Code_image/objective_func.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/objective_func.png -------------------------------------------------------------------------------- /Code_image/reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/reconstruction.png -------------------------------------------------------------------------------- /Code_Signal/Result/codes_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_10.png -------------------------------------------------------------------------------- /Code_Signal/Result/codes_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_20.png -------------------------------------------------------------------------------- /Code_Signal/Result/codes_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_30.png -------------------------------------------------------------------------------- /Code_Signal/Result/codes_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_40.png -------------------------------------------------------------------------------- /Code_Signal/Result/codes_0_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/codes_0_fit.png -------------------------------------------------------------------------------- /Code_Signal/Result/objective_func.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/objective_func.png -------------------------------------------------------------------------------- /Code_Signal/Result/reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/reconstruction.png -------------------------------------------------------------------------------- /Code_image/reconstr_err_wrt_beta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_image/reconstr_err_wrt_beta.png -------------------------------------------------------------------------------- /Code_Signal/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 0 6 | } 7 | -------------------------------------------------------------------------------- /Code_Signal/Result/original_signal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PC91/fast-flexible-conv-sparse-coding/HEAD/Code_Signal/Result/original_signal.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for the paper "Fast and Flexible Convolutional Sparse Coding" in Python 2 | Re-implementation of the paper Fast and Flexible Convolutional Sparse Coding (of 3 authors Felix Heide, Wolfgang Heidrich, Gordon Wetzstein) in Python 3 | https://www.cv-foundation.org/openaccess/content_cvpr_2015/app/3B_075.pdf 4 | -------------------------------------------------------------------------------- /Code_image/compare_beta.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to visualize the reconstruction and codes w.r.t. beta 3 | """ 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.cm as cm 7 | 8 | A = np.load('reconstr_err_wrt_beta.npz') 9 | 10 | list_reconstr_err_wrt_beta=A['list_reconstr_err_wrt_beta'] 11 | 12 | #Draw the reconstruction errors 13 | fig0 = plt.figure(0) 14 | plt.plot(np.arange(0.5, 5.1, 0.5), list_reconstr_err_wrt_beta[0:10]) 15 | plt.xlabel('Beta') 16 | plt.ylabel('Reconstruction error') 17 | plt.savefig('reconstr_err_wrt_beta.png') 18 | 19 | 20 | nrow = 10 21 | ncol = 11 22 | 23 | #Draw the 10 codes and the reconstructed image of the 10 original images 24 | fig1 = plt.figure(1) 25 | cbar_ax_code = fig1.add_axes([0, 0.15, 0.05, 0.7]) 26 | cbar_ax_recon = fig1.add_axes([0.92, 0.15, 0.05, 0.7]) 27 | for i in range(nrow): 28 | I = np.load('medium_10_5_5_beta_' + str((i + 1) / 2.0) + '.npz') 29 | z = I['z'] 30 | Dz = I['Dz'] 31 | for j in range(ncol): 32 | fig1.add_subplot(nrow, ncol, i * ncol + j + 1) 33 | if (j == ncol - 1): 34 | #reconstructed image 35 | im_recon = plt.imshow(Dz[0], cmap = cm.Greys_r) 36 | else: 37 | #sparse codes 38 | im_code = plt.imshow(abs(z[0][j]), cmap = cm.Greys_r, vmin = np.min(abs(z)), vmax = np.max(abs(z))) 39 | plt.axis('off') 40 | plt.subplots_adjust(left=0.08, bottom=0.0, right=0.9, top=1.0, wspace=0.15, hspace=0.05) 41 | fig1.colorbar(im_code, cax=cbar_ax_code) 42 | fig1.colorbar(im_recon, cax=cbar_ax_recon) 43 | mng = plt.get_current_fig_manager() 44 | mng.window.showMaximized() 45 | plt.savefig('coding_wrt_beta.png') -------------------------------------------------------------------------------- /Code_image/visualize_results_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to visualize stored results from one time of running the method 3 | """ 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.cm as cm 7 | 8 | #Load the result struct 9 | A = np.load('medium_10_5_5.npz') 10 | #Kernel 11 | d=A['d'] 12 | #Sparse codes 13 | z=A['z'] 14 | #Reconstructed images 15 | Dz=A['Dz'] 16 | #List of objective function's values over iterations 17 | list_obj_val_z = A['list_obj_val_z'] 18 | 19 | #Draw the objective function's values over iterations 20 | fig0 = plt.figure(0) 21 | plt.plot(list_obj_val_z[1:10]) 22 | plt.xticks(np.arange(0, 9), np.arange(2, 11)) 23 | plt.xlabel('Current iteration') 24 | plt.ylabel('Objective function') 25 | plt.savefig('objective_func.png') 26 | 27 | nrow = 1 28 | ncol = 10 29 | 30 | #Draw 10 reconstructed images 31 | fig1 = plt.figure(1) 32 | for i in range(1): 33 | for j in range(10): 34 | fig1.add_subplot(nrow, ncol, i * ncol + j + 1) 35 | plt.imshow(Dz[i * ncol + j], cmap = cm.Greys_r, vmin = np.min(Dz), vmax = np.max(Dz)) 36 | plt.axis('off') 37 | plt.savefig('reconstruction.png', bbox_inches='tight') 38 | 39 | #Draw 10 kernels 40 | fig2 = plt.figure(2) 41 | for i in range(nrow): 42 | for j in range(ncol): 43 | fig2.add_subplot(nrow, ncol, i * ncol + j + 1) 44 | plt.imshow(d[i * ncol + j], cmap = cm.Greys_r, vmin = np.min(d), vmax = np.max(d)) 45 | plt.axis('off') 46 | plt.colorbar() 47 | plt.savefig('filters_chau.png', bbox_inches='tight') 48 | 49 | #Draw 10 codes corresponding to 10 kernels of the first image 50 | fig3 = plt.figure(3) 51 | for i in range(nrow): 52 | for j in range(ncol): 53 | fig3.add_subplot(nrow, ncol, i * ncol + j + 1) 54 | plt.imshow(abs(z[0][i * ncol + j]), cmap = cm.Greys_r, vmin = np.min(abs(z[0])), vmax = np.max(abs(z[0]))) 55 | plt.axis('off') 56 | plt.savefig('coding.png', bbox_inches='tight') -------------------------------------------------------------------------------- /Code_Signal/main_signal_unknown_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to find a kernel and the corresponding codes for signal data 3 | """ 4 | import numpy as np 5 | from scipy.io import loadmat 6 | import CSC_signal as CSC 7 | import time 8 | 9 | start = time.clock() 10 | 11 | #Set this variable the data path to the EEG dataset 12 | #ata_path = '/home/chau/Dropbox/PRIM/EEG/' 13 | data_path = '/Users/huydinh/Dropbox/PRIM/EEG/' 14 | 15 | #Set this variable the path for the result folder 16 | result_path = '/home/chau/Dropbox/PRIM/Code_Signal/Result/' 17 | 18 | #Load data to be run 19 | mat = loadmat(data_path + 'eeg_data.mat') 20 | eeg_signal = mat['eeg_signal'] 21 | eeg_label = mat['eeg_label'] 22 | 23 | #Extract the wanted subset of data 24 | eeg_signal = np.concatenate((eeg_signal[0:1,:], eeg_signal[20:21,:],\ 25 | eeg_signal[40:41,:], eeg_signal[60:61,:],\ 26 | eeg_signal[80:81,:]), axis=0) 27 | 28 | n_signal = eeg_signal.shape[0] #the number of signals 29 | n_sample = eeg_signal.shape[1] #dimension of each signal 30 | 31 | #Normalize the data to be run 32 | b = eeg_signal 33 | b = ((b.T - np.mean(b, axis = 1)) / np.std(b, axis = 1)).T 34 | 35 | #Define the parameters 36 | #128 kernels with size of 201 37 | size_kernel = [128, 201] 38 | 39 | #Optim options 40 | max_it = 30 #the number of iterations 41 | tol = np.float64(1e-3) #the stop threshold for the algorithm 42 | 43 | #RUN THE ALGORITHM 44 | [d, z, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err] = \ 45 | CSC.learn_conv_sparse_coder(b, size_kernel, max_it, tol) 46 | 47 | #Save into an external struct .npz the kernel, the codes, the reconstruction 48 | #and list of objective function's values 49 | np.savez(result_path + 'result_eeg.npz',\ 50 | d=d, z=z, Dz=Dz, list_obj_val=list_obj_val,\ 51 | list_obj_val_filter=list_obj_val_filter, list_obj_val_z=list_obj_val_z) 52 | 53 | end = time.clock() 54 | 55 | print end - start -------------------------------------------------------------------------------- /Code_Signal/main_signal_known_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to find codes for signal data when the filter is already known 3 | """ 4 | import numpy as np 5 | from scipy.io import loadmat 6 | import CSC_signal as CSC 7 | import time 8 | 9 | start = time.clock() 10 | 11 | #Set this variable the data path to the EEG dataset 12 | #data_path = '/home/chau/Dropbox/PRIM/EEG/' 13 | data_path = '/Users/huydinh/Dropbox/PRIM/EEG/' 14 | 15 | #Set this variable the path for the result folder 16 | #result_path = '/home/chau/Dropbox/PRIM/Code_Signal/Result/' 17 | result_path = '/Users/huydinh/Dropbox/PRIM/Code_Signal/Result/' 18 | 19 | #Load the available result from previous times of running 20 | output = np.load(result_path + 'result_eeg_50_128_201.npz') 21 | 22 | #d is the known filter 23 | d = output['d'] 24 | #list of objective function's values after each iteration 25 | list_obj_val = output['list_obj_val'] 26 | #list of objective function's values after each d update 27 | list_obj_val_filter = output['list_obj_val_filter'] 28 | #list of objective function's values after each z update 29 | list_obj_val_z = output['list_obj_val_z'] 30 | 31 | 32 | #Load data to be run 33 | mat = loadmat(data_path + 'eeg_data.mat') 34 | eeg_signal = mat['eeg_signal'] 35 | eeg_label = mat['eeg_label'] 36 | 37 | #Extract the wanted subset of data 38 | eeg_signal = eeg_signal[99,:] 39 | eeg_signal = eeg_signal[None, :] 40 | 41 | n_signal = 1 #the number of signals 42 | n_sample = eeg_signal.shape[1] #dimension of each signal 43 | 44 | #Normalize the data to be run 45 | b = eeg_signal 46 | b = ((b.T - np.mean(b, axis = 1)) / np.std(b, axis = 1)).T 47 | 48 | #Define the parameters 49 | #Size of d, i.e. 128 kernels with size of 201 50 | size_kernel = [128, 201] 51 | 52 | #Optim options 53 | max_it = 30 #the number of iterations 54 | tol = np.float64(1e-3) #the stop threshold for the algorithm 55 | 56 | #RUN THE ALGORITHM 57 | [temp, z, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err] = \ 58 | CSC.learn_conv_sparse_coder(b, size_kernel, max_it, tol,\ 59 | known_d=d) 60 | 61 | end = time.clock() 62 | 63 | print end - start -------------------------------------------------------------------------------- /Code_image/main_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to find a kernel and the corresponding codes for image data 3 | """ 4 | import numpy as np 5 | from PIL import Image 6 | import CSC_image as CSC 7 | import os 8 | import time 9 | 10 | def listdir_nohidden(path): 11 | for f in os.listdir(path): 12 | if not f.startswith('.'): 13 | yield f 14 | 15 | start = time.clock() 16 | 17 | #Set this variable the data path to the image dataset 18 | #path = '/home/chau/Dropbox/PRIM/small_images/' 19 | path = '/Users/huydinh/Dropbox/PRIM/medium_images/' 20 | 21 | #Get the images 22 | fileList_gen = listdir_nohidden(path) 23 | fileList = list(fileList_gen) 24 | sample_image = Image.open(path + fileList[0], 'r') 25 | 26 | b = np.zeros((len(fileList), sample_image.size[1], sample_image.size[0]), dtype=np.float64) 27 | 28 | #Normalize each image to zero mean and unit variance 29 | idx = 0 30 | for infile in fileList: 31 | image = Image.open(path + infile, 'r').convert('L') 32 | b[idx,:] = np.asarray(image.getdata(),dtype=np.float64).\ 33 | reshape((image.size[1], image.size[0])) 34 | b[idx,:] = b[idx,:] / 255; 35 | b[idx,:] = (b[idx,:] - np.mean(b[idx,:])) / np.std(b[idx,:]) 36 | 37 | idx += 1 38 | 39 | #Define the parameters 40 | #10 kernels with size of 5 x 5 41 | size_kernel = [10, 5, 5] 42 | 43 | #Optim options 44 | max_it = 20 #number of iterations 45 | tol = np.float64(1e-3) #stop threshold 46 | 47 | #List of testing beta 48 | list_reconstr_err_wrt_beta = np.zeros(10) 49 | 50 | #Run the algorithm with beta from 0.5 to 5.0 by the increase of 0.5 51 | for beta in np.arange(0.5, 5.1, 0.5): 52 | [d, z, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err] = \ 53 | CSC.learn_conv_sparse_coder(b, size_kernel, max_it, tol, beta) 54 | list_reconstr_err_wrt_beta[beta * 2 - 1] = reconstr_err 55 | 56 | np.savez('medium_10_5_5_beta_' + str(beta) + '.npz', d=d, z=z, Dz=Dz, 57 | list_obj_val=list_obj_val, 58 | list_obj_val_filter=list_obj_val_filter, 59 | list_obj_val_z=list_obj_val_z) 60 | 61 | #Save the list of reconstruction errors with respect to beta 62 | np.savez('reconstr_err_wrt_beta.npz', list_reconstr_err_wrt_beta=list_reconstr_err_wrt_beta) 63 | end = time.clock() 64 | 65 | print end - start -------------------------------------------------------------------------------- /Code_Signal/visualize_code fitting.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to visualize stored results from one time of running the method 3 | """ 4 | import numpy as np 5 | from scipy.io import loadmat 6 | import matplotlib.pyplot as plt 7 | 8 | plt.close("all") 9 | 10 | #Set this variable the data path to the EEG dataset 11 | data_path = '/home/chau/Dropbox/PRIM/EEG/' 12 | d#ata_path = '/Users/huydinh/Dropbox/PRIM/EEG/' 13 | 14 | #Set this variable the path for the result folder 15 | result_path = '/home/chau/Dropbox/PRIM/Code_Signal/Result/' 16 | #result_path = '/Users/huydinh/Dropbox/PRIM/Code_Signal/Result/' 17 | 18 | #Load data to be run 19 | mat = loadmat(data_path + 'eeg_data.mat') 20 | eeg_signal = mat['eeg_signal'] 21 | eeg_label = mat['eeg_label'] 22 | 23 | #Load the available result from previous times of running 24 | output = np.load(result_path + 'result_eeg_50_128_201.npz') 25 | #Kernel 26 | d = output['d'] 27 | #Sparse codes 28 | z = output['z'] 29 | #Reconstrution 30 | Dz = output['Dz'] 31 | #list of objective function's values after each iteration 32 | list_obj_val =output['list_obj_val'] 33 | #list of objective function's values after each d update 34 | list_obj_val_filter = output['list_obj_val_filter'] 35 | #list of objective function's values after each z update 36 | list_obj_val_z = output['list_obj_val_z'] 37 | 38 | #Only take the main part of the reconstruction, not the added border part by the kernel 39 | Dz = Dz[:,100:6100] 40 | #Only take the main part of the codes, not the added border part by the kernel 41 | z = z[:,:,100:6100] 42 | #Make a signal become 0 if it is very close to 0 43 | z[np.max(z,axis=2)<0.005] = 0 44 | 45 | #Number of signals 46 | n_signal = eeg_signal.shape[0] 47 | #Dimention of each signal 48 | n_sample = eeg_signal.shape[1] 49 | #Number of kernels 50 | n_filter = d.shape[0] 51 | 52 | #Normalize the data 53 | b = eeg_signal 54 | b = ((b.T - np.mean(b, axis = 1)) / np.std(b, axis = 1)).T 55 | 56 | #Extract the corresponding original data for the stored result 57 | b = np.concatenate((b[0:10,:], b[20:30,:],\ 58 | b[40:50,:], b[60:70,:],\ 59 | b[80:90,:]), axis=0) 60 | 61 | #Draw the first singla and its first 10 codes corresponding to the first 10 kernels 62 | colors = ['r','g','b'] 63 | nrow = 11 64 | id_signal = 0 65 | fig4 = plt.figure(114) 66 | fig4.add_subplot(nrow, 1, 1) 67 | plt.plot(b[0]) 68 | plt.axis('off') 69 | for i in range(nrow-1): 70 | fig4.add_subplot(nrow, 1, i+2) 71 | plt.plot(z[id_signal, i, :], color=colors[np.mod(i,3)]) 72 | plt.axis('off') 73 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 74 | figManager = plt.get_current_fig_manager() 75 | figManager.window.showMaximized() 76 | fig4.canvas.set_window_title('codes_' + str(id_signal)) -------------------------------------------------------------------------------- /Code_Signal/visualize_results_signal.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to visualize stored results from one time of running the method 3 | """ 4 | import numpy as np 5 | from scipy.io import loadmat 6 | import matplotlib.pyplot as plt 7 | 8 | plt.close("all") 9 | 10 | #Set this variable the data path to the EEG dataset 11 | data_path = '/home/chau/Dropbox/PRIM/EEG/' 12 | #data_path = '/Users/huydinh/Dropbox/PRIM/EEG/' 13 | 14 | #Set this variable the path for the result folder 15 | result_path = '/home/chau/Dropbox/PRIM/Code_Signal/Result/' 16 | #result_path = '/Users/huydinh/Dropbox/PRIM/Code_Signal/Result/' 17 | 18 | #Load data to be run 19 | mat = loadmat(data_path + 'eeg_data.mat') 20 | eeg_signal = mat['eeg_signal'] 21 | eeg_label = mat['eeg_label'] 22 | 23 | #Load the available result from previous times of running 24 | output = np.load(result_path + 'result_eeg_50_128_201.npz') 25 | #Kernel 26 | d = output['d'] 27 | #Sparse codes 28 | z = output['z'] 29 | #Reconstruction 30 | Dz = output['Dz'] 31 | #list of objective function's values after each iteration 32 | list_obj_val =output['list_obj_val'] 33 | #list of objective function's values after each d update 34 | list_obj_val_filter = output['list_obj_val_filter'] 35 | #list of objective function's values after each z update 36 | list_obj_val_z = output['list_obj_val_z'] 37 | 38 | #Only take the main part of the reconstruction, not the added part by the kernel 39 | Dz = Dz[:,100:6100] 40 | #Make a signal become 0 if it is very close to 0 41 | z[np.max(z,axis=2)<0.005] = 0 42 | 43 | #Number of signals 44 | n_signal = eeg_signal.shape[0] 45 | #Dimention of each signal 46 | n_sample = eeg_signal.shape[1] 47 | #Number of kernels 48 | n_filter = d.shape[0] 49 | 50 | #Normalize the data 51 | b = eeg_signal 52 | b = ((b.T - np.mean(b, axis = 1)) / np.std(b, axis = 1)).T 53 | 54 | #Extract the corresponding original data for the stored result 55 | b = np.concatenate((b[0:10,:], b[20:30,:],\ 56 | b[40:50,:], b[60:70,:],\ 57 | b[80:90,:]), axis=0) 58 | 59 | #Draw the objective function's values over iterations 60 | fig0 = plt.figure(0) 61 | plt.plot(list_obj_val_z[5:30]) 62 | plt.xticks(np.arange(0, 25, 2), np.arange(6, 31, 2)) 63 | plt.xlabel('Current iteration') 64 | plt.ylabel('Objective function') 65 | plt.savefig(result_path+'objective_func.png') 66 | 67 | 68 | #Draw 5 sample original signals 69 | fig1 = plt.figure(1) 70 | for i in range(5): 71 | fig1.add_subplot(10, 1, i+1) 72 | plt.plot(b[i * 10, :]) 73 | plt.axis('off') 74 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, hspace=0.05) 75 | figManager = plt.get_current_fig_manager() 76 | fig1.canvas.set_window_title('original signal') 77 | plt.savefig(result_path+'original_signal.png') 78 | 79 | #Draw 5 corresponding reconstructed signals 80 | fig2 = plt.figure(2) 81 | for i in range(5): 82 | fig2.add_subplot(10, 1, i+1) 83 | plt.plot(Dz[i * 10, :]) 84 | plt.axis('off') 85 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, hspace=0.05) 86 | figManager = plt.get_current_fig_manager() 87 | fig2.canvas.set_window_title('reconstruction') 88 | plt.savefig(result_path+'reconstruction.png') 89 | 90 | 91 | #Draw 128 kernels 92 | nrow = 16 93 | ncol = 8 94 | fig3 = plt.figure(3) 95 | for i in range(nrow): 96 | for j in range(ncol): 97 | if (i * ncol + j < n_filter): 98 | fig3.add_subplot(nrow, ncol, i * ncol + j + 1) 99 | plt.plot(d[i*ncol+j, :]) 100 | plt.axis('off') 101 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 102 | figManager = plt.get_current_fig_manager() 103 | fig3.canvas.set_window_title('filters') 104 | plt.savefig(result_path+'filters.png') 105 | 106 | #Draw 128 codes corresponding to 128 kernels of signal id_signal, herer the 0th signal 107 | nrow = 16 108 | ncol = 8 109 | id_signal = 0 110 | fig4 = plt.figure(114) 111 | for i in range(nrow): 112 | for j in range(ncol): 113 | if (i * ncol + j < n_filter): 114 | fig4.add_subplot(nrow, ncol, i * ncol + j + 1) 115 | plt.plot(z[id_signal, i * ncol + j, :]) 116 | plt.axis('off') 117 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 118 | figManager = plt.get_current_fig_manager() 119 | fig4.canvas.set_window_title('codes_' + str(id_signal)) 120 | plt.savefig(result_path+'codes_' + str(id_signal) + '.png') 121 | 122 | #Draw 128 codes corresponding to 128 kernels of signal id_signal, herer the 10th signal 123 | nrow = 16 124 | ncol = 8 125 | id_signal = 10 126 | fig5 = plt.figure(115) 127 | for i in range(nrow): 128 | for j in range(ncol): 129 | if (i * ncol + j < n_filter): 130 | fig5.add_subplot(nrow, ncol, i * ncol + j + 1) 131 | plt.plot(z[id_signal, i * ncol + j, :]) 132 | plt.axis('off') 133 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 134 | figManager = plt.get_current_fig_manager() 135 | fig5.canvas.set_window_title('codes_' + str(id_signal)) 136 | plt.savefig(result_path+'codes_' + str(id_signal) + '.png') 137 | 138 | #Draw 128 codes corresponding to 128 kernels of signal id_signal, herer the 20th signal 139 | nrow = 16 140 | ncol = 8 141 | id_signal = 20 142 | fig6 = plt.figure(116) 143 | for i in range(nrow): 144 | for j in range(ncol): 145 | if (i * ncol + j < n_filter): 146 | fig6.add_subplot(nrow, ncol, i * ncol + j + 1) 147 | plt.plot(z[id_signal, i * ncol + j, :]) 148 | plt.axis('off') 149 | 150 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 151 | figManager = plt.get_current_fig_manager() 152 | fig6.canvas.set_window_title('codes_' + str(id_signal)) 153 | plt.savefig(result_path+'codes_' + str(id_signal) + '.png') 154 | 155 | #Draw 128 codes corresponding to 128 kernels of signal id_signal, herer the 30th signal 156 | nrow = 16 157 | ncol = 8 158 | id_signal = 30 159 | fig7 = plt.figure(117) 160 | for i in range(nrow): 161 | for j in range(ncol): 162 | if (i * ncol + j < n_filter): 163 | fig7.add_subplot(nrow, ncol, i * ncol + j + 1) 164 | plt.plot(z[id_signal, i * ncol + j, :]) 165 | plt.axis('off') 166 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 167 | figManager = plt.get_current_fig_manager() 168 | fig7.canvas.set_window_title('codes_' + str(id_signal)) 169 | plt.savefig(result_path+'codes_' + str(id_signal) + '.png') 170 | 171 | #Draw 128 codes corresponding to 128 kernels of signal id_signal, herer the 40th signal 172 | nrow = 16 173 | ncol = 8 174 | id_signal = 40 175 | fig8 = plt.figure(118) 176 | for i in range(nrow): 177 | for j in range(ncol): 178 | if (i * ncol + j < n_filter): 179 | fig8.add_subplot(nrow, ncol, i * ncol + j + 1) 180 | plt.plot(z[id_signal, i * ncol + j, :]) 181 | plt.axis('off') 182 | 183 | plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.15, hspace=0.05) 184 | figManager = plt.get_current_fig_manager() 185 | fig8.canvas.set_window_title('codes_' + str(id_signal)) 186 | plt.savefig(result_path+'codes_' + str(id_signal) + '.png') -------------------------------------------------------------------------------- /Code_Signal/CSC_signal.py: -------------------------------------------------------------------------------- 1 | """ 2 | - This code implements a solver for the paper "Fast and Flexible Convolutional Sparse Coding" on signal data. 3 | - The goal of this solver is to find the common filters, the codes for each signal series in the dataset 4 | and a reconstruction for the dataset. 5 | - The common filters (or kernels) and the codes for each image is denoted as d and z respectively 6 | - We denote the step to solve the filter as d-step and the codes z-step 7 | """ 8 | import numpy as np 9 | from scipy import linalg 10 | from scipy.fftpack import fft, ifft 11 | 12 | real_type = 'float64' 13 | imaginary_type = 'complex128' 14 | 15 | def learn_conv_sparse_coder(b, size_kernel, max_it, tol, 16 | known_d=None, 17 | beta=np.float64(1.0)): 18 | """ 19 | Main function to solve the convolutional sparse coding 20 | Parameters for this function 21 | - b : the signal dataset with size (num_signals, length) 22 | - size_kernel : the size of each kernel (num_kernels, length) 23 | - beta : the trade-off between sparsity and reconstruction error (as mentioned in the paper) 24 | - max_it : the maximum iterations of the outer loop 25 | - tol : the minimal difference in filters and codes after each iteration to continue 26 | - known_d : the predefined filters (if possible) 27 | 28 | Important variables used in the code: 29 | - u_D, u_Z : pair of proximal values for d-step and z-step 30 | - d_D, d_Z : pair of Lagrange multipliers in the ADMM algo for d-step and z-step 31 | - v_D, v_Z : pair of initial value pairs (Zd, d) for d-step and (Dz, z) for z-step 32 | """ 33 | 34 | k = size_kernel[0] 35 | n = b.shape[0] 36 | 37 | psf_radius = int(np.floor(size_kernel[1]/2)) 38 | 39 | size_x = [n, b.shape[1] + 2*psf_radius] 40 | size_z = [n, k, size_x[1]] 41 | size_k_full = [k, size_x[1]] 42 | 43 | #M is MtM, Mtb is Mtx, the matrix M is zero-padded in 2*psf_radius rows and cols 44 | M = np.pad(np.ones(b.shape, dtype = real_type), 45 | ((0,0), (psf_radius, psf_radius)), 46 | mode='constant', constant_values=0) 47 | Mtb = np.pad(b, ((0, 0), (psf_radius, psf_radius)),\ 48 | mode='constant', constant_values=0) 49 | 50 | 51 | """Penalty parameters, including the calculation of augmented Lagrange multipliers""" 52 | lambda_residual=np.float64(1.0) 53 | lambda_prior=np.float64(1.0) 54 | lambdas = [lambda_residual, lambda_prior] 55 | gamma_heuristic = 60 * lambda_prior * 1/np.amax(b) 56 | gammas_D = [gamma_heuristic / 5000, gamma_heuristic] 57 | gammas_Z = [gamma_heuristic / 500, gamma_heuristic] 58 | rho = gammas_D[1]/gammas_D[0] 59 | 60 | """Initialize variables for the d-step""" 61 | if known_d is None: 62 | varsize_D = [size_x, size_k_full] 63 | xi_D = [np.zeros(varsize_D[0], dtype = real_type), 64 | np.zeros(varsize_D[1], dtype = real_type)] 65 | 66 | xi_D_hat = [np.zeros(varsize_D[0], dtype = imaginary_type), 67 | np.zeros(varsize_D[1], dtype = imaginary_type)] 68 | 69 | u_D = [np.zeros(varsize_D[0], dtype = real_type), 70 | np.zeros(varsize_D[1], dtype = real_type)] 71 | 72 | #Lagrange multipliers 73 | d_D = [np.zeros(varsize_D[0], dtype = real_type), 74 | np.zeros(varsize_D[1], dtype = real_type)] 75 | 76 | v_D = [np.zeros(varsize_D[0], dtype = real_type), 77 | np.zeros(varsize_D[1], dtype = real_type)] 78 | 79 | d = np.random.normal(size=size_kernel) 80 | else: 81 | d = known_d 82 | 83 | #Initial the filters and its fft after being rolled to fit the frequency 84 | d = np.pad(d, ((0, 0), 85 | (0, size_x[1] - size_kernel[1])), 86 | mode='constant', constant_values=0) 87 | d = np.roll(d, -int(psf_radius), axis=1) 88 | d_hat = fft(d) 89 | 90 | """Initialize variables for the z-step""" 91 | varsize_Z = [size_x, size_z] 92 | xi_Z = [np.zeros(varsize_Z[0], dtype = real_type), 93 | np.zeros(varsize_Z[1], dtype = real_type)] 94 | 95 | xi_Z_hat = [np.zeros(varsize_Z[0], dtype = imaginary_type), 96 | np.zeros(varsize_Z[1], dtype = imaginary_type)] 97 | 98 | u_Z = [np.zeros(varsize_Z[0], dtype = real_type), 99 | np.zeros(varsize_Z[1], dtype = real_type)] 100 | 101 | #Lagrange multipliers 102 | d_Z = [np.zeros(varsize_Z[0], dtype = real_type), 103 | np.zeros(varsize_Z[1], dtype = real_type)] 104 | 105 | v_Z = [np.zeros(varsize_Z[0], dtype = real_type), 106 | np.zeros(varsize_Z[1], dtype = real_type)] 107 | 108 | #Initial the codes and its fft 109 | z = np.random.normal(size=size_z) 110 | z_hat = fft(z) 111 | 112 | 113 | """Initial objective function (usually very large)""" 114 | obj_val = obj_func(z_hat, d_hat, b, 115 | lambda_residual, lambda_prior, 116 | psf_radius, size_z, size_x) 117 | 118 | #Back-and-forth local iteration for d and z 119 | if known_d is None: 120 | max_it_d = 10 121 | else: 122 | max_it_d = 0 123 | 124 | max_it_z = 10 125 | 126 | obj_val_filter = obj_val 127 | obj_val_z = obj_val 128 | 129 | list_obj_val = np.zeros(max_it) 130 | list_obj_val_filter = np.zeros(max_it) 131 | list_obj_val_z = np.zeros(max_it) 132 | 133 | """Start the main algorithm""" 134 | for i in range(max_it): 135 | 136 | """D-STEP""" 137 | if known_d is None: 138 | #Precompute what is necessary for later 139 | [zhat_mat, zhat_inv_mat] = precompute_D_step(z_hat, size_z, rho) 140 | d_old = d 141 | 142 | for i_d in range(max_it_d): 143 | 144 | #Compute v = [Zd, d] 145 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 146 | v_D[0] = np.real(ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 147 | v_D[1] = d 148 | 149 | #Compute proximal updates 150 | u = v_D[0] - d_D[0] 151 | theta = lambdas[0]/gammas_D[0] 152 | u_D[0] = np.divide((Mtb + 1.0/theta * u), (M + 1.0/theta * np.ones(size_x))) 153 | 154 | u = v_D[1] - d_D[1] 155 | u_D[1] = KernelConstraintProj(u, size_k_full, psf_radius) 156 | 157 | #Update Langrange multipliers 158 | d_D[0] = d_D[0] + (u_D[0] - v_D[0]) 159 | d_D[1] = d_D[1] + (u_D[1] - v_D[1]) 160 | 161 | #Compute new xi=u+d and transform to fft 162 | xi_D[0] = u_D[0] + d_D[0] 163 | xi_D[1] = u_D[1] + d_D[1] 164 | xi_D_hat[0] = fft(xi_D[0]) 165 | xi_D_hat[1] = fft(xi_D[1]) 166 | 167 | #Solve convolutional inverse 168 | d_hat = solve_conv_term_D(zhat_mat, zhat_inv_mat, xi_D_hat, rho, size_z) 169 | d = np.real(ifft(d_hat)) 170 | 171 | if (i_d == max_it_d - 1): 172 | obj_val = obj_func(z_hat, d_hat, b, 173 | lambda_residual, lambda_prior, 174 | psf_radius, size_z, size_x) 175 | 176 | print('--> Obj %3.3f'% obj_val) 177 | 178 | 179 | obj_val_filter = obj_val 180 | 181 | #Debug progress 182 | d_diff = d - d_old 183 | d_comp = d 184 | 185 | if (i_d == max_it_d - 1): 186 | obj_val = obj_func(z_hat, d_hat, b, 187 | lambda_residual, lambda_prior, 188 | psf_radius, size_z, size_x) 189 | print('Iter D %d, Obj %3.3f, Diff %5.5f'% 190 | (i, obj_val, linalg.norm(d_diff) / linalg.norm(d_comp))) 191 | 192 | 193 | """Z-STEP""" 194 | #Precompute what is necessary for later 195 | [dhat_flat, dhatTdhat_flat] = precompute_Z_step(d_hat, size_x) 196 | dhatT_flat = np.ma.conjugate(dhat_flat.T) 197 | 198 | z_old = z 199 | 200 | for i_z in range(max_it_z): 201 | 202 | #Compute v = [Dz,z] 203 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 204 | v_Z[0] = np.real(ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 205 | v_Z[1] = z 206 | 207 | #Compute proximal updates 208 | u = v_Z[0] - d_Z[0] 209 | theta = lambdas[0]/gammas_Z[0] 210 | u_Z[0] = np.divide((Mtb + 1.0/theta * u ), (M + 1.0/theta * np.ones(size_x))) 211 | 212 | u = v_Z[1] - d_Z[1] 213 | theta = lambdas[1]/gammas_Z[1] * np.ones(u.shape) 214 | u_Z[1] = np.multiply(np.maximum(0, 1 - np.divide(theta, np.abs(u))), u) 215 | 216 | #Update Langrange multipliers 217 | d_Z[0] = d_Z[0] + (u_Z[0] - v_Z[0]) 218 | d_Z[1] = d_Z[1] + (u_Z[1] - v_Z[1]) 219 | 220 | #Compute new xi=u+d and transform to fft 221 | xi_Z[0] = u_Z[0] + d_Z[0] 222 | xi_Z[1] = u_Z[1] + d_Z[1] 223 | 224 | xi_Z_hat[0] = fft(xi_Z[0]) 225 | xi_Z_hat[1] = fft(xi_Z[1]) 226 | 227 | #Solve convolutional inverse 228 | z_hat = solve_conv_term_Z(dhatT_flat, dhatTdhat_flat, xi_Z_hat, gammas_Z, size_z) 229 | z = np.real(ifft(z_hat)) 230 | 231 | if (i_z == max_it_z - 1): 232 | obj_val = obj_func(z_hat, d_hat, b, 233 | lambda_residual, lambda_prior, 234 | psf_radius, size_z, size_x) 235 | 236 | print('--> Obj %3.3f'% obj_val) 237 | 238 | obj_val_z = obj_val 239 | 240 | list_obj_val[i] = obj_val 241 | list_obj_val_filter[i] = obj_val_filter 242 | list_obj_val_z[i] = obj_val_z 243 | 244 | #Debug progress 245 | z_diff = z - z_old 246 | z_comp = z 247 | 248 | print('Iter Z %d, Obj %3.3f, Diff %5.5f'% 249 | (i, obj_val, linalg.norm(z_diff) / linalg.norm(z_comp))) 250 | 251 | #Termination 252 | if (linalg.norm(z_diff) / linalg.norm(z_comp) < tol and 253 | linalg.norm(d_diff) / linalg.norm(d_comp) < tol): 254 | break 255 | 256 | """Final estimate""" 257 | z_res = z 258 | 259 | d_res = d 260 | d_res = np.roll(d_res, psf_radius, axis=1) 261 | d_res = d_res[:, 0:psf_radius*2+1] 262 | 263 | fft_dot = np.multiply(d_hat, z_hat) 264 | Dz = np.real(ifft(np.sum(fft_dot, axis=1).reshape(size_x))) 265 | 266 | obj_val = obj_func(z_hat, d_hat, b, 267 | lambda_residual, lambda_prior, 268 | psf_radius, size_z, size_x) 269 | print('Final objective function %f'% (obj_val)) 270 | 271 | reconstr_err = reconstruction_err(z_hat, d_hat, b, psf_radius, size_x) 272 | print('Final reconstruction error %f'% reconstr_err) 273 | 274 | return [d_res, z_res, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err] 275 | 276 | 277 | def KernelConstraintProj(u, size_k_full, psf_radius): 278 | """Computes the proximal operator for kernel by projection""" 279 | 280 | #Get support 281 | u_proj = u 282 | u_proj = np.roll(u_proj, psf_radius, axis=1) 283 | u_proj = u_proj[:, 0:psf_radius*2+1] 284 | 285 | #Normalize 286 | u_sum = np.sum(np.power(u_proj, 2), axis=1) 287 | u_norm = np.tile(u_sum, [u_proj.shape[1], 1]).transpose(1, 0) 288 | u_proj[u_norm >= 1] = u_proj[u_norm >= 1] / np.sqrt(u_norm[u_norm >= 1]) 289 | 290 | #Now shift back and pad again 291 | u_proj = np.pad(u_proj, ((0,0), 292 | (0, size_k_full[1] - (2*psf_radius+1))), 293 | mode='constant', constant_values=0) 294 | 295 | u_proj = np.roll(u_proj, -psf_radius, axis=1) 296 | 297 | return u_proj 298 | 299 | def precompute_D_step(z_hat, size_z, rho): 300 | """Computes to cache the values of Z^.T and (Z^.T*Z^ + rho*I)^-1 as in algorithm""" 301 | 302 | n = size_z[0] 303 | k = size_z[1] 304 | 305 | zhat_mat = np.transpose(z_hat, [2, 0, 1]) 306 | zhat_inv_mat = np.zeros((zhat_mat.shape[0], k, k), dtype=imaginary_type) 307 | inv_rho_z_hat_z_hat_t = np.zeros((zhat_mat.shape[0], n, n), dtype=imaginary_type) 308 | z_hat_mat_t = np.transpose(np.ma.conjugate(zhat_mat), [0, 2, 1]) 309 | 310 | #Compute Z_hat * Z_hat^T for each pixel 311 | z_hat_z_hat_t = np.einsum('knm,kmj->knj',zhat_mat, z_hat_mat_t) 312 | 313 | for i in range(zhat_mat.shape[0]): 314 | z_hat_z_hat_t_plus_rho = z_hat_z_hat_t[i] 315 | z_hat_z_hat_t_plus_rho.flat[::n + 1] += rho 316 | inv_rho_z_hat_z_hat_t[i] = linalg.pinv(z_hat_z_hat_t_plus_rho) 317 | 318 | zhat_inv_mat = 1.0/rho * (np.eye(k) - 319 | np.einsum('knm,kmj->knj', 320 | np.einsum('knm,kmj->knj', 321 | z_hat_mat_t, 322 | inv_rho_z_hat_z_hat_t), 323 | zhat_mat)) 324 | 325 | print ('Done precomputing for D') 326 | return [zhat_mat, zhat_inv_mat] 327 | 328 | 329 | def precompute_Z_step(dhat, size_x): 330 | """Computes to cache the values of D^.T and D^.T*D^ as in algorithm""" 331 | 332 | dhat_flat = dhat.T 333 | dhatTdhat_flat = np.sum(np.multiply(np.ma.conjugate(dhat_flat), dhat_flat), axis=1) 334 | print ('Done precomputing for Z') 335 | return [dhat_flat, dhatTdhat_flat] 336 | 337 | 338 | def solve_conv_term_D(zhat_mat, zhat_inv_mat, xi_hat, rho, size_z): 339 | """Solves argmin(||Zd - x1||_2^2 + rho * ||d - x2||_2^2""" 340 | 341 | k = size_z[1] 342 | sx = size_z[2] 343 | 344 | #Reshape to array per frequency 345 | xi_hat_0_flat = np.expand_dims(xi_hat[0].T, axis=2) 346 | xi_hat_1_flat = np.expand_dims(xi_hat[1].T, axis=2) 347 | 348 | x = np.zeros((zhat_mat.shape[0], k), dtype=imaginary_type) 349 | z_hat_mat_t = np.ma.conjugate(zhat_mat.transpose(0,2,1)) 350 | x = np.einsum("ijk, ikl -> ijl", zhat_inv_mat, 351 | np.einsum("ijk, ikl -> ijl", 352 | z_hat_mat_t, 353 | xi_hat_0_flat) + 354 | rho * xi_hat_1_flat) \ 355 | .reshape(sx, k) 356 | 357 | #Reshape to get back the new D^ 358 | d_hat = x.T 359 | 360 | return d_hat 361 | 362 | 363 | def solve_conv_term_Z(dhatT, dhatTdhat, xi_hat, gammas, size_z ): 364 | """Solves argmin(||Dz - x1||_2^2 + rho * ||z - x2||_2^2""" 365 | 366 | sx = size_z[2] 367 | rho = gammas[1]/gammas[0] 368 | 369 | #Compute b 370 | xi_hat_0_rep = np.expand_dims(xi_hat[0], axis = 1) 371 | xi_hat_1_rep = xi_hat[1] 372 | 373 | b = np.multiply(dhatT, xi_hat_0_rep) + rho * xi_hat_1_rep 374 | 375 | scInverse = np.divide(np.ones((1, sx)), rho*np.ones((1, sx)) + dhatTdhat.T) 376 | 377 | dhatT_dot_b = np.multiply(np.ma.conjugate(dhatT), b) 378 | dhatTb_rep = np.expand_dims(np.sum(dhatT_dot_b, axis=1), axis=1) 379 | x = 1.0/rho * (b - np.multiply(np.multiply(scInverse, dhatT), dhatTb_rep)) 380 | 381 | z_hat = x 382 | return z_hat 383 | 384 | 385 | def obj_func(z_hat, d_hat, b, lambda_residual, lambdas, \ 386 | psf_radius, size_z, size_x): 387 | """Computes the objective function including the data-fitting and regularization terms""" 388 | #Data-fitting term 389 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 390 | Dz = np.real(ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 391 | 392 | f_z = lambda_residual * 1.0/2.0 * \ 393 | np.power(linalg.norm(np.reshape(Dz[:, psf_radius:(Dz.shape[1] - psf_radius)] - b, \ 394 | -1, 1)), 2) 395 | #Regularizer 396 | z = ifft(z_hat) 397 | g_z = lambdas * np.sum(np.abs(z)) 398 | 399 | f_val = f_z + g_z 400 | return f_val 401 | 402 | 403 | def reconstruction_err(z_hat, d_hat, b, psf_radius, size_x): 404 | """Computes the reconstruction error from the data-fitting term""" 405 | 406 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 407 | Dz = np.real(ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 408 | 409 | err = 1.0/2.0 * \ 410 | np.power(linalg.norm(np.reshape(Dz[:, psf_radius:(Dz.shape[1] - psf_radius)] - b, \ 411 | -1, 1)), 2) 412 | 413 | return err -------------------------------------------------------------------------------- /Code_image/CSC_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | - This code implements a solver for the paper "Fast and Flexible Convolutional Sparse Coding". 3 | - The goal of this solver is to find the common filters (or kernel), the codes for each image 4 | in the dataset and a reconstruction for the dataset. 5 | - The common filters and the codes for each image is denoted as d and z respectively 6 | - We denote the step to solve the filter is d-step and the codes z-step 7 | """ 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | from scipy.fftpack import fft2, ifft2 12 | 13 | real_type = 'float64' 14 | imaginary_type = 'complex128' 15 | 16 | def learn_conv_sparse_coder(b, size_kernel, max_it, tol, beta=np.float64(1.0)): 17 | """ 18 | Main function to solve the convolutional sparse coding 19 | Parameters for this function 20 | - b : the image dataset with size (num_images, height, width) 21 | - size_kernel : the size of each kernel (num_kernels, height, width) 22 | - beta : the trade-off between sparsity and reconstruction error (as mentioned in the paper) 23 | - max_it : the maximum iterations of the outer loop 24 | - tol : the minimal difference in filters and codes after each iteration to continue 25 | 26 | Important variables used in the code: 27 | - u_D, u_Z : pair of proximal values for d-step and z-step 28 | - d_D, d_Z : pair of Lagrange multipliers in the ADMM algo for d-step and z-step 29 | - v_D, v_Z : pair of initial value pairs (Zd, d) for d-step and (Dz, z) for z-step 30 | """ 31 | 32 | psf_s = size_kernel[1] 33 | k = size_kernel[0] 34 | n = b.shape[0] 35 | 36 | psf_radius = int(np.floor(psf_s/2)) 37 | 38 | size_x = [n, b.shape[1] + 2*psf_radius, b.shape[2] + 2*psf_radius] 39 | size_z = [n, k, size_x[1], size_x[2]] 40 | size_k_full = [k, size_x[1], size_x[2]] 41 | 42 | #M is MtM, Mtb is Mtx, the matrix M is zero-padded in 2*psf_radius rows and cols 43 | M = np.pad(np.ones(b.shape, dtype = real_type), 44 | ((0,0), (psf_radius, psf_radius), (psf_radius, psf_radius)), 45 | mode='constant', constant_values=0) 46 | Mtb = np.pad(b, ((0, 0), (psf_radius, psf_radius), (psf_radius, psf_radius)),\ 47 | mode='constant', constant_values=0) 48 | 49 | 50 | """Penalty parameters, including the calculation of augmented Lagrange multipliers""" 51 | lambda_residual=np.float64(1.0) 52 | lambda_prior=np.float64(1.0) 53 | lambdas = [lambda_residual, lambda_prior] 54 | gamma_heuristic = 60 * lambda_prior * 1/np.amax(b) 55 | gammas_D = [gamma_heuristic / 5000, gamma_heuristic] 56 | gammas_Z = [gamma_heuristic / 500, gamma_heuristic] 57 | rho = gammas_D[1]/gammas_D[0] 58 | 59 | """Initialize variables for the d-step""" 60 | varsize_D = [size_x, size_k_full] 61 | xi_D = [np.zeros(varsize_D[0], dtype = real_type), 62 | np.zeros(varsize_D[1], dtype = real_type)] 63 | 64 | xi_D_hat = [np.zeros(varsize_D[0], dtype = imaginary_type), 65 | np.zeros(varsize_D[1], dtype = imaginary_type)] 66 | 67 | u_D = [np.zeros(varsize_D[0], dtype = real_type), 68 | np.zeros(varsize_D[1], dtype = real_type)] 69 | 70 | #Lagrange multipliers 71 | d_D = [np.zeros(varsize_D[0], dtype = real_type), 72 | np.zeros(varsize_D[1], dtype = real_type)] 73 | 74 | v_D = [np.zeros(varsize_D[0], dtype = real_type), 75 | np.zeros(varsize_D[1], dtype = real_type)] 76 | 77 | #Initial the filters and its fft after being rolled to fit the frequency 78 | d = np.random.normal(size=size_kernel) 79 | d = np.pad(d, ((0, 0), 80 | (0, size_x[1] - size_kernel[1]), 81 | (0, size_x[2] - size_kernel[2])), 82 | mode='constant', constant_values=0) 83 | d = np.roll(d, -int(psf_radius), axis=1) 84 | d = np.roll(d, -int(psf_radius), axis=2) 85 | d_hat = fft2(d) 86 | 87 | """Initialize variables for the z-step""" 88 | varsize_Z = [size_x, size_z] 89 | xi_Z = [np.zeros(varsize_Z[0], dtype = real_type), 90 | np.zeros(varsize_Z[1], dtype = real_type)] 91 | 92 | xi_Z_hat = [np.zeros(varsize_Z[0], dtype = imaginary_type), 93 | np.zeros(varsize_Z[1], dtype = imaginary_type)] 94 | 95 | u_Z = [np.zeros(varsize_Z[0], dtype = real_type), 96 | np.zeros(varsize_Z[1], dtype = real_type)] 97 | 98 | #Lagrange multipliers 99 | d_Z = [np.zeros(varsize_Z[0], dtype = real_type), 100 | np.zeros(varsize_Z[1], dtype = real_type)] 101 | 102 | v_Z = [np.zeros(varsize_Z[0], dtype = real_type), 103 | np.zeros(varsize_Z[1], dtype = real_type)] 104 | 105 | #Initial the codes and its fft 106 | z = np.random.normal(size=size_z) 107 | z_hat = fft2(z) 108 | 109 | 110 | """Initial objective function (usually very large)""" 111 | obj_val = obj_func(z_hat, d_hat, b, 112 | lambda_residual, lambda_prior, 113 | psf_radius, size_z, size_x) 114 | 115 | #Back-and-forth local iteration for d and z 116 | max_it_d = 10 117 | max_it_z = 10 118 | 119 | obj_val_filter = obj_val 120 | obj_val_z = obj_val 121 | 122 | list_obj_val = np.zeros(max_it) 123 | list_obj_val_filter = np.zeros(max_it) 124 | list_obj_val_z = np.zeros(max_it) 125 | 126 | """Start the main algorithm""" 127 | for i in range(max_it): 128 | 129 | """D-STEP""" 130 | #Precompute what is necessary for later 131 | [zhat_mat, zhat_inv_mat] = precompute_D_step(z_hat, size_z, rho) 132 | 133 | obj_val_min = min(obj_val_filter, obj_val_z) 134 | 135 | d_old = d 136 | d_hat_old = d_hat 137 | 138 | #Main loop to update filters 139 | for i_d in range(max_it_d): 140 | 141 | #Compute v = [Zd, d] 142 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 143 | v_D[0] = np.real(ifft2(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 144 | v_D[1] = d 145 | 146 | #Compute proximal updates 147 | u = v_D[0] - d_D[0] 148 | theta = lambdas[0]/gammas_D[0] 149 | u_D[0] = np.divide((Mtb + 1.0/theta * u), (M + 1.0/theta * np.ones(size_x))) 150 | 151 | u = v_D[1] - d_D[1] 152 | u_D[1] = KernelConstraintProj(u, size_k_full, psf_radius) 153 | 154 | #Update Langrange multipliers 155 | d_D[0] = d_D[0] + (u_D[0] - v_D[0]) 156 | d_D[1] = d_D[1] + (u_D[1] - v_D[1]) 157 | 158 | #Compute new xi=u+d and transform to fft 159 | xi_D[0] = u_D[0] + d_D[0] 160 | xi_D[1] = u_D[1] + d_D[1] 161 | xi_D_hat[0] = fft2(xi_D[0]) 162 | xi_D_hat[1] = fft2(xi_D[1]) 163 | 164 | #Solve convolutional inverse 165 | d_hat = solve_conv_term_D(zhat_mat, zhat_inv_mat, xi_D_hat, rho, size_z) 166 | d = np.real(ifft2(d_hat)) 167 | 168 | if (i_d == max_it_d - 1): 169 | obj_val = obj_func(z_hat, d_hat, b, 170 | lambda_residual, lambda_prior, 171 | psf_radius, size_z, size_x) 172 | 173 | print('--> Obj %3.3f'% obj_val) 174 | 175 | obj_val_filter = obj_val 176 | 177 | #Debug progress 178 | d_diff = d - d_old 179 | d_comp = d 180 | 181 | obj_val = obj_func(z_hat, d_hat, b, 182 | lambda_residual, lambda_prior, 183 | psf_radius, size_z, size_x) 184 | print('Iter D %d, Obj %3.3f, Diff %5.5f'% 185 | (i, obj_val, linalg.norm(d_diff) / linalg.norm(d_comp))) 186 | 187 | 188 | """Z-STEP""" 189 | #Precompute what is necessary for later 190 | [dhat_flat, dhatTdhat_flat] = precompute_Z_step(d_hat, size_x) 191 | dhatT_flat = np.ma.conjugate(dhat_flat.T) 192 | 193 | z_old = z 194 | z_hat_old = z_hat 195 | 196 | for i_z in range(max_it_z): 197 | 198 | #Compute v = [Dz,z] 199 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 200 | v_Z[0] = np.real(ifft2(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 201 | v_Z[1] = z 202 | 203 | #Compute proximal updates 204 | u = v_Z[0] - d_Z[0] 205 | theta = lambdas[0]/gammas_Z[0] 206 | u_Z[0] = np.divide((Mtb + 1.0/theta * u ), (M + 1.0/theta * np.ones(size_x))) 207 | 208 | u = v_Z[1] - d_Z[1] 209 | theta = lambdas[1]/gammas_Z[1] * np.ones(u.shape) 210 | u_Z[1] = np.multiply(np.maximum(0, 1 - np.divide(theta*beta, np.abs(u))), u) 211 | 212 | #Update Langrange multipliers 213 | d_Z[0] = d_Z[0] + (u_Z[0] - v_Z[0]) 214 | d_Z[1] = d_Z[1] + (u_Z[1] - v_Z[1]) 215 | 216 | #Compute new xi=u+d and transform to fft 217 | xi_Z[0] = u_Z[0] + d_Z[0] 218 | xi_Z[1] = u_Z[1] + d_Z[1] 219 | 220 | xi_Z_hat[0] = fft2(xi_Z[0]) 221 | xi_Z_hat[1] = fft2(xi_Z[1]) 222 | 223 | #Solve convolutional inverse 224 | z_hat = solve_conv_term_Z(dhatT_flat, dhatTdhat_flat, xi_Z_hat, gammas_Z, size_z) 225 | z = np.real(ifft2(z_hat)) 226 | 227 | if (i_z == max_it_z - 1): 228 | obj_val = obj_func(z_hat, d_hat, b, 229 | lambda_residual, lambda_prior, 230 | psf_radius, size_z, size_x) 231 | 232 | print('--> Obj %3.3f'% obj_val) 233 | 234 | obj_val_z = obj_val 235 | 236 | if (obj_val_min <= obj_val_filter and obj_val_min <= obj_val_z): 237 | z_hat = z_hat_old 238 | z = np.real(ifft2(z_hat)) 239 | 240 | d_hat = d_hat_old 241 | d = np.real(ifft2(d_hat)) 242 | 243 | obj_val = obj_func(z_hat, d_hat, b, 244 | lambda_residual, lambda_prior, 245 | psf_radius, size_z, size_x) 246 | break 247 | 248 | list_obj_val[i] = obj_val 249 | list_obj_val_filter[i] = obj_val_filter 250 | list_obj_val_z[i] = obj_val_z 251 | 252 | #Debug progress 253 | z_diff = z - z_old 254 | z_comp = z 255 | print('Iter Z %d, Obj %3.3f, Diff %5.5f'% 256 | (i, obj_val, linalg.norm(z_diff) / linalg.norm(z_comp))) 257 | 258 | #Termination 259 | if (linalg.norm(z_diff) / linalg.norm(z_comp) < tol and 260 | linalg.norm(d_diff) / linalg.norm(d_comp) < tol): 261 | break 262 | 263 | """Final estimate""" 264 | z_res = z 265 | 266 | d_res = d 267 | d_res = np.roll(d_res, psf_radius, axis=1) 268 | d_res = np.roll(d_res, psf_radius, axis=2) 269 | d_res = d_res[:, 0:psf_radius*2+1, 0:psf_radius*2+1] 270 | 271 | fft_dot = np.multiply(d_hat, z_hat) 272 | Dz = np.real(ifft2(np.sum(fft_dot, axis=1).reshape(size_x))) 273 | 274 | obj_val = obj_func(z_hat, d_hat, b, 275 | lambda_residual, lambda_prior, 276 | psf_radius, size_z, size_x) 277 | print('Final objective function %f'% (obj_val)) 278 | 279 | reconstr_err = reconstruction_err(z_hat, d_hat, b, psf_radius, size_x) 280 | print('Final reconstruction error %f'% reconstr_err) 281 | 282 | return [d_res, z_res, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err] 283 | 284 | 285 | def KernelConstraintProj(u, size_k_full, psf_radius): 286 | """Computes the proximal operator for kernel by projection""" 287 | 288 | #Get support 289 | u_proj = u 290 | u_proj = np.roll(u_proj, psf_radius, axis=1) 291 | u_proj = np.roll(u_proj, psf_radius, axis=2) 292 | u_proj = u_proj[:, 0:psf_radius*2+1, 0:psf_radius*2+1] 293 | 294 | #Normalize 295 | u_sum = np.sum(np.sum(np.power(u_proj, 2), axis=1), axis=1) 296 | u_norm = np.tile(u_sum, [u_proj.shape[1], u_proj.shape[2], 1]).transpose(2,0,1) 297 | u_proj[u_norm >= 1] = u_proj[u_norm >= 1] / np.sqrt(u_norm[u_norm >= 1]) 298 | 299 | #Now shift back and pad again 300 | u_proj = np.pad(u_proj, ((0,0), 301 | (0, size_k_full[1] - (2*psf_radius+1)), 302 | (0, size_k_full[2] - (2*psf_radius+1))), 303 | mode='constant', constant_values=0) 304 | 305 | u_proj = np.roll(u_proj, -psf_radius, axis=1) 306 | u_proj = np.roll(u_proj, -psf_radius, axis=2) 307 | 308 | return u_proj 309 | 310 | 311 | def precompute_D_step(z_hat, size_z, rho): 312 | """Computes to cache the values of Z^.T and (Z^.T*Z^ + rho*I)^-1 as in algorithm""" 313 | 314 | n = size_z[0] 315 | k = size_z[1] 316 | 317 | zhat_mat = np.transpose(z_hat.transpose(0,1,3,2).reshape(n, k, -1), [2, 0, 1]) 318 | zhat_inv_mat = np.zeros((zhat_mat.shape[0], k, k), dtype=imaginary_type) 319 | inv_rho_z_hat_z_hat_t = np.zeros((zhat_mat.shape[0], n, n), dtype=imaginary_type) 320 | z_hat_mat_t = np.transpose(np.ma.conjugate(zhat_mat), [0, 2, 1]) 321 | 322 | #Compute Z_hat * Z_hat^T for each pixel 323 | z_hat_z_hat_t = np.einsum('knm,kmj->knj',zhat_mat, z_hat_mat_t) 324 | 325 | for i in range(zhat_mat.shape[0]): 326 | z_hat_z_hat_t_plus_rho = z_hat_z_hat_t[i] 327 | z_hat_z_hat_t_plus_rho.flat[::n + 1] += rho 328 | inv_rho_z_hat_z_hat_t[i] = linalg.pinv(z_hat_z_hat_t_plus_rho) 329 | 330 | zhat_inv_mat = 1.0/rho * (np.eye(k) - 331 | np.einsum('knm,kmj->knj', 332 | np.einsum('knm,kmj->knj', 333 | z_hat_mat_t, 334 | inv_rho_z_hat_z_hat_t), 335 | zhat_mat)) 336 | 337 | print ('Done precomputing for D') 338 | return [zhat_mat, zhat_inv_mat] 339 | 340 | 341 | def precompute_Z_step(dhat, size_x): 342 | """Computes to cache the values of D^.T and D^.T*D^ as in algorithm""" 343 | 344 | dhat_flat = dhat.transpose(0,2,1).reshape((-1, size_x[1] * size_x[2])).T 345 | dhatTdhat_flat = np.sum(np.multiply(np.ma.conjugate(dhat_flat), dhat_flat), axis=1) 346 | print ('Done precomputing for Z') 347 | return [dhat_flat, dhatTdhat_flat] 348 | 349 | 350 | def solve_conv_term_D(zhat_mat, zhat_inv_mat, xi_hat, rho, size_z): 351 | """Solves argmin(||Zd - x1||_2^2 + rho * ||d - x2||_2^2""" 352 | 353 | n = size_z[0] 354 | k = size_z[1] 355 | sy = size_z[2] 356 | sx = size_z[3] 357 | 358 | xi_hat_0_flat = np.expand_dims(np.reshape(xi_hat[0].transpose(0,2,1), 359 | (n, sx * sy)).T, 360 | axis=2) 361 | xi_hat_1_flat = np.expand_dims(np.reshape(xi_hat[1].transpose(0,2,1), 362 | (k, sx * sy)).T, 363 | axis=2) 364 | 365 | x = np.zeros((zhat_mat.shape[0], k), dtype=imaginary_type) 366 | z_hat_mat_t = np.ma.conjugate(zhat_mat.transpose(0,2,1)) 367 | x = np.einsum("ijk, ikl -> ijl", zhat_inv_mat, 368 | np.einsum("ijk, ikl -> ijl", 369 | z_hat_mat_t, 370 | xi_hat_0_flat) + 371 | rho * xi_hat_1_flat) \ 372 | .reshape(sx * sy, k) 373 | 374 | #Reshape to get back the new D^ 375 | d_hat = np.reshape(x.T, (k, sx, sy)).transpose(0,2,1) 376 | 377 | return d_hat 378 | 379 | def solve_conv_term_Z(dhatT, dhatTdhat, xi_hat, gammas, size_z): 380 | """Solves argmin(||Dz - x1||_2^2 + rho * ||z - x2||_2^2""" 381 | 382 | n = size_z[0] 383 | k = size_z[1] 384 | sy = size_z[2] 385 | sx = size_z[3] 386 | 387 | rho = gammas[1]/gammas[0] 388 | 389 | xi_hat_0_rep = xi_hat[0].transpose([0,2,1]).reshape(n, 1, sy*sx) 390 | xi_hat_1_rep = xi_hat[1].transpose([0,1,3,2]).reshape(n, k, sy*sx) 391 | 392 | b = np.multiply(dhatT, xi_hat_0_rep) + rho * xi_hat_1_rep 393 | 394 | scInverse = np.divide(np.ones((1, sx*sy)), rho*np.ones((1, sx*sy)) + dhatTdhat.T) 395 | 396 | dhatT_dot_b = np.multiply(np.ma.conjugate(dhatT), b) 397 | dhatTb_rep = np.expand_dims(np.sum(dhatT_dot_b, axis=1), axis=1) 398 | x = 1.0/rho * (b - np.multiply(np.multiply(scInverse, dhatT), dhatTb_rep)) 399 | 400 | #Final transpose gives z_hat 401 | temp_size = np.zeros(4) 402 | temp_size[0] = size_z[0] 403 | temp_size[1] = size_z[1] 404 | temp_size[3] = size_z[2] 405 | temp_size[2] = size_z[3] 406 | z_hat = x.reshape(temp_size).transpose(0,1,3,2) 407 | 408 | return z_hat 409 | 410 | 411 | def obj_func(z_hat, d_hat, b, lambda_residual, lambdas, \ 412 | psf_radius, size_z, size_x): 413 | """Computes the objective function including the data-fitting and regularization terms""" 414 | #Data-fitting term 415 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 416 | Dz = np.real(ifft2(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 417 | 418 | f_z = lambda_residual * 1.0/2.0 * \ 419 | np.power(linalg.norm(np.reshape(Dz[:, psf_radius:(Dz.shape[1] - psf_radius), \ 420 | psf_radius:(Dz.shape[2] - psf_radius)] - b, \ 421 | -1, 1)), 2) 422 | #Regularizer 423 | z = ifft2(z_hat) 424 | g_z = lambdas * np.sum(np.abs(z)) 425 | 426 | f_val = f_z + g_z 427 | return f_val 428 | 429 | 430 | def reconstruction_err(z_hat, d_hat, b, psf_radius, size_x): 431 | """Computes the reconstruction error from the data-fitting term""" 432 | 433 | d_hat_dot_z_hat = np.multiply(d_hat, z_hat) 434 | Dz = np.real(ifft2(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x))) 435 | 436 | err = 1.0/2.0 * \ 437 | np.power(linalg.norm(np.reshape(Dz[:, psf_radius:(Dz.shape[1] - psf_radius), \ 438 | psf_radius:(Dz.shape[2] - psf_radius)] - b, \ 439 | -1, 1)), 2) 440 | return err --------------------------------------------------------------------------------