├── README.md ├── datasets ├── IMAGES.mat ├── IMAGES_RAW.mat └── Sparse net.url ├── ica.py ├── network.py ├── results ├── ICA.png ├── PCA.png ├── RF.png ├── RF_cauchy_thresholding.png └── error.png ├── sparse-coding.ipynb ├── sparse_coding.ipynb └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Coding (Olshausen & Field, 1996) Model 2 | Sparse Coding (Olshausen & Field, 1996) Model with **LCA** (locally competitive algorithm) ([Rozell et al., *Neural Comput*. 2008](https://www.ece.rice.edu/~eld1/papers/Rozell08.pdf)). 3 | 4 | Data is from . 5 | 6 | ## Requirement 7 | - `Python >= 3.5` 8 | - `numpy`, `matplotlib`, `scipy`, `tqdm`, `sklearn` 9 | 10 | ## Usage 11 | - Run `train.py` or `predictive-coding.ipynb` (written in Japanese). 12 | - `ica.py` is implementation of ICA and PCA for Natural images. 13 | 14 | ## Results 15 | ### Loss function 16 | 17 | 18 | ### Receptive fields (using soft threshold function) 19 | 20 | 21 | ### Receptive fields (using Cauchy threshold function (Mayo et al., 2020)) 22 | 23 | 24 | ### ICA 25 | 26 | 27 | ### PCA 28 | 29 | 30 | ## Reference 31 | - Olshausen BA, Field DJ. [Emergence of simple-cell receptive field properties by learning a sparse code for natural images](https://www.nature.com/articles/381607a0). *Nature*. 1996;381(6583):607–609. [Data and Code](http://www.rctn.org/bruno/sparsenet/), [pdf](https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf) 32 | - Rozell CJ, Johnson DH, Baraniuk RG, Olshausen BA. [Sparse coding via thresholding and local competition in neural circuits](http://www.mit.edu/~9.54/fall14/Classes/class07/Palm.pdf). *Neural Comput*. 2008;20(10):2526‐2563. 33 | - Mayo P, Holmes R, Achim A. [Iterative Cauchy Thresholding: Regularisation with a heavy-tailed prior](https://arxiv.org/abs/2003.12507). arXiv. 2020. -------------------------------------------------------------------------------- /datasets/IMAGES.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/datasets/IMAGES.mat -------------------------------------------------------------------------------- /datasets/IMAGES_RAW.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/datasets/IMAGES_RAW.mat -------------------------------------------------------------------------------- /datasets/Sparse net.url: -------------------------------------------------------------------------------- 1 | [InternetShortcut] 2 | URL=http://www.rctn.org/bruno/sparsenet/ 3 | -------------------------------------------------------------------------------- /ica.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from sklearn.decomposition import FastICA, PCA 6 | from tqdm import tqdm 7 | import scipy.io as sio 8 | 9 | # datasets from http://www.rctn.org/bruno/sparsenet/ 10 | # mat_images = sio.loadmat('datasets/IMAGES.mat') 11 | # imgs = mat_images['IMAGES'] 12 | mat_images_raw = sio.loadmat('datasets/IMAGES_RAW.mat') 13 | imgs_raw = mat_images_raw['IMAGESr'] 14 | 15 | # Simulation constants 16 | H, W, num_images = imgs_raw.shape 17 | 18 | num_patches = 15000 19 | patchs_list = [] 20 | w, h = 16, 16 # patch size 21 | 22 | # generate patches 23 | for i in tqdm(range(num_patches)): 24 | i = np.random.randint(0, num_images) 25 | # Get the coordinates of the upper left corner of clopping image randomly. 26 | beginx = np.random.randint(0, W-w-1) 27 | beginy = np.random.randint(0, H-h-1) 28 | img_clopped = imgs_raw[beginy:beginy+h, beginx:beginx+w, i] 29 | patchs_list.append(img_clopped.flatten()) 30 | 31 | patches = np.array(patchs_list) 32 | 33 | # perform ICA 34 | print("perform ICA") 35 | n_comp = 100 36 | ica = FastICA(n_components=n_comp) 37 | ica.fit(patches) 38 | ica_filters = ica.components_ 39 | 40 | # plot filters 41 | plt.figure(figsize=(6,6)) 42 | plt.subplots_adjust(hspace=0.1, wspace=0.1) 43 | for i in tqdm(range(n_comp)): 44 | plt.subplot(10, 10, i+1) 45 | plt.imshow(np.reshape(ica_filters[i], (w, h)), cmap="gray") 46 | plt.axis("off") 47 | plt.suptitle("ICA", fontsize=20) 48 | plt.subplots_adjust(top=0.9) 49 | plt.savefig("ICA.png") 50 | plt.show() 51 | 52 | # perform PCA 53 | print("perform PCA") 54 | pca = PCA(n_components=n_comp) 55 | pca.fit(patches) 56 | pca_filters = pca.components_ 57 | 58 | # plot filters 59 | plt.figure(figsize=(6,6)) 60 | plt.subplots_adjust(hspace=0.1, wspace=0.1) 61 | for i in tqdm(range(n_comp)): 62 | plt.subplot(10, 10, i+1) 63 | plt.imshow(np.reshape(pca_filters[i], (w, h)), cmap="gray") 64 | plt.axis("off") 65 | plt.suptitle("PCA", fontsize=20) 66 | plt.subplots_adjust(top=0.9) 67 | plt.savefig("PCA.png") 68 | plt.show() 69 | 70 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | class OlshausenField1996Model: 6 | def __init__(self, num_inputs, num_units, batch_size, 7 | lr_r=1e-2, lr_Phi=1e-2, lmda=5e-3): 8 | self.lr_r = lr_r # learning rate of r 9 | self.lr_Phi = lr_Phi # learning rate of Phi 10 | self.lmda = lmda # regularization parameter 11 | 12 | self.num_inputs = num_inputs 13 | self.num_units = num_units 14 | self.batch_size = batch_size 15 | 16 | # Weights 17 | Phi = np.random.randn(self.num_inputs, self.num_units).astype(np.float32) 18 | self.Phi = Phi * np.sqrt(1/self.num_units) 19 | 20 | # activity of neurons 21 | self.r = np.zeros((self.batch_size, self.num_units)) 22 | 23 | def initialize_states(self): 24 | self.r = np.zeros((self.batch_size, self.num_units)) 25 | 26 | def normalize_rows(self): 27 | self.Phi = self.Phi / np.maximum(np.linalg.norm(self.Phi, ord=2, axis=0, keepdims=True), 1e-8) 28 | 29 | # thresholding function of S(x)=|x| 30 | def soft_thresholding_func(self, x, lmda): 31 | return np.maximum(x - lmda, 0) - np.maximum(-x - lmda, 0) 32 | 33 | # thresholding function of S(x)=ln(1+x^2) 34 | def ln_thresholding_func(self, x, lmda): 35 | f = 9*lmda*x - 2*np.power(x, 3) - 18*x 36 | g = 3*lmda - np.square(x) + 3 37 | h = np.cbrt(np.sqrt(np.square(f) + 4*np.power(g, 3)) + f) 38 | two_croot = np.cbrt(2) # cubic root of two 39 | return (1/3)*(x - h / two_croot + two_croot*g / (1e-8+h)) 40 | 41 | # thresholding function https://arxiv.org/abs/2003.12507 42 | def cauchy_thresholding_func(self, x, lmda): 43 | f = 0.5*(x + np.sqrt(np.maximum(x**2 - lmda,0))) 44 | g = 0.5*(x - np.sqrt(np.maximum(x**2 - lmda,0))) 45 | return f*(x>=lmda) + g*(x<=-lmda) 46 | 47 | def calculate_total_error(self, error): 48 | recon_error = np.mean(error**2) 49 | sparsity_r = self.lmda*np.mean(np.abs(self.r)) 50 | return recon_error + sparsity_r 51 | 52 | def __call__(self, inputs, training=True): 53 | # Updates 54 | error = inputs - self.r @ self.Phi.T 55 | 56 | r = self.r + self.lr_r * error @ self.Phi 57 | self.r = self.soft_thresholding_func(r, self.lmda) 58 | #self.r = self.cauchy_thresholding_func(r, self.lmda) 59 | 60 | if training: 61 | error = inputs - self.r @ self.Phi.T 62 | dPhi = error.T @ self.r 63 | self.Phi += self.lr_Phi * dPhi 64 | 65 | return error, self.r 66 | -------------------------------------------------------------------------------- /results/ICA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/results/ICA.png -------------------------------------------------------------------------------- /results/PCA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/results/PCA.png -------------------------------------------------------------------------------- /results/RF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/results/RF.png -------------------------------------------------------------------------------- /results/RF_cauchy_thresholding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/results/RF_cauchy_thresholding.png -------------------------------------------------------------------------------- /results/error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takyamamoto/SparseCoding-OlshausenField-Model/6a73f2ca81b690b0d0fddbcfb14f508f597afd2c/results/error.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import network 6 | from tqdm import tqdm 7 | import scipy.io as sio 8 | 9 | np.random.seed(0) 10 | 11 | # datasets from http://www.rctn.org/bruno/sparsenet/ 12 | mat_images = sio.loadmat('datasets/IMAGES.mat') 13 | imgs = mat_images['IMAGES'] 14 | 15 | # Simulation constants 16 | H, W, num_images = imgs.shape 17 | num_iter = 500 # number of iterations 18 | nt_max = 1000 # Maximum number of simulation time 19 | batch_size = 250 # Batch size 20 | 21 | sz = 16 # image patch size 22 | num_units = 100 # number of neurons (units) 23 | 24 | eps = 1e-2 # small value which determines convergence 25 | error_list = [] # List to save errors 26 | 27 | # Define model 28 | model = network.OlshausenField1996Model(num_inputs=sz**2, num_units=num_units, 29 | batch_size=batch_size) 30 | 31 | # Run simulation 32 | for iter_ in tqdm(range(num_iter)): 33 | # Get the coordinates of the upper left corner of clopping image randomly. 34 | beginx = np.random.randint(0, W-sz, batch_size) 35 | beginy = np.random.randint(0, H-sz, batch_size) 36 | 37 | inputs_list = [] 38 | 39 | # Get images randomly 40 | for i in range(batch_size): 41 | idx = np.random.randint(0, num_images) 42 | img = imgs[:, :, idx] 43 | clop = img[beginy[i]:beginy[i]+sz, beginx[i]:beginx[i]+sz].flatten() 44 | inputs_list.append(clop - np.mean(clop)) 45 | 46 | inputs = np.array(inputs_list) # Input image patches 47 | 48 | model.initialize_states() # Reset states 49 | model.normalize_rows() # Normalize weights 50 | 51 | # Input an image patch until latent variables are converged 52 | r_tm1 = model.r # set previous r (t minus 1) 53 | 54 | for t in range(nt_max): 55 | # Update r without update weights 56 | error, r = model(inputs, training=False) 57 | dr = r - r_tm1 58 | 59 | # Compute norm of r 60 | dr_norm = np.linalg.norm(dr, ord=2) / (eps + np.linalg.norm(r_tm1, ord=2)) 61 | r_tm1 = r # update r_tm1 62 | 63 | # Check convergence of r, then update weights 64 | if dr_norm < eps: 65 | error, r = model(inputs, training=True) 66 | break 67 | 68 | # If failure to convergence, break and print error 69 | if t >= nt_max-2: 70 | print("Error at patch:", iter_) 71 | print(dr_norm) 72 | break 73 | 74 | error_list.append(model.calculate_total_error(error)) # Append errors 75 | 76 | # Print moving average error 77 | if iter_ % 100 == 99: 78 | print("\n iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:", 79 | np.mean(error_list[iter_-99:iter_])) 80 | 81 | # Plot error 82 | plt.figure(figsize=(5, 3)) 83 | plt.ylabel("Error") 84 | plt.xlabel("Iterations") 85 | plt.plot(np.arange(len(error_list)), np.array(error_list)) 86 | plt.tight_layout() 87 | plt.savefig("error.png") 88 | plt.show() 89 | 90 | # Plot Receptive fields 91 | fig = plt.figure(figsize=(8, 8)) 92 | plt.subplots_adjust(hspace=0.1, wspace=0.1) 93 | for i in tqdm(range(num_units)): 94 | plt.subplot(10, 10, i+1) 95 | plt.imshow(np.reshape(model.Phi[:, i], (sz, sz)), cmap="gray") 96 | plt.axis("off") 97 | 98 | fig.suptitle("Receptive fields", fontsize=20) 99 | plt.subplots_adjust(top=0.9) 100 | plt.savefig("RF.png") 101 | plt.show() --------------------------------------------------------------------------------