├── Data └── IMAGES_Vanhateren.mat ├── LICENSE ├── README.md ├── figures └── sparse_coding.png ├── sparse_coding_torch_Demo.ipynb ├── sparsify_PyTorch.py └── utility.py /Data/IMAGES_Vanhateren.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubeic/Sparse-Coding/92297d5c76ebf83763dc9dad61b5f08ac96d8f11/Data/IMAGES_Vanhateren.mat -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yubei Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Coding PyTorch 2 | 3 | This code implements sparse coding in PyTorch with the positive-only option. For the positive-only option, I only constraint the sparse coefficients to be non-negative. This choice is related but different from non-negative sparse coding or non-negative matrix factorization. The optimization solver used in this code is ISTA and FISTA. To demo the code, whitened natural images are adapted from: http://www.rctn.org/bruno/sparsenet/ 4 | 5 | ### A Sample of the Learned Dictionary 6 |

7 | 8 |

9 | 10 | In order to use this code, python 3 and PyTorch 0.4 or above is required. Please follow the steps in the following Jupyter notebook: 11 | 12 | ```bash 13 | sparse_coding_torch_Demo.ipynb 14 | ``` 15 | 16 | The following are some useful references: 17 | 18 | ### Sparse Coding 19 | ```bash 20 | @article{olshausen1996emergence, 21 | title={Emergence of simple-cell receptive field properties by learning a sparse code for natural images}, 22 | author={Olshausen, Bruno A and Field, David J}, 23 | journal={Nature}, 24 | volume={381}, 25 | number={6583}, 26 | pages={607}, 27 | year={1996}, 28 | publisher={Nature Publishing Group} 29 | } 30 | ``` 31 | ```bash 32 | @inproceedings{olshausen2013highly, 33 | title={Highly overcomplete sparse coding}, 34 | author={Olshausen, Bruno A}, 35 | booktitle={Human Vision and Electronic Imaging XVIII}, 36 | volume={8651}, 37 | pages={86510S}, 38 | year={2013}, 39 | organization={International Society for Optics and Photonics} 40 | } 41 | ``` 42 | 43 | ### FISTA Algorithm 44 | ```bash 45 | @article{beck2009fast, 46 | title={A fast iterative shrinkage-thresholding algorithm for linear inverse problems}, 47 | author={Beck, Amir and Teboulle, Marc}, 48 | journal={SIAM journal on imaging sciences}, 49 | volume={2}, 50 | number={1}, 51 | pages={183--202}, 52 | year={2009}, 53 | publisher={SIAM} 54 | } 55 | ``` 56 | 57 | ### Postive-Only Sparse Coding 58 | ``` 59 | @inproceedings{hoyer2002non, 60 | title={Non-negative sparse coding}, 61 | author={Hoyer, Patrik O}, 62 | booktitle={Proceedings of the 12th IEEE Workshop on Neural Networks for Signal Processing}, 63 | pages={557--565}, 64 | year={2002}, 65 | organization={IEEE} 66 | } 67 | ``` 68 | 69 | ### The Sparse Manifold Transform 70 | ``` 71 | @inproceedings{DBLP:conf/nips/ChenPO18, 72 | author = {Yubei Chen and 73 | Dylan M. Paiton and 74 | Bruno A. Olshausen}, 75 | title = {The Sparse Manifold Transform}, 76 | booktitle = {Advances in Neural Information Processing Systems 31: Annual Conference 77 | on Neural Information Processing Systems 2018, NeurIPS 2018, December 78 | 3-8, 2018, Montr{\'{e}}al, Canada}, 79 | pages = {10534--10545}, 80 | year = {2018} 81 | } 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /figures/sparse_coding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubeic/Sparse-Coding/92297d5c76ebf83763dc9dad61b5f08ac96d8f11/figures/sparse_coding.png -------------------------------------------------------------------------------- /sparsify_PyTorch.py: -------------------------------------------------------------------------------- 1 | # Yubei Chen, Sparse Manifold Transform Lib Ver 0.1 2 | """ 3 | This file contains multiple method to sparsify the coefficients 4 | """ 5 | import time 6 | import numpy as np 7 | import numpy.linalg as la 8 | #import utility 9 | from IPython.display import clear_output # This is to clean the output info to make the process cleaner 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | # import cupy as cp 14 | 15 | 16 | def quadraticBasisUpdate(basis, Res, ahat, lowestActivation, HessianDiag, stepSize = 0.001,constraint = 'L2', Noneg = False): 17 | """ 18 | This matrix update the basis function based on the Hessian matrix of the activation. 19 | It's very similar to Newton method. But since the Hessian matrix of the activation function is often ill-conditioned, we takes the pseudo inverse. 20 | 21 | Note: currently, we can just use the inverse of the activation energy. 22 | A better idea for this method should be caculating the local Lipschitz constant for each of the basis. 23 | The stepSize should be smaller than 1.0 * min(activation) to be stable. 24 | """ 25 | dBasis = stepSize*torch.mm(Res, ahat.t())/ahat.size(1) 26 | dBasis = dBasis.div_(HessianDiag+lowestActivation) 27 | basis = basis.add_(dBasis) 28 | if Noneg: 29 | basis = basis.clamp(min = 0.) 30 | if constraint == 'L2': 31 | basis = basis.div_(basis.norm(2,0)) 32 | return basis 33 | 34 | def ISTA_PN(I,basis,lambd,num_iter,eta=None, useMAGMA=True): 35 | # This is a positive-negative PyTorch-Ver ISTA solver 36 | # MAGMA uses CPU-GPU hybrid method to solve SVD problems, which is great for single task. When running multiple jobs, this flag should be turned off to leave the svd computation on only GPU. 37 | dtype = basis.type() 38 | batch_size=I.size(1) 39 | M = basis.size(1) 40 | if eta is None: 41 | if useMAGMA: 42 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0]) 43 | eta = 1./L 44 | else: 45 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1) 46 | eta = torch.from_numpy(eta.astype('float32')).cuda() 47 | 48 | #Res = torch.zeros(I.size()).type(dtype) 49 | #ahat = torch.zeros(M,batch_size).type(dtype) 50 | Res = torch.cuda.FloatTensor(I.size()).fill_(0) 51 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0) 52 | 53 | for t in range(num_iter): 54 | ahat = ahat.add(eta * basis.t().mm(Res)) 55 | ahat_sign = torch.sign(ahat) 56 | ahat.abs_() 57 | ahat.sub_(eta * lambd).clamp_(min = 0.) 58 | ahat.mul_(ahat_sign) 59 | Res = I - torch.mm(basis,ahat) 60 | return ahat, Res 61 | 62 | def FISTA(I,basis,lambd,num_iter,eta=None, useMAGMA=True): 63 | # This is a positive-only PyTorch-Ver FISTA solver 64 | dtype = basis.type() 65 | batch_size=I.size(1) 66 | M = basis.size(1) 67 | if eta is None: 68 | if useMAGMA: 69 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0]) 70 | eta = 1./L 71 | else: 72 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1) 73 | eta = torch.from_numpy(eta.astype('float32')).cuda() 74 | 75 | tk_n = 1. 76 | tk = 1. 77 | Res = torch.cuda.FloatTensor(I.size()).fill_(0) 78 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0) 79 | ahat_y = torch.cuda.FloatTensor(M,batch_size).fill_(0) 80 | 81 | for t in range(num_iter): 82 | tk = tk_n 83 | tk_n = (1+np.sqrt(1+4*tk**2))/2 84 | ahat_pre = ahat 85 | Res = I - torch.mm(basis,ahat_y) 86 | ahat_y = ahat_y.add(eta * basis.t().mm(Res)) 87 | ahat = ahat_y.sub(eta * lambd).clamp(min = 0.) 88 | ahat_y = ahat.add(ahat.sub(ahat_pre).mul((tk-1)/(tk_n))) 89 | Res = I - torch.mm(basis,ahat) 90 | return ahat, Res 91 | 92 | def ISTA(I,basis,lambd,num_iter,eta=None, useMAGMA=True): 93 | # This is a positive-only PyTorch-Ver ISTA solver 94 | dtype = basis.type() 95 | batch_size=I.size(1) 96 | M = basis.size(1) 97 | if eta is None: 98 | if useMAGMA: 99 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0]) 100 | eta = 1./L 101 | else: 102 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1) 103 | eta = torch.from_numpy(eta.astype('float32')).cuda() 104 | 105 | #Res = torch.zeros(I.size()).type(dtype) 106 | #ahat = torch.zeros(M,batch_size).type(dtype) 107 | Res = torch.cuda.FloatTensor(I.size()).fill_(0) 108 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0) 109 | 110 | for t in range(num_iter): 111 | ahat = ahat.add(eta * basis.t().mm(Res)) 112 | ahat = ahat.sub(eta * lambd).clamp(min = 0.) 113 | Res = I - torch.mm(basis,ahat) 114 | return ahat, Res 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | # Yubei Chen, Sparse Manifold Coding Lib Ver 0.1 2 | """ 3 | This file contains multiple utility functions 4 | """ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import time 8 | from mpl_toolkits.mplot3d import Axes3D 9 | import numpy.linalg as la 10 | from numpy import random 11 | 12 | 13 | def imshow(im,ax=None,nonBlock=False, title=None, vmin=None, vmax=None): 14 | axflag = True 15 | if ax == None: 16 | fig = plt.figure() 17 | ax = fig.gca() 18 | axflag = False 19 | axp = ax.imshow(im,cmap='gray',interpolation='none', vmin=vmin, vmax=vmax) 20 | if title!=None: 21 | ax.set_title(str(title)) 22 | if (~nonBlock) and (not axflag): 23 | cbaxes = fig.add_axes([0.9, 0.1, 0.03, 0.8]) # This is the position for the colorbar 24 | cb = plt.colorbar(axp, cax = cbaxes) 25 | 26 | def displayVecArry(basis,X=1,Y=1,ax='none',title=-1,nonBlock=False, equal_contrast = False,boundary='none'): 27 | axflag = True 28 | if ax == 'none': 29 | fig = plt.figure() 30 | ax = fig.gca() 31 | axflag = False 32 | basisTemp = basis.copy() 33 | if equal_contrast: 34 | #basis_mean = basisTemp.mean(1) 35 | for i in range(basisTemp.shape[-1]): 36 | basisTemp[:,i] = basisTemp[:,i] - basisTemp[:,i].mean() 37 | basisTemp[:,i] = basisTemp[:,i]/(basisTemp[:,i].max()-basisTemp[:,i].min()) 38 | 39 | if len(basisTemp.shape) == 1: 40 | basisTemp = np.reshape(basisTemp,[basisTemp.shape[0],1]) 41 | #if nonBlock: 42 | # plt.ion() 43 | #else: 44 | # plt.ioff() 45 | SHAPE = basisTemp.shape 46 | PATCH_SIZE = int(np.sqrt(SHAPE[0])) 47 | img = np.empty([(PATCH_SIZE+1)*X-1,(PATCH_SIZE+1)*Y-1]) 48 | img.fill(np.min(basisTemp)) 49 | if boundary != 'none': 50 | img.fill(np.min(boundary)) 51 | for i in range(X): 52 | for j in range(Y): 53 | img[(PATCH_SIZE+1)*i:(PATCH_SIZE+1)*(i+1)-1,\ 54 | (PATCH_SIZE+1)*j:(PATCH_SIZE+1)*(j+1)-1] = \ 55 | np.reshape(basisTemp[:,i*Y+j],[PATCH_SIZE,PATCH_SIZE]) 56 | ax.imshow(img,cmap='gray',interpolation='none') 57 | if title!=-1: 58 | ax.set_title(str(title)) 59 | ax.get_yaxis().set_visible(False) 60 | ax.get_xaxis().set_visible(False) 61 | ax.spines["right"].set_color("none") 62 | ax.spines["top"].set_color("none") 63 | ax.spines["left"].set_color("none") 64 | ax.spines["bottom"].set_color("none") 65 | #time.sleep(0.05) 66 | #if ~nonBlock: 67 | # plt.show() 68 | 69 | def plotFrameOff(ax): 70 | ax.get_yaxis().set_visible(False) 71 | ax.get_xaxis().set_visible(False) 72 | ax.spines["right"].set_color("none") 73 | ax.spines["top"].set_color("none") 74 | ax.spines["left"].set_color("none") 75 | ax.spines["bottom"].set_color("none") 76 | 77 | def createImage_FromVecArry(basis,X=1,Y=1): 78 | basisTemp = basis.copy() 79 | if len(basisTemp.shape) == 1: 80 | basisTemp = np.reshape(basisTemp,[basisTemp.shape[0],1]) 81 | SHAPE = basisTemp.shape 82 | PATCH_SIZE = int(np.sqrt(SHAPE[0])) 83 | img = np.empty([(PATCH_SIZE+1)*X-1,(PATCH_SIZE+1)*Y-1]) 84 | img.fill(np.min(basisTemp)) 85 | for i in range(X): 86 | for j in range(Y): 87 | img[(PATCH_SIZE+1)*i:(PATCH_SIZE+1)*(i+1)-1,\ 88 | (PATCH_SIZE+1)*j:(PATCH_SIZE+1)*(j+1)-1] = \ 89 | np.reshape(basisTemp[:,i*Y+j],[PATCH_SIZE,PATCH_SIZE]) 90 | return img 91 | 92 | def displayVec3D(basisFunction, ax = 'none', title=-1, nonBlock=False): 93 | """ 94 | Show a single basis function in wireframe, the basis function need to be 95 | reshaped into the right size 96 | """ 97 | axflag = True 98 | if ax == 'none': 99 | fig = plt.figure() 100 | ax = fig.add_subplot(111, projection='3d') 101 | axflag = False 102 | #if nonBlock: 103 | # plt.ion() 104 | #else: 105 | # plt.ioff() 106 | SHAPE = (basisFunction.T).shape 107 | #figure = plt.figure(fig) 108 | y = range(SHAPE[0]) 109 | x = range(SHAPE[1]) 110 | X, Y = np.meshgrid(x, y) 111 | ax.plot_wireframe(X, Y, basisFunction.T) 112 | if (~nonBlock) and (not axflag): 113 | fig.show() 114 | #plt.show() 115 | 116 | def displayVecSeq(VecSeq): 117 | return 118 | 119 | 120 | def displayFourierRecenter(fourierIm, ax = 'none', nonBlock=False): 121 | axflag = True 122 | if ax == 'none': 123 | fig = plt.figure() 124 | ax = fig.gca() 125 | axflag = False 126 | dim0Shift = np.int(np.floor((fourierIm.shape[0]-1)/2.)) 127 | dim1Shift = np.int(np.floor((fourierIm.shape[1]-1)/2.)) 128 | tempIm = np.roll(fourierIm,dim0Shift,axis=0) 129 | tempIm = np.roll(tempIm,dim1Shift,axis=1) 130 | ax.imshow(tempIm,interpolation='none', cmap = 'gray') 131 | if (~nonBlock) and (not axflag): 132 | fig.show() 133 | return tempIm 134 | 135 | 136 | def saveBasisParas(filename, _basis, _basis_size, _lambd, _iterations): 137 | np.savez(filename, basis=_basis, basis_size=_basis_size, lambd=_lambd, \ 138 | iterations = _iterations) 139 | 140 | 141 | def saveLattice(filename, _basis, _basis_size, _manifoldCoordinate, _settingParas): 142 | np.savez(filename, basis=_basis, basis_size=_basis_size, \ 143 | manifoldCoordinate=_manifoldCoordinate, settingParas=_settingParas) 144 | 145 | 146 | def loadLattice(filename): 147 | loaded = np.load(filename) 148 | return loaded['basis'], loaded['basis_size'], loaded['manifoldCoordinate'], loaded['settingParas'] 149 | 150 | 151 | def loadBasisParas(filename): 152 | loaded = np.load(filename) 153 | return loaded['basis'], loaded['basis_size'], loaded['lambd'], loaded['iterations'] 154 | 155 | 156 | def normalizeL2(vecArry): 157 | for i in range(vecArry.shape[1]): 158 | vecArry[:,i] = vecArry[:,i]/la.norm(vecArry[:,i],2) 159 | return vecArry 160 | 161 | def normalizeL1(vecArry): 162 | for i in range(vecArry.shape[1]): 163 | vecArry[:,i] = vecArry[:,i]/la.norm(vecArry[:,i],1) 164 | return vecArry 165 | 166 | def boundLinf(vecArry,bound): 167 | for i in range(vecArry.shape[1]): 168 | for j in range(vecArry.shape[0]): 169 | vecArry[j,i] = np.sign(vecArry[j,i])*np.min([np.abs(vecArry[j,i]),bound]) 170 | return vecArry 171 | 172 | def seq_filtering(Seq, filt): 173 | # This function temporally filter a time series 174 | # Each column of Seq is vector at a particular time step 175 | #TODO: Please finish this simple function 176 | return -1 177 | 178 | 179 | def sampleRandom(data, num): 180 | """ 181 | Currently data can not be 1d array 182 | """ 183 | dataSize = data.shape 184 | dataNum = dataSize[-1] 185 | sampleSize = np.array(dataSize) 186 | sampleSize[-1] = num 187 | sample = np.zeros(sampleSize) 188 | batch = np.floor(random.rand(num)*dataNum) 189 | batch = batch.astype(np.int) 190 | for i in range(num): 191 | sample[:,i] = data[:,batch[i]] 192 | return sample 193 | 194 | def sampleRandomWithParas(data,paras,num): 195 | """ 196 | Currently data can not be 1d array 197 | """ 198 | dataSize = data.shape 199 | paraSize = paras.shape 200 | dataNum = dataSize[-1] 201 | sampleSize = np.array(dataSize) 202 | sampleParaSize = np.array(paraSize) 203 | sampleSize[-1] = num 204 | sampleParaSize[-1] = num 205 | sample = np.zeros(sampleSize) 206 | sampleParas = np.zeros(sampleParaSize) 207 | batch = np.floor(random.rand(num)*dataNum) 208 | batch.astype(np.int) 209 | for i in range(num): 210 | sample[:,i] = data[:,batch[i]] 211 | sampleParas[:,i] = paras[:,batch[i]] 212 | return sample, sampleParas 213 | 214 | 215 | def errorMsg(msg): 216 | """ 217 | This function will output an error message msg 218 | """ 219 | try: 220 | raise Exception(msg) 221 | except Exception as inst: 222 | print(inst) 223 | 224 | def generalized_norm_square(V1,M): 225 | """ 226 | This function returns a vector of square of the generalized norm of each columns in V1 with respect to M. 227 | """ 228 | return np.diag(np.dot(V1.T,np.dot(M,V1))) 229 | 230 | def generalized_norm(V1,M): 231 | """ 232 | This function returns a vector of the generalized norm of each columns in V1 with respect to M. 233 | """ 234 | return np.sqrt(np.diag(np.dot(V1.T,np.dot(M,V1)))) 235 | 236 | def patch_translation(patch,xshift,yshift): 237 | """ 238 | Apply some pixel level translation on a given patch. It is not an in-place function. 239 | """ 240 | xdim = patch.shape[0] 241 | ydim = patch.shape[1] 242 | patch_new = patch.copy() 243 | patch_new[...] = 0 244 | locx = np.clip(-xshift,0,np.Infinity).astype('int') 245 | locy = np.clip(-yshift,0,np.Infinity).astype('int') 246 | locx_new = np.clip(xshift,0,np.Infinity).astype('int') 247 | locy_new = np.clip(yshift,0,np.Infinity).astype('int') 248 | patch_new[locx_new:xdim+xshift,locy_new:ydim+yshift] = patch[locx:xdim-xshift,locy:ydim-yshift] 249 | return patch_new 250 | 251 | 252 | --------------------------------------------------------------------------------