├── .gitignore ├── LICENSE ├── PyTorchNMTF ├── __init__.py ├── models │ ├── NMF.py │ └── __init__.py └── objectives.py ├── README.MD └── examples.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | 3 | __pycache__/ 4 | 5 | *.py[cod] 6 | 7 | *$py.class 8 | 9 | 10 | # C extensions 11 | 12 | *.so 13 | 14 | 15 | # Distribution / packaging 16 | 17 | .Python 18 | 19 | env/ 20 | 21 | build/ 22 | 23 | develop-eggs/ 24 | 25 | dist/ 26 | 27 | downloads/ 28 | 29 | eggs/ 30 | 31 | .eggs/ 32 | 33 | lib/ 34 | 35 | lib64/ 36 | 37 | parts/ 38 | 39 | sdist/ 40 | 41 | var/ 42 | 43 | *.egg-info/ 44 | 45 | .installed.cfg 46 | 47 | *.egg 48 | 49 | 50 | # PyInstaller 51 | 52 | # Usually these files are written by a python script from a template 53 | 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | 56 | *.manifest 57 | 58 | *.spec 59 | 60 | 61 | # Installer logs 62 | 63 | pip-log.txt 64 | 65 | pip-delete-this-directory.txt 66 | 67 | 68 | # Unit test / coverage reports 69 | 70 | htmlcov/ 71 | 72 | .tox/ 73 | 74 | .coverage 75 | 76 | .coverage.* 77 | 78 | .cache 79 | 80 | nosetests.xml 81 | 82 | coverage.xml 83 | 84 | *,cover 85 | 86 | .hypothesis/ 87 | 88 | 89 | # Translations 90 | 91 | *.mo 92 | 93 | *.pot 94 | 95 | 96 | # Django stuff: 97 | 98 | *.log 99 | 100 | local_settings.py 101 | 102 | 103 | # Flask stuff: 104 | 105 | instance/ 106 | 107 | .webassets-cache 108 | 109 | 110 | # Scrapy stuff: 111 | 112 | .scrapy 113 | 114 | 115 | # Sphinx documentation 116 | 117 | docs/_build/ 118 | 119 | 120 | # PyBuilder 121 | 122 | target/ 123 | 124 | 125 | # IPython Notebook 126 | 127 | .ipynb_checkpoints 128 | 129 | 130 | # pyenv 131 | 132 | .python-version 133 | 134 | 135 | # celery beat schedule file 136 | 137 | celerybeat-schedule 138 | 139 | 140 | # dotenv 141 | 142 | .env 143 | 144 | 145 | # virtualenv 146 | 147 | venv/ 148 | 149 | ENV/ 150 | 151 | 152 | # Spyder project settings 153 | 154 | .spyderproject 155 | 156 | 157 | # Rope project settings 158 | 159 | .ropeproject 160 | 161 | # dev 162 | dev 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Konstantin Sozykin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PyTorchNMTF/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from . import objectives 3 | from . import models -------------------------------------------------------------------------------- /PyTorchNMTF/models/NMF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from ipypb import ipb 4 | from collections import defaultdict 5 | from .. import objectives 6 | 7 | class NMF(nn.Module): 8 | 9 | def __init__(self,X, 10 | k = 10, solver = 'mu', n_iter = 10, eps = 1e-7, 11 | alpha = 0.99, 12 | loss = 'l2', 13 | lr = 1e-2, verbose = False): 14 | super(NMF, self).__init__() 15 | self.n_iter = n_iter 16 | self.k = k 17 | self.loss = loss 18 | self.lr = lr 19 | self.alpha = alpha 20 | self.verbose = verbose 21 | self.solver = solver 22 | self.eps = eps 23 | self.decomposed = False 24 | self.is_cuda = False 25 | self.report = defaultdict(list) 26 | self.__initfact__(X) 27 | 28 | def __initfact__(self,X): 29 | self.n,self.m = X.shape 30 | self.X = torch.from_numpy(X) 31 | self.scale = torch.sqrt(torch.mean(self.X) / self.k) 32 | W = torch.abs(torch.rand([self.n,self.k])*self.scale) 33 | H = torch.abs(torch.rand([self.k,self.m])*self.scale) 34 | self.W = torch.nn.Parameter(W) 35 | self.H = torch.nn.Parameter(H) 36 | # for autograd solver 37 | self.opt = torch.optim.RMSprop([self.W,self.H], alpha=self.alpha, lr=self.lr,weight_decay=1e-6) 38 | if self.loss == 'l2': 39 | self.loss_fn = objectives.l2 40 | elif self.loss == 'kl': 41 | self.loss_fn = objectives.kl_dev 42 | 43 | def to(self,device): 44 | self.is_cuda = (device == 'cuda') 45 | if self.is_cuda: 46 | self.X = self.X.to('cuda') 47 | return super(NMF, self).to(device) 48 | 49 | def plus(self,X): 50 | X[X < 0] = self.eps 51 | return X 52 | 53 | def __mu__(self,epoch): 54 | """ 55 | multiplicative update, explisit form. 56 | """ 57 | W,H = self.W,self.H 58 | WT = torch.transpose(W,0,1) 59 | HT = torch.transpose(H,0,1) 60 | XHT = self.X @ HT 61 | WHHT = W @ H @ HT 62 | W = W * (XHT)/(WHHT+self.eps) 63 | WTX = WT @ self.X 64 | WTWH = WT @ W @ H 65 | H = H * (WTX)/(WTWH+self.eps) 66 | self.W = torch.nn.Parameter(W) 67 | self.H = torch.nn.Parameter(H) 68 | l = self.loss_fn(self.X,self.W @ self.H) 69 | return l.item() 70 | 71 | def __autograd__(self,epoch): 72 | """ 73 | autograd update, with gradient projection 74 | """ 75 | self.opt.zero_grad() 76 | l = self.loss_fn(self.X,self.W @ self.H) 77 | l.backward() 78 | self.opt.step() 79 | ## grad projection 80 | for p in self.parameters(): 81 | p.data = self.plus(p.data) 82 | return l.item() 83 | 84 | 85 | def __update__(self,epoch): 86 | if self.solver == 'mu': 87 | l = self.__mu__(epoch) 88 | elif self.solver == 'autograd': 89 | l = self.__autograd__(epoch) 90 | else: 91 | raise NotImplementedError 92 | self.report['epoch'].append(epoch) 93 | self.report['loss'].append(l) 94 | if self.verbose and epoch % 500 == 0: 95 | print("%d\tloss: %.4f"%(epoch,l)) 96 | assert self.is_nonneg() 97 | 98 | def fit(self): 99 | it = range(self.n_iter) 100 | if self.verbose: 101 | it = ipb(it) 102 | for e in it: 103 | self.__update__(e) 104 | self.decomposed = True 105 | return self 106 | 107 | def show_report(self): 108 | return pd.DataFrame(self.report) 109 | 110 | def is_nonneg(self): 111 | return bool(torch.all(self.W >= 0) and torch.all(self.H >= 0)) 112 | 113 | def forward(self, H): 114 | return self.W @ H 115 | 116 | def fit_transform(self): 117 | if not self.decomposed: 118 | self.fit(X) 119 | return [self.W, self.H] 120 | 121 | def mse(self): 122 | err = ( (self.X - self.W @ self.H) ** 2 ).mean().detach().cpu().numpy() 123 | return float(err) 124 | 125 | def inverse_trainsform(self): 126 | return (self.W @ self.H).detach() -------------------------------------------------------------------------------- /PyTorchNMTF/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .NMF import NMF -------------------------------------------------------------------------------- /PyTorchNMTF/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | 4 | def l2(x,y): 5 | return torch.nn.MSELoss()(x,y) 6 | 7 | def kl_dev(x,y): 8 | return (x * torch.log(x/y) - x + y).mean() -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ### TorchNMTF 2 | 3 | A collection of Nonnegative Matrix and Tensor Factorizations models with *further* applications to Multi-modal Data Analysis and Blind Source Separation and Sensor Fusion Problems, implemented with Pytorch. 4 | 5 | ### Requirements 6 | 7 | - PyTorch == 1.0.0 and it's dependences 8 | - skimage 9 | - Pandas 10 | - [ipypb](https://github.com/Evfro/ipypb) 11 | 12 | ### Usage 13 | 14 | ``` 15 | see examples.ipynb 16 | ``` 17 | 18 | Here is a result of the inverse transform of a factorized image 19 | ![Chelsea the cat.](https://pp.userapi.com/c856132/v856132226/5a0f/fo90ZyRSXeY.jpg) 20 | 21 | 22 | ### Refereces 23 | 24 | 1. Andrzej Cichocki Rafal Zdunek Anh Huy Phan Shun‐Ichi Amari, Nonnegative Matrix and Tensor Factorizations 25 | 26 | 2. https://perso.telecom-paristech.fr/essid/teach/NMF_tutorial_ICME-2014.pdf 27 | 28 | 3. http://pmelchior.net/blog/proximal-matrix-factorization-in-pytorch.html 29 | 30 | ## TODO 31 | 32 | - [x] Basic NMF model with 2 solvers 33 | - [ ] ALS, HALS solvers, and advanced objectives 34 | - [ ] More Examples 35 | - [ ] Nonegative Tensor Factorisation 36 | - [ ] Tests 37 | - [ ] Setup.py 38 | - [ ] Nonrandom initialization. -------------------------------------------------------------------------------- /examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ['CUDA_VISIBLE_DEVICES']='6'\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import skimage.io\n", 14 | "import skimage.data\n", 15 | "from skimage.color import rgb2gray\n", 16 | "from PyTorchNMTF.models import NMF" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "%pylab inline" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "im = rgb2gray(skimage.data.chelsea())*255.\n", 42 | "im = im.astype('float32')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## Test Image Reconstruction" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "device = 'cuda'" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "n_iter = 2500\n", 68 | "k = 256" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "nmf_mu = NMF(im,n_iter=n_iter,k=k,solver='mu',verbose=True,loss='kl')\n", 78 | "nmf_mu = nmf_mu.to(device)\n", 79 | "nmf_mu.fit()\n", 80 | "im_recon_mu = nmf_mu.inverse_trainsform().cpu().numpy()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "nmf_autograd = NMF(im,n_iter=n_iter,k=k,solver='autograd',verbose=True,loss='kl')\n", 90 | "nmf_autograd = nmf_autograd.to(device)\n", 91 | "nmf_autograd.fit()\n", 92 | "im_recon_autograd = nmf_autograd.inverse_trainsform().cpu().numpy()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "figure(figsize=(12,12),facecolor='white')\n", 102 | "plt.subplot(221)\n", 103 | "imshow(im,cmap='gray')\n", 104 | "title('original image')\n", 105 | "plt.subplot(222)\n", 106 | "imshow(im_recon_mu,cmap='gray')\n", 107 | "title('mu solver, mse %.4f' % nmf_mu.mse() )\n", 108 | "plt.subplot(223)\n", 109 | "opt_name = nmf_autograd.opt.__class__.__name__\n", 110 | "title('autograd (%s) solver, mse %.4f' % (opt_name,nmf_autograd.mse()))\n", 111 | "imshow(im_recon_autograd,cmap='gray')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "assert nmf_autograd.is_nonneg() and nmf_mu.is_nonneg()" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python [conda env:bob_py3]", 127 | "language": "python", 128 | "name": "conda-env-bob_py3-py" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.6.8" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | --------------------------------------------------------------------------------