├── Data
└── IMAGES_Vanhateren.mat
├── LICENSE
├── README.md
├── figures
└── sparse_coding.png
├── sparse_coding_torch_Demo.ipynb
├── sparsify_PyTorch.py
└── utility.py
/Data/IMAGES_Vanhateren.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubeic/Sparse-Coding/92297d5c76ebf83763dc9dad61b5f08ac96d8f11/Data/IMAGES_Vanhateren.mat
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yubei Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sparse Coding PyTorch
2 |
3 | This code implements sparse coding in PyTorch with the positive-only option. For the positive-only option, I only constraint the sparse coefficients to be non-negative. This choice is related but different from non-negative sparse coding or non-negative matrix factorization. The optimization solver used in this code is ISTA and FISTA. To demo the code, whitened natural images are adapted from: http://www.rctn.org/bruno/sparsenet/
4 |
5 | ### A Sample of the Learned Dictionary
6 |
7 |
8 |
9 |
10 | In order to use this code, python 3 and PyTorch 0.4 or above is required. Please follow the steps in the following Jupyter notebook:
11 |
12 | ```bash
13 | sparse_coding_torch_Demo.ipynb
14 | ```
15 |
16 | The following are some useful references:
17 |
18 | ### Sparse Coding
19 | ```bash
20 | @article{olshausen1996emergence,
21 | title={Emergence of simple-cell receptive field properties by learning a sparse code for natural images},
22 | author={Olshausen, Bruno A and Field, David J},
23 | journal={Nature},
24 | volume={381},
25 | number={6583},
26 | pages={607},
27 | year={1996},
28 | publisher={Nature Publishing Group}
29 | }
30 | ```
31 | ```bash
32 | @inproceedings{olshausen2013highly,
33 | title={Highly overcomplete sparse coding},
34 | author={Olshausen, Bruno A},
35 | booktitle={Human Vision and Electronic Imaging XVIII},
36 | volume={8651},
37 | pages={86510S},
38 | year={2013},
39 | organization={International Society for Optics and Photonics}
40 | }
41 | ```
42 |
43 | ### FISTA Algorithm
44 | ```bash
45 | @article{beck2009fast,
46 | title={A fast iterative shrinkage-thresholding algorithm for linear inverse problems},
47 | author={Beck, Amir and Teboulle, Marc},
48 | journal={SIAM journal on imaging sciences},
49 | volume={2},
50 | number={1},
51 | pages={183--202},
52 | year={2009},
53 | publisher={SIAM}
54 | }
55 | ```
56 |
57 | ### Postive-Only Sparse Coding
58 | ```
59 | @inproceedings{hoyer2002non,
60 | title={Non-negative sparse coding},
61 | author={Hoyer, Patrik O},
62 | booktitle={Proceedings of the 12th IEEE Workshop on Neural Networks for Signal Processing},
63 | pages={557--565},
64 | year={2002},
65 | organization={IEEE}
66 | }
67 | ```
68 |
69 | ### The Sparse Manifold Transform
70 | ```
71 | @inproceedings{DBLP:conf/nips/ChenPO18,
72 | author = {Yubei Chen and
73 | Dylan M. Paiton and
74 | Bruno A. Olshausen},
75 | title = {The Sparse Manifold Transform},
76 | booktitle = {Advances in Neural Information Processing Systems 31: Annual Conference
77 | on Neural Information Processing Systems 2018, NeurIPS 2018, December
78 | 3-8, 2018, Montr{\'{e}}al, Canada},
79 | pages = {10534--10545},
80 | year = {2018}
81 | }
82 | ```
83 |
84 |
--------------------------------------------------------------------------------
/figures/sparse_coding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubeic/Sparse-Coding/92297d5c76ebf83763dc9dad61b5f08ac96d8f11/figures/sparse_coding.png
--------------------------------------------------------------------------------
/sparsify_PyTorch.py:
--------------------------------------------------------------------------------
1 | # Yubei Chen, Sparse Manifold Transform Lib Ver 0.1
2 | """
3 | This file contains multiple method to sparsify the coefficients
4 | """
5 | import time
6 | import numpy as np
7 | import numpy.linalg as la
8 | #import utility
9 | from IPython.display import clear_output # This is to clean the output info to make the process cleaner
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | # import cupy as cp
14 |
15 |
16 | def quadraticBasisUpdate(basis, Res, ahat, lowestActivation, HessianDiag, stepSize = 0.001,constraint = 'L2', Noneg = False):
17 | """
18 | This matrix update the basis function based on the Hessian matrix of the activation.
19 | It's very similar to Newton method. But since the Hessian matrix of the activation function is often ill-conditioned, we takes the pseudo inverse.
20 |
21 | Note: currently, we can just use the inverse of the activation energy.
22 | A better idea for this method should be caculating the local Lipschitz constant for each of the basis.
23 | The stepSize should be smaller than 1.0 * min(activation) to be stable.
24 | """
25 | dBasis = stepSize*torch.mm(Res, ahat.t())/ahat.size(1)
26 | dBasis = dBasis.div_(HessianDiag+lowestActivation)
27 | basis = basis.add_(dBasis)
28 | if Noneg:
29 | basis = basis.clamp(min = 0.)
30 | if constraint == 'L2':
31 | basis = basis.div_(basis.norm(2,0))
32 | return basis
33 |
34 | def ISTA_PN(I,basis,lambd,num_iter,eta=None, useMAGMA=True):
35 | # This is a positive-negative PyTorch-Ver ISTA solver
36 | # MAGMA uses CPU-GPU hybrid method to solve SVD problems, which is great for single task. When running multiple jobs, this flag should be turned off to leave the svd computation on only GPU.
37 | dtype = basis.type()
38 | batch_size=I.size(1)
39 | M = basis.size(1)
40 | if eta is None:
41 | if useMAGMA:
42 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0])
43 | eta = 1./L
44 | else:
45 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1)
46 | eta = torch.from_numpy(eta.astype('float32')).cuda()
47 |
48 | #Res = torch.zeros(I.size()).type(dtype)
49 | #ahat = torch.zeros(M,batch_size).type(dtype)
50 | Res = torch.cuda.FloatTensor(I.size()).fill_(0)
51 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0)
52 |
53 | for t in range(num_iter):
54 | ahat = ahat.add(eta * basis.t().mm(Res))
55 | ahat_sign = torch.sign(ahat)
56 | ahat.abs_()
57 | ahat.sub_(eta * lambd).clamp_(min = 0.)
58 | ahat.mul_(ahat_sign)
59 | Res = I - torch.mm(basis,ahat)
60 | return ahat, Res
61 |
62 | def FISTA(I,basis,lambd,num_iter,eta=None, useMAGMA=True):
63 | # This is a positive-only PyTorch-Ver FISTA solver
64 | dtype = basis.type()
65 | batch_size=I.size(1)
66 | M = basis.size(1)
67 | if eta is None:
68 | if useMAGMA:
69 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0])
70 | eta = 1./L
71 | else:
72 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1)
73 | eta = torch.from_numpy(eta.astype('float32')).cuda()
74 |
75 | tk_n = 1.
76 | tk = 1.
77 | Res = torch.cuda.FloatTensor(I.size()).fill_(0)
78 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0)
79 | ahat_y = torch.cuda.FloatTensor(M,batch_size).fill_(0)
80 |
81 | for t in range(num_iter):
82 | tk = tk_n
83 | tk_n = (1+np.sqrt(1+4*tk**2))/2
84 | ahat_pre = ahat
85 | Res = I - torch.mm(basis,ahat_y)
86 | ahat_y = ahat_y.add(eta * basis.t().mm(Res))
87 | ahat = ahat_y.sub(eta * lambd).clamp(min = 0.)
88 | ahat_y = ahat.add(ahat.sub(ahat_pre).mul((tk-1)/(tk_n)))
89 | Res = I - torch.mm(basis,ahat)
90 | return ahat, Res
91 |
92 | def ISTA(I,basis,lambd,num_iter,eta=None, useMAGMA=True):
93 | # This is a positive-only PyTorch-Ver ISTA solver
94 | dtype = basis.type()
95 | batch_size=I.size(1)
96 | M = basis.size(1)
97 | if eta is None:
98 | if useMAGMA:
99 | L = torch.max(torch.symeig(torch.mm(basis,basis.t()),eigenvectors=False)[0])
100 | eta = 1./L
101 | else:
102 | eta = 1./cp.linalg.eigvalsh(cp.asarray(torch.mm(basis,basis.t()).cpu().numpy())).max().get().reshape(1)
103 | eta = torch.from_numpy(eta.astype('float32')).cuda()
104 |
105 | #Res = torch.zeros(I.size()).type(dtype)
106 | #ahat = torch.zeros(M,batch_size).type(dtype)
107 | Res = torch.cuda.FloatTensor(I.size()).fill_(0)
108 | ahat = torch.cuda.FloatTensor(M,batch_size).fill_(0)
109 |
110 | for t in range(num_iter):
111 | ahat = ahat.add(eta * basis.t().mm(Res))
112 | ahat = ahat.sub(eta * lambd).clamp(min = 0.)
113 | Res = I - torch.mm(basis,ahat)
114 | return ahat, Res
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | # Yubei Chen, Sparse Manifold Coding Lib Ver 0.1
2 | """
3 | This file contains multiple utility functions
4 | """
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import time
8 | from mpl_toolkits.mplot3d import Axes3D
9 | import numpy.linalg as la
10 | from numpy import random
11 |
12 |
13 | def imshow(im,ax=None,nonBlock=False, title=None, vmin=None, vmax=None):
14 | axflag = True
15 | if ax == None:
16 | fig = plt.figure()
17 | ax = fig.gca()
18 | axflag = False
19 | axp = ax.imshow(im,cmap='gray',interpolation='none', vmin=vmin, vmax=vmax)
20 | if title!=None:
21 | ax.set_title(str(title))
22 | if (~nonBlock) and (not axflag):
23 | cbaxes = fig.add_axes([0.9, 0.1, 0.03, 0.8]) # This is the position for the colorbar
24 | cb = plt.colorbar(axp, cax = cbaxes)
25 |
26 | def displayVecArry(basis,X=1,Y=1,ax='none',title=-1,nonBlock=False, equal_contrast = False,boundary='none'):
27 | axflag = True
28 | if ax == 'none':
29 | fig = plt.figure()
30 | ax = fig.gca()
31 | axflag = False
32 | basisTemp = basis.copy()
33 | if equal_contrast:
34 | #basis_mean = basisTemp.mean(1)
35 | for i in range(basisTemp.shape[-1]):
36 | basisTemp[:,i] = basisTemp[:,i] - basisTemp[:,i].mean()
37 | basisTemp[:,i] = basisTemp[:,i]/(basisTemp[:,i].max()-basisTemp[:,i].min())
38 |
39 | if len(basisTemp.shape) == 1:
40 | basisTemp = np.reshape(basisTemp,[basisTemp.shape[0],1])
41 | #if nonBlock:
42 | # plt.ion()
43 | #else:
44 | # plt.ioff()
45 | SHAPE = basisTemp.shape
46 | PATCH_SIZE = int(np.sqrt(SHAPE[0]))
47 | img = np.empty([(PATCH_SIZE+1)*X-1,(PATCH_SIZE+1)*Y-1])
48 | img.fill(np.min(basisTemp))
49 | if boundary != 'none':
50 | img.fill(np.min(boundary))
51 | for i in range(X):
52 | for j in range(Y):
53 | img[(PATCH_SIZE+1)*i:(PATCH_SIZE+1)*(i+1)-1,\
54 | (PATCH_SIZE+1)*j:(PATCH_SIZE+1)*(j+1)-1] = \
55 | np.reshape(basisTemp[:,i*Y+j],[PATCH_SIZE,PATCH_SIZE])
56 | ax.imshow(img,cmap='gray',interpolation='none')
57 | if title!=-1:
58 | ax.set_title(str(title))
59 | ax.get_yaxis().set_visible(False)
60 | ax.get_xaxis().set_visible(False)
61 | ax.spines["right"].set_color("none")
62 | ax.spines["top"].set_color("none")
63 | ax.spines["left"].set_color("none")
64 | ax.spines["bottom"].set_color("none")
65 | #time.sleep(0.05)
66 | #if ~nonBlock:
67 | # plt.show()
68 |
69 | def plotFrameOff(ax):
70 | ax.get_yaxis().set_visible(False)
71 | ax.get_xaxis().set_visible(False)
72 | ax.spines["right"].set_color("none")
73 | ax.spines["top"].set_color("none")
74 | ax.spines["left"].set_color("none")
75 | ax.spines["bottom"].set_color("none")
76 |
77 | def createImage_FromVecArry(basis,X=1,Y=1):
78 | basisTemp = basis.copy()
79 | if len(basisTemp.shape) == 1:
80 | basisTemp = np.reshape(basisTemp,[basisTemp.shape[0],1])
81 | SHAPE = basisTemp.shape
82 | PATCH_SIZE = int(np.sqrt(SHAPE[0]))
83 | img = np.empty([(PATCH_SIZE+1)*X-1,(PATCH_SIZE+1)*Y-1])
84 | img.fill(np.min(basisTemp))
85 | for i in range(X):
86 | for j in range(Y):
87 | img[(PATCH_SIZE+1)*i:(PATCH_SIZE+1)*(i+1)-1,\
88 | (PATCH_SIZE+1)*j:(PATCH_SIZE+1)*(j+1)-1] = \
89 | np.reshape(basisTemp[:,i*Y+j],[PATCH_SIZE,PATCH_SIZE])
90 | return img
91 |
92 | def displayVec3D(basisFunction, ax = 'none', title=-1, nonBlock=False):
93 | """
94 | Show a single basis function in wireframe, the basis function need to be
95 | reshaped into the right size
96 | """
97 | axflag = True
98 | if ax == 'none':
99 | fig = plt.figure()
100 | ax = fig.add_subplot(111, projection='3d')
101 | axflag = False
102 | #if nonBlock:
103 | # plt.ion()
104 | #else:
105 | # plt.ioff()
106 | SHAPE = (basisFunction.T).shape
107 | #figure = plt.figure(fig)
108 | y = range(SHAPE[0])
109 | x = range(SHAPE[1])
110 | X, Y = np.meshgrid(x, y)
111 | ax.plot_wireframe(X, Y, basisFunction.T)
112 | if (~nonBlock) and (not axflag):
113 | fig.show()
114 | #plt.show()
115 |
116 | def displayVecSeq(VecSeq):
117 | return
118 |
119 |
120 | def displayFourierRecenter(fourierIm, ax = 'none', nonBlock=False):
121 | axflag = True
122 | if ax == 'none':
123 | fig = plt.figure()
124 | ax = fig.gca()
125 | axflag = False
126 | dim0Shift = np.int(np.floor((fourierIm.shape[0]-1)/2.))
127 | dim1Shift = np.int(np.floor((fourierIm.shape[1]-1)/2.))
128 | tempIm = np.roll(fourierIm,dim0Shift,axis=0)
129 | tempIm = np.roll(tempIm,dim1Shift,axis=1)
130 | ax.imshow(tempIm,interpolation='none', cmap = 'gray')
131 | if (~nonBlock) and (not axflag):
132 | fig.show()
133 | return tempIm
134 |
135 |
136 | def saveBasisParas(filename, _basis, _basis_size, _lambd, _iterations):
137 | np.savez(filename, basis=_basis, basis_size=_basis_size, lambd=_lambd, \
138 | iterations = _iterations)
139 |
140 |
141 | def saveLattice(filename, _basis, _basis_size, _manifoldCoordinate, _settingParas):
142 | np.savez(filename, basis=_basis, basis_size=_basis_size, \
143 | manifoldCoordinate=_manifoldCoordinate, settingParas=_settingParas)
144 |
145 |
146 | def loadLattice(filename):
147 | loaded = np.load(filename)
148 | return loaded['basis'], loaded['basis_size'], loaded['manifoldCoordinate'], loaded['settingParas']
149 |
150 |
151 | def loadBasisParas(filename):
152 | loaded = np.load(filename)
153 | return loaded['basis'], loaded['basis_size'], loaded['lambd'], loaded['iterations']
154 |
155 |
156 | def normalizeL2(vecArry):
157 | for i in range(vecArry.shape[1]):
158 | vecArry[:,i] = vecArry[:,i]/la.norm(vecArry[:,i],2)
159 | return vecArry
160 |
161 | def normalizeL1(vecArry):
162 | for i in range(vecArry.shape[1]):
163 | vecArry[:,i] = vecArry[:,i]/la.norm(vecArry[:,i],1)
164 | return vecArry
165 |
166 | def boundLinf(vecArry,bound):
167 | for i in range(vecArry.shape[1]):
168 | for j in range(vecArry.shape[0]):
169 | vecArry[j,i] = np.sign(vecArry[j,i])*np.min([np.abs(vecArry[j,i]),bound])
170 | return vecArry
171 |
172 | def seq_filtering(Seq, filt):
173 | # This function temporally filter a time series
174 | # Each column of Seq is vector at a particular time step
175 | #TODO: Please finish this simple function
176 | return -1
177 |
178 |
179 | def sampleRandom(data, num):
180 | """
181 | Currently data can not be 1d array
182 | """
183 | dataSize = data.shape
184 | dataNum = dataSize[-1]
185 | sampleSize = np.array(dataSize)
186 | sampleSize[-1] = num
187 | sample = np.zeros(sampleSize)
188 | batch = np.floor(random.rand(num)*dataNum)
189 | batch = batch.astype(np.int)
190 | for i in range(num):
191 | sample[:,i] = data[:,batch[i]]
192 | return sample
193 |
194 | def sampleRandomWithParas(data,paras,num):
195 | """
196 | Currently data can not be 1d array
197 | """
198 | dataSize = data.shape
199 | paraSize = paras.shape
200 | dataNum = dataSize[-1]
201 | sampleSize = np.array(dataSize)
202 | sampleParaSize = np.array(paraSize)
203 | sampleSize[-1] = num
204 | sampleParaSize[-1] = num
205 | sample = np.zeros(sampleSize)
206 | sampleParas = np.zeros(sampleParaSize)
207 | batch = np.floor(random.rand(num)*dataNum)
208 | batch.astype(np.int)
209 | for i in range(num):
210 | sample[:,i] = data[:,batch[i]]
211 | sampleParas[:,i] = paras[:,batch[i]]
212 | return sample, sampleParas
213 |
214 |
215 | def errorMsg(msg):
216 | """
217 | This function will output an error message msg
218 | """
219 | try:
220 | raise Exception(msg)
221 | except Exception as inst:
222 | print(inst)
223 |
224 | def generalized_norm_square(V1,M):
225 | """
226 | This function returns a vector of square of the generalized norm of each columns in V1 with respect to M.
227 | """
228 | return np.diag(np.dot(V1.T,np.dot(M,V1)))
229 |
230 | def generalized_norm(V1,M):
231 | """
232 | This function returns a vector of the generalized norm of each columns in V1 with respect to M.
233 | """
234 | return np.sqrt(np.diag(np.dot(V1.T,np.dot(M,V1))))
235 |
236 | def patch_translation(patch,xshift,yshift):
237 | """
238 | Apply some pixel level translation on a given patch. It is not an in-place function.
239 | """
240 | xdim = patch.shape[0]
241 | ydim = patch.shape[1]
242 | patch_new = patch.copy()
243 | patch_new[...] = 0
244 | locx = np.clip(-xshift,0,np.Infinity).astype('int')
245 | locy = np.clip(-yshift,0,np.Infinity).astype('int')
246 | locx_new = np.clip(xshift,0,np.Infinity).astype('int')
247 | locy_new = np.clip(yshift,0,np.Infinity).astype('int')
248 | patch_new[locx_new:xdim+xshift,locy_new:ydim+yshift] = patch[locx:xdim-xshift,locy:ydim-yshift]
249 | return patch_new
250 |
251 |
252 |
--------------------------------------------------------------------------------