├── .gitignore ├── LICENSE ├── README.md ├── data └── IMAGES.mat ├── src ├── model │ ├── ImageDataset.py │ ├── SparseNet.py │ └── __init__.py ├── scripts │ ├── __init__.py │ ├── plotting.py │ └── train.py └── utils │ ├── __init__.py │ └── cmd_line.py └── trained_models └── RF.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE 107 | *.idea/ 108 | 109 | # Mac 110 | *.DS_Store 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Preston Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Coding 2 | 3 | ![](./trained_models/RF.png) 4 | 5 | This is an implementation of Olshausen and Field's sparse coding paper in PyTorch. Iterative Shrinkage/Thresholding Algorithm 6 | (ISTA) is used to fit neuronal responses for the input. Gradients for receptive fields are calculated through PyTorh's autograd 7 | feature. 8 | 9 | ## Run 10 | To run the program: 11 | ```python 12 | cd src/scripts 13 | python train.py 14 | ``` 15 | To see a list of available hyperparameters to change: 16 | ```python 17 | python train.py -h 18 | ``` 19 | A checkpoint of the model is saved every 10 epochs to `trained_models`. To see the tensorboard logs: 20 | ```python 21 | tensorboard --logdir=runs 22 | ``` 23 | 24 | ## Will be added soon 25 | * Fast-ISTA 26 | 27 | ## References 28 | * Olshausen, B. A., & Field, D. J. (1996). Emergence of simple-cell receptive field properties by learning a sparse code for natural images. Nature, 381(6583), 607–609. https://doi.org/10.1038/381607a0 29 | * IMAGES.mat is downloaded from Olshausen's original Matlab implementation website: http://www.rctn.org/bruno/sparsenet/ -------------------------------------------------------------------------------- /data/IMAGES.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpjiang97/sparse-coding/64951ef2ecd63efdb1c4f6fc1296c79da0840329/data/IMAGES.mat -------------------------------------------------------------------------------- /src/model/ImageDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from scipy.io import loadmat 5 | 6 | 7 | class NatPatchDataset(Dataset): 8 | 9 | def __init__(self, N:int, width:int, height:int, border:int=4, fpath:str='../../data/IMAGES.mat'): 10 | super(NatPatchDataset, self).__init__() 11 | self.N = N 12 | self.width = width 13 | self.height = height 14 | self.border = border 15 | self.fpath = fpath 16 | # holder 17 | self.images = None 18 | # initialize patches 19 | self.extract_patches_() 20 | 21 | def __len__(self): 22 | return self.images.shape[0] 23 | 24 | def __getitem__(self, idx): 25 | return self.images[idx] 26 | 27 | def extract_patches_(self): 28 | # load mat 29 | X = loadmat(self.fpath) 30 | X = X['IMAGES'] 31 | img_size = X.shape[0] 32 | n_img = X.shape[2] 33 | self.images = torch.zeros((self.N * n_img, self.width, self.height)) 34 | # for every image 35 | counter = 0 36 | for i in range(n_img): 37 | img = X[:, :, i] 38 | for j in range(self.N): 39 | x = np.random.randint(self.border, img_size - self.width - self.border) 40 | y = np.random.randint(self.border, img_size - self.height - self.border) 41 | crop = torch.tensor(img[x:x+self.width, y:y+self.height]) 42 | self.images[counter, :, :] = crop - crop.mean() 43 | counter += 1 44 | -------------------------------------------------------------------------------- /src/model/SparseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SparseNet(nn.Module): 7 | 8 | def __init__(self, K:int, M:int, R_lr:float=0.1, lmda:float=5e-3, device=None): 9 | super(SparseNet, self).__init__() 10 | self.K = K 11 | self.M = M 12 | self.R_lr = R_lr 13 | self.lmda = lmda 14 | # synaptic weights 15 | self.device = torch.device("cpu") if device is None else device 16 | self.U = nn.Linear(self.K, self.M ** 2, bias=False).to(self.device) 17 | # responses 18 | self.R = None 19 | self.normalize_weights() 20 | 21 | def ista_(self, img_batch): 22 | # create R 23 | self.R = torch.zeros((img_batch.shape[0], self.K), requires_grad=True, device=self.device) 24 | converged = False 25 | # update R 26 | optim = torch.optim.SGD([{'params': self.R, "lr": self.R_lr}]) 27 | # train 28 | while not converged: 29 | old_R = self.R.clone().detach() 30 | # pred 31 | pred = self.U(self.R) 32 | # loss 33 | loss = ((img_batch - pred) ** 2).sum() 34 | loss.backward() 35 | # update R in place 36 | optim.step() 37 | # zero grad 38 | self.zero_grad() 39 | # prox 40 | self.R.data = SparseNet.soft_thresholding_(self.R, self.lmda) 41 | # convergence 42 | converged = torch.norm(self.R - old_R) / torch.norm(old_R) < 0.01 43 | 44 | @staticmethod 45 | def soft_thresholding_(x, alpha): 46 | with torch.no_grad(): 47 | rtn = F.relu(x - alpha) - F.relu(-x - alpha) 48 | return rtn.data 49 | 50 | def zero_grad(self): 51 | self.R.grad.zero_() 52 | self.U.zero_grad() 53 | 54 | def normalize_weights(self): 55 | with torch.no_grad(): 56 | self.U.weight.data = F.normalize(self.U.weight.data, dim=0) 57 | 58 | def forward(self, img_batch): 59 | # first fit 60 | self.ista_(img_batch) 61 | # now predict again 62 | pred = self.U(self.R) 63 | return pred 64 | 65 | 66 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpjiang97/sparse-coding/64951ef2ecd63efdb1c4f6fc1296c79da0840329/src/model/__init__.py -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpjiang97/sparse-coding/64951ef2ecd63efdb1c4f6fc1296c79da0840329/src/scripts/__init__.py -------------------------------------------------------------------------------- /src/scripts/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_rf(rf, out_dim, M): 6 | rf = rf.reshape(out_dim, -1) 7 | # normalize 8 | rf = rf.T / np.abs(rf).max(axis=1) 9 | rf = rf.T 10 | rf = rf.reshape(out_dim, M, M) 11 | # plotting 12 | n = int(np.ceil(np.sqrt(rf.shape[0]))) 13 | fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, sharey=True) 14 | fig.set_size_inches(10, 10) 15 | for i in range(rf.shape[0]): 16 | ax = axes[i // n][i % n] 17 | ax.imshow(rf[i], cmap='gray', vmin=-1, vmax=1) 18 | ax.set_xticks([]) 19 | ax.set_yticks([]) 20 | ax.set_aspect('equal') 21 | for j in range(rf.shape[0], n * n): 22 | ax = axes[j // n][j % n] 23 | ax.imshow(np.ones_like(rf[0]) * -1, cmap='gray', vmin=-1, vmax=1) 24 | ax.set_xticks([]) 25 | ax.set_yticks([]) 26 | ax.set_aspect('equal') 27 | fig.subplots_adjust(wspace=0.0, hspace=0.0) 28 | return fig 29 | -------------------------------------------------------------------------------- /src/scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath('../../.')) 4 | from tqdm import tqdm 5 | import torch 6 | from src.model.SparseNet import SparseNet 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from src.model.ImageDataset import NatPatchDataset 10 | from src.utils.cmd_line import parse_args 11 | from src.scripts.plotting import plot_rf 12 | 13 | 14 | # save to tensorboard 15 | board = SummaryWriter("../../runs/sparse-net") 16 | arg = parse_args() 17 | # if use cuda 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | # create net 20 | sparse_net = SparseNet(arg.n_neuron, arg.size, R_lr=arg.r_learning_rate, lmda=arg.reg, device=device) 21 | # load data 22 | dataloader = DataLoader(NatPatchDataset(arg.batch_size, arg.size, arg.size), batch_size=250) 23 | # train 24 | optim = torch.optim.SGD([{'params': sparse_net.U.weight, "lr": arg.learning_rate}]) 25 | for e in range(arg.epoch): 26 | running_loss = 0 27 | c = 0 28 | for img_batch in tqdm(dataloader, desc='training', total=len(dataloader)): 29 | img_batch = img_batch.reshape(img_batch.shape[0], -1).to(device) 30 | # update 31 | pred = sparse_net(img_batch) 32 | loss = ((img_batch - pred) ** 2).sum() 33 | running_loss += loss.item() 34 | loss.backward() 35 | # update U 36 | optim.step() 37 | # zero grad 38 | sparse_net.zero_grad() 39 | # norm 40 | sparse_net.normalize_weights() 41 | c += 1 42 | board.add_scalar('Loss', running_loss / c, e * len(dataloader) + c) 43 | if e % 5 == 4: 44 | # plotting 45 | fig = plot_rf(sparse_net.U.weight.T.reshape(arg.n_neuron, arg.size, arg.size).cpu().data.numpy(), arg.n_neuron, arg.size) 46 | board.add_figure('RF', fig, global_step=e * len(dataloader) + c) 47 | if e % 10 == 9: 48 | # save checkpoint 49 | torch.save(sparse_net, f"../../trained_models/ckpt-{e+1}.pth") 50 | torch.save(sparse_net, f"../../trained_models/ckpt-{e+1}.pth") 51 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpjiang97/sparse-coding/64951ef2ecd63efdb1c4f6fc1296c79da0840329/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/cmd_line.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | parser = argparse.ArgumentParser(description="Template") 5 | # model 6 | parser.add_argument('-N', '--batch_size', default=2000, type=int, help="Batch size") 7 | parser.add_argument('-K', '--n_neuron', default=400, type=int, help="The number of neurons") 8 | parser.add_argument('-M', '--size', default=10, type=int, help="The size of receptive field") 9 | # training 10 | parser.add_argument('-e', '--epoch', default=100, type=int, help="Number of Epochs") 11 | parser.add_argument('-lr', '--learning_rate', default=1e-2, type=float, help="Learning rate") 12 | parser.add_argument('-rlr', '--r_learning_rate', default=1e-2, type=float, help="Learning rate for ISTA") 13 | parser.add_argument('-lmda', '--reg', default=5e-3, type=float, help="LSTM hidden size") 14 | 15 | 16 | # Parse arguments 17 | def parse_args(): 18 | return parser.parse_args() 19 | -------------------------------------------------------------------------------- /trained_models/RF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lpjiang97/sparse-coding/64951ef2ecd63efdb1c4f6fc1296c79da0840329/trained_models/RF.png --------------------------------------------------------------------------------