├── .gitignore ├── ARD_NMF.py ├── LICENSE ├── NMF_functions.py ├── README.md ├── SignatureAnalyzer-GPU.py ├── SupplementalNote.pdf ├── __init__.py ├── example_data ├── POLEMSI_counts_matrix.txt └── POLEMSI_params.txt └── requirements-py3.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | -------------------------------------------------------------------------------- /ARD_NMF.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sys 4 | from sys import stdout 5 | import argparse 6 | import time 7 | from scipy.special import gamma 8 | import os 9 | import pickle 10 | import math 11 | import torch 12 | from typing import Union 13 | import multiprocessing.connection as mpc 14 | from .NMF_functions import * 15 | 16 | class ARD_NMF: 17 | """ 18 | NMF results class implements both half normal and exponential prior ARD NMF 19 | implementation based on https://arxiv.org/pdf/1111.6085.pdf 20 | """ 21 | def __init__(self,dataset,objective,dtype = torch.float32, verbose=True): 22 | self.eps_ = torch.tensor(1.e-30,dtype=dtype,requires_grad=False) 23 | self.dataset = dataset 24 | zero_idx = np.sum(self.dataset, axis=1) > 0 25 | self.V0 = self.dataset.values[zero_idx, :] 26 | self.V = self.V0 - np.min(self.V0) + 1.e-30 27 | self.V_max = np.max(self.V) 28 | self.M = self.V.shape[0] 29 | self.N = self.V.shape[1] 30 | self.objective = objective 31 | self.channel_names = self.dataset.index[zero_idx] 32 | self.sample_names = self.dataset.columns 33 | self.dtype = dtype 34 | self.verbose = verbose 35 | if self.verbose: print('NMF class initalized.') 36 | 37 | def initalize_data(self,a,phi,b,prior_W,prior_H,Beta,K0,dtype = torch.float32): 38 | """ 39 | Initializes dataset. 40 | 41 | Args: 42 | * a 43 | * phi: dispersion parameter - multiplied by variance if objective 44 | function is Gaussian (see Tan & Fevotte 2013) 45 | * b 46 | * prior_W 47 | * prior_H 48 | * Beta 49 | * K0: set to number of input features if not provided 50 | """ 51 | print('NMF class initialized.') 52 | 53 | def initalize_data(self,a,phi,b,prior_W,prior_H,Beta,K0,use_val_set,dtype = torch.float32): 54 | 55 | self.V = np.array(self.V) #when gets called in a loop as in run_parameter_sweep this can get updated to a torch tensor in a previous iteration which breaks some numpy functions 56 | 57 | if K0 == None: 58 | self.K0 = self.M 59 | self.number_of_active_components = self.M 60 | else: 61 | self.K0 = K0 62 | self.number_of_active_components = self.K0 63 | 64 | if self.objective.lower() == 'poisson': 65 | self.phi = torch.tensor(phi,dtype=dtype,requires_grad=False) 66 | else: 67 | self.phi = torch.tensor(np.var(self.V)* phi,dtype=dtype,requires_grad=False) 68 | 69 | if use_val_set: 70 | torch.manual_seed(0) #get the same mask each time 71 | self.mask = (torch.rand(self.V.shape) > 0.2).type(self.dtype) #create mask, randomly mask ~20% of data in shape V. Only used when passed 72 | else: 73 | self.mask = torch.ones(self.V.shape, dtype=self.dtype) 74 | 75 | self.a = a 76 | self.prior_W = prior_W 77 | self.prior_H = prior_H 78 | self.C = [] 79 | self.b = b 80 | 81 | W0 = np.multiply(np.random.uniform(size=[self.M, self.K0])+self.eps_.numpy(), np.sqrt(self.V_max)) 82 | H0 = np.multiply(np.random.uniform(size=[self.K0, self.N])+self.eps_.numpy(), np.sqrt(self.V_max)) 83 | L0 = np.sum(W0,axis=0) + np.sum(H0,axis=1) 84 | 85 | self.W = torch.tensor(W0, dtype=self.dtype, requires_grad=False) 86 | self.H = torch.tensor(H0, dtype=self.dtype, requires_grad=False) 87 | self.Lambda = torch.tensor(L0, dtype=torch.float32, requires_grad=False) 88 | 89 | # calculate default b as described in Tan and Fevotte (2012) 90 | if self.b == None or self.b == 'None': 91 | # L1 ARD 92 | if self.prior_H == 'L1' and self.prior_W == 'L1': 93 | 94 | self.bcpu = np.sqrt(np.true_divide( (self.a - 1)*(self.a - 2) * np.mean(self.V),self.K0 )) 95 | self.b = torch.tensor( 96 | np.sqrt(np.true_divide( (self.a - 1)*(self.a - 2) * np.mean(self.V),self.K0 )) 97 | ,dtype=self.dtype,requires_grad=False) 98 | 99 | self.C = torch.tensor(self.N + self.M + self.a + 1, dtype=self.dtype, requires_grad=False) 100 | # L2 ARD 101 | elif self.prior_H == 'L2' and self.prior_W == 'L2': 102 | 103 | self.bcpu = np.true_divide(np.pi * (self.a - 1) * np.mean(self.V),2*self.K0) 104 | self.b = torch.tensor( 105 | np.true_divide(np.pi * (self.a - 1) * np.mean(self.V),2*self.K0), 106 | dtype=self.dtype,requires_grad=False) 107 | 108 | self.C = torch.tensor( (self.N + self.M)*0.5 + self.a + 1, dtype=self.dtype,requires_grad=False) 109 | 110 | # L1 - L2 ARD 111 | elif self.prior_H == 'L1' and self.prior_W == 'L2': 112 | self.bcpu = np.true_divide(np.mean(self.V)*np.sqrt(2)*gamma(self.a-3/2),self.K0*np.sqrt(np.pi)*gamma(self.a)) 113 | self.b = torch.tensor( 114 | np.true_divide(np.mean(self.V)*np.sqrt(2)*gamma(self.a-3/2),self.K0*np.sqrt(np.pi)*gamma(self.a)) 115 | ,dtype=self.dtype,requires_grad=False) 116 | self.C = torch.tensor(self.N + self.M/2 + self.a + 1, dtype=self.dtype) 117 | elif self.prior_H == 'L2' and self.prior_W == 'L1': 118 | self.bcpu = np.true_divide(np.mean(self.V)*np.sqrt(2)*gamma(self.a-3/2),self.K0*np.sqrt(np.pi)*gamma(self.a)) 119 | self.b = torch.tensor( 120 | np.true_divide(np.mean(self.V)*np.sqrt(2)*gamma(self.a-3/2),self.K0*np.sqrt(np.pi)*gamma(self.a)), 121 | dtype=self.dtype,requires_grad=False) 122 | self.C = torch.tensor(self.N/2 + self.M + self.a + 1, dtype=self.dtype) 123 | else: 124 | self.bcpu = self.b 125 | self.b = torch.tensor(self.b, dtype=self.dtype,requires_grad=False) 126 | if self.prior_H == 'L1' and self.prior_W == 'L1': 127 | self.C = torch.tensor(self.N + self.M + self.a + 1, dtype=self.dtype,requires_grad=False) 128 | # L2 ARD 129 | elif self.prior_H == 'L2' and self.prior_W == 'L2': 130 | self.C = torch.tensor( (self.N + self.M)*0.5 + self.a + 1, dtype=self.dtype,requires_grad=False) 131 | # L1 - L2 ARD 132 | elif self.prior_H == 'L1' and self.prior_W == 'L2': 133 | self.C = torch.tensor(self.N + self.M/2 + self.a + 1, dtype=self.dtype,requires_grad=False) 134 | elif self.prior_H == 'L2' and self.prior_W == 'L1': 135 | self.C = torch.tensor(self.N/2 + self.M + self.a + 1, dtype=self.dtype,requires_grad=False) 136 | 137 | self.V = torch.tensor(self.V,dtype=self.dtype,requires_grad=False) 138 | if self.verbose: print('NMF data and parameters set.') 139 | 140 | def get_number_of_active_components(self): 141 | self.number_of_active_components = torch.sum(torch.sum(self.W,0)> 0.0, dtype=self.dtype) 142 | 143 | def print_report(iter,report,verbose,tag): 144 | """ 145 | Prints report. 146 | """ 147 | if verbose: 148 | print("nit={:>5} K={:>5} | obj={:.2f}\tb_div={:.2f}\tlam={:.2f}\tdel={:.8f}\tsumW={:.2f}\tsumH={:.2f}".format( 149 | iter, 150 | report[iter]['K'], 151 | report[iter]['obj'], 152 | report[iter]['b_div'], 153 | report[iter]['lam'], 154 | report[iter]['del'], 155 | report[iter]['W_sum'], 156 | report[iter]['H_sum'] 157 | ) 158 | ) 159 | else: 160 | stdout.write("\r{}nit={:>5} K={} \tdel={:.8f}".format( 161 | tag, 162 | iter, 163 | report[iter]['K'], 164 | report[iter]['del'] 165 | ) 166 | ) 167 | 168 | def run_method_engine( 169 | results: ARD_NMF, 170 | a: float, 171 | phi: float, 172 | b: float, 173 | Beta: int, 174 | W_prior: str, 175 | H_prior: str, 176 | K0: int, 177 | tolerance: float, 178 | max_iter: int, 179 | use_val_set: bool, 180 | report_freq: int = 10, 181 | active_thresh: float = 1e-5, 182 | send_end: Union[mpc.Connection, None] = None, 183 | cuda_int: Union[int, None] = 0, 184 | verbose: bool = True, 185 | tag: str = "" 186 | ) -> (pd.DataFrame, pd.DataFrame, np.ndarray, pd.DataFrame, np.ndarray): 187 | """ 188 | Run ARD-NMF Engine. 189 | ------------------------------------------------------------------------ 190 | Args: 191 | * results: initialized ARD_NMF class 192 | * a: shape parameter 193 | * phi: dispersion parameter 194 | * b: shape parameter 195 | * Beta: defined by objective function 196 | * W_prior: prior on W matrix ("L1" or "L2") 197 | * H_prior: prior on H matrix ("L1" or "L2") 198 | * K0: starting number of latent components 199 | * tolerance: end-point of optimization 200 | * max_iter: maximum number of iterations for algorithm 201 | * use_val_set: use validation set for ARD-NMF 202 | If False (default), set masks to all ones. 203 | Otherwise, use 0/1 mask to hold out 0's as validation set during training and will report objective function value for that set. 204 | * report_freq: how often to print updates 205 | * active_thresh: threshold for a latent component's impact on 206 | signature if the latent factor is less than this, it does not contribute 207 | * send_end: mpc.Connection resulting from multiprocessing.Pipe, 208 | for use in parameter sweep implementation 209 | * cuda_int: GPU to use. Defaults to 0. If "None" or if no GPU available, 210 | will perform decomposition using CPU. 211 | * verbose: verbose logging 212 | 213 | Returns: 214 | * H: (samples x K) 215 | * W: (K x features) 216 | * markers 217 | * signatures 218 | """ 219 | # initalize the NMF run 220 | results.initalize_data(a, phi, b, W_prior, H_prior, Beta, K0, use_val_set) 221 | # specify GPU 222 | cuda_string = 'cuda:'+str(cuda_int) 223 | # copy data to GPU 224 | if torch.cuda.device_count() > 0 and cuda_int is not None: 225 | if verbose: print(" * Using GPU: {}".format(cuda_string)) 226 | W,H,V,Lambda,C,b0,eps_,phi,mask = results.W.cuda(cuda_string),results.H.cuda(cuda_string),results.V.cuda(cuda_string),results.Lambda.cuda(cuda_string),results.C.cuda(cuda_string),results.b.cuda(cuda_string),results.eps_.cuda(cuda_string),results.phi.cuda(cuda_string),results.mask.cuda(cuda_string) 227 | else: 228 | W,H,V,Lambda,C,b0,eps_,phi,mask = results.W,results.H,results.V,results.Lambda,results.C,results.b,results.eps_,results.phi,results.mask 229 | if verbose: print(" * Using CPU") 230 | 231 | # tracking variables 232 | deltrack = 1000 233 | times = list() 234 | report = dict() 235 | iter = 0 236 | lam_previous = Lambda 237 | if verbose: print('%%%%%%%%%%%%%%%') 238 | if verbose: print('a =',results.a) 239 | if verbose: print('b =',results.bcpu) 240 | if verbose: print('%%%%%%%%%%%%%%%') 241 | 242 | # set method 243 | method = NMF_algorithim(Beta, H_prior, W_prior) 244 | 245 | start_time = time.time() 246 | while deltrack >= tolerance and iter < max_iter: 247 | # compute updates 248 | H,W,Lambda = method.forward(W,H,V,Lambda,C,b0,eps_,phi,mask) 249 | 250 | # compute objective and cost (excluding validation set, when mask is passed) 251 | l_ = beta_div(Beta,V,W,H,eps_,mask) 252 | cost_ = calculate_objective_function(Beta,V,W,H,Lambda,C,eps_,phi,results.K0,mask) 253 | 254 | # update tracking 255 | deltrack = torch.max(torch.div(torch.abs(Lambda-lam_previous), lam_previous+1e-30)) 256 | lam_previous = Lambda 257 | 258 | # ---------------------------- Reporting ---------------------------- # 259 | if iter % report_freq == 0: 260 | report[iter] = { 261 | 'K': torch.sum((torch.sum(H,1) * torch.sum(W,0))>active_thresh).cpu().numpy(), 262 | 'obj': cost_.cpu().numpy(), 263 | 'b_div': l_.cpu().numpy(), 264 | 'lam': torch.sum(Lambda).cpu().numpy(), 265 | 'del': deltrack.cpu().numpy(), 266 | 'W_sum': torch.sum(W).cpu().numpy(), 267 | 'H_sum': torch.sum(H).cpu().numpy() 268 | } 269 | print_report(iter,report,verbose,tag) 270 | # ------------------------------------------------------------------- # 271 | iter+=1 272 | 273 | 274 | # --------------------------- Final Report --------------------------- # 275 | report[iter] = { 276 | 'K': torch.sum((torch.sum(H,1) * torch.sum(W,0))>active_thresh).cpu().numpy(), 277 | 'obj': cost_.cpu().numpy(), 278 | 'b_div': l_.cpu().numpy(), 279 | 'lam': torch.sum(Lambda).cpu().numpy(), 280 | 'del': deltrack.cpu().numpy(), 281 | 'W_sum': torch.sum(W).cpu().numpy(), 282 | 'H_sum': torch.sum(H).cpu().numpy() 283 | } 284 | 285 | end_time = time.time() 286 | 287 | #compute validation set performance 288 | if use_val_set: 289 | heldout_mask = 1-mask #now select heldout values (inverse of mask) 290 | report[iter]['b_div_val'] = beta_div(Beta,V,W,H,eps_,heldout_mask) 291 | report[iter]['obj_val'] = calculate_objective_function(Beta,V,W,H,Lambda,C,eps_,phi,results.K0,heldout_mask) 292 | #print("validation set objective=%s\tbeta_div=%s" % (cost_.cpu().numpy(),l_.cpu().numpy())) 293 | else: 294 | report[iter]['b_div_val'] = None 295 | report[iter]['obj_val'] = None 296 | 297 | print_report(iter,report,verbose,tag) 298 | 299 | if not verbose: 300 | stdout.write("\n") 301 | # ------------------------------------------------------------------- # 302 | 303 | if send_end != None: 304 | send_end.send([W.cpu().numpy(), H.cpu().numpy(), mask.cpu().numpy(), cost_.cpu().numpy(), l_.cpu().numpy(), report[iter]['b_div_val'].cpu().numpy(), report[iter]['obj_val'].cpu().numpy(), end_time-start_time,]) 305 | else: 306 | final_report = pd.DataFrame.from_dict(report).T 307 | final_report.index.name = 'iter' 308 | return W.cpu().numpy(), H.cpu().numpy(), cost_.cpu().numpy(), final_report, Lambda.cpu().numpy(), mask.cpu().numpy() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Broad Institute 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /NMF_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | SEloss = nn.MSELoss(reduction = 'sum') 4 | 5 | class NMF_algorithim(nn.Module): 6 | ''' implements ARD NMF from https://arxiv.org/pdf/1111.6085.pdf ''' 7 | def __init__(self,Beta,H_prior,W_prior): 8 | super(NMF_algorithim, self).__init__() 9 | # Beta paramaterizes the objective function 10 | # Beta = 1 induces a poisson objective 11 | # Beta = 2 induces a gaussian objective 12 | # Priors on the component matrices are Exponential (L1) and half-normal (L2) 13 | 14 | if Beta == 1 and H_prior == 'L1' and W_prior == 'L1' : 15 | self.update_W = update_W_poisson_L1 16 | self.update_H = update_H_poisson_L1 17 | self.lambda_update = update_lambda_L1 18 | 19 | elif Beta == 1 and H_prior == 'L1' and W_prior == 'L2': 20 | self.update_W = update_W_poisson_L2 21 | self.update_H = update_H_poisson_L1 22 | self.lambda_update = update_lambda_L2_L1 23 | 24 | elif Beta == 1 and H_prior == 'L2' and W_prior == 'L1': 25 | self.update_W = update_W_poisson_L1 26 | self.update_H = update_H_poisson_L2 27 | self.lambda_update = update_lambda_L1_L2 28 | 29 | elif Beta == 1 and H_prior == 'L2' and W_prior == 'L2': 30 | self.update_W = update_W_poisson_L2 31 | self.update_H = update_H_poisson_L2 32 | self.lambda_update = update_lambda_L2 33 | 34 | if Beta == 2 and H_prior == 'L1' and W_prior == 'L1': 35 | self.update_W = update_W_gaussian_L1 36 | self.update_H = update_H_gaussian_L1 37 | self.lambda_update = update_lambda_L1 38 | 39 | elif Beta == 2 and H_prior == 'L1' and W_prior == 'L2': 40 | self.update_W = update_W_gaussian_L2 41 | self.update_H = update_H_gaussian_L1 42 | self.lambda_update = update_lambda_L2_L1 43 | 44 | elif Beta == 2 and H_prior == 'L2' and W_prior == 'L1': 45 | self.update_W = update_W_gaussian_L1 46 | self.update_H = update_H_gaussian_L2 47 | self.lambda_update = update_lambda_L1_L2 48 | 49 | elif Beta == 2 and H_prior == 'L2' and W_prior == 'L2': 50 | self.update_W = update_W_gaussian_L2 51 | self.update_H = update_H_gaussian_L2 52 | self.lambda_update = update_lambda_L2 53 | 54 | def forward(self,W, H, V, lambda_, C, b0, eps_, phi, mask): 55 | h_ = self.update_H(H, W, lambda_, phi, V, eps_, mask) 56 | w_ = self.update_W(h_, W, lambda_, phi, V, eps_, mask) 57 | lam_ = self.lambda_update(w_,h_,b0,C,eps_) 58 | return h_, w_,lam_ 59 | 60 | 61 | 62 | def beta_div(Beta,V,W,H,eps_,mask): 63 | V_ap = torch.matmul(W, H).type(V.dtype) + eps_.type(V.dtype) 64 | if Beta == 2: 65 | return SEloss(V*mask,V_ap*mask)/2 66 | if Beta == 1: 67 | lr = torch.log(torch.div(V, V_ap)) 68 | return torch.sum((V*mask*lr) - V*mask + V_ap*mask) 69 | 70 | def calculate_objective_function(Beta,V,W,H,lambda_,C, eps_,phi,K,mask): 71 | """ 72 | If a validation set is being used, this will mask heldout set when evaluating objective function during training. 73 | Can calculate validation set objective value by passing inverse mask. 74 | """ 75 | loss = beta_div(Beta,V,W,H,eps_,mask) 76 | cst = (K*C)*(1.0-torch.log(C)) 77 | return torch.pow(phi,-1)*loss + (C*torch.sum(torch.log(lambda_ * C))) + cst 78 | 79 | def update_H_poisson_L1(H, W, lambda_, phi, V, eps_,mask): 80 | #beta = 1 gamma(beta) = 1 81 | 82 | denom = torch.matmul(W.transpose(1,0), mask) + torch.div(phi, lambda_).reshape(-1,1) + eps_ 83 | V_ap = torch.matmul(W, H) + eps_ 84 | V_res = torch.div(V*mask, V_ap) 85 | update = torch.div(torch.matmul(W.transpose(1,0), V_res), denom) 86 | return H * update 87 | 88 | def update_H_poisson_L2(H,W,lambda_,phi,V, eps_,mask): 89 | #beta = 1 zeta(beta) = 1/2 90 | denom = torch.matmul(W.transpose(1,0), mask) + torch.div(phi*H, lambda_.reshape(-1,1)) + eps_ 91 | V_ap = torch.matmul(W, H) + eps_ 92 | update = torch.pow(torch.div(torch.matmul(W.transpose(0,1), torch.div(V*mask, V_ap)), denom),0.5) 93 | return H * update 94 | 95 | def update_H_gaussian_L1(H,W,lambda_,phi,V,eps_,mask): 96 | #beta = 2 gamma(beta) = 1 97 | V_ap = torch.matmul(W, H) + eps_ 98 | denom = torch.matmul(W.transpose(0,1),V_ap*mask) + torch.div(phi, lambda_ ).reshape(-1,1) + eps_ 99 | update = torch.div(torch.matmul(W.transpose(0,1),V*mask),denom) 100 | return H * update 101 | 102 | def update_H_gaussian_L2(H,W,lambda_,phi,V,eps_,mask): 103 | #beta = 2 zeta(beta) = 1 104 | V_ap = torch.matmul(W, H).type(V.dtype) + eps_ 105 | denom = torch.matmul(W.transpose(0,1).type(V.dtype),V_ap*mask) + torch.div(phi * H, lambda_.reshape(-1,1)).type(V.dtype) + eps_ 106 | update = torch.div(torch.matmul(W.transpose(0,1).type(V.dtype),V*mask),denom) 107 | return H * update.type(torch.float32) 108 | 109 | def update_W_poisson_L1(H, W, lambda_, phi, V, eps_,mask): 110 | #beta = 1 gamma(beta) = 1 111 | denom = torch.matmul(mask, H.transpose(1,0)) + torch.div(phi, lambda_ ) + eps_ 112 | V_ap = torch.matmul(W, H) + eps_ 113 | V_res = torch.div(V*mask, V_ap) 114 | update = torch.div(torch.matmul(V_res, H.transpose(0,1)), denom) 115 | return W * update 116 | 117 | def update_W_poisson_L2(H,W,lambda_,phi,V,eps_,mask): 118 | # beta = 1 zeta(beta) = 1/2 119 | V_ap = torch.matmul(W,H) + eps_ 120 | V_res = torch.div(V*mask, V_ap) 121 | denom = torch.matmul(mask, H.transpose(1,0)) + torch.div(phi*W,lambda_) + eps_ 122 | update = torch.pow(torch.div(torch.matmul(V_res,H.transpose(0,1)),denom),0.5) 123 | return W * update 124 | 125 | def update_W_gaussian_L1(H,W,lambda_,phi,V,eps_,mask): 126 | #beta = 2 gamma(beta) = 1 127 | V_ap = torch.matmul(W,H).type(V.dtype) + eps_ 128 | denom = torch.matmul(V_ap*mask,H.transpose(0,1).type(V.dtype)) + torch.div(phi,lambda_).type(V.dtype) + eps_ 129 | update = torch.div(torch.matmul(V*mask,H.transpose(0,1).type(V.dtype)),denom) 130 | return W * update.type(torch.float32) 131 | 132 | def update_W_gaussian_L2(H,W,lambda_,phi,V,eps_,mask): 133 | #beta = 2 zeta(beta) = 1 134 | V_ap = torch.matmul(W,H) + eps_ 135 | denom = torch.matmul(V_ap*mask,H.transpose(0,1)) + torch.div(phi*W,lambda_) + eps_ 136 | update = torch.div(torch.matmul(V*mask,H.transpose(0,1)),denom) 137 | return W * update 138 | 139 | # update tolerance value for early stop criteria 140 | def update_del(lambda_, lambda_last): 141 | del_ = torch.max(torch.div(torch.abs(lambda_ - lambda_last)), lambda_last) 142 | return del_ 143 | 144 | 145 | 146 | def update_lambda_L1(W,H,b0,C,eps_): 147 | return torch.div(torch.sum(W,0) + torch.sum(H,1) + b0, C) 148 | 149 | def update_lambda_L2(W,H,b0,C,eps_): 150 | return torch.div(0.5*torch.sum(W*W,0) + (0.5*torch.sum(H*H,1))+b0,C) 151 | 152 | def update_lambda_L1_L2(W,H,b0,C,eps_): 153 | return torch.div(torch.sum(W,0) + 0.5*torch.sum(H*H,1)+b0,C) 154 | 155 | def update_lambda_L2_L1(W,H,b0,C,eps_): 156 | return torch.div(0.5*torch.sum(torch.pow(W,2),0) + torch.sum(H,1)+b0,C) 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SignatureAnalyzer-GPU 2 | 3 | # Installation 4 | ``` 5 | git clone https://github.com/broadinstitute/SignatureAnalyzer-GPU.git 6 | ``` 7 | To install pytorch please use Anaconda (find more details @ https://pytorch.org/): 8 | ``` 9 | conda install pytorch torchvision cudatoolkit=9.0 -c pytorch 10 | ``` 11 | 12 | # Setup 13 | For easy set up you can create a python virtual enviroment which matches our own: 14 | ``` 15 | $ virtualenv venv 16 | 17 | $ source venv/bin/activate . 18 | 19 | (venv)$ pip install -r requirements-py3.txt 20 | ``` 21 | 22 | # How to run a single decomposition 23 | SignatureAnalyzer runs on a count matrix (passed to the argument --data) and performs regularized NMF (Bayes NMF). You can specify the regularization you want on the resulting W and H matrices by using the arguments --prior_on_W and --prior_on_H . Passing "L1" is equivalent to an exponential prior and "L2" is half-normal. 24 | 25 | For mathematical details see: 26 | 27 | 1. Tan, V. Y. F., Edric, C. & Evotte, F. Automatic Relevance Determination in Nonnegative Matrix Factorization with the β-Divergence. (2012). (https://arxiv.org/pdf/1111.6085.pdf) 28 | 29 | SignatureAnalyzer-CPU source publications: 30 | 31 | 1. Kim, J. et al. Somatic ERCC2 mutations are associated with a distinct genomic signature in urothelial tumors. Nat. Genet. 48, 600–606 (2016). (https://www.nature.com/articles/ng.3557) 32 | 33 | 2. Kasar, S. et al. Whole-genome sequencing reveals activation-induced cytidine deaminase signatures during indolent chronic lymphocytic leukaemia evolution. Nat. Commun. 6, 8866 (2015). (https://www.nature.com/articles/ncomms9866) 34 | 35 | 36 | Note that as part of this work we derived the form for a mixed prior (e.g. L1 on W and L2 on H) see the supplemental note in the repo. 37 | 38 | Example command line for a single run of SignatureAnalyzer-GPU: 39 | ``` 40 | python SignatureAnalyzer-GPU.py --data example_data/POLEMSI_counts_matrix.txt --max_iter=100000 --output_dir POLEMSI_EXAMPLE --prior_on_W L1 --prior_on_H L2 --labeled 41 | ``` 42 | Data should be formatted so that the rows are the categories and the columns are the samples. For a full description of inputs and outputs please see the repository wiki. 43 | 44 | 45 | # How to run an array of decompositions 46 | The short run time of SignatureAnalyzer-GPU enables performing a parameter search or running the same parameter settings many times to find a maximum likely decomposition or characterize the modal number of clusters/signatures for some setting. To perform such an analysis simply save parameters you would like to run in a tsv and pass it to the --parameters_file argument. We provide the parameters file and count matrix used to generate Figure 1B from the manuscript in the example_data directory. 47 | 48 | NOTE this is automatically configured to run on a single or multiple GPUs just run as usual to perform parallel runs. 49 | ``` 50 | python SignatureAnalyzer-GPU.py --data example_data/POLEMSI_counts_matrix.txt --prior_on_W L1 --prior_on_H L2 --output_dir example_data/POLEMSI_outputs/ --parameters_file example_data/POLEMSI_params.txt --max_iter 20000 --labeled --tolerance 1e-7 51 | ``` 52 | For a full description of inputs and outputs related to parameters_file runs see the repository wiki. 53 | -------------------------------------------------------------------------------- /SignatureAnalyzer-GPU.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sys 4 | import argparse 5 | import time 6 | from scipy.special import gamma 7 | import os 8 | import pickle 9 | import torch 10 | import NMF_functions 11 | from ARD_NMF import ARD_NMF 12 | import feather 13 | from ARD_NMF import run_method_engine 14 | import torch.nn as nn 15 | import torch.multiprocessing as mp 16 | 17 | def createFolder(directory): 18 | try: 19 | if not os.path.exists(directory): 20 | os.makedirs(directory) 21 | except OSError: 22 | print ('Error: Creating directory. ' + directory) 23 | 24 | def run_parameter_sweep(parameters, data, args, Beta): 25 | output = [] 26 | num_processes = torch.cuda.device_count() 27 | batches = int(len(parameters) / num_processes) 28 | idx = 0 29 | objectives = [] 30 | bdivs = [] 31 | val_objectives = [] 32 | val_bdivs = [] 33 | nsigs = [] 34 | times = [] 35 | while idx <= len(parameters)-num_processes: 36 | print(idx) 37 | pipe_list = [] 38 | processes = [] 39 | for rank in range(num_processes): 40 | recv_end, send_end = mp.Pipe(False) 41 | p = mp.Process(target=run_method_engine, args=( 42 | data, 43 | parameters.iloc[idx+rank]['a'], 44 | parameters.iloc[idx+rank]['phi'], 45 | parameters.iloc[idx+rank]['b'], 46 | Beta, 47 | args.prior_on_W, 48 | args.prior_on_H, 49 | parameters.iloc[idx+rank]['K0'], 50 | args.tolerance, 51 | args.max_iter, 52 | args.use_val_set, 53 | args.report_frequency, 54 | 1e-5, 55 | send_end, 56 | rank, 57 | )) 58 | 59 | pipe_list.append(recv_end) 60 | processes.append(p) 61 | p.start() 62 | 63 | result_list = [x.recv() for x in pipe_list] 64 | for p in processes: 65 | p.join() 66 | 67 | nsig = [write_output( 68 | x[0], 69 | x[1], 70 | data.channel_names, 71 | data.sample_names, 72 | args.output_dir, 73 | parameters['label'][idx+i] 74 | ) for i,x in enumerate(result_list)] 75 | 76 | [nsigs.append(ns) for i,ns in enumerate(nsig)] 77 | [objectives.append(obj[3]) for i,obj in enumerate(result_list)] 78 | [bdivs.append(obj[4]) for i,obj in enumerate(result_list)] 79 | [val_objectives.append(obj[5]) for i,obj in enumerate(result_list)] 80 | [val_bdivs.append(obj[6]) for i,obj in enumerate(result_list)] 81 | [times.append(time[7]) for i,time in enumerate(result_list)] 82 | 83 | idx += num_processes 84 | 85 | if idx < len(parameters): 86 | for i in range(len(parameters)-idx): 87 | idx+=i 88 | W,H,cost,final_report,lam,mask = run_method_engine( 89 | data, 90 | parameters.iloc[idx]['a'], 91 | parameters.iloc[idx]['phi'], 92 | parameters.iloc[idx]['b'], 93 | Beta, 94 | args.prior_on_W, 95 | args.prior_on_H, 96 | parameters.iloc[idx]['K0'], 97 | args.tolerance, 98 | args.max_iter, 99 | args.use_val_set, 100 | args.report_frequency, 101 | 1e-5, 102 | send_end, 103 | rank 104 | ) 105 | 106 | nsig = write_output( 107 | W, 108 | H, 109 | mask, 110 | data.channel_names, 111 | data.sample_names, 112 | args.output_dir, 113 | parameters['label'][idx]) 114 | 115 | nsigs.append(nsig) 116 | objectives.append(cost) 117 | 118 | times.append(time) 119 | val_objectives.append(final_report['obj_val']) 120 | bdivs.append(final_report['b_div']) 121 | val_bdivs.append(final_report['b_div_val']) 122 | 123 | parameters['nsigs'] = nsigs 124 | parameters['objective_trainset'] = objectives 125 | parameters['bdiv_trainset'] = bdivs 126 | parameters['objective_valset'] = val_objectives 127 | parameters['bdiv_valset'] = val_bdivs 128 | parameters['times'] = times 129 | parameters.to_csv(args.output_dir + '/parameters_with_results.txt',sep='\t',index=None) 130 | 131 | def write_output(W, H, mask, channel_names, sample_names, output_directory, label, active_thresh = 1e-5): 132 | createFolder(output_directory) 133 | nonzero_idx = (np.sum(H, axis=1) * np.sum(W, axis=0)) > active_thresh 134 | W_active = W[:, nonzero_idx] 135 | H_active = H[nonzero_idx, :] 136 | nsig = np.sum(nonzero_idx) 137 | # Normalize W and transfer weight to H matrix 138 | W_weight = np.sum(W_active, axis=0) 139 | W_final = W_active / W_weight 140 | H_final = W_weight[:, np.newaxis] * H_active 141 | 142 | sig_names = ['W' + str(j) for j in range(1, nsig + 1)] 143 | W_df = pd.DataFrame(data=W_final, index=channel_names, columns=sig_names) 144 | H_df = pd.DataFrame(data=H_final, index=sig_names, columns=sample_names) 145 | mask_df = pd.DataFrame(mask, index=channel_names, columns=sample_names) 146 | 147 | # Write W and H matrices 148 | W_df.to_csv(output_directory + '/'+label+ '_W.txt', sep='\t') 149 | H_df.to_csv(output_directory + '/'+label+ '_H.txt', sep='\t') 150 | mask_df.to_csv(output_directory + '/'+label+ '_mask.txt', sep='\t') 151 | 152 | return nsig 153 | 154 | def main(): 155 | ''' Run ARD NMF''' 156 | torch.multiprocessing.set_start_method('spawn') 157 | 158 | parser = argparse.ArgumentParser( 159 | description='NMF with some sparsity penalty described https://arxiv.org/pdf/1111.6085.pdf') 160 | parser.add_argument('--data', help='Data Matrix', required=True) 161 | parser.add_argument('--feather', help='Input in feather format', required=False, default=False, action='store_true') 162 | parser.add_argument('--parquet', help='Input in parquet format', required=False, default=False, action='store_true') 163 | parser.add_argument('--K0', help='Initial K parameter', required=False, default=None, type=int) 164 | parser.add_argument('--max_iter', help='maximum iterations', required=False, default=10000, type=int) 165 | parser.add_argument('--del_', help='Early stop condition based on lambda change', required=False, default=1, 166 | type=int) 167 | parser.add_argument('--tolerance', help='Early stop condition based on max lambda entry', required=False, default=1e-6, 168 | type=float) 169 | parser.add_argument('--phi', help='dispersion parameter see paper for discussion of choosing phi ' 170 | 'default = 1', required=False, default=1.0, type=float) 171 | parser.add_argument('--a', help='Hyperparamter for lambda. We recommend trying various values of a. Smaller values' 172 | 'will result in sparser results a good starting point might be' 173 | 'a = log(F+N)', required=False, default=10.0,type=float) 174 | 175 | parser.add_argument('--b', help='Hyperparamter for lambda. Default used is as recommended in Tan and Fevotte 2012', 176 | required = False,type=float, default = None) 177 | parser.add_argument('--objective',help='Defines the data objective. Choose between "poisson" or "gaussian". Defaults to Poisson', 178 | required=False,default='poisson',type=str) 179 | 180 | parser.add_argument('--prior_on_W',help = 'Prior on W matrix "L1" (exponential) or "L2" (half-normal)' 181 | ,required = False, default = 'L1',type=str) 182 | parser.add_argument('--prior_on_H',help = 'Prior on H matrix "L1" (exponential) or "L2" (half-normal)' 183 | ,required = False, default = 'L1',type=str) 184 | 185 | parser.add_argument('--output_dir', help='output_file_name if run in array mode this correspond to the output directory', required=True) 186 | parser.add_argument('--labeled', help='Input has row and column labels', required=False,default=False, action='store_true') 187 | parser.add_argument('--report_frequency', help='Number of iterations between progress reports', required=False, 188 | default=100, type=int) 189 | parser.add_argument('--dtype', help='Floating point accuracy', required=False, 190 | default='Float32', type=str) 191 | parser.add_argument('--parameters_file', help='allows running many different configurations of the NMF method on a multi' 192 | 'GPU system. To run in this mode provide this argument with a text file with ' 193 | 'the following headers:(a,phi,b,prior_on_W,prior_on_H,Beta,label) label ' 194 | 'indicates the output stem of the results from each run.', required = False, default = None) 195 | parser.add_argument('--force_use_val_set', dest='use_val_set', action='store_true', help='override detaults and use a validation set no matter what,' 196 | 'even when parameter search file is not passed.' 197 | 'If neither --force_use_val_set or --force_no_val_set is passed, will default to create and evaluate on' 198 | 'a held out validation set when parameters_file is provided, and not otherwise.') 199 | parser.add_argument('--force_no_val_set', dest='use_val_set', action='store_false', help='override detaults and dont use a validation set no matter what,' 200 | 'even when parameter search file is passed.' 201 | 'If neither --force_use_val_set or --force_no_val_set is passed, will default to create and evaluate on' 202 | 'a held out validation set when parameters_file is provided, and not otherwise.') 203 | parser.set_defaults(use_val_set=None) 204 | args = parser.parse_args() 205 | 206 | 207 | print('Reading data frame from '+ args.data) 208 | 209 | if args.dtype == 'Float32': 210 | args.dtype = torch.float32 211 | elif args.dtype == 'Float16': 212 | args.dtype = torch.float16 213 | 214 | if args.parquet: 215 | dataset = pd.read_parquet(args.data) 216 | elif args.feather: 217 | print('loading feather...') 218 | dataset = feather.read_dataframe(args.data) 219 | else: 220 | if args.labeled: 221 | dataset = pd.read_csv(args.data, sep='\t', header=0, index_col=0) 222 | else: 223 | dataset = pd.read_csv(args.data, sep='\t', header=None) 224 | 225 | 226 | if args.objective.lower() == 'poisson': 227 | Beta = 1 228 | elif args.objective.lower() == 'gaussian': 229 | Beta = 2 230 | else: 231 | print('objective parameter should be one of "gaussian" or "poisson"') 232 | sys.exit() 233 | 234 | data = ARD_NMF(dataset, args.objective) 235 | 236 | if args.parameters_file != None: 237 | if args.use_val_set == None: 238 | args.use_val_set = True 239 | parameters = pd.read_csv(args.parameters_file, sep='\t') 240 | run_parameter_sweep(parameters, data, args, Beta) 241 | else: 242 | if args.use_val_set == None: 243 | args.use_val_set=False 244 | W,H,cost,final_report,lam,mask = run_method_engine( 245 | data, 246 | args.a, 247 | args.phi, 248 | args.b, 249 | Beta, 250 | args.prior_on_W, 251 | args.prior_on_H, 252 | args.K0, 253 | args.tolerance, 254 | args.max_iter, 255 | args.use_val_set, 256 | args.report_frequency, 257 | ) 258 | nsig = write_output(W,H,mask,data.channel_names,data.sample_names,args.output_dir,args.output_dir) 259 | 260 | if __name__ == "__main__": 261 | main() -------------------------------------------------------------------------------- /SupplementalNote.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/SignatureAnalyzer-GPU/950300e487561b09a756385b80287bcafbfb1169/SupplementalNote.pdf -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/SignatureAnalyzer-GPU/950300e487561b09a756385b80287bcafbfb1169/__init__.py -------------------------------------------------------------------------------- /example_data/POLEMSI_params.txt: -------------------------------------------------------------------------------- 1 | a phi b K0 label 2 | 10 1 None 96 run_1 3 | 10 1 None 96 run_2 4 | 10 1 None 96 run_3 5 | 10 1 None 96 run_4 6 | 10 1 None 96 run_5 7 | 10 1 None 96 run_6 8 | 10 1 None 96 run_7 9 | 10 1 None 96 run_8 10 | 10 1 None 96 run_9 11 | 10 1 None 96 run_10 12 | 10 1 None 96 run_11 13 | 10 1 None 96 run_12 14 | 10 1 None 96 run_13 15 | 10 1 None 96 run_14 16 | 10 1 None 96 run_15 17 | 10 1 None 96 run_16 18 | 10 1 None 96 run_17 19 | 10 1 None 96 run_18 20 | 10 1 None 96 run_19 21 | 10 1 None 96 run_20 22 | 10 1 None 96 run_21 23 | -------------------------------------------------------------------------------- /requirements-py3.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | pandas==0.23.0 3 | pyarrow==0.11.1 4 | scikit-image==0.13.1 5 | scikit-learn==0.19.1 6 | scipy==1.1.0 7 | 8 | --------------------------------------------------------------------------------