├── .DS_Store
├── .gitignore
├── LICENSE
├── README.rst
├── bin
└── velorama
├── datasets
├── .DS_Store
├── dataset_A.h5ad
├── dataset_B.h5ad
├── dataset_C.h5ad
└── dataset_D.h5ad
├── pyproject.toml
└── velorama
├── __init__.py
├── models.py
├── run.py
├── train.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | results/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Rohit Singh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 |
2 | Velorama - Gene regulatory network inference for RNA velocity and pseudotime data
3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 |
5 | .. image:: http://cb.csail.mit.edu/cb/velorama/velorama_v5.png
6 | :width: 600
7 |
8 | Velorama is a Python library for inferring gene regulatory networks from single-cell RNA-seq data
9 |
10 | **It is designed for the case where RNA velocity or pseudotime data is available.**
11 | Here are some of the analyses that you can do with Velorama:
12 |
13 | - infer temporally-causal regulator-target links from RNA velocity cell-to-cell transition matrices.
14 | - infer over branching/merging trajectories using just pseudotime data without having to manually separate them.
15 | - estimate the relative speed of various regulators (i.e., how quickly they act on the target).
16 |
17 | Velorama offers support for both pseudotime and RNA velocity data.
18 |
19 |
20 | Velorama is based on a Granger causal approach and models the differentiation landscape as a directed acyclic graph (DAG) of cells, rather than as a linear total ordering required by previous approaches.
21 |
22 | =================
23 | Installation
24 | =================
25 |
26 | To install Velorama, follow the instructions below
27 |
28 | Using Conda/Mamba: ::
29 |
30 | git clone https://github.com/rs239/velorama.git
31 | cd ./velorama
32 | export SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True
33 | pip install .
34 |
35 | =================
36 | API Example Usage
37 | =================
38 |
39 | Velorama is currently offered as a command line tool that operates on ``AnnData`` objects. [Ed. Note: We are working on a clean API compatible with the scanpy ecosystem.] First, prepare an AnnData object of the dataset to be analyzed with Velorama. If you have RNA velocity data, make sure it is in the ``layers`` as required by `CellRank `_ and `scVelo `_, so that transition probabilities can be computed. We recommend performing standard single-cell normalization procedures (i.e. normalize counts to the median per-cell transcript count and log transform the normalized counts plus a pseudocount). Next, annotate the candidate regulators and targets in the ``var`` DataFrame of the ``AnnData`` object as follows. ::
40 |
41 | adata.var['is_reg'] = [n in regulator_genes for n in adata.var.index.values]
42 | adata.var['is_target'] = [n in target_genes for n in adata.var.index.values]
43 |
44 | Here ``regulator_genes`` is the set of gene symbols or IDs for the candidate regulators, while ``target_genes`` indicates the set of gene symbols or IDs for the candidate target genes. This AnnData object should be saved as ``{dataset}.h5ad``.
45 |
46 | We provide an example dataset here: `mouse endocrinogenesis `_. This dataset is from the scVelo vignette and is based on the study by `Bergen et al. (2020) `_.
47 |
48 | The below command runs Velorama, which saves the inferred Granger causal interactions and interaction speeds to a given directory. ::
49 |
50 | velorama -ds $dataset -dyn $dynamics -dev $device -l $L -hd $hidden -rd $rd
51 |
52 | Here, ``$dataset`` is the name of the dataset associated with the saved AnnData object. ``$dynamics`` can be "rna_velocity" or "pseudotime", depending on which data the user desires to use to construct the DAG. ``$device`` is chosen to be either "cuda" or "cpu". ``$rd`` is the name of the root directory that contains the saved AnnData object and where the outputs will be saved. Among the optional arguments, ``$L`` refers to the maximum number of lags to consider (default=5). ``$hidden`` indicates the dimensionality of the hidden layers (default=32).
53 |
54 |
55 | We encourage you to report issues at our `Github page`_ ; you can also create pull reports there to contribute your enhancements.
56 | If Velorama is useful for your research, please consider citing `bioRxiv (2022)`_.
57 |
58 | .. _bioRxiv (2022): https://www.biorxiv.org/content/10.1101/2022.10.18.512766v3
59 | .. _Github page: https://github.com/rs239/velorama
60 |
--------------------------------------------------------------------------------
/bin/velorama:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import velorama
4 |
5 | velorama.execute_cmdline()
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/datasets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/.DS_Store
--------------------------------------------------------------------------------
/datasets/dataset_A.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_A.h5ad
--------------------------------------------------------------------------------
/datasets/dataset_B.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_B.h5ad
--------------------------------------------------------------------------------
/datasets/dataset_C.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_C.h5ad
--------------------------------------------------------------------------------
/datasets/dataset_D.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rs239/velorama/44bcd5d52737a6c2266d397e903d10e98df8e4d5/datasets/dataset_D.h5ad
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "velorama"
3 | version = "0.0.3"
4 | description = "Gene regulatory network inference for RNA velocity and pseudotime data"
5 | authors = [
6 | { name = "Your Name" }
7 | ]
8 | license = { text = "MIT" }
9 | dependencies = [
10 | "numpy>=1.26.4",
11 | "scipy>=1.11.4",
12 | "pandas>=2.2.3",
13 | "scikit-learn>=1.5.2",
14 | "cellrank>=2.0.6",
15 | "scvelo>=0.3.2",
16 | "scanpy>=1.10.4",
17 | "anndata>=0.11.1",
18 | "torch==1.13.0",
19 | "ray[tune]==2.6.0",
20 | "matplotlib>=3.9.2",
21 | "h5py>=3.12.1",
22 | "tqdm>=4.67.1",
23 | "networkx>=3.4.2",
24 | "seaborn>=0.13.2",
25 | "statsmodels>=0.14.4",
26 | "schema_learn>=0.1.5.5",
27 | "umap-learn>=0.5.7"
28 | ]
29 | requires-python = ">=3.10"
30 |
31 | [build-system]
32 | requires = [
33 | "setuptools>=75.1.0",
34 | "wheel>=0.44.0",
35 | "pip>=24.2",
36 | "build"
37 | ]
38 | build-backend = "setuptools.build_meta"
39 |
40 | [tool.setuptools]
41 | packages = ["velorama"]
42 | zip-safe = false
43 |
44 | [project.scripts]
45 | velorama = "velorama:execute_cmdline"
--------------------------------------------------------------------------------
/velorama/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .models import VeloramaMLP
3 | from .train import train_model
4 | from .run import execute_cmdline
5 |
6 | __all__ = ['VeloramaMLP', 'train_model', 'execute_cmdline']
7 |
8 |
9 | import pkgutil, pathlib, importlib
10 |
11 | # from pkgutil import iter_modules
12 | # from pathlib import Path
13 | # from importlib import import_module
14 |
15 | # https://julienharbulot.com/python-dynamical-import.html
16 | # iterate through the modules in the current package
17 | #
18 | # # package_dir = pathlib.Path(__file__).resolve().parent
19 | # # for (_, module_name, _) in pkgutil.iter_modules([package_dir]):
20 | # # if 'datasets' in module_name:
21 | # # module = importlib.import_module(f"{__name__}.{module_name}")
22 |
--------------------------------------------------------------------------------
/velorama/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239)
4 | ### 2022
5 | ### MIT Licence
6 | ###
7 | ### Credit: parts of this code make use of the code from Tank et al.'s "Neural Granger Causality"
8 | ### - https://github.com/iancovert/Neural-GC
9 | ### - https://arxiv.org/abs/1802.05842
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | from .utils import activation_helper
15 |
16 | class VeloramaMLP(nn.Module):
17 |
18 | def __init__(self, n_targets, n_regs, lag, hidden, device, activation):
19 | super(VeloramaMLP, self).__init__()
20 | self.activation = activation_helper(activation)
21 | self.hidden = hidden
22 | self.lag = lag
23 | self.device = device
24 |
25 | # set up first layer
26 | layer = nn.Conv1d(n_regs, hidden[0]*n_targets, lag)
27 | modules = [layer]
28 |
29 | # set up subsequent layers
30 | for d_in, d_out in zip(hidden, hidden[1:] + [1]):
31 | layer = nn.Conv1d(d_in*n_targets, d_out*n_targets, 1, groups=n_targets)
32 | modules.append(layer)
33 |
34 | # Register parameters.
35 | self.layers = nn.ModuleList(modules)
36 |
37 | def forward(self,AX):
38 |
39 | # first layer
40 | ret = 0
41 | for i in range(self.lag):
42 | ret = ret + torch.matmul(AX[i], self.layers[0].weight[:, :, self.lag - 1 - i].T)
43 | ret = ret + self.layers[0].bias
44 |
45 | # subsequent layers
46 | ret = ret.T
47 | for i, fc in enumerate(self.layers):
48 | if i == 0:
49 | continue
50 | ret = self.activation(ret)
51 | ret = fc(ret)
52 |
53 | return ret.T
54 |
55 | def GC(self, threshold=True, ignore_lag=True):
56 | '''
57 | Extract learned Granger causality.
58 |
59 | Args:
60 | threshold: return norm of weights, or whether norm is nonzero.
61 | ignore_lag: if true, calculate norm of weights jointly for all lags.
62 |
63 | Returns:
64 | GC: (p x p) or (p x p x lag) matrix. In first case, entry (i, j)
65 | indicates whether variable j is Granger causal of variable i. In
66 | second case, entry (i, j, k) indicates whether it's Granger causal
67 | at lag k.
68 | '''
69 |
70 | W = self.layers[0].weight
71 | W = W.reshape(-1,self.hidden[0],W.shape[1],W.shape[2])
72 |
73 | if ignore_lag:
74 | GC = torch.norm(W, dim=(1, 3))
75 | else:
76 | GC = torch.norm(W, dim=1)
77 |
78 | if threshold:
79 | return (GC > 0).int()
80 | else:
81 | return GC
82 |
83 | class VeloramaMLPTarget(nn.Module):
84 |
85 | def __init__(self, n_targets, n_regs, lag, hidden, device, activation):
86 | super(VeloramaMLPTarget, self).__init__()
87 | self.activation = activation_helper(activation)
88 | self.hidden = hidden
89 | self.lag = lag
90 | self.device = device
91 |
92 | # set up first layer
93 | layer = nn.Conv1d(n_regs, hidden[0]*n_targets, lag)
94 | modules = [layer]
95 |
96 | # set up first layers (target variables)
97 | target_modules = [nn.Conv1d(n_targets, hidden[0]*n_targets,1,groups=n_targets,bias=False)
98 | for _ in range(lag)]
99 | self.target_layers = nn.ModuleList(target_modules)
100 |
101 | # set up subsequent layers
102 | for d_in, d_out in zip(hidden, hidden[1:] + [1]):
103 | layer = nn.Conv1d(d_in*n_targets, d_out*n_targets, 1, groups=n_targets)
104 | modules.append(layer)
105 |
106 | # Register parameters.
107 | self.layers = nn.ModuleList(modules)
108 |
109 | def forward(self,AX,AY):
110 |
111 | # first layer
112 | ret = 0
113 | for i in range(self.lag):
114 | ret = ret + torch.matmul(AX[i], self.layers[0].weight[:, :, self.lag - 1 - i].T)
115 | ret = ret + self.layers[0].bias
116 |
117 | # include contributions of target variables
118 | ret = ret.T
119 | for i in range(self.lag):
120 | ret = ret + self.target_layers[self.lag - 1 - i](AY[i].T)
121 |
122 | # subsequent layers
123 | for i, fc in enumerate(self.layers):
124 | if i == 0:
125 | continue
126 | ret = self.activation(ret)
127 | ret = fc(ret)
128 |
129 | return ret.T
130 |
131 | def GC(self, threshold=True, ignore_lag=True):
132 | '''
133 | Extract learned Granger causality.
134 |
135 | Args:
136 | threshold: return norm of weights, or whether norm is nonzero.
137 | ignore_lag: if true, calculate norm of weights jointly for all lags.
138 |
139 | Returns:
140 | GC: (p x p) or (p x p x lag) matrix. In first case, entry (i, j)
141 | indicates whether variable j is Granger causal of variable i. In
142 | second case, entry (i, j, k) indicates whether it's Granger causal
143 | at lag k.
144 | '''
145 |
146 | W = self.layers[0].weight
147 | W = W.reshape(-1,self.hidden[0],W.shape[1],W.shape[2])
148 |
149 | if ignore_lag:
150 | GC = torch.norm(W, dim=(1, 3))
151 | else:
152 | GC = torch.norm(W, dim=1)
153 |
154 | if threshold:
155 | return (GC > 0).int()
156 | else:
157 | return GC
158 |
159 | def prox_update(network, lam, lr, penalty):
160 | '''
161 | Perform in place proximal update on first layer weight matrix.
162 | Args:
163 | network: MLP network.
164 | lam: regularization parameter.
165 | lr: learning rate.
166 | penalty: one of GL (group lasso), GSGL (group sparse group lasso),
167 | H (hierarchical).
168 | '''
169 | W = network.layers[0].weight
170 | hidden, p, lag = W.shape
171 | if penalty == 'GL':
172 | norm = torch.norm(W, dim=(0, 2), keepdim=True)
173 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
174 | * torch.clamp(norm - (lr * lam), min=0.0))
175 | elif penalty == 'GSGL':
176 | norm = torch.norm(W, dim=0, keepdim=True)
177 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
178 | * torch.clamp(norm - (lr * lam), min=0.0))
179 | norm = torch.norm(W, dim=(0, 2), keepdim=True)
180 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
181 | * torch.clamp(norm - (lr * lam), min=0.0))
182 | elif penalty == 'H':
183 | # Lowest indices along third axis touch most lagged values.
184 | for i in range(lag):
185 | norm = torch.norm(W[:, :, :(i + 1)], dim=(0, 2), keepdim=True)
186 | W.data[:, :, :(i+1)] = (
187 | (W.data[:, :, :(i+1)] / torch.clamp(norm, min=(lr * lam)))
188 | * torch.clamp(norm - (lr * lam), min=0.0))
189 | else:
190 | raise ValueError('unsupported penalty: %s' % penalty)
191 |
192 | def prox_update_target(network, lam, lr, penalty):
193 | '''
194 | Perform in place proximal update on first layer weight matrix.
195 | Args:
196 | network: MLP network.
197 | lam: regularization parameter.
198 | lr: learning rate.
199 | penalty: one of GL (group lasso), GSGL (group sparse group lasso),
200 | H (hierarchical).
201 | '''
202 | W = network.layers[0].weight
203 | hidden, p, lag = W.shape
204 |
205 | if penalty == 'GL':
206 | norm = torch.norm(W, dim=(0, 2), keepdim=True)
207 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
208 | * torch.clamp(norm - (lr * lam), min=0.0))
209 | elif penalty == 'GSGL':
210 | norm = torch.norm(W, dim=0, keepdim=True)
211 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
212 | * torch.clamp(norm - (lr * lam), min=0.0))
213 | norm = torch.norm(W, dim=(0, 2), keepdim=True)
214 | W.data = ((W / torch.clamp(norm, min=(lr * lam)))
215 | * torch.clamp(norm - (lr * lam), min=0.0))
216 | elif penalty == 'H':
217 | # Lowest indices along third axis touch most lagged values.
218 | for i in range(lag):
219 | W = network.layers[0].weight
220 | target_W = torch.stack([network.target_layers[j].weight for j
221 | in range(len(network.target_layers))]).squeeze(-1)
222 | target_W = torch.swapaxes(torch.swapaxes(target_W,0,1),1,2)
223 | W_concat = torch.cat([W.data[:,:,:(i+1)],target_W.data[:,:,:(i+1)]],dim=1)
224 | norm = torch.norm(W_concat[:,:,:(i+1)], dim=(0, 2), keepdim=True)
225 |
226 | # update regulator weights
227 | W.data[:, :, :(i+1)] = (
228 | (W.data[:, :, :(i+1)] / torch.clamp(norm[:,0:-1], min=(lr * lam)))
229 | * torch.clamp(norm[:,0:-1] - (lr * lam), min=0.0))
230 |
231 | # update target weights
232 | for j in range(i+1):
233 | W_t = network.target_layers[j].weight
234 | W_t.data = ((W_t.data / torch.clamp(norm[:,-1:], min=(lr * lam)))
235 | * torch.clamp(norm[:,-1:] - (lr * lam), min=0.0))
236 | else:
237 | raise ValueError('unsupported penalty: %s' % penalty)
238 |
239 | # def prox_update_new(network, lam, lr, penalty):
240 | # '''
241 | # Perform in place proximal update on first layer weight matrix.
242 |
243 | # Args:
244 | # network: MLP network.
245 | # lam: regularization parameter.
246 | # lr: learning rate.
247 | # penalty: one of GL (group lasso), GSGL (group sparse group lasso),
248 | # H (hierarchical).
249 | # '''
250 |
251 | # W = network.layers[0].weight
252 | # hidden, p, lag = W.shape
253 |
254 | # W_copy = torch.clone(W)
255 | # W_copy = W.reshape(-1,network.hidden[0],W.shape[1],W.shape[2])
256 |
257 | # if penalty == 'GL':
258 |
259 | # norm = torch.norm(W_copy[:, :, :, :(i + 1)], dim=(1, 3), keepdim=True)
260 | # W_copy.data = W_copy
261 |
262 | # norm = torch.norm(W, dim=(0, 2), keepdim=True)
263 | # W.data = ((W / torch.clamp(norm, min=(lr * lam)))
264 | # * torch.clamp(norm - (lr * lam), min=0.0))
265 | # # elif penalty == 'GSGL':
266 | # # norm = torch.norm(W, dim=0, keepdim=True)
267 | # # W.data = ((W / torch.clamp(norm, min=(lr * lam)))
268 | # # * torch.clamp(norm - (lr * lam), min=0.0))
269 | # # norm = torch.norm(W, dim=(0, 2), keepdim=True)
270 | # # W.data = ((W / torch.clamp(norm, min=(lr * lam)))
271 | # # * torch.clamp(norm - (lr * lam), min=0.0))
272 | # elif penalty == 'H':
273 | # # Lowest indices along third axis touch most lagged values.
274 | # for i in range(lag):
275 | # norm = torch.norm(W_copy[:, :, :, :(i + 1)], dim=(1, 3), keepdim=True)
276 | # W_copy.data[:, :, :, :(i+1)] = (
277 | # (W_copy.data[:, :, :, :(i+1)] / torch.clamp(norm, min=(lr * lam)))
278 | # * torch.clamp(norm - (lr * lam), min=0.0))
279 | # W.data = W_copy.data.reshape(W.shape)
280 |
281 | # else:
282 | # raise ValueError('unsupported penalty: %s' % penalty)
283 |
284 | def regularize(network, lam, penalty):
285 | '''
286 | Calculate regularization term for first layer weight matrix.
287 | Args:
288 | network: MLP network.
289 | penalty: one of GL (group lasso), GSGL (group sparse group lasso),
290 | H (hierarchical).
291 | '''
292 | W = network.layers[0].weight
293 | hidden, p, lag = W.shape
294 | if penalty == 'GL':
295 | return lam * torch.sum(torch.norm(W, dim=(0, 2)))
296 | elif penalty == 'GSGL':
297 | return lam * (torch.sum(torch.norm(W, dim=(0, 2)))
298 | + torch.sum(torch.norm(W, dim=0)))
299 | elif penalty == 'H':
300 | # Lowest indices along third axis touch most lagged values.
301 | return lam * sum([torch.sum(torch.norm(W[:, :, :(i+1)], dim=(0, 2)))
302 | for i in range(lag)])
303 | else:
304 | raise ValueError('unsupported penalty: %s' % penalty)
305 |
306 | def regularize_target(network, lam, penalty):
307 | '''
308 | Calculate regularization term for first layer weight matrix.
309 | Args:
310 | network: MLP network.
311 | penalty: one of GL (group lasso), GSGL (group sparse group lasso),
312 | H (hierarchical).
313 | '''
314 | W = network.layers[0].weight
315 |
316 | hidden, p, lag = W.shape
317 | if penalty == 'GL':
318 | return lam * torch.sum(torch.norm(W, dim=(0, 2)))
319 | elif penalty == 'GSGL':
320 | return lam * (torch.sum(torch.norm(W, dim=(0, 2)))
321 | + torch.sum(torch.norm(W, dim=0)))
322 | elif penalty == 'H':
323 | # Lowest indices along third axis touch most lagged values.
324 | # return lam * sum([torch.sum(torch.norm(W[:, :, :(i+1)], dim=(0, 2)))
325 | # for i in range(lag)])
326 | target_W = torch.stack([network.target_layers[i].weight for i
327 | in range(len(network.target_layers))]).squeeze(-1)
328 | target_W = torch.swapaxes(torch.swapaxes(target_W,0,1),1,2)
329 |
330 | return lam * sum([torch.sum(torch.norm(torch.cat([W.data[:,:,:(i+1)],target_W.data[:,:,:(i+1)]],dim=1),
331 | dim=(0, 2))) for i in range(lag)])
332 | else:
333 | raise ValueError('unsupported penalty: %s' % penalty)
334 |
335 | # def regularize_new(network, lam, penalty):
336 | # '''
337 | # Calculate regularization term for first layer weight matrix.
338 |
339 | # Args:
340 | # network: MLP network.
341 | # penalty: one of GL (group lasso), GSGL (group sparse group lasso),
342 | # H (hierarchical).
343 | # '''
344 | # W = network.layers[0].weight
345 | # hidden, p, lag = W.shape
346 | # W = W.reshape(-1,network.hidden[0],W.shape[1],W.shape[2])
347 |
348 | # if penalty == 'GL':
349 | # return lam * torch.sum(torch.norm(W, dim=(1, 3))).sum()
350 | # # elif penalty == 'GSGL':
351 | # # return lam * (torch.sum(torch.norm(W, dim=(0, 3)))
352 | # # + torch.sum(torch.norm(W, dim=0)))
353 | # elif penalty == 'H':
354 | # # Lowest indices along third axis touch most lagged values.
355 | # return lam * sum([torch.sum(torch.norm(W[:, :, :, :(i+1)], dim=(1, 3)))
356 | # for i in range(lag)]).sum()
357 | # else:
358 | # raise ValueError('unsupported penalty: %s' % penalty)
359 |
360 | def ridge_regularize(network, lam):
361 | '''Apply ridge penalty at all subsequent layers.'''
362 | return lam * sum([torch.sum(fc.weight ** 2) for fc in network.layers[1:]])
363 |
364 |
365 | def restore_parameters(model, best_model):
366 | '''Move parameter values from best_model to model.'''
367 | for params, best_params in zip(model.parameters(), best_model.parameters()):
368 | params.data = best_params
369 |
--------------------------------------------------------------------------------
/velorama/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239)
4 | ### 2022
5 | ### MIT Licence
6 | ###
7 |
8 | import os
9 | import numpy as np
10 | import scanpy as sc
11 | import argparse
12 | import time
13 | import ray
14 | from ray import tune
15 | import statistics
16 | import scvelo as scv
17 | import pandas as pd
18 | import shutil
19 |
20 | from .models import *
21 | from .train import *
22 | from .utils import *
23 | from .utils import move_files
24 |
25 |
26 | def execute_cmdline():
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('-n','--name',dest='name',type=str,default='velorama_run',help='substring to have in our output files')
30 | parser.add_argument('-ds','--dataset',dest='dataset',type=str)
31 | parser.add_argument('-dyn','--dyn',dest='dynamics',type=str,default='pseudotime',
32 | choices=['pseudotime','rna_velocity','pseudotime_time','pseudotime_precomputed'])
33 | parser.add_argument('-ptloc','--ptloc',dest='ptloc',type=str,default='pseudotime')
34 | parser.add_argument('-dev','--device',dest='device',type=str,default='cpu')
35 | parser.add_argument('-s','--seed',dest='seed',type=int,default=0,help='Random seed. Set to 0,1,2 etc.')
36 | parser.add_argument('-lmr','--lam_ridge',dest='lam_ridge',type=float,default=0., help='Currenty unsupported')
37 | parser.add_argument('-p','--penalty',dest='penalty',type=str,default='H')
38 | parser.add_argument('-l','--lag',dest='lag',type=int,default=5)
39 | parser.add_argument('-hd', '--hidden',dest='hidden',type=int,default=32)
40 | parser.add_argument('-mi','--max_iter',dest='max_iter',type=int,default=1000)
41 | parser.add_argument('-lr','--learning_rate',dest='learning_rate',type=float,default=0.01)
42 | parser.add_argument('-pr','--proba',dest='proba',type=int,default=1)
43 | parser.add_argument('-ce','--check_every',dest='check_every',type=int,default=10)
44 | parser.add_argument('-rd','--root_dir',dest='root_dir',type=str)
45 | parser.add_argument('-sd','--save_dir',dest='save_dir',type=str,default='./results')
46 | parser.add_argument('-ls','--lam_start',dest='lam_start',type=float,default=-2)
47 | parser.add_argument('-le','--lam_end',type=float,default=1)
48 | parser.add_argument('-xn','--x_norm',dest='x_norm',type=str,default='zscore') # ,choices=['none','zscore','to_count:zscore','zscore_pca','maxmin','fill_zscore'])
49 | parser.add_argument('-nl','--num_lambdas',dest='num_lambdas',type=int,default=19)
50 | parser.add_argument('-rt','--reg_target',dest='reg_target',type=int,default=1)
51 | parser.add_argument('-nn','--n_neighbors',dest='n_neighbors',type=int,default=30)
52 | parser.add_argument('-vm','--velo_mode',dest='velo_mode',type=str,default='stochastic')
53 | parser.add_argument('-ts','--time_series',dest='time_series',type=int,default=0)
54 | parser.add_argument('-nc','--n_comps',dest='n_comps',type=int,default=50)
55 |
56 | args = parser.parse_args()
57 |
58 | if not os.path.exists(args.save_dir):
59 | os.mkdir(args.save_dir)
60 |
61 | adata = sc.read(os.path.join(args.root_dir,'{}.h5ad'.format(args.dataset)))
62 |
63 | if not args.reg_target:
64 | adata.var['is_target'] = True
65 | adata.var['is_reg'] = True
66 |
67 | target_genes = adata.var.index.values[adata.var['is_target']]
68 | reg_genes = adata.var.index.values[adata.var['is_reg']]
69 |
70 | if args.x_norm == 'zscore':
71 |
72 | print('Normalizing data: 0 mean, 1 SD')
73 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy()
74 | std = X_orig.std(0)
75 | std[std == 0] = 1
76 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std
77 | if 'De-noised' not in args.dataset:
78 | X = torch.clip(X,-5,5)
79 |
80 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy()
81 | std = Y_orig.std(0)
82 | std[std == 0] = 1
83 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std
84 | if 'De-noised' not in args.dataset:
85 | Y = torch.clip(Y,-5,5)
86 |
87 | elif args.x_norm == 'magic_zscore':
88 |
89 | import magic
90 | from scipy.sparse import issparse
91 |
92 | X = adata.X.toarray() if issparse(adata.X) else adata.X
93 | X = pd.DataFrame(X,columns=adata.var.index.values)
94 | magic_operator = magic.MAGIC()
95 | X_magic = magic_operator.fit_transform(X).astype(np.float32)
96 |
97 | X_orig = X_magic.values[:,adata.var['is_reg'].values]
98 | std = X_orig.std(0)
99 | std[std == 0] = 1
100 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std
101 | # X = torch.clip(X,-5,5)
102 |
103 | Y_orig = X_magic.values[:,adata.var['is_target'].values]
104 | std = Y_orig.std(0)
105 | std[std == 0] = 1
106 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std
107 |
108 | elif args.x_norm == 'fill_zscore':
109 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy()
110 | X_df = pd.DataFrame(X_orig)
111 | X_df[X_df < 1e-9] = np.nan
112 | X_df = X_df.fillna(X_df.median())
113 | X_orig = X_df.values
114 | std = X_orig.std(0)
115 | std[std == 0] = 1
116 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std
117 | # X = torch.clip(X,-5,5)
118 |
119 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy()
120 | Y_df = pd.DataFrame(Y_orig)
121 | Y_df[Y_df < 1e-9] = np.nan
122 | Y_df = Y_df.fillna(Y_df.median())
123 | Y_orig = Y_df.values
124 | std = Y_orig.std(0)
125 | std[std == 0] = 1
126 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std
127 | # Y = torch.clip(Y,-5,5)
128 |
129 |
130 | elif args.x_norm == 'to_count:zscore':
131 |
132 | print('Use counts: 0 mean, 1 SD')
133 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy()
134 | X_orig = 2**X_orig-1
135 | std = X_orig.std(0)
136 | std[std == 0] = 1
137 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std
138 |
139 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy()
140 | Y_orig = 2**Y_orig-1
141 | std = Y_orig.std(0)
142 | std[std == 0] = 1
143 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std
144 |
145 | elif args.x_norm == 'zscore_pca':
146 |
147 | print('PCA + normalizing data: 0 mean, 1 SD')
148 |
149 | sc.tl.pca(adata,n_comps=100)
150 | adata.X = adata.obsm['X_pca'].dot(adata.varm['PCs'].T)
151 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy()
152 | std = X_orig.std(0)
153 | std[std == 0] = 1
154 | X = torch.FloatTensor(X_orig-X_orig.mean(0))/std
155 | X = torch.clip(X,-5,5)
156 |
157 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy()
158 | std = Y_orig.std(0)
159 | std[std == 0] = 1
160 | Y = torch.FloatTensor(Y_orig-Y_orig.mean(0))/std
161 | Y = torch.clip(Y,-5,5)
162 |
163 | elif args.x_norm == 'maxmin':
164 |
165 | X_orig = adata[:,adata.var['is_reg']].X.toarray().copy()
166 | X_min = X_orig.min(0)
167 | X_max = X_orig.max(0)
168 | X = torch.FloatTensor((X_orig-X_min)/(X_max-X_min))
169 | X -= X.mean(0)
170 |
171 | Y_orig = adata[:,adata.var['is_target']].X.toarray().copy()
172 | Y_min = Y_orig.min(0)
173 | Y_max = Y_orig.max(0)
174 | Y = torch.FloatTensor((Y_orig-Y_min)/(Y_max-Y_min))
175 | Y -= Y.mean(0)
176 |
177 | else:
178 | assert args.x_norm == 'none'
179 | X = torch.FloatTensor(adata[:,adata.var['is_reg']].X.toarray())
180 | Y = torch.FloatTensor(adata[:,adata.var['is_target']].X.toarray())
181 |
182 | print('# of Regs: {}, # of Targets: {}'.format(X.shape[1],Y.shape[1]))
183 |
184 | print('Constructing DAG...')
185 |
186 | if 'De-noised' in args.dataset:
187 | sc.pp.normalize_total(adata, target_sum=1e4)
188 | sc.pp.log1p(adata)
189 |
190 | sc.pp.scale(adata)
191 | A = construct_dag(adata,dynamics=args.dynamics,ptloc=args.ptloc,proba=args.proba,
192 | n_neighbors=args.n_neighbors,velo_mode=args.velo_mode,
193 | use_time=args.time_series,n_comps=args.n_comps)
194 | A = torch.FloatTensor(A)
195 | AX = calculate_diffusion_lags(A,X,args.lag)
196 |
197 | if args.reg_target:
198 | AY = calculate_diffusion_lags(A,Y,args.lag)
199 | else:
200 | AY = None
201 |
202 | dir_name = '{}.seed{}.h{}.{}.lag{}.{}'.format(args.name,args.seed,args.hidden,args.penalty,args.lag,args.dynamics)
203 |
204 | if not os.path.exists(os.path.join(args.save_dir,dir_name)):
205 | os.mkdir(os.path.join(args.save_dir,dir_name))
206 |
207 | ray.init(object_store_memory=10**9)
208 |
209 | total_start = time.time()
210 | lam_list = np.logspace(args.lam_start, args.lam_end, num=args.num_lambdas).tolist()
211 |
212 | config = {'name': args.name,
213 | 'AX': AX,
214 | 'AY': AY,
215 | 'Y': Y,
216 | 'seed': args.seed,
217 | 'lr': args.learning_rate,
218 | 'lam': tune.grid_search(lam_list),
219 | 'lam_ridge': args.lam_ridge,
220 | 'penalty': args.penalty,
221 | 'lag': args.lag,
222 | 'hidden': [args.hidden],
223 | 'max_iter': args.max_iter,
224 | 'device': args.device,
225 | 'lookback': 5,
226 | 'check_every': args.check_every,
227 | 'verbose': True,
228 | 'dynamics': args.dynamics,
229 | 'results_dir': args.save_dir,
230 | 'dir_name': dir_name,
231 | 'reg_target': args.reg_target}
232 |
233 | ngpu = 0.2 if (args.device == 'cuda') else 0
234 | resources_per_trial = {"cpu": 1, "gpu": ngpu, "memory": 2 * 1024 * 1024 * 1024}
235 | analysis = tune.run(train_model,resources_per_trial=resources_per_trial,config=config,
236 | local_dir=os.path.join(args.root_dir,args.save_dir))
237 |
238 | target_dir = os.path.join(args.save_dir, dir_name)
239 | base_dir = args.save_dir
240 | move_files(base_dir, target_dir)
241 |
242 | # aggregate results
243 | lam_list = [np.round(lam,4) for lam in lam_list]
244 | all_lags = load_gc_interactions(args.name,args.save_dir,lam_list,hidden_dim=args.hidden,
245 | lag=args.lag,penalty=args.penalty,
246 | dynamics=args.dynamics,seed=args.seed,ignore_lag=False)
247 |
248 | gc_mat = estimate_interactions(all_lags,lag=args.lag)
249 | gc_df = pd.DataFrame(gc_mat.cpu().data.numpy(),index=target_genes,columns=reg_genes)
250 | gc_df.to_csv(os.path.join(args.save_dir,'{}.{}.velorama.interactions.tsv'.format(args.name,args.dynamics)),sep='\t')
251 |
252 | lag_mat = estimate_lags(all_lags,lag=args.lag)
253 | lag_df = pd.DataFrame(lag_mat.cpu().data.numpy(),index=target_genes,columns=reg_genes)
254 | lag_df.to_csv(os.path.join(args.save_dir,'{}.{}.velorama.lags.tsv'.format(args.name,args.dynamics)),sep='\t')
255 |
256 | print('Total time:',time.time()-total_start)
257 | np.savetxt(os.path.join(args.save_dir,dir_name + '.time.txt'),np.array([time.time()-total_start]))
258 |
259 | if __name__ == "__main__":
260 | execute_cmdline()
--------------------------------------------------------------------------------
/velorama/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | ### Authors: Anish Mudide (amudide), Alex Wu (alexw16), Rohit Singh (rs239)
4 | ### 2022
5 | ### MIT Licence
6 | ###
7 |
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from copy import deepcopy
13 | import time
14 |
15 | from .models import *
16 | from .utils import *
17 |
18 | def train_model(config):
19 |
20 | AX = config["AX"]
21 | AY = config["AY"]
22 |
23 | Y = config["Y"]
24 |
25 | name = config["name"]
26 |
27 | seed = config["seed"]
28 | lr = config["lr"]
29 | lam = config["lam"]
30 | lam_ridge = config["lam_ridge"]
31 | penalty = config["penalty"]
32 | lag = config["lag"]
33 | hidden = config["hidden"]
34 | max_iter = config["max_iter"]
35 | device = config["device"]
36 | lookback = config["lookback"]
37 | check_every = config["check_every"]
38 | verbose = config["verbose"]
39 | dynamics = config['dynamics']
40 |
41 | results_dir = config['results_dir']
42 | dir_name = config['dir_name']
43 | reg_target = config['reg_target']
44 |
45 | np.random.seed(seed)
46 | torch.manual_seed(seed)
47 |
48 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}'.format(name,seed,np.round(lam,4),
49 | hidden[0],penalty,lag,dynamics)
50 | gc_path1 = os.path.join(results_dir,dir_name,file_name + '.pt')
51 | gc_path2 = os.path.join(results_dir,dir_name,file_name + '.ignore_lag.pt')
52 |
53 | if not os.path.exists(gc_path1) and not os.path.exists(gc_path2):
54 |
55 | num_regs = AX.shape[-1]
56 | num_targets = Y.shape[1]
57 |
58 | AX = AX.to(device)
59 | Y = Y.to(device)
60 |
61 | if reg_target:
62 | vmlp = VeloramaMLPTarget(num_targets, num_regs, lag=lag, hidden=hidden,
63 | device=device, activation='relu')
64 | AY = AY.to(device)
65 | else:
66 | vmlp = VeloramaMLP(num_targets, num_regs, lag=lag, hidden=hidden,
67 | device=device, activation='relu')
68 |
69 | vmlp.to(device)
70 |
71 | '''Train model with ISTA.'''
72 | lag = vmlp.lag
73 | loss_fn = nn.MSELoss(reduction='none')
74 | train_loss_list = []
75 |
76 | # For early stopping.
77 | best_it = None
78 | best_loss = np.inf
79 | best_model = None
80 |
81 | # Calculate smooth error.
82 | if reg_target:
83 | preds = vmlp(AX,AY)
84 | else:
85 | preds = vmlp(AX)
86 |
87 | loss = loss_fn(preds,Y).mean(0).sum()
88 |
89 | # print('LOSS:---',loss)
90 |
91 | ridge = ridge_regularize(vmlp, lam_ridge)
92 | smooth = loss + ridge
93 |
94 | variable_usage_list = []
95 | loss_list = []
96 |
97 | # For early stopping.
98 | train_loss_list = []
99 | best_it = None
100 | best_loss = np.inf
101 | best_model = None
102 |
103 | for it in range(max_iter):
104 |
105 | start = time.time()
106 |
107 | # Take gradient step.
108 | smooth.backward()
109 | for param in vmlp.parameters():
110 | param.data = param - lr * param.grad
111 |
112 | # Take prox step.
113 | if lam > 0:
114 | # for net in vmlp.networks:
115 | if reg_target:
116 | prox_update_target(vmlp, lam, lr, penalty)
117 | else:
118 | prox_update(vmlp, lam, lr, penalty)
119 |
120 | vmlp.zero_grad()
121 |
122 | # Calculate loss for next iteration.
123 | if reg_target:
124 | preds = vmlp(AX,AY)
125 | else:
126 | preds = vmlp(AX)
127 |
128 | loss = loss_fn(preds,Y).mean(0).sum()
129 |
130 | ridge = ridge_regularize(vmlp, lam_ridge)
131 | smooth = loss + ridge
132 |
133 | # Check progress.
134 | if (it + 1) % check_every == 0:
135 |
136 | if reg_target:
137 | nonsmooth = regularize_target(vmlp, lam, penalty)
138 | else:
139 | nonsmooth = regularize(vmlp, lam, penalty)
140 | mean_loss = (smooth + nonsmooth).detach()/Y.shape[1]
141 |
142 | variable_usage = torch.mean(vmlp.GC(ignore_lag=False).float())
143 | variable_usage_list.append(variable_usage)
144 | loss_list.append(mean_loss)
145 |
146 | # Check for early stopping.
147 | if mean_loss < best_loss:
148 | best_loss = mean_loss
149 | best_it = it
150 | best_model = deepcopy(vmlp)
151 |
152 | if verbose:
153 | print('Lam={}: Iter {}, {} sec'.format(lam,it+1,np.round(time.time()-start,2)),
154 | '-----','Loss: %.2f' % mean_loss,', Variable usage = %.2f%%' % (100 * variable_usage)) # ,
155 | # '|||','%.3f' % loss_crit,'%.3f' % variable_usage_crit)
156 |
157 | elif (it - best_it) == lookback * check_every:
158 | if verbose:
159 | print('EARLY STOP: Lam={}, Iter {}'.format(lam,it + 1))
160 | break
161 |
162 | if verbose:
163 | print('Lam={}: Completed in {} iterations'.format(lam,it+1))
164 |
165 | # Restore best model.
166 | restore_parameters(vmlp, best_model)
167 |
168 | if not os.path.exists(results_dir):
169 | os.mkdir(results_dir)
170 | if not os.path.exists(os.path.join(results_dir,dir_name)):
171 | os.mkdir(os.path.join(results_dir,dir_name))
172 |
173 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}'.format(name,seed,np.round(lam,4),
174 | hidden[0],penalty,lag,dynamics)
175 | GC_lag = vmlp.GC(threshold=False, ignore_lag=False).cpu()
176 | torch.save(GC_lag, os.path.join(results_dir,dir_name,file_name + '.pt'))
177 |
178 | GC_lag = vmlp.GC(threshold=False, ignore_lag=True).cpu()
179 | torch.save(GC_lag, os.path.join(results_dir,dir_name,file_name + '.ignore_lag.pt'))
--------------------------------------------------------------------------------
/velorama/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | from scipy.stats import f
5 | from scipy.sparse import csr_matrix
6 | import statistics
7 | import scanpy as sc
8 | import scanpy.external as sce
9 | from anndata import AnnData
10 | import cellrank as cr
11 | import scvelo as scv
12 | import schema
13 | from torch.nn.functional import normalize
14 | import shutil
15 |
16 | from cellrank.kernels import VelocityKernel
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | def construct_dag(adata,dynamics='rna_velocity',ptloc='pseudotime',velo_mode='stochastic',proba=True,
22 | n_neighbors=30,n_comps=50,use_time=False):
23 |
24 | """Constructs the adjacency matrix for a DAG.
25 | Parameters
26 | ----------
27 | adata: 'AnnData'
28 | AnnData object with rows corresponding to cells and columns corresponding
29 | to genes.
30 | dynamics: {'rna_velocity','pseudotime','pseudotime_precomputed'}
31 | (default: rna_velocity)
32 | Dynamics used to orient and/or weight edges in the DAG of cells.
33 | If 'pseudotime_precomputed', the precomputed pseudotime values must be
34 | included as an observation category named 'pseudotime' in the included
35 | AnnData object (e.g., adata.obs['pseudotime'] = [list of float]).
36 | velo_mode: {'stochastic','deterministic','dynamical'} (default: 'stochastic')
37 | RNA velocity estimation using either the steady-state/deterministic,
38 | stochastic, or dynamical model of transcriptional dynamics from scVelo
39 | (Bergen et al., 2020).
40 | proba: 'bool' (default: True)
41 | Whether to use the transition probabilities from CellRank (Lange et al., 2022)
42 | in weighting the edges of the DAG or to discretize these probabilities by
43 | retaining only the top half of edges per cell.
44 | n_neighbors: 'int' (default: 15)
45 | Number of nearest neighbors to use in constructing a k-nearest
46 | neighbor graph for constructing a DAG if a custom DAG is not provided.
47 | n_comps: 'int' (default: 50)
48 | Number of principle components to compute and use for representing
49 | the gene expression profiles of cells.
50 | use_time: 'bool' (default: False)
51 | Whether to integrate time stamps in constructing the DAG. If True, time
52 | stamps must be included as an observation category named 'time' in the
53 | included AnnData object (e.g., adata.obs['time'] = [list of float]).
54 | """
55 |
56 | sc.tl.pca(adata, n_comps=n_comps, svd_solver='arpack')
57 | if use_time:
58 | sqp = schema.SchemaQP(min_desired_corr=0.9,mode='affine',
59 | params= {'decomposition_model': 'pca',
60 | 'num_top_components': adata.obsm['X_pca'].shape[1]})
61 | adata.obsm['X_rep'] = sqp.fit_transform(adata.obsm['X_pca'],[adata.obs['time']],
62 | ['numeric'],[1])
63 | else:
64 | adata.obsm['X_rep'] = adata.obsm['X_pca']
65 |
66 | if dynamics == 'pseudotime':
67 | A = construct_dag_pseudotime(adata.obsm['X_rep'],adata.uns['iroot'],
68 | n_neighbors=n_neighbors).T
69 |
70 | elif dynamics == 'pseudotime_precomputed':
71 | A = construct_dag_pseudotime(adata.obsm['X_rep'],adata.uns['iroot'],
72 | n_neighbors=n_neighbors,
73 | pseudotime_algo='precomputed',
74 | precomputed_pseudotime=adata.obs[ptloc].values).T
75 |
76 | elif dynamics == 'rna_velocity':
77 | scv.pp.moments(adata, n_neighbors=n_neighbors, use_rep='X_rep')
78 | if velo_mode == 'dynamical':
79 | scv.tl.recover_dynamics(adata)
80 | scv.tl.velocity(adata,mode=velo_mode)
81 | scv.tl.velocity_graph(adata)
82 | vk = VelocityKernel(adata).compute_transition_matrix()
83 | A = vk.transition_matrix
84 | A = A.toarray()
85 |
86 | # if proba is False (0), it won't use the probabilistic
87 | # transition matrix
88 | if not proba:
89 | for i in range(len(A)):
90 | nzeros = []
91 | for j in range(len(A)):
92 | if A[i][j] > 0:
93 | nzeros.append(A[i][j])
94 | m = statistics.median(nzeros)
95 | for j in range(len(A)):
96 | if A[i][j] < m:
97 | A[i][j] = 0
98 | else:
99 | A[i][j] = 1
100 |
101 | for i in range(len(A)):
102 | for j in range(len(A)):
103 | if A[i][j] > 0 and A[j][i] > 0 and A[i][j] > A[j][i]:
104 | A[j][i] = 0
105 |
106 | A = construct_S(torch.FloatTensor(A))
107 |
108 | return A
109 |
110 | def construct_dag_pseudotime(joint_feature_embeddings,iroot,n_neighbors=15,pseudotime_algo='dpt',
111 | precomputed_pseudotime=None):
112 |
113 | """Constructs the adjacency matrix for a DAG using pseudotime.
114 | Parameters
115 | ----------
116 | joint_feature_embeddings: 'numpy.ndarray' (default: None)
117 | Matrix of low dimensional embeddings with rows corresponding
118 | to observations and columns corresponding to feature embeddings
119 | for constructing a DAG if a custom DAG is not provided.
120 | iroot: 'int' (default: None)
121 | Index of root cell for inferring pseudotime for constructing a DAG
122 | if a custom DAG is not provided.
123 | n_neighbors: 'int' (default: 15)
124 | Number of nearest neighbors to use in constructing a k-nearest
125 | neighbor graph for constructing a DAG if a custom DAG is not provided.
126 | pseudotime_algo: {'dpt','palantir'}
127 | Pseudotime algorithm to use for constructing a DAG if a custom DAG
128 | is not provided. 'dpt' and 'palantir' perform the diffusion pseudotime
129 | (Haghverdi et al., 2016) and Palantir (Setty et al., 2019) algorithms,
130 | respectively.
131 | precomputed_pseudotime: 'numpy.ndarray' or List (default: None)
132 | Precomputed pseudotime values for all cells.
133 | """
134 |
135 | pseudotime,knn_graph = infer_knngraph_pseudotime(joint_feature_embeddings,iroot,
136 | n_neighbors=n_neighbors,pseudotime_algo=pseudotime_algo,
137 | precomputed_pseudotime=precomputed_pseudotime)
138 | dag_adjacency_matrix = dag_orient_edges(knn_graph,pseudotime)
139 |
140 | return dag_adjacency_matrix
141 |
142 | def infer_knngraph_pseudotime(joint_feature_embeddings,iroot,n_neighbors=15,pseudotime_algo='dpt',
143 | precomputed_pseudotime=None):
144 |
145 | adata = AnnData(joint_feature_embeddings)
146 | adata.obsm['X_joint'] = joint_feature_embeddings
147 | adata.uns['iroot'] = iroot
148 |
149 | if pseudotime_algo == 'dpt':
150 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors)
151 | sc.tl.dpt(adata)
152 | adata.obs['pseudotime'] = adata.obs['dpt_pseudotime'].values
153 | knn_graph = adata.obsp['distances'].astype(bool).astype(float)
154 | elif pseudotime_algo == 'precomputed':
155 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors)
156 | adata.obs['pseudotime'] = precomputed_pseudotime
157 | knn_graph = adata.obsp['distances'].astype(bool).astype(float)
158 | elif pseudotime_algo == 'palantir':
159 | sc.pp.neighbors(adata,use_rep='X_joint',n_neighbors=n_neighbors)
160 | sce.tl.palantir(adata, knn=n_neighbors,use_adjacency_matrix=True,
161 | distances_key='distances')
162 | pr_res = sce.tl.palantir_results(adata,
163 | early_cell=adata.obs.index.values[adata.uns['iroot']],
164 | ms_data='X_palantir_multiscale')
165 | adata.obs['pseudotime'] = pr_res.pseudotime
166 | knn_graph = adata.obsp['distances'].astype(bool).astype(float)
167 |
168 | return adata.obs['pseudotime'].values,knn_graph
169 |
170 | def dag_orient_edges(adjacency_matrix,pseudotime):
171 |
172 | A = adjacency_matrix.astype(bool).astype(float)
173 | D = -1*np.sign(pseudotime[:,None] - pseudotime).T
174 | D = (D == 1).astype(float)
175 | D = (A.toarray()*D).astype(bool).astype(float)
176 |
177 | return D
178 |
179 | def construct_S(D):
180 |
181 | S = D.clone()
182 | D_sum = D.sum(0)
183 | D_sum[D_sum == 0] = 1
184 |
185 | S = (S/D_sum)
186 | S = S.T
187 |
188 | return S
189 |
190 | def normalize_adjacency(D):
191 |
192 | S = D.clone()
193 | D_sum = D.sum(0)
194 | D_sum[D_sum == 0] = 1
195 |
196 | S = (S/D_sum)
197 |
198 | return S
199 |
200 | def seq2dag(N):
201 | A = torch.zeros(N, N)
202 | for i in range(N - 1):
203 | A[i][i + 1] = 1
204 | return A
205 |
206 | def activation_helper(activation, dim=None):
207 | if activation == 'sigmoid':
208 | act = nn.Sigmoid()
209 | elif activation == 'tanh':
210 | act = nn.Tanh()
211 | elif activation == 'relu':
212 | act = nn.ReLU()
213 | elif activation == 'leakyrelu':
214 | act = nn.LeakyReLU()
215 | elif activation is None:
216 | def act(x):
217 | return x
218 | else:
219 | raise ValueError('unsupported activation: %s' % activation)
220 | return act
221 |
222 | def calculate_diffusion_lags(A,X,lag):
223 |
224 | if A == "linear":
225 | A = seq2dag(X.shape[1])
226 |
227 | ax = []
228 | cur = A
229 | for _ in range(lag):
230 | ax.append(torch.matmul(cur, X))
231 | cur = torch.matmul(A, cur)
232 | for i in range(len(cur)):
233 | cur[i][i] = 0
234 |
235 | return torch.stack(ax)
236 |
237 | def load_gc_interactions(name,results_dir,lam_list,hidden_dim=16,lag=5,penalty='H',
238 | dynamics='rna_velocity',seed=0,ignore_lag=False):
239 |
240 | config_name = '{}.seed{}.h{}.{}.lag{}.{}'.format(name,seed,hidden_dim,penalty,lag,dynamics)
241 |
242 | all_lags = []
243 | for lam in lam_list:
244 | if ignore_lag:
245 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}.ignore_lag.pt'.format(name,seed,lam,hidden_dim,penalty,lag,dynamics)
246 | file_path = os.path.join(results_dir,config_name,file_name)
247 | gc_lag = torch.load(file_path)
248 | gc_lag = gc_lag.unsqueeze(-1)
249 | else:
250 | file_name = '{}.seed{}.lam{}.h{}.{}.lag{}.{}.pt'.format(name,seed,lam,hidden_dim,penalty,lag,dynamics)
251 | file_path = os.path.join(results_dir,config_name,file_name)
252 | gc_lag = torch.load(file_path)
253 | all_lags.append(gc_lag.detach())
254 |
255 | all_lags = torch.stack(all_lags)
256 |
257 | return all_lags
258 |
259 | def lor(x, y):
260 | return x + y
261 |
262 | def estimate_interactions(all_lags,lag=5,lower_thresh=0.01,upper_thresh=0.95,
263 | binarize=False,l2_norm=False):
264 |
265 | all_interactions = []
266 | for i in range(len(all_lags)):
267 | for j in range(lag):
268 |
269 | nnz_percent = (all_lags[i,:,:,j] != 0).float().mean().data.numpy()
270 |
271 | if nnz_percent > lower_thresh and nnz_percent < upper_thresh:
272 | interactions = all_lags[i,:,:,j]
273 |
274 | if l2_norm:
275 | interactions = normalize(interactions,dim=(0,1))
276 | if binarize:
277 | interactions = (interactions != 0).float()
278 |
279 | all_interactions.append(interactions)
280 | return torch.stack(all_interactions).mean(0)
281 |
282 | def estimate_lags(all_lags,lag,lower_thresh=0.01,upper_thresh=1.):
283 |
284 | retained_interactions = []
285 | for i in range(len(all_lags)):
286 | nnz_percent = (all_lags[i] != 0).float().mean().data.numpy()
287 | if nnz_percent > lower_thresh and nnz_percent < upper_thresh:
288 | retained_interactions.append(all_lags[i])
289 | retained_interactions = torch.stack(retained_interactions)
290 |
291 | est_lags = normalize(retained_interactions,p=1,dim=-1).mean(0)
292 | return (est_lags*(torch.arange(lag)+1)).sum(-1)
293 |
294 | def move_files(base_dir, target_dir):
295 | os.makedirs(target_dir, exist_ok=True) # Ensure the target directory exists
296 | # Walk through all directories recursively
297 | for root, dirs, files in os.walk(base_dir):
298 | if 'train' in root: # Check if ‘train’ is in the directory path
299 | for file in files:
300 | if file.endswith('.pt'): # Check if the file ends with ‘.pt’
301 | source_path = os.path.join(root, file)
302 | shutil.move(source_path, target_dir) # Move the file
303 | print(f"Moved: {source_path} to {target_dir}")
--------------------------------------------------------------------------------