├── .DS_Store ├── .gitignore ├── LICENSE ├── README.rst ├── bin └── velorama ├── datasets ├── .DS_Store ├── dataset_A.h5ad ├── dataset_B.h5ad ├── dataset_C.h5ad └── dataset_D.h5ad ├── pyproject.toml └── velorama ├── __init__.py ├── models.py ├── run.py ├── train.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | results/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rohit Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | 2 | Velorama - Gene regulatory network inference for RNA velocity and pseudotime data 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | 5 | .. image:: http://cb.csail.mit.edu/cb/velorama/velorama_v5.png 6 | :width: 600 7 | 8 | Velorama is a Python library for inferring gene regulatory networks from single-cell RNA-seq data 9 | 10 | **It is designed for the case where RNA velocity or pseudotime data is available.** 11 | Here are some of the analyses that you can do with Velorama: 12 | 13 | - infer temporally-causal regulator-target links from RNA velocity cell-to-cell transition matrices. 14 | - infer over branching/merging trajectories using just pseudotime data without having to manually separate them. 15 | - estimate the relative speed of various regulators (i.e., how quickly they act on the target). 16 | 17 | Velorama offers support for both pseudotime and RNA velocity data. 18 | 19 | 20 | Velorama is based on a Granger causal approach and models the differentiation landscape as a directed acyclic graph (DAG) of cells, rather than as a linear total ordering required by previous approaches. 21 | 22 | ================= 23 | Installation 24 | ================= 25 | 26 | To install Velorama, follow the instructions below 27 | 28 | Using Conda/Mamba: :: 29 | 30 | git clone https://github.com/rs239/velorama.git 31 | cd ./velorama 32 | export SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True 33 | pip install . 34 | 35 | ================= 36 | API Example Usage 37 | ================= 38 | 39 | Velorama is currently offered as a command line tool that operates on ``AnnData`` objects. [Ed. Note: We are working on a clean API compatible with the scanpy ecosystem.] First, prepare an AnnData object of the dataset to be analyzed with Velorama. If you have RNA velocity data, make sure it is in the ``layers`` as required by `CellRank `_ and `scVelo `_, so that transition probabilities can be computed. We recommend performing standard single-cell normalization procedures (i.e. normalize counts to the median per-cell transcript count and log transform the normalized counts plus a pseudocount). Next, annotate the candidate regulators and targets in the ``var`` DataFrame of the ``AnnData`` object as follows. :: 40 | 41 | adata.var['is_reg'] = [n in regulator_genes for n in adata.var.index.values] 42 | adata.var['is_target'] = [n in target_genes for n in adata.var.index.values] 43 | 44 | Here ``regulator_genes`` is the set of gene symbols or IDs for the candidate regulators, while ``target_genes`` indicates the set of gene symbols or IDs for the candidate target genes. This AnnData object should be saved as ``{dataset}.h5ad``. 45 | 46 | We provide an example dataset here: `mouse endocrinogenesis `_. This dataset is from the scVelo vignette and is based on the study by `Bergen et al. (2020) `_. 47 | 48 | The below command runs Velorama, which saves the inferred Granger causal interactions and interaction speeds to a given directory. :: 49 | 50 | velorama -ds $dataset -dyn $dynamics -dev $device -l $L -hd $hidden -rd $rd 51 | 52 | Here, ``$dataset`` is the name of the dataset associated with the saved AnnData object. ``$dynamics`` can be "rna_velocity" or "pseudotime", depending on which data the user desires to use to construct the DAG. ``$device`` is chosen to be either "cuda" or "cpu". ``$rd`` is the name of the root directory that contains the saved AnnData object and where the outputs will be saved. Among the optional arguments, ``$L`` refers to the maximum number of lags to consider (default=5). ``$hidden`` indicates the dimensionality of the hidden layers (default=32). 53 | 54 | 55 | We encourage you to report issues at our `Github page`_ ; you can also create pull reports there to contribute your enhancements. 56 | If Velorama is useful for your research, please consider citing `bioRxiv (2022)`_. 57 | 58 | .. _bioRxiv (2022): https://www.biorxiv.org/content/10.1101/2022.10.18.512766v3 59 | .. _Github page: https://github.com/rs239/velorama 60 | -------------------------------------------------------------------------------- /bin/velorama: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import velorama 4 | 5 | velorama.execute_cmdline() 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/.DS_Store -------------------------------------------------------------------------------- /datasets/dataset_A.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_A.h5ad -------------------------------------------------------------------------------- /datasets/dataset_B.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_B.h5ad -------------------------------------------------------------------------------- /datasets/dataset_C.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_C.h5ad -------------------------------------------------------------------------------- /datasets/dataset_D.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_D.h5ad -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "velorama" 3 | version = "0.0.3" 4 | description = "Gene regulatory network inference for RNA velocity and pseudotime data" 5 | authors = [ 6 | { name = "Your Name" } 7 | ] 8 | license = { text = "MIT" } 9 | dependencies = [ 10 | "numpy>=1.26.4", 11 | "scipy>=1.11.4", 12 | "pandas>=2.2.3", 13 | "scikit-learn>=1.5.2", 14 | "cellrank>=2.0.6", 15 | "scvelo>=0.3.2", 16 | "scanpy>=1.10.4", 17 | "anndata>=0.11.1", 18 | "torch==1.13.0", 19 | "ray[tune]==2.6.0", 20 | "matplotlib>=3.9.2", 21 | "h5py>=3.12.1", 22 | "tqdm>=4.67.1", 23 | "networkx>=3.4.2", 24 | "seaborn>=0.13.2", 25 | "statsmodels>=0.14.4", 26 | "schema_learn>=0.1.5.5", 27 | "umap-learn>=0.5.7" 28 | ] 29 | requires-python = ">=3.10" 30 | 31 | [build-system] 32 | requires = [ 33 | "setuptools>=75.1.0", 34 | "wheel>=0.44.0", 35 | "pip>=24.2", 36 | "build" 37 | ] 38 | build-backend = "setuptools.build_meta" 39 | 40 | [tool.setuptools] 41 | packages = ["velorama"] 42 | zip-safe = false 43 | 44 | [project.scripts] 45 | velorama = "velorama:execute_cmdline" -------------------------------------------------------------------------------- /velorama/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .models import VeloramaMLP 3 | from .train import train_model 4 | from .run import execute_cmdline 5 | 6 | __all__ = ['VeloramaMLP', 'train_model', 'execute_cmdline'] 7 | 8 | 9 | import pkgutil, pathlib, importlib 10 | 11 | # from pkgutil import iter_modules 12 | # from pathlib import Path 13 | # from importlib import import_module 14 | 15 | # https://julienharbulot.com/python-dynamical-import.html 16 | # iterate through the modules in the current package 17 | # 18 | # # package_dir = pathlib.Path(__file__).resolve().parent 19 | # # for (_, module_name, _) in pkgutil.iter_modules([package_dir]): 20 | # # if 'datasets' in module_name: 21 | # # module = importlib.import_module(f"{__name__}.{module_name}") 22 | -------------------------------------------------------------------------------- /velorama/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239) 4 | ### 2022 5 | ### MIT Licence 6 | ### 7 | ### Credit: parts of this code make use of the code from Tank et al.'s "Neural Granger Causality" 8 | ### - https://github.com/iancovert/Neural-GC 9 | ### - https://arxiv.org/abs/1802.05842 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .utils import activation_helper 15 | 16 | class VeloramaMLP(nn.Module): 17 | 18 | def __init__(self, n_targets, n_regs, lag, hidden, device, activation): 19 | super(VeloramaMLP, self).__init__() 20 | self.activation = activation_helper(activation) 21 | self.hidden = hidden 22 | self.lag = lag 23 | self.device = device 24 | 25 | # set up first layer 26 | layer = nn.Conv1d(n_regs, hidden[0]*n_targets, lag) 27 | modules = [layer] 28 | 29 | # set up subsequent layers 30 | for d_in, d_out in zip(hidden, hidden[1:] + [1]): 31 | layer = nn.Conv1d(d_in*n_targets, d_out*n_targets, 1, groups=n_targets) 32 | modules.append(layer) 33 | 34 | # Register parameters. 35 | self.layers = nn.ModuleList(modules) 36 | 37 | def forward(self,AX): 38 | 39 | # first layer 40 | ret = 0 41 | for i in range(self.lag): 42 | ret = ret + torch.matmul(AX[i], self.layers[0].weight[:, :, self.lag - 1 - i].T) 43 | ret = ret + self.layers[0].bias 44 | 45 | # subsequent layers 46 | ret = ret.T 47 | for i, fc in enumerate(self.layers): 48 | if i == 0: 49 | continue 50 | ret = self.activation(ret) 51 | ret = fc(ret) 52 | 53 | return ret.T 54 | 55 | def GC(self, threshold=True, ignore_lag=True): 56 | ''' 57 | Extract learned Granger causality. 58 | 59 | Args: 60 | threshold: return norm of weights, or whether norm is nonzero. 61 | ignore_lag: if true, calculate norm of weights jointly for all lags. 62 | 63 | Returns: 64 | GC: (p x p) or (p x p x lag) matrix. In first case, entry (i, j) 65 | indicates whether variable j is Granger causal of variable i. In 66 | second case, entry (i, j, k) indicates whether it's Granger causal 67 | at lag k. 68 | ''' 69 | 70 | W = self.layers[0].weight 71 | W = W.reshape(-1,self.hidden[0],W.shape[1],W.shape[2]) 72 | 73 | if ignore_lag: 74 | GC = torch.norm(W, dim=(1, 3)) 75 | else: 76 | GC = torch.norm(W, dim=1) 77 | 78 | if threshold: 79 | return (GC > 0).int() 80 | else: 81 | return GC 82 | 83 | class VeloramaMLPTarget(nn.Module): 84 | 85 | def __init__(self, n_targets, n_regs, lag, hidden, device, activation): 86 | super(VeloramaMLPTarget, self).__init__() 87 | self.activation = activation_helper(activation) 88 | self.hidden = hidden 89 | self.lag = lag 90 | self.device = device 91 | 92 | # set up first layer 93 | layer = nn.Conv1d(n_regs, hidden[0]*n_targets, lag) 94 | modules = [layer] 95 | 96 | # set up first layers (target variables) 97 | target_modules = [nn.Conv1d(n_targets, hidden[0]*n_targets,1,groups=n_targets,bias=False) 98 | for _ in range(lag)] 99 | self.target_layers = nn.ModuleList(target_modules) 100 | 101 | # set up subsequent layers 102 | for d_in, d_out in zip(hidden, hidden[1:] + [1]): 103 | layer = nn.Conv1d(d_in*n_targets, d_out*n_targets, 1, groups=n_targets) 104 | modules.append(layer) 105 | 106 | # Register parameters. 107 | self.layers = nn.ModuleList(modules) 108 | 109 | def forward(self,AX,AY): 110 | 111 | # first layer 112 | ret = 0 113 | for i in range(self.lag): 114 | ret = ret + torch.matmul(AX[i], self.layers[0].weight[:, :, self.lag - 1 - i].T) 115 | ret = ret + self.layers[0].bias 116 | 117 | # include contributions of target variables 118 | ret = ret.T 119 | for i in range(self.lag): 120 | ret = ret + self.target_layers[self.lag - 1 - i](AY[i].T) 121 | 122 | # subsequent layers 123 | for i, fc in enumerate(self.layers): 124 | if i == 0: 125 | continue 126 | ret = self.activation(ret) 127 | ret = fc(ret) 128 | 129 | return ret.T 130 | 131 | def GC(self, threshold=True, ignore_lag=True): 132 | ''' 133 | Extract learned Granger causality. 134 | 135 | Args: 136 | threshold: return norm of weights, or whether norm is nonzero. 137 | ignore_lag: if true, calculate norm of weights jointly for all lags. 138 | 139 | Returns: 140 | GC: (p x p) or (p x p x lag) matrix. In first case, entry (i, j) 141 | indicates whether variable j is Granger causal of variable i. In 142 | second case, entry (i, j, k) indicates whether it's Granger causal 143 | at lag k. 144 | ''' 145 | 146 | W = self.layers[0].weight 147 | W = W.reshape(-1,self.hidden[0],W.shape[1],W.shape[2]) 148 | 149 | if ignore_lag: 150 | GC = torch.norm(W, dim=(1, 3)) 151 | else: 152 | GC = torch.norm(W, dim=1) 153 | 154 | if threshold: 155 | return (GC > 0).int() 156 | else: 157 | return GC 158 | 159 | def prox_update(network, lam, lr, penalty): 160 | ''' 161 | Perform in place proximal update on first layer weight matrix. 162 | Args: 163 | network: MLP network. 164 | lam: regularization parameter. 165 | lr: learning rate. 166 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 167 | H (hierarchical). 168 | ''' 169 | W = network.layers[0].weight 170 | hidden, p, lag = W.shape 171 | if penalty == 'GL': 172 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 173 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 174 | * torch.clamp(norm - (lr * lam), min=0.0)) 175 | elif penalty == 'GSGL': 176 | norm = torch.norm(W, dim=0, keepdim=True) 177 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 178 | * torch.clamp(norm - (lr * lam), min=0.0)) 179 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 180 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 181 | * torch.clamp(norm - (lr * lam), min=0.0)) 182 | elif penalty == 'H': 183 | # Lowest indices along third axis touch most lagged values. 184 | for i in range(lag): 185 | norm = torch.norm(W[:, :, :(i + 1)], dim=(0, 2), keepdim=True) 186 | W.data[:, :, :(i+1)] = ( 187 | (W.data[:, :, :(i+1)] / torch.clamp(norm, min=(lr * lam))) 188 | * torch.clamp(norm - (lr * lam), min=0.0)) 189 | else: 190 | raise ValueError('unsupported penalty: %s' % penalty) 191 | 192 | def prox_update_target(network, lam, lr, penalty): 193 | ''' 194 | Perform in place proximal update on first layer weight matrix. 195 | Args: 196 | network: MLP network. 197 | lam: regularization parameter. 198 | lr: learning rate. 199 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 200 | H (hierarchical). 201 | ''' 202 | W = network.layers[0].weight 203 | hidden, p, lag = W.shape 204 | 205 | if penalty == 'GL': 206 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 207 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 208 | * torch.clamp(norm - (lr * lam), min=0.0)) 209 | elif penalty == 'GSGL': 210 | norm = torch.norm(W, dim=0, keepdim=True) 211 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 212 | * torch.clamp(norm - (lr * lam), min=0.0)) 213 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 214 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 215 | * torch.clamp(norm - (lr * lam), min=0.0)) 216 | elif penalty == 'H': 217 | # Lowest indices along third axis touch most lagged values. 218 | for i in range(lag): 219 | W = network.layers[0].weight 220 | target_W = torch.stack([network.target_layers[j].weight for j 221 | in range(len(network.target_layers))]).squeeze(-1) 222 | target_W = torch.swapaxes(torch.swapaxes(target_W,0,1),1,2) 223 | W_concat = torch.cat([W.data[:,:,:(i+1)],target_W.data[:,:,:(i+1)]],dim=1) 224 | norm = torch.norm(W_concat[:,:,:(i+1)], dim=(0, 2), keepdim=True) 225 | 226 | # update regulator weights 227 | W.data[:, :, :(i+1)] = ( 228 | (W.data[:, :, :(i+1)] / torch.clamp(norm[:,0:-1], min=(lr * lam))) 229 | * torch.clamp(norm[:,0:-1] - (lr * lam), min=0.0)) 230 | 231 | # update target weights 232 | for j in range(i+1): 233 | W_t = network.target_layers[j].weight 234 | W_t.data = ((W_t.data / torch.clamp(norm[:,-1:], min=(lr * lam))) 235 | * torch.clamp(norm[:,-1:] - (lr * lam), min=0.0)) 236 | else: 237 | raise ValueError('unsupported penalty: %s' % penalty) 238 | 239 | # def prox_update_new(network, lam, lr, penalty): 240 | # ''' 241 | # Perform in place proximal update on first layer weight matrix. 242 | 243 | # Args: 244 | # network: MLP network. 245 | # lam: regularization parameter. 246 | # lr: learning rate. 247 | # penalty: one of GL (group lasso), GSGL (group sparse group lasso), 248 | # H (hierarchical). 249 | # ''' 250 | 251 | # W = network.layers[0].weight 252 | # hidden, p, lag = W.shape 253 | 254 | # W_copy = torch.clone(W) 255 | # W_copy = W.reshape(-1,network.hidden[0],W.shape[1],W.shape[2]) 256 | 257 | # if penalty == 'GL': 258 | 259 | # norm = torch.norm(W_copy[:, :, :, :(i + 1)], dim=(1, 3), keepdim=True) 260 | # W_copy.data = W_copy 261 | 262 | # norm = torch.norm(W, dim=(0, 2), keepdim=True) 263 | # W.data = ((W / torch.clamp(norm, min=(lr * lam))) 264 | # * torch.clamp(norm - (lr * lam), min=0.0)) 265 | # # elif penalty == 'GSGL': 266 | # # norm = torch.norm(W, dim=0, keepdim=True) 267 | # # W.data = ((W / torch.clamp(norm, min=(lr * lam))) 268 | # # * torch.clamp(norm - (lr * lam), min=0.0)) 269 | # # norm = torch.norm(W, dim=(0, 2), keepdim=True) 270 | # # W.data = ((W / torch.clamp(norm, min=(lr * lam))) 271 | # # * torch.clamp(norm - (lr * lam), min=0.0)) 272 | # elif penalty == 'H': 273 | # # Lowest indices along third axis touch most lagged values. 274 | # for i in range(lag): 275 | # norm = torch.norm(W_copy[:, :, :, :(i + 1)], dim=(1, 3), keepdim=True) 276 | # W_copy.data[:, :, :, :(i+1)] = ( 277 | # (W_copy.data[:, :, :, :(i+1)] / torch.clamp(norm, min=(lr * lam))) 278 | # * torch.clamp(norm - (lr * lam), min=0.0)) 279 | # W.data = W_copy.data.reshape(W.shape) 280 | 281 | # else: 282 | # raise ValueError('unsupported penalty: %s' % penalty) 283 | 284 | def regularize(network, lam, penalty): 285 | ''' 286 | Calculate regularization term for first layer weight matrix. 287 | Args: 288 | network: MLP network. 289 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 290 | H (hierarchical). 291 | ''' 292 | W = network.layers[0].weight 293 | hidden, p, lag = W.shape 294 | if penalty == 'GL': 295 | return lam * torch.sum(torch.norm(W, dim=(0, 2))) 296 | elif penalty == 'GSGL': 297 | return lam * (torch.sum(torch.norm(W, dim=(0, 2))) 298 | + torch.sum(torch.norm(W, dim=0))) 299 | elif penalty == 'H': 300 | # Lowest indices along third axis touch most lagged values. 301 | return lam * sum([torch.sum(torch.norm(W[:, :, :(i+1)], dim=(0, 2))) 302 | for i in range(lag)]) 303 | else: 304 | raise ValueError('unsupported penalty: %s' % penalty) 305 | 306 | def regularize_target(network, lam, penalty): 307 | ''' 308 | Calculate regularization term for first layer weight matrix. 309 | Args: 310 | network: MLP network. 311 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 312 | H (hierarchical). 313 | ''' 314 | W = network.layers[0].weight 315 | 316 | hidden, p, lag = W.shape 317 | if penalty == 'GL': 318 | return lam * torch.sum(torch.norm(W, dim=(0, 2))) 319 | elif penalty == 'GSGL': 320 | return lam * (torch.sum(torch.norm(W, dim=(0, 2))) 321 | + torch.sum(torch.norm(W, dim=0))) 322 | elif penalty == 'H': 323 | # Lowest indices along third axis touch most lagged values. 324 | # return lam * sum([torch.sum(torch.norm(W[:, :, :(i+1)], dim=(0, 2))) 325 | # for i in range(lag)]) 326 | target_W = torch.stack([network.target_layers[i].weight for i 327 | in range(len(network.target_layers))]).squeeze(-1) 328 | target_W = torch.swapaxes(torch.swapaxes(target_W,0,1),1,2) 329 | 330 | return lam * sum([torch.sum(torch.norm(torch.cat([W.data[:,:,:(i+1)],target_W.data[:,:,:(i+1)]],dim=1), 331 | dim=(0, 2))) for i in range(lag)]) 332 | else: 333 | raise ValueError('unsupported penalty: %s' % penalty) 334 | 335 | # def regularize_new(network, lam, penalty): 336 | # ''' 337 | # Calculate regularization term for first layer weight matrix. 338 | 339 | # Args: 340 | # network: MLP network. 341 | # penalty: one of GL (group lasso), GSGL (group sparse group lasso), 342 | # H (hierarchical). 343 | # ''' 344 | # W = network.layers[0].weight 345 | # hidden, p, lag = W.shape 346 | # W = W.reshape(-1,network.hidden[0],W.shape[1],W.shape[2]) 347 | 348 | # if penalty == 'GL': 349 | # return lam * torch.sum(torch.norm(W, dim=(1, 3))).sum() 350 | # # elif penalty == 'GSGL': 351 | # # return lam * (torch.sum(torch.norm(W, dim=(0, 3))) 352 | # # + torch.sum(torch.norm(W, dim=0))) 353 | # elif penalty == 'H': 354 | # # Lowest indices along third axis touch most lagged values. 355 | # return lam * sum([torch.sum(torch.norm(W[:, :, :, :(i+1)], dim=(1, 3))) 356 | # for i in range(lag)]).sum() 357 | # else: 358 | # raise ValueError('unsupported penalty: %s' % penalty) 359 | 360 | def ridge_regularize(network, lam): 361 | '''Apply ridge penalty at all subsequent layers.''' 362 | return lam * sum([torch.sum(fc.weight ** 2) for fc in network.layers[1:]]) 363 | 364 | 365 | def restore_parameters(model, best_model): 366 | '''Move parameter values from best_model to model.''' 367 | for params, best_params in zip(model.parameters(), best_model.parameters()): 368 | params.data = best_params 369 | -------------------------------------------------------------------------------- /velorama/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239) 4 | ### 2022 5 | ### MIT Licence 6 | ### 7 | 8 | import os 9 | import numpy as np 10 | import scanpy as sc 11 | import argparse 12 | import time 13 | import ray 14 | from ray import tune 15 | import statistics 16 | import scvelo as scv 17 | import pandas as pd 18 | import shutil 19 | 20 | from .models import * 21 | from .train import * 22 | from .utils import * 23 | from .utils import move_files 24 | 25 | 26 | def execute_cmdline(): 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('-n','--name',dest='name',type=str,default='velorama_run',help='substring to have in our output files') 30 | parser.add_argument('-ds','--dataset',dest='dataset',type=str) 31 | parser.add_argument('-dyn','--dyn',dest='dynamics',type=str,default='pseudotime', 32 | choices=['pseudotime','rna_velocity','pseudotime_time','pseudotime_precomputed']) 33 | parser.add_argument('-ptloc','--ptloc',dest='ptloc',type=str,default='pseudotime') 34 | parser.add_argument('-dev','--device',dest='device',type=str,default='cpu') 35 | parser.add_argument('-s','--seed',dest='seed',type=int,default=0,help='Random seed. Set to 0,1,2 etc.') 36 | parser.add_argument('-lmr','--lam_ridge',dest='lam_ridge',type=float,default=0., help='Currenty unsupported') 37 | parser.add_argument('-p','--penalty',dest='penalty',type=str,default='H') 38 | parser.add_argument('-l','--lag',dest='lag',type=int,default=5) 39 | parser.add_argument('-hd', '--hidden',dest='hidden',type=int,default=32) 40 | parser.add_argument('-mi','--max_iter',dest='max_iter',type=int,default=1000) 41 | parser.add_argument('-lr','--learning_rate',dest='learning_rate',type=float,default=0.01) 42 | parser.add_argument('-pr','--proba',dest='proba',type=int,default=1) 43 | parser.add_argument('-ce','--check_every',dest='check_every',type=int,default=10) 44 | parser.add_argument('-rd','--root_dir',dest='root_dir',type=str) 45 | parser.add_argument('-sd','--save_dir',dest='save_dir',type=str,default='./results') 46 | parser.add_argument('-ls','--lam_start',dest='lam_start',type=float,default=-2) 47 | parser.add_argument('-le','--lam_end',type=float,default=1) 48 | parser.add_argument('-xn','--x_norm',dest='x_norm',type=str,default='zscore') # ,choices=['none','zscore','to_count:zscore','zscore_pca','maxmin','fill_zscore']) 49 | parser.add_argument('-nl','--num_lambdas',dest='num_lambdas',type=int,default=19) 50 | parser.add_argument('-rt','--reg_target',dest='reg_target',type=int,default=1) 51 | parser.add_argument('-nn','--n_neighbors',dest='n_neighbors',type=int,default=30) 52 | parser.add_argument('-vm','--velo_mode',dest='velo_mode',type=str,default='stochastic') 53 | parser.add_argument('-ts','--time_series',dest='time_series',type=int,default=0) 54 | parser.add_argument('-nc','--n_comps',dest='n_comps',type=int,default=50) 55 | 56 | args = parser.parse_args() 57 | 58 | if not os.path.exists(args.save_dir): 59 | os.mkdir(args.save_dir) 60 | 61 | adata = sc.read(os.path.join(args.root_dir,'{}.h5ad'.format(args.dataset))) 62 | 63 | if not args.reg_target: 64 | adata.var['is_target'] = True 65 | adata.var['is_reg'] = True 66 | 67 | target_genes = adata.var.index.values[adata.var['is_target']] 68 | reg_genes = adata.var.index.values[adata.var['is_reg']] 69 | 70 | if args.x_norm == 'zscore': 71 | 72 | print('Normalizing data: 0 mean, 1 SD') 73 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy() 74 | std = X_orig.std(0) 75 | std[std == 0] = 1 76 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std 77 | if 'De-noised' not in args.dataset: 78 | X = torch.clip(X,-5,5) 79 | 80 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy() 81 | std = Y_orig.std(0) 82 | std[std == 0] = 1 83 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std 84 | if 'De-noised' not in args.dataset: 85 | Y = torch.clip(Y,-5,5) 86 | 87 | elif args.x_norm == 'magic_zscore': 88 | 89 | import magic 90 | from scipy.sparse import issparse 91 | 92 | X = adata.X.toarray() if issparse(adata.X) else adata.X 93 | X = pd.DataFrame(X,columns=adata.var.index.values) 94 | magic_operator = magic.MAGIC() 95 | X_magic = magic_operator.fit_transform(X).astype(np.float32) 96 | 97 | X_orig = X_magic.values[:,adata.var['is_reg'].values] 98 | std = X_orig.std(0) 99 | std[std == 0] = 1 100 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std 101 | # X = torch.clip(X,-5,5) 102 | 103 | Y_orig = X_magic.values[:,adata.var['is_target'].values] 104 | std = Y_orig.std(0) 105 | std[std == 0] = 1 106 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std 107 | 108 | elif args.x_norm == 'fill_zscore': 109 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy() 110 | X_df = pd.DataFrame(X_orig) 111 | X_df[X_df < 1e-9] = np.nan 112 | X_df = X_df.fillna(X_df.median()) 113 | X_orig = X_df.values 114 | std = X_orig.std(0) 115 | std[std == 0] = 1 116 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std 117 | # X = torch.clip(X,-5,5) 118 | 119 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy() 120 | Y_df = pd.DataFrame(Y_orig) 121 | Y_df[Y_df < 1e-9] = np.nan 122 | Y_df = Y_df.fillna(Y_df.median()) 123 | Y_orig = Y_df.values 124 | std = Y_orig.std(0) 125 | std[std == 0] = 1 126 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std 127 | # Y = torch.clip(Y,-5,5) 128 | 129 | 130 | elif args.x_norm == 'to_count:zscore': 131 | 132 | print('Use counts: 0 mean, 1 SD') 133 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy() 134 | X_orig = 2**X_orig-1 135 | std = X_orig.std(0) 136 | std[std == 0] = 1 137 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std 138 | 139 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy() 140 | Y_orig = 2**Y_orig-1 141 | std = Y_orig.std(0) 142 | std[std == 0] = 1 143 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std 144 | 145 | elif args.x_norm == 'zscore_pca': 146 | 147 | print('PCA + normalizing data: 0 mean, 1 SD') 148 | 149 | sc.tl.pca(adata,n_comps=100) 150 | adata.X = adata.obsm['X_pca'].dot(adata.varm['PCs'].T) 151 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy() 152 | std = X_orig.std(0) 153 | std[std == 0] = 1 154 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std 155 | X = torch.clip(X,-5,5) 156 | 157 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy() 158 | std = Y_orig.std(0) 159 | std[std == 0] = 1 160 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std 161 | Y = torch.clip(Y,-5,5) 162 | 163 | elif args.x_norm == 'maxmin': 164 | 165 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy() 166 | X_min = X_orig.min(0) 167 | X_max = X_orig.max(0) 168 | X = torch.FloatTensor((X_orig-X_min)/(X_max-X_min)) 169 | X -= X.mean(0) 170 | 171 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy() 172 | Y_min = Y_orig.min(0) 173 | Y_max = Y_orig.max(0) 174 | Y = torch.FloatTensor((Y_orig-Y_min)/(Y_max-Y_min)) 175 | Y -= Y.mean(0) 176 | 177 | else: 178 | assert args.x_norm == 'none' 179 | X = torch.FloatTensor(adata[:,adata.var['is_reg']].X.toarray()) 180 | Y = torch.FloatTensor(adata[:,adata.var['is_target']].X.toarray()) 181 | 182 | print('# of Regs: {}, # of Targets: {}'.format(X.shape[1],Y.shape[1])) 183 | 184 | print('Constructing DAG...') 185 | 186 | if 'De-noised' in args.dataset: 187 | sc.pp.normalize_total(adata, target_sum=1e4) 188 | sc.pp.log1p(adata) 189 | 190 | sc.pp.scale(adata) 191 | A = construct_dag(adata,dynamics=args.dynamics,ptloc=args.ptloc,proba=args.proba, 192 | n_neighbors=args.n_neighbors,velo_mode=args.velo_mode, 193 | use_time=args.time_series,n_comps=args.n_comps) 194 | A = torch.FloatTensor(A) 195 | AX = calculate_diffusion_lags(A,X,args.lag) 196 | 197 | if args.reg_target: 198 | AY = calculate_diffusion_lags(A,Y,args.lag) 199 | else: 200 | AY = None 201 | 202 | dir_name = '{}.seed{}.h{}.{}.lag{}.{}'.format(args.name,args.seed,args.hidden,args.penalty,args.lag,args.dynamics) 203 | 204 | if not os.path.exists(os.path.join(args.save_dir,dir_name)): 205 | os.mkdir(os.path.join(args.save_dir,dir_name)) 206 | 207 | ray.init(object_store_memory=10**9) 208 | 209 | total_start = time.time() 210 | lam_list = np.logspace(args.lam_start, args.lam_end, num=args.num_lambdas).tolist() 211 | 212 | config = {'name': args.name, 213 | 'AX': AX, 214 | 'AY': AY, 215 | 'Y': Y, 216 | 'seed': args.seed, 217 | 'lr': args.learning_rate, 218 | 'lam': tune.grid_search(lam_list), 219 | 'lam_ridge': args.lam_ridge, 220 | 'penalty': args.penalty, 221 | 'lag': args.lag, 222 | 'hidden': [args.hidden], 223 | 'max_iter': args.max_iter, 224 | 'device': args.device, 225 | 'lookback': 5, 226 | 'check_every': args.check_every, 227 | 'verbose': True, 228 | 'dynamics': args.dynamics, 229 | 'results_dir': args.save_dir, 230 | 'dir_name': dir_name, 231 | 'reg_target': args.reg_target} 232 | 233 | ngpu = 0.2 if (args.device == 'cuda') else 0 234 | resources_per_trial = {"cpu": 1, "gpu": ngpu, "memory": 2 * 1024 * 1024 * 1024} 235 | analysis = tune.run(train_model,resources_per_trial=resources_per_trial,config=config, 236 | local_dir=os.path.join(args.root_dir,args.save_dir)) 237 | 238 | target_dir = os.path.join(args.save_dir, dir_name) 239 | base_dir = args.save_dir 240 | move_files(base_dir, target_dir) 241 | 242 | # aggregate results 243 | lam_list = [np.round(lam,4) for lam in lam_list] 244 | all_lags = load_gc_interactions(args.name,args.save_dir,lam_list,hidden_dim=args.hidden, 245 | lag=args.lag,penalty=args.penalty, 246 | dynamics=args.dynamics,seed=args.seed,ignore_lag=False) 247 | 248 | gc_mat = estimate_interactions(all_lags,lag=args.lag) 249 | gc_df = pd.DataFrame(gc_mat.cpu().data.numpy(),index=target_genes,columns=reg_genes) 250 | gc_df.to_csv(os.path.join(args.save_dir,'{}.{}.velorama.interactions.tsv'.format(args.name,args.dynamics)),sep='\t') 251 | 252 | lag_mat = estimate_lags(all_lags,lag=args.lag) 253 | lag_df = pd.DataFrame(lag_mat.cpu().data.numpy(),index=target_genes,columns=reg_genes) 254 | lag_df.to_csv(os.path.join(args.save_dir,'{}.{}.velorama.lags.tsv'.format(args.name,args.dynamics)),sep='\t') 255 | 256 | print('Total time:',time.time()-total_start) 257 | np.savetxt(os.path.join(args.save_dir,dir_name + '.time.txt'),np.array([time.time()-total_start])) 258 | 259 | if __name__ == "__main__": 260 | execute_cmdline() -------------------------------------------------------------------------------- /velorama/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239) 4 | ### 2022 5 | ### MIT Licence 6 | ### 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from copy import deepcopy 13 | import time 14 | 15 | from .models import * 16 | from .utils import * 17 | 18 | def train_model(config): 19 | 20 | AX = config["AX"] 21 | AY = config["AY"] 22 | 23 | Y = config["Y"] 24 | 25 | name = config["name"] 26 | 27 | seed = config["seed"] 28 | lr = config["lr"] 29 | lam = config["lam"] 30 | lam_ridge = config["lam_ridge"] 31 | penalty = config["penalty"] 32 | lag = config["lag"] 33 | hidden = config["hidden"] 34 | max_iter = config["max_iter"] 35 | device = config["device"] 36 | lookback = config["lookback"] 37 | check_every = config["check_every"] 38 | verbose = config["verbose"] 39 | dynamics = config['dynamics'] 40 | 41 | results_dir = config['results_dir'] 42 | dir_name = config['dir_name'] 43 | reg_target = config['reg_target'] 44 | 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | 48 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}'.format(name,seed,np.round(lam,4), 49 | hidden[0],penalty,lag,dynamics) 50 | gc_path1 = os.path.join(results_dir,dir_name,file_name + '.pt') 51 | gc_path2 = os.path.join(results_dir,dir_name,file_name + '.ignore_lag.pt') 52 | 53 | if not os.path.exists(gc_path1) and not os.path.exists(gc_path2): 54 | 55 | num_regs = AX.shape[-1] 56 | num_targets = Y.shape[1] 57 | 58 | AX = AX.to(device) 59 | Y = Y.to(device) 60 | 61 | if reg_target: 62 | vmlp = VeloramaMLPTarget(num_targets, num_regs, lag=lag, hidden=hidden, 63 | device=device, activation='relu') 64 | AY = AY.to(device) 65 | else: 66 | vmlp = VeloramaMLP(num_targets, num_regs, lag=lag, hidden=hidden, 67 | device=device, activation='relu') 68 | 69 | vmlp.to(device) 70 | 71 | '''Train model with ISTA.''' 72 | lag = vmlp.lag 73 | loss_fn = nn.MSELoss(reduction='none') 74 | train_loss_list = [] 75 | 76 | # For early stopping. 77 | best_it = None 78 | best_loss = np.inf 79 | best_model = None 80 | 81 | # Calculate smooth error. 82 | if reg_target: 83 | preds = vmlp(AX,AY) 84 | else: 85 | preds = vmlp(AX) 86 | 87 | loss = loss_fn(preds,Y).mean(0).sum() 88 | 89 | # print('LOSS:---',loss) 90 | 91 | ridge = ridge_regularize(vmlp, lam_ridge) 92 | smooth = loss + ridge 93 | 94 | variable_usage_list = [] 95 | loss_list = [] 96 | 97 | # For early stopping. 98 | train_loss_list = [] 99 | best_it = None 100 | best_loss = np.inf 101 | best_model = None 102 | 103 | for it in range(max_iter): 104 | 105 | start = time.time() 106 | 107 | # Take gradient step. 108 | smooth.backward() 109 | for param in vmlp.parameters(): 110 | param.data = param - lr * param.grad 111 | 112 | # Take prox step. 113 | if lam > 0: 114 | # for net in vmlp.networks: 115 | if reg_target: 116 | prox_update_target(vmlp, lam, lr, penalty) 117 | else: 118 | prox_update(vmlp, lam, lr, penalty) 119 | 120 | vmlp.zero_grad() 121 | 122 | # Calculate loss for next iteration. 123 | if reg_target: 124 | preds = vmlp(AX,AY) 125 | else: 126 | preds = vmlp(AX) 127 | 128 | loss = loss_fn(preds,Y).mean(0).sum() 129 | 130 | ridge = ridge_regularize(vmlp, lam_ridge) 131 | smooth = loss + ridge 132 | 133 | # Check progress. 134 | if (it + 1) % check_every == 0: 135 | 136 | if reg_target: 137 | nonsmooth = regularize_target(vmlp, lam, penalty) 138 | else: 139 | nonsmooth = regularize(vmlp, lam, penalty) 140 | mean_loss = (smooth + nonsmooth).detach()/Y.shape[1] 141 | 142 | variable_usage = torch.mean(vmlp.GC(ignore_lag=False).float()) 143 | variable_usage_list.append(variable_usage) 144 | loss_list.append(mean_loss) 145 | 146 | # Check for early stopping. 147 | if mean_loss < best_loss: 148 | best_loss = mean_loss 149 | best_it = it 150 | best_model = deepcopy(vmlp) 151 | 152 | if verbose: 153 | print('Lam={}: Iter {}, {} sec'.format(lam,it+1,np.round(time.time()-start,2)), 154 | '-----','Loss: %.2f' % mean_loss,', Variable usage = %.2f%%' % (100 * variable_usage)) # , 155 | # '|||','%.3f' % loss_crit,'%.3f' % variable_usage_crit) 156 | 157 | elif (it - best_it) == lookback * check_every: 158 | if verbose: 159 | print('EARLY STOP: Lam={}, Iter {}'.format(lam,it + 1)) 160 | break 161 | 162 | if verbose: 163 | print('Lam={}: Completed in {} iterations'.format(lam,it+1)) 164 | 165 | # Restore best model. 166 | restore_parameters(vmlp, best_model) 167 | 168 | if not os.path.exists(results_dir): 169 | os.mkdir(results_dir) 170 | if not os.path.exists(os.path.join(results_dir,dir_name)): 171 | os.mkdir(os.path.join(results_dir,dir_name)) 172 | 173 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}'.format(name,seed,np.round(lam,4), 174 | hidden[0],penalty,lag,dynamics) 175 | GC_lag = vmlp.GC(threshold=False, ignore_lag=False).cpu() 176 | torch.save(GC_lag, os.path.join(results_dir,dir_name,file_name + '.pt')) 177 | 178 | GC_lag = vmlp.GC(threshold=False, ignore_lag=True).cpu() 179 | torch.save(GC_lag, os.path.join(results_dir,dir_name,file_name + '.ignore_lag.pt')) -------------------------------------------------------------------------------- /velorama/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from scipy.stats import f 5 | from scipy.sparse import csr_matrix 6 | import statistics 7 | import scanpy as sc 8 | import scanpy.external as sce 9 | from anndata import AnnData 10 | import cellrank as cr 11 | import scvelo as scv 12 | import schema 13 | from torch.nn.functional import normalize 14 | import shutil 15 | 16 | from cellrank.kernels import VelocityKernel 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | def construct_dag(adata,dynamics='rna_velocity',ptloc='pseudotime',velo_mode='stochastic',proba=True, 22 | n_neighbors=30,n_comps=50,use_time=False): 23 | 24 | """Constructs the adjacency matrix for a DAG. 25 | Parameters 26 | ---------- 27 | adata: 'AnnData' 28 | AnnData object with rows corresponding to cells and columns corresponding 29 | to genes. 30 | dynamics: {'rna_velocity','pseudotime','pseudotime_precomputed'} 31 | (default: rna_velocity) 32 | Dynamics used to orient and/or weight edges in the DAG of cells. 33 | If 'pseudotime_precomputed', the precomputed pseudotime values must be 34 | included as an observation category named 'pseudotime' in the included 35 | AnnData object (e.g., adata.obs['pseudotime'] = [list of float]). 36 | velo_mode: {'stochastic','deterministic','dynamical'} (default: 'stochastic') 37 | RNA velocity estimation using either the steady-state/deterministic, 38 | stochastic, or dynamical model of transcriptional dynamics from scVelo 39 | (Bergen et al., 2020). 40 | proba: 'bool' (default: True) 41 | Whether to use the transition probabilities from CellRank (Lange et al., 2022) 42 | in weighting the edges of the DAG or to discretize these probabilities by 43 | retaining only the top half of edges per cell. 44 | n_neighbors: 'int' (default: 15) 45 | Number of nearest neighbors to use in constructing a k-nearest 46 | neighbor graph for constructing a DAG if a custom DAG is not provided. 47 | n_comps: 'int' (default: 50) 48 | Number of principle components to compute and use for representing 49 | the gene expression profiles of cells. 50 | use_time: 'bool' (default: False) 51 | Whether to integrate time stamps in constructing the DAG. If True, time 52 | stamps must be included as an observation category named 'time' in the 53 | included AnnData object (e.g., adata.obs['time'] = [list of float]). 54 | """ 55 | 56 | sc.tl.pca(adata, n_comps=n_comps, svd_solver='arpack') 57 | if use_time: 58 | sqp = schema.SchemaQP(min_desired_corr=0.9,mode='affine', 59 | params= {'decomposition_model': 'pca', 60 | 'num_top_components': adata.obsm['X_pca'].shape[1]}) 61 | adata.obsm['X_rep'] = sqp.fit_transform(adata.obsm['X_pca'],[adata.obs['time']], 62 | ['numeric'],[1]) 63 | else: 64 | adata.obsm['X_rep'] = adata.obsm['X_pca'] 65 | 66 | if dynamics == 'pseudotime': 67 | A = construct_dag_pseudotime(adata.obsm['X_rep'],adata.uns['iroot'], 68 | n_neighbors=n_neighbors).T 69 | 70 | elif dynamics == 'pseudotime_precomputed': 71 | A = construct_dag_pseudotime(adata.obsm['X_rep'],adata.uns['iroot'], 72 | n_neighbors=n_neighbors, 73 | pseudotime_algo='precomputed', 74 | precomputed_pseudotime=adata.obs[ptloc].values).T 75 | 76 | elif dynamics == 'rna_velocity': 77 | scv.pp.moments(adata, n_neighbors=n_neighbors, use_rep='X_rep') 78 | if velo_mode == 'dynamical': 79 | scv.tl.recover_dynamics(adata) 80 | scv.tl.velocity(adata,mode=velo_mode) 81 | scv.tl.velocity_graph(adata) 82 | vk = VelocityKernel(adata).compute_transition_matrix() 83 | A = vk.transition_matrix 84 | A = A.toarray() 85 | 86 | # if proba is False (0), it won't use the probabilistic 87 | # transition matrix 88 | if not proba: 89 | for i in range(len(A)): 90 | nzeros = [] 91 | for j in range(len(A)): 92 | if A[i][j] > 0: 93 | nzeros.append(A[i][j]) 94 | m = statistics.median(nzeros) 95 | for j in range(len(A)): 96 | if A[i][j] < m: 97 | A[i][j] = 0 98 | else: 99 | A[i][j] = 1 100 | 101 | for i in range(len(A)): 102 | for j in range(len(A)): 103 | if A[i][j] > 0 and A[j][i] > 0 and A[i][j] > A[j][i]: 104 | A[j][i] = 0 105 | 106 | A = construct_S(torch.FloatTensor(A)) 107 | 108 | return A 109 | 110 | def construct_dag_pseudotime(joint_feature_embeddings,iroot,n_neighbors=15,pseudotime_algo='dpt', 111 | precomputed_pseudotime=None): 112 | 113 | """Constructs the adjacency matrix for a DAG using pseudotime. 114 | Parameters 115 | ---------- 116 | joint_feature_embeddings: 'numpy.ndarray' (default: None) 117 | Matrix of low dimensional embeddings with rows corresponding 118 | to observations and columns corresponding to feature embeddings 119 | for constructing a DAG if a custom DAG is not provided. 120 | iroot: 'int' (default: None) 121 | Index of root cell for inferring pseudotime for constructing a DAG 122 | if a custom DAG is not provided. 123 | n_neighbors: 'int' (default: 15) 124 | Number of nearest neighbors to use in constructing a k-nearest 125 | neighbor graph for constructing a DAG if a custom DAG is not provided. 126 | pseudotime_algo: {'dpt','palantir'} 127 | Pseudotime algorithm to use for constructing a DAG if a custom DAG 128 | is not provided. 'dpt' and 'palantir' perform the diffusion pseudotime 129 | (Haghverdi et al., 2016) and Palantir (Setty et al., 2019) algorithms, 130 | respectively. 131 | precomputed_pseudotime: 'numpy.ndarray' or List (default: None) 132 | Precomputed pseudotime values for all cells. 133 | """ 134 | 135 | pseudotime,knn_graph = infer_knngraph_pseudotime(joint_feature_embeddings,iroot, 136 | n_neighbors=n_neighbors,pseudotime_algo=pseudotime_algo, 137 | precomputed_pseudotime=precomputed_pseudotime) 138 | dag_adjacency_matrix = dag_orient_edges(knn_graph,pseudotime) 139 | 140 | return dag_adjacency_matrix 141 | 142 | def infer_knngraph_pseudotime(joint_feature_embeddings,iroot,n_neighbors=15,pseudotime_algo='dpt', 143 | precomputed_pseudotime=None): 144 | 145 | adata = AnnData(joint_feature_embeddings) 146 | adata.obsm['X_joint'] = joint_feature_embeddings 147 | adata.uns['iroot'] = iroot 148 | 149 | if pseudotime_algo == 'dpt': 150 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors) 151 | sc.tl.dpt(adata) 152 | adata.obs['pseudotime'] = adata.obs['dpt_pseudotime'].values 153 | knn_graph = adata.obsp['distances'].astype(bool).astype(float) 154 | elif pseudotime_algo == 'precomputed': 155 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors) 156 | adata.obs['pseudotime'] = precomputed_pseudotime 157 | knn_graph = adata.obsp['distances'].astype(bool).astype(float) 158 | elif pseudotime_algo == 'palantir': 159 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors) 160 | sce.tl.palantir(adata, knn=n_neighbors,use_adjacency_matrix=True, 161 | distances_key='distances') 162 | pr_res = sce.tl.palantir_results(adata, 163 | early_cell=adata.obs.index.values[adata.uns['iroot']], 164 | ms_data='X_palantir_multiscale') 165 | adata.obs['pseudotime'] = pr_res.pseudotime 166 | knn_graph = adata.obsp['distances'].astype(bool).astype(float) 167 | 168 | return adata.obs['pseudotime'].values,knn_graph 169 | 170 | def dag_orient_edges(adjacency_matrix,pseudotime): 171 | 172 | A = adjacency_matrix.astype(bool).astype(float) 173 | D = -1*np.sign(pseudotime[:,None] - pseudotime).T 174 | D = (D == 1).astype(float) 175 | D = (A.toarray()*D).astype(bool).astype(float) 176 | 177 | return D 178 | 179 | def construct_S(D): 180 | 181 | S = D.clone() 182 | D_sum = D.sum(0) 183 | D_sum[D_sum == 0] = 1 184 | 185 | S = (S/D_sum) 186 | S = S.T 187 | 188 | return S 189 | 190 | def normalize_adjacency(D): 191 | 192 | S = D.clone() 193 | D_sum = D.sum(0) 194 | D_sum[D_sum == 0] = 1 195 | 196 | S = (S/D_sum) 197 | 198 | return S 199 | 200 | def seq2dag(N): 201 | A = torch.zeros(N, N) 202 | for i in range(N - 1): 203 | A[i][i + 1] = 1 204 | return A 205 | 206 | def activation_helper(activation, dim=None): 207 | if activation == 'sigmoid': 208 | act = nn.Sigmoid() 209 | elif activation == 'tanh': 210 | act = nn.Tanh() 211 | elif activation == 'relu': 212 | act = nn.ReLU() 213 | elif activation == 'leakyrelu': 214 | act = nn.LeakyReLU() 215 | elif activation is None: 216 | def act(x): 217 | return x 218 | else: 219 | raise ValueError('unsupported activation: %s' % activation) 220 | return act 221 | 222 | def calculate_diffusion_lags(A,X,lag): 223 | 224 | if A == "linear": 225 | A = seq2dag(X.shape[1]) 226 | 227 | ax = [] 228 | cur = A 229 | for _ in range(lag): 230 | ax.append(torch.matmul(cur, X)) 231 | cur = torch.matmul(A, cur) 232 | for i in range(len(cur)): 233 | cur[i][i] = 0 234 | 235 | return torch.stack(ax) 236 | 237 | def load_gc_interactions(name,results_dir,lam_list,hidden_dim=16,lag=5,penalty='H', 238 | dynamics='rna_velocity',seed=0,ignore_lag=False): 239 | 240 | config_name = '{}.seed{}.h{}.{}.lag{}.{}'.format(name,seed,hidden_dim,penalty,lag,dynamics) 241 | 242 | all_lags = [] 243 | for lam in lam_list: 244 | if ignore_lag: 245 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}.ignore_lag.pt'.format(name,seed,lam,hidden_dim,penalty,lag,dynamics) 246 | file_path = os.path.join(results_dir,config_name,file_name) 247 | gc_lag = torch.load(file_path) 248 | gc_lag = gc_lag.unsqueeze(-1) 249 | else: 250 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}.pt'.format(name,seed,lam,hidden_dim,penalty,lag,dynamics) 251 | file_path = os.path.join(results_dir,config_name,file_name) 252 | gc_lag = torch.load(file_path) 253 | all_lags.append(gc_lag.detach()) 254 | 255 | all_lags = torch.stack(all_lags) 256 | 257 | return all_lags 258 | 259 | def lor(x, y): 260 | return x + y 261 | 262 | def estimate_interactions(all_lags,lag=5,lower_thresh=0.01,upper_thresh=0.95, 263 | binarize=False,l2_norm=False): 264 | 265 | all_interactions = [] 266 | for i in range(len(all_lags)): 267 | for j in range(lag): 268 | 269 | nnz_percent = (all_lags[i,:,:,j] != 0).float().mean().data.numpy() 270 | 271 | if nnz_percent > lower_thresh and nnz_percent < upper_thresh: 272 | interactions = all_lags[i,:,:,j] 273 | 274 | if l2_norm: 275 | interactions = normalize(interactions,dim=(0,1)) 276 | if binarize: 277 | interactions = (interactions != 0).float() 278 | 279 | all_interactions.append(interactions) 280 | return torch.stack(all_interactions).mean(0) 281 | 282 | def estimate_lags(all_lags,lag,lower_thresh=0.01,upper_thresh=1.): 283 | 284 | retained_interactions = [] 285 | for i in range(len(all_lags)): 286 | nnz_percent = (all_lags[i] != 0).float().mean().data.numpy() 287 | if nnz_percent > lower_thresh and nnz_percent < upper_thresh: 288 | retained_interactions.append(all_lags[i]) 289 | retained_interactions = torch.stack(retained_interactions) 290 | 291 | est_lags = normalize(retained_interactions,p=1,dim=-1).mean(0) 292 | return (est_lags*(torch.arange(lag)+1)).sum(-1) 293 | 294 | def move_files(base_dir, target_dir): 295 | os.makedirs(target_dir, exist_ok=True) # Ensure the target directory exists 296 | # Walk through all directories recursively 297 | for root, dirs, files in os.walk(base_dir): 298 | if 'train' in root: # Check if ‘train’ is in the directory path 299 | for file in files: 300 | if file.endswith('.pt'): # Check if the file ends with ‘.pt’ 301 | source_path = os.path.join(root, file) 302 | shutil.move(source_path, target_dir) # Move the file 303 | print(f"Moved: {source_path} to {target_dir}") --------------------------------------------------------------------------------